diff --git a/connectors/config.py b/connectors/config.py index 0fd988695..f16fb01fc 100644 --- a/connectors/config.py +++ b/connectors/config.py @@ -121,6 +121,7 @@ def _default_config(): "network_drive": "connectors.sources.network_drive:NASDataSource", "notion": "connectors.sources.notion:NotionDataSource", "onedrive": "connectors.sources.onedrive:OneDriveDataSource", + "onelake": "connectors.sources.onelake:OneLakeDataSource", "oracle": "connectors.sources.oracle:OracleDataSource", "outlook": "connectors.sources.outlook:OutlookDataSource", "postgresql": "connectors.sources.postgresql:PostgreSQLDataSource", diff --git a/connectors/sources/onelake.py b/connectors/sources/onelake.py new file mode 100644 index 000000000..a81807e49 --- /dev/null +++ b/connectors/sources/onelake.py @@ -0,0 +1,388 @@ +"""OneLake connector to retrieve data from datalakes""" + +import asyncio +from functools import partial + +from azure.identity import ClientSecretCredential +from azure.storage.filedatalake import DataLakeServiceClient + +from connectors.source import BaseDataSource + +ACCOUNT_NAME = "onelake" + + +class OneLakeDataSource(BaseDataSource): + """OneLake""" + + name = "OneLake" + service_type = "onelake" + incremental_sync_enabled = True + + def __init__(self, configuration): + """Set up the connection to the azure base client + + Args: + configuration (DataSourceConfiguration): Object of DataSourceConfiguration class. + """ + super().__init__(configuration=configuration) + self.tenant_id = self.configuration["tenant_id"] + self.client_id = self.configuration["client_id"] + self.client_secret = self.configuration["client_secret"] + self.workspace_name = self.configuration["workspace_name"] + self.data_path = self.configuration["data_path"] + self.account_url = ( + f"https://{self.configuration['account_name']}.dfs.fabric.microsoft.com" + ) + self.service_client = None + self.file_system_client = None + self.directory_client = None + + @classmethod + def get_default_configuration(cls): + """Get the default configuration for OneLake + + Returns: + dictionary: Default configuration + """ + return { + "tenant_id": { + "label": "OneLake tenant id", + "order": 1, + "type": "str", + }, + "client_id": { + "label": "OneLake client id", + "order": 2, + "type": "str", + }, + "client_secret": { + "label": "OneLake client secret", + "order": 3, + "type": "str", + "sensitive": True, + }, + "workspace_name": { + "label": "OneLake workspace name", + "order": 4, + "type": "str", + }, + "data_path": { + "label": "OneLake data path", + "tooltip": "Path in format .Lakehouse/files/ this is uppercase sensitive, make sure to insert the correct path", + "order": 5, + "type": "str", + }, + "account_name": { + "tooltip": "In the most cases is 'onelake'", + "default_value": ACCOUNT_NAME, + "label": "Account name", + "required": False, + "order": 6, + "type": "str", + }, + } + + async def initialize(self): + """Initialize the Azure clients asynchronously""" + + if not self.service_client: + self.service_client = await self._get_service_client() + self.file_system_client = await self._get_file_system_client() + self.directory_client = await self._get_directory_client() + + async def ping(self): + """Verify the connection with OneLake""" + + self._logger.info("Generating file system client...") + + try: + await self.initialize() # Initialize the clients + + await self._get_directory_paths( + self.configuration["data_path"] + ) # Condition to check if the connection is successful + self._logger.info( + f"Connection to OneLake successful to {self.configuration['data_path']}" + ) + except Exception: + self._logger.exception("Error while connecting to OneLake.") + raise + + async def _process_items_concurrently( + self, items, process_item_func, max_concurrency=50 + ): + """Process a list of items concurrently using a semaphore for concurrency control. + + This function applies the `process_item_func` to each item in the `items` list + using a semaphore to control the level of concurrency. + + Args: + items (list): List of items to process. + process_item_func (function): The function to be called for each item. + max_concurrency (int): Maximum number of concurrent items to process. + + Returns: + list: A list containing the results of processing each item. + """ + + async def process_item(item, semaphore): + async with semaphore: + return await process_item_func(item) + + semaphore = asyncio.Semaphore(max_concurrency) + tasks = [process_item(item, semaphore) for item in items] + return await asyncio.gather(*tasks) + + def _get_token_credentials(self): + """Get the token credentials for OneLake + + Returns: + obj: Token credentials + """ + + tenant_id = self.configuration["tenant_id"] + client_id = self.configuration["client_id"] + client_secret = self.configuration["client_secret"] + + try: + return ClientSecretCredential(tenant_id, client_id, client_secret) + except Exception as e: + self._logger.error(f"Error while getting token credentials: {e}") + raise + + async def _get_service_client(self): + """Get the service client for OneLake. The service client is the client that allows to interact with the OneLake service. + + Returns: + obj: Service client + """ + + try: + return DataLakeServiceClient( + account_url=self.account_url, + credential=self._get_token_credentials(), + ) + except Exception as e: + self._logger.error(f"Error while getting service client: {e}") + raise + + async def _get_file_system_client(self): + """Get the file system client for OneLake. This client is used to interact with the file system of the OneLake service. + + Returns: + obj: File system client + """ + + try: + return self.service_client.get_file_system_client( + self.configuration["workspace_name"] + ) + except Exception as e: + self._logger.error(f"Error while getting file system client: {e}") + raise + + async def _get_directory_client(self): + """Get the directory client for OneLake + + Returns: + obj: Directory client + """ + + try: + return self.file_system_client.get_directory_client( + self.configuration["data_path"] + ) + except Exception as e: + self._logger.error(f"Error while getting directory client: {e}") + raise + + async def _get_file_client(self, file_name): + """Get file client from OneLake + + Args: + file_name (str): name of the file + + Returns: + obj: File client + """ + + try: + return self.directory_client.get_file_client(file_name) + except Exception as e: + self._logger.error(f"Error while getting file client: {e}") + raise + + async def get_files_properties(self, file_clients): + """Get the properties of a list of file clients + + Args: + file_clients (list): List of file clients + + Returns: + list: List of file properties + """ + + async def get_properties(file_client): + return file_client.get_file_properties() + + return await self._process_items_concurrently(file_clients, get_properties) + + async def _get_directory_paths(self, directory_path): + """List directory paths from data lake + + Args: + directory_path (str): Directory path + + Returns: + list: List of paths + """ + + try: + if not self.file_system_client: + await self.initialize() + + loop = asyncio.get_running_loop() + paths = await loop.run_in_executor( + None, self.file_system_client.get_paths, directory_path + ) + + return paths + except Exception as e: + self._logger.error(f"Error while getting directory paths: {e}") + raise + + async def format_file(self, file_client): + """Format file_client to be processed + + Args: + file_client (obj): File object + + Returns: + dict: Formatted file + """ + + try: + loop = asyncio.get_running_loop() + file_properties = await loop.run_in_executor( + None, file_client.get_file_properties + ) + + return { + "_id": f"{file_client.file_system_name}_{file_properties.name.split('/')[-1]}", + "name": file_properties.name.split("/")[-1], + "created_at": file_properties.creation_time.isoformat(), + "_timestamp": file_properties.last_modified.isoformat(), + "size": file_properties.size, + } + except Exception as e: + self._logger.error( + f"Error while formatting file or getting file properties: {e}" + ) + raise + + async def download_file(self, file_client): + """Download file from OneLake + + Args: + file_client (obj): File client + + Returns: + generator: File stream + """ + + try: + loop = asyncio.get_running_loop() + download = await loop.run_in_executor(None, file_client.download_file) + + stream = download.chunks() + for chunk in stream: + yield chunk + except Exception as e: + self._logger.error(f"Error while downloading file: {e}") + raise + + async def get_content(self, file_name, doit=None, timestamp=None): + """Obtains the file content for the specified file in `file_name`. + + Args: + file_name (obj): The file name to process to obtain the content. + timestamp (timestamp, optional): Timestamp of blob last modified. Defaults to None. + doit (boolean, optional): Boolean value for whether to get content or not. Defaults to None. + + Returns: + str: Content of the file or None if not applicable. + """ + + if not doit: + return + + file_client = await self._get_file_client(file_name) + file_properties = file_client.get_file_properties() + file_extension = self.get_file_extension(file_name) + + doc = { + "_id": f"{file_client.file_system_name}_{file_properties.name}", # id in format _ + } + + can_be_downloaded = self.can_file_be_downloaded( + file_extension=file_extension, + filename=file_properties.name, + file_size=file_properties.size, + ) + + if not can_be_downloaded: + self._logger.warning( + f"File {file_properties.name} cannot be downloaded. Skipping." + ) + return doc + + self._logger.debug(f"Downloading file {file_properties.name}...") + extracted_doc = await self.download_and_extract_file( + doc=doc, + source_filename=file_properties.name.split("/")[-1], + file_extension=file_extension, + download_func=partial(self.download_file, file_client), + ) + + return extracted_doc if extracted_doc is not None else doc + + async def prepare_files(self, doc_paths): + """Prepare files for processing + + Args: + doc_paths (list): List of paths extracted from OneLake + + Returns: + list: List of files + """ + + async def prepare_single_file(path): + file_name = path.name.split("/")[-1] + field_client = await self._get_file_client(file_name) + return await self.format_file(field_client) + + files = await self._process_items_concurrently(doc_paths, prepare_single_file) + + for file in files: + yield file + + async def get_docs(self, filtering=None): + """Get documents from OneLake and index them + + Yields: + tuple: dictionary with meta-data of each file and a partial function to get the file content. + """ + + self._logger.info(f"Fetching files from OneLake datalake {self.data_path}") + + directory_paths = await self._get_directory_paths( + self.configuration["data_path"] + ) + directory_paths = list(directory_paths) + + self._logger.debug(f"Found {len(directory_paths)} files in {self.data_path}") + + async for file in self.prepare_files(directory_paths): + file_dict = file + + yield file_dict, partial(self.get_content, file_dict["name"]) diff --git a/requirements/framework.txt b/requirements/framework.txt index 775ad7ac4..94da2fc48 100644 --- a/requirements/framework.txt +++ b/requirements/framework.txt @@ -44,3 +44,5 @@ notion-client==2.2.1 certifi==2024.7.4 aioboto3==12.4.0 pyasn1<0.6.1 +azure-identity==1.19.0 +azure-storage-file-datalake==12.14.0 \ No newline at end of file diff --git a/tests/sources/test_onelake.py b/tests/sources/test_onelake.py new file mode 100644 index 000000000..6215b40e5 --- /dev/null +++ b/tests/sources/test_onelake.py @@ -0,0 +1,813 @@ +from contextlib import asynccontextmanager +from datetime import datetime +from functools import partial +from unittest.mock import AsyncMock, MagicMock, Mock, patch + +import pytest + +from connectors.sources.onelake import OneLakeDataSource +from tests.sources.support import create_source + + +@asynccontextmanager +async def create_onelake_source( + use_text_extraction_service=False, +): + async with create_source( + OneLakeDataSource, + tenant_id="fake-tenant", + client_id="fake-client", + client_secret="fake-client-secret", + workspace_name="FakeWorkspace", + data_path="FakeDatalake.Lakehouse/Files/Data", + account_name="onelake", + use_text_extraction_service=use_text_extraction_service, + ) as source: + yield source + + +@pytest.mark.asyncio +async def test_init(): + """Test OneLakeDataSource initialization""" + + async with create_onelake_source() as source: + # Check that all configuration values are set correctly + assert source.tenant_id == source.configuration["tenant_id"] + assert source.client_id == source.configuration["client_id"] + assert source.client_secret == source.configuration["client_secret"] + assert source.workspace_name == source.configuration["workspace_name"] + assert source.data_path == source.configuration["data_path"] + assert ( + source.account_url + == f"https://{source.configuration['account_name']}.dfs.fabric.microsoft.com" + ) + + # Check that clients are initially None + assert source.service_client is None + assert source.file_system_client is None + assert source.directory_client is None + + +def test_get_default_configuration(): + """Test get_default_configuration class method""" + + config = OneLakeDataSource.get_default_configuration() + + # Check that all required configuration fields are present + required_fields = [ + "tenant_id", + "client_id", + "client_secret", + "workspace_name", + "data_path", + "account_name", + ] + + for field in required_fields: + assert field in config + assert "label" in config[field] + assert "order" in config[field] + assert "type" in config[field] + + # Check specific configurations + assert config["account_name"]["default_value"] == "onelake" + assert config["client_secret"]["sensitive"] is True + assert config["account_name"]["required"] is False + + +@pytest.mark.asyncio +async def test_initialize(): + """Test initialize method""" + + async with create_onelake_source() as source: + mock_service_client = Mock() + mock_file_system_client = Mock() + mock_directory_client = Mock() + + with patch.object( + source, "_get_service_client", new_callable=AsyncMock + ) as mock_get_service: + with patch.object( + source, "_get_file_system_client", new_callable=AsyncMock + ) as mock_get_fs: + with patch.object( + source, "_get_directory_client", new_callable=AsyncMock + ) as mock_get_dir: + mock_get_service.return_value = mock_service_client + mock_get_fs.return_value = mock_file_system_client + mock_get_dir.return_value = mock_directory_client + + # Test first initialization + await source.initialize() + + assert source.service_client == mock_service_client + assert source.file_system_client == mock_file_system_client + assert source.directory_client == mock_directory_client + + mock_get_service.assert_called_once() + mock_get_fs.assert_called_once() + mock_get_dir.assert_called_once() + + # Test that it doesn't re-initialize if service_client already exists + mock_get_service.reset_mock() + mock_get_fs.reset_mock() + mock_get_dir.reset_mock() + + await source.initialize() + + # Should not be called again + mock_get_service.assert_not_called() + mock_get_fs.assert_not_called() + mock_get_dir.assert_not_called() + + +@pytest.mark.asyncio +async def test_ping_for_successful_connection(): + """Test ping method of OneLakeDataSource class""" + + # Setup + async with create_onelake_source() as source: + with patch.object( + source, "_get_directory_paths", new_callable=AsyncMock + ) as mock_get_paths: + mock_get_paths.return_value = [] + + await source.ping() + + mock_get_paths.assert_called_once_with(source.configuration["data_path"]) + + +@pytest.mark.asyncio +async def test_ping_for_failed_connection(): + """Test ping method of OneLakeDataSource class with negative case""" + + # Setup + async with create_onelake_source() as source: + with patch.object( + source, "_get_directory_paths", new_callable=AsyncMock + ) as mock_get_paths: + mock_get_paths.side_effect = Exception("Something went wrong") + + # Run & Check + with pytest.raises(Exception, match="Something went wrong"): + await source.ping() + + mock_get_paths.assert_called_once_with(source.configuration["data_path"]) + + +@pytest.mark.asyncio +async def test_get_token_credentials(): + """Test _get_token_credentials method of OneLakeDataSource class""" + + # Setup + async with create_onelake_source() as source: + tenant_id = source.configuration["tenant_id"] + client_id = source.configuration["client_id"] + client_secret = source.configuration["client_secret"] + + with patch( + "connectors.sources.onelake.ClientSecretCredential", autospec=True + ) as mock_credential: + mock_instance = mock_credential.return_value + + # Run + credentials = source._get_token_credentials() + + # Check + mock_credential.assert_called_once_with(tenant_id, client_id, client_secret) + assert credentials is mock_instance + + +@pytest.mark.asyncio +async def test_get_token_credentials_error(): + """Test _get_token_credentials method when credential creation fails""" + + async with create_onelake_source() as source: + with patch( + "connectors.sources.onelake.ClientSecretCredential", autospec=True + ) as mock_credential: + mock_credential.side_effect = Exception("Credential error") + + with pytest.raises(Exception, match="Credential error"): + source._get_token_credentials() + + +@pytest.mark.asyncio +async def test_get_service_client(): + """Test _get_service_client method of OneLakeDataSource class""" + + # Setup + async with create_onelake_source() as source: + mock_service_client = Mock() + mock_credentials = Mock() + + with patch( + "connectors.sources.onelake.DataLakeServiceClient", + autospec=True, + ) as mock_client, patch.object( + source, "_get_token_credentials", return_value=mock_credentials + ): + mock_client.return_value = mock_service_client + + # Run + service_client = await source._get_service_client() + + # Check + mock_client.assert_called_once_with( + account_url=source.account_url, + credential=mock_credentials, + ) + assert service_client is mock_service_client + + +@pytest.mark.asyncio +async def test_get_service_client_error(): + """Test _get_service_client method when client creation fails""" + + async with create_onelake_source() as source: + with patch( + "connectors.sources.onelake.DataLakeServiceClient", + side_effect=Exception("Service client error"), + ): + with pytest.raises(Exception, match="Service client error"): + await source._get_service_client() + + +@pytest.mark.asyncio +async def test_get_file_system_client(): + """Test _get_file_system_client method of OneLakeDataSource class""" + + # Setup + async with create_onelake_source() as source: + mock_file_system_client = Mock() + workspace_name = source.configuration["workspace_name"] + + # Set up the service_client that _get_file_system_client depends on + mock_service_client = Mock() + mock_service_client.get_file_system_client.return_value = ( + mock_file_system_client + ) + source.service_client = mock_service_client + + # Run + file_system_client = await source._get_file_system_client() + + # Check + mock_service_client.get_file_system_client.assert_called_once_with( + workspace_name + ) + assert file_system_client == mock_file_system_client + + +@pytest.mark.asyncio +async def test_get_file_system_client_error(): + """Test _get_file_system_client method when client creation fails""" + + async with create_onelake_source() as source: + mock_service_client = Mock() + mock_service_client.get_file_system_client.side_effect = Exception("Test error") + source.service_client = mock_service_client + + with pytest.raises(Exception, match="Test error"): + await source._get_file_system_client() + + +@pytest.mark.asyncio +async def test_get_directory_client(): + """Test _get_directory_client method of OneLakeDataSource class""" + + # Setup + async with create_onelake_source() as source: + mock_directory_client = Mock() + data_path = source.configuration["data_path"] + + # Set up the file_system_client that _get_directory_client depends on + mock_file_system_client = Mock() + mock_file_system_client.get_directory_client.return_value = ( + mock_directory_client + ) + source.file_system_client = mock_file_system_client + + # Run + directory_client = await source._get_directory_client() + + # Check + mock_file_system_client.get_directory_client.assert_called_once_with(data_path) + assert directory_client == mock_directory_client + + +@pytest.mark.asyncio +async def test_get_directory_client_error(): + """Test _get_directory_client method when client creation fails""" + + async with create_onelake_source() as source: + mock_file_system_client = Mock() + mock_file_system_client.get_directory_client.side_effect = Exception( + "Test error" + ) + source.file_system_client = mock_file_system_client + + with pytest.raises(Exception, match="Test error"): + await source._get_directory_client() + + +@pytest.mark.asyncio +async def test_get_file_client_success(): + """Test successful file client retrieval""" + + mock_file_client = Mock() + mock_directory_client = Mock() + mock_directory_client.get_file_client.return_value = mock_file_client + + async with create_onelake_source() as source: + # Mock the directory_client directly since that's what _get_file_client uses + source.directory_client = mock_directory_client + + result = await source._get_file_client("test.txt") + + assert result == mock_file_client + mock_directory_client.get_file_client.assert_called_once_with("test.txt") + + +@pytest.mark.asyncio +async def test_get_file_client_error(): + """Test file client retrieval with error""" + + async with create_onelake_source() as source: + mock_directory_client = Mock() + mock_directory_client.get_file_client.side_effect = Exception( + "Error while getting file client" + ) + source.directory_client = mock_directory_client + + with pytest.raises(Exception, match="Error while getting file client"): + await source._get_file_client("test.txt") + + +@pytest.mark.asyncio +async def test_get_directory_paths(): + """Test _get_directory_paths method of OneLakeDataSource class""" + + # Setup + async with create_onelake_source() as source: + mock_paths = ["path1", "path2"] + directory_path = "mock_directory_path" + + # Set up the file_system_client so initialize() is not called + mock_file_system_client = Mock() + source.file_system_client = mock_file_system_client + + # Mock the run_in_executor call + with patch("asyncio.get_running_loop") as mock_loop: + mock_loop.return_value.run_in_executor = AsyncMock(return_value=mock_paths) + + # Run + paths = await source._get_directory_paths(directory_path) + + # Check + assert paths == mock_paths + mock_loop.return_value.run_in_executor.assert_called_once_with( + None, mock_file_system_client.get_paths, directory_path + ) + + +@pytest.mark.asyncio +async def test_get_directory_paths_with_initialize(): + """Test _get_directory_paths method when file_system_client is None""" + + async with create_onelake_source() as source: + mock_paths = ["path1", "path2"] + directory_path = "mock_directory_path" + + # Ensure file_system_client is None to trigger initialize() + source.file_system_client = None + + # Mock initialize to set up file_system_client + async def mock_initialize(): + mock_file_system_client = Mock() + source.file_system_client = mock_file_system_client + + with patch.object( + source, "initialize", side_effect=mock_initialize + ) as mock_init: + with patch("asyncio.get_running_loop") as mock_loop: + mock_loop.return_value.run_in_executor = AsyncMock( + return_value=mock_paths + ) + + # Run + paths = await source._get_directory_paths(directory_path) + + # Check + mock_init.assert_called_once() + assert paths == mock_paths + + +@pytest.mark.asyncio +async def test_get_directory_paths_error(): + """Test _get_directory_paths method when getting paths fails""" + + async with create_onelake_source() as source: + directory_path = "mock_directory_path" + + # Set up the file_system_client so initialize() is not called + mock_file_system_client = Mock() + source.file_system_client = mock_file_system_client + + with patch("asyncio.get_running_loop") as mock_loop: + mock_loop.return_value.run_in_executor = AsyncMock( + side_effect=Exception("Error while getting directory paths") + ) + + with pytest.raises(Exception, match="Error while getting directory paths"): + await source._get_directory_paths(directory_path) + + +@pytest.mark.asyncio +async def test_format_file(): + """Test format_file method of OneLakeDataSource class""" + + # Setup + async with create_onelake_source() as source: + mock_file_client = MagicMock() + mock_file_properties = MagicMock( + creation_time=datetime(2022, 4, 21, 12, 12, 30), + last_modified=datetime(2022, 4, 22, 15, 45, 10), + size=2048, + name="path/to/file.txt", + ) + + mock_file_properties.name.split.return_value = ["path", "to", "file.txt"] + mock_file_client.file_system_name = "my_file_system" + + expected_output = { + "_id": "my_file_system_file.txt", + "name": "file.txt", + "created_at": "2022-04-21T12:12:30", + "_timestamp": "2022-04-22T15:45:10", + "size": 2048, + } + + # Mock the run_in_executor call since format_file is now async + with patch("asyncio.get_running_loop") as mock_loop: + mock_loop.return_value.run_in_executor = AsyncMock( + return_value=mock_file_properties + ) + + # Execute + actual_output = await source.format_file(mock_file_client) + + # Assert + assert actual_output == expected_output + + +@pytest.mark.asyncio +async def test_format_file_error(): + """Test format_file method when getting properties fails""" + + async with create_onelake_source() as source: + mock_file_client = MagicMock() + mock_file_client.file_system_name = "my_file_system" + + with patch("asyncio.get_running_loop") as mock_loop: + mock_loop.return_value.run_in_executor = AsyncMock( + side_effect=Exception("Test error") + ) + + with pytest.raises(Exception, match="Test error"): + await source.format_file(mock_file_client) + + +@pytest.mark.asyncio +async def test_format_file_empty_name(): + """Test format_file method with empty file name""" + + async with create_onelake_source() as source: + mock_file_client = MagicMock() + mock_file_properties = MagicMock( + creation_time=datetime(2022, 4, 21, 12, 12, 30), + last_modified=datetime(2022, 4, 22, 15, 45, 10), + size=2048, + name="", + ) + mock_file_properties.name.split.return_value = [""] + mock_file_client.file_system_name = "my_file_system" + + with patch("asyncio.get_running_loop") as mock_loop: + mock_loop.return_value.run_in_executor = AsyncMock( + return_value=mock_file_properties + ) + + result = await source.format_file(mock_file_client) + assert result["name"] == "" + assert result["_id"] == "my_file_system_" + + +@pytest.mark.asyncio +async def test_download_file(): + """Test download_file method of OneLakeDataSource class""" + + # Setup + mock_file_client = Mock() + mock_download = Mock() + mock_chunks = ["chunk1", "chunk2", "chunk3"] + + async with create_onelake_source() as source: + with patch("asyncio.get_running_loop") as mock_loop: + mock_loop.return_value.run_in_executor = AsyncMock( + return_value=mock_download + ) + mock_download.chunks.return_value = iter(mock_chunks) + + # Run + chunks = [] + async for chunk in source.download_file(mock_file_client): + chunks.append(chunk) + + # Check + assert chunks == mock_chunks + mock_loop.return_value.run_in_executor.assert_called_once() + + +@pytest.mark.asyncio +async def test_download_file_with_error(): + """Test download_file method of OneLakeDataSource class with exception handling""" + + # Setup + mock_file_client = Mock() + + async with create_onelake_source() as source: + with patch("asyncio.get_running_loop") as mock_loop: + mock_loop.return_value.run_in_executor = AsyncMock( + side_effect=Exception("Test error") + ) + + # Run & Check + with pytest.raises(Exception, match="Test error"): + async for _ in source.download_file(mock_file_client): + pass + + +@pytest.mark.asyncio +async def test_get_content_with_download(): + """Test get_content method when doit=True""" + + async with create_onelake_source() as source: + + class FileClientMock: + file_system_name = "mockfilesystem" + + class FileProperties: + def __init__(self, name, size): + self.name = name + self.size = size + + def get_file_properties(self): + return self.FileProperties(name="file1.txt", size=2000) + + with patch.object( + source, + "_get_file_client", + new_callable=AsyncMock, + return_value=FileClientMock(), + ), patch.object( + source, "can_file_be_downloaded", return_value=True + ), patch.object( + source, "get_file_extension", return_value="txt" + ), patch.object( + source, + "download_and_extract_file", + new_callable=AsyncMock, + return_value={ + "_id": "mockfilesystem_file1.txt", + "_attachment": "TW9jayBjb250ZW50", + }, + ): + actual_response = await source.get_content("file1.txt", doit=True) + assert actual_response == { + "_id": "mockfilesystem_file1.txt", + "_attachment": "TW9jayBjb250ZW50", + } + + +@pytest.mark.asyncio +async def test_get_content_without_download(): + """Test get_content method when doit=False""" + + async with create_onelake_source() as source: + actual_response = await source.get_content("file1.txt", doit=False) + assert actual_response is None + + +@pytest.mark.asyncio +async def test_prepare_files(): + """Test prepare_files method of OneLakeDataSource class""" + + # Setup + doc_paths = [ + Mock(name="doc1.txt"), + Mock(name="doc2.txt"), + ] + + async with create_onelake_source() as source: + mock_file_results = [ + {"name": "doc1.txt", "id": "1"}, + {"name": "doc2.txt", "id": "2"}, + ] + + with patch.object( + source, "_process_items_concurrently", new_callable=AsyncMock + ) as mock_process: + mock_process.return_value = mock_file_results + + result = [] + async for item in source.prepare_files(doc_paths): + result.append(item) + + # Check results + assert result == mock_file_results + # Check that _process_items_concurrently was called with the paths + mock_process.assert_called_once() + call_args = mock_process.call_args[0] + assert call_args[0] == doc_paths # First argument should be the paths + + +@pytest.mark.asyncio +async def test_get_docs(): + """Test get_docs method of OneLakeDataSource class""" + + mock_paths = [ + Mock(name="doc1", path="folder/doc1"), + Mock(name="doc2", path="folder/doc2"), + ] + + mock_file_docs = [{"name": "doc1", "id": "1"}, {"name": "doc2", "id": "2"}] + + async def mock_prepare_files_impl(paths): + for doc in mock_file_docs: + yield doc + + async with create_onelake_source() as source: + with patch.object( + source, "_get_directory_paths", new_callable=AsyncMock + ) as mock_get_paths: + mock_get_paths.return_value = mock_paths + + with patch.object( + source, "prepare_files", side_effect=mock_prepare_files_impl + ): + result = [] + async for doc, get_content in source.get_docs(): + result.append((doc, get_content)) + + mock_get_paths.assert_called_once_with( + source.configuration["data_path"] + ) + assert len(result) == 2 + for (doc, get_content), expected_doc in zip(result, mock_file_docs): + assert doc == expected_doc + assert isinstance(get_content, partial) + assert get_content.func == source.get_content + assert get_content.args == (doc["name"],) + + +@pytest.mark.asyncio +async def test_process_items_concurrently(): + """Test _process_items_concurrently method""" + + async with create_onelake_source() as source: + items = ["item1", "item2", "item3"] + + async def mock_process_item(item): + return f"processed_{item}" + + result = await source._process_items_concurrently(items, mock_process_item) + + expected = ["processed_item1", "processed_item2", "processed_item3"] + assert result == expected + + +@pytest.mark.asyncio +async def test_process_items_concurrently_with_custom_concurrency(): + """Test _process_items_concurrently method with custom max_concurrency""" + + async with create_onelake_source() as source: + items = ["item1", "item2"] + + async def mock_process_item(item): + return f"processed_{item}" + + result = await source._process_items_concurrently( + items, mock_process_item, max_concurrency=1 + ) + + expected = ["processed_item1", "processed_item2"] + assert result == expected + + +@pytest.mark.asyncio +async def test_get_files_properties(): + """Test get_files_properties method""" + + async with create_onelake_source() as source: + mock_file_client1 = Mock() + mock_file_client2 = Mock() + mock_properties1 = Mock() + mock_properties2 = Mock() + + mock_file_client1.get_file_properties.return_value = mock_properties1 + mock_file_client2.get_file_properties.return_value = mock_properties2 + + file_clients = [mock_file_client1, mock_file_client2] + + with patch.object( + source, "_process_items_concurrently", new_callable=AsyncMock + ) as mock_process: + mock_process.return_value = [mock_properties1, mock_properties2] + + result = await source.get_files_properties(file_clients) + + assert result == [mock_properties1, mock_properties2] + mock_process.assert_called_once() + # Check that the first argument is the file_clients list + call_args = mock_process.call_args[0] + assert call_args[0] == file_clients + + +@pytest.mark.asyncio +async def test_get_content_file_cannot_be_downloaded(): + """Test get_content method when file cannot be downloaded""" + + async with create_onelake_source() as source: + + class FileClientMock: + file_system_name = "mockfilesystem" + + class FileProperties: + def __init__(self, name, size): + self.name = name + self.size = size + + def get_file_properties(self): + return self.FileProperties( + name="large_file.exe", size=200000000 + ) # Very large file + + with patch.object( + source, + "_get_file_client", + new_callable=AsyncMock, + return_value=FileClientMock(), + ), patch.object( + source, "can_file_be_downloaded", return_value=False # Cannot download + ), patch.object( + source, "get_file_extension", return_value="exe" + ): + result = await source.get_content("large_file.exe", doit=True) + + # Should return only the basic doc without content + expected = { + "_id": "mockfilesystem_large_file.exe", + } + assert result == expected + + +@pytest.mark.asyncio +async def test_get_content_extracted_doc_is_none(): + """Test get_content method when download_and_extract_file returns None""" + + async with create_onelake_source() as source: + + class FileClientMock: + file_system_name = "mockfilesystem" + + class FileProperties: + def __init__(self, name, size): + self.name = name + self.size = size + + def get_file_properties(self): + return self.FileProperties(name="file.txt", size=1000) + + with patch.object( + source, + "_get_file_client", + new_callable=AsyncMock, + return_value=FileClientMock(), + ), patch.object( + source, "can_file_be_downloaded", return_value=True + ), patch.object( + source, "get_file_extension", return_value="txt" + ), patch.object( + source, + "download_and_extract_file", + new_callable=AsyncMock, + return_value=None, + ): + result = await source.get_content("file.txt", doit=True) + + # Should return the basic doc when extracted_doc is None + expected = { + "_id": "mockfilesystem_file.txt", + } + assert result == expected