14
14
# You should have received a copy of the GNU General Public License along with
15
15
# satpy. If not, see <http://www.gnu.org/licenses/>.
16
16
"""Dataset identifying objects."""
17
+ from __future__ import annotations
17
18
18
19
import logging
19
20
import numbers
20
21
from copy import copy , deepcopy
21
22
from enum import Enum
23
+ from functools import partial
22
24
from typing import Any , NoReturn
23
25
24
26
import numpy as np
@@ -261,7 +263,7 @@ def __getitem__(self, key):
261
263
"""Get an item."""
262
264
return self ._dict [key ]
263
265
264
- def __eq__ (self , other ) :
266
+ def __eq__ (self , other : Any ) -> bool :
265
267
"""Compare the DataQuerys.
266
268
267
269
A DataQuery is considered equal to another DataQuery if all keys
@@ -271,6 +273,20 @@ def __eq__(self, other):
271
273
contains additional elements. Any DataQuery elements with the value
272
274
``"*"`` are ignored.
273
275
276
+ """
277
+ return self .equal (other , shared_keys = False )
278
+
279
+ def equal (self , other : Any , shared_keys : bool = False ) -> bool :
280
+ """Compare this DataQuery to another DataQuery or a DataID.
281
+
282
+ Args:
283
+ other: Other DataQuery or DataID to compare against.
284
+ shared_keys: Limit keys being compared to those shared
285
+ by both objects. If False (default), then all of the
286
+ current query's keys are used when compared against
287
+ a DataID. If compared against another DataQuery then
288
+ all keys are compared between the two queries.
289
+
274
290
"""
275
291
sdict = self ._asdict ()
276
292
try :
@@ -287,6 +303,9 @@ def __eq__(self, other):
287
303
if not o_is_id :
288
304
# if another DataQuery, then compare both sets of keys
289
305
keys_to_match |= set (odict .keys ())
306
+ if shared_keys :
307
+ # only compare with the keys that both objects share
308
+ keys_to_match &= set (odict .keys ())
290
309
if not keys_to_match :
291
310
return False
292
311
@@ -374,9 +393,10 @@ def __repr__(self):
374
393
items = ("{}={}" .format (key , repr (val )) for key , val in zip (self ._fields , self ._values ))
375
394
return self .__class__ .__name__ + "(" + ", " .join (items ) + ")"
376
395
377
- def filter_dataids (self , dataid_container ):
396
+ def filter_dataids (self , dataid_container , shared_keys : bool = False ):
378
397
"""Filter DataIDs based on this query."""
379
- keys = list (filter (self .__eq__ , dataid_container ))
398
+ func = partial (self .equal , shared_keys = shared_keys )
399
+ keys = list (filter (func , dataid_container ))
380
400
return keys
381
401
382
402
def sort_dataids_with_preference (self , all_ids , preference ):
0 commit comments