Skip to content

Commit d845483

Browse files
Merge branch 'fani-lab:main' into main
2 parents 55c8ab7 + 3c7ca1e commit d845483

File tree

1 file changed

+12
-25
lines changed

1 file changed

+12
-25
lines changed

src/cmn/team.py

Lines changed: 12 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import os, scipy.sparse, pickle, numpy as np, logging
22
from collections import Counter
3-
from functools import partial
3+
from functools import partial, reduce
44
log = logging.getLogger(__name__)
55

66
import pkgmgr as opentf
@@ -476,36 +476,23 @@ def merge_teams_by_skills(cls, teamsvecs, inplace=False, distinct=False): #https
476476
log.info(f'Merging teams whose subset of skills are the same ...')
477477

478478
vecs = teamsvecs if inplace else copy.deepcopy(teamsvecs)
479-
merge_list = {}
479+
skills_rows_map = {} # {skill: [row indices]}
480480

481-
# in the following loop rows that have similar skills are founded
482-
for i in range(len(vecs['skill'].rows)):
483-
merge_list[f'{i}'] = set()
484-
for j in range(i + 1, len(vecs['skill'].rows)):
485-
if vecs['skill'].rows[i] == vecs['skill'].rows[j]: merge_list[f'{i}'].add(j)
486-
if len(merge_list[f'{i}']) < 1: del merge_list[f'{i}']
487-
488-
delete_set = set()
489-
for key in merge_list.keys():
490-
for item in merge_list[key]: delete_set.add(item)
491-
492-
for item in delete_set:
493-
try: del merge_list[f'{item}']
494-
except KeyError: pass
481+
for i in range(vecs['skill'].shape[0]):
482+
current_skills = tuple(vecs['skill'].rows[i])
483+
skills_rows_map.setdefault(current_skills, []).append(i)
495484

496485
del_list = []
497-
for key_ in merge_list.keys():
498-
for value_ in merge_list[key_]:
499-
del_list.append(value_)
500-
vec1 = vecs['member'].getrow(int(key_))
501-
vec2 = vecs['member'].getrow(value_)
502-
result = np.add(vec1, vec2)
503-
result[result != 0] = 1
504-
vecs['member'][int(key_), :] = scipy.sparse.lil_matrix(result)
505-
vecs['member'][int(value_), :] = scipy.sparse.lil_matrix(result)
486+
487+
for _, rows in skills_rows_map.items():
488+
if len(rows) < 2: continue
489+
del_list.extend(rows[1:])
490+
new_members = reduce(lambda x, y: x.maximum(y), (vecs['member'].getrow(r) for r in rows))
491+
for row in rows: vecs['member'][row, :] = new_members
506492
if distinct:
507493
vecs['skill'] = scipy.sparse.lil_matrix(np.delete(vecs['skill'].toarray(), del_list, axis=0))
508494
vecs['member'] = scipy.sparse.lil_matrix(np.delete(vecs['member'].toarray(), del_list, axis=0))
495+
509496
return vecs
510497

511498
@staticmethod

0 commit comments

Comments
 (0)