Skip to content

Commit 33415c6

Browse files
support matching of a sample against a group of selected samples only
1 parent 4e3ca41 commit 33415c6

File tree

11 files changed

+200
-8
lines changed

11 files changed

+200
-8
lines changed

mcrit/Worker.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from mcrit.matchers.MatcherQuery import MatcherQuery
3333
from mcrit.matchers.MatcherSample import MatcherSample
3434
from mcrit.matchers.MatcherVs import MatcherVs
35+
from mcrit.matchers.MatcherVsGroup import MatcherVsGroup
3536
from mcrit.minhash.MinHasher import MinHasher
3637
from mcrit.queue.LocalQueue import Job
3738
from mcrit.queue.QueueFactory import QueueFactory
@@ -499,6 +500,27 @@ def getMatchesForSampleVs(
499500
match_report = matcher.getMatchesForSample(sample_id, other_sample_id)
500501
return match_report
501502

503+
# Reports PROGRESS
504+
@Remote(progress=True)
505+
def getMatchesForSampleVsGroup(
506+
self,
507+
sample_id,
508+
other_sample_ids:List[int],
509+
minhash_threshold=None,
510+
pichash_size=None,
511+
band_matches_required=None,
512+
progress_reporter=NoProgressReporter()
513+
):
514+
matcher = MatcherVsGroup(
515+
self,
516+
minhash_threshold=minhash_threshold,
517+
pichash_size=pichash_size,
518+
band_matches_required=band_matches_required,
519+
progress_reporter=progress_reporter
520+
)
521+
match_report = matcher.getMatchesForSample(sample_id, other_sample_ids)
522+
return match_report
523+
502524

503525
@Remote()
504526
def combineMatchesToCross(self, sample_to_job_id):

mcrit/client/McritClient.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def setUsername(self, username):
7171
self.headers.update({"username": username})
7272

7373
def _getMatchingRequestParams(
74-
self, minhash_threshold=None, pichash_size=None, force_recalculation=None, band_matches_required=None, exclude_self_matches=False
74+
self, minhash_threshold=None, pichash_size=None, force_recalculation=None, band_matches_required=None, exclude_self_matches=False, sample_group_only=False
7575
):
7676
params = {}
7777
if minhash_threshold is not None:
@@ -84,6 +84,8 @@ def _getMatchingRequestParams(
8484
params["band_matches_required"] = band_matches_required
8585
if exclude_self_matches:
8686
params["exclude_self_matches"] = True
87+
if sample_group_only:
88+
params["sample_group_only"] = True
8789
return params
8890

8991
def respawn(self):
@@ -471,13 +473,14 @@ def requestMatchesForSampleVs(
471473
def requestMatchesCross(
472474
self,
473475
sample_ids,
476+
sample_group_only=False,
474477
minhash_threshold=None,
475478
pichash_size=None,
476479
band_matches_required=None,
477480
force_recalculation=False,
478481
) -> None:
479482
params = self._getMatchingRequestParams(
480-
minhash_threshold, pichash_size, force_recalculation, band_matches_required
483+
minhash_threshold, pichash_size, force_recalculation, band_matches_required, sample_group_only=sample_group_only
481484
)
482485
response = requests.get(
483486
f"{self.mcrit_server}/matches/sample/cross/{','.join([str(id) for id in sample_ids])}",

mcrit/index/MinHashIndex.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,7 @@ def getMatchesForMappedBinary(self, binary, base_address, minhash_threshold=None
291291
def getMatchesForUnmappedBinary(self, binary, minhash_threshold=None):
292292
def getMatchesForSample(self, sample_id, minhash_threshold=None):
293293
def getMatchesForSampleVs(self, sample_id, other_sample_id, minhash_threshold=None):
294+
def getMatchesForSampleVsGroup(self, sample_id, other_sample_ids, minhash_threshold=None):
294295
def getAggregatedMatchesForSample(self, sample_id, minhash_threshold=None):
295296
def getUniqueBlocks(self, sample_ids):
296297
def addBinarySample(self, binary, is_dump, bitness, base_address,):
@@ -333,11 +334,14 @@ def addReportFile(self, report_filepath, calculate_hashes=True, calculate_matche
333334
report = SmdaReport.fromDict(report_json)
334335
return self.addReport(report, calculate_hashes=calculate_hashes, calculate_matches=calculate_matches)
335336

336-
def getMatchesCross(self, sample_ids:List[int], force_recalculation=False, **params):
337+
def getMatchesCross(self, sample_ids:List[int], sample_group_only=False, force_recalculation=False, **params):
337338
storage = self.getStorage()
338339
sample_to_job_id = {}
339340
for id in sample_ids:
340-
job_id = self.getMatchesForSample(id, force_recalculation=force_recalculation, **params)
341+
if sample_group_only:
342+
job_id = self.getMatchesForSampleVsGroup(id, [sid for sid in sample_ids if sid != id], force_recalculation=force_recalculation, **params)
343+
else:
344+
job_id = self.getMatchesForSample(id, force_recalculation=force_recalculation, **params)
341345
sample_to_job_id[id] = job_id
342346
return self.combineMatchesToCross(sample_to_job_id, await_jobs=[*sample_to_job_id.values()], force_recalculation=force_recalculation)
343347

mcrit/matchers/MatcherInterface.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -581,6 +581,7 @@ def _craftResultDict(
581581
"info": {
582582
"job": None,
583583
"sample": self._sample_info,
584+
"type": "",
584585
},
585586
"matches": {
586587
"aggregation": {

mcrit/matchers/MatcherQuery.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,9 @@ def getMatchesForSmdaReport(self, smda_report: "SmdaReport"):
3232
tmp_function_entries_dict[minhash.function_id].minhash_shingle_composition = minhash.getComposition()
3333
self._function_entries = list(tmp_function_entries_dict.values())
3434

35-
return self._getMatchesRoutine()
35+
match_report = self._getMatchesRoutine()
36+
match_report["info"]["type"] = "matcher_query"
37+
return match_report
3638

3739
def _getPicHashMatches(self) -> Dict[int, Set[Tuple[int, int, int]]]:
3840
pichash_matches = {}

mcrit/matchers/MatcherQueryFunction.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,9 @@ def getMatchesForSmdaFunction(self, smda_report: "SmdaReport"):
3434
minhash = self._worker.minhasher._calculateMinHash(smda_function)
3535
function_entry = FunctionEntry(self._sample_entry, smda_function, -1, minhash)
3636
self._function_entries = [function_entry]
37-
return self._getMatchesRoutine()
37+
match_report = self._getMatchesRoutine()
38+
match_report["info"]["type"] = "matcher_query_function"
39+
return match_report
3840

3941
def _getPicHashMatches(self) -> Dict[int, Set[Tuple[int, int, int]]]:
4042
pichash_matches = {}

mcrit/matchers/MatcherSample.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,9 @@ def getMatchesForSample(self, sample_id: int):
1313
self._sample_id = sample_id
1414
sample_entry = self._storage.getSampleById(sample_id)
1515
self._sample_info = sample_entry.toDict()
16-
17-
return self._getMatchesRoutine()
16+
match_report = self._getMatchesRoutine()
17+
match_report["info"]["type"] = "matcher_sample"
18+
return match_report
1819

1920
def _getPicHashMatches(self) -> Dict[int, Set[Tuple[int, int, int]]]:
2021
return self._storage.getPicHashMatchesBySampleId(self._sample_id)

mcrit/matchers/MatcherVs.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def getMatchesForSample(self, sample_id:int, other_sample_id:int):
2222
other_sample_entry = self._storage.getSampleById(other_sample_id)
2323
other_sample_info = other_sample_entry.toDict()
2424
matching_report["other_sample_info"] = other_sample_info
25+
matching_report["info"]["type"] = "matcher_vs"
2526
return matching_report
2627

2728
def _getPicHashMatches(self) -> Dict[int, Set[Tuple[int, int, int]]]:

mcrit/matchers/MatcherVsGroup.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import logging
2+
from typing import Dict, Set, Tuple, List
3+
4+
from mcrit.matchers.MatcherInterface import MatcherInterface, add_duration
5+
6+
7+
# Only do basicConfig if no handlers have been configured
8+
if len(logging._handlerList) == 0:
9+
logging.basicConfig(level=logging.INFO, format="%(asctime)-15s %(message)s")
10+
LOGGER = logging.getLogger(__name__)
11+
12+
13+
class MatcherVsGroup(MatcherInterface):
14+
""" Matcher to compare functions from one sample against functions from a group of other samples."""
15+
16+
def _additional_setup(self):
17+
self._other_function_entries = []
18+
self._sample_to_lib_info = {}
19+
self._sample_id_to_entry = {}
20+
self._sample_id = None
21+
22+
@add_duration
23+
def getMatchesForSample(self, sample_id:int, other_sample_ids:List[int]):
24+
self._function_entries = self._storage.getFunctionsBySampleId(sample_id)
25+
for other_sample_id in other_sample_ids:
26+
self._other_function_entries.extend(self._storage.getFunctionsBySampleId(other_sample_id))
27+
self._sample_id = sample_id
28+
sample_entry = self._storage.getSampleById(sample_id)
29+
self._sample_info = sample_entry.toDict()
30+
31+
LOGGER.info("Performing matching of sample %d against %d other samples, with %d functions total.", sample_id, len(other_sample_ids), len(self._function_entries) + len(self._other_function_entries))
32+
33+
matching_report = self._getMatchesRoutine()
34+
35+
matching_report["info"]["type"] = "matcher_vs_group"
36+
matching_report["other_sample_infos"] = []
37+
for other_sample_id in other_sample_ids:
38+
other_sample_entry = self._storage.getSampleById(other_sample_id)
39+
other_sample_info = other_sample_entry.toDict()
40+
matching_report["other_sample_infos"].append(other_sample_info)
41+
return matching_report
42+
43+
def _getPicHashMatches(self) -> Dict[int, Set[Tuple[int, int, int]]]:
44+
by_pichash = {}
45+
for function_entry in self._function_entries:
46+
pic_entry = by_pichash.get(function_entry.pichash, [])
47+
pic_entry.append((function_entry.family_id, function_entry.sample_id, function_entry.function_id))
48+
by_pichash[function_entry.pichash] = pic_entry
49+
for function_entry in self._other_function_entries:
50+
if function_entry.pichash in by_pichash:
51+
pic_entry = by_pichash.get(function_entry.pichash, None)
52+
pic_entry.append((function_entry.family_id, function_entry.sample_id, function_entry.function_id))
53+
by_pichash[function_entry.pichash] = pic_entry
54+
return by_pichash
55+
56+
def _createMinHashCandidateGroups(self, start=0, end=None) -> Dict[int, Set[int]]:
57+
# find candidates based on bands
58+
candidate_groups = super()._createMinHashCandidateGroups(start, end)
59+
60+
allowed_function_ids = set([entry.function_id for entry in self._other_function_entries])
61+
# NOTE Also include function ids of entry a to allow self-matches
62+
allowed_function_ids.update([entry.function_id for entry in self._function_entries])
63+
for fid, candidates in candidate_groups.items():
64+
candidate_groups[fid] = candidates.intersection(allowed_function_ids)
65+
return candidate_groups

mcrit/server/utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ def getMatchingParams(req_params):
2929
if key == "force_recalculation":
3030
if value.lower() == "true":
3131
parameters["force_recalculation"] = True
32+
if key == "sample_group_only":
33+
if value.lower() == "true":
34+
parameters["sample_group_only"] = True
3235
if key == "band_matches_required":
3336
band_matches_required = int(value)
3437
band_matches_required = max(0, band_matches_required)

0 commit comments

Comments
 (0)