Skip to content

Commit 9e2e04a

Browse files
authored
Add tests for collections and tools; improve numpy2 support (#139)
* Soft-reset FastADT changes to 1 commit for easier maintenance * Fix `PartialFormatter` needlessly stripping `format_spec` * Add tests for all members of `instamatic._collections` * Improve gatansocket3.py longarray padding to numpy-2 compatible * Improve gatansocket3.py longarray padding to numpy-2 compatible * Add simple tests for, clear duplicates in instamatic.tools * Temporarily require numpy > 2 for GitHub testing * Revert "Temporarily require numpy > 2 for GitHub testing" This reverts commit 3e7de09. * Lift the numpy < 2 requirement * With numpy 2, can we allow Python 3.13 tests? * LabelFrame -> super, Make # noqa more specific to appease ruff & PyCharm
1 parent 63155a2 commit 9e2e04a

File tree

11 files changed

+177
-63
lines changed

11 files changed

+177
-63
lines changed

.github/workflows/test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ jobs:
2222
strategy:
2323
fail-fast: false
2424
matrix:
25-
python-version: ['3.9', '3.10', '3.11', '3.12', ]
25+
python-version: ['3.9', '3.10', '3.11', '3.12', '3.13', ]
2626

2727
steps:
2828
- uses: actions/checkout@v3

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ classifiers = [
3131
"Programming Language :: Python :: 3.10",
3232
"Programming Language :: Python :: 3.11",
3333
"Programming Language :: Python :: 3.12",
34+
"Programming Language :: Python :: 3.13",
3435
"Development Status :: 5 - Production/Stable",
3536
"Intended Audience :: Science/Research",
3637
"License :: OSI Approved :: BSD License",
@@ -46,7 +47,7 @@ dependencies = [
4647
"lmfit >= 1.0.0",
4748
"matplotlib >= 3.1.2",
4849
"mrcfile >= 1.1.2",
49-
"numpy >= 1.17.3, <2",
50+
"numpy >= 1.17.3",
5051
"pandas >= 1.0.0",
5152
"pillow >= 7.0.0",
5253
"pywinauto >= 0.6.8; sys_platform == 'windows'",

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ ipython >= 7.11.1
55
lmfit >= 1.0.0
66
matplotlib >= 3.1.2
77
mrcfile >= 1.1.2
8-
numpy >= 1.17.3, <2
8+
numpy >= 1.17.3
99
pandas >= 1.0.0
1010
pillow >= 7.0.0
1111
pywinauto >= 0.6.8; sys_platform == 'windows'

scripts/process_dm.py

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from skimage.exposure import rescale_intensity
99

1010
from instamatic.processing.ImgConversionDM import ImgConversionDM as ImgConversion
11+
from instamatic.tools import relativistic_wavelength
1112

1213
# Script to process cRED data collecting using the DigitalMicrograph script `insteaDMatic`
1314
# https://github.com/instamatic-dev/InsteaDMatic
@@ -24,21 +25,6 @@
2425
# all `cred_log.txt` files in the subdirectories, and iterate over those.
2526

2627

27-
def relativistic_wavelength(voltage: float = 200):
28-
"""Calculate the relativistic wavelength of electrons Voltage in kV Return
29-
wavelength in Angstrom."""
30-
voltage *= 1000 # -> V
31-
32-
h = 6.626070150e-34 # planck constant J.s
33-
m = 9.10938356e-31 # electron rest mass kg
34-
e = 1.6021766208e-19 # elementary charge C
35-
c = 299792458 # speed of light m/s
36-
37-
wl = h / (2 * m * voltage * e * (1 + (e * voltage) / (2 * m * c**2))) ** 0.5
38-
39-
return round(wl * 1e10, 6) # m -> Angstrom
40-
41-
4228
def img_convert(credlog, tiff_path='tiff2', mrc_path='RED', smv_path='SMV'):
4329
credlog = Path(credlog)
4430
drc = credlog.parent
@@ -90,7 +76,7 @@ def img_convert(credlog, tiff_path='tiff2', mrc_path='RED', smv_path='SMV'):
9076
if line.startswith('Resolution:'):
9177
resolution = line.split()[-1]
9278

93-
wavelength = relativistic_wavelength(high_tension)
79+
wavelength = relativistic_wavelength(high_tension * 1000)
9480

9581
# convert from um to mm
9682
physical_pixelsize = physical_pixelsize[0] / 1000

src/instamatic/_collections.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
from __future__ import annotations
22

3-
import contextlib
43
import logging
54
import string
6-
import time
75
from collections import UserDict
8-
from typing import Any, Callable
6+
from dataclasses import dataclass
7+
from typing import Any
98

109

1110
class NoOverwriteDict(UserDict):
@@ -18,6 +17,8 @@ def __setitem__(self, key: Any, value: Any) -> None:
1817

1918

2019
class NullLogger(logging.Logger):
20+
"""A logger mock that ignores all logging, to be used in headless mode."""
21+
2122
def __init__(self, name='null'):
2223
super().__init__(name)
2324
self.addHandler(logging.NullHandler())
@@ -27,24 +28,28 @@ def __init__(self, name='null'):
2728
class PartialFormatter(string.Formatter):
2829
"""`str.format` alternative, allows for partial replacement of {fields}"""
2930

31+
@dataclass(frozen=True)
32+
class Missing:
33+
name: str
34+
3035
def __init__(self, missing: str = '{{{}}}') -> None:
3136
super().__init__()
3237
self.missing: str = missing # used instead of missing values
3338

3439
def get_field(self, field_name: str, args, kwargs) -> tuple[Any, str]:
3540
"""When field can't be found, return placeholder text instead."""
3641
try:
37-
obj, used_key = super().get_field(field_name, args, kwargs)
38-
return obj, used_key
42+
return super().get_field(field_name, args, kwargs)
3943
except (KeyError, AttributeError, IndexError, TypeError):
40-
return self.missing.format(field_name), field_name
44+
return PartialFormatter.Missing(field_name), field_name
4145

4246
def format_field(self, value: Any, format_spec: str) -> str:
4347
"""If the field was not found, format placeholder as string instead."""
44-
try:
45-
return super().format_field(value, format_spec)
46-
except (ValueError, TypeError):
47-
return str(value)
48+
if isinstance(value, PartialFormatter.Missing):
49+
if format_spec:
50+
return self.missing.format(f'{value.name}:{format_spec}')
51+
return self.missing.format(f'{value.name}')
52+
return super().format_field(value, format_spec)
4853

4954

5055
partial_formatter = PartialFormatter()

src/instamatic/calibrate/calibrate_stage_rotation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ def to_file(self, outdir: Optional[str] = None) -> None:
163163
outdir = calibration_drc
164164
yaml_path = Path(outdir) / CALIB_STAGE_ROTATION
165165
with open(yaml_path, 'w') as yaml_file:
166-
yaml.safe_dump(asdict(self), yaml_file) # noqa: correct type
166+
yaml.safe_dump(asdict(self), yaml_file) # type: ignore[arg-type]
167167
log(f'{self} saved to {yaml_path}.')
168168

169169
def plot(self, sst: Optional[list[SpanSpeedTime]] = None) -> None:

src/instamatic/camera/gatansocket3.py

Lines changed: 17 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,17 @@
8686
sArgsBuffer = np.zeros(ARGS_BUFFER_SIZE, dtype=np.byte)
8787

8888

89+
def string_to_longarray(string: str, *, dtype: np.dtype = np.int_) -> np.ndarray:
90+
"""Convert the string to a 1D np array of dtype (default np.int_ - C long)
91+
with numpy2-save padding to ensure length is a multiple of dtype.itemsize.
92+
"""
93+
s_bytes = string.encode('utf-8')
94+
dtype_size = np.dtype(dtype).itemsize
95+
if extra := len(s_bytes) % dtype_size:
96+
s_bytes += b'\0' * (dtype_size - extra)
97+
return np.frombuffer(s_bytes, dtype=dtype)
98+
99+
89100
class Message:
90101
"""Information packet to send and receive on the socket.
91102
@@ -335,14 +346,7 @@ def SetK2Parameters(
335346
funcCode = enum_gs['GS_SetK2Parameters']
336347

337348
self.save_frames = saveFrames
338-
339-
# filter name
340-
filt_str = filt + '\0'
341-
extra = len(filt_str) % 4
342-
if extra:
343-
npad = 4 - extra
344-
filt_str = filt_str + npad * '\0'
345-
longarray = np.frombuffer(filt_str.encode(), dtype=np.int_)
349+
longarray = string_to_longarray(filt + '\0', dtype=np.int_) # filter name
346350

347351
longs = [
348352
funcCode,
@@ -397,12 +401,7 @@ def SetupFileSaving(
397401
longs = [enum_gs['GS_SetupFileSaving'], rotationFlip]
398402
dbls = [pixelSize]
399403
bools = [filePerImage]
400-
names_str = dirname + '\0' + rootname + '\0'
401-
extra = len(names_str) % 4
402-
if extra:
403-
npad = 4 - extra
404-
names_str = names_str + npad * '\0'
405-
longarray = np.frombuffer(names_str.encode(), dtype=np.int_)
404+
longarray = string_to_longarray(dirname + '\0' + rootname + '\0', dtype=np.int_)
406405
message_send = Message(
407406
longargs=longs, boolargs=bools, dblargs=dbls, longarray=longarray
408407
)
@@ -664,24 +663,18 @@ def ExecuteScript(
664663
select_camera=0,
665664
recv_longargs_init=(0,),
666665
recv_dblargs_init=(0.0,),
667-
recv_longarray_init=[],
666+
recv_longarray_init=None,
668667
):
668+
"""Send the command string as a 1D longarray of np.int_ dtype."""
669669
funcCode = enum_gs['GS_ExecuteScript']
670-
cmd_str = command_line + '\0'
671-
extra = len(cmd_str) % 4
672-
if extra:
673-
npad = 4 - extra
674-
cmd_str = cmd_str + (npad) * '\0'
675-
# send the command string as 1D longarray
676-
longarray = np.frombuffer(cmd_str.encode(), dtype=np.int_)
677-
# print(longaray)
670+
longarray = string_to_longarray(command_line + '\0', dtype=np.int_)
678671
message_send = Message(
679672
longargs=(funcCode,), boolargs=(select_camera,), longarray=longarray
680673
)
681674
message_recv = Message(
682675
longargs=recv_longargs_init,
683676
dblargs=recv_dblargs_init,
684-
longarray=recv_longarray_init,
677+
longarray=[] if recv_longarray_init is None else recv_longarray_init,
685678
)
686679
self.ExchangeMessages(message_send, message_recv)
687680
return message_recv

src/instamatic/gui/fast_adt_frame.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,8 @@ def as_dict(self):
7070
class ExperimentalFastADT(LabelFrame):
7171
"""GUI panel to perform selected FastADT-style (c)RED & PED experiments."""
7272

73-
def __init__(self, parent): # noqa: parent.__init__ is called
74-
LabelFrame.__init__(self, parent, text='Experiment with a priori tracking options')
73+
def __init__(self, parent):
74+
super().__init__(parent, text='Experiment with a priori tracking options')
7575
self.parent = parent
7676
self.var = ExperimentalFastADTVariables()
7777
self.q: Optional[Queue] = None

src/instamatic/tools.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,11 @@
11
from __future__ import annotations
22

3-
import glob
4-
import os
53
import sys
64
from pathlib import Path
7-
from typing import Tuple
5+
from typing import Iterator
86

97
import numpy as np
108
from scipy import interpolate, ndimage
11-
from skimage import exposure
129
from skimage.measure import regionprops
1310

1411

@@ -71,9 +68,13 @@ def to_xds_untrusted_area(kind: str, coords: list) -> str:
7168
raise ValueError('Only quadrilaterals are supported for now')
7269

7370

74-
def find_subranges(lst: list) -> Tuple[int, int]:
71+
def find_subranges(lst: list[int]) -> Iterator[tuple[int, int]]:
7572
"""Takes a range of sequential numbers (possibly with gaps) and splits them
76-
in sequential sub-ranges defined by the minimum and maximum value."""
73+
in sequential sub-ranges defined by the minimum and maximum value.
74+
75+
Example:
76+
[1,2,3,7,8,10] --> (1,3), (7,8), (10,10)
77+
"""
7778
from itertools import groupby
7879
from operator import itemgetter
7980

@@ -274,7 +275,7 @@ def get_acquisition_time(
274275

275276

276277
def relativistic_wavelength(voltage: float = 200_000) -> float:
277-
"""Calculate the relativistic wavelength of electrons from the accelarating
278+
"""Calculate the relativistic wavelength of electrons from the accelerating
278279
voltage.
279280
280281
Input: Voltage in V

tests/test_collections.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
from __future__ import annotations
2+
3+
import logging
4+
from contextlib import nullcontext
5+
from dataclasses import dataclass, field
6+
from typing import Any, Optional, Type
7+
8+
import pytest
9+
10+
import instamatic._collections as ic
11+
from tests.utils import InstanceAutoTracker
12+
13+
14+
def test_no_overwrite_dict() -> None:
15+
"""Should work as normal dict unless key exists, in which case raises."""
16+
nod = ic.NoOverwriteDict({1: 2})
17+
nod.update({3: 4})
18+
nod[5] = 6
19+
del nod[1]
20+
nod[1] = 6
21+
assert nod == {1: 6, 3: 4, 5: 6}
22+
with pytest.raises(KeyError):
23+
nod[1] = 2
24+
with pytest.raises(KeyError):
25+
nod.update({3: 4})
26+
27+
28+
def test_null_logger(caplog) -> None:
29+
"""NullLogger should void and not propagate messages to root logger."""
30+
31+
messages = []
32+
handler = logging.StreamHandler()
33+
handler.emit = lambda record: messages.append(record.getMessage())
34+
null_logger = ic.NullLogger()
35+
root_logger = logging.getLogger()
36+
root_logger.addHandler(handler)
37+
38+
with caplog.at_level(logging.DEBUG):
39+
null_logger.debug('debug message that should be ignored')
40+
null_logger.info('info message that should be ignored')
41+
null_logger.warning('warning message that should be ignored')
42+
null_logger.error('error message that should be ignored')
43+
null_logger.critical('critical message that should be ignored')
44+
45+
# Nothing should have been captured by pytest's caplog
46+
root_logger.removeHandler(handler)
47+
assert caplog.records == []
48+
assert caplog.text == ''
49+
assert messages == []
50+
51+
52+
@dataclass
53+
class PartialFormatterTestCase(InstanceAutoTracker):
54+
template: str = '{s} & {f:06.2f}'
55+
args: list[Any] = field(default_factory=list)
56+
kwargs: dict[str, Any] = field(default_factory=dict)
57+
returns: str = ''
58+
raises: Optional[Type[Exception]] = None
59+
60+
61+
PartialFormatterTestCase(returns='{s} & {f:06.2f}')
62+
PartialFormatterTestCase(kwargs={'s': 'Text'}, returns='Text & {f:06.2f}')
63+
PartialFormatterTestCase(kwargs={'f': 3.1415}, returns='{s} & 003.14')
64+
PartialFormatterTestCase(kwargs={'x': 'test'}, returns='{s} & {f:06.2f}')
65+
PartialFormatterTestCase(kwargs={'f': 'Text'}, raises=ValueError)
66+
PartialFormatterTestCase(template='{0}{1}', args=[5], returns='5{1}')
67+
PartialFormatterTestCase(template='{0}{1}', args=[5, 6], returns='56')
68+
PartialFormatterTestCase(template='{0}{1}', args=[5, 6, 7], returns='56')
69+
70+
71+
@pytest.mark.parametrize('test_case', PartialFormatterTestCase.INSTANCES)
72+
def test_partial_formatter(test_case) -> None:
73+
"""Should replace only some {words}, but still fail if format is wrong."""
74+
c = test_case
75+
with pytest.raises(r) if (r := c.raises) else nullcontext():
76+
assert ic.partial_formatter.format(c.template, *c.args, **c.kwargs) == c.returns

0 commit comments

Comments
 (0)