|
1 | 1 | import os, scipy.sparse, pickle, numpy as np, logging
|
2 | 2 | from collections import Counter
|
3 |
| -from functools import partial |
| 3 | +from functools import partial, reduce |
4 | 4 | log = logging.getLogger(__name__)
|
5 | 5 |
|
6 | 6 | import pkgmgr as opentf
|
@@ -476,36 +476,23 @@ def merge_teams_by_skills(cls, teamsvecs, inplace=False, distinct=False): #https
|
476 | 476 | log.info(f'Merging teams whose subset of skills are the same ...')
|
477 | 477 |
|
478 | 478 | vecs = teamsvecs if inplace else copy.deepcopy(teamsvecs)
|
479 |
| - merge_list = {} |
| 479 | + skills_rows_map = {} # {skill: [row indices]} |
480 | 480 |
|
481 |
| - # in the following loop rows that have similar skills are founded |
482 |
| - for i in range(len(vecs['skill'].rows)): |
483 |
| - merge_list[f'{i}'] = set() |
484 |
| - for j in range(i + 1, len(vecs['skill'].rows)): |
485 |
| - if vecs['skill'].rows[i] == vecs['skill'].rows[j]: merge_list[f'{i}'].add(j) |
486 |
| - if len(merge_list[f'{i}']) < 1: del merge_list[f'{i}'] |
487 |
| - |
488 |
| - delete_set = set() |
489 |
| - for key in merge_list.keys(): |
490 |
| - for item in merge_list[key]: delete_set.add(item) |
491 |
| - |
492 |
| - for item in delete_set: |
493 |
| - try: del merge_list[f'{item}'] |
494 |
| - except KeyError: pass |
| 481 | + for i in range(vecs['skill'].shape[0]): |
| 482 | + current_skills = tuple(vecs['skill'].rows[i]) |
| 483 | + skills_rows_map.setdefault(current_skills, []).append(i) |
495 | 484 |
|
496 | 485 | del_list = []
|
497 |
| - for key_ in merge_list.keys(): |
498 |
| - for value_ in merge_list[key_]: |
499 |
| - del_list.append(value_) |
500 |
| - vec1 = vecs['member'].getrow(int(key_)) |
501 |
| - vec2 = vecs['member'].getrow(value_) |
502 |
| - result = np.add(vec1, vec2) |
503 |
| - result[result != 0] = 1 |
504 |
| - vecs['member'][int(key_), :] = scipy.sparse.lil_matrix(result) |
505 |
| - vecs['member'][int(value_), :] = scipy.sparse.lil_matrix(result) |
| 486 | + |
| 487 | + for _, rows in skills_rows_map.items(): |
| 488 | + if len(rows) < 2: continue |
| 489 | + del_list.extend(rows[1:]) |
| 490 | + new_members = reduce(lambda x, y: x.maximum(y), (vecs['member'].getrow(r) for r in rows)) |
| 491 | + for row in rows: vecs['member'][row, :] = new_members |
506 | 492 | if distinct:
|
507 | 493 | vecs['skill'] = scipy.sparse.lil_matrix(np.delete(vecs['skill'].toarray(), del_list, axis=0))
|
508 | 494 | vecs['member'] = scipy.sparse.lil_matrix(np.delete(vecs['member'].toarray(), del_list, axis=0))
|
| 495 | + |
509 | 496 | return vecs
|
510 | 497 |
|
511 | 498 | @staticmethod
|
|
0 commit comments