numba_mpi.api.reduce
file contains MPI_Reduce() implementations
1"""file contains MPI_Reduce() implementations""" 2 3import ctypes 4from numbers import Number 5 6import numpy as np 7from numba.core import types 8from numba.extending import overload 9 10from numba_mpi.common import _MPI_Comm_World_ptr, _MpiComm, _MpiDatatype, _MpiOp, libmpi 11from numba_mpi.utils import _mpi_addr, _mpi_dtype 12 13_MPI_Reduce = libmpi.MPI_Reduce 14_MPI_Reduce.restype = ctypes.c_int 15_MPI_Reduce.argtypes = [ 16 ctypes.c_void_p, 17 ctypes.c_void_p, 18 ctypes.c_int, 19 _MpiDatatype, 20 _MpiOp, 21 ctypes.c_int, 22 _MpiComm, 23] 24 25 26def reduce(sendobj, recvobj, operator, root): # pylint: disable=unused-argument 27 """wrapper for MPI_Reduce 28 Note that complex datatypes and user-defined functions are not properly supported. 29 Returns integer status code (0 == MPI_SUCCESS) 30 """ 31 if isinstance(sendobj, Number): 32 # reduce a single number 33 sendobj = np.array([sendobj]) 34 status = _MPI_Reduce( 35 sendobj.ctypes.data, 36 recvobj.ctypes.data, 37 sendobj.size, 38 _mpi_dtype(sendobj), 39 _mpi_addr(operator), 40 root, 41 _mpi_addr(_MPI_Comm_World_ptr), 42 ) 43 44 elif isinstance(sendobj, np.ndarray): 45 # reduce an array 46 sendobj = np.ascontiguousarray(sendobj) 47 status = _MPI_Reduce( 48 sendobj.ctypes.data, 49 recvobj.ctypes.data, 50 sendobj.size, 51 _mpi_dtype(sendobj), 52 _mpi_addr(operator), 53 root, 54 _mpi_addr(_MPI_Comm_World_ptr), 55 ) 56 57 else: 58 raise TypeError(f"Unsupported type {sendobj.__class__.__name__}") 59 60 return status 61 62 63@overload(reduce) 64def ol_reduce(sendobj, recvobj, operator, root): # pylint: disable=unused-argument 65 """wrapper for MPI_Reduce 66 Note that complex datatypes and user-defined functions are not properly supported. 67 Returns integer status code (0 == MPI_SUCCESS) 68 """ 69 if isinstance(sendobj, types.Number): 70 # reduce a single number 71 72 def impl(sendobj, recvobj, operator, root): 73 sendobj = np.array([sendobj]) 74 75 status = _MPI_Reduce( 76 sendobj.ctypes.data, 77 recvobj.ctypes.data, 78 sendobj.size, 79 _mpi_dtype(sendobj), 80 _mpi_addr(operator), 81 root, 82 _mpi_addr(_MPI_Comm_World_ptr), 83 ) 84 85 # The following no-op prevents numba from too aggressive optimizations 86 # This looks like a bug in numba (tested for version 0.55) 87 sendobj[0] # pylint: disable=pointless-statement 88 89 return status 90 91 elif isinstance(sendobj, types.Array): 92 # reduce an array 93 94 def impl(sendobj, recvobj, operator, root): 95 sendobj = np.ascontiguousarray(sendobj) 96 97 status = _MPI_Reduce( 98 sendobj.ctypes.data, 99 recvobj.ctypes.data, 100 sendobj.size, 101 _mpi_dtype(sendobj), 102 _mpi_addr(operator), 103 root, 104 _mpi_addr(_MPI_Comm_World_ptr), 105 ) 106 107 return status 108 109 else: 110 raise TypeError(f"Unsupported type {sendobj.__class__.__name__}") 111 112 return impl
def
reduce(sendobj, recvobj, operator, root):
27def reduce(sendobj, recvobj, operator, root): # pylint: disable=unused-argument 28 """wrapper for MPI_Reduce 29 Note that complex datatypes and user-defined functions are not properly supported. 30 Returns integer status code (0 == MPI_SUCCESS) 31 """ 32 if isinstance(sendobj, Number): 33 # reduce a single number 34 sendobj = np.array([sendobj]) 35 status = _MPI_Reduce( 36 sendobj.ctypes.data, 37 recvobj.ctypes.data, 38 sendobj.size, 39 _mpi_dtype(sendobj), 40 _mpi_addr(operator), 41 root, 42 _mpi_addr(_MPI_Comm_World_ptr), 43 ) 44 45 elif isinstance(sendobj, np.ndarray): 46 # reduce an array 47 sendobj = np.ascontiguousarray(sendobj) 48 status = _MPI_Reduce( 49 sendobj.ctypes.data, 50 recvobj.ctypes.data, 51 sendobj.size, 52 _mpi_dtype(sendobj), 53 _mpi_addr(operator), 54 root, 55 _mpi_addr(_MPI_Comm_World_ptr), 56 ) 57 58 else: 59 raise TypeError(f"Unsupported type {sendobj.__class__.__name__}") 60 61 return status
wrapper for MPI_Reduce Note that complex datatypes and user-defined functions are not properly supported. Returns integer status code (0 == MPI_SUCCESS)
@overload(reduce)
def
ol_reduce(sendobj, recvobj, operator, root):
64@overload(reduce) 65def ol_reduce(sendobj, recvobj, operator, root): # pylint: disable=unused-argument 66 """wrapper for MPI_Reduce 67 Note that complex datatypes and user-defined functions are not properly supported. 68 Returns integer status code (0 == MPI_SUCCESS) 69 """ 70 if isinstance(sendobj, types.Number): 71 # reduce a single number 72 73 def impl(sendobj, recvobj, operator, root): 74 sendobj = np.array([sendobj]) 75 76 status = _MPI_Reduce( 77 sendobj.ctypes.data, 78 recvobj.ctypes.data, 79 sendobj.size, 80 _mpi_dtype(sendobj), 81 _mpi_addr(operator), 82 root, 83 _mpi_addr(_MPI_Comm_World_ptr), 84 ) 85 86 # The following no-op prevents numba from too aggressive optimizations 87 # This looks like a bug in numba (tested for version 0.55) 88 sendobj[0] # pylint: disable=pointless-statement 89 90 return status 91 92 elif isinstance(sendobj, types.Array): 93 # reduce an array 94 95 def impl(sendobj, recvobj, operator, root): 96 sendobj = np.ascontiguousarray(sendobj) 97 98 status = _MPI_Reduce( 99 sendobj.ctypes.data, 100 recvobj.ctypes.data, 101 sendobj.size, 102 _mpi_dtype(sendobj), 103 _mpi_addr(operator), 104 root, 105 _mpi_addr(_MPI_Comm_World_ptr), 106 ) 107 108 return status 109 110 else: 111 raise TypeError(f"Unsupported type {sendobj.__class__.__name__}") 112 113 return impl
wrapper for MPI_Reduce Note that complex datatypes and user-defined functions are not properly supported. Returns integer status code (0 == MPI_SUCCESS)