Skip to content

Commit 6955bd5

Browse files
committed
[SPARK-50931][ML][PYTHON][CONNECT] Support Binarizer on connect
### What changes were proposed in this pull request? Support Binarizer on connect ### Why are the changes needed? feature parity ### Does this PR introduce _any_ user-facing change? yes, new algorithm ### How was this patch tested? added test ### Was this patch authored or co-authored using generative AI tooling? no Closes #49637 from zhengruifeng/ml_connect_binarizer. Authored-by: Ruifeng Zheng <ruifengz@apache.org> Signed-off-by: Ruifeng Zheng <ruifengz@apache.org>
1 parent 3377962 commit 6955bd5

File tree

3 files changed

+41
-4
lines changed

3 files changed

+41
-4
lines changed

mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
# So register the supported transformer here if you're trying to add a new one.
2020
########### Transformers
2121
org.apache.spark.ml.feature.DCT
22+
org.apache.spark.ml.feature.Binarizer
2223
org.apache.spark.ml.feature.VectorAssembler
2324
org.apache.spark.ml.feature.Tokenizer
2425
org.apache.spark.ml.feature.RegexTokenizer

python/pyspark/ml/tests/connect/test_parity_feature.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,6 @@
2222

2323

2424
class FeatureParityTests(FeatureTestsMixin, ReusedConnectTestCase):
25-
@unittest.skip("Need to support.")
26-
def test_binarizer(self):
27-
super().test_binarizer()
28-
2925
@unittest.skip("Need to support.")
3026
def test_idf(self):
3127
super().test_idf()

python/pyspark/ml/tests/test_feature.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -574,6 +574,46 @@ def test_binarizer(self):
574574
self.assertEqual(b1.getInputCol(), "input")
575575
self.assertEqual(b1.getOutputCol(), "output")
576576

577+
df = self.spark.createDataFrame(
578+
[
579+
(0.1, 0.0),
580+
(0.4, 1.0),
581+
(1.2, 1.3),
582+
(1.5, float("nan")),
583+
(float("nan"), 1.0),
584+
(float("nan"), 0.0),
585+
],
586+
["v1", "v2"],
587+
)
588+
589+
bucketizer = Binarizer(threshold=1.0, inputCol="v1", outputCol="f1")
590+
output = bucketizer.transform(df)
591+
self.assertEqual(output.columns, ["v1", "v2", "f1"])
592+
self.assertEqual(output.count(), 6)
593+
self.assertEqual(
594+
[r.f1 for r in output.select("f1").collect()],
595+
[0.0, 0.0, 1.0, 1.0, 0.0, 0.0],
596+
)
597+
598+
bucketizer = Binarizer(threshold=1.0, inputCols=["v1", "v2"], outputCols=["f1", "f2"])
599+
output = bucketizer.transform(df)
600+
self.assertEqual(output.columns, ["v1", "v2", "f1", "f2"])
601+
self.assertEqual(output.count(), 6)
602+
self.assertEqual(
603+
[r.f1 for r in output.select("f1").collect()],
604+
[0.0, 0.0, 1.0, 1.0, 0.0, 0.0],
605+
)
606+
self.assertEqual(
607+
[r.f2 for r in output.select("f2").collect()],
608+
[0.0, 0.0, 1.0, 0.0, 0.0, 0.0],
609+
)
610+
611+
# save & load
612+
with tempfile.TemporaryDirectory(prefix="binarizer") as d:
613+
bucketizer.write().overwrite().save(d)
614+
bucketizer2 = Binarizer.load(d)
615+
self.assertEqual(str(bucketizer), str(bucketizer2))
616+
577617
def test_idf(self):
578618
dataset = self.spark.createDataFrame(
579619
[(DenseVector([1.0, 2.0]),), (DenseVector([0.0, 1.0]),), (DenseVector([3.0, 0.2]),)],

0 commit comments

Comments
 (0)