28
28
29
29
from pyworkflow .protocol import params
30
30
31
- from pwchem .utils import performBatchThreading , findThreadFiles , concatFiles
31
+ from pwchem .utils import performBatchThreading , findThreadFiles , concatFiles , splitFile
32
32
from pwchem import Plugin as pwchemPlugin
33
33
from pwchem .constants import RDKIT_DIC
34
+ from pwchem .objects import SetOfSmallMolecules , SmallMolecule
34
35
35
36
from autodock import Plugin as autodockPlugin
36
37
from autodock .protocols import ProtChemAutodockGPU
@@ -62,8 +63,15 @@ def _defineParams(self, form):
62
63
63
64
form .addSection (label = "Prediction" )
64
65
group = form .addGroup ('Input' )
66
+ group .addParam ('useLibrary' , params .BooleanParam , label = 'Use library as input : ' , default = False ,
67
+ expertLevel = params .LEVEL_ADVANCED ,
68
+ help = 'Whether to use a SMI library SmallMoleculesLibrary object as input' )
69
+
70
+ group .addParam ('inputLibrary' , params .PointerParam , pointerClass = "SmallMoleculesLibrary" ,
71
+ label = 'Input library: ' , condition = 'useLibrary' ,
72
+ help = "Input Small molecules library to predict" )
65
73
group .addParam ('inputSmallMolecules' , params .PointerParam , pointerClass = "SetOfSmallMolecules" ,
66
- label = 'Input small molecules: ' , allowsNull = False ,
74
+ label = 'Input small molecules: ' , allowsNull = False , condition = 'not useLibrary' ,
67
75
help = "Input small molecules to be scored with the model" )
68
76
69
77
group = form .addGroup ('Training' )
@@ -154,8 +162,15 @@ def predictionStep(self):
154
162
sysName = self .getSystemName ()
155
163
scriptName = self .getScriptPath ()
156
164
157
- inMols = self .inputSmallMolecules .get ()
158
- smisFiles = self .buildSMIsFile (inMols , writeScores = False )
165
+ if not self .useLibrary .get ():
166
+ inMols = self .inputSmallMolecules .get ()
167
+ smisFiles = self .buildSMIsFile (inMols , writeScores = False )
168
+ else :
169
+ nt = self .numberOfThreads .get ()
170
+ inSMIFile = self .inputLibrary .get ().getFileName ()
171
+ smiFile = self .getInputSMIFile (writeScores = False )
172
+ os .link (inSMIFile , smiFile )
173
+ smisFiles = splitFile (smiFile , n = nt , remove = True , pref = 'inputSMIs_predict' )
159
174
160
175
modelsPath = os .path .abspath (autodockPlugin .getPluginHome ('models' ))
161
176
shutil .copytree (os .path .join (modelsPath , sysName ), os .path .abspath (self ._getPath (sysName )), dirs_exist_ok = True )
@@ -166,47 +181,74 @@ def predictionStep(self):
166
181
167
182
pwchemPlugin .runCondaCommand (self , args , GCR_DIC , f'python { scriptName } ' , cwd = self ._getPath ())
168
183
184
+ def writeSMIOutput (self , smi , smiName , oDir ):
185
+ oFile = os .path .join (oDir , f'{ smiName } .smi' )
186
+ with open (oFile , 'w' ) as f :
187
+ f .write (f'{ smi } { smiName } \n ' )
188
+ return oFile
169
189
170
190
def createOutputStep (self ):
171
- scoreDic = self .getScoreDic ()
172
- outputSet = self .inputSmallMolecules .get ().createCopy (self ._getPath (), copyInfo = True )
173
- for mol in self .inputSmallMolecules .get ():
174
- nMol = mol .clone ()
175
- molFile = nMol .getFileName ()
176
- if molFile in scoreDic :
177
- setattr (nMol , '_gcrScore' , params .Float (scoreDic [molFile ]))
178
- outputSet .append (nMol )
191
+ smiScoreDic = self .getScoreDic ()
179
192
193
+ if self .useLibrary .get ():
194
+ oDir = self ._getPath ('outputMolecules' )
195
+ if not os .path .exists (oDir ):
196
+ os .mkdir (oDir )
180
197
181
- outputSet .updateMolClass ()
182
- self ._defineOutputs (outputSmallMolecules = outputSet )
183
- self ._defineSourceRelation (self .inputSmallMolecules , outputSet )
198
+ inLib = self .inputLibrary .get ()
199
+ mapDic = inLib .getLibraryMap ()
200
+
201
+ outputSet = SetOfSmallMolecules ().create (outputPath = self ._getPath ())
202
+ for smi , score in smiScoreDic .items ():
203
+ smiName = mapDic [smi ]
204
+ oFile = self .writeSMIOutput (smi , smiName , oDir )
205
+
206
+ smallMolecule = SmallMolecule (smallMolFilename = oFile )
207
+ smallMolecule .setMolName (smiName )
208
+ setattr (smallMolecule , '_gcrScore' , params .Float (score ))
184
209
210
+ outputSet .append (smallMolecule )
185
211
212
+ else :
213
+ scoreDic = self .mapMolScoreDic (smiScoreDic )
214
+ outputSet = self .inputSmallMolecules .get ().createCopy (self ._getPath (), copyInfo = True )
215
+ for mol in self .inputSmallMolecules .get ():
216
+ nMol = mol .clone ()
217
+ molFile = nMol .getFileName ()
218
+ if molFile in scoreDic :
219
+ setattr (nMol , '_gcrScore' , params .Float (scoreDic [molFile ]))
220
+ outputSet .append (nMol )
221
+ outputSet .updateMolClass ()
222
+ self ._defineSourceRelation (self .inputSmallMolecules , outputSet )
223
+
224
+ self ._defineOutputs (outputSmallMolecules = outputSet )
186
225
187
226
188
227
def getOutputCSV (self ):
189
228
sysName = self .getSystemName ()
190
229
oFile = os .path .abspath (self ._getPath (os .path .join (sysName , 'results/predictions.csv' )))
191
230
if not os .path .exists (oFile ):
192
231
threadFiles = findThreadFiles (oFile )
193
- concatFiles (threadFiles , oFile , remove = True )
232
+ concatFiles (threadFiles , oFile , remove = True , skipHead = 1 )
194
233
195
234
return oFile
196
235
197
-
198
- def getScoreDic ( self ):
236
+ def mapMolScoreDic ( self , smiScoreDic ):
237
+ '''Maps the smi to the roiginal files and retuns: {molFile: score}'''
199
238
mapDic = self .parseCSVDic (self .getMapSMIFile (writeScores = False ))
200
- smiScoreDic = self .parseCSVDic (self .getOutputCSV ())
201
- scoreDic = {molFile : float (eval (smiScoreDic [smi ])[0 ]) for molFile , smi in mapDic .items () if smi in smiScoreDic }
239
+ scoreDic = {molFile : float (smiScoreDic [smi ]) for molFile , smi in mapDic .items () if smi in smiScoreDic }
202
240
return scoreDic
203
241
242
+ def getScoreDic (self ):
243
+ '''Return a dic as {smi: score}'''
244
+ return self .parseCSVDic (self .getOutputCSV ())
245
+
204
246
def parseCSVDic (self , csvFile ):
205
247
smiDic = {}
206
248
with open (csvFile ) as f :
207
249
for line in f :
208
250
sline = line .strip ().split (',' )
209
- smiDic [sline [0 ]] = sline [1 ]
251
+ smiDic [sline [0 ]] = eval ( sline [1 ])[ 0 ]
210
252
return smiDic
211
253
212
254
def getScriptPath (self ):
0 commit comments