From 4b8da14b605716181ec36d80859957f192207b2e Mon Sep 17 00:00:00 2001 From: Joseph Schuchart Date: Thu, 25 Apr 2024 10:53:07 -0400 Subject: [PATCH 01/12] Add CUDA/HIP implementations of reduction operators The operators are generated from macros. Function pointers to kernel launch functions are stored inside the ompi_op_t as a pointer to a struct that is filled if accelerator support is available. The ompi_op* API is extended to include versions taking streams and device IDs to allow enqueuing operators on streams. The old functions map to the stream versions with a NULL stream. Signed-off-by: Joseph Schuchart --- config/opal_check_cudart.m4 | 120 ++ ompi/mca/op/base/op_base_frame.c | 4 +- ompi/mca/op/base/op_base_op_select.c | 60 +- ompi/mca/op/cuda/Makefile.am | 84 + ompi/mca/op/cuda/configure.m4 | 41 + ompi/mca/op/cuda/help-ompi-mca-op-cuda.txt | 15 + ompi/mca/op/cuda/op_cuda.h | 80 + ompi/mca/op/cuda/op_cuda_component.c | 195 ++ ompi/mca/op/cuda/op_cuda_functions.c | 1897 ++++++++++++++++++++ ompi/mca/op/cuda/op_cuda_impl.cu | 1080 +++++++++++ ompi/mca/op/cuda/op_cuda_impl.h | 695 +++++++ ompi/mca/op/op.h | 66 +- ompi/mca/op/rocm/Makefile.am | 82 + ompi/mca/op/rocm/configure.m4 | 36 + ompi/mca/op/rocm/help-ompi-mca-op-rocm.txt | 15 + ompi/mca/op/rocm/op_rocm.h | 79 + ompi/mca/op/rocm/op_rocm_component.c | 207 +++ ompi/mca/op/rocm/op_rocm_functions.c | 1897 ++++++++++++++++++++ ompi/mca/op/rocm/op_rocm_impl.h | 706 ++++++++ ompi/mca/op/rocm/op_rocm_impl.hip | 1085 +++++++++++ ompi/op/Makefile.am | 2 + ompi/op/help-ompi-op.txt | 15 + ompi/op/op.c | 16 + ompi/op/op.h | 249 ++- 24 files changed, 8678 insertions(+), 48 deletions(-) create mode 100644 config/opal_check_cudart.m4 create mode 100644 ompi/mca/op/cuda/Makefile.am create mode 100644 ompi/mca/op/cuda/configure.m4 create mode 100644 ompi/mca/op/cuda/help-ompi-mca-op-cuda.txt create mode 100644 ompi/mca/op/cuda/op_cuda.h create mode 100644 ompi/mca/op/cuda/op_cuda_component.c create mode 100644 ompi/mca/op/cuda/op_cuda_functions.c create mode 100644 ompi/mca/op/cuda/op_cuda_impl.cu create mode 100644 ompi/mca/op/cuda/op_cuda_impl.h create mode 100644 ompi/mca/op/rocm/Makefile.am create mode 100644 ompi/mca/op/rocm/configure.m4 create mode 100644 ompi/mca/op/rocm/help-ompi-mca-op-rocm.txt create mode 100644 ompi/mca/op/rocm/op_rocm.h create mode 100644 ompi/mca/op/rocm/op_rocm_component.c create mode 100644 ompi/mca/op/rocm/op_rocm_functions.c create mode 100644 ompi/mca/op/rocm/op_rocm_impl.h create mode 100644 ompi/mca/op/rocm/op_rocm_impl.hip create mode 100644 ompi/op/help-ompi-op.txt diff --git a/config/opal_check_cudart.m4 b/config/opal_check_cudart.m4 new file mode 100644 index 00000000000..0e3fced8065 --- /dev/null +++ b/config/opal_check_cudart.m4 @@ -0,0 +1,120 @@ +dnl -*- autoconf -*- +dnl +dnl Copyright (c) 2004-2010 The Trustees of Indiana University and Indiana +dnl University Research and Technology +dnl Corporation. All rights reserved. +dnl Copyright (c) 2004-2005 The University of Tennessee and The University +dnl of Tennessee Research Foundation. All rights +dnl reserved. +dnl Copyright (c) 2004-2005 High Performance Computing Center Stuttgart, +dnl University of Stuttgart. All rights reserved. +dnl Copyright (c) 2004-2005 The Regents of the University of California. +dnl All rights reserved. +dnl Copyright (c) 2006-2016 Cisco Systems, Inc. All rights reserved. +dnl Copyright (c) 2007 Sun Microsystems, Inc. All rights reserved. +dnl Copyright (c) 2009 IBM Corporation. All rights reserved. +dnl Copyright (c) 2009 Los Alamos National Security, LLC. All rights +dnl reserved. +dnl Copyright (c) 2009-2011 Oak Ridge National Labs. All rights reserved. +dnl Copyright (c) 2011-2015 NVIDIA Corporation. All rights reserved. +dnl Copyright (c) 2015 Research Organization for Information Science +dnl and Technology (RIST). All rights reserved. +dnl Copyright (c) 2022 Amazon.com, Inc. or its affiliates. All Rights reserved. +dnl $COPYRIGHT$ +dnl +dnl Additional copyrights may follow +dnl +dnl $HEADER$ +dnl + + +# OPAL_CHECK_CUDART(prefix, [action-if-found], [action-if-not-found]) +# -------------------------------------------------------- +# check if CUDA runtime library support can be found. sets prefix_{CPPFLAGS, +# LDFLAGS, LIBS} as needed and runs action-if-found if there is +# support, otherwise executes action-if-not-found + +# +# Check for CUDA support +# +AC_DEFUN([OPAL_CHECK_CUDART],[ +OPAL_VAR_SCOPE_PUSH([cudart_save_CPPFLAGS cudart_save_LDFLAGS cudart_save_LIBS]) + +cudart_save_CPPFLAGS="$CPPFLAGS" +cudart_save_LDFLAGS="$LDFLAGS" +cudart_save_LIBS="$LIBS" + +# +# Check to see if the user provided paths for CUDART +# +AC_ARG_WITH([cudart], + [AS_HELP_STRING([--with-cudart=DIR], + [Path to the CUDA runtime library and header files])]) +AC_MSG_CHECKING([if --with-cudart is set]) +AC_ARG_WITH([cudart-libdir], + [AS_HELP_STRING([--with-cudart-libdir=DIR], + [Search for CUDA runtime libraries in DIR])]) + +#################################### +#### Check for CUDA runtime library +#################################### +AS_IF([test "x$with_cudart" != "xno" || test "x$with_cudart" = "x"], + [opal_check_cudart_happy=no + AC_MSG_RESULT([not set (--with-cudart=$with_cudart)])], + [AS_IF([test ! -d "$with_cudart"], + [AC_MSG_RESULT([not found]) + AC_MSG_WARN([Directory $with_cudart not found])] + [AS_IF([test "x`ls $with_cudart/include/cuda_runtime.h 2> /dev/null`" = "x"] + [AC_MSG_RESULT([not found]) + AC_MSG_WARN([Could not find cuda_runtime.h in $with_cudart/include])] + [opal_check_cudart_happy=yes + opal_cudart_incdir="$with_cudart/include"])])]) + +AS_IF([test "$opal_check_cudart_happy" = "no" && test "$with_cudart" != "no"], + [AC_PATH_PROG([nvcc_bin], [nvcc], ["not-found"]) + AS_IF([test "$nvcc_bin" = "not-found"], + [AC_MSG_WARN([Could not find nvcc binary])], + [nvcc_dirname=`AS_DIRNAME([$nvcc_bin])` + with_cudart=$nvcc_dirname/../ + opal_cudart_incdir=$nvcc_dirname/../include + opal_check_cudart_happy=yes]) + ] + []) + +AS_IF([test x"$with_cudart_libdir" = "x"], + [with_cudart_libdir=$with_cudart/lib64/] + []) + +AS_IF([test "$opal_check_cudart_happy" = "yes"], + [OAC_CHECK_PACKAGE([cudart], + [$1], + [cuda_runtime.h], + [cudart], + [cudaMalloc], + [opal_check_cudart_happy="yes"], + [opal_check_cudart_happy="no"])], + []) + + +AC_MSG_CHECKING([if have cuda runtime library support]) +if test "$opal_check_cudart_happy" = "yes"; then + AC_MSG_RESULT([yes (-I$opal_cudart_incdir)]) + CUDART_SUPPORT=1 + common_cudart_CPPFLAGS="-I$opal_cudart_incdir" + AC_SUBST([common_cudart_CPPFLAGS]) +else + AC_MSG_RESULT([no]) + CUDART_SUPPORT=0 +fi + + +OPAL_SUMMARY_ADD([Accelerators], [CUDART support], [], [$opal_check_cudart_happy]) +AM_CONDITIONAL([OPAL_cudart_support], [test "x$CUDART_SUPPORT" = "x1"]) +AC_DEFINE_UNQUOTED([OPAL_CUDART_SUPPORT],$CUDART_SUPPORT, + [Whether we have cuda runtime library support]) + +CPPFLAGS=${cudart_save_CPPFLAGS} +LDFLAGS=${cudart_save_LDFLAGS} +LIBS=${cudart_save_LIBS} +OPAL_VAR_SCOPE_POP +])dnl diff --git a/ompi/mca/op/base/op_base_frame.c b/ompi/mca/op/base/op_base_frame.c index 90167300851..1a7d6dc1320 100644 --- a/ompi/mca/op/base/op_base_frame.c +++ b/ompi/mca/op/base/op_base_frame.c @@ -2,7 +2,7 @@ * Copyright (c) 2004-2005 The Trustees of Indiana University and Indiana * University Research and Technology * Corporation. All rights reserved. - * Copyright (c) 2004-2005 The University of Tennessee and The University + * Copyright (c) 2004-2023 The University of Tennessee and The University * of Tennessee Research Foundation. All rights * reserved. * Copyright (c) 2004-2005 High Performance Computing Center Stuttgart, @@ -42,6 +42,7 @@ static void module_constructor(ompi_op_base_module_t *m) { m->opm_enable = NULL; m->opm_op = NULL; + m->opm_device_enabled = false; memset(&(m->opm_fns), 0, sizeof(m->opm_fns)); memset(&(m->opm_3buff_fns), 0, sizeof(m->opm_3buff_fns)); } @@ -50,6 +51,7 @@ static void module_constructor_1_0_0(ompi_op_base_module_1_0_0_t *m) { m->opm_enable = NULL; m->opm_op = NULL; + m->opm_device_enabled = false; memset(&(m->opm_fns), 0, sizeof(m->opm_fns)); memset(&(m->opm_3buff_fns), 0, sizeof(m->opm_3buff_fns)); } diff --git a/ompi/mca/op/base/op_base_op_select.c b/ompi/mca/op/base/op_base_op_select.c index c032172bf19..261bc93e5a7 100644 --- a/ompi/mca/op/base/op_base_op_select.c +++ b/ompi/mca/op/base/op_base_op_select.c @@ -3,7 +3,7 @@ * Copyright (c) 2004-2005 The Trustees of Indiana University and Indiana * University Research and Technology * Corporation. All rights reserved. - * Copyright (c) 2004-2009 The University of Tennessee and The University + * Copyright (c) 2004-2023 The University of Tennessee and The University * of Tennessee Research Foundation. All rights * reserved. * Copyright (c) 2004-2005 High Performance Computing Center Stuttgart, @@ -152,22 +152,50 @@ int ompi_op_base_op_select(ompi_op_t *op) } /* Copy over the non-NULL pointers */ - for (i = 0; i < OMPI_OP_BASE_TYPE_MAX; ++i) { - /* 2-buffer variants */ - if (NULL != avail->ao_module->opm_fns[i]) { - OBJ_RELEASE(op->o_func.intrinsic.modules[i]); - op->o_func.intrinsic.fns[i] = avail->ao_module->opm_fns[i]; - op->o_func.intrinsic.modules[i] = avail->ao_module; - OBJ_RETAIN(avail->ao_module); + if (avail->ao_module->opm_device_enabled) { + if (NULL == op->o_device_op) { + op->o_device_op = calloc(1, sizeof(*op->o_device_op)); } - - /* 3-buffer variants */ - if (NULL != avail->ao_module->opm_3buff_fns[i]) { - OBJ_RELEASE(op->o_3buff_intrinsic.modules[i]); - op->o_3buff_intrinsic.fns[i] = - avail->ao_module->opm_3buff_fns[i]; - op->o_3buff_intrinsic.modules[i] = avail->ao_module; - OBJ_RETAIN(avail->ao_module); + for (i = 0; i < OMPI_OP_BASE_TYPE_MAX; ++i) { + /* 2-buffer variants */ + if (NULL != avail->ao_module->opm_stream_fns[i]) { + if (NULL != op->o_device_op->do_intrinsic.modules[i]) { + OBJ_RELEASE(op->o_device_op->do_intrinsic.modules[i]); + } + op->o_device_op->do_intrinsic.fns[i] = avail->ao_module->opm_stream_fns[i]; + op->o_device_op->do_intrinsic.modules[i] = avail->ao_module; + OBJ_RETAIN(avail->ao_module); + } + + /* 3-buffer variants */ + if (NULL != avail->ao_module->opm_3buff_stream_fns[i]) { + if (NULL != op->o_device_op->do_3buff_intrinsic.modules[i]) { + OBJ_RELEASE(op->o_device_op->do_3buff_intrinsic.modules[i]); + } + op->o_device_op->do_3buff_intrinsic.fns[i] = + avail->ao_module->opm_3buff_stream_fns[i]; + op->o_device_op->do_3buff_intrinsic.modules[i] = avail->ao_module; + OBJ_RETAIN(avail->ao_module); + } + } + } else { + for (i = 0; i < OMPI_OP_BASE_TYPE_MAX; ++i) { + /* 2-buffer variants */ + if (NULL != avail->ao_module->opm_fns[i]) { + OBJ_RELEASE(op->o_func.intrinsic.modules[i]); + op->o_func.intrinsic.fns[i] = avail->ao_module->opm_fns[i]; + op->o_func.intrinsic.modules[i] = avail->ao_module; + OBJ_RETAIN(avail->ao_module); + } + + /* 3-buffer variants */ + if (NULL != avail->ao_module->opm_3buff_fns[i]) { + OBJ_RELEASE(op->o_3buff_intrinsic.modules[i]); + op->o_3buff_intrinsic.fns[i] = + avail->ao_module->opm_3buff_fns[i]; + op->o_3buff_intrinsic.modules[i] = avail->ao_module; + OBJ_RETAIN(avail->ao_module); + } } } diff --git a/ompi/mca/op/cuda/Makefile.am b/ompi/mca/op/cuda/Makefile.am new file mode 100644 index 00000000000..7075d26301c --- /dev/null +++ b/ompi/mca/op/cuda/Makefile.am @@ -0,0 +1,84 @@ +# +# Copyright (c) 2023 The University of Tennessee and The University +# of Tennessee Research Foundation. All rights +# reserved. +# $COPYRIGHT$ +# +# Additional copyrights may follow +# +# $HEADER$ +# + +# This component provides support for offloading reduce ops to CUDA devices. +# +# See https://github.com/open-mpi/ompi/wiki/devel-CreateComponent +# for more details on how to make Open MPI components. + +# First, list all .h and .c sources. It is necessary to list all .h +# files so that they will be picked up in the distribution tarball. + +AM_CPPFLAGS = $(op_cuda_CPPFLAGS) $(op_cudart_CPPFLAGS) + +dist_ompidata_DATA = help-ompi-mca-op-cuda.txt + +sources = op_cuda_component.c op_cuda.h op_cuda_functions.c op_cuda_impl.h +#sources_extended = op_cuda_functions.cu +cu_sources = op_cuda_impl.cu + +NVCC = nvcc -g +NVCCFLAGS= --std c++17 --gpu-architecture=compute_52 + +.cu.l$(OBJEXT): + $(LIBTOOL) $(AM_V_lt) --tag=CC $(AM_LIBTOOLFLAGS) \ + $(LIBTOOLFLAGS) --mode=compile $(NVCC) -prefer-non-pic $(NVCCFLAGS) -Wc,-Xcompiler,-fPIC,-g -c $< + +# -o $($@.o:.lo) + +# Open MPI components can be compiled two ways: +# +# 1. As a standalone dynamic shared object (DSO), sometimes called a +# dynamically loadable library (DLL). +# +# 2. As a static library that is slurped up into the upper-level +# libmpi library (regardless of whether libmpi is a static or dynamic +# library). This is called a "Libtool convenience library". +# +# The component needs to create an output library in this top-level +# component directory, and named either mca__.la (for DSO +# builds) or libmca__.la (for static builds). The OMPI +# build system will have set the +# MCA_BUILD_ompi___DSO AM_CONDITIONAL to indicate +# which way this component should be built. + +if MCA_BUILD_ompi_op_cuda_DSO +component_install = mca_op_cuda.la +else +component_install = +component_noinst = libmca_op_cuda.la +endif + +# Specific information for DSO builds. +# +# The DSO should install itself in $(ompilibdir) (by default, +# $prefix/lib/openmpi). + +mcacomponentdir = $(ompilibdir) +mcacomponent_LTLIBRARIES = $(component_install) +mca_op_cuda_la_SOURCES = $(sources) +mca_op_cuda_la_LIBADD = $(cu_sources:.cu=.lo) +mca_op_cuda_la_LDFLAGS = -module -avoid-version $(top_builddir)/ompi/lib@OMPI_LIBMPI_NAME@.la \ + $(op_cuda_LIBS) $(op_cudart_LDFLAGS) $(op_cudart_LIBS) +EXTRA_mca_op_cuda_la_SOURCES = $(cu_sources) + +# Specific information for static builds. +# +# Note that we *must* "noinst"; the upper-layer Makefile.am's will +# slurp in the resulting .la library into libmpi. + +noinst_LTLIBRARIES = $(component_noinst) +libmca_op_cuda_la_SOURCES = $(sources) +libmca_op_cuda_la_LIBADD = $(cu_sources:.cu=.lo) +libmca_op_cuda_la_LDFLAGS = -module -avoid-version\ + $(op_cuda_LIBS) $(op_cudart_LDFLAGS) $(op_cudart_LIBS) +EXTRA_libmca_op_cuda_la_SOURCES = $(cu_sources) + diff --git a/ompi/mca/op/cuda/configure.m4 b/ompi/mca/op/cuda/configure.m4 new file mode 100644 index 00000000000..0974e3aaf31 --- /dev/null +++ b/ompi/mca/op/cuda/configure.m4 @@ -0,0 +1,41 @@ +# -*- shell-script -*- +# +# Copyright (c) 2011-2013 NVIDIA Corporation. All rights reserved. +# Copyright (c) 2023 The University of Tennessee and The University +# of Tennessee Research Foundation. All rights +# reserved. +# Copyright (c) 2022 Amazon.com, Inc. or its affiliates. +# All Rights reserved. +# $COPYRIGHT$ +# +# Additional copyrights may follow +# +# $HEADER$ +# + +# +# If CUDA support was requested, then build the CUDA support library. +# This code checks makes sure the check was done earlier by the +# opal_check_cuda.m4 code. It also copies the flags and libs under +# opal_cuda_CPPFLAGS, opal_cuda_LDFLAGS, and opal_cuda_LIBS + +AC_DEFUN([MCA_ompi_op_cuda_CONFIG],[ + + AC_CONFIG_FILES([ompi/mca/op/cuda/Makefile]) + + OPAL_CHECK_CUDA([op_cuda]) + OPAL_CHECK_CUDART([op_cudart]) + + AS_IF([test "x$CUDA_SUPPORT" = "x1"], + [$1], + [$2]) + + AC_SUBST([op_cuda_CPPFLAGS]) + AC_SUBST([op_cuda_LDFLAGS]) + AC_SUBST([op_cuda_LIBS]) + + AC_SUBST([op_cudart_CPPFLAGS]) + AC_SUBST([op_cudart_LDFLAGS]) + AC_SUBST([op_cudart_LIBS]) + +])dnl diff --git a/ompi/mca/op/cuda/help-ompi-mca-op-cuda.txt b/ompi/mca/op/cuda/help-ompi-mca-op-cuda.txt new file mode 100644 index 00000000000..f999ebc939c --- /dev/null +++ b/ompi/mca/op/cuda/help-ompi-mca-op-cuda.txt @@ -0,0 +1,15 @@ +# -*- text -*- +# +# Copyright (c) 2023 The University of Tennessee and The University +# of Tennessee Research Foundation. All rights +# reserved. +# $COPYRIGHT$ +# +# Additional copyrights may follow +# +# $HEADER$ +# +# This is the US/English help file for Open MPI's CUDA operator component +# +[CUDA call failed] +"CUDA call %s failed: %s: %s\n" diff --git a/ompi/mca/op/cuda/op_cuda.h b/ompi/mca/op/cuda/op_cuda.h new file mode 100644 index 00000000000..ab349d48ee4 --- /dev/null +++ b/ompi/mca/op/cuda/op_cuda.h @@ -0,0 +1,80 @@ +/* + * Copyright (c) 2019-2023 The University of Tennessee and The University + * of Tennessee Research Foundation. All rights + * reserved. + * $COPYRIGHT$ + * + * Additional copyrights may follow + * + * $HEADER$ + */ + +#ifndef MCA_OP_CUDA_EXPORT_H +#define MCA_OP_CUDA_EXPORT_H + +#include "ompi_config.h" + +#include "ompi/mca/mca.h" +#include "opal/class/opal_object.h" + +#include "ompi/mca/op/op.h" +#include "ompi/runtime/mpiruntime.h" + +#include +#include + +BEGIN_C_DECLS + + +#define xstr(x) #x +#define str(x) xstr(x) + +#define CHECK(fn, args) \ + do { \ + cudaError_t err = fn args; \ + if (err != cudaSuccess) { \ + opal_show_help("help-ompi-mca-op-cuda.txt", \ + "CUDA call failed", true, \ + str(fn), cudaGetErrorName(err), \ + cudaGetErrorString(err)); \ + ompi_mpi_abort(MPI_COMM_WORLD, 1); \ + } \ + } while (0) + + +/** + * Derive a struct from the base op component struct, allowing us to + * cache some component-specific information on our well-known + * component struct. + */ +typedef struct { + /** The base op component struct */ + ompi_op_base_component_1_0_0_t super; + int cu_max_num_blocks; + int cu_max_num_threads; + int *cu_max_threads_per_block; + int *cu_max_blocks; + CUdevice *cu_devices; + int cu_num_devices; +} ompi_op_cuda_component_t; + +/** + * Globally exported variable. Note that it is a *cuda* component + * (defined above), which has the ompi_op_base_component_t as its + * first member. Hence, the MCA/op framework will find the data that + * it expects in the first memory locations, but then the component + * itself can cache additional information after that that can be used + * by both the component and modules. + */ +OMPI_DECLSPEC extern ompi_op_cuda_component_t + mca_op_cuda_component; + +OMPI_DECLSPEC extern +ompi_op_base_stream_handler_fn_t ompi_op_cuda_functions[OMPI_OP_BASE_FORTRAN_OP_MAX][OMPI_OP_BASE_TYPE_MAX]; + +OMPI_DECLSPEC extern +ompi_op_base_3buff_stream_handler_fn_t ompi_op_cuda_3buff_functions[OMPI_OP_BASE_FORTRAN_OP_MAX][OMPI_OP_BASE_TYPE_MAX]; + +END_C_DECLS + +#endif /* MCA_OP_CUDA_EXPORT_H */ diff --git a/ompi/mca/op/cuda/op_cuda_component.c b/ompi/mca/op/cuda/op_cuda_component.c new file mode 100644 index 00000000000..3ead710bd1d --- /dev/null +++ b/ompi/mca/op/cuda/op_cuda_component.c @@ -0,0 +1,195 @@ +/* + * Copyright (c) 2019-2023 The University of Tennessee and The University + * of Tennessee Research Foundation. All rights + * reserved. + * Copyright (c) 2020 Research Organization for Information Science + * and Technology (RIST). All rights reserved. + * Copyright (c) 2021 Cisco Systems, Inc. All rights reserved. + * $COPYRIGHT$ + * + * Additional copyrights may follow + * + * $HEADER$ + */ + +/** @file + * + * This is the "cuda" op component source code. + * + */ + +#include "ompi_config.h" + +#include "opal/util/printf.h" + +#include "ompi/constants.h" +#include "ompi/op/op.h" +#include "ompi/mca/op/op.h" +#include "ompi/mca/op/base/base.h" +#include "ompi/mca/op/cuda/op_cuda.h" + +#include + +static int cuda_component_open(void); +static int cuda_component_close(void); +static int cuda_component_init_query(bool enable_progress_threads, + bool enable_mpi_thread_multiple); +static struct ompi_op_base_module_1_0_0_t * + cuda_component_op_query(struct ompi_op_t *op, int *priority); +static int cuda_component_register(void); + +ompi_op_cuda_component_t mca_op_cuda_component = { + { + .opc_version = { + OMPI_OP_BASE_VERSION_1_0_0, + + .mca_component_name = "cuda", + MCA_BASE_MAKE_VERSION(component, OMPI_MAJOR_VERSION, OMPI_MINOR_VERSION, + OMPI_RELEASE_VERSION), + .mca_open_component = cuda_component_open, + .mca_close_component = cuda_component_close, + .mca_register_component_params = cuda_component_register, + }, + .opc_data = { + /* The component is checkpoint ready */ + MCA_BASE_METADATA_PARAM_CHECKPOINT + }, + + .opc_init_query = cuda_component_init_query, + .opc_op_query = cuda_component_op_query, + }, + .cu_max_num_blocks = -1, + .cu_max_num_threads = -1, + .cu_max_threads_per_block = NULL, + .cu_max_blocks = NULL, + .cu_devices = NULL, + .cu_num_devices = 0, +}; + +/* + * Component open + */ +static int cuda_component_open(void) +{ + return OMPI_SUCCESS; +} + +/* + * Component close + */ +static int cuda_component_close(void) +{ + if (mca_op_cuda_component.cu_num_devices > 0) { + free(mca_op_cuda_component.cu_max_threads_per_block); + mca_op_cuda_component.cu_max_threads_per_block = NULL; + free(mca_op_cuda_component.cu_max_blocks); + mca_op_cuda_component.cu_max_blocks = NULL; + free(mca_op_cuda_component.cu_devices); + mca_op_cuda_component.cu_devices = NULL; + mca_op_cuda_component.cu_num_devices = 0; + } + + return OMPI_SUCCESS; +} + +/* + * Register MCA params. + */ +static int +cuda_component_register(void) +{ + mca_base_var_enum_flag_t *new_enum_flag = NULL; + (void) mca_base_component_var_register(&mca_op_cuda_component.super.opc_version, + "max_num_blocks", + "Maximum number of thread blocks in kernels (-1: device limit)", + MCA_BASE_VAR_TYPE_INT, + &(new_enum_flag->super), 0, 0, + OPAL_INFO_LVL_4, + MCA_BASE_VAR_SCOPE_LOCAL, + &mca_op_cuda_component.cu_max_num_blocks); + + (void) mca_base_component_var_register(&mca_op_cuda_component.super.opc_version, + "max_num_threads", + "Maximum number of threads per block in kernels (-1: device limit)", + MCA_BASE_VAR_TYPE_INT, + &(new_enum_flag->super), 0, 0, + OPAL_INFO_LVL_4, + MCA_BASE_VAR_SCOPE_LOCAL, + &mca_op_cuda_component.cu_max_num_threads); + + return OMPI_SUCCESS; +} + + +/* + * Query whether this component wants to be used in this process. + */ +static int +cuda_component_init_query(bool enable_progress_threads, + bool enable_mpi_thread_multiple) +{ + int num_devices; + int rc; + // TODO: is this init needed here? + cuInit(0); + CHECK(cuDeviceGetCount, (&num_devices)); + mca_op_cuda_component.cu_num_devices = num_devices; + mca_op_cuda_component.cu_devices = (CUdevice*)malloc(num_devices*sizeof(CUdevice)); + mca_op_cuda_component.cu_max_threads_per_block = (int*)malloc(num_devices*sizeof(int)); + mca_op_cuda_component.cu_max_blocks = (int*)malloc(num_devices*sizeof(int)); + for (int i = 0; i < num_devices; ++i) { + CHECK(cuDeviceGet, (&mca_op_cuda_component.cu_devices[i], i)); + rc = cuDeviceGetAttribute(&mca_op_cuda_component.cu_max_threads_per_block[i], + CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_X, + mca_op_cuda_component.cu_devices[i]); + if (CUDA_SUCCESS != rc) { + /* fall-back to value that should work on every device */ + mca_op_cuda_component.cu_max_threads_per_block[i] = 512; + } + if (-1 < mca_op_cuda_component.cu_max_num_threads) { + if (mca_op_cuda_component.cu_max_threads_per_block[i] >= mca_op_cuda_component.cu_max_num_threads) { + mca_op_cuda_component.cu_max_threads_per_block[i] = mca_op_cuda_component.cu_max_num_threads; + } + } + + rc = cuDeviceGetAttribute(&mca_op_cuda_component.cu_max_blocks[i], + CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_X, + mca_op_cuda_component.cu_devices[i]); + if (CUDA_SUCCESS != rc) { + /* fall-back to value that should work on every device */ + mca_op_cuda_component.cu_max_blocks[i] = 512; + } + if (-1 < mca_op_cuda_component.cu_max_num_blocks) { + if (mca_op_cuda_component.cu_max_blocks[i] >= mca_op_cuda_component.cu_max_num_blocks) { + mca_op_cuda_component.cu_max_blocks[i] = mca_op_cuda_component.cu_max_num_blocks; + } + } + } + + return OMPI_SUCCESS; +} + +/* + * Query whether this component can be used for a specific op + */ +static struct ompi_op_base_module_1_0_0_t* +cuda_component_op_query(struct ompi_op_t *op, int *priority) +{ + ompi_op_base_module_t *module = NULL; + + module = OBJ_NEW(ompi_op_base_module_t); + module->opm_device_enabled = true; + for (int i = 0; i < OMPI_OP_BASE_TYPE_MAX; ++i) { + module->opm_stream_fns[i] = ompi_op_cuda_functions[op->o_f_to_c_index][i]; + module->opm_3buff_stream_fns[i] = ompi_op_cuda_3buff_functions[op->o_f_to_c_index][i]; + + if( NULL != module->opm_fns[i] ) { + OBJ_RETAIN(module); + } + if( NULL != module->opm_3buff_fns[i] ) { + OBJ_RETAIN(module); + } + } + *priority = 50; + return (ompi_op_base_module_1_0_0_t *) module; +} diff --git a/ompi/mca/op/cuda/op_cuda_functions.c b/ompi/mca/op/cuda/op_cuda_functions.c new file mode 100644 index 00000000000..904595147cb --- /dev/null +++ b/ompi/mca/op/cuda/op_cuda_functions.c @@ -0,0 +1,1897 @@ +/* + * Copyright (c) 2019-2023 The University of Tennessee and The University + * of Tennessee Research Foundation. All rights + * reserved. + * Copyright (c) 2020 Research Organization for Information Science + * and Technology (RIST). All rights reserved. + * $COPYRIGHT$ + * + * Additional copyrights may follow + * + * $HEADER$ + */ + +#include "ompi_config.h" + +#ifdef HAVE_SYS_TYPES_H +#include +#endif +#include "opal/util/output.h" + + +#include "ompi/op/op.h" +#include "ompi/mca/op/op.h" +#include "ompi/mca/op/base/base.h" +#include "ompi/mca/op/cuda/op_cuda.h" +#include "opal/mca/accelerator/accelerator.h" + +#include "ompi/mca/op/cuda/op_cuda.h" +#include "ompi/mca/op/cuda/op_cuda_impl.h" + +/** + * Disable warning about empty macro var-args. + * We use varargs to suppress expansion of typenames + * (e.g., int32_t -> int) which could lead to collisions + * for similar base types. */ +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wgnu-zero-variadic-macro-arguments" + +static inline void device_op_pre(const void *orig_source1, + void **source1, + int *source1_device, + const void *orig_source2, + void **source2, + int *source2_device, + void *orig_target, + void **target, + int *target_device, + int count, + struct ompi_datatype_t *dtype, + int *threads_per_block, + int *max_blocks, + int *device, + opal_accelerator_stream_t *stream) +{ + uint64_t target_flags = -1, source1_flags = -1, source2_flags = -1; + int target_rc, source1_rc, source2_rc = -1; + + *target = orig_target; + *source1 = (void*)orig_source1; + if (NULL != orig_source2) { + *source2 = (void*)orig_source2; + } + + if (*device != MCA_ACCELERATOR_NO_DEVICE_ID) { + /* we got the device from the caller, just adjust the output parameters */ + *target_device = *device; + *source1_device = *device; + if (NULL != source2_device) { + *source2_device = *device; + } + } else { + + target_rc = opal_accelerator.check_addr(*target, target_device, &target_flags); + source1_rc = opal_accelerator.check_addr(*source1, source1_device, &source1_flags); + *device = *target_device; + + if (NULL != orig_source2) { + source2_rc = opal_accelerator.check_addr(*source2, source2_device, &source2_flags); + } + + if (0 == target_rc && 0 == source1_rc && 0 == source2_rc) { + /* no buffers are on any device, select device 0 */ + *device = 0; + } else if (*target_device == -1) { + if (*source1_device == -1 && NULL != orig_source2) { + *device = *source2_device; + } else { + *device = *source1_device; + } + } + + if (0 == target_rc || 0 == source1_rc || *target_device != *source1_device) { + size_t nbytes; + ompi_datatype_type_size(dtype, &nbytes); + nbytes *= count; + + if (0 == target_rc) { + // allocate memory on the device for the target buffer + opal_accelerator.mem_alloc_stream(*device, target, nbytes, stream); + CHECK(cuMemcpyHtoDAsync, ((CUdeviceptr)*target, orig_target, nbytes, *(CUstream*)stream->stream)); + *target_device = -1; // mark target device as host + } + + if (0 == source1_rc || *device != *source1_device) { + // allocate memory on the device for the source buffer + opal_accelerator.mem_alloc_stream(*device, source1, nbytes, stream); + if (0 == source1_rc) { + /* copy from host to device */ + CHECK(cuMemcpyHtoDAsync, ((CUdeviceptr)*source1, orig_source1, nbytes, *(CUstream*)stream->stream)); + } else { + /* copy from one device to another device */ + /* TODO: does this actually work? Can we enable P2P? */ + CHECK(cuMemcpyDtoDAsync, ((CUdeviceptr)*source1, (CUdeviceptr)orig_source1, nbytes, *(CUstream*)stream->stream)); + } + } + + } + if (NULL != source2_device && *target_device != *source2_device) { + // allocate memory on the device for the source buffer + size_t nbytes; + ompi_datatype_type_size(dtype, &nbytes); + nbytes *= count; + + opal_accelerator.mem_alloc_stream(*device, source2, nbytes, stream); + if (0 == source2_rc) { + /* copy from host to device */ + //printf("copying source from host to device %d\n", *device); + CHECK(cuMemcpyHtoDAsync, ((CUdeviceptr)*source2, orig_source2, nbytes, *(CUstream*)stream->stream)); + } else { + /* copy from one device to another device */ + /* TODO: does this actually work? Can we enable P2P? */ + CHECK(cuMemcpyDtoDAsync, ((CUdeviceptr)*source2, (CUdeviceptr)orig_source2, nbytes, *(CUstream*)stream->stream)); + } + } + } + *threads_per_block = mca_op_cuda_component.cu_max_threads_per_block[*device]; + *max_blocks = mca_op_cuda_component.cu_max_blocks[*device]; +} + +static inline void device_op_post(void *source1, + int source1_device, + void *source2, + int source2_device, + void *orig_target, + void *target, + int target_device, + int count, + struct ompi_datatype_t *dtype, + int device, + opal_accelerator_stream_t *stream) +{ + if (MCA_ACCELERATOR_NO_DEVICE_ID == target_device) { + + size_t nbytes; + ompi_datatype_type_size(dtype, &nbytes); + nbytes *= count; + + CHECK(cuMemcpyDtoHAsync, (orig_target, (CUdeviceptr)target, nbytes, *(CUstream *)stream->stream)); + } + + if (MCA_ACCELERATOR_NO_DEVICE_ID == target_device) { + opal_accelerator.mem_release_stream(device, target, stream); + //CHECK(cuMemFreeAsync, ((CUdeviceptr)target, mca_op_cuda_component.cu_stream)); + } + if (source1_device != device) { + opal_accelerator.mem_release_stream(device, source1, stream); + //CHECK(cuMemFreeAsync, ((CUdeviceptr)source, mca_op_cuda_component.cu_stream)); + } + if (NULL != source2 && source2_device != device) { + opal_accelerator.mem_release_stream(device, source2, stream); + //CHECK(cuMemFreeAsync, ((CUdeviceptr)source, mca_op_cuda_component.cu_stream)); + } +} + +#define FUNC(name, type_name, type) \ + static \ + void ompi_op_cuda_2buff_##name##_##type_name(const void *in, void *inout, int *count, \ + struct ompi_datatype_t **dtype, \ + int device, \ + opal_accelerator_stream_t *stream, \ + struct ompi_op_base_module_1_0_0_t *module) __opal_attribute_unused__; \ + static \ + void ompi_op_cuda_2buff_##name##_##type_name(const void *in, void *inout, int *count, \ + struct ompi_datatype_t **dtype, \ + int device, \ + opal_accelerator_stream_t *stream, \ + struct ompi_op_base_module_1_0_0_t *module) { \ + int threads_per_block, max_blocks; \ + int source_device, target_device; \ + type *source, *target; \ + int n = *count; \ + device_op_pre(in, (void**)&source, &source_device, NULL, NULL, NULL, \ + inout, (void**)&target, &target_device, \ + n, *dtype, \ + &threads_per_block, &max_blocks, &device, stream); \ + CUstream *custream = (CUstream*)stream->stream; \ + ompi_op_cuda_2buff_##name##_##type_name##_submit(source, target, n, threads_per_block, max_blocks, *custream); \ + device_op_post(source, source_device, NULL, -1, inout, target, target_device, n, *dtype, device, stream); \ + } + +#define OP_FUNC(name, type_name, type, ...) FUNC(name, __VA_ARGS__##type_name, __VA_ARGS__##type) + +/* reuse the macro above, no work is actually done so we don't care about the func */ +#define FUNC_FUNC(name, type_name, type, ...) FUNC(name, __VA_ARGS__##type_name, __VA_ARGS__##type) + +/* + * Since all the functions in this file are essentially identical, we + * use a macro to substitute in names and types. The core operation + * in all functions that use this macro is the same. + * + * This macro is for minloc and maxloc + */ +#define LOC_FUNC(name, type_name) FUNC(name, type_name, ompi_op_predefined_##type_name##_t) + +/* Dispatch Fortran types to C types */ +#define FORT_INT_FUNC(name, type_name, type) \ + static \ + void ompi_op_cuda_2buff_##name##_##type_name(const void *in, void *inout, int *count, \ + struct ompi_datatype_t **dtype, \ + int device, \ + opal_accelerator_stream_t *stream, \ + struct ompi_op_base_module_1_0_0_t *module) { \ + \ + _Static_assert(sizeof(type) >= sizeof(int8_t) && sizeof(type) <= sizeof(int64_t), \ + "Unsupported Fortran integer size"); \ + switch(sizeof(type)) { \ + case sizeof(int8_t): \ + ompi_op_cuda_2buff_##name##_int8_t(in, inout, count, dtype, device, stream, module); \ + break; \ + case sizeof(int16_t): \ + ompi_op_cuda_2buff_##name##_int16_t(in, inout, count, dtype, device, stream, module); \ + break; \ + case sizeof(int32_t): \ + ompi_op_cuda_2buff_##name##_int32_t(in, inout, count, dtype, device, stream, module); \ + break; \ + case sizeof(int64_t): \ + ompi_op_cuda_2buff_##name##_int64_t(in, inout, count, dtype, device, stream, module); \ + break; \ + } \ + } + +#define FORT_FLOAT_FUNC(name, type_name, type) \ + static \ + void ompi_op_cuda_2buff_##name##_##type_name(const void *in, void *inout, int *count, \ + struct ompi_datatype_t **dtype, \ + int device, \ + opal_accelerator_stream_t *stream, \ + struct ompi_op_base_module_1_0_0_t *module) { \ + _Static_assert(sizeof(type) >= sizeof(float) && sizeof(type) <= sizeof(long double), \ + "Unsupported Fortran float size"); \ + switch(sizeof(type)) { \ + case sizeof(float): \ + ompi_op_cuda_2buff_##name##_float(in, inout, count, dtype, device, stream, module); \ + break; \ + case sizeof(double): \ + ompi_op_cuda_2buff_##name##_double(in, inout, count, dtype, device, stream, module); \ + break; \ + case sizeof(long double): \ + ompi_op_cuda_2buff_##name##_long_double(in, inout, count, dtype, device, stream, module); \ + break; \ + } \ + } + +#define FORT_LOC_INT_FUNC(name, type_name, type) \ + static \ + void ompi_op_cuda_2buff_##name##_##type_name(const void *in, void *inout, int *count, \ + struct ompi_datatype_t **dtype, \ + int device, \ + opal_accelerator_stream_t *stream, \ + struct ompi_op_base_module_1_0_0_t *module) { \ + \ + _Static_assert(sizeof(type) >= sizeof(int8_t) && sizeof(type) <= sizeof(int64_t), \ + "Unsupported Fortran integer type"); \ + switch(sizeof(type)) { \ + case sizeof(int8_t): \ + ompi_op_cuda_2buff_##name##_2int8(in, inout, count, dtype, device, stream, module); \ + break; \ + case sizeof(int16_t): \ + ompi_op_cuda_2buff_##name##_2int16(in, inout, count, dtype, device, stream, module); \ + break; \ + case sizeof(int32_t): \ + ompi_op_cuda_2buff_##name##_2int32(in, inout, count, dtype, device, stream, module); \ + break; \ + case sizeof(int64_t): \ + ompi_op_cuda_2buff_##name##_2int64(in, inout, count, dtype, device, stream, module); \ + break; \ + } \ + } + +#define FORT_LOC_FLOAT_FUNC(name, type_name, type) \ + static \ + void ompi_op_cuda_2buff_##name##_##type_name(const void *in, void *inout, int *count, \ + struct ompi_datatype_t **dtype, \ + int device, \ + opal_accelerator_stream_t *stream, \ + struct ompi_op_base_module_1_0_0_t *module) { \ + _Static_assert(sizeof(type) >= sizeof(float) && sizeof(type) <= sizeof(double), \ + "Unsupported Fortran float size"); \ + switch(sizeof(type)) { \ + case sizeof(float): \ + ompi_op_cuda_2buff_##name##_2float(in, inout, count, dtype, device, stream, module); \ + break; \ + case sizeof(double): \ + ompi_op_cuda_2buff_##name##_2double(in, inout, count, dtype, device, stream, module); \ + break; \ + } \ + } + +/************************************************************************* + * Max + *************************************************************************/ + +/* C integer */ +FUNC_FUNC(max, int8_t, int8_t) +FUNC_FUNC(max, uint8_t, uint8_t) +FUNC_FUNC(max, int16_t, int16_t) +FUNC_FUNC(max, uint16_t, uint16_t) +FUNC_FUNC(max, int32_t, int32_t) +FUNC_FUNC(max, uint32_t, uint32_t) +FUNC_FUNC(max, int64_t, int64_t) +FUNC_FUNC(max, uint64_t, uint64_t) +FUNC_FUNC(max, long, long) +FUNC_FUNC(max, ulong, unsigned long) + +/* Fortran integer */ +#if OMPI_HAVE_FORTRAN_INTEGER +FORT_INT_FUNC(max, fortran_integer, ompi_fortran_integer_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER1 +FORT_INT_FUNC(max, fortran_integer1, ompi_fortran_integer1_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER2 +FORT_INT_FUNC(max, fortran_integer2, ompi_fortran_integer2_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER4 +FORT_INT_FUNC(max, fortran_integer4, ompi_fortran_integer4_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER8 +FORT_INT_FUNC(max, fortran_integer8, ompi_fortran_integer8_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER16 +FORT_INT_FUNC(max, fortran_integer16, ompi_fortran_integer16_t) +#endif + +#if 0 +/* Floating point */ +#if defined(HAVE_SHORT_FLOAT) +FUNC_FUNC(max, short_float, short float) +#elif defined(HAVE_OPAL_SHORT_FLOAT_T) +FUNC_FUNC(max, short_float, opal_short_float_t) +#endif +#endif // 0 +FUNC_FUNC(max, float, float) +FUNC_FUNC(max, double, double) +FUNC_FUNC(max, long_double, long double) +#if OMPI_HAVE_FORTRAN_REAL +FORT_FLOAT_FUNC(max, fortran_real, ompi_fortran_real_t) +#endif +#if OMPI_HAVE_FORTRAN_DOUBLE_PRECISION +FORT_FLOAT_FUNC(max, fortran_double_precision, ompi_fortran_double_precision_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL2 +FORT_FLOAT_FUNC(max, fortran_real2, ompi_fortran_real2_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL4 +FORT_FLOAT_FUNC(max, fortran_real4, ompi_fortran_real4_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL8 +FORT_FLOAT_FUNC(max, fortran_real8, ompi_fortran_real8_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL16 && OMPI_REAL16_MATCHES_C +FORT_FLOAT_FUNC(max, fortran_real16, ompi_fortran_real16_t) +#endif + + +/************************************************************************* + * Min + *************************************************************************/ + +/* C integer */ +FUNC_FUNC(min, int8_t, int8_t) +FUNC_FUNC(min, uint8_t, uint8_t) +FUNC_FUNC(min, int16_t, int16_t) +FUNC_FUNC(min, uint16_t, uint16_t) +FUNC_FUNC(min, int32_t, int32_t) +FUNC_FUNC(min, uint32_t, uint32_t) +FUNC_FUNC(min, int64_t, int64_t) +FUNC_FUNC(min, uint64_t, uint64_t) +FUNC_FUNC(min, long, long) +FUNC_FUNC(min, ulong, unsigned long) + +/* Fortran integer */ +#if OMPI_HAVE_FORTRAN_INTEGER +FORT_INT_FUNC(min, fortran_integer, ompi_fortran_integer_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER1 +FORT_INT_FUNC(min, fortran_integer1, ompi_fortran_integer1_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER2 +FORT_INT_FUNC(min, fortran_integer2, ompi_fortran_integer2_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER4 +FORT_INT_FUNC(min, fortran_integer4, ompi_fortran_integer4_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER8 +FORT_INT_FUNC(min, fortran_integer8, ompi_fortran_integer8_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER16 +FORT_INT_FUNC(min, fortran_integer16, ompi_fortran_integer16_t) +#endif + +#if 0 +/* Floating point */ +#if defined(HAVE_SHORT_FLOAT) +FUNC_FUNC(min, short_float, short float) +#elif defined(HAVE_OPAL_SHORT_FLOAT_T) +FUNC_FUNC(min, short_float, opal_short_float_t) +#endif +#endif // 0 + +FUNC_FUNC(min, float, float) +FUNC_FUNC(min, double, double) +FUNC_FUNC(min, long_double, long double) +#if OMPI_HAVE_FORTRAN_REAL +FORT_FLOAT_FUNC(min, fortran_real, ompi_fortran_real_t) +#endif +#if OMPI_HAVE_FORTRAN_DOUBLE_PRECISION +FORT_FLOAT_FUNC(min, fortran_double_precision, ompi_fortran_double_precision_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL2 +FORT_FLOAT_FUNC(min, fortran_real2, ompi_fortran_real2_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL4 +FORT_FLOAT_FUNC(min, fortran_real4, ompi_fortran_real4_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL8 +FORT_FLOAT_FUNC(min, fortran_real8, ompi_fortran_real8_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL16 && OMPI_REAL16_MATCHES_C +FORT_FLOAT_FUNC(min, fortran_real16, ompi_fortran_real16_t) +#endif + +/************************************************************************* + * Sum + *************************************************************************/ + +/* C integer */ +OP_FUNC(sum, int8_t, int8_t) +OP_FUNC(sum, uint8_t, uint8_t) +OP_FUNC(sum, int16_t, int16_t) +OP_FUNC(sum, uint16_t, uint16_t) +OP_FUNC(sum, int32_t, int32_t) +OP_FUNC(sum, uint32_t, uint32_t) +OP_FUNC(sum, int64_t, int64_t) +OP_FUNC(sum, uint64_t, uint64_t) +OP_FUNC(sum, long, long) +OP_FUNC(sum, ulong, unsigned long) + +/* Fortran integer */ +#if OMPI_HAVE_FORTRAN_INTEGER +FORT_INT_FUNC(sum, fortran_integer, ompi_fortran_integer_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER1 +FORT_INT_FUNC(sum, fortran_integer1, ompi_fortran_integer1_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER2 +FORT_INT_FUNC(sum, fortran_integer2, ompi_fortran_integer2_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER4 +FORT_INT_FUNC(sum, fortran_integer4, ompi_fortran_integer4_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER8 +FORT_INT_FUNC(sum, fortran_integer8, ompi_fortran_integer8_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER16 +FORT_INT_FUNC(sum, fortran_integer16, ompi_fortran_integer16_t) +#endif + +#if 0 +/* Floating point */ +#if defined(HAVE_SHORT_FLOAT) +OP_FUNC(sum, short_float, short float) +#elif defined(HAVE_OPAL_SHORT_FLOAT_T) +OP_FUNC(sum, short_float, opal_short_float_t) +#endif +#endif // 0 + +OP_FUNC(sum, float, float) +OP_FUNC(sum, double, double) +OP_FUNC(sum, long_double, long double) +#if OMPI_HAVE_FORTRAN_REAL +FORT_FLOAT_FUNC(sum, fortran_real, ompi_fortran_real_t) +#endif +#if OMPI_HAVE_FORTRAN_DOUBLE_PRECISION +FORT_FLOAT_FUNC(sum, fortran_double_precision, ompi_fortran_double_precision_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL2 +FORT_FLOAT_FUNC(sum, fortran_real2, ompi_fortran_real2_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL4 +FORT_FLOAT_FUNC(sum, fortran_real4, ompi_fortran_real4_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL8 +FORT_FLOAT_FUNC(sum, fortran_real8, ompi_fortran_real8_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL16 && OMPI_REAL16_MATCHES_C +FORT_FLOAT_FUNC(sum, fortran_real16, ompi_fortran_real16_t) +#endif +/* Complex */ +#if 0 +#if defined(HAVE_SHORT_FLOAT__COMPLEX) +OP_FUNC(sum, c_short_float_complex, short float _Complex) +#elif defined(HAVE_OPAL_SHORT_FLOAT_COMPLEX_T) +COMPLEX_SUM_FUNC(c_short_float_complex, opal_short_float_t) +#endif +OP_FUNC(sum, c_long_double_complex, long double _Complex) +#endif // 0 + +FUNC_FUNC(sum, c_float_complex, cuFloatComplex) +FUNC_FUNC(sum, c_double_complex, cuDoubleComplex) + +/************************************************************************* + * Product + *************************************************************************/ + +/* C integer */ +OP_FUNC(prod, int8_t, int8_t) +OP_FUNC(prod, uint8_t, uint8_t) +OP_FUNC(prod, int16_t, int16_t) +OP_FUNC(prod, uint16_t, uint16_t) +OP_FUNC(prod, int32_t, int32_t) +OP_FUNC(prod, uint32_t, uint32_t) +OP_FUNC(prod, int64_t, int64_t) +OP_FUNC(prod, uint64_t, uint64_t) +OP_FUNC(prod, long, long) +OP_FUNC(prod, ulong, unsigned long) + +/* Fortran integer */ +#if OMPI_HAVE_FORTRAN_INTEGER +FORT_INT_FUNC(prod, fortran_integer, ompi_fortran_integer_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER1 +FORT_INT_FUNC(prod, fortran_integer1, ompi_fortran_integer1_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER2 +FORT_INT_FUNC(prod, fortran_integer2, ompi_fortran_integer2_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER4 +FORT_INT_FUNC(prod, fortran_integer4, ompi_fortran_integer4_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER8 +FORT_INT_FUNC(prod, fortran_integer8, ompi_fortran_integer8_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER16 +FORT_INT_FUNC(prod, fortran_integer16, ompi_fortran_integer16_t) +#endif +/* Floating point */ + +#if 0 +#if defined(HAVE_SHORT_FLOAT) +OP_FUNC(prod, short_float, short float) +#elif defined(HAVE_OPAL_SHORT_FLOAT_T) +OP_FUNC(prod, short_float, opal_short_float_t) +#endif +#endif // 0 + +OP_FUNC(prod, float, float) +OP_FUNC(prod, double, double) +OP_FUNC(prod, long_double, long double) +#if OMPI_HAVE_FORTRAN_REAL +FORT_FLOAT_FUNC(prod, fortran_real, ompi_fortran_real_t) +#endif +#if OMPI_HAVE_FORTRAN_DOUBLE_PRECISION +FORT_FLOAT_FUNC(prod, fortran_double_precision, ompi_fortran_double_precision_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL2 +FORT_FLOAT_FUNC(prod, fortran_real2, ompi_fortran_real2_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL4 +FORT_FLOAT_FUNC(prod, fortran_real4, ompi_fortran_real4_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL8 +FORT_FLOAT_FUNC(prod, fortran_real8, ompi_fortran_real8_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL16 && OMPI_REAL16_MATCHES_C +FORT_FLOAT_FUNC(prod, fortran_real16, ompi_fortran_real16_t) +#endif +/* Complex */ +#if 0 +#if defined(HAVE_SHORT_FLOAT__COMPLEX) +OP_FUNC(prod, c_short_float_complex, short float _Complex) +#elif defined(HAVE_OPAL_SHORT_FLOAT_COMPLEX_T) +COMPLEX_PROD_FUNC(c_short_float_complex, opal_short_float_t) +#endif +OP_FUNC(prod, c_long_double_complex, long double _Complex) +#endif // 0 + +FUNC_FUNC(prod, c_float_complex, cuFloatComplex) +FUNC_FUNC(prod, c_double_complex, cuDoubleComplex) + +/************************************************************************* + * Logical AND + *************************************************************************/ + +/* C integer */ +FUNC_FUNC(land, int8_t, int8_t) +FUNC_FUNC(land, uint8_t, uint8_t) +FUNC_FUNC(land, int16_t, int16_t) +FUNC_FUNC(land, uint16_t, uint16_t) +FUNC_FUNC(land, int32_t, int32_t) +FUNC_FUNC(land, uint32_t, uint32_t) +FUNC_FUNC(land, int64_t, int64_t) +FUNC_FUNC(land, uint64_t, uint64_t) +FUNC_FUNC(land, long, long) +FUNC_FUNC(land, ulong, unsigned long) + +/* Logical */ +#if OMPI_HAVE_FORTRAN_LOGICAL +FORT_INT_FUNC(land, fortran_logical, ompi_fortran_logical_t) +#endif +/* C++ bool */ +FUNC_FUNC(land, bool, bool) + +/************************************************************************* + * Logical OR + *************************************************************************/ + +/* C integer */ +FUNC_FUNC(lor, int8_t, int8_t) +FUNC_FUNC(lor, uint8_t, uint8_t) +FUNC_FUNC(lor, int16_t, int16_t) +FUNC_FUNC(lor, uint16_t, uint16_t) +FUNC_FUNC(lor, int32_t, int32_t) +FUNC_FUNC(lor, uint32_t, uint32_t) +FUNC_FUNC(lor, int64_t, int64_t) +FUNC_FUNC(lor, uint64_t, uint64_t) +FUNC_FUNC(lor, long, long) +FUNC_FUNC(lor, ulong, unsigned long) + +/* Logical */ +#if OMPI_HAVE_FORTRAN_LOGICAL +FORT_INT_FUNC(lor, fortran_logical, ompi_fortran_logical_t) +#endif +/* C++ bool */ +FUNC_FUNC(lor, bool, bool) + +/************************************************************************* + * Logical XOR + *************************************************************************/ + +/* C integer */ +FUNC_FUNC(lxor, int8_t, int8_t) +FUNC_FUNC(lxor, uint8_t, uint8_t) +FUNC_FUNC(lxor, int16_t, int16_t) +FUNC_FUNC(lxor, uint16_t, uint16_t) +FUNC_FUNC(lxor, int32_t, int32_t) +FUNC_FUNC(lxor, uint32_t, uint32_t) +FUNC_FUNC(lxor, int64_t, int64_t) +FUNC_FUNC(lxor, uint64_t, uint64_t) +FUNC_FUNC(lxor, long, long) +FUNC_FUNC(lxor, ulong, unsigned long) + + +/* Logical */ +#if OMPI_HAVE_FORTRAN_LOGICAL +FORT_INT_FUNC(lxor, fortran_logical, ompi_fortran_logical_t) +#endif +/* C++ bool */ +FUNC_FUNC(lxor, bool, bool) + +/************************************************************************* + * Bitwise AND + *************************************************************************/ + +/* C integer */ +FUNC_FUNC(band, int8_t, int8_t) +FUNC_FUNC(band, uint8_t, uint8_t) +FUNC_FUNC(band, int16_t, int16_t) +FUNC_FUNC(band, uint16_t, uint16_t) +FUNC_FUNC(band, int32_t, int32_t) +FUNC_FUNC(band, uint32_t, uint32_t) +FUNC_FUNC(band, int64_t, int64_t) +FUNC_FUNC(band, uint64_t, uint64_t) +FUNC_FUNC(band, long, long) +FUNC_FUNC(band, ulong, unsigned long) + +/* Fortran integer */ +#if OMPI_HAVE_FORTRAN_INTEGER +FORT_INT_FUNC(band, fortran_integer, ompi_fortran_integer_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER1 +FORT_INT_FUNC(band, fortran_integer1, ompi_fortran_integer1_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER2 +FORT_INT_FUNC(band, fortran_integer2, ompi_fortran_integer2_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER4 +FORT_INT_FUNC(band, fortran_integer4, ompi_fortran_integer4_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER8 +FORT_INT_FUNC(band, fortran_integer8, ompi_fortran_integer8_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER16 +FORT_INT_FUNC(band, fortran_integer16, ompi_fortran_integer16_t) +#endif +/* Byte */ +FUNC_FUNC(band, byte, char) + +/************************************************************************* + * Bitwise OR + *************************************************************************/ + +/* C integer */ +FUNC_FUNC(bor, int8_t, int8_t) +FUNC_FUNC(bor, uint8_t, uint8_t) +FUNC_FUNC(bor, int16_t, int16_t) +FUNC_FUNC(bor, uint16_t, uint16_t) +FUNC_FUNC(bor, int32_t, int32_t) +FUNC_FUNC(bor, uint32_t, uint32_t) +FUNC_FUNC(bor, int64_t, int64_t) +FUNC_FUNC(bor, uint64_t, uint64_t) +FUNC_FUNC(bor, long, long) +FUNC_FUNC(bor, ulong, unsigned long) + +/* Fortran integer */ +#if OMPI_HAVE_FORTRAN_INTEGER +FORT_INT_FUNC(bor, fortran_integer, ompi_fortran_integer_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER1 +FORT_INT_FUNC(bor, fortran_integer1, ompi_fortran_integer1_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER2 +FORT_INT_FUNC(bor, fortran_integer2, ompi_fortran_integer2_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER4 +FORT_INT_FUNC(bor, fortran_integer4, ompi_fortran_integer4_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER8 +FORT_INT_FUNC(bor, fortran_integer8, ompi_fortran_integer8_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER16 +FORT_INT_FUNC(bor, fortran_integer16, ompi_fortran_integer16_t) +#endif +/* Byte */ +FUNC_FUNC(bor, byte, char) + +/************************************************************************* + * Bitwise XOR + *************************************************************************/ + +/* C integer */ +FUNC_FUNC(bxor, int8_t, int8_t) +FUNC_FUNC(bxor, uint8_t, uint8_t) +FUNC_FUNC(bxor, int16_t, int16_t) +FUNC_FUNC(bxor, uint16_t, uint16_t) +FUNC_FUNC(bxor, int32_t, int32_t) +FUNC_FUNC(bxor, uint32_t, uint32_t) +FUNC_FUNC(bxor, int64_t, int64_t) +FUNC_FUNC(bxor, uint64_t, uint64_t) +FUNC_FUNC(bxor, long, long) +FUNC_FUNC(bxor, ulong, unsigned long) + +/* Fortran integer */ +#if OMPI_HAVE_FORTRAN_INTEGER +FORT_INT_FUNC(bxor, fortran_integer, ompi_fortran_integer_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER1 +FORT_INT_FUNC(bxor, fortran_integer1, ompi_fortran_integer1_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER2 +FORT_INT_FUNC(bxor, fortran_integer2, ompi_fortran_integer2_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER4 +FORT_INT_FUNC(bxor, fortran_integer4, ompi_fortran_integer4_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER8 +FORT_INT_FUNC(bxor, fortran_integer8, ompi_fortran_integer8_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER16 +FORT_INT_FUNC(bxor, fortran_integer16, ompi_fortran_integer16_t) +#endif +/* Byte */ +FUNC_FUNC(bxor, byte, char) + +/************************************************************************* + * Max location + *************************************************************************/ + +LOC_FUNC(maxloc, float_int) +LOC_FUNC(maxloc, double_int) +LOC_FUNC(maxloc, long_int) +LOC_FUNC(maxloc, 2int) +LOC_FUNC(maxloc, short_int) +LOC_FUNC(maxloc, long_double_int) + +/* Fortran compat types */ +LOC_FUNC(maxloc, 2float) +LOC_FUNC(maxloc, 2double) +LOC_FUNC(maxloc, 2int8) +LOC_FUNC(maxloc, 2int16) +LOC_FUNC(maxloc, 2int32) +LOC_FUNC(maxloc, 2int64) + +/* Fortran integer */ +#if OMPI_HAVE_FORTRAN_INTEGER +FORT_LOC_INT_FUNC(maxloc, 2integer, ompi_fortran_integer_t) +#endif +/* Fortran float */ +#if OMPI_HAVE_FORTRAN_REAL +FORT_LOC_FLOAT_FUNC(maxloc, 2real, ompi_fortran_real_t) +#endif +#if OMPI_HAVE_FORTRAN_DOUBLE_PRECISION +FORT_LOC_FLOAT_FUNC(maxloc, 2double_precision, ompi_fortran_double_precision_t) +#endif + +/************************************************************************* + * Min location + *************************************************************************/ + +LOC_FUNC(minloc, float_int) +LOC_FUNC(minloc, double_int) +LOC_FUNC(minloc, long_int) +LOC_FUNC(minloc, 2int) +LOC_FUNC(minloc, short_int) +LOC_FUNC(minloc, long_double_int) + +/* Fortran compat types */ +LOC_FUNC(minloc, 2float) +LOC_FUNC(minloc, 2double) +LOC_FUNC(minloc, 2int8) +LOC_FUNC(minloc, 2int16) +LOC_FUNC(minloc, 2int32) +LOC_FUNC(minloc, 2int64) + +/* Fortran integer */ +#if OMPI_HAVE_FORTRAN_INTEGER +FORT_LOC_INT_FUNC(minloc, 2integer, ompi_fortran_integer_t) +#endif +/* Fortran float */ +#if OMPI_HAVE_FORTRAN_REAL +FORT_LOC_FLOAT_FUNC(minloc, 2real, ompi_fortran_real_t) +#endif +#if OMPI_HAVE_FORTRAN_DOUBLE_PRECISION +FORT_LOC_FLOAT_FUNC(minloc, 2double_precision, ompi_fortran_double_precision_t) +#endif + +/* + * This is a three buffer (2 input and 1 output) version of the reduction + * routines, needed for some optimizations. + */ +#define FUNC_3BUF(name, type_name, type) \ + static \ + void ompi_op_cuda_3buff_##name##_##type_name(const void *in1, const void *in2, void *out, int *count, \ + struct ompi_datatype_t **dtype, \ + int device, \ + opal_accelerator_stream_t *stream, \ + struct ompi_op_base_module_1_0_0_t *module) { \ + int threads_per_block, max_blocks; \ + int source1_device, source2_device, target_device; \ + type *source1, *source2, *target; \ + int n = *count; \ + device_op_pre(in1, (void**)&source1, &source1_device, \ + in2, (void**)&source2, &source2_device, \ + out, (void**)&target, &target_device, \ + n, *dtype, \ + &threads_per_block, &max_blocks, &device, stream); \ + CUstream *custream = (CUstream*)stream->stream; \ + ompi_op_cuda_3buff_##name##_##type_name##_submit(source1, source2, target, n, threads_per_block, max_blocks, *custream);\ + device_op_post(source1, source1_device, source2, source2_device, out, target, target_device, n, *dtype, device, stream);\ + } + + +#define OP_FUNC_3BUF(name, type_name, type, ...) FUNC_3BUF(name, __VA_ARGS__##type_name, __VA_ARGS__##type) + +/* reuse the macro above, no work is actually done so we don't care about the func */ +#define FUNC_FUNC_3BUF(name, type_name, type, ...) FUNC_3BUF(name, __VA_ARGS__##type_name, __VA_ARGS__##type) + +/* + * Since all the functions in this file are essentially identical, we + * use a macro to substitute in names and types. The core operation + * in all functions that use this macro is the same. + * + * This macro is for minloc and maxloc + */ +#define LOC_FUNC_3BUF(name, type_name) FUNC_3BUF(name, type_name, ompi_op_predefined_##type_name##_t) + + +/* Dispatch Fortran types to C types */ +#define FORT_INT_FUNC_3BUF(name, type_name, type) \ + static \ + void ompi_op_cuda_3buff_##name##_##type_name(const void *in1, const void *in2, void *out, int *count, \ + struct ompi_datatype_t **dtype, \ + int device, \ + opal_accelerator_stream_t *stream, \ + struct ompi_op_base_module_1_0_0_t *module) { \ + \ + _Static_assert(sizeof(type) >= sizeof(int8_t) && sizeof(type) <= sizeof(int64_t), \ + "Unsupported Fortran integer size"); \ + switch(sizeof(type)) { \ + case sizeof(int8_t): \ + ompi_op_cuda_3buff_##name##_int8_t(in1, in2, out, count, dtype, device, stream, module); \ + break; \ + case sizeof(int16_t): \ + ompi_op_cuda_3buff_##name##_int16_t(in1, in2, out, count, dtype, device, stream, module); \ + break; \ + case sizeof(int32_t): \ + ompi_op_cuda_3buff_##name##_int32_t(in1, in2, out, count, dtype, device, stream, module); \ + break; \ + case sizeof(int64_t): \ + ompi_op_cuda_3buff_##name##_int64_t(in1, in2, out, count, dtype, device, stream, module); \ + break; \ + } \ + } + +/* Dispatch Fortran types to C types */ +#define FORT_FLOAT_FUNC_3BUF(name, type_name, type) \ + static \ + void ompi_op_cuda_3buff_##name##_##type_name(const void *in1, const void *in2, void *out, int *count, \ + struct ompi_datatype_t **dtype, \ + int device, \ + opal_accelerator_stream_t *stream, \ + struct ompi_op_base_module_1_0_0_t *module) { \ + _Static_assert(sizeof(type) >= sizeof(float) && sizeof(type) <= sizeof(long double), \ + "Unsupported Fortran float size"); \ + switch(sizeof(type)) { \ + case sizeof(float): \ + ompi_op_cuda_3buff_##name##_float(in1, in2, out, count, dtype, device, stream, module); \ + break; \ + case sizeof(double): \ + ompi_op_cuda_3buff_##name##_double(in1, in2, out, count, dtype, device, stream, module); \ + break; \ + case sizeof(long double): \ + ompi_op_cuda_3buff_##name##_long_double(in1, in2, out, count, dtype, device, stream, module); \ + break; \ + } \ + } + +#define FORT_LOC_INT_FUNC_3BUF(name, type_name, type) \ + static \ + void ompi_op_cuda_3buff_##name##_##type_name(const void *in1, const void *in2, void *inout, int *count, \ + struct ompi_datatype_t **dtype, \ + int device, \ + opal_accelerator_stream_t *stream, \ + struct ompi_op_base_module_1_0_0_t *module) { \ + \ + _Static_assert(sizeof(type) >= sizeof(int8_t) && sizeof(type) <= sizeof(int64_t), \ + "Unsupported Fortran integer size"); \ + switch(sizeof(type)) { \ + case sizeof(int8_t): \ + ompi_op_cuda_3buff_##name##_2int8(in1, in2, inout, count, dtype, device, stream, module); \ + break; \ + case sizeof(int16_t): \ + ompi_op_cuda_3buff_##name##_2int16(in1, in2, inout, count, dtype, device, stream, module); \ + break; \ + case sizeof(int32_t): \ + ompi_op_cuda_3buff_##name##_2int32(in1, in2, inout, count, dtype, device, stream, module); \ + break; \ + case sizeof(int64_t): \ + ompi_op_cuda_3buff_##name##_2int64(in1, in2, inout, count, dtype, device, stream, module); \ + break; \ + } \ + } + +#define FORT_LOC_FLOAT_FUNC_3BUF(name, type_name, type) \ + static \ + void ompi_op_cuda_3buff_##name##_##type_name(const void *in1, const void *in2, void *inout, int *count, \ + struct ompi_datatype_t **dtype, \ + int device, \ + opal_accelerator_stream_t *stream, \ + struct ompi_op_base_module_1_0_0_t *module) { \ + _Static_assert(sizeof(type) >= sizeof(float) && sizeof(type) <= sizeof(double), \ + "IUnsuported Fortran float size"); \ + switch(sizeof(type)) { \ + case sizeof(float): \ + ompi_op_cuda_3buff_##name##_2float(in1, in2, inout, count, dtype, device, stream, module); \ + break; \ + case sizeof(double): \ + ompi_op_cuda_3buff_##name##_2double(in1, in2, inout, count, dtype, device, stream, module); \ + break; \ + } \ + } + + +/************************************************************************* + * Max + *************************************************************************/ + +/* C integer */ +FUNC_FUNC_3BUF(max, int8_t, int8_t) +FUNC_FUNC_3BUF(max, uint8_t, uint8_t) +FUNC_FUNC_3BUF(max, int16_t, int16_t) +FUNC_FUNC_3BUF(max, uint16_t, uint16_t) +FUNC_FUNC_3BUF(max, int32_t, int32_t) +FUNC_FUNC_3BUF(max, uint32_t, uint32_t) +FUNC_FUNC_3BUF(max, int64_t, int64_t) +FUNC_FUNC_3BUF(max, uint64_t, uint64_t) +FUNC_FUNC_3BUF(max, long, long) +FUNC_FUNC_3BUF(max, ulong, unsigned long) + +/* Fortran integer */ +#if OMPI_HAVE_FORTRAN_INTEGER +FORT_INT_FUNC_3BUF(max, fortran_integer, ompi_fortran_integer_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER1 +FORT_INT_FUNC_3BUF(max, fortran_integer1, ompi_fortran_integer1_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER2 +FORT_INT_FUNC_3BUF(max, fortran_integer2, ompi_fortran_integer2_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER4 +FORT_INT_FUNC_3BUF(max, fortran_integer4, ompi_fortran_integer4_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER8 +FORT_INT_FUNC_3BUF(max, fortran_integer8, ompi_fortran_integer8_t) +#endif +/* Floating point */ +#if 0 +#if defined(HAVE_SHORT_FLOAT) +FUNC_FUNC_3BUF(max, short_float, short float) +#elif defined(HAVE_OPAL_SHORT_FLOAT_T) +FUNC_FUNC_3BUF(max, short_float, opal_short_float_t) +#endif +#endif // 0 +FUNC_FUNC_3BUF(max, float, float) +FUNC_FUNC_3BUF(max, double, double) +FUNC_FUNC_3BUF(max, long_double, long double) +#if OMPI_HAVE_FORTRAN_REAL +FORT_FLOAT_FUNC_3BUF(max, fortran_real, ompi_fortran_real_t) +#endif +#if OMPI_HAVE_FORTRAN_DOUBLE_PRECISION +FORT_FLOAT_FUNC_3BUF(max, fortran_double_precision, ompi_fortran_double_precision_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL2 +FORT_FLOAT_FUNC_3BUF(max, fortran_real2, ompi_fortran_real2_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL4 +FORT_FLOAT_FUNC_3BUF(max, fortran_real4, ompi_fortran_real4_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL8 +FORT_FLOAT_FUNC_3BUF(max, fortran_real8, ompi_fortran_real8_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL16 && OMPI_REAL16_MATCHES_C +FORT_FLOAT_FUNC_3BUF(max, fortran_real16, ompi_fortran_real16_t) +#endif + + +/************************************************************************* + * Min + *************************************************************************/ + +/* C integer */ +FUNC_FUNC_3BUF(min, int8_t, int8_t) +FUNC_FUNC_3BUF(min, uint8_t, uint8_t) +FUNC_FUNC_3BUF(min, int16_t, int16_t) +FUNC_FUNC_3BUF(min, uint16_t, uint16_t) +FUNC_FUNC_3BUF(min, int32_t, int32_t) +FUNC_FUNC_3BUF(min, uint32_t, uint32_t) +FUNC_FUNC_3BUF(min, int64_t, int64_t) +FUNC_FUNC_3BUF(min, uint64_t, uint64_t) +FUNC_FUNC_3BUF(min, long, long) +FUNC_FUNC_3BUF(min, ulong, unsigned long) + +/* Fortran integer */ +#if OMPI_HAVE_FORTRAN_INTEGER +FORT_INT_FUNC_3BUF(min, fortran_integer, ompi_fortran_integer_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER1 +FORT_INT_FUNC_3BUF(min, fortran_integer1, ompi_fortran_integer1_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER2 +FORT_INT_FUNC_3BUF(min, fortran_integer2, ompi_fortran_integer2_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER4 +FORT_INT_FUNC_3BUF(min, fortran_integer4, ompi_fortran_integer4_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER8 +FORT_INT_FUNC_3BUF(min, fortran_integer8, ompi_fortran_integer8_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER16 +FORT_INT_FUNC_3BUF(min, fortran_integer16, ompi_fortran_integer16_t) +#endif +/* Floating point */ +#if 0 +#if defined(HAVE_SHORT_FLOAT) +FUNC_FUNC_3BUF(min, short_float, short float) +#elif defined(HAVE_OPAL_SHORT_FLOAT_T) +FUNC_FUNC_3BUF(min, short_float, opal_short_float_t) +#endif +#endif // 0 +FUNC_FUNC_3BUF(min, float, float) +FUNC_FUNC_3BUF(min, double, double) +FUNC_FUNC_3BUF(min, long_double, long double) +#if OMPI_HAVE_FORTRAN_REAL +FORT_FLOAT_FUNC_3BUF(min, fortran_real, ompi_fortran_real_t) +#endif +#if OMPI_HAVE_FORTRAN_DOUBLE_PRECISION +FORT_FLOAT_FUNC_3BUF(min, fortran_double_precision, ompi_fortran_double_precision_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL2 +FORT_FLOAT_FUNC_3BUF(min, fortran_real2, ompi_fortran_real2_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL4 +FORT_FLOAT_FUNC_3BUF(min, fortran_real4, ompi_fortran_real4_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL8 +FORT_FLOAT_FUNC_3BUF(min, fortran_real8, ompi_fortran_real8_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL16 && OMPI_REAL16_MATCHES_C +FORT_FLOAT_FUNC_3BUF(min, fortran_real16, ompi_fortran_real16_t) +#endif + +/************************************************************************* + * Sum + *************************************************************************/ + +/* C integer */ +OP_FUNC_3BUF(sum, int8_t, int8_t) +OP_FUNC_3BUF(sum, uint8_t, uint8_t) +OP_FUNC_3BUF(sum, int16_t, int16_t) +OP_FUNC_3BUF(sum, uint16_t, uint16_t) +OP_FUNC_3BUF(sum, int32_t, int32_t) +OP_FUNC_3BUF(sum, uint32_t, uint32_t) +OP_FUNC_3BUF(sum, int64_t, int64_t) +OP_FUNC_3BUF(sum, uint64_t, uint64_t) +OP_FUNC_3BUF(sum, long, long) +OP_FUNC_3BUF(sum, ulong, unsigned long) + +/* Fortran integer */ +#if OMPI_HAVE_FORTRAN_INTEGER +FORT_INT_FUNC_3BUF(sum, fortran_integer, ompi_fortran_integer_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER1 +FORT_INT_FUNC_3BUF(sum, fortran_integer1, ompi_fortran_integer1_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER2 +FORT_INT_FUNC_3BUF(sum, fortran_integer2, ompi_fortran_integer2_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER4 +FORT_INT_FUNC_3BUF(sum, fortran_integer4, ompi_fortran_integer4_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER8 +FORT_INT_FUNC_3BUF(sum, fortran_integer8, ompi_fortran_integer8_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER16 +FORT_INT_FUNC_3BUF(sum, fortran_integer16, ompi_fortran_integer16_t) +#endif +/* Floating point */ +#if 0 +#if defined(HAVE_SHORT_FLOAT) +OP_FUNC_3BUF(sum, short_float, short float) +#elif defined(HAVE_OPAL_SHORT_FLOAT_T) +OP_FUNC_3BUF(sum, short_float, opal_short_float_t) +#endif +#endif // 0 +OP_FUNC_3BUF(sum, float, float) +OP_FUNC_3BUF(sum, double, double) +OP_FUNC_3BUF(sum, long_double, long double) +#if OMPI_HAVE_FORTRAN_REAL +FORT_FLOAT_FUNC_3BUF(sum, fortran_real, ompi_fortran_real_t) +#endif +#if OMPI_HAVE_FORTRAN_DOUBLE_PRECISION +FORT_FLOAT_FUNC_3BUF(sum, fortran_double_precision, ompi_fortran_double_precision_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL2 +FORT_FLOAT_FUNC_3BUF(sum, fortran_real2, ompi_fortran_real2_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL4 +FORT_FLOAT_FUNC_3BUF(sum, fortran_real4, ompi_fortran_real4_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL8 +FORT_FLOAT_FUNC_3BUF(sum, fortran_real8, ompi_fortran_real8_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL16 && OMPI_REAL16_MATCHES_C +FORT_FLOAT_FUNC_3BUF(sum, fortran_real16, ompi_fortran_real16_t) +#endif +/* Complex */ +#if 0 +#if defined(HAVE_SHORT_FLOAT__COMPLEX) +OP_FUNC_3BUF(sum, c_short_float_complex, short float _Complex) +#elif defined(HAVE_OPAL_SHORT_FLOAT_COMPLEX_T) +COMPLEX_SUM_FUNC_3BUF(c_short_float_complex, opal_short_float_t) +#endif +OP_FUNC_3BUF(sum, c_float_complex, float _Complex) +OP_FUNC_3BUF(sum, c_double_complex, double _Complex) +OP_FUNC_3BUF(sum, c_long_double_complex, long double _Complex) +#endif // 0 + +FUNC_FUNC_3BUF(sum, c_float_complex, cuFloatComplex) +FUNC_FUNC_3BUF(sum, c_double_complex, cuDoubleComplex) + +/************************************************************************* + * Product + *************************************************************************/ + +/* C integer */ +OP_FUNC_3BUF(prod, int8_t, int8_t) +OP_FUNC_3BUF(prod, uint8_t, uint8_t) +OP_FUNC_3BUF(prod, int16_t, int16_t) +OP_FUNC_3BUF(prod, uint16_t, uint16_t) +OP_FUNC_3BUF(prod, int32_t, int32_t) +OP_FUNC_3BUF(prod, uint32_t, uint32_t) +OP_FUNC_3BUF(prod, int64_t, int64_t) +OP_FUNC_3BUF(prod, uint64_t, uint64_t) +OP_FUNC_3BUF(prod, long, long) +OP_FUNC_3BUF(prod, ulong, unsigned long) + +/* Fortran integer */ +#if OMPI_HAVE_FORTRAN_INTEGER +FORT_INT_FUNC_3BUF(prod, fortran_integer, ompi_fortran_integer_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER1 +FORT_INT_FUNC_3BUF(prod, fortran_integer1, ompi_fortran_integer1_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER2 +FORT_INT_FUNC_3BUF(prod, fortran_integer2, ompi_fortran_integer2_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER4 +FORT_INT_FUNC_3BUF(prod, fortran_integer4, ompi_fortran_integer4_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER8 +FORT_INT_FUNC_3BUF(prod, fortran_integer8, ompi_fortran_integer8_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER16 +FORT_INT_FUNC_3BUF(prod, fortran_integer16, ompi_fortran_integer16_t) +#endif +/* Floating point */ +#if 0 +#if defined(HAVE_SHORT_FLOAT) +FORT_FLOAT_FUNC_3BUF(prod, short_float, short float) +#elif defined(HAVE_OPAL_SHORT_FLOAT_T) +FORT_FLOAT_FUNC_3BUF(prod, short_float, opal_short_float_t) +#endif +#endif // 0 +OP_FUNC_3BUF(prod, float, float) +OP_FUNC_3BUF(prod, double, double) +OP_FUNC_3BUF(prod, long_double, long double) +#if OMPI_HAVE_FORTRAN_REAL +FORT_FLOAT_FUNC_3BUF(prod, fortran_real, ompi_fortran_real_t) +#endif +#if OMPI_HAVE_FORTRAN_DOUBLE_PRECISION +FORT_FLOAT_FUNC_3BUF(prod, fortran_double_precision, ompi_fortran_double_precision_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL2 +FORT_FLOAT_FUNC_3BUF(prod, fortran_real2, ompi_fortran_real2_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL4 +FORT_FLOAT_FUNC_3BUF(prod, fortran_real4, ompi_fortran_real4_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL8 +FORT_FLOAT_FUNC_3BUF(prod, fortran_real8, ompi_fortran_real8_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL16 && OMPI_REAL16_MATCHES_C +FORT_FLOAT_FUNC_3BUF(prod, fortran_real16, ompi_fortran_real16_t) +#endif +/* Complex */ +#if 0 +#if defined(HAVE_SHORT_FLOAT__COMPLEX) +OP_FUNC_3BUF(prod, c_short_float_complex, short float _Complex) +#elif defined(HAVE_OPAL_SHORT_FLOAT_COMPLEX_T) +COMPLEX_PROD_FUNC_3BUF(c_short_float_complex, opal_short_float_t) +#endif +OP_FUNC_3BUF(prod, c_long_double_complex, long double _Complex) +#endif // 0 + +FUNC_FUNC_3BUF(prod, c_float_complex, cuFloatComplex) +FUNC_FUNC_3BUF(prod, c_double_complex, cuDoubleComplex) + +/************************************************************************* + * Logical AND + *************************************************************************/ + +/* C integer */ +FUNC_FUNC_3BUF(land, int8_t, int8_t) +FUNC_FUNC_3BUF(land, uint8_t, uint8_t) +FUNC_FUNC_3BUF(land, int16_t, int16_t) +FUNC_FUNC_3BUF(land, uint16_t, uint16_t) +FUNC_FUNC_3BUF(land, int32_t, int32_t) +FUNC_FUNC_3BUF(land, uint32_t, uint32_t) +FUNC_FUNC_3BUF(land, int64_t, int64_t) +FUNC_FUNC_3BUF(land, uint64_t, uint64_t) +FUNC_FUNC_3BUF(land, long, long) +FUNC_FUNC_3BUF(land, ulong, unsigned long) + +/* Logical */ +#if OMPI_HAVE_FORTRAN_LOGICAL +FORT_INT_FUNC_3BUF(land, fortran_logical, ompi_fortran_logical_t) +#endif +/* C++ bool */ +FUNC_FUNC_3BUF(land, bool, bool) + +/************************************************************************* + * Logical OR + *************************************************************************/ + +/* C integer */ +FUNC_FUNC_3BUF(lor, int8_t, int8_t) +FUNC_FUNC_3BUF(lor, uint8_t, uint8_t) +FUNC_FUNC_3BUF(lor, int16_t, int16_t) +FUNC_FUNC_3BUF(lor, uint16_t, uint16_t) +FUNC_FUNC_3BUF(lor, int32_t, int32_t) +FUNC_FUNC_3BUF(lor, uint32_t, uint32_t) +FUNC_FUNC_3BUF(lor, int64_t, int64_t) +FUNC_FUNC_3BUF(lor, uint64_t, uint64_t) +FUNC_FUNC_3BUF(lor, long, long) +FUNC_FUNC_3BUF(lor, ulong, unsigned long) + +/* Logical */ +#if OMPI_HAVE_FORTRAN_LOGICAL +FORT_INT_FUNC_3BUF(lor, fortran_logical, ompi_fortran_logical_t) +#endif +/* C++ bool */ +FUNC_FUNC_3BUF(lor, bool, bool) + +/************************************************************************* + * Logical XOR + *************************************************************************/ + +/* C integer */ +FUNC_FUNC_3BUF(lxor, int8_t, int8_t) +FUNC_FUNC_3BUF(lxor, uint8_t, uint8_t) +FUNC_FUNC_3BUF(lxor, int16_t, int16_t) +FUNC_FUNC_3BUF(lxor, uint16_t, uint16_t) +FUNC_FUNC_3BUF(lxor, int32_t, int32_t) +FUNC_FUNC_3BUF(lxor, uint32_t, uint32_t) +FUNC_FUNC_3BUF(lxor, int64_t, int64_t) +FUNC_FUNC_3BUF(lxor, uint64_t, uint64_t) +FUNC_FUNC_3BUF(lxor, long, long) +FUNC_FUNC_3BUF(lxor, ulong, unsigned long) + +/* Logical */ +#if OMPI_HAVE_FORTRAN_LOGICAL +FORT_INT_FUNC_3BUF(lxor, fortran_logical, ompi_fortran_logical_t) +#endif +/* C++ bool */ +FUNC_FUNC_3BUF(lxor, bool, bool) + +/************************************************************************* + * Bitwise AND + *************************************************************************/ + +/* C integer */ +FUNC_FUNC_3BUF(band, int8_t, int8_t) +FUNC_FUNC_3BUF(band, uint8_t, uint8_t) +FUNC_FUNC_3BUF(band, int16_t, int16_t) +FUNC_FUNC_3BUF(band, uint16_t, uint16_t) +FUNC_FUNC_3BUF(band, int32_t, int32_t) +FUNC_FUNC_3BUF(band, uint32_t, uint32_t) +FUNC_FUNC_3BUF(band, int64_t, int64_t) +FUNC_FUNC_3BUF(band, uint64_t, uint64_t) +FUNC_FUNC_3BUF(band, long, long) +FUNC_FUNC_3BUF(band, ulong, unsigned long) + +/* Fortran integer */ +#if OMPI_HAVE_FORTRAN_INTEGER +FORT_INT_FUNC_3BUF(band, fortran_integer, ompi_fortran_integer_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER1 +FORT_INT_FUNC_3BUF(band, fortran_integer1, ompi_fortran_integer1_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER2 +FORT_INT_FUNC_3BUF(band, fortran_integer2, ompi_fortran_integer2_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER4 +FORT_INT_FUNC_3BUF(band, fortran_integer4, ompi_fortran_integer4_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER8 +FORT_INT_FUNC_3BUF(band, fortran_integer8, ompi_fortran_integer8_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER16 +FORT_INT_FUNC_3BUF(band, fortran_integer16, ompi_fortran_integer16_t) +#endif +/* Byte */ +FORT_INT_FUNC_3BUF(band, byte, char) + +/************************************************************************* + * Bitwise OR + *************************************************************************/ + +/* C integer */ +FUNC_FUNC_3BUF(bor, int8_t, int8_t) +FUNC_FUNC_3BUF(bor, uint8_t, uint8_t) +FUNC_FUNC_3BUF(bor, int16_t, int16_t) +FUNC_FUNC_3BUF(bor, uint16_t, uint16_t) +FUNC_FUNC_3BUF(bor, int32_t, int32_t) +FUNC_FUNC_3BUF(bor, uint32_t, uint32_t) +FUNC_FUNC_3BUF(bor, int64_t, int64_t) +FUNC_FUNC_3BUF(bor, uint64_t, uint64_t) +FUNC_FUNC_3BUF(bor, long, long) +FUNC_FUNC_3BUF(bor, ulong, unsigned long) + +/* Fortran integer */ +#if OMPI_HAVE_FORTRAN_INTEGER +FORT_INT_FUNC_3BUF(bor, fortran_integer, ompi_fortran_integer_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER1 +FORT_INT_FUNC_3BUF(bor, fortran_integer1, ompi_fortran_integer1_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER2 +FORT_INT_FUNC_3BUF(bor, fortran_integer2, ompi_fortran_integer2_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER4 +FORT_INT_FUNC_3BUF(bor, fortran_integer4, ompi_fortran_integer4_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER8 +FORT_INT_FUNC_3BUF(bor, fortran_integer8, ompi_fortran_integer8_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER16 +FORT_INT_FUNC_3BUF(bor, fortran_integer16, ompi_fortran_integer16_t) +#endif +/* Byte */ +FORT_INT_FUNC_3BUF(bor, byte, char) + +/************************************************************************* + * Bitwise XOR + *************************************************************************/ + +/* C integer */ +FUNC_FUNC_3BUF(bxor, int8_t, int8_t) +FUNC_FUNC_3BUF(bxor, uint8_t, uint8_t) +FUNC_FUNC_3BUF(bxor, int16_t, int16_t) +FUNC_FUNC_3BUF(bxor, uint16_t, uint16_t) +FUNC_FUNC_3BUF(bxor, int32_t, int32_t) +FUNC_FUNC_3BUF(bxor, uint32_t, uint32_t) +FUNC_FUNC_3BUF(bxor, int64_t, int64_t) +FUNC_FUNC_3BUF(bxor, uint64_t, uint64_t) +FUNC_FUNC_3BUF(bxor, long, long) +FUNC_FUNC_3BUF(bxor, ulong, unsigned long) + +/* Fortran integer */ +#if OMPI_HAVE_FORTRAN_INTEGER +FORT_INT_FUNC_3BUF(bxor, fortran_integer, ompi_fortran_integer_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER1 +FORT_INT_FUNC_3BUF(bxor, fortran_integer1, ompi_fortran_integer1_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER2 +FORT_INT_FUNC_3BUF(bxor, fortran_integer2, ompi_fortran_integer2_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER4 +FORT_INT_FUNC_3BUF(bxor, fortran_integer4, ompi_fortran_integer4_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER8 +FORT_INT_FUNC_3BUF(bxor, fortran_integer8, ompi_fortran_integer8_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER16 +FORT_INT_FUNC_3BUF(bxor, fortran_integer16, ompi_fortran_integer16_t) +#endif +/* Byte */ +FORT_INT_FUNC_3BUF(bxor, byte, char) + +/************************************************************************* + * Max location + *************************************************************************/ + +LOC_FUNC_3BUF(maxloc, float_int) +LOC_FUNC_3BUF(maxloc, double_int) +LOC_FUNC_3BUF(maxloc, long_int) +LOC_FUNC_3BUF(maxloc, 2int) +LOC_FUNC_3BUF(maxloc, short_int) +LOC_FUNC_3BUF(maxloc, long_double_int) + +/* Fortran compat types */ +LOC_FUNC_3BUF(maxloc, 2float) +LOC_FUNC_3BUF(maxloc, 2double) +LOC_FUNC_3BUF(maxloc, 2int8) +LOC_FUNC_3BUF(maxloc, 2int16) +LOC_FUNC_3BUF(maxloc, 2int32) +LOC_FUNC_3BUF(maxloc, 2int64) + +/* Fortran integer */ +#if OMPI_HAVE_FORTRAN_INTEGER +FORT_LOC_INT_FUNC_3BUF(maxloc, 2integer, ompi_fortran_integer_t) +#endif +/* Fortran float */ +#if OMPI_HAVE_FORTRAN_REAL +FORT_LOC_FLOAT_FUNC_3BUF(maxloc, 2real, ompi_fortran_real_t) +#endif +#if OMPI_HAVE_FORTRAN_DOUBLE_PRECISION +FORT_LOC_FLOAT_FUNC_3BUF(maxloc, 2double_precision, ompi_fortran_double_precision_t) +#endif + + +/************************************************************************* + * Min location + *************************************************************************/ + +LOC_FUNC_3BUF(minloc, float_int) +LOC_FUNC_3BUF(minloc, double_int) +LOC_FUNC_3BUF(minloc, long_int) +LOC_FUNC_3BUF(minloc, 2int) +LOC_FUNC_3BUF(minloc, short_int) +LOC_FUNC_3BUF(minloc, long_double_int) + +/* Fortran compat types */ +LOC_FUNC_3BUF(minloc, 2float) +LOC_FUNC_3BUF(minloc, 2double) +LOC_FUNC_3BUF(minloc, 2int8) +LOC_FUNC_3BUF(minloc, 2int16) +LOC_FUNC_3BUF(minloc, 2int32) +LOC_FUNC_3BUF(minloc, 2int64) + +/* Fortran integer */ +#if OMPI_HAVE_FORTRAN_INTEGER +FORT_LOC_INT_FUNC_3BUF(minloc, 2integer, ompi_fortran_integer_t) +#endif +/* Fortran float */ +#if OMPI_HAVE_FORTRAN_REAL +FORT_LOC_FLOAT_FUNC_3BUF(minloc, 2real, ompi_fortran_real_t) +#endif +#if OMPI_HAVE_FORTRAN_DOUBLE_PRECISION +FORT_LOC_FLOAT_FUNC_3BUF(minloc, 2double_precision, ompi_fortran_double_precision_t) +#endif + +/* + * Helpful defines, because there's soooo many names! + * + * **NOTE** These #define's used to be strictly ordered but the use of + * designated initializers removed this restrictions. When adding new + * operators ALWAYS use a designated initializer! + */ + +/** C integer ***********************************************************/ +#define C_INTEGER(name, ftype) \ + [OMPI_OP_BASE_TYPE_INT8_T] = ompi_op_cuda_##ftype##_##name##_int8_t, \ + [OMPI_OP_BASE_TYPE_UINT8_T] = ompi_op_cuda_##ftype##_##name##_uint8_t, \ + [OMPI_OP_BASE_TYPE_INT16_T] = ompi_op_cuda_##ftype##_##name##_int16_t, \ + [OMPI_OP_BASE_TYPE_UINT16_T] = ompi_op_cuda_##ftype##_##name##_uint16_t, \ + [OMPI_OP_BASE_TYPE_INT32_T] = ompi_op_cuda_##ftype##_##name##_int32_t, \ + [OMPI_OP_BASE_TYPE_UINT32_T] = ompi_op_cuda_##ftype##_##name##_uint32_t, \ + [OMPI_OP_BASE_TYPE_INT64_T] = ompi_op_cuda_##ftype##_##name##_int64_t, \ + [OMPI_OP_BASE_TYPE_LONG] = ompi_op_cuda_##ftype##_##name##_long, \ + [OMPI_OP_BASE_TYPE_UNSIGNED_LONG] = ompi_op_cuda_##ftype##_##name##_ulong, \ + [OMPI_OP_BASE_TYPE_UINT64_T] = ompi_op_cuda_##ftype##_##name##_uint64_t + +/** All the Fortran integers ********************************************/ + +#if OMPI_HAVE_FORTRAN_INTEGER +#define FORTRAN_INTEGER_PLAIN(name, ftype) ompi_op_cuda_##ftype##_##name##_fortran_integer +#else +#define FORTRAN_INTEGER_PLAIN(name, ftype) NULL +#endif +#if OMPI_HAVE_FORTRAN_INTEGER1 +#define FORTRAN_INTEGER1(name, ftype) ompi_op_cuda_##ftype##_##name##_fortran_integer1 +#else +#define FORTRAN_INTEGER1(name, ftype) NULL +#endif +#if OMPI_HAVE_FORTRAN_INTEGER2 +#define FORTRAN_INTEGER2(name, ftype) ompi_op_cuda_##ftype##_##name##_fortran_integer2 +#else +#define FORTRAN_INTEGER2(name, ftype) NULL +#endif +#if OMPI_HAVE_FORTRAN_INTEGER4 +#define FORTRAN_INTEGER4(name, ftype) ompi_op_cuda_##ftype##_##name##_fortran_integer4 +#else +#define FORTRAN_INTEGER4(name, ftype) NULL +#endif +#if OMPI_HAVE_FORTRAN_INTEGER8 +#define FORTRAN_INTEGER8(name, ftype) ompi_op_cuda_##ftype##_##name##_fortran_integer8 +#else +#define FORTRAN_INTEGER8(name, ftype) NULL +#endif +#if OMPI_HAVE_FORTRAN_INTEGER16 +#define FORTRAN_INTEGER16(name, ftype) ompi_op_cuda_##ftype##_##name##_fortran_integer16 +#else +#define FORTRAN_INTEGER16(name, ftype) NULL +#endif + +#define FORTRAN_INTEGER(name, ftype) \ + [OMPI_OP_BASE_TYPE_INTEGER] = FORTRAN_INTEGER_PLAIN(name, ftype), \ + [OMPI_OP_BASE_TYPE_INTEGER1] = FORTRAN_INTEGER1(name, ftype), \ + [OMPI_OP_BASE_TYPE_INTEGER2] = FORTRAN_INTEGER2(name, ftype), \ + [OMPI_OP_BASE_TYPE_INTEGER4] = FORTRAN_INTEGER4(name, ftype), \ + [OMPI_OP_BASE_TYPE_INTEGER8] = FORTRAN_INTEGER8(name, ftype), \ + [OMPI_OP_BASE_TYPE_INTEGER16] = FORTRAN_INTEGER16(name, ftype) + +/** All the Fortran reals ***********************************************/ + +#if OMPI_HAVE_FORTRAN_REAL +#define FLOATING_POINT_FORTRAN_REAL_PLAIN(name, ftype) ompi_op_cuda_##ftype##_##name##_fortran_real +#else +#define FLOATING_POINT_FORTRAN_REAL_PLAIN(name, ftype) NULL +#endif +#if OMPI_HAVE_FORTRAN_REAL2 +#define FLOATING_POINT_FORTRAN_REAL2(name, ftype) ompi_op_cuda_##ftype##_##name##_fortran_real2 +#else +#define FLOATING_POINT_FORTRAN_REAL2(name, ftype) NULL +#endif +#if OMPI_HAVE_FORTRAN_REAL4 +#define FLOATING_POINT_FORTRAN_REAL4(name, ftype) ompi_op_cuda_##ftype##_##name##_fortran_real4 +#else +#define FLOATING_POINT_FORTRAN_REAL4(name, ftype) NULL +#endif +#if OMPI_HAVE_FORTRAN_REAL8 +#define FLOATING_POINT_FORTRAN_REAL8(name, ftype) ompi_op_cuda_##ftype##_##name##_fortran_real8 +#else +#define FLOATING_POINT_FORTRAN_REAL8(name, ftype) NULL +#endif +/* If: + - we have fortran REAL*16and* + - fortran REAL*16 matches the bit representation of the + corresponding C type + Only then do we put in function pointers for REAL*16 reductions. + Otherwise, just put in NULL. */ +#if OMPI_HAVE_FORTRAN_REAL16 && OMPI_REAL16_MATCHES_C +#define FLOATING_POINT_FORTRAN_REAL16(name, ftype) ompi_op_cuda_##ftype##_##name##_fortran_real16 +#else +#define FLOATING_POINT_FORTRAN_REAL16(name, ftype) NULL +#endif + +#define FLOATING_POINT_FORTRAN_REAL(name, ftype) \ + [OMPI_OP_BASE_TYPE_REAL] = FLOATING_POINT_FORTRAN_REAL_PLAIN(name, ftype), \ + [OMPI_OP_BASE_TYPE_REAL2] = FLOATING_POINT_FORTRAN_REAL2(name, ftype), \ + [OMPI_OP_BASE_TYPE_REAL4] = FLOATING_POINT_FORTRAN_REAL4(name, ftype), \ + [OMPI_OP_BASE_TYPE_REAL8] = FLOATING_POINT_FORTRAN_REAL8(name, ftype), \ + [OMPI_OP_BASE_TYPE_REAL16] = FLOATING_POINT_FORTRAN_REAL16(name, ftype) + +/** Fortran double precision ********************************************/ + +#if OMPI_HAVE_FORTRAN_DOUBLE_PRECISION +#define FLOATING_POINT_FORTRAN_DOUBLE_PRECISION(name, ftype) \ + ompi_op_cuda_##ftype##_##name##_fortran_double_precision +#else +#define FLOATING_POINT_FORTRAN_DOUBLE_PRECISION(name, ftype) NULL +#endif + +/** Floating point, including all the Fortran reals *********************/ + +//#if defined(HAVE_SHORT_FLOAT) || defined(HAVE_OPAL_SHORT_FLOAT_T) +//#define SHORT_FLOAT(name, ftype) ompi_op_cuda_##ftype##_##name##_short_float +//#else +#define SHORT_FLOAT(name, ftype) NULL +//#endif +#define FLOAT(name, ftype) ompi_op_cuda_##ftype##_##name##_float +#define DOUBLE(name, ftype) ompi_op_cuda_##ftype##_##name##_double +#define LONG_DOUBLE(name, ftype) ompi_op_cuda_##ftype##_##name##_long_double + +#define FLOATING_POINT(name, ftype) \ + [OMPI_OP_BASE_TYPE_SHORT_FLOAT] = SHORT_FLOAT(name, ftype), \ + [OMPI_OP_BASE_TYPE_FLOAT] = FLOAT(name, ftype), \ + [OMPI_OP_BASE_TYPE_DOUBLE] = DOUBLE(name, ftype), \ + FLOATING_POINT_FORTRAN_REAL(name, ftype), \ + [OMPI_OP_BASE_TYPE_DOUBLE_PRECISION] = FLOATING_POINT_FORTRAN_DOUBLE_PRECISION(name, ftype), \ + [OMPI_OP_BASE_TYPE_LONG_DOUBLE] = LONG_DOUBLE(name, ftype) + +/** Fortran logical *****************************************************/ + +#if OMPI_HAVE_FORTRAN_LOGICAL +#define FORTRAN_LOGICAL(name, ftype) \ + ompi_op_cuda_##ftype##_##name##_fortran_logical /* OMPI_OP_CUDA_TYPE_LOGICAL */ +#else +#define FORTRAN_LOGICAL(name, ftype) NULL +#endif + +#define LOGICAL(name, ftype) \ + [OMPI_OP_BASE_TYPE_LOGICAL] = FORTRAN_LOGICAL(name, ftype), \ + [OMPI_OP_BASE_TYPE_BOOL] = ompi_op_cuda_##ftype##_##name##_bool + +/** Complex *****************************************************/ +#if 0 + +#if defined(HAVE_SHORT_FLOAT__COMPLEX) || defined(HAVE_OPAL_SHORT_FLOAT_COMPLEX_T) +#define SHORT_FLOAT_COMPLEX(name, ftype) ompi_op_cuda_##ftype##_##name##_c_short_float_complex +#else +#define SHORT_FLOAT_COMPLEX(name, ftype) NULL +#endif +#define LONG_DOUBLE_COMPLEX(name, ftype) ompi_op_cuda_##ftype##_##name##_c_long_double_complex +#else +#define SHORT_FLOAT_COMPLEX(name, ftype) NULL +#define LONG_DOUBLE_COMPLEX(name, ftype) NULL +#endif // 0 +#define FLOAT_COMPLEX(name, ftype) ompi_op_cuda_##ftype##_##name##_c_float_complex +#define DOUBLE_COMPLEX(name, ftype) ompi_op_cuda_##ftype##_##name##_c_double_complex + +#define COMPLEX(name, ftype) \ + [OMPI_OP_BASE_TYPE_C_SHORT_FLOAT_COMPLEX] = SHORT_FLOAT_COMPLEX(name, ftype), \ + [OMPI_OP_BASE_TYPE_C_FLOAT_COMPLEX] = FLOAT_COMPLEX(name, ftype), \ + [OMPI_OP_BASE_TYPE_C_DOUBLE_COMPLEX] = DOUBLE_COMPLEX(name, ftype), \ + [OMPI_OP_BASE_TYPE_C_LONG_DOUBLE_COMPLEX] = LONG_DOUBLE_COMPLEX(name, ftype) + +/** Byte ****************************************************************/ + +#define BYTE(name, ftype) \ + [OMPI_OP_BASE_TYPE_BYTE] = ompi_op_cuda_##ftype##_##name##_byte + +/** Fortran complex *****************************************************/ +/** Fortran "2" types ***************************************************/ + +#if OMPI_HAVE_FORTRAN_REAL +#define TWOLOC_FORTRAN_2REAL(name, ftype) ompi_op_cuda_##ftype##_##name##_2real +#else +#define TWOLOC_FORTRAN_2REAL(name, ftype) NULL +#endif + +#if OMPI_HAVE_FORTRAN_DOUBLE_PRECISION +#define TWOLOC_FORTRAN_2DOUBLE_PRECISION(name, ftype) ompi_op_cuda_##ftype##_##name##_2double_precision +#else +#define TWOLOC_FORTRAN_2DOUBLE_PRECISION(name, ftype) NULL +#endif +#if OMPI_HAVE_FORTRAN_INTEGER +#define TWOLOC_FORTRAN_2INTEGER(name, ftype) ompi_op_cuda_##ftype##_##name##_2integer +#else +#define TWOLOC_FORTRAN_2INTEGER(name, ftype) NULL +#endif + +/** All "2" types *******************************************************/ + +#define TWOLOC(name, ftype) \ + [OMPI_OP_BASE_TYPE_2REAL] = TWOLOC_FORTRAN_2REAL(name, ftype), \ + [OMPI_OP_BASE_TYPE_2DOUBLE_PRECISION] = TWOLOC_FORTRAN_2DOUBLE_PRECISION(name, ftype), \ + [OMPI_OP_BASE_TYPE_2INTEGER] = TWOLOC_FORTRAN_2INTEGER(name, ftype), \ + [OMPI_OP_BASE_TYPE_FLOAT_INT] = ompi_op_cuda_##ftype##_##name##_float_int, \ + [OMPI_OP_BASE_TYPE_DOUBLE_INT] = ompi_op_cuda_##ftype##_##name##_double_int, \ + [OMPI_OP_BASE_TYPE_LONG_INT] = ompi_op_cuda_##ftype##_##name##_long_int, \ + [OMPI_OP_BASE_TYPE_2INT] = ompi_op_cuda_##ftype##_##name##_2int, \ + [OMPI_OP_BASE_TYPE_SHORT_INT] = ompi_op_cuda_##ftype##_##name##_short_int, \ + [OMPI_OP_BASE_TYPE_LONG_DOUBLE_INT] = ompi_op_cuda_##ftype##_##name##_long_double_int + +/* + * MPI_OP_NULL + * All types + */ +#define FLAGS_NO_FLOAT \ + (OMPI_OP_FLAGS_INTRINSIC | OMPI_OP_FLAGS_ASSOC | OMPI_OP_FLAGS_COMMUTE) +#define FLAGS \ + (OMPI_OP_FLAGS_INTRINSIC | OMPI_OP_FLAGS_ASSOC | \ + OMPI_OP_FLAGS_FLOAT_ASSOC | OMPI_OP_FLAGS_COMMUTE) + +ompi_op_base_stream_handler_fn_t ompi_op_cuda_functions[OMPI_OP_BASE_FORTRAN_OP_MAX][OMPI_OP_BASE_TYPE_MAX] = + { + /* Corresponds to MPI_OP_NULL */ + [OMPI_OP_BASE_FORTRAN_NULL] = { + /* Leaving this empty puts in NULL for all entries */ + NULL, + }, + /* Corresponds to MPI_MAX */ + [OMPI_OP_BASE_FORTRAN_MAX] = { + C_INTEGER(max, 2buff), + FORTRAN_INTEGER(max, 2buff), + FLOATING_POINT(max, 2buff), + }, + /* Corresponds to MPI_MIN */ + [OMPI_OP_BASE_FORTRAN_MIN] = { + C_INTEGER(min, 2buff), + FORTRAN_INTEGER(min, 2buff), + FLOATING_POINT(min, 2buff), + }, + /* Corresponds to MPI_SUM */ + [OMPI_OP_BASE_FORTRAN_SUM] = { + C_INTEGER(sum, 2buff), + FORTRAN_INTEGER(sum, 2buff), + FLOATING_POINT(sum, 2buff), + COMPLEX(sum, 2buff), + }, + /* Corresponds to MPI_PROD */ + [OMPI_OP_BASE_FORTRAN_PROD] = { + C_INTEGER(prod, 2buff), + FORTRAN_INTEGER(prod, 2buff), + FLOATING_POINT(prod, 2buff), + COMPLEX(prod, 2buff), + }, + /* Corresponds to MPI_LAND */ + [OMPI_OP_BASE_FORTRAN_LAND] = { + C_INTEGER(land, 2buff), + LOGICAL(land, 2buff), + }, + /* Corresponds to MPI_BAND */ + [OMPI_OP_BASE_FORTRAN_BAND] = { + C_INTEGER(band, 2buff), + FORTRAN_INTEGER(band, 2buff), + BYTE(band, 2buff), + }, + /* Corresponds to MPI_LOR */ + [OMPI_OP_BASE_FORTRAN_LOR] = { + C_INTEGER(lor, 2buff), + LOGICAL(lor, 2buff), + }, + /* Corresponds to MPI_BOR */ + [OMPI_OP_BASE_FORTRAN_BOR] = { + C_INTEGER(bor, 2buff), + FORTRAN_INTEGER(bor, 2buff), + BYTE(bor, 2buff), + }, + /* Corresponds to MPI_LXOR */ + [OMPI_OP_BASE_FORTRAN_LXOR] = { + C_INTEGER(lxor, 2buff), + LOGICAL(lxor, 2buff), + }, + /* Corresponds to MPI_BXOR */ + [OMPI_OP_BASE_FORTRAN_BXOR] = { + C_INTEGER(bxor, 2buff), + FORTRAN_INTEGER(bxor, 2buff), + BYTE(bxor, 2buff), + }, + /* Corresponds to MPI_MAXLOC */ + [OMPI_OP_BASE_FORTRAN_MAXLOC] = { + TWOLOC(maxloc, 2buff), + }, + /* Corresponds to MPI_MINLOC */ + [OMPI_OP_BASE_FORTRAN_MINLOC] = { + TWOLOC(minloc, 2buff), + }, + /* Corresponds to MPI_REPLACE */ + [OMPI_OP_BASE_FORTRAN_REPLACE] = { + /* (MPI_ACCUMULATE is handled differently than the other + reductions, so just zero out its function + implementations here to ensure that users don't invoke + MPI_REPLACE with any reduction operations other than + ACCUMULATE) */ + NULL, + }, + + }; + +ompi_op_base_3buff_stream_handler_fn_t ompi_op_cuda_3buff_functions[OMPI_OP_BASE_FORTRAN_OP_MAX][OMPI_OP_BASE_TYPE_MAX] = + { + /* Corresponds to MPI_OP_NULL */ + [OMPI_OP_BASE_FORTRAN_NULL] = { + /* Leaving this empty puts in NULL for all entries */ + NULL, + }, + /* Corresponds to MPI_MAX */ + [OMPI_OP_BASE_FORTRAN_MAX] = { + C_INTEGER(max, 3buff), + FORTRAN_INTEGER(max, 3buff), + FLOATING_POINT(max, 3buff), + }, + /* Corresponds to MPI_MIN */ + [OMPI_OP_BASE_FORTRAN_MIN] = { + C_INTEGER(min, 3buff), + FORTRAN_INTEGER(min, 3buff), + FLOATING_POINT(min, 3buff), + }, + /* Corresponds to MPI_SUM */ + [OMPI_OP_BASE_FORTRAN_SUM] = { + C_INTEGER(sum, 3buff), + FORTRAN_INTEGER(sum, 3buff), + FLOATING_POINT(sum, 3buff), + COMPLEX(sum, 3buff), + }, + /* Corresponds to MPI_PROD */ + [OMPI_OP_BASE_FORTRAN_PROD] = { + C_INTEGER(prod, 3buff), + FORTRAN_INTEGER(prod, 3buff), + FLOATING_POINT(prod, 3buff), + COMPLEX(prod, 3buff), + }, + /* Corresponds to MPI_LAND */ + [OMPI_OP_BASE_FORTRAN_LAND] ={ + C_INTEGER(land, 3buff), + LOGICAL(land, 3buff), + }, + /* Corresponds to MPI_BAND */ + [OMPI_OP_BASE_FORTRAN_BAND] = { + C_INTEGER(band, 3buff), + FORTRAN_INTEGER(band, 3buff), + BYTE(band, 3buff), + }, + /* Corresponds to MPI_LOR */ + [OMPI_OP_BASE_FORTRAN_LOR] = { + C_INTEGER(lor, 3buff), + LOGICAL(lor, 3buff), + }, + /* Corresponds to MPI_BOR */ + [OMPI_OP_BASE_FORTRAN_BOR] = { + C_INTEGER(bor, 3buff), + FORTRAN_INTEGER(bor, 3buff), + BYTE(bor, 3buff), + }, + /* Corresponds to MPI_LXOR */ + [OMPI_OP_BASE_FORTRAN_LXOR] = { + C_INTEGER(lxor, 3buff), + LOGICAL(lxor, 3buff), + }, + /* Corresponds to MPI_BXOR */ + [OMPI_OP_BASE_FORTRAN_BXOR] = { + C_INTEGER(bxor, 3buff), + FORTRAN_INTEGER(bxor, 3buff), + BYTE(bxor, 3buff), + }, + /* Corresponds to MPI_MAXLOC */ + [OMPI_OP_BASE_FORTRAN_MAXLOC] = { + TWOLOC(maxloc, 3buff), + }, + /* Corresponds to MPI_MINLOC */ + [OMPI_OP_BASE_FORTRAN_MINLOC] = { + TWOLOC(minloc, 3buff), + }, + /* Corresponds to MPI_REPLACE */ + [OMPI_OP_BASE_FORTRAN_REPLACE] = { + /* MPI_ACCUMULATE is handled differently than the other + reductions, so just zero out its function + implementations here to ensure that users don't invoke + MPI_REPLACE with any reduction operations other than + ACCUMULATE */ + NULL, + }, + }; diff --git a/ompi/mca/op/cuda/op_cuda_impl.cu b/ompi/mca/op/cuda/op_cuda_impl.cu new file mode 100644 index 00000000000..3daf7f56fbb --- /dev/null +++ b/ompi/mca/op/cuda/op_cuda_impl.cu @@ -0,0 +1,1080 @@ +/* + * Copyright (c) 2019-2023 The University of Tennessee and The University + * of Tennessee Research Foundation. All rights + * reserved. + * Copyright (c) 2020 Research Organization for Information Science + * and Technology (RIST). All rights reserved. + * $COPYRIGHT$ + * + * Additional copyrights may follow + * + * $HEADER$ + */ + +#include "op_cuda_impl.h" + +#include + +#include + +#define ISSIGNED(x) std::is_signed_v + +template +static inline __device__ constexpr T tmax(T a, T b) { + return (a > b) ? a : b; +} + +template +static inline __device__ constexpr T tmin(T a, T b) { + return (a < b) ? a : b; +} + +template +static inline __device__ constexpr T tsum(T a, T b) { + return a+b; +} + +template +static inline __device__ constexpr T tprod(T a, T b) { + return a*b; +} + +template +static inline __device__ T vmax(const T& a, const T& b) { + return T{tmax(a.x, b.x), tmax(a.y, b.y), tmax(a.z, b.z), tmax(a.w, b.w)}; +} + +template +static inline __device__ T vmin(const T& a, const T& b) { + return T{tmin(a.x, b.x), tmin(a.y, b.y), tmin(a.z, b.z), tmin(a.w, b.w)}; +} + +template +static inline __device__ T vsum(const T& a, const T& b) { + return T{tsum(a.x, b.x), tsum(a.y, b.y), tsum(a.z, b.z), tsum(a.w, b.w)}; +} + +template +static inline __device__ T vprod(const T& a, const T& b) { + return T{(a.x * b.x), (a.y * b.y), (a.z * b.z), (a.w * b.w)}; +} + + +/* TODO: missing support for + * - short float (conditional on whether short float is available) + * - some Fortran types + * - some complex types + */ + +#define USE_VECTORS 1 + +#define OP_FUNC(name, type_name, type, op) \ + static __global__ void \ + ompi_op_cuda_2buff_##name##_##type_name##_kernel(const type *__restrict__ in, \ + type *__restrict__ inout, int n) { \ + const int index = blockIdx.x * blockDim.x + threadIdx.x; \ + const int stride = blockDim.x * gridDim.x; \ + for (int i = index; i < n; i += stride) { \ + /*if (index < n) { int i = index;*/ \ + inout[i] = inout[i] op in[i]; \ + } \ + } \ + void ompi_op_cuda_2buff_##name##_##type_name##_submit(const type *in, \ + type *inout, \ + int count, \ + int threads_per_block, \ + int max_blocks, \ + CUstream stream) { \ + int threads = min(count, threads_per_block); \ + int blocks = min((count + threads-1) / threads, max_blocks); \ + int n = count; \ + CUstream s = stream; \ + ompi_op_cuda_2buff_##name##_##type_name##_kernel<<>>(in, inout, n); \ + } + + +#if defined(USE_VECTORS) +#define OPV_FUNC(name, type_name, type, vtype, vlen, op) \ + static __global__ void \ + ompi_op_cuda_2buff_##name##_##type_name##_kernel(const type *__restrict__ in, \ + type *__restrict__ inout, int n) { \ + const int index = blockIdx.x * blockDim.x + threadIdx.x; \ + const int stride = blockDim.x * gridDim.x; \ + for (int i = index; i < n/vlen; i += stride) { \ + vtype vin = ((vtype*)in)[i]; \ + vtype vinout = ((vtype*)inout)[i]; \ + vinout.x = vinout.x op vin.x; \ + vinout.y = vinout.y op vin.y; \ + vinout.z = vinout.z op vin.z; \ + vinout.w = vinout.w op vin.w; \ + ((vtype*)inout)[i] = vinout; \ + } \ + int remainder = n%vlen; \ + if (index == (n/vlen) && remainder != 0) { \ + while(remainder) { \ + int idx = n - remainder--; \ + inout[idx] = inout[idx] op in[idx]; \ + } \ + } \ + } \ + void ompi_op_cuda_2buff_##name##_##type_name##_submit(const type *in, \ + type *inout, \ + int count, \ + int threads_per_block, \ + int max_blocks, \ + CUstream stream) { \ + int vcount = (count + vlen-1)/vlen; \ + int threads = min(threads_per_block, vcount); \ + int blocks = min((vcount + threads-1) / threads, max_blocks); \ + int n = count; \ + CUstream s = stream; \ + ompi_op_cuda_2buff_##name##_##type_name##_kernel<<>>(in, inout, n); \ + } +#else // USE_VECTORS +#define OPV_FUNC(name, type_name, type, vtype, vlen, op) OP_FUNC(name, type_name, type, op) +#endif // USE_VECTORS + +#define FUNC_FUNC_FN(name, type_name, type, fn) \ + static __global__ void \ + ompi_op_cuda_2buff_##name##_##type_name##_kernel(const type *__restrict__ in, \ + type *__restrict__ inout, int n) { \ + const int index = blockIdx.x * blockDim.x + threadIdx.x; \ + const int stride = blockDim.x * gridDim.x; \ + for (int i = index; i < n; i += stride) { \ + inout[i] = fn(inout[i], in[i]); \ + } \ + } \ + void \ + ompi_op_cuda_2buff_##name##_##type_name##_submit(const type *in, \ + type *inout, \ + int count, \ + int threads_per_block, \ + int max_blocks, \ + CUstream stream) { \ + int threads = min(count, threads_per_block); \ + int blocks = min((count + threads-1) / threads, max_blocks); \ + int n = count; \ + CUstream s = stream; \ + ompi_op_cuda_2buff_##name##_##type_name##_kernel<<>>(in, inout, n); \ + } + +#define FUNC_FUNC(name, type_name, type) FUNC_FUNC_FN(name, type_name, type, current_func) + +#if defined(USE_VECTORS) +#define VFUNC_FUNC(name, type_name, type, vtype, vlen, vfn, fn) \ + static __global__ void \ + ompi_op_cuda_2buff_##name##_##type_name##_kernel(const type *__restrict__ in, \ + type *__restrict__ inout, int n) { \ + const int index = blockIdx.x * blockDim.x + threadIdx.x; \ + const int stride = blockDim.x * gridDim.x; \ + for (int i = index; i < n/vlen; i += stride) { \ + ((vtype*)inout)[i] = vfn(((vtype*)inout)[i], ((vtype*)in)[i]); \ + } \ + int remainder = n%vlen; \ + if (index == (n/vlen) && remainder != 0) { \ + while(remainder) { \ + int idx = n - remainder--; \ + inout[idx] = fn(inout[idx], in[idx]); \ + } \ + } \ + } \ + static void \ + ompi_op_cuda_2buff_##name##_##type_name##_submit(const type *in, \ + type *inout, \ + int count, \ + int threads_per_block, \ + int max_blocks, \ + CUstream stream) { \ + int vcount = (count + vlen-1)/vlen; \ + int threads = min(threads_per_block, vcount); \ + int blocks = min((vcount + threads-1) / threads, max_blocks); \ + int n = count; \ + CUstream s = stream; \ + ompi_op_cuda_2buff_##name##_##type_name##_kernel<<>>(in, inout, n); \ + } +#else +#define VFUNC_FUNC(name, type_name, type, vtype, vlen, vfn, fn) FUNC_FUNC_FN(name, type_name, type, fn) +#endif // defined(USE_VECTORS) + +/* + * Since all the functions in this file are essentially identical, we + * use a macro to substitute in names and types. The core operation + * in all functions that use this macro is the same. + * + * This macro is for minloc and maxloc + */ + +#define LOC_FUNC(name, type_name, op) \ + static __global__ void \ + ompi_op_cuda_2buff_##name##_##type_name##_kernel(const ompi_op_predefined_##type_name##_t *__restrict__ in, \ + ompi_op_predefined_##type_name##_t *__restrict__ inout, \ + int n) \ + { \ + const int index = blockIdx.x * blockDim.x + threadIdx.x; \ + const int stride = blockDim.x * gridDim.x; \ + for (int i = index; i < n; i += stride) { \ + const ompi_op_predefined_##type_name##_t *a = &in[i]; \ + ompi_op_predefined_##type_name##_t *b = &inout[i]; \ + if (a->v op b->v) { \ + b->v = a->v; \ + b->k = a->k; \ + } else if (a->v == b->v) { \ + b->k = (b->k < a->k ? b->k : a->k); \ + } \ + } \ + } \ + void \ + ompi_op_cuda_2buff_##name##_##type_name##_submit(const ompi_op_predefined_##type_name##_t *a, \ + ompi_op_predefined_##type_name##_t *b, \ + int count, \ + int threads_per_block, \ + int max_blocks, \ + CUstream stream) { \ + int threads = min(count, threads_per_block); \ + int blocks = min((count + threads-1) / threads, max_blocks); \ + CUstream s = stream; \ + ompi_op_cuda_2buff_##name##_##type_name##_kernel<<>>(a, b, count); \ + } + +#define OPV_DISPATCH(name, type_name, type) \ + void ompi_op_cuda_2buff_##name##_##type_name##_submit(const type *in, \ + type *inout, \ + int count, \ + int threads_per_block, \ + int max_blocks, \ + CUstream stream) { \ + static_assert(sizeof(type_name) <= sizeof(unsigned long long), "Unknown size type"); \ + if constexpr(!ISSIGNED(type)) { \ + if constexpr(sizeof(type_name) == sizeof(unsigned char)) { \ + ompi_op_cuda_2buff_##name##_uchar_submit((const unsigned char*)in, (unsigned char*)inout, count, \ + threads_per_block, \ + max_blocks, stream); \ + } else if constexpr(sizeof(type_name) == sizeof(unsigned short)) { \ + ompi_op_cuda_2buff_##name##_ushort_submit((const unsigned short*)in, (unsigned short*)inout, count, \ + threads_per_block, \ + max_blocks, stream); \ + } else if constexpr(sizeof(type_name) == sizeof(unsigned int)) { \ + ompi_op_cuda_2buff_##name##_uint_submit((const unsigned int*)in, (unsigned int*)inout, count, \ + threads_per_block, \ + max_blocks, stream); \ + } else if constexpr(sizeof(type_name) == sizeof(unsigned long)) { \ + ompi_op_cuda_2buff_##name##_ulong_submit((const unsigned long*)in, (unsigned long*)inout, count, \ + threads_per_block, \ + max_blocks, stream); \ + } else if constexpr(sizeof(type_name) == sizeof(unsigned long long)) { \ + ompi_op_cuda_2buff_##name##_ulonglong_submit((const unsigned long long*)in, (unsigned long long*)inout, count, \ + threads_per_block, \ + max_blocks, stream); \ + } \ + } else { \ + if constexpr(sizeof(type_name) == sizeof(char)) { \ + ompi_op_cuda_2buff_##name##_char_submit((const char*)in, (char*)inout, count, \ + threads_per_block, \ + max_blocks, stream); \ + } else if constexpr(sizeof(type_name) == sizeof(short)) { \ + ompi_op_cuda_2buff_##name##_short_submit((const short*)in, (short*)inout, count, \ + threads_per_block, \ + max_blocks, stream); \ + } else if constexpr(sizeof(type_name) == sizeof(int)) { \ + ompi_op_cuda_2buff_##name##_int_submit((const int*)in, (int*)inout, count, \ + threads_per_block, \ + max_blocks, stream); \ + } else if constexpr(sizeof(type_name) == sizeof(long)) { \ + ompi_op_cuda_2buff_##name##_long_submit((const long*)in, (long*)inout, count, \ + threads_per_block, \ + max_blocks, stream); \ + } else if constexpr(sizeof(type_name) == sizeof(long long)) { \ + ompi_op_cuda_2buff_##name##_longlong_submit((const long long*)in, (long long*)inout, count,\ + threads_per_block, \ + max_blocks, stream); \ + } \ + } \ + } + +/************************************************************************* + * Max + *************************************************************************/ + +/* C integer */ +VFUNC_FUNC(max, char, char, char4, 4, vmax, max) +VFUNC_FUNC(max, uchar, unsigned char, uchar4, 4, vmax, max) +VFUNC_FUNC(max, short, short, short4, 4, vmax, max) +VFUNC_FUNC(max, ushort, unsigned short, ushort4, 4, vmax, max) +VFUNC_FUNC(max, int, int, int4, 4, vmax, max) +VFUNC_FUNC(max, uint, unsigned int, uint4, 4, vmax, max) + +#undef current_func +#define current_func(a, b) max(a, b) +FUNC_FUNC(max, long, long) +FUNC_FUNC(max, ulong, unsigned long) +FUNC_FUNC(max, longlong, long long) +FUNC_FUNC(max, ulonglong, unsigned long long) + +/* dispatch fixed-size types */ +OPV_DISPATCH(max, int8_t, int8_t) +OPV_DISPATCH(max, uint8_t, uint8_t) +OPV_DISPATCH(max, int16_t, int16_t) +OPV_DISPATCH(max, uint16_t, uint16_t) +OPV_DISPATCH(max, int32_t, int32_t) +OPV_DISPATCH(max, uint32_t, uint32_t) +OPV_DISPATCH(max, int64_t, int64_t) +OPV_DISPATCH(max, uint64_t, uint64_t) + +#undef current_func +#define current_func(a, b) ((a) > (b) ? (a) : (b)) +FUNC_FUNC(max, long_double, long double) + +#if !defined(DO_NOT_USE_INTRINSICS) +#undef current_func +#define current_func(a, b) fmaxf(a, b) +#endif // DO_NOT_USE_INTRINSICS +FUNC_FUNC(max, float, float) + +#if !defined(DO_NOT_USE_INTRINSICS) +#undef current_func +#define current_func(a, b) fmax(a, b) +#endif // DO_NOT_USE_INTRINSICS +FUNC_FUNC(max, double, double) + +// __CUDA_ARCH__ is only defined when compiling device code +#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 530 +#undef current_func +#define current_func(a, b) __hmax2(a, b) +//VFUNC_FUNC(max, halfx, half, half2, 2, __hmax2, __hmax) +#endif // __CUDA_ARCH__ + +/************************************************************************* + * Min + *************************************************************************/ + +/* C integer */ +VFUNC_FUNC(min, char, char, char4, 4, vmin, min) +VFUNC_FUNC(min, uchar, unsigned char, uchar4, 4, vmin, min) +VFUNC_FUNC(min, short, short, short4, 4, vmin, min) +VFUNC_FUNC(min, ushort, unsigned short, ushort4, 4, vmin, min) +VFUNC_FUNC(min, int, int, int4, 4, vmin, min) +VFUNC_FUNC(min, uint, unsigned int, uint4, 4, vmin, min) + +#undef current_func +#define current_func(a, b) min(a, b) +FUNC_FUNC(min, long, long) +FUNC_FUNC(min, ulong, unsigned long) +FUNC_FUNC(min, longlong, long long) +FUNC_FUNC(min, ulonglong, unsigned long long) +OPV_DISPATCH(min, int8_t, int8_t) +OPV_DISPATCH(min, uint8_t, uint8_t) +OPV_DISPATCH(min, int16_t, int16_t) +OPV_DISPATCH(min, uint16_t, uint16_t) +OPV_DISPATCH(min, int32_t, int32_t) +OPV_DISPATCH(min, uint32_t, uint32_t) +OPV_DISPATCH(min, int64_t, int64_t) +OPV_DISPATCH(min, uint64_t, uint64_t) + + + +#if !defined(DO_NOT_USE_INTRINSICS) +#undef current_func +#define current_func(a, b) fminf(a, b) +#endif // DO_NOT_USE_INTRINSICS +FUNC_FUNC(min, float, float) + +#if !defined(DO_NOT_USE_INTRINSICS) +#undef current_func +#define current_func(a, b) fmin(a, b) +#endif // DO_NOT_USE_INTRINSICS +FUNC_FUNC(min, double, double) + +#undef current_func +#define current_func(a, b) ((a) < (b) ? (a) : (b)) +FUNC_FUNC(min, long_double, long double) + +// __CUDA_ARCH__ is only defined when compiling device code +#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 530 +#undef current_func +#define current_func(a, b) __hmin2(a, b) +//VFUNC_FUNC(min, half, half, half2, 2, __hmin2, __hmin) +#endif // __CUDA_ARCH__ + +/************************************************************************* + * Sum + *************************************************************************/ + +/* C integer */ +VFUNC_FUNC(sum, char, char, char4, 4, vsum, tsum) +VFUNC_FUNC(sum, uchar, unsigned char, uchar4, 4, vsum, tsum) +VFUNC_FUNC(sum, short, short, short4, 4, vsum, tsum) +VFUNC_FUNC(sum, ushort, unsigned short, ushort4, 4, vsum, tsum) +VFUNC_FUNC(sum, int, int, int4, 4, vsum, tsum) +VFUNC_FUNC(sum, uint, unsigned int, uint4, 4, vsum, tsum) + +#undef current_func +#define current_func(a, b) tsum(a, b) +FUNC_FUNC(sum, long, long) +FUNC_FUNC(sum, ulong, unsigned long) +FUNC_FUNC(sum, longlong, long long) +FUNC_FUNC(sum, ulonglong, unsigned long long) + +OPV_DISPATCH(sum, int8_t, int8_t) +OPV_DISPATCH(sum, uint8_t, uint8_t) +OPV_DISPATCH(sum, int16_t, int16_t) +OPV_DISPATCH(sum, uint16_t, uint16_t) +OPV_DISPATCH(sum, int32_t, int32_t) +OPV_DISPATCH(sum, uint32_t, uint32_t) +OPV_DISPATCH(sum, int64_t, int64_t) +OPV_DISPATCH(sum, uint64_t, uint64_t) + +OPV_FUNC(sum, float, float, float4, 4, +) +OPV_FUNC(sum, double, double, double4, 4, +) +OP_FUNC(sum, long_double, long double, +) + +// __CUDA_ARCH__ is only defined when compiling device code +#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 530 +#undef current_func +#define current_func(a, b) __hadd2(a, b) +//VFUNC_FUNC(sum, half, half, half2, 2, __hadd2, __hadd) +#endif // __CUDA_ARCH__ + +/* Complex */ +#if 0 +#if defined(HAVE_SHORT_FLOAT__COMPLEX) +OP_FUNC(sum, c_short_float_complex, short float _Complex, +=) +#elif defined(HAVE_OPAL_SHORT_FLOAT_COMPLEX_T) +COMPLEX_SUM_FUNC(c_short_float_complex, opal_short_float_t) +OP_FUNC(sum, c_long_double_complex, cuLongDoubleComplex, +=) +#endif +#endif // 0 +#undef current_func +#define current_func(a, b) (cuCaddf(a,b)) +FUNC_FUNC(sum, c_float_complex, cuFloatComplex) +#undef current_func +#define current_func(a, b) (cuCadd(a,b)) +FUNC_FUNC(sum, c_double_complex, cuDoubleComplex) + +/************************************************************************* + * Product + *************************************************************************/ + +/* C integer */ +#undef current_func +#define current_func(a, b) tprod(a, b) +FUNC_FUNC(prod, char, char) +FUNC_FUNC(prod, uchar, unsigned char) +FUNC_FUNC(prod, short, short) +FUNC_FUNC(prod, ushort, unsigned short) +FUNC_FUNC(prod, int, int) +FUNC_FUNC(prod, uint, unsigned int) +FUNC_FUNC(prod, long, long) +FUNC_FUNC(prod, ulong, unsigned long) +FUNC_FUNC(prod, longlong, long long) +FUNC_FUNC(prod, ulonglong, unsigned long long) + +OPV_DISPATCH(prod, int8_t, int8_t) +OPV_DISPATCH(prod, uint8_t, uint8_t) +OPV_DISPATCH(prod, int16_t, int16_t) +OPV_DISPATCH(prod, uint16_t, uint16_t) +OPV_DISPATCH(prod, int32_t, int32_t) +OPV_DISPATCH(prod, uint32_t, uint32_t) +OPV_DISPATCH(prod, int64_t, int64_t) +OPV_DISPATCH(prod, uint64_t, uint64_t) + + +OPV_FUNC(prod, float, float, float4, 4, *) +OPV_FUNC(prod, double, double, double4, 4, *) +OP_FUNC(prod, long_double, long double, *) + +// __CUDA_ARCH__ is only defined when compiling device code +#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 530 +#undef current_func +#define current_func(a, b) __hmul2(a, b) +//VFUNC_FUNC(prod, half, half, half2, 2, __hmul2, __hmul) +#endif // __CUDA_ARCH__ + +/* Complex */ +#if 0 +#if defined(HAVE_SHORT_FLOAT__COMPLEX) +OP_FUNC(prod, c_short_float_complex, short float _Complex, *=) +#elif defined(HAVE_OPAL_SHORT_FLOAT_COMPLEX_T) +COMPLEX_PROD_FUNC(c_short_float_complex, opal_short_float_t) +#endif +OP_FUNC(prod, c_long_double_complex, long double _Complex, *=) +#endif // 0 +#undef current_func +#define current_func(a, b) (cuCmulf(a,b)) +FUNC_FUNC(prod, c_float_complex, cuFloatComplex) +#undef current_func +#define current_func(a, b) (cuCmul(a,b)) +FUNC_FUNC(prod, c_double_complex, cuDoubleComplex) + +/************************************************************************* + * Logical AND + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) && (b)) +/* C integer */ +FUNC_FUNC(land, int8_t, int8_t) +FUNC_FUNC(land, uint8_t, uint8_t) +FUNC_FUNC(land, int16_t, int16_t) +FUNC_FUNC(land, uint16_t, uint16_t) +FUNC_FUNC(land, int32_t, int32_t) +FUNC_FUNC(land, uint32_t, uint32_t) +FUNC_FUNC(land, int64_t, int64_t) +FUNC_FUNC(land, uint64_t, uint64_t) +FUNC_FUNC(land, long, long) +FUNC_FUNC(land, ulong, unsigned long) + +/* C++ bool */ +FUNC_FUNC(land, bool, bool) + +/************************************************************************* + * Logical OR + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) || (b)) +/* C integer */ +FUNC_FUNC(lor, int8_t, int8_t) +FUNC_FUNC(lor, uint8_t, uint8_t) +FUNC_FUNC(lor, int16_t, int16_t) +FUNC_FUNC(lor, uint16_t, uint16_t) +FUNC_FUNC(lor, int32_t, int32_t) +FUNC_FUNC(lor, uint32_t, uint32_t) +FUNC_FUNC(lor, int64_t, int64_t) +FUNC_FUNC(lor, uint64_t, uint64_t) +FUNC_FUNC(lor, long, long) +FUNC_FUNC(lor, ulong, unsigned long) + +/* C++ bool */ +FUNC_FUNC(lor, bool, bool) + +/************************************************************************* + * Logical XOR + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a ? 1 : 0) ^ (b ? 1: 0)) +/* C integer */ +FUNC_FUNC(lxor, int8_t, int8_t) +FUNC_FUNC(lxor, uint8_t, uint8_t) +FUNC_FUNC(lxor, int16_t, int16_t) +FUNC_FUNC(lxor, uint16_t, uint16_t) +FUNC_FUNC(lxor, int32_t, int32_t) +FUNC_FUNC(lxor, uint32_t, uint32_t) +FUNC_FUNC(lxor, int64_t, int64_t) +FUNC_FUNC(lxor, uint64_t, uint64_t) +FUNC_FUNC(lxor, long, long) +FUNC_FUNC(lxor, ulong, unsigned long) + +/* C++ bool */ +FUNC_FUNC(lxor, bool, bool) + +/************************************************************************* + * Bitwise AND + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) & (b)) +/* C integer */ +FUNC_FUNC(band, int8_t, int8_t) +FUNC_FUNC(band, uint8_t, uint8_t) +FUNC_FUNC(band, int16_t, int16_t) +FUNC_FUNC(band, uint16_t, uint16_t) +FUNC_FUNC(band, int32_t, int32_t) +FUNC_FUNC(band, uint32_t, uint32_t) +FUNC_FUNC(band, int64_t, int64_t) +FUNC_FUNC(band, uint64_t, uint64_t) +FUNC_FUNC(band, long, long) +FUNC_FUNC(band, ulong, unsigned long) + +/* Byte */ +FUNC_FUNC(band, byte, char) + +/************************************************************************* + * Bitwise OR + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) | (b)) +/* C integer */ +FUNC_FUNC(bor, int8_t, int8_t) +FUNC_FUNC(bor, uint8_t, uint8_t) +FUNC_FUNC(bor, int16_t, int16_t) +FUNC_FUNC(bor, uint16_t, uint16_t) +FUNC_FUNC(bor, int32_t, int32_t) +FUNC_FUNC(bor, uint32_t, uint32_t) +FUNC_FUNC(bor, int64_t, int64_t) +FUNC_FUNC(bor, uint64_t, uint64_t) +FUNC_FUNC(bor, long, long) +FUNC_FUNC(bor, ulong, unsigned long) + +/* Byte */ +FUNC_FUNC(bor, byte, char) + +/************************************************************************* + * Bitwise XOR + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) ^ (b)) +/* C integer */ +FUNC_FUNC(bxor, int8_t, int8_t) +FUNC_FUNC(bxor, uint8_t, uint8_t) +FUNC_FUNC(bxor, int16_t, int16_t) +FUNC_FUNC(bxor, uint16_t, uint16_t) +FUNC_FUNC(bxor, int32_t, int32_t) +FUNC_FUNC(bxor, uint32_t, uint32_t) +FUNC_FUNC(bxor, int64_t, int64_t) +FUNC_FUNC(bxor, uint64_t, uint64_t) +FUNC_FUNC(bxor, long, long) +FUNC_FUNC(bxor, ulong, unsigned long) + +/* Byte */ +FUNC_FUNC(bxor, byte, char) + +/************************************************************************* + * Max location + *************************************************************************/ + +LOC_FUNC(maxloc, float_int, >) +LOC_FUNC(maxloc, double_int, >) +LOC_FUNC(maxloc, long_int, >) +LOC_FUNC(maxloc, 2int, >) +LOC_FUNC(maxloc, short_int, >) +LOC_FUNC(maxloc, long_double_int, >) + +/* Fortran compat types */ +LOC_FUNC(maxloc, 2float, >) +LOC_FUNC(maxloc, 2double, >) +LOC_FUNC(maxloc, 2int8, >) +LOC_FUNC(maxloc, 2int16, >) +LOC_FUNC(maxloc, 2int32, >) +LOC_FUNC(maxloc, 2int64, >) + +/************************************************************************* + * Min location + *************************************************************************/ + +LOC_FUNC(minloc, float_int, <) +LOC_FUNC(minloc, double_int, <) +LOC_FUNC(minloc, long_int, <) +LOC_FUNC(minloc, 2int, <) +LOC_FUNC(minloc, short_int, <) +LOC_FUNC(minloc, long_double_int, <) + +/* Fortran compat types */ +LOC_FUNC(minloc, 2float, <) +LOC_FUNC(minloc, 2double, <) +LOC_FUNC(minloc, 2int8, <) +LOC_FUNC(minloc, 2int16, <) +LOC_FUNC(minloc, 2int32, <) +LOC_FUNC(minloc, 2int64, <) + +/* + * This is a three buffer (2 input and 1 output) version of the reduction + * routines, needed for some optimizations. + */ +#define OP_FUNC_3BUF(name, type_name, type, op) \ + static __global__ void \ + ompi_op_cuda_3buff_##name##_##type_name##_kernel(const type *__restrict__ in1, \ + const type *__restrict__ in2, \ + type *__restrict__ out, int n) { \ + const int index = blockIdx.x * blockDim.x + threadIdx.x; \ + const int stride = blockDim.x * gridDim.x; \ + for (int i = index; i < n; i += stride) { \ + out[i] = in1[i] op in2[i]; \ + } \ + } \ + void ompi_op_cuda_3buff_##name##_##type_name##_submit(const type *in1, const type *in2, \ + type *out, int count, \ + int threads_per_block, \ + int max_blocks, \ + CUstream stream) { \ + int threads = min(count, threads_per_block); \ + int blocks = min((count + threads-1) / threads, max_blocks); \ + ompi_op_cuda_3buff_##name##_##type_name##_kernel<<>>(in1, in2, out, count); \ + } + + +/* + * Since all the functions in this file are essentially identical, we + * use a macro to substitute in names and types. The core operation + * in all functions that use this macro is the same. + * + * This macro is for (out = op(in1, in2)) + */ +#define FUNC_FUNC_3BUF(name, type_name, type) \ + static __global__ void \ + ompi_op_cuda_3buff_##name##_##type_name##_kernel(const type *__restrict__ in1, \ + const type *__restrict__ in2, \ + type *__restrict__ out, int n) { \ + const int index = blockIdx.x * blockDim.x + threadIdx.x; \ + const int stride = blockDim.x * gridDim.x; \ + for (int i = index; i < n; i += stride) { \ + out[i] = current_func(in1[i], in2[i]); \ + } \ + } \ + void \ + ompi_op_cuda_3buff_##name##_##type_name##_submit(const type *in1, const type *in2, \ + type *out, int count, \ + int threads_per_block, \ + int max_blocks, \ + CUstream stream) { \ + int threads = min(count, threads_per_block); \ + int blocks = min((count + threads-1) / threads, max_blocks); \ + ompi_op_cuda_3buff_##name##_##type_name##_kernel<<>>(in1, in2, out, count); \ + } + +/* + * Since all the functions in this file are essentially identical, we + * use a macro to substitute in names and types. The core operation + * in all functions that use this macro is the same. + * + * This macro is for minloc and maxloc + */ +#define LOC_FUNC_3BUF(name, type_name, op) \ + static __global__ void \ + ompi_op_cuda_3buff_##name##_##type_name##_kernel(const ompi_op_predefined_##type_name##_t *__restrict__ in1, \ + const ompi_op_predefined_##type_name##_t *__restrict__ in2, \ + ompi_op_predefined_##type_name##_t *__restrict__ out, \ + int n) \ + { \ + const int index = blockIdx.x * blockDim.x + threadIdx.x; \ + const int stride = blockDim.x * gridDim.x; \ + for (int i = index; i < n; i += stride) { \ + const ompi_op_predefined_##type_name##_t *a1 = &in1[i]; \ + const ompi_op_predefined_##type_name##_t *a2 = &in2[i]; \ + ompi_op_predefined_##type_name##_t *b = &out[i]; \ + if (a1->v op a2->v) { \ + b->v = a1->v; \ + b->k = a1->k; \ + } else if (a1->v == a2->v) { \ + b->v = a1->v; \ + b->k = (a2->k < a1->k ? a2->k : a1->k); \ + } else { \ + b->v = a2->v; \ + b->k = a2->k; \ + } \ + } \ + } \ + void \ + ompi_op_cuda_3buff_##name##_##type_name##_submit(const ompi_op_predefined_##type_name##_t *in1, \ + const ompi_op_predefined_##type_name##_t *in2, \ + ompi_op_predefined_##type_name##_t *out, \ + int count, \ + int threads_per_block, \ + int max_blocks, \ + CUstream stream) \ + { \ + int threads = min(count, threads_per_block); \ + int blocks = min((count + threads-1) / threads, max_blocks); \ + ompi_op_cuda_3buff_##name##_##type_name##_kernel<<>>(in1, in2, out, count); \ + } + + +/************************************************************************* + * Max + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) > (b) ? (a) : (b)) +/* C integer */ +FUNC_FUNC_3BUF(max, int8_t, int8_t) +FUNC_FUNC_3BUF(max, uint8_t, uint8_t) +FUNC_FUNC_3BUF(max, int16_t, int16_t) +FUNC_FUNC_3BUF(max, uint16_t, uint16_t) +FUNC_FUNC_3BUF(max, int32_t, int32_t) +FUNC_FUNC_3BUF(max, uint32_t, uint32_t) +FUNC_FUNC_3BUF(max, int64_t, int64_t) +FUNC_FUNC_3BUF(max, uint64_t, uint64_t) +FUNC_FUNC_3BUF(max, long, long) +FUNC_FUNC_3BUF(max, ulong, unsigned long) + +/* Floating point */ +#if defined(HAVE_SHORT_FLOAT) +FUNC_FUNC_3BUF(max, short_float, short float) +#elif defined(HAVE_OPAL_SHORT_FLOAT_T) +FUNC_FUNC_3BUF(max, short_float, opal_short_float_t) +#endif +FUNC_FUNC_3BUF(max, float, float) +FUNC_FUNC_3BUF(max, double, double) +FUNC_FUNC_3BUF(max, long_double, long double) + + +/************************************************************************* + * Min + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) < (b) ? (a) : (b)) +/* C integer */ +FUNC_FUNC_3BUF(min, int8_t, int8_t) +FUNC_FUNC_3BUF(min, uint8_t, uint8_t) +FUNC_FUNC_3BUF(min, int16_t, int16_t) +FUNC_FUNC_3BUF(min, uint16_t, uint16_t) +FUNC_FUNC_3BUF(min, int32_t, int32_t) +FUNC_FUNC_3BUF(min, uint32_t, uint32_t) +FUNC_FUNC_3BUF(min, int64_t, int64_t) +FUNC_FUNC_3BUF(min, uint64_t, uint64_t) +FUNC_FUNC_3BUF(min, long, long) +FUNC_FUNC_3BUF(min, ulong, unsigned long) + +/* Floating point */ +#if defined(HAVE_SHORT_FLOAT) +FUNC_FUNC_3BUF(min, short_float, short float) +#elif defined(HAVE_OPAL_SHORT_FLOAT_T) +FUNC_FUNC_3BUF(min, short_float, opal_short_float_t) +#endif +FUNC_FUNC_3BUF(min, float, float) +FUNC_FUNC_3BUF(min, double, double) +FUNC_FUNC_3BUF(min, long_double, long double) + +/************************************************************************* + * Sum + *************************************************************************/ + +/* C integer */ +OP_FUNC_3BUF(sum, int8_t, int8_t, +) +OP_FUNC_3BUF(sum, uint8_t, uint8_t, +) +OP_FUNC_3BUF(sum, int16_t, int16_t, +) +OP_FUNC_3BUF(sum, uint16_t, uint16_t, +) +OP_FUNC_3BUF(sum, int32_t, int32_t, +) +OP_FUNC_3BUF(sum, uint32_t, uint32_t, +) +OP_FUNC_3BUF(sum, int64_t, int64_t, +) +OP_FUNC_3BUF(sum, uint64_t, uint64_t, +) +OP_FUNC_3BUF(sum, long, long, +) +OP_FUNC_3BUF(sum, ulong, unsigned long, +) + +/* Floating point */ +#if defined(HAVE_SHORT_FLOAT) +OP_FUNC_3BUF(sum, short_float, short float, +) +#elif defined(HAVE_OPAL_SHORT_FLOAT_T) +OP_FUNC_3BUF(sum, short_float, opal_short_float_t, +) +#endif +OP_FUNC_3BUF(sum, float, float, +) +OP_FUNC_3BUF(sum, double, double, +) +OP_FUNC_3BUF(sum, long_double, long double, +) + +/* Complex */ +#if 0 +#if defined(HAVE_SHORT_FLOAT__COMPLEX) +OP_FUNC_3BUF(sum, c_short_float_complex, short float _Complex, +) +#elif defined(HAVE_OPAL_SHORT_FLOAT_COMPLEX_T) +COMPLEX_SUM_FUNC_3BUF(c_short_float_complex, opal_short_float_t) +#endif +OP_FUNC_3BUF(sum, c_long_double_complex, cuLongDoubleComplex, +) +#endif // 0 +#undef current_func +#define current_func(a, b) (cuCaddf(a,b)) +FUNC_FUNC_3BUF(sum, c_float_complex, cuFloatComplex) +#undef current_func +#define current_func(a, b) (cuCadd(a,b)) +FUNC_FUNC_3BUF(sum, c_double_complex, cuDoubleComplex) + +/************************************************************************* + * Product + *************************************************************************/ + +/* C integer */ +OP_FUNC_3BUF(prod, int8_t, int8_t, *) +OP_FUNC_3BUF(prod, uint8_t, uint8_t, *) +OP_FUNC_3BUF(prod, int16_t, int16_t, *) +OP_FUNC_3BUF(prod, uint16_t, uint16_t, *) +OP_FUNC_3BUF(prod, int32_t, int32_t, *) +OP_FUNC_3BUF(prod, uint32_t, uint32_t, *) +OP_FUNC_3BUF(prod, int64_t, int64_t, *) +OP_FUNC_3BUF(prod, uint64_t, uint64_t, *) +OP_FUNC_3BUF(prod, long, long, *) +OP_FUNC_3BUF(prod, ulong, unsigned long, *) + +/* Floating point */ +#if defined(HAVE_SHORT_FLOAT) +OP_FUNC_3BUF(prod, short_float, short float, *) +#elif defined(HAVE_OPAL_SHORT_FLOAT_T) +OP_FUNC_3BUF(prod, short_float, opal_short_float_t, *) +#endif +OP_FUNC_3BUF(prod, float, float, *) +OP_FUNC_3BUF(prod, double, double, *) +OP_FUNC_3BUF(prod, long_double, long double, *) + +/* Complex */ +#if 0 +#if defined(HAVE_SHORT_FLOAT__COMPLEX) +OP_FUNC_3BUF(prod, c_short_float_complex, short float _Complex, *) +#elif defined(HAVE_OPAL_SHORT_FLOAT_COMPLEX_T) +COMPLEX_PROD_FUNC_3BUF(c_short_float_complex, opal_short_float_t) +#endif +OP_FUNC_3BUF(prod, c_long_double_complex, long double _Complex, *) +#endif // 0 +#undef current_func +#define current_func(a, b) (cuCmulf(a,b)) +FUNC_FUNC_3BUF(prod, c_float_complex, cuFloatComplex) +#undef current_func +#define current_func(a, b) (cuCmul(a,b)) +FUNC_FUNC_3BUF(prod, c_double_complex, cuDoubleComplex) + +/************************************************************************* + * Logical AND + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) && (b)) +/* C integer */ +FUNC_FUNC_3BUF(land, int8_t, int8_t) +FUNC_FUNC_3BUF(land, uint8_t, uint8_t) +FUNC_FUNC_3BUF(land, int16_t, int16_t) +FUNC_FUNC_3BUF(land, uint16_t, uint16_t) +FUNC_FUNC_3BUF(land, int32_t, int32_t) +FUNC_FUNC_3BUF(land, uint32_t, uint32_t) +FUNC_FUNC_3BUF(land, int64_t, int64_t) +FUNC_FUNC_3BUF(land, uint64_t, uint64_t) +FUNC_FUNC_3BUF(land, long, long) +FUNC_FUNC_3BUF(land, ulong, unsigned long) + +/* C++ bool */ +FUNC_FUNC_3BUF(land, bool, bool) + +/************************************************************************* + * Logical OR + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) || (b)) +/* C integer */ +FUNC_FUNC_3BUF(lor, int8_t, int8_t) +FUNC_FUNC_3BUF(lor, uint8_t, uint8_t) +FUNC_FUNC_3BUF(lor, int16_t, int16_t) +FUNC_FUNC_3BUF(lor, uint16_t, uint16_t) +FUNC_FUNC_3BUF(lor, int32_t, int32_t) +FUNC_FUNC_3BUF(lor, uint32_t, uint32_t) +FUNC_FUNC_3BUF(lor, int64_t, int64_t) +FUNC_FUNC_3BUF(lor, uint64_t, uint64_t) +FUNC_FUNC_3BUF(lor, long, long) +FUNC_FUNC_3BUF(lor, ulong, unsigned long) + +/* C++ bool */ +FUNC_FUNC_3BUF(lor, bool, bool) + +/************************************************************************* + * Logical XOR + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a ? 1 : 0) ^ (b ? 1: 0)) +/* C integer */ +FUNC_FUNC_3BUF(lxor, int8_t, int8_t) +FUNC_FUNC_3BUF(lxor, uint8_t, uint8_t) +FUNC_FUNC_3BUF(lxor, int16_t, int16_t) +FUNC_FUNC_3BUF(lxor, uint16_t, uint16_t) +FUNC_FUNC_3BUF(lxor, int32_t, int32_t) +FUNC_FUNC_3BUF(lxor, uint32_t, uint32_t) +FUNC_FUNC_3BUF(lxor, int64_t, int64_t) +FUNC_FUNC_3BUF(lxor, uint64_t, uint64_t) +FUNC_FUNC_3BUF(lxor, long, long) +FUNC_FUNC_3BUF(lxor, ulong, unsigned long) + +/* C++ bool */ +FUNC_FUNC_3BUF(lxor, bool, bool) + +/************************************************************************* + * Bitwise AND + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) & (b)) +/* C integer */ +FUNC_FUNC_3BUF(band, int8_t, int8_t) +FUNC_FUNC_3BUF(band, uint8_t, uint8_t) +FUNC_FUNC_3BUF(band, int16_t, int16_t) +FUNC_FUNC_3BUF(band, uint16_t, uint16_t) +FUNC_FUNC_3BUF(band, int32_t, int32_t) +FUNC_FUNC_3BUF(band, uint32_t, uint32_t) +FUNC_FUNC_3BUF(band, int64_t, int64_t) +FUNC_FUNC_3BUF(band, uint64_t, uint64_t) +FUNC_FUNC_3BUF(band, long, long) +FUNC_FUNC_3BUF(band, ulong, unsigned long) + +/* Byte */ +FUNC_FUNC_3BUF(band, byte, char) + +/************************************************************************* + * Bitwise OR + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) | (b)) +/* C integer */ +FUNC_FUNC_3BUF(bor, int8_t, int8_t) +FUNC_FUNC_3BUF(bor, uint8_t, uint8_t) +FUNC_FUNC_3BUF(bor, int16_t, int16_t) +FUNC_FUNC_3BUF(bor, uint16_t, uint16_t) +FUNC_FUNC_3BUF(bor, int32_t, int32_t) +FUNC_FUNC_3BUF(bor, uint32_t, uint32_t) +FUNC_FUNC_3BUF(bor, int64_t, int64_t) +FUNC_FUNC_3BUF(bor, uint64_t, uint64_t) +FUNC_FUNC_3BUF(bor, long, long) +FUNC_FUNC_3BUF(bor, ulong, unsigned long) + +/* Byte */ +FUNC_FUNC_3BUF(bor, byte, char) + +/************************************************************************* + * Bitwise XOR + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) ^ (b)) +/* C integer */ +FUNC_FUNC_3BUF(bxor, int8_t, int8_t) +FUNC_FUNC_3BUF(bxor, uint8_t, uint8_t) +FUNC_FUNC_3BUF(bxor, int16_t, int16_t) +FUNC_FUNC_3BUF(bxor, uint16_t, uint16_t) +FUNC_FUNC_3BUF(bxor, int32_t, int32_t) +FUNC_FUNC_3BUF(bxor, uint32_t, uint32_t) +FUNC_FUNC_3BUF(bxor, int64_t, int64_t) +FUNC_FUNC_3BUF(bxor, uint64_t, uint64_t) +FUNC_FUNC_3BUF(bxor, long, long) +FUNC_FUNC_3BUF(bxor, ulong, unsigned long) + +/* Byte */ +FUNC_FUNC_3BUF(bxor, byte, char) + +/************************************************************************* + * Max location + *************************************************************************/ + +LOC_FUNC_3BUF(maxloc, float_int, >) +LOC_FUNC_3BUF(maxloc, double_int, >) +LOC_FUNC_3BUF(maxloc, long_int, >) +LOC_FUNC_3BUF(maxloc, 2int, >) +LOC_FUNC_3BUF(maxloc, short_int, >) +LOC_FUNC_3BUF(maxloc, long_double_int, >) + +/* Fortran compat types */ +LOC_FUNC_3BUF(maxloc, 2float, >) +LOC_FUNC_3BUF(maxloc, 2double, >) +LOC_FUNC_3BUF(maxloc, 2int8, >) +LOC_FUNC_3BUF(maxloc, 2int16, >) +LOC_FUNC_3BUF(maxloc, 2int32, >) +LOC_FUNC_3BUF(maxloc, 2int64, >) + +/************************************************************************* + * Min location + *************************************************************************/ + +LOC_FUNC_3BUF(minloc, float_int, <) +LOC_FUNC_3BUF(minloc, double_int, <) +LOC_FUNC_3BUF(minloc, long_int, <) +LOC_FUNC_3BUF(minloc, 2int, <) +LOC_FUNC_3BUF(minloc, short_int, <) +LOC_FUNC_3BUF(minloc, long_double_int, <) + +/* Fortran compat types */ +LOC_FUNC_3BUF(minloc, 2float, <) +LOC_FUNC_3BUF(minloc, 2double, <) +LOC_FUNC_3BUF(minloc, 2int8, <) +LOC_FUNC_3BUF(minloc, 2int16, <) +LOC_FUNC_3BUF(minloc, 2int32, <) +LOC_FUNC_3BUF(minloc, 2int64, <) diff --git a/ompi/mca/op/cuda/op_cuda_impl.h b/ompi/mca/op/cuda/op_cuda_impl.h new file mode 100644 index 00000000000..43209581bab --- /dev/null +++ b/ompi/mca/op/cuda/op_cuda_impl.h @@ -0,0 +1,695 @@ +/* + * Copyright (c) 2019-2023 The University of Tennessee and The University + * of Tennessee Research Foundation. All rights + * reserved. + * Copyright (c) 2020 Research Organization for Information Science + * and Technology (RIST). All rights reserved. + * $COPYRIGHT$ + * + * Additional copyrights may follow + * + * $HEADER$ + */ + +#include + +#include +#include +#include + +#ifndef BEGIN_C_DECLS +#if defined(c_plusplus) || defined(__cplusplus) +# define BEGIN_C_DECLS extern "C" { +# define END_C_DECLS } +#else +# define BEGIN_C_DECLS /* empty */ +# define END_C_DECLS /* empty */ +#endif +#endif + +BEGIN_C_DECLS + +#define OP_FUNC_SIG(name, type_name, type) \ + void ompi_op_cuda_2buff_##name##_##type_name##_submit(const type *in, \ + type *inout, \ + int count, \ + int threads_per_block, \ + int max_blocks, \ + CUstream stream); + +#define FUNC_FUNC_SIG(name, type_name, type) \ + void ompi_op_cuda_2buff_##name##_##type_name##_submit(const type *in, \ + type *inout, \ + int count, \ + int threads_per_block, \ + int max_blocks, \ + CUstream stream); + +/* + * Since all the functions in this file are essentially identical, we + * use a macro to substitute in names and types. The core operation + * in all functions that use this macro is the same. + * + * This macro is for minloc and maxloc + */ +#define LOC_STRUCT(type_name, type1, type2) \ + typedef struct { \ + type1 v; \ + type2 k; \ + } ompi_op_predefined_##type_name##_t; + +#define LOC_FUNC_SIG(name, type_name) \ + void ompi_op_cuda_2buff_##name##_##type_name##_submit(const ompi_op_predefined_##type_name##_t *a, \ + ompi_op_predefined_##type_name##_t *b, \ + int count, \ + int threads_per_block, \ + int max_blocks, \ + CUstream stream); + +/************************************************************************* + * Max + *************************************************************************/ + +/* C integer */ +FUNC_FUNC_SIG(max, int8_t, int8_t) +FUNC_FUNC_SIG(max, uint8_t, uint8_t) +FUNC_FUNC_SIG(max, int16_t, int16_t) +FUNC_FUNC_SIG(max, uint16_t, uint16_t) +FUNC_FUNC_SIG(max, int32_t, int32_t) +FUNC_FUNC_SIG(max, uint32_t, uint32_t) +FUNC_FUNC_SIG(max, int64_t, int64_t) +FUNC_FUNC_SIG(max, uint64_t, uint64_t) +FUNC_FUNC_SIG(max, long, long) +FUNC_FUNC_SIG(max, ulong, unsigned long) + +#if 0 +/* Floating point */ +#if defined(HAVE_SHORT_FLOAT) +FUNC_FUNC_SIG(max, short_float, short float) +#elif defined(HAVE_OPAL_SHORT_FLOAT_T) +FUNC_FUNC_SIG(max, short_float, opal_short_float_t) +#endif +#endif // 0 + +FUNC_FUNC_SIG(max, float, float) +FUNC_FUNC_SIG(max, double, double) +FUNC_FUNC_SIG(max, long_double, long double) + +/************************************************************************* + * Min + *************************************************************************/ + +/* C integer */ +FUNC_FUNC_SIG(min, int8_t, int8_t) +FUNC_FUNC_SIG(min, uint8_t, uint8_t) +FUNC_FUNC_SIG(min, int16_t, int16_t) +FUNC_FUNC_SIG(min, uint16_t, uint16_t) +FUNC_FUNC_SIG(min, int32_t, int32_t) +FUNC_FUNC_SIG(min, uint32_t, uint32_t) +FUNC_FUNC_SIG(min, int64_t, int64_t) +FUNC_FUNC_SIG(min, uint64_t, uint64_t) +FUNC_FUNC_SIG(min, long, long) +FUNC_FUNC_SIG(min, ulong, unsigned long) + +#if 0 +/* Floating point */ +#if defined(HAVE_SHORT_FLOAT) +FUNC_FUNC_SIG(min, short_float, short float) +#elif defined(HAVE_OPAL_SHORT_FLOAT_T) +FUNC_FUNC_SIG(min, short_float, opal_short_float_t) +#endif +#endif // 0 + +FUNC_FUNC_SIG(min, float, float) +FUNC_FUNC_SIG(min, double, double) +FUNC_FUNC_SIG(min, long_double, long double) + +/************************************************************************* + * Sum + *************************************************************************/ + +/* C integer */ +OP_FUNC_SIG(sum, int8_t, int8_t) +OP_FUNC_SIG(sum, uint8_t, uint8_t) +OP_FUNC_SIG(sum, int16_t, int16_t) +OP_FUNC_SIG(sum, uint16_t, uint16_t) +OP_FUNC_SIG(sum, int32_t, int32_t) +OP_FUNC_SIG(sum, uint32_t, uint32_t) +OP_FUNC_SIG(sum, int64_t, int64_t) +OP_FUNC_SIG(sum, uint64_t, uint64_t) +OP_FUNC_SIG(sum, long, long) +OP_FUNC_SIG(sum, ulong, unsigned long) + +//#if __CUDA_ARCH__ >= 530 +//OP_FUNC_SIG(sum, half, half) +//#endif // __CUDA_ARCH__ + +#if 0 +/* Floating point */ +#if defined(HAVE_SHORT_FLOAT) +OP_FUNC_SIG(sum, short_float, short float) +#elif defined(HAVE_OPAL_SHORT_FLOAT_T) +OP_FUNC_SIG(sum, short_float, opal_short_float_t) +#endif +#endif // 0 + +OP_FUNC_SIG(sum, float, float) +OP_FUNC_SIG(sum, double, double) +OP_FUNC_SIG(sum, long_double, long double) + +/* Complex */ +#if 0 +#if defined(HAVE_SHORT_FLOAT__COMPLEX) +OP_FUNC_SIG(sum, c_short_float_complex, short float _Complex) +#elif defined(HAVE_OPAL_SHORT_FLOAT_COMPLEX_T) +COMPLEX_SUM_FUNC(c_short_float_complex, opal_short_float_t) +OP_FUNC_SIG(sum, c_long_double_complex, long double _Complex) +#endif +#endif // 0 +FUNC_FUNC_SIG(sum, c_float_complex, cuFloatComplex) +FUNC_FUNC_SIG(sum, c_double_complex, cuDoubleComplex) + +/************************************************************************* + * Product + *************************************************************************/ + +/* C integer */ +OP_FUNC_SIG(prod, int8_t, int8_t) +OP_FUNC_SIG(prod, uint8_t, uint8_t) +OP_FUNC_SIG(prod, int16_t, int16_t) +OP_FUNC_SIG(prod, uint16_t, uint16_t) +OP_FUNC_SIG(prod, int32_t, int32_t) +OP_FUNC_SIG(prod, uint32_t, uint32_t) +OP_FUNC_SIG(prod, int64_t, int64_t) +OP_FUNC_SIG(prod, uint64_t, uint64_t) +OP_FUNC_SIG(prod, long, long) +OP_FUNC_SIG(prod, ulong, unsigned long) + +#if 0 +/* Floating point */ +#if defined(HAVE_SHORT_FLOAT) +OP_FUNC_SIG(prod, short_float, short float) +#elif defined(HAVE_OPAL_SHORT_FLOAT_T) +OP_FUNC_SIG(prod, short_float, opal_short_float_t) +#endif +#endif // 0 + +OP_FUNC_SIG(prod, float, float) +OP_FUNC_SIG(prod, float, float) +OP_FUNC_SIG(prod, double, double) +OP_FUNC_SIG(prod, long_double, long double) + +/* Complex */ +#if 0 +#if defined(HAVE_SHORT_FLOAT__COMPLEX) +OP_FUNC_SIG(prod, c_short_float_complex, short float _Complex) +#elif defined(HAVE_OPAL_SHORT_FLOAT_COMPLEX_T) +COMPLEX_PROD_FUNC(c_short_float_complex, opal_short_float_t) +#endif +OP_FUNC_SIG(prod, c_long_double_complex, long double _Complex) +#endif // 0 + +FUNC_FUNC_SIG(prod, c_float_complex, cuFloatComplex) +FUNC_FUNC_SIG(prod, c_double_complex, cuDoubleComplex) + +/************************************************************************* + * Logical AND + *************************************************************************/ + +/* C integer */ +FUNC_FUNC_SIG(land, int8_t, int8_t) +FUNC_FUNC_SIG(land, uint8_t, uint8_t) +FUNC_FUNC_SIG(land, int16_t, int16_t) +FUNC_FUNC_SIG(land, uint16_t, uint16_t) +FUNC_FUNC_SIG(land, int32_t, int32_t) +FUNC_FUNC_SIG(land, uint32_t, uint32_t) +FUNC_FUNC_SIG(land, int64_t, int64_t) +FUNC_FUNC_SIG(land, uint64_t, uint64_t) +FUNC_FUNC_SIG(land, long, long) +FUNC_FUNC_SIG(land, ulong, unsigned long) + +/* C++ bool */ +FUNC_FUNC_SIG(land, bool, bool) + +/************************************************************************* + * Logical OR + *************************************************************************/ + +/* C integer */ +FUNC_FUNC_SIG(lor, int8_t, int8_t) +FUNC_FUNC_SIG(lor, uint8_t, uint8_t) +FUNC_FUNC_SIG(lor, int16_t, int16_t) +FUNC_FUNC_SIG(lor, uint16_t, uint16_t) +FUNC_FUNC_SIG(lor, int32_t, int32_t) +FUNC_FUNC_SIG(lor, uint32_t, uint32_t) +FUNC_FUNC_SIG(lor, int64_t, int64_t) +FUNC_FUNC_SIG(lor, uint64_t, uint64_t) +FUNC_FUNC_SIG(lor, long, long) +FUNC_FUNC_SIG(lor, ulong, unsigned long) + +/* C++ bool */ +FUNC_FUNC_SIG(lor, bool, bool) + +/************************************************************************* + * Logical XOR + *************************************************************************/ + +/* C integer */ +FUNC_FUNC_SIG(lxor, int8_t, int8_t) +FUNC_FUNC_SIG(lxor, uint8_t, uint8_t) +FUNC_FUNC_SIG(lxor, int16_t, int16_t) +FUNC_FUNC_SIG(lxor, uint16_t, uint16_t) +FUNC_FUNC_SIG(lxor, int32_t, int32_t) +FUNC_FUNC_SIG(lxor, uint32_t, uint32_t) +FUNC_FUNC_SIG(lxor, int64_t, int64_t) +FUNC_FUNC_SIG(lxor, uint64_t, uint64_t) +FUNC_FUNC_SIG(lxor, long, long) +FUNC_FUNC_SIG(lxor, ulong, unsigned long) + +/* C++ bool */ +FUNC_FUNC_SIG(lxor, bool, bool) + +/************************************************************************* + * Bitwise AND + *************************************************************************/ + +/* C integer */ +FUNC_FUNC_SIG(band, int8_t, int8_t) +FUNC_FUNC_SIG(band, uint8_t, uint8_t) +FUNC_FUNC_SIG(band, int16_t, int16_t) +FUNC_FUNC_SIG(band, uint16_t, uint16_t) +FUNC_FUNC_SIG(band, int32_t, int32_t) +FUNC_FUNC_SIG(band, uint32_t, uint32_t) +FUNC_FUNC_SIG(band, int64_t, int64_t) +FUNC_FUNC_SIG(band, uint64_t, uint64_t) +FUNC_FUNC_SIG(band, long, long) +FUNC_FUNC_SIG(band, ulong, unsigned long) + +/* Byte */ +FUNC_FUNC_SIG(band, byte, char) + +/************************************************************************* + * Bitwise OR + *************************************************************************/ + +/* C integer */ +FUNC_FUNC_SIG(bor, int8_t, int8_t) +FUNC_FUNC_SIG(bor, uint8_t, uint8_t) +FUNC_FUNC_SIG(bor, int16_t, int16_t) +FUNC_FUNC_SIG(bor, uint16_t, uint16_t) +FUNC_FUNC_SIG(bor, int32_t, int32_t) +FUNC_FUNC_SIG(bor, uint32_t, uint32_t) +FUNC_FUNC_SIG(bor, int64_t, int64_t) +FUNC_FUNC_SIG(bor, uint64_t, uint64_t) +FUNC_FUNC_SIG(bor, long, long) +FUNC_FUNC_SIG(bor, ulong, unsigned long) + +/* Byte */ +FUNC_FUNC_SIG(bor, byte, char) + +/************************************************************************* + * Bitwise XOR + *************************************************************************/ + +/* C integer */ +FUNC_FUNC_SIG(bxor, int8_t, int8_t) +FUNC_FUNC_SIG(bxor, uint8_t, uint8_t) +FUNC_FUNC_SIG(bxor, int16_t, int16_t) +FUNC_FUNC_SIG(bxor, uint16_t, uint16_t) +FUNC_FUNC_SIG(bxor, int32_t, int32_t) +FUNC_FUNC_SIG(bxor, uint32_t, uint32_t) +FUNC_FUNC_SIG(bxor, int64_t, int64_t) +FUNC_FUNC_SIG(bxor, uint64_t, uint64_t) +FUNC_FUNC_SIG(bxor, long, long) +FUNC_FUNC_SIG(bxor, ulong, unsigned long) + +/* Byte */ +FUNC_FUNC_SIG(bxor, byte, char) + +/************************************************************************* + * Min and max location "pair" datatypes + *************************************************************************/ + +LOC_STRUCT(float_int, float, int) +LOC_STRUCT(double_int, double, int) +LOC_STRUCT(long_int, long, int) +LOC_STRUCT(2int, int, int) +LOC_STRUCT(short_int, short, int) +LOC_STRUCT(long_double_int, long double, int) +LOC_STRUCT(ulong, unsigned long, int) +/* compat types for Fortran */ +LOC_STRUCT(2float, float, float) +LOC_STRUCT(2double, double, double) +LOC_STRUCT(2int8, int8_t, int8_t) +LOC_STRUCT(2int16, int16_t, int16_t) +LOC_STRUCT(2int32, int32_t, int32_t) +LOC_STRUCT(2int64, int64_t, int64_t) + +/************************************************************************* + * Max location + *************************************************************************/ + +LOC_FUNC_SIG(maxloc, 2float) +LOC_FUNC_SIG(maxloc, 2double) +LOC_FUNC_SIG(maxloc, 2int8) +LOC_FUNC_SIG(maxloc, 2int16) +LOC_FUNC_SIG(maxloc, 2int32) +LOC_FUNC_SIG(maxloc, 2int64) + +LOC_FUNC_SIG(maxloc, float_int) +LOC_FUNC_SIG(maxloc, double_int) +LOC_FUNC_SIG(maxloc, long_int) +LOC_FUNC_SIG(maxloc, 2int) +LOC_FUNC_SIG(maxloc, short_int) +LOC_FUNC_SIG(maxloc, long_double_int) + +/************************************************************************* + * Min location + *************************************************************************/ + +LOC_FUNC_SIG(minloc, 2float) +LOC_FUNC_SIG(minloc, 2double) +LOC_FUNC_SIG(minloc, 2int8) +LOC_FUNC_SIG(minloc, 2int16) +LOC_FUNC_SIG(minloc, 2int32) +LOC_FUNC_SIG(minloc, 2int64) + +LOC_FUNC_SIG(minloc, float_int) +LOC_FUNC_SIG(minloc, double_int) +LOC_FUNC_SIG(minloc, long_int) +LOC_FUNC_SIG(minloc, 2int) +LOC_FUNC_SIG(minloc, short_int) +LOC_FUNC_SIG(minloc, long_double_int) + + + +#define OP_FUNC_3BUF_SIG(name, type_name, type) \ + void ompi_op_cuda_3buff_##name##_##type_name##_submit(const type *in1, \ + const type *in2, \ + type *inout, \ + int count, \ + int threads_per_block, \ + int max_blocks, \ + CUstream stream); + +#define FUNC_FUNC_3BUF_SIG(name, type_name, type) \ + void ompi_op_cuda_3buff_##name##_##type_name##_submit(const type *in1, \ + const type *in2, \ + type *inout, \ + int count, \ + int threads_per_block, \ + int max_blocks, \ + CUstream stream); + +#define LOC_FUNC_3BUF_SIG(name, type_name) \ + void ompi_op_cuda_3buff_##name##_##type_name##_submit(const ompi_op_predefined_##type_name##_t *a1, \ + const ompi_op_predefined_##type_name##_t *a2, \ + ompi_op_predefined_##type_name##_t *b, \ + int count, \ + int threads_per_block, \ + int max_blocks, \ + CUstream stream); + + +/************************************************************************* + * Max + *************************************************************************/ + +/* C integer */ +FUNC_FUNC_3BUF_SIG(max, int8_t, int8_t) +FUNC_FUNC_3BUF_SIG(max, uint8_t, uint8_t) +FUNC_FUNC_3BUF_SIG(max, int16_t, int16_t) +FUNC_FUNC_3BUF_SIG(max, uint16_t, uint16_t) +FUNC_FUNC_3BUF_SIG(max, int32_t, int32_t) +FUNC_FUNC_3BUF_SIG(max, uint32_t, uint32_t) +FUNC_FUNC_3BUF_SIG(max, int64_t, int64_t) +FUNC_FUNC_3BUF_SIG(max, uint64_t, uint64_t) +FUNC_FUNC_3BUF_SIG(max, long, long) +FUNC_FUNC_3BUF_SIG(max, ulong, unsigned long) + +/* Floating point */ +#if defined(HAVE_SHORT_FLOAT) +FUNC_FUNC_3BUF_SIG(max, short_float, short float) +#elif defined(HAVE_OPAL_SHORT_FLOAT_T) +FUNC_FUNC_3BUF_SIG(max, short_float, opal_short_float_t) +#endif +FUNC_FUNC_3BUF_SIG(max, float, float) +FUNC_FUNC_3BUF_SIG(max, double, double) +FUNC_FUNC_3BUF_SIG(max, long_double, long double) + +/************************************************************************* + * Min + *************************************************************************/ + +/* C integer */ +FUNC_FUNC_3BUF_SIG(min, int8_t, int8_t) +FUNC_FUNC_3BUF_SIG(min, uint8_t, uint8_t) +FUNC_FUNC_3BUF_SIG(min, int16_t, int16_t) +FUNC_FUNC_3BUF_SIG(min, uint16_t, uint16_t) +FUNC_FUNC_3BUF_SIG(min, int32_t, int32_t) +FUNC_FUNC_3BUF_SIG(min, uint32_t, uint32_t) +FUNC_FUNC_3BUF_SIG(min, int64_t, int64_t) +FUNC_FUNC_3BUF_SIG(min, uint64_t, uint64_t) +FUNC_FUNC_3BUF_SIG(min, long, long) +FUNC_FUNC_3BUF_SIG(min, ulong, unsigned long) + +/* Floating point */ +#if defined(HAVE_SHORT_FLOAT) +FUNC_FUNC_3BUF_SIG(min, short_float, short float) +#elif defined(HAVE_OPAL_SHORT_FLOAT_T) +FUNC_FUNC_3BUF_SIG(min, short_float, opal_short_float_t) +#endif +FUNC_FUNC_3BUF_SIG(min, float, float) +FUNC_FUNC_3BUF_SIG(min, double, double) +FUNC_FUNC_3BUF_SIG(min, long_double, long double) + +/************************************************************************* + * Sum + *************************************************************************/ + +/* C integer */ +OP_FUNC_3BUF_SIG(sum, int8_t, int8_t) +OP_FUNC_3BUF_SIG(sum, uint8_t, uint8_t) +OP_FUNC_3BUF_SIG(sum, int16_t, int16_t) +OP_FUNC_3BUF_SIG(sum, uint16_t, uint16_t) +OP_FUNC_3BUF_SIG(sum, int32_t, int32_t) +OP_FUNC_3BUF_SIG(sum, uint32_t, uint32_t) +OP_FUNC_3BUF_SIG(sum, int64_t, int64_t) +OP_FUNC_3BUF_SIG(sum, uint64_t, uint64_t) +OP_FUNC_3BUF_SIG(sum, long, long) +OP_FUNC_3BUF_SIG(sum, ulong, unsigned long) + +/* Floating point */ +#if defined(HAVE_SHORT_FLOAT) +OP_FUNC_3BUF_SIG(sum, short_float, short float) +#elif defined(HAVE_OPAL_SHORT_FLOAT_T) +OP_FUNC_3BUF_SIG(sum, short_float, opal_short_float_t) +#endif +OP_FUNC_3BUF_SIG(sum, float, float) +OP_FUNC_3BUF_SIG(sum, double, double) +OP_FUNC_3BUF_SIG(sum, long_double, long double) + +/* Complex */ +#if 0 +#if defined(HAVE_SHORT_FLOAT__COMPLEX) +OP_FUNC_3BUF_SIG(sum, c_short_float_complex, short float _Complex) +#elif defined(HAVE_OPAL_SHORT_FLOAT_COMPLEX_T) +COMPLEX_SUM_FUNC_3BUF(c_short_float_complex, opal_short_float_t) +#endif +OP_FUNC_3BUF_SIG(sum, c_long_double_complex, long double _Complex) +#endif // 0 +FUNC_FUNC_3BUF_SIG(sum, c_float_complex, cuFloatComplex) +FUNC_FUNC_3BUF_SIG(sum, c_double_complex, cuDoubleComplex) + +/************************************************************************* + * Product + *************************************************************************/ + +/* C integer */ +OP_FUNC_3BUF_SIG(prod, int8_t, int8_t) +OP_FUNC_3BUF_SIG(prod, uint8_t, uint8_t) +OP_FUNC_3BUF_SIG(prod, int16_t, int16_t) +OP_FUNC_3BUF_SIG(prod, uint16_t, uint16_t) +OP_FUNC_3BUF_SIG(prod, int32_t, int32_t) +OP_FUNC_3BUF_SIG(prod, uint32_t, uint32_t) +OP_FUNC_3BUF_SIG(prod, int64_t, int64_t) +OP_FUNC_3BUF_SIG(prod, uint64_t, uint64_t) +OP_FUNC_3BUF_SIG(prod, long, long) +OP_FUNC_3BUF_SIG(prod, ulong, unsigned long) + +/* Floating point */ +#if defined(HAVE_SHORT_FLOAT) +OP_FUNC_3BUF_SIG(prod, short_float, short float) +#elif defined(HAVE_OPAL_SHORT_FLOAT_T) +OP_FUNC_3BUF_SIG(prod, short_float, opal_short_float_t) +#endif +OP_FUNC_3BUF_SIG(prod, float, float) +OP_FUNC_3BUF_SIG(prod, double, double) +OP_FUNC_3BUF_SIG(prod, long_double, long double) + +/* Complex */ +#if 0 +#if defined(HAVE_SHORT_FLOAT__COMPLEX) +OP_FUNC_3BUF_SIG(prod, c_short_float_complex, short float _Complex) +#elif defined(HAVE_OPAL_SHORT_FLOAT_COMPLEX_T) +COMPLEX_PROD_FUNC_3BUF(c_short_float_complex, opal_short_float_t) +#endif +OP_FUNC_3BUF_SIG(prod, c_float_complex, float _Complex) +OP_FUNC_3BUF_SIG(prod, c_double_complex, double _Complex) +OP_FUNC_3BUF_SIG(prod, c_long_double_complex, long double _Complex) +#endif // 0 +FUNC_FUNC_3BUF_SIG(prod, c_float_complex, cuFloatComplex) +FUNC_FUNC_3BUF_SIG(prod, c_double_complex, cuDoubleComplex) + +/************************************************************************* + * Logical AND + *************************************************************************/ + +/* C integer */ +FUNC_FUNC_3BUF_SIG(land, int8_t, int8_t) +FUNC_FUNC_3BUF_SIG(land, uint8_t, uint8_t) +FUNC_FUNC_3BUF_SIG(land, int16_t, int16_t) +FUNC_FUNC_3BUF_SIG(land, uint16_t, uint16_t) +FUNC_FUNC_3BUF_SIG(land, int32_t, int32_t) +FUNC_FUNC_3BUF_SIG(land, uint32_t, uint32_t) +FUNC_FUNC_3BUF_SIG(land, int64_t, int64_t) +FUNC_FUNC_3BUF_SIG(land, uint64_t, uint64_t) +FUNC_FUNC_3BUF_SIG(land, long, long) +FUNC_FUNC_3BUF_SIG(land, ulong, unsigned long) + +/* C++ bool */ +FUNC_FUNC_3BUF_SIG(land, bool, bool) + +/************************************************************************* + * Logical OR + *************************************************************************/ + +/* C integer */ +FUNC_FUNC_3BUF_SIG(lor, int8_t, int8_t) +FUNC_FUNC_3BUF_SIG(lor, uint8_t, uint8_t) +FUNC_FUNC_3BUF_SIG(lor, int16_t, int16_t) +FUNC_FUNC_3BUF_SIG(lor, uint16_t, uint16_t) +FUNC_FUNC_3BUF_SIG(lor, int32_t, int32_t) +FUNC_FUNC_3BUF_SIG(lor, uint32_t, uint32_t) +FUNC_FUNC_3BUF_SIG(lor, int64_t, int64_t) +FUNC_FUNC_3BUF_SIG(lor, uint64_t, uint64_t) +FUNC_FUNC_3BUF_SIG(lor, long, long) +FUNC_FUNC_3BUF_SIG(lor, ulong, unsigned long) + +/* C++ bool */ +FUNC_FUNC_3BUF_SIG(lor, bool, bool) + +/************************************************************************* + * Logical XOR + *************************************************************************/ + +/* C integer */ +FUNC_FUNC_3BUF_SIG(lxor, int8_t, int8_t) +FUNC_FUNC_3BUF_SIG(lxor, uint8_t, uint8_t) +FUNC_FUNC_3BUF_SIG(lxor, int16_t, int16_t) +FUNC_FUNC_3BUF_SIG(lxor, uint16_t, uint16_t) +FUNC_FUNC_3BUF_SIG(lxor, int32_t, int32_t) +FUNC_FUNC_3BUF_SIG(lxor, uint32_t, uint32_t) +FUNC_FUNC_3BUF_SIG(lxor, int64_t, int64_t) +FUNC_FUNC_3BUF_SIG(lxor, uint64_t, uint64_t) +FUNC_FUNC_3BUF_SIG(lxor, long, long) +FUNC_FUNC_3BUF_SIG(lxor, ulong, unsigned long) + +/* C++ bool */ +FUNC_FUNC_3BUF_SIG(lxor, bool, bool) + +/************************************************************************* + * Bitwise AND + *************************************************************************/ + +/* C integer */ +FUNC_FUNC_3BUF_SIG(band, int8_t, int8_t) +FUNC_FUNC_3BUF_SIG(band, uint8_t, uint8_t) +FUNC_FUNC_3BUF_SIG(band, int16_t, int16_t) +FUNC_FUNC_3BUF_SIG(band, uint16_t, uint16_t) +FUNC_FUNC_3BUF_SIG(band, int32_t, int32_t) +FUNC_FUNC_3BUF_SIG(band, uint32_t, uint32_t) +FUNC_FUNC_3BUF_SIG(band, int64_t, int64_t) +FUNC_FUNC_3BUF_SIG(band, uint64_t, uint64_t) +FUNC_FUNC_3BUF_SIG(band, long, long) +FUNC_FUNC_3BUF_SIG(band, ulong, unsigned long) + +/* Byte */ +FUNC_FUNC_3BUF_SIG(band, byte, char) + +/************************************************************************* + * Bitwise OR + *************************************************************************/ + +/* C integer */ +FUNC_FUNC_3BUF_SIG(bor, int8_t, int8_t) +FUNC_FUNC_3BUF_SIG(bor, uint8_t, uint8_t) +FUNC_FUNC_3BUF_SIG(bor, int16_t, int16_t) +FUNC_FUNC_3BUF_SIG(bor, uint16_t, uint16_t) +FUNC_FUNC_3BUF_SIG(bor, int32_t, int32_t) +FUNC_FUNC_3BUF_SIG(bor, uint32_t, uint32_t) +FUNC_FUNC_3BUF_SIG(bor, int64_t, int64_t) +FUNC_FUNC_3BUF_SIG(bor, uint64_t, uint64_t) +FUNC_FUNC_3BUF_SIG(bor, long, long) +FUNC_FUNC_3BUF_SIG(bor, ulong, unsigned long) + +/* Byte */ +FUNC_FUNC_3BUF_SIG(bor, byte, char) + +/************************************************************************* + * Bitwise XOR + *************************************************************************/ + +/* C integer */ +FUNC_FUNC_3BUF_SIG(bxor, int8_t, int8_t) +FUNC_FUNC_3BUF_SIG(bxor, uint8_t, uint8_t) +FUNC_FUNC_3BUF_SIG(bxor, int16_t, int16_t) +FUNC_FUNC_3BUF_SIG(bxor, uint16_t, uint16_t) +FUNC_FUNC_3BUF_SIG(bxor, int32_t, int32_t) +FUNC_FUNC_3BUF_SIG(bxor, uint32_t, uint32_t) +FUNC_FUNC_3BUF_SIG(bxor, int64_t, int64_t) +FUNC_FUNC_3BUF_SIG(bxor, uint64_t, uint64_t) +FUNC_FUNC_3BUF_SIG(bxor, long, long) +FUNC_FUNC_3BUF_SIG(bxor, ulong, unsigned long) + +/* Byte */ +FUNC_FUNC_3BUF_SIG(bxor, byte, char) + +/************************************************************************* + * Max location + *************************************************************************/ + +LOC_FUNC_3BUF_SIG(maxloc, float_int) +LOC_FUNC_3BUF_SIG(maxloc, double_int) +LOC_FUNC_3BUF_SIG(maxloc, long_int) +LOC_FUNC_3BUF_SIG(maxloc, 2int) +LOC_FUNC_3BUF_SIG(maxloc, short_int) +LOC_FUNC_3BUF_SIG(maxloc, long_double_int) + +LOC_FUNC_3BUF_SIG(maxloc, 2float) +LOC_FUNC_3BUF_SIG(maxloc, 2double) +LOC_FUNC_3BUF_SIG(maxloc, 2int8) +LOC_FUNC_3BUF_SIG(maxloc, 2int16) +LOC_FUNC_3BUF_SIG(maxloc, 2int32) +LOC_FUNC_3BUF_SIG(maxloc, 2int64) + +/************************************************************************* + * Min location + *************************************************************************/ + +LOC_FUNC_3BUF_SIG(minloc, float_int) +LOC_FUNC_3BUF_SIG(minloc, double_int) +LOC_FUNC_3BUF_SIG(minloc, long_int) +LOC_FUNC_3BUF_SIG(minloc, 2int) +LOC_FUNC_3BUF_SIG(minloc, short_int) +LOC_FUNC_3BUF_SIG(minloc, long_double_int) + +LOC_FUNC_3BUF_SIG(minloc, 2float) +LOC_FUNC_3BUF_SIG(minloc, 2double) +LOC_FUNC_3BUF_SIG(minloc, 2int8) +LOC_FUNC_3BUF_SIG(minloc, 2int16) +LOC_FUNC_3BUF_SIG(minloc, 2int32) +LOC_FUNC_3BUF_SIG(minloc, 2int64) + + +END_C_DECLS diff --git a/ompi/mca/op/op.h b/ompi/mca/op/op.h index 34d26376ab9..097c2a109b4 100644 --- a/ompi/mca/op/op.h +++ b/ompi/mca/op/op.h @@ -85,6 +85,7 @@ #include "ompi_config.h" #include "opal/class/opal_object.h" +#include "opal/mca/accelerator/accelerator.h" #include "ompi/mca/mca.h" /* @@ -266,6 +267,22 @@ typedef void (*ompi_op_base_handler_fn_1_0_0_t)(const void *, void *, int *, typedef ompi_op_base_handler_fn_1_0_0_t ompi_op_base_handler_fn_t; +/** + * Typedef for 2-buffer op functions on streams/devices. + * + * We don't use MPI_User_function because this would create a + * confusing dependency loop between this file and mpi.h. So this is + * repeated code, but it's better this way (and this typedef will + * never change, so there's not much of a maintenance worry). + */ +typedef void (*ompi_op_base_stream_handler_fn_1_0_0_t)(const void *, void *, int *, + struct ompi_datatype_t **, + int device, + opal_accelerator_stream_t *stream, + struct ompi_op_base_module_1_0_0_t *); + +typedef ompi_op_base_stream_handler_fn_1_0_0_t ompi_op_base_stream_handler_fn_t; + /* * Typedef for 3-buffer (two input and one output) op functions. */ @@ -277,6 +294,19 @@ typedef void (*ompi_op_base_3buff_handler_fn_1_0_0_t)(const void *, typedef ompi_op_base_3buff_handler_fn_1_0_0_t ompi_op_base_3buff_handler_fn_t; +/* + * Typedef for 3-buffer (two input and one output) op functions on streams. + */ +typedef void (*ompi_op_base_3buff_stream_handler_fn_1_0_0_t)(const void *, + const void *, + void *, int *, + struct ompi_datatype_t **, + int device, + opal_accelerator_stream_t*, + struct ompi_op_base_module_1_0_0_t *); + +typedef ompi_op_base_3buff_stream_handler_fn_1_0_0_t ompi_op_base_3buff_stream_handler_fn_t; + /** * Op component initialization * @@ -376,10 +406,18 @@ typedef struct ompi_op_base_module_1_0_0_t { is being used for */ struct ompi_op_t *opm_op; + bool opm_device_enabled; + /** Function pointers for all the different datatypes to be used with the MPI_Op that this module is used with */ - ompi_op_base_handler_fn_1_0_0_t opm_fns[OMPI_OP_BASE_TYPE_MAX]; - ompi_op_base_3buff_handler_fn_1_0_0_t opm_3buff_fns[OMPI_OP_BASE_TYPE_MAX]; + union { + ompi_op_base_handler_fn_1_0_0_t opm_fns[OMPI_OP_BASE_TYPE_MAX]; + ompi_op_base_stream_handler_fn_1_0_0_t opm_stream_fns[OMPI_OP_BASE_TYPE_MAX]; + }; + union { + ompi_op_base_3buff_handler_fn_1_0_0_t opm_3buff_fns[OMPI_OP_BASE_TYPE_MAX]; + ompi_op_base_3buff_stream_handler_fn_1_0_0_t opm_3buff_stream_fns[OMPI_OP_BASE_TYPE_MAX]; + }; } ompi_op_base_module_1_0_0_t; /** @@ -404,6 +442,18 @@ typedef struct ompi_op_base_op_fns_1_0_0_t { typedef ompi_op_base_op_fns_1_0_0_t ompi_op_base_op_fns_t; +/** + * Struct that is used in op.h to hold all the function pointers and + * pointers to the corresopnding modules (so that we can properly + * RETAIN/RELEASE them) + */ +typedef struct ompi_op_base_op_stream_fns_1_0_0_t { + ompi_op_base_stream_handler_fn_1_0_0_t fns[OMPI_OP_BASE_TYPE_MAX]; + ompi_op_base_module_t *modules[OMPI_OP_BASE_TYPE_MAX]; +} ompi_op_base_op_stream_fns_1_0_0_t; + +typedef ompi_op_base_op_stream_fns_1_0_0_t ompi_op_base_op_stream_fns_t; + /** * Struct that is used in op.h to hold all the function pointers and * pointers to the corresopnding modules (so that we can properly @@ -416,6 +466,18 @@ typedef struct ompi_op_base_op_3buff_fns_1_0_0_t { typedef ompi_op_base_op_3buff_fns_1_0_0_t ompi_op_base_op_3buff_fns_t; +/** + * Struct that is used in op.h to hold all the function pointers and + * pointers to the corresopnding modules (so that we can properly + * RETAIN/RELEASE them) + */ +typedef struct ompi_op_base_op_3buff_stream_fns_1_0_0_t { + ompi_op_base_3buff_stream_handler_fn_1_0_0_t fns[OMPI_OP_BASE_TYPE_MAX]; + ompi_op_base_module_t *modules[OMPI_OP_BASE_TYPE_MAX]; +} ompi_op_base_op_3buff_stream_fns_1_0_0_t; + +typedef ompi_op_base_op_3buff_stream_fns_1_0_0_t ompi_op_base_op_3buff_stream_fns_t; + /* * Macro for use in modules that are of type op v2.0.0 */ diff --git a/ompi/mca/op/rocm/Makefile.am b/ompi/mca/op/rocm/Makefile.am new file mode 100644 index 00000000000..a4d999e25f9 --- /dev/null +++ b/ompi/mca/op/rocm/Makefile.am @@ -0,0 +1,82 @@ +# +# Copyright (c) 2023 The University of Tennessee and The University +# of Tennessee Research Foundation. All rights +# reserved. +# $COPYRIGHT$ +# +# Additional copyrights may follow +# +# $HEADER$ +# + +# This component provides support for offloading reduce ops to ROCM devices. +# +# See https://github.com/open-mpi/ompi/wiki/devel-CreateComponent +# for more details on how to make Open MPI components. + +# First, list all .h and .c sources. It is necessary to list all .h +# files so that they will be picked up in the distribution tarball. + +AM_CPPFLAGS = $(op_rocm_CPPFLAGS) + +dist_ompidata_DATA = help-ompi-mca-op-rocm.txt + +sources = op_rocm_component.c op_rocm.h op_rocm_functions.c op_rocm_impl.h +rocm_sources = op_rocm_impl.hip + +HIPCC = hipcc + +.cpp.l$(OBJEXT): + $(LIBTOOL) $(AM_V_lt) --tag=CC $(AM_LIBTOOLFLAGS) \ + $(LIBTOOLFLAGS) --mode=compile $(HIPCC) -O2 -std=c++17 -fvectorize -prefer-non-pic -Wc,-fPIC,-g -c $< + +# -o $($@.o:.lo) + +# Open MPI components can be compiled two ways: +# +# 1. As a standalone dynamic shared object (DSO), sometimes called a +# dynamically loadable library (DLL). +# +# 2. As a static library that is slurped up into the upper-level +# libmpi library (regardless of whether libmpi is a static or dynamic +# library). This is called a "Libtool convenience library". +# +# The component needs to create an output library in this top-level +# component directory, and named either mca__.la (for DSO +# builds) or libmca__.la (for static builds). The OMPI +# build system will have set the +# MCA_BUILD_ompi___DSO AM_CONDITIONAL to indicate +# which way this component should be built. + +if MCA_BUILD_ompi_op_rocm_DSO +component_install = mca_op_rocm.la +else +component_install = +component_noinst = libmca_op_rocm.la +endif + +# Specific information for DSO builds. +# +# The DSO should install itself in $(ompilibdir) (by default, +# $prefix/lib/openmpi). + +mcacomponentdir = $(ompilibdir) +mcacomponent_LTLIBRARIES = $(component_install) +mca_op_rocm_la_SOURCES = $(sources) +mca_op_rocm_la_LIBADD = $(rocm_sources:.cpp=.lo) +mca_op_rocm_la_LDFLAGS = -module -avoid-version $(top_builddir)/ompi/lib@OMPI_LIBMPI_NAME@.la \ + $(op_rocm_LIBS) +EXTRA_mca_op_rocm_la_SOURCES = $(rocm_sources) + +# Specific information for static builds. +# +# Note that we *must* "noinst"; the upper-layer Makefile.am's will +# slurp in the resulting .la library into libmpi. + +noinst_LTLIBRARIES = $(component_noinst) +libmca_op_rocm_la_SOURCES = $(sources) +libmca_op_rocm_la_LIBADD = $(rocm_sources:.cpp=.lo) +libmca_op_rocm_la_LDFLAGS = -module -avoid-version\ + $(op_rocm_LIBS) +EXTRA_libmca_op_rocm_la_SOURCES = $(rocm_sources) + diff --git a/ompi/mca/op/rocm/configure.m4 b/ompi/mca/op/rocm/configure.m4 new file mode 100644 index 00000000000..ffd88698be0 --- /dev/null +++ b/ompi/mca/op/rocm/configure.m4 @@ -0,0 +1,36 @@ +# -*- shell-script -*- +# +# Copyright (c) 2011-2013 NVIDIA Corporation. All rights reserved. +# Copyright (c) 2023 The University of Tennessee and The University +# of Tennessee Research Foundation. All rights +# reserved. +# Copyright (c) 2022 Amazon.com, Inc. or its affiliates. +# All Rights reserved. +# $COPYRIGHT$ +# +# Additional copyrights may follow +# +# $HEADER$ +# + +# +# If ROCm support was requested, then build the ROCm support library. +# This code checks makes sure the check was done earlier by the +# opal_check_rocm.m4 code. It also copies the flags and libs under +# opal_rocm_CPPFLAGS, opal_rocm_LDFLAGS, and opal_rocm_LIBS + +AC_DEFUN([MCA_ompi_op_rocm_CONFIG],[ + + AC_CONFIG_FILES([ompi/mca/op/rocm/Makefile]) + + OPAL_CHECK_ROCM([op_rocm]) + + AS_IF([test "x$ROCM_SUPPORT" = "x1"], + [$1], + [$2]) + + AC_SUBST([op_rocm_CPPFLAGS]) + AC_SUBST([op_rocm_LDFLAGS]) + AC_SUBST([op_rocm_LIBS]) + +])dnl diff --git a/ompi/mca/op/rocm/help-ompi-mca-op-rocm.txt b/ompi/mca/op/rocm/help-ompi-mca-op-rocm.txt new file mode 100644 index 00000000000..848afbb663d --- /dev/null +++ b/ompi/mca/op/rocm/help-ompi-mca-op-rocm.txt @@ -0,0 +1,15 @@ +# -*- text -*- +# +# Copyright (c) 2023 The University of Tennessee and The University +# of Tennessee Research Foundation. All rights +# reserved. +# $COPYRIGHT$ +# +# Additional copyrights may follow +# +# $HEADER$ +# +# This is the US/English help file for Open MPI's HIP operator component +# +[HIP call failed] +"HIP call %s failed: %s: %s\n" diff --git a/ompi/mca/op/rocm/op_rocm.h b/ompi/mca/op/rocm/op_rocm.h new file mode 100644 index 00000000000..0f86f44c41e --- /dev/null +++ b/ompi/mca/op/rocm/op_rocm.h @@ -0,0 +1,79 @@ +/* + * Copyright (c) 2019-2023 The University of Tennessee and The University + * of Tennessee Research Foundation. All rights + * reserved. + * $COPYRIGHT$ + * + * Additional copyrights may follow + * + * $HEADER$ + */ + +#ifndef MCA_OP_CUDA_EXPORT_H +#define MCA_OP_CUDA_EXPORT_H + +#include "ompi_config.h" + +#include "ompi/mca/mca.h" +#include "opal/class/opal_object.h" + +#include "ompi/mca/op/op.h" +#include "ompi/runtime/mpiruntime.h" + +#include + +BEGIN_C_DECLS + + +#define xstr(x) #x +#define str(x) xstr(x) + +#define CHECK(fn, args) \ + do { \ + hipError_t err = fn args; \ + if (err != hipSuccess) { \ + opal_show_help("help-ompi-mca-op-rocm.txt", \ + "HIP call failed", true, \ + str(fn), hipGetErrorName(err), \ + hipGetErrorString(err)); \ + ompi_mpi_abort(MPI_COMM_WORLD, 1); \ + } \ + } while (0) + + +/** + * Derive a struct from the base op component struct, allowing us to + * cache some component-specific information on our well-known + * component struct. + */ +typedef struct { + /** The base op component struct */ + ompi_op_base_component_1_0_0_t super; + int ro_max_num_blocks; + int ro_max_num_threads; + int *ro_max_threads_per_block; + int *ro_max_blocks; + hipDevice_t *ro_devices; + int ro_num_devices; +} ompi_op_rocm_component_t; + +/** + * Globally exported variable. Note that it is a *rocm* component + * (defined above), which has the ompi_op_base_component_t as its + * first member. Hence, the MCA/op framework will find the data that + * it expects in the first memory locations, but then the component + * itself can cache additional information after that that can be used + * by both the component and modules. + */ +OMPI_DECLSPEC extern ompi_op_rocm_component_t + mca_op_rocm_component; + +OMPI_DECLSPEC extern +ompi_op_base_stream_handler_fn_t ompi_op_rocm_functions[OMPI_OP_BASE_FORTRAN_OP_MAX][OMPI_OP_BASE_TYPE_MAX]; + +OMPI_DECLSPEC extern +ompi_op_base_3buff_stream_handler_fn_t ompi_op_rocm_3buff_functions[OMPI_OP_BASE_FORTRAN_OP_MAX][OMPI_OP_BASE_TYPE_MAX]; + +END_C_DECLS + +#endif /* MCA_OP_ROCM_EXPORT_H */ diff --git a/ompi/mca/op/rocm/op_rocm_component.c b/ompi/mca/op/rocm/op_rocm_component.c new file mode 100644 index 00000000000..79ad2d36a66 --- /dev/null +++ b/ompi/mca/op/rocm/op_rocm_component.c @@ -0,0 +1,207 @@ +/* + * Copyright (c) 2019-2023 The University of Tennessee and The University + * of Tennessee Research Foundation. All rights + * reserved. + * Copyright (c) 2020 Research Organization for Information Science + * and Technology (RIST). All rights reserved. + * Copyright (c) 2021 Cisco Systems, Inc. All rights reserved. + * $COPYRIGHT$ + * + * Additional copyrights may follow + * + * $HEADER$ + */ + +/** @file + * + * This is the "rocm" op component source code. + * + */ + +#include "ompi_config.h" + +#include "opal/util/printf.h" + +#include "ompi/constants.h" +#include "ompi/op/op.h" +#include "ompi/mca/op/op.h" +#include "ompi/mca/op/base/base.h" +#include "ompi/mca/op/rocm/op_rocm.h" + +#include + +static int rocm_component_open(void); +static int rocm_component_close(void); +static int rocm_component_init_query(bool enable_progress_threads, + bool enable_mpi_thread_multiple); +static struct ompi_op_base_module_1_0_0_t * + rocm_component_op_query(struct ompi_op_t *op, int *priority); +static int rocm_component_register(void); + +ompi_op_rocm_component_t mca_op_rocm_component = { + { + .opc_version = { + OMPI_OP_BASE_VERSION_1_0_0, + + .mca_component_name = "rocm", + MCA_BASE_MAKE_VERSION(component, OMPI_MAJOR_VERSION, OMPI_MINOR_VERSION, + OMPI_RELEASE_VERSION), + .mca_open_component = rocm_component_open, + .mca_close_component = rocm_component_close, + .mca_register_component_params = rocm_component_register, + }, + .opc_data = { + /* The component is checkpoint ready */ + MCA_BASE_METADATA_PARAM_CHECKPOINT + }, + + .opc_init_query = rocm_component_init_query, + .opc_op_query = rocm_component_op_query, + }, + .ro_max_num_blocks = -1, + .ro_max_num_threads = -1, + .ro_max_threads_per_block = NULL, + .ro_max_blocks = NULL, + .ro_devices = NULL, + .ro_num_devices = 0, +}; + +/* + * Component open + */ +static int rocm_component_open(void) +{ + /* We checked the flags during register, so if they are set to + * zero either the architecture is not suitable or the user disabled + * AVX support. + * + * A first level check to see what level of AVX is available on the + * hardware. + * + * Note that if this function returns non-OMPI_SUCCESS, then this + * component won't even be shown in ompi_info output (which is + * probably not what you want). + */ + return OMPI_SUCCESS; +} + +/* + * Component close + */ +static int rocm_component_close(void) +{ + if (mca_op_rocm_component.ro_num_devices > 0) { + //hipStreamDestroy(mca_op_rocm_component.ro_stream); + free(mca_op_rocm_component.ro_max_threads_per_block); + mca_op_rocm_component.ro_max_threads_per_block = NULL; + free(mca_op_rocm_component.ro_max_blocks); + mca_op_rocm_component.ro_max_blocks = NULL; + free(mca_op_rocm_component.ro_devices); + mca_op_rocm_component.ro_devices = NULL; + mca_op_rocm_component.ro_num_devices = 0; + } + + return OMPI_SUCCESS; +} + +/* + * Register MCA params. + */ +static int +rocm_component_register(void) +{ + /* TODO: add mca paramters */ + + mca_base_var_enum_flag_t *new_enum_flag = NULL; + (void) mca_base_component_var_register(&mca_op_rocm_component.super.opc_version, + "max_num_blocks", + "Maximum number of thread blocks in kernels (-1: device limit)", + MCA_BASE_VAR_TYPE_INT, + &(new_enum_flag->super), 0, 0, + OPAL_INFO_LVL_4, + MCA_BASE_VAR_SCOPE_LOCAL, + &mca_op_rocm_component.ro_max_num_blocks); + + (void) mca_base_component_var_register(&mca_op_rocm_component.super.opc_version, + "max_num_threads", + "Maximum number of threads per block in kernels (-1: device limit)", + MCA_BASE_VAR_TYPE_INT, + &(new_enum_flag->super), 0, 0, + OPAL_INFO_LVL_4, + MCA_BASE_VAR_SCOPE_LOCAL, + &mca_op_rocm_component.ro_max_num_threads); + + return OMPI_SUCCESS; +} + + +/* + * Query whether this component wants to be used in this process. + */ +static int +rocm_component_init_query(bool enable_progress_threads, + bool enable_mpi_thread_multiple) +{ + int num_devices; + int rc; + CHECK(hipGetDeviceCount, (&num_devices)); + mca_op_rocm_component.ro_num_devices = num_devices; + mca_op_rocm_component.ro_devices = (hipDevice_t*)malloc(num_devices*sizeof(hipDevice_t)); + mca_op_rocm_component.ro_max_threads_per_block = (int*)malloc(num_devices*sizeof(int)); + mca_op_rocm_component.ro_max_blocks = (int*)malloc(num_devices*sizeof(int)); + for (int i = 0; i < num_devices; ++i) { + CHECK(hipDeviceGet, (&mca_op_rocm_component.ro_devices[i], i)); + rc = hipDeviceGetAttribute(&mca_op_rocm_component.ro_max_threads_per_block[i], + hipDeviceAttributeMaxBlockDimX, + mca_op_rocm_component.ro_devices[i]); + if (hipSuccess != rc) { + /* fall-back to value that should work on every device */ + mca_op_rocm_component.ro_max_threads_per_block[i] = 512; + } + if (-1 < mca_op_rocm_component.ro_max_num_threads) { + if (mca_op_rocm_component.ro_max_threads_per_block[i] > mca_op_rocm_component.ro_max_num_threads) { + mca_op_rocm_component.ro_max_threads_per_block[i] = mca_op_rocm_component.ro_max_num_threads; + } + } + + rc = hipDeviceGetAttribute(&mca_op_rocm_component.ro_max_blocks[i], + hipDeviceAttributeMaxGridDimX, + mca_op_rocm_component.ro_devices[i]); + if (hipSuccess != rc) { + /* we'll try to max out the blocks */ + mca_op_rocm_component.ro_max_blocks[i] = 512; + } + if (-1 < mca_op_rocm_component.ro_max_num_blocks) { + if (mca_op_rocm_component.ro_max_blocks[i] > mca_op_rocm_component.ro_max_num_blocks) { + mca_op_rocm_component.ro_max_blocks[i] = mca_op_rocm_component.ro_max_num_blocks; + } + } + } + + return OMPI_SUCCESS; +} + +/* + * Query whether this component can be used for a specific op + */ +static struct ompi_op_base_module_1_0_0_t* +rocm_component_op_query(struct ompi_op_t *op, int *priority) +{ + ompi_op_base_module_t *module = NULL; + + module = OBJ_NEW(ompi_op_base_module_t); + module->opm_device_enabled = true; + for (int i = 0; i < OMPI_OP_BASE_TYPE_MAX; ++i) { + module->opm_stream_fns[i] = ompi_op_rocm_functions[op->o_f_to_c_index][i]; + module->opm_3buff_stream_fns[i] = ompi_op_rocm_3buff_functions[op->o_f_to_c_index][i]; + + if( NULL != module->opm_fns[i] ) { + OBJ_RETAIN(module); + } + if( NULL != module->opm_3buff_fns[i] ) { + OBJ_RETAIN(module); + } + } + *priority = 50; + return (ompi_op_base_module_1_0_0_t *) module; +} diff --git a/ompi/mca/op/rocm/op_rocm_functions.c b/ompi/mca/op/rocm/op_rocm_functions.c new file mode 100644 index 00000000000..46cd02709bd --- /dev/null +++ b/ompi/mca/op/rocm/op_rocm_functions.c @@ -0,0 +1,1897 @@ +/* + * Copyright (c) 2019-2023 The University of Tennessee and The University + * of Tennessee Research Foundation. All rights + * reserved. + * Copyright (c) 2020 Research Organization for Information Science + * and Technology (RIST). All rights reserved. + * $COPYRIGHT$ + * + * Additional copyrights may follow + * + * $HEADER$ + */ + +#include "ompi_config.h" + +#ifdef HAVE_SYS_TYPES_H +#include +#endif +#include "opal/util/output.h" + + +#include "ompi/op/op.h" +#include "ompi/mca/op/op.h" +#include "ompi/mca/op/base/base.h" +#include "ompi/mca/op/rocm/op_rocm.h" +#include "opal/mca/accelerator/accelerator.h" + +#include "ompi/mca/op/rocm/op_rocm.h" +#include "ompi/mca/op/rocm/op_rocm_impl.h" + +/** + * Disable warning about empty macro var-args. + * We use varargs to suppress expansion of typenames + * (e.g., int32_t -> int) which could lead to collisions + * for similar base types. */ +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wgnu-zero-variadic-macro-arguments" + +static inline void device_op_pre(const void *orig_source1, + void **source1, + int *source1_device, + const void *orig_source2, + void **source2, + int *source2_device, + void *orig_target, + void **target, + int *target_device, + int count, + struct ompi_datatype_t *dtype, + int *threads_per_block, + int *max_blocks, + int *device, + opal_accelerator_stream_t *stream) +{ + uint64_t target_flags = -1, source1_flags = -1, source2_flags = -1; + int target_rc, source1_rc, source2_rc = -1; + + *target = orig_target; + *source1 = (void*)orig_source1; + if (NULL != orig_source2) { + *source2 = (void*)orig_source2; + } + + if (*device != MCA_ACCELERATOR_NO_DEVICE_ID) { + /* we got the device from the caller, just adjust the output parameters */ + *target_device = *device; + *source1_device = *device; + if (NULL != source2_device) { + *source2_device = *device; + } + } else { + + target_rc = opal_accelerator.check_addr(*target, target_device, &target_flags); + source1_rc = opal_accelerator.check_addr(*source1, source1_device, &source1_flags); + *device = *target_device; + + if (NULL != orig_source2) { + source2_rc = opal_accelerator.check_addr(*source2, source2_device, &source2_flags); + } + + if (0 == target_rc && 0 == source1_rc && 0 == source2_rc) { + /* no buffers are on any device, select device 0 */ + *device = 0; + } else if (*target_device == -1) { + if (*source1_device == -1 && NULL != orig_source2) { + *device = *source2_device; + } else { + *device = *source1_device; + } + } + + if (0 == target_rc || 0 == source1_rc || *target_device != *source1_device) { + size_t nbytes; + ompi_datatype_type_size(dtype, &nbytes); + nbytes *= count; + + if (0 == target_rc) { + // allocate memory on the device for the target buffer + opal_accelerator.mem_alloc_stream(*device, target, nbytes, stream); + opal_accelerator.mem_copy_async(*target_device, *source1_device, + *target, orig_target, nbytes, stream, + MCA_ACCELERATOR_TRANSFER_UNSPEC); + *target_device = -1; // mark target device as host + } + + if (0 == source1_rc || *device != *source1_device) { + // allocate memory on the device for the source buffer + opal_accelerator.mem_alloc_stream(*device, source1, nbytes, stream); + opal_accelerator.mem_copy_async(*target_device, *source1_device, + *source1, orig_source1, nbytes, stream, + MCA_ACCELERATOR_TRANSFER_UNSPEC); + } + + } + if (NULL != source2_device && *target_device != *source2_device) { + // allocate memory on the device for the source buffer + size_t nbytes; + ompi_datatype_type_size(dtype, &nbytes); + nbytes *= count; + + opal_accelerator.mem_alloc_stream(*device, source2, nbytes, stream); + opal_accelerator.mem_copy_async(*target_device, *source2_device, + *source2, orig_source2, nbytes, stream, + MCA_ACCELERATOR_TRANSFER_UNSPEC); + } + } + + *threads_per_block = mca_op_rocm_component.ro_max_threads_per_block[*device]; + *max_blocks = mca_op_rocm_component.ro_max_blocks[*device]; + +} + +static inline void device_op_post(void *source1, + int source1_device, + void *source2, + int source2_device, + void *orig_target, + void *target, + int target_device, + int count, + struct ompi_datatype_t *dtype, + int device, + opal_accelerator_stream_t *stream) +{ + if (MCA_ACCELERATOR_NO_DEVICE_ID == target_device) { + + size_t nbytes; + ompi_datatype_type_size(dtype, &nbytes); + nbytes *= count; + + CHECK(hipMemcpyDtoHAsync, (orig_target, (hipDeviceptr_t)target, nbytes, *(hipStream_t *)stream->stream)); + } + + if (MCA_ACCELERATOR_NO_DEVICE_ID == target_device) { + opal_accelerator.mem_release_stream(device, target, stream); + } + if (source1_device != device) { + opal_accelerator.mem_release_stream(device, source1, stream); + } + if (NULL != source2 && source2_device != device) { + opal_accelerator.mem_release_stream(device, source2, stream); + } +} + +#define FUNC(name, type_name, type) \ + static \ + void ompi_op_rocm_2buff_##name##_##type_name(const void *in, void *inout, int *count, \ + struct ompi_datatype_t **dtype, \ + int device, \ + opal_accelerator_stream_t *stream, \ + struct ompi_op_base_module_1_0_0_t *module) __opal_attribute_unused__; \ + static \ + void ompi_op_rocm_2buff_##name##_##type_name(const void *in, void *inout, int *count, \ + struct ompi_datatype_t **dtype, \ + int device, \ + opal_accelerator_stream_t *stream, \ + struct ompi_op_base_module_1_0_0_t *module) { \ + int threads_per_block, max_blocks; \ + int source_device, target_device; \ + type *source, *target; \ + int n = *count; \ + device_op_pre(in, (void**)&source, &source_device, NULL, NULL, NULL, \ + inout, (void**)&target, &target_device, \ + n, *dtype, \ + &threads_per_block, &max_blocks, &device, stream); \ + hipStream_t *custream = (hipStream_t*)stream->stream; \ + ompi_op_rocm_2buff_##name##_##type_name##_submit(source, target, n, threads_per_block, max_blocks, *custream);\ + device_op_post(source, source_device, NULL, -1, inout, target, target_device, n, *dtype, device, stream); \ + } + +/* concatenate type_name and type to avoid expansion (e.g., int32_t -> int) */ +#define OP_FUNC(name, type_name, type, ...) FUNC(name, __VA_ARGS__##type_name, __VA_ARGS__##type) + +/* reuse the macro above */ +#define FUNC_FUNC(name, type_name, type, ...) FUNC(name, __VA_ARGS__##type_name, __VA_ARGS__##type) + +/* + * Since all the functions in this file are essentially identical, we + * use a macro to substitute in names and types. The core operation + * in all functions that use this macro is the same. + * + * This macro is for minloc and maxloc + */ +#define LOC_FUNC(name, type_name) FUNC(name, type_name, ompi_op_predefined_##type_name##_t) + +/* Dispatch Fortran types to C types */ +#define FORT_INT_FUNC(name, type_name, type) \ + static \ + void ompi_op_rocm_2buff_##name##_##type_name(const void *in, void *inout, int *count, \ + struct ompi_datatype_t **dtype, \ + int device, \ + opal_accelerator_stream_t *stream, \ + struct ompi_op_base_module_1_0_0_t *module) { \ + \ + _Static_assert(sizeof(type) >= sizeof(int8_t) && sizeof(type) <= sizeof(int64_t), "Unsupported integer size (<1B, >8B)"); \ + switch(sizeof(type)) { \ + case sizeof(int8_t): \ + ompi_op_rocm_2buff_##name##_int8_t(in, inout, count, dtype, device, stream, module); \ + break; \ + case sizeof(int16_t): \ + ompi_op_rocm_2buff_##name##_int16_t(in, inout, count, dtype, device, stream, module); \ + break; \ + case sizeof(int32_t): \ + ompi_op_rocm_2buff_##name##_int32_t(in, inout, count, dtype, device, stream, module); \ + break; \ + case sizeof(int64_t): \ + ompi_op_rocm_2buff_##name##_int64_t(in, inout, count, dtype, device, stream, module); \ + break; \ + } \ + } + +/* Dispatch Fortran types to C types */ +#define FORT_FLOAT_FUNC(name, type_name, type) \ + static \ + void ompi_op_rocm_2buff_##name##_##type_name(const void *in, void *inout, int *count, \ + struct ompi_datatype_t **dtype, \ + int device, \ + opal_accelerator_stream_t *stream, \ + struct ompi_op_base_module_1_0_0_t *module) { \ + _Static_assert(sizeof(type) >= sizeof(float) && sizeof(type) <= sizeof(long double), "Unsupported float size (<4B, >8B)"); \ + switch(sizeof(type)) { \ + case sizeof(float): \ + ompi_op_rocm_2buff_##name##_float(in, inout, count, dtype, device, stream, module); \ + break; \ + case sizeof(double): \ + ompi_op_rocm_2buff_##name##_double(in, inout, count, dtype, device, stream, module); \ + break; \ + case sizeof(long double): \ + ompi_op_rocm_2buff_##name##_long_double(in, inout, count, dtype, device, stream, module); \ + break; \ + } \ + } + + +#define FORT_LOC_INT_FUNC(name, type_name, type) \ + static \ + void ompi_op_rocm_2buff_##name##_##type_name(const void *in, void *inout, int *count, \ + struct ompi_datatype_t **dtype, \ + int device, \ + opal_accelerator_stream_t *stream, \ + struct ompi_op_base_module_1_0_0_t *module) { \ + \ + _Static_assert(sizeof(type) >= sizeof(int8_t) && sizeof(type) <= sizeof(int64_t), \ + "Unsupported Fortran integer size"); \ + switch(sizeof(type)) { \ + case sizeof(int8_t): \ + ompi_op_rocm_2buff_##name##_2int8(in, inout, count, dtype, device, stream, module); \ + break; \ + case sizeof(int16_t): \ + ompi_op_rocm_2buff_##name##_2int16(in, inout, count, dtype, device, stream, module); \ + break; \ + case sizeof(int32_t): \ + ompi_op_rocm_2buff_##name##_2int32(in, inout, count, dtype, device, stream, module); \ + break; \ + case sizeof(int64_t): \ + ompi_op_rocm_2buff_##name##_2int64(in, inout, count, dtype, device, stream, module); \ + break; \ + } \ + } + +#define FORT_LOC_FLOAT_FUNC(name, type_name, type) \ + static \ + void ompi_op_rocm_2buff_##name##_##type_name(const void *in, void *inout, int *count, \ + struct ompi_datatype_t **dtype, \ + int device, \ + opal_accelerator_stream_t *stream, \ + struct ompi_op_base_module_1_0_0_t *module) { \ + _Static_assert(sizeof(type) >= sizeof(float) && sizeof(type) <= sizeof(double), \ + "Unsupported Fortran float size"); \ + switch(sizeof(type)) { \ + case sizeof(float): \ + ompi_op_rocm_2buff_##name##_2float(in, inout, count, dtype, device, stream, module); \ + break; \ + case sizeof(double): \ + ompi_op_rocm_2buff_##name##_2double(in, inout, count, dtype, device, stream, module); \ + break; \ + } \ + } + + +/************************************************************************* + * Max + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) > (b) ? (a) : (b)) +/* C integer */ +FUNC_FUNC(max, int8_t, int8_t) +FUNC_FUNC(max, uint8_t, uint8_t) +FUNC_FUNC(max, int16_t, int16_t) +FUNC_FUNC(max, uint16_t, uint16_t) +FUNC_FUNC(max, int32_t, int32_t) +FUNC_FUNC(max, uint32_t, uint32_t) +FUNC_FUNC(max, int64_t, int64_t) +FUNC_FUNC(max, uint64_t, uint64_t) +FUNC_FUNC(max, long, long) +FUNC_FUNC(max, ulong, unsigned long) + +/* Fortran integer */ +#if OMPI_HAVE_FORTRAN_INTEGER +FORT_INT_FUNC(max, fortran_integer, ompi_fortran_integer_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER1 +FORT_INT_FUNC(max, fortran_integer1, ompi_fortran_integer1_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER2 +FORT_INT_FUNC(max, fortran_integer2, ompi_fortran_integer2_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER4 +FORT_INT_FUNC(max, fortran_integer4, ompi_fortran_integer4_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER8 +FORT_INT_FUNC(max, fortran_integer8, ompi_fortran_integer8_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER16 +FORT_INT_FUNC(max, fortran_integer16, ompi_fortran_integer16_t) +#endif + +FUNC_FUNC(max, float, float) +FUNC_FUNC(max, double, double) +FUNC_FUNC(max, long_double, long double) +#if OMPI_HAVE_FORTRAN_REAL +FORT_FLOAT_FUNC(max, fortran_real, ompi_fortran_real_t) +#endif +#if OMPI_HAVE_FORTRAN_DOUBLE_PRECISION +FORT_FLOAT_FUNC(max, fortran_double_precision, ompi_fortran_double_precision_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL2 +FORT_FLOAT_FUNC(max, fortran_real2, ompi_fortran_real2_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL4 +FORT_FLOAT_FUNC(max, fortran_real4, ompi_fortran_real4_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL8 +FORT_FLOAT_FUNC(max, fortran_real8, ompi_fortran_real8_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL16 && OMPI_REAL16_MATCHES_C +FORT_FLOAT_FUNC(max, fortran_real16, ompi_fortran_real16_t) +#endif + + +/************************************************************************* + * Min + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) < (b) ? (a) : (b)) +/* C integer */ +FUNC_FUNC(min, int8_t, int8_t) +FUNC_FUNC(min, uint8_t, uint8_t) +FUNC_FUNC(min, int16_t, int16_t) +FUNC_FUNC(min, uint16_t, uint16_t) +FUNC_FUNC(min, int32_t, int32_t) +FUNC_FUNC(min, uint32_t, uint32_t) +FUNC_FUNC(min, int64_t, int64_t) +FUNC_FUNC(min, uint64_t, uint64_t) +FUNC_FUNC(min, long, long) +FUNC_FUNC(min, ulong, unsigned long) + +/* Fortran integer */ +#if OMPI_HAVE_FORTRAN_INTEGER +FORT_INT_FUNC(min, fortran_integer, ompi_fortran_integer_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER1 +FORT_INT_FUNC(min, fortran_integer1, ompi_fortran_integer1_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER2 +FORT_INT_FUNC(min, fortran_integer2, ompi_fortran_integer2_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER4 +FORT_INT_FUNC(min, fortran_integer4, ompi_fortran_integer4_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER8 +FORT_INT_FUNC(min, fortran_integer8, ompi_fortran_integer8_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER16 +FORT_INT_FUNC(min, fortran_integer16, ompi_fortran_integer16_t) +#endif + +FUNC_FUNC(min, float, float) +FUNC_FUNC(min, double, double) +FUNC_FUNC(min, long_double, long double) +#if OMPI_HAVE_FORTRAN_REAL +FORT_FLOAT_FUNC(min, fortran_real, ompi_fortran_real_t) +#endif +#if OMPI_HAVE_FORTRAN_DOUBLE_PRECISION +FORT_FLOAT_FUNC(min, fortran_double_precision, ompi_fortran_double_precision_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL2 +FORT_FLOAT_FUNC(min, fortran_real2, ompi_fortran_real2_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL4 +FORT_FLOAT_FUNC(min, fortran_real4, ompi_fortran_real4_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL8 +FORT_FLOAT_FUNC(min, fortran_real8, ompi_fortran_real8_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL16 && OMPI_REAL16_MATCHES_C +FORT_FLOAT_FUNC(min, fortran_real16, ompi_fortran_real16_t) +#endif + +/************************************************************************* + * Sum + *************************************************************************/ + +/* C integer */ +OP_FUNC(sum, int8_t, int8_t) +OP_FUNC(sum, uint8_t, uint8_t) +OP_FUNC(sum, int16_t, int16_t) +OP_FUNC(sum, uint16_t, uint16_t) +OP_FUNC(sum, int32_t, int32_t) +OP_FUNC(sum, uint32_t, uint32_t) +OP_FUNC(sum, int64_t, int64_t) +OP_FUNC(sum, uint64_t, uint64_t) +OP_FUNC(sum, long, long) +OP_FUNC(sum, ulong, unsigned long) + +/* Fortran integer */ +#if OMPI_HAVE_FORTRAN_INTEGER +FORT_INT_FUNC(sum, fortran_integer, ompi_fortran_integer_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER1 +FORT_INT_FUNC(sum, fortran_integer1, ompi_fortran_integer1_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER2 +FORT_INT_FUNC(sum, fortran_integer2, ompi_fortran_integer2_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER4 +FORT_INT_FUNC(sum, fortran_integer4, ompi_fortran_integer4_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER8 +FORT_INT_FUNC(sum, fortran_integer8, ompi_fortran_integer8_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER16 +FORT_INT_FUNC(sum, fortran_integer16, ompi_fortran_integer16_t) +#endif + +OP_FUNC(sum, float, float) +OP_FUNC(sum, double, double) +OP_FUNC(sum, long_double, long double) +#if OMPI_HAVE_FORTRAN_REAL +FORT_FLOAT_FUNC(sum, fortran_real, ompi_fortran_real_t) +#endif +#if OMPI_HAVE_FORTRAN_DOUBLE_PRECISION +FORT_FLOAT_FUNC(sum, fortran_double_precision, ompi_fortran_double_precision_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL2 +FORT_FLOAT_FUNC(sum, fortran_real2, ompi_fortran_real2_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL4 +FORT_FLOAT_FUNC(sum, fortran_real4, ompi_fortran_real4_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL8 +FORT_FLOAT_FUNC(sum, fortran_real8, ompi_fortran_real8_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL16 && OMPI_REAL16_MATCHES_C +FORT_FLOAT_FUNC(sum, fortran_real16, ompi_fortran_real16_t) +#endif +/* Complex */ +#if 0 +#if defined(HAVE_SHORT_FLOAT__COMPLEX) +OP_FUNC(sum, c_short_float_complex, short float _Complex) +#elif defined(HAVE_OPAL_SHORT_FLOAT_COMPLEX_T) +COMPLEX_SUM_FUNC(c_short_float_complex, opal_short_float_t) +#endif +OP_FUNC(sum, c_long_double_complex, long double _Complex) +#endif // 0 +#undef current_func +#define current_func(a, b) (hipCaddf(a,b)) +FUNC_FUNC(sum, c_float_complex, hipFloatComplex) +#undef current_func +#define current_func(a, b) (hipCadd(a,b)) +FUNC_FUNC(sum, c_double_complex, hipDoubleComplex) + +/************************************************************************* + * Product + *************************************************************************/ + +/* C integer */ +OP_FUNC(prod, int8_t, int8_t) +OP_FUNC(prod, uint8_t, uint8_t) +OP_FUNC(prod, int16_t, int16_t) +OP_FUNC(prod, uint16_t, uint16_t) +OP_FUNC(prod, int32_t, int32_t) +OP_FUNC(prod, uint32_t, uint32_t) +OP_FUNC(prod, int64_t, int64_t) +OP_FUNC(prod, uint64_t, uint64_t) +OP_FUNC(prod, long, long) +OP_FUNC(prod, ulong, unsigned long) + +/* Fortran integer */ +#if OMPI_HAVE_FORTRAN_INTEGER +FORT_INT_FUNC(prod, fortran_integer, ompi_fortran_integer_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER1 +FORT_INT_FUNC(prod, fortran_integer1, ompi_fortran_integer1_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER2 +FORT_INT_FUNC(prod, fortran_integer2, ompi_fortran_integer2_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER4 +FORT_INT_FUNC(prod, fortran_integer4, ompi_fortran_integer4_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER8 +FORT_INT_FUNC(prod, fortran_integer8, ompi_fortran_integer8_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER16 +FORT_INT_FUNC(prod, fortran_integer16, ompi_fortran_integer16_t) +#endif +/* Floating point */ + +OP_FUNC(prod, float, float) +OP_FUNC(prod, double, double) +OP_FUNC(prod, long_double, long double) +#if OMPI_HAVE_FORTRAN_REAL +FORT_FLOAT_FUNC(prod, fortran_real, ompi_fortran_real_t) +#endif +#if OMPI_HAVE_FORTRAN_DOUBLE_PRECISION +FORT_FLOAT_FUNC(prod, fortran_double_precision, ompi_fortran_double_precision_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL2 +FORT_FLOAT_FUNC(prod, fortran_real2, ompi_fortran_real2_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL4 +FORT_FLOAT_FUNC(prod, fortran_real4, ompi_fortran_real4_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL8 +FORT_FLOAT_FUNC(prod, fortran_real8, ompi_fortran_real8_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL16 && OMPI_REAL16_MATCHES_C +FORT_FLOAT_FUNC(prod, fortran_real16, ompi_fortran_real16_t) +#endif + +/* Complex */ +#if 0 +#if defined(HAVE_SHORT_FLOAT__COMPLEX) +OP_FUNC(prod, c_short_float_complex, short float _Complex) +#elif defined(HAVE_OPAL_SHORT_FLOAT_COMPLEX_T) +COMPLEX_PROD_FUNC(c_short_float_complex, opal_short_float_t) +#endif +OP_FUNC(prod, c_long_double_complex, long double _Complex) +#endif // 0 +#undef current_func +#define current_func(a, b) (hipCmulf(a,b)) +FUNC_FUNC(prod, c_float_complex, hipFloatComplex) +#undef current_func +#define current_func(a, b) (hipCmul(a,b)) +FUNC_FUNC(prod, c_double_complex, hipDoubleComplex) + +/************************************************************************* + * Logical AND + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) && (b)) +/* C integer */ +FUNC_FUNC(land, int8_t, int8_t) +FUNC_FUNC(land, uint8_t, uint8_t) +FUNC_FUNC(land, int16_t, int16_t) +FUNC_FUNC(land, uint16_t, uint16_t) +FUNC_FUNC(land, int32_t, int32_t) +FUNC_FUNC(land, uint32_t, uint32_t) +FUNC_FUNC(land, int64_t, int64_t) +FUNC_FUNC(land, uint64_t, uint64_t) +FUNC_FUNC(land, long, long) +FUNC_FUNC(land, ulong, unsigned long) + +/* Logical */ +#if OMPI_HAVE_FORTRAN_LOGICAL +FORT_INT_FUNC(land, fortran_logical, ompi_fortran_logical_t) +#endif +/* C++ bool */ +FUNC_FUNC(land, bool, bool) + +/************************************************************************* + * Logical OR + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) || (b)) +/* C integer */ +FUNC_FUNC(lor, int8_t, int8_t) +FUNC_FUNC(lor, uint8_t, uint8_t) +FUNC_FUNC(lor, int16_t, int16_t) +FUNC_FUNC(lor, uint16_t, uint16_t) +FUNC_FUNC(lor, int32_t, int32_t) +FUNC_FUNC(lor, uint32_t, uint32_t) +FUNC_FUNC(lor, int64_t, int64_t) +FUNC_FUNC(lor, uint64_t, uint64_t) +FUNC_FUNC(lor, long, long) +FUNC_FUNC(lor, ulong, unsigned long) + +/* Logical */ +#if OMPI_HAVE_FORTRAN_LOGICAL +FORT_INT_FUNC(lor, fortran_logical, ompi_fortran_logical_t) +#endif +/* C++ bool */ +FUNC_FUNC(lor, bool, bool) + +/************************************************************************* + * Logical XOR + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a ? 1 : 0) ^ (b ? 1: 0)) +/* C integer */ +FUNC_FUNC(lxor, int8_t, int8_t) +FUNC_FUNC(lxor, uint8_t, uint8_t) +FUNC_FUNC(lxor, int16_t, int16_t) +FUNC_FUNC(lxor, uint16_t, uint16_t) +FUNC_FUNC(lxor, int32_t, int32_t) +FUNC_FUNC(lxor, uint32_t, uint32_t) +FUNC_FUNC(lxor, int64_t, int64_t) +FUNC_FUNC(lxor, uint64_t, uint64_t) +FUNC_FUNC(lxor, long, long) +FUNC_FUNC(lxor, ulong, unsigned long) + + +/* Logical */ +#if OMPI_HAVE_FORTRAN_LOGICAL +FORT_INT_FUNC(lxor, fortran_logical, ompi_fortran_logical_t) +#endif +/* C++ bool */ +FUNC_FUNC(lxor, bool, bool) + +/************************************************************************* + * Bitwise AND + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) & (b)) +/* C integer */ +FUNC_FUNC(band, int8_t, int8_t) +FUNC_FUNC(band, uint8_t, uint8_t) +FUNC_FUNC(band, int16_t, int16_t) +FUNC_FUNC(band, uint16_t, uint16_t) +FUNC_FUNC(band, int32_t, int32_t) +FUNC_FUNC(band, uint32_t, uint32_t) +FUNC_FUNC(band, int64_t, int64_t) +FUNC_FUNC(band, uint64_t, uint64_t) +FUNC_FUNC(band, long, long) +FUNC_FUNC(band, ulong, unsigned long) + +/* Fortran integer */ +#if OMPI_HAVE_FORTRAN_INTEGER +FORT_INT_FUNC(band, fortran_integer, ompi_fortran_integer_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER1 +FORT_INT_FUNC(band, fortran_integer1, ompi_fortran_integer1_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER2 +FORT_INT_FUNC(band, fortran_integer2, ompi_fortran_integer2_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER4 +FORT_INT_FUNC(band, fortran_integer4, ompi_fortran_integer4_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER8 +FORT_INT_FUNC(band, fortran_integer8, ompi_fortran_integer8_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER16 +FORT_INT_FUNC(band, fortran_integer16, ompi_fortran_integer16_t) +#endif +/* Byte */ +FUNC_FUNC(band, byte, char) + +/************************************************************************* + * Bitwise OR + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) | (b)) +/* C integer */ +FUNC_FUNC(bor, int8_t, int8_t) +FUNC_FUNC(bor, uint8_t, uint8_t) +FUNC_FUNC(bor, int16_t, int16_t) +FUNC_FUNC(bor, uint16_t, uint16_t) +FUNC_FUNC(bor, int32_t, int32_t) +FUNC_FUNC(bor, uint32_t, uint32_t) +FUNC_FUNC(bor, int64_t, int64_t) +FUNC_FUNC(bor, uint64_t, uint64_t) +FUNC_FUNC(bor, long, long) +FUNC_FUNC(bor, ulong, unsigned long) + +/* Fortran integer */ +#if OMPI_HAVE_FORTRAN_INTEGER +FORT_INT_FUNC(bor, fortran_integer, ompi_fortran_integer_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER1 +FORT_INT_FUNC(bor, fortran_integer1, ompi_fortran_integer1_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER2 +FORT_INT_FUNC(bor, fortran_integer2, ompi_fortran_integer2_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER4 +FORT_INT_FUNC(bor, fortran_integer4, ompi_fortran_integer4_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER8 +FORT_INT_FUNC(bor, fortran_integer8, ompi_fortran_integer8_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER16 +FORT_INT_FUNC(bor, fortran_integer16, ompi_fortran_integer16_t) +#endif +/* Byte */ +FUNC_FUNC(bor, byte, char) + +/************************************************************************* + * Bitwise XOR + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) ^ (b)) +/* C integer */ +FUNC_FUNC(bxor, int8_t, int8_t) +FUNC_FUNC(bxor, uint8_t, uint8_t) +FUNC_FUNC(bxor, int16_t, int16_t) +FUNC_FUNC(bxor, uint16_t, uint16_t) +FUNC_FUNC(bxor, int32_t, int32_t) +FUNC_FUNC(bxor, uint32_t, uint32_t) +FUNC_FUNC(bxor, int64_t, int64_t) +FUNC_FUNC(bxor, uint64_t, uint64_t) +FUNC_FUNC(bxor, long, long) +FUNC_FUNC(bxor, ulong, unsigned long) + +/* Fortran integer */ +#if OMPI_HAVE_FORTRAN_INTEGER +FORT_INT_FUNC(bxor, fortran_integer, ompi_fortran_integer_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER1 +FORT_INT_FUNC(bxor, fortran_integer1, ompi_fortran_integer1_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER2 +FORT_INT_FUNC(bxor, fortran_integer2, ompi_fortran_integer2_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER4 +FORT_INT_FUNC(bxor, fortran_integer4, ompi_fortran_integer4_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER8 +FORT_INT_FUNC(bxor, fortran_integer8, ompi_fortran_integer8_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER16 +FORT_INT_FUNC(bxor, fortran_integer16, ompi_fortran_integer16_t) +#endif +/* Byte */ +FUNC_FUNC(bxor, byte, char) + +/************************************************************************* + * Max location + *************************************************************************/ + +LOC_FUNC(maxloc, float_int) +LOC_FUNC(maxloc, double_int) +LOC_FUNC(maxloc, long_int) +LOC_FUNC(maxloc, 2int) +LOC_FUNC(maxloc, short_int) +LOC_FUNC(maxloc, long_double_int) + +/* Fortran compat types */ +LOC_FUNC(maxloc, 2float) +LOC_FUNC(maxloc, 2double) +LOC_FUNC(maxloc, 2int8) +LOC_FUNC(maxloc, 2int16) +LOC_FUNC(maxloc, 2int32) +LOC_FUNC(maxloc, 2int64) + +/* Fortran integer */ +#if OMPI_HAVE_FORTRAN_INTEGER +FORT_LOC_INT_FUNC(maxloc, 2integer, ompi_fortran_integer_t) +#endif +/* Fortran float */ +#if OMPI_HAVE_FORTRAN_REAL +FORT_LOC_FLOAT_FUNC(maxloc, 2real, ompi_fortran_real_t) +#endif +#if OMPI_HAVE_FORTRAN_DOUBLE_PRECISION +FORT_LOC_FLOAT_FUNC(maxloc, 2double_precision, ompi_fortran_double_precision_t) +#endif + + +/************************************************************************* + * Min location + *************************************************************************/ + +LOC_FUNC(minloc, float_int) +LOC_FUNC(minloc, double_int) +LOC_FUNC(minloc, long_int) +LOC_FUNC(minloc, 2int) +LOC_FUNC(minloc, short_int) +LOC_FUNC(minloc, long_double_int) + +/* Fortran compat types */ +LOC_FUNC(minloc, 2float) +LOC_FUNC(minloc, 2double) +LOC_FUNC(minloc, 2int8) +LOC_FUNC(minloc, 2int16) +LOC_FUNC(minloc, 2int32) +LOC_FUNC(minloc, 2int64) + +/* Fortran integer */ +#if OMPI_HAVE_FORTRAN_INTEGER +FORT_LOC_INT_FUNC(minloc, 2integer, ompi_fortran_integer_t) +#endif +/* Fortran float */ +#if OMPI_HAVE_FORTRAN_REAL +FORT_LOC_FLOAT_FUNC(minloc, 2real, ompi_fortran_real_t) +#endif +#if OMPI_HAVE_FORTRAN_DOUBLE_PRECISION +FORT_LOC_FLOAT_FUNC(minloc, 2double_precision, ompi_fortran_double_precision_t) +#endif + +/* + * This is a three buffer (2 input and 1 output) version of the reduction + * routines, needed for some optimizations. + */ +#define FUNC_3BUF(name, type_name, type) \ + static \ + void ompi_op_rocm_3buff_##name##_##type_name(const void *in1, const void *in2, void *out, int *count, \ + struct ompi_datatype_t **dtype, \ + int device, \ + opal_accelerator_stream_t *stream, \ + struct ompi_op_base_module_1_0_0_t *module) { \ + int threads_per_block, max_blocks; \ + int source1_device, source2_device, target_device; \ + type *source1, *source2, *target; \ + int n = *count; \ + device_op_pre(in1, (void**)&source1, &source1_device, \ + in2, (void**)&source2, &source2_device, \ + out, (void**)&target, &target_device, \ + n, *dtype, \ + &threads_per_block, &max_blocks, &device, stream); \ + hipStream_t *hipstream = (hipStream_t*)stream->stream; \ + ompi_op_rocm_3buff_##name##_##type_name##_submit(source1, source2, target, n, threads_per_block, max_blocks, *hipstream);\ + device_op_post(source1, source1_device, source2, source2_device, out, target, target_device, n, *dtype, device, stream);\ + } + + +#define OP_FUNC_3BUF(name, type_name, type, ...) FUNC_3BUF(name, __VA_ARGS__##type_name, __VA_ARGS__##type) + +/* reuse the macro above, no work is actually done so we don't care about the func */ +#define FUNC_FUNC_3BUF(name, type_name, type, ...) FUNC_3BUF(name, __VA_ARGS__##type_name, __VA_ARGS__##type) + +/* + * Since all the functions in this file are essentially identical, we + * use a macro to substitute in names and types. The core operation + * in all functions that use this macro is the same. + * + * This macro is for minloc and maxloc + */ +#define LOC_FUNC_3BUF(name, type_name) FUNC_3BUF(name, type_name, ompi_op_predefined_##type_name##_t) + + +/* Dispatch Fortran types to C types */ +#define FORT_INT_FUNC_3BUF(name, type_name, type) \ + static \ + void ompi_op_rocm_3buff_##name##_##type_name(const void *in1, const void *in2, void *out, int *count, \ + struct ompi_datatype_t **dtype, \ + int device, \ + opal_accelerator_stream_t *stream, \ + struct ompi_op_base_module_1_0_0_t *module) { \ + \ + _Static_assert(sizeof(type) >= sizeof(int8_t) && sizeof(type) <= sizeof(int64_t), "Unsupported integer size (<1B, >8B)"); \ + switch(sizeof(type)) { \ + case sizeof(int8_t): \ + ompi_op_rocm_3buff_##name##_int8_t(in1, in2, out, count, dtype, device, stream, module); \ + break; \ + case sizeof(int16_t): \ + ompi_op_rocm_3buff_##name##_int16_t(in1, in2, out, count, dtype, device, stream, module); \ + break; \ + case sizeof(int32_t): \ + ompi_op_rocm_3buff_##name##_int32_t(in1, in2, out, count, dtype, device, stream, module); \ + break; \ + case sizeof(int64_t): \ + ompi_op_rocm_3buff_##name##_int64_t(in1, in2, out, count, dtype, device, stream, module); \ + break; \ + } \ + } + +/* Dispatch Fortran types to C types */ +#define FORT_FLOAT_FUNC_3BUF(name, type_name, type) \ + static \ + void ompi_op_rocm_3buff_##name##_##type_name(const void *in1, const void *in2, void *out, int *count, \ + struct ompi_datatype_t **dtype, \ + int device, \ + opal_accelerator_stream_t *stream, \ + struct ompi_op_base_module_1_0_0_t *module) { \ + _Static_assert(sizeof(type) >= sizeof(float) && sizeof(type) <= sizeof(long double), "Unsupported float size (<4B, >8B)"); \ + switch(sizeof(type)) { \ + case sizeof(float): \ + ompi_op_rocm_3buff_##name##_float(in1, in2, out, count, dtype, device, stream, module); \ + break; \ + case sizeof(double): \ + ompi_op_rocm_3buff_##name##_double(in1, in2, out, count, dtype, device, stream, module); \ + break; \ + case sizeof(long double): \ + ompi_op_rocm_3buff_##name##_long_double(in1, in2, out, count, dtype, device, stream, module); \ + break; \ + } \ + } + +#define FORT_LOC_INT_FUNC_3BUF(name, type_name, type) \ + static \ + void ompi_op_rocm_3buff_##name##_##type_name(const void *in1, const void *in2, void *inout, int *count, \ + struct ompi_datatype_t **dtype, \ + int device, \ + opal_accelerator_stream_t *stream, \ + struct ompi_op_base_module_1_0_0_t *module) { \ + \ + _Static_assert(sizeof(type) >= sizeof(int8_t) && sizeof(type) <= sizeof(int64_t), \ + "Unsupported Fortran integer size"); \ + switch(sizeof(type)) { \ + case sizeof(int8_t): \ + ompi_op_rocm_3buff_##name##_2int8(in1, in2, inout, count, dtype, device, stream, module); \ + break; \ + case sizeof(int16_t): \ + ompi_op_rocm_3buff_##name##_2int16(in1, in2, inout, count, dtype, device, stream, module); \ + break; \ + case sizeof(int32_t): \ + ompi_op_rocm_3buff_##name##_2int32(in1, in2, inout, count, dtype, device, stream, module); \ + break; \ + case sizeof(int64_t): \ + ompi_op_rocm_3buff_##name##_2int64(in1, in2, inout, count, dtype, device, stream, module); \ + break; \ + } \ + } + +#define FORT_LOC_FLOAT_FUNC_3BUF(name, type_name, type) \ + static \ + void ompi_op_rocm_3buff_##name##_##type_name(const void *in1, const void *in2, void *inout, int *count, \ + struct ompi_datatype_t **dtype, \ + int device, \ + opal_accelerator_stream_t *stream, \ + struct ompi_op_base_module_1_0_0_t *module) { \ + _Static_assert(sizeof(type) >= sizeof(float) && sizeof(type) <= sizeof(double), \ + "Unsupported Fortran float size"); \ + switch(sizeof(type)) { \ + case sizeof(float): \ + ompi_op_rocm_3buff_##name##_2float(in1, in2, inout, count, dtype, device, stream, module); \ + break; \ + case sizeof(double): \ + ompi_op_rocm_3buff_##name##_2double(in1, in2, inout, count, dtype, device, stream, module); \ + break; \ + } \ + } + + +/************************************************************************* + * Max + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) > (b) ? (a) : (b)) +/* C integer */ +FUNC_FUNC_3BUF(max, int8_t, int8_t) +FUNC_FUNC_3BUF(max, uint8_t, uint8_t) +FUNC_FUNC_3BUF(max, int16_t, int16_t) +FUNC_FUNC_3BUF(max, uint16_t, uint16_t) +FUNC_FUNC_3BUF(max, int32_t, int32_t) +FUNC_FUNC_3BUF(max, uint32_t, uint32_t) +FUNC_FUNC_3BUF(max, int64_t, int64_t) +FUNC_FUNC_3BUF(max, uint64_t, uint64_t) +FUNC_FUNC_3BUF(max, long, long) +FUNC_FUNC_3BUF(max, ulong, unsigned long) + +/* Fortran integer */ +#if OMPI_HAVE_FORTRAN_INTEGER +FORT_INT_FUNC_3BUF(max, fortran_integer, ompi_fortran_integer_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER1 +FORT_INT_FUNC_3BUF(max, fortran_integer1, ompi_fortran_integer1_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER2 +FORT_INT_FUNC_3BUF(max, fortran_integer2, ompi_fortran_integer2_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER4 +FORT_INT_FUNC_3BUF(max, fortran_integer4, ompi_fortran_integer4_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER8 +FORT_INT_FUNC_3BUF(max, fortran_integer8, ompi_fortran_integer8_t) +#endif +/* Floating point */ +#if 0 +#if defined(HAVE_SHORT_FLOAT) +FUNC_FUNC_3BUF(max, short_float, short float) +#elif defined(HAVE_OPAL_SHORT_FLOAT_T) +FUNC_FUNC_3BUF(max, short_float, opal_short_float_t) +#endif +#endif // 0 +FUNC_FUNC_3BUF(max, float, float) +FUNC_FUNC_3BUF(max, double, double) +FUNC_FUNC_3BUF(max, long_double, long double) +#if OMPI_HAVE_FORTRAN_REAL +FORT_FLOAT_FUNC_3BUF(max, fortran_real, ompi_fortran_real_t) +#endif +#if OMPI_HAVE_FORTRAN_DOUBLE_PRECISION +FORT_FLOAT_FUNC_3BUF(max, fortran_double_precision, ompi_fortran_double_precision_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL2 +FORT_FLOAT_FUNC_3BUF(max, fortran_real2, ompi_fortran_real2_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL4 +FORT_FLOAT_FUNC_3BUF(max, fortran_real4, ompi_fortran_real4_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL8 +FORT_FLOAT_FUNC_3BUF(max, fortran_real8, ompi_fortran_real8_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL16 && OMPI_REAL16_MATCHES_C +FORT_FLOAT_FUNC_3BUF(max, fortran_real16, ompi_fortran_real16_t) +#endif + + +/************************************************************************* + * Min + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) < (b) ? (a) : (b)) +/* C integer */ +FUNC_FUNC_3BUF(min, int8_t, int8_t) +FUNC_FUNC_3BUF(min, uint8_t, uint8_t) +FUNC_FUNC_3BUF(min, int16_t, int16_t) +FUNC_FUNC_3BUF(min, uint16_t, uint16_t) +FUNC_FUNC_3BUF(min, int32_t, int32_t) +FUNC_FUNC_3BUF(min, uint32_t, uint32_t) +FUNC_FUNC_3BUF(min, int64_t, int64_t) +FUNC_FUNC_3BUF(min, uint64_t, uint64_t) +FUNC_FUNC_3BUF(min, long, long) +FUNC_FUNC_3BUF(min, ulong, unsigned long) + +/* Fortran integer */ +#if OMPI_HAVE_FORTRAN_INTEGER +FORT_INT_FUNC_3BUF(min, fortran_integer, ompi_fortran_integer_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER1 +FORT_INT_FUNC_3BUF(min, fortran_integer1, ompi_fortran_integer1_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER2 +FORT_INT_FUNC_3BUF(min, fortran_integer2, ompi_fortran_integer2_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER4 +FORT_INT_FUNC_3BUF(min, fortran_integer4, ompi_fortran_integer4_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER8 +FORT_INT_FUNC_3BUF(min, fortran_integer8, ompi_fortran_integer8_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER16 +FORT_INT_FUNC_3BUF(min, fortran_integer16, ompi_fortran_integer16_t) +#endif +/* Floating point */ +#if 0 +#if defined(HAVE_SHORT_FLOAT) +FUNC_FUNC_3BUF(min, short_float, short float) +#elif defined(HAVE_OPAL_SHORT_FLOAT_T) +FUNC_FUNC_3BUF(min, short_float, opal_short_float_t) +#endif +#endif // 0 +FUNC_FUNC_3BUF(min, float, float) +FUNC_FUNC_3BUF(min, double, double) +FUNC_FUNC_3BUF(min, long_double, long double) +#if OMPI_HAVE_FORTRAN_REAL +FORT_FLOAT_FUNC_3BUF(min, fortran_real, ompi_fortran_real_t) +#endif +#if OMPI_HAVE_FORTRAN_DOUBLE_PRECISION +FORT_FLOAT_FUNC_3BUF(min, fortran_double_precision, ompi_fortran_double_precision_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL2 +FORT_FLOAT_FUNC_3BUF(min, fortran_real2, ompi_fortran_real2_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL4 +FORT_FLOAT_FUNC_3BUF(min, fortran_real4, ompi_fortran_real4_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL8 +FORT_FLOAT_FUNC_3BUF(min, fortran_real8, ompi_fortran_real8_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL16 && OMPI_REAL16_MATCHES_C +FORT_FLOAT_FUNC_3BUF(min, fortran_real16, ompi_fortran_real16_t) +#endif + +/************************************************************************* + * Sum + *************************************************************************/ + +/* C integer */ +OP_FUNC_3BUF(sum, int8_t, int8_t) +OP_FUNC_3BUF(sum, uint8_t, uint8_t) +OP_FUNC_3BUF(sum, int16_t, int16_t) +OP_FUNC_3BUF(sum, uint16_t, uint16_t) +OP_FUNC_3BUF(sum, int32_t, int32_t) +OP_FUNC_3BUF(sum, uint32_t, uint32_t) +OP_FUNC_3BUF(sum, int64_t, int64_t) +OP_FUNC_3BUF(sum, uint64_t, uint64_t) +OP_FUNC_3BUF(sum, long, long) +OP_FUNC_3BUF(sum, ulong, unsigned long) + +/* Fortran integer */ +#if OMPI_HAVE_FORTRAN_INTEGER +FORT_INT_FUNC_3BUF(sum, fortran_integer, ompi_fortran_integer_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER1 +FORT_INT_FUNC_3BUF(sum, fortran_integer1, ompi_fortran_integer1_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER2 +FORT_INT_FUNC_3BUF(sum, fortran_integer2, ompi_fortran_integer2_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER4 +FORT_INT_FUNC_3BUF(sum, fortran_integer4, ompi_fortran_integer4_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER8 +FORT_INT_FUNC_3BUF(sum, fortran_integer8, ompi_fortran_integer8_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER16 +FORT_INT_FUNC_3BUF(sum, fortran_integer16, ompi_fortran_integer16_t) +#endif +/* Floating point */ +#if 0 +#if defined(HAVE_SHORT_FLOAT) +OP_FUNC_3BUF(sum, short_float, short float) +#elif defined(HAVE_OPAL_SHORT_FLOAT_T) +OP_FUNC_3BUF(sum, short_float, opal_short_float_t) +#endif +#endif // 0 +OP_FUNC_3BUF(sum, float, float) +OP_FUNC_3BUF(sum, double, double) +OP_FUNC_3BUF(sum, long_double, long double) +#if OMPI_HAVE_FORTRAN_REAL +FORT_FLOAT_FUNC_3BUF(sum, fortran_real, ompi_fortran_real_t) +#endif +#if OMPI_HAVE_FORTRAN_DOUBLE_PRECISION +FORT_FLOAT_FUNC_3BUF(sum, fortran_double_precision, ompi_fortran_double_precision_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL2 +FORT_FLOAT_FUNC_3BUF(sum, fortran_real2, ompi_fortran_real2_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL4 +FORT_FLOAT_FUNC_3BUF(sum, fortran_real4, ompi_fortran_real4_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL8 +FORT_FLOAT_FUNC_3BUF(sum, fortran_real8, ompi_fortran_real8_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL16 && OMPI_REAL16_MATCHES_C +FORT_FLOAT_FUNC_3BUF(sum, fortran_real16, ompi_fortran_real16_t) +#endif +/* Complex */ +#if 0 +#if defined(HAVE_SHORT_FLOAT__COMPLEX) +OP_FUNC_3BUF(sum, c_short_float_complex, short float _Complex) +#elif defined(HAVE_OPAL_SHORT_FLOAT_COMPLEX_T) +COMPLEX_SUM_FUNC_3BUF(c_short_float_complex, opal_short_float_t) +#endif +OP_FUNC_3BUF(sum, c_long_double_complex, long double _Complex) +#endif // 0 +#undef current_func +#define current_func(a, b) (hipCaddf(a,b)) +FUNC_FUNC_3BUF(sum, c_float_complex, hipFloatComplex) +#undef current_func +#define current_func(a, b) (hipCadd(a,b)) +FUNC_FUNC_3BUF(sum, c_double_complex, hipDoubleComplex) + +/************************************************************************* + * Product + *************************************************************************/ + +/* C integer */ +OP_FUNC_3BUF(prod, int8_t, int8_t) +OP_FUNC_3BUF(prod, uint8_t, uint8_t) +OP_FUNC_3BUF(prod, int16_t, int16_t) +OP_FUNC_3BUF(prod, uint16_t, uint16_t) +OP_FUNC_3BUF(prod, int32_t, int32_t) +OP_FUNC_3BUF(prod, uint32_t, uint32_t) +OP_FUNC_3BUF(prod, int64_t, int64_t) +OP_FUNC_3BUF(prod, uint64_t, uint64_t) +OP_FUNC_3BUF(prod, long, long) +OP_FUNC_3BUF(prod, ulong, unsigned long) + +/* Fortran integer */ +#if OMPI_HAVE_FORTRAN_INTEGER +FORT_INT_FUNC_3BUF(prod, fortran_integer, ompi_fortran_integer_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER1 +FORT_INT_FUNC_3BUF(prod, fortran_integer1, ompi_fortran_integer1_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER2 +FORT_INT_FUNC_3BUF(prod, fortran_integer2, ompi_fortran_integer2_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER4 +FORT_INT_FUNC_3BUF(prod, fortran_integer4, ompi_fortran_integer4_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER8 +FORT_INT_FUNC_3BUF(prod, fortran_integer8, ompi_fortran_integer8_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER16 +FORT_INT_FUNC_3BUF(prod, fortran_integer16, ompi_fortran_integer16_t) +#endif +/* Floating point */ +#if 0 +#if defined(HAVE_SHORT_FLOAT) +FORT_FLOAT_FUNC_3BUF(prod, short_float, short float) +#elif defined(HAVE_OPAL_SHORT_FLOAT_T) +FORT_FLOAT_FUNC_3BUF(prod, short_float, opal_short_float_t) +#endif +#endif // 0 +OP_FUNC_3BUF(prod, float, float) +OP_FUNC_3BUF(prod, double, double) +OP_FUNC_3BUF(prod, long_double, long double) +#if OMPI_HAVE_FORTRAN_REAL +FORT_FLOAT_FUNC_3BUF(prod, fortran_real, ompi_fortran_real_t) +#endif +#if OMPI_HAVE_FORTRAN_DOUBLE_PRECISION +FORT_FLOAT_FUNC_3BUF(prod, fortran_double_precision, ompi_fortran_double_precision_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL2 +FORT_FLOAT_FUNC_3BUF(prod, fortran_real2, ompi_fortran_real2_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL4 +FORT_FLOAT_FUNC_3BUF(prod, fortran_real4, ompi_fortran_real4_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL8 +FORT_FLOAT_FUNC_3BUF(prod, fortran_real8, ompi_fortran_real8_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL16 && OMPI_REAL16_MATCHES_C +FORT_FLOAT_FUNC_3BUF(prod, fortran_real16, ompi_fortran_real16_t) +#endif +/* Complex */ +#if 0 +#if defined(HAVE_SHORT_FLOAT__COMPLEX) +OP_FUNC_3BUF(prod, c_short_float_complex, short float _Complex) +#elif defined(HAVE_OPAL_SHORT_FLOAT_COMPLEX_T) +COMPLEX_PROD_FUNC_3BUF(c_short_float_complex, opal_short_float_t) +#endif +OP_FUNC_3BUF(prod, c_long_double_complex, long double _Complex) +#endif // 0 +#undef current_func +#define current_func(a, b) (hipCmulf(a,b)) +FUNC_FUNC_3BUF(prod, c_float_complex, hipFloatComplex) +#undef current_func +#define current_func(a, b) (hipCmul(a,b)) +FUNC_FUNC_3BUF(prod, c_double_complex, hipDoubleComplex) + +/************************************************************************* + * Logical AND + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) && (b)) +/* C integer */ +FUNC_FUNC_3BUF(land, int8_t, int8_t) +FUNC_FUNC_3BUF(land, uint8_t, uint8_t) +FUNC_FUNC_3BUF(land, int16_t, int16_t) +FUNC_FUNC_3BUF(land, uint16_t, uint16_t) +FUNC_FUNC_3BUF(land, int32_t, int32_t) +FUNC_FUNC_3BUF(land, uint32_t, uint32_t) +FUNC_FUNC_3BUF(land, int64_t, int64_t) +FUNC_FUNC_3BUF(land, uint64_t, uint64_t) +FUNC_FUNC_3BUF(land, long, long) +FUNC_FUNC_3BUF(land, ulong, unsigned long) + +/* Logical */ +#if OMPI_HAVE_FORTRAN_LOGICAL +FORT_INT_FUNC_3BUF(land, fortran_logical, ompi_fortran_logical_t) +#endif +/* C++ bool */ +FUNC_FUNC_3BUF(land, bool, bool) + +/************************************************************************* + * Logical OR + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) || (b)) +/* C integer */ +FUNC_FUNC_3BUF(lor, int8_t, int8_t) +FUNC_FUNC_3BUF(lor, uint8_t, uint8_t) +FUNC_FUNC_3BUF(lor, int16_t, int16_t) +FUNC_FUNC_3BUF(lor, uint16_t, uint16_t) +FUNC_FUNC_3BUF(lor, int32_t, int32_t) +FUNC_FUNC_3BUF(lor, uint32_t, uint32_t) +FUNC_FUNC_3BUF(lor, int64_t, int64_t) +FUNC_FUNC_3BUF(lor, uint64_t, uint64_t) +FUNC_FUNC_3BUF(lor, long, long) +FUNC_FUNC_3BUF(lor, ulong, unsigned long) + +/* Logical */ +#if OMPI_HAVE_FORTRAN_LOGICAL +FORT_INT_FUNC_3BUF(lor, fortran_logical, ompi_fortran_logical_t) +#endif +/* C++ bool */ +FUNC_FUNC_3BUF(lor, bool, bool) + +/************************************************************************* + * Logical XOR + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a ? 1 : 0) ^ (b ? 1: 0)) +/* C integer */ +FUNC_FUNC_3BUF(lxor, int8_t, int8_t) +FUNC_FUNC_3BUF(lxor, uint8_t, uint8_t) +FUNC_FUNC_3BUF(lxor, int16_t, int16_t) +FUNC_FUNC_3BUF(lxor, uint16_t, uint16_t) +FUNC_FUNC_3BUF(lxor, int32_t, int32_t) +FUNC_FUNC_3BUF(lxor, uint32_t, uint32_t) +FUNC_FUNC_3BUF(lxor, int64_t, int64_t) +FUNC_FUNC_3BUF(lxor, uint64_t, uint64_t) +FUNC_FUNC_3BUF(lxor, long, long) +FUNC_FUNC_3BUF(lxor, ulong, unsigned long) + +/* Logical */ +#if OMPI_HAVE_FORTRAN_LOGICAL +FORT_INT_FUNC_3BUF(lxor, fortran_logical, ompi_fortran_logical_t) +#endif +/* C++ bool */ +FUNC_FUNC_3BUF(lxor, bool, bool) + +/************************************************************************* + * Bitwise AND + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) & (b)) +/* C integer */ +FUNC_FUNC_3BUF(band, int8_t, int8_t) +FUNC_FUNC_3BUF(band, uint8_t, uint8_t) +FUNC_FUNC_3BUF(band, int16_t, int16_t) +FUNC_FUNC_3BUF(band, uint16_t, uint16_t) +FUNC_FUNC_3BUF(band, int32_t, int32_t) +FUNC_FUNC_3BUF(band, uint32_t, uint32_t) +FUNC_FUNC_3BUF(band, int64_t, int64_t) +FUNC_FUNC_3BUF(band, uint64_t, uint64_t) +FUNC_FUNC_3BUF(band, long, long) +FUNC_FUNC_3BUF(band, ulong, unsigned long) + +/* Fortran integer */ +#if OMPI_HAVE_FORTRAN_INTEGER +FORT_INT_FUNC_3BUF(band, fortran_integer, ompi_fortran_integer_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER1 +FORT_INT_FUNC_3BUF(band, fortran_integer1, ompi_fortran_integer1_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER2 +FORT_INT_FUNC_3BUF(band, fortran_integer2, ompi_fortran_integer2_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER4 +FORT_INT_FUNC_3BUF(band, fortran_integer4, ompi_fortran_integer4_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER8 +FORT_INT_FUNC_3BUF(band, fortran_integer8, ompi_fortran_integer8_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER16 +FORT_INT_FUNC_3BUF(band, fortran_integer16, ompi_fortran_integer16_t) +#endif +/* Byte */ +FORT_INT_FUNC_3BUF(band, byte, char) + +/************************************************************************* + * Bitwise OR + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) | (b)) +/* C integer */ +FUNC_FUNC_3BUF(bor, int8_t, int8_t) +FUNC_FUNC_3BUF(bor, uint8_t, uint8_t) +FUNC_FUNC_3BUF(bor, int16_t, int16_t) +FUNC_FUNC_3BUF(bor, uint16_t, uint16_t) +FUNC_FUNC_3BUF(bor, int32_t, int32_t) +FUNC_FUNC_3BUF(bor, uint32_t, uint32_t) +FUNC_FUNC_3BUF(bor, int64_t, int64_t) +FUNC_FUNC_3BUF(bor, uint64_t, uint64_t) +FUNC_FUNC_3BUF(bor, long, long) +FUNC_FUNC_3BUF(bor, ulong, unsigned long) + +/* Fortran integer */ +#if OMPI_HAVE_FORTRAN_INTEGER +FORT_INT_FUNC_3BUF(bor, fortran_integer, ompi_fortran_integer_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER1 +FORT_INT_FUNC_3BUF(bor, fortran_integer1, ompi_fortran_integer1_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER2 +FORT_INT_FUNC_3BUF(bor, fortran_integer2, ompi_fortran_integer2_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER4 +FORT_INT_FUNC_3BUF(bor, fortran_integer4, ompi_fortran_integer4_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER8 +FORT_INT_FUNC_3BUF(bor, fortran_integer8, ompi_fortran_integer8_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER16 +FORT_INT_FUNC_3BUF(bor, fortran_integer16, ompi_fortran_integer16_t) +#endif +/* Byte */ +FORT_INT_FUNC_3BUF(bor, byte, char) + +/************************************************************************* + * Bitwise XOR + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) ^ (b)) +/* C integer */ +FUNC_FUNC_3BUF(bxor, int8_t, int8_t) +FUNC_FUNC_3BUF(bxor, uint8_t, uint8_t) +FUNC_FUNC_3BUF(bxor, int16_t, int16_t) +FUNC_FUNC_3BUF(bxor, uint16_t, uint16_t) +FUNC_FUNC_3BUF(bxor, int32_t, int32_t) +FUNC_FUNC_3BUF(bxor, uint32_t, uint32_t) +FUNC_FUNC_3BUF(bxor, int64_t, int64_t) +FUNC_FUNC_3BUF(bxor, uint64_t, uint64_t) +FUNC_FUNC_3BUF(bxor, long, long) +FUNC_FUNC_3BUF(bxor, ulong, unsigned long) + +/* Fortran integer */ +#if OMPI_HAVE_FORTRAN_INTEGER +FORT_INT_FUNC_3BUF(bxor, fortran_integer, ompi_fortran_integer_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER1 +FORT_INT_FUNC_3BUF(bxor, fortran_integer1, ompi_fortran_integer1_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER2 +FORT_INT_FUNC_3BUF(bxor, fortran_integer2, ompi_fortran_integer2_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER4 +FORT_INT_FUNC_3BUF(bxor, fortran_integer4, ompi_fortran_integer4_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER8 +FORT_INT_FUNC_3BUF(bxor, fortran_integer8, ompi_fortran_integer8_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER16 +FORT_INT_FUNC_3BUF(bxor, fortran_integer16, ompi_fortran_integer16_t) +#endif +/* Byte */ +FORT_INT_FUNC_3BUF(bxor, byte, char) + +/************************************************************************* + * Max location + *************************************************************************/ + +LOC_FUNC_3BUF(maxloc, float_int) +LOC_FUNC_3BUF(maxloc, double_int) +LOC_FUNC_3BUF(maxloc, long_int) +LOC_FUNC_3BUF(maxloc, 2int) +LOC_FUNC_3BUF(maxloc, short_int) +LOC_FUNC_3BUF(maxloc, long_double_int) + +/* Fortran compat types */ +LOC_FUNC_3BUF(maxloc, 2float) +LOC_FUNC_3BUF(maxloc, 2double) +LOC_FUNC_3BUF(maxloc, 2int8) +LOC_FUNC_3BUF(maxloc, 2int16) +LOC_FUNC_3BUF(maxloc, 2int32) +LOC_FUNC_3BUF(maxloc, 2int64) + +/* Fortran integer */ +#if OMPI_HAVE_FORTRAN_INTEGER +FORT_LOC_INT_FUNC_3BUF(maxloc, 2integer, ompi_fortran_integer_t) +#endif +/* Fortran float */ +#if OMPI_HAVE_FORTRAN_REAL +FORT_LOC_FLOAT_FUNC_3BUF(maxloc, 2real, ompi_fortran_real_t) +#endif +#if OMPI_HAVE_FORTRAN_DOUBLE_PRECISION +FORT_LOC_FLOAT_FUNC_3BUF(maxloc, 2double_precision, ompi_fortran_double_precision_t) +#endif + +/************************************************************************* + * Min location + *************************************************************************/ + +LOC_FUNC_3BUF(minloc, float_int) +LOC_FUNC_3BUF(minloc, double_int) +LOC_FUNC_3BUF(minloc, long_int) +LOC_FUNC_3BUF(minloc, 2int) +LOC_FUNC_3BUF(minloc, short_int) +LOC_FUNC_3BUF(minloc, long_double_int) + +/* Fortran compat types */ +LOC_FUNC_3BUF(minloc, 2float) +LOC_FUNC_3BUF(minloc, 2double) +LOC_FUNC_3BUF(minloc, 2int8) +LOC_FUNC_3BUF(minloc, 2int16) +LOC_FUNC_3BUF(minloc, 2int32) +LOC_FUNC_3BUF(minloc, 2int64) + +/* Fortran integer */ +#if OMPI_HAVE_FORTRAN_INTEGER +FORT_LOC_INT_FUNC_3BUF(minloc, 2integer, ompi_fortran_integer_t) +#endif +/* Fortran float */ +#if OMPI_HAVE_FORTRAN_REAL +FORT_LOC_FLOAT_FUNC_3BUF(minloc, 2real, ompi_fortran_real_t) +#endif +#if OMPI_HAVE_FORTRAN_DOUBLE_PRECISION +FORT_LOC_FLOAT_FUNC_3BUF(minloc, 2double_precision, ompi_fortran_double_precision_t) +#endif + + +/* + * Helpful defines, because there's soooo many names! + * + * **NOTE** These #define's used to be strictly ordered but the use of + * designated initializers removed this restrictions. When adding new + * operators ALWAYS use a designated initializer! + */ + +/** C integer ***********************************************************/ +#define C_INTEGER(name, ftype) \ + [OMPI_OP_BASE_TYPE_INT8_T] = ompi_op_rocm_##ftype##_##name##_int8_t, \ + [OMPI_OP_BASE_TYPE_UINT8_T] = ompi_op_rocm_##ftype##_##name##_uint8_t, \ + [OMPI_OP_BASE_TYPE_INT16_T] = ompi_op_rocm_##ftype##_##name##_int16_t, \ + [OMPI_OP_BASE_TYPE_UINT16_T] = ompi_op_rocm_##ftype##_##name##_uint16_t, \ + [OMPI_OP_BASE_TYPE_INT32_T] = ompi_op_rocm_##ftype##_##name##_int32_t, \ + [OMPI_OP_BASE_TYPE_UINT32_T] = ompi_op_rocm_##ftype##_##name##_uint32_t, \ + [OMPI_OP_BASE_TYPE_INT64_T] = ompi_op_rocm_##ftype##_##name##_int64_t, \ + [OMPI_OP_BASE_TYPE_LONG] = ompi_op_rocm_##ftype##_##name##_long, \ + [OMPI_OP_BASE_TYPE_UNSIGNED_LONG] = ompi_op_rocm_##ftype##_##name##_ulong, \ + [OMPI_OP_BASE_TYPE_UINT64_T] = ompi_op_rocm_##ftype##_##name##_uint64_t + +/** All the Fortran integers ********************************************/ + +#if OMPI_HAVE_FORTRAN_INTEGER +#define FORTRAN_INTEGER_PLAIN(name, ftype) ompi_op_rocm_##ftype##_##name##_fortran_integer +#else +#define FORTRAN_INTEGER_PLAIN(name, ftype) NULL +#endif +#if OMPI_HAVE_FORTRAN_INTEGER1 +#define FORTRAN_INTEGER1(name, ftype) ompi_op_rocm_##ftype##_##name##_fortran_integer1 +#else +#define FORTRAN_INTEGER1(name, ftype) NULL +#endif +#if OMPI_HAVE_FORTRAN_INTEGER2 +#define FORTRAN_INTEGER2(name, ftype) ompi_op_rocm_##ftype##_##name##_fortran_integer2 +#else +#define FORTRAN_INTEGER2(name, ftype) NULL +#endif +#if OMPI_HAVE_FORTRAN_INTEGER4 +#define FORTRAN_INTEGER4(name, ftype) ompi_op_rocm_##ftype##_##name##_fortran_integer4 +#else +#define FORTRAN_INTEGER4(name, ftype) NULL +#endif +#if OMPI_HAVE_FORTRAN_INTEGER8 +#define FORTRAN_INTEGER8(name, ftype) ompi_op_rocm_##ftype##_##name##_fortran_integer8 +#else +#define FORTRAN_INTEGER8(name, ftype) NULL +#endif +#if OMPI_HAVE_FORTRAN_INTEGER16 +#define FORTRAN_INTEGER16(name, ftype) ompi_op_rocm_##ftype##_##name##_fortran_integer16 +#else +#define FORTRAN_INTEGER16(name, ftype) NULL +#endif + +#define FORTRAN_INTEGER(name, ftype) \ + [OMPI_OP_BASE_TYPE_INTEGER] = FORTRAN_INTEGER_PLAIN(name, ftype), \ + [OMPI_OP_BASE_TYPE_INTEGER1] = FORTRAN_INTEGER1(name, ftype), \ + [OMPI_OP_BASE_TYPE_INTEGER2] = FORTRAN_INTEGER2(name, ftype), \ + [OMPI_OP_BASE_TYPE_INTEGER4] = FORTRAN_INTEGER4(name, ftype), \ + [OMPI_OP_BASE_TYPE_INTEGER8] = FORTRAN_INTEGER8(name, ftype), \ + [OMPI_OP_BASE_TYPE_INTEGER16] = FORTRAN_INTEGER16(name, ftype) + +/** All the Fortran reals ***********************************************/ + +#if OMPI_HAVE_FORTRAN_REAL +#define FLOATING_POINT_FORTRAN_REAL_PLAIN(name, ftype) ompi_op_rocm_##ftype##_##name##_fortran_real +#else +#define FLOATING_POINT_FORTRAN_REAL_PLAIN(name, ftype) NULL +#endif +#if OMPI_HAVE_FORTRAN_REAL2 +#define FLOATING_POINT_FORTRAN_REAL2(name, ftype) ompi_op_rocm_##ftype##_##name##_fortran_real2 +#else +#define FLOATING_POINT_FORTRAN_REAL2(name, ftype) NULL +#endif +#if OMPI_HAVE_FORTRAN_REAL4 +#define FLOATING_POINT_FORTRAN_REAL4(name, ftype) ompi_op_rocm_##ftype##_##name##_fortran_real4 +#else +#define FLOATING_POINT_FORTRAN_REAL4(name, ftype) NULL +#endif +#if OMPI_HAVE_FORTRAN_REAL8 +#define FLOATING_POINT_FORTRAN_REAL8(name, ftype) ompi_op_rocm_##ftype##_##name##_fortran_real8 +#else +#define FLOATING_POINT_FORTRAN_REAL8(name, ftype) NULL +#endif +/* If: + - we have fortran REAL*16and* + - fortran REAL*16 matches the bit representation of the + corresponding C type + Only then do we put in function pointers for REAL*16 reductions. + Otherwise, just put in NULL. */ +#if OMPI_HAVE_FORTRAN_REAL16 && OMPI_REAL16_MATCHES_C +#define FLOATING_POINT_FORTRAN_REAL16(name, ftype) ompi_op_rocm_##ftype##_##name##_fortran_real16 +#else +#define FLOATING_POINT_FORTRAN_REAL16(name, ftype) NULL +#endif + +#define FLOATING_POINT_FORTRAN_REAL(name, ftype) \ + [OMPI_OP_BASE_TYPE_REAL] = FLOATING_POINT_FORTRAN_REAL_PLAIN(name, ftype), \ + [OMPI_OP_BASE_TYPE_REAL2] = FLOATING_POINT_FORTRAN_REAL2(name, ftype), \ + [OMPI_OP_BASE_TYPE_REAL4] = FLOATING_POINT_FORTRAN_REAL4(name, ftype), \ + [OMPI_OP_BASE_TYPE_REAL8] = FLOATING_POINT_FORTRAN_REAL8(name, ftype), \ + [OMPI_OP_BASE_TYPE_REAL16] = FLOATING_POINT_FORTRAN_REAL16(name, ftype) + +/** Fortran double precision ********************************************/ + +#if OMPI_HAVE_FORTRAN_DOUBLE_PRECISION +#define FLOATING_POINT_FORTRAN_DOUBLE_PRECISION(name, ftype) \ + ompi_op_rocm_##ftype##_##name##_fortran_double_precision +#else +#define FLOATING_POINT_FORTRAN_DOUBLE_PRECISION(name, ftype) NULL +#endif + +/** Floating point, including all the Fortran reals *********************/ + +//#if defined(HAVE_SHORT_FLOAT) || defined(HAVE_OPAL_SHORT_FLOAT_T) +//#define SHORT_FLOAT(name, ftype) ompi_op_rocm_##ftype##_##name##_short_float +//#else +#define SHORT_FLOAT(name, ftype) NULL +//#endif +#define FLOAT(name, ftype) ompi_op_rocm_##ftype##_##name##_float +#define DOUBLE(name, ftype) ompi_op_rocm_##ftype##_##name##_double +#define LONG_DOUBLE(name, ftype) ompi_op_rocm_##ftype##_##name##_long_double + +#define FLOATING_POINT(name, ftype) \ + [OMPI_OP_BASE_TYPE_SHORT_FLOAT] = SHORT_FLOAT(name, ftype), \ + [OMPI_OP_BASE_TYPE_FLOAT] = FLOAT(name, ftype), \ + [OMPI_OP_BASE_TYPE_DOUBLE] = DOUBLE(name, ftype), \ + FLOATING_POINT_FORTRAN_REAL(name, ftype), \ + [OMPI_OP_BASE_TYPE_DOUBLE_PRECISION] = FLOATING_POINT_FORTRAN_DOUBLE_PRECISION(name, ftype), \ + [OMPI_OP_BASE_TYPE_LONG_DOUBLE] = LONG_DOUBLE(name, ftype) + +/** Fortran logical *****************************************************/ + +#if OMPI_HAVE_FORTRAN_LOGICAL +#define FORTRAN_LOGICAL(name, ftype) \ + ompi_op_rocm_##ftype##_##name##_fortran_logical /* OMPI_OP_ROCM_TYPE_LOGICAL */ +#else +#define FORTRAN_LOGICAL(name, ftype) NULL +#endif + +#define LOGICAL(name, ftype) \ + [OMPI_OP_BASE_TYPE_LOGICAL] = FORTRAN_LOGICAL(name, ftype), \ + [OMPI_OP_BASE_TYPE_BOOL] = ompi_op_rocm_##ftype##_##name##_bool + +/** Complex *****************************************************/ +#if 0 + +#if defined(HAVE_SHORT_FLOAT__COMPLEX) || defined(HAVE_OPAL_SHORT_FLOAT_COMPLEX_T) +#define SHORT_FLOAT_COMPLEX(name, ftype) ompi_op_rocm_##ftype##_##name##_c_short_float_complex +#else +#define SHORT_FLOAT_COMPLEX(name, ftype) NULL +#endif +#define LONG_DOUBLE_COMPLEX(name, ftype) ompi_op_rocm_##ftype##_##name##_c_long_double_complex +#else +#define SHORT_FLOAT_COMPLEX(name, ftype) NULL +#define LONG_DOUBLE_COMPLEX(name, ftype) NULL +#endif // 0 +#define FLOAT_COMPLEX(name, ftype) ompi_op_rocm_##ftype##_##name##_c_float_complex +#define DOUBLE_COMPLEX(name, ftype) ompi_op_rocm_##ftype##_##name##_c_double_complex + +#define COMPLEX(name, ftype) \ + [OMPI_OP_BASE_TYPE_C_SHORT_FLOAT_COMPLEX] = SHORT_FLOAT_COMPLEX(name, ftype), \ + [OMPI_OP_BASE_TYPE_C_FLOAT_COMPLEX] = FLOAT_COMPLEX(name, ftype), \ + [OMPI_OP_BASE_TYPE_C_DOUBLE_COMPLEX] = DOUBLE_COMPLEX(name, ftype), \ + [OMPI_OP_BASE_TYPE_C_LONG_DOUBLE_COMPLEX] = LONG_DOUBLE_COMPLEX(name, ftype) + +/** Byte ****************************************************************/ + +#define BYTE(name, ftype) \ + [OMPI_OP_BASE_TYPE_BYTE] = ompi_op_rocm_##ftype##_##name##_byte + +/** Fortran complex *****************************************************/ +/** Fortran "2" types ***************************************************/ + +#if OMPI_HAVE_FORTRAN_REAL +#define TWOLOC_FORTRAN_2REAL(name, ftype) ompi_op_rocm_##ftype##_##name##_2double_precision +#else +#define TWOLOC_FORTRAN_2REAL(name, ftype) NULL +#endif + +#if OMPI_HAVE_FORTRAN_DOUBLE_PRECISION +#define TWOLOC_FORTRAN_2DOUBLE_PRECISION(name, ftype) ompi_op_rocm_##ftype##_##name##_2double_precision +#else +#define TWOLOC_FORTRAN_2DOUBLE_PRECISION(name, ftype) NULL +#endif +#if OMPI_HAVE_FORTRAN_INTEGER +#define TWOLOC_FORTRAN_2INTEGER(name, ftype) ompi_op_rocm_##ftype##_##name##_2int +#else +#define TWOLOC_FORTRAN_2INTEGER(name, ftype) NULL +#endif + +/** All "2" types *******************************************************/ + +#define TWOLOC(name, ftype) \ + [OMPI_OP_BASE_TYPE_2REAL] = TWOLOC_FORTRAN_2REAL(name, ftype), \ + [OMPI_OP_BASE_TYPE_2DOUBLE_PRECISION] = TWOLOC_FORTRAN_2DOUBLE_PRECISION(name, ftype), \ + [OMPI_OP_BASE_TYPE_2INTEGER] = TWOLOC_FORTRAN_2INTEGER(name, ftype), \ + [OMPI_OP_BASE_TYPE_FLOAT_INT] = ompi_op_rocm_##ftype##_##name##_float_int, \ + [OMPI_OP_BASE_TYPE_DOUBLE_INT] = ompi_op_rocm_##ftype##_##name##_double_int, \ + [OMPI_OP_BASE_TYPE_LONG_INT] = ompi_op_rocm_##ftype##_##name##_long_int, \ + [OMPI_OP_BASE_TYPE_2INT] = ompi_op_rocm_##ftype##_##name##_2int, \ + [OMPI_OP_BASE_TYPE_SHORT_INT] = ompi_op_rocm_##ftype##_##name##_short_int, \ + [OMPI_OP_BASE_TYPE_LONG_DOUBLE_INT] = ompi_op_rocm_##ftype##_##name##_long_double_int + +/* + * MPI_OP_NULL + * All types + */ +#define FLAGS_NO_FLOAT \ + (OMPI_OP_FLAGS_INTRINSIC | OMPI_OP_FLAGS_ASSOC | OMPI_OP_FLAGS_COMMUTE) +#define FLAGS \ + (OMPI_OP_FLAGS_INTRINSIC | OMPI_OP_FLAGS_ASSOC | \ + OMPI_OP_FLAGS_FLOAT_ASSOC | OMPI_OP_FLAGS_COMMUTE) + +ompi_op_base_stream_handler_fn_t ompi_op_rocm_functions[OMPI_OP_BASE_FORTRAN_OP_MAX][OMPI_OP_BASE_TYPE_MAX] = + { + /* Corresponds to MPI_OP_NULL */ + [OMPI_OP_BASE_FORTRAN_NULL] = { + /* Leaving this empty puts in NULL for all entries */ + NULL, + }, + /* Corresponds to MPI_MAX */ + [OMPI_OP_BASE_FORTRAN_MAX] = { + C_INTEGER(max, 2buff), + FORTRAN_INTEGER(max, 2buff), + FLOATING_POINT(max, 2buff), + }, + /* Corresponds to MPI_MIN */ + [OMPI_OP_BASE_FORTRAN_MIN] = { + C_INTEGER(min, 2buff), + FORTRAN_INTEGER(min, 2buff), + FLOATING_POINT(min, 2buff), + }, + /* Corresponds to MPI_SUM */ + [OMPI_OP_BASE_FORTRAN_SUM] = { + C_INTEGER(sum, 2buff), + FORTRAN_INTEGER(sum, 2buff), + FLOATING_POINT(sum, 2buff), + COMPLEX(sum, 2buff), + }, + /* Corresponds to MPI_PROD */ + [OMPI_OP_BASE_FORTRAN_PROD] = { + C_INTEGER(prod, 2buff), + FORTRAN_INTEGER(prod, 2buff), + FLOATING_POINT(prod, 2buff), + COMPLEX(prod, 2buff), + }, + /* Corresponds to MPI_LAND */ + [OMPI_OP_BASE_FORTRAN_LAND] = { + C_INTEGER(land, 2buff), + LOGICAL(land, 2buff), + }, + /* Corresponds to MPI_BAND */ + [OMPI_OP_BASE_FORTRAN_BAND] = { + C_INTEGER(band, 2buff), + FORTRAN_INTEGER(band, 2buff), + BYTE(band, 2buff), + }, + /* Corresponds to MPI_LOR */ + [OMPI_OP_BASE_FORTRAN_LOR] = { + C_INTEGER(lor, 2buff), + LOGICAL(lor, 2buff), + }, + /* Corresponds to MPI_BOR */ + [OMPI_OP_BASE_FORTRAN_BOR] = { + C_INTEGER(bor, 2buff), + FORTRAN_INTEGER(bor, 2buff), + BYTE(bor, 2buff), + }, + /* Corresponds to MPI_LXOR */ + [OMPI_OP_BASE_FORTRAN_LXOR] = { + C_INTEGER(lxor, 2buff), + LOGICAL(lxor, 2buff), + }, + /* Corresponds to MPI_BXOR */ + [OMPI_OP_BASE_FORTRAN_BXOR] = { + C_INTEGER(bxor, 2buff), + FORTRAN_INTEGER(bxor, 2buff), + BYTE(bxor, 2buff), + }, + /* Corresponds to MPI_MAXLOC */ + [OMPI_OP_BASE_FORTRAN_MAXLOC] = { + TWOLOC(maxloc, 2buff), + }, + /* Corresponds to MPI_MINLOC */ + [OMPI_OP_BASE_FORTRAN_MINLOC] = { + TWOLOC(minloc, 2buff), + }, + /* Corresponds to MPI_REPLACE */ + [OMPI_OP_BASE_FORTRAN_REPLACE] = { + /* (MPI_ACCUMULATE is handled differently than the other + reductions, so just zero out its function + implementations here to ensure that users don't invoke + MPI_REPLACE with any reduction operations other than + ACCUMULATE) */ + NULL, + }, + + }; + +ompi_op_base_3buff_stream_handler_fn_t ompi_op_rocm_3buff_functions[OMPI_OP_BASE_FORTRAN_OP_MAX][OMPI_OP_BASE_TYPE_MAX] = + { + /* Corresponds to MPI_OP_NULL */ + [OMPI_OP_BASE_FORTRAN_NULL] = { + /* Leaving this empty puts in NULL for all entries */ + NULL, + }, + /* Corresponds to MPI_MAX */ + [OMPI_OP_BASE_FORTRAN_MAX] = { + C_INTEGER(max, 3buff), + FORTRAN_INTEGER(max, 3buff), + FLOATING_POINT(max, 3buff), + }, + /* Corresponds to MPI_MIN */ + [OMPI_OP_BASE_FORTRAN_MIN] = { + C_INTEGER(min, 3buff), + FORTRAN_INTEGER(min, 3buff), + FLOATING_POINT(min, 3buff), + }, + /* Corresponds to MPI_SUM */ + [OMPI_OP_BASE_FORTRAN_SUM] = { + C_INTEGER(sum, 3buff), + FORTRAN_INTEGER(sum, 3buff), + FLOATING_POINT(sum, 3buff), + COMPLEX(sum, 3buff), + }, + /* Corresponds to MPI_PROD */ + [OMPI_OP_BASE_FORTRAN_PROD] = { + C_INTEGER(prod, 3buff), + FORTRAN_INTEGER(prod, 3buff), + FLOATING_POINT(prod, 3buff), + COMPLEX(prod, 3buff), + }, + /* Corresponds to MPI_LAND */ + [OMPI_OP_BASE_FORTRAN_LAND] ={ + C_INTEGER(land, 3buff), + LOGICAL(land, 3buff), + }, + /* Corresponds to MPI_BAND */ + [OMPI_OP_BASE_FORTRAN_BAND] = { + C_INTEGER(band, 3buff), + FORTRAN_INTEGER(band, 3buff), + BYTE(band, 3buff), + }, + /* Corresponds to MPI_LOR */ + [OMPI_OP_BASE_FORTRAN_LOR] = { + C_INTEGER(lor, 3buff), + LOGICAL(lor, 3buff), + }, + /* Corresponds to MPI_BOR */ + [OMPI_OP_BASE_FORTRAN_BOR] = { + C_INTEGER(bor, 3buff), + FORTRAN_INTEGER(bor, 3buff), + BYTE(bor, 3buff), + }, + /* Corresponds to MPI_LXOR */ + [OMPI_OP_BASE_FORTRAN_LXOR] = { + C_INTEGER(lxor, 3buff), + LOGICAL(lxor, 3buff), + }, + /* Corresponds to MPI_BXOR */ + [OMPI_OP_BASE_FORTRAN_BXOR] = { + C_INTEGER(bxor, 3buff), + FORTRAN_INTEGER(bxor, 3buff), + BYTE(bxor, 3buff), + }, + /* Corresponds to MPI_MAXLOC */ + [OMPI_OP_BASE_FORTRAN_MAXLOC] = { + TWOLOC(maxloc, 3buff), + }, + /* Corresponds to MPI_MINLOC */ + [OMPI_OP_BASE_FORTRAN_MINLOC] = { + TWOLOC(minloc, 3buff), + }, + /* Corresponds to MPI_REPLACE */ + [OMPI_OP_BASE_FORTRAN_REPLACE] = { + /* MPI_ACCUMULATE is handled differently than the other + reductions, so just zero out its function + implementations here to ensure that users don't invoke + MPI_REPLACE with any reduction operations other than + ACCUMULATE */ + NULL, + }, + }; diff --git a/ompi/mca/op/rocm/op_rocm_impl.h b/ompi/mca/op/rocm/op_rocm_impl.h new file mode 100644 index 00000000000..907a19fd4fa --- /dev/null +++ b/ompi/mca/op/rocm/op_rocm_impl.h @@ -0,0 +1,706 @@ +/* + * Copyright (c) 2019-2023 The University of Tennessee and The University + * of Tennessee Research Foundation. All rights + * reserved. + * Copyright (c) 2020 Research Organization for Information Science + * and Technology (RIST). All rights reserved. + * $COPYRIGHT$ + * + * Additional copyrights may follow + * + * $HEADER$ + */ + +#include + +#include +#include + +#ifndef BEGIN_C_DECLS +#if defined(c_plusplus) || defined(__cplusplus) +# define BEGIN_C_DECLS extern "C" { +# define END_C_DECLS } +#else +# define BEGIN_C_DECLS /* empty */ +# define END_C_DECLS /* empty */ +#endif +#endif + +BEGIN_C_DECLS + +#define OP_FUNC_SIG(name, type_name, type) \ + void ompi_op_rocm_2buff_##name##_##type_name##_submit(const type *in, \ + type *inout, \ + int count, \ + int threads_per_block, \ + int max_blocks, \ + hipStream_t stream); + +#define FUNC_FUNC_SIG(name, type_name, type) \ + void ompi_op_rocm_2buff_##name##_##type_name##_submit(const type *in, \ + type *inout, \ + int count, \ + int threads_per_block, \ + int max_blocks, \ + hipStream_t stream); + +/* + * Since all the functions in this file are essentially identical, we + * use a macro to substitute in names and types. The core operation + * in all functions that use this macro is the same. + * + * This macro is for minloc and maxloc + */ +#define LOC_STRUCT(type_name, type1, type2) \ + typedef struct { \ + type1 v; \ + type2 k; \ + } ompi_op_predefined_##type_name##_t; + +#define LOC_FUNC_SIG(name, type_name) \ + void ompi_op_rocm_2buff_##name##_##type_name##_submit(const ompi_op_predefined_##type_name##_t *a, \ + ompi_op_predefined_##type_name##_t *b, \ + int count, \ + int threads_per_block, \ + int max_blocks, \ + hipStream_t stream); + +/************************************************************************* + * Max + *************************************************************************/ + +/* C integer */ +FUNC_FUNC_SIG(max, int8_t, int8_t) +FUNC_FUNC_SIG(max, uint8_t, uint8_t) +FUNC_FUNC_SIG(max, int16_t, int16_t) +FUNC_FUNC_SIG(max, uint16_t, uint16_t) +FUNC_FUNC_SIG(max, int32_t, int32_t) +FUNC_FUNC_SIG(max, uint32_t, uint32_t) +FUNC_FUNC_SIG(max, int64_t, int64_t) +FUNC_FUNC_SIG(max, uint64_t, uint64_t) +FUNC_FUNC_SIG(max, long, long) +FUNC_FUNC_SIG(max, ulong, unsigned long) + +#if 0 +/* Floating point */ +#if defined(HAVE_SHORT_FLOAT) +FUNC_FUNC_SIG(max, short_float, short float) +#elif defined(HAVE_OPAL_SHORT_FLOAT_T) +FUNC_FUNC_SIG(max, short_float, opal_short_float_t) +#endif +#endif // 0 + +FUNC_FUNC_SIG(max, float, float) +FUNC_FUNC_SIG(max, double, double) +FUNC_FUNC_SIG(max, long_double, long double) + +/************************************************************************* + * Min + *************************************************************************/ + +/* C integer */ +FUNC_FUNC_SIG(min, int8_t, int8_t) +FUNC_FUNC_SIG(min, uint8_t, uint8_t) +FUNC_FUNC_SIG(min, int16_t, int16_t) +FUNC_FUNC_SIG(min, uint16_t, uint16_t) +FUNC_FUNC_SIG(min, int32_t, int32_t) +FUNC_FUNC_SIG(min, uint32_t, uint32_t) +FUNC_FUNC_SIG(min, int64_t, int64_t) +FUNC_FUNC_SIG(min, uint64_t, uint64_t) +FUNC_FUNC_SIG(min, long, long) +FUNC_FUNC_SIG(min, ulong, unsigned long) + +#if 0 +/* Floating point */ +#if defined(HAVE_SHORT_FLOAT) +FUNC_FUNC_SIG(min, short_float, short float) +#elif defined(HAVE_OPAL_SHORT_FLOAT_T) +FUNC_FUNC_SIG(min, short_float, opal_short_float_t) +#endif +#endif // 0 + +FUNC_FUNC_SIG(min, float, float) +FUNC_FUNC_SIG(min, double, double) +FUNC_FUNC_SIG(min, long_double, long double) + +/************************************************************************* + * Sum + *************************************************************************/ + +/* C integer */ +OP_FUNC_SIG(sum, int8_t, int8_t) +OP_FUNC_SIG(sum, uint8_t, uint8_t) +OP_FUNC_SIG(sum, int16_t, int16_t) +OP_FUNC_SIG(sum, uint16_t, uint16_t) +OP_FUNC_SIG(sum, int32_t, int32_t) +OP_FUNC_SIG(sum, uint32_t, uint32_t) +OP_FUNC_SIG(sum, int64_t, int64_t) +OP_FUNC_SIG(sum, uint64_t, uint64_t) +OP_FUNC_SIG(sum, long, long) +OP_FUNC_SIG(sum, ulong, unsigned long) + +#if 0 +/* Floating point */ +#if defined(HAVE_SHORT_FLOAT) +OP_FUNC_SIG(sum, short_float, short float) +#elif defined(HAVE_OPAL_SHORT_FLOAT_T) +OP_FUNC_SIG(sum, short_float, opal_short_float_t) +#endif +#endif // 0 + +OP_FUNC_SIG(sum, float, float) +OP_FUNC_SIG(sum, double, double) +OP_FUNC_SIG(sum, long_double, long double) + +/* Complex */ +#if 0 +#if defined(HAVE_SHORT_FLOAT__COMPLEX) +OP_FUNC_SIG(sum, c_short_float_complex, short float _Complex) +#elif defined(HAVE_OPAL_SHORT_FLOAT_COMPLEX_T) +COMPLEX_SUM_FUNC(c_short_float_complex, opal_short_float_t) +#endif +#endif // 0 +FUNC_FUNC_SIG(sum, c_float_complex, hipFloatComplex) +FUNC_FUNC_SIG(sum, c_double_complex, hipDoubleComplex) +//OP_FUNC_SIG(sum, c_float_complex, float _Complex) +//OP_FUNC_SIG(sum, c_double_complex, double _Complex) +//OP_FUNC_SIG(sum, c_long_double_complex, long double _Complex) + +/************************************************************************* + * Product + *************************************************************************/ + +/* C integer */ +OP_FUNC_SIG(prod, int8_t, int8_t) +OP_FUNC_SIG(prod, uint8_t, uint8_t) +OP_FUNC_SIG(prod, int16_t, int16_t) +OP_FUNC_SIG(prod, uint16_t, uint16_t) +OP_FUNC_SIG(prod, int32_t, int32_t) +OP_FUNC_SIG(prod, uint32_t, uint32_t) +OP_FUNC_SIG(prod, int64_t, int64_t) +OP_FUNC_SIG(prod, uint64_t, uint64_t) +OP_FUNC_SIG(prod, long, long) +OP_FUNC_SIG(prod, ulong, unsigned long) + +#if 0 +/* Floating point */ +#if defined(HAVE_SHORT_FLOAT) +OP_FUNC_SIG(prod, short_float, short float) +#elif defined(HAVE_OPAL_SHORT_FLOAT_T) +OP_FUNC_SIG(prod, short_float, opal_short_float_t) +#endif +#endif // 0 + +OP_FUNC_SIG(prod, float, float) +OP_FUNC_SIG(prod, double, double) +OP_FUNC_SIG(prod, long_double, long double) + +/* Complex */ +#if 0 +#if defined(HAVE_SHORT_FLOAT__COMPLEX) +OP_FUNC_SIG(prod, c_short_float_complex, short float _Complex) +#elif defined(HAVE_OPAL_SHORT_FLOAT_COMPLEX_T) +COMPLEX_PROD_FUNC(c_short_float_complex, opal_short_float_t) +#endif +OP_FUNC_SIG(prod, c_long_double_complex, long double _Complex) +#endif // 0 +FUNC_FUNC_SIG(prod, c_float_complex, hipFloatComplex) +FUNC_FUNC_SIG(prod, c_double_complex, hipDoubleComplex) + +/************************************************************************* + * Logical AND + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) && (b)) +/* C integer */ +FUNC_FUNC_SIG(land, int8_t, int8_t) +FUNC_FUNC_SIG(land, uint8_t, uint8_t) +FUNC_FUNC_SIG(land, int16_t, int16_t) +FUNC_FUNC_SIG(land, uint16_t, uint16_t) +FUNC_FUNC_SIG(land, int32_t, int32_t) +FUNC_FUNC_SIG(land, uint32_t, uint32_t) +FUNC_FUNC_SIG(land, int64_t, int64_t) +FUNC_FUNC_SIG(land, uint64_t, uint64_t) +FUNC_FUNC_SIG(land, long, long) +FUNC_FUNC_SIG(land, ulong, unsigned long) + +/* C++ bool */ +FUNC_FUNC_SIG(land, bool, bool) + +/************************************************************************* + * Logical OR + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) || (b)) +/* C integer */ +FUNC_FUNC_SIG(lor, int8_t, int8_t) +FUNC_FUNC_SIG(lor, uint8_t, uint8_t) +FUNC_FUNC_SIG(lor, int16_t, int16_t) +FUNC_FUNC_SIG(lor, uint16_t, uint16_t) +FUNC_FUNC_SIG(lor, int32_t, int32_t) +FUNC_FUNC_SIG(lor, uint32_t, uint32_t) +FUNC_FUNC_SIG(lor, int64_t, int64_t) +FUNC_FUNC_SIG(lor, uint64_t, uint64_t) +FUNC_FUNC_SIG(lor, long, long) +FUNC_FUNC_SIG(lor, ulong, unsigned long) + +/* C++ bool */ +FUNC_FUNC_SIG(lor, bool, bool) + +/************************************************************************* + * Logical XOR + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a ? 1 : 0) ^ (b ? 1: 0)) +/* C integer */ +FUNC_FUNC_SIG(lxor, int8_t, int8_t) +FUNC_FUNC_SIG(lxor, uint8_t, uint8_t) +FUNC_FUNC_SIG(lxor, int16_t, int16_t) +FUNC_FUNC_SIG(lxor, uint16_t, uint16_t) +FUNC_FUNC_SIG(lxor, int32_t, int32_t) +FUNC_FUNC_SIG(lxor, uint32_t, uint32_t) +FUNC_FUNC_SIG(lxor, int64_t, int64_t) +FUNC_FUNC_SIG(lxor, uint64_t, uint64_t) +FUNC_FUNC_SIG(lxor, long, long) +FUNC_FUNC_SIG(lxor, ulong, unsigned long) + +/* C++ bool */ +FUNC_FUNC_SIG(lxor, bool, bool) + +/************************************************************************* + * Bitwise AND + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) & (b)) +/* C integer */ +FUNC_FUNC_SIG(band, int8_t, int8_t) +FUNC_FUNC_SIG(band, uint8_t, uint8_t) +FUNC_FUNC_SIG(band, int16_t, int16_t) +FUNC_FUNC_SIG(band, uint16_t, uint16_t) +FUNC_FUNC_SIG(band, int32_t, int32_t) +FUNC_FUNC_SIG(band, uint32_t, uint32_t) +FUNC_FUNC_SIG(band, int64_t, int64_t) +FUNC_FUNC_SIG(band, uint64_t, uint64_t) +FUNC_FUNC_SIG(band, long, long) +FUNC_FUNC_SIG(band, ulong, unsigned long) + +/* Byte */ +FUNC_FUNC_SIG(band, byte, char) + +/************************************************************************* + * Bitwise OR + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) | (b)) +/* C integer */ +FUNC_FUNC_SIG(bor, int8_t, int8_t) +FUNC_FUNC_SIG(bor, uint8_t, uint8_t) +FUNC_FUNC_SIG(bor, int16_t, int16_t) +FUNC_FUNC_SIG(bor, uint16_t, uint16_t) +FUNC_FUNC_SIG(bor, int32_t, int32_t) +FUNC_FUNC_SIG(bor, uint32_t, uint32_t) +FUNC_FUNC_SIG(bor, int64_t, int64_t) +FUNC_FUNC_SIG(bor, uint64_t, uint64_t) +FUNC_FUNC_SIG(bor, long, long) +FUNC_FUNC_SIG(bor, ulong, unsigned long) + +/* Byte */ +FUNC_FUNC_SIG(bor, byte, char) + +/************************************************************************* + * Bitwise XOR + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) ^ (b)) +/* C integer */ +FUNC_FUNC_SIG(bxor, int8_t, int8_t) +FUNC_FUNC_SIG(bxor, uint8_t, uint8_t) +FUNC_FUNC_SIG(bxor, int16_t, int16_t) +FUNC_FUNC_SIG(bxor, uint16_t, uint16_t) +FUNC_FUNC_SIG(bxor, int32_t, int32_t) +FUNC_FUNC_SIG(bxor, uint32_t, uint32_t) +FUNC_FUNC_SIG(bxor, int64_t, int64_t) +FUNC_FUNC_SIG(bxor, uint64_t, uint64_t) +FUNC_FUNC_SIG(bxor, long, long) +FUNC_FUNC_SIG(bxor, ulong, unsigned long) + +/* Byte */ +FUNC_FUNC_SIG(bxor, byte, char) + +/************************************************************************* + * Min and max location "pair" datatypes + *************************************************************************/ + +LOC_STRUCT(float_int, float, int) +LOC_STRUCT(double_int, double, int) +LOC_STRUCT(long_int, long, int) +LOC_STRUCT(2int, int, int) +LOC_STRUCT(short_int, short, int) +LOC_STRUCT(long_double_int, long double, int) +LOC_STRUCT(ulong, unsigned long, int) +/* compat types for Fortran */ +LOC_STRUCT(2float, float, float) +LOC_STRUCT(2double, double, double) +LOC_STRUCT(2int8, int8_t, int8_t) +LOC_STRUCT(2int16, int16_t, int16_t) +LOC_STRUCT(2int32, int32_t, int32_t) +LOC_STRUCT(2int64, int64_t, int64_t) + +/************************************************************************* + * Max location + *************************************************************************/ + +LOC_FUNC_SIG(maxloc, 2float) +LOC_FUNC_SIG(maxloc, 2double) +LOC_FUNC_SIG(maxloc, 2int8) +LOC_FUNC_SIG(maxloc, 2int16) +LOC_FUNC_SIG(maxloc, 2int32) +LOC_FUNC_SIG(maxloc, 2int64) + +LOC_FUNC_SIG(maxloc, float_int) +LOC_FUNC_SIG(maxloc, double_int) +LOC_FUNC_SIG(maxloc, long_int) +LOC_FUNC_SIG(maxloc, 2int) +LOC_FUNC_SIG(maxloc, short_int) +LOC_FUNC_SIG(maxloc, long_double_int) + +/************************************************************************* + * Min location + *************************************************************************/ + +LOC_FUNC_SIG(minloc, 2float) +LOC_FUNC_SIG(minloc, 2double) +LOC_FUNC_SIG(minloc, 2int8) +LOC_FUNC_SIG(minloc, 2int16) +LOC_FUNC_SIG(minloc, 2int32) +LOC_FUNC_SIG(minloc, 2int64) + +LOC_FUNC_SIG(minloc, float_int) +LOC_FUNC_SIG(minloc, double_int) +LOC_FUNC_SIG(minloc, long_int) +LOC_FUNC_SIG(minloc, 2int) +LOC_FUNC_SIG(minloc, short_int) +LOC_FUNC_SIG(minloc, long_double_int) + + + +#define OP_FUNC_3BUF_SIG(name, type_name, type) \ + void ompi_op_rocm_3buff_##name##_##type_name##_submit(const type *in1, \ + const type *in2, \ + type *inout, \ + int count, \ + int threads_per_block, \ + int max_blocks, \ + hipStream_t stream); + +#define FUNC_FUNC_3BUF_SIG(name, type_name, type) \ + void ompi_op_rocm_3buff_##name##_##type_name##_submit(const type *in1, \ + const type *in2, \ + type *inout, \ + int count, \ + int threads_per_block, \ + int max_blocks, \ + hipStream_t stream); + +#define LOC_FUNC_3BUF_SIG(name, type_name) \ + void ompi_op_rocm_3buff_##name##_##type_name##_submit(const ompi_op_predefined_##type_name##_t *a1, \ + const ompi_op_predefined_##type_name##_t *a2, \ + ompi_op_predefined_##type_name##_t *b, \ + int count, \ + int threads_per_block, \ + int max_blocks, \ + hipStream_t stream); + + +/************************************************************************* + * Max + *************************************************************************/ + +/* C integer */ +FUNC_FUNC_3BUF_SIG(max, int8_t, int8_t) +FUNC_FUNC_3BUF_SIG(max, uint8_t, uint8_t) +FUNC_FUNC_3BUF_SIG(max, int16_t, int16_t) +FUNC_FUNC_3BUF_SIG(max, uint16_t, uint16_t) +FUNC_FUNC_3BUF_SIG(max, int32_t, int32_t) +FUNC_FUNC_3BUF_SIG(max, uint32_t, uint32_t) +FUNC_FUNC_3BUF_SIG(max, int64_t, int64_t) +FUNC_FUNC_3BUF_SIG(max, uint64_t, uint64_t) +FUNC_FUNC_3BUF_SIG(max, long, long) +FUNC_FUNC_3BUF_SIG(max, ulong, unsigned long) + +/* Floating point */ +#if 0 +#if defined(HAVE_SHORT_FLOAT) +FUNC_FUNC_3BUF_SIG(max, short_float, short float) +#elif defined(HAVE_OPAL_SHORT_FLOAT_T) +FUNC_FUNC_3BUF_SIG(max, short_float, opal_short_float_t) +#endif +#endif // 0 +FUNC_FUNC_3BUF_SIG(max, float, float) +FUNC_FUNC_3BUF_SIG(max, double, double) +FUNC_FUNC_3BUF_SIG(max, long_double, long double) + +/************************************************************************* + * Min + *************************************************************************/ + +/* C integer */ +FUNC_FUNC_3BUF_SIG(min, int8_t, int8_t) +FUNC_FUNC_3BUF_SIG(min, uint8_t, uint8_t) +FUNC_FUNC_3BUF_SIG(min, int16_t, int16_t) +FUNC_FUNC_3BUF_SIG(min, uint16_t, uint16_t) +FUNC_FUNC_3BUF_SIG(min, int32_t, int32_t) +FUNC_FUNC_3BUF_SIG(min, uint32_t, uint32_t) +FUNC_FUNC_3BUF_SIG(min, int64_t, int64_t) +FUNC_FUNC_3BUF_SIG(min, uint64_t, uint64_t) +FUNC_FUNC_3BUF_SIG(min, long, long) +FUNC_FUNC_3BUF_SIG(min, ulong, unsigned long) + +/* Floating point */ +#if 0 +#if defined(HAVE_SHORT_FLOAT) +FUNC_FUNC_3BUF_SIG(min, short_float, short float) +#elif defined(HAVE_OPAL_SHORT_FLOAT_T) +FUNC_FUNC_3BUF_SIG(min, short_float, opal_short_float_t) +#endif +#endif // 0 +FUNC_FUNC_3BUF_SIG(min, float, float) +FUNC_FUNC_3BUF_SIG(min, double, double) +FUNC_FUNC_3BUF_SIG(min, long_double, long double) + +/************************************************************************* + * Sum + *************************************************************************/ + +/* C integer */ +OP_FUNC_3BUF_SIG(sum, int8_t, int8_t) +OP_FUNC_3BUF_SIG(sum, uint8_t, uint8_t) +OP_FUNC_3BUF_SIG(sum, int16_t, int16_t) +OP_FUNC_3BUF_SIG(sum, uint16_t, uint16_t) +OP_FUNC_3BUF_SIG(sum, int32_t, int32_t) +OP_FUNC_3BUF_SIG(sum, uint32_t, uint32_t) +OP_FUNC_3BUF_SIG(sum, int64_t, int64_t) +OP_FUNC_3BUF_SIG(sum, uint64_t, uint64_t) +OP_FUNC_3BUF_SIG(sum, long, long) +OP_FUNC_3BUF_SIG(sum, ulong, unsigned long) + +/* Floating point */ +#if 0 +#if defined(HAVE_SHORT_FLOAT) +OP_FUNC_3BUF_SIG(sum, short_float, short float) +#elif defined(HAVE_OPAL_SHORT_FLOAT_T) +OP_FUNC_3BUF_SIG(sum, short_float, opal_short_float_t) +#endif +#endif // 0 +OP_FUNC_3BUF_SIG(sum, float, float) +OP_FUNC_3BUF_SIG(sum, double, double) +OP_FUNC_3BUF_SIG(sum, long_double, long double) +/* Complex */ +#if 0 +#if defined(HAVE_SHORT_FLOAT__COMPLEX) +OP_FUNC_3BUF_SIG(sum, c_short_float_complex, short float _Complex) +#endif +#endif // 0 +FUNC_FUNC_3BUF_SIG(sum, c_float_complex, hipFloatComplex) +FUNC_FUNC_3BUF_SIG(sum, c_double_complex, hipDoubleComplex) + +/************************************************************************* + * Product + *************************************************************************/ + +/* C integer */ +OP_FUNC_3BUF_SIG(prod, int8_t, int8_t) +OP_FUNC_3BUF_SIG(prod, uint8_t, uint8_t) +OP_FUNC_3BUF_SIG(prod, int16_t, int16_t) +OP_FUNC_3BUF_SIG(prod, uint16_t, uint16_t) +OP_FUNC_3BUF_SIG(prod, int32_t, int32_t) +OP_FUNC_3BUF_SIG(prod, uint32_t, uint32_t) +OP_FUNC_3BUF_SIG(prod, int64_t, int64_t) +OP_FUNC_3BUF_SIG(prod, uint64_t, uint64_t) +OP_FUNC_3BUF_SIG(prod, long, long) +OP_FUNC_3BUF_SIG(prod, ulong, unsigned long) + +/* Floating point */ +#if 0 +#if defined(HAVE_SHORT_FLOAT) +OP_FUNC_3BUF_SIG(prod, short_float, short float) +#endif +#endif // 0 +OP_FUNC_3BUF_SIG(prod, float, float) +OP_FUNC_3BUF_SIG(prod, double, double) +OP_FUNC_3BUF_SIG(prod, long_double, long double) + +/* Complex */ +#if 0 +#if defined(HAVE_SHORT_FLOAT__COMPLEX) +OP_FUNC_3BUF_SIG(prod, c_short_float_complex, short float _Complex) +#elif defined(HAVE_OPAL_SHORT_FLOAT_COMPLEX_T) +COMPLEX_PROD_FUNC_3BUF(c_short_float_complex, opal_short_float_t) +#endif +OP_FUNC_3BUF_SIG(prod, c_float_complex, float _Complex) +OP_FUNC_3BUF_SIG(prod, c_double_complex, double _Complex) +OP_FUNC_3BUF_SIG(prod, c_long_double_complex, long double _Complex) +#if defined(HAVE_SHORT_FLOAT__COMPLEX) +OP_FUNC_3BUF_SIG(prod, c_short_float_complex, short float _Complex) +#endif +#endif // 0 +FUNC_FUNC_3BUF_SIG(prod, c_float_complex, hipFloatComplex) +FUNC_FUNC_3BUF_SIG(prod, c_double_complex, hipDoubleComplex) + +/************************************************************************* + * Logical AND + *************************************************************************/ + +/* C integer */ +FUNC_FUNC_3BUF_SIG(land, int8_t, int8_t) +FUNC_FUNC_3BUF_SIG(land, uint8_t, uint8_t) +FUNC_FUNC_3BUF_SIG(land, int16_t, int16_t) +FUNC_FUNC_3BUF_SIG(land, uint16_t, uint16_t) +FUNC_FUNC_3BUF_SIG(land, int32_t, int32_t) +FUNC_FUNC_3BUF_SIG(land, uint32_t, uint32_t) +FUNC_FUNC_3BUF_SIG(land, int64_t, int64_t) +FUNC_FUNC_3BUF_SIG(land, uint64_t, uint64_t) +FUNC_FUNC_3BUF_SIG(land, long, long) +FUNC_FUNC_3BUF_SIG(land, ulong, unsigned long) + +/* C++ bool */ +FUNC_FUNC_3BUF_SIG(land, bool, bool) + +/************************************************************************* + * Logical OR + *************************************************************************/ + +/* C integer */ +FUNC_FUNC_3BUF_SIG(lor, int8_t, int8_t) +FUNC_FUNC_3BUF_SIG(lor, uint8_t, uint8_t) +FUNC_FUNC_3BUF_SIG(lor, int16_t, int16_t) +FUNC_FUNC_3BUF_SIG(lor, uint16_t, uint16_t) +FUNC_FUNC_3BUF_SIG(lor, int32_t, int32_t) +FUNC_FUNC_3BUF_SIG(lor, uint32_t, uint32_t) +FUNC_FUNC_3BUF_SIG(lor, int64_t, int64_t) +FUNC_FUNC_3BUF_SIG(lor, uint64_t, uint64_t) +FUNC_FUNC_3BUF_SIG(lor, long, long) +FUNC_FUNC_3BUF_SIG(lor, ulong, unsigned long) + +/* C++ bool */ +FUNC_FUNC_3BUF_SIG(lor, bool, bool) + +/************************************************************************* + * Logical XOR + *************************************************************************/ + +/* C integer */ +FUNC_FUNC_3BUF_SIG(lxor, int8_t, int8_t) +FUNC_FUNC_3BUF_SIG(lxor, uint8_t, uint8_t) +FUNC_FUNC_3BUF_SIG(lxor, int16_t, int16_t) +FUNC_FUNC_3BUF_SIG(lxor, uint16_t, uint16_t) +FUNC_FUNC_3BUF_SIG(lxor, int32_t, int32_t) +FUNC_FUNC_3BUF_SIG(lxor, uint32_t, uint32_t) +FUNC_FUNC_3BUF_SIG(lxor, int64_t, int64_t) +FUNC_FUNC_3BUF_SIG(lxor, uint64_t, uint64_t) +FUNC_FUNC_3BUF_SIG(lxor, long, long) +FUNC_FUNC_3BUF_SIG(lxor, ulong, unsigned long) + +/* C++ bool */ +FUNC_FUNC_3BUF_SIG(lxor, bool, bool) + +/************************************************************************* + * Bitwise AND + *************************************************************************/ + +/* C integer */ +FUNC_FUNC_3BUF_SIG(band, int8_t, int8_t) +FUNC_FUNC_3BUF_SIG(band, uint8_t, uint8_t) +FUNC_FUNC_3BUF_SIG(band, int16_t, int16_t) +FUNC_FUNC_3BUF_SIG(band, uint16_t, uint16_t) +FUNC_FUNC_3BUF_SIG(band, int32_t, int32_t) +FUNC_FUNC_3BUF_SIG(band, uint32_t, uint32_t) +FUNC_FUNC_3BUF_SIG(band, int64_t, int64_t) +FUNC_FUNC_3BUF_SIG(band, uint64_t, uint64_t) +FUNC_FUNC_3BUF_SIG(band, long, long) +FUNC_FUNC_3BUF_SIG(band, ulong, unsigned long) + +/* Byte */ +FUNC_FUNC_3BUF_SIG(band, byte, char) + +/************************************************************************* + * Bitwise OR + *************************************************************************/ + +/* C integer */ +FUNC_FUNC_3BUF_SIG(bor, int8_t, int8_t) +FUNC_FUNC_3BUF_SIG(bor, uint8_t, uint8_t) +FUNC_FUNC_3BUF_SIG(bor, int16_t, int16_t) +FUNC_FUNC_3BUF_SIG(bor, uint16_t, uint16_t) +FUNC_FUNC_3BUF_SIG(bor, int32_t, int32_t) +FUNC_FUNC_3BUF_SIG(bor, uint32_t, uint32_t) +FUNC_FUNC_3BUF_SIG(bor, int64_t, int64_t) +FUNC_FUNC_3BUF_SIG(bor, uint64_t, uint64_t) +FUNC_FUNC_3BUF_SIG(bor, long, long) +FUNC_FUNC_3BUF_SIG(bor, ulong, unsigned long) + +/* Byte */ +FUNC_FUNC_3BUF_SIG(bor, byte, char) + +/************************************************************************* + * Bitwise XOR + *************************************************************************/ + +/* C integer */ +FUNC_FUNC_3BUF_SIG(bxor, int8_t, int8_t) +FUNC_FUNC_3BUF_SIG(bxor, uint8_t, uint8_t) +FUNC_FUNC_3BUF_SIG(bxor, int16_t, int16_t) +FUNC_FUNC_3BUF_SIG(bxor, uint16_t, uint16_t) +FUNC_FUNC_3BUF_SIG(bxor, int32_t, int32_t) +FUNC_FUNC_3BUF_SIG(bxor, uint32_t, uint32_t) +FUNC_FUNC_3BUF_SIG(bxor, int64_t, int64_t) +FUNC_FUNC_3BUF_SIG(bxor, uint64_t, uint64_t) +FUNC_FUNC_3BUF_SIG(bxor, long, long) +FUNC_FUNC_3BUF_SIG(bxor, ulong, unsigned long) + +/* Byte */ +FUNC_FUNC_3BUF_SIG(bxor, byte, char) + +/************************************************************************* + * Max location + *************************************************************************/ + +LOC_FUNC_3BUF_SIG(maxloc, float_int) +LOC_FUNC_3BUF_SIG(maxloc, double_int) +LOC_FUNC_3BUF_SIG(maxloc, long_int) +LOC_FUNC_3BUF_SIG(maxloc, 2int) +LOC_FUNC_3BUF_SIG(maxloc, short_int) +LOC_FUNC_3BUF_SIG(maxloc, long_double_int) + +LOC_FUNC_3BUF_SIG(maxloc, 2float) +LOC_FUNC_3BUF_SIG(maxloc, 2double) +LOC_FUNC_3BUF_SIG(maxloc, 2int8) +LOC_FUNC_3BUF_SIG(maxloc, 2int16) +LOC_FUNC_3BUF_SIG(maxloc, 2int32) +LOC_FUNC_3BUF_SIG(maxloc, 2int64) + +/************************************************************************* + * Min location + *************************************************************************/ + +LOC_FUNC_3BUF_SIG(minloc, float_int) +LOC_FUNC_3BUF_SIG(minloc, double_int) +LOC_FUNC_3BUF_SIG(minloc, long_int) +LOC_FUNC_3BUF_SIG(minloc, 2int) +LOC_FUNC_3BUF_SIG(minloc, short_int) +LOC_FUNC_3BUF_SIG(minloc, long_double_int) + +LOC_FUNC_3BUF_SIG(minloc, 2float) +LOC_FUNC_3BUF_SIG(minloc, 2double) +LOC_FUNC_3BUF_SIG(minloc, 2int8) +LOC_FUNC_3BUF_SIG(minloc, 2int16) +LOC_FUNC_3BUF_SIG(minloc, 2int32) +LOC_FUNC_3BUF_SIG(minloc, 2int64) + +END_C_DECLS diff --git a/ompi/mca/op/rocm/op_rocm_impl.hip b/ompi/mca/op/rocm/op_rocm_impl.hip new file mode 100644 index 00000000000..45a6eee4349 --- /dev/null +++ b/ompi/mca/op/rocm/op_rocm_impl.hip @@ -0,0 +1,1085 @@ +/* + * Copyright (c) 2019-2023 The University of Tennessee and The University + * of Tennessee Research Foundation. All rights + * reserved. + * Copyright (c) 2020 Research Organization for Information Science + * and Technology (RIST). All rights reserved. + * $COPYRIGHT$ + * + * Additional copyrights may follow + * + * $HEADER$ + */ + +#include "hip/hip_runtime.h" +#include +#include + +#include + +#include "op_rocm_impl.h" + +//#define DO_NOT_USE_INTRINSICS 1 +#define USE_VECTORS 1 + +#include + +#define ISSIGNED(x) std::is_signed_v + +template +static inline __device__ constexpr T tmax(T a, T b) { + return (a > b) ? a : b; +} + +template +static inline __device__ constexpr T tmin(T a, T b) { + return (a < b) ? a : b; +} + +template +static inline __device__ constexpr T tsum(T a, T b) { + return a+b; +} + +template +static inline __device__ constexpr T tprod(T a, T b) { + return a*b; +} + +template +static inline __device__ T vmax(const T& a, const T& b) { + return T{tmax(a.x, b.x), tmax(a.y, b.y), tmax(a.z, b.z), tmax(a.w, b.w)}; +} + +template +static inline __device__ T vmin(const T& a, const T& b) { + return T{tmin(a.x, b.x), tmin(a.y, b.y), tmin(a.z, b.z), tmin(a.w, b.w)}; +} + +template +static inline __device__ T vsum(const T& a, const T& b) { + return T{tsum(a.x, b.x), tsum(a.y, b.y), tsum(a.z, b.z), tsum(a.w, b.w)}; +} + +template +static inline __device__ T vprod(const T& a, const T& b) { + return T{(a.x * b.x), (a.y * b.y), (a.z * b.z), (a.w * b.w)}; +} + + +/* TODO: missing support for + * - short float (conditional on whether short float is available) + * - complex + */ + +#define VECLEN 2 +#define VECTYPE(t) t##VECLEN + +#define OP_FUNC(name, type_name, type, op) \ + static __global__ void \ + ompi_op_rocm_2buff_##name##_##type_name##_kernel(const type *__restrict__ in, \ + type *__restrict__ inout, int n) { \ + const int index = blockIdx.x * blockDim.x + threadIdx.x; \ + const int stride = blockDim.x * gridDim.x; \ + for (int i = index; i < n; i += stride) { \ + inout[i] = inout[i] op in[i]; \ + } \ + } \ + void ompi_op_rocm_2buff_##name##_##type_name##_submit(const type *in, \ + type *inout, \ + int count, \ + int threads_per_block, \ + int max_blocks, \ + hipStream_t stream) { \ + int threads = min(threads_per_block, count); \ + int blocks = min((count + threads-1) / threads, max_blocks); \ + int n = count; \ + hipStream_t s = stream; \ + hipLaunchKernelGGL(ompi_op_rocm_2buff_##name##_##type_name##_kernel, \ + dim3(blocks), dim3(threads), 0, s, \ + in, inout, n); \ + } + +#if defined(USE_VECTORS) +#define OPV_FUNC(name, type_name, type, vtype, vlen, op) \ + static __global__ void \ + ompi_op_rocm_2buff_##name##_##type_name##_kernel(const type *__restrict__ in, \ + type *__restrict__ inout, int n) { \ + const int index = blockIdx.x * blockDim.x + threadIdx.x; \ + const int stride = blockDim.x * gridDim.x; \ + for (int i = index; i < n/vlen; i += stride) { \ + ((vtype*)inout)[i] = ((vtype*)inout)[i] op ((vtype*)in)[i]; \ + } \ + int remainder = n%vlen; \ + if (index == (n/vlen) && remainder != 0) { \ + while(remainder) { \ + int idx = n - remainder--; \ + inout[idx] = inout[idx] op in[idx]; \ + } \ + } \ + } \ + void ompi_op_rocm_2buff_##name##_##type_name##_submit(const type *in, \ + type *inout, \ + int count, \ + int threads_per_block, \ + int max_blocks, \ + hipStream_t stream) { \ + int vcount = (count + vlen-1)/vlen; \ + int threads = min(threads_per_block, vcount); \ + int blocks = min((vcount + threads-1) / threads, max_blocks); \ + int n = count; \ + hipStream_t s = stream; \ + hipLaunchKernelGGL(ompi_op_rocm_2buff_##name##_##type_name##_kernel, \ + dim3(blocks), dim3(threads), 0, s, \ + in, inout, n); \ + } +#else // USE_VECTORS +#define OPV_FUNC(name, type_name, type, vtype, vlen, op) OP_FUNC(name, type_name, type, op) +#endif // USE_VECTORS + + +#define FUNC_FUNC(name, type_name, type) \ + static __global__ void \ + ompi_op_rocm_2buff_##name##_##type_name##_kernel(const type *__restrict__ in, \ + type *__restrict__ inout, int n) { \ + const int index = blockIdx.x * blockDim.x + threadIdx.x; \ + const int stride = blockDim.x * gridDim.x; \ + for (int i = index; i < n; i += stride) { \ + inout[i] = current_func(inout[i], in[i]); \ + } \ + } \ + void \ + ompi_op_rocm_2buff_##name##_##type_name##_submit(const type *in, \ + type *inout, \ + int count, \ + int threads_per_block, \ + int max_blocks, \ + hipStream_t stream) { \ + int threads = min(threads_per_block, count); \ + int blocks = min((count + threads-1) / threads, max_blocks); \ + int n = count; \ + hipStream_t s = stream; \ + hipLaunchKernelGGL(ompi_op_rocm_2buff_##name##_##type_name##_kernel, \ + dim3(blocks), dim3(threads), 0, s, \ + in, inout, n); \ + } + + +#if defined(USE_VECTORS) +#define VFUNC_FUNC(name, type_name, type, vtype, vlen, vfn, fn) \ + static __global__ void \ + ompi_op_rocm_2buff_##name##_##type_name##_kernel(const type *__restrict__ in, \ + type *__restrict__ inout, int n) { \ + const int index = blockIdx.x * blockDim.x + threadIdx.x; \ + const int stride = blockDim.x * gridDim.x; \ + for (int i = index; i < n/vlen; i += stride) { \ + ((vtype*)inout)[i] = vfn(((vtype*)inout)[i], ((vtype*)in)[i]); \ + } \ + int remainder = n%vlen; \ + if (index == (n/vlen) && remainder != 0) { \ + while(remainder) { \ + int idx = n - remainder--; \ + inout[idx] = fn(inout[idx], in[idx]); \ + } \ + } \ + } \ + static void \ + ompi_op_rocm_2buff_##name##_##type_name##_submit(const type *in, \ + type *inout, \ + int count, \ + int threads_per_block, \ + int max_blocks, \ + hipStream_t stream) { \ + int vcount = (count + vlen-1)/vlen; \ + int threads = min(threads_per_block, vcount); \ + int blocks = min((vcount + threads-1) / threads, max_blocks); \ + int n = count; \ + hipStream_t s = stream; \ + ompi_op_rocm_2buff_##name##_##type_name##_kernel<<>>(in, inout, n); \ + } +#else +#define VFUNC_FUNC(name, type_name, type, vtype, vlen, vfn, fn) FUNC_FUNC_FN(name, type_name, type, fn) +#endif // defined(USE_VECTORS) + +#define FUNC_FUNC_FN(name, type_name, type, fn) \ + static __global__ void \ + ompi_op_rocm_2buff_##name##_##type_name##_kernel(const type *__restrict__ in, \ + type *__restrict__ inout, int n) { \ + const int index = blockIdx.x * blockDim.x + threadIdx.x; \ + const int stride = blockDim.x * gridDim.x; \ + for (int i = index; i < n; i += stride) { \ + inout[i] = fn(inout[i], in[i]); \ + } \ + } \ + void \ + ompi_op_rocm_2buff_##name##_##type_name##_submit(const type *in, \ + type *inout, \ + int count, \ + int threads_per_block, \ + int max_blocks, \ + hipStream_t stream) { \ + int threads = min(count, threads_per_block); \ + int blocks = min((count + threads-1) / threads, max_blocks); \ + int n = count; \ + hipStream_t s = stream; \ + ompi_op_rocm_2buff_##name##_##type_name##_kernel<<>>(in, inout, n); \ + } + +/* + * Since all the functions in this file are essentially identical, we + * use a macro to substitute in names and types. The core operation + * in all functions that use this macro is the same. + * + * This macro is for minloc and maxloc + */ + +#define LOC_FUNC(name, type_name, op) \ + static __global__ void \ + ompi_op_rocm_2buff_##name##_##type_name##_kernel(const ompi_op_predefined_##type_name##_t *__restrict__ in, \ + ompi_op_predefined_##type_name##_t *__restrict__ inout, \ + int n) \ + { \ + const int index = blockIdx.x * blockDim.x + threadIdx.x; \ + const int stride = blockDim.x * gridDim.x; \ + for (int i = index; i < n; i += stride) { \ + const ompi_op_predefined_##type_name##_t *a = &in[i]; \ + ompi_op_predefined_##type_name##_t *b = &inout[i]; \ + if (a->v op b->v) { \ + b->v = a->v; \ + b->k = a->k; \ + } else if (a->v == b->v) { \ + b->k = (b->k < a->k ? b->k : a->k); \ + } \ + } \ + } \ + void \ + ompi_op_rocm_2buff_##name##_##type_name##_submit(const ompi_op_predefined_##type_name##_t *a, \ + ompi_op_predefined_##type_name##_t *b, \ + int count, \ + int threads_per_block, \ + int max_blocks, \ + hipStream_t stream) { \ + int threads = min(threads_per_block, count); \ + int blocks = min((count + threads-1) / threads, max_blocks); \ + hipStream_t s = stream; \ + hipLaunchKernelGGL(ompi_op_rocm_2buff_##name##_##type_name##_kernel, \ + dim3(blocks), dim3(threads), 0, s, \ + a, b, count); \ + } + + +#define OPV_DISPATCH(name, type_name, type) \ + void ompi_op_rocm_2buff_##name##_##type_name##_submit(const type *in, \ + type *inout, \ + int count, \ + int threads_per_block, \ + int max_blocks, \ + hipStream_t stream) { \ + static_assert(sizeof(type_name) <= sizeof(unsigned long long), "Unknown size type"); \ + if constexpr(!ISSIGNED(type)) { \ + if constexpr(sizeof(type_name) == sizeof(unsigned char)) { \ + ompi_op_rocm_2buff_##name##_uchar_submit((const unsigned char*)in, (unsigned char*)inout, count, \ + threads_per_block, \ + max_blocks, stream); \ + } else if constexpr(sizeof(type_name) == sizeof(unsigned short)) { \ + ompi_op_rocm_2buff_##name##_ushort_submit((const unsigned short*)in, (unsigned short*)inout, count, \ + threads_per_block, \ + max_blocks, stream); \ + } else if constexpr(sizeof(type_name) == sizeof(unsigned int)) { \ + ompi_op_rocm_2buff_##name##_uint_submit((const unsigned int*)in, (unsigned int*)inout, count, \ + threads_per_block, \ + max_blocks, stream); \ + } else if constexpr(sizeof(type_name) == sizeof(unsigned long)) { \ + ompi_op_rocm_2buff_##name##_ulong_submit((const unsigned long*)in, (unsigned long*)inout, count, \ + threads_per_block, \ + max_blocks, stream); \ + } else if constexpr(sizeof(type_name) == sizeof(unsigned long long)) { \ + ompi_op_rocm_2buff_##name##_ulonglong_submit((const unsigned long long*)in, (unsigned long long*)inout, count, \ + threads_per_block, \ + max_blocks, stream); \ + } \ + } else { \ + if constexpr(sizeof(type_name) == sizeof(char)) { \ + ompi_op_rocm_2buff_##name##_char_submit((const char*)in, (char*)inout, count, \ + threads_per_block, \ + max_blocks, stream); \ + } else if constexpr(sizeof(type_name) == sizeof(short)) { \ + ompi_op_rocm_2buff_##name##_short_submit((const short*)in, (short*)inout, count, \ + threads_per_block, \ + max_blocks, stream); \ + } else if constexpr(sizeof(type_name) == sizeof(int)) { \ + ompi_op_rocm_2buff_##name##_int_submit((const int*)in, (int*)inout, count, \ + threads_per_block, \ + max_blocks, stream); \ + } else if constexpr(sizeof(type_name) == sizeof(long)) { \ + ompi_op_rocm_2buff_##name##_long_submit((const long*)in, (long*)inout, count, \ + threads_per_block, \ + max_blocks, stream); \ + } else if constexpr(sizeof(type_name) == sizeof(long long)) { \ + ompi_op_rocm_2buff_##name##_longlong_submit((const long long*)in, (long long*)inout, count,\ + threads_per_block, \ + max_blocks, stream); \ + } \ + } \ + } + +/************************************************************************* + * Max + *************************************************************************/ + +/* C integer */ +VFUNC_FUNC(max, char, char, char4, 4, vmax, max) +VFUNC_FUNC(max, uchar, unsigned char, uchar4, 4, vmax, max) +VFUNC_FUNC(max, short, short, short4, 4, vmax, max) +VFUNC_FUNC(max, ushort, unsigned short, ushort4, 4, vmax, max) +VFUNC_FUNC(max, int, int, int4, 4, vmax, max) +VFUNC_FUNC(max, uint, unsigned int, uint4, 4, vmax, max) + +#undef current_func +#define current_func(a, b) max(a, b) +FUNC_FUNC(max, long, long) +FUNC_FUNC(max, ulong, unsigned long) +FUNC_FUNC(max, longlong, long long) +FUNC_FUNC(max, ulonglong, unsigned long long) + + +/* dispatch fixed-size types */ +OPV_DISPATCH(max, int8_t, int8_t) +OPV_DISPATCH(max, uint8_t, uint8_t) +OPV_DISPATCH(max, int16_t, int16_t) +OPV_DISPATCH(max, uint16_t, uint16_t) +OPV_DISPATCH(max, int32_t, int32_t) +OPV_DISPATCH(max, uint32_t, uint32_t) +OPV_DISPATCH(max, int64_t, int64_t) +OPV_DISPATCH(max, uint64_t, uint64_t) + +#if !defined(DO_NOT_USE_INTRINSICS) +#undef current_func +#define current_func(a, b) fmaxf(a, b) +#endif // DO_NOT_USE_INTRINSICS +FUNC_FUNC(max, float, float) + +#if !defined(DO_NOT_USE_INTRINSICS) +#undef current_func +#define current_func(a, b) fmax(a, b) +#endif // DO_NOT_USE_INTRINSICS +FUNC_FUNC(max, double, double) + +#undef current_func +#define current_func(a, b) ((a) > (b) ? (a) : (b)) +FUNC_FUNC(max, long_double, long double) + +/************************************************************************* + * Min + *************************************************************************/ + +/* C integer */ +VFUNC_FUNC(min, char, char, char4, 4, vmin, min) +VFUNC_FUNC(min, uchar, unsigned char, uchar4, 4, vmin, min) +VFUNC_FUNC(min, short, short, short4, 4, vmin, min) +VFUNC_FUNC(min, ushort, unsigned short, ushort4, 4, vmin, min) +VFUNC_FUNC(min, int, int, int4, 4, vmin, min) +VFUNC_FUNC(min, uint, unsigned int, uint4, 4, vmin, min) + +#undef current_func +#define current_func(a, b) min(a, b) +FUNC_FUNC(min, long, long) +FUNC_FUNC(min, ulong, unsigned long) +FUNC_FUNC(min, longlong, long long) +FUNC_FUNC(min, ulonglong, unsigned long long) +OPV_DISPATCH(min, int8_t, int8_t) +OPV_DISPATCH(min, uint8_t, uint8_t) +OPV_DISPATCH(min, int16_t, int16_t) +OPV_DISPATCH(min, uint16_t, uint16_t) +OPV_DISPATCH(min, int32_t, int32_t) +OPV_DISPATCH(min, uint32_t, uint32_t) +OPV_DISPATCH(min, int64_t, int64_t) +OPV_DISPATCH(min, uint64_t, uint64_t) + +#if !defined(DO_NOT_USE_INTRINSICS) +#undef current_func +#define current_func(a, b) fminf(a, b) +#endif // DO_NOT_USE_INTRINSICS +FUNC_FUNC(min, float, float) + +#if !defined(DO_NOT_USE_INTRINSICS) +#undef current_func +#define current_func(a, b) fmin(a, b) +#endif // DO_NOT_USE_INTRINSICS +FUNC_FUNC(min, double, double) + +#undef current_func +#define current_func(a, b) ((a) < (b) ? (a) : (b)) +FUNC_FUNC(min, long_double, long double) + + +/************************************************************************* + * Sum + *************************************************************************/ + +/* C integer */ +VFUNC_FUNC(sum, char, char, char4, 4, vsum, tsum) +VFUNC_FUNC(sum, uchar, unsigned char, uchar4, 4, vsum, tsum) +VFUNC_FUNC(sum, short, short, short4, 4, vsum, tsum) +VFUNC_FUNC(sum, ushort, unsigned short, ushort4, 4, vsum, tsum) +VFUNC_FUNC(sum, int, int, int4, 4, vsum, tsum) +VFUNC_FUNC(sum, uint, unsigned int, uint4, 4, vsum, tsum) + +#undef current_func +#define current_func(a, b) tsum(a, b) +FUNC_FUNC(sum, long, long) +FUNC_FUNC(sum, ulong, unsigned long) +FUNC_FUNC(sum, longlong, long long) +FUNC_FUNC(sum, ulonglong, unsigned long long) + +OPV_DISPATCH(sum, int8_t, int8_t) +OPV_DISPATCH(sum, uint8_t, uint8_t) +OPV_DISPATCH(sum, int16_t, int16_t) +OPV_DISPATCH(sum, uint16_t, uint16_t) +OPV_DISPATCH(sum, int32_t, int32_t) +OPV_DISPATCH(sum, uint32_t, uint32_t) +OPV_DISPATCH(sum, int64_t, int64_t) +OPV_DISPATCH(sum, uint64_t, uint64_t) + +OPV_FUNC(sum, float, float, float4, 4, +) +OPV_FUNC(sum, double, double, double4, 4, +) +OP_FUNC(sum, long_double, long double, +) + +/* Complex */ +#if 0 +#if defined(HAVE_SHORT_FLOAT__COMPLEX) +OP_FUNC(sum, c_short_float_complex, short float _Complex, +=) +#elif defined(HAVE_OPAL_SHORT_FLOAT_COMPLEX_T) +COMPLEX_SUM_FUNC(c_short_float_complex, opal_short_float_t) +#endif +OP_FUNC(sum, c_long_double_complex, cuLongDoubleComplex, +=) +#endif // 0 +#undef current_func +#define current_func(a, b) (hipCaddf(a,b)) +FUNC_FUNC(sum, c_float_complex, hipFloatComplex) +#undef current_func +#define current_func(a, b) (hipCadd(a,b)) +FUNC_FUNC(sum, c_double_complex, hipDoubleComplex) + +/************************************************************************* + * Product + *************************************************************************/ + +/* C integer */ +#undef current_func +#define current_func(a, b) tprod(a, b) +FUNC_FUNC(prod, char, char) +FUNC_FUNC(prod, uchar, unsigned char) +FUNC_FUNC(prod, short, short) +FUNC_FUNC(prod, ushort, unsigned short) +FUNC_FUNC(prod, int, int) +FUNC_FUNC(prod, uint, unsigned int) +FUNC_FUNC(prod, long, long) +FUNC_FUNC(prod, ulong, unsigned long) +FUNC_FUNC(prod, longlong, long long) +FUNC_FUNC(prod, ulonglong, unsigned long long) + +OPV_DISPATCH(prod, int8_t, int8_t) +OPV_DISPATCH(prod, uint8_t, uint8_t) +OPV_DISPATCH(prod, int16_t, int16_t) +OPV_DISPATCH(prod, uint16_t, uint16_t) +OPV_DISPATCH(prod, int32_t, int32_t) +OPV_DISPATCH(prod, uint32_t, uint32_t) +OPV_DISPATCH(prod, int64_t, int64_t) +OPV_DISPATCH(prod, uint64_t, uint64_t) + + +OPV_FUNC(prod, float, float, float4, 4, *) +OPV_FUNC(prod, double, double, double4, 4, *) +OP_FUNC(prod, long_double, long double, *) + +/* Complex */ +#if 0 +#if defined(HAVE_SHORT_FLOAT__COMPLEX) +OP_FUNC(sum, c_short_float_complex, short float _Complex, +=) +#elif defined(HAVE_OPAL_SHORT_FLOAT_COMPLEX_T) +COMPLEX_SUM_FUNC(c_short_float_complex, opal_short_float_t) +#endif +OP_FUNC(sum, c_long_double_complex, cuLongDoubleComplex, +=) +#endif // 0 +#undef current_func +#define current_func(a, b) (hipCmulf(a,b)) +FUNC_FUNC(prod, c_float_complex, hipFloatComplex) +#undef current_func +#define current_func(a, b) (hipCmul(a,b)) +FUNC_FUNC(prod, c_double_complex, hipDoubleComplex) + +/************************************************************************* + * Logical AND + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) && (b)) +/* C integer */ +FUNC_FUNC(land, int8_t, int8_t) +FUNC_FUNC(land, uint8_t, uint8_t) +FUNC_FUNC(land, int16_t, int16_t) +FUNC_FUNC(land, uint16_t, uint16_t) +FUNC_FUNC(land, int32_t, int32_t) +FUNC_FUNC(land, uint32_t, uint32_t) +FUNC_FUNC(land, int64_t, int64_t) +FUNC_FUNC(land, uint64_t, uint64_t) +FUNC_FUNC(land, long, long) +FUNC_FUNC(land, ulong, unsigned long) + +/* C++ bool */ +FUNC_FUNC(land, bool, bool) + +/************************************************************************* + * Logical OR + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) || (b)) +/* C integer */ +FUNC_FUNC(lor, int8_t, int8_t) +FUNC_FUNC(lor, uint8_t, uint8_t) +FUNC_FUNC(lor, int16_t, int16_t) +FUNC_FUNC(lor, uint16_t, uint16_t) +FUNC_FUNC(lor, int32_t, int32_t) +FUNC_FUNC(lor, uint32_t, uint32_t) +FUNC_FUNC(lor, int64_t, int64_t) +FUNC_FUNC(lor, uint64_t, uint64_t) +FUNC_FUNC(lor, long, long) +FUNC_FUNC(lor, ulong, unsigned long) + +/* C++ bool */ +FUNC_FUNC(lor, bool, bool) + +/************************************************************************* + * Logical XOR + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a ? 1 : 0) ^ (b ? 1: 0)) +/* C integer */ +FUNC_FUNC(lxor, int8_t, int8_t) +FUNC_FUNC(lxor, uint8_t, uint8_t) +FUNC_FUNC(lxor, int16_t, int16_t) +FUNC_FUNC(lxor, uint16_t, uint16_t) +FUNC_FUNC(lxor, int32_t, int32_t) +FUNC_FUNC(lxor, uint32_t, uint32_t) +FUNC_FUNC(lxor, int64_t, int64_t) +FUNC_FUNC(lxor, uint64_t, uint64_t) +FUNC_FUNC(lxor, long, long) +FUNC_FUNC(lxor, ulong, unsigned long) + +/* C++ bool */ +FUNC_FUNC(lxor, bool, bool) + +/************************************************************************* + * Bitwise AND + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) & (b)) +/* C integer */ +FUNC_FUNC(band, int8_t, int8_t) +FUNC_FUNC(band, uint8_t, uint8_t) +FUNC_FUNC(band, int16_t, int16_t) +FUNC_FUNC(band, uint16_t, uint16_t) +FUNC_FUNC(band, int32_t, int32_t) +FUNC_FUNC(band, uint32_t, uint32_t) +FUNC_FUNC(band, int64_t, int64_t) +FUNC_FUNC(band, uint64_t, uint64_t) +FUNC_FUNC(band, long, long) +FUNC_FUNC(band, ulong, unsigned long) + +/* Byte */ +FUNC_FUNC(band, byte, char) + +/************************************************************************* + * Bitwise OR + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) | (b)) +/* C integer */ +FUNC_FUNC(bor, int8_t, int8_t) +FUNC_FUNC(bor, uint8_t, uint8_t) +FUNC_FUNC(bor, int16_t, int16_t) +FUNC_FUNC(bor, uint16_t, uint16_t) +FUNC_FUNC(bor, int32_t, int32_t) +FUNC_FUNC(bor, uint32_t, uint32_t) +FUNC_FUNC(bor, int64_t, int64_t) +FUNC_FUNC(bor, uint64_t, uint64_t) +FUNC_FUNC(bor, long, long) +FUNC_FUNC(bor, ulong, unsigned long) + +/* Byte */ +FUNC_FUNC(bor, byte, char) + +/************************************************************************* + * Bitwise XOR + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) ^ (b)) +/* C integer */ +FUNC_FUNC(bxor, int8_t, int8_t) +FUNC_FUNC(bxor, uint8_t, uint8_t) +FUNC_FUNC(bxor, int16_t, int16_t) +FUNC_FUNC(bxor, uint16_t, uint16_t) +FUNC_FUNC(bxor, int32_t, int32_t) +FUNC_FUNC(bxor, uint32_t, uint32_t) +FUNC_FUNC(bxor, int64_t, int64_t) +FUNC_FUNC(bxor, uint64_t, uint64_t) +FUNC_FUNC(bxor, long, long) +FUNC_FUNC(bxor, ulong, unsigned long) + +/* Byte */ +FUNC_FUNC(bxor, byte, char) + +/************************************************************************* + * Max location + *************************************************************************/ + +LOC_FUNC(maxloc, float_int, >) +LOC_FUNC(maxloc, double_int, >) +LOC_FUNC(maxloc, long_int, >) +LOC_FUNC(maxloc, 2int, >) +LOC_FUNC(maxloc, short_int, >) +LOC_FUNC(maxloc, long_double_int, >) + +/* Fortran compat types */ +LOC_FUNC(maxloc, 2float, >) +LOC_FUNC(maxloc, 2double, >) +LOC_FUNC(maxloc, 2int8, >) +LOC_FUNC(maxloc, 2int16, >) +LOC_FUNC(maxloc, 2int32, >) +LOC_FUNC(maxloc, 2int64, >) + +/************************************************************************* + * Min location + *************************************************************************/ + +LOC_FUNC(minloc, float_int, <) +LOC_FUNC(minloc, double_int, <) +LOC_FUNC(minloc, long_int, <) +LOC_FUNC(minloc, 2int, <) +LOC_FUNC(minloc, short_int, <) +LOC_FUNC(minloc, long_double_int, <) + +/* Fortran compat types */ +LOC_FUNC(minloc, 2float, <) +LOC_FUNC(minloc, 2double, <) +LOC_FUNC(minloc, 2int8, <) +LOC_FUNC(minloc, 2int16, <) +LOC_FUNC(minloc, 2int32, <) +LOC_FUNC(minloc, 2int64, <) + + +/* + * This is a three buffer (2 input and 1 output) version of the reduction + * routines, needed for some optimizations. + */ +#define OP_FUNC_3BUF(name, type_name, type, op) \ + static __global__ void \ + ompi_op_rocm_3buff_##name##_##type_name##_kernel(const type *__restrict__ in1, \ + const type *__restrict__ in2, \ + type *__restrict__ out, int n) { \ + const int index = blockIdx.x * blockDim.x + threadIdx.x; \ + const int stride = blockDim.x * gridDim.x; \ + for (int i = index; i < n; i += stride) { \ + out[i] = in1[i] op in2[i]; \ + } \ + } \ + void ompi_op_rocm_3buff_##name##_##type_name##_submit(const type *in1, const type *in2, \ + type *out, int count, \ + int threads_per_block, \ + int max_blocks, \ + hipStream_t stream) { \ + int threads = min(threads_per_block, count); \ + int blocks = min((count + threads-1) / threads, max_blocks); \ + hipLaunchKernelGGL(ompi_op_rocm_3buff_##name##_##type_name##_kernel, \ + dim3(blocks), dim3(threads), 0, stream, \ + in1, in2, out, count); \ + } + + +/* + * Since all the functions in this file are essentially identical, we + * use a macro to substitute in names and types. The core operation + * in all functions that use this macro is the same. + * + * This macro is for (out = op(in1, in2)) + */ +#define FUNC_FUNC_3BUF(name, type_name, type) \ + static __global__ void \ + ompi_op_rocm_3buff_##name##_##type_name##_kernel(const type *__restrict__ in1, \ + const type *__restrict__ in2, \ + type *__restrict__ out, int n) { \ + const int index = blockIdx.x * blockDim.x + threadIdx.x; \ + const int stride = blockDim.x * gridDim.x; \ + for (int i = index; i < n; i += stride) { \ + out[i] = current_func(in1[i], in2[i]); \ + } \ + } \ + void \ + ompi_op_rocm_3buff_##name##_##type_name##_submit(const type *in1, const type *in2, \ + type *out, int count, \ + int threads_per_block, \ + int max_blocks, \ + hipStream_t stream) { \ + int threads = min(threads_per_block, count); \ + int blocks = min((count + threads-1) / threads, max_blocks); \ + hipLaunchKernelGGL(ompi_op_rocm_3buff_##name##_##type_name##_kernel, \ + dim3(blocks), dim3(threads), 0, stream, \ + in1, in2, out, count); \ + } + +/* + * Since all the functions in this file are essentially identical, we + * use a macro to substitute in names and types. The core operation + * in all functions that use this macro is the same. + * + * This macro is for minloc and maxloc + */ +/* +#define LOC_STRUCT(type_name, type1, type2) \ + typedef struct { \ + type1 v; \ + type2 k; \ + } ompi_op_predefined_##type_name##_t; +*/ + +#define LOC_FUNC_3BUF(name, type_name, op) \ + static __global__ void \ + ompi_op_rocm_3buff_##name##_##type_name##_kernel(const ompi_op_predefined_##type_name##_t *__restrict__ in1, \ + const ompi_op_predefined_##type_name##_t *__restrict__ in2, \ + ompi_op_predefined_##type_name##_t *__restrict__ out, \ + int n) \ + { \ + const int index = blockIdx.x * blockDim.x + threadIdx.x; \ + const int stride = blockDim.x * gridDim.x; \ + for (int i = index; i < n; i += stride) { \ + const ompi_op_predefined_##type_name##_t *a1 = &in1[i]; \ + const ompi_op_predefined_##type_name##_t *a2 = &in2[i]; \ + ompi_op_predefined_##type_name##_t *b = &out[i]; \ + if (a1->v op a2->v) { \ + b->v = a1->v; \ + b->k = a1->k; \ + } else if (a1->v == a2->v) { \ + b->v = a1->v; \ + b->k = (a2->k < a1->k ? a2->k : a1->k); \ + } else { \ + b->v = a2->v; \ + b->k = a2->k; \ + } \ + } \ + } \ + void \ + ompi_op_rocm_3buff_##name##_##type_name##_submit(const ompi_op_predefined_##type_name##_t *__restrict__ in1, \ + const ompi_op_predefined_##type_name##_t *__restrict__ in2, \ + ompi_op_predefined_##type_name##_t *__restrict__ out, \ + int count, \ + int threads_per_block, \ + int max_blocks, \ + hipStream_t stream) \ + { \ + int threads = min(threads_per_block, count); \ + int blocks = min((count + threads-1) / threads, max_blocks); \ + hipLaunchKernelGGL(ompi_op_rocm_3buff_##name##_##type_name##_kernel, \ + dim3(blocks), dim3(threads), 0, stream, \ + in1, in2, out, count); \ + } + + +/************************************************************************* + * Max + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) > (b) ? (a) : (b)) +/* C integer */ +FUNC_FUNC_3BUF(max, int8_t, int8_t) +FUNC_FUNC_3BUF(max, uint8_t, uint8_t) +FUNC_FUNC_3BUF(max, int16_t, int16_t) +FUNC_FUNC_3BUF(max, uint16_t, uint16_t) +FUNC_FUNC_3BUF(max, int32_t, int32_t) +FUNC_FUNC_3BUF(max, uint32_t, uint32_t) +FUNC_FUNC_3BUF(max, int64_t, int64_t) +FUNC_FUNC_3BUF(max, uint64_t, uint64_t) +FUNC_FUNC_3BUF(max, long, long) +FUNC_FUNC_3BUF(max, ulong, unsigned long) + +/* Floating point */ +#if defined(HAVE_SHORT_FLOAT) +FUNC_FUNC_3BUF(max, short_float, short float) +#elif defined(HAVE_OPAL_SHORT_FLOAT_T) +FUNC_FUNC_3BUF(max, short_float, opal_short_float_t) +#endif +FUNC_FUNC_3BUF(max, float, float) +FUNC_FUNC_3BUF(max, double, double) +FUNC_FUNC_3BUF(max, long_double, long double) + +/************************************************************************* + * Min + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) < (b) ? (a) : (b)) +/* C integer */ +FUNC_FUNC_3BUF(min, int8_t, int8_t) +FUNC_FUNC_3BUF(min, uint8_t, uint8_t) +FUNC_FUNC_3BUF(min, int16_t, int16_t) +FUNC_FUNC_3BUF(min, uint16_t, uint16_t) +FUNC_FUNC_3BUF(min, int32_t, int32_t) +FUNC_FUNC_3BUF(min, uint32_t, uint32_t) +FUNC_FUNC_3BUF(min, int64_t, int64_t) +FUNC_FUNC_3BUF(min, uint64_t, uint64_t) +FUNC_FUNC_3BUF(min, long, long) +FUNC_FUNC_3BUF(min, ulong, unsigned long) + +/* Floating point */ +#if defined(HAVE_SHORT_FLOAT) +FUNC_FUNC_3BUF(min, short_float, short float) +#elif defined(HAVE_OPAL_SHORT_FLOAT_T) +FUNC_FUNC_3BUF(min, short_float, opal_short_float_t) +#endif +FUNC_FUNC_3BUF(min, float, float) +FUNC_FUNC_3BUF(min, double, double) +FUNC_FUNC_3BUF(min, long_double, long double) + +/************************************************************************* + * Sum + *************************************************************************/ + +/* C integer */ +OP_FUNC_3BUF(sum, int8_t, int8_t, +) +OP_FUNC_3BUF(sum, uint8_t, uint8_t, +) +OP_FUNC_3BUF(sum, int16_t, int16_t, +) +OP_FUNC_3BUF(sum, uint16_t, uint16_t, +) +OP_FUNC_3BUF(sum, int32_t, int32_t, +) +OP_FUNC_3BUF(sum, uint32_t, uint32_t, +) +OP_FUNC_3BUF(sum, int64_t, int64_t, +) +OP_FUNC_3BUF(sum, uint64_t, uint64_t, +) +OP_FUNC_3BUF(sum, long, long, +) +OP_FUNC_3BUF(sum, ulong, unsigned long, +) + +/* Floating point */ +#if defined(HAVE_SHORT_FLOAT) +OP_FUNC_3BUF(sum, short_float, short float, +) +#elif defined(HAVE_OPAL_SHORT_FLOAT_T) +OP_FUNC_3BUF(sum, short_float, opal_short_float_t, +) +#endif +OP_FUNC_3BUF(sum, float, float, +) +OP_FUNC_3BUF(sum, double, double, +) +OP_FUNC_3BUF(sum, long_double, long double, +) + +/* Complex */ +#if 0 +#if defined(HAVE_SHORT_FLOAT__COMPLEX) +OP_FUNC_3BUF(sum, c_short_float_complex, short float _Complex, +) +#elif defined(HAVE_OPAL_SHORT_FLOAT_COMPLEX_T) +COMPLEX_SUM_FUNC_3BUF(c_short_float_complex, opal_short_float_t) +#endif +OP_FUNC_3BUF(sum, c_long_double_complex, cuLongDoubleComplex, +) +#endif // 0 +#undef current_func +#define current_func(a, b) (hipCaddf(a,b)) +FUNC_FUNC_3BUF(sum, c_float_complex, hipFloatComplex) +#undef current_func +#define current_func(a, b) (hipCadd(a,b)) +FUNC_FUNC_3BUF(sum, c_double_complex, hipDoubleComplex) + +/************************************************************************* + * Product + *************************************************************************/ + +/* C integer */ +OP_FUNC_3BUF(prod, int8_t, int8_t, *) +OP_FUNC_3BUF(prod, uint8_t, uint8_t, *) +OP_FUNC_3BUF(prod, int16_t, int16_t, *) +OP_FUNC_3BUF(prod, uint16_t, uint16_t, *) +OP_FUNC_3BUF(prod, int32_t, int32_t, *) +OP_FUNC_3BUF(prod, uint32_t, uint32_t, *) +OP_FUNC_3BUF(prod, int64_t, int64_t, *) +OP_FUNC_3BUF(prod, uint64_t, uint64_t, *) +OP_FUNC_3BUF(prod, long, long, *) +OP_FUNC_3BUF(prod, ulong, unsigned long, *) + +/* Complex */ +#if 0 +#if defined(HAVE_SHORT_FLOAT__COMPLEX) +OP_FUNC_3BUF(prod, c_short_float_complex, short float _Complex, *) +#elif defined(HAVE_OPAL_SHORT_FLOAT_COMPLEX_T) +COMPLEX_PROD_FUNC_3BUF(c_short_float_complex, opal_short_float_t) +#endif +OP_FUNC_3BUF(prod, c_long_double_complex, long double _Complex, *) +#endif // 0 +#undef current_func +#define current_func(a, b) (hipCmulf(a,b)) +FUNC_FUNC_3BUF(prod, c_float_complex, hipFloatComplex) +#undef current_func +#define current_func(a, b) (hipCmul(a,b)) +FUNC_FUNC_3BUF(prod, c_double_complex, hipDoubleComplex) + +/************************************************************************* + * Logical AND + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) && (b)) +/* C integer */ +FUNC_FUNC_3BUF(land, int8_t, int8_t) +FUNC_FUNC_3BUF(land, uint8_t, uint8_t) +FUNC_FUNC_3BUF(land, int16_t, int16_t) +FUNC_FUNC_3BUF(land, uint16_t, uint16_t) +FUNC_FUNC_3BUF(land, int32_t, int32_t) +FUNC_FUNC_3BUF(land, uint32_t, uint32_t) +FUNC_FUNC_3BUF(land, int64_t, int64_t) +FUNC_FUNC_3BUF(land, uint64_t, uint64_t) +FUNC_FUNC_3BUF(land, long, long) +FUNC_FUNC_3BUF(land, ulong, unsigned long) + +/* C++ bool */ +FUNC_FUNC_3BUF(land, bool, bool) + +/************************************************************************* + * Logical OR + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) || (b)) +/* C integer */ +FUNC_FUNC_3BUF(lor, int8_t, int8_t) +FUNC_FUNC_3BUF(lor, uint8_t, uint8_t) +FUNC_FUNC_3BUF(lor, int16_t, int16_t) +FUNC_FUNC_3BUF(lor, uint16_t, uint16_t) +FUNC_FUNC_3BUF(lor, int32_t, int32_t) +FUNC_FUNC_3BUF(lor, uint32_t, uint32_t) +FUNC_FUNC_3BUF(lor, int64_t, int64_t) +FUNC_FUNC_3BUF(lor, uint64_t, uint64_t) +FUNC_FUNC_3BUF(lor, long, long) +FUNC_FUNC_3BUF(lor, ulong, unsigned long) + +/* C++ bool */ +FUNC_FUNC_3BUF(lor, bool, bool) + +/************************************************************************* + * Logical XOR + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a ? 1 : 0) ^ (b ? 1: 0)) +/* C integer */ +FUNC_FUNC_3BUF(lxor, int8_t, int8_t) +FUNC_FUNC_3BUF(lxor, uint8_t, uint8_t) +FUNC_FUNC_3BUF(lxor, int16_t, int16_t) +FUNC_FUNC_3BUF(lxor, uint16_t, uint16_t) +FUNC_FUNC_3BUF(lxor, int32_t, int32_t) +FUNC_FUNC_3BUF(lxor, uint32_t, uint32_t) +FUNC_FUNC_3BUF(lxor, int64_t, int64_t) +FUNC_FUNC_3BUF(lxor, uint64_t, uint64_t) +FUNC_FUNC_3BUF(lxor, long, long) +FUNC_FUNC_3BUF(lxor, ulong, unsigned long) + +/* C++ bool */ +FUNC_FUNC_3BUF(lxor, bool, bool) + +/************************************************************************* + * Bitwise AND + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) & (b)) +/* C integer */ +FUNC_FUNC_3BUF(band, int8_t, int8_t) +FUNC_FUNC_3BUF(band, uint8_t, uint8_t) +FUNC_FUNC_3BUF(band, int16_t, int16_t) +FUNC_FUNC_3BUF(band, uint16_t, uint16_t) +FUNC_FUNC_3BUF(band, int32_t, int32_t) +FUNC_FUNC_3BUF(band, uint32_t, uint32_t) +FUNC_FUNC_3BUF(band, int64_t, int64_t) +FUNC_FUNC_3BUF(band, uint64_t, uint64_t) +FUNC_FUNC_3BUF(band, long, long) +FUNC_FUNC_3BUF(band, ulong, unsigned long) + +/* Byte */ +FUNC_FUNC_3BUF(band, byte, char) + +/************************************************************************* + * Bitwise OR + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) | (b)) +/* C integer */ +FUNC_FUNC_3BUF(bor, int8_t, int8_t) +FUNC_FUNC_3BUF(bor, uint8_t, uint8_t) +FUNC_FUNC_3BUF(bor, int16_t, int16_t) +FUNC_FUNC_3BUF(bor, uint16_t, uint16_t) +FUNC_FUNC_3BUF(bor, int32_t, int32_t) +FUNC_FUNC_3BUF(bor, uint32_t, uint32_t) +FUNC_FUNC_3BUF(bor, int64_t, int64_t) +FUNC_FUNC_3BUF(bor, uint64_t, uint64_t) +FUNC_FUNC_3BUF(bor, long, long) +FUNC_FUNC_3BUF(bor, ulong, unsigned long) + +/* Byte */ +FUNC_FUNC_3BUF(bor, byte, char) + +/************************************************************************* + * Bitwise XOR + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) ^ (b)) +/* C integer */ +FUNC_FUNC_3BUF(bxor, int8_t, int8_t) +FUNC_FUNC_3BUF(bxor, uint8_t, uint8_t) +FUNC_FUNC_3BUF(bxor, int16_t, int16_t) +FUNC_FUNC_3BUF(bxor, uint16_t, uint16_t) +FUNC_FUNC_3BUF(bxor, int32_t, int32_t) +FUNC_FUNC_3BUF(bxor, uint32_t, uint32_t) +FUNC_FUNC_3BUF(bxor, int64_t, int64_t) +FUNC_FUNC_3BUF(bxor, uint64_t, uint64_t) +FUNC_FUNC_3BUF(bxor, long, long) +FUNC_FUNC_3BUF(bxor, ulong, unsigned long) + +/* Byte */ +FUNC_FUNC_3BUF(bxor, byte, char) + +/************************************************************************* + * Max location + *************************************************************************/ + +LOC_FUNC_3BUF(maxloc, float_int, >) +LOC_FUNC_3BUF(maxloc, double_int, >) +LOC_FUNC_3BUF(maxloc, long_int, >) +LOC_FUNC_3BUF(maxloc, 2int, >) +LOC_FUNC_3BUF(maxloc, short_int, >) +LOC_FUNC_3BUF(maxloc, long_double_int, >) + +/* Fortran compat types */ +LOC_FUNC_3BUF(maxloc, 2float, >) +LOC_FUNC_3BUF(maxloc, 2double, >) +LOC_FUNC_3BUF(maxloc, 2int8, >) +LOC_FUNC_3BUF(maxloc, 2int16, >) +LOC_FUNC_3BUF(maxloc, 2int32, >) +LOC_FUNC_3BUF(maxloc, 2int64, >) + +/************************************************************************* + * Min location + *************************************************************************/ + +LOC_FUNC_3BUF(minloc, float_int, <) +LOC_FUNC_3BUF(minloc, double_int, <) +LOC_FUNC_3BUF(minloc, long_int, <) +LOC_FUNC_3BUF(minloc, 2int, <) +LOC_FUNC_3BUF(minloc, short_int, <) +LOC_FUNC_3BUF(minloc, long_double_int, <) + +/* Fortran compat types */ +LOC_FUNC_3BUF(minloc, 2float, <) +LOC_FUNC_3BUF(minloc, 2double, <) +LOC_FUNC_3BUF(minloc, 2int8, <) +LOC_FUNC_3BUF(minloc, 2int16, <) +LOC_FUNC_3BUF(minloc, 2int32, <) +LOC_FUNC_3BUF(minloc, 2int64, <) diff --git a/ompi/op/Makefile.am b/ompi/op/Makefile.am index 5599c31311b..f0ba89c5618 100644 --- a/ompi/op/Makefile.am +++ b/ompi/op/Makefile.am @@ -22,6 +22,8 @@ # This makefile.am does not stand on its own - it is included from # ompi/Makefile.am +dist_ompidata_DATA += op/help-ompi-op.txt + headers += op/op.h lib@OMPI_LIBMPI_NAME@_la_SOURCES += op/op.c diff --git a/ompi/op/help-ompi-op.txt b/ompi/op/help-ompi-op.txt new file mode 100644 index 00000000000..5cfb60b8f9f --- /dev/null +++ b/ompi/op/help-ompi-op.txt @@ -0,0 +1,15 @@ +# -*- text -*- +# +# Copyright (c) 2004-2023 The University of Tennessee and The University +# of Tennessee Research Foundation. All rights +# reserved. +# $COPYRIGHT$ +# +# Additional copyrights may follow +# +# $HEADER$ +# +# This is the US/English help file for Open MPI's allocator bucket support +# +[missing implementation] +ERROR: No suitable module for op %s on type %s found for device memory! diff --git a/ompi/op/op.c b/ompi/op/op.c index 3977fa8b97b..a75d6b33d5b 100644 --- a/ompi/op/op.c +++ b/ompi/op/op.c @@ -475,6 +475,7 @@ static void ompi_op_construct(ompi_op_t *new_op) new_op->o_3buff_intrinsic.fns[i] = NULL; new_op->o_3buff_intrinsic.modules[i] = NULL; } + new_op->o_device_op = NULL; } @@ -506,4 +507,19 @@ static void ompi_op_destruct(ompi_op_t *op) op->o_3buff_intrinsic.modules[i] = NULL; } } + + if (op->o_device_op != NULL) { + for (i = 0; i < OMPI_OP_BASE_TYPE_MAX; ++i) { + if( NULL != op->o_device_op->do_intrinsic.modules[i] ) { + OBJ_RELEASE(op->o_device_op->do_intrinsic.modules[i]); + op->o_device_op->do_intrinsic.modules[i] = NULL; + } + if( NULL != op->o_device_op->do_3buff_intrinsic.modules[i] ) { + OBJ_RELEASE(op->o_device_op->do_3buff_intrinsic.modules[i]); + op->o_device_op->do_3buff_intrinsic.modules[i] = NULL; + } + } + free(op->o_device_op); + op->o_device_op = NULL; + } } diff --git a/ompi/op/op.h b/ompi/op/op.h index f3cf5b53636..05a4c0c89e3 100644 --- a/ompi/op/op.h +++ b/ompi/op/op.h @@ -3,7 +3,7 @@ * Copyright (c) 2004-2006 The Trustees of Indiana University and Indiana * University Research and Technology * Corporation. All rights reserved. - * Copyright (c) 2004-2007 The University of Tennessee and The University + * Copyright (c) 2004-2023 The University of Tennessee and The University * of Tennessee Research Foundation. All rights * reserved. * Copyright (c) 2004-2007 High Performance Computing Center Stuttgart, @@ -44,6 +44,7 @@ #include "opal/class/opal_object.h" #include "opal/util/printf.h" +#include "opal/util/show_help.h" #include "ompi/datatype/ompi_datatype.h" #include "ompi/mpi/fortran/base/fint_2_int.h" @@ -122,6 +123,15 @@ enum ompi_op_type { OMPI_OP_REPLACE, OMPI_OP_NUM_OF_TYPES }; + +/* device op information */ +struct ompi_device_op_t { + opal_accelerator_stream_t *do_stream; + ompi_op_base_op_stream_fns_t do_intrinsic; + ompi_op_base_op_3buff_stream_fns_t do_3buff_intrinsic; +}; +typedef struct ompi_device_op_t ompi_device_op_t; + /** * Back-end type of MPI_Op */ @@ -167,6 +177,10 @@ struct ompi_op_t { /** 3-buffer functions, which is only for intrinsic ops. No need for the C/C++/Fortran user-defined functions. */ ompi_op_base_op_3buff_fns_t o_3buff_intrinsic; + + /** device functions, only for intrinsic ops. + Provided if device support is detected. */ + ompi_device_op_t *o_device_op; }; /** @@ -376,7 +390,7 @@ OMPI_DECLSPEC void ompi_op_set_java_callback(ompi_op_t *op, void *jnienv, * this function is provided to hide the internal structure field * names. */ -static inline bool ompi_op_is_intrinsic(ompi_op_t * op) +static inline bool ompi_op_is_intrinsic(const ompi_op_t * op) { return (bool) (0 != (op->o_flags & OMPI_OP_FLAGS_INTRINSIC)); } @@ -500,9 +514,11 @@ static inline bool ompi_op_is_valid(ompi_op_t * op, ompi_datatype_t * ddt, * optimization). If you give it an intrinsic op with a datatype that * is not defined to have that operation, it is likely to seg fault. */ -static inline void ompi_op_reduce(ompi_op_t * op, const void *source, - void *target, size_t full_count, - ompi_datatype_t * dtype) +static inline void ompi_op_reduce_stream(ompi_op_t * op, const void *source, + void *target, size_t full_count, + ompi_datatype_t * dtype, + int device, + opal_accelerator_stream_t *stream) { MPI_Fint f_dtype, f_count; int count = full_count; @@ -531,7 +547,7 @@ static inline void ompi_op_reduce(ompi_op_t * op, const void *source, } shift = done_count * ext; // Recurse one level in iterations of 'int' - ompi_op_reduce(op, (const char*)source + shift, (char*)target + shift, iter_count, dtype); + ompi_op_reduce_stream(op, (char*)source + shift, (char*)target + shift, iter_count, dtype, device, stream); done_count += iter_count; } return; @@ -560,6 +576,44 @@ static inline void ompi_op_reduce(ompi_op_t * op, const void *source, * :-) */ + bool use_device_op = false; + /* check if either of the buffers is on a device and if so make sure we can + * access handle it properly */ + if (device != MCA_ACCELERATOR_NO_DEVICE_ID && + ompi_datatype_is_predefined(dtype) && + 0 != (op->o_flags & OMPI_OP_FLAGS_INTRINSIC) && + NULL != op->o_device_op) { + use_device_op = true; + } + + if (!use_device_op) { + /* query the accelerator for whether we can still execute */ + int source_dev_id, target_dev_id; + uint64_t source_flags, target_flags; + int target_check_addr = opal_accelerator.check_addr(target, &target_dev_id, &target_flags); + int source_check_addr = opal_accelerator.check_addr(source, &source_dev_id, &source_flags); + if (target_check_addr > 0 && + source_check_addr > 0 && + ompi_datatype_is_predefined(dtype) && + 0 != (op->o_flags & OMPI_OP_FLAGS_INTRINSIC) && + NULL != op->o_device_op) { + use_device_op = true; + if (target_dev_id == source_dev_id) { + /* both inputs are on the same device; if not the op will take of that */ + device = target_dev_id; + } + } else { + /* check whether we can access the memory from the host */ + if ((target_check_addr == 0 || (target_flags & MCA_ACCELERATOR_FLAGS_UNIFIED_MEMORY)) && + (source_check_addr == 0 || (source_flags & MCA_ACCELERATOR_FLAGS_UNIFIED_MEMORY))) { + /* nothing to be done, we won't need device-capable ops */ + } else { + opal_show_help("help-ompi-op.txt", "missing implementation", true, op->o_name, dtype->name); + abort(); + } + } + } + /* For intrinsics, we also pass the corresponding op module */ if (0 != (op->o_flags & OMPI_OP_FLAGS_INTRINSIC)) { int dtype_id; @@ -569,9 +623,28 @@ static inline void ompi_op_reduce(ompi_op_t * op, const void *source, } else { dtype_id = ompi_op_ddt_map[dtype->id]; } - op->o_func.intrinsic.fns[dtype_id](source, target, - &count, &dtype, - op->o_func.intrinsic.modules[dtype_id]); + if (use_device_op) { + if (NULL == op->o_device_op) { + fprintf(stderr, "no suitable device op module found!"); + abort(); // TODO: be more graceful! + } + opal_accelerator_stream_t *actual_stream = stream; + bool flush_stream = false; + if (NULL == stream) { + actual_stream = MCA_ACCELERATOR_STREAM_DEFAULT; + flush_stream = true; + } + op->o_device_op->do_intrinsic.fns[dtype_id]((void*)source, target, + &count, &dtype, device, actual_stream, + op->o_device_op->do_intrinsic.modules[dtype_id]); + if (flush_stream) { + opal_accelerator.sync_stream(actual_stream); + } + } else { + op->o_func.intrinsic.fns[dtype_id]((void*)source, target, + &count, &dtype, + op->o_func.intrinsic.modules[dtype_id]); + } return; } @@ -579,24 +652,31 @@ static inline void ompi_op_reduce(ompi_op_t * op, const void *source, if (0 != (op->o_flags & OMPI_OP_FLAGS_FORTRAN_FUNC)) { f_dtype = OMPI_INT_2_FINT(dtype->d_f_to_c_index); f_count = OMPI_INT_2_FINT(count); - op->o_func.fort_fn(source, target, &f_count, &f_dtype); + op->o_func.fort_fn((void*)source, target, &f_count, &f_dtype); return; } else if (0 != (op->o_flags & OMPI_OP_FLAGS_JAVA_FUNC)) { - op->o_func.java_data.intercept_fn(source, target, &count, &dtype, + op->o_func.java_data.intercept_fn((void*)source, target, &count, &dtype, op->o_func.java_data.baseType, op->o_func.java_data.jnienv, op->o_func.java_data.object); return; } - op->o_func.c_fn(source, target, &count, &dtype); + op->o_func.c_fn((void*)source, target, &count, &dtype); return; } -static inline void ompi_3buff_op_user (ompi_op_t *op, void * restrict source1, void * restrict source2, +static inline void ompi_op_reduce(ompi_op_t * op, const void *source, + void *target, size_t full_count, + ompi_datatype_t * dtype) +{ + ompi_op_reduce_stream(op, source, target, full_count, dtype, MCA_ACCELERATOR_NO_DEVICE_ID, NULL); +} + +static inline void ompi_3buff_op_user (ompi_op_t *op, const void * source1, const void * source2, void * restrict result, int count, struct ompi_datatype_t *dtype) { - ompi_datatype_copy_content_same_ddt (dtype, count, (char*)result, (char*)source1); - op->o_func.c_fn (source2, result, &count, &dtype); + ompi_datatype_copy_content_same_ddt (dtype, count, result, (void*)source1); + op->o_func.c_fn ((void*)source2, result, &count, &dtype); } /** @@ -622,24 +702,135 @@ static inline void ompi_3buff_op_user (ompi_op_t *op, void * restrict source1, v * * Otherwise, this function is the same as ompi_op_reduce. */ -static inline void ompi_3buff_op_reduce(ompi_op_t * op, void *source1, - void *source2, void *target, - int count, ompi_datatype_t * dtype) +static inline void ompi_3buff_op_reduce_stream(ompi_op_t * op, const void *source1, + const void *source2, void *target, + int count, ompi_datatype_t * dtype, + int device, + opal_accelerator_stream_t *stream) { - void *restrict src1; - void *restrict src2; - void *restrict tgt; - src1 = source1; - src2 = source2; - tgt = target; + bool use_device_op = false; + if (OPAL_UNLIKELY(!ompi_op_is_intrinsic (op))) { + /* no 3buff variants for user-defined ops */ + ompi_3buff_op_user (op, source1, source2, target, count, dtype); + return; + } + + if (device != MCA_ACCELERATOR_NO_DEVICE_ID && + ompi_datatype_is_predefined(dtype) && + op->o_flags & OMPI_OP_FLAGS_INTRINSIC && + NULL != op->o_device_op) { + use_device_op = true; + } + if (!use_device_op) { + int source1_dev_id, source2_dev_id, target_dev_id; + uint64_t source1_flags, source2_flags, target_flags; + int target_check_addr = opal_accelerator.check_addr(target, &target_dev_id, &target_flags); + int source1_check_addr = opal_accelerator.check_addr(source1, &source1_dev_id, &source1_flags); + int source2_check_addr = opal_accelerator.check_addr(source2, &source2_dev_id, &source2_flags); + /* check if either of the buffers is on a device and if so make sure we can + * access handle it properly */ + if (target_check_addr > 0 || source1_check_addr > 0 || source2_check_addr > 0) { + if (ompi_datatype_is_predefined(dtype) && + op->o_flags & OMPI_OP_FLAGS_INTRINSIC && + NULL != op->o_device_op) { + use_device_op = true; + device = target_dev_id; + } else { + /* check whether we can access the memory from the host */ + if ((target_check_addr == 0 || (target_flags & MCA_ACCELERATOR_FLAGS_UNIFIED_MEMORY)) && + (source1_check_addr == 0 || (source1_flags & MCA_ACCELERATOR_FLAGS_UNIFIED_MEMORY)) && + (source2_check_addr == 0 || (source2_flags & MCA_ACCELERATOR_FLAGS_UNIFIED_MEMORY))) { + /* nothing to be done, we won't need device-capable ops */ + } else { + fprintf(stderr, "3buff op: no suitable op module found for device memory!\n"); + abort(); + } + } + } + } + + /* For intrinsics, we also pass the corresponding op module */ + if (0 != (op->o_flags & OMPI_OP_FLAGS_INTRINSIC)) { + int dtype_id; + if (!ompi_datatype_is_predefined(dtype)) { + ompi_datatype_t *dt = ompi_datatype_get_single_predefined_type_from_args(dtype); + dtype_id = ompi_op_ddt_map[dt->id]; + } else { + dtype_id = ompi_op_ddt_map[dtype->id]; + } + if (use_device_op) { + opal_accelerator_stream_t *actual_stream = stream; + bool flush_stream = false; + if (NULL == stream) { + actual_stream = MCA_ACCELERATOR_STREAM_DEFAULT; + flush_stream = true; + } + op->o_device_op->do_3buff_intrinsic.fns[dtype_id]((void*)source1, (void*)source2, target, + &count, &dtype, device, actual_stream, + op->o_device_op->do_3buff_intrinsic.modules[dtype_id]); + if (flush_stream) { + opal_accelerator.sync_stream(actual_stream); + } + } else { + op->o_3buff_intrinsic.fns[dtype_id]((void*)source1, (void*)source2, target, + &count, &dtype, + op->o_func.intrinsic.modules[dtype_id]); + } + } +} + + +static inline void ompi_3buff_op_reduce(ompi_op_t * op, const void *source1, + const void *source2, void *target, + int count, ompi_datatype_t * dtype) +{ if (OPAL_LIKELY(ompi_op_is_intrinsic (op))) { - op->o_3buff_intrinsic.fns[ompi_op_ddt_map[dtype->id]](src1, src2, - tgt, &count, - &dtype, - op->o_3buff_intrinsic.modules[ompi_op_ddt_map[dtype->id]]); + ompi_3buff_op_reduce_stream(op, source1, source2, target, count, dtype, MCA_ACCELERATOR_NO_DEVICE_ID, NULL); } else { - ompi_3buff_op_user (op, src1, src2, tgt, count, dtype); + ompi_3buff_op_user (op, source1, source2, target, count, dtype); + } +} + +static inline void ompi_op_preferred_device(ompi_op_t *op, int source_dev, + int target_dev, size_t count, + ompi_datatype_t *dtype, int *op_device) +{ + /* default to host */ + *op_device = -1; + if (!ompi_op_is_intrinsic (op)) { + return; + } + /* quick check: can we execute on the device? */ + int dtype_id = ompi_op_ddt_map[dtype->id]; + if (NULL == op->o_device_op || NULL == op->o_device_op->do_intrinsic.fns[dtype_id]) { + /* not available on the gpu, must select host */ + return; + } + + size_t size_type; + ompi_datatype_type_size(dtype, &size_type); + + float device_bw; + if (target_dev >= 0) { + opal_accelerator.get_mem_bw(target_dev, &device_bw); + } else if (source_dev >= 0) { + opal_accelerator.get_mem_bw(source_dev, &device_bw); + } + + // assume we reach 50% of theoretical peak on the device + device_bw /= 2.0; + + // TODO: determine at runtime (?) + const float host_bw = 10.0; // 10GB/s + + float host_startup_cost = 0.0; // host has no startup cost + float host_compute_cost = (count*size_type) / (host_bw*1024); // assume 10GB/s memory bandwidth on host + float device_startup_cost = 10.0; // 10us startup cost on device + float device_compute_cost = (count*size_type) / (device_bw*1024); + + if ((host_startup_cost + host_compute_cost) > (device_startup_cost + device_compute_cost)) { + *op_device = (target_dev >= 0) ? target_dev : source_dev; } } From 13aeecf58aa902cc2e4bc2de67bb135793def936 Mon Sep 17 00:00:00 2001 From: Joseph Schuchart Date: Wed, 19 Jun 2024 16:32:37 -0400 Subject: [PATCH 02/12] Build op/cuda and op/rocm as dso by default Signed-off-by: Joseph Schuchart --- config/opal_mca.m4 | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/config/opal_mca.m4 b/config/opal_mca.m4 index b425fe63bf2..76ff8f4222f 100644 --- a/config/opal_mca.m4 +++ b/config/opal_mca.m4 @@ -186,7 +186,7 @@ of type-component pairs. For example, --enable-mca-no-build=pml-ob1]) else msg= if test -z "$enable_mca_dso"; then - enable_mca_dso="accelerator-cuda,accelerator-rocm,accelerator-ze,btl-smcuda,rcache-gpusm,rcache-rgpusm" + enable_mca_dso="accelerator-cuda,accelerator-rocm,accelerator-ze,btl-smcuda,rcache-gpusm,rcache-rgpusm,op-cuda,op-rocm" msg="(default)" fi DSO_all=0 From bc5c3a1599611ef3f23b9aecdc071add2dbe6d1c Mon Sep 17 00:00:00 2001 From: Joseph Schuchart Date: Wed, 19 Jun 2024 19:22:18 -0400 Subject: [PATCH 03/12] Remove DECLSPEC from internal functions Signed-off-by: Joseph Schuchart --- ompi/mca/op/cuda/op_cuda.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ompi/mca/op/cuda/op_cuda.h b/ompi/mca/op/cuda/op_cuda.h index ab349d48ee4..11417b28550 100644 --- a/ompi/mca/op/cuda/op_cuda.h +++ b/ompi/mca/op/cuda/op_cuda.h @@ -69,10 +69,10 @@ typedef struct { OMPI_DECLSPEC extern ompi_op_cuda_component_t mca_op_cuda_component; -OMPI_DECLSPEC extern +extern ompi_op_base_stream_handler_fn_t ompi_op_cuda_functions[OMPI_OP_BASE_FORTRAN_OP_MAX][OMPI_OP_BASE_TYPE_MAX]; -OMPI_DECLSPEC extern +extern ompi_op_base_3buff_stream_handler_fn_t ompi_op_cuda_3buff_functions[OMPI_OP_BASE_FORTRAN_OP_MAX][OMPI_OP_BASE_TYPE_MAX]; END_C_DECLS From c2c5aec65a91f8cd8d955a999c8d6226697557f1 Mon Sep 17 00:00:00 2001 From: Joseph Schuchart Date: Wed, 19 Jun 2024 19:24:05 -0400 Subject: [PATCH 04/12] op/cuda: Lazily initialize the CUDA information Signed-off-by: Joseph Schuchart --- ompi/mca/op/cuda/op_cuda.h | 2 + ompi/mca/op/cuda/op_cuda_component.c | 94 +++++++++++++++++----------- ompi/mca/op/cuda/op_cuda_functions.c | 2 + 3 files changed, 60 insertions(+), 38 deletions(-) diff --git a/ompi/mca/op/cuda/op_cuda.h b/ompi/mca/op/cuda/op_cuda.h index 11417b28550..a88fb49c0ef 100644 --- a/ompi/mca/op/cuda/op_cuda.h +++ b/ompi/mca/op/cuda/op_cuda.h @@ -75,6 +75,8 @@ ompi_op_base_stream_handler_fn_t ompi_op_cuda_functions[OMPI_OP_BASE_FORTRAN_OP_ extern ompi_op_base_3buff_stream_handler_fn_t ompi_op_cuda_3buff_functions[OMPI_OP_BASE_FORTRAN_OP_MAX][OMPI_OP_BASE_TYPE_MAX]; +void ompi_op_cuda_lazy_init(); + END_C_DECLS #endif /* MCA_OP_CUDA_EXPORT_H */ diff --git a/ompi/mca/op/cuda/op_cuda_component.c b/ompi/mca/op/cuda/op_cuda_component.c index 3ead710bd1d..9d36bdd52df 100644 --- a/ompi/mca/op/cuda/op_cuda_component.c +++ b/ompi/mca/op/cuda/op_cuda_component.c @@ -38,6 +38,9 @@ static struct ompi_op_base_module_1_0_0_t * cuda_component_op_query(struct ompi_op_t *op, int *priority); static int cuda_component_register(void); +static opal_mutex_t init_lock = OPAL_MUTEX_STATIC_INIT; +static bool init_complete = false; + ompi_op_cuda_component_t mca_op_cuda_component = { { .opc_version = { @@ -128,44 +131,6 @@ static int cuda_component_init_query(bool enable_progress_threads, bool enable_mpi_thread_multiple) { - int num_devices; - int rc; - // TODO: is this init needed here? - cuInit(0); - CHECK(cuDeviceGetCount, (&num_devices)); - mca_op_cuda_component.cu_num_devices = num_devices; - mca_op_cuda_component.cu_devices = (CUdevice*)malloc(num_devices*sizeof(CUdevice)); - mca_op_cuda_component.cu_max_threads_per_block = (int*)malloc(num_devices*sizeof(int)); - mca_op_cuda_component.cu_max_blocks = (int*)malloc(num_devices*sizeof(int)); - for (int i = 0; i < num_devices; ++i) { - CHECK(cuDeviceGet, (&mca_op_cuda_component.cu_devices[i], i)); - rc = cuDeviceGetAttribute(&mca_op_cuda_component.cu_max_threads_per_block[i], - CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_X, - mca_op_cuda_component.cu_devices[i]); - if (CUDA_SUCCESS != rc) { - /* fall-back to value that should work on every device */ - mca_op_cuda_component.cu_max_threads_per_block[i] = 512; - } - if (-1 < mca_op_cuda_component.cu_max_num_threads) { - if (mca_op_cuda_component.cu_max_threads_per_block[i] >= mca_op_cuda_component.cu_max_num_threads) { - mca_op_cuda_component.cu_max_threads_per_block[i] = mca_op_cuda_component.cu_max_num_threads; - } - } - - rc = cuDeviceGetAttribute(&mca_op_cuda_component.cu_max_blocks[i], - CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_X, - mca_op_cuda_component.cu_devices[i]); - if (CUDA_SUCCESS != rc) { - /* fall-back to value that should work on every device */ - mca_op_cuda_component.cu_max_blocks[i] = 512; - } - if (-1 < mca_op_cuda_component.cu_max_num_blocks) { - if (mca_op_cuda_component.cu_max_blocks[i] >= mca_op_cuda_component.cu_max_num_blocks) { - mca_op_cuda_component.cu_max_blocks[i] = mca_op_cuda_component.cu_max_num_blocks; - } - } - } - return OMPI_SUCCESS; } @@ -193,3 +158,56 @@ cuda_component_op_query(struct ompi_op_t *op, int *priority) *priority = 50; return (ompi_op_base_module_1_0_0_t *) module; } + +void ompi_op_cuda_lazy_init() +{ + /* Double checked locking to avoid having to + * grab locks post lazy-initialization. */ + opal_atomic_rmb(); + if (init_complete) return; + + OPAL_THREAD_LOCK(&init_lock); + + if (!init_complete) { + int num_devices; + int rc; + // TODO: is this init needed here? + cuInit(0); + CHECK(cuDeviceGetCount, (&num_devices)); + mca_op_cuda_component.cu_num_devices = num_devices; + mca_op_cuda_component.cu_devices = (CUdevice*)malloc(num_devices*sizeof(CUdevice)); + mca_op_cuda_component.cu_max_threads_per_block = (int*)malloc(num_devices*sizeof(int)); + mca_op_cuda_component.cu_max_blocks = (int*)malloc(num_devices*sizeof(int)); + for (int i = 0; i < num_devices; ++i) { + CHECK(cuDeviceGet, (&mca_op_cuda_component.cu_devices[i], i)); + rc = cuDeviceGetAttribute(&mca_op_cuda_component.cu_max_threads_per_block[i], + CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_X, + mca_op_cuda_component.cu_devices[i]); + if (CUDA_SUCCESS != rc) { + /* fall-back to value that should work on every device */ + mca_op_cuda_component.cu_max_threads_per_block[i] = 512; + } + if (-1 < mca_op_cuda_component.cu_max_num_threads) { + if (mca_op_cuda_component.cu_max_threads_per_block[i] >= mca_op_cuda_component.cu_max_num_threads) { + mca_op_cuda_component.cu_max_threads_per_block[i] = mca_op_cuda_component.cu_max_num_threads; + } + } + + rc = cuDeviceGetAttribute(&mca_op_cuda_component.cu_max_blocks[i], + CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_X, + mca_op_cuda_component.cu_devices[i]); + if (CUDA_SUCCESS != rc) { + /* fall-back to value that should work on every device */ + mca_op_cuda_component.cu_max_blocks[i] = 512; + } + if (-1 < mca_op_cuda_component.cu_max_num_blocks) { + if (mca_op_cuda_component.cu_max_blocks[i] >= mca_op_cuda_component.cu_max_num_blocks) { + mca_op_cuda_component.cu_max_blocks[i] = mca_op_cuda_component.cu_max_num_blocks; + } + } + } + opal_atomic_wmb(); + init_complete = true; + } + OPAL_THREAD_UNLOCK(&init_lock); +} \ No newline at end of file diff --git a/ompi/mca/op/cuda/op_cuda_functions.c b/ompi/mca/op/cuda/op_cuda_functions.c index 904595147cb..27361cee6a3 100644 --- a/ompi/mca/op/cuda/op_cuda_functions.c +++ b/ompi/mca/op/cuda/op_cuda_functions.c @@ -55,6 +55,8 @@ static inline void device_op_pre(const void *orig_source1, uint64_t target_flags = -1, source1_flags = -1, source2_flags = -1; int target_rc, source1_rc, source2_rc = -1; + ompi_op_cuda_lazy_init(); + *target = orig_target; *source1 = (void*)orig_source1; if (NULL != orig_source2) { From 606f778661484428a899f96bb4790841a195c462 Mon Sep 17 00:00:00 2001 From: Joseph Schuchart Date: Thu, 20 Jun 2024 18:17:07 -0400 Subject: [PATCH 05/12] op/cuda: Add flexible vector type CUDA provides only limited vector widths and only for variable width integer types. We use our own vector type and some C++ templates to get more flexible vectors. We aim to get 128bit loads by adjusting the width based on the type size. Signed-off-by: Joseph Schuchart --- ompi/mca/op/cuda/op_cuda_impl.cu | 1090 +++++++++++++----------------- 1 file changed, 480 insertions(+), 610 deletions(-) diff --git a/ompi/mca/op/cuda/op_cuda_impl.cu b/ompi/mca/op/cuda/op_cuda_impl.cu index 3daf7f56fbb..afbb84b5071 100644 --- a/ompi/mca/op/cuda/op_cuda_impl.cu +++ b/ompi/mca/op/cuda/op_cuda_impl.cu @@ -18,6 +18,24 @@ #include #define ISSIGNED(x) std::is_signed_v +#define ALIGN(x,a,t) (((x)+((t)(a)-1)) & ~(((t)(a)-1))) +#define ALIGN_PTR(x,a,t) ((t)ALIGN((uintptr_t)x, a, uintptr_t)) +#define ALIGN_PAD_AMOUNT(x,s) ((~((uintptr_t)(x))+1) & ((uintptr_t)(s)+(!(uintptr_t)(s))-1)) + +template +struct __align__(sizeof(T)*N) Vec { + T v[N]; + + template + __device__ Vec(S... l) + : v{std::forward(l)...} + { } + + __device__ + T& operator[](size_t i) { return v[i]; } + __device__ + const T& operator[](size_t i) const { return v[i]; } +}; template static inline __device__ constexpr T tmax(T a, T b) { @@ -40,100 +58,100 @@ static inline __device__ constexpr T tprod(T a, T b) { } template -static inline __device__ T vmax(const T& a, const T& b) { - return T{tmax(a.x, b.x), tmax(a.y, b.y), tmax(a.z, b.z), tmax(a.w, b.w)}; +static inline __device__ constexpr T tband(T a, T b) { + return a&b; +} + +template +static inline __device__ constexpr T tbor(T a, T b) { + return a|b; +} + +template +static inline __device__ constexpr T tbxor(T a, T b) { + return a^b; } template -static inline __device__ T vmin(const T& a, const T& b) { - return T{tmin(a.x, b.x), tmin(a.y, b.y), tmin(a.z, b.z), tmin(a.w, b.w)}; +static inline __device__ constexpr T tland(T a, T b) { + return a&&b; } template -static inline __device__ T vsum(const T& a, const T& b) { - return T{tsum(a.x, b.x), tsum(a.y, b.y), tsum(a.z, b.z), tsum(a.w, b.w)}; +static inline __device__ constexpr T tlor(T a, T b) { + return a||b; } template -static inline __device__ T vprod(const T& a, const T& b) { - return T{(a.x * b.x), (a.y * b.y), (a.z * b.z), (a.w * b.w)}; +static inline __device__ constexpr T tlxor(T a, T b) { + return ((!!a) ^ (!!b)) ? 1 : 0; +} + +template +__device__ +static inline V apply(const V& a, const V& b, Fn&& fn, std::index_sequence) { + /* apply fn to all members of the vector and return a new vector */ + return {fn(a[Ns], b[Ns])...}; +} + + +template +static inline __device__ Vec vmax(const Vec& a, const Vec& b) { + return apply(a, b, [](const T&a, const T&b) -> T { return (a > b) ? a : b; }, std::make_index_sequence{}); +} + +template +static inline __device__ Vec vmin(const Vec& a, const Vec& b) { + return apply(a, b, [](const T&a, const T&b) -> T { return (a < b) ? a : b; }, std::make_index_sequence{}); +} + +template +static inline __device__ Vec vsum(const Vec& a, const Vec& b) { + return apply(a, b, [](const T&a, const T&b) -> T { return a + b; }, std::make_index_sequence{}); +} + +template +static inline __device__ Vec vprod(const Vec& a, const Vec& b) { + return apply(a, b, [](const T&a, const T&b) -> T { return a * b; }, std::make_index_sequence{}); +} + +template +static inline __device__ Vec vband(const Vec& a, const Vec& b) { + return apply(a, b, [](const T&a, const T&b) -> T { return a & b; }, std::make_index_sequence{}); +} + +template +static inline __device__ Vec vbor(const Vec& a, const Vec& b) { + return apply(a, b, [](const T&a, const T&b) -> T { return a | b; }, std::make_index_sequence{}); +} + +template +static inline __device__ Vec vbxor(const Vec& a, const Vec& b) { + return apply(a, b, [](const T&a, const T&b) -> T { return a ^ b; }, std::make_index_sequence{}); +} + +template +static inline __device__ Vec vland(const Vec& a, const Vec& b) { + return apply(a, b, [](const T&a, const T&b) -> T { return a && b; }, std::make_index_sequence{}); +} + +template +static inline __device__ Vec vlor(const Vec& a, const Vec& b) { + return apply(a, b, [](const T&a, const T&b) -> T { return a || b; }, std::make_index_sequence{}); +} + +template +static inline __device__ Vec vlxor(const Vec& a, const Vec& b) { + return apply(a, b, [](const T&a, const T&b) -> T { return ((!!a) ^ (!!b)) ? 1 : 0; }, std::make_index_sequence{}); } /* TODO: missing support for * - short float (conditional on whether short float is available) - * - some Fortran types - * - some complex types */ #define USE_VECTORS 1 -#define OP_FUNC(name, type_name, type, op) \ - static __global__ void \ - ompi_op_cuda_2buff_##name##_##type_name##_kernel(const type *__restrict__ in, \ - type *__restrict__ inout, int n) { \ - const int index = blockIdx.x * blockDim.x + threadIdx.x; \ - const int stride = blockDim.x * gridDim.x; \ - for (int i = index; i < n; i += stride) { \ - /*if (index < n) { int i = index;*/ \ - inout[i] = inout[i] op in[i]; \ - } \ - } \ - void ompi_op_cuda_2buff_##name##_##type_name##_submit(const type *in, \ - type *inout, \ - int count, \ - int threads_per_block, \ - int max_blocks, \ - CUstream stream) { \ - int threads = min(count, threads_per_block); \ - int blocks = min((count + threads-1) / threads, max_blocks); \ - int n = count; \ - CUstream s = stream; \ - ompi_op_cuda_2buff_##name##_##type_name##_kernel<<>>(in, inout, n); \ - } - - -#if defined(USE_VECTORS) -#define OPV_FUNC(name, type_name, type, vtype, vlen, op) \ - static __global__ void \ - ompi_op_cuda_2buff_##name##_##type_name##_kernel(const type *__restrict__ in, \ - type *__restrict__ inout, int n) { \ - const int index = blockIdx.x * blockDim.x + threadIdx.x; \ - const int stride = blockDim.x * gridDim.x; \ - for (int i = index; i < n/vlen; i += stride) { \ - vtype vin = ((vtype*)in)[i]; \ - vtype vinout = ((vtype*)inout)[i]; \ - vinout.x = vinout.x op vin.x; \ - vinout.y = vinout.y op vin.y; \ - vinout.z = vinout.z op vin.z; \ - vinout.w = vinout.w op vin.w; \ - ((vtype*)inout)[i] = vinout; \ - } \ - int remainder = n%vlen; \ - if (index == (n/vlen) && remainder != 0) { \ - while(remainder) { \ - int idx = n - remainder--; \ - inout[idx] = inout[idx] op in[idx]; \ - } \ - } \ - } \ - void ompi_op_cuda_2buff_##name##_##type_name##_submit(const type *in, \ - type *inout, \ - int count, \ - int threads_per_block, \ - int max_blocks, \ - CUstream stream) { \ - int vcount = (count + vlen-1)/vlen; \ - int threads = min(threads_per_block, vcount); \ - int blocks = min((vcount + threads-1) / threads, max_blocks); \ - int n = count; \ - CUstream s = stream; \ - ompi_op_cuda_2buff_##name##_##type_name##_kernel<<>>(in, inout, n); \ - } -#else // USE_VECTORS -#define OPV_FUNC(name, type_name, type, vtype, vlen, op) OP_FUNC(name, type_name, type, op) -#endif // USE_VECTORS - #define FUNC_FUNC_FN(name, type_name, type, fn) \ static __global__ void \ ompi_op_cuda_2buff_##name##_##type_name##_kernel(const type *__restrict__ in, \ @@ -161,39 +179,68 @@ static inline __device__ T vprod(const T& a, const T& b) { #define FUNC_FUNC(name, type_name, type) FUNC_FUNC_FN(name, type_name, type, current_func) #if defined(USE_VECTORS) -#define VFUNC_FUNC(name, type_name, type, vtype, vlen, vfn, fn) \ - static __global__ void \ - ompi_op_cuda_2buff_##name##_##type_name##_kernel(const type *__restrict__ in, \ - type *__restrict__ inout, int n) { \ - const int index = blockIdx.x * blockDim.x + threadIdx.x; \ - const int stride = blockDim.x * gridDim.x; \ - for (int i = index; i < n/vlen; i += stride) { \ - ((vtype*)inout)[i] = vfn(((vtype*)inout)[i], ((vtype*)in)[i]); \ - } \ - int remainder = n%vlen; \ - if (index == (n/vlen) && remainder != 0) { \ - while(remainder) { \ - int idx = n - remainder--; \ - inout[idx] = fn(inout[idx], in[idx]); \ - } \ - } \ - } \ - static void \ - ompi_op_cuda_2buff_##name##_##type_name##_submit(const type *in, \ - type *inout, \ - int count, \ - int threads_per_block, \ - int max_blocks, \ - CUstream stream) { \ - int vcount = (count + vlen-1)/vlen; \ - int threads = min(threads_per_block, vcount); \ - int blocks = min((vcount + threads-1) / threads, max_blocks); \ - int n = count; \ - CUstream s = stream; \ - ompi_op_cuda_2buff_##name##_##type_name##_kernel<<>>(in, inout, n); \ +#define VFUNC_FUNC(name, type_name, type, vlen, vfn, fn) \ + static __global__ void \ + ompi_op_cuda_2buff_##name##_##type_name##_kernel_v(const type *__restrict__ in, \ + type *__restrict__ inout, int n) { \ + using vtype = Vec; \ + constexpr const size_t alignment = sizeof(type)*vlen; \ + const int index = blockIdx.x * blockDim.x + threadIdx.x; \ + const int stride = blockDim.x * gridDim.x; \ + size_t in_pad = ALIGN_PAD_AMOUNT(in, alignment); \ + const vtype * inv = ALIGN_PTR(in, alignment, const vtype*); \ + vtype * inoutv = ALIGN_PTR(inout, alignment, vtype*); \ + for (int i = index; i < (n/vlen - in_pad/sizeof(type)); i += stride) { \ + inoutv[i] = vfn(inoutv[i], inv[i]); \ + } \ + if (in_pad > 0) { \ + /* manage front values */ \ + if (index < ((in_pad/sizeof(type)) - 1)) { \ + inout[index] = fn(inout[index], in[index]); \ + } \ + } \ + int remainder = (n%vlen); \ + if (remainder > 0) { \ + /* manage back values */ \ + if (index < (remainder-1)) { \ + size_t idx = n - remainder + index; \ + inout[idx] = fn(inout[idx], in[idx]); \ + } \ + } \ + } \ + static __global__ void \ + ompi_op_cuda_2buff_##name##_##type_name##_kernel(const type *__restrict__ in, \ + type *__restrict__ inout, int n) { \ + /* non-vectorized version (e.g., due to mismatching alignment) */ \ + const int index = blockIdx.x * blockDim.x + threadIdx.x; \ + const int stride = blockDim.x * gridDim.x; \ + for (int i = index; i < n; i += stride) { \ + inout[i] = fn(inout[i], in[i]); \ + } \ + } \ + void \ + ompi_op_cuda_2buff_##name##_##type_name##_submit(const type *in, \ + type *inout, \ + int count, \ + int threads_per_block, \ + int max_blocks, \ + CUstream stream) { \ + int vcount = (count + vlen-1)/vlen; \ + int threads = min(threads_per_block, vcount); \ + int blocks = min((vcount + threads-1) / threads, max_blocks); \ + int n = count; \ + CUstream s = stream; \ + constexpr const size_t alignment = sizeof(type)*vlen; \ + size_t in_pad = ALIGN_PAD_AMOUNT(in, alignment); \ + size_t inout_pad = ALIGN_PAD_AMOUNT(inout, alignment); \ + if (in_pad == inout_pad) { \ + ompi_op_cuda_2buff_##name##_##type_name##_kernel_v<<>>(in, inout, n); \ + } else { \ + ompi_op_cuda_2buff_##name##_##type_name##_kernel<<>>(in, inout, n); \ + } \ } #else -#define VFUNC_FUNC(name, type_name, type, vtype, vlen, vfn, fn) FUNC_FUNC_FN(name, type_name, type, fn) +#define VFUNC_FUNC(name, type_name, type, vlen, vfn, fn) FUNC_FUNC_FN(name, type_name, type, fn) #endif // defined(USE_VECTORS) /* @@ -236,213 +283,74 @@ static inline __device__ T vprod(const T& a, const T& b) { ompi_op_cuda_2buff_##name##_##type_name##_kernel<<>>(a, b, count); \ } -#define OPV_DISPATCH(name, type_name, type) \ - void ompi_op_cuda_2buff_##name##_##type_name##_submit(const type *in, \ - type *inout, \ - int count, \ - int threads_per_block, \ - int max_blocks, \ - CUstream stream) { \ - static_assert(sizeof(type_name) <= sizeof(unsigned long long), "Unknown size type"); \ - if constexpr(!ISSIGNED(type)) { \ - if constexpr(sizeof(type_name) == sizeof(unsigned char)) { \ - ompi_op_cuda_2buff_##name##_uchar_submit((const unsigned char*)in, (unsigned char*)inout, count, \ - threads_per_block, \ - max_blocks, stream); \ - } else if constexpr(sizeof(type_name) == sizeof(unsigned short)) { \ - ompi_op_cuda_2buff_##name##_ushort_submit((const unsigned short*)in, (unsigned short*)inout, count, \ - threads_per_block, \ - max_blocks, stream); \ - } else if constexpr(sizeof(type_name) == sizeof(unsigned int)) { \ - ompi_op_cuda_2buff_##name##_uint_submit((const unsigned int*)in, (unsigned int*)inout, count, \ - threads_per_block, \ - max_blocks, stream); \ - } else if constexpr(sizeof(type_name) == sizeof(unsigned long)) { \ - ompi_op_cuda_2buff_##name##_ulong_submit((const unsigned long*)in, (unsigned long*)inout, count, \ - threads_per_block, \ - max_blocks, stream); \ - } else if constexpr(sizeof(type_name) == sizeof(unsigned long long)) { \ - ompi_op_cuda_2buff_##name##_ulonglong_submit((const unsigned long long*)in, (unsigned long long*)inout, count, \ - threads_per_block, \ - max_blocks, stream); \ - } \ - } else { \ - if constexpr(sizeof(type_name) == sizeof(char)) { \ - ompi_op_cuda_2buff_##name##_char_submit((const char*)in, (char*)inout, count, \ - threads_per_block, \ - max_blocks, stream); \ - } else if constexpr(sizeof(type_name) == sizeof(short)) { \ - ompi_op_cuda_2buff_##name##_short_submit((const short*)in, (short*)inout, count, \ - threads_per_block, \ - max_blocks, stream); \ - } else if constexpr(sizeof(type_name) == sizeof(int)) { \ - ompi_op_cuda_2buff_##name##_int_submit((const int*)in, (int*)inout, count, \ - threads_per_block, \ - max_blocks, stream); \ - } else if constexpr(sizeof(type_name) == sizeof(long)) { \ - ompi_op_cuda_2buff_##name##_long_submit((const long*)in, (long*)inout, count, \ - threads_per_block, \ - max_blocks, stream); \ - } else if constexpr(sizeof(type_name) == sizeof(long long)) { \ - ompi_op_cuda_2buff_##name##_longlong_submit((const long long*)in, (long long*)inout, count,\ - threads_per_block, \ - max_blocks, stream); \ - } \ - } \ - } - /************************************************************************* * Max *************************************************************************/ /* C integer */ -VFUNC_FUNC(max, char, char, char4, 4, vmax, max) -VFUNC_FUNC(max, uchar, unsigned char, uchar4, 4, vmax, max) -VFUNC_FUNC(max, short, short, short4, 4, vmax, max) -VFUNC_FUNC(max, ushort, unsigned short, ushort4, 4, vmax, max) -VFUNC_FUNC(max, int, int, int4, 4, vmax, max) -VFUNC_FUNC(max, uint, unsigned int, uint4, 4, vmax, max) -#undef current_func -#define current_func(a, b) max(a, b) -FUNC_FUNC(max, long, long) -FUNC_FUNC(max, ulong, unsigned long) -FUNC_FUNC(max, longlong, long long) -FUNC_FUNC(max, ulonglong, unsigned long long) - -/* dispatch fixed-size types */ -OPV_DISPATCH(max, int8_t, int8_t) -OPV_DISPATCH(max, uint8_t, uint8_t) -OPV_DISPATCH(max, int16_t, int16_t) -OPV_DISPATCH(max, uint16_t, uint16_t) -OPV_DISPATCH(max, int32_t, int32_t) -OPV_DISPATCH(max, uint32_t, uint32_t) -OPV_DISPATCH(max, int64_t, int64_t) -OPV_DISPATCH(max, uint64_t, uint64_t) - -#undef current_func -#define current_func(a, b) ((a) > (b) ? (a) : (b)) -FUNC_FUNC(max, long_double, long double) - -#if !defined(DO_NOT_USE_INTRINSICS) -#undef current_func -#define current_func(a, b) fmaxf(a, b) -#endif // DO_NOT_USE_INTRINSICS -FUNC_FUNC(max, float, float) - -#if !defined(DO_NOT_USE_INTRINSICS) -#undef current_func -#define current_func(a, b) fmax(a, b) -#endif // DO_NOT_USE_INTRINSICS -FUNC_FUNC(max, double, double) - -// __CUDA_ARCH__ is only defined when compiling device code -#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 530 -#undef current_func -#define current_func(a, b) __hmax2(a, b) -//VFUNC_FUNC(max, halfx, half, half2, 2, __hmax2, __hmax) -#endif // __CUDA_ARCH__ +/* fixed-size types: 16B vector sizes + * TODO: should this be fine-tuned to the architecture? */ +VFUNC_FUNC(max, int8_t, int8_t, 16, vmax, tmax) +VFUNC_FUNC(max, uint8_t, uint8_t, 16, vmax, tmax) +VFUNC_FUNC(max, int16_t, int16_t, 8, vmax, tmax) +VFUNC_FUNC(max, uint16_t, uint16_t, 8, vmax, tmax) +VFUNC_FUNC(max, int32_t, int32_t, 4, vmax, tmax) +VFUNC_FUNC(max, uint32_t, uint32_t, 4, vmax, tmax) +VFUNC_FUNC(max, int64_t, int64_t, 2, vmax, tmax) +VFUNC_FUNC(max, uint64_t, uint64_t, 2, vmax, tmax) + +VFUNC_FUNC(max, long, long, 2, vmax, tmax) +VFUNC_FUNC(max, ulong, unsigned long, 2, vmax, tmax) + +/* float */ +VFUNC_FUNC(max, float, float, 4, vmax, tmax) +VFUNC_FUNC(max, double, double, 2, vmax, tmax) +VFUNC_FUNC(max, long_double, long double, 1, vmax, tmax) /************************************************************************* * Min *************************************************************************/ /* C integer */ -VFUNC_FUNC(min, char, char, char4, 4, vmin, min) -VFUNC_FUNC(min, uchar, unsigned char, uchar4, 4, vmin, min) -VFUNC_FUNC(min, short, short, short4, 4, vmin, min) -VFUNC_FUNC(min, ushort, unsigned short, ushort4, 4, vmin, min) -VFUNC_FUNC(min, int, int, int4, 4, vmin, min) -VFUNC_FUNC(min, uint, unsigned int, uint4, 4, vmin, min) - -#undef current_func -#define current_func(a, b) min(a, b) -FUNC_FUNC(min, long, long) -FUNC_FUNC(min, ulong, unsigned long) -FUNC_FUNC(min, longlong, long long) -FUNC_FUNC(min, ulonglong, unsigned long long) -OPV_DISPATCH(min, int8_t, int8_t) -OPV_DISPATCH(min, uint8_t, uint8_t) -OPV_DISPATCH(min, int16_t, int16_t) -OPV_DISPATCH(min, uint16_t, uint16_t) -OPV_DISPATCH(min, int32_t, int32_t) -OPV_DISPATCH(min, uint32_t, uint32_t) -OPV_DISPATCH(min, int64_t, int64_t) -OPV_DISPATCH(min, uint64_t, uint64_t) - - - -#if !defined(DO_NOT_USE_INTRINSICS) -#undef current_func -#define current_func(a, b) fminf(a, b) -#endif // DO_NOT_USE_INTRINSICS -FUNC_FUNC(min, float, float) - -#if !defined(DO_NOT_USE_INTRINSICS) -#undef current_func -#define current_func(a, b) fmin(a, b) -#endif // DO_NOT_USE_INTRINSICS -FUNC_FUNC(min, double, double) - -#undef current_func -#define current_func(a, b) ((a) < (b) ? (a) : (b)) -FUNC_FUNC(min, long_double, long double) - -// __CUDA_ARCH__ is only defined when compiling device code -#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 530 -#undef current_func -#define current_func(a, b) __hmin2(a, b) -//VFUNC_FUNC(min, half, half, half2, 2, __hmin2, __hmin) -#endif // __CUDA_ARCH__ +VFUNC_FUNC(min, int8_t, int8_t, 16, vmin, tmin) +VFUNC_FUNC(min, uint8_t, uint8_t, 16, vmin, tmin) +VFUNC_FUNC(min, int16_t, int16_t, 8, vmin, tmin) +VFUNC_FUNC(min, uint16_t, uint16_t, 8, vmin, tmin) +VFUNC_FUNC(min, int32_t, int32_t, 4, vmin, tmin) +VFUNC_FUNC(min, uint32_t, uint32_t, 4, vmin, tmin) +VFUNC_FUNC(min, int64_t, int64_t, 2, vmin, tmin) +VFUNC_FUNC(min, uint64_t, uint64_t, 2, vmin, tmin) +VFUNC_FUNC(min, long, long, 2, vmin, tmin) +VFUNC_FUNC(min, ulong, unsigned long, 2, vmin, tmin) + +/* float */ +VFUNC_FUNC(min, float, float, 4, vmin, tmin) +VFUNC_FUNC(min, double, double, 2, vmin, tmin) +VFUNC_FUNC(min, long_double, long double, 1, vmin, tmin) /************************************************************************* * Sum *************************************************************************/ /* C integer */ -VFUNC_FUNC(sum, char, char, char4, 4, vsum, tsum) -VFUNC_FUNC(sum, uchar, unsigned char, uchar4, 4, vsum, tsum) -VFUNC_FUNC(sum, short, short, short4, 4, vsum, tsum) -VFUNC_FUNC(sum, ushort, unsigned short, ushort4, 4, vsum, tsum) -VFUNC_FUNC(sum, int, int, int4, 4, vsum, tsum) -VFUNC_FUNC(sum, uint, unsigned int, uint4, 4, vsum, tsum) - -#undef current_func -#define current_func(a, b) tsum(a, b) -FUNC_FUNC(sum, long, long) -FUNC_FUNC(sum, ulong, unsigned long) -FUNC_FUNC(sum, longlong, long long) -FUNC_FUNC(sum, ulonglong, unsigned long long) - -OPV_DISPATCH(sum, int8_t, int8_t) -OPV_DISPATCH(sum, uint8_t, uint8_t) -OPV_DISPATCH(sum, int16_t, int16_t) -OPV_DISPATCH(sum, uint16_t, uint16_t) -OPV_DISPATCH(sum, int32_t, int32_t) -OPV_DISPATCH(sum, uint32_t, uint32_t) -OPV_DISPATCH(sum, int64_t, int64_t) -OPV_DISPATCH(sum, uint64_t, uint64_t) - -OPV_FUNC(sum, float, float, float4, 4, +) -OPV_FUNC(sum, double, double, double4, 4, +) -OP_FUNC(sum, long_double, long double, +) - -// __CUDA_ARCH__ is only defined when compiling device code -#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 530 -#undef current_func -#define current_func(a, b) __hadd2(a, b) -//VFUNC_FUNC(sum, half, half, half2, 2, __hadd2, __hadd) -#endif // __CUDA_ARCH__ +VFUNC_FUNC(sum, int8_t, int8_t, 16, vsum, tsum) +VFUNC_FUNC(sum, uint8_t, uint8_t, 16, vsum, tsum) +VFUNC_FUNC(sum, int16_t, int16_t, 8, vsum, tsum) +VFUNC_FUNC(sum, uint16_t, uint16_t, 8, vsum, tsum) +VFUNC_FUNC(sum, int32_t, int32_t, 4, vsum, tsum) +VFUNC_FUNC(sum, uint32_t, uint32_t, 4, vsum, tsum) +VFUNC_FUNC(sum, int64_t, int64_t, 2, vsum, tsum) +VFUNC_FUNC(sum, uint64_t, uint64_t, 2, vsum, tsum) +VFUNC_FUNC(sum, long, long, 2, vsum, tsum) +VFUNC_FUNC(sum, ulong, unsigned long, 2, vsum, tsum) + +/* float */ +VFUNC_FUNC(sum, float, float, 4, vsum, tsum) +VFUNC_FUNC(sum, double, double, 2, vsum, tsum) +VFUNC_FUNC(sum, long_double, long double, 1, vsum, tsum) /* Complex */ -#if 0 -#if defined(HAVE_SHORT_FLOAT__COMPLEX) -OP_FUNC(sum, c_short_float_complex, short float _Complex, +=) -#elif defined(HAVE_OPAL_SHORT_FLOAT_COMPLEX_T) -COMPLEX_SUM_FUNC(c_short_float_complex, opal_short_float_t) -OP_FUNC(sum, c_long_double_complex, cuLongDoubleComplex, +=) -#endif -#endif // 0 #undef current_func #define current_func(a, b) (cuCaddf(a,b)) FUNC_FUNC(sum, c_float_complex, cuFloatComplex) @@ -455,49 +363,23 @@ FUNC_FUNC(sum, c_double_complex, cuDoubleComplex) *************************************************************************/ /* C integer */ -#undef current_func -#define current_func(a, b) tprod(a, b) -FUNC_FUNC(prod, char, char) -FUNC_FUNC(prod, uchar, unsigned char) -FUNC_FUNC(prod, short, short) -FUNC_FUNC(prod, ushort, unsigned short) -FUNC_FUNC(prod, int, int) -FUNC_FUNC(prod, uint, unsigned int) -FUNC_FUNC(prod, long, long) -FUNC_FUNC(prod, ulong, unsigned long) -FUNC_FUNC(prod, longlong, long long) -FUNC_FUNC(prod, ulonglong, unsigned long long) - -OPV_DISPATCH(prod, int8_t, int8_t) -OPV_DISPATCH(prod, uint8_t, uint8_t) -OPV_DISPATCH(prod, int16_t, int16_t) -OPV_DISPATCH(prod, uint16_t, uint16_t) -OPV_DISPATCH(prod, int32_t, int32_t) -OPV_DISPATCH(prod, uint32_t, uint32_t) -OPV_DISPATCH(prod, int64_t, int64_t) -OPV_DISPATCH(prod, uint64_t, uint64_t) - - -OPV_FUNC(prod, float, float, float4, 4, *) -OPV_FUNC(prod, double, double, double4, 4, *) -OP_FUNC(prod, long_double, long double, *) - -// __CUDA_ARCH__ is only defined when compiling device code -#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 530 -#undef current_func -#define current_func(a, b) __hmul2(a, b) -//VFUNC_FUNC(prod, half, half, half2, 2, __hmul2, __hmul) -#endif // __CUDA_ARCH__ +VFUNC_FUNC(prod, int8_t, int8_t, 16, vprod, tprod) +VFUNC_FUNC(prod, uint8_t, uint8_t, 16, vprod, tprod) +VFUNC_FUNC(prod, int16_t, int16_t, 8, vprod, tprod) +VFUNC_FUNC(prod, uint16_t, uint16_t, 8, vprod, tprod) +VFUNC_FUNC(prod, int32_t, int32_t, 4, vprod, tprod) +VFUNC_FUNC(prod, uint32_t, uint32_t, 4, vprod, tprod) +VFUNC_FUNC(prod, int64_t, int64_t, 2, vprod, tprod) +VFUNC_FUNC(prod, uint64_t, uint64_t, 2, vprod, tprod) +VFUNC_FUNC(prod, long, long, 2, vprod, tprod) +VFUNC_FUNC(prod, ulong, unsigned long, 2, vprod, tprod) + +/* float */ +VFUNC_FUNC(prod, float, float, 4, vprod, tprod) +VFUNC_FUNC(prod, double, double, 2, vprod, tprod) +VFUNC_FUNC(prod, long_double, long double, 1, vprod, tprod) /* Complex */ -#if 0 -#if defined(HAVE_SHORT_FLOAT__COMPLEX) -OP_FUNC(prod, c_short_float_complex, short float _Complex, *=) -#elif defined(HAVE_OPAL_SHORT_FLOAT_COMPLEX_T) -COMPLEX_PROD_FUNC(c_short_float_complex, opal_short_float_t) -#endif -OP_FUNC(prod, c_long_double_complex, long double _Complex, *=) -#endif // 0 #undef current_func #define current_func(a, b) (cuCmulf(a,b)) FUNC_FUNC(prod, c_float_complex, cuFloatComplex) @@ -509,127 +391,116 @@ FUNC_FUNC(prod, c_double_complex, cuDoubleComplex) * Logical AND *************************************************************************/ -#undef current_func -#define current_func(a, b) ((a) && (b)) /* C integer */ -FUNC_FUNC(land, int8_t, int8_t) -FUNC_FUNC(land, uint8_t, uint8_t) -FUNC_FUNC(land, int16_t, int16_t) -FUNC_FUNC(land, uint16_t, uint16_t) -FUNC_FUNC(land, int32_t, int32_t) -FUNC_FUNC(land, uint32_t, uint32_t) -FUNC_FUNC(land, int64_t, int64_t) -FUNC_FUNC(land, uint64_t, uint64_t) -FUNC_FUNC(land, long, long) -FUNC_FUNC(land, ulong, unsigned long) +VFUNC_FUNC(land, int8_t, int8_t, 16, vland, tland) +VFUNC_FUNC(land, uint8_t, uint8_t, 16, vland, tland) +VFUNC_FUNC(land, int16_t, int16_t, 8, vland, tland) +VFUNC_FUNC(land, uint16_t, uint16_t, 8, vland, tland) +VFUNC_FUNC(land, int32_t, int32_t, 4, vland, tland) +VFUNC_FUNC(land, uint32_t, uint32_t, 4, vland, tland) +VFUNC_FUNC(land, int64_t, int64_t, 2, vland, tland) +VFUNC_FUNC(land, uint64_t, uint64_t, 2, vland, tland) +VFUNC_FUNC(land, long, long, 2, vland, tland) +VFUNC_FUNC(land, ulong, unsigned long, 2, vland, tland) /* C++ bool */ -FUNC_FUNC(land, bool, bool) +VFUNC_FUNC(land, bool, bool, 16, vland, tland) /************************************************************************* * Logical OR *************************************************************************/ -#undef current_func -#define current_func(a, b) ((a) || (b)) /* C integer */ -FUNC_FUNC(lor, int8_t, int8_t) -FUNC_FUNC(lor, uint8_t, uint8_t) -FUNC_FUNC(lor, int16_t, int16_t) -FUNC_FUNC(lor, uint16_t, uint16_t) -FUNC_FUNC(lor, int32_t, int32_t) -FUNC_FUNC(lor, uint32_t, uint32_t) -FUNC_FUNC(lor, int64_t, int64_t) -FUNC_FUNC(lor, uint64_t, uint64_t) -FUNC_FUNC(lor, long, long) -FUNC_FUNC(lor, ulong, unsigned long) +VFUNC_FUNC(lor, int8_t, int8_t, 16, vlor, tlor) +VFUNC_FUNC(lor, uint8_t, uint8_t, 16, vlor, tlor) +VFUNC_FUNC(lor, int16_t, int16_t, 8, vlor, tlor) +VFUNC_FUNC(lor, uint16_t, uint16_t, 8, vlor, tlor) +VFUNC_FUNC(lor, int32_t, int32_t, 4, vlor, tlor) +VFUNC_FUNC(lor, uint32_t, uint32_t, 4, vlor, tlor) +VFUNC_FUNC(lor, int64_t, int64_t, 2, vlor, tlor) +VFUNC_FUNC(lor, uint64_t, uint64_t, 2, vlor, tlor) +VFUNC_FUNC(lor, long, long, 2, vlor, tlor) +VFUNC_FUNC(lor, ulong, unsigned long, 2, vlor, tlor) /* C++ bool */ -FUNC_FUNC(lor, bool, bool) +VFUNC_FUNC(lor, bool, bool, 16, vlor, tlor) /************************************************************************* * Logical XOR *************************************************************************/ -#undef current_func -#define current_func(a, b) ((a ? 1 : 0) ^ (b ? 1: 0)) /* C integer */ -FUNC_FUNC(lxor, int8_t, int8_t) -FUNC_FUNC(lxor, uint8_t, uint8_t) -FUNC_FUNC(lxor, int16_t, int16_t) -FUNC_FUNC(lxor, uint16_t, uint16_t) -FUNC_FUNC(lxor, int32_t, int32_t) -FUNC_FUNC(lxor, uint32_t, uint32_t) -FUNC_FUNC(lxor, int64_t, int64_t) -FUNC_FUNC(lxor, uint64_t, uint64_t) -FUNC_FUNC(lxor, long, long) -FUNC_FUNC(lxor, ulong, unsigned long) +VFUNC_FUNC(lxor, int8_t, int8_t, 16, vlxor, tlxor) +VFUNC_FUNC(lxor, uint8_t, uint8_t, 16, vlxor, tlxor) +VFUNC_FUNC(lxor, int16_t, int16_t, 8, vlxor, tlxor) +VFUNC_FUNC(lxor, uint16_t, uint16_t, 8, vlxor, tlxor) +VFUNC_FUNC(lxor, int32_t, int32_t, 4, vlxor, tlxor) +VFUNC_FUNC(lxor, uint32_t, uint32_t, 4, vlxor, tlxor) +VFUNC_FUNC(lxor, int64_t, int64_t, 2, vlxor, tlxor) +VFUNC_FUNC(lxor, uint64_t, uint64_t, 2, vlxor, tlxor) +VFUNC_FUNC(lxor, long, long, 2, vlxor, tlxor) +VFUNC_FUNC(lxor, ulong, unsigned long, 2, vlxor, tlxor) /* C++ bool */ -FUNC_FUNC(lxor, bool, bool) +VFUNC_FUNC(lxor, bool, bool, 16, vlxor, tlxor) + /************************************************************************* * Bitwise AND *************************************************************************/ -#undef current_func -#define current_func(a, b) ((a) & (b)) /* C integer */ -FUNC_FUNC(band, int8_t, int8_t) -FUNC_FUNC(band, uint8_t, uint8_t) -FUNC_FUNC(band, int16_t, int16_t) -FUNC_FUNC(band, uint16_t, uint16_t) -FUNC_FUNC(band, int32_t, int32_t) -FUNC_FUNC(band, uint32_t, uint32_t) -FUNC_FUNC(band, int64_t, int64_t) -FUNC_FUNC(band, uint64_t, uint64_t) -FUNC_FUNC(band, long, long) -FUNC_FUNC(band, ulong, unsigned long) - -/* Byte */ -FUNC_FUNC(band, byte, char) +VFUNC_FUNC(band, int8_t, int8_t, 16, vband, tband) +VFUNC_FUNC(band, uint8_t, uint8_t, 16, vband, tband) +VFUNC_FUNC(band, int16_t, int16_t, 8, vband, tband) +VFUNC_FUNC(band, uint16_t, uint16_t, 8, vband, tband) +VFUNC_FUNC(band, int32_t, int32_t, 4, vband, tband) +VFUNC_FUNC(band, uint32_t, uint32_t, 4, vband, tband) +VFUNC_FUNC(band, int64_t, int64_t, 2, vband, tband) +VFUNC_FUNC(band, uint64_t, uint64_t, 2, vband, tband) +VFUNC_FUNC(band, long, long, 2, vband, tband) +VFUNC_FUNC(band, ulong, unsigned long, 2, vband, tband) + +/* C++ byte */ +VFUNC_FUNC(band, byte, char, 16, vband, tband) /************************************************************************* * Bitwise OR *************************************************************************/ -#undef current_func -#define current_func(a, b) ((a) | (b)) /* C integer */ -FUNC_FUNC(bor, int8_t, int8_t) -FUNC_FUNC(bor, uint8_t, uint8_t) -FUNC_FUNC(bor, int16_t, int16_t) -FUNC_FUNC(bor, uint16_t, uint16_t) -FUNC_FUNC(bor, int32_t, int32_t) -FUNC_FUNC(bor, uint32_t, uint32_t) -FUNC_FUNC(bor, int64_t, int64_t) -FUNC_FUNC(bor, uint64_t, uint64_t) -FUNC_FUNC(bor, long, long) -FUNC_FUNC(bor, ulong, unsigned long) - -/* Byte */ -FUNC_FUNC(bor, byte, char) +VFUNC_FUNC(bor, int8_t, int8_t, 16, vbor, tbor) +VFUNC_FUNC(bor, uint8_t, uint8_t, 16, vbor, tbor) +VFUNC_FUNC(bor, int16_t, int16_t, 8, vbor, tbor) +VFUNC_FUNC(bor, uint16_t, uint16_t, 8, vbor, tbor) +VFUNC_FUNC(bor, int32_t, int32_t, 4, vbor, tbor) +VFUNC_FUNC(bor, uint32_t, uint32_t, 4, vbor, tbor) +VFUNC_FUNC(bor, int64_t, int64_t, 2, vbor, tbor) +VFUNC_FUNC(bor, uint64_t, uint64_t, 2, vbor, tbor) +VFUNC_FUNC(bor, long, long, 2, vbor, tbor) +VFUNC_FUNC(bor, ulong, unsigned long, 2, vbor, tbor) + +/* C++ byte */ +VFUNC_FUNC(bor, byte, char, 16, vbor, tbor) /************************************************************************* * Bitwise XOR *************************************************************************/ -#undef current_func -#define current_func(a, b) ((a) ^ (b)) /* C integer */ -FUNC_FUNC(bxor, int8_t, int8_t) -FUNC_FUNC(bxor, uint8_t, uint8_t) -FUNC_FUNC(bxor, int16_t, int16_t) -FUNC_FUNC(bxor, uint16_t, uint16_t) -FUNC_FUNC(bxor, int32_t, int32_t) -FUNC_FUNC(bxor, uint32_t, uint32_t) -FUNC_FUNC(bxor, int64_t, int64_t) -FUNC_FUNC(bxor, uint64_t, uint64_t) -FUNC_FUNC(bxor, long, long) -FUNC_FUNC(bxor, ulong, unsigned long) - -/* Byte */ -FUNC_FUNC(bxor, byte, char) +VFUNC_FUNC(bxor, int8_t, int8_t, 16, vbxor, tbxor) +VFUNC_FUNC(bxor, uint8_t, uint8_t, 16, vbxor, tbxor) +VFUNC_FUNC(bxor, int16_t, int16_t, 8, vbxor, tbxor) +VFUNC_FUNC(bxor, uint16_t, uint16_t, 8, vbxor, tbxor) +VFUNC_FUNC(bxor, int32_t, int32_t, 4, vbxor, tbxor) +VFUNC_FUNC(bxor, uint32_t, uint32_t, 4, vbxor, tbxor) +VFUNC_FUNC(bxor, int64_t, int64_t, 2, vbxor, tbxor) +VFUNC_FUNC(bxor, uint64_t, uint64_t, 2, vbxor, tbxor) +VFUNC_FUNC(bxor, long, long, 2, vbxor, tbxor) +VFUNC_FUNC(bxor, ulong, unsigned long, 2, vbxor, tbxor) + +/* C++ byte */ +VFUNC_FUNC(bxor, byte, char, 16, vbxor, tbxor) /************************************************************************* * Max location @@ -726,6 +597,76 @@ LOC_FUNC(minloc, 2int64, <) 0, stream>>>(in1, in2, out, count); \ } + +#if defined(USE_VECTORS) +#define VFUNC_FUNC_3BUF(name, type_name, type, vlen, vfn, fn) \ + static __global__ void \ + ompi_op_cuda_3buff_##name##_##type_name##_kernel_v(const type *__restrict__ in1, \ + const type *__restrict__ in2, \ + type *__restrict__ out, int n) { \ + using vtype = Vec; \ + constexpr const size_t alignment = sizeof(type)*vlen; \ + const int index = blockIdx.x * blockDim.x + threadIdx.x; \ + const int stride = blockDim.x * gridDim.x; \ + size_t in_pad = ALIGN_PAD_AMOUNT(in1, alignment); \ + const vtype * in1v = ALIGN_PTR(in1, alignment, const vtype*); \ + const vtype * in2v = ALIGN_PTR(in2, alignment, const vtype*); \ + vtype * outv = ALIGN_PTR(out, alignment, vtype*); \ + for (int i = index; i < (n/vlen - in_pad/sizeof(type)); i += stride) { \ + outv[i] = vfn(in1v[i], in2v[i]); \ + } \ + if (in_pad > 0) { \ + /* manage front values */ \ + if (index < ((in_pad/sizeof(type)) - 1)) { \ + out[index] = fn(in1[index], in2[index]); \ + } \ + } \ + int remainder = (n%vlen); \ + if (remainder > 0) { \ + /* manage back values */ \ + if (index < (remainder-1)) { \ + size_t idx = n - remainder + index; \ + out[idx] = fn(in1[idx], in2[idx]); \ + } \ + } \ + } \ + static __global__ void \ + ompi_op_cuda_3buff_##name##_##type_name##_kernel(const type *__restrict__ in1, \ + const type *__restrict__ in2, \ + type *__restrict__ out, int n) { \ + /* non-vectorized version (e.g., due to mismatching alignment) */ \ + const int index = blockIdx.x * blockDim.x + threadIdx.x; \ + const int stride = blockDim.x * gridDim.x; \ + for (int i = index; i < n; i += stride) { \ + out[i] = fn(in1[i], in2[i]); \ + } \ + } \ + void \ + ompi_op_cuda_3buff_##name##_##type_name##_submit(const type *in1, \ + const type *in2, \ + type *out, \ + int count, \ + int threads_per_block, \ + int max_blocks, \ + CUstream stream) { \ + int vcount = (count + vlen-1)/vlen; \ + int threads = min(threads_per_block, vcount); \ + int blocks = min((vcount + threads-1) / threads, max_blocks); \ + int n = count; \ + CUstream s = stream; \ + constexpr const size_t alignment = sizeof(type)*vlen; \ + size_t in1_pad = ALIGN_PAD_AMOUNT(in1, alignment); \ + size_t in2_pad = ALIGN_PAD_AMOUNT(in2, alignment); \ + size_t out_pad = ALIGN_PAD_AMOUNT(out, alignment); \ + if (in1_pad == in2_pad && in1_pad == out_pad) { \ + ompi_op_cuda_3buff_##name##_##type_name##_kernel_v<<>>(in1, in2, out, n);\ + } else { \ + ompi_op_cuda_3buff_##name##_##type_name##_kernel<<>>(in1, in2, out, n); \ + } \ + } +#else +#define VFUNC_FUNC(name, type_name, type, vlen, vfn, fn) FUNC_FUNC_FN(name, type_name, type, fn) +#endif // defined(USE_VECTORS) /* * Since all the functions in this file are essentially identical, we * use a macro to substitute in names and types. The core operation @@ -778,94 +719,66 @@ LOC_FUNC(minloc, 2int64, <) * Max *************************************************************************/ -#undef current_func -#define current_func(a, b) ((a) > (b) ? (a) : (b)) -/* C integer */ -FUNC_FUNC_3BUF(max, int8_t, int8_t) -FUNC_FUNC_3BUF(max, uint8_t, uint8_t) -FUNC_FUNC_3BUF(max, int16_t, int16_t) -FUNC_FUNC_3BUF(max, uint16_t, uint16_t) -FUNC_FUNC_3BUF(max, int32_t, int32_t) -FUNC_FUNC_3BUF(max, uint32_t, uint32_t) -FUNC_FUNC_3BUF(max, int64_t, int64_t) -FUNC_FUNC_3BUF(max, uint64_t, uint64_t) -FUNC_FUNC_3BUF(max, long, long) -FUNC_FUNC_3BUF(max, ulong, unsigned long) - -/* Floating point */ -#if defined(HAVE_SHORT_FLOAT) -FUNC_FUNC_3BUF(max, short_float, short float) -#elif defined(HAVE_OPAL_SHORT_FLOAT_T) -FUNC_FUNC_3BUF(max, short_float, opal_short_float_t) -#endif -FUNC_FUNC_3BUF(max, float, float) -FUNC_FUNC_3BUF(max, double, double) -FUNC_FUNC_3BUF(max, long_double, long double) +/* fixed-size types: 16B vector sizes + * TODO: should this be fine-tuned to the architecture? */ + VFUNC_FUNC_3BUF(max, int8_t, int8_t, 16, vmax, tmax) + VFUNC_FUNC_3BUF(max, uint8_t, uint8_t, 16, vmax, tmax) + VFUNC_FUNC_3BUF(max, int16_t, int16_t, 8, vmax, tmax) + VFUNC_FUNC_3BUF(max, uint16_t, uint16_t, 8, vmax, tmax) + VFUNC_FUNC_3BUF(max, int32_t, int32_t, 4, vmax, tmax) + VFUNC_FUNC_3BUF(max, uint32_t, uint32_t, 4, vmax, tmax) + VFUNC_FUNC_3BUF(max, int64_t, int64_t, 2, vmax, tmax) + VFUNC_FUNC_3BUF(max, uint64_t, uint64_t, 2, vmax, tmax) + + /* float */ + VFUNC_FUNC_3BUF(max, float, float, 4, vmax, tmax) + VFUNC_FUNC_3BUF(max, double, double, 2, vmax, tmax) + VFUNC_FUNC_3BUF(max, long_double, long double, 1, vmax, tmax) /************************************************************************* * Min *************************************************************************/ -#undef current_func -#define current_func(a, b) ((a) < (b) ? (a) : (b)) /* C integer */ -FUNC_FUNC_3BUF(min, int8_t, int8_t) -FUNC_FUNC_3BUF(min, uint8_t, uint8_t) -FUNC_FUNC_3BUF(min, int16_t, int16_t) -FUNC_FUNC_3BUF(min, uint16_t, uint16_t) -FUNC_FUNC_3BUF(min, int32_t, int32_t) -FUNC_FUNC_3BUF(min, uint32_t, uint32_t) -FUNC_FUNC_3BUF(min, int64_t, int64_t) -FUNC_FUNC_3BUF(min, uint64_t, uint64_t) -FUNC_FUNC_3BUF(min, long, long) -FUNC_FUNC_3BUF(min, ulong, unsigned long) - -/* Floating point */ -#if defined(HAVE_SHORT_FLOAT) -FUNC_FUNC_3BUF(min, short_float, short float) -#elif defined(HAVE_OPAL_SHORT_FLOAT_T) -FUNC_FUNC_3BUF(min, short_float, opal_short_float_t) -#endif -FUNC_FUNC_3BUF(min, float, float) -FUNC_FUNC_3BUF(min, double, double) -FUNC_FUNC_3BUF(min, long_double, long double) +VFUNC_FUNC_3BUF(min, int8_t, int8_t, 16, vmin, tmin) +VFUNC_FUNC_3BUF(min, uint8_t, uint8_t, 16, vmin, tmin) +VFUNC_FUNC_3BUF(min, int16_t, int16_t, 8, vmin, tmin) +VFUNC_FUNC_3BUF(min, uint16_t, uint16_t, 8, vmin, tmin) +VFUNC_FUNC_3BUF(min, int32_t, int32_t, 4, vmin, tmin) +VFUNC_FUNC_3BUF(min, uint32_t, uint32_t, 4, vmin, tmin) +VFUNC_FUNC_3BUF(min, int64_t, int64_t, 2, vmin, tmin) +VFUNC_FUNC_3BUF(min, uint64_t, uint64_t, 2, vmin, tmin) +VFUNC_FUNC_3BUF(min, long, long, 2, vmin, tmin) +VFUNC_FUNC_3BUF(min, ulong, unsigned long, 2, vmin, tmin) + +/* float */ +VFUNC_FUNC_3BUF(min, float, float, 4, vmin, tmin) +VFUNC_FUNC_3BUF(min, double, double, 2, vmin, tmin) +VFUNC_FUNC_3BUF(min, long_double, long double, 1, vmin, tmin) /************************************************************************* * Sum *************************************************************************/ /* C integer */ -OP_FUNC_3BUF(sum, int8_t, int8_t, +) -OP_FUNC_3BUF(sum, uint8_t, uint8_t, +) -OP_FUNC_3BUF(sum, int16_t, int16_t, +) -OP_FUNC_3BUF(sum, uint16_t, uint16_t, +) -OP_FUNC_3BUF(sum, int32_t, int32_t, +) -OP_FUNC_3BUF(sum, uint32_t, uint32_t, +) -OP_FUNC_3BUF(sum, int64_t, int64_t, +) -OP_FUNC_3BUF(sum, uint64_t, uint64_t, +) -OP_FUNC_3BUF(sum, long, long, +) -OP_FUNC_3BUF(sum, ulong, unsigned long, +) - -/* Floating point */ -#if defined(HAVE_SHORT_FLOAT) -OP_FUNC_3BUF(sum, short_float, short float, +) -#elif defined(HAVE_OPAL_SHORT_FLOAT_T) -OP_FUNC_3BUF(sum, short_float, opal_short_float_t, +) -#endif -OP_FUNC_3BUF(sum, float, float, +) -OP_FUNC_3BUF(sum, double, double, +) -OP_FUNC_3BUF(sum, long_double, long double, +) +VFUNC_FUNC_3BUF(sum, int8_t, int8_t, 16, vsum, tsum) +VFUNC_FUNC_3BUF(sum, uint8_t, uint8_t, 16, vsum, tsum) +VFUNC_FUNC_3BUF(sum, int16_t, int16_t, 8, vsum, tsum) +VFUNC_FUNC_3BUF(sum, uint16_t, uint16_t, 8, vsum, tsum) +VFUNC_FUNC_3BUF(sum, int32_t, int32_t, 4, vsum, tsum) +VFUNC_FUNC_3BUF(sum, uint32_t, uint32_t, 4, vsum, tsum) +VFUNC_FUNC_3BUF(sum, int64_t, int64_t, 2, vsum, tsum) +VFUNC_FUNC_3BUF(sum, uint64_t, uint64_t, 2, vsum, tsum) +VFUNC_FUNC_3BUF(sum, long, long, 2, vsum, tsum) +VFUNC_FUNC_3BUF(sum, ulong, unsigned long, 2, vsum, tsum) + +/* float */ +VFUNC_FUNC_3BUF(sum, float, float, 4, vsum, tsum) +VFUNC_FUNC_3BUF(sum, double, double, 2, vsum, tsum) +VFUNC_FUNC_3BUF(sum, long_double, long double, 1, vsum, tsum) /* Complex */ -#if 0 -#if defined(HAVE_SHORT_FLOAT__COMPLEX) -OP_FUNC_3BUF(sum, c_short_float_complex, short float _Complex, +) -#elif defined(HAVE_OPAL_SHORT_FLOAT_COMPLEX_T) -COMPLEX_SUM_FUNC_3BUF(c_short_float_complex, opal_short_float_t) -#endif -OP_FUNC_3BUF(sum, c_long_double_complex, cuLongDoubleComplex, +) -#endif // 0 #undef current_func #define current_func(a, b) (cuCaddf(a,b)) FUNC_FUNC_3BUF(sum, c_float_complex, cuFloatComplex) @@ -878,36 +791,23 @@ FUNC_FUNC_3BUF(sum, c_double_complex, cuDoubleComplex) *************************************************************************/ /* C integer */ -OP_FUNC_3BUF(prod, int8_t, int8_t, *) -OP_FUNC_3BUF(prod, uint8_t, uint8_t, *) -OP_FUNC_3BUF(prod, int16_t, int16_t, *) -OP_FUNC_3BUF(prod, uint16_t, uint16_t, *) -OP_FUNC_3BUF(prod, int32_t, int32_t, *) -OP_FUNC_3BUF(prod, uint32_t, uint32_t, *) -OP_FUNC_3BUF(prod, int64_t, int64_t, *) -OP_FUNC_3BUF(prod, uint64_t, uint64_t, *) -OP_FUNC_3BUF(prod, long, long, *) -OP_FUNC_3BUF(prod, ulong, unsigned long, *) - -/* Floating point */ -#if defined(HAVE_SHORT_FLOAT) -OP_FUNC_3BUF(prod, short_float, short float, *) -#elif defined(HAVE_OPAL_SHORT_FLOAT_T) -OP_FUNC_3BUF(prod, short_float, opal_short_float_t, *) -#endif -OP_FUNC_3BUF(prod, float, float, *) -OP_FUNC_3BUF(prod, double, double, *) -OP_FUNC_3BUF(prod, long_double, long double, *) +VFUNC_FUNC_3BUF(prod, int8_t, int8_t, 16, vprod, tprod) +VFUNC_FUNC_3BUF(prod, uint8_t, uint8_t, 16, vprod, tprod) +VFUNC_FUNC_3BUF(prod, int16_t, int16_t, 8, vprod, tprod) +VFUNC_FUNC_3BUF(prod, uint16_t, uint16_t, 8, vprod, tprod) +VFUNC_FUNC_3BUF(prod, int32_t, int32_t, 4, vprod, tprod) +VFUNC_FUNC_3BUF(prod, uint32_t, uint32_t, 4, vprod, tprod) +VFUNC_FUNC_3BUF(prod, int64_t, int64_t, 2, vprod, tprod) +VFUNC_FUNC_3BUF(prod, uint64_t, uint64_t, 2, vprod, tprod) +VFUNC_FUNC_3BUF(prod, long, long, 2, vprod, tprod) +VFUNC_FUNC_3BUF(prod, ulong, unsigned long, 2, vprod, tprod) + +/* float */ +VFUNC_FUNC_3BUF(prod, float, float, 4, vprod, tprod) +VFUNC_FUNC_3BUF(prod, double, double, 2, vprod, tprod) +VFUNC_FUNC_3BUF(prod, long_double, long double, 1, vprod, tprod) /* Complex */ -#if 0 -#if defined(HAVE_SHORT_FLOAT__COMPLEX) -OP_FUNC_3BUF(prod, c_short_float_complex, short float _Complex, *) -#elif defined(HAVE_OPAL_SHORT_FLOAT_COMPLEX_T) -COMPLEX_PROD_FUNC_3BUF(c_short_float_complex, opal_short_float_t) -#endif -OP_FUNC_3BUF(prod, c_long_double_complex, long double _Complex, *) -#endif // 0 #undef current_func #define current_func(a, b) (cuCmulf(a,b)) FUNC_FUNC_3BUF(prod, c_float_complex, cuFloatComplex) @@ -919,127 +819,97 @@ FUNC_FUNC_3BUF(prod, c_double_complex, cuDoubleComplex) * Logical AND *************************************************************************/ -#undef current_func -#define current_func(a, b) ((a) && (b)) /* C integer */ -FUNC_FUNC_3BUF(land, int8_t, int8_t) -FUNC_FUNC_3BUF(land, uint8_t, uint8_t) -FUNC_FUNC_3BUF(land, int16_t, int16_t) -FUNC_FUNC_3BUF(land, uint16_t, uint16_t) -FUNC_FUNC_3BUF(land, int32_t, int32_t) -FUNC_FUNC_3BUF(land, uint32_t, uint32_t) -FUNC_FUNC_3BUF(land, int64_t, int64_t) -FUNC_FUNC_3BUF(land, uint64_t, uint64_t) -FUNC_FUNC_3BUF(land, long, long) -FUNC_FUNC_3BUF(land, ulong, unsigned long) +VFUNC_FUNC_3BUF(land, int8_t, int8_t, 16, vland, tland) +VFUNC_FUNC_3BUF(land, uint8_t, uint8_t, 16, vland, tland) +VFUNC_FUNC_3BUF(land, int16_t, int16_t, 8, vland, tland) +VFUNC_FUNC_3BUF(land, uint16_t, uint16_t, 8, vland, tland) +VFUNC_FUNC_3BUF(land, int32_t, int32_t, 4, vland, tland) +VFUNC_FUNC_3BUF(land, uint32_t, uint32_t, 4, vland, tland) +VFUNC_FUNC_3BUF(land, int64_t, int64_t, 2, vland, tland) +VFUNC_FUNC_3BUF(land, uint64_t, uint64_t, 2, vland, tland) +VFUNC_FUNC_3BUF(land, long, long, 2, vland, tland) +VFUNC_FUNC_3BUF(land, ulong, unsigned long, 2, vland, tland) /* C++ bool */ -FUNC_FUNC_3BUF(land, bool, bool) +VFUNC_FUNC_3BUF(land, bool, bool, 16, vland, tland) /************************************************************************* * Logical OR *************************************************************************/ -#undef current_func -#define current_func(a, b) ((a) || (b)) /* C integer */ -FUNC_FUNC_3BUF(lor, int8_t, int8_t) -FUNC_FUNC_3BUF(lor, uint8_t, uint8_t) -FUNC_FUNC_3BUF(lor, int16_t, int16_t) -FUNC_FUNC_3BUF(lor, uint16_t, uint16_t) -FUNC_FUNC_3BUF(lor, int32_t, int32_t) -FUNC_FUNC_3BUF(lor, uint32_t, uint32_t) -FUNC_FUNC_3BUF(lor, int64_t, int64_t) -FUNC_FUNC_3BUF(lor, uint64_t, uint64_t) -FUNC_FUNC_3BUF(lor, long, long) -FUNC_FUNC_3BUF(lor, ulong, unsigned long) +VFUNC_FUNC_3BUF(lor, int8_t, int8_t, 16, vlor, tlor) +VFUNC_FUNC_3BUF(lor, uint8_t, uint8_t, 16, vlor, tlor) +VFUNC_FUNC_3BUF(lor, int16_t, int16_t, 8, vlor, tlor) +VFUNC_FUNC_3BUF(lor, uint16_t, uint16_t, 8, vlor, tlor) +VFUNC_FUNC_3BUF(lor, int32_t, int32_t, 4, vlor, tlor) +VFUNC_FUNC_3BUF(lor, uint32_t, uint32_t, 4, vlor, tlor) +VFUNC_FUNC_3BUF(lor, int64_t, int64_t, 2, vlor, tlor) +VFUNC_FUNC_3BUF(lor, uint64_t, uint64_t, 2, vlor, tlor) +VFUNC_FUNC_3BUF(lor, long, long, 2, vlor, tlor) +VFUNC_FUNC_3BUF(lor, ulong, unsigned long, 2, vlor, tlor) /* C++ bool */ -FUNC_FUNC_3BUF(lor, bool, bool) +VFUNC_FUNC_3BUF(lor, bool, bool, 16, vlor, tlor) /************************************************************************* * Logical XOR *************************************************************************/ -#undef current_func -#define current_func(a, b) ((a ? 1 : 0) ^ (b ? 1: 0)) /* C integer */ -FUNC_FUNC_3BUF(lxor, int8_t, int8_t) -FUNC_FUNC_3BUF(lxor, uint8_t, uint8_t) -FUNC_FUNC_3BUF(lxor, int16_t, int16_t) -FUNC_FUNC_3BUF(lxor, uint16_t, uint16_t) -FUNC_FUNC_3BUF(lxor, int32_t, int32_t) -FUNC_FUNC_3BUF(lxor, uint32_t, uint32_t) -FUNC_FUNC_3BUF(lxor, int64_t, int64_t) -FUNC_FUNC_3BUF(lxor, uint64_t, uint64_t) -FUNC_FUNC_3BUF(lxor, long, long) -FUNC_FUNC_3BUF(lxor, ulong, unsigned long) +VFUNC_FUNC_3BUF(lxor, int8_t, int8_t, 16, vlxor, tlxor) +VFUNC_FUNC_3BUF(lxor, uint8_t, uint8_t, 16, vlxor, tlxor) +VFUNC_FUNC_3BUF(lxor, int16_t, int16_t, 8, vlxor, tlxor) +VFUNC_FUNC_3BUF(lxor, uint16_t, uint16_t, 8, vlxor, tlxor) +VFUNC_FUNC_3BUF(lxor, int32_t, int32_t, 4, vlxor, tlxor) +VFUNC_FUNC_3BUF(lxor, uint32_t, uint32_t, 4, vlxor, tlxor) +VFUNC_FUNC_3BUF(lxor, int64_t, int64_t, 2, vlxor, tlxor) +VFUNC_FUNC_3BUF(lxor, uint64_t, uint64_t, 2, vlxor, tlxor) +VFUNC_FUNC_3BUF(lxor, long, long, 2, vlxor, tlxor) +VFUNC_FUNC_3BUF(lxor, ulong, unsigned long, 2, vlxor, tlxor) /* C++ bool */ -FUNC_FUNC_3BUF(lxor, bool, bool) +VFUNC_FUNC_3BUF(lxor, bool, bool, 16, vlxor, tlxor) + /************************************************************************* * Bitwise AND *************************************************************************/ -#undef current_func -#define current_func(a, b) ((a) & (b)) /* C integer */ -FUNC_FUNC_3BUF(band, int8_t, int8_t) -FUNC_FUNC_3BUF(band, uint8_t, uint8_t) -FUNC_FUNC_3BUF(band, int16_t, int16_t) -FUNC_FUNC_3BUF(band, uint16_t, uint16_t) -FUNC_FUNC_3BUF(band, int32_t, int32_t) -FUNC_FUNC_3BUF(band, uint32_t, uint32_t) -FUNC_FUNC_3BUF(band, int64_t, int64_t) -FUNC_FUNC_3BUF(band, uint64_t, uint64_t) -FUNC_FUNC_3BUF(band, long, long) -FUNC_FUNC_3BUF(band, ulong, unsigned long) - -/* Byte */ -FUNC_FUNC_3BUF(band, byte, char) +VFUNC_FUNC_3BUF(band, int8_t, int8_t, 16, vband, tband) +VFUNC_FUNC_3BUF(band, uint8_t, uint8_t, 16, vband, tband) +VFUNC_FUNC_3BUF(band, int16_t, int16_t, 8, vband, tband) +VFUNC_FUNC_3BUF(band, uint16_t, uint16_t, 8, vband, tband) +VFUNC_FUNC_3BUF(band, int32_t, int32_t, 4, vband, tband) +VFUNC_FUNC_3BUF(band, uint32_t, uint32_t, 4, vband, tband) +VFUNC_FUNC_3BUF(band, int64_t, int64_t, 2, vband, tband) +VFUNC_FUNC_3BUF(band, uint64_t, uint64_t, 2, vband, tband) +VFUNC_FUNC_3BUF(band, long, long, 2, vband, tband) +VFUNC_FUNC_3BUF(band, ulong, unsigned long, 2, vband, tband) + +/* C++ byte */ +VFUNC_FUNC_3BUF(band, byte, char, 16, vband, tband) /************************************************************************* * Bitwise OR *************************************************************************/ -#undef current_func -#define current_func(a, b) ((a) | (b)) -/* C integer */ -FUNC_FUNC_3BUF(bor, int8_t, int8_t) -FUNC_FUNC_3BUF(bor, uint8_t, uint8_t) -FUNC_FUNC_3BUF(bor, int16_t, int16_t) -FUNC_FUNC_3BUF(bor, uint16_t, uint16_t) -FUNC_FUNC_3BUF(bor, int32_t, int32_t) -FUNC_FUNC_3BUF(bor, uint32_t, uint32_t) -FUNC_FUNC_3BUF(bor, int64_t, int64_t) -FUNC_FUNC_3BUF(bor, uint64_t, uint64_t) -FUNC_FUNC_3BUF(bor, long, long) -FUNC_FUNC_3BUF(bor, ulong, unsigned long) - -/* Byte */ -FUNC_FUNC_3BUF(bor, byte, char) - -/************************************************************************* - * Bitwise XOR - *************************************************************************/ - -#undef current_func -#define current_func(a, b) ((a) ^ (b)) /* C integer */ -FUNC_FUNC_3BUF(bxor, int8_t, int8_t) -FUNC_FUNC_3BUF(bxor, uint8_t, uint8_t) -FUNC_FUNC_3BUF(bxor, int16_t, int16_t) -FUNC_FUNC_3BUF(bxor, uint16_t, uint16_t) -FUNC_FUNC_3BUF(bxor, int32_t, int32_t) -FUNC_FUNC_3BUF(bxor, uint32_t, uint32_t) -FUNC_FUNC_3BUF(bxor, int64_t, int64_t) -FUNC_FUNC_3BUF(bxor, uint64_t, uint64_t) -FUNC_FUNC_3BUF(bxor, long, long) -FUNC_FUNC_3BUF(bxor, ulong, unsigned long) - -/* Byte */ -FUNC_FUNC_3BUF(bxor, byte, char) +VFUNC_FUNC_3BUF(bor, int8_t, int8_t, 16, vbor, tbor) +VFUNC_FUNC_3BUF(bor, uint8_t, uint8_t, 16, vbor, tbor) +VFUNC_FUNC_3BUF(bor, int16_t, int16_t, 8, vbor, tbor) +VFUNC_FUNC_3BUF(bor, uint16_t, uint16_t, 8, vbor, tbor) +VFUNC_FUNC_3BUF(bor, int32_t, int32_t, 4, vbor, tbor) +VFUNC_FUNC_3BUF(bor, uint32_t, uint32_t, 4, vbor, tbor) +VFUNC_FUNC_3BUF(bor, int64_t, int64_t, 2, vbor, tbor) +VFUNC_FUNC_3BUF(bor, uint64_t, uint64_t, 2, vbor, tbor) +VFUNC_FUNC_3BUF(bor, long, long, 2, vbor, tbor) +VFUNC_FUNC_3BUF(bor, ulong, unsigned long, 2, vbor, tbor) + +/* C++ byte */ +VFUNC_FUNC_3BUF(bor, byte, char, 16, vbor, tbor) /************************************************************************* * Max location From 37c5dad85aa6507f44dde3586d753a959b1e9163 Mon Sep 17 00:00:00 2001 From: Joseph Schuchart Date: Thu, 20 Jun 2024 18:17:27 -0400 Subject: [PATCH 06/12] op/cuda: cleanup and remove short float remnants Signed-off-by: Joseph Schuchart --- ompi/mca/op/cuda/op_cuda_functions.c | 104 +-------------------------- ompi/mca/op/cuda/op_cuda_impl.h | 92 ------------------------ 2 files changed, 3 insertions(+), 193 deletions(-) diff --git a/ompi/mca/op/cuda/op_cuda_functions.c b/ompi/mca/op/cuda/op_cuda_functions.c index 27361cee6a3..b3a1e3d18ad 100644 --- a/ompi/mca/op/cuda/op_cuda_functions.c +++ b/ompi/mca/op/cuda/op_cuda_functions.c @@ -344,14 +344,6 @@ FORT_INT_FUNC(max, fortran_integer8, ompi_fortran_integer8_t) FORT_INT_FUNC(max, fortran_integer16, ompi_fortran_integer16_t) #endif -#if 0 -/* Floating point */ -#if defined(HAVE_SHORT_FLOAT) -FUNC_FUNC(max, short_float, short float) -#elif defined(HAVE_OPAL_SHORT_FLOAT_T) -FUNC_FUNC(max, short_float, opal_short_float_t) -#endif -#endif // 0 FUNC_FUNC(max, float, float) FUNC_FUNC(max, double, double) FUNC_FUNC(max, long_double, long double) @@ -411,15 +403,6 @@ FORT_INT_FUNC(min, fortran_integer8, ompi_fortran_integer8_t) FORT_INT_FUNC(min, fortran_integer16, ompi_fortran_integer16_t) #endif -#if 0 -/* Floating point */ -#if defined(HAVE_SHORT_FLOAT) -FUNC_FUNC(min, short_float, short float) -#elif defined(HAVE_OPAL_SHORT_FLOAT_T) -FUNC_FUNC(min, short_float, opal_short_float_t) -#endif -#endif // 0 - FUNC_FUNC(min, float, float) FUNC_FUNC(min, double, double) FUNC_FUNC(min, long_double, long double) @@ -478,15 +461,6 @@ FORT_INT_FUNC(sum, fortran_integer8, ompi_fortran_integer8_t) FORT_INT_FUNC(sum, fortran_integer16, ompi_fortran_integer16_t) #endif -#if 0 -/* Floating point */ -#if defined(HAVE_SHORT_FLOAT) -OP_FUNC(sum, short_float, short float) -#elif defined(HAVE_OPAL_SHORT_FLOAT_T) -OP_FUNC(sum, short_float, opal_short_float_t) -#endif -#endif // 0 - OP_FUNC(sum, float, float) OP_FUNC(sum, double, double) OP_FUNC(sum, long_double, long double) @@ -508,16 +482,8 @@ FORT_FLOAT_FUNC(sum, fortran_real8, ompi_fortran_real8_t) #if OMPI_HAVE_FORTRAN_REAL16 && OMPI_REAL16_MATCHES_C FORT_FLOAT_FUNC(sum, fortran_real16, ompi_fortran_real16_t) #endif -/* Complex */ -#if 0 -#if defined(HAVE_SHORT_FLOAT__COMPLEX) -OP_FUNC(sum, c_short_float_complex, short float _Complex) -#elif defined(HAVE_OPAL_SHORT_FLOAT_COMPLEX_T) -COMPLEX_SUM_FUNC(c_short_float_complex, opal_short_float_t) -#endif -OP_FUNC(sum, c_long_double_complex, long double _Complex) -#endif // 0 +/* Complex */ FUNC_FUNC(sum, c_float_complex, cuFloatComplex) FUNC_FUNC(sum, c_double_complex, cuDoubleComplex) @@ -556,16 +522,8 @@ FORT_INT_FUNC(prod, fortran_integer8, ompi_fortran_integer8_t) #if OMPI_HAVE_FORTRAN_INTEGER16 FORT_INT_FUNC(prod, fortran_integer16, ompi_fortran_integer16_t) #endif -/* Floating point */ - -#if 0 -#if defined(HAVE_SHORT_FLOAT) -OP_FUNC(prod, short_float, short float) -#elif defined(HAVE_OPAL_SHORT_FLOAT_T) -OP_FUNC(prod, short_float, opal_short_float_t) -#endif -#endif // 0 +/* Floating point */ OP_FUNC(prod, float, float) OP_FUNC(prod, double, double) OP_FUNC(prod, long_double, long double) @@ -587,16 +545,8 @@ FORT_FLOAT_FUNC(prod, fortran_real8, ompi_fortran_real8_t) #if OMPI_HAVE_FORTRAN_REAL16 && OMPI_REAL16_MATCHES_C FORT_FLOAT_FUNC(prod, fortran_real16, ompi_fortran_real16_t) #endif -/* Complex */ -#if 0 -#if defined(HAVE_SHORT_FLOAT__COMPLEX) -OP_FUNC(prod, c_short_float_complex, short float _Complex) -#elif defined(HAVE_OPAL_SHORT_FLOAT_COMPLEX_T) -COMPLEX_PROD_FUNC(c_short_float_complex, opal_short_float_t) -#endif -OP_FUNC(prod, c_long_double_complex, long double _Complex) -#endif // 0 +/* Complex */ FUNC_FUNC(prod, c_float_complex, cuFloatComplex) FUNC_FUNC(prod, c_double_complex, cuDoubleComplex) @@ -1016,13 +966,6 @@ FORT_INT_FUNC_3BUF(max, fortran_integer4, ompi_fortran_integer4_t) FORT_INT_FUNC_3BUF(max, fortran_integer8, ompi_fortran_integer8_t) #endif /* Floating point */ -#if 0 -#if defined(HAVE_SHORT_FLOAT) -FUNC_FUNC_3BUF(max, short_float, short float) -#elif defined(HAVE_OPAL_SHORT_FLOAT_T) -FUNC_FUNC_3BUF(max, short_float, opal_short_float_t) -#endif -#endif // 0 FUNC_FUNC_3BUF(max, float, float) FUNC_FUNC_3BUF(max, double, double) FUNC_FUNC_3BUF(max, long_double, long double) @@ -1082,13 +1025,6 @@ FORT_INT_FUNC_3BUF(min, fortran_integer8, ompi_fortran_integer8_t) FORT_INT_FUNC_3BUF(min, fortran_integer16, ompi_fortran_integer16_t) #endif /* Floating point */ -#if 0 -#if defined(HAVE_SHORT_FLOAT) -FUNC_FUNC_3BUF(min, short_float, short float) -#elif defined(HAVE_OPAL_SHORT_FLOAT_T) -FUNC_FUNC_3BUF(min, short_float, opal_short_float_t) -#endif -#endif // 0 FUNC_FUNC_3BUF(min, float, float) FUNC_FUNC_3BUF(min, double, double) FUNC_FUNC_3BUF(min, long_double, long double) @@ -1147,13 +1083,6 @@ FORT_INT_FUNC_3BUF(sum, fortran_integer8, ompi_fortran_integer8_t) FORT_INT_FUNC_3BUF(sum, fortran_integer16, ompi_fortran_integer16_t) #endif /* Floating point */ -#if 0 -#if defined(HAVE_SHORT_FLOAT) -OP_FUNC_3BUF(sum, short_float, short float) -#elif defined(HAVE_OPAL_SHORT_FLOAT_T) -OP_FUNC_3BUF(sum, short_float, opal_short_float_t) -#endif -#endif // 0 OP_FUNC_3BUF(sum, float, float) OP_FUNC_3BUF(sum, double, double) OP_FUNC_3BUF(sum, long_double, long double) @@ -1176,17 +1105,6 @@ FORT_FLOAT_FUNC_3BUF(sum, fortran_real8, ompi_fortran_real8_t) FORT_FLOAT_FUNC_3BUF(sum, fortran_real16, ompi_fortran_real16_t) #endif /* Complex */ -#if 0 -#if defined(HAVE_SHORT_FLOAT__COMPLEX) -OP_FUNC_3BUF(sum, c_short_float_complex, short float _Complex) -#elif defined(HAVE_OPAL_SHORT_FLOAT_COMPLEX_T) -COMPLEX_SUM_FUNC_3BUF(c_short_float_complex, opal_short_float_t) -#endif -OP_FUNC_3BUF(sum, c_float_complex, float _Complex) -OP_FUNC_3BUF(sum, c_double_complex, double _Complex) -OP_FUNC_3BUF(sum, c_long_double_complex, long double _Complex) -#endif // 0 - FUNC_FUNC_3BUF(sum, c_float_complex, cuFloatComplex) FUNC_FUNC_3BUF(sum, c_double_complex, cuDoubleComplex) @@ -1226,13 +1144,6 @@ FORT_INT_FUNC_3BUF(prod, fortran_integer8, ompi_fortran_integer8_t) FORT_INT_FUNC_3BUF(prod, fortran_integer16, ompi_fortran_integer16_t) #endif /* Floating point */ -#if 0 -#if defined(HAVE_SHORT_FLOAT) -FORT_FLOAT_FUNC_3BUF(prod, short_float, short float) -#elif defined(HAVE_OPAL_SHORT_FLOAT_T) -FORT_FLOAT_FUNC_3BUF(prod, short_float, opal_short_float_t) -#endif -#endif // 0 OP_FUNC_3BUF(prod, float, float) OP_FUNC_3BUF(prod, double, double) OP_FUNC_3BUF(prod, long_double, long double) @@ -1255,15 +1166,6 @@ FORT_FLOAT_FUNC_3BUF(prod, fortran_real8, ompi_fortran_real8_t) FORT_FLOAT_FUNC_3BUF(prod, fortran_real16, ompi_fortran_real16_t) #endif /* Complex */ -#if 0 -#if defined(HAVE_SHORT_FLOAT__COMPLEX) -OP_FUNC_3BUF(prod, c_short_float_complex, short float _Complex) -#elif defined(HAVE_OPAL_SHORT_FLOAT_COMPLEX_T) -COMPLEX_PROD_FUNC_3BUF(c_short_float_complex, opal_short_float_t) -#endif -OP_FUNC_3BUF(prod, c_long_double_complex, long double _Complex) -#endif // 0 - FUNC_FUNC_3BUF(prod, c_float_complex, cuFloatComplex) FUNC_FUNC_3BUF(prod, c_double_complex, cuDoubleComplex) diff --git a/ompi/mca/op/cuda/op_cuda_impl.h b/ompi/mca/op/cuda/op_cuda_impl.h index 43209581bab..3eb63daa32f 100644 --- a/ompi/mca/op/cuda/op_cuda_impl.h +++ b/ompi/mca/op/cuda/op_cuda_impl.h @@ -82,15 +82,7 @@ FUNC_FUNC_SIG(max, uint64_t, uint64_t) FUNC_FUNC_SIG(max, long, long) FUNC_FUNC_SIG(max, ulong, unsigned long) -#if 0 /* Floating point */ -#if defined(HAVE_SHORT_FLOAT) -FUNC_FUNC_SIG(max, short_float, short float) -#elif defined(HAVE_OPAL_SHORT_FLOAT_T) -FUNC_FUNC_SIG(max, short_float, opal_short_float_t) -#endif -#endif // 0 - FUNC_FUNC_SIG(max, float, float) FUNC_FUNC_SIG(max, double, double) FUNC_FUNC_SIG(max, long_double, long double) @@ -111,15 +103,7 @@ FUNC_FUNC_SIG(min, uint64_t, uint64_t) FUNC_FUNC_SIG(min, long, long) FUNC_FUNC_SIG(min, ulong, unsigned long) -#if 0 /* Floating point */ -#if defined(HAVE_SHORT_FLOAT) -FUNC_FUNC_SIG(min, short_float, short float) -#elif defined(HAVE_OPAL_SHORT_FLOAT_T) -FUNC_FUNC_SIG(min, short_float, opal_short_float_t) -#endif -#endif // 0 - FUNC_FUNC_SIG(min, float, float) FUNC_FUNC_SIG(min, double, double) FUNC_FUNC_SIG(min, long_double, long double) @@ -140,32 +124,13 @@ OP_FUNC_SIG(sum, uint64_t, uint64_t) OP_FUNC_SIG(sum, long, long) OP_FUNC_SIG(sum, ulong, unsigned long) -//#if __CUDA_ARCH__ >= 530 -//OP_FUNC_SIG(sum, half, half) -//#endif // __CUDA_ARCH__ -#if 0 /* Floating point */ -#if defined(HAVE_SHORT_FLOAT) -OP_FUNC_SIG(sum, short_float, short float) -#elif defined(HAVE_OPAL_SHORT_FLOAT_T) -OP_FUNC_SIG(sum, short_float, opal_short_float_t) -#endif -#endif // 0 - OP_FUNC_SIG(sum, float, float) OP_FUNC_SIG(sum, double, double) OP_FUNC_SIG(sum, long_double, long double) /* Complex */ -#if 0 -#if defined(HAVE_SHORT_FLOAT__COMPLEX) -OP_FUNC_SIG(sum, c_short_float_complex, short float _Complex) -#elif defined(HAVE_OPAL_SHORT_FLOAT_COMPLEX_T) -COMPLEX_SUM_FUNC(c_short_float_complex, opal_short_float_t) -OP_FUNC_SIG(sum, c_long_double_complex, long double _Complex) -#endif -#endif // 0 FUNC_FUNC_SIG(sum, c_float_complex, cuFloatComplex) FUNC_FUNC_SIG(sum, c_double_complex, cuDoubleComplex) @@ -185,30 +150,11 @@ OP_FUNC_SIG(prod, uint64_t, uint64_t) OP_FUNC_SIG(prod, long, long) OP_FUNC_SIG(prod, ulong, unsigned long) -#if 0 /* Floating point */ -#if defined(HAVE_SHORT_FLOAT) -OP_FUNC_SIG(prod, short_float, short float) -#elif defined(HAVE_OPAL_SHORT_FLOAT_T) -OP_FUNC_SIG(prod, short_float, opal_short_float_t) -#endif -#endif // 0 - -OP_FUNC_SIG(prod, float, float) OP_FUNC_SIG(prod, float, float) OP_FUNC_SIG(prod, double, double) OP_FUNC_SIG(prod, long_double, long double) -/* Complex */ -#if 0 -#if defined(HAVE_SHORT_FLOAT__COMPLEX) -OP_FUNC_SIG(prod, c_short_float_complex, short float _Complex) -#elif defined(HAVE_OPAL_SHORT_FLOAT_COMPLEX_T) -COMPLEX_PROD_FUNC(c_short_float_complex, opal_short_float_t) -#endif -OP_FUNC_SIG(prod, c_long_double_complex, long double _Complex) -#endif // 0 - FUNC_FUNC_SIG(prod, c_float_complex, cuFloatComplex) FUNC_FUNC_SIG(prod, c_double_complex, cuDoubleComplex) @@ -428,11 +374,6 @@ FUNC_FUNC_3BUF_SIG(max, long, long) FUNC_FUNC_3BUF_SIG(max, ulong, unsigned long) /* Floating point */ -#if defined(HAVE_SHORT_FLOAT) -FUNC_FUNC_3BUF_SIG(max, short_float, short float) -#elif defined(HAVE_OPAL_SHORT_FLOAT_T) -FUNC_FUNC_3BUF_SIG(max, short_float, opal_short_float_t) -#endif FUNC_FUNC_3BUF_SIG(max, float, float) FUNC_FUNC_3BUF_SIG(max, double, double) FUNC_FUNC_3BUF_SIG(max, long_double, long double) @@ -454,11 +395,6 @@ FUNC_FUNC_3BUF_SIG(min, long, long) FUNC_FUNC_3BUF_SIG(min, ulong, unsigned long) /* Floating point */ -#if defined(HAVE_SHORT_FLOAT) -FUNC_FUNC_3BUF_SIG(min, short_float, short float) -#elif defined(HAVE_OPAL_SHORT_FLOAT_T) -FUNC_FUNC_3BUF_SIG(min, short_float, opal_short_float_t) -#endif FUNC_FUNC_3BUF_SIG(min, float, float) FUNC_FUNC_3BUF_SIG(min, double, double) FUNC_FUNC_3BUF_SIG(min, long_double, long double) @@ -480,24 +416,11 @@ OP_FUNC_3BUF_SIG(sum, long, long) OP_FUNC_3BUF_SIG(sum, ulong, unsigned long) /* Floating point */ -#if defined(HAVE_SHORT_FLOAT) -OP_FUNC_3BUF_SIG(sum, short_float, short float) -#elif defined(HAVE_OPAL_SHORT_FLOAT_T) -OP_FUNC_3BUF_SIG(sum, short_float, opal_short_float_t) -#endif OP_FUNC_3BUF_SIG(sum, float, float) OP_FUNC_3BUF_SIG(sum, double, double) OP_FUNC_3BUF_SIG(sum, long_double, long double) /* Complex */ -#if 0 -#if defined(HAVE_SHORT_FLOAT__COMPLEX) -OP_FUNC_3BUF_SIG(sum, c_short_float_complex, short float _Complex) -#elif defined(HAVE_OPAL_SHORT_FLOAT_COMPLEX_T) -COMPLEX_SUM_FUNC_3BUF(c_short_float_complex, opal_short_float_t) -#endif -OP_FUNC_3BUF_SIG(sum, c_long_double_complex, long double _Complex) -#endif // 0 FUNC_FUNC_3BUF_SIG(sum, c_float_complex, cuFloatComplex) FUNC_FUNC_3BUF_SIG(sum, c_double_complex, cuDoubleComplex) @@ -518,26 +441,11 @@ OP_FUNC_3BUF_SIG(prod, long, long) OP_FUNC_3BUF_SIG(prod, ulong, unsigned long) /* Floating point */ -#if defined(HAVE_SHORT_FLOAT) -OP_FUNC_3BUF_SIG(prod, short_float, short float) -#elif defined(HAVE_OPAL_SHORT_FLOAT_T) -OP_FUNC_3BUF_SIG(prod, short_float, opal_short_float_t) -#endif OP_FUNC_3BUF_SIG(prod, float, float) OP_FUNC_3BUF_SIG(prod, double, double) OP_FUNC_3BUF_SIG(prod, long_double, long double) /* Complex */ -#if 0 -#if defined(HAVE_SHORT_FLOAT__COMPLEX) -OP_FUNC_3BUF_SIG(prod, c_short_float_complex, short float _Complex) -#elif defined(HAVE_OPAL_SHORT_FLOAT_COMPLEX_T) -COMPLEX_PROD_FUNC_3BUF(c_short_float_complex, opal_short_float_t) -#endif -OP_FUNC_3BUF_SIG(prod, c_float_complex, float _Complex) -OP_FUNC_3BUF_SIG(prod, c_double_complex, double _Complex) -OP_FUNC_3BUF_SIG(prod, c_long_double_complex, long double _Complex) -#endif // 0 FUNC_FUNC_3BUF_SIG(prod, c_float_complex, cuFloatComplex) FUNC_FUNC_3BUF_SIG(prod, c_double_complex, cuDoubleComplex) From 4d4d6291173708702cdb0038192cf98a66da904f Mon Sep 17 00:00:00 2001 From: Joseph Schuchart Date: Fri, 21 Jun 2024 11:07:15 -0400 Subject: [PATCH 07/12] Add LDFLAGS to op/rocm linker command Signed-off-by: Joseph Schuchart --- ompi/mca/op/rocm/Makefile.am | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ompi/mca/op/rocm/Makefile.am b/ompi/mca/op/rocm/Makefile.am index a4d999e25f9..93a110c7e23 100644 --- a/ompi/mca/op/rocm/Makefile.am +++ b/ompi/mca/op/rocm/Makefile.am @@ -65,7 +65,7 @@ mcacomponent_LTLIBRARIES = $(component_install) mca_op_rocm_la_SOURCES = $(sources) mca_op_rocm_la_LIBADD = $(rocm_sources:.cpp=.lo) mca_op_rocm_la_LDFLAGS = -module -avoid-version $(top_builddir)/ompi/lib@OMPI_LIBMPI_NAME@.la \ - $(op_rocm_LIBS) + $(op_rocm_LDFLAGS) $(op_rocm_LIBS) EXTRA_mca_op_rocm_la_SOURCES = $(rocm_sources) # Specific information for static builds. @@ -77,6 +77,6 @@ noinst_LTLIBRARIES = $(component_noinst) libmca_op_rocm_la_SOURCES = $(sources) libmca_op_rocm_la_LIBADD = $(rocm_sources:.cpp=.lo) libmca_op_rocm_la_LDFLAGS = -module -avoid-version\ - $(op_rocm_LIBS) + $(op_rocm_LDFLAGS) $(op_rocm_LIBS) EXTRA_libmca_op_rocm_la_SOURCES = $(rocm_sources) From 9fe635103ac11bcf451a526e65a731e84b16b96d Mon Sep 17 00:00:00 2001 From: Joseph Schuchart Date: Fri, 5 Jul 2024 16:32:16 -0400 Subject: [PATCH 08/12] First attempt to check for NVCC Signed-off-by: Joseph Schuchart --- config/opal_check_nvcc.m4 | 74 +++++++++++++++++++++++++++++++++++ ompi/mca/op/cuda/Makefile.am | 10 ++--- ompi/mca/op/cuda/configure.m4 | 6 ++- 3 files changed, 83 insertions(+), 7 deletions(-) create mode 100644 config/opal_check_nvcc.m4 diff --git a/config/opal_check_nvcc.m4 b/config/opal_check_nvcc.m4 new file mode 100644 index 00000000000..6443cdd01a2 --- /dev/null +++ b/config/opal_check_nvcc.m4 @@ -0,0 +1,74 @@ +dnl -*- autoconf -*- +dnl +dnl Copyright (c) 2020-2022 Cisco Systems, Inc. All rights reserved. +dnl Copyright (c) 2024 Jeffrey M. Squyres. All rights reserved. +dnl +dnl $COPYRIGHT$ +dnl +dnl Additional copyrights may follow +dnl +dnl $HEADER$ +dnl + +dnl Setup Sphinx for building HTML docs and man pages +dnl +dnl 1 -> sanity file to check if pre-built docs are already available +dnl You probably want to pass something like +dnl "$srcdir/docs/_build/man/foo.1" +dnl +dnl 2 -> (OPTIONAL) URL to display in AC_MSG_WARN when docs will not be installed +dnl If $2 is empty, nothing will be displayed. +dnl Note: if $2 contains a #, be sure to double quote it +dnl (e.g., [[https://example.com/foo.html#some-anchor]]) +dnl +dnl 3 -> (OPTIONAL) Filename of requirements.txt-like file containing +dnl the required Pip modules (to be displayed if rendering a +dnl simple RST project fails). +dnl +dnl This macro requires that OAC_PUSH_PREFIX was previously called. +dnl The pushed prefix may be used if this macro chooses to set {OAC +dnl prefix}_MAKEDIST_DISABLE. If set, it is a message indicating why +dnl "make dist" should be disabled, suitable for emitting via +dnl AC_MSG_WARN. +AC_DEFUN([OPAL_CHECK_NVCC],[ + + # This option is probably only helpful to developers: have + # configure fail if Sphinx is not found (i.e., if you don't have + # the ability to use Sphinx to build the HTML docs and man pages). + AC_ARG_ENABLE([nvcc], + [AS_HELP_STRING([--enable-nvcc], + [Force configure to fail if CUDA nvcc is not found (CUDA nvcc is used to build CUDA operator support).])]) + + AC_ARG_WITH([nvcc], + [AS_HELP_STRING([--with-nvcc=DIR], + [Path to the CUDA compiler])]) + + AC_ARG_WITH([nvcc_compute_arch], + [AS_HELP_STRING([--with-nvcc-compute-arch=ARCH], + [Compute architecture to use for CUDA (default: 52)])]) + + AS_IF([test -n "$with_nvcc"], + [OPAL_NVCC=$with_nvcc], + # no path specified, try to find nvcc + [AC_PATH_PROG([OPAL_NVCC], [nvcc], [])]) + + # If the user requested to disable sphinx, then pretend we didn't + # find it. + AS_IF([test "$enable_nvcc" = "no"], + [OPAL_NVCC=]) + + # default to CUDA compute architecture 52 + AS_IF([test -n "$with_nvcc_compute_arch"], + [OPAL_NVCC_COMPUTE_ARCH=$with_nvcc_compute_arch], + [OPAL_NVCC_COMPUTE_ARCH=52]) + + # If --enable-sphinx was specified and we did not find Sphinx, + # abort. This is likely only useful to prevent "oops!" moments + # from developers. + AS_IF([test -z "$OPAL_NVCC" && test "$enable_nvcc" = "yes"], + [AC_MSG_WARN([A suitable CUDA compiler was not found, but --enable-nvcc was specified]) + AC_MSG_ERROR([Cannot continue])]) + + OPAL_SUMMARY_ADD([Accelerators], [NVCC compiler], [], [$OPAL_NVCC (compute arch: $OPAL_NVCC_COMPUTE_ARCH)]) + +]) diff --git a/ompi/mca/op/cuda/Makefile.am b/ompi/mca/op/cuda/Makefile.am index 7075d26301c..dc776220f6e 100644 --- a/ompi/mca/op/cuda/Makefile.am +++ b/ompi/mca/op/cuda/Makefile.am @@ -22,17 +22,15 @@ AM_CPPFLAGS = $(op_cuda_CPPFLAGS) $(op_cudart_CPPFLAGS) dist_ompidata_DATA = help-ompi-mca-op-cuda.txt sources = op_cuda_component.c op_cuda.h op_cuda_functions.c op_cuda_impl.h -#sources_extended = op_cuda_functions.cu cu_sources = op_cuda_impl.cu -NVCC = nvcc -g -NVCCFLAGS= --std c++17 --gpu-architecture=compute_52 +NVCCFLAGS = --std c++17 --gpu-architecture=compute_$(OPAL_NVCC_COMPUTE_ARCH) +# let the underlying compiler generate PIC code +PIC_FLAGS = -prefer-non-pic -Wc,-Xcompiler,-fPIC .cu.l$(OBJEXT): $(LIBTOOL) $(AM_V_lt) --tag=CC $(AM_LIBTOOLFLAGS) \ - $(LIBTOOLFLAGS) --mode=compile $(NVCC) -prefer-non-pic $(NVCCFLAGS) -Wc,-Xcompiler,-fPIC,-g -c $< - -# -o $($@.o:.lo) + $(LIBTOOLFLAGS) --mode=compile $(OPAL_NVCC) $(NVCCFLAGS) $(PIC_FLAGS) -c $< # Open MPI components can be compiled two ways: # diff --git a/ompi/mca/op/cuda/configure.m4 b/ompi/mca/op/cuda/configure.m4 index 0974e3aaf31..15b3fd7c2d1 100644 --- a/ompi/mca/op/cuda/configure.m4 +++ b/ompi/mca/op/cuda/configure.m4 @@ -25,8 +25,9 @@ AC_DEFUN([MCA_ompi_op_cuda_CONFIG],[ OPAL_CHECK_CUDA([op_cuda]) OPAL_CHECK_CUDART([op_cudart]) + OPAL_CHECK_NVCC([op_nvcc]) - AS_IF([test "x$CUDA_SUPPORT" = "x1"], + AS_IF([test "x$CUDA_SUPPORT" = "x1" -a "x$CUDART_SUPPORT" = "x1" -a -n "$OPAL_NVCC"], [$1], [$2]) @@ -38,4 +39,7 @@ AC_DEFUN([MCA_ompi_op_cuda_CONFIG],[ AC_SUBST([op_cudart_LDFLAGS]) AC_SUBST([op_cudart_LIBS]) + AC_SUBST([OPAL_NVCC]) + AC_SUBST([OPAL_NVCC_COMPUTE_ARCH]) + ])dnl From 60cc5aa4e216b0e0ca5893ceabbdfc72296310bf Mon Sep 17 00:00:00 2001 From: Joseph Schuchart Date: Mon, 8 Jul 2024 11:03:35 -0400 Subject: [PATCH 09/12] Add check for hipcc Signed-off-by: Joseph Schuchart --- config/opal_check_hipcc.m4 | 51 +++++++++++++++++++++++++++++++++++ config/opal_check_nvcc.m4 | 29 +++++++------------- ompi/mca/op/rocm/Makefile.am | 2 +- ompi/mca/op/rocm/configure.m4 | 5 +++- 4 files changed, 65 insertions(+), 22 deletions(-) create mode 100644 config/opal_check_hipcc.m4 diff --git a/config/opal_check_hipcc.m4 b/config/opal_check_hipcc.m4 new file mode 100644 index 00000000000..568a7854ad2 --- /dev/null +++ b/config/opal_check_hipcc.m4 @@ -0,0 +1,51 @@ +dnl -*- autoconf -*- +dnl +dnl Copyright (c) 2024 Stony Brook University. All rights reserved. +dnl +dnl $COPYRIGHT$ +dnl +dnl Additional copyrights may follow +dnl +dnl $HEADER$ +dnl + +dnl +dnl Check for HIPCC and bail out if HIPCC was requested +dnl Options provided: +dnl --with-hipcc[=path/to/hipcc]: provide a path to HIPCC +dnl --enable-hipcc: require HIPCC, bail out if not found +dnl + +AC_DEFUN([OPAL_CHECK_HIPCC],[ + + # This option is probably only helpful to developers: have + # configure fail if Sphinx is not found (i.e., if you don't have + # the ability to use Sphinx to build the HTML docs and man pages). + AC_ARG_ENABLE([hipcc], + [AS_HELP_STRING([--enable-hipcc], + [Force configure to fail if CUDA hipcc is not found (CUDA hipcc is used to build CUDA operator support).])]) + + AC_ARG_WITH([hipcc], + [AS_HELP_STRING([--with-hipcc=DIR], + [Path to the CUDA compiler])]) + + AS_IF([test -n "$with_hipcc"], + [OPAL_HIPCC=$with_hipcc], + # no path specified, try to find hipcc + [AC_PATH_PROG([OPAL_HIPCC], [hipcc], [])]) + + # If the user requested to disable sphinx, then pretend we didn't + # find it. + AS_IF([test "$enable_hipcc" = "no"], + [OPAL_HIPCC=]) + + # If --enable-sphinx was specified and we did not find Sphinx, + # abort. This is likely only useful to prevent "oops!" moments + # from developers. + AS_IF([test -z "$OPAL_HIPCC" && test "$enable_hipcc" = "yes"], + [AC_MSG_WARN([A suitable CUDA compiler was not found, but --enable-hipcc was specified]) + AC_MSG_ERROR([Cannot continue])]) + + OPAL_SUMMARY_ADD([Accelerators], [HIPCC compiler], [], [$OPAL_HIPCC]) + +]) diff --git a/config/opal_check_nvcc.m4 b/config/opal_check_nvcc.m4 index 6443cdd01a2..fa29151bebc 100644 --- a/config/opal_check_nvcc.m4 +++ b/config/opal_check_nvcc.m4 @@ -1,7 +1,6 @@ dnl -*- autoconf -*- dnl -dnl Copyright (c) 2020-2022 Cisco Systems, Inc. All rights reserved. -dnl Copyright (c) 2024 Jeffrey M. Squyres. All rights reserved. +dnl Copyright (c) 2024 Stony Brook University. All rights reserved. dnl dnl $COPYRIGHT$ dnl @@ -10,26 +9,16 @@ dnl dnl $HEADER$ dnl -dnl Setup Sphinx for building HTML docs and man pages dnl -dnl 1 -> sanity file to check if pre-built docs are already available -dnl You probably want to pass something like -dnl "$srcdir/docs/_build/man/foo.1" +dnl Check for NVCC and bail out if NVCC was requested +dnl Options provided: +dnl --with-nvcc[=path/to/nvcc]: provide a path to NVCC +dnl --enable-nvcc: require NVCC, bail out if not found +dnl --nvcc-compute-arch: request a specific compute +dnl architecture for the operator +dnl kernels dnl -dnl 2 -> (OPTIONAL) URL to display in AC_MSG_WARN when docs will not be installed -dnl If $2 is empty, nothing will be displayed. -dnl Note: if $2 contains a #, be sure to double quote it -dnl (e.g., [[https://example.com/foo.html#some-anchor]]) -dnl -dnl 3 -> (OPTIONAL) Filename of requirements.txt-like file containing -dnl the required Pip modules (to be displayed if rendering a -dnl simple RST project fails). -dnl -dnl This macro requires that OAC_PUSH_PREFIX was previously called. -dnl The pushed prefix may be used if this macro chooses to set {OAC -dnl prefix}_MAKEDIST_DISABLE. If set, it is a message indicating why -dnl "make dist" should be disabled, suitable for emitting via -dnl AC_MSG_WARN. + AC_DEFUN([OPAL_CHECK_NVCC],[ # This option is probably only helpful to developers: have diff --git a/ompi/mca/op/rocm/Makefile.am b/ompi/mca/op/rocm/Makefile.am index 93a110c7e23..c6177374178 100644 --- a/ompi/mca/op/rocm/Makefile.am +++ b/ompi/mca/op/rocm/Makefile.am @@ -26,7 +26,7 @@ rocm_sources = op_rocm_impl.hip HIPCC = hipcc -.cpp.l$(OBJEXT): +.hip.l$(OBJEXT): $(LIBTOOL) $(AM_V_lt) --tag=CC $(AM_LIBTOOLFLAGS) \ $(LIBTOOLFLAGS) --mode=compile $(HIPCC) -O2 -std=c++17 -fvectorize -prefer-non-pic -Wc,-fPIC,-g -c $< diff --git a/ompi/mca/op/rocm/configure.m4 b/ompi/mca/op/rocm/configure.m4 index ffd88698be0..d773060eed7 100644 --- a/ompi/mca/op/rocm/configure.m4 +++ b/ompi/mca/op/rocm/configure.m4 @@ -24,8 +24,9 @@ AC_DEFUN([MCA_ompi_op_rocm_CONFIG],[ AC_CONFIG_FILES([ompi/mca/op/rocm/Makefile]) OPAL_CHECK_ROCM([op_rocm]) + OPAL_CHECK_HIPCC([op_hipcc]) - AS_IF([test "x$ROCM_SUPPORT" = "x1"], + AS_IF([test "x$ROCM_SUPPORT" = "x1" -a -n "$OPAL_HIPCC"], [$1], [$2]) @@ -33,4 +34,6 @@ AC_DEFUN([MCA_ompi_op_rocm_CONFIG],[ AC_SUBST([op_rocm_LDFLAGS]) AC_SUBST([op_rocm_LIBS]) + AC_SUBST([OPAL_HIPCC]) + ])dnl From 46fbda1bc02eae9e42e45b78b9acc25690e20b72 Mon Sep 17 00:00:00 2001 From: Joseph Schuchart Date: Mon, 8 Jul 2024 18:12:00 -0400 Subject: [PATCH 10/12] Mark NVCC, NVCCFLAGS, HIPCC, and HIPCCFLAGS as precious Signed-off-by: Joseph Schuchart --- config/opal_check_hipcc.m4 | 31 ++++++++++++------------- config/opal_check_nvcc.m4 | 43 +++++++++++++++-------------------- ompi/mca/op/cuda/Makefile.am | 3 +-- ompi/mca/op/cuda/configure.m4 | 5 +--- ompi/mca/op/rocm/Makefile.am | 4 +--- ompi/mca/op/rocm/configure.m4 | 4 +--- 6 files changed, 36 insertions(+), 54 deletions(-) diff --git a/config/opal_check_hipcc.m4 b/config/opal_check_hipcc.m4 index 568a7854ad2..3ed9ecef8ac 100644 --- a/config/opal_check_hipcc.m4 +++ b/config/opal_check_hipcc.m4 @@ -18,34 +18,31 @@ dnl AC_DEFUN([OPAL_CHECK_HIPCC],[ - # This option is probably only helpful to developers: have - # configure fail if Sphinx is not found (i.e., if you don't have - # the ability to use Sphinx to build the HTML docs and man pages). AC_ARG_ENABLE([hipcc], [AS_HELP_STRING([--enable-hipcc], - [Force configure to fail if CUDA hipcc is not found (CUDA hipcc is used to build CUDA operator support).])]) + [Force configure to fail if hipcc is not found (hipcc is used to build HIP operator support).])]) AC_ARG_WITH([hipcc], [AS_HELP_STRING([--with-hipcc=DIR], - [Path to the CUDA compiler])]) + [Path to the HIP compiler])]) AS_IF([test -n "$with_hipcc"], - [OPAL_HIPCC=$with_hipcc], - # no path specified, try to find hipcc - [AC_PATH_PROG([OPAL_HIPCC], [hipcc], [])]) + [HIPCC=$with_hipcc]) + AS_IF([test -z "$HIPCC"], + # try to find hipcc in PATH + [AC_PATH_PROG([HIPCC], [hipcc], [])]) - # If the user requested to disable sphinx, then pretend we didn't - # find it. + # disable support if explicitly specified AS_IF([test "$enable_hipcc" = "no"], - [OPAL_HIPCC=]) + [HIPCC=]) - # If --enable-sphinx was specified and we did not find Sphinx, - # abort. This is likely only useful to prevent "oops!" moments - # from developers. - AS_IF([test -z "$OPAL_HIPCC" && test "$enable_hipcc" = "yes"], - [AC_MSG_WARN([A suitable CUDA compiler was not found, but --enable-hipcc was specified]) + AS_IF([test -z "$HIPCC" && test "$enable_hipcc" = "yes"], + [AC_MSG_WARN([A suitable HIP compiler was not found, but --enable-hipcc=yes was specified]) AC_MSG_ERROR([Cannot continue])]) - OPAL_SUMMARY_ADD([Accelerators], [HIPCC compiler], [], [$OPAL_HIPCC]) + OPAL_SUMMARY_ADD([Accelerators], [HIPCC compiler], [], [$HIPCC (flags: $HIPCCFLAGS)]) + + AC_ARG_VAR([HIPCC], [AMD HIP compiler]) + AC_ARG_VAR([HIPCCFLAGS], [AMD HIP compiler flags]) ]) diff --git a/config/opal_check_nvcc.m4 b/config/opal_check_nvcc.m4 index fa29151bebc..d4d8faab675 100644 --- a/config/opal_check_nvcc.m4 +++ b/config/opal_check_nvcc.m4 @@ -21,9 +21,6 @@ dnl AC_DEFUN([OPAL_CHECK_NVCC],[ - # This option is probably only helpful to developers: have - # configure fail if Sphinx is not found (i.e., if you don't have - # the ability to use Sphinx to build the HTML docs and man pages). AC_ARG_ENABLE([nvcc], [AS_HELP_STRING([--enable-nvcc], [Force configure to fail if CUDA nvcc is not found (CUDA nvcc is used to build CUDA operator support).])]) @@ -32,32 +29,28 @@ AC_DEFUN([OPAL_CHECK_NVCC],[ [AS_HELP_STRING([--with-nvcc=DIR], [Path to the CUDA compiler])]) - AC_ARG_WITH([nvcc_compute_arch], - [AS_HELP_STRING([--with-nvcc-compute-arch=ARCH], - [Compute architecture to use for CUDA (default: 52)])]) - AS_IF([test -n "$with_nvcc"], - [OPAL_NVCC=$with_nvcc], - # no path specified, try to find nvcc - [AC_PATH_PROG([OPAL_NVCC], [nvcc], [])]) + [NVCC=$with_nvcc]) + AS_IF([test -z "$NVCC"], + # try to find nvcc in PATH + [AC_PATH_PROG([NVCC], [nvcc], [])]) - # If the user requested to disable sphinx, then pretend we didn't - # find it. + # disable ussage of NVCC if explicitly specified AS_IF([test "$enable_nvcc" = "no"], - [OPAL_NVCC=]) - - # default to CUDA compute architecture 52 - AS_IF([test -n "$with_nvcc_compute_arch"], - [OPAL_NVCC_COMPUTE_ARCH=$with_nvcc_compute_arch], - [OPAL_NVCC_COMPUTE_ARCH=52]) - - # If --enable-sphinx was specified and we did not find Sphinx, - # abort. This is likely only useful to prevent "oops!" moments - # from developers. - AS_IF([test -z "$OPAL_NVCC" && test "$enable_nvcc" = "yes"], - [AC_MSG_WARN([A suitable CUDA compiler was not found, but --enable-nvcc was specified]) + [NVCC=]) + + # prepend C++17 standard, allow override by user + AS_IF([test -n "$NVCCFLAGS"], + [NVCCFLAGS=--std c++17 $NVCCFLAGS], + [NVCCFLAGS=--std c++17]) + + AS_IF([test -z "$NVCC" && test "$enable_nvcc" = "yes"], + [AC_MSG_WARN([A suitable CUDA compiler was not found, but --enable-nvcc=yes was specified]) AC_MSG_ERROR([Cannot continue])]) - OPAL_SUMMARY_ADD([Accelerators], [NVCC compiler], [], [$OPAL_NVCC (compute arch: $OPAL_NVCC_COMPUTE_ARCH)]) + OPAL_SUMMARY_ADD([Accelerators], [NVCC compiler], [], [$NVCC (flags: $NVCCFLAGS)]) + + AC_ARG_VAR([NVCC], [NVIDIA CUDA compiler]) + AC_ARG_VAR([NVCCFLAGS], [NVIDIA CUDA compiler flags]) ]) diff --git a/ompi/mca/op/cuda/Makefile.am b/ompi/mca/op/cuda/Makefile.am index dc776220f6e..68f8230fd35 100644 --- a/ompi/mca/op/cuda/Makefile.am +++ b/ompi/mca/op/cuda/Makefile.am @@ -24,13 +24,12 @@ dist_ompidata_DATA = help-ompi-mca-op-cuda.txt sources = op_cuda_component.c op_cuda.h op_cuda_functions.c op_cuda_impl.h cu_sources = op_cuda_impl.cu -NVCCFLAGS = --std c++17 --gpu-architecture=compute_$(OPAL_NVCC_COMPUTE_ARCH) # let the underlying compiler generate PIC code PIC_FLAGS = -prefer-non-pic -Wc,-Xcompiler,-fPIC .cu.l$(OBJEXT): $(LIBTOOL) $(AM_V_lt) --tag=CC $(AM_LIBTOOLFLAGS) \ - $(LIBTOOLFLAGS) --mode=compile $(OPAL_NVCC) $(NVCCFLAGS) $(PIC_FLAGS) -c $< + $(LIBTOOLFLAGS) --mode=compile $(NVCC) $(NVCCFLAGS) $(PIC_FLAGS) -c $< # Open MPI components can be compiled two ways: # diff --git a/ompi/mca/op/cuda/configure.m4 b/ompi/mca/op/cuda/configure.m4 index 15b3fd7c2d1..51be0763a33 100644 --- a/ompi/mca/op/cuda/configure.m4 +++ b/ompi/mca/op/cuda/configure.m4 @@ -27,7 +27,7 @@ AC_DEFUN([MCA_ompi_op_cuda_CONFIG],[ OPAL_CHECK_CUDART([op_cudart]) OPAL_CHECK_NVCC([op_nvcc]) - AS_IF([test "x$CUDA_SUPPORT" = "x1" -a "x$CUDART_SUPPORT" = "x1" -a -n "$OPAL_NVCC"], + AS_IF([test "x$CUDA_SUPPORT" = "x1" -a "x$CUDART_SUPPORT" = "x1" -a -n "$NVCC"], [$1], [$2]) @@ -39,7 +39,4 @@ AC_DEFUN([MCA_ompi_op_cuda_CONFIG],[ AC_SUBST([op_cudart_LDFLAGS]) AC_SUBST([op_cudart_LIBS]) - AC_SUBST([OPAL_NVCC]) - AC_SUBST([OPAL_NVCC_COMPUTE_ARCH]) - ])dnl diff --git a/ompi/mca/op/rocm/Makefile.am b/ompi/mca/op/rocm/Makefile.am index c6177374178..676f2f37a14 100644 --- a/ompi/mca/op/rocm/Makefile.am +++ b/ompi/mca/op/rocm/Makefile.am @@ -24,11 +24,9 @@ dist_ompidata_DATA = help-ompi-mca-op-rocm.txt sources = op_rocm_component.c op_rocm.h op_rocm_functions.c op_rocm_impl.h rocm_sources = op_rocm_impl.hip -HIPCC = hipcc - .hip.l$(OBJEXT): $(LIBTOOL) $(AM_V_lt) --tag=CC $(AM_LIBTOOLFLAGS) \ - $(LIBTOOLFLAGS) --mode=compile $(HIPCC) -O2 -std=c++17 -fvectorize -prefer-non-pic -Wc,-fPIC,-g -c $< + $(LIBTOOLFLAGS) --mode=compile $(HIPCC) $(HIPCCFLAGS) -c $< # -o $($@.o:.lo) diff --git a/ompi/mca/op/rocm/configure.m4 b/ompi/mca/op/rocm/configure.m4 index d773060eed7..155b89f6a7f 100644 --- a/ompi/mca/op/rocm/configure.m4 +++ b/ompi/mca/op/rocm/configure.m4 @@ -26,7 +26,7 @@ AC_DEFUN([MCA_ompi_op_rocm_CONFIG],[ OPAL_CHECK_ROCM([op_rocm]) OPAL_CHECK_HIPCC([op_hipcc]) - AS_IF([test "x$ROCM_SUPPORT" = "x1" -a -n "$OPAL_HIPCC"], + AS_IF([test "x$ROCM_SUPPORT" = "x1" -a -n "$HIPCC"], [$1], [$2]) @@ -34,6 +34,4 @@ AC_DEFUN([MCA_ompi_op_rocm_CONFIG],[ AC_SUBST([op_rocm_LDFLAGS]) AC_SUBST([op_rocm_LIBS]) - AC_SUBST([OPAL_HIPCC]) - ])dnl From 730102baf1d1c8a1101ef7a0cd54a0c3fcc30814 Mon Sep 17 00:00:00 2001 From: Joseph Schuchart Date: Tue, 9 Jul 2024 10:56:07 -0400 Subject: [PATCH 11/12] Point CI workflows to nvcc/hipcc Signed-off-by: Joseph Schuchart --- .github/workflows/compile-cuda.yaml | 2 +- .github/workflows/compile-rocm.yaml | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/compile-cuda.yaml b/.github/workflows/compile-cuda.yaml index 0bddcd3c744..45bb78a7aa9 100644 --- a/.github/workflows/compile-cuda.yaml +++ b/.github/workflows/compile-cuda.yaml @@ -24,5 +24,5 @@ jobs: - name: Build Open MPI run: | ./autogen.pl - ./configure --prefix=${PWD}/install --with-cuda=${CUDA_PATH} --with-cuda-libdir=${CUDA_PATH}/lib64/stubs + ./configure --prefix=${PWD}/install --with-cuda=${CUDA_PATH} --with-cuda-libdir=${CUDA_PATH}/lib64/stubs --enable-nvcc NVCC=/usr/local/cuda/bin/nvcc make -j diff --git a/.github/workflows/compile-rocm.yaml b/.github/workflows/compile-rocm.yaml index 2ce2a80f01a..dbbe43dc4af 100644 --- a/.github/workflows/compile-rocm.yaml +++ b/.github/workflows/compile-rocm.yaml @@ -20,12 +20,12 @@ jobs: echo "deb [arch=amd64 signed-by=/etc/apt/keyrings/rocm.gpg] https://repo.radeon.com/rocm/apt/5.7.1 jammy main" | sudo tee --append /etc/apt/sources.list.d/rocm.list echo -e 'Package: *\nPin: release o=repo.radeon.com\nPin-Priority: 600' | sudo tee /etc/apt/preferences.d/rocm-pin-600 sudo apt update - sudo apt install -y rocm-hip-runtime + sudo apt install -y rocm-hip-runtime hipcc - uses: actions/checkout@v4 with: submodules: recursive - name: Build Open MPI run: | ./autogen.pl - ./configure --prefix=${PWD}/install --with-rocm=/opt/rocm --disable-mpi-fortran + ./configure --prefix=${PWD}/install --with-rocm=/opt/rocm --disable-mpi-fortran --enable-hipcc HIPCC=/opt/rocm-5.7.1/bin/hipcc LD_LIBRARY_PATH=/opt/rocm/lib make -j From c200c0292dd6e3d98e049d0f2a0baa57490bbba3 Mon Sep 17 00:00:00 2001 From: Joseph Schuchart Date: Sun, 15 Sep 2024 19:49:21 -0400 Subject: [PATCH 12/12] More robust find for cudart Signed-off-by: Joseph Schuchart --- config/opal_check_cudart.m4 | 29 +++++++++++++++++------------ 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/config/opal_check_cudart.m4 b/config/opal_check_cudart.m4 index 0e3fced8065..833f0435375 100644 --- a/config/opal_check_cudart.m4 +++ b/config/opal_check_cudart.m4 @@ -58,31 +58,36 @@ AC_ARG_WITH([cudart-libdir], #################################### #### Check for CUDA runtime library #################################### -AS_IF([test "x$with_cudart" != "xno" || test "x$with_cudart" = "x"], +AS_IF([test "x$with_cudart" = "xno" || test "x$with_cudart" = "x"], [opal_check_cudart_happy=no AC_MSG_RESULT([not set (--with-cudart=$with_cudart)])], [AS_IF([test ! -d "$with_cudart"], [AC_MSG_RESULT([not found]) - AC_MSG_WARN([Directory $with_cudart not found])] - [AS_IF([test "x`ls $with_cudart/include/cuda_runtime.h 2> /dev/null`" = "x"] - [AC_MSG_RESULT([not found]) - AC_MSG_WARN([Could not find cuda_runtime.h in $with_cudart/include])] - [opal_check_cudart_happy=yes - opal_cudart_incdir="$with_cudart/include"])])]) + AC_MSG_WARN([Directory $with_cudart not found])], + [OPAL_FLAGS_APPEND_UNIQ([CPPFLAGS], [-I$with_cudart/include]) + AC_CHECK_HEADERS([cuda_runtime.h], + [opal_check_cudart_happy=yes + opal_cudart_incdir="$with_cudart/include"] + [AC_MSG_RESULT([not found]) + AC_MSG_WARN([Could not find cuda_runtime.h in $with_cudart/include])])])]) +CPPFLAGS=${cudart_save_CPPFLAGS} +# try include path relative to nvcc AS_IF([test "$opal_check_cudart_happy" = "no" && test "$with_cudart" != "no"], [AC_PATH_PROG([nvcc_bin], [nvcc], ["not-found"]) AS_IF([test "$nvcc_bin" = "not-found"], [AC_MSG_WARN([Could not find nvcc binary])], [nvcc_dirname=`AS_DIRNAME([$nvcc_bin])` - with_cudart=$nvcc_dirname/../ - opal_cudart_incdir=$nvcc_dirname/../include - opal_check_cudart_happy=yes]) - ] + OPAL_FLAGS_APPEND_UNIQ([CPPFLAGS], [-I$nvcc_dirname/../include]) + AC_CHECK_HEADERS([cuda_runtime.h], + [opal_check_cudart_happy=yes, + with_cudart=$nvcc_dirname/../ + opal_cudart_incdir="$with_cudart/include"])])], []) +CPPFLAGS=${cudart_save_CPPFLAGS} AS_IF([test x"$with_cudart_libdir" = "x"], - [with_cudart_libdir=$with_cudart/lib64/] + [with_cudart_libdir=$with_cudart/lib64/], []) AS_IF([test "$opal_check_cudart_happy" = "yes"],