1
+ from __future__ import annotations
2
+
3
+ import warnings
1
4
from types import ModuleType
2
5
from typing import Any
3
6
13
16
)
14
17
from xarray .namedarray .core import NamedArray
15
18
19
+ with warnings .catch_warnings ():
20
+ warnings .filterwarnings (
21
+ "ignore" ,
22
+ r"The numpy.array_api submodule is still experimental" ,
23
+ category = UserWarning ,
24
+ )
25
+ import numpy .array_api as nxp # noqa: F401
26
+
16
27
17
28
def _get_data_namespace (x : NamedArray [Any , Any ]) -> ModuleType :
18
29
if isinstance (x ._data , _arrayapi ):
19
30
return x ._data .__array_namespace__ ()
20
- else :
21
- return np
31
+
32
+ return np
33
+
34
+
35
+ # %% Creation Functions
22
36
23
37
24
38
def astype (
@@ -49,18 +63,25 @@ def astype(
49
63
50
64
Examples
51
65
--------
52
- >>> narr = NamedArray(("x",), np.array([1.5, 2.5]))
53
- >>> astype(narr, np.dtype(int)).data
54
- array([1, 2])
66
+ >>> narr = NamedArray(("x",), nxp.asarray([1.5, 2.5]))
67
+ >>> narr
68
+ <xarray.NamedArray (x: 2)>
69
+ Array([1.5, 2.5], dtype=float64)
70
+ >>> astype(narr, np.dtype(np.int32))
71
+ <xarray.NamedArray (x: 2)>
72
+ Array([1, 2], dtype=int32)
55
73
"""
56
74
if isinstance (x ._data , _arrayapi ):
57
75
xp = x ._data .__array_namespace__ ()
58
- return x ._new (data = xp .astype (x , dtype , copy = copy ))
76
+ return x ._new (data = xp .astype (x . _data , dtype , copy = copy ))
59
77
60
78
# np.astype doesn't exist yet:
61
79
return x ._new (data = x ._data .astype (dtype , copy = copy )) # type: ignore[attr-defined]
62
80
63
81
82
+ # %% Elementwise Functions
83
+
84
+
64
85
def imag (
65
86
x : NamedArray [_ShapeType , np .dtype [_SupportsImag [_ScalarType ]]], / # type: ignore[type-var]
66
87
) -> NamedArray [_ShapeType , np .dtype [_ScalarType ]]:
@@ -83,8 +104,9 @@ def imag(
83
104
84
105
Examples
85
106
--------
86
- >>> narr = NamedArray(("x",), np.array([1 + 2j, 2 + 4j]))
87
- >>> imag(narr).data
107
+ >>> narr = NamedArray(("x",), np.asarray([1.0 + 2j, 2 + 4j])) # TODO: Use nxp
108
+ >>> imag(narr)
109
+ <xarray.NamedArray (x: 2)>
88
110
array([2., 4.])
89
111
"""
90
112
xp = _get_data_namespace (x )
@@ -114,9 +136,11 @@ def real(
114
136
115
137
Examples
116
138
--------
117
- >>> narr = NamedArray(("x",), np.array([1 + 2j, 2 + 4j]))
118
- >>> real(narr).data
139
+ >>> narr = NamedArray(("x",), np.asarray([1.0 + 2j, 2 + 4j])) # TODO: Use nxp
140
+ >>> real(narr)
141
+ <xarray.NamedArray (x: 2)>
119
142
array([1., 2.])
120
143
"""
121
144
xp = _get_data_namespace (x )
122
- return x ._new (data = xp .real (x ._data ))
145
+ out = x ._new (data = xp .real (x ._data ))
146
+ return out
0 commit comments