numba_mpi.api.scatter_gather
MPI_Scatter() implementation
1"""MPI_Scatter() implementation""" 2 3import ctypes 4 5import numba 6 7from numba_mpi.common import _MPI_Comm_World_ptr, libmpi 8from numba_mpi.utils import _mpi_addr, _mpi_dtype, _MpiComm, _MpiDatatype 9 10_MPI_Scatter = libmpi.MPI_Scatter 11_MPI_Scatter.restype = ctypes.c_int 12_MPI_Scatter.argtypes = [ 13 # pylint:disable=duplicate-code 14 ctypes.c_void_p, # send_data 15 ctypes.c_int, # send_count 16 _MpiDatatype, # send_data_type 17 ctypes.c_void_p, # recv_data 18 ctypes.c_int, # recv_count 19 _MpiDatatype, # recv_data_type 20 ctypes.c_int, # root 21 _MpiComm, # communicator 22] 23 24_MPI_Gather = libmpi.MPI_Gather 25_MPI_Gather.restype = ctypes.c_int 26_MPI_Gather.argtypes = _MPI_Scatter.argtypes 27 28_MPI_Allgather = libmpi.MPI_Allgather 29_MPI_Allgather.restype = ctypes.c_int 30_MPI_Allgather.argtypes = [*_MPI_Scatter.argtypes[:-2], _MPI_Scatter.argtypes[-1]] 31 32 33@numba.njit() 34def scatter(send_data, recv_data, count, root): 35 """wrapper for MPI_Scatter(). Returns integer status code (0 == MPI_SUCCESS)""" 36 assert send_data.flags.c_contiguous # TODO #60 37 assert recv_data.flags.c_contiguous # TODO #60 38 39 status = _MPI_Scatter( 40 send_data.ctypes.data, 41 count, 42 _mpi_dtype(send_data), 43 recv_data.ctypes.data, 44 recv_data.size, 45 _mpi_dtype(recv_data), 46 root, 47 _mpi_addr(_MPI_Comm_World_ptr), 48 ) 49 return status 50 51 52@numba.njit() 53def gather(send_data, recv_data, count, root): 54 """wrapper for MPI_Gather(). Returns integer status code (0 == MPI_SUCCESS)""" 55 assert send_data.flags.c_contiguous # TODO #60 56 assert recv_data.flags.c_contiguous # TODO #60 57 58 status = _MPI_Gather( 59 send_data.ctypes.data, 60 send_data.size, 61 _mpi_dtype(send_data), 62 recv_data.ctypes.data, 63 count, 64 _mpi_dtype(recv_data), 65 root, 66 _mpi_addr(_MPI_Comm_World_ptr), 67 ) 68 return status 69 70 71@numba.njit() 72def allgather(send_data, recv_data, count): 73 """wrapper for MPI_Allgather(). Returns integer status code (0 == MPI_SUCCESS)""" 74 assert send_data.flags.c_contiguous # TODO #60 75 assert recv_data.flags.c_contiguous # TODO #60 76 77 status = _MPI_Allgather( 78 send_data.ctypes.data, 79 send_data.size, 80 _mpi_dtype(send_data), 81 recv_data.ctypes.data, 82 count, 83 _mpi_dtype(recv_data), 84 _mpi_addr(_MPI_Comm_World_ptr), 85 ) 86 return status
@numba.njit()
def
scatter(send_data, recv_data, count, root):
34@numba.njit() 35def scatter(send_data, recv_data, count, root): 36 """wrapper for MPI_Scatter(). Returns integer status code (0 == MPI_SUCCESS)""" 37 assert send_data.flags.c_contiguous # TODO #60 38 assert recv_data.flags.c_contiguous # TODO #60 39 40 status = _MPI_Scatter( 41 send_data.ctypes.data, 42 count, 43 _mpi_dtype(send_data), 44 recv_data.ctypes.data, 45 recv_data.size, 46 _mpi_dtype(recv_data), 47 root, 48 _mpi_addr(_MPI_Comm_World_ptr), 49 ) 50 return status
wrapper for MPI_Scatter(). Returns integer status code (0 == MPI_SUCCESS)
@numba.njit()
def
gather(send_data, recv_data, count, root):
53@numba.njit() 54def gather(send_data, recv_data, count, root): 55 """wrapper for MPI_Gather(). Returns integer status code (0 == MPI_SUCCESS)""" 56 assert send_data.flags.c_contiguous # TODO #60 57 assert recv_data.flags.c_contiguous # TODO #60 58 59 status = _MPI_Gather( 60 send_data.ctypes.data, 61 send_data.size, 62 _mpi_dtype(send_data), 63 recv_data.ctypes.data, 64 count, 65 _mpi_dtype(recv_data), 66 root, 67 _mpi_addr(_MPI_Comm_World_ptr), 68 ) 69 return status
wrapper for MPI_Gather(). Returns integer status code (0 == MPI_SUCCESS)
@numba.njit()
def
allgather(send_data, recv_data, count):
72@numba.njit() 73def allgather(send_data, recv_data, count): 74 """wrapper for MPI_Allgather(). Returns integer status code (0 == MPI_SUCCESS)""" 75 assert send_data.flags.c_contiguous # TODO #60 76 assert recv_data.flags.c_contiguous # TODO #60 77 78 status = _MPI_Allgather( 79 send_data.ctypes.data, 80 send_data.size, 81 _mpi_dtype(send_data), 82 recv_data.ctypes.data, 83 count, 84 _mpi_dtype(recv_data), 85 _mpi_addr(_MPI_Comm_World_ptr), 86 ) 87 return status
wrapper for MPI_Allgather(). Returns integer status code (0 == MPI_SUCCESS)