Module numba_mpi.common

variables used across API implementation

Expand source code
"""variables used across API implementation"""

import ctypes
import os
from ctypes.util import find_library
from pathlib import Path

import numba
import numpy as np
import psutil
from mpi4py import MPI

# pylint: disable=protected-access
_MPI_Comm_World_ptr = MPI._addressof(MPI.COMM_WORLD)

_MPI_DTYPES = {
    np.dtype("uint8"): MPI._addressof(MPI.CHAR),
    np.dtype("int32"): MPI._addressof(MPI.INT32_T),
    np.dtype("int64"): MPI._addressof(MPI.INT64_T),
    np.dtype("float"): MPI._addressof(MPI.FLOAT),
    np.dtype("double"): MPI._addressof(MPI.DOUBLE),
    np.dtype("complex64"): MPI._addressof(MPI.C_FLOAT_COMPLEX),
    np.dtype("complex128"): MPI._addressof(MPI.C_DOUBLE_COMPLEX),
}

if MPI._sizeof(MPI.Comm) == ctypes.sizeof(ctypes.c_int):
    _MpiComm = ctypes.c_int
else:
    _MpiComm = ctypes.c_void_p

if MPI._sizeof(MPI.Datatype) == ctypes.sizeof(ctypes.c_int):
    _MpiDatatype = ctypes.c_int
    _MpiOp = ctypes.c_int
else:
    _MpiDatatype = ctypes.c_void_p
    _MpiOp = ctypes.c_void_p

if MPI._sizeof(MPI.Request) == ctypes.sizeof(ctypes.c_int):
    RequestType = np.intc
else:
    RequestType = np.uintp

# pylint: enable=protected-access
_MpiStatusPtr = ctypes.c_void_p
_MpiRequestPtr = ctypes.c_void_p


# TODO: add proper handling of status objects
@numba.njit
def create_status_buffer(count=1):
    """Helper function for creating numpy array storing pointers to MPI_Status results."""
    return np.empty(count * 5, dtype=np.intc)


LIB = None
names = ("mpich", "mpi", "msmpi", "impi")

ps = psutil.Process(os.getpid())
windows = os.name == "nt"
if hasattr(ps, "memory_maps"):
    for dll in ps.memory_maps():
        path = Path(dll.path)
        if windows or path.stem.startswith("lib"):
            for name in names:
                if name + ("" if windows else ".") in path.stem:
                    LIB = str(path)
                    break
else:
    for name in names:
        LIB = find_library(name)
        if LIB is not None:
            break

if LIB is None:
    raise RuntimeError("no MPI library found")

libmpi = ctypes.CDLL(LIB)

send_recv_args = [
    ctypes.c_void_p,
    ctypes.c_int,
    _MpiDatatype,
    ctypes.c_int,
    ctypes.c_int,
    _MpiComm,
]

send_recv_async_args = send_recv_args + [_MpiRequestPtr]

Functions

def create_status_buffer(count=1)

Helper function for creating numpy array storing pointers to MPI_Status results.

Expand source code
@numba.njit
def create_status_buffer(count=1):
    """Helper function for creating numpy array storing pointers to MPI_Status results."""
    return np.empty(count * 5, dtype=np.intc)