13
13
14
14
import click
15
15
import sqlitedict
16
+ import numpy as np
16
17
17
18
from profold2 .data .dataset import ProteinStructureDataset
18
19
from profold2 .data .parsers import parse_fasta , parse_a3m
@@ -615,6 +616,7 @@ def attr_update_weight_and_task(**args):
615
616
)
616
617
@click .option ("--mask" , type = str , default = "-" , hidden = True )
617
618
@click .option ("--task_def" , type = str , default = json .dumps (task .make_def ()), hidden = True )
619
+ @click .option ("--task_pid_prefix" , type = str , default = "tcr_pmhc_" , hidden = True )
618
620
@click .option (
619
621
"--chunksize" ,
620
622
type = int ,
@@ -642,7 +644,23 @@ def predict(**args):
642
644
max_var_depth = None
643
645
)
644
646
647
+ def _parse_result (a3m_string ):
648
+ _ , descriptions = parse_fasta (a3m_string )
649
+ for fields in map (lambda x : x .split ("\t " ), descriptions ):
650
+ pid , fields = fields [0 ], fields [1 :]
651
+ if pid .startswith (args .task_pid_prefix ):
652
+ pred = [None ] * task .task_num
653
+ for field in fields :
654
+ i = field .find (":" )
655
+ if i != - 1 :
656
+ if field [:i ] == "Elo_score" :
657
+ pred = json .loads (field [i + 1 :])
658
+ break
659
+ yield pid , pred
660
+
645
661
for model in args .model :
662
+ pred_dict = defaultdict (list )
663
+
646
664
for ref_pkl in glob .glob (os .path .join (args .ref_pkl , f"{ model } _*.pkl" )):
647
665
pdb_id = os .path .basename (ref_pkl )
648
666
assert pdb_id .startswith (f"{ model } _" )
@@ -653,21 +671,45 @@ def predict(**args):
653
671
if args .verbose :
654
672
print (f"predict affinity ranking score with { model } :{ ref_pkl } " )
655
673
674
+ # model params
656
675
setattr (args , "model_file" , ref_pkl )
657
676
setattr (args , "model_ckpt" , os .path .join (args .ref_pkl , f"{ model } _model.pth" ))
658
677
678
+ # prepare variants
659
679
feat = data .get_multimer (compose_pid (pid , "P" ), chains )
660
680
assert len (feat ["str_var" ]) == len (feat ["variant_pid" ])
661
681
a3m_string = "\n " .join (
662
682
f">{ pid } \n { var } " for pid , var in zip (feat ["variant_pid" ], feat ["str_var" ])
663
683
)
664
684
with io .StringIO (a3m_string ) as a3m_file :
665
685
setattr (args , "a3m_file" , [a3m_file ])
666
- with open (
667
- os .path .join (args .output_dir , f"{ model } _{ pid } .a3m" ), "w"
668
- ) as output_file :
686
+
687
+ output_file_path = os .path .join (args .output_dir , f"{ model } _{ pid } .a3m" )
688
+ with open ( output_file_path , "w" ) as output_file :
669
689
setattr (args , "output_file" , output_file )
670
- energy .main (args )
690
+ energy .main (args ) # calc the Elo-score
691
+ with open (output_file_path , "r" ) as output_file :
692
+ a3m_string = output_file .read ()
693
+ for pid , pred in _parse_result (a3m_string ):
694
+ pred_dict [pid ].append (pred )
695
+
696
+ # write results to csv
697
+ with open (os .path .join (args .output_dir , f"{ model } _pred.csv" ), "w" ) as f :
698
+ writer = csv .DictWriter (f , fieldnames = ["id" , "chains" ] + task .task_name_list )
699
+ writer .writeheader ()
700
+ for pid , pred_list in pred_dict .items ():
701
+ chain_list , * _ = data .chain_list [pid ] # FIX: data.get_chain_list(protein_id)
702
+ assert chain_list , (pid , pid in data .chain_list , len (data .chain_list ))
703
+ _ , pred_mask = task .make_label (0 , chain_list )
704
+
705
+ pred_list , pred_mask = np .asarray (pred_list ), np .asarray (pred_mask )
706
+ pred_list = np .sum (pred_list * pred_mask [None ], axis = 0 ) / pred_list .shape [0 ]
707
+
708
+ row = {"id" : pid , "chains" : "_" .join (chain_list )}
709
+ for idx , (pred , mask ) in enumerate (zip (pred_list , pred_mask )):
710
+ if mask :
711
+ row [task .task_name_list [idx ]] = pred
712
+ writer .writerow (row )
671
713
672
714
673
715
if __name__ == "__main__" :
0 commit comments