Skip to content

Commit 7d98494

Browse files
committed
feat: write ranking affinity score to csv file
1 parent e8d8059 commit 7d98494

File tree

1 file changed

+42
-3
lines changed

1 file changed

+42
-3
lines changed

main.py

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
import click
1515
import sqlitedict
16+
import numpy as np
1617

1718
from profold2.data.dataset import ProteinStructureDataset
1819
from profold2.data.parsers import parse_fasta, parse_a3m
@@ -615,6 +616,7 @@ def attr_update_weight_and_task(**args):
615616
)
616617
@click.option("--mask", type=str, default="-", hidden=True)
617618
@click.option("--task_def", type=str, default=json.dumps(task.make_def()), hidden=True)
619+
@click.option("--task_pid_prefix", type=str, default="tcr_pmhc_", hidden=True)
618620
@click.option(
619621
"--chunksize",
620622
type=int,
@@ -642,7 +644,23 @@ def predict(**args):
642644
max_var_depth=None
643645
)
644646

647+
def _parse_result(a3m_string):
648+
_, descriptions = parse_fasta(a3m_string)
649+
for fields in map(lambda x: x.split("\t"), descriptions):
650+
pid, fields = fields[0], fields[1:]
651+
if pid.startswith(args.task_pid_prefix):
652+
pred = [None] * task.task_num
653+
for field in fields:
654+
i = field.find(":")
655+
if i != -1:
656+
if field[:i] == "Elo_score":
657+
pred = json.loads(field[i + 1:])
658+
break
659+
yield pid, pred
660+
645661
for model in args.model:
662+
pred_dict = defaultdict(list)
663+
646664
for ref_pkl in glob.glob(os.path.join(args.ref_pkl, f"{model}_*.pkl")):
647665
pdb_id = os.path.basename(ref_pkl)
648666
assert pdb_id.startswith(f"{model}_")
@@ -663,11 +681,32 @@ def predict(**args):
663681
)
664682
with io.StringIO(a3m_string) as a3m_file:
665683
setattr(args, "a3m_file", [a3m_file])
666-
with open(
667-
os.path.join(args.output_dir, f"{model}_{pid}.a3m"), "w"
668-
) as output_file:
684+
685+
output_file_path = os.path.join(args.output_dir, f"{model}_{pid}.a3m")
686+
with open(output_file_path, "w") as output_file:
669687
setattr(args, "output_file", output_file)
670688
energy.main(args)
689+
with open(output_file_path, "r") as output_file:
690+
a3m_string = output_file.read()
691+
for pid, pred in _parse_result(a3m_string):
692+
pred_dict[pid].append(pred)
693+
694+
with open(os.path.join(args.output_dir, f"{model}_pred.csv"), "w") as f:
695+
writer = csv.DictWriter(f, fieldnames=["id", "chains"] + task.task_name_list)
696+
writer.writeheader()
697+
for pid, pred_list in pred_dict.items():
698+
chain_list, *_ = data.chain_list[pid] # FIX: data.get_chain_list(protein_id)
699+
assert chain_list, (pid, pid in data.chain_list, len(data.chain_list))
700+
_, pred_mask = task.make_label(0, chain_list)
701+
702+
pred_list, pred_mask = np.asarray(pred_list), np.asarray(pred_mask)
703+
pred_list = np.sum(pred_list * pred_mask[None], axis=0) / pred_list.shape[0]
704+
705+
row = {"id": pid, "chains": "_".join(chain_list)}
706+
for idx, (pred, mask) in enumerate(zip(pred_list, pred_mask)):
707+
if mask:
708+
row[task.task_name_list[idx]] = pred
709+
writer.writerow(row)
671710

672711

673712
if __name__ == "__main__":

0 commit comments

Comments
 (0)