Skip to content

Commit e1ce7fb

Browse files
author
Holger Kohr
committed
WIP: fix cupy->numpy transfers
1 parent dd2b7c4 commit e1ce7fb

File tree

1 file changed

+101
-39
lines changed

1 file changed

+101
-39
lines changed

odl/space/cupy_tensors.py

+101-39
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,13 @@
1616
import numpy as np
1717
import warnings
1818

19-
from odl.set import RealNumbers
19+
from odl.set import RealNumbers, ComplexNumbers
2020
from odl.space.base_tensors import TensorSpace, Tensor
2121
from odl.space.weighting import (
2222
Weighting, ArrayWeighting, ConstWeighting,
2323
CustomInner, CustomNorm, CustomDist)
24-
from odl.util import dtype_str, is_floating_dtype, signature_string
24+
from odl.util import (
25+
array_str, dtype_str, is_floating_dtype, signature_string, indent)
2526

2627
try:
2728
import cupy
@@ -599,13 +600,14 @@ def __eq__(self, other):
599600
>>> same_space == space
600601
True
601602
"""
602-
return (super().__eq__(other) and
603+
return (super(CupyTensorSpace, self).__eq__(other) and
603604
self.device == other.device and
604605
self.weighting == other.weighting)
605606

606607
def __hash__(self):
607608
"""Return ``hash(self)``."""
608-
return hash((super().__hash__(), self.device, self.weighting))
609+
return hash((super(CupyTensorSpace, self).__hash__(), self.device,
610+
self.weighting))
609611

610612
def _lincomb(self, a, x1, b, x2, out):
611613
"""Linear combination of ``x1`` and ``x2``.
@@ -874,12 +876,16 @@ def default_dtype(field=None):
874876
dtype : `numpy.dtype`
875877
Numpy data type specifier. The returned defaults are:
876878
877-
``RealNumbers()`` : ``np.dtype('float64')``
879+
- ``RealNumbers()`` or ``None`` : ``np.dtype('float64')``
880+
- ``ComplexNumbers()`` : ``np.dtype('complex128')``
878881
879-
``ComplexNumbers()`` : not supported
882+
These choices correspond to the defaults of the ``cupy``
883+
library.
880884
"""
881885
if field is None or field == RealNumbers():
882886
return np.dtype('float64')
887+
elif field == ComplexNumbers():
888+
return np.dtype('complex128')
883889
else:
884890
raise ValueError('no default data type defined for field {}.'
885891
''.format(field))
@@ -1029,7 +1035,7 @@ def copy(self):
10291035
10301036
Returns
10311037
-------
1032-
copy : `pygpu._array.ndgpuarray`
1038+
copy : `CupyTensor`
10331039
A deep copy.
10341040
10351041
Examples
@@ -1056,7 +1062,7 @@ def __getitem__(self, indices):
10561062
10571063
Returns
10581064
-------
1059-
values : scalar or `pygpu._array.ndgpuarray`
1065+
values : scalar or `cupy.core.core.ndarray`
10601066
The value(s) at the index (indices).
10611067
10621068
Examples
@@ -1107,11 +1113,11 @@ def __getitem__(self, indices):
11071113
arr = self.data[indices]
11081114
if arr.shape == ():
11091115
if arr.dtype.kind == 'f':
1110-
return float(np.asarray(arr))
1116+
return float(cupy.asnumpy(arr))
11111117
elif arr.dtype.kind == 'c':
1112-
return complex(np.asarray(arr))
1118+
return complex(cupy.asnumpy(arr))
11131119
elif arr.dtype.kind in ('u', 'i'):
1114-
return int(np.asarray(arr))
1120+
return int(cupy.asnumpy(arr))
11151121
else:
11161122
raise RuntimeError("no conversion for dtype {}"
11171123
"".format(arr.dtype))
@@ -1280,12 +1286,22 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
12801286
for further details. See also the `general documentation on
12811287
Numpy ufuncs`_.
12821288
1283-
.. note::
1284-
This implementation looks for native ufuncs in ``pygpu.ufuncs``
1285-
and falls back to the basic implementation with Numpy arrays
1286-
in case no native ufunc is available. That fallback version
1287-
comes with significant overhead due to data copies between
1288-
host and device.
1289+
.. warning::
1290+
Apart from ``'__call__'`` (invoked by, e.g., ``np.add(x, y))``,
1291+
CuPy has no native implementation of ufunc methods like
1292+
``'reduce'`` or ``'accumulate'``. We manually implement the
1293+
mappings (covering most use cases)
1294+
1295+
- ``np.add.reduce`` -> ``cupy.sum``
1296+
- ``np.add.accumulate`` -> ``cupy.cumsum``
1297+
- ``np.multiply.reduce`` -> ``cupy.prod``
1298+
- ``np.multiply.reduce`` -> ``cupy.cumprod``.
1299+
1300+
**All other such methods will run Numpy code and be slow**!
1301+
1302+
Please consult the `CuPy documentation on ufuncs
1303+
<https://docs-cupy.chainer.org/en/stable/reference/ufunc.html>`_
1304+
to check the current state of the library.
12891305
12901306
.. note::
12911307
When an ``out`` parameter is specified, and (one of) it has
@@ -1507,10 +1523,10 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
15071523
inp.data if isinstance(inp, type(self)) else inp
15081524
for inp in inputs)
15091525

1510-
# For native ufuncs, we turn non-scalar inputs into cupy arrays,
1511-
# as a workaround for https://github.com/cupy/cupy/issues/594
1512-
# TODO: remove code when the upstream issue is fixed
15131526
if use_native:
1527+
# TODO: remove when upstream issue is fixed
1528+
# For native ufuncs, we turn non-scalar inputs into cupy arrays,
1529+
# as a workaround for https://github.com/cupy/cupy/issues/594
15141530
inputs, orig_inputs = [], inputs
15151531
for inp in orig_inputs:
15161532
if (isinstance(inp, cupy.ndarray) or
@@ -1519,6 +1535,20 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
15191535
inputs.append(inp)
15201536
else:
15211537
inputs.append(cupy.array(inp))
1538+
elif method != 'at':
1539+
# TODO: remove when upstream issue is fixed
1540+
# For non-native ufuncs (except `at`), we need ot cast our tensors
1541+
# and Cupy arrays to Numpy arrays explicitly, since `__array__`
1542+
# and friends are not implemented. See
1543+
# https://github.com/cupy/cupy/issues/589
1544+
inputs, orig_inputs = [], inputs
1545+
for inp in orig_inputs:
1546+
if isinstance(inp, cupy.ndarray):
1547+
inputs.append(cupy.asnumpy(inp))
1548+
elif isinstance(inp, CupyTensor):
1549+
inputs.append(cupy.asnumpy(inp.data))
1550+
else:
1551+
inputs.append(inp)
15221552

15231553
# --- Get some parameters for later --- #
15241554

@@ -1595,20 +1625,19 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
15951625
def eval_at_via_npy(*inputs, **kwargs):
15961626
import ctypes
15971627
cupy_arr = inputs[0]
1598-
npy_arr = np.asarray(cupy_arr)
1628+
npy_arr = cupy.asnumpy(cupy_arr)
15991629
new_inputs = (npy_arr,) + inputs[1:]
16001630
super(CupyTensor, self).__array_ufunc__(
16011631
ufunc, method, *new_inputs, **kwargs)
16021632
# Workaround for https://github.com/cupy/cupy/issues/593
1603-
# TODO: use cupy_arr[:] = npy_arr when it's fixed and not
1604-
# slower
1633+
# TODO: use cupy_arr[:] = npy_arr when available
16051634
cupy_arr.data.copy_from_host(
16061635
npy_arr.ctypes.data_as(ctypes.c_void_p), npy_arr.nbytes)
16071636

16081637
if use_native:
16091638
# Native method could exist but raise `NotImplementedError`
1610-
# or return `NotImplemented`, falling back to Numpy case
1611-
# then, too
1639+
# or return `NotImplemented`. We fall back to Numpy also in
1640+
# that situation.
16121641
try:
16131642
res = native_method(*inputs, **kwargs)
16141643
except NotImplementedError:
@@ -1626,8 +1655,8 @@ def eval_at_via_npy(*inputs, **kwargs):
16261655

16271656
if use_native:
16281657
# Native method could exist but raise `NotImplementedError`
1629-
# or return `NotImplemented`, falling back to base case
1630-
# then, too
1658+
# or return `NotImplemented`. We fall back to Numpy also in
1659+
# that situation.
16311660
try:
16321661
res = native_method(*inputs, **kwargs)
16331662
except NotImplementedError:
@@ -1653,8 +1682,8 @@ def eval_at_via_npy(*inputs, **kwargs):
16531682
if is_floating_dtype(res.dtype):
16541683
if res.shape != self.shape:
16551684
# Don't propagate weighting if shape changes
1656-
weighting = CupyTensorSpaceConstWeighting(1.0,
1657-
exponent)
1685+
weighting = CupyTensorSpaceConstWeighting(
1686+
1.0, exponent)
16581687
spc_kwargs = {'weighting': weighting}
16591688
else:
16601689
spc_kwargs = {}
@@ -1723,8 +1752,10 @@ def real(self):
17231752
real : `CupyTensor` view with real dtype
17241753
The real part of this tensor as an element of an `rn` space.
17251754
"""
1726-
# Only real dtypes currently
1727-
return self
1755+
if self.space.is_real:
1756+
return self
1757+
else:
1758+
return self.space.real_space.element(self.data.real)
17281759

