numba_mpi.api.allreduce
file contains MPI_Allreduce() implementations
1"""file contains MPI_Allreduce() implementations""" 2 3import ctypes 4from numbers import Number 5 6import numpy as np 7from numba.core import types 8from numba.extending import overload 9 10from numba_mpi.api.operator import Operator 11from numba_mpi.common import _MPI_Comm_World_ptr, _MpiComm, _MpiDatatype, _MpiOp, libmpi 12from numba_mpi.utils import _mpi_addr, _mpi_dtype 13 14_MPI_Allreduce = libmpi.MPI_Allreduce 15_MPI_Allreduce.restype = ctypes.c_int 16_MPI_Allreduce.argtypes = [ 17 ctypes.c_void_p, 18 ctypes.c_void_p, 19 ctypes.c_int, 20 _MpiDatatype, 21 _MpiOp, 22 _MpiComm, 23] 24 25 26def allreduce( 27 sendobj, recvobj, operator=Operator.SUM 28): # pylint: disable=unused-argument 29 """wrapper for MPI_Allreduce 30 Note that complex datatypes and user-defined functions are not properly supported. 31 Returns integer status code (0 == MPI_SUCCESS) 32 """ 33 if isinstance(sendobj, Number): 34 # reduce a single number 35 sendobj = np.array([sendobj]) 36 status = _MPI_Allreduce( 37 sendobj.ctypes.data, 38 recvobj.ctypes.data, 39 sendobj.size, 40 _mpi_dtype(sendobj), 41 _mpi_addr(operator), 42 _mpi_addr(_MPI_Comm_World_ptr), 43 ) 44 45 elif isinstance(sendobj, np.ndarray): 46 # reduce an array 47 sendobj = np.ascontiguousarray(sendobj) 48 status = _MPI_Allreduce( 49 sendobj.ctypes.data, 50 recvobj.ctypes.data, 51 sendobj.size, 52 _mpi_dtype(sendobj), 53 _mpi_addr(operator), 54 _mpi_addr(_MPI_Comm_World_ptr), 55 ) 56 57 else: 58 raise TypeError(f"Unsupported type {sendobj.__class__.__name__}") 59 60 return status 61 62 63@overload(allreduce) 64def ol_allreduce( 65 sendobj, recvobj, operator=Operator.SUM 66): # pylint: disable=unused-argument 67 """wrapper for MPI_Allreduce 68 Note that complex datatypes and user-defined functions are not properly supported. 69 Returns integer status code (0 == MPI_SUCCESS) 70 """ 71 if isinstance(sendobj, types.Number): 72 # reduce a single number 73 74 def impl(sendobj, recvobj, operator=Operator.SUM): 75 sendobj = np.array([sendobj]) 76 77 status = _MPI_Allreduce( 78 sendobj.ctypes.data, 79 recvobj.ctypes.data, 80 sendobj.size, 81 _mpi_dtype(sendobj), 82 _mpi_addr(operator), 83 _mpi_addr(_MPI_Comm_World_ptr), 84 ) 85 86 # The following no-op prevents numba from too aggressive optimizations 87 # This looks like a bug in numba (tested for version 0.55) 88 sendobj[0] # pylint: disable=pointless-statement 89 90 return status 91 92 elif isinstance(sendobj, types.Array): 93 # reduce an array 94 95 def impl(sendobj, recvobj, operator=Operator.SUM): 96 sendobj = np.ascontiguousarray(sendobj) 97 98 status = _MPI_Allreduce( 99 sendobj.ctypes.data, 100 recvobj.ctypes.data, 101 sendobj.size, 102 _mpi_dtype(sendobj), 103 _mpi_addr(operator), 104 _mpi_addr(_MPI_Comm_World_ptr), 105 ) 106 107 return status 108 109 else: 110 raise TypeError(f"Unsupported type {sendobj.__class__.__name__}") 111 112 return impl
def
allreduce(sendobj, recvobj, operator=<Operator.SUM: 140570140756608>):
27def allreduce( 28 sendobj, recvobj, operator=Operator.SUM 29): # pylint: disable=unused-argument 30 """wrapper for MPI_Allreduce 31 Note that complex datatypes and user-defined functions are not properly supported. 32 Returns integer status code (0 == MPI_SUCCESS) 33 """ 34 if isinstance(sendobj, Number): 35 # reduce a single number 36 sendobj = np.array([sendobj]) 37 status = _MPI_Allreduce( 38 sendobj.ctypes.data, 39 recvobj.ctypes.data, 40 sendobj.size, 41 _mpi_dtype(sendobj), 42 _mpi_addr(operator), 43 _mpi_addr(_MPI_Comm_World_ptr), 44 ) 45 46 elif isinstance(sendobj, np.ndarray): 47 # reduce an array 48 sendobj = np.ascontiguousarray(sendobj) 49 status = _MPI_Allreduce( 50 sendobj.ctypes.data, 51 recvobj.ctypes.data, 52 sendobj.size, 53 _mpi_dtype(sendobj), 54 _mpi_addr(operator), 55 _mpi_addr(_MPI_Comm_World_ptr), 56 ) 57 58 else: 59 raise TypeError(f"Unsupported type {sendobj.__class__.__name__}") 60 61 return status
wrapper for MPI_Allreduce Note that complex datatypes and user-defined functions are not properly supported. Returns integer status code (0 == MPI_SUCCESS)
@overload(allreduce)
def
ol_allreduce(sendobj, recvobj, operator=<Operator.SUM: 140570140756608>):
64@overload(allreduce) 65def ol_allreduce( 66 sendobj, recvobj, operator=Operator.SUM 67): # pylint: disable=unused-argument 68 """wrapper for MPI_Allreduce 69 Note that complex datatypes and user-defined functions are not properly supported. 70 Returns integer status code (0 == MPI_SUCCESS) 71 """ 72 if isinstance(sendobj, types.Number): 73 # reduce a single number 74 75 def impl(sendobj, recvobj, operator=Operator.SUM): 76 sendobj = np.array([sendobj]) 77 78 status = _MPI_Allreduce( 79 sendobj.ctypes.data, 80 recvobj.ctypes.data, 81 sendobj.size, 82 _mpi_dtype(sendobj), 83 _mpi_addr(operator), 84 _mpi_addr(_MPI_Comm_World_ptr), 85 ) 86 87 # The following no-op prevents numba from too aggressive optimizations 88 # This looks like a bug in numba (tested for version 0.55) 89 sendobj[0] # pylint: disable=pointless-statement 90 91 return status 92 93 elif isinstance(sendobj, types.Array): 94 # reduce an array 95 96 def impl(sendobj, recvobj, operator=Operator.SUM): 97 sendobj = np.ascontiguousarray(sendobj) 98 99 status = _MPI_Allreduce( 100 sendobj.ctypes.data, 101 recvobj.ctypes.data, 102 sendobj.size, 103 _mpi_dtype(sendobj), 104 _mpi_addr(operator), 105 _mpi_addr(_MPI_Comm_World_ptr), 106 ) 107 108 return status 109 110 else: 111 raise TypeError(f"Unsupported type {sendobj.__class__.__name__}") 112 113 return impl
wrapper for MPI_Allreduce Note that complex datatypes and user-defined functions are not properly supported. Returns integer status code (0 == MPI_SUCCESS)