6
6
import weakref
7
7
from abc import ABC , abstractmethod
8
8
from functools import wraps
9
- from typing import Optional , Union
9
+ from typing import Literal , Optional , Union
10
10
11
11
import numpy as np
12
12
from more_itertools import always_iterable
@@ -445,6 +445,9 @@ def _get_by_attribute(
445
445
attribute : str ,
446
446
value : Union [unyt_quantity , tuple [float , str ]],
447
447
tolerance : Union [None , unyt_quantity , tuple [float , str ]] = None ,
448
+ side : Union [
449
+ Literal ["nearest" ], Literal ["smaller" ], Literal ["larger" ]
450
+ ] = "nearest" ,
448
451
) -> "Dataset" :
449
452
r"""
450
453
Get a dataset at or near to a given value.
@@ -462,8 +465,16 @@ def _get_by_attribute(
462
465
within the tolerance value. If None, simply return the
463
466
nearest dataset.
464
467
Default: None.
468
+ side : str
469
+ The side of the value to return. Can be 'nearest', 'smaller' or 'larger'.
470
+ Default: 'nearest'.
465
471
"""
466
472
473
+ if side not in ("nearest" , "smaller" , "larger" ):
474
+ raise ValueError (
475
+ f"side must be 'nearest', 'smaller' or 'larger', got { side } "
476
+ )
477
+
467
478
# Use a binary search to find the closest value
468
479
iL = 0
469
480
iH = len (self ._pre_outputs ) - 1
@@ -518,7 +529,13 @@ def _get_by_attribute(
518
529
dsL = dsH = dsM
519
530
break
520
531
521
- if abs (value - getattr (dsL , attribute )) < abs (value - getattr (dsH , attribute )):
532
+ if side == "smaller" :
533
+ ds_best = dsL if sign > 0 else dsH
534
+ elif side == "larger" :
535
+ ds_best = dsH if sign > 0 else dsL
536
+ elif abs (value - getattr (dsL , attribute )) < abs (
537
+ value - getattr (dsH , attribute )
538
+ ):
522
539
ds_best = dsL
523
540
else :
524
541
ds_best = dsH
@@ -534,6 +551,9 @@ def get_by_time(
534
551
self ,
535
552
time : Union [unyt_quantity , tuple ],
536
553
tolerance : Union [None , unyt_quantity , tuple ] = None ,
554
+ side : Union [
555
+ Literal ["nearest" ], Literal ["smaller" ], Literal ["larger" ]
556
+ ] = "nearest" ,
537
557
):
538
558
"""
539
559
Get a dataset at or near to a given time.
@@ -547,16 +567,28 @@ def get_by_time(
547
567
within the tolerance value. If None, simply return the
548
568
nearest dataset.
549
569
Default: None.
570
+ side : str
571
+ The side of the value to return. Can be 'nearest', 'smaller' or 'larger'.
572
+ Default: 'nearest'.
550
573
551
574
Examples
552
575
--------
553
576
>>> ds = ts.get_by_time((12, "Gyr"))
554
577
>>> t = ts[0].quan(12, "Gyr")
555
578
... ds = ts.get_by_time(t, tolerance=(100, "Myr"))
556
579
"""
557
- return self ._get_by_attribute ("current_time" , time , tolerance = tolerance )
580
+ return self ._get_by_attribute (
581
+ "current_time" , time , tolerance = tolerance , side = side
582
+ )
558
583
559
- def get_by_redshift (self , redshift : float , tolerance : Optional [float ] = None ):
584
+ def get_by_redshift (
585
+ self ,
586
+ redshift : float ,
587
+ tolerance : Optional [float ] = None ,
588
+ side : Union [
589
+ Literal ["nearest" ], Literal ["smaller" ], Literal ["larger" ]
590
+ ] = "nearest" ,
591
+ ):
560
592
"""
561
593
Get a dataset at or near to a given time.
562
594
@@ -569,6 +601,9 @@ def get_by_redshift(self, redshift: float, tolerance: Optional[float] = None):
569
601
within the tolerance value. If None, simply return the
570
602
nearest dataset.
571
603
Default: None.
604
+ side : str
605
+ The side of the value to return. Can be 'nearest', 'smaller' or 'larger'.
606
+ Default: 'nearest'.
572
607
573
608
Examples
574
609
--------
0 commit comments