diff --git a/Makefile b/Makefile index d2715edd..409438d0 100644 --- a/Makefile +++ b/Makefile @@ -53,10 +53,11 @@ tests: tests_nccl: mpiexec -n $(NUM_PROCESSES) pytest tests_nccl/ --with-mpi +# sphinx-build does not work well with NCCL doc: cd docs && rm -rf source/api/generated && rm -rf source/gallery &&\ rm -rf source/tutorials && rm -rf build &&\ - cd .. && sphinx-build -b html docs/source docs/build + cd .. && NCCL_PYLOPS_MPI=0 sphinx-build -b html docs/source docs/build doc_cupy: cp tutorials_cupy/* tutorials/ diff --git a/docs/source/benchmarking.rst b/docs/source/benchmarking.rst new file mode 100644 index 00000000..eae99dff --- /dev/null +++ b/docs/source/benchmarking.rst @@ -0,0 +1,41 @@ +.. _benchmarkutility: + +Benchmark Utility in PyLops-MPI +=============================== +PyLops-MPI users can convenienly benchmark the performance of their code with a simple decorator. +:py:func:`pylops_mpi.utils.benchmark` and :py:func:`pylops_mpi.utils.mark` support various +function calling patterns that may arise when benchmarking distributed code. + +- :py:func:`pylops_mpi.utils.benchmark` is a **decorator** used to time the execution of entire functions. +- :py:func:`pylops_mpi.utils.mark` is a **function** used inside decorated functions to insert fine-grained time measurements. + +.. note:: + This benchmark utility is enabled by default i.e., if the user decorates the function with :py:func:`@benchmark`, the function will go through + the time measurements, adding overheads. Users can turn off the benchmark while leaving the decorator in-place with + + .. code-block:: bash + + >> export BENCH_PYLOPS_MPI=0 + +The usage can be as simple as: + +.. code-block:: python + + @benchmark + def function_to_time(): + # Your computation + +The result will print out to the standard output. +For fine-grained time measurements, :py:func:`pylops_mpi.utils.mark` can be inserted in the code region of benchmarked functions: + +.. code-block:: python + + @benchmark + def funtion_to_time(): + # You computation that you may want to ignore it in benchmark + mark("Begin Region") + # You computation + mark("Finish Region") + +You can also nest benchmarked functions to track execution times across layers of function calls with the output being correctly formatted. +Additionally, the result can also be exported to the text file. For completed and runnable examples, visit :ref:`sphx_glr_tutorials_benchmarking.py` diff --git a/pylops_mpi/DistributedArray.py b/pylops_mpi/DistributedArray.py index 5499d963..5807a518 100644 --- a/pylops_mpi/DistributedArray.py +++ b/pylops_mpi/DistributedArray.py @@ -452,7 +452,7 @@ def _check_local_shapes(self, local_shapes): elif self.partition is Partition.SCATTER: local_shape = local_shapes[self.rank] # Check if local shape sum up to global shape and other dimensions align with global shape - if self._allreduce(local_shape[self.axis]) != self.global_shape[self.axis] or \ + if self.base_comm.allreduce(local_shape[self.axis]) != self.global_shape[self.axis] or \ not np.array_equal(np.delete(local_shape, self.axis), np.delete(self.global_shape, self.axis)): raise ValueError(f"Local shapes don't align with the global shape;" f"{local_shapes} != {self.global_shape}") diff --git a/pylops_mpi/utils/_nccl.py b/pylops_mpi/utils/_nccl.py index 508893ea..19c09922 100644 --- a/pylops_mpi/utils/_nccl.py +++ b/pylops_mpi/utils/_nccl.py @@ -1,6 +1,7 @@ __all__ = [ "_prepare_nccl_allgather_inputs", "_unroll_nccl_allgather_recv", + "_nccl_sync", "initialize_nccl_comm", "nccl_split", "nccl_allgather", @@ -19,7 +20,6 @@ import cupy as cp import cupy.cuda.nccl as nccl - cupy_to_nccl_dtype = { "float32": nccl.NCCL_FLOAT32, "float64": nccl.NCCL_FLOAT64, @@ -63,6 +63,13 @@ def _nccl_buf_size(buf, count=None): return count if count else buf.size +def _nccl_sync(): + """A thin wrapper of CuPy's synchronization for protected import""" + if cp.cuda.runtime.getDeviceCount() == 0: + return + cp.cuda.runtime.deviceSynchronize() + + def _prepare_nccl_allgather_inputs(send_buf, send_buf_shapes) -> Tuple[cp.ndarray, cp.ndarray]: r""" Prepare send_buf and recv_buf for NCCL allgather (nccl_allgather) diff --git a/pylops_mpi/utils/benchmark.py b/pylops_mpi/utils/benchmark.py new file mode 100644 index 00000000..2e7be83d --- /dev/null +++ b/pylops_mpi/utils/benchmark.py @@ -0,0 +1,167 @@ +import functools +import logging +import os +import time +from typing import Callable, Optional, List +from mpi4py import MPI + +from pylops.utils import deps as pylops_deps # avoid namespace crashes with pylops_mpi.utils +from pylops_mpi.utils import deps + +cupy_message = pylops_deps.cupy_import("the benchmark module") +nccl_message = deps.nccl_import("the benchmark module") + +if nccl_message is None and cupy_message is None: + from pylops_mpi.utils._nccl import _nccl_sync +else: + def _nccl_sync(): + pass + +# Benchmark is enabled by default +ENABLE_BENCHMARK = int(os.getenv("BENCH_PYLOPS_MPI", 1)) == 1 + +# Stack of active mark functions for nested support +_mark_func_stack = [] +_markers = [] + + +def _parse_output_tree(markers: List[str]): + """This function parses the list of strings gathered during the benchmark call and output them + as one properly formatted string. The format of output string follows the hierarchy of function calls + i.e., the nested funtion calls are indented. + + Parameters + ---------- + markers: :obj:`list`, optional + A list of markers/labels generated from the benchmark call + """ + global _markers + output = [] + stack = [] + i = 0 + while i < len(markers): + label, time, level = markers[i] + if label.startswith("[decorator]"): + indent = "\t" * (level - 1) + output.append(f"{indent}{label}: total runtime: {time:6f} s\n") + else: + if stack: + prev_label, prev_time, prev_level = stack[-1] + if prev_level == level: + indent = "\t" * level + output.append(f"{indent}{prev_label}-->{label}: {time - prev_time:6f} s\n") + stack.pop() + + # Push to the stack only if it is going deeper or still at the same level + if i + 1 <= len(markers) - 1: + _, _ , next_level = markers[i + 1] + if next_level >= level: + stack.append(markers[i]) + i += 1 + # reset markers, allowing other benchmarked function to start fresh + _markers = [] + return output + + +def _sync(): + """Synchronize all MPI processes or CUDA Devices""" + _nccl_sync() + MPI.COMM_WORLD.Barrier() + + +def mark(label: str): + """This function allows users to measure time arbitary lines of the function + + Parameters + ---------- + label: :obj:`str` + A label of the mark. This signifies both 1) the end of the + previous mark 2) the beginning of the new mark + """ + if not ENABLE_BENCHMARK: + return + if not _mark_func_stack: + raise RuntimeError("mark() called outside of a benchmarked region") + _mark_func_stack[-1](label) + + +def benchmark(func: Optional[Callable] = None, + description: Optional[str] = "", + logger: Optional[logging.Logger] = None, + ): + """A wrapper for code injection for time measurement. + + This wrapper measures the start-to-end time of the wrapped function when + decorated without any argument. + + It also allows users to put a call to mark() anywhere inside the wrapped function + for fine-grain time benchmark. This wrapper defines the local_mark() and pushes it + to the _mark_func_stack for isolation in case of nested call. + The user-facing mark() will always call the function at the top of the _mark_func_stack. + + Parameters + ---------- + func : :obj:`callable`, optional + Function to be decorated. Defaults to ``None``. + description : :obj:`str`, optional + Description for the output text. Defaults to ``''``. + logger: :obj:`logging.Logger`, optional + A `logging.Logger` object for logging the benchmark text output. This logger must be setup before + passing to this function to either writing output to a file or log to stdout. If `logger` + is not provided, the output is printed to stdout. + """ + + def noop_decorator(func): + @functools.wraps(func) + def wrapped(*args, **kwargs): + return func(*args, **kwargs) + return wrapped + + @functools.wraps(func) + def decorator(func): + def wrapper(*args, **kwargs): + rank = MPI.COMM_WORLD.Get_rank() + + level = len(_mark_func_stack) + 1 + # The header is needed for later tree parsing. Here it is allocating its spot. + # the tuple at this index will be replaced after elapsed time is calculated. + _markers.append((f"[decorator]{description or func.__name__}", None, level)) + header_index = len(_markers) - 1 + + def local_mark(label): + _markers.append((label, time.perf_counter(), level)) + + _mark_func_stack.append(local_mark) + + _sync() + start_time = time.perf_counter() + # the mark() called in wrapped function will now call local_mark + result = func(*args, **kwargs) + _sync() + end_time = time.perf_counter() + + elapsed = end_time - start_time + _markers[header_index] = (f"[decorator]{description or func.__name__}", elapsed, level) + + # In case of nesting, the wrapped callee must pop its closure from stack so that + # when the callee returns, the wrapped caller operates on its closure (and its level label), which now becomes + # the top of the stack. + _mark_func_stack.pop() + + # all the calls have fininshed + if not _mark_func_stack: + if rank == 0: + output = _parse_output_tree(_markers) + if logger: + logger.info("".join(output)) + else: + print("".join(output)) + return result + return wrapper + + # The code still has to return decorator so that the in-place decorator with arguments + # like @benchmark(logger=logger) does not throw the error and can be kept untouched. + if not ENABLE_BENCHMARK: + return noop_decorator if func is None else noop_decorator(func) + + return decorator if func is None else decorator(func) diff --git a/tutorials/benchmarking.py b/tutorials/benchmarking.py new file mode 100644 index 00000000..15a84ef8 --- /dev/null +++ b/tutorials/benchmarking.py @@ -0,0 +1,134 @@ +r""" +Benchmark Utility in PyLops-MPI +=============================== +This tutorial demonstrates how to use the :py:func:`pylops_mpi.utils.benchmark` and +:py:func:`pylops_mpi.utils.mark` utility methods in PyLops-MPI. It contains various +function calling pattern that may come up during the benchmarking of a distributed code. + +:py:func:`pylops_mpi.utils.benchmark` is a decorator used to decorate any +function to measure its execution time from start to finish +:py:func:`pylops_mpi.utils.mark` is a function used inside the benchmark-decorated +function to provide fine-grain time measurements. +""" + +import sys +import logging +import numpy as np +from mpi4py import MPI +from pylops_mpi import DistributedArray, Partition + +np.random.seed(42) +rank = MPI.COMM_WORLD.Get_rank() + +par = {'global_shape': (500, 501), + 'partition': Partition.SCATTER, 'dtype': np.float64, + 'axis': 1} + +############################################################################### +# Let's start by import the utility and a simple exampple +from pylops_mpi.utils.benchmark import benchmark, mark + + +@benchmark +def inner_func(par): + dist_arr = DistributedArray(global_shape=par['global_shape'], + partition=par['partition'], + dtype=par['dtype'], axis=par['axis']) + # may perform computation here + dist_arr.dot(dist_arr) + + +############################################################################### +# When we call :py:func:`inner_func`, we will see the result +# of the benchmark print to standard output. If we want to customize the +# function name in the printout, we can pass the parameter `description` +# to the :py:func:`benchmark` +# i.e., :py:func:`@benchmark(description="printout_name")` + +inner_func(par) + +############################################################################### +# We may want to get the fine-grained time measurements by timing the execution +# time of arbitary lines of code. :py:func:`pylops_mpi.utils.mark` provides such utitlity. + + +@benchmark +def inner_func_with_mark(par): + mark("Begin array constructor") + dist_arr = DistributedArray(global_shape=par['global_shape'], + partition=par['partition'], + dtype=par['dtype'], axis=par['axis']) + mark("Begin dot") + dist_arr.dot(dist_arr) + mark("Finish dot") + + +############################################################################### +# Now when we run, we get the detailed time measurement. Note that there is a tag +# [decorator] next to the function name to distinguish between the start-to-end time +# measurement of the top-level function and those that comes from :py:func:`pylops_mpi.utils.mark` +inner_func_with_mark(par) + +############################################################################### +# This utility benchmarking routines can also be nested. Let's define +# an outer function that internally calls the decorated :py:func:`inner_func_with_mark` + + +@benchmark +def outer_func_with_mark(par): + mark("Outer func start") + inner_func_with_mark(par) + dist_arr = DistributedArray(global_shape=par['global_shape'], + partition=par['partition'], + dtype=par['dtype'], axis=par['axis']) + dist_arr + dist_arr + mark("Outer func ends") + + +############################################################################### +# If we run :py:func:`outer_func_with_mark`, we get the time measurement nicely +# printed out with the nested indentation to specify that nested calls. +outer_func_with_mark(par) + + +############################################################################### +# In some cases, we may want to write benchmark output to a text file. +# :py:func:`pylops_mpi.utils.benchmark` also takes the py:class:`logging.Logger` +# in its argument. +# Here we define a simple :py:func:`make_logger()`. We set the :py:func:`logger.propagate = False` +# to isolate the logging of our benchmark from that of the rest of the code + +save_file = True +file_path = "benchmark.log" + + +def make_logger(save_file=False, file_path=''): + logger = logging.getLogger(__name__) + logging.basicConfig(filename=file_path if save_file else None, filemode='w', level=logging.INFO, force=True) + logger.propagate = False + if save_file: + handler = logging.FileHandler(file_path, mode='w') + else: + handler = logging.StreamHandler(sys.stdout) + logger.addHandler(handler) + return logger + + +logger = make_logger(save_file, file_path) + + +############################################################################### +# Then we can pass the logger to the :py:func:`pylops_mpi.utils.benchmark` + +@benchmark(logger=logger) +def inner_func_with_logger(par): + dist_arr = DistributedArray(global_shape=par['global_shape'], + partition=par['partition'], + dtype=par['dtype'], axis=par['axis']) + # may perform computation here + dist_arr.dot(dist_arr) + + +############################################################################### +# Run this function and observe that the file `benchmark.log` is written. +inner_func_with_logger(par) diff --git a/tutorials/lsm.py b/tutorials/lsm.py index bee504bd..c83fcce2 100644 --- a/tutorials/lsm.py +++ b/tutorials/lsm.py @@ -209,3 +209,6 @@ axs[2].set_title(r"$d_{inv}$") axs[2].axis("tight") plt.tight_layout() + +############################################################################### +# To run this tutorial with our NCCL backend, refer to `Least-squares Migration with NCCL tutorial `_ in the repository. diff --git a/tutorials/mdd.py b/tutorials/mdd.py index 7913f3f5..d4fd4571 100644 --- a/tutorials/mdd.py +++ b/tutorials/mdd.py @@ -231,3 +231,6 @@ ) ax3.set_ylim([t2[-1], t2[0]]) fig.tight_layout() + +############################################################################### +# To run this tutorial with our NCCL backend, refer to `Multi-Dimensional Deconvolution with NCCL tutorial `_ in the repository. diff --git a/tutorials/poststack.py b/tutorials/poststack.py index ff42ba20..8401fd31 100644 --- a/tutorials/poststack.py +++ b/tutorials/poststack.py @@ -279,4 +279,4 @@ axs[5][2].axis('tight') ############################################################################### -# To run this tutorial with our NCCL backend, refer to :ref:`sphx_glr_tutorials_poststack_nccl.py` +# To run this tutorial with our NCCL backend, refer to `Post Stack Inversion with NCCL tutorial `_ in the repository.