33from __future__ import annotations
44
55import itertools
6- from functools import cached_property
6+ import re
7+ from collections .abc import Callable
8+ from functools import cached_property , lru_cache
79from typing import TYPE_CHECKING , cast
810
911import numpy as np
5254 from cudf .core .dtypes import DecimalDtype
5355
5456
57+ # For now all supported re flags have matching names in libcudf. If that ever changes
58+ # this construction will need to be updated with more explicit mapping.
59+ _FLAG_MAP = {
60+ getattr (re , flag ): getattr (plc .strings .regex_flags .RegexFlags , flag )
61+ for flag in ("MULTILINE" , "DOTALL" )
62+ }
63+
64+
65+ @lru_cache
66+ def plc_flags_from_re_flags (
67+ flags : re .RegexFlag ,
68+ ) -> plc .strings .regex_flags .RegexFlags :
69+ # Convert Python re flags to pylibcudf RegexFlags
70+ plc_flags = plc .strings .regex_flags .RegexFlags (0 )
71+ for re_flag , plc_flag in _FLAG_MAP .items ():
72+ if flags & re_flag :
73+ plc_flags |= plc_flag
74+ flags &= ~ re_flag
75+ if flags :
76+ raise ValueError (f"Unsupported re flags: { flags } " )
77+ return plc_flags
78+
79+
5580class StringColumn (ColumnBase ):
5681 """
5782 Implements operations for Columns of String type
@@ -323,7 +348,9 @@ def as_numerical_column(self, dtype: np.dtype) -> NumericalColumn:
323348 if not is_pandas_nullable_extension_dtype (dtype ):
324349 result = result .fillna (False )
325350 return result ._with_type_metadata (dtype ) # type: ignore[return-value]
326- elif dtype .kind in {"i" , "u" }:
351+
352+ cast_func : Callable [[plc .Column , plc .DataType ], plc .Column ]
353+ if dtype .kind in {"i" , "u" }:
327354 if not self .is_integer ().all ():
328355 raise ValueError (
329356 "Could not convert strings to integer "
@@ -362,7 +389,9 @@ def strptime(
362389 raise ValueError (
363390 "Cannot convert `None` value to datetime or timedelta."
364391 )
365- elif dtype .kind == "M" : # type: ignore[union-attr]
392+
393+ casting_func : Callable [[plc .Column , plc .DataType , str ], plc .Column ]
394+ if dtype .kind == "M" : # type: ignore[union-attr]
366395 if format .endswith ("%z" ):
367396 raise NotImplementedError (
368397 "cuDF does not yet support timezone-aware datetimes"
@@ -587,10 +616,10 @@ def _binaryop(self, other: ColumnBinaryOperand, op: str) -> ColumnBase:
587616 }:
588617 if isinstance (other , pa .Scalar ):
589618 other = pa_scalar_to_plc_scalar (other )
590- lhs , rhs = (other , self ) if reflect else (self , other )
619+ lhs_op , rhs_op = (other , self ) if reflect else (self , other )
591620 return binaryop .binaryop (
592- lhs = lhs ,
593- rhs = rhs ,
621+ lhs = lhs_op ,
622+ rhs = rhs_op ,
594623 op = op ,
595624 dtype = get_dtype_of_same_kind (
596625 self .dtype , np .dtype (np .bool_ )
@@ -1062,7 +1091,7 @@ def _split(
10621091 self ,
10631092 delimiter : plc .Scalar ,
10641093 maxsplit : int ,
1065- method : Callable [[plc .Column , plc .Scalar , int ], plc .Column ],
1094+ method : Callable [[plc .Column , plc .Scalar , int ], plc .Table ],
10661095 ) -> dict [int , Self ]:
10671096 plc_table = method (
10681097 self .to_pylibcudf (mode = "read" ),
@@ -1086,7 +1115,7 @@ def rsplit(self, delimiter: plc.Scalar, maxsplit: int) -> dict[int, Self]:
10861115 def _partition (
10871116 self ,
10881117 delimiter : plc .Scalar ,
1089- method : Callable [[plc .Column , plc .Scalar ], plc .Column ],
1118+ method : Callable [[plc .Column , plc .Scalar ], plc .Table ],
10901119 ) -> dict [int , Self ]:
10911120 plc_table = method (
10921121 self .to_pylibcudf (mode = "read" ),
@@ -1180,7 +1209,10 @@ def concatenate(
11801209 def extract (self , pattern : str , flags : int ) -> dict [int , Self ]:
11811210 plc_table = plc .strings .extract .extract (
11821211 self .to_pylibcudf (mode = "read" ),
1183- plc .strings .regex_program .RegexProgram .create (pattern , flags ),
1212+ plc .strings .regex_program .RegexProgram .create (
1213+ pattern ,
1214+ plc_flags_from_re_flags (flags ),
1215+ ),
11841216 )
11851217 return dict (
11861218 enumerate (
@@ -1192,7 +1224,10 @@ def extract(self, pattern: str, flags: int) -> dict[int, Self]:
11921224 def contains_re (self , pattern : str , flags : int ) -> Self :
11931225 plc_column = plc .strings .contains .contains_re (
11941226 self .to_pylibcudf (mode = "read" ),
1195- plc .strings .regex_program .RegexProgram .create (pattern , flags ),
1227+ plc .strings .regex_program .RegexProgram .create (
1228+ pattern ,
1229+ plc_flags_from_re_flags (flags ),
1230+ ),
11961231 )
11971232 return type (self ).from_pylibcudf (plc_column ) # type: ignore[return-value]
11981233
@@ -1400,7 +1435,9 @@ def wrap(self, width: int) -> Self:
14001435 def count_re (self , pattern : str , flags : int ) -> NumericalColumn :
14011436 plc_result = plc .strings .contains .count_re (
14021437 self .to_pylibcudf (mode = "read" ),
1403- plc .strings .regex_program .RegexProgram .create (pattern , flags ),
1438+ plc .strings .regex_program .RegexProgram .create (
1439+ pattern , plc_flags_from_re_flags (flags )
1440+ ),
14041441 )
14051442 return type (self ).from_pylibcudf (plc_result ) # type: ignore[return-value]
14061443
@@ -1415,7 +1452,9 @@ def findall(
14151452 ) -> Self :
14161453 plc_result = method (
14171454 self .to_pylibcudf (mode = "read" ),
1418- plc .strings .regex_program .RegexProgram .create (pat , flags ),
1455+ plc .strings .regex_program .RegexProgram .create (
1456+ pat , plc_flags_from_re_flags (flags )
1457+ ),
14191458 )
14201459 return type (self ).from_pylibcudf (plc_result ) # type: ignore[return-value]
14211460
@@ -1464,7 +1503,9 @@ def find(
14641503 def matches_re (self , pattern : str , flags : int ) -> Self :
14651504 plc_result = plc .strings .contains .matches_re (
14661505 self .to_pylibcudf (mode = "read" ),
1467- plc .strings .regex_program .RegexProgram .create (pattern , flags ),
1506+ plc .strings .regex_program .RegexProgram .create (
1507+ pattern , plc_flags_from_re_flags (flags )
1508+ ),
14681509 )
14691510 return type (self ).from_pylibcudf (plc_result ) # type: ignore[return-value]
14701511
0 commit comments