Skip to content

Commit 1176303

Browse files
committed
Refactor DataQuery equality checks
1 parent 57b36a9 commit 1176303

File tree

1 file changed

+49
-45
lines changed

1 file changed

+49
-45
lines changed

satpy/dataset/dataid.py

Lines changed: 49 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -299,58 +299,15 @@ def equal(self, other: Any, shared_keys: bool = False) -> bool:
299299

300300
# if other is a DataID then must match this query exactly
301301
o_is_id = hasattr(other, "id_keys")
302-
keys_to_match = self._keys_to_compare(sdict, odict, o_is_id, shared_keys)
302+
keys_to_match = _keys_to_compare(sdict, odict, o_is_id, shared_keys)
303303
if not keys_to_match:
304304
return False
305305

306306
for key in keys_to_match:
307-
if not self._compare_key_equality(sdict, odict, key, o_is_id):
307+
if not _compare_key_equality(sdict, odict, key, o_is_id):
308308
return False
309309
return True
310310

311-
@staticmethod
312-
def _keys_to_compare(sdict: dict, odict: dict, o_is_id: bool, shared_keys: bool) -> set:
313-
keys_to_match = set(sdict.keys())
314-
if not o_is_id:
315-
# if another DataQuery, then compare both sets of keys
316-
keys_to_match |= set(odict.keys())
317-
if shared_keys:
318-
# only compare with the keys that both objects share
319-
keys_to_match &= set(odict.keys())
320-
return keys_to_match
321-
322-
@staticmethod
323-
def _compare_key_equality(sdict: dict, odict: dict, key: str, o_is_id: bool) -> bool:
324-
if key not in sdict:
325-
return False
326-
sval = sdict[key]
327-
if sval == "*":
328-
return True
329-
330-
if key not in odict:
331-
return False
332-
oval = odict[key]
333-
if oval == "*":
334-
# Gotcha: if a DataID contains a "*" this could cause
335-
# unexpected matches. A DataID is not expected to use "*"
336-
return True
337-
338-
if isinstance(sval, list) or isinstance(oval, list):
339-
# multiple options to match
340-
if not isinstance(sval, list):
341-
# query to query comparison, make a list to iterate over
342-
sval = [sval]
343-
if o_is_id:
344-
return oval in sval
345-
346-
# we're matching against a DataQuery who could have its own list
347-
if not isinstance(oval, list):
348-
oval = [oval]
349-
s_in_o = any(_sval in oval for _sval in sval)
350-
o_in_s = any(_oval in sval for _oval in oval)
351-
return s_in_o or o_in_s
352-
return oval == sval
353-
354311
def __hash__(self):
355312
"""Hash."""
356313
fields = []
@@ -559,3 +516,50 @@ def update_id_with_query(orig_id: DataID, query: DataQuery) -> DataID:
559516
id_keys = orig_id_keys if all(key in orig_id_keys for key in new_id_dict) else default_id_keys_config
560517
new_id = DataID(id_keys, **new_id_dict)
561518
return new_id
519+
520+
521+
def _keys_to_compare(sdict: dict, odict: dict, o_is_id: bool, shared_keys: bool) -> set:
522+
keys_to_match = set(sdict.keys())
523+
if not o_is_id:
524+
# if another DataQuery, then compare both sets of keys
525+
keys_to_match |= set(odict.keys())
526+
if shared_keys:
527+
# only compare with the keys that both objects share
528+
keys_to_match &= set(odict.keys())
529+
return keys_to_match
530+
531+
532+
def _compare_key_equality(sdict: dict, odict: dict, key: str, o_is_id: bool) -> bool:
533+
if key not in sdict:
534+
return False
535+
sval = sdict[key]
536+
if sval == "*":
537+
return True
538+
539+
if key not in odict:
540+
return False
541+
oval = odict[key]
542+
if oval == "*":
543+
# Gotcha: if a DataID contains a "*" this could cause
544+
# unexpected matches. A DataID is not expected to use "*"
545+
return True
546+
547+
return _compare_values(sval, oval, o_is_id)
548+
549+
550+
def _compare_values(sval: Any, oval: Any, o_is_id: bool) -> bool:
551+
if isinstance(sval, list) or isinstance(oval, list):
552+
# multiple options to match
553+
if not isinstance(sval, list):
554+
# query to query comparison, make a list to iterate over
555+
sval = [sval]
556+
if o_is_id:
557+
return oval in sval
558+
559+
# we're matching against a DataQuery who could have its own list
560+
if not isinstance(oval, list):
561+
oval = [oval]
562+
s_in_o = any(_sval in oval for _sval in sval)
563+
o_in_s = any(_oval in sval for _oval in oval)
564+
return s_in_o or o_in_s
565+
return oval == sval

0 commit comments

Comments
 (0)