Skip to content

Commit e0e7136

Browse files
committed
Provide 'side' to pick whether we want the closest, smaller or larger value
1 parent d50bd5d commit e0e7136

File tree

1 file changed

+39
-4
lines changed

1 file changed

+39
-4
lines changed

yt/data_objects/time_series.py

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import weakref
77
from abc import ABC, abstractmethod
88
from functools import wraps
9-
from typing import Optional, Union
9+
from typing import Literal, Optional, Union
1010

1111
import numpy as np
1212
from more_itertools import always_iterable
@@ -445,6 +445,9 @@ def _get_by_attribute(
445445
attribute: str,
446446
value: Union[unyt_quantity, tuple[float, str]],
447447
tolerance: Union[None, unyt_quantity, tuple[float, str]] = None,
448+
side: Union[
449+
Literal["nearest"], Literal["smaller"], Literal["larger"]
450+
] = "nearest",
448451
) -> "Dataset":
449452
r"""
450453
Get a dataset at or near to a given value.
@@ -462,8 +465,16 @@ def _get_by_attribute(
462465
within the tolerance value. If None, simply return the
463466
nearest dataset.
464467
Default: None.
468+
side : str
469+
The side of the value to return. Can be 'nearest', 'smaller' or 'larger'.
470+
Default: 'nearest'.
465471
"""
466472

473+
if side not in ("nearest", "smaller", "larger"):
474+
raise ValueError(
475+
f"side must be 'nearest', 'smaller' or 'larger', got {side}"
476+
)
477+
467478
# Use a binary search to find the closest value
468479
iL = 0
469480
iH = len(self._pre_outputs) - 1
@@ -518,7 +529,13 @@ def _get_by_attribute(
518529
dsL = dsH = dsM
519530
break
520531

521-
if abs(value - getattr(dsL, attribute)) < abs(value - getattr(dsH, attribute)):
532+
if side == "smaller":
533+
ds_best = dsL if sign > 0 else dsH
534+
elif side == "larger":
535+
ds_best = dsH if sign > 0 else dsL
536+
elif abs(value - getattr(dsL, attribute)) < abs(
537+
value - getattr(dsH, attribute)
538+
):
522539
ds_best = dsL
523540
else:
524541
ds_best = dsH
@@ -534,6 +551,9 @@ def get_by_time(
534551
self,
535552
time: Union[unyt_quantity, tuple],
536553
tolerance: Union[None, unyt_quantity, tuple] = None,
554+
side: Union[
555+
Literal["nearest"], Literal["smaller"], Literal["larger"]
556+
] = "nearest",
537557
):
538558
"""
539559
Get a dataset at or near to a given time.
@@ -547,16 +567,28 @@ def get_by_time(
547567
within the tolerance value. If None, simply return the
548568
nearest dataset.
549569
Default: None.
570+
side : str
571+
The side of the value to return. Can be 'nearest', 'smaller' or 'larger'.
572+
Default: 'nearest'.
550573
551574
Examples
552575
--------
553576
>>> ds = ts.get_by_time((12, "Gyr"))
554577
>>> t = ts[0].quan(12, "Gyr")
555578
... ds = ts.get_by_time(t, tolerance=(100, "Myr"))
556579
"""
557-
return self._get_by_attribute("current_time", time, tolerance=tolerance)
580+
return self._get_by_attribute(
581+
"current_time", time, tolerance=tolerance, side=side
582+
)
558583

559-
def get_by_redshift(self, redshift: float, tolerance: Optional[float] = None):
584+
def get_by_redshift(
585+
self,
586+
redshift: float,
587+
tolerance: Optional[float] = None,
588+
side: Union[
589+
Literal["nearest"], Literal["smaller"], Literal["larger"]
590+
] = "nearest",
591+
):
560592
"""
561593
Get a dataset at or near to a given time.
562594
@@ -569,6 +601,9 @@ def get_by_redshift(self, redshift: float, tolerance: Optional[float] = None):
569601
within the tolerance value. If None, simply return the
570602
nearest dataset.
571603
Default: None.
604+
side : str
605+
The side of the value to return. Can be 'nearest', 'smaller' or 'larger'.
606+
Default: 'nearest'.
572607
573608
Examples
574609
--------

0 commit comments

Comments
 (0)