numba_mpi.api.allreduce

file contains MPI_Allreduce() implementations

  1"""file contains MPI_Allreduce() implementations"""
  2
  3import ctypes
  4from numbers import Number
  5
  6import numpy as np
  7from numba.core import types
  8from numba.extending import overload
  9
 10from numba_mpi.api.operator import Operator
 11from numba_mpi.common import _MPI_Comm_World_ptr, _MpiComm, _MpiDatatype, _MpiOp, libmpi
 12from numba_mpi.utils import _mpi_addr, _mpi_dtype
 13
 14_MPI_Allreduce = libmpi.MPI_Allreduce
 15_MPI_Allreduce.restype = ctypes.c_int
 16_MPI_Allreduce.argtypes = [
 17    ctypes.c_void_p,
 18    ctypes.c_void_p,
 19    ctypes.c_int,
 20    _MpiDatatype,
 21    _MpiOp,
 22    _MpiComm,
 23]
 24
 25
 26def allreduce(
 27    sendobj, recvobj, operator=Operator.SUM
 28):  # pylint: disable=unused-argument
 29    """wrapper for MPI_Allreduce
 30    Note that complex datatypes and user-defined functions are not properly supported.
 31    Returns integer status code (0 == MPI_SUCCESS)
 32    """
 33    if isinstance(sendobj, Number):
 34        # reduce a single number
 35        sendobj = np.array([sendobj])
 36        status = _MPI_Allreduce(
 37            sendobj.ctypes.data,
 38            recvobj.ctypes.data,
 39            sendobj.size,
 40            _mpi_dtype(sendobj),
 41            _mpi_addr(operator),
 42            _mpi_addr(_MPI_Comm_World_ptr),
 43        )
 44
 45    elif isinstance(sendobj, np.ndarray):
 46        # reduce an array
 47        sendobj = np.ascontiguousarray(sendobj)
 48        status = _MPI_Allreduce(
 49            sendobj.ctypes.data,
 50            recvobj.ctypes.data,
 51            sendobj.size,
 52            _mpi_dtype(sendobj),
 53            _mpi_addr(operator),
 54            _mpi_addr(_MPI_Comm_World_ptr),
 55        )
 56
 57    else:
 58        raise TypeError(f"Unsupported type {sendobj.__class__.__name__}")
 59
 60    return status
 61
 62
 63@overload(allreduce)
 64def ol_allreduce(
 65    sendobj, recvobj, operator=Operator.SUM
 66):  # pylint: disable=unused-argument
 67    """wrapper for MPI_Allreduce
 68    Note that complex datatypes and user-defined functions are not properly supported.
 69    Returns integer status code (0 == MPI_SUCCESS)
 70    """
 71    if isinstance(sendobj, types.Number):
 72        # reduce a single number
 73
 74        def impl(sendobj, recvobj, operator=Operator.SUM):
 75            sendobj = np.array([sendobj])
 76
 77            status = _MPI_Allreduce(
 78                sendobj.ctypes.data,
 79                recvobj.ctypes.data,
 80                sendobj.size,
 81                _mpi_dtype(sendobj),
 82                _mpi_addr(operator),
 83                _mpi_addr(_MPI_Comm_World_ptr),
 84            )
 85
 86            # The following no-op prevents numba from too aggressive optimizations
 87            # This looks like a bug in numba (tested for version 0.55)
 88            sendobj[0]  # pylint: disable=pointless-statement
 89
 90            return status
 91
 92    elif isinstance(sendobj, types.Array):
 93        # reduce an array
 94
 95        def impl(sendobj, recvobj, operator=Operator.SUM):
 96            sendobj = np.ascontiguousarray(sendobj)
 97
 98            status = _MPI_Allreduce(
 99                sendobj.ctypes.data,
100                recvobj.ctypes.data,
101                sendobj.size,
102                _mpi_dtype(sendobj),
103                _mpi_addr(operator),
104                _mpi_addr(_MPI_Comm_World_ptr),
105            )
106
107            return status
108
109    else:
110        raise TypeError(f"Unsupported type {sendobj.__class__.__name__}")
111
112    return impl
def allreduce(sendobj, recvobj, operator=<Operator.SUM: 140570140756608>):
27def allreduce(
28    sendobj, recvobj, operator=Operator.SUM
29):  # pylint: disable=unused-argument
30    """wrapper for MPI_Allreduce
31    Note that complex datatypes and user-defined functions are not properly supported.
32    Returns integer status code (0 == MPI_SUCCESS)
33    """
34    if isinstance(sendobj, Number):
35        # reduce a single number
36        sendobj = np.array([sendobj])
37        status = _MPI_Allreduce(
38            sendobj.ctypes.data,
39            recvobj.ctypes.data,
40            sendobj.size,
41            _mpi_dtype(sendobj),
42            _mpi_addr(operator),
43            _mpi_addr(_MPI_Comm_World_ptr),
44        )
45
46    elif isinstance(sendobj, np.ndarray):
47        # reduce an array
48        sendobj = np.ascontiguousarray(sendobj)
49        status = _MPI_Allreduce(
50            sendobj.ctypes.data,
51            recvobj.ctypes.data,
52            sendobj.size,
53            _mpi_dtype(sendobj),
54            _mpi_addr(operator),
55            _mpi_addr(_MPI_Comm_World_ptr),
56        )
57
58    else:
59        raise TypeError(f"Unsupported type {sendobj.__class__.__name__}")
60
61    return status

wrapper for MPI_Allreduce Note that complex datatypes and user-defined functions are not properly supported. Returns integer status code (0 == MPI_SUCCESS)

@overload(allreduce)
def ol_allreduce(sendobj, recvobj, operator=<Operator.SUM: 140570140756608>):
 64@overload(allreduce)
 65def ol_allreduce(
 66    sendobj, recvobj, operator=Operator.SUM
 67):  # pylint: disable=unused-argument
 68    """wrapper for MPI_Allreduce
 69    Note that complex datatypes and user-defined functions are not properly supported.
 70    Returns integer status code (0 == MPI_SUCCESS)
 71    """
 72    if isinstance(sendobj, types.Number):
 73        # reduce a single number
 74
 75        def impl(sendobj, recvobj, operator=Operator.SUM):
 76            sendobj = np.array([sendobj])
 77
 78            status = _MPI_Allreduce(
 79                sendobj.ctypes.data,
 80                recvobj.ctypes.data,
 81                sendobj.size,
 82                _mpi_dtype(sendobj),
 83                _mpi_addr(operator),
 84                _mpi_addr(_MPI_Comm_World_ptr),
 85            )
 86
 87            # The following no-op prevents numba from too aggressive optimizations
 88            # This looks like a bug in numba (tested for version 0.55)
 89            sendobj[0]  # pylint: disable=pointless-statement
 90
 91            return status
 92
 93    elif isinstance(sendobj, types.Array):
 94        # reduce an array
 95
 96        def impl(sendobj, recvobj, operator=Operator.SUM):
 97            sendobj = np.ascontiguousarray(sendobj)
 98
 99            status = _MPI_Allreduce(
100                sendobj.ctypes.data,
101                recvobj.ctypes.data,
102                sendobj.size,
103                _mpi_dtype(sendobj),
104                _mpi_addr(operator),
105                _mpi_addr(_MPI_Comm_World_ptr),
106            )
107
108            return status
109
110    else:
111        raise TypeError(f"Unsupported type {sendobj.__class__.__name__}")
112
113    return impl

wrapper for MPI_Allreduce Note that complex datatypes and user-defined functions are not properly supported. Returns integer status code (0 == MPI_SUCCESS)