Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,11 @@ tests:
tests_nccl:
mpiexec -n $(NUM_PROCESSES) pytest tests_nccl/ --with-mpi

# sphinx-build does not work well with NCCL
doc:
cd docs && rm -rf source/api/generated && rm -rf source/gallery &&\
rm -rf source/tutorials && rm -rf build &&\
cd .. && sphinx-build -b html docs/source docs/build
cd .. && NCCL_PYLOPS_MPI=0 sphinx-build -b html docs/source docs/build

doc_cupy:
cp tutorials_cupy/* tutorials/
Expand Down
41 changes: 41 additions & 0 deletions docs/source/benchmarking.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
.. _benchmarkutility:

Benchmark Utility in PyLops-MPI
===============================
PyLops-MPI users can convenienly benchmark the performance of their code with a simple decorator.
:py:func:`pylops_mpi.utils.benchmark` and :py:func:`pylops_mpi.utils.mark` support various
function calling patterns that may arise when benchmarking distributed code.

- :py:func:`pylops_mpi.utils.benchmark` is a **decorator** used to time the execution of entire functions.
- :py:func:`pylops_mpi.utils.mark` is a **function** used inside decorated functions to insert fine-grained time measurements.

.. note::
This benchmark utility is enabled by default i.e., if the user decorates the function with :py:func:`@benchmark`, the function will go through
the time measurements, adding overheads. Users can turn off the benchmark while leaving the decorator in-place with

.. code-block:: bash

>> export BENCH_PYLOPS_MPI=0

The usage can be as simple as:

.. code-block:: python

@benchmark
def function_to_time():
# Your computation

The result will print out to the standard output.
For fine-grained time measurements, :py:func:`pylops_mpi.utils.mark` can be inserted in the code region of benchmarked functions:

.. code-block:: python

@benchmark
def funtion_to_time():
# You computation that you may want to ignore it in benchmark
mark("Begin Region")
# You computation
mark("Finish Region")

You can also nest benchmarked functions to track execution times across layers of function calls with the output being correctly formatted.
Additionally, the result can also be exported to the text file. For completed and runnable examples, visit :ref:`sphx_glr_tutorials_benchmarking.py`
2 changes: 1 addition & 1 deletion pylops_mpi/DistributedArray.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,7 @@ def _check_local_shapes(self, local_shapes):
elif self.partition is Partition.SCATTER:
local_shape = local_shapes[self.rank]
# Check if local shape sum up to global shape and other dimensions align with global shape
if self._allreduce(local_shape[self.axis]) != self.global_shape[self.axis] or \
if self.base_comm.allreduce(local_shape[self.axis]) != self.global_shape[self.axis] or \
not np.array_equal(np.delete(local_shape, self.axis), np.delete(self.global_shape, self.axis)):
raise ValueError(f"Local shapes don't align with the global shape;"
f"{local_shapes} != {self.global_shape}")
Expand Down
9 changes: 8 additions & 1 deletion pylops_mpi/utils/_nccl.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
__all__ = [
"_prepare_nccl_allgather_inputs",
"_unroll_nccl_allgather_recv",
"_nccl_sync",
"initialize_nccl_comm",
"nccl_split",
"nccl_allgather",
Expand All @@ -19,7 +20,6 @@
import cupy as cp
import cupy.cuda.nccl as nccl


cupy_to_nccl_dtype = {
"float32": nccl.NCCL_FLOAT32,
"float64": nccl.NCCL_FLOAT64,
Expand Down Expand Up @@ -63,6 +63,13 @@ def _nccl_buf_size(buf, count=None):
return count if count else buf.size


def _nccl_sync():
"""A thin wrapper of CuPy's synchronization for protected import"""
if cp.cuda.runtime.getDeviceCount() == 0:
return
cp.cuda.runtime.deviceSynchronize()


def _prepare_nccl_allgather_inputs(send_buf, send_buf_shapes) -> Tuple[cp.ndarray, cp.ndarray]:
r""" Prepare send_buf and recv_buf for NCCL allgather (nccl_allgather)

Expand Down
167 changes: 167 additions & 0 deletions pylops_mpi/utils/benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
import functools
import logging
import os
import time
from typing import Callable, Optional, List
from mpi4py import MPI

from pylops.utils import deps as pylops_deps # avoid namespace crashes with pylops_mpi.utils
from pylops_mpi.utils import deps

cupy_message = pylops_deps.cupy_import("the benchmark module")
nccl_message = deps.nccl_import("the benchmark module")

if nccl_message is None and cupy_message is None:
from pylops_mpi.utils._nccl import _nccl_sync
else:
def _nccl_sync():
pass

# Benchmark is enabled by default
ENABLE_BENCHMARK = int(os.getenv("BENCH_PYLOPS_MPI", 1)) == 1

# Stack of active mark functions for nested support
_mark_func_stack = []
_markers = []


def _parse_output_tree(markers: List[str]):
"""This function parses the list of strings gathered during the benchmark call and output them
as one properly formatted string. The format of output string follows the hierarchy of function calls
i.e., the nested funtion calls are indented.

Parameters
----------
markers: :obj:`list`, optional
A list of markers/labels generated from the benchmark call
"""
global _markers
output = []
stack = []
i = 0
while i < len(markers):
label, time, level = markers[i]
if label.startswith("[decorator]"):
indent = "\t" * (level - 1)
output.append(f"{indent}{label}: total runtime: {time:6f} s\n")
else:
if stack:
prev_label, prev_time, prev_level = stack[-1]
if prev_level == level:
indent = "\t" * level
output.append(f"{indent}{prev_label}-->{label}: {time - prev_time:6f} s\n")
stack.pop()

# Push to the stack only if it is going deeper or still at the same level
if i + 1 <= len(markers) - 1:
_, _ , next_level = markers[i + 1]
if next_level >= level:
stack.append(markers[i])
i += 1
# reset markers, allowing other benchmarked function to start fresh
_markers = []
return output


def _sync():
"""Synchronize all MPI processes or CUDA Devices"""
_nccl_sync()
MPI.COMM_WORLD.Barrier()


def mark(label: str):
"""This function allows users to measure time arbitary lines of the function

Parameters
----------
label: :obj:`str`
A label of the mark. This signifies both 1) the end of the
previous mark 2) the beginning of the new mark
"""
if not ENABLE_BENCHMARK:
return
if not _mark_func_stack:
raise RuntimeError("mark() called outside of a benchmarked region")
_mark_func_stack[-1](label)


def benchmark(func: Optional[Callable] = None,
description: Optional[str] = "",
logger: Optional[logging.Logger] = None,
):
"""A wrapper for code injection for time measurement.

This wrapper measures the start-to-end time of the wrapped function when
decorated without any argument.

It also allows users to put a call to mark() anywhere inside the wrapped function
for fine-grain time benchmark. This wrapper defines the local_mark() and pushes it
to the _mark_func_stack for isolation in case of nested call.
The user-facing mark() will always call the function at the top of the _mark_func_stack.

Parameters
----------
func : :obj:`callable`, optional
Function to be decorated. Defaults to ``None``.
description : :obj:`str`, optional
Description for the output text. Defaults to ``''``.
logger: :obj:`logging.Logger`, optional
A `logging.Logger` object for logging the benchmark text output. This logger must be setup before
passing to this function to either writing output to a file or log to stdout. If `logger`
is not provided, the output is printed to stdout.
"""

def noop_decorator(func):
@functools.wraps(func)
def wrapped(*args, **kwargs):
return func(*args, **kwargs)
return wrapped

@functools.wraps(func)
def decorator(func):
def wrapper(*args, **kwargs):
rank = MPI.COMM_WORLD.Get_rank()

level = len(_mark_func_stack) + 1
# The header is needed for later tree parsing. Here it is allocating its spot.
# the tuple at this index will be replaced after elapsed time is calculated.
_markers.append((f"[decorator]{description or func.__name__}", None, level))
header_index = len(_markers) - 1

def local_mark(label):
_markers.append((label, time.perf_counter(), level))

_mark_func_stack.append(local_mark)

_sync()
start_time = time.perf_counter()
# the mark() called in wrapped function will now call local_mark
result = func(*args, **kwargs)
_sync()
end_time = time.perf_counter()

elapsed = end_time - start_time
_markers[header_index] = (f"[decorator]{description or func.__name__}", elapsed, level)

# In case of nesting, the wrapped callee must pop its closure from stack so that
# when the callee returns, the wrapped caller operates on its closure (and its level label), which now becomes
# the top of the stack.
_mark_func_stack.pop()

# all the calls have fininshed
if not _mark_func_stack:
if rank == 0:
output = _parse_output_tree(_markers)
if logger:
logger.info("".join(output))
else:
print("".join(output))
return result
return wrapper

# The code still has to return decorator so that the in-place decorator with arguments
# like @benchmark(logger=logger) does not throw the error and can be kept untouched.
if not ENABLE_BENCHMARK:
return noop_decorator if func is None else noop_decorator(func)

return decorator if func is None else decorator(func)
134 changes: 134 additions & 0 deletions tutorials/benchmarking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
r"""
Benchmark Utility in PyLops-MPI
===============================
This tutorial demonstrates how to use the :py:func:`pylops_mpi.utils.benchmark` and
:py:func:`pylops_mpi.utils.mark` utility methods in PyLops-MPI. It contains various
function calling pattern that may come up during the benchmarking of a distributed code.

:py:func:`pylops_mpi.utils.benchmark` is a decorator used to decorate any
function to measure its execution time from start to finish
:py:func:`pylops_mpi.utils.mark` is a function used inside the benchmark-decorated
function to provide fine-grain time measurements.
"""

import sys
import logging
import numpy as np
from mpi4py import MPI
from pylops_mpi import DistributedArray, Partition

np.random.seed(42)
rank = MPI.COMM_WORLD.Get_rank()

par = {'global_shape': (500, 501),
'partition': Partition.SCATTER, 'dtype': np.float64,
'axis': 1}

###############################################################################
# Let's start by import the utility and a simple exampple
from pylops_mpi.utils.benchmark import benchmark, mark


@benchmark
def inner_func(par):
dist_arr = DistributedArray(global_shape=par['global_shape'],
partition=par['partition'],
dtype=par['dtype'], axis=par['axis'])
# may perform computation here
dist_arr.dot(dist_arr)


###############################################################################
# When we call :py:func:`inner_func`, we will see the result
# of the benchmark print to standard output. If we want to customize the
# function name in the printout, we can pass the parameter `description`
# to the :py:func:`benchmark`
# i.e., :py:func:`@benchmark(description="printout_name")`

inner_func(par)

###############################################################################
# We may want to get the fine-grained time measurements by timing the execution
# time of arbitary lines of code. :py:func:`pylops_mpi.utils.mark` provides such utitlity.


@benchmark
def inner_func_with_mark(par):
mark("Begin array constructor")
dist_arr = DistributedArray(global_shape=par['global_shape'],
partition=par['partition'],
dtype=par['dtype'], axis=par['axis'])
mark("Begin dot")
dist_arr.dot(dist_arr)
mark("Finish dot")


###############################################################################
# Now when we run, we get the detailed time measurement. Note that there is a tag
# [decorator] next to the function name to distinguish between the start-to-end time
# measurement of the top-level function and those that comes from :py:func:`pylops_mpi.utils.mark`
inner_func_with_mark(par)

###############################################################################
# This utility benchmarking routines can also be nested. Let's define
# an outer function that internally calls the decorated :py:func:`inner_func_with_mark`


@benchmark
def outer_func_with_mark(par):
mark("Outer func start")
inner_func_with_mark(par)
dist_arr = DistributedArray(global_shape=par['global_shape'],
partition=par['partition'],
dtype=par['dtype'], axis=par['axis'])
dist_arr + dist_arr
mark("Outer func ends")


###############################################################################
# If we run :py:func:`outer_func_with_mark`, we get the time measurement nicely
# printed out with the nested indentation to specify that nested calls.
outer_func_with_mark(par)


###############################################################################
# In some cases, we may want to write benchmark output to a text file.
# :py:func:`pylops_mpi.utils.benchmark` also takes the py:class:`logging.Logger`
# in its argument.
# Here we define a simple :py:func:`make_logger()`. We set the :py:func:`logger.propagate = False`
# to isolate the logging of our benchmark from that of the rest of the code

save_file = True
file_path = "benchmark.log"


def make_logger(save_file=False, file_path=''):
logger = logging.getLogger(__name__)
logging.basicConfig(filename=file_path if save_file else None, filemode='w', level=logging.INFO, force=True)
logger.propagate = False
if save_file:
handler = logging.FileHandler(file_path, mode='w')
else:
handler = logging.StreamHandler(sys.stdout)
logger.addHandler(handler)
return logger


logger = make_logger(save_file, file_path)


###############################################################################
# Then we can pass the logger to the :py:func:`pylops_mpi.utils.benchmark`

@benchmark(logger=logger)
def inner_func_with_logger(par):
dist_arr = DistributedArray(global_shape=par['global_shape'],
partition=par['partition'],
dtype=par['dtype'], axis=par['axis'])
# may perform computation here
dist_arr.dot(dist_arr)


###############################################################################
# Run this function and observe that the file `benchmark.log` is written.
inner_func_with_logger(par)
Loading
Loading