Skip to content

Commit 90fb03c

Browse files
committed
black
1 parent b54e2c5 commit 90fb03c

File tree

1 file changed

+18
-12
lines changed

1 file changed

+18
-12
lines changed

array_api_compat/common/_helpers.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -788,6 +788,7 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]]
788788
return x
789789
return x.to_device(device, stream=stream)
790790

791+
791792
def size(x):
792793
"""
793794
Return the total number of elements of x.
@@ -802,6 +803,7 @@ def size(x):
802803
return None
803804
return math.prod(x.shape)
804805

806+
805807
def is_writeable_array(x):
806808
"""
807809
Return False if x.__setitem__ is expected to raise; True otherwise
@@ -812,6 +814,7 @@ def is_writeable_array(x):
812814
return False
813815
return True
814816

817+
815818
def _parse_copy_param(x, copy: bool | None | Literal["_force_false"]) -> bool:
816819
"""Preprocess and validate a copy parameter, in line with the same
817820
parameter in np.asarray(), np.astype(), etc.
@@ -827,8 +830,10 @@ def _parse_copy_param(x, copy: bool | None | Literal["_force_false"]) -> bool:
827830
raise ValueError(f"Invalid value for copy: {copy!r}")
828831
return copy
829832

833+
830834
_undef = object()
831835

836+
832837
class at:
833838
"""
834839
Update operations for read-only arrays.
@@ -897,6 +902,7 @@ class at:
897902
--------
898903
https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html
899904
"""
905+
900906
__slots__ = ("x", "idx")
901907

902908
def __init__(self, x, idx=_undef):
@@ -945,7 +951,7 @@ def _common(
945951
if copy and is_jax_array(self.x):
946952
# Use JAX's at[]
947953
at_ = self.x.at[self.idx]
948-
args = (y, ) if y is not _undef else ()
954+
args = (y,) if y is not _undef else ()
949955
return getattr(at_, at_op)(*args, **kwargs), None
950956

951957
# Emulate at[] behaviour for non-JAX arrays
@@ -958,12 +964,9 @@ def get(self, copy: bool | None = True, **kwargs):
958964
# Special case when xp=numpy and idx is a fancy index
959965
# If copy is not False, avoid an unnecessary double copy.
960966
# if copy is forced to False, raise.
961-
if (
962-
is_numpy_array(self.x)
963-
and (
964-
isinstance(self.idx, (list, tuple))
965-
or (is_numpy_array(self.idx) and self.idx.dtype.kind in "biu")
966-
)
967+
if is_numpy_array(self.x) and (
968+
isinstance(self.idx, (list, tuple))
969+
or (is_numpy_array(self.idx) and self.idx.dtype.kind in "biu")
967970
):
968971
if copy is False:
969972
raise ValueError(
@@ -994,12 +997,14 @@ def apply(self, ufunc, /, **kwargs):
994997
ufunc.at(x, self.idx)
995998
return x
996999

997-
def _iop(self, at_op: str, elwise_op: Callable[[Array, Array], Array], y: Array, **kwargs):
1000+
def _iop(
1001+
self, at_op: str, elwise_op: Callable[[Array, Array], Array], y: Array, **kwargs
1002+
):
9981003
"""x[idx] += y or equivalent in-place operation on a subset of x
9991004
10001005
which is the same as saying
10011006
x[idx] = x[idx] + y
1002-
Note that this is not the same as
1007+
Note that this is not the same as
10031008
operator.iadd(x[idx], y)
10041009
Consider for example when x is a numpy array and idx is a fancy index, which
10051010
triggers a deep copy on __getitem__.
@@ -1017,11 +1022,11 @@ def add(self, y, /, **kwargs):
10171022
def subtract(self, y, /, **kwargs):
10181023
"""x[idx] -= y"""
10191024
return self._iop("subtract", operator.sub, y, **kwargs)
1020-
1025+
10211026
def multiply(self, y, /, **kwargs):
10221027
"""x[idx] *= y"""
10231028
return self._iop("multiply", operator.mul, y, **kwargs)
1024-
1029+
10251030
def divide(self, y, /, **kwargs):
10261031
"""x[idx] /= y"""
10271032
return self._iop("divide", operator.truediv, y, **kwargs)
@@ -1040,9 +1045,10 @@ def max(self, y, /, **kwargs):
10401045
xp = array_namespace(self.x)
10411046
return self._iop("max", xp.maximum, y, **kwargs)
10421047

1048+
10431049
def where(condition, x=None, y=None, /, copy: bool | None = True):
10441050
"""Return elements from x when condition is True and from y when
1045-
it is False.
1051+
it is False.
10461052
10471053
This is a wrapper around xp.where that adds the copy parameter:
10481054

0 commit comments

Comments
 (0)