numba_mpi.utils

helper functions used across API implementation

 1"""helper functions used across API implementation"""
 2
 3import numba
 4import numpy as np
 5from numba.core import cgutils, types
 6
 7from .common import _MPI_DTYPES, _MpiComm, _MpiDatatype
 8
 9
10def _mpi_dtype(arr):
11    ptr = _get_dtype_numpy_to_mpi_ptr(arr)
12    return _MpiDatatype.from_address(ptr)
13
14
15@numba.extending.overload(_mpi_dtype)
16def _mpi_dtype_njit(arr):
17    mpi_dtype = _get_dtype_numba_to_mpi_ptr(arr)
18
19    # pylint: disable-next=unused-argument
20    def impl(arr):
21        return numba.carray(
22            # pylint: disable-next=no-value-for-parameter
23            _address_as_void_pointer(mpi_dtype),
24            shape=(1,),
25            dtype=np.intp,
26        )[0]
27
28    return impl
29
30
31def _get_dtype_numba_to_mpi_ptr(arr):
32    for np_dtype, mpi_ptr in _MPI_DTYPES.items():
33        if arr.dtype == numba.from_dtype(np_dtype):
34            return mpi_ptr
35    raise NotImplementedError(f"Type: {arr.dtype}")
36
37
38def _get_dtype_numpy_to_mpi_ptr(arr):
39    for np_dtype, mpi_ptr in _MPI_DTYPES.items():
40        if np.can_cast(arr.dtype, np_dtype, casting="equiv"):
41            return mpi_ptr
42    raise NotImplementedError(f"Type: {arr.dtype}")
43
44
45def _mpi_addr(ptr):
46    return _MpiComm.from_address(ptr)
47
48
49@numba.extending.overload(_mpi_addr)
50def _mpi_addr_njit(ptr):
51    def impl(ptr):
52        return numba.carray(
53            # pylint: disable-next=no-value-for-parameter
54            _address_as_void_pointer(ptr),
55            shape=(1,),
56            dtype=np.intp,
57        )[0]
58
59    return impl
60
61
62# https://stackoverflow.com/questions/61509903/how-to-pass-array-pointer-to-numba-function
63@numba.extending.intrinsic
64def _address_as_void_pointer(_, src):
65    """returns a void pointer from a given memory address"""
66    sig = types.voidptr(src)
67
68    def codegen(__, builder, ___, args):
69        return builder.inttoptr(args[0], cgutils.voidptr_t)
70
71    return sig, codegen