numba_mpi.api.reduce

file contains MPI_Reduce() implementations

  1"""file contains MPI_Reduce() implementations"""
  2
  3import ctypes
  4from numbers import Number
  5
  6import numpy as np
  7from numba.core import types
  8from numba.extending import overload
  9
 10from numba_mpi.common import _MPI_Comm_World_ptr, _MpiComm, _MpiDatatype, _MpiOp, libmpi
 11from numba_mpi.utils import _mpi_addr, _mpi_dtype
 12
 13_MPI_Reduce = libmpi.MPI_Reduce
 14_MPI_Reduce.restype = ctypes.c_int
 15_MPI_Reduce.argtypes = [
 16    ctypes.c_void_p,
 17    ctypes.c_void_p,
 18    ctypes.c_int,
 19    _MpiDatatype,
 20    _MpiOp,
 21    ctypes.c_int,
 22    _MpiComm,
 23]
 24
 25
 26def reduce(sendobj, recvobj, operator, root):  # pylint: disable=unused-argument
 27    """wrapper for MPI_Reduce
 28    Note that complex datatypes and user-defined functions are not properly supported.
 29    Returns integer status code (0 == MPI_SUCCESS)
 30    """
 31    if isinstance(sendobj, Number):
 32        # reduce a single number
 33        sendobj = np.array([sendobj])
 34        status = _MPI_Reduce(
 35            sendobj.ctypes.data,
 36            recvobj.ctypes.data,
 37            sendobj.size,
 38            _mpi_dtype(sendobj),
 39            _mpi_addr(operator),
 40            root,
 41            _mpi_addr(_MPI_Comm_World_ptr),
 42        )
 43
 44    elif isinstance(sendobj, np.ndarray):
 45        # reduce an array
 46        sendobj = np.ascontiguousarray(sendobj)
 47        status = _MPI_Reduce(
 48            sendobj.ctypes.data,
 49            recvobj.ctypes.data,
 50            sendobj.size,
 51            _mpi_dtype(sendobj),
 52            _mpi_addr(operator),
 53            root,
 54            _mpi_addr(_MPI_Comm_World_ptr),
 55        )
 56
 57    else:
 58        raise TypeError(f"Unsupported type {sendobj.__class__.__name__}")
 59
 60    return status
 61
 62
 63@overload(reduce)
 64def ol_reduce(sendobj, recvobj, operator, root):  # pylint: disable=unused-argument
 65    """wrapper for MPI_Reduce
 66    Note that complex datatypes and user-defined functions are not properly supported.
 67    Returns integer status code (0 == MPI_SUCCESS)
 68    """
 69    if isinstance(sendobj, types.Number):
 70        # reduce a single number
 71
 72        def impl(sendobj, recvobj, operator, root):
 73            sendobj = np.array([sendobj])
 74
 75            status = _MPI_Reduce(
 76                sendobj.ctypes.data,
 77                recvobj.ctypes.data,
 78                sendobj.size,
 79                _mpi_dtype(sendobj),
 80                _mpi_addr(operator),
 81                root,
 82                _mpi_addr(_MPI_Comm_World_ptr),
 83            )
 84
 85            # The following no-op prevents numba from too aggressive optimizations
 86            # This looks like a bug in numba (tested for version 0.55)
 87            sendobj[0]  # pylint: disable=pointless-statement
 88
 89            return status
 90
 91    elif isinstance(sendobj, types.Array):
 92        # reduce an array
 93
 94        def impl(sendobj, recvobj, operator, root):
 95            sendobj = np.ascontiguousarray(sendobj)
 96
 97            status = _MPI_Reduce(
 98                sendobj.ctypes.data,
 99                recvobj.ctypes.data,
100                sendobj.size,
101                _mpi_dtype(sendobj),
102                _mpi_addr(operator),
103                root,
104                _mpi_addr(_MPI_Comm_World_ptr),
105            )
106
107            return status
108
109    else:
110        raise TypeError(f"Unsupported type {sendobj.__class__.__name__}")
111
112    return impl
def reduce(sendobj, recvobj, operator, root):
27def reduce(sendobj, recvobj, operator, root):  # pylint: disable=unused-argument
28    """wrapper for MPI_Reduce
29    Note that complex datatypes and user-defined functions are not properly supported.
30    Returns integer status code (0 == MPI_SUCCESS)
31    """
32    if isinstance(sendobj, Number):
33        # reduce a single number
34        sendobj = np.array([sendobj])
35        status = _MPI_Reduce(
36            sendobj.ctypes.data,
37            recvobj.ctypes.data,
38            sendobj.size,
39            _mpi_dtype(sendobj),
40            _mpi_addr(operator),
41            root,
42            _mpi_addr(_MPI_Comm_World_ptr),
43        )
44
45    elif isinstance(sendobj, np.ndarray):
46        # reduce an array
47        sendobj = np.ascontiguousarray(sendobj)
48        status = _MPI_Reduce(
49            sendobj.ctypes.data,
50            recvobj.ctypes.data,
51            sendobj.size,
52            _mpi_dtype(sendobj),
53            _mpi_addr(operator),
54            root,
55            _mpi_addr(_MPI_Comm_World_ptr),
56        )
57
58    else:
59        raise TypeError(f"Unsupported type {sendobj.__class__.__name__}")
60
61    return status

wrapper for MPI_Reduce Note that complex datatypes and user-defined functions are not properly supported. Returns integer status code (0 == MPI_SUCCESS)

@overload(reduce)
def ol_reduce(sendobj, recvobj, operator, root):
 64@overload(reduce)
 65def ol_reduce(sendobj, recvobj, operator, root):  # pylint: disable=unused-argument
 66    """wrapper for MPI_Reduce
 67    Note that complex datatypes and user-defined functions are not properly supported.
 68    Returns integer status code (0 == MPI_SUCCESS)
 69    """
 70    if isinstance(sendobj, types.Number):
 71        # reduce a single number
 72
 73        def impl(sendobj, recvobj, operator, root):
 74            sendobj = np.array([sendobj])
 75
 76            status = _MPI_Reduce(
 77                sendobj.ctypes.data,
 78                recvobj.ctypes.data,
 79                sendobj.size,
 80                _mpi_dtype(sendobj),
 81                _mpi_addr(operator),
 82                root,
 83                _mpi_addr(_MPI_Comm_World_ptr),
 84            )
 85
 86            # The following no-op prevents numba from too aggressive optimizations
 87            # This looks like a bug in numba (tested for version 0.55)
 88            sendobj[0]  # pylint: disable=pointless-statement
 89
 90            return status
 91
 92    elif isinstance(sendobj, types.Array):
 93        # reduce an array
 94
 95        def impl(sendobj, recvobj, operator, root):
 96            sendobj = np.ascontiguousarray(sendobj)
 97
 98            status = _MPI_Reduce(
 99                sendobj.ctypes.data,
100                recvobj.ctypes.data,
101                sendobj.size,
102                _mpi_dtype(sendobj),
103                _mpi_addr(operator),
104                root,
105                _mpi_addr(_MPI_Comm_World_ptr),
106            )
107
108            return status
109
110    else:
111        raise TypeError(f"Unsupported type {sendobj.__class__.__name__}")
112
113    return impl

wrapper for MPI_Reduce Note that complex datatypes and user-defined functions are not properly supported. Returns integer status code (0 == MPI_SUCCESS)