numba_mpi.utils
helper functions used across API implementation
1"""helper functions used across API implementation""" 2 3import numba 4import numpy as np 5from numba.core import cgutils, types 6 7from .common import _MPI_DTYPES, _MpiComm, _MpiDatatype 8 9 10def _mpi_dtype(arr): 11 ptr = _get_dtype_numpy_to_mpi_ptr(arr) 12 return _MpiDatatype.from_address(ptr) 13 14 15@numba.extending.overload(_mpi_dtype) 16def _mpi_dtype_njit(arr): 17 mpi_dtype = _get_dtype_numba_to_mpi_ptr(arr) 18 19 # pylint: disable-next=unused-argument 20 def impl(arr): 21 return numba.carray( 22 # pylint: disable-next=no-value-for-parameter 23 _address_as_void_pointer(mpi_dtype), 24 shape=(1,), 25 dtype=np.intp, 26 )[0] 27 28 return impl 29 30 31def _get_dtype_numba_to_mpi_ptr(arr): 32 for np_dtype, mpi_ptr in _MPI_DTYPES.items(): 33 if arr.dtype == numba.from_dtype(np_dtype): 34 return mpi_ptr 35 raise NotImplementedError(f"Type: {arr.dtype}") 36 37 38def _get_dtype_numpy_to_mpi_ptr(arr): 39 for np_dtype, mpi_ptr in _MPI_DTYPES.items(): 40 if np.can_cast(arr.dtype, np_dtype, casting="equiv"): 41 return mpi_ptr 42 raise NotImplementedError(f"Type: {arr.dtype}") 43 44 45def _mpi_addr(ptr): 46 return _MpiComm.from_address(ptr) 47 48 49@numba.extending.overload(_mpi_addr) 50def _mpi_addr_njit(ptr): 51 def impl(ptr): 52 return numba.carray( 53 # pylint: disable-next=no-value-for-parameter 54 _address_as_void_pointer(ptr), 55 shape=(1,), 56 dtype=np.intp, 57 )[0] 58 59 return impl 60 61 62# https://stackoverflow.com/questions/61509903/how-to-pass-array-pointer-to-numba-function 63@numba.extending.intrinsic 64def _address_as_void_pointer(_, src): 65 """returns a void pointer from a given memory address""" 66 sig = types.voidptr(src) 67 68 def codegen(__, builder, ___, args): 69 return builder.inttoptr(args[0], cgutils.voidptr_t) 70 71 return sig, codegen