Skip to content

Commit aa24a9a

Browse files
wbo4958zhengruifeng
authored andcommitted
[SPARK-50812][ML][PYTHON][CONNECT] Add support PolynomialExpansion
### What changes were proposed in this pull request? Support PolynomialExpansion on connect ### Why are the changes needed? feature parity ### Does this PR introduce _any_ user-facing change? Yes ### How was this patch tested? CI passes ### Was this patch authored or co-authored using generative AI tooling? No Closes #49702 from wbo4958/px. Authored-by: Bobby Wang <wbo4958@gmail.com> Signed-off-by: Ruifeng Zheng <ruifengz@apache.org>
1 parent e891627 commit aa24a9a

File tree

2 files changed

+27
-0
lines changed

2 files changed

+27
-0
lines changed

mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ org.apache.spark.ml.feature.FeatureHasher
3636
org.apache.spark.ml.feature.ElementwiseProduct
3737
org.apache.spark.ml.feature.HashingTF
3838
org.apache.spark.ml.feature.IndexToString
39+
org.apache.spark.ml.feature.PolynomialExpansion
3940

4041
########### Model for loading
4142
# classification

python/pyspark/ml/tests/test_feature.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@
7777
MinHashLSH,
7878
MinHashLSHModel,
7979
IndexToString,
80+
PolynomialExpansion,
8081
)
8182
from pyspark.ml.linalg import DenseVector, SparseVector, Vectors
8283
from pyspark.sql import Row
@@ -85,6 +86,31 @@
8586

8687

8788
class FeatureTestsMixin:
89+
def test_polynomial_expansion(self):
90+
df = self.spark.createDataFrame([(Vectors.dense([0.5, 2.0]),)], ["dense"])
91+
px = PolynomialExpansion(degree=2)
92+
px.setInputCol("dense")
93+
px.setOutputCol("expanded")
94+
self.assertTrue(
95+
np.allclose(
96+
px.transform(df).head().expanded.toArray(), [0.5, 0.25, 2.0, 1.0, 4.0], atol=1e-4
97+
)
98+
)
99+
100+
def check(p: PolynomialExpansion) -> None:
101+
self.assertEqual(p.getInputCol(), "dense")
102+
self.assertEqual(p.getOutputCol(), "expanded")
103+
self.assertEqual(p.getDegree(), 2)
104+
105+
check(px)
106+
107+
# save & load
108+
with tempfile.TemporaryDirectory(prefix="px") as d:
109+
px.write().overwrite().save(d)
110+
px2 = PolynomialExpansion.load(d)
111+
self.assertEqual(str(px), str(px2))
112+
check(px2)
113+
88114
def test_index_string(self):
89115
dataset = self.spark.createDataFrame(
90116
[

0 commit comments

Comments
 (0)