Skip to content

Commit 3377962

Browse files
committed
[SPARK-50963][ML][PYTHON][TESTS][FOLLOW-UP] Enable a parity test
### What changes were proposed in this pull request? Enable a existing test on connect, move it after `test_stop_words_remover` and rename it ### Why are the changes needed? for test coverage ### Does this PR introduce _any_ user-facing change? no, test-only ### How was this patch tested? parity test ### Was this patch authored or co-authored using generative AI tooling? no Closes #49642 from zhengruifeng/ml_remover_test. Authored-by: Ruifeng Zheng <ruifengz@apache.org> Signed-off-by: Ruifeng Zheng <ruifengz@apache.org>
1 parent 560dd5e commit 3377962

File tree

2 files changed

+31
-35
lines changed

2 files changed

+31
-35
lines changed

python/pyspark/ml/tests/connect/test_parity_feature.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,6 @@ def test_idf(self):
3434
def test_ngram(self):
3535
super().test_ngram()
3636

37-
@unittest.skip("Need to support.")
38-
def test_stopwordsremover(self):
39-
super().test_stopwordsremover()
40-
4137
@unittest.skip("Need to support.")
4238
def test_count_vectorizer_with_binary(self):
4339
super().test_count_vectorizer_with_binary()

python/pyspark/ml/tests/test_feature.py

Lines changed: 31 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -517,6 +517,37 @@ def test_stop_words_remover(self):
517517
remover2 = StopWordsRemover.load(d)
518518
self.assertEqual(str(remover), str(remover2))
519519

520+
def test_stop_words_remover_II(self):
521+
dataset = self.spark.createDataFrame([Row(input=["a", "panda"])])
522+
stopWordRemover = StopWordsRemover(inputCol="input", outputCol="output")
523+
# Default
524+
self.assertEqual(stopWordRemover.getInputCol(), "input")
525+
transformedDF = stopWordRemover.transform(dataset)
526+
self.assertEqual(transformedDF.head().output, ["panda"])
527+
self.assertEqual(type(stopWordRemover.getStopWords()), list)
528+
self.assertTrue(isinstance(stopWordRemover.getStopWords()[0], str))
529+
# Custom
530+
stopwords = ["panda"]
531+
stopWordRemover.setStopWords(stopwords)
532+
self.assertEqual(stopWordRemover.getInputCol(), "input")
533+
self.assertEqual(stopWordRemover.getStopWords(), stopwords)
534+
transformedDF = stopWordRemover.transform(dataset)
535+
self.assertEqual(transformedDF.head().output, ["a"])
536+
# with language selection
537+
stopwords = StopWordsRemover.loadDefaultStopWords("turkish")
538+
dataset = self.spark.createDataFrame([Row(input=["acaba", "ama", "biri"])])
539+
stopWordRemover.setStopWords(stopwords)
540+
self.assertEqual(stopWordRemover.getStopWords(), stopwords)
541+
transformedDF = stopWordRemover.transform(dataset)
542+
self.assertEqual(transformedDF.head().output, [])
543+
# with locale
544+
stopwords = ["BELKİ"]
545+
dataset = self.spark.createDataFrame([Row(input=["belki"])])
546+
stopWordRemover.setStopWords(stopwords).setLocale("tr")
547+
self.assertEqual(stopWordRemover.getStopWords(), stopwords)
548+
transformedDF = stopWordRemover.transform(dataset)
549+
self.assertEqual(transformedDF.head().output, [])
550+
520551
def test_binarizer(self):
521552
b0 = Binarizer()
522553
self.assertListEqual(
@@ -570,37 +601,6 @@ def test_ngram(self):
570601
transformedDF = ngram0.transform(dataset)
571602
self.assertEqual(transformedDF.head().output, ["a b c d", "b c d e"])
572603

573-
def test_stopwordsremover(self):
574-
dataset = self.spark.createDataFrame([Row(input=["a", "panda"])])
575-
stopWordRemover = StopWordsRemover(inputCol="input", outputCol="output")
576-
# Default
577-
self.assertEqual(stopWordRemover.getInputCol(), "input")
578-
transformedDF = stopWordRemover.transform(dataset)
579-
self.assertEqual(transformedDF.head().output, ["panda"])
580-
self.assertEqual(type(stopWordRemover.getStopWords()), list)
581-
self.assertTrue(isinstance(stopWordRemover.getStopWords()[0], str))
582-
# Custom
583-
stopwords = ["panda"]
584-
stopWordRemover.setStopWords(stopwords)
585-
self.assertEqual(stopWordRemover.getInputCol(), "input")
586-
self.assertEqual(stopWordRemover.getStopWords(), stopwords)
587-
transformedDF = stopWordRemover.transform(dataset)
588-
self.assertEqual(transformedDF.head().output, ["a"])
589-
# with language selection
590-
stopwords = StopWordsRemover.loadDefaultStopWords("turkish")
591-
dataset = self.spark.createDataFrame([Row(input=["acaba", "ama", "biri"])])
592-
stopWordRemover.setStopWords(stopwords)
593-
self.assertEqual(stopWordRemover.getStopWords(), stopwords)
594-
transformedDF = stopWordRemover.transform(dataset)
595-
self.assertEqual(transformedDF.head().output, [])
596-
# with locale
597-
stopwords = ["BELKİ"]
598-
dataset = self.spark.createDataFrame([Row(input=["belki"])])
599-
stopWordRemover.setStopWords(stopwords).setLocale("tr")
600-
self.assertEqual(stopWordRemover.getStopWords(), stopwords)
601-
transformedDF = stopWordRemover.transform(dataset)
602-
self.assertEqual(transformedDF.head().output, [])
603-
604604
def test_count_vectorizer_with_binary(self):
605605
dataset = self.spark.createDataFrame(
606606
[

0 commit comments

Comments
 (0)