17291760
@real.setter
17301761
def real(self, newreal):
@@ -1737,7 +1768,7 @@ def real(self, newreal):
17371768
newreal : `array-like` or scalar
17381769
The new real part for this tensor.
17391770
"""
1740-
self.real.data[:] = newreal
1771+
self.data.real[:] = newreal
17411772

17421773
@property
17431774
def imag(self):
@@ -1748,8 +1779,10 @@ def imag(self):
17481779
imag : `CupyTensor`
17491780
The imaginary part of this tensor as an element of an `rn` space.
17501781
"""
1751-
# Only real dtypes currently
1752-
return self.space.zero()
1782+
if self.space.is_real:
1783+
return self.space.zero()
1784+
else:
1785+
return self.space.real_space.element(self.data.imag)
17531786

17541787
@imag.setter
17551788
def imag(self, newimag):
@@ -1762,7 +1795,7 @@ def imag(self, newimag):
17621795
newimag : `array-like` or scalar
17631796
The new imaginary part for this tensor.
17641797
"""
1765-
raise NotImplementedError('complex dtypes not supported')
1798+
self.data.imag[:] = newimag
17661799

17671800
def conj(self, out=None):
17681801
"""Complex conjugate of this tensor.
@@ -1779,11 +1812,17 @@ def conj(self, out=None):
17791812
The complex conjugate tensor. If ``out`` was provided,
17801813
the returned object is a reference to it.
17811814
"""
1782-
# Only real dtypes currently
17831815
if out is None:
1784-
return self.copy()
1816+
if self.space.is_real:
1817+
return self.copy()
1818+
else:
1819+
return self.space.element(self.data.conj())
17851820
else:
1786-
self.assign(out)
1821+
if self.space.is_real:
1822+
self.assign(out)
1823+
else:
1824+
# In-place not available as it seems
1825+
out[:] = self.data.conj()
17871826
return out
17881827

17891828
def __ipow__(self, other):
@@ -1806,7 +1845,8 @@ def _weighting(weights, exponent):
18061845
if np.isscalar(weights):
18071846
weighting = CupyTensorSpaceConstWeighting(weights, exponent=exponent)
18081847
else:
1809-
# TODO: sequence of 1D array-likes
1848+
# TODO: sequence of 1D array-likes, see
1849+
# https://github.com/odlgroup/odl/pull/1238
18101850
weights = cupy.array(weights, copy=False)
18111851
weighting = CupyTensorSpaceArrayWeighting(weights, exponent=exponent)
18121852
return weighting
@@ -2065,6 +2105,28 @@ def dist(self, x1, x2):
20652105
else:
20662106
return float(distpw(x1.data, x2.data, self.exponent, self.array))
20672107

2108+
# TODO: remove repr_part and __repr__ when cupy.ndarray.__array__
2109+
# is implemented. See
2110+
# https://github.com/cupy/cupy/issues/589
2111+
@property
2112+
def repr_part(self):
2113+
"""String usable in a space's ``__repr__`` method."""
2114+
# TODO: use edgeitems
2115+
arr_str = array_str(cupy.asnumpy(self.array), nprint=10)
2116+
optargs = [('weighting', arr_str, ''),
2117+
('exponent', self.exponent, 2.0)]
2118+
return signature_string([], optargs, sep=',\n',
2119+
mod=[[], ['!s', ':.4']])
2120+
2121+
def __repr__(self):
2122+
"""Return ``repr(self)``."""
2123+
# TODO: use edgeitems
2124+
posargs = [array_str(cupy.asnumpy(self.array), nprint=10)]
2125+
optargs = [('exponent', self.exponent, 2.0)]
2126+
inner_str = signature_string(posargs, optargs, sep=',\n',
2127+
mod=['!s', ':.4'])
2128+
return '{}(\n{}\n)'.format(self.__class__.__name__, indent(inner_str))
2129+
20682130

20692131
class CupyTensorSpaceConstWeighting(ConstWeighting):
20702132

0 commit comments

Comments
 (0)