Skip to content

Commit ee6c166

Browse files
Style fix for Dataloader (#1838)
1 parent 904392a commit ee6c166

File tree

1 file changed

+25
-11
lines changed

1 file changed

+25
-11
lines changed

dspy/datasets/dataloader.py

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,7 @@
1010

1111

1212
class DataLoader(Dataset):
13-
def __init__(
14-
self,
15-
):
13+
def __init__(self):
1614
pass
1715

1816
def from_huggingface(
@@ -97,8 +95,7 @@ def from_parquet(self, file_path: str, fields: List[str] = None, input_keys: Tup
9795

9896
return [dspy.Example({field: row[field] for field in fields}).with_inputs(input_keys) for row in dataset]
9997

100-
def from_rm(self, num_samples: int,
101-
fields: List[str], input_keys: List[str]) -> List[dspy.Example]:
98+
def from_rm(self, num_samples: int, fields: List[str], input_keys: List[str]) -> List[dspy.Example]:
10299
try:
103100
rm = dspy.settings.rm
104101
try:
@@ -107,9 +104,13 @@ def from_rm(self, num_samples: int,
107104
for row in rm.get_objects(num_samples=num_samples, fields=fields)
108105
]
109106
except AttributeError:
110-
raise ValueError("Retrieval module does not support `get_objects`. Please use a different retrieval module.")
107+
raise ValueError(
108+
"Retrieval module does not support `get_objects`. Please use a different retrieval module."
109+
)
111110
except AttributeError:
112-
raise ValueError("Retrieval module not found. Please set a retrieval module using `dspy.settings.configure`.")
111+
raise ValueError(
112+
"Retrieval module not found. Please set a retrieval module using `dspy.settings.configure`."
113+
)
113114

114115
def sample(
115116
self,
@@ -119,7 +120,9 @@ def sample(
119120
**kwargs,
120121
) -> List[dspy.Example]:
121122
if not isinstance(dataset, list):
122-
raise ValueError(f"Invalid dataset provided of type {type(dataset)}. Please provide a list of examples.")
123+
raise ValueError(
124+
f"Invalid dataset provided of type {type(dataset)}. Please provide a list of `dspy.Example`s."
125+
)
123126

124127
return random.sample(dataset, n, *args, **kwargs)
125128

@@ -141,17 +144,28 @@ def train_test_split(
141144
elif train_size is not None and isinstance(train_size, int):
142145
train_end = train_size
143146
else:
144-
raise ValueError("Invalid train_size. Please provide a float between 0 and 1 or an int.")
147+
raise ValueError(
148+
"Invalid `train_size`. Please provide a float between 0 and 1 to represent the proportion of the "
149+
"dataset to include in the train split or an int to represent the absolute number of samples to "
150+
f"include in the train split. Received `train_size`: {train_size}."
151+
)
145152

146153
if test_size is not None:
147154
if isinstance(test_size, float) and (0 < test_size < 1):
148155
test_end = int(len(dataset_shuffled) * test_size)
149156
elif isinstance(test_size, int):
150157
test_end = test_size
151158
else:
152-
raise ValueError("Invalid test_size. Please provide a float between 0 and 1 or an int.")
159+
raise ValueError(
160+
"Invalid `test_size`. Please provide a float between 0 and 1 to represent the proportion of the "
161+
"dataset to include in the test split or an int to represent the absolute number of samples to "
162+
f"include in the test split. Received `test_size`: {test_size}."
163+
)
153164
if train_end + test_end > len(dataset_shuffled):
154-
raise ValueError("train_size + test_size cannot exceed the total number of samples.")
165+
raise ValueError(
166+
"`train_size` + `test_size` cannot exceed the total number of samples. Received "
167+
f"`train_size`: {train_end}, `test_size`: {test_end}, and `dataset_size`: {len(dataset_shuffled)}."
168+
)
155169
else:
156170
test_end = len(dataset_shuffled) - train_end
157171

0 commit comments

Comments
 (0)