numba_mpi.api.bcast

MPI_Bcast() implementation

 1"""MPI_Bcast() implementation"""
 2
 3import ctypes
 4
 5import numba
 6import numpy as np
 7from numba.core import types
 8from numba.core.extending import overload
 9
10from numba_mpi.common import _MPI_Comm_World_ptr, libmpi
11from numba_mpi.utils import _mpi_addr, _mpi_dtype, _MpiComm, _MpiDatatype
12
13_MPI_Bcast = libmpi.MPI_Bcast
14_MPI_Bcast.restype = ctypes.c_int
15_MPI_Bcast.argtypes = [
16    # pylint:disable=duplicate-code
17    ctypes.c_void_p,
18    ctypes.c_int,
19    _MpiDatatype,
20    ctypes.c_int,
21    _MpiComm,
22]
23
24
25@numba.njit()
26def impl_ndarray(data, root):
27    """MPI_Bcast implementation for ndarray datatype"""
28    assert data.flags.c_contiguous  # TODO #60
29
30    status = _MPI_Bcast(
31        data.ctypes.data,
32        data.size,
33        _mpi_dtype(data),
34        root,
35        _mpi_addr(_MPI_Comm_World_ptr),
36    )
37    return status
38
39
40def impl_chararray(data, root):
41    """MPI_Bcast implementation for chararray datatype"""
42    assert data.flags.c_contiguous  # TODO #60
43    data = data.view(np.uint8)
44
45    status = _MPI_Bcast(
46        data.ctypes.data,
47        data.size,
48        _mpi_dtype(data),
49        root,
50        _mpi_addr(_MPI_Comm_World_ptr),
51    )
52    return status
53
54
55def bcast(data, root):
56    """wrapper for MPI_Bcast(). Returns integer status code (0 == MPI_SUCCESS)"""
57    if data.dtype == np.dtype("S1"):
58        return impl_chararray(data, root)
59    if isinstance(data, np.ndarray):
60        return impl_ndarray(data, root)
61
62    raise TypeError(f"Unsupported type {data.__class__.__name__}")
63
64
65@overload(bcast)
66def __bcast_njit(data, root):
67    """wrapper for MPI_Bcast(). Returns integer status code (0 == MPI_SUCCESS)"""
68    if isinstance(data, types.Array):
69
70        def impl(data, root):
71            return impl_ndarray(data, root)
72
73    else:
74        raise TypeError(f"Unsupported type {data.__class__.__name__}")
75
76    return impl
@numba.njit()
def impl_ndarray(data, root):
26@numba.njit()
27def impl_ndarray(data, root):
28    """MPI_Bcast implementation for ndarray datatype"""
29    assert data.flags.c_contiguous  # TODO #60
30
31    status = _MPI_Bcast(
32        data.ctypes.data,
33        data.size,
34        _mpi_dtype(data),
35        root,
36        _mpi_addr(_MPI_Comm_World_ptr),
37    )
38    return status

MPI_Bcast implementation for ndarray datatype

def impl_chararray(data, root):
41def impl_chararray(data, root):
42    """MPI_Bcast implementation for chararray datatype"""
43    assert data.flags.c_contiguous  # TODO #60
44    data = data.view(np.uint8)
45
46    status = _MPI_Bcast(
47        data.ctypes.data,
48        data.size,
49        _mpi_dtype(data),
50        root,
51        _mpi_addr(_MPI_Comm_World_ptr),
52    )
53    return status

MPI_Bcast implementation for chararray datatype

def bcast(data, root):
56def bcast(data, root):
57    """wrapper for MPI_Bcast(). Returns integer status code (0 == MPI_SUCCESS)"""
58    if data.dtype == np.dtype("S1"):
59        return impl_chararray(data, root)
60    if isinstance(data, np.ndarray):
61        return impl_ndarray(data, root)
62
63    raise TypeError(f"Unsupported type {data.__class__.__name__}")

wrapper for MPI_Bcast(). Returns integer status code (0 == MPI_SUCCESS)