|
| 1 | +import logging |
| 2 | +from typing import Dict, Set, Tuple, List |
| 3 | + |
| 4 | +from mcrit.matchers.MatcherInterface import MatcherInterface, add_duration |
| 5 | + |
| 6 | + |
| 7 | +# Only do basicConfig if no handlers have been configured |
| 8 | +if len(logging._handlerList) == 0: |
| 9 | + logging.basicConfig(level=logging.INFO, format="%(asctime)-15s %(message)s") |
| 10 | +LOGGER = logging.getLogger(__name__) |
| 11 | + |
| 12 | + |
| 13 | +class MatcherVsGroup(MatcherInterface): |
| 14 | + """ Matcher to compare functions from one sample against functions from a group of other samples.""" |
| 15 | + |
| 16 | + def _additional_setup(self): |
| 17 | + self._other_function_entries = [] |
| 18 | + self._sample_to_lib_info = {} |
| 19 | + self._sample_id_to_entry = {} |
| 20 | + self._sample_id = None |
| 21 | + |
| 22 | + @add_duration |
| 23 | + def getMatchesForSample(self, sample_id:int, other_sample_ids:List[int]): |
| 24 | + self._function_entries = self._storage.getFunctionsBySampleId(sample_id) |
| 25 | + for other_sample_id in other_sample_ids: |
| 26 | + self._other_function_entries.extend(self._storage.getFunctionsBySampleId(other_sample_id)) |
| 27 | + self._sample_id = sample_id |
| 28 | + sample_entry = self._storage.getSampleById(sample_id) |
| 29 | + self._sample_info = sample_entry.toDict() |
| 30 | + |
| 31 | + LOGGER.info("Performing matching of sample %d against %d other samples, with %d functions total.", sample_id, len(other_sample_ids), len(self._function_entries) + len(self._other_function_entries)) |
| 32 | + |
| 33 | + matching_report = self._getMatchesRoutine() |
| 34 | + |
| 35 | + matching_report["info"]["type"] = "matcher_vs_group" |
| 36 | + matching_report["other_sample_infos"] = [] |
| 37 | + for other_sample_id in other_sample_ids: |
| 38 | + other_sample_entry = self._storage.getSampleById(other_sample_id) |
| 39 | + other_sample_info = other_sample_entry.toDict() |
| 40 | + matching_report["other_sample_infos"].append(other_sample_info) |
| 41 | + return matching_report |
| 42 | + |
| 43 | + def _getPicHashMatches(self) -> Dict[int, Set[Tuple[int, int, int]]]: |
| 44 | + by_pichash = {} |
| 45 | + for function_entry in self._function_entries: |
| 46 | + pic_entry = by_pichash.get(function_entry.pichash, []) |
| 47 | + pic_entry.append((function_entry.family_id, function_entry.sample_id, function_entry.function_id)) |
| 48 | + by_pichash[function_entry.pichash] = pic_entry |
| 49 | + for function_entry in self._other_function_entries: |
| 50 | + if function_entry.pichash in by_pichash: |
| 51 | + pic_entry = by_pichash.get(function_entry.pichash, None) |
| 52 | + pic_entry.append((function_entry.family_id, function_entry.sample_id, function_entry.function_id)) |
| 53 | + by_pichash[function_entry.pichash] = pic_entry |
| 54 | + return by_pichash |
| 55 | + |
| 56 | + def _createMinHashCandidateGroups(self, start=0, end=None) -> Dict[int, Set[int]]: |
| 57 | + # find candidates based on bands |
| 58 | + candidate_groups = super()._createMinHashCandidateGroups(start, end) |
| 59 | + |
| 60 | + allowed_function_ids = set([entry.function_id for entry in self._other_function_entries]) |
| 61 | + # NOTE Also include function ids of entry a to allow self-matches |
| 62 | + allowed_function_ids.update([entry.function_id for entry in self._function_entries]) |
| 63 | + for fid, candidates in candidate_groups.items(): |
| 64 | + candidate_groups[fid] = candidates.intersection(allowed_function_ids) |
| 65 | + return candidate_groups |
0 commit comments