22import tempfile
33import uuid
44from pathlib import Path
5- from typing import Generator
5+ from typing import Generator , Optional
66
77import pytest
88from delta import configure_spark_with_delta_pip
@@ -24,7 +24,7 @@ def spark() -> Generator[SparkSession, None, None]:
2424 yield spark
2525 else :
2626 # If databricks-connect is not installed, we use use local Spark session
27- warehouse_dir = tempfile .TemporaryDirectory (). name
27+ warehouse_dir = tempfile .mkdtemp ()
2828 _builder = (
2929 SparkSession .builder .master ("local[*]" )
3030 .config ("spark.hive.metastore.warehouse.dir" , Path (warehouse_dir ).as_uri ())
@@ -47,7 +47,7 @@ def spark() -> Generator[SparkSession, None, None]:
4747
4848
4949@pytest .fixture (scope = "session" )
50- def catalog_name () -> str :
50+ def catalog_name () -> Optional [ str ] :
5151 """Fixture to provide the catalog name for tests.
5252
5353 In Databricks, we use the "lake_dev" catalog.
@@ -67,11 +67,15 @@ def create_schema(spark, catalog_name, request) -> Generator[str, None, None]:
6767 """
6868 module_name = request .module .__name__ .split ("." )[- 1 ] # Get just the module name without path
6969 schema_name = f"pytest_{ module_name } _{ uuid .uuid4 ().hex [:8 ]} "
70+
7071 if catalog_name is not None :
7172 full_schema_name = f"{ catalog_name } .{ schema_name } "
73+ else :
74+ full_schema_name = schema_name
75+
7276 spark .sql (f"CREATE SCHEMA IF NOT EXISTS { full_schema_name } " )
7377 yield schema_name
74- spark .sql (f"DROP SCHEMA { full_schema_name } CASCADE" )
78+ spark .sql (f"DROP SCHEMA IF EXISTS { full_schema_name } CASCADE" )
7579
7680
7781@pytest .fixture (scope = "function" )
0 commit comments