@@ -234,16 +234,19 @@ def check_superset_subset(subset=None, superset=None, name_subset=None, name_sup
234
234
235
235
# df checking functions
236
236
def check_df (name = "df" , df = None , accept_none = False , accept_nan = True , check_all_positive = False ,
237
- cols_requiered = None , cols_forbidden = None , cols_nan_check = None , str_add = None ):
237
+ check_series = False , cols_requiered = None , cols_forbidden = None , cols_nan_check = None ,
238
+ str_add = None ):
238
239
"""Check if the provided DataFrame meets various criteria such as NaN values, required/forbidden columns, etc."""
239
240
# Check DataFrame and values
240
241
if df is None :
241
242
if not accept_none :
242
243
raise ValueError (f"'{ name } ' should not be None" )
243
244
else :
244
245
return None
245
- if not isinstance (df , pd .DataFrame ):
246
- str_error = add_str (str_error = f"'{ name } ' ({ type (df )} ) should be DataFrame" ,
246
+ _check_dtype = pd .DataFrame if not check_series else pd .Series
247
+ _str_check_dtype = "DataFrame" if not check_series else "Series"
248
+ if not isinstance (df , _check_dtype ):
249
+ str_error = add_str (str_error = f"'{ name } ' ({ type (df )} ) should be { _str_check_dtype } " ,
247
250
str_add = str_add )
248
251
raise ValueError (str_error )
249
252
if not accept_nan and df .isna ().any ().any ():
@@ -278,11 +281,12 @@ def check_df(name="df", df=None, accept_none=False, accept_nan=True, check_all_p
278
281
str_error = add_str (str_error = f"NaN values are not allowed in '{ cols_nan_check } '." ,
279
282
str_add = str_add )
280
283
raise ValueError (str_error )
281
- columns = list (df )
282
- cols_duplicated = [x for x in columns if columns .count (x ) > 1 ]
283
- if len (cols_duplicated ) > 0 :
284
- str_error = add_str (str_error = f"The following columns are duplicated '{ cols_duplicated } '." , str_add = str_add )
285
- raise ValueError (str_error )
284
+ if _check_dtype == pd .DataFrame :
285
+ columns = list (df )
286
+ cols_duplicated = [x for x in columns if columns .count (x ) > 1 ]
287
+ if len (cols_duplicated ) > 0 :
288
+ str_error = add_str (str_error = f"The following columns are duplicated '{ cols_duplicated } '." , str_add = str_add )
289
+ raise ValueError (str_error )
286
290
287
291
288
292
def check_warning_consecutive_index (name = "df" , df = None ):
0 commit comments