Skip to content

Commit 88d4aec

Browse files
authored
test: update tests with latest commits from main (#220)
* test: update tests to branch to ensure they pass * fix: use in-place ops properly * fix: buggy thing * fix: correct FFT shape for inplace * fix: avoid multiple extra calls to set things inplace * test: fix more test warnings * test: fix more test warnings * test: latest commit * fix: tests for implements maybe * test: make tests robust across versions * fix: ensure implements removes leading spaces in older pythons * fix: make sure coeffs is cast to an array * debug: stop on first failure for now * test: use latest testing branch * fix: ensure we cast to native byte order * test: use latest testing code * fix: more native byte order casts * fix: make a copy * style: pre-commit * Apply suggestion from @beckermr * test: try a new test that might be more robust * test: run it all via split * test: try this test * test: use latest changes * fix: do not store durations for float32 tests * test: update tests submodule
1 parent d9f0155 commit 88d4aec

14 files changed

Lines changed: 252 additions & 117 deletions

File tree

.github/workflows/python_package.yaml

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,17 @@ jobs:
6161
cat .test_durations*
6262
fi
6363
64+
- name: Test with pytest in float32
65+
run: |
66+
pytest \
67+
-vv \
68+
--durations=100 \
69+
--randomly-seed=42 \
70+
--splits ${NUM_SPLITS} --group ${{ matrix.group }} \
71+
--splitting-algorithm least_duration \
72+
--retries 1 \
73+
--test-in-float32
74+
6475
- name: Test with pytest
6576
run: |
6677
pytest \
@@ -74,13 +85,6 @@ jobs:
7485
--clean-durations \
7586
--retries 1
7687
77-
- name: Test with pytest in float32
78-
if: ${{ matrix.group == '1' }}
79-
run: |
80-
pytest \
81-
-vv \
82-
--test-in-float32
83-
8488
- name: Upload test durations
8589
uses: actions/upload-artifact@v7
8690
with:

jax_galsim/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from .errors import GalSimKeyError, GalSimIndexError, GalSimNotImplementedError
99
from .errors import GalSimBoundsError, GalSimUndefinedBoundsError, GalSimImmutableError
1010
from .errors import GalSimIncompatibleValuesError, GalSimSEDError, GalSimHSMError
11-
from .errors import GalSimFFTSizeError
11+
from .errors import GalSimFFTSizeError, GalSimFFTSizeWarning
1212
from .errors import GalSimConfigError, GalSimConfigValueError
1313
from .errors import GalSimWarning, GalSimDeprecationWarning
1414

jax_galsim/convolve.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -486,10 +486,12 @@ def _kValue(self, pos):
486486

487487
def _drawKImage(self, image, jac=None):
488488
image = self.orig_obj._drawKImage(image, jac)
489-
image._array = jnp.where(
490-
jnp.abs(image.array) > self._min_acc_kvalue,
491-
1.0 / image.array,
492-
self._inv_min_acc_kvalue,
489+
image._array = image._array.at[...].set(
490+
jnp.where(
491+
jnp.abs(image.array) > self._min_acc_kvalue,
492+
1.0 / image.array,
493+
self._inv_min_acc_kvalue,
494+
)
493495
)
494496
kx, ky = image.get_pixel_centers()
495497
_jac = jnp.eye(2) if jac is None else jac
@@ -500,10 +502,12 @@ def _drawKImage(self, image, jac=None):
500502
)
501503
ksq = (kx**2 + ky**2) * image.scale**2
502504
# Set to zero outside of nominal maxk so as not to amplify high frequencies.
503-
image._array = jnp.where(
504-
ksq > self.maxk**2,
505-
0.0,
506-
image.array,
505+
image._array = image._array.at[...].set(
506+
jnp.where(
507+
ksq > self.maxk**2,
508+
0.0,
509+
image.array,
510+
)
507511
)
508512
return image
509513

jax_galsim/core/utils.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,17 @@
99
from jax.tree_util import tree_flatten
1010

1111

12+
def cast_numpy_array_to_native_byte_order(arr):
13+
"""Cast an array to native byte order."""
14+
if not isinstance(arr, np.ndarray):
15+
return arr
16+
17+
if arr.dtype.isnative:
18+
return arr
19+
20+
return arr.astype(arr.dtype.newbyteorder("="))
21+
22+
1223
def has_tracers(x):
1324
"""Return True if the input item is a JAX tracer or object, False otherwise."""
1425
for item in tree_flatten(x)[0]:
@@ -296,7 +307,7 @@ class ParsedDoc(NamedTuple):
296307
sections: dict[str, str] = {}
297308

298309

299-
def _break_off_body_section_by_newline(body):
310+
def _break_off_body_section_by_newline(body, double_check_first_indent=False):
300311
first_lines = []
301312
body_lines = []
302313
found_first_break = False
@@ -314,7 +325,14 @@ def _break_off_body_section_by_newline(body):
314325
else:
315326
first_lines.append(line)
316327

328+
if double_check_first_indent and len(first_lines) > 1:
329+
len_first_indent = len(first_lines[1]) - len(first_lines[1].lstrip())
330+
if len_first_indent > 0:
331+
first_indent = first_lines[1][:len_first_indent]
332+
first_lines[0] = first_indent + first_lines[0].lstrip()
333+
317334
firstline = "\n".join(first_lines)
335+
firstline = textwrap.dedent(firstline)
318336
body = "\n".join(body_lines)
319337
body = textwrap.dedent(body.lstrip("\n"))
320338

@@ -337,7 +355,9 @@ def _parse_galsimdoc(docstr):
337355

338356
signature, body = "", docstr
339357

340-
firstline, body = _break_off_body_section_by_newline(body)
358+
firstline, body = _break_off_body_section_by_newline(
359+
body, double_check_first_indent=True
360+
)
341361

342362
summary = firstline
343363
if not summary:

jax_galsim/errors.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
GalSimDeprecationWarning,
66
GalSimError,
77
GalSimFFTSizeError,
8+
GalSimFFTSizeWarning,
89
GalSimHSMError,
910
GalSimImmutableError,
1011
GalSimIncompatibleValuesError,

jax_galsim/gsobject.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -809,7 +809,7 @@ def drawReal(self, image, add_to_image=False):
809809
im1 = self._drawReal(image)
810810
temp = im1.subImage(image.bounds)
811811
if add_to_image:
812-
image._array = image._array + temp._array
812+
image._array = image._array.at[...].add(temp._array)
813813
else:
814814
image._array = temp._array
815815

@@ -929,7 +929,7 @@ def drawFFT_finish(self, image, kimage, wrap_size, add_to_image):
929929
# Add (a portion of) this to the original image.
930930
temp = real_image.subImage(image.bounds)
931931
if add_to_image:
932-
image._array = image._array + temp._array
932+
image._array = image._array.at[...].add(temp._array)
933933
else:
934934
image._array = temp._array
935935

@@ -1043,7 +1043,7 @@ def drawKImage(
10431043
if not add_to_image:
10441044
image._array = im2._array
10451045
else:
1046-
image._array = im2._array + image._array
1046+
image._array = image._array.at[...].add(im2._array)
10471047

10481048
image_in._array = image._array
10491049
image_in._bounds = image._bounds

0 commit comments

Comments
 (0)