Skip to content

Commit 6c9e8bf

Browse files
committed
[data] add test for preprocessor
Signed-off-by: Xingyu Long <xingyulong97@gmail.com>
1 parent 6a50837 commit 6c9e8bf

File tree

1 file changed

+39
-0
lines changed

1 file changed

+39
-0
lines changed

python/ray/data/tests/preprocessors/test_preprocessors.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,45 @@ def test_fit_twice(mocked_warn):
165165
mocked_warn.assert_called_once_with(msg)
166166

167167

168+
def test_ray_remote_args_and_fn():
169+
batch_size = 2
170+
171+
ray_remote_args = {"num_cpus": 2}
172+
173+
def func(df):
174+
import os
175+
176+
df["value"][:] = int(os.environ["__MY_TEST__"])
177+
return df
178+
179+
class DummyPreprocessor(Preprocessor):
180+
_is_fittable = False
181+
182+
def _get_transform_config(self):
183+
return {"batch_size": batch_size}
184+
185+
def _transform_numpy(self, data):
186+
assert (
187+
ray.get_runtime_context().get_assigned_resources()["CPU"]
188+
== ray_remote_args["num_cpus"]
189+
)
190+
assert len(data["value"]) == batch_size
191+
func(data)
192+
return data
193+
194+
def _determine_transform_to_use(self):
195+
return "numpy"
196+
197+
prep = DummyPreprocessor(
198+
ray_remote_args=ray_remote_args,
199+
ray_remote_args_fn=lambda: {"runtime_env": {"env_vars": {"__MY_TEST__": "69"}}},
200+
)
201+
ds = ray.data.from_pandas(pd.DataFrame({"value": list(range(10))}))
202+
ds = prep.transform(ds)
203+
204+
assert sorted([x["value"] for x in ds.take(5)]) == [69, 69, 69, 69, 69]
205+
206+
168207
def test_transform_config():
169208
"""Tests that the transform_config of
170209
the Preprocessor is respected during transform."""

0 commit comments

Comments
 (0)