Module numba_mpi.utils
helper functions used across API implementation
Expand source code
"""helper functions used across API implementation"""
import numba
import numpy as np
from numba.core import cgutils, types
from .common import _MPI_DTYPES, _MpiComm, _MpiDatatype
def _mpi_dtype(arr):
ptr = _get_dtype_numpy_to_mpi_ptr(arr)
return _MpiDatatype.from_address(ptr)
@numba.extending.overload(_mpi_dtype)
def _mpi_dtype_njit(arr):
mpi_dtype = _get_dtype_numba_to_mpi_ptr(arr)
# pylint: disable-next=unused-argument
def impl(arr):
return numba.carray(
# pylint: disable-next=no-value-for-parameter
_address_as_void_pointer(mpi_dtype),
shape=(1,),
dtype=np.intp,
)[0]
return impl
def _get_dtype_numba_to_mpi_ptr(arr):
for np_dtype, mpi_ptr in _MPI_DTYPES.items():
if arr.dtype == numba.from_dtype(np_dtype):
return mpi_ptr
raise NotImplementedError(f"Type: {arr.dtype}")
def _get_dtype_numpy_to_mpi_ptr(arr):
for np_dtype, mpi_ptr in _MPI_DTYPES.items():
if np.can_cast(arr.dtype, np_dtype, casting="equiv"):
return mpi_ptr
raise NotImplementedError(f"Type: {arr.dtype}")
def _mpi_addr(ptr):
return _MpiComm.from_address(ptr)
@numba.extending.overload(_mpi_addr)
def _mpi_addr_njit(ptr):
def impl(ptr):
return numba.carray(
# pylint: disable-next=no-value-for-parameter
_address_as_void_pointer(ptr),
shape=(1,),
dtype=np.intp,
)[0]
return impl
# https://stackoverflow.com/questions/61509903/how-to-pass-array-pointer-to-numba-function
@numba.extending.intrinsic
def _address_as_void_pointer(_, src):
"""returns a void pointer from a given memory address"""
sig = types.voidptr(src)
def codegen(__, builder, ___, args):
return builder.inttoptr(args[0], cgutils.voidptr_t)
return sig, codegen