diff --git a/mssql/setup.py b/mssql/setup.py index c1fd74855..148b71a19 100644 --- a/mssql/setup.py +++ b/mssql/setup.py @@ -13,7 +13,9 @@ install_requires=[ "testcontainers-core", "sqlalchemy", + # TODO: convert these to extras "pymssql", + "pyodbc", ], python_requires=">=3.7", ) diff --git a/mssql/testcontainers/mssql/__init__.py b/mssql/testcontainers/mssql/__init__.py index 678d12891..bfdf379b2 100644 --- a/mssql/testcontainers/mssql/__init__.py +++ b/mssql/testcontainers/mssql/__init__.py @@ -1,5 +1,5 @@ from os import environ -from typing import Optional +from typing import Optional, Literal from testcontainers.core.generic import DbContainer @@ -19,9 +19,31 @@ class SqlServerContainer(DbContainer): ... result = e.execute("select @@VERSION") """ - def __init__(self, image: str = "mcr.microsoft.com/mssql/server:2019-latest", user: str = "SA", - password: Optional[str] = None, port: int = 1433, dbname: str = "tempdb", - dialect: str = 'mssql+pymssql', **kwargs) -> None: + def __init__( + self, + image: str = "mcr.microsoft.com/mssql/server:2019-latest", + user: str = "SA", + password: Optional[str] = None, + port: int = 1433, + dbname: str = "tempdb", + dialect: Literal["mssql+pymssql", "mssql+pyodbc"] = "mssql+pymssql", + **kwargs, + ) -> None: + """ + Initialize SqlServerContainer + + Args: + image: MSSQL Server image. For example, use a specific version + user: DB user name + password: DB password + port: Port to be exposed + dbname: Database name + dialect: SQLAlchemy database dialect. Allowed values are + * 'mssql+pymssql': Uses `pymssql `_ driver + * 'mssql+pyodbc': Uses `pyodbc `_ driver + This also defines the driver that is used to connect to the database. + kwargs: Keyword arguments passed to initialization of underlying docker container + """ super(SqlServerContainer, self).__init__(image, **kwargs) self.port_to_expose = port @@ -36,10 +58,32 @@ def _configure(self) -> None: self.with_env("SA_PASSWORD", self.SQLSERVER_PASSWORD) self.with_env("SQLSERVER_USER", self.SQLSERVER_USER) self.with_env("SQLSERVER_DBNAME", self.SQLSERVER_DBNAME) - self.with_env("ACCEPT_EULA", 'Y') + self.with_env("ACCEPT_EULA", "Y") def get_connection_url(self) -> str: - return super()._create_connection_url( - dialect=self.dialect, username=self.SQLSERVER_USER, password=self.SQLSERVER_PASSWORD, - db_name=self.SQLSERVER_DBNAME, port=self.port_to_expose + url = super()._create_connection_url( + dialect=self.dialect, + username=self.SQLSERVER_USER, + password=self.SQLSERVER_PASSWORD, + db_name=self.SQLSERVER_DBNAME, + port=self.port_to_expose, ) + if self.dialect == "mssql+pyodbc": + url += f"?driver={self._get_url_suffix_for_latest_pyodbc_version()}" + return url + + def _get_url_suffix_for_latest_pyodbc_version(self): + import pyodbc + import re + + r = re.compile(r"ODBC Driver \d{1,2} for SQL Server") + # We sort drivers in reversed order to get the latest + drivers = sorted(list(filter(r.match, pyodbc.drivers())), reverse=True) + version_numbers = [int(v) for v in re.findall(r"\d{1,2}", "".join(drivers))] + max_version_index = version_numbers.index(max(version_numbers)) + if len(drivers) > 0: + driver_str = drivers[max_version_index].replace(" ", "+") + else: + raise ImportError(f"No driver available for using dialect {self.dialect}") + + return driver_str diff --git a/mssql/tests/test_mssql.py b/mssql/tests/test_mssql.py index 63b0a0135..aa67020cb 100644 --- a/mssql/tests/test_mssql.py +++ b/mssql/tests/test_mssql.py @@ -1,18 +1,37 @@ +import re + import sqlalchemy from testcontainers.mssql import SqlServerContainer +from unittest.mock import patch def test_docker_run_mssql(): - image = 'mcr.microsoft.com/azure-sql-edge' - dialect = 'mssql+pymssql' - with SqlServerContainer(image, dialect=dialect) as mssql: - e = sqlalchemy.create_engine(mssql.get_connection_url()) - result = e.execute('select @@servicename') - for row in result: - assert row[0] == 'MSSQLSERVER' + image = "mcr.microsoft.com/azure-sql-edge" + dialects = ["mssql+pymssql", "mssql+pyodbc"] + ends_withs = ["tempdb", "for+SQL+Server"] + for dialect, end_with in zip(dialects, ends_withs): + with SqlServerContainer(dialect=dialect) as mssql: + url = mssql.get_connection_url() + assert url.endswith(end_with) + e = sqlalchemy.create_engine(url) + result = e.execute("select @@servicename") + for row in result: + assert row[0] == "MSSQLSERVER" with SqlServerContainer(image, password="1Secure*Password2", dialect=dialect) as mssql: e = sqlalchemy.create_engine(mssql.get_connection_url()) - result = e.execute('select @@servicename') + result = e.execute("select @@servicename") for row in result: - assert row[0] == 'MSSQLSERVER' + assert row[0] == "MSSQLSERVER" + + +def test_get_url_suffix_for_latest_pyodbc_version(): + container = SqlServerContainer() + + version_numbers = [10, 8] + with patch( + "pyodbc.drivers", return_value=[f"ODBC Driver {v} for SQL Server" for v in version_numbers] + ): + driver_str = container._get_url_suffix_for_latest_pyodbc_version() + latest_version = int(re.findall(r"\d{1,2}", driver_str)[0]) + assert latest_version == max(version_numbers)