Skip to content

Commit 321ea16

Browse files
committed
reformatted files with black 24.x
1 parent b006dd7 commit 321ea16

File tree

4 files changed

+31
-26
lines changed

4 files changed

+31
-26
lines changed

environment-ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
name: pyart-dev
1+
name: tobac-dev
22
channels:
33
- conda-forge
44
dependencies:

tobac/feature_detection.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1015,7 +1015,6 @@ def feature_detection_multithreshold_timestep(
10151015
raise ValueError(
10161016
"Please provide the input parameter statistic to determine what statistics to calculate."
10171017
)
1018-
10191018

10201019
track_data = gaussian_filter(
10211020
track_data, sigma=sigma_threshold

tobac/tests/test_utils_bulk_statistics.py

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -153,51 +153,51 @@ def test_bulk_statistics_missing_segments():
153153
### Test 2D data with time dimension
154154
test_data = tb_test.make_simple_sample_data_2D().core_data()
155155
common_dset_opts = {
156-
"in_arr": test_data,
157-
"data_type": "iris",}
158-
156+
"in_arr": test_data,
157+
"data_type": "iris",
158+
}
159+
159160
test_data_iris = tb_test.make_dataset_from_arr(
160-
time_dim_num=0, y_dim_num=1, x_dim_num=2, **common_dset_opts)
161+
time_dim_num=0, y_dim_num=1, x_dim_num=2, **common_dset_opts
162+
)
161163

162164
# detect features
163165
threshold = 7
164166
# test_data_iris = testing.make_dataset_from_arr(test_data, data_type="iris")
165167
fd_output = tobac.feature_detection.feature_detection_multithreshold(
166-
test_data_iris,
167-
dxy=1000,
168-
threshold=[threshold],
169-
n_min_threshold=100,
170-
target="maximum",)
168+
test_data_iris,
169+
dxy=1000,
170+
threshold=[threshold],
171+
n_min_threshold=100,
172+
target="maximum",
173+
)
171174

172175
# perform segmentation with bulk statistics
173176
stats = {
174-
"segment_max": np.max,
175-
"segment_min": min,
176-
"percentiles": (np.percentile, {"q": 95}),}
177+
"segment_max": np.max,
178+
"segment_min": min,
179+
"percentiles": (np.percentile, {"q": 95}),
180+
}
177181

178182
out_seg_mask, out_df = tobac.segmentation.segmentation_2D(
179-
fd_output, test_data_iris, dxy=1000, threshold=threshold)
183+
fd_output, test_data_iris, dxy=1000, threshold=threshold
184+
)
180185

181-
# specify some timesteps we set to zero
186+
# specify some timesteps we set to zero
182187
timesteps_to_zero = [1, 3, 10] # 0-based indexing
183-
modified_data = out_seg_mask.data.copy()
188+
modified_data = out_seg_mask.data.copy()
184189
# Set values to zero for the specified timesteps
185190
for timestep in timesteps_to_zero:
186191
modified_data[timestep, :, :] = 0 # Set all values for this timestep to zero
187192

188193
# assure that bulk statistics in postprocessing give same result
189194
out_segmentation = tb_utils.get_statistics_from_mask(
190-
out_df, out_seg_mask, test_data_iris, statistic=stats)
195+
out_df, out_seg_mask, test_data_iris, statistic=stats
196+
)
191197

192198
assert out_df.time.unique().size == out_segmentation.time.unique().size
193199

194200

195-
196-
197-
198-
199-
200-
201201
def test_bulk_statistics_multiple_fields():
202202
"""
203203
Test that multiple field input to bulk_statistics works as intended

tobac/utils/bulk_statistics.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -301,9 +301,15 @@ def get_statistics_from_mask(
301301

302302
for tt in pd.to_datetime(segmentation_mask.time):
303303
# select specific timestep
304-
segmentation_mask_t = segmentation_mask.sel(time=tt, method = 'nearest').data
304+
segmentation_mask_t = segmentation_mask.sel(time=tt, method="nearest").data
305305
fields_t = (
306-
field.sel(time=tt, method = 'nearest', tolerance = np.timedelta64(1000, 'us')).values if "time" in field.coords else field.values
306+
(
307+
field.sel(
308+
time=tt, method="nearest", tolerance=np.timedelta64(1000, "us")
309+
).values
310+
if "time" in field.coords
311+
else field.values
312+
)
307313
for field in fields
308314
)
309315

0 commit comments

Comments
 (0)