Skip to content

refactor: remove __future__ annotations and update type hints #17

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ repos:
- id: mixed-line-ending
- id: trailing-whitespace
- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: "v0.9.4"
rev: "v0.9.10"
hooks:
- id: ruff
args: ["--fix"]
Expand Down
9 changes: 9 additions & 0 deletions base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from sqlspec.adapters.duckdb.config import DuckDBConfig
from sqlspec.base import ConfigManager

dbs = ConfigManager()

config = DuckDBConfig(database="test.duckdb", extensions=[{"name": "vss"}])
etl_db = dbs.add_config(config)

connection = dbs.get_connection(etl_db)
6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@ flask = ["flask"]
litestar = ["litestar"]
msgspec = ["msgspec"]
oracledb = ["oracledb"]
orjson = ["orjson"]
performance = ["sqlglot[rs]"]
psycopg = ["psycopg[binary,pool]"]
pydantic = ["pydantic"]
pydantic = ["pydantic", "pydantic-extra-types"]
pymssql = ["pymssql"]
pymysql = ["pymysql"]
spanner = ["google-cloud-spanner"]
Expand Down Expand Up @@ -211,6 +212,7 @@ lint.select = [
"UP", # pyupgrade
"W", # pycodestyle - warning
"YTT", # flake8-2020

]

line-length = 120
Expand All @@ -232,6 +234,8 @@ lint.ignore = [
"PLW2901", # pylint - for loop variable overwritten by assignment target
"RUF012", # Ruff-specific rule - annotated with classvar
"ISC001", # Ruff formatter incompatible
"A005", # flake8 - Module `x` shadows a Python standard-library module
"PLC0415", # pylint - `import` should be at the top of the file
]
src = ["sqlspec", "tests", "docs/examples"]
target-version = "py39"
Expand Down
1 change: 0 additions & 1 deletion sqlspec/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +0,0 @@
from __future__ import annotations
2 changes: 0 additions & 2 deletions sqlspec/__metadata__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
"""Metadata for the Project."""

from __future__ import annotations

from importlib.metadata import PackageNotFoundError, metadata, version

__all__ = ("__project__", "__version__")
Expand Down
61 changes: 53 additions & 8 deletions sqlspec/_serialization.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,69 @@
import datetime
import enum
from typing import Any

__all__ = ("decode_json", "encode_json")
from sqlspec._typing import PYDANTIC_INSTALLED, BaseModel


def _type_to_string(value: Any) -> str: # pragma: no cover
if isinstance(value, datetime.datetime):
return convert_datetime_to_gmt_iso(value)
if isinstance(value, datetime.date):
return convert_date_to_iso(value)
if isinstance(value, enum.Enum):
return str(value.value)
if PYDANTIC_INSTALLED and isinstance(value, BaseModel):
return value.model_dump_json()
try:
val = str(value)
except Exception as exc:
raise TypeError from exc
return val


try:
from msgspec.json import Decoder, Encoder # pyright: ignore[reportMissingImports]
from msgspec.json import Decoder, Encoder

encoder, decoder = Encoder(), Decoder()
encoder, decoder = Encoder(enc_hook=_type_to_string), Decoder()
decode_json = decoder.decode

def encode_json(data: Any) -> str:
def encode_json(data: Any) -> str: # pragma: no cover
return encoder.encode(data).decode("utf-8")

except ImportError:
try:
from orjson import dumps as _encode_json # pyright: ignore[reportMissingImports,reportUnknownVariableType]
from orjson import loads as decode_json # type: ignore[no-redef]
from orjson import ( # pyright: ignore[reportMissingImports]
OPT_NAIVE_UTC, # pyright: ignore[reportUnknownVariableType]
OPT_SERIALIZE_NUMPY, # pyright: ignore[reportUnknownVariableType]
OPT_SERIALIZE_UUID, # pyright: ignore[reportUnknownVariableType]
)
from orjson import dumps as _encode_json # pyright: ignore[reportUnknownVariableType,reportMissingImports]
from orjson import loads as decode_json # type: ignore[no-redef,assignment,unused-ignore]

