13
13
14
14
import click
15
15
import sqlitedict
16
+ import numpy as np
16
17
17
18
from profold2 .data .dataset import ProteinStructureDataset
18
19
from profold2 .data .parsers import parse_fasta , parse_a3m
@@ -615,6 +616,7 @@ def attr_update_weight_and_task(**args):
615
616
)
616
617
@click .option ("--mask" , type = str , default = "-" , hidden = True )
617
618
@click .option ("--task_def" , type = str , default = json .dumps (task .make_def ()), hidden = True )
619
+ @click .option ("--task_pid_prefix" , type = str , default = "tcr_pmhc_" , hidden = True )
618
620
@click .option (
619
621
"--chunksize" ,
620
622
type = int ,
@@ -642,7 +644,23 @@ def predict(**args):
642
644
max_var_depth = None
643
645
)
644
646
647
+ def _parse_result (a3m_string ):
648
+ _ , descriptions = parse_fasta (a3m_string )
649
+ for fields in map (lambda x : x .split ("\t " ), descriptions ):
650
+ pid , fields = fields [0 ], fields [1 :]
651
+ if pid .startswith (args .task_pid_prefix ):
652
+ pred = [None ] * task .task_num
653
+ for field in fields :
654
+ i = field .find (":" )
655
+ if i != - 1 :
656
+ if field [:i ] == "Elo_score" :
657
+ pred = json .loads (field [i + 1 :])
658
+ break
659
+ yield pid , pred
660
+
645
661
for model in args .model :
662
+ pred_dict = defaultdict (list )
663
+
646
664
for ref_pkl in glob .glob (os .path .join (args .ref_pkl , f"{ model } _*.pkl" )):
647
665
pdb_id = os .path .basename (ref_pkl )
648
666
assert pdb_id .startswith (f"{ model } _" )
@@ -663,11 +681,32 @@ def predict(**args):
663
681
)
664
682
with io .StringIO (a3m_string ) as a3m_file :
665
683
setattr (args , "a3m_file" , [a3m_file ])
666
- with open (
667
- os .path .join (args .output_dir , f"{ model } _{ pid } .a3m" ), "w"
668
- ) as output_file :
684
+
685
+ output_file_path = os .path .join (args .output_dir , f"{ model } _{ pid } .a3m" )
686
+ with open ( output_file_path , "w" ) as output_file :
669
687
setattr (args , "output_file" , output_file )
670
688
energy .main (args )
689
+ with open (output_file_path , "r" ) as output_file :
690
+ a3m_string = output_file .read ()
691
+ for pid , pred in _parse_result (a3m_string ):
692
+ pred_dict [pid ].append (pred )
693
+
694
+ with open (os .path .join (args .output_dir , f"{ model } _pred.csv" ), "w" ) as f :
695
+ writer = csv .DictWriter (f , fieldnames = ["id" , "chains" ] + task .task_name_list )
696
+ writer .writeheader ()
697
+ for pid , pred_list in pred_dict .items ():
698
+ chain_list , * _ = data .chain_list [pid ] # FIX: data.get_chain_list(protein_id)
699
+ assert chain_list , (pid , pid in data .chain_list , len (data .chain_list ))
700
+ _ , pred_mask = task .make_label (0 , chain_list )
701
+
702
+ pred_list , pred_mask = np .asarray (pred_list ), np .asarray (pred_mask )
703
+ pred_list = np .sum (pred_list * pred_mask [None ], axis = 0 ) / pred_list .shape [0 ]
704
+
705
+ row = {"id" : pid , "chains" : "_" .join (chain_list )}
706
+ for idx , (pred , mask ) in enumerate (zip (pred_list , pred_mask )):
707
+ if mask :
708
+ row [task .task_name_list [idx ]] = pred
709
+ writer .writerow (row )
671
710
672
711
673
712
if __name__ == "__main__" :
0 commit comments