numba_mpi.api.scatter_gather

MPI_Scatter() implementation

 1"""MPI_Scatter() implementation"""
 2
 3import ctypes
 4
 5import numba
 6
 7from numba_mpi.common import _MPI_Comm_World_ptr, libmpi
 8from numba_mpi.utils import _mpi_addr, _mpi_dtype, _MpiComm, _MpiDatatype
 9
10_MPI_Scatter = libmpi.MPI_Scatter
11_MPI_Scatter.restype = ctypes.c_int
12_MPI_Scatter.argtypes = [
13    # pylint:disable=duplicate-code
14    ctypes.c_void_p,  # send_data
15    ctypes.c_int,  # send_count
16    _MpiDatatype,  # send_data_type
17    ctypes.c_void_p,  # recv_data
18    ctypes.c_int,  # recv_count
19    _MpiDatatype,  # recv_data_type
20    ctypes.c_int,  # root
21    _MpiComm,  # communicator
22]
23
24_MPI_Gather = libmpi.MPI_Gather
25_MPI_Gather.restype = ctypes.c_int
26_MPI_Gather.argtypes = _MPI_Scatter.argtypes
27
28_MPI_Allgather = libmpi.MPI_Allgather
29_MPI_Allgather.restype = ctypes.c_int
30_MPI_Allgather.argtypes = [*_MPI_Scatter.argtypes[:-2], _MPI_Scatter.argtypes[-1]]
31
32
33@numba.njit()
34def scatter(send_data, recv_data, count, root):
35    """wrapper for MPI_Scatter(). Returns integer status code (0 == MPI_SUCCESS)"""
36    assert send_data.flags.c_contiguous  # TODO #60
37    assert recv_data.flags.c_contiguous  # TODO #60
38
39    status = _MPI_Scatter(
40        send_data.ctypes.data,
41        count,
42        _mpi_dtype(send_data),
43        recv_data.ctypes.data,
44        recv_data.size,
45        _mpi_dtype(recv_data),
46        root,
47        _mpi_addr(_MPI_Comm_World_ptr),
48    )
49    return status
50
51
52@numba.njit()
53def gather(send_data, recv_data, count, root):
54    """wrapper for MPI_Gather(). Returns integer status code (0 == MPI_SUCCESS)"""
55    assert send_data.flags.c_contiguous  # TODO #60
56    assert recv_data.flags.c_contiguous  # TODO #60
57
58    status = _MPI_Gather(
59        send_data.ctypes.data,
60        send_data.size,
61        _mpi_dtype(send_data),
62        recv_data.ctypes.data,
63        count,
64        _mpi_dtype(recv_data),
65        root,
66        _mpi_addr(_MPI_Comm_World_ptr),
67    )
68    return status
69
70
71@numba.njit()
72def allgather(send_data, recv_data, count):
73    """wrapper for MPI_Allgather(). Returns integer status code (0 == MPI_SUCCESS)"""
74    assert send_data.flags.c_contiguous  # TODO #60
75    assert recv_data.flags.c_contiguous  # TODO #60
76
77    status = _MPI_Allgather(
78        send_data.ctypes.data,
79        send_data.size,
80        _mpi_dtype(send_data),
81        recv_data.ctypes.data,
82        count,
83        _mpi_dtype(recv_data),
84        _mpi_addr(_MPI_Comm_World_ptr),
85    )
86    return status
@numba.njit()
def scatter(send_data, recv_data, count, root):
34@numba.njit()
35def scatter(send_data, recv_data, count, root):
36    """wrapper for MPI_Scatter(). Returns integer status code (0 == MPI_SUCCESS)"""
37    assert send_data.flags.c_contiguous  # TODO #60
38    assert recv_data.flags.c_contiguous  # TODO #60
39
40    status = _MPI_Scatter(
41        send_data.ctypes.data,
42        count,
43        _mpi_dtype(send_data),
44        recv_data.ctypes.data,
45        recv_data.size,
46        _mpi_dtype(recv_data),
47        root,
48        _mpi_addr(_MPI_Comm_World_ptr),
49    )
50    return status

wrapper for MPI_Scatter(). Returns integer status code (0 == MPI_SUCCESS)

@numba.njit()
def gather(send_data, recv_data, count, root):
53@numba.njit()
54def gather(send_data, recv_data, count, root):
55    """wrapper for MPI_Gather(). Returns integer status code (0 == MPI_SUCCESS)"""
56    assert send_data.flags.c_contiguous  # TODO #60
57    assert recv_data.flags.c_contiguous  # TODO #60
58
59    status = _MPI_Gather(
60        send_data.ctypes.data,
61        send_data.size,
62        _mpi_dtype(send_data),
63        recv_data.ctypes.data,
64        count,
65        _mpi_dtype(recv_data),
66        root,
67        _mpi_addr(_MPI_Comm_World_ptr),
68    )
69    return status

wrapper for MPI_Gather(). Returns integer status code (0 == MPI_SUCCESS)

@numba.njit()
def allgather(send_data, recv_data, count):
72@numba.njit()
73def allgather(send_data, recv_data, count):
74    """wrapper for MPI_Allgather(). Returns integer status code (0 == MPI_SUCCESS)"""
75    assert send_data.flags.c_contiguous  # TODO #60
76    assert recv_data.flags.c_contiguous  # TODO #60
77
78    status = _MPI_Allgather(
79        send_data.ctypes.data,
80        send_data.size,
81        _mpi_dtype(send_data),
82        recv_data.ctypes.data,
83        count,
84        _mpi_dtype(recv_data),
85        _mpi_addr(_MPI_Comm_World_ptr),
86    )
87    return status

wrapper for MPI_Allgather(). Returns integer status code (0 == MPI_SUCCESS)