@@ -517,6 +517,37 @@ def test_stop_words_remover(self):
517
517
remover2 = StopWordsRemover .load (d )
518
518
self .assertEqual (str (remover ), str (remover2 ))
519
519
520
+ def test_stop_words_remover_II (self ):
521
+ dataset = self .spark .createDataFrame ([Row (input = ["a" , "panda" ])])
522
+ stopWordRemover = StopWordsRemover (inputCol = "input" , outputCol = "output" )
523
+ # Default
524
+ self .assertEqual (stopWordRemover .getInputCol (), "input" )
525
+ transformedDF = stopWordRemover .transform (dataset )
526
+ self .assertEqual (transformedDF .head ().output , ["panda" ])
527
+ self .assertEqual (type (stopWordRemover .getStopWords ()), list )
528
+ self .assertTrue (isinstance (stopWordRemover .getStopWords ()[0 ], str ))
529
+ # Custom
530
+ stopwords = ["panda" ]
531
+ stopWordRemover .setStopWords (stopwords )
532
+ self .assertEqual (stopWordRemover .getInputCol (), "input" )
533
+ self .assertEqual (stopWordRemover .getStopWords (), stopwords )
534
+ transformedDF = stopWordRemover .transform (dataset )
535
+ self .assertEqual (transformedDF .head ().output , ["a" ])
536
+ # with language selection
537
+ stopwords = StopWordsRemover .loadDefaultStopWords ("turkish" )
538
+ dataset = self .spark .createDataFrame ([Row (input = ["acaba" , "ama" , "biri" ])])
539
+ stopWordRemover .setStopWords (stopwords )
540
+ self .assertEqual (stopWordRemover .getStopWords (), stopwords )
541
+ transformedDF = stopWordRemover .transform (dataset )
542
+ self .assertEqual (transformedDF .head ().output , [])
543
+ # with locale
544
+ stopwords = ["BELKİ" ]
545
+ dataset = self .spark .createDataFrame ([Row (input = ["belki" ])])
546
+ stopWordRemover .setStopWords (stopwords ).setLocale ("tr" )
547
+ self .assertEqual (stopWordRemover .getStopWords (), stopwords )
548
+ transformedDF = stopWordRemover .transform (dataset )
549
+ self .assertEqual (transformedDF .head ().output , [])
550
+
520
551
def test_binarizer (self ):
521
552
b0 = Binarizer ()
522
553
self .assertListEqual (
@@ -570,37 +601,6 @@ def test_ngram(self):
570
601
transformedDF = ngram0 .transform (dataset )
571
602
self .assertEqual (transformedDF .head ().output , ["a b c d" , "b c d e" ])
572
603
573
- def test_stopwordsremover (self ):
574
- dataset = self .spark .createDataFrame ([Row (input = ["a" , "panda" ])])
575
- stopWordRemover = StopWordsRemover (inputCol = "input" , outputCol = "output" )
576
- # Default
577
- self .assertEqual (stopWordRemover .getInputCol (), "input" )
578
- transformedDF = stopWordRemover .transform (dataset )
579
- self .assertEqual (transformedDF .head ().output , ["panda" ])
580
- self .assertEqual (type (stopWordRemover .getStopWords ()), list )
581
- self .assertTrue (isinstance (stopWordRemover .getStopWords ()[0 ], str ))
582
- # Custom
583
- stopwords = ["panda" ]
584
- stopWordRemover .setStopWords (stopwords )
585
- self .assertEqual (stopWordRemover .getInputCol (), "input" )
586
- self .assertEqual (stopWordRemover .getStopWords (), stopwords )
587
- transformedDF = stopWordRemover .transform (dataset )
588
- self .assertEqual (transformedDF .head ().output , ["a" ])
589
- # with language selection
590
- stopwords = StopWordsRemover .loadDefaultStopWords ("turkish" )
591
- dataset = self .spark .createDataFrame ([Row (input = ["acaba" , "ama" , "biri" ])])
592
- stopWordRemover .setStopWords (stopwords )
593
- self .assertEqual (stopWordRemover .getStopWords (), stopwords )
594
- transformedDF = stopWordRemover .transform (dataset )
595
- self .assertEqual (transformedDF .head ().output , [])
596
- # with locale
597
- stopwords = ["BELKİ" ]
598
- dataset = self .spark .createDataFrame ([Row (input = ["belki" ])])
599
- stopWordRemover .setStopWords (stopwords ).setLocale ("tr" )
600
- self .assertEqual (stopWordRemover .getStopWords (), stopwords )
601
- transformedDF = stopWordRemover .transform (dataset )
602
- self .assertEqual (transformedDF .head ().output , [])
603
-
604
604
def test_count_vectorizer_with_binary (self ):
605
605
dataset = self .spark .createDataFrame (
606
606
[
0 commit comments