From 5bdf0baf86b0a69a51e7b7f75a3dd9e8e8e75e01 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Guido=20Ple=C3=9Fmann?= Date: Mon, 23 Jan 2023 08:19:58 +0100 Subject: [PATCH 1/7] Allow to use dialect mssql+pymssql --- mssql/setup.py | 2 ++ mssql/testcontainers/mssql/__init__.py | 12 +++++++++--- mssql/tests/test_mssql.py | 13 +++++++------ 3 files changed, 18 insertions(+), 9 deletions(-) diff --git a/mssql/setup.py b/mssql/setup.py index c1fd74855..148b71a19 100644 --- a/mssql/setup.py +++ b/mssql/setup.py @@ -13,7 +13,9 @@ install_requires=[ "testcontainers-core", "sqlalchemy", + # TODO: convert these to extras "pymssql", + "pyodbc", ], python_requires=">=3.7", ) diff --git a/mssql/testcontainers/mssql/__init__.py b/mssql/testcontainers/mssql/__init__.py index 678d12891..034fdefb3 100644 --- a/mssql/testcontainers/mssql/__init__.py +++ b/mssql/testcontainers/mssql/__init__.py @@ -1,5 +1,5 @@ from os import environ -from typing import Optional +from typing import Optional, Literal from testcontainers.core.generic import DbContainer @@ -21,7 +21,8 @@ class SqlServerContainer(DbContainer): def __init__(self, image: str = "mcr.microsoft.com/mssql/server:2019-latest", user: str = "SA", password: Optional[str] = None, port: int = 1433, dbname: str = "tempdb", - dialect: str = 'mssql+pymssql', **kwargs) -> None: + dialect: Literal['mssql+pymssql', 'mssql+pyodbc'] = 'mssql+pymssql', **kwargs) -> None: + # TODO: add documentation about dialect super(SqlServerContainer, self).__init__(image, **kwargs) self.port_to_expose = port @@ -39,7 +40,12 @@ def _configure(self) -> None: self.with_env("ACCEPT_EULA", 'Y') def get_connection_url(self) -> str: - return super()._create_connection_url( + url = super()._create_connection_url( dialect=self.dialect, username=self.SQLSERVER_USER, password=self.SQLSERVER_PASSWORD, db_name=self.SQLSERVER_DBNAME, port=self.port_to_expose ) + if self.dialect == "mssql+pyodbc": + # TODO: get ODBC version from installed package + url += "?driver=ODBC+Driver+17+for+SQL+Server" + return url + diff --git a/mssql/tests/test_mssql.py b/mssql/tests/test_mssql.py index 63b0a0135..5c531f431 100644 --- a/mssql/tests/test_mssql.py +++ b/mssql/tests/test_mssql.py @@ -4,12 +4,13 @@ def test_docker_run_mssql(): image = 'mcr.microsoft.com/azure-sql-edge' - dialect = 'mssql+pymssql' - with SqlServerContainer(image, dialect=dialect) as mssql: - e = sqlalchemy.create_engine(mssql.get_connection_url()) - result = e.execute('select @@servicename') - for row in result: - assert row[0] == 'MSSQLSERVER' + dialects = ['mssql+pymssql', 'mssql+pyodbc'] + for dialect in dialects: + with SqlServerContainer(dialect=dialect) as mssql: + e = sqlalchemy.create_engine(mssql.get_connection_url()) + result = e.execute('select @@servicename') + for row in result: + assert row[0] == 'MSSQLSERVER' with SqlServerContainer(image, password="1Secure*Password2", dialect=dialect) as mssql: e = sqlalchemy.create_engine(mssql.get_connection_url()) From 43e28d47f5d6525d1658043738ad872848a22d60 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Guido=20Ple=C3=9Fmann?= Date: Mon, 23 Jan 2023 09:08:56 +0100 Subject: [PATCH 2/7] Add docstring for SqlServerContainer.__init__ --- mssql/testcontainers/mssql/__init__.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/mssql/testcontainers/mssql/__init__.py b/mssql/testcontainers/mssql/__init__.py index 034fdefb3..88ebbc8b3 100644 --- a/mssql/testcontainers/mssql/__init__.py +++ b/mssql/testcontainers/mssql/__init__.py @@ -22,6 +22,21 @@ class SqlServerContainer(DbContainer): def __init__(self, image: str = "mcr.microsoft.com/mssql/server:2019-latest", user: str = "SA", password: Optional[str] = None, port: int = 1433, dbname: str = "tempdb", dialect: Literal['mssql+pymssql', 'mssql+pyodbc'] = 'mssql+pymssql', **kwargs) -> None: + """ + Initialize SqlServerContainer + + Args: + image: MSSQL Server image. For example, use a specific version + user: DB user name + password: DB password + port: Port to be exposed + dbname: Database name + dialect: SQLAlchemy database dialect. Allowed values are + * 'mssql+pymssql': Uses `pymssql `_ driver + * 'mssql+pyodbc': Uses `pyodbc `_ driver + This also defines the driver that is used to connect to the database. + kwargs: Keyword arguments passed to initialization of underlying docker container + """ # TODO: add documentation about dialect super(SqlServerContainer, self).__init__(image, **kwargs) From c9d4d5a0971e94adf51a8599f50680c45271e7d4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Guido=20Ple=C3=9Fmann?= Date: Mon, 23 Jan 2023 09:09:54 +0100 Subject: [PATCH 3/7] Remove TODO notes --- mssql/testcontainers/mssql/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/mssql/testcontainers/mssql/__init__.py b/mssql/testcontainers/mssql/__init__.py index 88ebbc8b3..5fec5191a 100644 --- a/mssql/testcontainers/mssql/__init__.py +++ b/mssql/testcontainers/mssql/__init__.py @@ -37,7 +37,6 @@ def __init__(self, image: str = "mcr.microsoft.com/mssql/server:2019-latest", us This also defines the driver that is used to connect to the database. kwargs: Keyword arguments passed to initialization of underlying docker container """ - # TODO: add documentation about dialect super(SqlServerContainer, self).__init__(image, **kwargs) self.port_to_expose = port From fc9f96c8c036db46750b3bf499d839c8ad052327 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Guido=20Ple=C3=9Fmann?= Date: Mon, 23 Jan 2023 17:04:36 +0100 Subject: [PATCH 4/7] Use pyodbc version for the driver is available --- mssql/testcontainers/mssql/__init__.py | 13 +++++++++++-- mssql/tests/test_mssql.py | 7 +++++-- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/mssql/testcontainers/mssql/__init__.py b/mssql/testcontainers/mssql/__init__.py index 5fec5191a..d2f8d9348 100644 --- a/mssql/testcontainers/mssql/__init__.py +++ b/mssql/testcontainers/mssql/__init__.py @@ -59,7 +59,16 @@ def get_connection_url(self) -> str: db_name=self.SQLSERVER_DBNAME, port=self.port_to_expose ) if self.dialect == "mssql+pyodbc": - # TODO: get ODBC version from installed package - url += "?driver=ODBC+Driver+17+for+SQL+Server" + import pyodbc + import re + r = re.compile('ODBC Driver \d{1,2} for SQL Server') + # We sort drivers in reversed order to get the latest + drivers = sorted(list(filter(r.match, pyodbc.drivers())), reverse=True) + + if len(drivers) > 0: + driver_str = drivers[0].replace(" ", "+") + else: + raise ImportError(f"No driver available for using dialect {self.dialect}") + url += f"?driver={driver_str}" return url diff --git a/mssql/tests/test_mssql.py b/mssql/tests/test_mssql.py index 5c531f431..73ffc9ccc 100644 --- a/mssql/tests/test_mssql.py +++ b/mssql/tests/test_mssql.py @@ -5,9 +5,12 @@ def test_docker_run_mssql(): image = 'mcr.microsoft.com/azure-sql-edge' dialects = ['mssql+pymssql', 'mssql+pyodbc'] - for dialect in dialects: + ends_withs = ["tempdb", "for+SQL+Server"] + for dialect, end_with in zip(dialects, ends_withs): with SqlServerContainer(dialect=dialect) as mssql: - e = sqlalchemy.create_engine(mssql.get_connection_url()) + url = mssql.get_connection_url() + assert url.endswith(end_with) + e = sqlalchemy.create_engine(url) result = e.execute('select @@servicename') for row in result: assert row[0] == 'MSSQLSERVER' From 9e0f60319a569a4530e80d6cc3cf388e7bf0b6cf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Guido=20Ple=C3=9Fmann?= Date: Sat, 4 Feb 2023 16:13:37 +0100 Subject: [PATCH 5/7] Extract method to find latest pyodbc ...and verify that version sorting does not work correctly. --- mssql/testcontainers/mssql/__init__.py | 51 +++++++++++++++++--------- mssql/tests/test_mssql.py | 13 +++++++ 2 files changed, 46 insertions(+), 18 deletions(-) diff --git a/mssql/testcontainers/mssql/__init__.py b/mssql/testcontainers/mssql/__init__.py index d2f8d9348..49690b95a 100644 --- a/mssql/testcontainers/mssql/__init__.py +++ b/mssql/testcontainers/mssql/__init__.py @@ -19,9 +19,16 @@ class SqlServerContainer(DbContainer): ... result = e.execute("select @@VERSION") """ - def __init__(self, image: str = "mcr.microsoft.com/mssql/server:2019-latest", user: str = "SA", - password: Optional[str] = None, port: int = 1433, dbname: str = "tempdb", - dialect: Literal['mssql+pymssql', 'mssql+pyodbc'] = 'mssql+pymssql', **kwargs) -> None: + def __init__( + self, + image: str = "mcr.microsoft.com/mssql/server:2019-latest", + user: str = "SA", + password: Optional[str] = None, + port: int = 1433, + dbname: str = "tempdb", + dialect: Literal["mssql+pymssql", "mssql+pyodbc"] = "mssql+pymssql", + **kwargs, + ) -> None: """ Initialize SqlServerContainer @@ -42,7 +49,9 @@ def __init__(self, image: str = "mcr.microsoft.com/mssql/server:2019-latest", us self.port_to_expose = port self.with_exposed_ports(self.port_to_expose) - self.SQLSERVER_PASSWORD = password or environ.get("SQLSERVER_PASSWORD", "1Secure*Password1") + self.SQLSERVER_PASSWORD = password or environ.get( + "SQLSERVER_PASSWORD", "1Secure*Password1" + ) self.SQLSERVER_USER = user self.SQLSERVER_DBNAME = dbname self.dialect = dialect @@ -51,24 +60,30 @@ def _configure(self) -> None: self.with_env("SA_PASSWORD", self.SQLSERVER_PASSWORD) self.with_env("SQLSERVER_USER", self.SQLSERVER_USER) self.with_env("SQLSERVER_DBNAME", self.SQLSERVER_DBNAME) - self.with_env("ACCEPT_EULA", 'Y') + self.with_env("ACCEPT_EULA", "Y") def get_connection_url(self) -> str: url = super()._create_connection_url( - dialect=self.dialect, username=self.SQLSERVER_USER, password=self.SQLSERVER_PASSWORD, - db_name=self.SQLSERVER_DBNAME, port=self.port_to_expose + dialect=self.dialect, + username=self.SQLSERVER_USER, + password=self.SQLSERVER_PASSWORD, + db_name=self.SQLSERVER_DBNAME, + port=self.port_to_expose, ) if self.dialect == "mssql+pyodbc": - import pyodbc - import re - r = re.compile('ODBC Driver \d{1,2} for SQL Server') - # We sort drivers in reversed order to get the latest - drivers = sorted(list(filter(r.match, pyodbc.drivers())), reverse=True) - - if len(drivers) > 0: - driver_str = drivers[0].replace(" ", "+") - else: - raise ImportError(f"No driver available for using dialect {self.dialect}") - url += f"?driver={driver_str}" + url += f"?driver={self._get_url_suffix_for_latest_pyodbc_version()}" return url + def _get_url_suffix_for_latest_pyodbc_version(self): + import pyodbc + import re + + r = re.compile("ODBC Driver \d{1,2} for SQL Server") + # We sort drivers in reversed order to get the latest + drivers = sorted(list(filter(r.match, pyodbc.drivers())), reverse=True) + if len(drivers) > 0: + driver_str = drivers[0].replace(" ", "+") + else: + raise ImportError(f"No driver available for using dialect {self.dialect}") + + return driver_str diff --git a/mssql/tests/test_mssql.py b/mssql/tests/test_mssql.py index 73ffc9ccc..397f07804 100644 --- a/mssql/tests/test_mssql.py +++ b/mssql/tests/test_mssql.py @@ -1,5 +1,8 @@ +import re + import sqlalchemy from testcontainers.mssql import SqlServerContainer +from unittest.mock import patch def test_docker_run_mssql(): @@ -20,3 +23,13 @@ def test_docker_run_mssql(): result = e.execute('select @@servicename') for row in result: assert row[0] == 'MSSQLSERVER' + + +def test_get_url_suffix_for_latest_pyodbc_version(): + container = SqlServerContainer() + + version_numbers = [10, 8] + with patch("pyodbc.drivers", return_value=[f'ODBC Driver {v} for SQL Server' for v in version_numbers]) as mock_method: + driver_str = container._get_url_suffix_for_latest_pyodbc_version() + latest_version = int(re.findall('\d{1,2}', driver_str)[0]) + assert latest_version == max(version_numbers) From 57cfc9eeb6e447922959524ff6d3384be777ca80 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Guido=20Ple=C3=9Fmann?= Date: Mon, 6 Feb 2023 22:55:04 +0100 Subject: [PATCH 6/7] Select latest driver by sorted integer version number --- mssql/testcontainers/mssql/__init__.py | 4 +++- mssql/tests/test_mssql.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/mssql/testcontainers/mssql/__init__.py b/mssql/testcontainers/mssql/__init__.py index 49690b95a..68fe0e9a4 100644 --- a/mssql/testcontainers/mssql/__init__.py +++ b/mssql/testcontainers/mssql/__init__.py @@ -81,8 +81,10 @@ def _get_url_suffix_for_latest_pyodbc_version(self): r = re.compile("ODBC Driver \d{1,2} for SQL Server") # We sort drivers in reversed order to get the latest drivers = sorted(list(filter(r.match, pyodbc.drivers())), reverse=True) + version_numbers = [int(v) for v in re.findall("\d{1,2}", "".join(drivers))] + max_version_index = version_numbers.index(max(version_numbers)) if len(drivers) > 0: - driver_str = drivers[0].replace(" ", "+") + driver_str = drivers[max_version_index].replace(" ", "+") else: raise ImportError(f"No driver available for using dialect {self.dialect}") diff --git a/mssql/tests/test_mssql.py b/mssql/tests/test_mssql.py index 397f07804..cafb8ea99 100644 --- a/mssql/tests/test_mssql.py +++ b/mssql/tests/test_mssql.py @@ -29,7 +29,7 @@ def test_get_url_suffix_for_latest_pyodbc_version(): container = SqlServerContainer() version_numbers = [10, 8] - with patch("pyodbc.drivers", return_value=[f'ODBC Driver {v} for SQL Server' for v in version_numbers]) as mock_method: + with patch("pyodbc.drivers", return_value=[f'ODBC Driver {v} for SQL Server' for v in version_numbers]): driver_str = container._get_url_suffix_for_latest_pyodbc_version() latest_version = int(re.findall('\d{1,2}', driver_str)[0]) assert latest_version == max(version_numbers) From ed5a0c17e57715d284da9d4b1f16002ff406ebcf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Guido=20Ple=C3=9Fmann?= Date: Tue, 7 Feb 2023 08:14:42 +0100 Subject: [PATCH 7/7] Satisfy flake8 --- mssql/testcontainers/mssql/__init__.py | 8 +++----- mssql/tests/test_mssql.py | 20 +++++++++++--------- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/mssql/testcontainers/mssql/__init__.py b/mssql/testcontainers/mssql/__init__.py index 68fe0e9a4..bfdf379b2 100644 --- a/mssql/testcontainers/mssql/__init__.py +++ b/mssql/testcontainers/mssql/__init__.py @@ -49,9 +49,7 @@ def __init__( self.port_to_expose = port self.with_exposed_ports(self.port_to_expose) - self.SQLSERVER_PASSWORD = password or environ.get( - "SQLSERVER_PASSWORD", "1Secure*Password1" - ) + self.SQLSERVER_PASSWORD = password or environ.get("SQLSERVER_PASSWORD", "1Secure*Password1") self.SQLSERVER_USER = user self.SQLSERVER_DBNAME = dbname self.dialect = dialect @@ -78,10 +76,10 @@ def _get_url_suffix_for_latest_pyodbc_version(self): import pyodbc import re - r = re.compile("ODBC Driver \d{1,2} for SQL Server") + r = re.compile(r"ODBC Driver \d{1,2} for SQL Server") # We sort drivers in reversed order to get the latest drivers = sorted(list(filter(r.match, pyodbc.drivers())), reverse=True) - version_numbers = [int(v) for v in re.findall("\d{1,2}", "".join(drivers))] + version_numbers = [int(v) for v in re.findall(r"\d{1,2}", "".join(drivers))] max_version_index = version_numbers.index(max(version_numbers)) if len(drivers) > 0: driver_str = drivers[max_version_index].replace(" ", "+") diff --git a/mssql/tests/test_mssql.py b/mssql/tests/test_mssql.py index cafb8ea99..aa67020cb 100644 --- a/mssql/tests/test_mssql.py +++ b/mssql/tests/test_mssql.py @@ -6,30 +6,32 @@ def test_docker_run_mssql(): - image = 'mcr.microsoft.com/azure-sql-edge' - dialects = ['mssql+pymssql', 'mssql+pyodbc'] + image = "mcr.microsoft.com/azure-sql-edge" + dialects = ["mssql+pymssql", "mssql+pyodbc"] ends_withs = ["tempdb", "for+SQL+Server"] for dialect, end_with in zip(dialects, ends_withs): with SqlServerContainer(dialect=dialect) as mssql: url = mssql.get_connection_url() assert url.endswith(end_with) e = sqlalchemy.create_engine(url) - result = e.execute('select @@servicename') + result = e.execute("select @@servicename") for row in result: - assert row[0] == 'MSSQLSERVER' + assert row[0] == "MSSQLSERVER" with SqlServerContainer(image, password="1Secure*Password2", dialect=dialect) as mssql: e = sqlalchemy.create_engine(mssql.get_connection_url()) - result = e.execute('select @@servicename') + result = e.execute("select @@servicename") for row in result: - assert row[0] == 'MSSQLSERVER' + assert row[0] == "MSSQLSERVER" def test_get_url_suffix_for_latest_pyodbc_version(): container = SqlServerContainer() version_numbers = [10, 8] - with patch("pyodbc.drivers", return_value=[f'ODBC Driver {v} for SQL Server' for v in version_numbers]): - driver_str = container._get_url_suffix_for_latest_pyodbc_version() - latest_version = int(re.findall('\d{1,2}', driver_str)[0]) + with patch( + "pyodbc.drivers", return_value=[f"ODBC Driver {v} for SQL Server" for v in version_numbers] + ): + driver_str = container._get_url_suffix_for_latest_pyodbc_version() + latest_version = int(re.findall(r"\d{1,2}", driver_str)[0]) assert latest_version == max(version_numbers)