6
6
import weakref
7
7
from abc import ABC , abstractmethod
8
8
from functools import wraps
9
- from typing import Optional , Union
9
+ from typing import Literal , Optional , Union
10
10
11
11
import numpy as np
12
12
from more_itertools import always_iterable
@@ -445,6 +445,7 @@ def _get_by_attribute(
445
445
attribute : str ,
446
446
value : Union [unyt_quantity , tuple [float , str ]],
447
447
tolerance : Union [None , unyt_quantity , tuple [float , str ]] = None ,
448
+ side : Union [Literal ["nearest" ], Literal ["left" ], Literal ["right" ]] = "nearest" ,
448
449
) -> "Dataset" :
449
450
r"""
450
451
Get a dataset at or near to a given value.
@@ -462,8 +463,14 @@ def _get_by_attribute(
462
463
within the tolerance value. If None, simply return the
463
464
nearest dataset.
464
465
Default: None.
466
+ side : str
467
+ The side of the value to return. Can be 'nearest', 'left' or 'right'.
468
+ Default: 'nearest'.
465
469
"""
466
470
471
+ if side not in ("nearest" , "left" , "right" ):
472
+ raise ValueError (f"side must be 'nearest', 'left' or 'right', got { side } ." )
473
+
467
474
# Use a binary search to find the closest value
468
475
iL = 0
469
476
iH = len (self ._pre_outputs ) - 1
@@ -518,7 +525,13 @@ def _get_by_attribute(
518
525
dsL = dsH = dsM
519
526
break
520
527
521
- if abs (value - getattr (dsL , attribute )) < abs (value - getattr (dsH , attribute )):
528
+ if side == "left" :
529
+ ds_best = dsL
530
+ elif side == "right" :
531
+ ds_best = dsH
532
+ elif abs (value - getattr (dsL , attribute )) < abs (
533
+ value - getattr (dsH , attribute )
534
+ ):
522
535
ds_best = dsL
523
536
else :
524
537
ds_best = dsH
@@ -534,6 +547,7 @@ def get_by_time(
534
547
self ,
535
548
time : Union [unyt_quantity , tuple ],
536
549
tolerance : Union [None , unyt_quantity , tuple ] = None ,
550
+ side : Union [Literal ["nearest" ], Literal ["left" ], Literal ["right" ]] = "nearest" ,
537
551
):
538
552
"""
539
553
Get a dataset at or near to a given time.
@@ -547,16 +561,26 @@ def get_by_time(
547
561
within the tolerance value. If None, simply return the
548
562
nearest dataset.
549
563
Default: None.
564
+ side : str
565
+ The side of the value to return. Can be 'nearest', 'left' or 'right'.
566
+ Default: 'nearest'.
550
567
551
568
Examples
552
569
--------
553
570
>>> ds = ts.get_by_time((12, "Gyr"))
554
571
>>> t = ts[0].quan(12, "Gyr")
555
572
... ds = ts.get_by_time(t, tolerance=(100, "Myr"))
556
573
"""
557
- return self ._get_by_attribute ("current_time" , time , tolerance = tolerance )
574
+ return self ._get_by_attribute (
575
+ "current_time" , time , tolerance = tolerance , side = side
576
+ )
558
577
559
- def get_by_redshift (self , redshift : float , tolerance : Optional [float ] = None ):
578
+ def get_by_redshift (
579
+ self ,
580
+ redshift : float ,
581
+ tolerance : Optional [float ] = None ,
582
+ side : Union [Literal ["nearest" ], Literal ["left" ], Literal ["right" ]] = "nearest" ,
583
+ ):
560
584
"""
561
585
Get a dataset at or near to a given time.
562
586
@@ -569,6 +593,9 @@ def get_by_redshift(self, redshift: float, tolerance: Optional[float] = None):
569
593
within the tolerance value. If None, simply return the
570
594
nearest dataset.
571
595
Default: None.
596
+ side : str
597
+ The side of the value to return. Can be 'nearest', 'left' or 'right'.
598
+ Default: 'nearest'.
572
599
573
600
Examples
574
601
--------
0 commit comments