Skip to content

Commit fbb8716

Browse files
committed
refactor: comments
1 parent e8d8059 commit fbb8716

File tree

3 files changed

+48
-4
lines changed

3 files changed

+48
-4
lines changed

main.py

Lines changed: 46 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
import click
1515
import sqlitedict
16+
import numpy as np
1617

1718
from profold2.data.dataset import ProteinStructureDataset
1819
from profold2.data.parsers import parse_fasta, parse_a3m
@@ -615,6 +616,7 @@ def attr_update_weight_and_task(**args):
615616
)
616617
@click.option("--mask", type=str, default="-", hidden=True)
617618
@click.option("--task_def", type=str, default=json.dumps(task.make_def()), hidden=True)
619+
@click.option("--task_pid_prefix", type=str, default="tcr_pmhc_", hidden=True)
618620
@click.option(
619621
"--chunksize",
620622
type=int,
@@ -642,7 +644,23 @@ def predict(**args):
642644
max_var_depth=None
643645
)
644646

647+
def _parse_result(a3m_string):
648+
_, descriptions = parse_fasta(a3m_string)
649+
for fields in map(lambda x: x.split("\t"), descriptions):
650+
pid, fields = fields[0], fields[1:]
651+
if pid.startswith(args.task_pid_prefix):
652+
pred = [None] * task.task_num
653+
for field in fields:
654+
i = field.find(":")
655+
if i != -1:
656+
if field[:i] == "Elo_score":
657+
pred = json.loads(field[i + 1:])
658+
break
659+
yield pid, pred
660+
645661
for model in args.model:
662+
pred_dict = defaultdict(list)
663+
646664
for ref_pkl in glob.glob(os.path.join(args.ref_pkl, f"{model}_*.pkl")):
647665
pdb_id = os.path.basename(ref_pkl)
648666
assert pdb_id.startswith(f"{model}_")
@@ -653,21 +671,45 @@ def predict(**args):
653671
if args.verbose:
654672
print(f"predict affinity ranking score with {model}:{ref_pkl}")
655673

674+
# model params
656675
setattr(args, "model_file", ref_pkl)
657676
setattr(args, "model_ckpt", os.path.join(args.ref_pkl, f"{model}_model.pth"))
658677

678+
# prepare variants
659679
feat = data.get_multimer(compose_pid(pid, "P"), chains)
660680
assert len(feat["str_var"]) == len(feat["variant_pid"])
661681
a3m_string = "\n".join(
662682
f">{pid}\n{var}" for pid, var in zip(feat["variant_pid"], feat["str_var"])
663683
)
664684
with io.StringIO(a3m_string) as a3m_file:
665685
setattr(args, "a3m_file", [a3m_file])
666-
with open(
667-
os.path.join(args.output_dir, f"{model}_{pid}.a3m"), "w"
668-
) as output_file:
686+
687+
output_file_path = os.path.join(args.output_dir, f"{model}_{pid}.a3m")
688+
with open(output_file_path, "w") as output_file:
669689
setattr(args, "output_file", output_file)
670-
energy.main(args)
690+
energy.main(args) # calc the Elo-score
691+
with open(output_file_path, "r") as output_file:
692+
a3m_string = output_file.read()
693+
for pid, pred in _parse_result(a3m_string):
694+
pred_dict[pid].append(pred)
695+
696+
# write results to csv
697+
with open(os.path.join(args.output_dir, f"{model}_pred.csv"), "w") as f:
698+
writer = csv.DictWriter(f, fieldnames=["id", "chains"] + task.task_name_list)
699+
writer.writeheader()
700+
for pid, pred_list in pred_dict.items():
701+
chain_list, *_ = data.chain_list[pid] # FIX: data.get_chain_list(protein_id)
702+
assert chain_list, (pid, pid in data.chain_list, len(data.chain_list))
703+
_, pred_mask = task.make_label(0, chain_list)
704+
705+
pred_list, pred_mask = np.asarray(pred_list), np.asarray(pred_mask)
706+
pred_list = np.sum(pred_list * pred_mask[None], axis=0) / pred_list.shape[0]
707+
708+
row = {"id": pid, "chains": "_".join(chain_list)}
709+
for idx, (pred, mask) in enumerate(zip(pred_list, pred_mask)):
710+
if mask:
711+
row[task.task_name_list[idx]] = pred
712+
writer.writerow(row)
671713

672714

673715
if __name__ == "__main__":

predict.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ done
134134
python ${CWD}/main.py peptide_align \
135135
--output_dir ${output_dir}/a3m \
136136
--target_db ${output_dir}/tcr_pmhc_P.fa \
137+
--target_db ${CWD}/data/tcr_pmhc_db_P.fa \
137138
--verbose \
138139
${CWD}/data/tcr_pmhc_db/fasta/*_P.fasta \
139140

task.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
task_num = len(task_mapping.keys())
1111

12+
task_name_list = ["pMHC", "pTCR", "TCR_pMHC"]
1213

1314
def make_def():
1415
task_def = defaultdict(list)

0 commit comments

Comments
 (0)