@@ -127,7 +127,7 @@ def __getitem__(self, idx):
127
127
if self .target not in ["processid" , "bin_uri" , "dna_bin" ]:
128
128
label = torch .tensor (self .labels [idx ], dtype = torch .int64 )
129
129
else :
130
- label = self .labels [idx ]
130
+ label = self .labels [idx ]
131
131
132
132
return processed_barcode , label , att_mask
133
133
@@ -136,7 +136,8 @@ def representations_from_df(
136
136
filename ,
137
137
embedder ,
138
138
batch_size = 128 ,
139
- save_embeddings = True ,
139
+ save_embeddings = False ,
140
+ load_embeddings = False ,
140
141
dataset = "BIOSCAN-5M" ,
141
142
embeddings_folder = "embeddings/" ,
142
143
target = "species" ,
@@ -155,24 +156,32 @@ def representations_from_df(
155
156
print (f"Calculating embeddings for { backbone } " )
156
157
157
158
# create a folder for a specific backbone within embeddings
158
- backbone_folder = os .path .join (embeddings_path , backbone )
159
- if not os .path .isdir (backbone_folder ):
160
- os .mkdir (backbone_folder )
161
159
162
- # Check if the embeddings have been saved for that file
163
- prefix = filename .split ("/" )[- 1 ].split ("." )[0 ]
164
- out_fname = f"{ os .path .join (backbone_folder , prefix )} .pickle"
165
- print (out_fname )
166
-
167
- if os .path .exists (out_fname ):
168
- print (f"We found the file { out_fname } . It seems that we have computed the embeddings ... \n " )
169
- print ("Loading the embeddings from that file" )
170
-
171
- with open (out_fname , "rb" ) as handle :
172
- embeddings = pickle .load (handle )
173
-
174
- return embeddings
175
160
161
+ if save_embeddings or load_embeddings :
162
+ embeddings_path = f"{ embeddings_folder } /{ dataset } "
163
+ os .makedirs (embeddings_path , exist_ok = True )
164
+
165
+ backbone_folder = os .path .join (embeddings_path , backbone )
166
+ if not os .path .isdir (backbone_folder ):
167
+ os .mkdir (backbone_folder )
168
+
169
+
170
+ prefix = filename .split ("/" )[- 1 ].split ("." )[0 ]
171
+ out_fname = f"{ os .path .join (backbone_folder , prefix )} .pickle"
172
+ print (out_fname )
173
+
174
+ if load_embeddings :
175
+ if os .path .exists (out_fname ):
176
+ print (f"We found the file { out_fname } . It seems that we have computed the embeddings ... \n " )
177
+ print ("Loading the embeddings from that file" )
178
+
179
+ with open (out_fname , "rb" ) as handle :
180
+ embeddings = pickle .load (handle )
181
+
182
+ return embeddings
183
+ else :
184
+ raise FileNotFoundError (f"We could not find file { out_fname } " )
176
185
else :
177
186
print (f"Just making sure that dataset is { dataset } " )
178
187
dataset_val = DNADataset (
@@ -252,12 +261,14 @@ def representations_from_df(
252
261
# print(all_embeddings.shape)
253
262
# print(all_ids.shape)
254
263
255
- save_embeddings = {"data" : all_embeddings , "ids" : all_ids }
256
-
257
- with open (out_fname , "wb" ) as handle :
258
- pickle .dump (save_embeddings , handle , protocol = pickle .HIGHEST_PROTOCOL )
259
-
260
- return save_embeddings
264
+ to_save_embeddings = {"data" : all_embeddings , "ids" : all_ids }
265
+ if save_embeddings :
266
+ print (f"Saving embeddings to { out_fname } " )
267
+ with open (out_fname , "wb" ) as handle :
268
+ pickle .dump (to_save_embeddings , handle , protocol = pickle .HIGHEST_PROTOCOL )
269
+ else :
270
+ print ("save_embeddings flag set to False. skipping pickle saving" )
271
+ return to_save_embeddings
261
272
262
273
263
274
def labels_from_df (filename , target_level , label_pipeline ):
0 commit comments