Skip to content

Commit d3cddf0

Browse files
committed
ENH: improve dataset loading
1 parent e8d6361 commit d3cddf0

File tree

1 file changed

+35
-24
lines changed

1 file changed

+35
-24
lines changed

baselines/datasets.py

Lines changed: 35 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def __getitem__(self, idx):
127127
if self.target not in ["processid", "bin_uri", "dna_bin"]:
128128
label = torch.tensor(self.labels[idx], dtype=torch.int64)
129129
else:
130-
label = self.labels[idx]
130+
label = self.labels[idx]
131131

132132
return processed_barcode, label, att_mask
133133

@@ -136,7 +136,8 @@ def representations_from_df(
136136
filename,
137137
embedder,
138138
batch_size=128,
139-
save_embeddings=True,
139+
save_embeddings=False,
140+
load_embeddings=False,
140141
dataset="BIOSCAN-5M",
141142
embeddings_folder="embeddings/",
142143
target="species",
@@ -155,24 +156,32 @@ def representations_from_df(
155156
print(f"Calculating embeddings for {backbone}")
156157

157158
# create a folder for a specific backbone within embeddings
158-
backbone_folder = os.path.join(embeddings_path, backbone)
159-
if not os.path.isdir(backbone_folder):
160-
os.mkdir(backbone_folder)
161159

162-
# Check if the embeddings have been saved for that file
163-
prefix = filename.split("/")[-1].split(".")[0]
164-
out_fname = f"{os.path.join(backbone_folder, prefix)}.pickle"
165-
print(out_fname)
166-
167-
if os.path.exists(out_fname):
168-
print(f"We found the file {out_fname}. It seems that we have computed the embeddings ... \n")
169-
print("Loading the embeddings from that file")
170-
171-
with open(out_fname, "rb") as handle:
172-
embeddings = pickle.load(handle)
173-
174-
return embeddings
175160

161+
if save_embeddings or load_embeddings:
162+
embeddings_path = f"{embeddings_folder}/{dataset}"
163+
os.makedirs(embeddings_path, exist_ok=True)
164+
165+
backbone_folder = os.path.join(embeddings_path, backbone)
166+
if not os.path.isdir(backbone_folder):
167+
os.mkdir(backbone_folder)
168+
169+
170+
prefix = filename.split("/")[-1].split(".")[0]
171+
out_fname = f"{os.path.join(backbone_folder, prefix)}.pickle"
172+
print(out_fname)
173+
174+
if load_embeddings:
175+
if os.path.exists(out_fname):
176+
print(f"We found the file {out_fname}. It seems that we have computed the embeddings ... \n")
177+
print("Loading the embeddings from that file")
178+
179+
with open(out_fname, "rb") as handle:
180+
embeddings = pickle.load(handle)
181+
182+
return embeddings
183+
else:
184+
raise FileNotFoundError(f"We could not find file {out_fname}")
176185
else:
177186
print(f"Just making sure that dataset is {dataset}")
178187
dataset_val = DNADataset(
@@ -252,12 +261,14 @@ def representations_from_df(
252261
# print(all_embeddings.shape)
253262
# print(all_ids.shape)
254263

255-
save_embeddings = {"data": all_embeddings, "ids": all_ids}
256-
257-
with open(out_fname, "wb") as handle:
258-
pickle.dump(save_embeddings, handle, protocol=pickle.HIGHEST_PROTOCOL)
259-
260-
return save_embeddings
264+
to_save_embeddings = {"data": all_embeddings, "ids": all_ids}
265+
if save_embeddings:
266+
print(f"Saving embeddings to {out_fname}")
267+
with open(out_fname, "wb") as handle:
268+
pickle.dump(to_save_embeddings, handle, protocol=pickle.HIGHEST_PROTOCOL)
269+
else:
270+
print("save_embeddings flag set to False. skipping pickle saving")
271+
return to_save_embeddings
261272

262273

263274
def labels_from_df(filename, target_level, label_pipeline):

0 commit comments

Comments
 (0)