Skip to content

Commit 86a263c

Browse files
committed
refactor: comments
1 parent 9daaf37 commit 86a263c

File tree

3 files changed

+6
-2
lines changed

3 files changed

+6
-2
lines changed

main.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -671,9 +671,11 @@ def _parse_result(a3m_string):
671671
if args.verbose:
672672
print(f"predict affinity ranking score with {model}:{ref_pkl}")
673673

674+
# model params
674675
setattr(args, "model_file", ref_pkl)
675676
setattr(args, "model_ckpt", os.path.join(args.ref_pkl, f"{model}_model.pth"))
676677

678+
# prepare variants
677679
feat = data.get_multimer(compose_pid(pid, "P"), chains)
678680
assert len(feat["str_var"]) == len(feat["variant_pid"])
679681
a3m_string = "\n".join(
@@ -685,12 +687,13 @@ def _parse_result(a3m_string):
685687
output_file_path = os.path.join(args.output_dir, f"{model}_{pid}.a3m")
686688
with open(output_file_path, "w") as output_file:
687689
setattr(args, "output_file", output_file)
688-
energy.main(args)
690+
energy.main(args) # calc the Elo-score
689691
with open(output_file_path, "r") as output_file:
690692
a3m_string = output_file.read()
691693
for pid, pred in _parse_result(a3m_string):
692694
pred_dict[pid].append(pred)
693695

696+
# write results to csv
694697
with open(os.path.join(args.output_dir, f"{model}_pred.csv"), "w") as f:
695698
writer = csv.DictWriter(f, fieldnames=["id", "chains"] + task.task_name_list)
696699
writer.writeheader()

predict.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ done
134134
python ${CWD}/main.py peptide_align \
135135
--output_dir ${output_dir}/a3m \
136136
--target_db ${output_dir}/tcr_pmhc_P.fa \
137+
--target_db ${CWD}/data/tcr_pmhc_db_P.fa \
137138
--verbose \
138139
${CWD}/data/tcr_pmhc_db/fasta/*_P.fasta \
139140

0 commit comments

Comments
 (0)