def encode_json(data: Any) -> str:
return _encode_json(data).decode("utf-8") # type: ignore[no-any-return]
def encode_json(data: Any) -> str: # pragma: no cover
return _encode_json(
data, default=_type_to_string, option=OPT_SERIALIZE_NUMPY | OPT_NAIVE_UTC | OPT_SERIALIZE_UUID
).decode("utf-8")

except ImportError:
from json import dumps as encode_json # type: ignore[assignment]
from json import loads as decode_json # type: ignore[assignment]

__all__ = (
"convert_date_to_iso",
"convert_datetime_to_gmt_iso",
"decode_json",
"encode_json",
)


def convert_datetime_to_gmt_iso(dt: datetime.datetime) -> str: # pragma: no cover
"""Handle datetime serialization for nested timestamps."""
if not dt.tzinfo:
dt = dt.replace(tzinfo=datetime.timezone.utc)
return dt.isoformat().replace("+00:00", "Z")


def convert_date_to_iso(dt: datetime.date) -> str: # pragma: no cover
"""Handle datetime serialization for nested timestamps."""
return dt.isoformat()
89 changes: 69 additions & 20 deletions sqlspec/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
This is used to ensure compatibility when one or more of the libraries are installed.
"""

from __future__ import annotations

from enum import Enum
from typing import (
Any,
Expand All @@ -30,21 +28,58 @@ class DataclassProtocol(Protocol):
T_co = TypeVar("T_co", covariant=True)

try:
from pydantic import BaseModel, FailFast, TypeAdapter
from pydantic import (
BaseModel,
FailFast, # pyright: ignore[reportGeneralTypeIssues,reportAssignmentType]
TypeAdapter,
)

PYDANTIC_INSTALLED = True
except ImportError:
from dataclasses import dataclass

@runtime_checkable
class BaseModel(Protocol): # type: ignore[no-redef]
"""Placeholder Implementation"""

model_fields: ClassVar[dict[str, Any]]

def model_dump(self, *args: Any, **kwargs: Any) -> dict[str, Any]:
def model_dump(
self,
/,
*,
include: "Optional[Any]" = None,
exclude: "Optional[Any]" = None,
context: "Optional[Any]" = None,
by_alias: bool = False,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
round_trip: bool = False,
warnings: "Union[bool, Literal['none', 'warn', 'error']]" = True,
serialize_as_any: bool = False,
) -> "dict[str, Any]":
"""Placeholder"""
return {}

def model_dump_json(
self,
/,
*,
include: "Optional[Any]" = None,
exclude: "Optional[Any]" = None,
context: "Optional[Any]" = None,
by_alias: bool = False,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
round_trip: bool = False,
warnings: "Union[bool, Literal['none', 'warn', 'error']]" = True,
serialize_as_any: bool = False,
) -> str:
"""Placeholder"""
return ""

@runtime_checkable
class TypeAdapter(Protocol[T_co]): # type: ignore[no-redef]
"""Placeholder Implementation"""
Expand All @@ -53,9 +88,9 @@ def __init__(
self,
type: Any, # noqa: A002
*,
config: Any | None = None,
config: "Optional[Any]" = None,
_parent_depth: int = 2,
module: str | None = None,
module: "Optional[str]" = None,
) -> None:
"""Init"""

Expand All @@ -64,42 +99,56 @@ def validate_python(
object: Any, # noqa: A002
/,
*,
strict: bool | None = None,
from_attributes: bool | None = None,
context: dict[str, Any] | None = None,
) -> T_co:
strict: "Optional[bool]" = None,
from_attributes: "Optional[bool]" = None,
context: "Optional[dict[str, Any]]" = None,
experimental_allow_partial: "Union[bool, Literal['off', 'on', 'trailing-strings']]" = False,
) -> "T_co":
"""Stub"""
return cast("T_co", object)

@runtime_checkable
class FailFast(Protocol): # type: ignore[no-redef]
@dataclass
class FailFast: # type: ignore[no-redef]
"""Placeholder Implementation for FailFast"""

def __init__(self, *args: Any, **kwargs: Any) -> None:
"""Init"""
fail_fast: bool = True

PYDANTIC_INSTALLED = False # pyright: ignore[reportConstantRedefinition]

try:
from msgspec import (
UNSET,
Struct,
UnsetType, # pyright: ignore[reportAssignmentType]
UnsetType, # pyright: ignore[reportAssignmentType,reportGeneralTypeIssues]
convert,
)

MSGSPEC_INSTALLED: bool = True
except ImportError:
import enum
from collections.abc import Iterable
from typing import TYPE_CHECKING, Callable, Optional, Union

if TYPE_CHECKING:
from collections.abc import Iterable

@dataclass_transform()
@runtime_checkable
class Struct(Protocol): # type: ignore[no-redef]
"""Placeholder Implementation"""

__struct_fields__: ClassVar[tuple[str, ...]]

def convert(*args: Any, **kwargs: Any) -> Any: # type: ignore[no-redef]
__struct_fields__: "ClassVar[tuple[str, ...]]"

def convert( # type: ignore[no-redef]
obj: Any,
type: "Union[Any, type[T]]", # noqa: A002
*,
strict: bool = True,
from_attributes: bool = False,
dec_hook: "Optional[Callable[[type, Any], Any]]" = None,
builtin_types: "Union[Iterable[type], None]" = None,
str_keys: bool = False,
) -> "Union[T, Any]":
"""Placeholder implementation"""
return {}

Expand All @@ -124,11 +173,11 @@ class DTOData(Protocol[T]): # type: ignore[no-redef]
def __init__(self, backend: Any, data_as_builtins: Any) -> None:
"""Placeholder init"""

def create_instance(self, **kwargs: Any) -> T:
def create_instance(self, **kwargs: Any) -> "T":
"""Placeholder implementation"""
return cast("T", kwargs)

def update_instance(self, instance: T, **kwargs: Any) -> T:
def update_instance(self, instance: "T", **kwargs: Any) -> "T":
"""Placeholder implementation"""
return cast("T", kwargs)

Expand Down
14 changes: 6 additions & 8 deletions sqlspec/adapters/adbc/config.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from __future__ import annotations

from contextlib import contextmanager
from dataclasses import dataclass
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any, Optional, Union

from sqlspec.base import NoPoolSyncConfig
from sqlspec.typing import Empty, EmptyType
Expand All @@ -27,15 +25,15 @@ class AdbcDatabaseConfig(NoPoolSyncConfig["Connection"]):
__supports_connection_pooling = False
__is_async = False

uri: str | EmptyType = Empty
uri: "Union[str, EmptyType]" = Empty
"""Database URI"""
driver_name: str | EmptyType = Empty
driver_name: "Union[str, EmptyType]" = Empty
"""Name of the ADBC driver to use"""
db_kwargs: dict[str, Any] | None = None
db_kwargs: "Optional[dict[str, Any]]" = None
"""Additional database-specific connection parameters"""

@property
def connection_params(self) -> dict[str, Any]:
def connection_params(self) -> "dict[str, Any]":
"""Return the connection parameters as a dict."""
return {
k: v
Expand All @@ -44,7 +42,7 @@ def connection_params(self) -> dict[str, Any]:
}

@contextmanager
def provide_connection(self, *args: Any, **kwargs: Any) -> Generator[Connection, None, None]:
def provide_connection(self, *args: "Any", **kwargs: "Any") -> "Generator[Connection, None, None]":
"""Create and provide a database connection."""
from adbc_driver_manager.dbapi import connect

Expand Down
Loading