Skip to content

Commit d6d9eff

Browse files
committed
feat(spanner): Google Spanner Driver
1 parent 7c98ba9 commit d6d9eff

File tree

7 files changed

+1278
-0
lines changed

7 files changed

+1278
-0
lines changed

sqlspec/adapters/spanner/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from .config import SpannerConfig, SpannerPoolConfig
2+
from .driver import SpannerConnection, SpannerDriver
3+
4+
__all__ = ("SpannerConfig", "SpannerPoolConfig", "SpannerConnection", "SpannerDriver")
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from ._sync import SpannerConfig, SpannerPoolConfig
2+
3+
__all__ = ("SpannerConfig", "SpannerPoolConfig")
Lines changed: 274 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,274 @@
1+
import logging
2+
import threading
3+
from contextlib import contextmanager
4+
from dataclasses import dataclass, field
5+
from typing import TYPE_CHECKING, Any, Optional, Union
6+
7+
from google.cloud.spanner_v1 import Client
8+
from google.cloud.spanner_v1.database import Database
9+
from google.cloud.spanner_v1.pool import AbstractSessionPool, FixedSizePool, PingingPool, TransactionPingingPool
10+
from google.cloud.spanner_v1.snapshot import Snapshot
11+
from google.cloud.spanner_v1.transaction import Transaction
12+
13+
from sqlspec.adapters.spanner.driver import SpannerDriver
14+
from sqlspec.base import SyncDatabaseConfig
15+
from sqlspec.exceptions import ImproperConfigurationError
16+
from sqlspec.typing import dataclass_to_dict
17+
18+
if TYPE_CHECKING:
19+
from collections.abc import Generator
20+
21+
from google.auth.credentials import Credentials
22+
23+
# Define the Connection Type alias
24+
SpannerSyncConnection = Union[Snapshot, Transaction]
25+
26+
# Get logger instance
27+
logger = logging.getLogger("sqlspec")
28+
29+
__all__ = ("SpannerConfig", "SpannerPoolConfig")
30+
31+
32+
@dataclass
33+
class SpannerPoolConfig:
34+
"""Configuration for the Spanner session pool.
35+
36+
Ref: https://cloud.google.com/python/docs/reference/spanner/latest/advanced-session-pool-topics
37+
"""
38+
39+
pool_type: type[AbstractSessionPool] = FixedSizePool
40+
"""The type of session pool to use. Defaults to FixedSizePool."""
41+
min_sessions: int = 1
42+
"""The minimum number of sessions to keep in the pool."""
43+
max_sessions: int = 10
44+
"""The maximum number of sessions allowed in the pool."""
45+
labels: Optional[dict[str, str]] = None
46+
"""Labels to apply to sessions created by the pool."""
47+
ping_interval: int = 300 # Default 5 minutes
48+
"""Interval (in seconds) for pinging sessions in PingingPool/TransactionPingingPool."""
49+
# Add other pool-specific configs as needed, e.g., ping_interval for PingingPool
50+
51+
52+
@dataclass
53+
class SpannerConfig(
54+
SyncDatabaseConfig[SpannerSyncConnection, AbstractSessionPool, SpannerDriver]
55+
): # Replace Any with actual Connection/Driver types later
56+
"""Synchronous Google Cloud Spanner database Configuration.
57+
58+
This class provides the configuration for Spanner database connections.
59+
"""
60+
61+
project: Optional[str] = None
62+
"""Google Cloud project ID."""
63+
instance_id: Optional[str] = None
64+
"""Spanner instance ID."""
65+
database_id: Optional[str] = None
66+
"""Spanner database ID."""
67+
credentials: Optional["Credentials"] = None
68+
"""Optional Google Cloud credentials. If None, uses Application Default Credentials."""
69+
client_options: Optional[dict[str, Any]] = None
70+
"""Optional dictionary of client options for the Spanner client."""
71+
pool_config: Optional[SpannerPoolConfig] = field(default_factory=SpannerPoolConfig)
72+
"""Spanner session pool configuration."""
73+
pool_instance: Optional[AbstractSessionPool] = None
74+
"""Optional pre-configured pool instance to use."""
75+
76+
# Define actual types
77+
connection_type: "type[SpannerSyncConnection]" = field(init=False, default=Union[Snapshot, Transaction]) # type: ignore
78+
driver_type: "type[SpannerDriver]" = field(init=False, default=SpannerDriver)
79+
80+
_client: Optional[Client] = field(init=False, default=None, repr=False, hash=False)
81+
_database: Optional[Database] = field(init=False, default=None, repr=False, hash=False)
82+
_ping_thread: "Optional[threading.Thread]" = field(init=False, default=None, repr=False, hash=False)
83+
84+
def __post_init__(self) -> None:
85+
# Basic check, more robust checks might be needed later
86+
if self.pool_instance and not self.pool_config:
87+
# If a pool instance is provided, we might not need pool_config
88+
pass
89+
elif not self.pool_config:
90+
# Create default if not provided and pool_instance is also None
91+
self.pool_config = SpannerPoolConfig()
92+
93+
@property
94+
def client(self) -> Client:
95+
"""Provides the Spanner Client, creating it if necessary."""
96+
if self._client is None:
97+
self._client = Client(
98+
project=self.project,
99+
credentials=self.credentials,
100+
client_options=self.client_options,
101+
)
102+
return self._client
103+
104+
@property
105+
def database(self) -> Database:
106+
"""Provides the Spanner Database instance, creating client, pool, and database if necessary.
107+
108+
This method ensures that the database instance is created and configured correctly.
109+
It also handles any additional configuration options that may be needed for the database.
110+
111+
Args:
112+
*args: Additional positional arguments to pass to the database constructor.
113+
**kwargs: Additional keyword arguments to pass to the database constructor.
114+
115+
Raises:
116+
ImproperConfigurationError: If project, instance, and database IDs are not configured.
117+
118+
Returns:
119+
The configured database instance.
120+
"""
121+
if self._database is None:
122+
if not self.project or not self.instance_id or not self.database_id:
123+
msg = "Project, instance, and database IDs must be configured."
124+
raise ImproperConfigurationError(msg)
125+
126+
# Ensure client exists
127+
spanner_client = self.client
128+
# Ensure pool exists (this will create it if needed)
129+
pool = self.provide_pool()
130+
131+
# Get instance object
132+
instance = spanner_client.instance(self.instance_id) # type: ignore[no-untyped-call]
133+
134+
# Create the final Database object using the created pool
135+
self._database = instance.database(database_id=self.database_id, pool=pool)
136+
return self._database
137+
138+
def provide_pool(self, *args: Any, **kwargs: Any) -> AbstractSessionPool:
139+
"""Provides the configured session pool, creating it if necessary .
140+
141+
This method ensures that the session pool is created and configured correctly.
142+
It also handles any additional configuration options that may be needed for the pool.
143+
144+
Args:
145+
*args: Additional positional arguments to pass to the pool constructor.
146+
**kwargs: Additional keyword arguments to pass to the pool constructor.
147+
148+
Raises:
149+
ImproperConfigurationError: If pool_config is not set or project, instance, and database IDs are not configured.
150+
151+
Returns:
152+
The configured session pool.
153+
"""
154+
if self.pool_instance:
155+
return self.pool_instance
156+
157+
if not self.pool_config:
158+
# This should be handled by __post_init__, but double-check
159+
msg = "pool_config must be set if pool_instance is not provided."
160+
raise ImproperConfigurationError(msg)
161+
162+
if not self.project or not self.instance_id or not self.database_id:
163+
msg = "Project, instance, and database IDs must be configured to create pool."
164+
raise ImproperConfigurationError(msg)
165+
166+
instance = self.client.instance(self.instance_id)
167+
168+
pool_kwargs = dataclass_to_dict(self.pool_config, exclude_empty=True, exclude={"pool_type"})
169+
170+
# Only include ping_interval if using a relevant pool type
171+
if not issubclass(self.pool_config.pool_type, (PingingPool, TransactionPingingPool)):
172+
pool_kwargs.pop("ping_interval", None)
173+
174+
self.pool_instance = self.pool_config.pool_type(
175+
database=Database(database_id=self.database_id, instance=instance), # pyright: ignore
176+
**pool_kwargs,
177+
)
178+
179+
# Start pinging thread if applicable and not already running
180+
if isinstance(self.pool_instance, (PingingPool, TransactionPingingPool)) and self._ping_thread is None:
181+
self._ping_thread = threading.Thread(
182+
target=self.pool_instance.ping,
183+
daemon=True, # Ensure thread exits with application
184+
name=f"spanner-ping-{self.project}-{self.instance_id}-{self.database_id}",
185+
)
186+
self._ping_thread.start()
187+
logger.debug("Started Spanner background ping thread for %s", self.pool_instance)
188+
189+
return self.pool_instance
190+
191+
@contextmanager
192+
def provide_connection(self, *args: Any, **kwargs: Any) -> "Generator[SpannerSyncConnection, None, None]":
193+
"""Provides a Spanner snapshot context (suitable for reads).
194+
195+
This method ensures that the connection is created and configured correctly.
196+
It also handles any additional configuration options that may be needed for the connection.
197+
198+
Args:
199+
*args: Additional positional arguments to pass to the connection constructor.
200+
**kwargs: Additional keyword arguments to pass to the connection constructor.
201+
202+
Yields:
203+
The configured connection.
204+
"""
205+
db = self.database # Ensure database and pool are initialized
206+
with db.snapshot() as snapshot:
207+
yield snapshot # Replace with actual connection object later
208+
209+
@contextmanager
210+
def provide_session(self, *args: Any, **kwargs: Any) -> "Generator[SpannerDriver, None, None]":
211+
"""Provides a driver instance initialized with a connection context (Snapshot).
212+
213+
This method ensures that the driver is created and configured correctly.
214+
It also handles any additional configuration options that may be needed for the driver.
215+
216+
Args:
217+
*args: Additional positional arguments to pass to the driver constructor.
218+
**kwargs: Additional keyword arguments to pass to the driver constructor.
219+
220+
Yields:
221+
The configured driver.
222+
"""
223+
with self.provide_connection(*args, **kwargs) as connection:
224+
yield self.driver_type(connection) # pyright: ignore
225+
226+
def close_pool(self) -> None:
227+
"""Clears internal references to the pool, database, and client."""
228+
# Spanner pool doesn't require explicit closing usually.
229+
self.pool_instance = None
230+
self._database = None
231+
self._client = None
232+
# Clear thread reference, but don't need to join (it's daemon)
233+
self._ping_thread = None
234+
235+
@property
236+
def connection_config_dict(self) -> "dict[str, Any]":
237+
"""Returns connection-related parameters."""
238+
config = {
239+
"project": self.project,
240+
"instance_id": self.instance_id,
241+
"database_id": self.database_id,
242+
"credentials": self.credentials,
243+
"client_options": self.client_options,
244+
}
245+
return {k: v for k, v in config.items() if v is not None}
246+
247+
@property
248+
def pool_config_dict(self) -> "dict[str, Any]":
249+
"""Returns pool configuration parameters.
250+
251+
This method ensures that the pool configuration is returned correctly.
252+
It also handles any additional configuration options that may be needed for the pool.
253+
254+
Args:
255+
*args: Additional positional arguments to pass to the pool constructor.
256+
**kwargs: Additional keyword arguments to pass to the pool constructor.
257+
258+
Raises:
259+
ImproperConfigurationError: If pool_config is not set or project, instance, and database IDs are not configured.
260+
261+
Returns:
262+
The pool configuration parameters.
263+
"""
264+
if self.pool_config:
265+
return dataclass_to_dict(self.pool_config, exclude_empty=True)
266+
# If pool_config was not initially provided but pool_instance was,
267+
# this method might be called unexpectedly. Add check.
268+
if self.pool_instance:
269+
# We can't reconstruct the config dict from the instance easily.
270+
msg = "Cannot retrieve pool_config_dict when initialized with pool_instance."
271+
raise ImproperConfigurationError(msg)
272+
# Should not be reachable if __post_init__ runs correctly
273+
msg = "pool_config is not set."
274+
raise ImproperConfigurationError(msg)

0 commit comments

Comments
 (0)