numba_mpi.api.bcast
MPI_Bcast() implementation
1"""MPI_Bcast() implementation""" 2 3import ctypes 4 5import numba 6import numpy as np 7from numba.core import types 8from numba.core.extending import overload 9 10from numba_mpi.common import _MPI_Comm_World_ptr, libmpi 11from numba_mpi.utils import _mpi_addr, _mpi_dtype, _MpiComm, _MpiDatatype 12 13_MPI_Bcast = libmpi.MPI_Bcast 14_MPI_Bcast.restype = ctypes.c_int 15_MPI_Bcast.argtypes = [ 16 # pylint:disable=duplicate-code 17 ctypes.c_void_p, 18 ctypes.c_int, 19 _MpiDatatype, 20 ctypes.c_int, 21 _MpiComm, 22] 23 24 25@numba.njit() 26def impl_ndarray(data, root): 27 """MPI_Bcast implementation for ndarray datatype""" 28 assert data.flags.c_contiguous # TODO #60 29 30 status = _MPI_Bcast( 31 data.ctypes.data, 32 data.size, 33 _mpi_dtype(data), 34 root, 35 _mpi_addr(_MPI_Comm_World_ptr), 36 ) 37 return status 38 39 40def impl_chararray(data, root): 41 """MPI_Bcast implementation for chararray datatype""" 42 assert data.flags.c_contiguous # TODO #60 43 data = data.view(np.uint8) 44 45 status = _MPI_Bcast( 46 data.ctypes.data, 47 data.size, 48 _mpi_dtype(data), 49 root, 50 _mpi_addr(_MPI_Comm_World_ptr), 51 ) 52 return status 53 54 55def bcast(data, root): 56 """wrapper for MPI_Bcast(). Returns integer status code (0 == MPI_SUCCESS)""" 57 if data.dtype == np.dtype("S1"): 58 return impl_chararray(data, root) 59 if isinstance(data, np.ndarray): 60 return impl_ndarray(data, root) 61 62 raise TypeError(f"Unsupported type {data.__class__.__name__}") 63 64 65@overload(bcast) 66def __bcast_njit(data, root): 67 """wrapper for MPI_Bcast(). Returns integer status code (0 == MPI_SUCCESS)""" 68 if isinstance(data, types.Array): 69 70 def impl(data, root): 71 return impl_ndarray(data, root) 72 73 else: 74 raise TypeError(f"Unsupported type {data.__class__.__name__}") 75 76 return impl
@numba.njit()
def
impl_ndarray(data, root):
26@numba.njit() 27def impl_ndarray(data, root): 28 """MPI_Bcast implementation for ndarray datatype""" 29 assert data.flags.c_contiguous # TODO #60 30 31 status = _MPI_Bcast( 32 data.ctypes.data, 33 data.size, 34 _mpi_dtype(data), 35 root, 36 _mpi_addr(_MPI_Comm_World_ptr), 37 ) 38 return status
MPI_Bcast implementation for ndarray datatype
def
impl_chararray(data, root):
41def impl_chararray(data, root): 42 """MPI_Bcast implementation for chararray datatype""" 43 assert data.flags.c_contiguous # TODO #60 44 data = data.view(np.uint8) 45 46 status = _MPI_Bcast( 47 data.ctypes.data, 48 data.size, 49 _mpi_dtype(data), 50 root, 51 _mpi_addr(_MPI_Comm_World_ptr), 52 ) 53 return status
MPI_Bcast implementation for chararray datatype
def
bcast(data, root):
56def bcast(data, root): 57 """wrapper for MPI_Bcast(). Returns integer status code (0 == MPI_SUCCESS)""" 58 if data.dtype == np.dtype("S1"): 59 return impl_chararray(data, root) 60 if isinstance(data, np.ndarray): 61 return impl_ndarray(data, root) 62 63 raise TypeError(f"Unsupported type {data.__class__.__name__}")
wrapper for MPI_Bcast(). Returns integer status code (0 == MPI_SUCCESS)