diff --git a/docs/deployment/code_examples/google/cloud_run_example.py b/docs/deployment/code_examples/google/cloud_run_example.py index c01a4f3821..88db52bc75 100644 --- a/docs/deployment/code_examples/google/cloud_run_example.py +++ b/docs/deployment/code_examples/google/cloud_run_example.py @@ -5,24 +5,23 @@ import uvicorn from litestar import Litestar, get -from crawlee import service_locator from crawlee.crawlers import PlaywrightCrawler, PlaywrightCrawlingContext - -# highlight-start -# Disable writing storage data to the file system -configuration = service_locator.get_configuration() -configuration.persist_storage = False -configuration.write_metadata = False -# highlight-end +from crawlee.storage_clients import MemoryStorageClient @get('/') async def main() -> str: """The crawler entry point that will be called when the HTTP endpoint is accessed.""" + # highlight-start + # Disable writing storage data to the file system + storage_client = MemoryStorageClient() + # highlight-end + crawler = PlaywrightCrawler( headless=True, max_requests_per_crawl=10, browser_type='firefox', + storage_client=storage_client, ) @crawler.router.default_handler diff --git a/docs/deployment/code_examples/google/google_example.py b/docs/deployment/code_examples/google/google_example.py index f7180aa417..e31af2c3ab 100644 --- a/docs/deployment/code_examples/google/google_example.py +++ b/docs/deployment/code_examples/google/google_example.py @@ -6,22 +6,21 @@ import functions_framework from flask import Request, Response -from crawlee import service_locator from crawlee.crawlers import ( BeautifulSoupCrawler, BeautifulSoupCrawlingContext, ) - -# highlight-start -# Disable writing storage data to the file system -configuration = service_locator.get_configuration() -configuration.persist_storage = False -configuration.write_metadata = False -# highlight-end +from crawlee.storage_clients import MemoryStorageClient async def main() -> str: + # highlight-start + # Disable writing storage data to the file system + storage_client = MemoryStorageClient() + # highlight-end + crawler = BeautifulSoupCrawler( + storage_client=storage_client, max_request_retries=1, request_handler_timeout=timedelta(seconds=30), max_requests_per_crawl=10, diff --git a/docs/examples/code_examples/export_entire_dataset_to_file_csv.py b/docs/examples/code_examples/export_entire_dataset_to_file_csv.py index 115474fc61..f86a469c03 100644 --- a/docs/examples/code_examples/export_entire_dataset_to_file_csv.py +++ b/docs/examples/code_examples/export_entire_dataset_to_file_csv.py @@ -30,7 +30,7 @@ async def request_handler(context: BeautifulSoupCrawlingContext) -> None: await crawler.run(['https://crawlee.dev']) # Export the entire dataset to a CSV file. - await crawler.export_data_csv(path='results.csv') + await crawler.export_data(path='results.csv') if __name__ == '__main__': diff --git a/docs/examples/code_examples/export_entire_dataset_to_file_json.py b/docs/examples/code_examples/export_entire_dataset_to_file_json.py index 5c871fb228..81fe07afa4 100644 --- a/docs/examples/code_examples/export_entire_dataset_to_file_json.py +++ b/docs/examples/code_examples/export_entire_dataset_to_file_json.py @@ -30,7 +30,7 @@ async def request_handler(context: BeautifulSoupCrawlingContext) -> None: await crawler.run(['https://crawlee.dev']) # Export the entire dataset to a JSON file. - await crawler.export_data_json(path='results.json') + await crawler.export_data(path='results.json') if __name__ == '__main__': diff --git a/docs/examples/code_examples/parsel_crawler.py b/docs/examples/code_examples/parsel_crawler.py index 61ddb7484e..9807d7ca3b 100644 --- a/docs/examples/code_examples/parsel_crawler.py +++ b/docs/examples/code_examples/parsel_crawler.py @@ -40,7 +40,7 @@ async def some_hook(context: BasicCrawlingContext) -> None: await crawler.run(['https://github.com']) # Export the entire dataset to a JSON file. - await crawler.export_data_json(path='results.json') + await crawler.export_data(path='results.json') if __name__ == '__main__': diff --git a/docs/guides/code_examples/storage_clients/custom_storage_client_example.py b/docs/guides/code_examples/storage_clients/custom_storage_client_example.py new file mode 100644 index 0000000000..271b83d811 --- /dev/null +++ b/docs/guides/code_examples/storage_clients/custom_storage_client_example.py @@ -0,0 +1,65 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from crawlee.storage_clients import StorageClient +from crawlee.storage_clients._base import ( + DatasetClient, + KeyValueStoreClient, + RequestQueueClient, +) + +if TYPE_CHECKING: + from crawlee.configuration import Configuration + +# Implement the storage type clients with your backend logic. + + +class CustomDatasetClient(DatasetClient): + # Implement methods like push_data, get_data, iterate_items, etc. + pass + + +class CustomKeyValueStoreClient(KeyValueStoreClient): + # Implement methods like get_value, set_value, delete, etc. + pass + + +class CustomRequestQueueClient(RequestQueueClient): + # Implement methods like add_request, fetch_next_request, etc. + pass + + +# Implement the storage client factory. + + +class CustomStorageClient(StorageClient): + async def create_dataset_client( + self, + *, + id: str | None = None, + name: str | None = None, + configuration: Configuration | None = None, + ) -> CustomDatasetClient: + # Create and return your custom dataset client. + pass + + async def create_kvs_client( + self, + *, + id: str | None = None, + name: str | None = None, + configuration: Configuration | None = None, + ) -> CustomKeyValueStoreClient: + # Create and return your custom key-value store client. + pass + + async def create_rq_client( + self, + *, + id: str | None = None, + name: str | None = None, + configuration: Configuration | None = None, + ) -> CustomRequestQueueClient: + # Create and return your custom request queue client. + pass diff --git a/docs/guides/code_examples/storage_clients/file_system_storage_client_basic_example.py b/docs/guides/code_examples/storage_clients/file_system_storage_client_basic_example.py new file mode 100644 index 0000000000..62969f8024 --- /dev/null +++ b/docs/guides/code_examples/storage_clients/file_system_storage_client_basic_example.py @@ -0,0 +1,8 @@ +from crawlee.crawlers import ParselCrawler +from crawlee.storage_clients import FileSystemStorageClient + +# Create a new instance of storage client. +storage_client = FileSystemStorageClient() + +# And pass it to the crawler. +crawler = ParselCrawler(storage_client=storage_client) diff --git a/docs/guides/code_examples/storage_clients/file_system_storage_client_configuration_example.py b/docs/guides/code_examples/storage_clients/file_system_storage_client_configuration_example.py new file mode 100644 index 0000000000..1d3507660f --- /dev/null +++ b/docs/guides/code_examples/storage_clients/file_system_storage_client_configuration_example.py @@ -0,0 +1,18 @@ +from crawlee.configuration import Configuration +from crawlee.crawlers import ParselCrawler +from crawlee.storage_clients import FileSystemStorageClient + +# Create a new instance of storage client. +storage_client = FileSystemStorageClient() + +# Create a configuration with custom settings. +configuration = Configuration( + storage_dir='./my_storage', + purge_on_start=False, +) + +# And pass them to the crawler. +crawler = ParselCrawler( + storage_client=storage_client, + configuration=configuration, +) diff --git a/docs/guides/code_examples/storage_clients/memory_storage_client_basic_example.py b/docs/guides/code_examples/storage_clients/memory_storage_client_basic_example.py new file mode 100644 index 0000000000..fe79edc3f4 --- /dev/null +++ b/docs/guides/code_examples/storage_clients/memory_storage_client_basic_example.py @@ -0,0 +1,8 @@ +from crawlee.crawlers import ParselCrawler +from crawlee.storage_clients import MemoryStorageClient + +# Create a new instance of storage client. +storage_client = MemoryStorageClient() + +# And pass it to the crawler. +crawler = ParselCrawler(storage_client=storage_client) diff --git a/docs/guides/code_examples/storage_clients/registering_storage_client_example.py b/docs/guides/code_examples/storage_clients/registering_storage_client_example.py new file mode 100644 index 0000000000..995278e7f6 --- /dev/null +++ b/docs/guides/code_examples/storage_clients/registering_storage_client_example.py @@ -0,0 +1,29 @@ +import asyncio + +from crawlee import service_locator +from crawlee.crawlers import ParselCrawler +from crawlee.storage_clients import MemoryStorageClient +from crawlee.storages import Dataset + + +async def main() -> None: + # Create custom storage client, MemoryStorageClient for example. + storage_client = MemoryStorageClient() + + # Register it globally via the service locator. + service_locator.set_storage_client(storage_client) + + # Or pass it directly to the crawler, it will be registered globally + # to the service locator under the hood. + crawler = ParselCrawler(storage_client=storage_client) + + # Or just provide it when opening a storage (e.g. dataset), it will be used + # for this storage only, not globally. + dataset = await Dataset.open( + name='my_dataset', + storage_client=storage_client, + ) + + +if __name__ == '__main__': + asyncio.run(main()) diff --git a/docs/guides/code_examples/storages/cleaning_purge_explicitly_example.py b/docs/guides/code_examples/storages/cleaning_purge_explicitly_example.py index 15435da7bf..17911b79d7 100644 --- a/docs/guides/code_examples/storages/cleaning_purge_explicitly_example.py +++ b/docs/guides/code_examples/storages/cleaning_purge_explicitly_example.py @@ -1,20 +1,19 @@ import asyncio -from crawlee.crawlers import HttpCrawler -from crawlee.storage_clients import MemoryStorageClient +from crawlee.storages import Dataset async def main() -> None: - storage_client = MemoryStorageClient.from_config() + # Create storage client with configuration + dataset = await Dataset.open(name='my-dataset') - # Call the purge_on_start method to explicitly purge the storage. - # highlight-next-line - await storage_client.purge_on_start() + # Purge the dataset explicitly - purging will remove all items from the dataset. + # But keeps the dataset itself and its metadata. + await dataset.purge() - # Pass the storage client to the crawler. - crawler = HttpCrawler(storage_client=storage_client) - - # ... + # Or you can drop the dataset completely, which will remove the dataset + # and all its items. + await dataset.drop() if __name__ == '__main__': diff --git a/docs/guides/code_examples/storages/dataset_basic_example.py b/docs/guides/code_examples/storages/dataset_basic_example.py index 9b67f36eb0..03b7581f85 100644 --- a/docs/guides/code_examples/storages/dataset_basic_example.py +++ b/docs/guides/code_examples/storages/dataset_basic_example.py @@ -6,7 +6,7 @@ async def main() -> None: # Open the dataset, if it does not exist, it will be created. # Leave name empty to use the default dataset. - dataset = await Dataset.open() + dataset = await Dataset.open(name='my-dataset') # Push a single row of data. await dataset.push_data({'foo': 'bar'}) diff --git a/docs/guides/code_examples/storages/dataset_with_crawler_explicit_example.py b/docs/guides/code_examples/storages/dataset_with_crawler_explicit_example.py index 7c6a613b8f..2b19c86994 100644 --- a/docs/guides/code_examples/storages/dataset_with_crawler_explicit_example.py +++ b/docs/guides/code_examples/storages/dataset_with_crawler_explicit_example.py @@ -7,7 +7,7 @@ async def main() -> None: # Open the dataset, if it does not exist, it will be created. # Leave name empty to use the default dataset. - dataset = await Dataset.open() + dataset = await Dataset.open(name='my-dataset') # Create a new crawler (it can be any subclass of BasicCrawler). crawler = BeautifulSoupCrawler() diff --git a/docs/guides/code_examples/storages/kvs_basic_example.py b/docs/guides/code_examples/storages/kvs_basic_example.py index 7821fa75de..9cc66c59a5 100644 --- a/docs/guides/code_examples/storages/kvs_basic_example.py +++ b/docs/guides/code_examples/storages/kvs_basic_example.py @@ -6,7 +6,7 @@ async def main() -> None: # Open the key-value store, if it does not exist, it will be created. # Leave name empty to use the default KVS. - kvs = await KeyValueStore.open() + kvs = await KeyValueStore.open(name='my-key-value-store') # Set a value associated with 'some-key'. await kvs.set_value(key='some-key', value={'foo': 'bar'}) diff --git a/docs/guides/code_examples/storages/kvs_with_crawler_explicit_example.py b/docs/guides/code_examples/storages/kvs_with_crawler_explicit_example.py index 66a921bd04..4c965457c3 100644 --- a/docs/guides/code_examples/storages/kvs_with_crawler_explicit_example.py +++ b/docs/guides/code_examples/storages/kvs_with_crawler_explicit_example.py @@ -7,7 +7,7 @@ async def main() -> None: # Open the key-value store, if it does not exist, it will be created. # Leave name empty to use the default KVS. - kvs = await KeyValueStore.open() + kvs = await KeyValueStore.open(name='my-key-value-store') # Create a new Playwright crawler. crawler = PlaywrightCrawler() diff --git a/docs/guides/code_examples/storages/rq_basic_example.py b/docs/guides/code_examples/storages/rq_basic_example.py index 9e983bb9fe..388c184fc6 100644 --- a/docs/guides/code_examples/storages/rq_basic_example.py +++ b/docs/guides/code_examples/storages/rq_basic_example.py @@ -12,7 +12,7 @@ async def main() -> None: await request_queue.add_request('https://apify.com/') # Add multiple requests as a batch. - await request_queue.add_requests_batched( + await request_queue.add_requests( ['https://crawlee.dev/', 'https://crawlee.dev/python/'] ) diff --git a/docs/guides/code_examples/storages/rq_with_crawler_explicit_example.py b/docs/guides/code_examples/storages/rq_with_crawler_explicit_example.py index 21bedad0b9..aac7b0bcb8 100644 --- a/docs/guides/code_examples/storages/rq_with_crawler_explicit_example.py +++ b/docs/guides/code_examples/storages/rq_with_crawler_explicit_example.py @@ -10,12 +10,10 @@ async def main() -> None: request_queue = await RequestQueue.open(name='my-request-queue') # Interact with the request queue directly, e.g. add a batch of requests. - await request_queue.add_requests_batched( - ['https://apify.com/', 'https://crawlee.dev/'] - ) + await request_queue.add_requests(['https://apify.com/', 'https://crawlee.dev/']) # Create a new crawler (it can be any subclass of BasicCrawler) and pass the request - # list as request manager to it. It will be managed by the crawler. + # queue as request manager to it. It will be managed by the crawler. crawler = HttpCrawler(request_manager=request_queue) # Define the default request handler, which will be called for every request. diff --git a/docs/guides/request_loaders.mdx b/docs/guides/request_loaders.mdx index 73fe374a62..289d7c07ff 100644 --- a/docs/guides/request_loaders.mdx +++ b/docs/guides/request_loaders.mdx @@ -42,7 +42,7 @@ classDiagram %% Abstract classes %% ======================== -class BaseStorage { +class Storage { <> + id + name @@ -52,12 +52,12 @@ class BaseStorage { class RequestLoader { <> + + handled_count + + total_count + fetch_next_request() + mark_request_as_handled() + is_empty() + is_finished() - + get_handled_count() - + get_total_count() + to_tandem() } @@ -92,7 +92,7 @@ class RequestManagerTandem { %% Inheritance arrows %% ======================== -BaseStorage <|-- RequestQueue +Storage <|-- RequestQueue RequestManager <|-- RequestQueue RequestLoader <|-- RequestManager diff --git a/docs/guides/storage_clients.mdx b/docs/guides/storage_clients.mdx new file mode 100644 index 0000000000..6175eb2785 --- /dev/null +++ b/docs/guides/storage_clients.mdx @@ -0,0 +1,141 @@ +--- +id: storage-clients +title: Storage clients +description: How to work with storage clients in Crawlee, including the built-in clients and how to create your own. +--- + +import ApiLink from '@site/src/components/ApiLink'; +import Tabs from '@theme/Tabs'; +import TabItem from '@theme/TabItem'; +import RunnableCodeBlock from '@site/src/components/RunnableCodeBlock'; + +import MemoryStorageClientBasicExample from '!!raw-loader!roa-loader!./code_examples/storage_clients/memory_storage_client_basic_example.py'; +import FileSystemStorageClientBasicExample from '!!raw-loader!roa-loader!./code_examples/storage_clients/file_system_storage_client_basic_example.py'; +import FileSystemStorageClientConfigurationExample from '!!raw-loader!roa-loader!./code_examples/storage_clients/file_system_storage_client_configuration_example.py'; +import CustomStorageClientExample from '!!raw-loader!roa-loader!./code_examples/storage_clients/custom_storage_client_example.py'; + +Storage clients in Crawlee are subclasses of `StorageClient`. They handle interactions with different storage backends. For instance: + +- `MemoryStorageClient`: Stores data purely in memory with no persistence. +- `FileSystemStorageClient`: Provides persistent file system storage with in-memory caching for better performance. +- [`ApifyStorageClient`](https://docs.apify.com/sdk/python/reference/class/ApifyStorageClient): Manages storage on the [Apify platform](https://apify.com). Apify storage client is implemented in the [Apify SDK](https://github.com/apify/apify-sdk-python). You will find more information about it in the [Apify SDK documentation](https://docs.apify.com/sdk/python/docs/overview/introduction). + +Each storage client is responsible for maintaining the storages in a specific environment. This abstraction makes it easier to switch between different environments, e.g. between local development and cloud production setup. + +Storage clients provide a unified interface for interacting with `Dataset`, `KeyValueStore`, and `RequestQueue`, regardless of the underlying storage implementation. They handle operations like creating, reading, updating, and deleting storage instances, as well as managing data persistence and cleanup. + +## Built-in storage clients + +Crawlee for Python currently provides two main storage client implementations: + +### Memory storage client + +The `MemoryStorageClient` stores all data in memory using Python data structures. It provides fast access but does not persist data between runs, meaning all data is lost when the program terminates. + + +{MemoryStorageClientBasicExample} + + +The `MemoryStorageClient` is a good choice for testing, development, short-lived operations where speed is more important than data persistence, or HTTP APIs where each request should be handled with a fresh storage. It is not suitable for production use or long-running crawls, as all data will be lost when the program exits. + +:::warning Persistence limitation +The `MemoryStorageClient` does not persist data between runs. All data is lost when the program terminates. +::: + +### File system storage client + +The `FileSystemStorageClient` provides persistent storage by writing data directly to the file system. It uses smart caching and batch processing for better performance while storing data in human-readable JSON format. This is a default storage client used by Crawlee when no other storage client is specified. + +:::warning Concurrency limitation +The `FileSystemStorageClient` is not safe for concurrent access from multiple crawler processes. Use it only when running a single crawler process at a time. +::: + +This storage client is ideal for large datasets, and long-running operations where data persistence is required. Data can be easily inspected and shared with other tools. + + +{FileSystemStorageClientBasicExample} + + +Configuration options for the `FileSystemStorageClient` can be set through environment variables or the `Configuration` class. + - **`storage_dir`** (env: `CRAWLEE_STORAGE_DIR`, default: `'./storage'`): The root directory for all storage data. + - **`purge_on_start`** (env: `CRAWLEE_PURGE_ON_START`, default: `True`): Whether to purge default storages on start. + +Data are stored using the following directory structure: + +```text +{CRAWLEE_STORAGE_DIR}/ +├── datasets/ +│ └── {DATASET_NAME}/ +│ ├── __metadata__.json +│ ├── 000000001.json +│ └── 000000002.json +├── key_value_stores/ +│ └── {KVS_NAME}/ +│ ├── __metadata__.json +│ ├── key1.json +│ ├── key2.txt +│ └── key3.json +└── request_queues/ + └── {RQ_NAME}/ + ├── __metadata__.json + ├── {REQUEST_ID_1}.json + └── {REQUEST_ID_2}.json +``` + +Where: +- `{CRAWLEE_STORAGE_DIR}`: The root directory for local storage. +- `{DATASET_NAME}`, `{KVS_NAME}`, `{RQ_NAME}`: The unique names for each storage instance (defaults to `"default"`). +- Files are stored directly without additional metadata files for simpler structure. + +Here is an example of how to configure the `FileSystemStorageClient`: + + +{FileSystemStorageClientConfigurationExample} + + +## Creating a custom storage client + +A storage client consists of two parts: the storage client factory and individual storage type clients. The `StorageClient` acts as a factory that creates specific clients (`DatasetClient`, `KeyValueStoreClient`, `RequestQueueClient`) where the actual storage logic is implemented. + +Here is an example of a custom storage client that implements the `StorageClient` interface: + + +{CustomStorageClientExample} + + +Custom storage clients can implement any storage logic, such as connecting to a database, using a cloud storage service, or integrating with other systems. They must implement the required methods for creating, reading, updating, and deleting data in the respective storages. + +## Registering storage clients + +Storage clients can be registered either: +- Globally, with the `ServiceLocator` or passed directly to the crawler; +- Or storage specific, when opening a storage instance like `Dataset`, `KeyValueStore`, or `RequestQueue`. + +```python +from crawlee.storage_clients import CustomStorageClient +from crawlee.service_locator import service_locator +from crawlee.crawlers import ParselCrawler +from crawlee.storages import Dataset + +# Create custom storage client. +storage_client = CustomStorageClient() +storage_client = CustomStorageClient() + +# Register it either with the service locator. +service_locator.set_storage_client(storage_client) + +# Or pass it directly to the crawler. +crawler = ParselCrawler(storage_client=storage_client) + +# Or just provide it when opening a storage (e.g. dataset). +dataset = await Dataset.open( + name='my_dataset', + storage_client=storage_client, +) +``` + +You can also register a different storage client for each storage instance, allowing you to use different backends for different storages. This is useful when you want to use for example a fast in-memory storage for `RequestQueue` while persisting scraping results for `Dataset` or `KeyValueStore`. + +## Conclusion + +Storage clients in Crawlee provide different backends for storages. Use `MemoryStorageClient` for testing and fast operations without persistence, or `FileSystemStorageClient` for environments where data needs to persist. You can also create custom storage clients for specialized backends by implementing the `StorageClient` interface. If you have questions or need assistance, feel free to reach out on our [GitHub](https://github.com/apify/crawlee-python) or join our [Discord community](https://discord.com/invite/jyEM2PRvMU). Happy scraping! diff --git a/docs/guides/storages.mdx b/docs/guides/storages.mdx index 3be168b683..22626e7143 100644 --- a/docs/guides/storages.mdx +++ b/docs/guides/storages.mdx @@ -17,7 +17,7 @@ import RqHelperEnqueueLinksExample from '!!raw-loader!roa-loader!./code_examples import DatasetBasicExample from '!!raw-loader!roa-loader!./code_examples/storages/dataset_basic_example.py'; import DatasetWithCrawlerExample from '!!raw-loader!roa-loader!./code_examples/storages/dataset_with_crawler_example.py'; -import DatasetWithCrawerExplicitExample from '!!raw-loader!roa-loader!./code_examples/storages/dataset_with_crawler_explicit_example.py'; +import DatasetWithCrawlerExplicitExample from '!!raw-loader!roa-loader!./code_examples/storages/dataset_with_crawler_explicit_example.py'; import KvsBasicExample from '!!raw-loader!roa-loader!./code_examples/storages/kvs_basic_example.py'; import KvsWithCrawlerExample from '!!raw-loader!roa-loader!./code_examples/storages/kvs_with_crawler_example.py'; @@ -26,51 +26,18 @@ import KvsWithCrawlerExplicitExample from '!!raw-loader!roa-loader!./code_exampl import CleaningDoNotPurgeExample from '!!raw-loader!roa-loader!./code_examples/storages/cleaning_do_not_purge_example.py'; import CleaningPurgeExplicitlyExample from '!!raw-loader!roa-loader!./code_examples/storages/cleaning_purge_explicitly_example.py'; -Crawlee offers multiple storage types for managing and persisting your crawling data. Request-oriented storages, such as the `RequestQueue`, help you store and deduplicate URLs, while result-oriented storages, like `Dataset` and `KeyValueStore`, focus on storing and retrieving scraping results. This guide helps you choose the storage type that suits your needs. +Crawlee offers several storage types for managing and persisting your crawling data. Request-oriented storages, such as the `RequestQueue`, help you store and deduplicate URLs, while result-oriented storages, like `Dataset` and `KeyValueStore`, focus on storing and retrieving scraping results. This guide helps you choose the storage type that suits your needs. -## Storage clients +Crawlee's storage system consists of two main layers: +- **Storages** (`Dataset`, `KeyValueStore`, `RequestQueue`): High-level interfaces for interacting with different storage types. +- **Storage clients** (`MemoryStorageClient`, `FileSystemStorageClient`, etc.): Backend implementations that handle the actual data persistence and management. -Storage clients in Crawlee are subclasses of `StorageClient`. They handle interactions with different storage backends. For instance: - -- `MemoryStorageClient`: Stores data in memory and persists it to the local file system. -- [`ApifyStorageClient`](https://docs.apify.com/sdk/python/reference/class/ApifyStorageClient): Manages storage on the [Apify platform](https://apify.com). Apify storage client is implemented in the [Apify SDK](https://github.com/apify/apify-sdk-python). - -Each storage client is responsible for maintaining the storages in a specific environment. This abstraction makes it easier to switch between different environments, e.g. between local development and cloud production setup. - -### Memory storage client - -The `MemoryStorageClient` is the default and currently the only one storage client in Crawlee. It stores data in memory and persists it to the local file system. The data are stored in the following directory structure: - -```text -{CRAWLEE_STORAGE_DIR}/{storage_type}/{STORAGE_ID}/ -``` - -where: - -- `{CRAWLEE_STORAGE_DIR}`: The root directory for local storage, specified by the `CRAWLEE_STORAGE_DIR` environment variable (default: `./storage`). -- `{storage_type}`: The type of storage (e.g., `datasets`, `key_value_stores`, `request_queues`). -- `{STORAGE_ID}`: The ID of the specific storage instance (default: `default`). - -:::info NOTE -The current `MemoryStorageClient` and its interface is quite old and not great. We plan to refactor it, together with the whole `StorageClient` interface in the near future and it better and and easier to use. We also plan to introduce new storage clients for different storage backends - e.g. for [SQLite](https://sqlite.org/). -::: - -You can override default storage IDs using these environment variables: `CRAWLEE_DEFAULT_DATASET_ID`, `CRAWLEE_DEFAULT_KEY_VALUE_STORE_ID`, or `CRAWLEE_DEFAULT_REQUEST_QUEUE_ID`. +For more information about storage clients and their configuration, see the [Storage clients guide](./storage-clients). ## Request queue The `RequestQueue` is the primary storage for URLs in Crawlee, especially useful for deep crawling. It supports dynamic addition and removal of URLs, making it ideal for recursive tasks where URLs are discovered and added during the crawling process (e.g., following links across multiple pages). Each Crawlee project has a **default request queue**, which can be used to store URLs during a specific run. The `RequestQueue` is highly useful for large-scale and complex crawls. -By default, data are stored using the following path structure: - -```text -{CRAWLEE_STORAGE_DIR}/request_queues/{QUEUE_ID}/{INDEX}.json -``` - -- `{CRAWLEE_STORAGE_DIR}`: The root directory for all storage data, specified by the environment variable. -- `{QUEUE_ID}`: The ID of the request queue, "default" by default. -- `{INDEX}`: Represents the zero-based index of the record within the queue. - The following code demonstrates the usage of the `RequestQueue`: @@ -123,15 +90,6 @@ For a detailed explanation of the `RequestMan The `Dataset` is designed for storing structured data, where each entry has a consistent set of attributes, such as products in an online store or real estate listings. Think of a `Dataset` as a table: each entry corresponds to a row, with attributes represented as columns. Datasets are append-only, allowing you to add new records but not modify or delete existing ones. Every Crawlee project run is associated with a default dataset, typically used to store results specific to that crawler execution. However, using this dataset is optional. -By default, data are stored using the following path structure: - -```text -{CRAWLEE_STORAGE_DIR}/datasets/{DATASET_ID}/{INDEX}.json -``` -- `{CRAWLEE_STORAGE_DIR}`: The root directory for all storage data specified by the environment variable. -- `{DATASET_ID}`: The dataset's ID, "default" by default. -- `{INDEX}`: Represents the zero-based index of the record within the dataset. - The following code demonstrates basic operations of the dataset: @@ -147,7 +105,7 @@ The following code demonstrates basic operations of the dataset: - {DatasetWithCrawerExplicitExample} + {DatasetWithCrawlerExplicitExample} @@ -162,16 +120,6 @@ Crawlee provides the following helper function to simplify interactions with the The `KeyValueStore` is designed to save and retrieve data records or files efficiently. Each record is uniquely identified by a key and is associated with a specific MIME type, making the `KeyValueStore` ideal for tasks like saving web page screenshots, PDFs, or tracking the state of crawlers. -By default, data are stored using the following path structure: - -```text -{CRAWLEE_STORAGE_DIR}/key_value_stores/{STORE_ID}/{KEY}.{EXT} -``` -- `{CRAWLEE_STORAGE_DIR}`: The root directory for all storage data specified by the environment variable. -- `{STORE_ID}`: The KVS's ID, "default" by default. -- `{KEY}`: The unique key for the record. -- `{EXT}`: The file extension corresponding to the MIME type of the content. - The following code demonstrates the usage of the `KeyValueStore`: @@ -202,20 +150,39 @@ Crawlee provides the following helper function to simplify interactions with the ## Cleaning up the storages -Default storages are purged before the crawler starts, unless explicitly configured otherwise. For that case, see `Configuration.purge_on_start`. This cleanup happens as soon as a storage is accessed, either when you open a storage (e.g. using `RequestQueue.open`, `Dataset.open`, `KeyValueStore.open`) or when interacting with a storage through one of the helper functions (e.g. `push_data`), which implicitly opens the result storage. +By default, Crawlee automatically cleans up **default storages** before each crawler run to ensure a clean state. This behavior is controlled by the `Configuration.purge_on_start` setting (default: `True`). + +### What gets purged + +- **Default storages** are completely removed and recreated at the start of each run, ensuring that you start with a clean slate. +- **Named storages** are never automatically purged and persist across runs. +- The behavior depends on the storage client implementation. + +### When purging happens + +The cleanup occurs as soon as a storage is accessed: +- When opening a storage explicitly (e.g., `RequestQueue.open`, `Dataset.open`, `KeyValueStore.open`). +- When using helper functions that implicitly open storages (e.g., `push_data`). +- Automatically when `BasicCrawler.run` is invoked. + +### Disabling automatic purging + +To disable automatic purging, set `purge_on_start=False` in your configuration: {CleaningDoNotPurgeExample} -If you do not explicitly interact with storages in your code, the purging will occur automatically when the `BasicCrawler.run` method is invoked. +### Manual purging -If you need to purge storages earlier, you can call `MemoryStorageClient.purge_on_start` directly if you are using the default storage client. This method triggers the purging process for the underlying storage implementation you are currently using. +Purge on start behavior just triggers the storage's `purge` method, which removes all data from the storage. If you want to purge the storage manually, you can do so by calling the `purge` method on the storage instance. Or if you want to delete the storage completely, you can call the `drop` method on the storage instance, which will remove the storage, including metadata and all its data. {CleaningPurgeExplicitlyExample} +Note that purging behavior may vary between storage client implementations. For more details on storage configuration and client implementations, see the [Storage clients guide](./storage-clients). + ## Conclusion -This guide introduced you to the different storage types available in Crawlee and how to interact with them. You learned how to manage requests and store and retrieve scraping results using the `RequestQueue`, `Dataset`, and `KeyValueStore`. You also discovered how to use helper functions to simplify interactions with these storages. Finally, you learned how to clean up storages before starting a crawler run and how to purge them explicitly. If you have questions or need assistance, feel free to reach out on our [GitHub](https://github.com/apify/crawlee-python) or join our [Discord community](https://discord.com/invite/jyEM2PRvMU). Happy scraping! +This guide introduced you to the different storage types available in Crawlee and how to interact with them. You learned how to manage requests using the `RequestQueue` and store and retrieve scraping results using the `Dataset` and `KeyValueStore`. You also discovered how to use helper functions to simplify interactions with these storages. Finally, you learned how to clean up storages before starting a crawler run. If you have questions or need assistance, feel free to reach out on our [GitHub](https://github.com/apify/crawlee-python) or join our [Discord community](https://discord.com/invite/jyEM2PRvMU). Happy scraping! diff --git a/docs/upgrading/upgrading_to_v1.md b/docs/upgrading/upgrading_to_v1.md new file mode 100644 index 0000000000..1d7219dbb4 --- /dev/null +++ b/docs/upgrading/upgrading_to_v1.md @@ -0,0 +1,148 @@ +--- +id: upgrading-to-v1 +title: Upgrading to v1 +--- + +This page summarizes the breaking changes between Crawlee for Python v0.6 and v1.0. + +## Storage clients + +In v1.0, we are introducing a new storage clients system. We have completely reworked their interface, +making it much simpler to write your own storage clients. This allows you to easily store your request queues, +key-value stores, and datasets in various destinations. + +### New storage clients + +Previously, the `MemoryStorageClient` handled both in-memory storage and file system persistence, depending +on configuration. In v1.0, we've split this into two dedicated classes: + +- `MemoryStorageClient` - stores all data in memory only. +- `FileSystemStorageClient` - persists data on the file system, with in-memory caching for improved performance. + +For details about the new interface, see the `BaseStorageClient` documentation. You can also check out +the [Storage clients guide](https://crawlee.dev/python/docs/guides/) for more information on available +storage clients and instructions on writing your own. + +### Memory storage client + +Before: + +```python +from crawlee.configuration import Configuration +from crawlee.storage_clients import MemoryStorageClient + +configuration = Configuration(persist_storage=False) +storage_client = MemoryStorageClient.from_config(configuration) +``` + +Now: + +```python +from crawlee.storage_clients import MemoryStorageClient + +storage_client = MemoryStorageClient() +``` + +### File-system storage client + +Before: + +```python +from crawlee.configuration import Configuration +from crawlee.storage_clients import MemoryStorageClient + +configuration = Configuration(persist_storage=True) +storage_client = MemoryStorageClient.from_config(configuration) +``` + +Now: + +```python +from crawlee.storage_clients import FileSystemStorageClient + +storage_client = FileSystemStorageClient() +``` + +The way you register storage clients remains the same: + +```python +from crawlee import service_locator +from crawlee.crawlers import ParselCrawler +from crawlee.storage_clients import MemoryStorageClient +from crawlee.storages import Dataset + +# Create custom storage client, MemoryStorageClient for example. +storage_client = MemoryStorageClient() + +# Register it globally via the service locator. +service_locator.set_storage_client(storage_client) + +# Or pass it directly to the crawler, it will be registered globally +# to the service locator under the hood. +crawler = ParselCrawler(storage_client=storage_client) + +# Or just provide it when opening a storage (e.g. dataset), it will be used +# for this storage only, not globally. +dataset = await Dataset.open( + name='my_dataset', + storage_client=storage_client, +) +``` + +### Breaking changes + +The `persist_storage` and `persist_metadata` fields have been removed from the `Configuration` class. +Persistence is now determined solely by the storage client class you use. + +### Storage client instance behavior + +Instance caching is implemented for the storage open methods: `Dataset.open()`, `KeyValueStore.open()`, +and `RequestQueue.open()`. This means that when you call these methods with the same arguments, +the same instance is returned each time. + +In contrast, when using client methods such as `StorageClient.open_dataset_client()`, each call creates +a new `DatasetClient` instance, even if the arguments are identical. These methods do not use instance caching. + +This usage pattern is not common, and it is generally recommended to open storages using the standard storage +open methods rather than the storage client methods. + +### Writing custom storage clients + +The storage client interface has been fully reworked. Collection storage clients have been removed - now there is +one storage client class per storage type (`RequestQueue`, `KeyValueStore`, and `Dataset`). Writing your own storage +clients is now much simpler, allowing you to store your request queues, key-value stores, and datasets in any +destination you choose. + +## Dataset + +- There are a few new methods: + - `get_metadata` + - `purge` + - `list_items` +- The `from_storage_object` method has been removed - use the `open` method with `name` or `id` instead. +- The `get_info` and `storage_object` properties have been replaced by the new `get_metadata` method. +- The `set_metadata` method has been removed. +- The `write_to_json` and `write_to_csv` methods have been removed - use `export_to` instead. + +## Key-value store + +- There are a few new methods: + - `get_metadata` + - `purge` + - `delete_value` + - `list_keys` +- The `from_storage_object` method has been removed - use the `open` method with `name` or `id` instead. +- The `get_info` and `storage_object` properties have been replaced by the new `get_metadata` method. +- The `set_metadata` method has been removed. + +## Request queue + +- There are a few new methods: + - `get_metadata` + - `purge` + - `add_requests` (renamed from `add_requests_batched`) +- The `from_storage_object` method has been removed - use the `open` method with `name` or `id` instead. +- The `get_info` and `storage_object` properties have been replaced by the new `get_metadata` method. +- The `set_metadata` method has been removed. +- `resource_directory` from `RequestQueueMetadata` removed – use `path_to_...` property. +- `RequestQueueHead` model replaced with `RequestQueueHeadWithLocks`. diff --git a/pyproject.toml b/pyproject.toml index 686aa10c3f..db65c9b3ca 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -144,9 +144,9 @@ ignore = [ "ISC001", # This rule may cause conflicts when used with the formatter "FIX", # flake8-fixme "PLR0911", # Too many return statements + "PLR0912", # Too many branches "PLR0913", # Too many arguments in function definition "PLR0915", # Too many statements - "PTH", # flake8-use-pathlib "PYI034", # `__aenter__` methods in classes like `{name}` usually return `self` at runtime "PYI036", # The second argument in `__aexit__` should be annotated with `object` or `BaseException | None` "S102", # Use of `exec` detected @@ -168,6 +168,7 @@ indent-style = "space" "F401", # Unused imports ] "**/{tests}/*" = [ + "ASYNC230", # Async functions should not open files with blocking methods like `open` "D", # Everything from the pydocstyle "INP001", # File {filename} is part of an implicit namespace package, add an __init__.py "PLR2004", # Magic value used in comparison, consider replacing {value} with a constant variable @@ -205,9 +206,6 @@ builtins-ignorelist = ["id"] [tool.ruff.lint.isort] known-first-party = ["crawlee"] -[tool.ruff.lint.pylint] -max-branches = 18 - [tool.pytest.ini_options] addopts = "-ra" asyncio_default_fixture_loop_scope = "function" @@ -220,7 +218,10 @@ markers = [ [tool.mypy] python_version = "3.9" plugins = ["pydantic.mypy"] -exclude = ["src/crawlee/project_template"] +exclude = [ + "src/crawlee/project_template", + "docs/guides/code_examples/storage_clients/custom_storage_client_example.py", +] files = ["src", "tests", "docs", "website"] check_untyped_defs = true disallow_incomplete_defs = true @@ -256,7 +257,7 @@ ignore_missing_imports = true [[tool.mypy.overrides]] module = [ - "running_in_web_server.*" # False positive when fastapi not available + "running_in_web_server.*", # False positive when fastapi not available ] disable_error_code = ["misc"] diff --git a/src/crawlee/_cli.py b/src/crawlee/_cli.py index 60d8d1a138..d7eadde35c 100644 --- a/src/crawlee/_cli.py +++ b/src/crawlee/_cli.py @@ -22,7 +22,7 @@ cli = typer.Typer(no_args_is_help=True) template_directory = importlib.resources.files('crawlee') / 'project_template' -with open(str(template_directory / 'cookiecutter.json')) as f: +with (template_directory / 'cookiecutter.json').open() as f: cookiecutter_json = json.load(f) crawler_choices = cookiecutter_json['crawler_type'] diff --git a/src/crawlee/_consts.py b/src/crawlee/_consts.py index d8d40087b0..9345e53e98 100644 --- a/src/crawlee/_consts.py +++ b/src/crawlee/_consts.py @@ -1,3 +1,4 @@ from __future__ import annotations METADATA_FILENAME = '__metadata__.json' +"""The name of the metadata file for storage clients.""" diff --git a/src/crawlee/_request.py b/src/crawlee/_request.py index adb43949ea..a3581e7ebf 100644 --- a/src/crawlee/_request.py +++ b/src/crawlee/_request.py @@ -160,6 +160,22 @@ class Request(BaseModel): model_config = ConfigDict(populate_by_name=True) + id: str + """A unique identifier for the request. Note that this is not used for deduplication, and should not be confused + with `unique_key`.""" + + unique_key: Annotated[str, Field(alias='uniqueKey')] + """A unique key identifying the request. Two requests with the same `unique_key` are considered as pointing + to the same URL. + + If `unique_key` is not provided, then it is automatically generated by normalizing the URL. + For example, the URL of `HTTP://www.EXAMPLE.com/something/` will produce the `unique_key` + of `http://www.example.com/something`. + + Pass an arbitrary non-empty text value to the `unique_key` property to override the default behavior + and specify which URLs shall be considered equal. + """ + url: Annotated[str, BeforeValidator(validate_http_url), Field()] """The URL of the web page to crawl. Must be a valid HTTP or HTTPS URL, and may include query parameters and fragments.""" @@ -207,22 +223,6 @@ class Request(BaseModel): handled_at: Annotated[datetime | None, Field(alias='handledAt')] = None """Timestamp when the request was handled.""" - unique_key: Annotated[str, Field(alias='uniqueKey')] - """A unique key identifying the request. Two requests with the same `unique_key` are considered as pointing - to the same URL. - - If `unique_key` is not provided, then it is automatically generated by normalizing the URL. - For example, the URL of `HTTP://www.EXAMPLE.com/something/` will produce the `unique_key` - of `http://www.example.com/something`. - - Pass an arbitrary non-empty text value to the `unique_key` property - to override the default behavior and specify which URLs shall be considered equal. - """ - - id: str - """A unique identifier for the request. Note that this is not used for deduplication, and should not be confused - with `unique_key`.""" - @classmethod def from_url( cls, @@ -398,6 +398,11 @@ def forefront(self) -> bool: def forefront(self, new_value: bool) -> None: self.crawlee_data.forefront = new_value + @property + def was_already_handled(self) -> bool: + """Indicates whether the request was handled.""" + return self.handled_at is not None + class RequestWithLock(Request): """A crawling request with information about locks.""" diff --git a/src/crawlee/_service_locator.py b/src/crawlee/_service_locator.py index 31bc36c63c..52f934a881 100644 --- a/src/crawlee/_service_locator.py +++ b/src/crawlee/_service_locator.py @@ -1,10 +1,15 @@ from __future__ import annotations +from typing import TYPE_CHECKING + from crawlee._utils.docs import docs_group from crawlee.configuration import Configuration from crawlee.errors import ServiceConflictError -from crawlee.events import EventManager -from crawlee.storage_clients import StorageClient +from crawlee.events import EventManager, LocalEventManager +from crawlee.storage_clients import FileSystemStorageClient, StorageClient + +if TYPE_CHECKING: + from crawlee.storages._storage_instance_manager import StorageInstanceManager @docs_group('Classes') @@ -18,6 +23,7 @@ def __init__(self) -> None: self._configuration: Configuration | None = None self._event_manager: EventManager | None = None self._storage_client: StorageClient | None = None + self._storage_instance_manager: StorageInstanceManager | None = None # Flags to check if the services were already set. self._configuration_was_retrieved = False @@ -49,8 +55,6 @@ def set_configuration(self, configuration: Configuration) -> None: def get_event_manager(self) -> EventManager: """Get the event manager.""" if self._event_manager is None: - from crawlee.events import LocalEventManager - self._event_manager = ( LocalEventManager().from_config(config=self._configuration) if self._configuration @@ -77,13 +81,7 @@ def set_event_manager(self, event_manager: EventManager) -> None: def get_storage_client(self) -> StorageClient: """Get the storage client.""" if self._storage_client is None: - from crawlee.storage_clients import MemoryStorageClient - - self._storage_client = ( - MemoryStorageClient.from_config(config=self._configuration) - if self._configuration - else MemoryStorageClient.from_config() - ) + self._storage_client = FileSystemStorageClient() self._storage_client_was_retrieved = True return self._storage_client @@ -102,5 +100,16 @@ def set_storage_client(self, storage_client: StorageClient) -> None: self._storage_client = storage_client + @property + def storage_instance_manager(self) -> StorageInstanceManager: + """Get the storage instance manager.""" + if self._storage_instance_manager is None: + # Import here to avoid circular imports. + from crawlee.storages._storage_instance_manager import StorageInstanceManager + + self._storage_instance_manager = StorageInstanceManager() + + return self._storage_instance_manager + service_locator = ServiceLocator() diff --git a/src/crawlee/_types.py b/src/crawlee/_types.py index 3cb84111fe..6dd758958c 100644 --- a/src/crawlee/_types.py +++ b/src/crawlee/_types.py @@ -3,27 +3,41 @@ import dataclasses from collections.abc import Iterator, Mapping from dataclasses import dataclass -from enum import Enum -from typing import TYPE_CHECKING, Annotated, Any, Callable, Literal, Optional, Protocol, TypeVar, Union, cast, overload +from typing import ( + TYPE_CHECKING, + Annotated, + Any, + Callable, + Literal, + Optional, + Protocol, + TypedDict, + TypeVar, + Union, + cast, + overload, +) from pydantic import ConfigDict, Field, PlainValidator, RootModel -from typing_extensions import NotRequired, TypeAlias, TypedDict, Unpack from crawlee._utils.docs import docs_group if TYPE_CHECKING: + import json import logging import re - from collections.abc import Coroutine, Sequence + from collections.abc import Callable, Coroutine, Sequence + + from typing_extensions import NotRequired, Required, TypeAlias, Unpack from crawlee import Glob, Request from crawlee._request import RequestOptions + from crawlee.configuration import Configuration from crawlee.http_clients import HttpResponse from crawlee.proxy_configuration import ProxyInfo from crawlee.sessions import Session - from crawlee.storage_clients.models import DatasetItemsListPage + from crawlee.storage_clients import StorageClient from crawlee.storages import KeyValueStore - from crawlee.storages._dataset import ExportToKwargs, GetDataKwargs # Workaround for https://github.com/pydantic/pydantic/issues/9445 J = TypeVar('J', bound='JsonSerializable') @@ -138,15 +152,6 @@ def __init__( self.max_tasks_per_minute = max_tasks_per_minute -@docs_group('Data structures') -class StorageTypes(str, Enum): - """Possible Crawlee storage types.""" - - DATASET = 'Dataset' - KEY_VALUE_STORE = 'Key-value store' - REQUEST_QUEUE = 'Request queue' - - class EnqueueLinksKwargs(TypedDict): """Keyword arguments for the `enqueue_links` methods.""" @@ -190,7 +195,7 @@ class PushDataKwargs(TypedDict): class PushDataFunctionCall(PushDataKwargs): - data: JsonSerializable + data: list[dict[str, Any]] | dict[str, Any] dataset_id: str | None dataset_name: str | None @@ -271,16 +276,12 @@ async def add_requests( async def push_data( self, - data: JsonSerializable, + data: list[dict[str, Any]] | dict[str, Any], dataset_id: str | None = None, dataset_name: str | None = None, **kwargs: Unpack[PushDataKwargs], ) -> None: """Track a call to the `push_data` context helper.""" - from crawlee.storages._dataset import Dataset - - await Dataset.check_and_serialize(data) - self.push_data_calls.append( PushDataFunctionCall( data=data, @@ -420,55 +421,6 @@ def __call__( """ -@docs_group('Functions') -class ExportToFunction(Protocol): - """A function for exporting data from a `Dataset`. - - It simplifies the process of exporting data from a `Dataset`. It opens the specified one and exports - its content to a `KeyValueStore`. - """ - - def __call__( - self, - dataset_id: str | None = None, - dataset_name: str | None = None, - **kwargs: Unpack[ExportToKwargs], - ) -> Coroutine[None, None, None]: - """Call dunder method. - - Args: - dataset_id: The ID of the `Dataset` to export data from. - dataset_name: The name of the `Dataset` to export data from. - **kwargs: Additional keyword arguments. - """ - - -@docs_group('Functions') -class GetDataFunction(Protocol): - """A function for retrieving data from a `Dataset`. - - It simplifies the process of accessing data from a `Dataset`. It opens the specified one and retrieves - data based on the provided parameters. It allows filtering and pagination. - """ - - def __call__( - self, - dataset_id: str | None = None, - dataset_name: str | None = None, - **kwargs: Unpack[GetDataKwargs], - ) -> Coroutine[None, None, DatasetItemsListPage]: - """Call dunder method. - - Args: - dataset_id: ID of the `Dataset` to get data from. - dataset_name: Name of the `Dataset` to get data from. - **kwargs: Additional keyword arguments. - - Returns: - A page of retrieved items. - """ - - @docs_group('Functions') class GetKeyValueStoreFunction(Protocol): """A function for accessing a `KeyValueStore`. @@ -520,7 +472,7 @@ class PushDataFunction(Protocol): def __call__( self, - data: JsonSerializable, + data: list[dict[str, Any]] | dict[str, Any], dataset_id: str | None = None, dataset_name: str | None = None, **kwargs: Unpack[PushDataKwargs], @@ -579,18 +531,6 @@ def __bool__(self) -> bool: return bool(self.screenshot or self.html) -@docs_group('Functions') -class GetPageSnapshot(Protocol): - """A function for getting snapshot of a page.""" - - def __call__(self) -> Coroutine[None, None, PageSnapshot]: - """Get page snapshot. - - Returns: - Snapshot of a page. - """ - - @docs_group('Functions') class UseStateFunction(Protocol): """A function for managing state within the crawling context. @@ -658,3 +598,133 @@ async def get_snapshot(self) -> PageSnapshot: def __hash__(self) -> int: """Return hash of the context. Each context is considered unique.""" return id(self) + + +class GetDataKwargs(TypedDict): + """Keyword arguments for dataset's `get_data` method.""" + + offset: NotRequired[int] + """Skips the specified number of items at the start.""" + + limit: NotRequired[int | None] + """The maximum number of items to retrieve. Unlimited if None.""" + + clean: NotRequired[bool] + """Return only non-empty items and excludes hidden fields. Shortcut for `skip_hidden` and `skip_empty`.""" + + desc: NotRequired[bool] + """Set to True to sort results in descending order.""" + + fields: NotRequired[list[str]] + """Fields to include in each item. Sorts fields as specified if provided.""" + + omit: NotRequired[list[str]] + """Fields to exclude from each item.""" + + unwind: NotRequired[str] + """Unwinds items by a specified array field, turning each element into a separate item.""" + + skip_empty: NotRequired[bool] + """Excludes empty items from the results if True.""" + + skip_hidden: NotRequired[bool] + """Excludes fields starting with '#' if True.""" + + flatten: NotRequired[list[str]] + """Fields to be flattened in returned items.""" + + view: NotRequired[str] + """Specifies the dataset view to be used.""" + + +class ExportToKwargs(TypedDict): + """Keyword arguments for dataset's `export_to` method.""" + + key: Required[str] + """The key under which to save the data.""" + + content_type: NotRequired[Literal['json', 'csv']] + """The format in which to export the data. Either 'json' or 'csv'.""" + + to_kvs_id: NotRequired[str] + """ID of the key-value store to save the exported file.""" + + to_kvs_name: NotRequired[str] + """Name of the key-value store to save the exported file.""" + + to_kvs_storage_client: NotRequired[StorageClient] + """The storage client to use for saving the exported file.""" + + to_kvs_configuration: NotRequired[Configuration] + """The configuration to use for saving the exported file.""" + + +class ExportDataJsonKwargs(TypedDict): + """Keyword arguments for dataset's `export_data_json` method.""" + + skipkeys: NotRequired[bool] + """If True (default: False), dict keys that are not of a basic type (str, int, float, bool, None) will be skipped + instead of raising a `TypeError`.""" + + ensure_ascii: NotRequired[bool] + """Determines if non-ASCII characters should be escaped in the output JSON string.""" + + check_circular: NotRequired[bool] + """If False (default: True), skips the circular reference check for container types. A circular reference will + result in a `RecursionError` or worse if unchecked.""" + + allow_nan: NotRequired[bool] + """If False (default: True), raises a ValueError for out-of-range float values (nan, inf, -inf) to strictly comply + with the JSON specification. If True, uses their JavaScript equivalents (NaN, Infinity, -Infinity).""" + + cls: NotRequired[type[json.JSONEncoder]] + """Allows specifying a custom JSON encoder.""" + + indent: NotRequired[int] + """Specifies the number of spaces for indentation in the pretty-printed JSON output.""" + + separators: NotRequired[tuple[str, str]] + """A tuple of (item_separator, key_separator). The default is (', ', ': ') if indent is None and (',', ': ') + otherwise.""" + + default: NotRequired[Callable] + """A function called for objects that can't be serialized otherwise. It should return a JSON-encodable version + of the object or raise a `TypeError`.""" + + sort_keys: NotRequired[bool] + """Specifies whether the output JSON object should have keys sorted alphabetically.""" + + +class ExportDataCsvKwargs(TypedDict): + """Keyword arguments for dataset's `export_data_csv` method.""" + + dialect: NotRequired[str] + """Specifies a dialect to be used in CSV parsing and writing.""" + + delimiter: NotRequired[str] + """A one-character string used to separate fields. Defaults to ','.""" + + doublequote: NotRequired[bool] + """Controls how instances of `quotechar` inside a field should be quoted. When True, the character is doubled; + when False, the `escapechar` is used as a prefix. Defaults to True.""" + + escapechar: NotRequired[str] + """A one-character string used to escape the delimiter if `quoting` is set to `QUOTE_NONE` and the `quotechar` + if `doublequote` is False. Defaults to None, disabling escaping.""" + + lineterminator: NotRequired[str] + """The string used to terminate lines produced by the writer. Defaults to '\\r\\n'.""" + + quotechar: NotRequired[str] + """A one-character string used to quote fields containing special characters, like the delimiter or quotechar, + or fields containing new-line characters. Defaults to '\"'.""" + + quoting: NotRequired[int] + """Controls when quotes should be generated by the writer and recognized by the reader. Can take any of + the `QUOTE_*` constants, with a default of `QUOTE_MINIMAL`.""" + + skipinitialspace: NotRequired[bool] + """When True, spaces immediately following the delimiter are ignored. Defaults to False.""" + + strict: NotRequired[bool] + """When True, raises an exception on bad CSV input. Defaults to False.""" diff --git a/src/crawlee/_utils/data_processing.py b/src/crawlee/_utils/data_processing.py deleted file mode 100644 index e423650952..0000000000 --- a/src/crawlee/_utils/data_processing.py +++ /dev/null @@ -1,41 +0,0 @@ -from __future__ import annotations - -import json -from enum import Enum -from typing import TYPE_CHECKING, Any, NoReturn - -from crawlee._utils.file import ContentType, is_content_type - -if TYPE_CHECKING: - from crawlee._types import StorageTypes - - -def maybe_extract_enum_member_value(maybe_enum_member: Any) -> Any: - """Extract the value of an enumeration member if it is an Enum, otherwise return the original value.""" - if isinstance(maybe_enum_member, Enum): - return maybe_enum_member.value - return maybe_enum_member - - -def maybe_parse_body(body: bytes, content_type: str) -> Any: - """Parse the response body based on the content type.""" - if is_content_type(ContentType.JSON, content_type): - s = body.decode('utf-8') - return json.loads(s) - - if is_content_type(ContentType.XML, content_type) or is_content_type(ContentType.TEXT, content_type): - return body.decode('utf-8') - - return body - - -def raise_on_duplicate_storage(client_type: StorageTypes, key_name: str, value: str) -> NoReturn: - """Raise an error indicating that a storage with the provided key name and value already exists.""" - client_type = maybe_extract_enum_member_value(client_type) - raise ValueError(f'{client_type} with {key_name} "{value}" already exists.') - - -def raise_on_non_existing_storage(client_type: StorageTypes, id: str | None) -> NoReturn: - """Raise an error indicating that a storage with the provided id does not exist.""" - client_type = maybe_extract_enum_member_value(client_type) - raise ValueError(f'{client_type} with id "{id}" does not exist.') diff --git a/src/crawlee/_utils/docs.py b/src/crawlee/_utils/docs.py index 08d73addf1..8f0120ca99 100644 --- a/src/crawlee/_utils/docs.py +++ b/src/crawlee/_utils/docs.py @@ -1,11 +1,13 @@ from __future__ import annotations -from typing import Callable, Literal +from typing import Any, Callable, Literal, TypeVar GroupName = Literal['Classes', 'Abstract classes', 'Data structures', 'Event payloads', 'Errors', 'Functions'] +T = TypeVar('T', bound=Callable[..., Any]) -def docs_group(group_name: GroupName) -> Callable: # noqa: ARG001 + +def docs_group(group_name: GroupName) -> Callable[[T], T]: # noqa: ARG001 """Mark a symbol for rendering and grouping in documentation. This decorator is used solely for documentation purposes and does not modify the behavior @@ -18,7 +20,7 @@ def docs_group(group_name: GroupName) -> Callable: # noqa: ARG001 The original callable without modification. """ - def wrapper(func: Callable) -> Callable: + def wrapper(func: T) -> T: return func return wrapper diff --git a/src/crawlee/_utils/file.py b/src/crawlee/_utils/file.py index 022d0604ef..c7190b739a 100644 --- a/src/crawlee/_utils/file.py +++ b/src/crawlee/_utils/file.py @@ -1,110 +1,178 @@ from __future__ import annotations import asyncio -import contextlib -import io +import csv import json -import mimetypes import os -import re -import shutil -from enum import Enum -from typing import TYPE_CHECKING +import sys +import tempfile +from pathlib import Path +from typing import TYPE_CHECKING, overload if TYPE_CHECKING: - from pathlib import Path - from typing import Any - - -class ContentType(Enum): - JSON = r'^application/json' - TEXT = r'^text/' - XML = r'^application/.*xml$' - - def matches(self, content_type: str) -> bool: - """Check if the content type matches the enum's pattern.""" - return bool(re.search(self.value, content_type, re.IGNORECASE)) - - -def is_content_type(content_type_enum: ContentType, content_type: str) -> bool: - """Check if the provided content type string matches the specified ContentType.""" - return content_type_enum.matches(content_type) - - -async def force_remove(filename: str | Path) -> None: - """Remove a file, suppressing the FileNotFoundError if it does not exist. - - JS-like rm(filename, { force: true }). + from collections.abc import AsyncIterator + from typing import Any, TextIO + + from typing_extensions import Unpack + + from crawlee._types import ExportDataCsvKwargs, ExportDataJsonKwargs + +if sys.platform == 'win32': + + def _write_file(path: Path, data: str | bytes) -> None: + """Windows-specific file write implementation. + + This implementation writes directly to the file without using a temporary file, because + they are problematic due to permissions issues on Windows. + """ + if isinstance(data, bytes): + path.write_bytes(data) + elif isinstance(data, str): + path.write_text(data, encoding='utf-8') + else: + raise TypeError(f'Unsupported data type: {type(data)}. Expected str or bytes.') +else: + + def _write_file(path: Path, data: str | bytes) -> None: + """Linux/Unix-specific file write implementation using temporary files.""" + dir_path = path.parent + fd, tmp_path = tempfile.mkstemp( + suffix=f'{path.suffix}.tmp', + prefix=f'{path.name}.', + dir=str(dir_path), + ) + + if not isinstance(data, (str, bytes)): + raise TypeError(f'Unsupported data type: {type(data)}. Expected str or bytes.') + + try: + if isinstance(data, bytes): + with os.fdopen(fd, 'wb') as tmp_file: + tmp_file.write(data) + else: + with os.fdopen(fd, 'w', encoding='utf-8') as tmp_file: + tmp_file.write(data) + + # Atomically replace the destination file with the temporary file + Path(tmp_path).replace(path) + except Exception: + Path(tmp_path).unlink(missing_ok=True) + raise + + +def infer_mime_type(value: Any) -> str: + """Infer the MIME content type from the value. Args: - filename: The path to the file to be removed. + value: The value to infer the content type from. + + Returns: + The inferred MIME content type. """ - with contextlib.suppress(FileNotFoundError): - await asyncio.to_thread(os.remove, filename) + # If the value is bytes (or bytearray), return binary content type. + if isinstance(value, (bytes, bytearray)): + return 'application/octet-stream' + # If the value is a dict or list, assume JSON. + if isinstance(value, (dict, list)): + return 'application/json; charset=utf-8' -async def force_rename(src_dir: str | Path, dst_dir: str | Path) -> None: - """Rename a directory, ensuring that the destination directory is removed if it exists. + # If the value is a string, number or boolean, assume plain text. + if isinstance(value, (str, int, float, bool)): + return 'text/plain; charset=utf-8' - Args: - src_dir: The source directory path. - dst_dir: The destination directory path. - """ - # Make sure source directory exists - if await asyncio.to_thread(os.path.exists, src_dir): - # Remove destination directory if it exists - if await asyncio.to_thread(os.path.exists, dst_dir): - await asyncio.to_thread(shutil.rmtree, dst_dir, ignore_errors=True) - await asyncio.to_thread(os.rename, src_dir, dst_dir) + # Default fallback. + return 'application/octet-stream' -def determine_file_extension(content_type: str) -> str | None: - """Determine the file extension for a given MIME content type. +async def json_dumps(obj: Any) -> str: + """Serialize an object to a JSON-formatted string with specific settings. Args: - content_type: The MIME content type string. + obj: The object to serialize. Returns: - A string representing the determined file extension without a leading dot, - or None if no extension could be determined. + A string containing the JSON representation of the input object. """ - # e.g. mimetypes.guess_extension('application/json ') does not work... - actual_content_type = content_type.split(';')[0].strip() - - # mimetypes.guess_extension returns 'xsl' in this case, because 'application/xxx' is "structured" - # ('text/xml' would be "unstructured" and return 'xml') we have to explicitly override it here - if actual_content_type == 'application/xml': - return 'xml' - - # Determine the extension from the mime type - ext = mimetypes.guess_extension(actual_content_type) + return await asyncio.to_thread(json.dumps, obj, ensure_ascii=False, indent=2, default=str) - # Remove the leading dot if extension successfully parsed - return ext[1:] if ext is not None else ext +@overload +async def atomic_write( + path: Path, + data: str, + *, + retry_count: int = 0, +) -> None: ... -def is_file_or_bytes(value: Any) -> bool: - """Determine if the input value is a file-like object or bytes. - This function checks whether the provided value is an instance of bytes, bytearray, or io.IOBase (file-like). - The method is simplified for common use cases and may not cover all edge cases. +@overload +async def atomic_write( + path: Path, + data: bytes, + *, + retry_count: int = 0, +) -> None: ... - Args: - value: The value to be checked. - Returns: - True if the value is either a file-like object or bytes, False otherwise. - """ - return isinstance(value, (bytes, bytearray, io.IOBase)) +async def atomic_write( + path: Path, + data: str | bytes, + *, + retry_count: int = 0, +) -> None: + """Write data to a file atomically to prevent data corruption or partial writes. - -async def json_dumps(obj: Any) -> str: - """Serialize an object to a JSON-formatted string with specific settings. + This function handles both text and binary data. The binary mode is automatically + detected based on the data type (bytes = binary, str = text). It ensures atomic + writing by creating a temporary file and then atomically replacing the target file, + which prevents data corruption if the process is interrupted during the write operation. Args: - obj: The object to serialize. - - Returns: - A string containing the JSON representation of the input object. + path: The path to the destination file. + data: The data to write to the file (string or bytes). + retry_count: Internal parameter to track the number of retry attempts (default: 0). """ - return await asyncio.to_thread(json.dumps, obj, ensure_ascii=False, indent=2, default=str) + max_retries = 3 + + try: + # Use the platform-specific write function resolved at import time. + await asyncio.to_thread(_write_file, path, data) + except (FileNotFoundError, PermissionError): + if retry_count < max_retries: + return await atomic_write( + path, + data, + retry_count=retry_count + 1, + ) + # If we reach the maximum number of retries, raise the exception. + raise + + +async def export_json_to_stream( + iterator: AsyncIterator[dict[str, Any]], + dst: TextIO, + **kwargs: Unpack[ExportDataJsonKwargs], +) -> None: + items = [item async for item in iterator] + json.dump(items, dst, **kwargs) + + +async def export_csv_to_stream( + iterator: AsyncIterator[dict[str, Any]], + dst: TextIO, + **kwargs: Unpack[ExportDataCsvKwargs], +) -> None: + writer = csv.writer(dst, **kwargs) + write_header = True + + # Iterate over the dataset and write to CSV. + async for item in iterator: + if not item: + continue + + if write_header: + writer.writerow(item.keys()) + write_header = False + + writer.writerow(item.values()) diff --git a/src/crawlee/_utils/globs.py b/src/crawlee/_utils/globs.py index d497631d07..f7e1a57927 100644 --- a/src/crawlee/_utils/globs.py +++ b/src/crawlee/_utils/globs.py @@ -73,7 +73,7 @@ def _translate( return rf'(?s:{res})\Z' -def _fnmatch_translate(pat: str, star: str, question_mark: str) -> list[str]: # noqa: PLR0912 +def _fnmatch_translate(pat: str, star: str, question_mark: str) -> list[str]: """Copy of fnmatch._translate from Python 3.13.""" res = list[str]() add = res.append diff --git a/src/crawlee/_utils/recoverable_state.py b/src/crawlee/_utils/recoverable_state.py index 2cfdcd9ec7..35ee0a1d3f 100644 --- a/src/crawlee/_utils/recoverable_state.py +++ b/src/crawlee/_utils/recoverable_state.py @@ -4,13 +4,13 @@ from pydantic import BaseModel -from crawlee import service_locator from crawlee.events._types import Event, EventPersistStateData -from crawlee.storages._key_value_store import KeyValueStore if TYPE_CHECKING: import logging + from crawlee.storages._key_value_store import KeyValueStore + TStateModel = TypeVar('TStateModel', bound=BaseModel) @@ -59,7 +59,7 @@ def __init__( self._persist_state_key = persist_state_key self._persist_state_kvs_name = persist_state_kvs_name self._persist_state_kvs_id = persist_state_kvs_id - self._key_value_store: KeyValueStore | None = None + self._key_value_store: 'KeyValueStore | None' = None # noqa: UP037 self._log = logger async def initialize(self) -> TStateModel: @@ -75,12 +75,18 @@ async def initialize(self) -> TStateModel: self._state = self._default_state.model_copy(deep=True) return self.current_value + # Import here to avoid circular imports. + from crawlee.storages._key_value_store import KeyValueStore + self._key_value_store = await KeyValueStore.open( name=self._persist_state_kvs_name, id=self._persist_state_kvs_id ) await self._load_saved_state() + # Import here to avoid circular imports. + from crawlee import service_locator + event_manager = service_locator.get_event_manager() event_manager.on(event=Event.PERSIST_STATE, listener=self.persist_state) @@ -95,6 +101,9 @@ async def teardown(self) -> None: if not self._persistence_enabled: return + # Import here to avoid circular imports. + from crawlee import service_locator + event_manager = service_locator.get_event_manager() event_manager.off(event=Event.PERSIST_STATE, listener=self.persist_state) await self.persist_state() diff --git a/src/crawlee/configuration.py b/src/crawlee/configuration.py index de22118816..cc1f10a491 100644 --- a/src/crawlee/configuration.py +++ b/src/crawlee/configuration.py @@ -73,42 +73,6 @@ class Configuration(BaseSettings): ] = 'INFO' """The logging level.""" - default_dataset_id: Annotated[ - str, - Field( - validation_alias=AliasChoices( - 'actor_default_dataset_id', - 'apify_default_dataset_id', - 'crawlee_default_dataset_id', - ) - ), - ] = 'default' - """The default `Dataset` ID. This option is utilized by the storage client.""" - - default_key_value_store_id: Annotated[ - str, - Field( - validation_alias=AliasChoices( - 'actor_default_key_value_store_id', - 'apify_default_key_value_store_id', - 'crawlee_default_key_value_store_id', - ) - ), - ] = 'default' - """The default `KeyValueStore` ID. This option is utilized by the storage client.""" - - default_request_queue_id: Annotated[ - str, - Field( - validation_alias=AliasChoices( - 'actor_default_request_queue_id', - 'apify_default_request_queue_id', - 'crawlee_default_request_queue_id', - ) - ), - ] = 'default' - """The default `RequestQueue` ID. This option is utilized by the storage client.""" - purge_on_start: Annotated[ bool, Field( @@ -118,21 +82,7 @@ class Configuration(BaseSettings): ) ), ] = True - """Whether to purge the storage on the start. This option is utilized by the `MemoryStorageClient`.""" - - write_metadata: Annotated[bool, Field(alias='crawlee_write_metadata')] = True - """Whether to write the storage metadata. This option is utilized by the `MemoryStorageClient`.""" - - persist_storage: Annotated[ - bool, - Field( - validation_alias=AliasChoices( - 'apify_persist_storage', - 'crawlee_persist_storage', - ) - ), - ] = True - """Whether to persist the storage. This option is utilized by the `MemoryStorageClient`.""" + """Whether to purge the storage on the start. This option is utilized by the storage clients.""" persist_state_interval: Annotated[ timedelta_ms, @@ -239,7 +189,7 @@ class Configuration(BaseSettings): ), ), ] = './storage' - """The path to the storage directory. This option is utilized by the `MemoryStorageClient`.""" + """The path to the storage directory. This option is utilized by the storage clients.""" headless: Annotated[ bool, diff --git a/src/crawlee/crawlers/_basic/_basic_crawler.py b/src/crawlee/crawlers/_basic/_basic_crawler.py index 192d34091f..087e2ebca9 100644 --- a/src/crawlee/crawlers/_basic/_basic_crawler.py +++ b/src/crawlee/crawlers/_basic/_basic_crawler.py @@ -38,6 +38,7 @@ SkippedReason, ) from crawlee._utils.docs import docs_group +from crawlee._utils.file import export_csv_to_stream, export_json_to_stream from crawlee._utils.robots import RobotsTxtFile from crawlee._utils.urls import convert_to_absolute_url, is_url_absolute from crawlee._utils.wait import wait_for @@ -73,8 +74,10 @@ ConcurrencySettings, EnqueueLinksFunction, ExtractLinksFunction, + GetDataKwargs, HttpMethod, JsonSerializable, + PushDataKwargs, ) from crawlee.configuration import Configuration from crawlee.events import EventManager @@ -85,7 +88,6 @@ from crawlee.statistics import FinalStatistics from crawlee.storage_clients import StorageClient from crawlee.storage_clients.models import DatasetItemsListPage - from crawlee.storages._dataset import ExportDataCsvKwargs, ExportDataJsonKwargs, GetDataKwargs, PushDataKwargs TCrawlingContext = TypeVar('TCrawlingContext', bound=BasicCrawlingContext, default=BasicCrawlingContext) TStatisticsState = TypeVar('TStatisticsState', bound=StatisticsState, default=StatisticsState) @@ -685,6 +687,7 @@ async def add_requests( self, requests: Sequence[str | Request], *, + forefront: bool = False, batch_size: int = 1000, wait_time_between_batches: timedelta = timedelta(0), wait_for_all_requests_to_be_added: bool = False, @@ -694,6 +697,7 @@ async def add_requests( Args: requests: A list of requests to add to the queue. + forefront: If True, add requests to the forefront of the queue. batch_size: The number of requests to add in one batch. wait_time_between_batches: Time to wait between adding batches. wait_for_all_requests_to_be_added: If True, wait for all requests to be added before returning. @@ -718,17 +722,21 @@ async def add_requests( request_manager = await self.get_request_manager() - await request_manager.add_requests_batched( + await request_manager.add_requests( requests=allowed_requests, + forefront=forefront, batch_size=batch_size, wait_time_between_batches=wait_time_between_batches, wait_for_all_requests_to_be_added=wait_for_all_requests_to_be_added, wait_for_all_requests_to_be_added_timeout=wait_for_all_requests_to_be_added_timeout, ) - async def _use_state(self, default_value: dict[str, JsonSerializable] | None = None) -> dict[str, JsonSerializable]: - store = await self.get_key_value_store() - return await store.get_auto_saved_value(self._CRAWLEE_STATE_KEY, default_value) + async def _use_state( + self, + default_value: dict[str, JsonSerializable] | None = None, + ) -> dict[str, JsonSerializable]: + kvs = await self.get_key_value_store() + return await kvs.get_auto_saved_value(self._CRAWLEE_STATE_KEY, default_value) async def _save_crawler_state(self) -> None: store = await self.get_key_value_store() @@ -762,81 +770,32 @@ async def export_data( dataset_id: str | None = None, dataset_name: str | None = None, ) -> None: - """Export data from a `Dataset`. + """Export all items from a Dataset to a JSON or CSV file. - This helper method simplifies the process of exporting data from a `Dataset`. It opens the specified - one and then exports the data based on the provided parameters. If you need to pass options - specific to the output format, use the `export_data_csv` or `export_data_json` method instead. + This method simplifies the process of exporting data collected during crawling. It automatically + determines the export format based on the file extension (`.json` or `.csv`) and handles + the conversion of `Dataset` items to the appropriate format. Args: - path: The destination path. - dataset_id: The ID of the `Dataset`. - dataset_name: The name of the `Dataset`. + path: The destination file path. Must end with '.json' or '.csv'. + dataset_id: The ID of the Dataset to export from. If None, uses `name` parameter instead. + dataset_name: The name of the Dataset to export from. If None, uses `id` parameter instead. """ dataset = await self.get_dataset(id=dataset_id, name=dataset_name) path = path if isinstance(path, Path) else Path(path) - destination = path.open('w', newline='') + dst = path.open('w', newline='') if path.suffix == '.csv': - await dataset.write_to_csv(destination) + await export_csv_to_stream(dataset.iterate_items(), dst) elif path.suffix == '.json': - await dataset.write_to_json(destination) + await export_json_to_stream(dataset.iterate_items(), dst) else: raise ValueError(f'Unsupported file extension: {path.suffix}') - async def export_data_csv( - self, - path: str | Path, - *, - dataset_id: str | None = None, - dataset_name: str | None = None, - **kwargs: Unpack[ExportDataCsvKwargs], - ) -> None: - """Export data from a `Dataset` to a CSV file. - - This helper method simplifies the process of exporting data from a `Dataset` in csv format. It opens - the specified one and then exports the data based on the provided parameters. - - Args: - path: The destination path. - content_type: The output format. - dataset_id: The ID of the `Dataset`. - dataset_name: The name of the `Dataset`. - kwargs: Extra configurations for dumping/writing in csv format. - """ - dataset = await self.get_dataset(id=dataset_id, name=dataset_name) - path = path if isinstance(path, Path) else Path(path) - - return await dataset.write_to_csv(path.open('w', newline=''), **kwargs) - - async def export_data_json( - self, - path: str | Path, - *, - dataset_id: str | None = None, - dataset_name: str | None = None, - **kwargs: Unpack[ExportDataJsonKwargs], - ) -> None: - """Export data from a `Dataset` to a JSON file. - - This helper method simplifies the process of exporting data from a `Dataset` in json format. It opens the - specified one and then exports the data based on the provided parameters. - - Args: - path: The destination path - dataset_id: The ID of the `Dataset`. - dataset_name: The name of the `Dataset`. - kwargs: Extra configurations for dumping/writing in json format. - """ - dataset = await self.get_dataset(id=dataset_id, name=dataset_name) - path = path if isinstance(path, Path) else Path(path) - - return await dataset.write_to_json(path.open('w', newline=''), **kwargs) - async def _push_data( self, - data: JsonSerializable, + data: list[dict[str, Any]] | dict[str, Any], dataset_id: str | None = None, dataset_name: str | None = None, **kwargs: Unpack[PushDataKwargs], @@ -1211,7 +1170,7 @@ async def _commit_request_handler_result(self, context: BasicCrawlingContext) -> if self._max_crawl_depth is None or dst_request.crawl_depth <= self._max_crawl_depth: requests.append(dst_request) - await request_manager.add_requests_batched(requests) + await request_manager.add_requests(requests) for push_data_call in result.push_data_calls: await self._push_data(**push_data_call) diff --git a/src/crawlee/fingerprint_suite/_browserforge_adapter.py b/src/crawlee/fingerprint_suite/_browserforge_adapter.py index d64ddd59f0..11f9f82d79 100644 --- a/src/crawlee/fingerprint_suite/_browserforge_adapter.py +++ b/src/crawlee/fingerprint_suite/_browserforge_adapter.py @@ -1,10 +1,10 @@ from __future__ import annotations -import os.path from collections.abc import Iterable from copy import deepcopy from functools import reduce from operator import or_ +from pathlib import Path from typing import TYPE_CHECKING, Any, Literal from browserforge.bayesian_network import extract_json @@ -253,9 +253,9 @@ def generate(self, browser_type: SupportedBrowserType = 'chromium') -> dict[str, def get_available_header_network() -> dict: """Get header network that contains possible header values.""" - if os.path.isfile(DATA_DIR / 'header-network.zip'): + if Path(DATA_DIR / 'header-network.zip').is_file(): return extract_json(DATA_DIR / 'header-network.zip') - if os.path.isfile(DATA_DIR / 'header-network-definition.zip'): + if Path(DATA_DIR / 'header-network-definition.zip').is_file(): return extract_json(DATA_DIR / 'header-network-definition.zip') raise FileNotFoundError('Missing header-network file.') diff --git a/src/crawlee/project_template/hooks/post_gen_project.py b/src/crawlee/project_template/hooks/post_gen_project.py index e076ff9308..c0495a724d 100644 --- a/src/crawlee/project_template/hooks/post_gen_project.py +++ b/src/crawlee/project_template/hooks/post_gen_project.py @@ -2,7 +2,6 @@ import subprocess from pathlib import Path - # % if cookiecutter.package_manager in ['poetry', 'uv'] Path('requirements.txt').unlink() @@ -32,8 +31,9 @@ # Install requirements and generate requirements.txt as an impromptu lockfile subprocess.check_call([str(path / 'pip'), 'install', '-r', 'requirements.txt']) -with open('requirements.txt', 'w') as requirements_txt: - subprocess.check_call([str(path / 'pip'), 'freeze'], stdout=requirements_txt) +Path('requirements.txt').write_text( + subprocess.check_output([str(path / 'pip'), 'freeze']).decode() +) # % if cookiecutter.crawler_type == 'playwright' subprocess.check_call([str(path / 'playwright'), 'install']) diff --git a/src/crawlee/request_loaders/_request_list.py b/src/crawlee/request_loaders/_request_list.py index 5964b106d0..aaba12f5c4 100644 --- a/src/crawlee/request_loaders/_request_list.py +++ b/src/crawlee/request_loaders/_request_list.py @@ -54,6 +54,10 @@ def __init__( def name(self) -> str | None: return self._name + @override + async def get_handled_count(self) -> int: + return self._handled_count + @override async def get_total_count(self) -> int: return self._assumed_total_count @@ -87,10 +91,6 @@ async def mark_request_as_handled(self, request: Request) -> None: self._handled_count += 1 self._in_progress.remove(request.id) - @override - async def get_handled_count(self) -> int: - return self._handled_count - async def _ensure_next_request(self) -> None: if self._requests_lock is None: self._requests_lock = asyncio.Lock() diff --git a/src/crawlee/request_loaders/_request_loader.py b/src/crawlee/request_loaders/_request_loader.py index e358306a45..1f9e4aa641 100644 --- a/src/crawlee/request_loaders/_request_loader.py +++ b/src/crawlee/request_loaders/_request_loader.py @@ -25,13 +25,17 @@ class RequestLoader(ABC): - Managing state information such as the total and handled request counts. """ + @abstractmethod + async def get_handled_count(self) -> int: + """Get the number of requests in the loader that have been handled.""" + @abstractmethod async def get_total_count(self) -> int: - """Return an offline approximation of the total number of requests in the source (i.e. pending + handled).""" + """Get an offline approximation of the total number of requests in the loader (i.e. pending + handled).""" @abstractmethod async def is_empty(self) -> bool: - """Return True if there are no more requests in the source (there might still be unfinished requests).""" + """Return True if there are no more requests in the loader (there might still be unfinished requests).""" @abstractmethod async def is_finished(self) -> bool: @@ -45,10 +49,6 @@ async def fetch_next_request(self) -> Request | None: async def mark_request_as_handled(self, request: Request) -> ProcessedRequest | None: """Mark a request as handled after a successful processing (or after giving up retrying).""" - @abstractmethod - async def get_handled_count(self) -> int: - """Return the number of handled requests.""" - async def to_tandem(self, request_manager: RequestManager | None = None) -> RequestManagerTandem: """Combine the loader with a request manager to support adding and reclaiming requests. diff --git a/src/crawlee/request_loaders/_request_manager.py b/src/crawlee/request_loaders/_request_manager.py index f63f962cb9..5a8427c2cb 100644 --- a/src/crawlee/request_loaders/_request_manager.py +++ b/src/crawlee/request_loaders/_request_manager.py @@ -6,12 +6,12 @@ from crawlee._utils.docs import docs_group from crawlee.request_loaders._request_loader import RequestLoader +from crawlee.storage_clients.models import ProcessedRequest if TYPE_CHECKING: from collections.abc import Sequence from crawlee._request import Request - from crawlee.storage_clients.models import ProcessedRequest @docs_group('Abstract classes') @@ -40,10 +40,11 @@ async def add_request( Information about the request addition to the manager. """ - async def add_requests_batched( + async def add_requests( self, requests: Sequence[str | Request], *, + forefront: bool = False, batch_size: int = 1000, # noqa: ARG002 wait_time_between_batches: timedelta = timedelta(seconds=1), # noqa: ARG002 wait_for_all_requests_to_be_added: bool = False, # noqa: ARG002 @@ -53,14 +54,17 @@ async def add_requests_batched( Args: requests: Requests to enqueue. + forefront: If True, add requests to the beginning of the queue. batch_size: The number of requests to add in one batch. wait_time_between_batches: Time to wait between adding batches. wait_for_all_requests_to_be_added: If True, wait for all requests to be added before returning. wait_for_all_requests_to_be_added_timeout: Timeout for waiting for all requests to be added. """ # Default and dumb implementation. + processed_requests = list[ProcessedRequest]() for request in requests: - await self.add_request(request) + processed_request = await self.add_request(request, forefront=forefront) + processed_requests.append(processed_request) @abstractmethod async def reclaim_request(self, request: Request, *, forefront: bool = False) -> ProcessedRequest | None: diff --git a/src/crawlee/request_loaders/_request_manager_tandem.py b/src/crawlee/request_loaders/_request_manager_tandem.py index 9f0b8cefe8..6a5fe8aa65 100644 --- a/src/crawlee/request_loaders/_request_manager_tandem.py +++ b/src/crawlee/request_loaders/_request_manager_tandem.py @@ -32,6 +32,10 @@ def __init__(self, request_loader: RequestLoader, request_manager: RequestManage self._read_only_loader = request_loader self._read_write_manager = request_manager + @override + async def get_handled_count(self) -> int: + return await self._read_write_manager.get_handled_count() + @override async def get_total_count(self) -> int: return (await self._read_only_loader.get_total_count()) + (await self._read_write_manager.get_total_count()) @@ -49,17 +53,19 @@ async def add_request(self, request: str | Request, *, forefront: bool = False) return await self._read_write_manager.add_request(request, forefront=forefront) @override - async def add_requests_batched( + async def add_requests( self, requests: Sequence[str | Request], *, + forefront: bool = False, batch_size: int = 1000, wait_time_between_batches: timedelta = timedelta(seconds=1), wait_for_all_requests_to_be_added: bool = False, wait_for_all_requests_to_be_added_timeout: timedelta | None = None, ) -> None: - return await self._read_write_manager.add_requests_batched( + return await self._read_write_manager.add_requests( requests, + forefront=forefront, batch_size=batch_size, wait_time_between_batches=wait_time_between_batches, wait_for_all_requests_to_be_added=wait_for_all_requests_to_be_added, @@ -97,10 +103,6 @@ async def reclaim_request(self, request: Request, *, forefront: bool = False) -> async def mark_request_as_handled(self, request: Request) -> None: await self._read_write_manager.mark_request_as_handled(request) - @override - async def get_handled_count(self) -> int: - return await self._read_write_manager.get_handled_count() - @override async def drop(self) -> None: await self._read_write_manager.drop() diff --git a/src/crawlee/statistics/_error_snapshotter.py b/src/crawlee/statistics/_error_snapshotter.py index 21dbd33d48..4404904226 100644 --- a/src/crawlee/statistics/_error_snapshotter.py +++ b/src/crawlee/statistics/_error_snapshotter.py @@ -23,7 +23,12 @@ class ErrorSnapshotter: def __init__(self, *, snapshot_kvs_name: str | None = None) -> None: self._kvs_name = snapshot_kvs_name - async def capture_snapshot(self, error_message: str, file_and_line: str, context: BasicCrawlingContext) -> None: + async def capture_snapshot( + self, + error_message: str, + file_and_line: str, + context: BasicCrawlingContext, + ) -> None: """Capture error snapshot and save it to key value store. It saves the error snapshot directly to a key value store. It can't use `context.get_key_value_store` because @@ -37,26 +42,28 @@ async def capture_snapshot(self, error_message: str, file_and_line: str, context context: Context that is used to get the snapshot. """ if snapshot := await context.get_snapshot(): + kvs = await KeyValueStore.open(name=self._kvs_name) snapshot_base_name = self._get_snapshot_base_name(error_message, file_and_line) - snapshot_save_tasks = [] + snapshot_save_tasks = list[asyncio.Task]() + if snapshot.html: snapshot_save_tasks.append( - asyncio.create_task(self._save_html(snapshot.html, base_name=snapshot_base_name)) + asyncio.create_task(self._save_html(kvs, snapshot.html, base_name=snapshot_base_name)) ) + if snapshot.screenshot: snapshot_save_tasks.append( - asyncio.create_task(self._save_screenshot(snapshot.screenshot, base_name=snapshot_base_name)) + asyncio.create_task(self._save_screenshot(kvs, snapshot.screenshot, base_name=snapshot_base_name)) ) + await asyncio.gather(*snapshot_save_tasks) - async def _save_html(self, html: str, base_name: str) -> None: + async def _save_html(self, kvs: KeyValueStore, html: str, base_name: str) -> None: file_name = f'{base_name}.html' - kvs = await KeyValueStore.open(name=self._kvs_name) await kvs.set_value(file_name, html, content_type='text/html') - async def _save_screenshot(self, screenshot: bytes, base_name: str) -> None: + async def _save_screenshot(self, kvs: KeyValueStore, screenshot: bytes, base_name: str) -> None: file_name = f'{base_name}.jpg' - kvs = await KeyValueStore.open(name=self._kvs_name) await kvs.set_value(file_name, screenshot, content_type='image/jpeg') def _sanitize_filename(self, filename: str) -> str: diff --git a/src/crawlee/storage_clients/__init__.py b/src/crawlee/storage_clients/__init__.py index 66d352d7a7..ce8c713ca9 100644 --- a/src/crawlee/storage_clients/__init__.py +++ b/src/crawlee/storage_clients/__init__.py @@ -1,4 +1,9 @@ from ._base import StorageClient +from ._file_system import FileSystemStorageClient from ._memory import MemoryStorageClient -__all__ = ['MemoryStorageClient', 'StorageClient'] +__all__ = [ + 'FileSystemStorageClient', + 'MemoryStorageClient', + 'StorageClient', +] diff --git a/src/crawlee/storage_clients/_base/__init__.py b/src/crawlee/storage_clients/_base/__init__.py index 5194da8768..73298560da 100644 --- a/src/crawlee/storage_clients/_base/__init__.py +++ b/src/crawlee/storage_clients/_base/__init__.py @@ -1,20 +1,11 @@ from ._dataset_client import DatasetClient -from ._dataset_collection_client import DatasetCollectionClient from ._key_value_store_client import KeyValueStoreClient -from ._key_value_store_collection_client import KeyValueStoreCollectionClient from ._request_queue_client import RequestQueueClient -from ._request_queue_collection_client import RequestQueueCollectionClient from ._storage_client import StorageClient -from ._types import ResourceClient, ResourceCollectionClient __all__ = [ 'DatasetClient', - 'DatasetCollectionClient', 'KeyValueStoreClient', - 'KeyValueStoreCollectionClient', 'RequestQueueClient', - 'RequestQueueCollectionClient', - 'ResourceClient', - 'ResourceCollectionClient', 'StorageClient', ] diff --git a/src/crawlee/storage_clients/_base/_dataset_client.py b/src/crawlee/storage_clients/_base/_dataset_client.py index d8495b2dd0..840d816ea2 100644 --- a/src/crawlee/storage_clients/_base/_dataset_client.py +++ b/src/crawlee/storage_clients/_base/_dataset_client.py @@ -7,58 +7,56 @@ if TYPE_CHECKING: from collections.abc import AsyncIterator - from contextlib import AbstractAsyncContextManager + from typing import Any - from httpx import Response - - from crawlee._types import JsonSerializable from crawlee.storage_clients.models import DatasetItemsListPage, DatasetMetadata @docs_group('Abstract classes') class DatasetClient(ABC): - """An abstract class for dataset resource clients. + """An abstract class for dataset storage clients. + + Dataset clients provide an interface for accessing and manipulating dataset storage. They handle + operations like adding and getting dataset items across different storage backends. + + Storage clients are specific to the type of storage they manage (`Dataset`, `KeyValueStore`, + `RequestQueue`), and can operate with various storage systems including memory, file system, + databases, and cloud storage solutions. - These clients are specific to the type of resource they manage and operate under a designated storage - client, like a memory storage client. + This abstract class defines the interface that all specific dataset clients must implement. """ - _LIST_ITEMS_LIMIT = 999_999_999_999 - """This is what API returns in the x-apify-pagination-limit header when no limit query parameter is used.""" + @abstractmethod + async def get_metadata(self) -> DatasetMetadata: + """Get the metadata of the dataset.""" @abstractmethod - async def get(self) -> DatasetMetadata | None: - """Get metadata about the dataset being managed by this client. + async def drop(self) -> None: + """Drop the whole dataset and remove all its items. - Returns: - An object containing the dataset's details, or None if the dataset does not exist. + The backend method for the `Dataset.drop` call. """ @abstractmethod - async def update( - self, - *, - name: str | None = None, - ) -> DatasetMetadata: - """Update the dataset metadata. + async def purge(self) -> None: + """Purge all items from the dataset. - Args: - name: New new name for the dataset. - - Returns: - An object reflecting the updated dataset metadata. + The backend method for the `Dataset.purge` call. """ @abstractmethod - async def delete(self) -> None: - """Permanently delete the dataset managed by this client.""" + async def push_data(self, data: list[Any] | dict[str, Any]) -> None: + """Push data to the dataset. + + The backend method for the `Dataset.push_data` call. + """ @abstractmethod - async def list_items( + async def get_data( self, *, - offset: int | None = 0, - limit: int | None = _LIST_ITEMS_LIMIT, + offset: int = 0, + limit: int | None = 999_999_999_999, clean: bool = False, desc: bool = False, fields: list[str] | None = None, @@ -69,27 +67,9 @@ async def list_items( flatten: list[str] | None = None, view: str | None = None, ) -> DatasetItemsListPage: - """Retrieve a paginated list of items from a dataset based on various filtering parameters. - - This method provides the flexibility to filter, sort, and modify the appearance of dataset items - when listed. Each parameter modifies the result set according to its purpose. The method also - supports pagination through 'offset' and 'limit' parameters. - - Args: - offset: The number of initial items to skip. - limit: The maximum number of items to return. - clean: If True, removes empty items and hidden fields, equivalent to 'skip_hidden' and 'skip_empty'. - desc: If True, items are returned in descending order, i.e., newest first. - fields: Specifies a subset of fields to include in each item. - omit: Specifies a subset of fields to exclude from each item. - unwind: Specifies a field that should be unwound. If it's an array, each element becomes a separate record. - skip_empty: If True, omits items that are empty after other filters have been applied. - skip_hidden: If True, omits fields starting with the '#' character. - flatten: A list of fields to flatten in each item. - view: The specific view of the dataset to use when retrieving items. - - Returns: - An object with filtered, sorted, and paginated dataset items plus pagination details. + """Get data from the dataset with various filtering options. + + The backend method for the `Dataset.get_data` call. """ @abstractmethod @@ -105,127 +85,13 @@ async def iterate_items( unwind: str | None = None, skip_empty: bool = False, skip_hidden: bool = False, - ) -> AsyncIterator[dict]: - """Iterate over items in the dataset according to specified filters and sorting. - - This method allows for asynchronously iterating through dataset items while applying various filters such as - skipping empty items, hiding specific fields, and sorting. It supports pagination via `offset` and `limit` - parameters, and can modify the appearance of dataset items using `fields`, `omit`, `unwind`, `skip_empty`, and - `skip_hidden` parameters. - - Args: - offset: The number of initial items to skip. - limit: The maximum number of items to iterate over. None means no limit. - clean: If True, removes empty items and hidden fields, equivalent to 'skip_hidden' and 'skip_empty'. - desc: If set to True, items are returned in descending order, i.e., newest first. - fields: Specifies a subset of fields to include in each item. - omit: Specifies a subset of fields to exclude from each item. - unwind: Specifies a field that should be unwound into separate items. - skip_empty: If set to True, omits items that are empty after other filters have been applied. - skip_hidden: If set to True, omits fields starting with the '#' character from the output. - - Yields: - An asynchronous iterator of dictionary objects, each representing a dataset item after applying - the specified filters and transformations. + ) -> AsyncIterator[dict[str, Any]]: + """Iterate over the dataset items with filtering options. + + The backend method for the `Dataset.iterate_items` call. """ # This syntax is to make mypy properly work with abstract AsyncIterator. # https://mypy.readthedocs.io/en/stable/more_types.html#asynchronous-iterators raise NotImplementedError if False: # type: ignore[unreachable] yield 0 - - @abstractmethod - async def get_items_as_bytes( - self, - *, - item_format: str = 'json', - offset: int | None = None, - limit: int | None = None, - desc: bool = False, - clean: bool = False, - bom: bool = False, - delimiter: str | None = None, - fields: list[str] | None = None, - omit: list[str] | None = None, - unwind: str | None = None, - skip_empty: bool = False, - skip_header_row: bool = False, - skip_hidden: bool = False, - xml_root: str | None = None, - xml_row: str | None = None, - flatten: list[str] | None = None, - ) -> bytes: - """Retrieve dataset items as bytes. - - Args: - item_format: Output format (e.g., 'json', 'csv'); default is 'json'. - offset: Number of items to skip; default is 0. - limit: Max number of items to return; no default limit. - desc: If True, results are returned in descending order. - clean: If True, filters out empty items and hidden fields. - bom: Include or exclude UTF-8 BOM; default behavior varies by format. - delimiter: Delimiter character for CSV; default is ','. - fields: List of fields to include in the results. - omit: List of fields to omit from the results. - unwind: Unwinds a field into separate records. - skip_empty: If True, skips empty items in the output. - skip_header_row: If True, skips the header row in CSV. - skip_hidden: If True, skips hidden fields in the output. - xml_root: Root element name for XML output; default is 'items'. - xml_row: Element name for each item in XML output; default is 'item'. - flatten: List of fields to flatten. - - Returns: - The dataset items as raw bytes. - """ - - @abstractmethod - async def stream_items( - self, - *, - item_format: str = 'json', - offset: int | None = None, - limit: int | None = None, - desc: bool = False, - clean: bool = False, - bom: bool = False, - delimiter: str | None = None, - fields: list[str] | None = None, - omit: list[str] | None = None, - unwind: str | None = None, - skip_empty: bool = False, - skip_header_row: bool = False, - skip_hidden: bool = False, - xml_root: str | None = None, - xml_row: str | None = None, - ) -> AbstractAsyncContextManager[Response | None]: - """Retrieve dataset items as a streaming response. - - Args: - item_format: Output format, options include json, jsonl, csv, html, xlsx, xml, rss; default is json. - offset: Number of items to skip at the start; default is 0. - limit: Maximum number of items to return; no default limit. - desc: If True, reverses the order of results. - clean: If True, filters out empty items and hidden fields. - bom: Include or exclude UTF-8 BOM; varies by format. - delimiter: Delimiter for CSV files; default is ','. - fields: List of fields to include in the output. - omit: List of fields to omit from the output. - unwind: Unwinds a field into separate records. - skip_empty: If True, empty items are omitted. - skip_header_row: If True, skips the header row in CSV. - skip_hidden: If True, hides fields starting with the # character. - xml_root: Custom root element name for XML output; default is 'items'. - xml_row: Custom element name for each item in XML; default is 'item'. - - Yields: - The dataset items in a streaming response. - """ - - @abstractmethod - async def push_items(self, items: JsonSerializable) -> None: - """Push items to the dataset. - - Args: - items: The items which to push in the dataset. They must be JSON serializable. - """ diff --git a/src/crawlee/storage_clients/_base/_dataset_collection_client.py b/src/crawlee/storage_clients/_base/_dataset_collection_client.py deleted file mode 100644 index 8530655c8c..0000000000 --- a/src/crawlee/storage_clients/_base/_dataset_collection_client.py +++ /dev/null @@ -1,59 +0,0 @@ -from __future__ import annotations - -from abc import ABC, abstractmethod -from typing import TYPE_CHECKING - -from crawlee._utils.docs import docs_group - -if TYPE_CHECKING: - from crawlee.storage_clients.models import DatasetListPage, DatasetMetadata - - -@docs_group('Abstract classes') -class DatasetCollectionClient(ABC): - """An abstract class for dataset collection clients. - - This collection client handles operations that involve multiple instances of a given resource type. - """ - - @abstractmethod - async def get_or_create( - self, - *, - id: str | None = None, - name: str | None = None, - schema: dict | None = None, - ) -> DatasetMetadata: - """Retrieve an existing dataset by its name or ID, or create a new one if it does not exist. - - Args: - id: Optional ID of the dataset to retrieve or create. If provided, the method will attempt - to find a dataset with the ID. - name: Optional name of the dataset resource to retrieve or create. If provided, the method will - attempt to find a dataset with this name. - schema: Optional schema for the dataset resource to be created. - - Returns: - Metadata object containing the information of the retrieved or created dataset. - """ - - @abstractmethod - async def list( - self, - *, - unnamed: bool = False, - limit: int | None = None, - offset: int | None = None, - desc: bool = False, - ) -> DatasetListPage: - """List the available datasets. - - Args: - unnamed: Whether to list only the unnamed datasets. - limit: Maximum number of datasets to return. - offset: Number of datasets to skip from the beginning of the list. - desc: Whether to sort the datasets in descending order. - - Returns: - The list of available datasets matching the specified filters. - """ diff --git a/src/crawlee/storage_clients/_base/_key_value_store_client.py b/src/crawlee/storage_clients/_base/_key_value_store_client.py index 6a5d141be6..0def370551 100644 --- a/src/crawlee/storage_clients/_base/_key_value_store_client.py +++ b/src/crawlee/storage_clients/_base/_key_value_store_client.py @@ -6,126 +6,97 @@ from crawlee._utils.docs import docs_group if TYPE_CHECKING: - from contextlib import AbstractAsyncContextManager + from collections.abc import AsyncIterator - from httpx import Response - - from crawlee.storage_clients.models import KeyValueStoreListKeysPage, KeyValueStoreMetadata, KeyValueStoreRecord + from crawlee.storage_clients.models import KeyValueStoreMetadata, KeyValueStoreRecord, KeyValueStoreRecordMetadata @docs_group('Abstract classes') class KeyValueStoreClient(ABC): - """An abstract class for key-value store resource clients. + """An abstract class for key-value store (KVS) storage clients. - These clients are specific to the type of resource they manage and operate under a designated storage - client, like a memory storage client. - """ + Key-value stores clients provide an interface for accessing and manipulating KVS storage. They handle + operations like getting, setting, deleting KVS values across different storage backends. - @abstractmethod - async def get(self) -> KeyValueStoreMetadata | None: - """Get metadata about the key-value store being managed by this client. + Storage clients are specific to the type of storage they manage (`Dataset`, `KeyValueStore`, + `RequestQueue`), and can operate with various storage systems including memory, file system, + databases, and cloud storage solutions. - Returns: - An object containing the key-value store's details, or None if the key-value store does not exist. - """ + This abstract class defines the interface that all specific KVS clients must implement. + """ @abstractmethod - async def update( - self, - *, - name: str | None = None, - ) -> KeyValueStoreMetadata: - """Update the key-value store metadata. + async def get_metadata(self) -> KeyValueStoreMetadata: + """Get the metadata of the key-value store.""" - Args: - name: New new name for the key-value store. + @abstractmethod + async def drop(self) -> None: + """Drop the whole key-value store and remove all its values. - Returns: - An object reflecting the updated key-value store metadata. + The backend method for the `KeyValueStore.drop` call. """ @abstractmethod - async def delete(self) -> None: - """Permanently delete the key-value store managed by this client.""" + async def purge(self) -> None: + """Purge all items from the key-value store. - @abstractmethod - async def list_keys( - self, - *, - limit: int = 1000, - exclusive_start_key: str | None = None, - ) -> KeyValueStoreListKeysPage: - """List the keys in the key-value store. - - Args: - limit: Number of keys to be returned. Maximum value is 1000. - exclusive_start_key: All keys up to this one (including) are skipped from the result. - - Returns: - The list of keys in the key-value store matching the given arguments. + The backend method for the `KeyValueStore.purge` call. """ @abstractmethod - async def get_record(self, key: str) -> KeyValueStoreRecord | None: + async def get_value(self, *, key: str) -> KeyValueStoreRecord | None: """Retrieve the given record from the key-value store. - Args: - key: Key of the record to retrieve. - - Returns: - The requested record, or None, if the record does not exist + The backend method for the `KeyValueStore.get_value` call. """ @abstractmethod - async def get_record_as_bytes(self, key: str) -> KeyValueStoreRecord[bytes] | None: - """Retrieve the given record from the key-value store, without parsing it. - - Args: - key: Key of the record to retrieve. + async def set_value(self, *, key: str, value: Any, content_type: str | None = None) -> None: + """Set a value in the key-value store by its key. - Returns: - The requested record, or None, if the record does not exist + The backend method for the `KeyValueStore.set_value` call. """ @abstractmethod - async def stream_record(self, key: str) -> AbstractAsyncContextManager[KeyValueStoreRecord[Response] | None]: - """Retrieve the given record from the key-value store, as a stream. - - Args: - key: Key of the record to retrieve. + async def delete_value(self, *, key: str) -> None: + """Delete a value from the key-value store by its key. - Returns: - The requested record as a context-managed streaming Response, or None, if the record does not exist + The backend method for the `KeyValueStore.delete_value` call. """ @abstractmethod - async def set_record(self, key: str, value: Any, content_type: str | None = None) -> None: - """Set a value to the given record in the key-value store. + async def iterate_keys( + self, + *, + exclusive_start_key: str | None = None, + limit: int | None = None, + ) -> AsyncIterator[KeyValueStoreRecordMetadata]: + """Iterate over all the existing keys in the key-value store. - Args: - key: The key of the record to save the value to. - value: The value to save into the record. - content_type: The content type of the saved value. + The backend method for the `KeyValueStore.iterate_keys` call. """ + # This syntax is to make mypy properly work with abstract AsyncIterator. + # https://mypy.readthedocs.io/en/stable/more_types.html#asynchronous-iterators + raise NotImplementedError + if False: # type: ignore[unreachable] + yield 0 @abstractmethod - async def delete_record(self, key: str) -> None: - """Delete the specified record from the key-value store. + async def get_public_url(self, *, key: str) -> str: + """Get the public URL for the given key. - Args: - key: The key of the record which to delete. + The backend method for the `KeyValueStore.get_public_url` call. """ @abstractmethod - async def get_public_url(self, key: str) -> str: - """Get the public URL for the given key. + async def record_exists(self, *, key: str) -> bool: + """Check if a record with the given key exists in the key-value store. + + The backend method for the `KeyValueStore.record_exists` call. Args: - key: Key of the record for which URL is required. + key: The key to check for existence. Returns: - The public URL for the given key. - - Raises: - ValueError: If the key does not exist. + True if a record with the given key exists, False otherwise. """ diff --git a/src/crawlee/storage_clients/_base/_key_value_store_collection_client.py b/src/crawlee/storage_clients/_base/_key_value_store_collection_client.py deleted file mode 100644 index b447cf49b1..0000000000 --- a/src/crawlee/storage_clients/_base/_key_value_store_collection_client.py +++ /dev/null @@ -1,59 +0,0 @@ -from __future__ import annotations - -from abc import ABC, abstractmethod -from typing import TYPE_CHECKING - -from crawlee._utils.docs import docs_group - -if TYPE_CHECKING: - from crawlee.storage_clients.models import KeyValueStoreListPage, KeyValueStoreMetadata - - -@docs_group('Abstract classes') -class KeyValueStoreCollectionClient(ABC): - """An abstract class for key-value store collection clients. - - This collection client handles operations that involve multiple instances of a given resource type. - """ - - @abstractmethod - async def get_or_create( - self, - *, - id: str | None = None, - name: str | None = None, - schema: dict | None = None, - ) -> KeyValueStoreMetadata: - """Retrieve an existing key-value store by its name or ID, or create a new one if it does not exist. - - Args: - id: Optional ID of the key-value store to retrieve or create. If provided, the method will attempt - to find a key-value store with the ID. - name: Optional name of the key-value store resource to retrieve or create. If provided, the method will - attempt to find a key-value store with this name. - schema: Optional schema for the key-value store resource to be created. - - Returns: - Metadata object containing the information of the retrieved or created key-value store. - """ - - @abstractmethod - async def list( - self, - *, - unnamed: bool = False, - limit: int | None = None, - offset: int | None = None, - desc: bool = False, - ) -> KeyValueStoreListPage: - """List the available key-value stores. - - Args: - unnamed: Whether to list only the unnamed key-value stores. - limit: Maximum number of key-value stores to return. - offset: Number of key-value stores to skip from the beginning of the list. - desc: Whether to sort the key-value stores in descending order. - - Returns: - The list of available key-value stores matching the specified filters. - """ diff --git a/src/crawlee/storage_clients/_base/_request_queue_client.py b/src/crawlee/storage_clients/_base/_request_queue_client.py index 06b180801a..c50b1af685 100644 --- a/src/crawlee/storage_clients/_base/_request_queue_client.py +++ b/src/crawlee/storage_clients/_base/_request_queue_client.py @@ -8,15 +8,8 @@ if TYPE_CHECKING: from collections.abc import Sequence - from crawlee.storage_clients.models import ( - BatchRequestsOperationResponse, - ProcessedRequest, - ProlongRequestLockResponse, - Request, - RequestQueueHead, - RequestQueueHeadWithLocks, - RequestQueueMetadata, - ) + from crawlee import Request + from crawlee.storage_clients.models import AddRequestsResponse, ProcessedRequest, RequestQueueMetadata @docs_group('Abstract classes') @@ -28,90 +21,48 @@ class RequestQueueClient(ABC): """ @abstractmethod - async def get(self) -> RequestQueueMetadata | None: - """Get metadata about the request queue being managed by this client. - - Returns: - An object containing the request queue's details, or None if the request queue does not exist. - """ - - @abstractmethod - async def update( - self, - *, - name: str | None = None, - ) -> RequestQueueMetadata: - """Update the request queue metadata. - - Args: - name: New new name for the request queue. - - Returns: - An object reflecting the updated request queue metadata. - """ - - @abstractmethod - async def delete(self) -> None: - """Permanently delete the request queue managed by this client.""" + async def get_metadata(self) -> RequestQueueMetadata: + """Get the metadata of the request queue.""" @abstractmethod - async def list_head(self, *, limit: int | None = None) -> RequestQueueHead: - """Retrieve a given number of requests from the beginning of the queue. + async def drop(self) -> None: + """Drop the whole request queue and remove all its values. - Args: - limit: How many requests to retrieve. - - Returns: - The desired number of requests from the beginning of the queue. - """ - - @abstractmethod - async def list_and_lock_head(self, *, lock_secs: int, limit: int | None = None) -> RequestQueueHeadWithLocks: - """Fetch and lock a specified number of requests from the start of the queue. - - Retrieve and locks the first few requests of a queue for the specified duration. This prevents the requests - from being fetched by another client until the lock expires. - - Args: - lock_secs: Duration for which the requests are locked, in seconds. - limit: Maximum number of requests to retrieve and lock. - - Returns: - The desired number of locked requests from the beginning of the queue. + The backend method for the `RequestQueue.drop` call. """ @abstractmethod - async def add_request( - self, - request: Request, - *, - forefront: bool = False, - ) -> ProcessedRequest: - """Add a request to the queue. - - Args: - request: The request to add to the queue. - forefront: Whether to add the request to the head or the end of the queue. + async def purge(self) -> None: + """Purge all items from the request queue. - Returns: - Request queue operation information. + The backend method for the `RequestQueue.purge` call. """ @abstractmethod - async def batch_add_requests( + async def add_batch_of_requests( self, requests: Sequence[Request], *, forefront: bool = False, - ) -> BatchRequestsOperationResponse: - """Add a batch of requests to the queue. + ) -> AddRequestsResponse: + """Add batch of requests to the queue. + + This method adds a batch of requests to the queue. Each request is processed based on its uniqueness + (determined by `unique_key`). Duplicates will be identified but not re-added to the queue. Args: - requests: The requests to add to the queue. - forefront: Whether to add the requests to the head or the end of the queue. + requests: The collection of requests to add to the queue. + forefront: Whether to put the added requests at the beginning (True) or the end (False) of the queue. + When True, the requests will be processed sooner than previously added requests. + batch_size: The maximum number of requests to add in a single batch. + wait_time_between_batches: The time to wait between adding batches of requests. + wait_for_all_requests_to_be_added: If True, the method will wait until all requests are added + to the queue before returning. + wait_for_all_requests_to_be_added_timeout: The maximum time to wait for all requests to be added. Returns: - Request queue batch operation information. + A response object containing information about which requests were successfully + processed and which failed (if any). """ @abstractmethod @@ -126,64 +77,58 @@ async def get_request(self, request_id: str) -> Request | None: """ @abstractmethod - async def update_request( - self, - request: Request, - *, - forefront: bool = False, - ) -> ProcessedRequest: - """Update a request in the queue. + async def fetch_next_request(self) -> Request | None: + """Return the next request in the queue to be processed. - Args: - request: The updated request. - forefront: Whether to put the updated request in the beginning or the end of the queue. + Once you successfully finish processing of the request, you need to call `RequestQueue.mark_request_as_handled` + to mark the request as handled in the queue. If there was some error in processing the request, call + `RequestQueue.reclaim_request` instead, so that the queue will give the request to some other consumer + in another call to the `fetch_next_request` method. + + Note that the `None` return value does not mean the queue processing finished, it means there are currently + no pending requests. To check whether all requests in queue were finished, use `RequestQueue.is_finished` + instead. Returns: - The updated request + The request or `None` if there are no more pending requests. """ @abstractmethod - async def delete_request(self, request_id: str) -> None: - """Delete a request from the queue. - - Args: - request_id: ID of the request to delete. - """ + async def mark_request_as_handled(self, request: Request) -> ProcessedRequest | None: + """Mark a request as handled after successful processing. - @abstractmethod - async def prolong_request_lock( - self, - request_id: str, - *, - forefront: bool = False, - lock_secs: int, - ) -> ProlongRequestLockResponse: - """Prolong the lock on a specific request in the queue. + Handled requests will never again be returned by the `RequestQueue.fetch_next_request` method. Args: - request_id: The identifier of the request whose lock is to be prolonged. - forefront: Whether to put the request in the beginning or the end of the queue after lock expires. - lock_secs: The additional amount of time, in seconds, that the request will remain locked. + request: The request to mark as handled. + + Returns: + Information about the queue operation. `None` if the given request was not in progress. """ @abstractmethod - async def delete_request_lock( + async def reclaim_request( self, - request_id: str, + request: Request, *, forefront: bool = False, - ) -> None: - """Delete the lock on a specific request in the queue. + ) -> ProcessedRequest | None: + """Reclaim a failed request back to the queue. + + The request will be returned for processing later again by another call to `RequestQueue.fetch_next_request`. Args: - request_id: ID of the request to delete the lock. - forefront: Whether to put the request in the beginning or the end of the queue after the lock is deleted. + request: The request to return to the queue. + forefront: Whether to add the request to the head or the end of the queue. + + Returns: + Information about the queue operation. `None` if the given request was not in progress. """ @abstractmethod - async def batch_delete_requests(self, requests: list[Request]) -> BatchRequestsOperationResponse: - """Delete given requests from the queue. + async def is_empty(self) -> bool: + """Check if the request queue is empty. - Args: - requests: The requests to delete from the queue. + Returns: + True if the request queue is empty, False otherwise. """ diff --git a/src/crawlee/storage_clients/_base/_request_queue_collection_client.py b/src/crawlee/storage_clients/_base/_request_queue_collection_client.py deleted file mode 100644 index 7de876c344..0000000000 --- a/src/crawlee/storage_clients/_base/_request_queue_collection_client.py +++ /dev/null @@ -1,59 +0,0 @@ -from __future__ import annotations - -from abc import ABC, abstractmethod -from typing import TYPE_CHECKING - -from crawlee._utils.docs import docs_group - -if TYPE_CHECKING: - from crawlee.storage_clients.models import RequestQueueListPage, RequestQueueMetadata - - -@docs_group('Abstract classes') -class RequestQueueCollectionClient(ABC): - """An abstract class for request queue collection clients. - - This collection client handles operations that involve multiple instances of a given resource type. - """ - - @abstractmethod - async def get_or_create( - self, - *, - id: str | None = None, - name: str | None = None, - schema: dict | None = None, - ) -> RequestQueueMetadata: - """Retrieve an existing request queue by its name or ID, or create a new one if it does not exist. - - Args: - id: Optional ID of the request queue to retrieve or create. If provided, the method will attempt - to find a request queue with the ID. - name: Optional name of the request queue resource to retrieve or create. If provided, the method will - attempt to find a request queue with this name. - schema: Optional schema for the request queue resource to be created. - - Returns: - Metadata object containing the information of the retrieved or created request queue. - """ - - @abstractmethod - async def list( - self, - *, - unnamed: bool = False, - limit: int | None = None, - offset: int | None = None, - desc: bool = False, - ) -> RequestQueueListPage: - """List the available request queues. - - Args: - unnamed: Whether to list only the unnamed request queues. - limit: Maximum number of request queues to return. - offset: Number of request queues to skip from the beginning of the list. - desc: Whether to sort the request queues in descending order. - - Returns: - The list of available request queues matching the specified filters. - """ diff --git a/src/crawlee/storage_clients/_base/_storage_client.py b/src/crawlee/storage_clients/_base/_storage_client.py index 4f022cf30a..ef27e3e563 100644 --- a/src/crawlee/storage_clients/_base/_storage_client.py +++ b/src/crawlee/storage_clients/_base/_storage_client.py @@ -1,5 +1,3 @@ -# Inspiration: https://github.com/apify/crawlee/blob/v3.8.2/packages/types/src/storages.ts#L314:L328 - from __future__ import annotations from abc import ABC, abstractmethod @@ -8,55 +6,77 @@ from crawlee._utils.docs import docs_group if TYPE_CHECKING: + from crawlee.configuration import Configuration + from ._dataset_client import DatasetClient - from ._dataset_collection_client import DatasetCollectionClient from ._key_value_store_client import KeyValueStoreClient - from ._key_value_store_collection_client import KeyValueStoreCollectionClient from ._request_queue_client import RequestQueueClient - from ._request_queue_collection_client import RequestQueueCollectionClient @docs_group('Abstract classes') class StorageClient(ABC): - """Defines an abstract base for storage clients. - - It offers interfaces to get subclients for interacting with storage resources like datasets, key-value stores, - and request queues. - """ + """Base class for storage clients. - @abstractmethod - def dataset(self, id: str) -> DatasetClient: - """Get a subclient for a specific dataset by its ID.""" - - @abstractmethod - def datasets(self) -> DatasetCollectionClient: - """Get a subclient for dataset collection operations.""" + The `StorageClient` serves as an abstract base class that defines the interface for accessing Crawlee's + storage types: datasets, key-value stores, and request queues. It provides methods to open clients for + each of these storage types and handles common functionality. - @abstractmethod - def key_value_store(self, id: str) -> KeyValueStoreClient: - """Get a subclient for a specific key-value store by its ID.""" + Storage clients implementations can be provided for various backends (file system, memory, databases, + various cloud providers, etc.) to support different use cases from development to production environments. - @abstractmethod - def key_value_stores(self) -> KeyValueStoreCollectionClient: - """Get a subclient for key-value store collection operations.""" + Each storage client implementation is responsible for ensuring proper initialization, data persistence + (where applicable), and consistent access patterns across all storage types it supports. + """ @abstractmethod - def request_queue(self, id: str) -> RequestQueueClient: - """Get a subclient for a specific request queue by its ID.""" + async def create_dataset_client( + self, + *, + id: str | None = None, + name: str | None = None, + configuration: Configuration | None = None, + ) -> DatasetClient: + """Create a dataset client.""" @abstractmethod - def request_queues(self) -> RequestQueueCollectionClient: - """Get a subclient for request queue collection operations.""" + async def create_kvs_client( + self, + *, + id: str | None = None, + name: str | None = None, + configuration: Configuration | None = None, + ) -> KeyValueStoreClient: + """Create a key-value store client.""" @abstractmethod - async def purge_on_start(self) -> None: - """Perform a purge of the default storages. - - This method ensures that the purge is executed only once during the lifetime of the instance. - It is primarily used to clean up residual data from previous runs to maintain a clean state. - If the storage client does not support purging, leave it empty. - """ + async def create_rq_client( + self, + *, + id: str | None = None, + name: str | None = None, + configuration: Configuration | None = None, + ) -> RequestQueueClient: + """Create a request queue client.""" def get_rate_limit_errors(self) -> dict[int, int]: """Return statistics about rate limit errors encountered by the HTTP client in storage client.""" return {} + + async def _purge_if_needed( + self, + client: DatasetClient | KeyValueStoreClient | RequestQueueClient, + configuration: Configuration, + ) -> None: + """Purge the client if needed. + + The purge is only performed if the configuration indicates that it should be done and the client + is not a named storage. Named storages are considered global and will typically outlive the run, + so they are not purged. + + Args: + client: The storage client to potentially purge. + configuration: Configuration that determines whether purging should occur. + """ + metadata = await client.get_metadata() + if configuration.purge_on_start and metadata.name is None: + await client.purge() diff --git a/src/crawlee/storage_clients/_base/_types.py b/src/crawlee/storage_clients/_base/_types.py deleted file mode 100644 index a5cf1325f5..0000000000 --- a/src/crawlee/storage_clients/_base/_types.py +++ /dev/null @@ -1,22 +0,0 @@ -from __future__ import annotations - -from typing import Union - -from ._dataset_client import DatasetClient -from ._dataset_collection_client import DatasetCollectionClient -from ._key_value_store_client import KeyValueStoreClient -from ._key_value_store_collection_client import KeyValueStoreCollectionClient -from ._request_queue_client import RequestQueueClient -from ._request_queue_collection_client import RequestQueueCollectionClient - -ResourceClient = Union[ - DatasetClient, - KeyValueStoreClient, - RequestQueueClient, -] - -ResourceCollectionClient = Union[ - DatasetCollectionClient, - KeyValueStoreCollectionClient, - RequestQueueCollectionClient, -] diff --git a/src/crawlee/storage_clients/_file_system/__init__.py b/src/crawlee/storage_clients/_file_system/__init__.py new file mode 100644 index 0000000000..2169896d86 --- /dev/null +++ b/src/crawlee/storage_clients/_file_system/__init__.py @@ -0,0 +1,11 @@ +from ._dataset_client import FileSystemDatasetClient +from ._key_value_store_client import FileSystemKeyValueStoreClient +from ._request_queue_client import FileSystemRequestQueueClient +from ._storage_client import FileSystemStorageClient + +__all__ = [ + 'FileSystemDatasetClient', + 'FileSystemKeyValueStoreClient', + 'FileSystemRequestQueueClient', + 'FileSystemStorageClient', +] diff --git a/src/crawlee/storage_clients/_file_system/_dataset_client.py b/src/crawlee/storage_clients/_file_system/_dataset_client.py new file mode 100644 index 0000000000..54b0fe30ca --- /dev/null +++ b/src/crawlee/storage_clients/_file_system/_dataset_client.py @@ -0,0 +1,483 @@ +from __future__ import annotations + +import asyncio +import json +import shutil +from datetime import datetime, timezone +from logging import getLogger +from pathlib import Path +from typing import TYPE_CHECKING, Any + +from pydantic import ValidationError +from typing_extensions import override + +from crawlee._consts import METADATA_FILENAME +from crawlee._utils.crypto import crypto_random_object_id +from crawlee._utils.file import atomic_write, json_dumps +from crawlee.storage_clients._base import DatasetClient +from crawlee.storage_clients.models import DatasetItemsListPage, DatasetMetadata + +if TYPE_CHECKING: + from collections.abc import AsyncIterator + + from crawlee.configuration import Configuration + +logger = getLogger(__name__) + + +class FileSystemDatasetClient(DatasetClient): + """File system implementation of the dataset client. + + This client persists dataset items to the file system as individual JSON files within a structured + directory hierarchy following the pattern: + + ``` + {STORAGE_DIR}/datasets/{DATASET_ID}/{ITEM_ID}.json + ``` + + Each item is stored as a separate file, which allows for durability and the ability to + recover after process termination. Dataset operations like filtering, sorting, and pagination are + implemented by processing the stored files according to the requested parameters. + + This implementation is ideal for long-running crawlers where data persistence is important, + and for development environments where you want to easily inspect the collected data between runs. + """ + + _STORAGE_SUBDIR = 'datasets' + """The name of the subdirectory where datasets are stored.""" + + _STORAGE_SUBSUBDIR_DEFAULT = 'default' + """The name of the subdirectory for the default dataset.""" + + _ITEM_FILENAME_DIGITS = 9 + """Number of digits used for the dataset item file names (e.g., 000000019.json).""" + + def __init__( + self, + *, + metadata: DatasetMetadata, + storage_dir: Path, + lock: asyncio.Lock, + ) -> None: + """Initialize a new instance. + + Preferably use the `FileSystemDatasetClient.open` class method to create a new instance. + """ + self._metadata = metadata + + self._storage_dir = storage_dir + """The base directory where the storage data are being persisted.""" + + self._lock = lock + """A lock to ensure that only one operation is performed at a time.""" + + @override + async def get_metadata(self) -> DatasetMetadata: + return self._metadata + + @property + def path_to_dataset(self) -> Path: + """The full path to the dataset directory.""" + if self._metadata.name is None: + return self._storage_dir / self._STORAGE_SUBDIR / self._STORAGE_SUBSUBDIR_DEFAULT + + return self._storage_dir / self._STORAGE_SUBDIR / self._metadata.name + + @property + def path_to_metadata(self) -> Path: + """The full path to the dataset metadata file.""" + return self.path_to_dataset / METADATA_FILENAME + + @classmethod + async def open( + cls, + *, + id: str | None, + name: str | None, + configuration: Configuration, + ) -> FileSystemDatasetClient: + """Open or create a file system dataset client. + + This method attempts to open an existing dataset from the file system. If a dataset with the specified ID + or name exists, it loads the metadata from the stored files. If no existing dataset is found, a new one + is created. + + Args: + id: The ID of the dataset to open. If provided, searches for existing dataset by ID. + name: The name of the dataset to open. If not provided, uses the default dataset. + configuration: The configuration object containing storage directory settings. + + Returns: + An instance for the opened or created storage client. + + Raises: + ValueError: If a dataset with the specified ID is not found, or if metadata is invalid. + """ + storage_dir = Path(configuration.storage_dir) + dataset_base_path = storage_dir / cls._STORAGE_SUBDIR + + if not dataset_base_path.exists(): + await asyncio.to_thread(dataset_base_path.mkdir, parents=True, exist_ok=True) + + # Get a new instance by ID. + if id: + found = False + for dataset_dir in dataset_base_path.iterdir(): + if not dataset_dir.is_dir(): + continue + + metadata_path = dataset_dir / METADATA_FILENAME + if not metadata_path.exists(): + continue + + try: + file = await asyncio.to_thread(metadata_path.open) + try: + file_content = json.load(file) + metadata = DatasetMetadata(**file_content) + if metadata.id == id: + client = cls( + metadata=metadata, + storage_dir=storage_dir, + lock=asyncio.Lock(), + ) + await client._update_metadata(update_accessed_at=True) + found = True + break + finally: + await asyncio.to_thread(file.close) + except (json.JSONDecodeError, ValidationError): + continue + + if not found: + raise ValueError(f'Dataset with ID "{id}" not found') + + # Get a new instance by name. + else: + dataset_path = ( + dataset_base_path / cls._STORAGE_SUBSUBDIR_DEFAULT if name is None else dataset_base_path / name + ) + metadata_path = dataset_path / METADATA_FILENAME + + # If the dataset directory exists, reconstruct the client from the metadata file. + if dataset_path.exists() and metadata_path.exists(): + file = await asyncio.to_thread(open, metadata_path) + try: + file_content = json.load(file) + finally: + await asyncio.to_thread(file.close) + try: + metadata = DatasetMetadata(**file_content) + except ValidationError as exc: + raise ValueError(f'Invalid metadata file for dataset "{name}"') from exc + + client = cls( + metadata=metadata, + storage_dir=storage_dir, + lock=asyncio.Lock(), + ) + + await client._update_metadata(update_accessed_at=True) + + # Otherwise, create a new dataset client. + else: + now = datetime.now(timezone.utc) + metadata = DatasetMetadata( + id=crypto_random_object_id(), + name=name, + created_at=now, + accessed_at=now, + modified_at=now, + item_count=0, + ) + client = cls( + metadata=metadata, + storage_dir=storage_dir, + lock=asyncio.Lock(), + ) + await client._update_metadata() + + return client + + @override + async def drop(self) -> None: + async with self._lock: + if self.path_to_dataset.exists(): + await asyncio.to_thread(shutil.rmtree, self.path_to_dataset) + + @override + async def purge(self) -> None: + async with self._lock: + for file_path in await self._get_sorted_data_files(): + await asyncio.to_thread(file_path.unlink, missing_ok=True) + + await self._update_metadata( + update_accessed_at=True, + update_modified_at=True, + new_item_count=0, + ) + + @override + async def push_data(self, data: list[dict[str, Any]] | dict[str, Any]) -> None: + async with self._lock: + new_item_count = self._metadata.item_count + if isinstance(data, list): + for item in data: + new_item_count += 1 + await self._push_item(item, new_item_count) + else: + new_item_count += 1 + await self._push_item(data, new_item_count) + + # now update metadata under the same lock + await self._update_metadata( + update_accessed_at=True, + update_modified_at=True, + new_item_count=new_item_count, + ) + + @override + async def get_data( + self, + *, + offset: int = 0, + limit: int | None = 999_999_999_999, + clean: bool = False, + desc: bool = False, + fields: list[str] | None = None, + omit: list[str] | None = None, + unwind: str | None = None, + skip_empty: bool = False, + skip_hidden: bool = False, + flatten: list[str] | None = None, + view: str | None = None, + ) -> DatasetItemsListPage: + # Check for unsupported arguments and log a warning if found. + unsupported_args: dict[str, Any] = { + 'clean': clean, + 'fields': fields, + 'omit': omit, + 'unwind': unwind, + 'skip_hidden': skip_hidden, + 'flatten': flatten, + 'view': view, + } + unsupported = {k: v for k, v in unsupported_args.items() if v not in (False, None)} + + if unsupported: + logger.warning( + f'The arguments {list(unsupported.keys())} of get_data are not supported by the ' + f'{self.__class__.__name__} client.' + ) + + # If the dataset directory does not exist, log a warning and return an empty page. + if not self.path_to_dataset.exists(): + logger.warning(f'Dataset directory not found: {self.path_to_dataset}') + return DatasetItemsListPage( + count=0, + offset=offset, + limit=limit or 0, + total=0, + desc=desc, + items=[], + ) + + # Get the list of sorted data files. + async with self._lock: + try: + data_files = await self._get_sorted_data_files() + except FileNotFoundError: + # directory was dropped mid-check + return DatasetItemsListPage(count=0, offset=offset, limit=limit or 0, total=0, desc=desc, items=[]) + + total = len(data_files) + + # Reverse the order if descending order is requested. + if desc: + data_files.reverse() + + # Apply offset and limit slicing. + selected_files = data_files[offset:] + if limit is not None: + selected_files = selected_files[:limit] + + # Read and parse each data file. + items = list[dict[str, Any]]() + for file_path in selected_files: + try: + file_content = await asyncio.to_thread(file_path.read_text, encoding='utf-8') + except FileNotFoundError: + logger.warning(f'File disappeared during iterate_items(): {file_path}, skipping') + continue + + try: + item = json.loads(file_content) + except json.JSONDecodeError: + logger.exception(f'Corrupt JSON in {file_path}, skipping') + continue + + # Skip empty items if requested. + if skip_empty and not item: + continue + + items.append(item) + + async with self._lock: + await self._update_metadata(update_accessed_at=True) + + # Return a paginated list page of dataset items. + return DatasetItemsListPage( + count=len(items), + offset=offset, + limit=limit or total - offset, + total=total, + desc=desc, + items=items, + ) + + @override + async def iterate_items( + self, + *, + offset: int = 0, + limit: int | None = None, + clean: bool = False, + desc: bool = False, + fields: list[str] | None = None, + omit: list[str] | None = None, + unwind: str | None = None, + skip_empty: bool = False, + skip_hidden: bool = False, + ) -> AsyncIterator[dict[str, Any]]: + # Check for unsupported arguments and log a warning if found. + unsupported_args: dict[str, Any] = { + 'clean': clean, + 'fields': fields, + 'omit': omit, + 'unwind': unwind, + 'skip_hidden': skip_hidden, + } + unsupported = {k: v for k, v in unsupported_args.items() if v not in (False, None)} + + if unsupported: + logger.warning( + f'The arguments {list(unsupported.keys())} of iterate are not supported ' + f'by the {self.__class__.__name__} client.' + ) + + # If the dataset directory does not exist, log a warning and return immediately. + if not self.path_to_dataset.exists(): + logger.warning(f'Dataset directory not found: {self.path_to_dataset}') + return + + # Get the list of sorted data files. + async with self._lock: + try: + data_files = await self._get_sorted_data_files() + except FileNotFoundError: + return + + # Reverse the order if descending order is requested. + if desc: + data_files.reverse() + + # Apply offset and limit slicing. + selected_files = data_files[offset:] + if limit is not None: + selected_files = selected_files[:limit] + + # Iterate over each data file, reading and yielding its parsed content. + for file_path in selected_files: + try: + file_content = await asyncio.to_thread(file_path.read_text, encoding='utf-8') + except FileNotFoundError: + logger.warning(f'File disappeared during iterate_items(): {file_path}, skipping') + continue + + try: + item = json.loads(file_content) + except json.JSONDecodeError: + logger.exception(f'Corrupt JSON in {file_path}, skipping') + continue + + # Skip empty items if requested. + if skip_empty and not item: + continue + + yield item + + async with self._lock: + await self._update_metadata(update_accessed_at=True) + + async def _update_metadata( + self, + *, + new_item_count: int | None = None, + update_accessed_at: bool = False, + update_modified_at: bool = False, + ) -> None: + """Update the dataset metadata file with current information. + + Args: + new_item_count: If provided, update the item count to this value. + update_accessed_at: If True, update the `accessed_at` timestamp to the current time. + update_modified_at: If True, update the `modified_at` timestamp to the current time. + """ + now = datetime.now(timezone.utc) + + if update_accessed_at: + self._metadata.accessed_at = now + if update_modified_at: + self._metadata.modified_at = now + if new_item_count is not None: + self._metadata.item_count = new_item_count + + # Ensure the parent directory for the metadata file exists. + await asyncio.to_thread(self.path_to_metadata.parent.mkdir, parents=True, exist_ok=True) + + # Dump the serialized metadata to the file. + data = await json_dumps(self._metadata.model_dump()) + await atomic_write(self.path_to_metadata, data) + + async def _push_item(self, item: dict[str, Any], item_id: int) -> None: + """Push a single item to the dataset. + + This method writes the item as a JSON file with a zero-padded numeric filename + that reflects its position in the dataset sequence. + + Args: + item: The data item to add to the dataset. + item_id: The sequential ID to use for this item's filename. + """ + # Generate the filename for the new item using zero-padded numbering. + filename = f'{str(item_id).zfill(self._ITEM_FILENAME_DIGITS)}.json' + file_path = self.path_to_dataset / filename + + # Ensure the dataset directory exists. + await asyncio.to_thread(self.path_to_dataset.mkdir, parents=True, exist_ok=True) + + # Dump the serialized item to the file. + data = await json_dumps(item) + await atomic_write(file_path, data) + + async def _get_sorted_data_files(self) -> list[Path]: + """Retrieve and return a sorted list of data files in the dataset directory. + + The files are sorted numerically based on the filename (without extension), + which corresponds to the order items were added to the dataset. + + Returns: + A list of `Path` objects pointing to data files, sorted by numeric filename. + """ + # Retrieve and sort all JSON files in the dataset directory numerically. + files = await asyncio.to_thread( + sorted, + self.path_to_dataset.glob('*.json'), + key=lambda f: int(f.stem) if f.stem.isdigit() else 0, + ) + + # Remove the metadata file from the list if present. + if self.path_to_metadata in files: + files.remove(self.path_to_metadata) + + return files diff --git a/src/crawlee/storage_clients/_file_system/_key_value_store_client.py b/src/crawlee/storage_clients/_file_system/_key_value_store_client.py new file mode 100644 index 0000000000..bc94980bcc --- /dev/null +++ b/src/crawlee/storage_clients/_file_system/_key_value_store_client.py @@ -0,0 +1,486 @@ +from __future__ import annotations + +import asyncio +import json +import shutil +import urllib.parse +from datetime import datetime, timezone +from logging import getLogger +from pathlib import Path +from typing import TYPE_CHECKING, Any + +from pydantic import ValidationError +from typing_extensions import override + +from crawlee._consts import METADATA_FILENAME +from crawlee._utils.crypto import crypto_random_object_id +from crawlee._utils.file import atomic_write, infer_mime_type, json_dumps +from crawlee.storage_clients._base import KeyValueStoreClient +from crawlee.storage_clients.models import KeyValueStoreMetadata, KeyValueStoreRecord, KeyValueStoreRecordMetadata + +if TYPE_CHECKING: + from collections.abc import AsyncIterator + + from crawlee.configuration import Configuration + + +logger = getLogger(__name__) + + +class FileSystemKeyValueStoreClient(KeyValueStoreClient): + """File system implementation of the key-value store client. + + This client persists data to the file system, making it suitable for scenarios where data needs to + survive process restarts. Keys are mapped to file paths in a directory structure following the pattern: + + ``` + {STORAGE_DIR}/key_value_stores/{STORE_ID}/{KEY} + ``` + + Binary data is stored as-is, while JSON and text data are stored in human-readable format. + The implementation automatically handles serialization based on the content type and + maintains metadata about each record. + + This implementation is ideal for long-running crawlers where persistence is important and + for development environments where you want to easily inspect the stored data between runs. + """ + + _STORAGE_SUBDIR = 'key_value_stores' + """The name of the subdirectory where key-value stores are stored.""" + + _STORAGE_SUBSUBDIR_DEFAULT = 'default' + """The name of the subdirectory for the default key-value store.""" + + def __init__( + self, + *, + metadata: KeyValueStoreMetadata, + storage_dir: Path, + lock: asyncio.Lock, + ) -> None: + """Initialize a new instance. + + Preferably use the `FileSystemKeyValueStoreClient.open` class method to create a new instance. + """ + self._metadata = metadata + + self._storage_dir = storage_dir + """The base directory where the storage data are being persisted.""" + + self._lock = lock + """A lock to ensure that only one operation is performed at a time.""" + + @override + async def get_metadata(self) -> KeyValueStoreMetadata: + return self._metadata + + @property + def path_to_kvs(self) -> Path: + """The full path to the key-value store directory.""" + if self._metadata.name is None: + return self._storage_dir / self._STORAGE_SUBDIR / self._STORAGE_SUBSUBDIR_DEFAULT + + return self._storage_dir / self._STORAGE_SUBDIR / self._metadata.name + + @property + def path_to_metadata(self) -> Path: + """The full path to the key-value store metadata file.""" + return self.path_to_kvs / METADATA_FILENAME + + @classmethod + async def open( + cls, + *, + id: str | None, + name: str | None, + configuration: Configuration, + ) -> FileSystemKeyValueStoreClient: + """Open or create a file system key-value store client. + + This method attempts to open an existing key-value store from the file system. If a KVS with the specified + ID or name exists, it loads the metadata from the stored files. If no existing store is found, a new one + is created. + + Args: + id: The ID of the key-value store to open. If provided, searches for existing store by ID. + name: The name of the key-value store to open. If not provided, uses the default store. + configuration: The configuration object containing storage directory settings. + + Returns: + An instance for the opened or created storage client. + + Raises: + ValueError: If a store with the specified ID is not found, or if metadata is invalid. + """ + storage_dir = Path(configuration.storage_dir) + kvs_base_path = storage_dir / cls._STORAGE_SUBDIR + + if not kvs_base_path.exists(): + await asyncio.to_thread(kvs_base_path.mkdir, parents=True, exist_ok=True) + + # Get a new instance by ID. + if id: + found = False + for kvs_dir in kvs_base_path.iterdir(): + if not kvs_dir.is_dir(): + continue + + metadata_path = kvs_dir / METADATA_FILENAME + if not metadata_path.exists(): + continue + + try: + file = await asyncio.to_thread(metadata_path.open) + try: + file_content = json.load(file) + metadata = KeyValueStoreMetadata(**file_content) + if metadata.id == id: + client = cls( + metadata=metadata, + storage_dir=storage_dir, + lock=asyncio.Lock(), + ) + await client._update_metadata(update_accessed_at=True) + found = True + break + finally: + await asyncio.to_thread(file.close) + except (json.JSONDecodeError, ValidationError): + continue + + if not found: + raise ValueError(f'Key-value store with ID "{id}" not found.') + + # Get a new instance by name. + else: + kvs_path = kvs_base_path / cls._STORAGE_SUBSUBDIR_DEFAULT if name is None else kvs_base_path / name + metadata_path = kvs_path / METADATA_FILENAME + + # If the key-value store directory exists, reconstruct the client from the metadata file. + if kvs_path.exists() and metadata_path.exists(): + file = await asyncio.to_thread(open, metadata_path) + try: + file_content = json.load(file) + finally: + await asyncio.to_thread(file.close) + try: + metadata = KeyValueStoreMetadata(**file_content) + except ValidationError as exc: + raise ValueError(f'Invalid metadata file for key-value store "{name}"') from exc + + client = cls( + metadata=metadata, + storage_dir=storage_dir, + lock=asyncio.Lock(), + ) + + await client._update_metadata(update_accessed_at=True) + + # Otherwise, create a new key-value store client. + else: + now = datetime.now(timezone.utc) + metadata = KeyValueStoreMetadata( + id=crypto_random_object_id(), + name=name, + created_at=now, + accessed_at=now, + modified_at=now, + ) + client = cls( + metadata=metadata, + storage_dir=storage_dir, + lock=asyncio.Lock(), + ) + await client._update_metadata() + + return client + + @override + async def drop(self) -> None: + # If the client directory exists, remove it recursively. + if self.path_to_kvs.exists(): + async with self._lock: + await asyncio.to_thread(shutil.rmtree, self.path_to_kvs) + + @override + async def purge(self) -> None: + async with self._lock: + for file_path in self.path_to_kvs.glob('*'): + if file_path.name == METADATA_FILENAME: + continue + await asyncio.to_thread(file_path.unlink, missing_ok=True) + + await self._update_metadata( + update_accessed_at=True, + update_modified_at=True, + ) + + @override + async def get_value(self, *, key: str) -> KeyValueStoreRecord | None: + # Update the metadata to record access + async with self._lock: + await self._update_metadata(update_accessed_at=True) + + record_path = self.path_to_kvs / self._encode_key(key) + + if not record_path.exists(): + return None + + # Found a file for this key, now look for its metadata + record_metadata_filepath = record_path.with_name(f'{record_path.name}.{METADATA_FILENAME}') + if not record_metadata_filepath.exists(): + logger.warning(f'Found value file for key "{key}" but no metadata file.') + return None + + # Read the metadata file + async with self._lock: + try: + file = await asyncio.to_thread(open, record_metadata_filepath) + except FileNotFoundError: + logger.warning(f'Metadata file disappeared for key "{key}", aborting get_value') + return None + + try: + metadata_content = json.load(file) + except json.JSONDecodeError: + logger.warning(f'Invalid metadata file for key "{key}"') + return None + finally: + await asyncio.to_thread(file.close) + + try: + metadata = KeyValueStoreRecordMetadata(**metadata_content) + except ValidationError: + logger.warning(f'Invalid metadata schema for key "{key}"') + return None + + # Read the actual value + try: + value_bytes = await asyncio.to_thread(record_path.read_bytes) + except FileNotFoundError: + logger.warning(f'Value file disappeared for key "{key}"') + return None + + # Handle None values + if metadata.content_type == 'application/x-none': + value = None + # Handle JSON values + elif 'application/json' in metadata.content_type: + try: + value = json.loads(value_bytes.decode('utf-8')) + except (json.JSONDecodeError, UnicodeDecodeError): + logger.warning(f'Failed to decode JSON value for key "{key}"') + return None + # Handle text values + elif metadata.content_type.startswith('text/'): + try: + value = value_bytes.decode('utf-8') + except UnicodeDecodeError: + logger.warning(f'Failed to decode text value for key "{key}"') + return None + # Handle binary values + else: + value = value_bytes + + # Calculate the size of the value in bytes + size = len(value_bytes) + + return KeyValueStoreRecord( + key=metadata.key, + value=value, + content_type=metadata.content_type, + size=size, + ) + + @override + async def set_value(self, *, key: str, value: Any, content_type: str | None = None) -> None: + # Special handling for None values + if value is None: + content_type = 'application/x-none' # Special content type to identify None values + value_bytes = b'' + else: + content_type = content_type or infer_mime_type(value) + + # Serialize the value to bytes. + if 'application/json' in content_type: + value_bytes = (await json_dumps(value)).encode('utf-8') + elif isinstance(value, str): + value_bytes = value.encode('utf-8') + elif isinstance(value, (bytes, bytearray)): + value_bytes = value + else: + # Fallback: attempt to convert to string and encode. + value_bytes = str(value).encode('utf-8') + + record_path = self.path_to_kvs / self._encode_key(key) + + # Prepare the metadata + size = len(value_bytes) + record_metadata = KeyValueStoreRecordMetadata(key=key, content_type=content_type, size=size) + record_metadata_filepath = record_path.with_name(f'{record_path.name}.{METADATA_FILENAME}') + record_metadata_content = await json_dumps(record_metadata.model_dump()) + + async with self._lock: + # Ensure the key-value store directory exists. + await asyncio.to_thread(self.path_to_kvs.mkdir, parents=True, exist_ok=True) + + # Write the value to the file. + await atomic_write(record_path, value_bytes) + + # Write the record metadata to the file. + await atomic_write(record_metadata_filepath, record_metadata_content) + + # Update the KVS metadata to record the access and modification. + await self._update_metadata(update_accessed_at=True, update_modified_at=True) + + @override + async def delete_value(self, *, key: str) -> None: + record_path = self.path_to_kvs / self._encode_key(key) + metadata_path = record_path.with_name(f'{record_path.name}.{METADATA_FILENAME}') + deleted = False + + async with self._lock: + # Delete the value file and its metadata if found + if record_path.exists(): + await asyncio.to_thread(record_path.unlink, missing_ok=True) + + # Delete the metadata file if it exists + if metadata_path.exists(): + await asyncio.to_thread(metadata_path.unlink, missing_ok=True) + else: + logger.warning(f'Found value file for key "{key}" but no metadata file when trying to delete it.') + + deleted = True + + # If we deleted something, update the KVS metadata + if deleted: + await self._update_metadata(update_accessed_at=True, update_modified_at=True) + + @override + async def iterate_keys( + self, + *, + exclusive_start_key: str | None = None, + limit: int | None = None, + ) -> AsyncIterator[KeyValueStoreRecordMetadata]: + # Check if the KVS directory exists + if not self.path_to_kvs.exists(): + return + + # List and sort all files *inside* a brief lock, then release it immediately: + async with self._lock: + files = sorted(await asyncio.to_thread(list, self.path_to_kvs.glob('*'))) + + count = 0 + + for file_path in files: + # Skip the main metadata file + if file_path.name == METADATA_FILENAME: + continue + + # Only process metadata files for records + if not file_path.name.endswith(f'.{METADATA_FILENAME}'): + continue + + # Extract the base key name from the metadata filename + key_name = self._decode_key(file_path.name[: -len(f'.{METADATA_FILENAME}')]) + + # Apply exclusive_start_key filter if provided + if exclusive_start_key is not None and key_name <= exclusive_start_key: + continue + + # Try to read and parse the metadata file + try: + metadata_content = await asyncio.to_thread(file_path.read_text, encoding='utf-8') + except FileNotFoundError: + logger.warning(f'Metadata file disappeared for key "{key_name}", skipping it.') + continue + + try: + metadata_dict = json.loads(metadata_content) + except json.JSONDecodeError: + logger.warning(f'Failed to decode metadata file for key "{key_name}", skipping it.') + continue + + try: + record_metadata = KeyValueStoreRecordMetadata(**metadata_dict) + except ValidationError: + logger.warning(f'Invalid metadata schema for key "{key_name}", skipping it.') + + yield record_metadata + + count += 1 + if limit and count >= limit: + break + + # Update accessed_at timestamp + async with self._lock: + await self._update_metadata(update_accessed_at=True) + + @override + async def get_public_url(self, *, key: str) -> str: + """Return a file:// URL for the given key. + + Args: + key: The key to get the public URL for. + + Returns: + A file:// URL pointing to the file on the local filesystem. + """ + record_path = self.path_to_kvs / self._encode_key(key) + absolute_path = record_path.absolute() + return absolute_path.as_uri() + + @override + async def record_exists(self, *, key: str) -> bool: + """Check if a record with the given key exists in the key-value store. + + Args: + key: The key to check for existence. + + Returns: + True if a record with the given key exists, False otherwise. + """ + # Update the metadata to record access + async with self._lock: + await self._update_metadata(update_accessed_at=True) + + record_path = self.path_to_kvs / self._encode_key(key) + record_metadata_filepath = record_path.with_name(f'{record_path.name}.{METADATA_FILENAME}') + + # Both the value file and metadata file must exist for a record to be considered existing + return record_path.exists() and record_metadata_filepath.exists() + + async def _update_metadata( + self, + *, + update_accessed_at: bool = False, + update_modified_at: bool = False, + ) -> None: + """Update the KVS metadata file with current information. + + Args: + update_accessed_at: If True, update the `accessed_at` timestamp to the current time. + update_modified_at: If True, update the `modified_at` timestamp to the current time. + """ + now = datetime.now(timezone.utc) + + if update_accessed_at: + self._metadata.accessed_at = now + if update_modified_at: + self._metadata.modified_at = now + + # Ensure the parent directory for the metadata file exists. + await asyncio.to_thread(self.path_to_metadata.parent.mkdir, parents=True, exist_ok=True) + + # Dump the serialized metadata to the file. + data = await json_dumps(self._metadata.model_dump()) + await atomic_write(self.path_to_metadata, data) + + def _encode_key(self, key: str) -> str: + """Encode a key to make it safe for use in a file path.""" + return urllib.parse.quote(key, safe='') + + def _decode_key(self, encoded_key: str) -> str: + """Decode a key that was encoded to make it safe for use in a file path.""" + return urllib.parse.unquote(encoded_key) diff --git a/src/crawlee/storage_clients/_file_system/_request_queue_client.py b/src/crawlee/storage_clients/_file_system/_request_queue_client.py new file mode 100644 index 0000000000..e574855e99 --- /dev/null +++ b/src/crawlee/storage_clients/_file_system/_request_queue_client.py @@ -0,0 +1,818 @@ +from __future__ import annotations + +import asyncio +import json +import shutil +from collections import deque +from datetime import datetime, timezone +from logging import getLogger +from pathlib import Path +from typing import TYPE_CHECKING + +from pydantic import BaseModel, ValidationError +from typing_extensions import override + +from crawlee import Request +from crawlee._consts import METADATA_FILENAME +from crawlee._utils.crypto import crypto_random_object_id +from crawlee._utils.file import atomic_write, json_dumps +from crawlee._utils.recoverable_state import RecoverableState +from crawlee.storage_clients._base import RequestQueueClient +from crawlee.storage_clients.models import ( + AddRequestsResponse, + ProcessedRequest, + RequestQueueMetadata, + UnprocessedRequest, +) + +if TYPE_CHECKING: + from collections.abc import Sequence + + from crawlee.configuration import Configuration + +logger = getLogger(__name__) + + +class RequestQueueState(BaseModel): + """State model for the `FileSystemRequestQueueClient`.""" + + sequence_counter: int = 0 + """Counter for regular request ordering.""" + + forefront_sequence_counter: int = 0 + """Counter for forefront request ordering.""" + + forefront_requests: dict[str, int] = {} + """Mapping of forefront request IDs to their sequence numbers.""" + + regular_requests: dict[str, int] = {} + """Mapping of regular request IDs to their sequence numbers.""" + + in_progress_requests: set[str] = set() + """Set of request IDs currently being processed.""" + + handled_requests: set[str] = set() + """Set of request IDs that have been handled.""" + + +class FileSystemRequestQueueClient(RequestQueueClient): + """A file system implementation of the request queue client. + + This client persists requests to the file system as individual JSON files, making it suitable for scenarios + where data needs to survive process restarts. Each request is stored as a separate file in a directory + structure following the pattern: + + ``` + {STORAGE_DIR}/request_queues/{QUEUE_ID}/{REQUEST_ID}.json + ``` + + The implementation uses `RecoverableState` to maintain ordering information, in-progress status, and + request handling status. This allows for proper state recovery across process restarts without + embedding metadata in individual request files. File system storage provides durability at the cost of + slower I/O operations compared to memory only-based storage. + + This implementation is ideal for long-running crawlers where persistence is important and for situations + where you need to resume crawling after process termination. + """ + + _STORAGE_SUBDIR = 'request_queues' + """The name of the subdirectory where request queues are stored.""" + + _STORAGE_SUBSUBDIR_DEFAULT = 'default' + """The name of the subdirectory for the default request queue.""" + + _MAX_REQUESTS_IN_CACHE = 100_000 + """Maximum number of requests to keep in cache for faster access.""" + + def __init__( + self, + *, + metadata: RequestQueueMetadata, + storage_dir: Path, + lock: asyncio.Lock, + ) -> None: + """Initialize a new instance. + + Preferably use the `FileSystemRequestQueueClient.open` class method to create a new instance. + """ + self._metadata = metadata + + self._storage_dir = storage_dir + """The base directory where the storage data are being persisted.""" + + self._lock = lock + """A lock to ensure that only one operation is performed at a time.""" + + self._request_cache = deque[Request]() + """Cache for requests: forefront requests at the beginning, regular requests at the end.""" + + self._request_cache_needs_refresh = True + """Flag indicating whether the cache needs to be refreshed from filesystem.""" + + self._is_empty_cache: bool | None = None + """Cache for is_empty result: None means unknown, True/False is cached state.""" + + self._state = RecoverableState[RequestQueueState]( + default_state=RequestQueueState(), + persist_state_key='request_queue_state', + persistence_enabled=True, + persist_state_kvs_name=f'__RQ_STATE_{self._metadata.id}', + logger=logger, + ) + """Recoverable state to maintain request ordering, in-progress status, and handled status.""" + + @override + async def get_metadata(self) -> RequestQueueMetadata: + return self._metadata + + @property + def path_to_rq(self) -> Path: + """The full path to the request queue directory.""" + if self._metadata.name is None: + return self._storage_dir / self._STORAGE_SUBDIR / self._STORAGE_SUBSUBDIR_DEFAULT + + return self._storage_dir / self._STORAGE_SUBDIR / self._metadata.name + + @property + def path_to_metadata(self) -> Path: + """The full path to the request queue metadata file.""" + return self.path_to_rq / METADATA_FILENAME + + @classmethod + async def open( + cls, + *, + id: str | None, + name: str | None, + configuration: Configuration, + ) -> FileSystemRequestQueueClient: + """Open or create a file system request queue client. + + This method attempts to open an existing request queue from the file system. If a queue with the specified + ID or name exists, it loads the metadata and state from the stored files. If no existing queue is found, + a new one is created. + + Args: + id: The ID of the request queue to open. If provided, searches for existing queue by ID. + name: The name of the request queue to open. If not provided, uses the default queue. + configuration: The configuration object containing storage directory settings. + + Returns: + An instance for the opened or created storage client. + + Raises: + ValueError: If a queue with the specified ID is not found, or if metadata is invalid. + """ + storage_dir = Path(configuration.storage_dir) + rq_base_path = storage_dir / cls._STORAGE_SUBDIR + + if not rq_base_path.exists(): + await asyncio.to_thread(rq_base_path.mkdir, parents=True, exist_ok=True) + + # Open an existing RQ by its ID, raise an error if not found. + if id: + found = False + for rq_dir in rq_base_path.iterdir(): + if not rq_dir.is_dir(): + continue + + metadata_path = rq_dir / METADATA_FILENAME + if not metadata_path.exists(): + continue + + try: + file = await asyncio.to_thread(metadata_path.open) + try: + file_content = json.load(file) + metadata = RequestQueueMetadata(**file_content) + + if metadata.id == id: + client = cls( + metadata=metadata, + storage_dir=storage_dir, + lock=asyncio.Lock(), + ) + await client._state.initialize() + await client._discover_existing_requests() + await client._update_metadata(update_accessed_at=True) + found = True + break + finally: + await asyncio.to_thread(file.close) + except (json.JSONDecodeError, ValidationError): + continue + + if not found: + raise ValueError(f'Request queue with ID "{id}" not found') + + # Open an existing RQ by its name, or create a new one if not found. + else: + rq_path = rq_base_path / cls._STORAGE_SUBSUBDIR_DEFAULT if name is None else rq_base_path / name + metadata_path = rq_path / METADATA_FILENAME + + # If the RQ directory exists, reconstruct the client from the metadata file. + if rq_path.exists() and metadata_path.exists(): + file = await asyncio.to_thread(open, metadata_path) + try: + file_content = json.load(file) + finally: + await asyncio.to_thread(file.close) + try: + metadata = RequestQueueMetadata(**file_content) + except ValidationError as exc: + raise ValueError(f'Invalid metadata file for request queue "{name}"') from exc + + metadata.name = name + + client = cls( + metadata=metadata, + storage_dir=storage_dir, + lock=asyncio.Lock(), + ) + + await client._state.initialize() + await client._discover_existing_requests() + await client._update_metadata(update_accessed_at=True) + + # Otherwise, create a new dataset client. + else: + now = datetime.now(timezone.utc) + metadata = RequestQueueMetadata( + id=crypto_random_object_id(), + name=name, + created_at=now, + accessed_at=now, + modified_at=now, + had_multiple_clients=False, + handled_request_count=0, + pending_request_count=0, + stats={}, + total_request_count=0, + ) + client = cls( + metadata=metadata, + storage_dir=storage_dir, + lock=asyncio.Lock(), + ) + await client._state.initialize() + await client._update_metadata() + + return client + + @override + async def drop(self) -> None: + async with self._lock: + # Remove the RQ dir recursively if it exists. + if self.path_to_rq.exists(): + await asyncio.to_thread(shutil.rmtree, self.path_to_rq) + + # Clear recoverable state + await self._state.reset() + await self._state.teardown() + self._request_cache.clear() + self._request_cache_needs_refresh = True + + # Invalidate is_empty cache. + self._is_empty_cache = None + + @override + async def purge(self) -> None: + async with self._lock: + request_files = await self._get_request_files(self.path_to_rq) + + for file_path in request_files: + await asyncio.to_thread(file_path.unlink, missing_ok=True) + + # Clear recoverable state + await self._state.reset() + self._request_cache.clear() + self._request_cache_needs_refresh = True + + await self._update_metadata( + update_modified_at=True, + update_accessed_at=True, + new_pending_request_count=0, + ) + + # Invalidate is_empty cache. + self._is_empty_cache = None + + @override + async def add_batch_of_requests( + self, + requests: Sequence[Request], + *, + forefront: bool = False, + ) -> AddRequestsResponse: + async with self._lock: + self._is_empty_cache = None + new_total_request_count = self._metadata.total_request_count + new_pending_request_count = self._metadata.pending_request_count + processed_requests = list[ProcessedRequest]() + unprocessed_requests = list[UnprocessedRequest]() + state = self._state.current_value + + # Prepare a dictionary to track existing requests by their unique keys. + existing_unique_keys: dict[str, Path] = {} + existing_request_files = await self._get_request_files(self.path_to_rq) + + for request_file in existing_request_files: + existing_request = await self._parse_request_file(request_file) + if existing_request is not None: + existing_unique_keys[existing_request.unique_key] = request_file + + # Process each request in the batch. + for request in requests: + existing_request_file = existing_unique_keys.get(request.unique_key) + existing_request = None + + # Only load the full request from disk if we found a duplicate + if existing_request_file is not None: + existing_request = await self._parse_request_file(existing_request_file) + + # If there is no existing request with the same unique key, add the new request. + if existing_request is None: + request_path = self._get_request_path(request.id) + + # Add sequence number to ensure FIFO ordering using state. + if forefront: + sequence_number = state.forefront_sequence_counter + state.forefront_sequence_counter += 1 + state.forefront_requests[request.id] = sequence_number + else: + sequence_number = state.sequence_counter + state.sequence_counter += 1 + state.regular_requests[request.id] = sequence_number + + # Save the clean request without extra fields + request_data = await json_dumps(request.model_dump()) + await atomic_write(request_path, request_data) + + # Update the metadata counts. + new_total_request_count += 1 + new_pending_request_count += 1 + + # Add to our index for subsequent requests in this batch + existing_unique_keys[request.unique_key] = self._get_request_path(request.id) + + processed_requests.append( + ProcessedRequest( + id=request.id, + unique_key=request.unique_key, + was_already_present=False, + was_already_handled=False, + ) + ) + + # If the request already exists in the RQ, just update it if needed. + else: + # Set the processed request flags. + was_already_present = existing_request is not None + was_already_handled = existing_request.id in state.handled_requests + + # If the request is already in the RQ and handled, just continue with the next one. + if was_already_present and was_already_handled: + processed_requests.append( + ProcessedRequest( + id=existing_request.id, + unique_key=request.unique_key, + was_already_present=True, + was_already_handled=True, + ) + ) + + # If the request is already in the RQ but not handled yet, update it. + elif was_already_present and not was_already_handled: + # Update request type (forefront vs regular) in state + if forefront: + # Move from regular to forefront if needed + if existing_request.id in state.regular_requests: + state.regular_requests.pop(existing_request.id) + if existing_request.id not in state.forefront_requests: + state.forefront_requests[existing_request.id] = state.forefront_sequence_counter + state.forefront_sequence_counter += 1 + elif ( + existing_request.id not in state.forefront_requests + and existing_request.id not in state.regular_requests + ): + # Keep as regular if not already forefront + state.regular_requests[existing_request.id] = state.sequence_counter + state.sequence_counter += 1 + + processed_requests.append( + ProcessedRequest( + id=existing_request.id, + unique_key=request.unique_key, + was_already_present=True, + was_already_handled=False, + ) + ) + + else: + logger.warning(f'Request with unique key "{request.unique_key}" could not be processed.') + unprocessed_requests.append( + UnprocessedRequest( + unique_key=request.unique_key, + url=request.url, + method=request.method, + ) + ) + + await self._update_metadata( + update_modified_at=True, + update_accessed_at=True, + new_total_request_count=new_total_request_count, + new_pending_request_count=new_pending_request_count, + ) + + # Invalidate the cache if we added forefront requests. + if forefront: + self._request_cache_needs_refresh = True + + # Invalidate is_empty cache. + self._is_empty_cache = None + + return AddRequestsResponse( + processed_requests=processed_requests, + unprocessed_requests=unprocessed_requests, + ) + + @override + async def get_request(self, request_id: str) -> Request | None: + async with self._lock: + request_path = self._get_request_path(request_id) + request = await self._parse_request_file(request_path) + + if request is None: + logger.warning(f'Request with ID "{request_id}" not found in the queue.') + return None + + state = self._state.current_value + state.in_progress_requests.add(request.id) + await self._update_metadata(update_accessed_at=True) + return request + + @override + async def fetch_next_request(self) -> Request | None: + async with self._lock: + # Refresh cache if needed or if it's empty. + if self._request_cache_needs_refresh or not self._request_cache: + await self._refresh_cache() + + next_request: Request | None = None + state = self._state.current_value + + # Fetch from the front of the deque (forefront requests are at the beginning). + while self._request_cache and next_request is None: + candidate = self._request_cache.popleft() + + # Skip requests that are already in progress, however this should not happen. + if candidate.id not in state.in_progress_requests: + next_request = candidate + + if next_request is not None: + state.in_progress_requests.add(next_request.id) + + return next_request + + @override + async def mark_request_as_handled(self, request: Request) -> ProcessedRequest | None: + async with self._lock: + self._is_empty_cache = None + state = self._state.current_value + + # Check if the request is in progress. + if request.id not in state.in_progress_requests: + logger.warning(f'Marking request {request.id} as handled that is not in progress.') + return None + + # Update the request's handled_at timestamp. + if request.handled_at is None: + request.handled_at = datetime.now(timezone.utc) + + # Dump the updated request to the file. + request_path = self._get_request_path(request.id) + + if not await asyncio.to_thread(request_path.exists): + logger.warning(f'Request file for {request.id} does not exist, cannot mark as handled.') + return None + + request_data = await json_dumps(request.model_dump()) + await atomic_write(request_path, request_data) + + # Update state: remove from in-progress and add to handled. + state.in_progress_requests.discard(request.id) + state.handled_requests.add(request.id) + + # Update RQ metadata. + await self._update_metadata( + update_modified_at=True, + update_accessed_at=True, + new_handled_request_count=self._metadata.handled_request_count + 1, + new_pending_request_count=self._metadata.pending_request_count - 1, + ) + + return ProcessedRequest( + id=request.id, + unique_key=request.unique_key, + was_already_present=True, + was_already_handled=True, + ) + + @override + async def reclaim_request( + self, + request: Request, + *, + forefront: bool = False, + ) -> ProcessedRequest | None: + async with self._lock: + self._is_empty_cache = None + state = self._state.current_value + + # Check if the request is in progress. + if request.id not in state.in_progress_requests: + logger.info(f'Reclaiming request {request.id} that is not in progress.') + return None + + request_path = self._get_request_path(request.id) + + if not await asyncio.to_thread(request_path.exists): + logger.warning(f'Request file for {request.id} does not exist, cannot reclaim.') + return None + + # Update sequence number and state to ensure proper ordering. + if forefront: + # Remove from regular requests if it was there + state.regular_requests.pop(request.id, None) + sequence_number = state.forefront_sequence_counter + state.forefront_sequence_counter += 1 + state.forefront_requests[request.id] = sequence_number + else: + # Remove from forefront requests if it was there + state.forefront_requests.pop(request.id, None) + sequence_number = state.sequence_counter + state.sequence_counter += 1 + state.regular_requests[request.id] = sequence_number + + # Save the clean request without extra fields + request_data = await json_dumps(request.model_dump()) + await atomic_write(request_path, request_data) + + # Remove from in-progress. + state.in_progress_requests.discard(request.id) + + # Update RQ metadata. + await self._update_metadata( + update_modified_at=True, + update_accessed_at=True, + ) + + # Add the request back to the cache. + if forefront: + self._request_cache.appendleft(request) + else: + self._request_cache.append(request) + + return ProcessedRequest( + id=request.id, + unique_key=request.unique_key, + was_already_present=True, + was_already_handled=False, + ) + + @override + async def is_empty(self) -> bool: + async with self._lock: + # If we have a cached value, return it immediately. + if self._is_empty_cache is not None: + return self._is_empty_cache + + state = self._state.current_value + + # If there are in-progress requests, return False immediately. + if len(state.in_progress_requests) > 0: + self._is_empty_cache = False + return False + + # If we have a cached requests, check them first (fast path). + if self._request_cache: + for req in self._request_cache: + if req.id not in state.handled_requests: + self._is_empty_cache = False + return False + self._is_empty_cache = True + return len(state.in_progress_requests) == 0 + + # Fallback: check state for unhandled requests. + await self._update_metadata(update_accessed_at=True) + + # Check if there are any requests that are not handled + all_requests = set(state.forefront_requests.keys()) | set(state.regular_requests.keys()) + unhandled_requests = all_requests - state.handled_requests + + if unhandled_requests: + self._is_empty_cache = False + return False + + self._is_empty_cache = True + return True + + def _get_request_path(self, request_id: str) -> Path: + """Get the path to a specific request file. + + Args: + request_id: The ID of the request. + + Returns: + The path to the request file. + """ + return self.path_to_rq / f'{request_id}.json' + + async def _update_metadata( + self, + *, + new_handled_request_count: int | None = None, + new_pending_request_count: int | None = None, + new_total_request_count: int | None = None, + update_had_multiple_clients: bool = False, + update_accessed_at: bool = False, + update_modified_at: bool = False, + ) -> None: + """Update the dataset metadata file with current information. + + Args: + new_handled_request_count: If provided, update the handled_request_count to this value. + new_pending_request_count: If provided, update the pending_request_count to this value. + new_total_request_count: If provided, update the total_request_count to this value. + update_had_multiple_clients: If True, set had_multiple_clients to True. + update_accessed_at: If True, update the `accessed_at` timestamp to the current time. + update_modified_at: If True, update the `modified_at` timestamp to the current time. + """ + # Always create a new timestamp to ensure it's truly updated + now = datetime.now(timezone.utc) + + # Update timestamps according to parameters + if update_accessed_at: + self._metadata.accessed_at = now + + if update_modified_at: + self._metadata.modified_at = now + + # Update request counts if provided + if new_handled_request_count is not None: + self._metadata.handled_request_count = new_handled_request_count + + if new_pending_request_count is not None: + self._metadata.pending_request_count = new_pending_request_count + + if new_total_request_count is not None: + self._metadata.total_request_count = new_total_request_count + + if update_had_multiple_clients: + self._metadata.had_multiple_clients = True + + # Ensure the parent directory for the metadata file exists. + await asyncio.to_thread(self.path_to_metadata.parent.mkdir, parents=True, exist_ok=True) + + # Dump the serialized metadata to the file. + data = await json_dumps(self._metadata.model_dump()) + await atomic_write(self.path_to_metadata, data) + + async def _refresh_cache(self) -> None: + """Refresh the request cache from filesystem. + + This method loads up to _MAX_REQUESTS_IN_CACHE requests from the filesystem, + prioritizing forefront requests and maintaining proper ordering. + """ + self._request_cache.clear() + state = self._state.current_value + + forefront_requests = list[tuple[Request, int]]() # (request, sequence) + regular_requests = list[tuple[Request, int]]() # (request, sequence) + + request_files = await self._get_request_files(self.path_to_rq) + + for request_file in request_files: + request = await self._parse_request_file(request_file) + + if request is None: + continue + + # Skip handled requests + if request.id in state.handled_requests: + continue + + # Skip in-progress requests + if request.id in state.in_progress_requests: + continue + + # Determine if request is forefront or regular based on state + if request.id in state.forefront_requests: + sequence = state.forefront_requests[request.id] + forefront_requests.append((request, sequence)) + elif request.id in state.regular_requests: + sequence = state.regular_requests[request.id] + regular_requests.append((request, sequence)) + else: + # Request not in state, skip it (might be orphaned) + logger.warning(f'Request {request.id} not found in state, skipping.') + continue + + # Sort forefront requests by sequence (newest first for LIFO behavior). + forefront_requests.sort(key=lambda item: item[1], reverse=True) + + # Sort regular requests by sequence (oldest first for FIFO behavior). + regular_requests.sort(key=lambda item: item[1], reverse=False) + + # Add forefront requests to the beginning of the cache (left side). Since forefront_requests are sorted + # by sequence (newest first), we need to add them in reverse order to maintain correct priority. + for request, _ in reversed(forefront_requests): + if len(self._request_cache) >= self._MAX_REQUESTS_IN_CACHE: + break + self._request_cache.appendleft(request) + + # Add regular requests to the end of the cache (right side). + for request, _ in regular_requests: + if len(self._request_cache) >= self._MAX_REQUESTS_IN_CACHE: + break + self._request_cache.append(request) + + self._request_cache_needs_refresh = False + + @classmethod + async def _get_request_files(cls, path_to_rq: Path) -> list[Path]: + """Get all request files from the RQ. + + Args: + path_to_rq: The path to the request queue directory. + + Returns: + A list of paths to all request files. + """ + # Create the requests directory if it doesn't exist. + await asyncio.to_thread(path_to_rq.mkdir, parents=True, exist_ok=True) + + # List all the json files. + files = await asyncio.to_thread(list, path_to_rq.glob('*.json')) + + # Filter out metadata file and non-file entries. + filtered = filter( + lambda request_file: request_file.is_file() and request_file.name != METADATA_FILENAME, + files, + ) + + return list(filtered) + + @classmethod + async def _parse_request_file(cls, file_path: Path) -> Request | None: + """Parse a request file and return the `Request` object. + + Args: + file_path: The path to the request file. + + Returns: + The parsed `Request` object or `None` if the file could not be read or parsed. + """ + # Open the request file. + try: + file = await asyncio.to_thread(open, file_path) + except FileNotFoundError: + logger.warning(f'Request file "{file_path}" not found.') + return None + + # Read the file content and parse it as JSON. + try: + file_content = json.load(file) + except json.JSONDecodeError as exc: + logger.warning(f'Failed to parse request file {file_path}: {exc!s}') + return None + finally: + await asyncio.to_thread(file.close) + + # Validate the content against the Request model. + try: + return Request.model_validate(file_content) + except ValidationError as exc: + logger.warning(f'Failed to validate request file {file_path}: {exc!s}') + return None + + async def _discover_existing_requests(self) -> None: + """Discover and load existing requests into the state when opening an existing request queue.""" + request_files = await self._get_request_files(self.path_to_rq) + state = self._state.current_value + + for request_file in request_files: + request = await self._parse_request_file(request_file) + if request is None: + continue + + # Add request to state as regular request (assign sequence numbers) + if request.id not in state.regular_requests and request.id not in state.forefront_requests: + # Assign as regular request with current sequence counter + state.regular_requests[request.id] = state.sequence_counter + state.sequence_counter += 1 + + # Check if request was already handled + if request.handled_at is not None: + state.handled_requests.add(request.id) diff --git a/src/crawlee/storage_clients/_file_system/_storage_client.py b/src/crawlee/storage_clients/_file_system/_storage_client.py new file mode 100644 index 0000000000..9c293725d3 --- /dev/null +++ b/src/crawlee/storage_clients/_file_system/_storage_client.py @@ -0,0 +1,69 @@ +from __future__ import annotations + +from typing_extensions import override + +from crawlee._utils.docs import docs_group +from crawlee.configuration import Configuration +from crawlee.storage_clients._base import StorageClient + +from ._dataset_client import FileSystemDatasetClient +from ._key_value_store_client import FileSystemKeyValueStoreClient +from ._request_queue_client import FileSystemRequestQueueClient + + +@docs_group('Classes') +class FileSystemStorageClient(StorageClient): + """File system implementation of the storage client. + + This storage client provides access to datasets, key-value stores, and request queues that persist data + to the local file system. Each storage type is implemented with its own specific file system client + that stores data in a structured directory hierarchy. + + Data is stored in JSON format in predictable file paths, making it easy to inspect and manipulate + the stored data outside of the Crawlee application if needed. + + All data persists between program runs but is limited to access from the local machine + where the files are stored. + + Warning: This storage client is not safe for concurrent access from multiple crawler processes. + Use it only when running a single crawler process at a time. + """ + + @override + async def create_dataset_client( + self, + *, + id: str | None = None, + name: str | None = None, + configuration: Configuration | None = None, + ) -> FileSystemDatasetClient: + configuration = configuration or Configuration.get_global_configuration() + client = await FileSystemDatasetClient.open(id=id, name=name, configuration=configuration) + await self._purge_if_needed(client, configuration) + return client + + @override + async def create_kvs_client( + self, + *, + id: str | None = None, + name: str | None = None, + configuration: Configuration | None = None, + ) -> FileSystemKeyValueStoreClient: + configuration = configuration or Configuration.get_global_configuration() + client = await FileSystemKeyValueStoreClient.open(id=id, name=name, configuration=configuration) + await self._purge_if_needed(client, configuration) + return client + + @override + async def create_rq_client( + self, + *, + id: str | None = None, + name: str | None = None, + configuration: Configuration | None = None, + ) -> FileSystemRequestQueueClient: + configuration = configuration or Configuration.get_global_configuration() + client = await FileSystemRequestQueueClient.open(id=id, name=name, configuration=configuration) + await self._purge_if_needed(client, configuration) + return client diff --git a/src/crawlee/storage_clients/_file_system/py.typed b/src/crawlee/storage_clients/_file_system/py.typed new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/crawlee/storage_clients/_memory/__init__.py b/src/crawlee/storage_clients/_memory/__init__.py index 09912e124d..3746907b4f 100644 --- a/src/crawlee/storage_clients/_memory/__init__.py +++ b/src/crawlee/storage_clients/_memory/__init__.py @@ -1,17 +1,11 @@ -from ._dataset_client import DatasetClient -from ._dataset_collection_client import DatasetCollectionClient -from ._key_value_store_client import KeyValueStoreClient -from ._key_value_store_collection_client import KeyValueStoreCollectionClient -from ._memory_storage_client import MemoryStorageClient -from ._request_queue_client import RequestQueueClient -from ._request_queue_collection_client import RequestQueueCollectionClient +from ._dataset_client import MemoryDatasetClient +from ._key_value_store_client import MemoryKeyValueStoreClient +from ._request_queue_client import MemoryRequestQueueClient +from ._storage_client import MemoryStorageClient __all__ = [ - 'DatasetClient', - 'DatasetCollectionClient', - 'KeyValueStoreClient', - 'KeyValueStoreCollectionClient', + 'MemoryDatasetClient', + 'MemoryKeyValueStoreClient', + 'MemoryRequestQueueClient', 'MemoryStorageClient', - 'RequestQueueClient', - 'RequestQueueCollectionClient', ] diff --git a/src/crawlee/storage_clients/_memory/_creation_management.py b/src/crawlee/storage_clients/_memory/_creation_management.py deleted file mode 100644 index f6d4fc1c91..0000000000 --- a/src/crawlee/storage_clients/_memory/_creation_management.py +++ /dev/null @@ -1,429 +0,0 @@ -from __future__ import annotations - -import asyncio -import json -import mimetypes -import os -import pathlib -from datetime import datetime, timezone -from logging import getLogger -from typing import TYPE_CHECKING - -from crawlee._consts import METADATA_FILENAME -from crawlee._utils.data_processing import maybe_parse_body -from crawlee._utils.file import json_dumps -from crawlee.storage_clients.models import ( - DatasetMetadata, - InternalRequest, - KeyValueStoreMetadata, - KeyValueStoreRecord, - KeyValueStoreRecordMetadata, - RequestQueueMetadata, -) - -if TYPE_CHECKING: - from ._dataset_client import DatasetClient - from ._key_value_store_client import KeyValueStoreClient - from ._memory_storage_client import MemoryStorageClient, TResourceClient - from ._request_queue_client import RequestQueueClient - -logger = getLogger(__name__) - - -async def persist_metadata_if_enabled(*, data: dict, entity_directory: str, write_metadata: bool) -> None: - """Update or writes metadata to a specified directory. - - The function writes a given metadata dictionary to a JSON file within a specified directory. - The writing process is skipped if `write_metadata` is False. Before writing, it ensures that - the target directory exists, creating it if necessary. - - Args: - data: A dictionary containing metadata to be written. - entity_directory: The directory path where the metadata file should be stored. - write_metadata: A boolean flag indicating whether the metadata should be written to file. - """ - # Skip metadata write; ensure directory exists first - if not write_metadata: - return - - # Ensure the directory for the entity exists - await asyncio.to_thread(os.makedirs, entity_directory, exist_ok=True) - - # Write the metadata to the file - file_path = os.path.join(entity_directory, METADATA_FILENAME) - f = await asyncio.to_thread(open, file_path, mode='wb') - try: - s = await json_dumps(data) - await asyncio.to_thread(f.write, s.encode('utf-8')) - finally: - await asyncio.to_thread(f.close) - - -def find_or_create_client_by_id_or_name_inner( - resource_client_class: type[TResourceClient], - memory_storage_client: MemoryStorageClient, - id: str | None = None, - name: str | None = None, -) -> TResourceClient | None: - """Locate or create a new storage client based on the given ID or name. - - This method attempts to find a storage client in the memory cache first. If not found, - it tries to locate a storage directory by name. If still not found, it searches through - storage directories for a matching ID or name in their metadata. If none exists, and the - specified ID is 'default', it checks for a default storage directory. If a storage client - is found or created, it is added to the memory cache. If no storage client can be located or - created, the method returns None. - - Args: - resource_client_class: The class of the resource client. - memory_storage_client: The memory storage client used to store and retrieve storage clients. - id: The unique identifier for the storage client. - name: The name of the storage client. - - Raises: - ValueError: If both id and name are None. - - Returns: - The found or created storage client, or None if no client could be found or created. - """ - from ._dataset_client import DatasetClient - from ._key_value_store_client import KeyValueStoreClient - from ._request_queue_client import RequestQueueClient - - if id is None and name is None: - raise ValueError('Either id or name must be specified.') - - # First check memory cache - found = memory_storage_client.get_cached_resource_client(resource_client_class, id, name) - - if found is not None: - return found - - storage_path = _determine_storage_path(resource_client_class, memory_storage_client, id, name) - - if not storage_path: - return None - - # Create from directory if storage path is found - if issubclass(resource_client_class, DatasetClient): - resource_client = create_dataset_from_directory(storage_path, memory_storage_client, id, name) - elif issubclass(resource_client_class, KeyValueStoreClient): - resource_client = create_kvs_from_directory(storage_path, memory_storage_client, id, name) - elif issubclass(resource_client_class, RequestQueueClient): - resource_client = create_rq_from_directory(storage_path, memory_storage_client, id, name) - else: - raise TypeError('Invalid resource client class.') - - memory_storage_client.add_resource_client_to_cache(resource_client) - return resource_client - - -async def get_or_create_inner( - *, - memory_storage_client: MemoryStorageClient, - storage_client_cache: list[TResourceClient], - resource_client_class: type[TResourceClient], - name: str | None = None, - id: str | None = None, -) -> TResourceClient: - """Retrieve a named storage, or create a new one when it doesn't exist. - - Args: - memory_storage_client: The memory storage client. - storage_client_cache: The cache of storage clients. - resource_client_class: The class of the storage to retrieve or create. - name: The name of the storage to retrieve or create. - id: ID of the storage to retrieve or create. - - Returns: - The retrieved or newly-created storage. - """ - # If the name or id is provided, try to find the dataset in the cache - if name or id: - found = find_or_create_client_by_id_or_name_inner( - resource_client_class=resource_client_class, - memory_storage_client=memory_storage_client, - name=name, - id=id, - ) - if found: - return found - - # Otherwise, create a new one and add it to the cache - resource_client = resource_client_class( - id=id, - name=name, - memory_storage_client=memory_storage_client, - ) - - storage_client_cache.append(resource_client) - - # Write to the disk - await persist_metadata_if_enabled( - data=resource_client.resource_info.model_dump(), - entity_directory=resource_client.resource_directory, - write_metadata=memory_storage_client.write_metadata, - ) - - return resource_client - - -def create_dataset_from_directory( - storage_directory: str, - memory_storage_client: MemoryStorageClient, - id: str | None = None, - name: str | None = None, -) -> DatasetClient: - from ._dataset_client import DatasetClient - - item_count = 0 - has_seen_metadata_file = False - created_at = datetime.now(timezone.utc) - accessed_at = datetime.now(timezone.utc) - modified_at = datetime.now(timezone.utc) - - # Load metadata if it exists - metadata_filepath = os.path.join(storage_directory, METADATA_FILENAME) - - if os.path.exists(metadata_filepath): - has_seen_metadata_file = True - with open(metadata_filepath, encoding='utf-8') as f: - json_content = json.load(f) - resource_info = DatasetMetadata(**json_content) - - id = resource_info.id - name = resource_info.name - item_count = resource_info.item_count - created_at = resource_info.created_at - accessed_at = resource_info.accessed_at - modified_at = resource_info.modified_at - - # Load dataset entries - entries: dict[str, dict] = {} - - for entry in os.scandir(storage_directory): - if entry.is_file(): - if entry.name == METADATA_FILENAME: - has_seen_metadata_file = True - continue - - with open(os.path.join(storage_directory, entry.name), encoding='utf-8') as f: - entry_content = json.load(f) - - entry_name = entry.name.split('.')[0] - entries[entry_name] = entry_content - - if not has_seen_metadata_file: - item_count += 1 - - # Create new dataset client - new_client = DatasetClient( - memory_storage_client=memory_storage_client, - id=id, - name=name, - created_at=created_at, - accessed_at=accessed_at, - modified_at=modified_at, - item_count=item_count, - ) - - new_client.dataset_entries.update(entries) - return new_client - - -def create_kvs_from_directory( - storage_directory: str, - memory_storage_client: MemoryStorageClient, - id: str | None = None, - name: str | None = None, -) -> KeyValueStoreClient: - from ._key_value_store_client import KeyValueStoreClient - - created_at = datetime.now(timezone.utc) - accessed_at = datetime.now(timezone.utc) - modified_at = datetime.now(timezone.utc) - - # Load metadata if it exists - metadata_filepath = os.path.join(storage_directory, METADATA_FILENAME) - - if os.path.exists(metadata_filepath): - with open(metadata_filepath, encoding='utf-8') as f: - json_content = json.load(f) - resource_info = KeyValueStoreMetadata(**json_content) - - id = resource_info.id - name = resource_info.name - created_at = resource_info.created_at - accessed_at = resource_info.accessed_at - modified_at = resource_info.modified_at - - # Create new KVS client - new_client = KeyValueStoreClient( - memory_storage_client=memory_storage_client, - id=id, - name=name, - accessed_at=accessed_at, - created_at=created_at, - modified_at=modified_at, - ) - - # Scan the KVS folder, check each entry in there and parse it as a store record - for entry in os.scandir(storage_directory): - if not entry.is_file(): - continue - - # Ignore metadata files on their own - if entry.name.endswith(METADATA_FILENAME): - continue - - # Try checking if this file has a metadata file associated with it - record_metadata = None - record_metadata_filepath = os.path.join(storage_directory, f'{entry.name}.__metadata__.json') - - if os.path.exists(record_metadata_filepath): - with open(record_metadata_filepath, encoding='utf-8') as metadata_file: - try: - json_content = json.load(metadata_file) - record_metadata = KeyValueStoreRecordMetadata(**json_content) - - except Exception: - logger.warning( - f'Metadata of key-value store entry "{entry.name}" for store {name or id} could ' - 'not be parsed. The metadata file will be ignored.', - exc_info=True, - ) - - if not record_metadata: - content_type, _ = mimetypes.guess_type(entry.name) - if content_type is None: - content_type = 'application/octet-stream' - - record_metadata = KeyValueStoreRecordMetadata( - key=pathlib.Path(entry.name).stem, - content_type=content_type, - ) - - with open(os.path.join(storage_directory, entry.name), 'rb') as f: - file_content = f.read() - - try: - maybe_parse_body(file_content, record_metadata.content_type) - except Exception: - record_metadata.content_type = 'application/octet-stream' - logger.warning( - f'Key-value store entry "{record_metadata.key}" for store {name or id} could not be parsed.' - 'The entry will be assumed as binary.', - exc_info=True, - ) - - new_client.records[record_metadata.key] = KeyValueStoreRecord( - key=record_metadata.key, - content_type=record_metadata.content_type, - filename=entry.name, - value=file_content, - ) - - return new_client - - -def create_rq_from_directory( - storage_directory: str, - memory_storage_client: MemoryStorageClient, - id: str | None = None, - name: str | None = None, -) -> RequestQueueClient: - from ._request_queue_client import RequestQueueClient - - created_at = datetime.now(timezone.utc) - accessed_at = datetime.now(timezone.utc) - modified_at = datetime.now(timezone.utc) - handled_request_count = 0 - pending_request_count = 0 - - # Load metadata if it exists - metadata_filepath = os.path.join(storage_directory, METADATA_FILENAME) - - if os.path.exists(metadata_filepath): - with open(metadata_filepath, encoding='utf-8') as f: - json_content = json.load(f) - resource_info = RequestQueueMetadata(**json_content) - - id = resource_info.id - name = resource_info.name - created_at = resource_info.created_at - accessed_at = resource_info.accessed_at - modified_at = resource_info.modified_at - handled_request_count = resource_info.handled_request_count - pending_request_count = resource_info.pending_request_count - - # Load request entries - entries: dict[str, InternalRequest] = {} - - for entry in os.scandir(storage_directory): - if entry.is_file(): - if entry.name == METADATA_FILENAME: - continue - - with open(os.path.join(storage_directory, entry.name), encoding='utf-8') as f: - content = json.load(f) - - request = InternalRequest(**content) - - entries[request.id] = request - - # Create new RQ client - new_client = RequestQueueClient( - memory_storage_client=memory_storage_client, - id=id, - name=name, - accessed_at=accessed_at, - created_at=created_at, - modified_at=modified_at, - handled_request_count=handled_request_count, - pending_request_count=pending_request_count, - ) - - new_client.requests.update(entries) - return new_client - - -def _determine_storage_path( - resource_client_class: type[TResourceClient], - memory_storage_client: MemoryStorageClient, - id: str | None = None, - name: str | None = None, -) -> str | None: - storages_dir = memory_storage_client._get_storage_dir(resource_client_class) # noqa: SLF001 - default_id = memory_storage_client._get_default_storage_id(resource_client_class) # noqa: SLF001 - - # Try to find by name directly from directories - if name: - possible_storage_path = os.path.join(storages_dir, name) - if os.access(possible_storage_path, os.F_OK): - return possible_storage_path - - # If not found, try finding by metadata - if os.access(storages_dir, os.F_OK): - for entry in os.scandir(storages_dir): - if entry.is_dir(): - metadata_path = os.path.join(entry.path, METADATA_FILENAME) - if os.access(metadata_path, os.F_OK): - with open(metadata_path, encoding='utf-8') as metadata_file: - try: - metadata = json.load(metadata_file) - if (id and metadata.get('id') == id) or (name and metadata.get('name') == name): - return entry.path - except Exception: - logger.warning( - f'Metadata of store entry "{entry.name}" for store {name or id} could not be parsed. ' - 'The metadata file will be ignored.', - exc_info=True, - ) - - # Check for default storage directory as a last resort - if id == default_id: - possible_storage_path = os.path.join(storages_dir, default_id) - if os.access(possible_storage_path, os.F_OK): - return possible_storage_path - - return None diff --git a/src/crawlee/storage_clients/_memory/_dataset_client.py b/src/crawlee/storage_clients/_memory/_dataset_client.py index 50c8c7c8d4..dd64a9d9ed 100644 --- a/src/crawlee/storage_clients/_memory/_dataset_client.py +++ b/src/crawlee/storage_clients/_memory/_dataset_client.py @@ -1,162 +1,130 @@ from __future__ import annotations -import asyncio -import json -import os -import shutil from datetime import datetime, timezone from logging import getLogger from typing import TYPE_CHECKING, Any from typing_extensions import override -from crawlee._types import StorageTypes from crawlee._utils.crypto import crypto_random_object_id -from crawlee._utils.data_processing import raise_on_duplicate_storage, raise_on_non_existing_storage -from crawlee._utils.file import force_rename, json_dumps -from crawlee.storage_clients._base import DatasetClient as BaseDatasetClient +from crawlee.storage_clients._base import DatasetClient from crawlee.storage_clients.models import DatasetItemsListPage, DatasetMetadata -from ._creation_management import find_or_create_client_by_id_or_name_inner - if TYPE_CHECKING: from collections.abc import AsyncIterator - from contextlib import AbstractAsyncContextManager - - from httpx import Response - - from crawlee._types import JsonSerializable - from crawlee.storage_clients import MemoryStorageClient logger = getLogger(__name__) -class DatasetClient(BaseDatasetClient): - """Subclient for manipulating a single dataset.""" +class MemoryDatasetClient(DatasetClient): + """Memory implementation of the dataset client. - _LIST_ITEMS_LIMIT = 999_999_999_999 - """This is what API returns in the x-apify-pagination-limit header when no limit query parameter is used.""" + This client stores dataset items in memory using Python lists and dictionaries. No data is persisted + between process runs, meaning all stored data is lost when the program terminates. This implementation + is primarily useful for testing, development, and short-lived crawler operations where persistent + storage is not required. - _LOCAL_ENTRY_NAME_DIGITS = 9 - """Number of characters of the dataset item file names, e.g.: 000000019.json - 9 digits.""" + The memory implementation provides fast access to data but is limited by available memory and + does not support data sharing across different processes. It supports all dataset operations including + sorting, filtering, and pagination, but performs them entirely in memory. + """ def __init__( self, *, - memory_storage_client: MemoryStorageClient, - id: str | None = None, - name: str | None = None, - created_at: datetime | None = None, - accessed_at: datetime | None = None, - modified_at: datetime | None = None, - item_count: int = 0, + metadata: DatasetMetadata, ) -> None: - self._memory_storage_client = memory_storage_client - self.id = id or crypto_random_object_id() - self.name = name - self._created_at = created_at or datetime.now(timezone.utc) - self._accessed_at = accessed_at or datetime.now(timezone.utc) - self._modified_at = modified_at or datetime.now(timezone.utc) - - self.dataset_entries: dict[str, dict] = {} - self.file_operation_lock = asyncio.Lock() - self.item_count = item_count - - @property - def resource_info(self) -> DatasetMetadata: - """Get the resource info for the dataset client.""" - return DatasetMetadata( - id=self.id, - name=self.name, - accessed_at=self._accessed_at, - created_at=self._created_at, - modified_at=self._modified_at, - item_count=self.item_count, - ) - - @property - def resource_directory(self) -> str: - """Get the resource directory for the client.""" - return os.path.join(self._memory_storage_client.datasets_directory, self.name or self.id) + """Initialize a new instance. - @override - async def get(self) -> DatasetMetadata | None: - found = find_or_create_client_by_id_or_name_inner( - resource_client_class=DatasetClient, - memory_storage_client=self._memory_storage_client, - id=self.id, - name=self.name, - ) - - if found: - async with found.file_operation_lock: - await found.update_timestamps(has_been_modified=False) - return found.resource_info + Preferably use the `MemoryDatasetClient.open` class method to create a new instance. + """ + self._metadata = metadata - return None + self._records = list[dict[str, Any]]() + """List to hold dataset items. Each item is a dictionary representing a record.""" @override - async def update(self, *, name: str | None = None) -> DatasetMetadata: - # Check by id - existing_dataset_by_id = find_or_create_client_by_id_or_name_inner( - resource_client_class=DatasetClient, - memory_storage_client=self._memory_storage_client, - id=self.id, - name=self.name, - ) + async def get_metadata(self) -> DatasetMetadata: + return self._metadata - if existing_dataset_by_id is None: - raise_on_non_existing_storage(StorageTypes.DATASET, self.id) - - # Skip if no changes - if name is None: - return existing_dataset_by_id.resource_info - - async with existing_dataset_by_id.file_operation_lock: - # Check that name is not in use already - existing_dataset_by_name = next( - ( - dataset - for dataset in self._memory_storage_client.datasets_handled - if dataset.name and dataset.name.lower() == name.lower() - ), - None, - ) - - if existing_dataset_by_name is not None: - raise_on_duplicate_storage(StorageTypes.DATASET, 'name', name) + @classmethod + async def open( + cls, + *, + id: str | None, + name: str | None, + ) -> MemoryDatasetClient: + """Open or create a new memory dataset client. - previous_dir = existing_dataset_by_id.resource_directory - existing_dataset_by_id.name = name + This method creates a new in-memory dataset instance. Unlike persistent storage implementations, memory + datasets don't check for existing datasets with the same name or ID since all data exists only in memory + and is lost when the process terminates. - await force_rename(previous_dir, existing_dataset_by_id.resource_directory) + Args: + id: The ID of the dataset. If not provided, a random ID will be generated. + name: The name of the dataset. If not provided, the dataset will be unnamed. - # Update timestamps - await existing_dataset_by_id.update_timestamps(has_been_modified=True) + Returns: + An instance for the opened or created storage client. + """ + # Otherwise create a new dataset + dataset_id = id or crypto_random_object_id() + now = datetime.now(timezone.utc) + + metadata = DatasetMetadata( + id=dataset_id, + name=name, + created_at=now, + accessed_at=now, + modified_at=now, + item_count=0, + ) - return existing_dataset_by_id.resource_info + return cls(metadata=metadata) @override - async def delete(self) -> None: - dataset = next( - (dataset for dataset in self._memory_storage_client.datasets_handled if dataset.id == self.id), None + async def drop(self) -> None: + self._records.clear() + await self._update_metadata( + update_accessed_at=True, + update_modified_at=True, + new_item_count=0, ) - if dataset is not None: - async with dataset.file_operation_lock: - self._memory_storage_client.datasets_handled.remove(dataset) - dataset.item_count = 0 - dataset.dataset_entries.clear() + @override + async def purge(self) -> None: + self._records.clear() + await self._update_metadata( + update_accessed_at=True, + update_modified_at=True, + new_item_count=0, + ) - if os.path.exists(dataset.resource_directory): - await asyncio.to_thread(shutil.rmtree, dataset.resource_directory) + @override + async def push_data(self, data: list[dict[str, Any]] | dict[str, Any]) -> None: + metadata = await self.get_metadata() + new_item_count = metadata.item_count + + if isinstance(data, list): + for item in data: + new_item_count += 1 + await self._push_item(item) + else: + new_item_count += 1 + await self._push_item(data) + + await self._update_metadata( + update_accessed_at=True, + update_modified_at=True, + new_item_count=new_item_count, + ) @override - async def list_items( + async def get_data( self, *, - offset: int | None = 0, - limit: int | None = _LIST_ITEMS_LIMIT, + offset: int = 0, + limit: int | None = 999_999_999_999, clean: bool = False, desc: bool = False, fields: list[str] | None = None, @@ -167,44 +135,48 @@ async def list_items( flatten: list[str] | None = None, view: str | None = None, ) -> DatasetItemsListPage: - # Check by id - existing_dataset_by_id = find_or_create_client_by_id_or_name_inner( - resource_client_class=DatasetClient, - memory_storage_client=self._memory_storage_client, - id=self.id, - name=self.name, - ) - - if existing_dataset_by_id is None: - raise_on_non_existing_storage(StorageTypes.DATASET, self.id) - - async with existing_dataset_by_id.file_operation_lock: - start, end = existing_dataset_by_id.get_start_and_end_indexes( - max(existing_dataset_by_id.item_count - (offset or 0) - (limit or self._LIST_ITEMS_LIMIT), 0) - if desc - else offset or 0, - limit, + # Check for unsupported arguments and log a warning if found + unsupported_args: dict[str, Any] = { + 'clean': clean, + 'fields': fields, + 'omit': omit, + 'unwind': unwind, + 'skip_hidden': skip_hidden, + 'flatten': flatten, + 'view': view, + } + unsupported = {k: v for k, v in unsupported_args.items() if v not in (False, None)} + + if unsupported: + logger.warning( + f'The arguments {list(unsupported.keys())} of get_data are not supported ' + f'by the {self.__class__.__name__} client.' ) - items = [] + total = len(self._records) + items = self._records.copy() - for idx in range(start, end): - entry_number = self._generate_local_entry_name(idx) - items.append(existing_dataset_by_id.dataset_entries[entry_number]) + # Apply skip_empty filter if requested + if skip_empty: + items = [item for item in items if item] - await existing_dataset_by_id.update_timestamps(has_been_modified=False) + # Apply sorting + if desc: + items = list(reversed(items)) - if desc: - items.reverse() + # Apply pagination + sliced_items = items[offset : (offset + limit) if limit is not None else total] - return DatasetItemsListPage( - count=len(items), - desc=desc or False, - items=items, - limit=limit or self._LIST_ITEMS_LIMIT, - offset=offset or 0, - total=existing_dataset_by_id.item_count, - ) + await self._update_metadata(update_accessed_at=True) + + return DatasetItemsListPage( + count=len(sliced_items), + offset=offset, + limit=limit or (total - offset), + total=total, + desc=desc, + items=sliced_items, + ) @override async def iterate_items( @@ -219,192 +191,67 @@ async def iterate_items( unwind: str | None = None, skip_empty: bool = False, skip_hidden: bool = False, - ) -> AsyncIterator[dict]: - cache_size = 1000 - first_item = offset - - # If there is no limit, set last_item to None until we get the total from the first API response - last_item = None if limit is None else offset + limit - current_offset = first_item - - while last_item is None or current_offset < last_item: - current_limit = cache_size if last_item is None else min(cache_size, last_item - current_offset) - - current_items_page = await self.list_items( - offset=current_offset, - limit=current_limit, - desc=desc, + ) -> AsyncIterator[dict[str, Any]]: + # Check for unsupported arguments and log a warning if found + unsupported_args: dict[str, Any] = { + 'clean': clean, + 'fields': fields, + 'omit': omit, + 'unwind': unwind, + 'skip_hidden': skip_hidden, + } + unsupported = {k: v for k, v in unsupported_args.items() if v not in (False, None)} + + if unsupported: + logger.warning( + f'The arguments {list(unsupported.keys())} of iterate are not supported ' + f'by the {self.__class__.__name__} client.' ) - current_offset += current_items_page.count - if last_item is None or current_items_page.total < last_item: - last_item = current_items_page.total - - for item in current_items_page.items: - yield item - - @override - async def get_items_as_bytes( - self, - *, - item_format: str = 'json', - offset: int | None = None, - limit: int | None = None, - desc: bool = False, - clean: bool = False, - bom: bool = False, - delimiter: str | None = None, - fields: list[str] | None = None, - omit: list[str] | None = None, - unwind: str | None = None, - skip_empty: bool = False, - skip_header_row: bool = False, - skip_hidden: bool = False, - xml_root: str | None = None, - xml_row: str | None = None, - flatten: list[str] | None = None, - ) -> bytes: - raise NotImplementedError('This method is not supported in memory storage.') - - @override - async def stream_items( - self, - *, - item_format: str = 'json', - offset: int | None = None, - limit: int | None = None, - desc: bool = False, - clean: bool = False, - bom: bool = False, - delimiter: str | None = None, - fields: list[str] | None = None, - omit: list[str] | None = None, - unwind: str | None = None, - skip_empty: bool = False, - skip_header_row: bool = False, - skip_hidden: bool = False, - xml_root: str | None = None, - xml_row: str | None = None, - ) -> AbstractAsyncContextManager[Response | None]: - raise NotImplementedError('This method is not supported in memory storage.') + items = self._records.copy() - @override - async def push_items( - self, - items: JsonSerializable, - ) -> None: - # Check by id - existing_dataset_by_id = find_or_create_client_by_id_or_name_inner( - resource_client_class=DatasetClient, - memory_storage_client=self._memory_storage_client, - id=self.id, - name=self.name, - ) - - if existing_dataset_by_id is None: - raise_on_non_existing_storage(StorageTypes.DATASET, self.id) + # Apply sorting + if desc: + items = list(reversed(items)) - normalized = self._normalize_items(items) + # Apply pagination + sliced_items = items[offset : (offset + limit) if limit is not None else len(items)] - added_ids: list[str] = [] - for entry in normalized: - existing_dataset_by_id.item_count += 1 - idx = self._generate_local_entry_name(existing_dataset_by_id.item_count) + # Yield items one by one + for item in sliced_items: + if skip_empty and not item: + continue + yield item - existing_dataset_by_id.dataset_entries[idx] = entry - added_ids.append(idx) - - data_entries = [(id, existing_dataset_by_id.dataset_entries[id]) for id in added_ids] - - async with existing_dataset_by_id.file_operation_lock: - await existing_dataset_by_id.update_timestamps(has_been_modified=True) - - await self._persist_dataset_items_to_disk( - data=data_entries, - entity_directory=existing_dataset_by_id.resource_directory, - persist_storage=self._memory_storage_client.persist_storage, - ) + await self._update_metadata(update_accessed_at=True) - async def _persist_dataset_items_to_disk( + async def _update_metadata( self, *, - data: list[tuple[str, dict]], - entity_directory: str, - persist_storage: bool, + new_item_count: int | None = None, + update_accessed_at: bool = False, + update_modified_at: bool = False, ) -> None: - """Write dataset items to the disk. - - The function iterates over a list of dataset items, each represented as a tuple of an identifier - and a dictionary, and writes them as individual JSON files in a specified directory. The function - will skip writing if `persist_storage` is False. Before writing, it ensures that the target - directory exists, creating it if necessary. + """Update the dataset metadata with current information. Args: - data: A list of tuples, each containing an identifier (string) and a data dictionary. - entity_directory: The directory path where the dataset items should be stored. - persist_storage: A boolean flag indicating whether the data should be persisted to the disk. + new_item_count: If provided, update the item count to this value. + update_accessed_at: If True, update the `accessed_at` timestamp to the current time. + update_modified_at: If True, update the `modified_at` timestamp to the current time. """ - # Skip writing files to the disk if the client has the option set to false - if not persist_storage: - return - - # Ensure the directory for the entity exists - await asyncio.to_thread(os.makedirs, entity_directory, exist_ok=True) - - # Save all the new items to the disk - for idx, item in data: - file_path = os.path.join(entity_directory, f'{idx}.json') - f = await asyncio.to_thread(open, file_path, mode='w', encoding='utf-8') - try: - s = await json_dumps(item) - await asyncio.to_thread(f.write, s) - finally: - await asyncio.to_thread(f.close) - - async def update_timestamps(self, *, has_been_modified: bool) -> None: - """Update the timestamps of the dataset.""" - from ._creation_management import persist_metadata_if_enabled - - self._accessed_at = datetime.now(timezone.utc) - - if has_been_modified: - self._modified_at = datetime.now(timezone.utc) - - await persist_metadata_if_enabled( - data=self.resource_info.model_dump(), - entity_directory=self.resource_directory, - write_metadata=self._memory_storage_client.write_metadata, - ) - - def get_start_and_end_indexes(self, offset: int, limit: int | None = None) -> tuple[int, int]: - """Calculate the start and end indexes for listing items.""" - actual_limit = limit or self.item_count - start = offset + 1 - end = min(offset + actual_limit, self.item_count) + 1 - return (start, end) - - def _generate_local_entry_name(self, idx: int) -> str: - return str(idx).zfill(self._LOCAL_ENTRY_NAME_DIGITS) - - def _normalize_items(self, items: JsonSerializable) -> list[dict]: - def normalize_item(item: Any) -> dict | None: - if isinstance(item, str): - item = json.loads(item) + now = datetime.now(timezone.utc) - if isinstance(item, list): - received = ',\n'.join(item) - raise TypeError( - f'Each dataset item can only be a single JSON object, not an array. Received: [{received}]' - ) + if update_accessed_at: + self._metadata.accessed_at = now + if update_modified_at: + self._metadata.modified_at = now + if new_item_count is not None: + self._metadata.item_count = new_item_count - if (not isinstance(item, dict)) and item is not None: - raise TypeError(f'Each dataset item must be a JSON object. Received: {item}') + async def _push_item(self, item: dict[str, Any]) -> None: + """Push a single item to the dataset. - return item - - if isinstance(items, str): - items = json.loads(items) - - result = list(map(normalize_item, items)) if isinstance(items, list) else [normalize_item(items)] - # filter(None, ..) returns items that are True - return list(filter(None, result)) + Args: + item: The data item to add to the dataset. + """ + self._records.append(item) diff --git a/src/crawlee/storage_clients/_memory/_dataset_collection_client.py b/src/crawlee/storage_clients/_memory/_dataset_collection_client.py deleted file mode 100644 index 9e32b4086b..0000000000 --- a/src/crawlee/storage_clients/_memory/_dataset_collection_client.py +++ /dev/null @@ -1,62 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -from typing_extensions import override - -from crawlee.storage_clients._base import DatasetCollectionClient as BaseDatasetCollectionClient -from crawlee.storage_clients.models import DatasetListPage, DatasetMetadata - -from ._creation_management import get_or_create_inner -from ._dataset_client import DatasetClient - -if TYPE_CHECKING: - from ._memory_storage_client import MemoryStorageClient - - -class DatasetCollectionClient(BaseDatasetCollectionClient): - """Subclient for manipulating datasets.""" - - def __init__(self, *, memory_storage_client: MemoryStorageClient) -> None: - self._memory_storage_client = memory_storage_client - - @property - def _storage_client_cache(self) -> list[DatasetClient]: - return self._memory_storage_client.datasets_handled - - @override - async def get_or_create( - self, - *, - name: str | None = None, - schema: dict | None = None, - id: str | None = None, - ) -> DatasetMetadata: - resource_client = await get_or_create_inner( - memory_storage_client=self._memory_storage_client, - storage_client_cache=self._storage_client_cache, - resource_client_class=DatasetClient, - name=name, - id=id, - ) - return resource_client.resource_info - - @override - async def list( - self, - *, - unnamed: bool = False, - limit: int | None = None, - offset: int | None = None, - desc: bool = False, - ) -> DatasetListPage: - items = [storage.resource_info for storage in self._storage_client_cache] - - return DatasetListPage( - total=len(items), - count=len(items), - offset=0, - limit=len(items), - desc=False, - items=sorted(items, key=lambda item: item.created_at), - ) diff --git a/src/crawlee/storage_clients/_memory/_key_value_store_client.py b/src/crawlee/storage_clients/_memory/_key_value_store_client.py index ab9def0f06..7dacf6d95d 100644 --- a/src/crawlee/storage_clients/_memory/_key_value_store_client.py +++ b/src/crawlee/storage_clients/_memory/_key_value_store_client.py @@ -1,425 +1,177 @@ from __future__ import annotations -import asyncio -import io -import os -import shutil +import sys from datetime import datetime, timezone -from logging import getLogger from typing import TYPE_CHECKING, Any from typing_extensions import override -from crawlee._types import StorageTypes from crawlee._utils.crypto import crypto_random_object_id -from crawlee._utils.data_processing import maybe_parse_body, raise_on_duplicate_storage, raise_on_non_existing_storage -from crawlee._utils.file import determine_file_extension, force_remove, force_rename, is_file_or_bytes, json_dumps -from crawlee.storage_clients._base import KeyValueStoreClient as BaseKeyValueStoreClient -from crawlee.storage_clients.models import ( - KeyValueStoreKeyInfo, - KeyValueStoreListKeysPage, - KeyValueStoreMetadata, - KeyValueStoreRecord, - KeyValueStoreRecordMetadata, -) - -from ._creation_management import find_or_create_client_by_id_or_name_inner, persist_metadata_if_enabled +from crawlee._utils.file import infer_mime_type +from crawlee.storage_clients._base import KeyValueStoreClient +from crawlee.storage_clients.models import KeyValueStoreMetadata, KeyValueStoreRecord, KeyValueStoreRecordMetadata if TYPE_CHECKING: - from contextlib import AbstractAsyncContextManager + from collections.abc import AsyncIterator - from httpx import Response - from crawlee.storage_clients import MemoryStorageClient +class MemoryKeyValueStoreClient(KeyValueStoreClient): + """Memory implementation of the key-value store client. -logger = getLogger(__name__) + This client stores data in memory as Python dictionaries. No data is persisted between + process runs, meaning all stored data is lost when the program terminates. This implementation + is primarily useful for testing, development, and short-lived crawler operations where + persistence is not required. - -class KeyValueStoreClient(BaseKeyValueStoreClient): - """Subclient for manipulating a single key-value store.""" + The memory implementation provides fast access to data but is limited by available memory and + does not support data sharing across different processes. + """ def __init__( self, *, - memory_storage_client: MemoryStorageClient, - id: str | None = None, - name: str | None = None, - created_at: datetime | None = None, - accessed_at: datetime | None = None, - modified_at: datetime | None = None, + metadata: KeyValueStoreMetadata, ) -> None: - self.id = id or crypto_random_object_id() - self.name = name - - self._memory_storage_client = memory_storage_client - self._created_at = created_at or datetime.now(timezone.utc) - self._accessed_at = accessed_at or datetime.now(timezone.utc) - self._modified_at = modified_at or datetime.now(timezone.utc) - - self.records: dict[str, KeyValueStoreRecord] = {} - self.file_operation_lock = asyncio.Lock() - - @property - def resource_info(self) -> KeyValueStoreMetadata: - """Get the resource info for the key-value store client.""" - return KeyValueStoreMetadata( - id=self.id, - name=self.name, - accessed_at=self._accessed_at, - created_at=self._created_at, - modified_at=self._modified_at, - user_id='1', - ) + """Initialize a new instance. + + Preferably use the `MemoryKeyValueStoreClient.open` class method to create a new instance. + """ + self._metadata = metadata - @property - def resource_directory(self) -> str: - """Get the resource directory for the client.""" - return os.path.join(self._memory_storage_client.key_value_stores_directory, self.name or self.id) + self._records = dict[str, KeyValueStoreRecord]() + """Dictionary to hold key-value records.""" @override - async def get(self) -> KeyValueStoreMetadata | None: - found = find_or_create_client_by_id_or_name_inner( - resource_client_class=KeyValueStoreClient, - memory_storage_client=self._memory_storage_client, - id=self.id, - name=self.name, + async def get_metadata(self) -> KeyValueStoreMetadata: + return self._metadata + + @classmethod + async def open( + cls, + *, + id: str | None, + name: str | None, + ) -> MemoryKeyValueStoreClient: + """Open or create a new memory key-value store client. + + This method creates a new in-memory key-value store instance. Unlike persistent storage implementations, + memory KVS don't check for existing stores with the same name or ID since all data exists only in memory + and is lost when the process terminates. + + Args: + id: The ID of the key-value store. If not provided, a random ID will be generated. + name: The name of the key-value store. If not provided, the store will be unnamed. + + Returns: + An instance for the opened or created storage client. + """ + # Otherwise create a new key-value store + store_id = id or crypto_random_object_id() + now = datetime.now(timezone.utc) + + metadata = KeyValueStoreMetadata( + id=store_id, + name=name, + created_at=now, + accessed_at=now, + modified_at=now, ) - if found: - async with found.file_operation_lock: - await found.update_timestamps(has_been_modified=False) - return found.resource_info + return cls(metadata=metadata) - return None + @override + async def drop(self) -> None: + self._records.clear() + await self._update_metadata(update_accessed_at=True, update_modified_at=True) @override - async def update(self, *, name: str | None = None) -> KeyValueStoreMetadata: - # Check by id - existing_store_by_id = find_or_create_client_by_id_or_name_inner( - resource_client_class=KeyValueStoreClient, - memory_storage_client=self._memory_storage_client, - id=self.id, - name=self.name, - ) + async def purge(self) -> None: + self._records.clear() + await self._update_metadata(update_accessed_at=True, update_modified_at=True) - if existing_store_by_id is None: - raise_on_non_existing_storage(StorageTypes.KEY_VALUE_STORE, self.id) - - # Skip if no changes - if name is None: - return existing_store_by_id.resource_info - - async with existing_store_by_id.file_operation_lock: - # Check that name is not in use already - existing_store_by_name = next( - ( - store - for store in self._memory_storage_client.key_value_stores_handled - if store.name and store.name.lower() == name.lower() - ), - None, - ) + @override + async def get_value(self, *, key: str) -> KeyValueStoreRecord | None: + await self._update_metadata(update_accessed_at=True) - if existing_store_by_name is not None: - raise_on_duplicate_storage(StorageTypes.KEY_VALUE_STORE, 'name', name) + # Return None if key doesn't exist + return self._records.get(key, None) - previous_dir = existing_store_by_id.resource_directory - existing_store_by_id.name = name + @override + async def set_value(self, *, key: str, value: Any, content_type: str | None = None) -> None: + content_type = content_type or infer_mime_type(value) + size = sys.getsizeof(value) - await force_rename(previous_dir, existing_store_by_id.resource_directory) + # Create and store the record + record = KeyValueStoreRecord( + key=key, + value=value, + content_type=content_type, + size=size, + ) - # Update timestamps - await existing_store_by_id.update_timestamps(has_been_modified=True) + self._records[key] = record - return existing_store_by_id.resource_info + await self._update_metadata(update_accessed_at=True, update_modified_at=True) @override - async def delete(self) -> None: - store = next( - (store for store in self._memory_storage_client.key_value_stores_handled if store.id == self.id), None - ) - - if store is not None: - async with store.file_operation_lock: - self._memory_storage_client.key_value_stores_handled.remove(store) - store.records.clear() - - if os.path.exists(store.resource_directory): - await asyncio.to_thread(shutil.rmtree, store.resource_directory) + async def delete_value(self, *, key: str) -> None: + if key in self._records: + del self._records[key] + await self._update_metadata(update_accessed_at=True, update_modified_at=True) @override - async def list_keys( + async def iterate_keys( self, *, - limit: int = 1000, exclusive_start_key: str | None = None, - ) -> KeyValueStoreListKeysPage: - # Check by id - existing_store_by_id = find_or_create_client_by_id_or_name_inner( - resource_client_class=KeyValueStoreClient, - memory_storage_client=self._memory_storage_client, - id=self.id, - name=self.name, - ) - - if existing_store_by_id is None: - raise_on_non_existing_storage(StorageTypes.KEY_VALUE_STORE, self.id) - - items: list[KeyValueStoreKeyInfo] = [] - - for record in existing_store_by_id.records.values(): - size = len(record.value) - items.append(KeyValueStoreKeyInfo(key=record.key, size=size)) - - if len(items) == 0: - return KeyValueStoreListKeysPage( - count=len(items), - limit=limit, - exclusive_start_key=exclusive_start_key, - is_truncated=False, - next_exclusive_start_key=None, - items=items, - ) + limit: int | None = None, + ) -> AsyncIterator[KeyValueStoreRecordMetadata]: + await self._update_metadata(update_accessed_at=True) - # Lexically sort to emulate the API - items = sorted(items, key=lambda item: item.key) + # Get all keys, sorted alphabetically + keys = sorted(self._records.keys()) - truncated_items = items + # Apply exclusive_start_key filter if provided if exclusive_start_key is not None: - key_pos = next((idx for idx, item in enumerate(items) if item.key == exclusive_start_key), None) - if key_pos is not None: - truncated_items = items[(key_pos + 1) :] - - limited_items = truncated_items[:limit] - - last_item_in_store = items[-1] - last_selected_item = limited_items[-1] - is_last_selected_item_absolutely_last = last_item_in_store == last_selected_item - next_exclusive_start_key = None if is_last_selected_item_absolutely_last else last_selected_item.key - - async with existing_store_by_id.file_operation_lock: - await existing_store_by_id.update_timestamps(has_been_modified=False) - - return KeyValueStoreListKeysPage( - count=len(items), - limit=limit, - exclusive_start_key=exclusive_start_key, - is_truncated=not is_last_selected_item_absolutely_last, - next_exclusive_start_key=next_exclusive_start_key, - items=limited_items, - ) - - @override - async def get_record(self, key: str) -> KeyValueStoreRecord | None: - return await self._get_record_internal(key) - - @override - async def get_record_as_bytes(self, key: str) -> KeyValueStoreRecord[bytes] | None: - return await self._get_record_internal(key, as_bytes=True) - - @override - async def stream_record(self, key: str) -> AbstractAsyncContextManager[KeyValueStoreRecord[Response] | None]: - raise NotImplementedError('This method is not supported in memory storage.') + keys = [k for k in keys if k > exclusive_start_key] + + # Apply limit if provided + if limit is not None: + keys = keys[:limit] + + # Yield metadata for each key + for key in keys: + record = self._records[key] + yield KeyValueStoreRecordMetadata( + key=key, + content_type=record.content_type, + size=record.size, + ) @override - async def set_record(self, key: str, value: Any, content_type: str | None = None) -> None: - # Check by id - existing_store_by_id = find_or_create_client_by_id_or_name_inner( - resource_client_class=KeyValueStoreClient, - memory_storage_client=self._memory_storage_client, - id=self.id, - name=self.name, - ) - - if existing_store_by_id is None: - raise_on_non_existing_storage(StorageTypes.KEY_VALUE_STORE, self.id) - - if isinstance(value, io.IOBase): - raise NotImplementedError('File-like values are not supported in local memory storage') - - if content_type is None: - if is_file_or_bytes(value): - content_type = 'application/octet-stream' - elif isinstance(value, str): - content_type = 'text/plain; charset=utf-8' - else: - content_type = 'application/json; charset=utf-8' - - if 'application/json' in content_type and not is_file_or_bytes(value) and not isinstance(value, str): - s = await json_dumps(value) - value = s.encode('utf-8') - - async with existing_store_by_id.file_operation_lock: - await existing_store_by_id.update_timestamps(has_been_modified=True) - record = KeyValueStoreRecord(key=key, value=value, content_type=content_type, filename=None) - - old_record = existing_store_by_id.records.get(key) - existing_store_by_id.records[key] = record - - if self._memory_storage_client.persist_storage: - record_filename = self._filename_from_record(record) - record.filename = record_filename - - if old_record is not None and self._filename_from_record(old_record) != record_filename: - await existing_store_by_id.delete_persisted_record(old_record) - - await existing_store_by_id.persist_record(record) + async def get_public_url(self, *, key: str) -> str: + raise NotImplementedError('Public URLs are not supported for memory key-value stores.') @override - async def delete_record(self, key: str) -> None: - # Check by id - existing_store_by_id = find_or_create_client_by_id_or_name_inner( - resource_client_class=KeyValueStoreClient, - memory_storage_client=self._memory_storage_client, - id=self.id, - name=self.name, - ) - - if existing_store_by_id is None: - raise_on_non_existing_storage(StorageTypes.KEY_VALUE_STORE, self.id) + async def record_exists(self, *, key: str) -> bool: + await self._update_metadata(update_accessed_at=True) + return key in self._records - record = existing_store_by_id.records.get(key) - - if record is not None: - async with existing_store_by_id.file_operation_lock: - del existing_store_by_id.records[key] - await existing_store_by_id.update_timestamps(has_been_modified=True) - if self._memory_storage_client.persist_storage: - await existing_store_by_id.delete_persisted_record(record) - - @override - async def get_public_url(self, key: str) -> str: - existing_store_by_id = find_or_create_client_by_id_or_name_inner( - resource_client_class=KeyValueStoreClient, - memory_storage_client=self._memory_storage_client, - id=self.id, - name=self.name, - ) - - if existing_store_by_id is None: - raise_on_non_existing_storage(StorageTypes.KEY_VALUE_STORE, self.id) - - record = await self._get_record_internal(key) - - if not record: - raise ValueError(f'Record with key "{key}" was not found.') - - resource_dir = existing_store_by_id.resource_directory - record_filename = self._filename_from_record(record) - record_path = os.path.join(resource_dir, record_filename) - return f'file://{record_path}' - - async def persist_record(self, record: KeyValueStoreRecord) -> None: - """Persist the specified record to the key-value store.""" - store_directory = self.resource_directory - record_filename = self._filename_from_record(record) - record.filename = record_filename - record.content_type = record.content_type or 'application/octet-stream' - - # Ensure the directory for the entity exists - await asyncio.to_thread(os.makedirs, store_directory, exist_ok=True) - - # Create files for the record - record_path = os.path.join(store_directory, record_filename) - record_metadata_path = os.path.join(store_directory, f'{record_filename}.__metadata__.json') - - # Convert to bytes if string - if isinstance(record.value, str): - record.value = record.value.encode('utf-8') - - f = await asyncio.to_thread(open, record_path, mode='wb') - try: - await asyncio.to_thread(f.write, record.value) - finally: - await asyncio.to_thread(f.close) - - if self._memory_storage_client.write_metadata: - metadata_f = await asyncio.to_thread(open, record_metadata_path, mode='wb') - - try: - record_metadata = KeyValueStoreRecordMetadata(key=record.key, content_type=record.content_type) - await asyncio.to_thread(metadata_f.write, record_metadata.model_dump_json(indent=2).encode('utf-8')) - finally: - await asyncio.to_thread(metadata_f.close) - - async def delete_persisted_record(self, record: KeyValueStoreRecord) -> None: - """Delete the specified record from the key-value store.""" - store_directory = self.resource_directory - record_filename = self._filename_from_record(record) - - # Ensure the directory for the entity exists - await asyncio.to_thread(os.makedirs, store_directory, exist_ok=True) - - # Create files for the record - record_path = os.path.join(store_directory, record_filename) - record_metadata_path = os.path.join(store_directory, record_filename + '.__metadata__.json') - - await force_remove(record_path) - await force_remove(record_metadata_path) - - async def update_timestamps(self, *, has_been_modified: bool) -> None: - """Update the timestamps of the key-value store.""" - self._accessed_at = datetime.now(timezone.utc) - - if has_been_modified: - self._modified_at = datetime.now(timezone.utc) - - await persist_metadata_if_enabled( - data=self.resource_info.model_dump(), - entity_directory=self.resource_directory, - write_metadata=self._memory_storage_client.write_metadata, - ) - - async def _get_record_internal( + async def _update_metadata( self, - key: str, *, - as_bytes: bool = False, - ) -> KeyValueStoreRecord | None: - # Check by id - existing_store_by_id = find_or_create_client_by_id_or_name_inner( - resource_client_class=KeyValueStoreClient, - memory_storage_client=self._memory_storage_client, - id=self.id, - name=self.name, - ) - - if existing_store_by_id is None: - raise_on_non_existing_storage(StorageTypes.KEY_VALUE_STORE, self.id) - - stored_record = existing_store_by_id.records.get(key) - - if stored_record is None: - return None - - record = KeyValueStoreRecord( - key=stored_record.key, - value=stored_record.value, - content_type=stored_record.content_type, - filename=stored_record.filename, - ) - - if not as_bytes: - try: - record.value = maybe_parse_body(record.value, str(record.content_type)) - except ValueError: - logger.exception('Error parsing key-value store record') - - async with existing_store_by_id.file_operation_lock: - await existing_store_by_id.update_timestamps(has_been_modified=False) - - return record - - def _filename_from_record(self, record: KeyValueStoreRecord) -> str: - if record.filename is not None: - return record.filename - - if not record.content_type or record.content_type == 'application/octet-stream': - return record.key - - extension = determine_file_extension(record.content_type) - - if record.key.endswith(f'.{extension}'): - return record.key - - return f'{record.key}.{extension}' + update_accessed_at: bool = False, + update_modified_at: bool = False, + ) -> None: + """Update the key-value store metadata with current information. + + Args: + update_accessed_at: If True, update the `accessed_at` timestamp to the current time. + update_modified_at: If True, update the `modified_at` timestamp to the current time. + """ + now = datetime.now(timezone.utc) + + if update_accessed_at: + self._metadata.accessed_at = now + if update_modified_at: + self._metadata.modified_at = now diff --git a/src/crawlee/storage_clients/_memory/_key_value_store_collection_client.py b/src/crawlee/storage_clients/_memory/_key_value_store_collection_client.py deleted file mode 100644 index 939780449f..0000000000 --- a/src/crawlee/storage_clients/_memory/_key_value_store_collection_client.py +++ /dev/null @@ -1,62 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -from typing_extensions import override - -from crawlee.storage_clients._base import KeyValueStoreCollectionClient as BaseKeyValueStoreCollectionClient -from crawlee.storage_clients.models import KeyValueStoreListPage, KeyValueStoreMetadata - -from ._creation_management import get_or_create_inner -from ._key_value_store_client import KeyValueStoreClient - -if TYPE_CHECKING: - from ._memory_storage_client import MemoryStorageClient - - -class KeyValueStoreCollectionClient(BaseKeyValueStoreCollectionClient): - """Subclient for manipulating key-value stores.""" - - def __init__(self, *, memory_storage_client: MemoryStorageClient) -> None: - self._memory_storage_client = memory_storage_client - - @property - def _storage_client_cache(self) -> list[KeyValueStoreClient]: - return self._memory_storage_client.key_value_stores_handled - - @override - async def get_or_create( - self, - *, - name: str | None = None, - schema: dict | None = None, - id: str | None = None, - ) -> KeyValueStoreMetadata: - resource_client = await get_or_create_inner( - memory_storage_client=self._memory_storage_client, - storage_client_cache=self._storage_client_cache, - resource_client_class=KeyValueStoreClient, - name=name, - id=id, - ) - return resource_client.resource_info - - @override - async def list( - self, - *, - unnamed: bool = False, - limit: int | None = None, - offset: int | None = None, - desc: bool = False, - ) -> KeyValueStoreListPage: - items = [storage.resource_info for storage in self._storage_client_cache] - - return KeyValueStoreListPage( - total=len(items), - count=len(items), - offset=0, - limit=len(items), - desc=False, - items=sorted(items, key=lambda item: item.created_at), - ) diff --git a/src/crawlee/storage_clients/_memory/_memory_storage_client.py b/src/crawlee/storage_clients/_memory/_memory_storage_client.py deleted file mode 100644 index 8000f41274..0000000000 --- a/src/crawlee/storage_clients/_memory/_memory_storage_client.py +++ /dev/null @@ -1,358 +0,0 @@ -from __future__ import annotations - -import asyncio -import contextlib -import os -import shutil -from logging import getLogger -from pathlib import Path -from typing import TYPE_CHECKING, TypeVar - -from typing_extensions import override - -from crawlee._utils.docs import docs_group -from crawlee.configuration import Configuration -from crawlee.storage_clients import StorageClient - -from ._dataset_client import DatasetClient -from ._dataset_collection_client import DatasetCollectionClient -from ._key_value_store_client import KeyValueStoreClient -from ._key_value_store_collection_client import KeyValueStoreCollectionClient -from ._request_queue_client import RequestQueueClient -from ._request_queue_collection_client import RequestQueueCollectionClient - -if TYPE_CHECKING: - from crawlee.storage_clients._base import ResourceClient - - -TResourceClient = TypeVar('TResourceClient', DatasetClient, KeyValueStoreClient, RequestQueueClient) - -logger = getLogger(__name__) - - -@docs_group('Classes') -class MemoryStorageClient(StorageClient): - """Represents an in-memory storage client for managing datasets, key-value stores, and request queues. - - It emulates in-memory storage similar to the Apify platform, supporting both in-memory and local file system-based - persistence. - - The behavior of the storage, such as data persistence and metadata writing, can be customized via initialization - parameters or environment variables. - """ - - _MIGRATING_KEY_VALUE_STORE_DIR_NAME = '__CRAWLEE_MIGRATING_KEY_VALUE_STORE' - """Name of the directory used to temporarily store files during the migration of the default key-value store.""" - - _TEMPORARY_DIR_NAME = '__CRAWLEE_TEMPORARY' - """Name of the directory used to temporarily store files during purges.""" - - _DATASETS_DIR_NAME = 'datasets' - """Name of the directory containing datasets.""" - - _KEY_VALUE_STORES_DIR_NAME = 'key_value_stores' - """Name of the directory containing key-value stores.""" - - _REQUEST_QUEUES_DIR_NAME = 'request_queues' - """Name of the directory containing request queues.""" - - def __init__( - self, - *, - write_metadata: bool, - persist_storage: bool, - storage_dir: str, - default_request_queue_id: str, - default_key_value_store_id: str, - default_dataset_id: str, - ) -> None: - """Initialize a new instance. - - In most cases, you should use the `from_config` constructor to create a new instance based on - the provided configuration. - - Args: - write_metadata: Whether to write metadata to the storage. - persist_storage: Whether to persist the storage. - storage_dir: Path to the storage directory. - default_request_queue_id: The default request queue ID. - default_key_value_store_id: The default key-value store ID. - default_dataset_id: The default dataset ID. - """ - # Set the internal attributes. - self._write_metadata = write_metadata - self._persist_storage = persist_storage - self._storage_dir = storage_dir - self._default_request_queue_id = default_request_queue_id - self._default_key_value_store_id = default_key_value_store_id - self._default_dataset_id = default_dataset_id - - self.datasets_handled: list[DatasetClient] = [] - self.key_value_stores_handled: list[KeyValueStoreClient] = [] - self.request_queues_handled: list[RequestQueueClient] = [] - - self._purged_on_start = False # Indicates whether a purge was already performed on this instance. - self._purge_lock = asyncio.Lock() - - @classmethod - def from_config(cls, config: Configuration | None = None) -> MemoryStorageClient: - """Initialize a new instance based on the provided `Configuration`. - - Args: - config: The `Configuration` instance. Uses the global (default) one if not provided. - """ - config = config or Configuration.get_global_configuration() - - return cls( - write_metadata=config.write_metadata, - persist_storage=config.persist_storage, - storage_dir=config.storage_dir, - default_request_queue_id=config.default_request_queue_id, - default_key_value_store_id=config.default_key_value_store_id, - default_dataset_id=config.default_dataset_id, - ) - - @property - def write_metadata(self) -> bool: - """Whether to write metadata to the storage.""" - return self._write_metadata - - @property - def persist_storage(self) -> bool: - """Whether to persist the storage.""" - return self._persist_storage - - @property - def storage_dir(self) -> str: - """Path to the storage directory.""" - return self._storage_dir - - @property - def datasets_directory(self) -> str: - """Path to the directory containing datasets.""" - return os.path.join(self.storage_dir, self._DATASETS_DIR_NAME) - - @property - def key_value_stores_directory(self) -> str: - """Path to the directory containing key-value stores.""" - return os.path.join(self.storage_dir, self._KEY_VALUE_STORES_DIR_NAME) - - @property - def request_queues_directory(self) -> str: - """Path to the directory containing request queues.""" - return os.path.join(self.storage_dir, self._REQUEST_QUEUES_DIR_NAME) - - @override - def dataset(self, id: str) -> DatasetClient: - return DatasetClient(memory_storage_client=self, id=id) - - @override - def datasets(self) -> DatasetCollectionClient: - return DatasetCollectionClient(memory_storage_client=self) - - @override - def key_value_store(self, id: str) -> KeyValueStoreClient: - return KeyValueStoreClient(memory_storage_client=self, id=id) - - @override - def key_value_stores(self) -> KeyValueStoreCollectionClient: - return KeyValueStoreCollectionClient(memory_storage_client=self) - - @override - def request_queue(self, id: str) -> RequestQueueClient: - return RequestQueueClient(memory_storage_client=self, id=id) - - @override - def request_queues(self) -> RequestQueueCollectionClient: - return RequestQueueCollectionClient(memory_storage_client=self) - - @override - async def purge_on_start(self) -> None: - # Optimistic, non-blocking check - if self._purged_on_start is True: - logger.debug('Storage was already purged on start.') - return - - async with self._purge_lock: - # Another check under the lock just to be sure - if self._purged_on_start is True: - # Mypy doesn't understand that the _purged_on_start can change while we're getting the async lock - return # type: ignore[unreachable] - - await self._purge_default_storages() - self._purged_on_start = True - - def get_cached_resource_client( - self, - resource_client_class: type[TResourceClient], - id: str | None, - name: str | None, - ) -> TResourceClient | None: - """Try to return a resource client from the internal cache.""" - if issubclass(resource_client_class, DatasetClient): - cache = self.datasets_handled - elif issubclass(resource_client_class, KeyValueStoreClient): - cache = self.key_value_stores_handled - elif issubclass(resource_client_class, RequestQueueClient): - cache = self.request_queues_handled - else: - return None - - for storage_client in cache: - if storage_client.id == id or ( - storage_client.name and name and storage_client.name.lower() == name.lower() - ): - return storage_client - - return None - - def add_resource_client_to_cache(self, resource_client: ResourceClient) -> None: - """Add a new resource client to the internal cache.""" - if isinstance(resource_client, DatasetClient): - self.datasets_handled.append(resource_client) - if isinstance(resource_client, KeyValueStoreClient): - self.key_value_stores_handled.append(resource_client) - if isinstance(resource_client, RequestQueueClient): - self.request_queues_handled.append(resource_client) - - async def _purge_default_storages(self) -> None: - """Clean up the storage directories, preparing the environment for a new run. - - It aims to remove residues from previous executions to avoid data contamination between runs. - - It specifically targets: - - The local directory containing the default dataset. - - All records from the default key-value store in the local directory, except for the 'INPUT' key. - - The local directory containing the default request queue. - """ - # Key-value stores - if await asyncio.to_thread(os.path.exists, self.key_value_stores_directory): - key_value_store_folders = await asyncio.to_thread(os.scandir, self.key_value_stores_directory) - for key_value_store_folder in key_value_store_folders: - if key_value_store_folder.name.startswith( - self._TEMPORARY_DIR_NAME - ) or key_value_store_folder.name.startswith('__OLD'): - await self._batch_remove_files(key_value_store_folder.path) - elif key_value_store_folder.name == self._default_key_value_store_id: - await self._handle_default_key_value_store(key_value_store_folder.path) - - # Datasets - if await asyncio.to_thread(os.path.exists, self.datasets_directory): - dataset_folders = await asyncio.to_thread(os.scandir, self.datasets_directory) - for dataset_folder in dataset_folders: - if dataset_folder.name == self._default_dataset_id or dataset_folder.name.startswith( - self._TEMPORARY_DIR_NAME - ): - await self._batch_remove_files(dataset_folder.path) - - # Request queues - if await asyncio.to_thread(os.path.exists, self.request_queues_directory): - request_queue_folders = await asyncio.to_thread(os.scandir, self.request_queues_directory) - for request_queue_folder in request_queue_folders: - if request_queue_folder.name == self._default_request_queue_id or request_queue_folder.name.startswith( - self._TEMPORARY_DIR_NAME - ): - await self._batch_remove_files(request_queue_folder.path) - - async def _handle_default_key_value_store(self, folder: str) -> None: - """Manage the cleanup of the default key-value store. - - It removes all files to ensure a clean state except for a set of predefined input keys (`possible_input_keys`). - - Args: - folder: Path to the default key-value store directory to clean. - """ - folder_exists = await asyncio.to_thread(os.path.exists, folder) - temporary_path = os.path.normpath(os.path.join(folder, '..', self._MIGRATING_KEY_VALUE_STORE_DIR_NAME)) - - # For optimization, we want to only attempt to copy a few files from the default key-value store - possible_input_keys = [ - 'INPUT', - 'INPUT.json', - 'INPUT.bin', - 'INPUT.txt', - ] - - if folder_exists: - # Create a temporary folder to save important files in - Path(temporary_path).mkdir(parents=True, exist_ok=True) - - # Go through each file and save the ones that are important - for entity in possible_input_keys: - original_file_path = os.path.join(folder, entity) - temp_file_path = os.path.join(temporary_path, entity) - with contextlib.suppress(Exception): - await asyncio.to_thread(os.rename, original_file_path, temp_file_path) - - # Remove the original folder and all its content - counter = 0 - temp_path_for_old_folder = os.path.normpath(os.path.join(folder, f'../__OLD_DEFAULT_{counter}__')) - done = False - try: - while not done: - await asyncio.to_thread(os.rename, folder, temp_path_for_old_folder) - done = True - except Exception: - counter += 1 - temp_path_for_old_folder = os.path.normpath(os.path.join(folder, f'../__OLD_DEFAULT_{counter}__')) - - # Replace the temporary folder with the original folder - await asyncio.to_thread(os.rename, temporary_path, folder) - - # Remove the old folder - await self._batch_remove_files(temp_path_for_old_folder) - - async def _batch_remove_files(self, folder: str, counter: int = 0) -> None: - """Remove a folder and its contents in batches to minimize blocking time. - - This method first renames the target folder to a temporary name, then deletes the temporary folder, - allowing the file system operations to proceed without hindering other asynchronous tasks. - - Args: - folder: The directory path to remove. - counter: A counter used for generating temporary directory names in case of conflicts. - """ - folder_exists = await asyncio.to_thread(os.path.exists, folder) - - if folder_exists: - temporary_folder = ( - folder - if os.path.basename(folder).startswith(f'{self._TEMPORARY_DIR_NAME}_') - else os.path.normpath(os.path.join(folder, '..', f'{self._TEMPORARY_DIR_NAME}_{counter}')) - ) - - try: - # Rename the old folder to the new one to allow background deletions - await asyncio.to_thread(os.rename, folder, temporary_folder) - except Exception: - # Folder exists already, try again with an incremented counter - return await self._batch_remove_files(folder, counter + 1) - - await asyncio.to_thread(shutil.rmtree, temporary_folder, ignore_errors=True) - return None - - def _get_default_storage_id(self, storage_client_class: type[TResourceClient]) -> str: - """Get the default storage ID based on the storage class.""" - if issubclass(storage_client_class, DatasetClient): - return self._default_dataset_id - - if issubclass(storage_client_class, KeyValueStoreClient): - return self._default_key_value_store_id - - if issubclass(storage_client_class, RequestQueueClient): - return self._default_request_queue_id - - raise ValueError(f'Invalid storage class: {storage_client_class.__name__}') - - def _get_storage_dir(self, storage_client_class: type[TResourceClient]) -> str: - """Get the storage directory based on the storage class.""" - if issubclass(storage_client_class, DatasetClient): - return self.datasets_directory - - if issubclass(storage_client_class, KeyValueStoreClient): - return self.key_value_stores_directory - - if issubclass(storage_client_class, RequestQueueClient): - return self.request_queues_directory - - raise ValueError(f'Invalid storage class: {storage_client_class.__name__}') diff --git a/src/crawlee/storage_clients/_memory/_request_queue_client.py b/src/crawlee/storage_clients/_memory/_request_queue_client.py index 477d53df07..ad166e20bd 100644 --- a/src/crawlee/storage_clients/_memory/_request_queue_client.py +++ b/src/crawlee/storage_clients/_memory/_request_queue_client.py @@ -1,558 +1,337 @@ from __future__ import annotations -import asyncio -import os -import shutil +from collections import deque +from contextlib import suppress from datetime import datetime, timezone -from decimal import Decimal from logging import getLogger from typing import TYPE_CHECKING -from sortedcollections import ValueSortedDict from typing_extensions import override -from crawlee._types import StorageTypes +from crawlee import Request from crawlee._utils.crypto import crypto_random_object_id -from crawlee._utils.data_processing import raise_on_duplicate_storage, raise_on_non_existing_storage -from crawlee._utils.file import force_remove, force_rename, json_dumps -from crawlee._utils.requests import unique_key_to_request_id -from crawlee.storage_clients._base import RequestQueueClient as BaseRequestQueueClient -from crawlee.storage_clients.models import ( - BatchRequestsOperationResponse, - InternalRequest, - ProcessedRequest, - ProlongRequestLockResponse, - RequestQueueHead, - RequestQueueHeadWithLocks, - RequestQueueMetadata, - UnprocessedRequest, -) - -from ._creation_management import find_or_create_client_by_id_or_name_inner, persist_metadata_if_enabled +from crawlee.storage_clients._base import RequestQueueClient +from crawlee.storage_clients.models import AddRequestsResponse, ProcessedRequest, RequestQueueMetadata if TYPE_CHECKING: from collections.abc import Sequence - from sortedcontainers import SortedDict - - from crawlee import Request +logger = getLogger(__name__) - from ._memory_storage_client import MemoryStorageClient -logger = getLogger(__name__) +class MemoryRequestQueueClient(RequestQueueClient): + """Memory implementation of the request queue client. + No data is persisted between process runs, which means all requests are lost when the program terminates. + This implementation is primarily useful for testing, development, and short-lived crawler runs where + persistence is not required. -class RequestQueueClient(BaseRequestQueueClient): - """Subclient for manipulating a single request queue.""" + This client provides fast access to request data but is limited by available memory and does not support + data sharing across different processes. + """ def __init__( self, *, - memory_storage_client: MemoryStorageClient, - id: str | None = None, - name: str | None = None, - created_at: datetime | None = None, - accessed_at: datetime | None = None, - modified_at: datetime | None = None, - handled_request_count: int = 0, - pending_request_count: int = 0, + metadata: RequestQueueMetadata, ) -> None: - self._memory_storage_client = memory_storage_client - self.id = id or crypto_random_object_id() - self.name = name - self._created_at = created_at or datetime.now(timezone.utc) - self._accessed_at = accessed_at or datetime.now(timezone.utc) - self._modified_at = modified_at or datetime.now(timezone.utc) - self.handled_request_count = handled_request_count - self.pending_request_count = pending_request_count - - self.requests: SortedDict[str, InternalRequest] = ValueSortedDict( - lambda request: request.order_no or -float('inf') - ) - self.file_operation_lock = asyncio.Lock() - self._last_used_timestamp = Decimal(0) - - self._in_progress = set[str]() - - @property - def resource_info(self) -> RequestQueueMetadata: - """Get the resource info for the request queue client.""" - return RequestQueueMetadata( - id=self.id, - name=self.name, - accessed_at=self._accessed_at, - created_at=self._created_at, - modified_at=self._modified_at, - had_multiple_clients=False, - handled_request_count=self.handled_request_count, - pending_request_count=self.pending_request_count, - stats={}, - total_request_count=len(self.requests), - user_id='1', - resource_directory=self.resource_directory, - ) - - @property - def resource_directory(self) -> str: - """Get the resource directory for the client.""" - return os.path.join(self._memory_storage_client.request_queues_directory, self.name or self.id) + """Initialize a new instance. - @override - async def get(self) -> RequestQueueMetadata | None: - found = find_or_create_client_by_id_or_name_inner( - resource_client_class=RequestQueueClient, - memory_storage_client=self._memory_storage_client, - id=self.id, - name=self.name, - ) + Preferably use the `MemoryRequestQueueClient.open` class method to create a new instance. + """ + self._metadata = metadata - if found: - async with found.file_operation_lock: - await found.update_timestamps(has_been_modified=False) - return found.resource_info + self._pending_requests = deque[Request]() + """Pending requests are those that have been added to the queue but not yet fetched for processing.""" - return None + self._handled_requests = dict[str, Request]() + """Handled requests are those that have been processed and marked as handled.""" - @override - async def update(self, *, name: str | None = None) -> RequestQueueMetadata: - # Check by id - existing_queue_by_id = find_or_create_client_by_id_or_name_inner( - resource_client_class=RequestQueueClient, - memory_storage_client=self._memory_storage_client, - id=self.id, - name=self.name, - ) + self._in_progress_requests = dict[str, Request]() + """In-progress requests are those that have been fetched but not yet marked as handled or reclaimed.""" - if existing_queue_by_id is None: - raise_on_non_existing_storage(StorageTypes.REQUEST_QUEUE, self.id) - - # Skip if no changes - if name is None: - return existing_queue_by_id.resource_info - - async with existing_queue_by_id.file_operation_lock: - # Check that name is not in use already - existing_queue_by_name = next( - ( - queue - for queue in self._memory_storage_client.request_queues_handled - if queue.name and queue.name.lower() == name.lower() - ), - None, - ) + self._requests_by_id = dict[str, Request]() + """ID -> Request mapping for fast lookup by request ID.""" - if existing_queue_by_name is not None: - raise_on_duplicate_storage(StorageTypes.REQUEST_QUEUE, 'name', name) + self._requests_by_unique_key = dict[str, Request]() + """Unique key -> Request mapping for fast lookup by unique key.""" - previous_dir = existing_queue_by_id.resource_directory - existing_queue_by_id.name = name + @override + async def get_metadata(self) -> RequestQueueMetadata: + return self._metadata - await force_rename(previous_dir, existing_queue_by_id.resource_directory) + @classmethod + async def open( + cls, + *, + id: str | None, + name: str | None, + ) -> MemoryRequestQueueClient: + """Open or create a new memory request queue client. - # Update timestamps - await existing_queue_by_id.update_timestamps(has_been_modified=True) + This method creates a new in-memory request queue instance. Unlike persistent storage implementations, + memory queues don't check for existing queues with the same name or ID since all data exists only + in memory and is lost when the process terminates. - return existing_queue_by_id.resource_info + Args: + id: The ID of the request queue. If not provided, a random ID will be generated. + name: The name of the request queue. If not provided, the queue will be unnamed. - @override - async def delete(self) -> None: - queue = next( - (queue for queue in self._memory_storage_client.request_queues_handled if queue.id == self.id), - None, + Returns: + An instance for the opened or created storage client. + """ + # Otherwise create a new queue + queue_id = id or crypto_random_object_id() + now = datetime.now(timezone.utc) + + metadata = RequestQueueMetadata( + id=queue_id, + name=name, + created_at=now, + accessed_at=now, + modified_at=now, + had_multiple_clients=False, + handled_request_count=0, + pending_request_count=0, + stats={}, + total_request_count=0, ) - if queue is not None: - async with queue.file_operation_lock: - self._memory_storage_client.request_queues_handled.remove(queue) - queue.pending_request_count = 0 - queue.handled_request_count = 0 - queue.requests.clear() - - if os.path.exists(queue.resource_directory): - await asyncio.to_thread(shutil.rmtree, queue.resource_directory) + return cls(metadata=metadata) @override - async def list_head(self, *, limit: int | None = None, skip_in_progress: bool = False) -> RequestQueueHead: - existing_queue_by_id = find_or_create_client_by_id_or_name_inner( - resource_client_class=RequestQueueClient, - memory_storage_client=self._memory_storage_client, - id=self.id, - name=self.name, + async def drop(self) -> None: + self._pending_requests.clear() + self._handled_requests.clear() + self._requests_by_id.clear() + self._requests_by_unique_key.clear() + self._in_progress_requests.clear() + + await self._update_metadata( + update_modified_at=True, + update_accessed_at=True, + new_handled_request_count=0, + new_pending_request_count=0, + new_total_request_count=0, ) - if existing_queue_by_id is None: - raise_on_non_existing_storage(StorageTypes.REQUEST_QUEUE, self.id) - - async with existing_queue_by_id.file_operation_lock: - await existing_queue_by_id.update_timestamps(has_been_modified=False) - - requests: list[Request] = [] - - # Iterate all requests in the queue which have sorted key larger than infinity, which means - # `order_no` is not `None`. This will iterate them in order of `order_no`. - for request_key in existing_queue_by_id.requests.irange_key( # type: ignore[attr-defined] # irange_key is a valid SortedDict method but not recognized by mypy - min_key=-float('inf'), inclusive=(False, True) - ): - if len(requests) == limit: - break - - if skip_in_progress and request_key in existing_queue_by_id._in_progress: # noqa: SLF001 - continue - internal_request = existing_queue_by_id.requests.get(request_key) - - # Check that the request still exists and was not handled, - # in case something deleted it or marked it as handled concurrenctly - if internal_request and not internal_request.handled_at: - requests.append(internal_request.to_request()) - - return RequestQueueHead( - limit=limit, - had_multiple_clients=False, - queue_modified_at=existing_queue_by_id._modified_at, # noqa: SLF001 - items=requests, - ) - @override - async def list_and_lock_head(self, *, lock_secs: int, limit: int | None = None) -> RequestQueueHeadWithLocks: - existing_queue_by_id = find_or_create_client_by_id_or_name_inner( - resource_client_class=RequestQueueClient, - memory_storage_client=self._memory_storage_client, - id=self.id, - name=self.name, - ) - - if existing_queue_by_id is None: - raise_on_non_existing_storage(StorageTypes.REQUEST_QUEUE, self.id) - - result = await self.list_head(limit=limit, skip_in_progress=True) - - for item in result.items: - existing_queue_by_id._in_progress.add(item.id) # noqa: SLF001 - - return RequestQueueHeadWithLocks( - queue_has_locked_requests=len(existing_queue_by_id._in_progress) > 0, # noqa: SLF001 - lock_secs=lock_secs, - limit=result.limit, - had_multiple_clients=result.had_multiple_clients, - queue_modified_at=result.queue_modified_at, - items=result.items, + async def purge(self) -> None: + self._pending_requests.clear() + self._handled_requests.clear() + self._requests_by_id.clear() + self._requests_by_unique_key.clear() + self._in_progress_requests.clear() + + await self._update_metadata( + update_modified_at=True, + update_accessed_at=True, + new_pending_request_count=0, ) @override - async def add_request( + async def add_batch_of_requests( self, - request: Request, + requests: Sequence[Request], *, forefront: bool = False, - ) -> ProcessedRequest: - existing_queue_by_id = find_or_create_client_by_id_or_name_inner( - resource_client_class=RequestQueueClient, - memory_storage_client=self._memory_storage_client, - id=self.id, - name=self.name, - ) - - if existing_queue_by_id is None: - raise_on_non_existing_storage(StorageTypes.REQUEST_QUEUE, self.id) + ) -> AddRequestsResponse: + processed_requests = [] + for request in requests: + # Check if the request is already in the queue by unique_key. + existing_request = self._requests_by_unique_key.get(request.unique_key) - internal_request = await self._create_internal_request(request, forefront) + was_already_present = existing_request is not None + was_already_handled = was_already_present and existing_request and existing_request.handled_at is not None - async with existing_queue_by_id.file_operation_lock: - existing_internal_request_with_id = existing_queue_by_id.requests.get(internal_request.id) + # If the request is already in the queue and handled, don't add it again. + if was_already_handled: + processed_requests.append( + ProcessedRequest( + id=request.id, + unique_key=request.unique_key, + was_already_present=True, + was_already_handled=True, + ) + ) + continue + + # If the request is already in the queue but not handled, update it. + if was_already_present and existing_request: + # Update the existing request with any new data and + # remove old request from pending queue if it's there. + with suppress(ValueError): + self._pending_requests.remove(existing_request) + + # Update indexes. + self._requests_by_id[request.id] = request + self._requests_by_unique_key[request.unique_key] = request + + # Add updated request back to queue. + if forefront: + self._pending_requests.appendleft(request) + else: + self._pending_requests.append(request) + # Add the new request to the queue. + else: + if forefront: + self._pending_requests.appendleft(request) + else: + self._pending_requests.append(request) - # We already have the request present, so we return information about it - if existing_internal_request_with_id is not None: - await existing_queue_by_id.update_timestamps(has_been_modified=False) + # Update indexes. + self._requests_by_id[request.id] = request + self._requests_by_unique_key[request.unique_key] = request - return ProcessedRequest( - id=internal_request.id, - unique_key=internal_request.unique_key, - was_already_present=True, - was_already_handled=existing_internal_request_with_id.handled_at is not None, + await self._update_metadata( + new_total_request_count=self._metadata.total_request_count + 1, + new_pending_request_count=self._metadata.pending_request_count + 1, ) - existing_queue_by_id.requests[internal_request.id] = internal_request - if internal_request.handled_at: - existing_queue_by_id.handled_request_count += 1 - else: - existing_queue_by_id.pending_request_count += 1 - await existing_queue_by_id.update_timestamps(has_been_modified=True) - await self._persist_single_request_to_storage( - request=internal_request, - entity_directory=existing_queue_by_id.resource_directory, - persist_storage=self._memory_storage_client.persist_storage, + processed_requests.append( + ProcessedRequest( + id=request.id, + unique_key=request.unique_key, + was_already_present=was_already_present, + was_already_handled=False, + ) ) - # We return was_already_handled=False even though the request may have been added as handled, - # because that's how API behaves. - return ProcessedRequest( - id=internal_request.id, - unique_key=internal_request.unique_key, - was_already_present=False, - was_already_handled=False, - ) + await self._update_metadata(update_accessed_at=True, update_modified_at=True) - @override - async def get_request(self, request_id: str) -> Request | None: - existing_queue_by_id = find_or_create_client_by_id_or_name_inner( - resource_client_class=RequestQueueClient, - memory_storage_client=self._memory_storage_client, - id=self.id, - name=self.name, + return AddRequestsResponse( + processed_requests=processed_requests, + unprocessed_requests=[], ) - if existing_queue_by_id is None: - raise_on_non_existing_storage(StorageTypes.REQUEST_QUEUE, self.id) - - async with existing_queue_by_id.file_operation_lock: - await existing_queue_by_id.update_timestamps(has_been_modified=False) - - internal_request = existing_queue_by_id.requests.get(request_id) - return internal_request.to_request() if internal_request else None - @override - async def update_request( - self, - request: Request, - *, - forefront: bool = False, - ) -> ProcessedRequest: - existing_queue_by_id = find_or_create_client_by_id_or_name_inner( - resource_client_class=RequestQueueClient, - memory_storage_client=self._memory_storage_client, - id=self.id, - name=self.name, - ) - - if existing_queue_by_id is None: - raise_on_non_existing_storage(StorageTypes.REQUEST_QUEUE, self.id) + async def fetch_next_request(self) -> Request | None: + while self._pending_requests: + request = self._pending_requests.popleft() - internal_request = await self._create_internal_request(request, forefront) + # Skip if already handled (shouldn't happen, but safety check). + if request.was_already_handled: + continue - # First we need to check the existing request to be able to return information about its handled state. - existing_internal_request = existing_queue_by_id.requests.get(internal_request.id) + # Skip if already in progress (shouldn't happen, but safety check). + if request.id in self._in_progress_requests: + self._pending_requests.appendleft(request) + break - # Undefined means that the request is not present in the queue. - # We need to insert it, to behave the same as API. - if existing_internal_request is None: - return await self.add_request(request, forefront=forefront) + # Mark as in progress. + self._in_progress_requests[request.id] = request + return request - async with existing_queue_by_id.file_operation_lock: - # When updating the request, we need to make sure that - # the handled counts are updated correctly in all cases. - existing_queue_by_id.requests[internal_request.id] = internal_request + return None - pending_count_adjustment = 0 - is_request_handled_state_changing = existing_internal_request.handled_at != internal_request.handled_at + @override + async def get_request(self, request_id: str) -> Request | None: + await self._update_metadata(update_accessed_at=True) + return self._requests_by_id.get(request_id) - request_was_handled_before_update = existing_internal_request.handled_at is not None + @override + async def mark_request_as_handled(self, request: Request) -> ProcessedRequest | None: + # Check if the request is in progress. + if request.id not in self._in_progress_requests: + return None - # We add 1 pending request if previous state was handled - if is_request_handled_state_changing: - pending_count_adjustment = 1 if request_was_handled_before_update else -1 + # Set handled_at timestamp if not already set. + if not request.was_already_handled: + request.handled_at = datetime.now(timezone.utc) - existing_queue_by_id.pending_request_count += pending_count_adjustment - existing_queue_by_id.handled_request_count -= pending_count_adjustment - await existing_queue_by_id.update_timestamps(has_been_modified=True) - await self._persist_single_request_to_storage( - request=internal_request, - entity_directory=existing_queue_by_id.resource_directory, - persist_storage=self._memory_storage_client.persist_storage, - ) + # Move request to handled storage. + self._handled_requests[request.id] = request - if request.handled_at is not None: - existing_queue_by_id._in_progress.discard(request.id) # noqa: SLF001 + # Update indexes (keep the request in indexes for get_request to work). + self._requests_by_id[request.id] = request + self._requests_by_unique_key[request.unique_key] = request - return ProcessedRequest( - id=internal_request.id, - unique_key=internal_request.unique_key, - was_already_present=True, - was_already_handled=request_was_handled_before_update, - ) + # Remove from in-progress. + del self._in_progress_requests[request.id] - @override - async def delete_request(self, request_id: str) -> None: - existing_queue_by_id = find_or_create_client_by_id_or_name_inner( - resource_client_class=RequestQueueClient, - memory_storage_client=self._memory_storage_client, - id=self.id, - name=self.name, + # Update metadata. + await self._update_metadata( + new_handled_request_count=self._metadata.handled_request_count + 1, + new_pending_request_count=self._metadata.pending_request_count - 1, + update_modified_at=True, ) - if existing_queue_by_id is None: - raise_on_non_existing_storage(StorageTypes.REQUEST_QUEUE, self.id) - - async with existing_queue_by_id.file_operation_lock: - internal_request = existing_queue_by_id.requests.get(request_id) - - if internal_request: - del existing_queue_by_id.requests[request_id] - if internal_request.handled_at: - existing_queue_by_id.handled_request_count -= 1 - else: - existing_queue_by_id.pending_request_count -= 1 - await existing_queue_by_id.update_timestamps(has_been_modified=True) - await self._delete_request_file_from_storage( - entity_directory=existing_queue_by_id.resource_directory, - request_id=request_id, - ) - - @override - async def prolong_request_lock( - self, - request_id: str, - *, - forefront: bool = False, - lock_secs: int, - ) -> ProlongRequestLockResponse: - return ProlongRequestLockResponse(lock_expires_at=datetime.now(timezone.utc)) + return ProcessedRequest( + id=request.id, + unique_key=request.unique_key, + was_already_present=True, + was_already_handled=True, + ) @override - async def delete_request_lock( + async def reclaim_request( self, - request_id: str, + request: Request, *, forefront: bool = False, - ) -> None: - existing_queue_by_id = find_or_create_client_by_id_or_name_inner( - resource_client_class=RequestQueueClient, - memory_storage_client=self._memory_storage_client, - id=self.id, - name=self.name, - ) + ) -> ProcessedRequest | None: + # Check if the request is in progress. + if request.id not in self._in_progress_requests: + return None - if existing_queue_by_id is None: - raise_on_non_existing_storage(StorageTypes.REQUEST_QUEUE, self.id) + # Remove from in-progress. + del self._in_progress_requests[request.id] - existing_queue_by_id._in_progress.discard(request_id) # noqa: SLF001 + # Add request back to pending queue. + if forefront: + self._pending_requests.appendleft(request) + else: + self._pending_requests.append(request) - @override - async def batch_add_requests( - self, - requests: Sequence[Request], - *, - forefront: bool = False, - ) -> BatchRequestsOperationResponse: - processed_requests = list[ProcessedRequest]() - unprocessed_requests = list[UnprocessedRequest]() + # Update metadata timestamps. + await self._update_metadata(update_modified_at=True) - for request in requests: - try: - processed_request = await self.add_request(request, forefront=forefront) - processed_requests.append( - ProcessedRequest( - id=processed_request.id, - unique_key=processed_request.unique_key, - was_already_present=processed_request.was_already_present, - was_already_handled=processed_request.was_already_handled, - ) - ) - except Exception as exc: # noqa: PERF203 - logger.warning(f'Error adding request to the queue: {exc}') - unprocessed_requests.append( - UnprocessedRequest( - unique_key=request.unique_key, - url=request.url, - method=request.method, - ) - ) - - return BatchRequestsOperationResponse( - processed_requests=processed_requests, - unprocessed_requests=unprocessed_requests, + return ProcessedRequest( + id=request.id, + unique_key=request.unique_key, + was_already_present=True, + was_already_handled=False, ) @override - async def batch_delete_requests(self, requests: list[Request]) -> BatchRequestsOperationResponse: - raise NotImplementedError('This method is not supported in memory storage.') + async def is_empty(self) -> bool: + """Check if the queue is empty. - async def update_timestamps(self, *, has_been_modified: bool) -> None: - """Update the timestamps of the request queue.""" - self._accessed_at = datetime.now(timezone.utc) - - if has_been_modified: - self._modified_at = datetime.now(timezone.utc) + Returns: + True if the queue is empty, False otherwise. + """ + await self._update_metadata(update_accessed_at=True) - await persist_metadata_if_enabled( - data=self.resource_info.model_dump(), - entity_directory=self.resource_directory, - write_metadata=self._memory_storage_client.write_metadata, - ) + # Queue is empty if there are no pending requests and no requests in progress. + return len(self._pending_requests) == 0 and len(self._in_progress_requests) == 0 - async def _persist_single_request_to_storage( + async def _update_metadata( self, *, - request: InternalRequest, - entity_directory: str, - persist_storage: bool, + update_accessed_at: bool = False, + update_modified_at: bool = False, + new_handled_request_count: int | None = None, + new_pending_request_count: int | None = None, + new_total_request_count: int | None = None, ) -> None: - """Update or writes a single request item to the disk. - - This function writes a given request dictionary to a JSON file, named after the request's ID, - within a specified directory. The writing process is skipped if `persist_storage` is False. - Before writing, it ensures that the target directory exists, creating it if necessary. - - Args: - request: The dictionary containing the request data. - entity_directory: The directory path where the request file should be stored. - persist_storage: A boolean flag indicating whether the request should be persisted to the disk. - """ - # Skip writing files to the disk if the client has the option set to false - if not persist_storage: - return - - # Ensure the directory for the entity exists - await asyncio.to_thread(os.makedirs, entity_directory, exist_ok=True) - - # Write the request to the file - file_path = os.path.join(entity_directory, f'{request.id}.json') - f = await asyncio.to_thread(open, file_path, mode='w', encoding='utf-8') - try: - s = await json_dumps(request.model_dump()) - await asyncio.to_thread(f.write, s) - finally: - f.close() - - async def _delete_request_file_from_storage(self, *, request_id: str, entity_directory: str) -> None: - """Delete a specific request item from the disk. - - This function removes a file representing a request, identified by the request's ID, from a - specified directory. Before attempting to remove the file, it ensures that the target directory - exists, creating it if necessary. + """Update the request queue metadata with current information. Args: - request_id: The identifier of the request to be deleted. - entity_directory: The directory path where the request file is stored. + update_accessed_at: If True, update the `accessed_at` timestamp to the current time. + update_modified_at: If True, update the `modified_at` timestamp to the current time. + new_handled_request_count: If provided, set the handled request count to this value. + new_pending_request_count: If provided, set the pending request count to this value. + new_total_request_count: If provided, set the total request count to this value. """ - # Ensure the directory for the entity exists - await asyncio.to_thread(os.makedirs, entity_directory, exist_ok=True) - - file_path = os.path.join(entity_directory, f'{request_id}.json') - await force_remove(file_path) - - async def _create_internal_request(self, request: Request, forefront: bool | None) -> InternalRequest: - order_no = self._calculate_order_no(request, forefront) - id = unique_key_to_request_id(request.unique_key) - - if request.id is not None and request.id != id: - logger.warning( - f'The request ID does not match the ID from the unique_key (request.id={request.id}, id={id}).' - ) - - return InternalRequest.from_request(request=request, id=id, order_no=order_no) - - def _calculate_order_no(self, request: Request, forefront: bool | None) -> Decimal | None: - if request.handled_at is not None: - return None - - # Get the current timestamp in milliseconds - timestamp = Decimal(str(datetime.now(tz=timezone.utc).timestamp())) * Decimal(1000) - timestamp = round(timestamp, 6) - - # Make sure that this timestamp was not used yet, so that we have unique order_nos - if timestamp <= self._last_used_timestamp: - timestamp = self._last_used_timestamp + Decimal('0.000001') - - self._last_used_timestamp = timestamp - - return -timestamp if forefront else timestamp + now = datetime.now(timezone.utc) + + if update_accessed_at: + self._metadata.accessed_at = now + if update_modified_at: + self._metadata.modified_at = now + if new_handled_request_count is not None: + self._metadata.handled_request_count = new_handled_request_count + if new_pending_request_count is not None: + self._metadata.pending_request_count = new_pending_request_count + if new_total_request_count is not None: + self._metadata.total_request_count = new_total_request_count diff --git a/src/crawlee/storage_clients/_memory/_request_queue_collection_client.py b/src/crawlee/storage_clients/_memory/_request_queue_collection_client.py deleted file mode 100644 index 2f2df2be89..0000000000 --- a/src/crawlee/storage_clients/_memory/_request_queue_collection_client.py +++ /dev/null @@ -1,62 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -from typing_extensions import override - -from crawlee.storage_clients._base import RequestQueueCollectionClient as BaseRequestQueueCollectionClient -from crawlee.storage_clients.models import RequestQueueListPage, RequestQueueMetadata - -from ._creation_management import get_or_create_inner -from ._request_queue_client import RequestQueueClient - -if TYPE_CHECKING: - from ._memory_storage_client import MemoryStorageClient - - -class RequestQueueCollectionClient(BaseRequestQueueCollectionClient): - """Subclient for manipulating request queues.""" - - def __init__(self, *, memory_storage_client: MemoryStorageClient) -> None: - self._memory_storage_client = memory_storage_client - - @property - def _storage_client_cache(self) -> list[RequestQueueClient]: - return self._memory_storage_client.request_queues_handled - - @override - async def get_or_create( - self, - *, - name: str | None = None, - schema: dict | None = None, - id: str | None = None, - ) -> RequestQueueMetadata: - resource_client = await get_or_create_inner( - memory_storage_client=self._memory_storage_client, - storage_client_cache=self._storage_client_cache, - resource_client_class=RequestQueueClient, - name=name, - id=id, - ) - return resource_client.resource_info - - @override - async def list( - self, - *, - unnamed: bool = False, - limit: int | None = None, - offset: int | None = None, - desc: bool = False, - ) -> RequestQueueListPage: - items = [storage.resource_info for storage in self._storage_client_cache] - - return RequestQueueListPage( - total=len(items), - count=len(items), - offset=0, - limit=len(items), - desc=False, - items=sorted(items, key=lambda item: item.created_at), - ) diff --git a/src/crawlee/storage_clients/_memory/_storage_client.py b/src/crawlee/storage_clients/_memory/_storage_client.py new file mode 100644 index 0000000000..645294cad7 --- /dev/null +++ b/src/crawlee/storage_clients/_memory/_storage_client.py @@ -0,0 +1,67 @@ +from __future__ import annotations + +from typing_extensions import override + +from crawlee._utils.docs import docs_group +from crawlee.configuration import Configuration +from crawlee.storage_clients._base import StorageClient + +from ._dataset_client import MemoryDatasetClient +from ._key_value_store_client import MemoryKeyValueStoreClient +from ._request_queue_client import MemoryRequestQueueClient + + +@docs_group('Classes') +class MemoryStorageClient(StorageClient): + """Memory implementation of the storage client. + + This storage client provides access to datasets, key-value stores, and request queues that store all data + in memory using Python data structures (lists and dictionaries). No data is persisted between process runs, + meaning all stored data is lost when the program terminates. + + The memory implementation provides fast access to data but is limited by available memory and does not + support data sharing across different processes. All storage operations happen entirely in memory with + no disk operations. + + The memory storage client is useful for testing and development environments, or short-lived crawler + operations where persistence is not required. + """ + + @override + async def create_dataset_client( + self, + *, + id: str | None = None, + name: str | None = None, + configuration: Configuration | None = None, + ) -> MemoryDatasetClient: + configuration = configuration or Configuration.get_global_configuration() + client = await MemoryDatasetClient.open(id=id, name=name) + await self._purge_if_needed(client, configuration) + return client + + @override + async def create_kvs_client( + self, + *, + id: str | None = None, + name: str | None = None, + configuration: Configuration | None = None, + ) -> MemoryKeyValueStoreClient: + configuration = configuration or Configuration.get_global_configuration() + client = await MemoryKeyValueStoreClient.open(id=id, name=name) + await self._purge_if_needed(client, configuration) + return client + + @override + async def create_rq_client( + self, + *, + id: str | None = None, + name: str | None = None, + configuration: Configuration | None = None, + ) -> MemoryRequestQueueClient: + configuration = configuration or Configuration.get_global_configuration() + client = await MemoryRequestQueueClient.open(id=id, name=name) + await self._purge_if_needed(client, configuration) + return client diff --git a/src/crawlee/storage_clients/models.py b/src/crawlee/storage_clients/models.py index f016e24730..3cb5b67b7a 100644 --- a/src/crawlee/storage_clients/models.py +++ b/src/crawlee/storage_clients/models.py @@ -1,14 +1,11 @@ from __future__ import annotations -import json from datetime import datetime -from decimal import Decimal from typing import Annotated, Any, Generic from pydantic import BaseModel, BeforeValidator, ConfigDict, Field from typing_extensions import TypeVar -from crawlee import Request from crawlee._types import HttpMethod from crawlee._utils.docs import docs_group from crawlee._utils.urls import validate_http_url @@ -26,10 +23,19 @@ class StorageMetadata(BaseModel): model_config = ConfigDict(populate_by_name=True, extra='allow') id: Annotated[str, Field(alias='id')] - name: Annotated[str | None, Field(alias='name', default='')] + """The unique identifier of the storage.""" + + name: Annotated[str | None, Field(alias='name', default=None)] + """The name of the storage.""" + accessed_at: Annotated[datetime, Field(alias='accessedAt')] + """The timestamp when the storage was last accessed.""" + created_at: Annotated[datetime, Field(alias='createdAt')] + """The timestamp when the storage was created.""" + modified_at: Annotated[datetime, Field(alias='modifiedAt')] + """The timestamp when the storage was last modified.""" @docs_group('Data structures') @@ -39,6 +45,7 @@ class DatasetMetadata(StorageMetadata): model_config = ConfigDict(populate_by_name=True) item_count: Annotated[int, Field(alias='itemCount')] + """The number of items in the dataset.""" @docs_group('Data structures') @@ -47,8 +54,6 @@ class KeyValueStoreMetadata(StorageMetadata): model_config = ConfigDict(populate_by_name=True) - user_id: Annotated[str, Field(alias='userId')] - @docs_group('Data structures') class RequestQueueMetadata(StorageMetadata): @@ -57,24 +62,19 @@ class RequestQueueMetadata(StorageMetadata): model_config = ConfigDict(populate_by_name=True) had_multiple_clients: Annotated[bool, Field(alias='hadMultipleClients')] - handled_request_count: Annotated[int, Field(alias='handledRequestCount')] - pending_request_count: Annotated[int, Field(alias='pendingRequestCount')] - stats: Annotated[dict, Field(alias='stats')] - total_request_count: Annotated[int, Field(alias='totalRequestCount')] - user_id: Annotated[str, Field(alias='userId')] - resource_directory: Annotated[str, Field(alias='resourceDirectory')] + """Indicates whether the queue has been accessed by multiple clients (consumers).""" + handled_request_count: Annotated[int, Field(alias='handledRequestCount')] + """The number of requests that have been handled from the queue.""" -@docs_group('Data structures') -class KeyValueStoreRecord(BaseModel, Generic[KvsValueType]): - """Model for a key-value store record.""" + pending_request_count: Annotated[int, Field(alias='pendingRequestCount')] + """The number of requests that are still pending in the queue.""" - model_config = ConfigDict(populate_by_name=True) + stats: Annotated[dict, Field(alias='stats')] + """Statistics about the request queue, TODO?""" - key: Annotated[str, Field(alias='key')] - value: Annotated[KvsValueType, Field(alias='value')] - content_type: Annotated[str | None, Field(alias='contentType', default=None)] - filename: Annotated[str | None, Field(alias='filename', default=None)] + total_request_count: Annotated[int, Field(alias='totalRequestCount')] + """The total number of requests that have been added to the queue.""" @docs_group('Data structures') @@ -84,68 +84,34 @@ class KeyValueStoreRecordMetadata(BaseModel): model_config = ConfigDict(populate_by_name=True) key: Annotated[str, Field(alias='key')] - content_type: Annotated[str, Field(alias='contentType')] - - -@docs_group('Data structures') -class KeyValueStoreKeyInfo(BaseModel): - """Model for a key-value store key info.""" - - model_config = ConfigDict(populate_by_name=True) - - key: Annotated[str, Field(alias='key')] - size: Annotated[int, Field(alias='size')] - - -@docs_group('Data structures') -class KeyValueStoreListKeysPage(BaseModel): - """Model for listing keys in the key-value store.""" - - model_config = ConfigDict(populate_by_name=True) - - count: Annotated[int, Field(alias='count')] - limit: Annotated[int, Field(alias='limit')] - is_truncated: Annotated[bool, Field(alias='isTruncated')] - items: Annotated[list[KeyValueStoreKeyInfo], Field(alias='items', default_factory=list)] - exclusive_start_key: Annotated[str | None, Field(alias='exclusiveStartKey', default=None)] - next_exclusive_start_key: Annotated[str | None, Field(alias='nextExclusiveStartKey', default=None)] + """The key of the record. + A unique identifier for the record in the key-value store. + """ -@docs_group('Data structures') -class RequestQueueHeadState(BaseModel): - """Model for the request queue head state.""" + content_type: Annotated[str, Field(alias='contentType')] + """The MIME type of the record. - model_config = ConfigDict(populate_by_name=True) + Describe the format and type of data stored in the record, following the MIME specification. + """ - was_limit_reached: Annotated[bool, Field(alias='wasLimitReached')] - prev_limit: Annotated[int, Field(alias='prevLimit')] - queue_modified_at: Annotated[datetime, Field(alias='queueModifiedAt')] - query_started_at: Annotated[datetime, Field(alias='queryStartedAt')] - had_multiple_clients: Annotated[bool, Field(alias='hadMultipleClients')] + size: Annotated[int | None, Field(alias='size', default=None)] = None + """The size of the record in bytes.""" @docs_group('Data structures') -class RequestQueueHead(BaseModel): - """Model for the request queue head.""" +class KeyValueStoreRecord(KeyValueStoreRecordMetadata, Generic[KvsValueType]): + """Model for a key-value store record.""" model_config = ConfigDict(populate_by_name=True) - limit: Annotated[int | None, Field(alias='limit', default=None)] - had_multiple_clients: Annotated[bool, Field(alias='hadMultipleClients')] - queue_modified_at: Annotated[datetime, Field(alias='queueModifiedAt')] - items: Annotated[list[Request], Field(alias='items', default_factory=list)] + value: Annotated[KvsValueType, Field(alias='value')] + """The value of the record.""" @docs_group('Data structures') -class RequestQueueHeadWithLocks(RequestQueueHead): - """Model for request queue head with locks.""" - - lock_secs: Annotated[int, Field(alias='lockSecs')] - queue_has_locked_requests: Annotated[bool | None, Field(alias='queueHasLockedRequests')] = None - - -class _ListPage(BaseModel): - """Model for a single page of storage items returned from a collection list method.""" +class DatasetItemsListPage(BaseModel): + """Model for a single page of dataset items returned from a collection list method.""" model_config = ConfigDict(populate_by_name=True) @@ -164,48 +130,10 @@ class _ListPage(BaseModel): desc: Annotated[bool, Field(default=False)] """Indicates if the returned list is in descending order.""" - -@docs_group('Data structures') -class DatasetListPage(_ListPage): - """Model for a single page of dataset items returned from a collection list method.""" - - items: Annotated[list[DatasetMetadata], Field(default_factory=list)] - """The list of dataset items returned on this page.""" - - -@docs_group('Data structures') -class KeyValueStoreListPage(_ListPage): - """Model for a single page of key-value store items returned from a collection list method.""" - - items: Annotated[list[KeyValueStoreMetadata], Field(default_factory=list)] - """The list of key-value store items returned on this page.""" - - -@docs_group('Data structures') -class RequestQueueListPage(_ListPage): - """Model for a single page of request queue items returned from a collection list method.""" - - items: Annotated[list[RequestQueueMetadata], Field(default_factory=list)] - """The list of request queue items returned on this page.""" - - -@docs_group('Data structures') -class DatasetItemsListPage(_ListPage): - """Model for a single page of dataset items returned from a collection list method.""" - items: Annotated[list[dict], Field(default_factory=list)] """The list of dataset items returned on this page.""" -@docs_group('Data structures') -class ProlongRequestLockResponse(BaseModel): - """Response to prolong request lock calls.""" - - model_config = ConfigDict(populate_by_name=True) - - lock_expires_at: Annotated[datetime, Field(alias='lockExpiresAt')] - - @docs_group('Data structures') class ProcessedRequest(BaseModel): """Represents a processed request.""" @@ -230,48 +158,19 @@ class UnprocessedRequest(BaseModel): @docs_group('Data structures') -class BatchRequestsOperationResponse(BaseModel): - """Response to batch request deletion calls.""" +class AddRequestsResponse(BaseModel): + """Model for a response to add requests to a queue. + + Contains detailed information about the processing results when adding multiple requests + to a queue. This includes which requests were successfully processed and which ones + encountered issues during processing. + """ model_config = ConfigDict(populate_by_name=True) processed_requests: Annotated[list[ProcessedRequest], Field(alias='processedRequests')] - unprocessed_requests: Annotated[list[UnprocessedRequest], Field(alias='unprocessedRequests')] - + """Successfully processed requests, including information about whether they were + already present in the queue and whether they had been handled previously.""" -class InternalRequest(BaseModel): - """Internal representation of a queue request with additional metadata for ordering and storage.""" - - model_config = ConfigDict(populate_by_name=True) - - id: str - - unique_key: str - - order_no: Decimal | None = None - """Order number for maintaining request sequence in queue. - Used for restoring correct request order when recovering queue from storage.""" - - handled_at: datetime | None - - request: Annotated[ - Request, - Field(alias='json_'), - BeforeValidator(lambda v: json.loads(v) if isinstance(v, str) else v), - ] - """Original Request object. The alias 'json_' is required for backward compatibility with legacy code.""" - - @classmethod - def from_request(cls, request: Request, id: str, order_no: Decimal | None) -> InternalRequest: - """Create an internal request from a `Request` object.""" - return cls( - unique_key=request.unique_key, - id=id, - handled_at=request.handled_at, - order_no=order_no, - request=request, - ) - - def to_request(self) -> Request: - """Convert the internal request back to a `Request` object.""" - return self.request + unprocessed_requests: Annotated[list[UnprocessedRequest], Field(alias='unprocessedRequests')] + """Requests that could not be processed, typically due to validation errors or other issues.""" diff --git a/src/crawlee/storages/_base.py b/src/crawlee/storages/_base.py index 08d2cbd7be..073d27f77c 100644 --- a/src/crawlee/storages/_base.py +++ b/src/crawlee/storages/_base.py @@ -6,7 +6,7 @@ if TYPE_CHECKING: from crawlee.configuration import Configuration from crawlee.storage_clients._base import StorageClient - from crawlee.storage_clients.models import StorageMetadata + from crawlee.storage_clients.models import DatasetMetadata, KeyValueStoreMetadata, RequestQueueMetadata class Storage(ABC): @@ -22,15 +22,9 @@ def id(self) -> str: def name(self) -> str | None: """Get the storage name.""" - @property - @abstractmethod - def storage_object(self) -> StorageMetadata: - """Get the full storage object.""" - - @storage_object.setter @abstractmethod - def storage_object(self, storage_object: StorageMetadata) -> None: - """Set the full storage object.""" + async def get_metadata(self) -> DatasetMetadata | KeyValueStoreMetadata | RequestQueueMetadata: + """Get the storage metadata.""" @classmethod @abstractmethod @@ -55,3 +49,11 @@ async def open( @abstractmethod async def drop(self) -> None: """Drop the storage, removing it from the underlying storage client and clearing the cache.""" + + @abstractmethod + async def purge(self) -> None: + """Purge the storage, removing all items from the underlying storage client. + + This method does not remove the storage itself, e.g. don't remove the metadata, + but clears all items within it. + """ diff --git a/src/crawlee/storages/_creation_management.py b/src/crawlee/storages/_creation_management.py deleted file mode 100644 index 14d9b1719e..0000000000 --- a/src/crawlee/storages/_creation_management.py +++ /dev/null @@ -1,231 +0,0 @@ -from __future__ import annotations - -import asyncio -from typing import TYPE_CHECKING, TypeVar -from weakref import WeakKeyDictionary - -from crawlee.storage_clients import MemoryStorageClient - -from ._dataset import Dataset -from ._key_value_store import KeyValueStore -from ._request_queue import RequestQueue - -if TYPE_CHECKING: - from crawlee.configuration import Configuration - from crawlee.storage_clients._base import ResourceClient, ResourceCollectionClient, StorageClient - -TResource = TypeVar('TResource', Dataset, KeyValueStore, RequestQueue) - - -_creation_locks = WeakKeyDictionary[asyncio.AbstractEventLoop, asyncio.Lock]() -"""Locks for storage creation (we need a separate lock for every event loop so that tests work as expected).""" - -_cache_dataset_by_id: dict[str, Dataset] = {} -_cache_dataset_by_name: dict[str, Dataset] = {} -_cache_kvs_by_id: dict[str, KeyValueStore] = {} -_cache_kvs_by_name: dict[str, KeyValueStore] = {} -_cache_rq_by_id: dict[str, RequestQueue] = {} -_cache_rq_by_name: dict[str, RequestQueue] = {} - - -def _get_from_cache_by_name( - storage_class: type[TResource], - name: str, -) -> TResource | None: - """Try to restore storage from cache by name.""" - if issubclass(storage_class, Dataset): - return _cache_dataset_by_name.get(name) - if issubclass(storage_class, KeyValueStore): - return _cache_kvs_by_name.get(name) - if issubclass(storage_class, RequestQueue): - return _cache_rq_by_name.get(name) - raise ValueError(f'Unknown storage class: {storage_class.__name__}') - - -def _get_from_cache_by_id( - storage_class: type[TResource], - id: str, -) -> TResource | None: - """Try to restore storage from cache by ID.""" - if issubclass(storage_class, Dataset): - return _cache_dataset_by_id.get(id) - if issubclass(storage_class, KeyValueStore): - return _cache_kvs_by_id.get(id) - if issubclass(storage_class, RequestQueue): - return _cache_rq_by_id.get(id) - raise ValueError(f'Unknown storage: {storage_class.__name__}') - - -def _add_to_cache_by_name(name: str, storage: TResource) -> None: - """Add storage to cache by name.""" - if isinstance(storage, Dataset): - _cache_dataset_by_name[name] = storage - elif isinstance(storage, KeyValueStore): - _cache_kvs_by_name[name] = storage - elif isinstance(storage, RequestQueue): - _cache_rq_by_name[name] = storage - else: - raise TypeError(f'Unknown storage: {storage}') - - -def _add_to_cache_by_id(id: str, storage: TResource) -> None: - """Add storage to cache by ID.""" - if isinstance(storage, Dataset): - _cache_dataset_by_id[id] = storage - elif isinstance(storage, KeyValueStore): - _cache_kvs_by_id[id] = storage - elif isinstance(storage, RequestQueue): - _cache_rq_by_id[id] = storage - else: - raise TypeError(f'Unknown storage: {storage}') - - -def _rm_from_cache_by_id(storage_class: type, id: str) -> None: - """Remove a storage from cache by ID.""" - try: - if issubclass(storage_class, Dataset): - del _cache_dataset_by_id[id] - elif issubclass(storage_class, KeyValueStore): - del _cache_kvs_by_id[id] - elif issubclass(storage_class, RequestQueue): - del _cache_rq_by_id[id] - else: - raise TypeError(f'Unknown storage class: {storage_class.__name__}') - except KeyError as exc: - raise RuntimeError(f'Storage with provided ID was not found ({id}).') from exc - - -def _rm_from_cache_by_name(storage_class: type, name: str) -> None: - """Remove a storage from cache by name.""" - try: - if issubclass(storage_class, Dataset): - del _cache_dataset_by_name[name] - elif issubclass(storage_class, KeyValueStore): - del _cache_kvs_by_name[name] - elif issubclass(storage_class, RequestQueue): - del _cache_rq_by_name[name] - else: - raise TypeError(f'Unknown storage class: {storage_class.__name__}') - except KeyError as exc: - raise RuntimeError(f'Storage with provided name was not found ({name}).') from exc - - -def _get_default_storage_id(configuration: Configuration, storage_class: type[TResource]) -> str: - if issubclass(storage_class, Dataset): - return configuration.default_dataset_id - if issubclass(storage_class, KeyValueStore): - return configuration.default_key_value_store_id - if issubclass(storage_class, RequestQueue): - return configuration.default_request_queue_id - - raise TypeError(f'Unknown storage class: {storage_class.__name__}') - - -async def open_storage( - *, - storage_class: type[TResource], - id: str | None, - name: str | None, - configuration: Configuration, - storage_client: StorageClient, -) -> TResource: - """Open either a new storage or restore an existing one and return it.""" - # Try to restore the storage from cache by name - if name: - cached_storage = _get_from_cache_by_name(storage_class=storage_class, name=name) - if cached_storage: - return cached_storage - - default_id = _get_default_storage_id(configuration, storage_class) - - if not id and not name: - id = default_id - - # Find out if the storage is a default on memory storage - is_default_on_memory = id == default_id and isinstance(storage_client, MemoryStorageClient) - - # Try to restore storage from cache by ID - if id: - cached_storage = _get_from_cache_by_id(storage_class=storage_class, id=id) - if cached_storage: - return cached_storage - - # Purge on start if configured - if configuration.purge_on_start: - await storage_client.purge_on_start() - - # Lock and create new storage - loop = asyncio.get_running_loop() - if loop not in _creation_locks: - _creation_locks[loop] = asyncio.Lock() - - async with _creation_locks[loop]: - if id and not is_default_on_memory: - resource_client = _get_resource_client(storage_class, storage_client, id) - storage_object = await resource_client.get() - if not storage_object: - raise RuntimeError(f'{storage_class.__name__} with id "{id}" does not exist!') - - elif is_default_on_memory: - resource_collection_client = _get_resource_collection_client(storage_class, storage_client) - storage_object = await resource_collection_client.get_or_create(name=name, id=id) - - else: - resource_collection_client = _get_resource_collection_client(storage_class, storage_client) - storage_object = await resource_collection_client.get_or_create(name=name) - - storage = storage_class.from_storage_object(storage_client=storage_client, storage_object=storage_object) - - # Cache the storage by ID and name - _add_to_cache_by_id(storage.id, storage) - if storage.name is not None: - _add_to_cache_by_name(storage.name, storage) - - return storage - - -def remove_storage_from_cache( - *, - storage_class: type, - id: str | None = None, - name: str | None = None, -) -> None: - """Remove a storage from cache by ID or name.""" - if id: - _rm_from_cache_by_id(storage_class=storage_class, id=id) - - if name: - _rm_from_cache_by_name(storage_class=storage_class, name=name) - - -def _get_resource_client( - storage_class: type[TResource], - storage_client: StorageClient, - id: str, -) -> ResourceClient: - if issubclass(storage_class, Dataset): - return storage_client.dataset(id) - - if issubclass(storage_class, KeyValueStore): - return storage_client.key_value_store(id) - - if issubclass(storage_class, RequestQueue): - return storage_client.request_queue(id) - - raise ValueError(f'Unknown storage class label: {storage_class.__name__}') - - -def _get_resource_collection_client( - storage_class: type, - storage_client: StorageClient, -) -> ResourceCollectionClient: - if issubclass(storage_class, Dataset): - return storage_client.datasets() - - if issubclass(storage_class, KeyValueStore): - return storage_client.key_value_stores() - - if issubclass(storage_class, RequestQueue): - return storage_client.request_queues() - - raise ValueError(f'Unknown storage class: {storage_class.__name__}') diff --git a/src/crawlee/storages/_dataset.py b/src/crawlee/storages/_dataset.py index 7cb58ae817..7004e4cd2d 100644 --- a/src/crawlee/storages/_dataset.py +++ b/src/crawlee/storages/_dataset.py @@ -1,223 +1,83 @@ from __future__ import annotations -import csv -import io -import json import logging -from datetime import datetime, timezone -from typing import TYPE_CHECKING, Literal, TextIO, TypedDict, cast +from io import StringIO +from typing import TYPE_CHECKING, overload -from typing_extensions import NotRequired, Required, Unpack, override +from typing_extensions import override from crawlee import service_locator -from crawlee._utils.byte_size import ByteSize from crawlee._utils.docs import docs_group -from crawlee._utils.file import json_dumps -from crawlee.storage_clients.models import DatasetMetadata, StorageMetadata +from crawlee._utils.file import export_csv_to_stream, export_json_to_stream from ._base import Storage from ._key_value_store import KeyValueStore if TYPE_CHECKING: - from collections.abc import AsyncIterator, Callable + from collections.abc import AsyncIterator + from typing import Any, Literal - from crawlee._types import JsonSerializable, PushDataKwargs + from typing_extensions import Unpack + + from crawlee._types import ExportDataCsvKwargs, ExportDataJsonKwargs from crawlee.configuration import Configuration from crawlee.storage_clients import StorageClient - from crawlee.storage_clients.models import DatasetItemsListPage + from crawlee.storage_clients._base import DatasetClient + from crawlee.storage_clients.models import DatasetItemsListPage, DatasetMetadata logger = logging.getLogger(__name__) -class GetDataKwargs(TypedDict): - """Keyword arguments for dataset's `get_data` method.""" - - offset: NotRequired[int] - """Skip the specified number of items at the start.""" - - limit: NotRequired[int] - """The maximum number of items to retrieve. Unlimited if None.""" - - clean: NotRequired[bool] - """Return only non-empty items and excludes hidden fields. Shortcut for skip_hidden and skip_empty.""" - - desc: NotRequired[bool] - """Set to True to sort results in descending order.""" - - fields: NotRequired[list[str]] - """Fields to include in each item. Sorts fields as specified if provided.""" - - omit: NotRequired[list[str]] - """Fields to exclude from each item.""" - - unwind: NotRequired[str] - """Unwind items by a specified array field, turning each element into a separate item.""" - - skip_empty: NotRequired[bool] - """Exclude empty items from the results if True.""" - - skip_hidden: NotRequired[bool] - """Exclude fields starting with '#' if True.""" - - flatten: NotRequired[list[str]] - """Field to be flattened in returned items.""" - - view: NotRequired[str] - """Specify the dataset view to be used.""" - - -class ExportToKwargs(TypedDict): - """Keyword arguments for dataset's `export_to` method.""" - - key: Required[str] - """The key under which to save the data.""" - - content_type: NotRequired[Literal['json', 'csv']] - """The format in which to export the data. Either 'json' or 'csv'.""" - - to_key_value_store_id: NotRequired[str] - """ID of the key-value store to save the exported file.""" - - to_key_value_store_name: NotRequired[str] - """Name of the key-value store to save the exported file.""" - - -class ExportDataJsonKwargs(TypedDict): - """Keyword arguments for dataset's `export_data_json` method.""" - - skipkeys: NotRequired[bool] - """If True (default: False), dict keys that are not of a basic type (str, int, float, bool, None) will be skipped - instead of raising a `TypeError`.""" - - ensure_ascii: NotRequired[bool] - """Determines if non-ASCII characters should be escaped in the output JSON string.""" - - check_circular: NotRequired[bool] - """If False (default: True), skips the circular reference check for container types. A circular reference will - result in a `RecursionError` or worse if unchecked.""" - - allow_nan: NotRequired[bool] - """If False (default: True), raises a ValueError for out-of-range float values (nan, inf, -inf) to strictly comply - with the JSON specification. If True, uses their JavaScript equivalents (NaN, Infinity, -Infinity).""" - - cls: NotRequired[type[json.JSONEncoder]] - """Allows specifying a custom JSON encoder.""" - - indent: NotRequired[int] - """Specifies the number of spaces for indentation in the pretty-printed JSON output.""" - - separators: NotRequired[tuple[str, str]] - """A tuple of (item_separator, key_separator). The default is (', ', ': ') if indent is None and (',', ': ') - otherwise.""" - - default: NotRequired[Callable] - """A function called for objects that can't be serialized otherwise. It should return a JSON-encodable version - of the object or raise a `TypeError`.""" - - sort_keys: NotRequired[bool] - """Specifies whether the output JSON object should have keys sorted alphabetically.""" - - -class ExportDataCsvKwargs(TypedDict): - """Keyword arguments for dataset's `export_data_csv` method.""" - - dialect: NotRequired[str] - """Specifies a dialect to be used in CSV parsing and writing.""" - - delimiter: NotRequired[str] - """A one-character string used to separate fields. Defaults to ','.""" - - doublequote: NotRequired[bool] - """Controls how instances of `quotechar` inside a field should be quoted. When True, the character is doubled; - when False, the `escapechar` is used as a prefix. Defaults to True.""" - - escapechar: NotRequired[str] - """A one-character string used to escape the delimiter if `quoting` is set to `QUOTE_NONE` and the `quotechar` - if `doublequote` is False. Defaults to None, disabling escaping.""" - - lineterminator: NotRequired[str] - """The string used to terminate lines produced by the writer. Defaults to '\\r\\n'.""" - - quotechar: NotRequired[str] - """A one-character string used to quote fields containing special characters, like the delimiter or quotechar, - or fields containing new-line characters. Defaults to '\"'.""" - - quoting: NotRequired[int] - """Controls when quotes should be generated by the writer and recognized by the reader. Can take any of - the `QUOTE_*` constants, with a default of `QUOTE_MINIMAL`.""" - - skipinitialspace: NotRequired[bool] - """When True, spaces immediately following the delimiter are ignored. Defaults to False.""" - - strict: NotRequired[bool] - """When True, raises an exception on bad CSV input. Defaults to False.""" - - @docs_group('Classes') class Dataset(Storage): - """Represents an append-only structured storage, ideal for tabular data similar to database tables. - - The `Dataset` class is designed to store structured data, where each entry (row) maintains consistent attributes - (columns) across the dataset. It operates in an append-only mode, allowing new records to be added, but not - modified or deleted. This makes it particularly useful for storing results from web crawling operations. + """Dataset is a storage for managing structured tabular data. - Data can be stored either locally or in the cloud. It depends on the setup of underlying storage client. - By default a `MemoryStorageClient` is used, but it can be changed to a different one. + The dataset class provides a high-level interface for storing and retrieving structured data + with consistent schema, similar to database tables or spreadsheets. It abstracts the underlying + storage implementation details, offering a consistent API regardless of where the data is + physically stored. - By default, data is stored using the following path structure: - ``` - {CRAWLEE_STORAGE_DIR}/datasets/{DATASET_ID}/{INDEX}.json - ``` - - `{CRAWLEE_STORAGE_DIR}`: The root directory for all storage data specified by the environment variable. - - `{DATASET_ID}`: Specifies the dataset, either "default" or a custom dataset ID. - - `{INDEX}`: Represents the zero-based index of the record within the dataset. + Dataset operates in an append-only mode, allowing new records to be added but not modified + or deleted after creation. This makes it particularly suitable for storing crawling results + and other data that should be immutable once collected. - To open a dataset, use the `open` class method by specifying an `id`, `name`, or `configuration`. If none are - provided, the default dataset for the current crawler run is used. Attempting to open a dataset by `id` that does - not exist will raise an error; however, if accessed by `name`, the dataset will be created if it doesn't already - exist. + The class provides methods for adding data, retrieving data with various filtering options, + and exporting data to different formats. You can create a dataset using the `open` class method, + specifying either a name or ID. The underlying storage implementation is determined by + the configured storage client. ### Usage ```python from crawlee.storages import Dataset + # Open a dataset dataset = await Dataset.open(name='my_dataset') + + # Add data + await dataset.push_data({'title': 'Example Product', 'price': 99.99}) + + # Retrieve filtered data + results = await dataset.get_data(limit=10, desc=True) + + # Export data + await dataset.export_to('results.json', content_type='json') ``` """ - _MAX_PAYLOAD_SIZE = ByteSize.from_mb(9) - """Maximum size for a single payload.""" + def __init__(self, client: DatasetClient, id: str, name: str | None) -> None: + """Initialize a new instance. - _SAFETY_BUFFER_PERCENT = 0.01 / 100 # 0.01% - """Percentage buffer to reduce payload limit slightly for safety.""" + Preferably use the `Dataset.open` constructor to create a new instance. - _EFFECTIVE_LIMIT_SIZE = _MAX_PAYLOAD_SIZE - (_MAX_PAYLOAD_SIZE * _SAFETY_BUFFER_PERCENT) - """Calculated payload limit considering safety buffer.""" - - def __init__(self, id: str, name: str | None, storage_client: StorageClient) -> None: + Args: + client: An instance of a storage client. + id: The unique identifier of the storage. + name: The name of the storage, if available. + """ + self._client = client self._id = id self._name = name - datetime_now = datetime.now(timezone.utc) - self._storage_object = StorageMetadata( - id=id, name=name, accessed_at=datetime_now, created_at=datetime_now, modified_at=datetime_now - ) - - # Get resource clients from the storage client. - self._resource_client = storage_client.dataset(self._id) - self._resource_collection_client = storage_client.datasets() - - @classmethod - def from_storage_object(cls, storage_client: StorageClient, storage_object: StorageMetadata) -> Dataset: - """Initialize a new instance of Dataset from a storage metadata object.""" - dataset = Dataset( - id=storage_object.id, - name=storage_object.name, - storage_client=storage_client, - ) - - dataset.storage_object = storage_object - return dataset @property @override @@ -229,15 +89,9 @@ def id(self) -> str: def name(self) -> str | None: return self._name - @property - @override - def storage_object(self) -> StorageMetadata: - return self._storage_object - - @storage_object.setter @override - def storage_object(self, storage_object: StorageMetadata) -> None: - self._storage_object = storage_object + async def get_metadata(self) -> DatasetMetadata: + return await self._client.get_metadata() @override @classmethod @@ -249,27 +103,28 @@ async def open( configuration: Configuration | None = None, storage_client: StorageClient | None = None, ) -> Dataset: - from crawlee.storages._creation_management import open_storage + configuration = service_locator.get_configuration() if configuration is None else configuration + storage_client = service_locator.get_storage_client() if storage_client is None else storage_client - configuration = configuration or service_locator.get_configuration() - storage_client = storage_client or service_locator.get_storage_client() - - return await open_storage( - storage_class=cls, + return await service_locator.storage_instance_manager.open_storage_instance( + cls, id=id, name=name, configuration=configuration, - storage_client=storage_client, + client_opener=storage_client.create_dataset_client, ) @override async def drop(self) -> None: - from crawlee.storages._creation_management import remove_storage_from_cache + storage_instance_manager = service_locator.storage_instance_manager + storage_instance_manager.remove_from_cache(self) + await self._client.drop() - await self._resource_client.delete() - remove_storage_from_cache(storage_class=self.__class__, id=self._id, name=self._name) + @override + async def purge(self) -> None: + await self._client.purge() - async def push_data(self, data: JsonSerializable, **kwargs: Unpack[PushDataKwargs]) -> None: + async def push_data(self, data: list[dict[str, Any]] | dict[str, Any]) -> None: """Store an object or an array of objects to the dataset. The size of the data is limited by the receiving API and therefore `push_data()` will only @@ -279,127 +134,65 @@ async def push_data(self, data: JsonSerializable, **kwargs: Unpack[PushDataKwarg Args: data: A JSON serializable data structure to be stored in the dataset. The JSON representation of each item must be smaller than 9MB. - kwargs: Keyword arguments for the storage client method. """ - # Handle singular items - if not isinstance(data, list): - items = await self.check_and_serialize(data) - return await self._resource_client.push_items(items, **kwargs) + await self._client.push_data(data=data) - # Handle lists - payloads_generator = (await self.check_and_serialize(item, index) for index, item in enumerate(data)) - - # Invoke client in series to preserve the order of data - async for items in self._chunk_by_size(payloads_generator): - await self._resource_client.push_items(items, **kwargs) - - return None - - async def get_data(self, **kwargs: Unpack[GetDataKwargs]) -> DatasetItemsListPage: - """Retrieve dataset items based on filtering, sorting, and pagination parameters. + async def get_data( + self, + *, + offset: int = 0, + limit: int | None = 999_999_999_999, + clean: bool = False, + desc: bool = False, + fields: list[str] | None = None, + omit: list[str] | None = None, + unwind: str | None = None, + skip_empty: bool = False, + skip_hidden: bool = False, + flatten: list[str] | None = None, + view: str | None = None, + ) -> DatasetItemsListPage: + """Retrieve a paginated list of items from a dataset based on various filtering parameters. - This method allows customization of the data retrieval process from a dataset, supporting operations such as - field selection, ordering, and skipping specific records based on provided parameters. + This method provides the flexibility to filter, sort, and modify the appearance of dataset items + when listed. Each parameter modifies the result set according to its purpose. The method also + supports pagination through 'offset' and 'limit' parameters. Args: - kwargs: Keyword arguments for the storage client method. + offset: Skips the specified number of items at the start. + limit: The maximum number of items to retrieve. Unlimited if None. + clean: Return only non-empty items and excludes hidden fields. Shortcut for skip_hidden and skip_empty. + desc: Set to True to sort results in descending order. + fields: Fields to include in each item. Sorts fields as specified if provided. + omit: Fields to exclude from each item. + unwind: Unwinds items by a specified array field, turning each element into a separate item. + skip_empty: Excludes empty items from the results if True. + skip_hidden: Excludes fields starting with '#' if True. + flatten: Fields to be flattened in returned items. + view: Specifies the dataset view to be used. Returns: - List page containing filtered and paginated dataset items. + An object with filtered, sorted, and paginated dataset items plus pagination details. """ - return await self._resource_client.list_items(**kwargs) - - async def write_to_csv(self, destination: TextIO, **kwargs: Unpack[ExportDataCsvKwargs]) -> None: - """Export the entire dataset into an arbitrary stream. - - Args: - destination: The stream into which the dataset contents should be written. - kwargs: Additional keyword arguments for `csv.writer`. - """ - items: list[dict] = [] - limit = 1000 - offset = 0 - - while True: - list_items = await self._resource_client.list_items(limit=limit, offset=offset) - items.extend(list_items.items) - if list_items.total <= offset + list_items.count: - break - offset += list_items.count - - if items: - writer = csv.writer(destination, **kwargs) - writer.writerows([items[0].keys(), *[item.values() for item in items]]) - else: - logger.warning('Attempting to export an empty dataset - no file will be created') - - async def write_to_json(self, destination: TextIO, **kwargs: Unpack[ExportDataJsonKwargs]) -> None: - """Export the entire dataset into an arbitrary stream. - - Args: - destination: The stream into which the dataset contents should be written. - kwargs: Additional keyword arguments for `json.dump`. - """ - items: list[dict] = [] - limit = 1000 - offset = 0 - - while True: - list_items = await self._resource_client.list_items(limit=limit, offset=offset) - items.extend(list_items.items) - if list_items.total <= offset + list_items.count: - break - offset += list_items.count - - if items: - json.dump(items, destination, **kwargs) - else: - logger.warning('Attempting to export an empty dataset - no file will be created') - - async def export_to(self, **kwargs: Unpack[ExportToKwargs]) -> None: - """Export the entire dataset into a specified file stored under a key in a key-value store. - - This method consolidates all entries from a specified dataset into one file, which is then saved under a - given key in a key-value store. The format of the exported file is determined by the `content_type` parameter. - Either the dataset's ID or name should be specified, and similarly, either the target key-value store's ID or - name should be used. - - Args: - kwargs: Keyword arguments for the storage client method. - """ - key = cast('str', kwargs.get('key')) - content_type = kwargs.get('content_type', 'json') - to_key_value_store_id = kwargs.get('to_key_value_store_id') - to_key_value_store_name = kwargs.get('to_key_value_store_name') - - key_value_store = await KeyValueStore.open(id=to_key_value_store_id, name=to_key_value_store_name) - - output = io.StringIO() - if content_type == 'csv': - await self.write_to_csv(output) - elif content_type == 'json': - await self.write_to_json(output) - else: - raise ValueError('Unsupported content type, expecting CSV or JSON') - - if content_type == 'csv': - await key_value_store.set_value(key, output.getvalue(), 'text/csv') - - if content_type == 'json': - await key_value_store.set_value(key, output.getvalue(), 'application/json') - - async def get_info(self) -> DatasetMetadata | None: - """Get an object containing general information about the dataset.""" - metadata = await self._resource_client.get() - if isinstance(metadata, DatasetMetadata): - return metadata - return None + return await self._client.get_data( + offset=offset, + limit=limit, + clean=clean, + desc=desc, + fields=fields, + omit=omit, + unwind=unwind, + skip_empty=skip_empty, + skip_hidden=skip_hidden, + flatten=flatten, + view=view, + ) async def iterate_items( self, *, offset: int = 0, - limit: int | None = None, + limit: int | None = 999_999_999_999, clean: bool = False, desc: bool = False, fields: list[str] | None = None, @@ -407,28 +200,30 @@ async def iterate_items( unwind: str | None = None, skip_empty: bool = False, skip_hidden: bool = False, - ) -> AsyncIterator[dict]: - """Iterate over dataset items, applying filtering, sorting, and pagination. + ) -> AsyncIterator[dict[str, Any]]: + """Iterate over items in the dataset according to specified filters and sorting. - Retrieve dataset items incrementally, allowing fine-grained control over the data fetched. The function - supports various parameters to filter, sort, and limit the data returned, facilitating tailored dataset - queries. + This method allows for asynchronously iterating through dataset items while applying various filters such as + skipping empty items, hiding specific fields, and sorting. It supports pagination via `offset` and `limit` + parameters, and can modify the appearance of dataset items using `fields`, `omit`, `unwind`, `skip_empty`, and + `skip_hidden` parameters. Args: - offset: Initial number of items to skip. - limit: Max number of items to return. No limit if None. - clean: Filter out empty items and hidden fields if True. - desc: Return items in reverse order if True. - fields: Specific fields to include in each item. - omit: Fields to omit from each item. - unwind: Field name to unwind items by. - skip_empty: Omits empty items if True. + offset: Skips the specified number of items at the start. + limit: The maximum number of items to retrieve. Unlimited if None. + clean: Return only non-empty items and excludes hidden fields. Shortcut for skip_hidden and skip_empty. + desc: Set to True to sort results in descending order. + fields: Fields to include in each item. Sorts fields as specified if provided. + omit: Fields to exclude from each item. + unwind: Unwinds items by a specified array field, turning each element into a separate item. + skip_empty: Excludes empty items from the results if True. skip_hidden: Excludes fields starting with '#' if True. Yields: - Each item from the dataset as a dictionary. + An asynchronous iterator of dictionary objects, each representing a dataset item after applying + the specified filters and transformations. """ - async for item in self._resource_client.iterate_items( + async for item in self._client.iterate_items( offset=offset, limit=limit, clean=clean, @@ -441,59 +236,121 @@ async def iterate_items( ): yield item - @classmethod - async def check_and_serialize(cls, item: JsonSerializable, index: int | None = None) -> str: - """Serialize a given item to JSON, checks its serializability and size against a limit. + async def list_items( + self, + *, + offset: int = 0, + limit: int | None = 999_999_999_999, + clean: bool = False, + desc: bool = False, + fields: list[str] | None = None, + omit: list[str] | None = None, + unwind: str | None = None, + skip_empty: bool = False, + skip_hidden: bool = False, + ) -> list[dict[str, Any]]: + """Retrieve a list of all items from the dataset according to specified filters and sorting. + + This method collects all dataset items into a list while applying various filters such as + skipping empty items, hiding specific fields, and sorting. It supports pagination via `offset` and `limit` + parameters, and can modify the appearance of dataset items using `fields`, `omit`, `unwind`, `skip_empty`, and + `skip_hidden` parameters. Args: - item: The item to serialize. - index: Index of the item, used for error context. + offset: Skips the specified number of items at the start. + limit: The maximum number of items to retrieve. Unlimited if None. + clean: Return only non-empty items and excludes hidden fields. Shortcut for skip_hidden and skip_empty. + desc: Set to True to sort results in descending order. + fields: Fields to include in each item. Sorts fields as specified if provided. + omit: Fields to exclude from each item. + unwind: Unwinds items by a specified array field, turning each element into a separate item. + skip_empty: Excludes empty items from the results if True. + skip_hidden: Excludes fields starting with '#' if True. Returns: - Serialized JSON string. - - Raises: - ValueError: If item is not JSON serializable or exceeds size limit. + A list of dictionary objects, each representing a dataset item after applying + the specified filters and transformations. """ - s = ' ' if index is None else f' at index {index} ' - - try: - payload = await json_dumps(item) - except Exception as exc: - raise ValueError(f'Data item{s}is not serializable to JSON.') from exc - - payload_size = ByteSize(len(payload.encode('utf-8'))) - if payload_size > cls._EFFECTIVE_LIMIT_SIZE: - raise ValueError(f'Data item{s}is too large (size: {payload_size}, limit: {cls._EFFECTIVE_LIMIT_SIZE})') - - return payload - - async def _chunk_by_size(self, items: AsyncIterator[str]) -> AsyncIterator[str]: - """Yield chunks of JSON arrays composed of input strings, respecting a size limit. + return [ + item + async for item in self.iterate_items( + offset=offset, + limit=limit, + clean=clean, + desc=desc, + fields=fields, + omit=omit, + unwind=unwind, + skip_empty=skip_empty, + skip_hidden=skip_hidden, + ) + ] + + @overload + async def export_to( + self, + key: str, + content_type: Literal['json'], + to_kvs_id: str | None = None, + to_kvs_name: str | None = None, + to_kvs_storage_client: StorageClient | None = None, + to_kvs_configuration: Configuration | None = None, + **kwargs: Unpack[ExportDataJsonKwargs], + ) -> None: ... + + @overload + async def export_to( + self, + key: str, + content_type: Literal['csv'], + to_kvs_id: str | None = None, + to_kvs_name: str | None = None, + to_kvs_storage_client: StorageClient | None = None, + to_kvs_configuration: Configuration | None = None, + **kwargs: Unpack[ExportDataCsvKwargs], + ) -> None: ... + + async def export_to( + self, + key: str, + content_type: Literal['json', 'csv'] = 'json', + to_kvs_id: str | None = None, + to_kvs_name: str | None = None, + to_kvs_storage_client: StorageClient | None = None, + to_kvs_configuration: Configuration | None = None, + **kwargs: Any, + ) -> None: + """Export the entire dataset into a specified file stored under a key in a key-value store. - Groups an iterable of JSON string payloads into larger JSON arrays, ensuring the total size - of each array does not exceed `EFFECTIVE_LIMIT_SIZE`. Each output is a JSON array string that - contains as many payloads as possible without breaching the size threshold, maintaining the - order of the original payloads. Assumes individual items are below the size limit. + This method consolidates all entries from a specified dataset into one file, which is then saved under a + given key in a key-value store. The format of the exported file is determined by the `content_type` parameter. + Either the dataset's ID or name should be specified, and similarly, either the target key-value store's ID or + name should be used. Args: - items: Iterable of JSON string payloads. - - Yields: - Strings representing JSON arrays of payloads, each staying within the size limit. + key: The key under which to save the data in the key-value store. + content_type: The format in which to export the data. + to_kvs_id: ID of the key-value store to save the exported file. + Specify only one of ID or name. + to_kvs_name: Name of the key-value store to save the exported file. + Specify only one of ID or name. + to_kvs_storage_client: Storage client to use for the key-value store. + to_kvs_configuration: Configuration for the key-value store. + kwargs: Additional parameters for the export operation, specific to the chosen content type. """ - last_chunk_size = ByteSize(2) # Add 2 bytes for [] wrapper. - current_chunk = [] - - async for payload in items: - payload_size = ByteSize(len(payload.encode('utf-8'))) - - if last_chunk_size + payload_size <= self._EFFECTIVE_LIMIT_SIZE: - current_chunk.append(payload) - last_chunk_size += payload_size + ByteSize(1) # Add 1 byte for ',' separator. - else: - yield f'[{",".join(current_chunk)}]' - current_chunk = [payload] - last_chunk_size = payload_size + ByteSize(2) # Add 2 bytes for [] wrapper. + kvs = await KeyValueStore.open( + id=to_kvs_id, + name=to_kvs_name, + configuration=to_kvs_configuration, + storage_client=to_kvs_storage_client, + ) + dst = StringIO() - yield f'[{",".join(current_chunk)}]' + if content_type == 'csv': + await export_csv_to_stream(self.iterate_items(), dst, **kwargs) + await kvs.set_value(key, dst.getvalue(), 'text/csv') + elif content_type == 'json': + await export_json_to_stream(self.iterate_items(), dst, **kwargs) + await kvs.set_value(key, dst.getvalue(), 'application/json') + else: + raise ValueError('Unsupported content type, expecting CSV or JSON') diff --git a/src/crawlee/storages/_key_value_store.py b/src/crawlee/storages/_key_value_store.py index fc077726d1..f205011bfb 100644 --- a/src/crawlee/storages/_key_value_store.py +++ b/src/crawlee/storages/_key_value_store.py @@ -2,7 +2,6 @@ import asyncio from collections.abc import AsyncIterator -from datetime import datetime, timezone from logging import getLogger from typing import TYPE_CHECKING, Any, ClassVar, TypeVar, overload @@ -12,16 +11,20 @@ from crawlee import service_locator from crawlee._types import JsonSerializable # noqa: TC001 from crawlee._utils.docs import docs_group -from crawlee.storage_clients.models import KeyValueStoreKeyInfo, KeyValueStoreMetadata, StorageMetadata +from crawlee._utils.recoverable_state import RecoverableState +from crawlee.storage_clients.models import KeyValueStoreMetadata from ._base import Storage if TYPE_CHECKING: from collections.abc import AsyncIterator - from crawlee._utils.recoverable_state import RecoverableState from crawlee.configuration import Configuration from crawlee.storage_clients import StorageClient + from crawlee.storage_clients._base import KeyValueStoreClient + from crawlee.storage_clients.models import KeyValueStoreMetadata, KeyValueStoreRecordMetadata +else: + from crawlee._utils.recoverable_state import RecoverableState T = TypeVar('T') @@ -34,70 +37,59 @@ class AutosavedValue(RootModel): @docs_group('Classes') class KeyValueStore(Storage): - """Represents a key-value based storage for reading and writing data records or files. - - Each data record is identified by a unique key and associated with a specific MIME content type. This class is - commonly used in crawler runs to store inputs and outputs, typically in JSON format, but it also supports other - content types. + """Key-value store is a storage for reading and writing data records with unique key identifiers. - Data can be stored either locally or in the cloud. It depends on the setup of underlying storage client. - By default a `MemoryStorageClient` is used, but it can be changed to a different one. + The key-value store class acts as a high-level interface for storing, retrieving, and managing data records + identified by unique string keys. It abstracts away the underlying storage implementation details, + allowing you to work with the same API regardless of whether data is stored in memory, on disk, + or in the cloud. - By default, data is stored using the following path structure: - ``` - {CRAWLEE_STORAGE_DIR}/key_value_stores/{STORE_ID}/{KEY}.{EXT} - ``` - - `{CRAWLEE_STORAGE_DIR}`: The root directory for all storage data specified by the environment variable. - - `{STORE_ID}`: The identifier for the key-value store, either "default" or as specified by - `CRAWLEE_DEFAULT_KEY_VALUE_STORE_ID`. - - `{KEY}`: The unique key for the record. - - `{EXT}`: The file extension corresponding to the MIME type of the content. + Each data record is associated with a specific MIME content type, allowing storage of various + data formats such as JSON, text, images, HTML snapshots or any binary data. This class is + commonly used to store inputs, outputs, and other artifacts of crawler operations. - To open a key-value store, use the `open` class method, providing an `id`, `name`, or optional `configuration`. - If none are specified, the default store for the current crawler run is used. Attempting to open a store by `id` - that does not exist will raise an error; however, if accessed by `name`, the store will be created if it does not - already exist. + You can instantiate a key-value store using the `open` class method, which will create a store + with the specified name or id. The underlying storage implementation is determined by the configured + storage client. ### Usage ```python from crawlee.storages import KeyValueStore - kvs = await KeyValueStore.open(name='my_kvs') + # Open a named key-value store + kvs = await KeyValueStore.open(name='my-store') + + # Store and retrieve data + await kvs.set_value('product-1234.json', [{'name': 'Smartphone', 'price': 799.99}]) + product = await kvs.get_value('product-1234') ``` """ - # Cache for recoverable (auto-saved) values _autosaved_values: ClassVar[ dict[ str, dict[str, RecoverableState[AutosavedValue]], ] ] = {} + """Cache for recoverable (auto-saved) values.""" + + def __init__(self, client: KeyValueStoreClient, id: str, name: str | None) -> None: + """Initialize a new instance. + + Preferably use the `KeyValueStore.open` constructor to create a new instance. - def __init__(self, id: str, name: str | None, storage_client: StorageClient) -> None: + Args: + client: An instance of a storage client. + id: The unique identifier of the storage. + name: The name of the storage, if available. + """ + self._client = client self._id = id self._name = name - datetime_now = datetime.now(timezone.utc) - self._storage_object = StorageMetadata( - id=id, name=name, accessed_at=datetime_now, created_at=datetime_now, modified_at=datetime_now - ) - # Get resource clients from storage client - self._resource_client = storage_client.key_value_store(self._id) self._autosave_lock = asyncio.Lock() - - @classmethod - def from_storage_object(cls, storage_client: StorageClient, storage_object: StorageMetadata) -> KeyValueStore: - """Initialize a new instance of KeyValueStore from a storage metadata object.""" - key_value_store = KeyValueStore( - id=storage_object.id, - name=storage_object.name, - storage_client=storage_client, - ) - - key_value_store.storage_object = storage_object - return key_value_store + """Lock for autosaving values to prevent concurrent modifications.""" @property @override @@ -109,19 +101,9 @@ def id(self) -> str: def name(self) -> str | None: return self._name - @property - @override - def storage_object(self) -> StorageMetadata: - return self._storage_object - - @storage_object.setter @override - def storage_object(self, storage_object: StorageMetadata) -> None: - self._storage_object = storage_object - - async def get_info(self) -> KeyValueStoreMetadata | None: - """Get an object containing general information about the key value store.""" - return await self._resource_client.get() + async def get_metadata(self) -> KeyValueStoreMetadata: + return await self._client.get_metadata() @override @classmethod @@ -133,26 +115,28 @@ async def open( configuration: Configuration | None = None, storage_client: StorageClient | None = None, ) -> KeyValueStore: - from crawlee.storages._creation_management import open_storage + configuration = service_locator.get_configuration() if configuration is None else configuration + storage_client = service_locator.get_storage_client() if storage_client is None else storage_client - configuration = configuration or service_locator.get_configuration() - storage_client = storage_client or service_locator.get_storage_client() - - return await open_storage( - storage_class=cls, + return await service_locator.storage_instance_manager.open_storage_instance( + cls, id=id, name=name, configuration=configuration, - storage_client=storage_client, + client_opener=storage_client.create_kvs_client, ) @override async def drop(self) -> None: - from crawlee.storages._creation_management import remove_storage_from_cache + storage_instance_manager = service_locator.storage_instance_manager + storage_instance_manager.remove_from_cache(self) + + await self._clear_cache() # Clear cache with persistent values. + await self._client.drop() - remove_storage_from_cache(storage_class=self.__class__, id=self._id, name=self._name) - await self._clear_cache() - await self._resource_client.delete() + @override + async def purge(self) -> None: + await self._client.purge() @overload async def get_value(self, key: str) -> Any: ... @@ -173,44 +157,86 @@ async def get_value(self, key: str, default_value: T | None = None) -> T | None: Returns: The value associated with the given key. `default_value` is used in case the record does not exist. """ - record = await self._resource_client.get_record(key) + record = await self._client.get_value(key=key) return record.value if record else default_value - async def iterate_keys(self, exclusive_start_key: str | None = None) -> AsyncIterator[KeyValueStoreKeyInfo]: + async def set_value( + self, + key: str, + value: Any, + content_type: str | None = None, + ) -> None: + """Set a value in the KVS. + + Args: + key: Key of the record to set. + value: Value to set. + content_type: The MIME content type string. + """ + await self._client.set_value(key=key, value=value, content_type=content_type) + + async def delete_value(self, key: str) -> None: + """Delete a value from the KVS. + + Args: + key: Key of the record to delete. + """ + await self._client.delete_value(key=key) + + async def iterate_keys( + self, + exclusive_start_key: str | None = None, + limit: int | None = None, + ) -> AsyncIterator[KeyValueStoreRecordMetadata]: """Iterate over the existing keys in the KVS. Args: exclusive_start_key: Key to start the iteration from. + limit: Maximum number of keys to return. None means no limit. Yields: Information about the key. """ - while True: - list_keys = await self._resource_client.list_keys(exclusive_start_key=exclusive_start_key) - for item in list_keys.items: - yield KeyValueStoreKeyInfo(key=item.key, size=item.size) - - if not list_keys.is_truncated: - break - exclusive_start_key = list_keys.next_exclusive_start_key + async for item in self._client.iterate_keys( + exclusive_start_key=exclusive_start_key, + limit=limit, + ): + yield item - async def set_value( + async def list_keys( self, - key: str, - value: Any, - content_type: str | None = None, - ) -> None: - """Set a value in the KVS. + exclusive_start_key: str | None = None, + limit: int = 1000, + ) -> list[KeyValueStoreRecordMetadata]: + """List all the existing keys in the KVS. + + It uses client's `iterate_keys` method to get the keys. Args: - key: Key of the record to set. - value: Value to set. If `None`, the record is deleted. - content_type: Content type of the record. + exclusive_start_key: Key to start the iteration from. + limit: Maximum number of keys to return. + + Returns: + A list of keys in the KVS. """ - if value is None: - return await self._resource_client.delete_record(key) + return [ + key + async for key in self._client.iterate_keys( + exclusive_start_key=exclusive_start_key, + limit=limit, + ) + ] + + async def record_exists(self, key: str) -> bool: + """Check if a record with the given key exists in the key-value store. + + Args: + key: Key of the record to check for existence. - return await self._resource_client.set_record(key, value, content_type) + Returns: + True if a record with the given key exists, False otherwise. + """ + return await self._client.record_exists(key=key) async def get_public_url(self, key: str) -> str: """Get the public URL for the given key. @@ -221,7 +247,7 @@ async def get_public_url(self, key: str) -> str: Returns: The public URL for the given key. """ - return await self._resource_client.get_public_url(key) + return await self._client.get_public_url(key=key) async def get_auto_saved_value( self, @@ -237,12 +263,10 @@ async def get_auto_saved_value( Returns: Return the value of the key. """ - from crawlee._utils.recoverable_state import RecoverableState - default_value = {} if default_value is None else default_value async with self._autosave_lock: - cache = self._autosaved_values.setdefault(self._id, {}) + cache = self._autosaved_values.setdefault(self.id, {}) if key in cache: return cache[key].current_value.root @@ -250,7 +274,7 @@ async def get_auto_saved_value( cache[key] = recoverable_state = RecoverableState( default_state=AutosavedValue(default_value), persistence_enabled=True, - persist_state_kvs_id=self._id, + persist_state_kvs_id=self.id, persist_state_key=key, logger=logger, ) @@ -259,17 +283,17 @@ async def get_auto_saved_value( return recoverable_state.current_value.root - async def _clear_cache(self) -> None: - """Clear cache with autosaved values.""" + async def persist_autosaved_values(self) -> None: + """Force autosaved values to be saved without waiting for an event in Event Manager.""" if self.id in self._autosaved_values: cache = self._autosaved_values[self.id] for value in cache.values(): - await value.teardown() - cache.clear() + await value.persist_state() - async def persist_autosaved_values(self) -> None: - """Force autosaved values to be saved without waiting for an event in Event Manager.""" + async def _clear_cache(self) -> None: + """Clear cache with autosaved values.""" if self.id in self._autosaved_values: cache = self._autosaved_values[self.id] for value in cache.values(): - await value.persist_state() + await value.teardown() + cache.clear() diff --git a/src/crawlee/storages/_request_queue.py b/src/crawlee/storages/_request_queue.py index b3274ccc81..c1b0227bdf 100644 --- a/src/crawlee/storages/_request_queue.py +++ b/src/crawlee/storages/_request_queue.py @@ -1,23 +1,16 @@ from __future__ import annotations import asyncio -from collections import deque -from contextlib import suppress -from datetime import datetime, timedelta, timezone +from datetime import timedelta from logging import getLogger -from typing import TYPE_CHECKING, Any, TypedDict, TypeVar +from typing import TYPE_CHECKING, TypeVar -from cachetools import LRUCache from typing_extensions import override -from crawlee import service_locator -from crawlee._utils.crypto import crypto_random_object_id +from crawlee import Request, service_locator from crawlee._utils.docs import docs_group -from crawlee._utils.requests import unique_key_to_request_id from crawlee._utils.wait import wait_for_all_tasks_for_finish -from crawlee.events import Event from crawlee.request_loaders import RequestManager -from crawlee.storage_clients.models import ProcessedRequest, RequestQueueMetadata, StorageMetadata from ._base import Storage @@ -27,111 +20,72 @@ from crawlee import Request from crawlee.configuration import Configuration from crawlee.storage_clients import StorageClient + from crawlee.storage_clients._base import RequestQueueClient + from crawlee.storage_clients.models import ProcessedRequest, RequestQueueMetadata logger = getLogger(__name__) T = TypeVar('T') -class CachedRequest(TypedDict): - id: str - was_already_handled: bool - hydrated: Request | None - lock_expires_at: datetime | None - forefront: bool - - @docs_group('Classes') class RequestQueue(Storage, RequestManager): - """Represents a queue storage for managing HTTP requests in web crawling operations. + """Request queue is a storage for managing HTTP requests. - The `RequestQueue` class handles a queue of HTTP requests, each identified by a unique URL, to facilitate structured - web crawling. It supports both breadth-first and depth-first crawling strategies, allowing for recursive crawling - starting from an initial set of URLs. Each URL in the queue is uniquely identified by a `unique_key`, which can be - customized to allow the same URL to be added multiple times under different keys. + The request queue class serves as a high-level interface for organizing and managing HTTP requests + during web crawling. It provides methods for adding, retrieving, and manipulating requests throughout + the crawling lifecycle, abstracting away the underlying storage implementation details. - Data can be stored either locally or in the cloud. It depends on the setup of underlying storage client. - By default a `MemoryStorageClient` is used, but it can be changed to a different one. + Request queue maintains the state of each URL to be crawled, tracking whether it has been processed, + is currently being handled, or is waiting in the queue. Each URL in the queue is uniquely identified + by a `unique_key` property, which prevents duplicate processing unless explicitly configured otherwise. - By default, data is stored using the following path structure: - ``` - {CRAWLEE_STORAGE_DIR}/request_queues/{QUEUE_ID}/{REQUEST_ID}.json - ``` - - `{CRAWLEE_STORAGE_DIR}`: The root directory for all storage data specified by the environment variable. - - `{QUEUE_ID}`: The identifier for the request queue, either "default" or as specified. - - `{REQUEST_ID}`: The unique identifier for each request in the queue. + The class supports both breadth-first and depth-first crawling strategies through its `forefront` parameter + when adding requests. It also provides mechanisms for error handling and request reclamation when + processing fails. - The `RequestQueue` supports both creating new queues and opening existing ones by `id` or `name`. Named queues - persist indefinitely, while unnamed queues expire after 7 days unless specified otherwise. The queue supports - mutable operations, allowing URLs to be added and removed as needed. + You can open a request queue using the `open` class method, specifying either a name or ID to identify + the queue. The underlying storage implementation is determined by the configured storage client. ### Usage ```python from crawlee.storages import RequestQueue - rq = await RequestQueue.open(name='my_rq') + # Open a request queue + rq = await RequestQueue.open(name='my_queue') + + # Add a request + await rq.add_request('https://example.com') + + # Process requests + request = await rq.fetch_next_request() + if request: + try: + # Process the request + # ... + await rq.mark_request_as_handled(request) + except Exception: + await rq.reclaim_request(request) ``` """ - _MAX_CACHED_REQUESTS = 1_000_000 - """Maximum number of requests that can be cached.""" + def __init__(self, client: RequestQueueClient, id: str, name: str | None) -> None: + """Initialize a new instance. - def __init__( - self, - id: str, - name: str | None, - storage_client: StorageClient, - ) -> None: - config = service_locator.get_configuration() - event_manager = service_locator.get_event_manager() + Preferably use the `RequestQueue.open` constructor to create a new instance. + Args: + client: An instance of a storage client. + id: The unique identifier of the storage. + name: The name of the storage, if available. + """ + self._client = client self._id = id self._name = name - datetime_now = datetime.now(timezone.utc) - self._storage_object = StorageMetadata( - id=id, name=name, accessed_at=datetime_now, created_at=datetime_now, modified_at=datetime_now - ) - - # Get resource clients from storage client - self._resource_client = storage_client.request_queue(self._id) - self._resource_collection_client = storage_client.request_queues() - - self._request_lock_time = timedelta(minutes=3) - self._queue_paused_for_migration = False - self._queue_has_locked_requests: bool | None = None - self._should_check_for_forefront_requests = False - - self._is_finished_log_throttle_counter = 0 - self._dequeued_request_count = 0 - - event_manager.on(event=Event.MIGRATING, listener=lambda _: setattr(self, '_queue_paused_for_migration', True)) - event_manager.on(event=Event.MIGRATING, listener=self._clear_possible_locks) - event_manager.on(event=Event.ABORTING, listener=self._clear_possible_locks) - - # Other internal attributes - self._tasks = list[asyncio.Task]() - self._client_key = crypto_random_object_id() - self._internal_timeout = config.internal_timeout or timedelta(minutes=5) - self._assumed_total_count = 0 - self._assumed_handled_count = 0 - self._queue_head = deque[str]() - self._list_head_and_lock_task: asyncio.Task | None = None - self._last_activity = datetime.now(timezone.utc) - self._requests_cache: LRUCache[str, CachedRequest] = LRUCache(maxsize=self._MAX_CACHED_REQUESTS) - - @classmethod - def from_storage_object(cls, storage_client: StorageClient, storage_object: StorageMetadata) -> RequestQueue: - """Initialize a new instance of RequestQueue from a storage metadata object.""" - request_queue = RequestQueue( - id=storage_object.id, - name=storage_object.name, - storage_client=storage_client, - ) - - request_queue.storage_object = storage_object - return request_queue + self._add_requests_tasks = list[asyncio.Task]() + """A list of tasks for adding requests to the queue.""" @property @override @@ -143,15 +97,19 @@ def id(self) -> str: def name(self) -> str | None: return self._name - @property @override - def storage_object(self) -> StorageMetadata: - return self._storage_object + async def get_metadata(self) -> RequestQueueMetadata: + return await self._client.get_metadata() - @storage_object.setter @override - def storage_object(self, storage_object: StorageMetadata) -> None: - self._storage_object = storage_object + async def get_handled_count(self) -> int: + metadata = await self._client.get_metadata() + return metadata.handled_request_count + + @override + async def get_total_count(self) -> int: + metadata = await self._client.get_metadata() + return metadata.total_request_count @override @classmethod @@ -163,29 +121,28 @@ async def open( configuration: Configuration | None = None, storage_client: StorageClient | None = None, ) -> RequestQueue: - from crawlee.storages._creation_management import open_storage + configuration = service_locator.get_configuration() if configuration is None else configuration + storage_client = service_locator.get_storage_client() if storage_client is None else storage_client - configuration = configuration or service_locator.get_configuration() - storage_client = storage_client or service_locator.get_storage_client() - - return await open_storage( - storage_class=cls, + return await service_locator.storage_instance_manager.open_storage_instance( + cls, id=id, name=name, configuration=configuration, - storage_client=storage_client, + client_opener=storage_client.create_rq_client, ) @override - async def drop(self, *, timeout: timedelta | None = None) -> None: - from crawlee.storages._creation_management import remove_storage_from_cache + async def drop(self) -> None: + # Remove from cache before dropping + storage_instance_manager = service_locator.storage_instance_manager + storage_instance_manager.remove_from_cache(self) - # Wait for all tasks to finish - await wait_for_all_tasks_for_finish(self._tasks, logger=logger, timeout=timeout) + await self._client.drop() - # Delete the storage from the underlying client and remove it from the cache - await self._resource_client.delete() - remove_storage_from_cache(storage_class=self.__class__, id=self._id, name=self._name) + @override + async def purge(self) -> None: + await self._client.purge() @override async def add_request( @@ -195,40 +152,15 @@ async def add_request( forefront: bool = False, ) -> ProcessedRequest: request = self._transform_request(request) - self._last_activity = datetime.now(timezone.utc) - - cache_key = unique_key_to_request_id(request.unique_key) - cached_info = self._requests_cache.get(cache_key) - - if cached_info: - request.id = cached_info['id'] - # We may assume that if request is in local cache then also the information if the request was already - # handled is there because just one client should be using one queue. - return ProcessedRequest( - id=request.id, - unique_key=request.unique_key, - was_already_present=True, - was_already_handled=cached_info['was_already_handled'], - ) - - processed_request = await self._resource_client.add_request(request, forefront=forefront) - processed_request.unique_key = request.unique_key - - self._cache_request(cache_key, processed_request, forefront=forefront) - - if not processed_request.was_already_present and forefront: - self._should_check_for_forefront_requests = True - - if request.handled_at is None and not processed_request.was_already_present: - self._assumed_total_count += 1 - - return processed_request + response = await self._client.add_batch_of_requests([request], forefront=forefront) + return response.processed_requests[0] @override - async def add_requests_batched( + async def add_requests( self, requests: Sequence[str | Request], *, + forefront: bool = False, batch_size: int = 1000, wait_time_between_batches: timedelta = timedelta(seconds=1), wait_for_all_requests_to_be_added: bool = False, @@ -240,21 +172,31 @@ async def add_requests_batched( # Wait for the first batch to be added first_batch = transformed_requests[:batch_size] if first_batch: - await self._process_batch(first_batch, base_retry_wait=wait_time_between_batches) + await self._process_batch( + first_batch, + base_retry_wait=wait_time_between_batches, + forefront=forefront, + ) async def _process_remaining_batches() -> None: for i in range(batch_size, len(transformed_requests), batch_size): batch = transformed_requests[i : i + batch_size] - await self._process_batch(batch, base_retry_wait=wait_time_between_batches) + await self._process_batch( + batch, + base_retry_wait=wait_time_between_batches, + forefront=forefront, + ) if i + batch_size < len(transformed_requests): await asyncio.sleep(wait_time_secs) # Create and start the task to process remaining batches in the background remaining_batches_task = asyncio.create_task( - _process_remaining_batches(), name='request_queue_process_remaining_batches_task' + _process_remaining_batches(), + name='request_queue_process_remaining_batches_task', ) - self._tasks.append(remaining_batches_task) - remaining_batches_task.add_done_callback(lambda _: self._tasks.remove(remaining_batches_task)) + + self._add_requests_tasks.append(remaining_batches_task) + remaining_batches_task.add_done_callback(lambda _: self._add_requests_tasks.remove(remaining_batches_task)) # Wait for all tasks to finish if requested if wait_for_all_requests_to_be_added: @@ -264,42 +206,6 @@ async def _process_remaining_batches() -> None: timeout=wait_for_all_requests_to_be_added_timeout, ) - async def _process_batch(self, batch: Sequence[Request], base_retry_wait: timedelta, attempt: int = 1) -> None: - max_attempts = 5 - response = await self._resource_client.batch_add_requests(batch) - - if response.unprocessed_requests: - logger.debug(f'Following requests were not processed: {response.unprocessed_requests}.') - if attempt > max_attempts: - logger.warning( - f'Following requests were not processed even after {max_attempts} attempts:\n' - f'{response.unprocessed_requests}' - ) - else: - logger.debug('Retry to add requests.') - unprocessed_requests_unique_keys = {request.unique_key for request in response.unprocessed_requests} - retry_batch = [request for request in batch if request.unique_key in unprocessed_requests_unique_keys] - await asyncio.sleep((base_retry_wait * attempt).total_seconds()) - await self._process_batch(retry_batch, base_retry_wait=base_retry_wait, attempt=attempt + 1) - - request_count = len(batch) - len(response.unprocessed_requests) - self._assumed_total_count += request_count - if request_count: - logger.debug( - f'Added {request_count} requests to the queue. Processed requests: {response.processed_requests}' - ) - - async def get_request(self, request_id: str) -> Request | None: - """Retrieve a request from the queue. - - Args: - request_id: ID of the request to retrieve. - - Returns: - The retrieved request, or `None`, if it does not exist. - """ - return await self._resource_client.get_request(request_id) - async def fetch_next_request(self) -> Request | None: """Return the next request in the queue to be processed. @@ -313,75 +219,35 @@ async def fetch_next_request(self) -> Request | None: instead. Returns: - The request or `None` if there are no more pending requests. + The next request to process, or `None` if there are no more pending requests. """ - self._last_activity = datetime.now(timezone.utc) - - await self._ensure_head_is_non_empty() - - # We are likely done at this point. - if len(self._queue_head) == 0: - return None + return await self._client.fetch_next_request() - next_request_id = self._queue_head.popleft() - request = await self._get_or_hydrate_request(next_request_id) - - # NOTE: It can happen that the queue head index is inconsistent with the main queue table. - # This can occur in two situations: + async def get_request(self, request_id: str) -> Request | None: + """Retrieve a specific request from the queue by its ID. - # 1) - # Queue head index is ahead of the main table and the request is not present in the main table yet - # (i.e. get_request() returned null). In this case, keep the request marked as in progress for a short while, - # so that is_finished() doesn't return true and _ensure_head_is_non_empty() doesn't not load the request into - # the queueHeadDict straight again. After the interval expires, fetch_next_request() will try to fetch this - # request again, until it eventually appears in the main table. - if request is None: - logger.debug( - 'Cannot find a request from the beginning of queue, will be retried later', - extra={'nextRequestId': next_request_id}, - ) - return None - - # 2) - # Queue head index is behind the main table and the underlying request was already handled (by some other - # client, since we keep the track of handled requests in recently_handled dictionary). We just add the request - # to the recently_handled dictionary so that next call to _ensure_head_is_non_empty() will not put the request - # again to queue_head_dict. - if request.handled_at is not None: - logger.debug( - 'Request fetched from the beginning of queue was already handled', - extra={'nextRequestId': next_request_id}, - ) - return None + Args: + request_id: The ID of the request to retrieve. - self._dequeued_request_count += 1 - return request + Returns: + The request with the specified ID, or `None` if no such request exists. + """ + return await self._client.get_request(request_id) async def mark_request_as_handled(self, request: Request) -> ProcessedRequest | None: """Mark a request as handled after successful processing. - Handled requests will never again be returned by the `RequestQueue.fetch_next_request` method. + This method should be called after a request has been successfully processed. + Once marked as handled, the request will be removed from the queue and will + not be returned in subsequent calls to `fetch_next_request` method. Args: request: The request to mark as handled. Returns: - Information about the queue operation. `None` if the given request was not in progress. + Information about the queue operation. """ - self._last_activity = datetime.now(timezone.utc) - - if request.handled_at is None: - request.handled_at = datetime.now(timezone.utc) - - processed_request = await self._resource_client.update_request(request) - processed_request.unique_key = request.unique_key - self._dequeued_request_count -= 1 - - if not processed_request.was_already_handled: - self._assumed_handled_count += 1 - - self._cache_request(unique_key_to_request_id(request.unique_key), processed_request, forefront=False) - return processed_request + return await self._client.mark_request_as_handled(request) async def reclaim_request( self, @@ -389,325 +255,83 @@ async def reclaim_request( *, forefront: bool = False, ) -> ProcessedRequest | None: - """Reclaim a failed request back to the queue. + """Reclaim a failed request back to the queue for later processing. - The request will be returned for processing later again by another call to `RequestQueue.fetch_next_request`. + If a request fails during processing, this method can be used to return it to the queue. + The request will be returned for processing again in a subsequent call + to `RequestQueue.fetch_next_request`. Args: request: The request to return to the queue. - forefront: Whether to add the request to the head or the end of the queue. + forefront: If true, the request will be added to the beginning of the queue. + Otherwise, it will be added to the end. Returns: - Information about the queue operation. `None` if the given request was not in progress. + Information about the queue operation. """ - self._last_activity = datetime.now(timezone.utc) - - processed_request = await self._resource_client.update_request(request, forefront=forefront) - processed_request.unique_key = request.unique_key - self._cache_request(unique_key_to_request_id(request.unique_key), processed_request, forefront=forefront) - - if forefront: - self._should_check_for_forefront_requests = True - - if processed_request: - # Try to delete the request lock if possible - try: - await self._resource_client.delete_request_lock(request.id, forefront=forefront) - except Exception as err: - logger.debug(f'Failed to delete request lock for request {request.id}', exc_info=err) - - return processed_request + return await self._client.reclaim_request(request, forefront=forefront) async def is_empty(self) -> bool: - """Check whether the queue is empty. + """Check if the request queue is empty. + + An empty queue means that there are no requests currently in the queue, either pending or being processed. + However, this does not necessarily mean that the crawling operation is finished, as there still might be + tasks that could add additional requests to the queue. Returns: - bool: `True` if the next call to `RequestQueue.fetch_next_request` would return `None`, otherwise `False`. + True if the request queue is empty, False otherwise. """ - await self._ensure_head_is_non_empty() - return len(self._queue_head) == 0 + return await self._client.is_empty() async def is_finished(self) -> bool: - """Check whether the queue is finished. + """Check if the request queue is finished. - Due to the nature of distributed storage used by the queue, the function might occasionally return a false - negative, but it will never return a false positive. + A finished queue means that all requests in the queue have been processed (the queue is empty) and there + are no more tasks that could add additional requests to the queue. This is the definitive way to check + if a crawling operation is complete. Returns: - bool: `True` if all requests were already handled and there are no more left. `False` otherwise. + True if the request queue is finished (empty and no pending add operations), False otherwise. """ - if self._tasks: - logger.debug('Background tasks are still in progress') + if self._add_requests_tasks: + logger.debug('Background add requests tasks are still in progress.') return False - if self._queue_head: - logger.debug( - 'There are still ids in the queue head that are pending processing', - extra={ - 'queue_head_ids_pending': len(self._queue_head), - }, - ) - - return False - - await self._ensure_head_is_non_empty() - - if self._queue_head: - logger.debug('Queue head still returned requests that need to be processed') - - return False - - # Could not lock any new requests - decide based on whether the queue contains requests locked by another client - if self._queue_has_locked_requests is not None: - if self._queue_has_locked_requests and self._dequeued_request_count == 0: - # The `% 25` was absolutely arbitrarily picked. It's just to not spam the logs too much. - if self._is_finished_log_throttle_counter % 25 == 0: - logger.info('The queue still contains requests locked by another client') - - self._is_finished_log_throttle_counter += 1 - - logger.debug( - f'Deciding if we are finished based on `queue_has_locked_requests` = {self._queue_has_locked_requests}' - ) - return not self._queue_has_locked_requests - - metadata = await self._resource_client.get() - if metadata is not None and not metadata.had_multiple_clients and not self._queue_head: - logger.debug('Queue head is empty and there are no other clients - we are finished') - + if await self.is_empty(): + logger.debug('The request queue is empty.') return True - # The following is a legacy algorithm for checking if the queue is finished. - # It is used only for request queue clients that do not provide the `queue_has_locked_requests` flag. - current_head = await self._resource_client.list_head(limit=2) - - if current_head.items: - logger.debug('The queue still contains unfinished requests or requests locked by another client') - - return len(current_head.items) == 0 - - async def get_info(self) -> RequestQueueMetadata | None: - """Get an object containing general information about the request queue.""" - return await self._resource_client.get() - - @override - async def get_handled_count(self) -> int: - return self._assumed_handled_count - - @override - async def get_total_count(self) -> int: - return self._assumed_total_count - - async def _ensure_head_is_non_empty(self) -> None: - # Stop fetching if we are paused for migration - if self._queue_paused_for_migration: - return + return False - # We want to fetch ahead of time to minimize dead time - if len(self._queue_head) > 1 and not self._should_check_for_forefront_requests: - return - - if self._list_head_and_lock_task is None: - task = asyncio.create_task(self._list_head_and_lock(), name='request_queue_list_head_and_lock_task') - - def callback(_: Any) -> None: - self._list_head_and_lock_task = None - - task.add_done_callback(callback) - self._list_head_and_lock_task = task - - await self._list_head_and_lock_task - - async def _list_head_and_lock(self) -> None: - # Make a copy so that we can clear the flag only if the whole method executes after the flag was set - # (i.e, it was not set in the middle of the execution of the method) - should_check_for_forefront_requests = self._should_check_for_forefront_requests - - limit = 25 - - response = await self._resource_client.list_and_lock_head( - limit=limit, lock_secs=int(self._request_lock_time.total_seconds()) - ) - - self._queue_has_locked_requests = response.queue_has_locked_requests - - head_id_buffer = list[str]() - forefront_head_id_buffer = list[str]() + async def _process_batch( + self, + batch: Sequence[Request], + *, + base_retry_wait: timedelta, + attempt: int = 1, + forefront: bool = False, + ) -> None: + """Process a batch of requests with automatic retry mechanism.""" + max_attempts = 5 + response = await self._client.add_batch_of_requests(batch, forefront=forefront) - for request in response.items: - # Queue head index might be behind the main table, so ensure we don't recycle requests - if not request.id or not request.unique_key: - logger.debug( - 'Skipping request from queue head, already in progress or recently handled', - extra={ - 'id': request.id, - 'unique_key': request.unique_key, - }, + if response.unprocessed_requests: + logger.debug(f'Following requests were not processed: {response.unprocessed_requests}.') + if attempt > max_attempts: + logger.warning( + f'Following requests were not processed even after {max_attempts} attempts:\n' + f'{response.unprocessed_requests}' ) - - # Remove the lock from the request for now, so that it can be picked up later - # This may/may not succeed, but that's fine - with suppress(Exception): - await self._resource_client.delete_request_lock(request.id) - - continue - - # If we remember that we added the request ourselves and we added it to the forefront, - # we will put it to the beginning of the local queue head to preserve the expected order. - # If we do not remember that, we will enqueue it normally. - cached_request = self._requests_cache.get(unique_key_to_request_id(request.unique_key)) - forefront = cached_request['forefront'] if cached_request else False - - if forefront: - forefront_head_id_buffer.insert(0, request.id) else: - head_id_buffer.append(request.id) - - self._cache_request( - unique_key_to_request_id(request.unique_key), - ProcessedRequest( - id=request.id, - unique_key=request.unique_key, - was_already_present=True, - was_already_handled=False, - ), - forefront=forefront, - ) - - for request_id in head_id_buffer: - self._queue_head.append(request_id) - - for request_id in forefront_head_id_buffer: - self._queue_head.appendleft(request_id) - - # If the queue head became too big, unlock the excess requests - to_unlock = list[str]() - while len(self._queue_head) > limit: - to_unlock.append(self._queue_head.pop()) - - if to_unlock: - await asyncio.gather( - *[self._resource_client.delete_request_lock(request_id) for request_id in to_unlock], - return_exceptions=True, # Just ignore the exceptions - ) - - # Unset the should_check_for_forefront_requests flag - the check is finished - if should_check_for_forefront_requests: - self._should_check_for_forefront_requests = False - - def _reset(self) -> None: - self._queue_head.clear() - self._list_head_and_lock_task = None - self._assumed_total_count = 0 - self._assumed_handled_count = 0 - self._requests_cache.clear() - self._last_activity = datetime.now(timezone.utc) - - def _cache_request(self, cache_key: str, processed_request: ProcessedRequest, *, forefront: bool) -> None: - self._requests_cache[cache_key] = { - 'id': processed_request.id, - 'was_already_handled': processed_request.was_already_handled, - 'hydrated': None, - 'lock_expires_at': None, - 'forefront': forefront, - } - - async def _get_or_hydrate_request(self, request_id: str) -> Request | None: - cached_entry = self._requests_cache.get(request_id) - - if not cached_entry: - # 2.1. Attempt to prolong the request lock to see if we still own the request - prolong_result = await self._prolong_request_lock(request_id) - - if not prolong_result: - return None - - # 2.1.1. If successful, hydrate the request and return it - hydrated_request = await self.get_request(request_id) - - # Queue head index is ahead of the main table and the request is not present in the main table yet - # (i.e. get_request() returned null). - if not hydrated_request: - # Remove the lock from the request for now, so that it can be picked up later - # This may/may not succeed, but that's fine - with suppress(Exception): - await self._resource_client.delete_request_lock(request_id) - - return None - - self._requests_cache[request_id] = { - 'id': request_id, - 'hydrated': hydrated_request, - 'was_already_handled': hydrated_request.handled_at is not None, - 'lock_expires_at': prolong_result, - 'forefront': False, - } - - return hydrated_request - - # 1.1. If hydrated, prolong the lock more and return it - if cached_entry['hydrated']: - # 1.1.1. If the lock expired on the hydrated requests, try to prolong. If we fail, we lost the request - # (or it was handled already) - if cached_entry['lock_expires_at'] and cached_entry['lock_expires_at'] < datetime.now(timezone.utc): - prolonged = await self._prolong_request_lock(cached_entry['id']) - - if not prolonged: - return None - - cached_entry['lock_expires_at'] = prolonged - - return cached_entry['hydrated'] - - # 1.2. If not hydrated, try to prolong the lock first (to ensure we keep it in our queue), hydrate and return it - prolonged = await self._prolong_request_lock(cached_entry['id']) - - if not prolonged: - return None - - # This might still return null if the queue head is inconsistent with the main queue table. - hydrated_request = await self.get_request(cached_entry['id']) - - cached_entry['hydrated'] = hydrated_request - - # Queue head index is ahead of the main table and the request is not present in the main table yet - # (i.e. get_request() returned null). - if not hydrated_request: - # Remove the lock from the request for now, so that it can be picked up later - # This may/may not succeed, but that's fine - with suppress(Exception): - await self._resource_client.delete_request_lock(cached_entry['id']) - - return None + logger.debug('Retry to add requests.') + unprocessed_requests_unique_keys = {request.unique_key for request in response.unprocessed_requests} + retry_batch = [request for request in batch if request.unique_key in unprocessed_requests_unique_keys] + await asyncio.sleep((base_retry_wait * attempt).total_seconds()) + await self._process_batch(retry_batch, base_retry_wait=base_retry_wait, attempt=attempt + 1) - return hydrated_request + request_count = len(batch) - len(response.unprocessed_requests) - async def _prolong_request_lock(self, request_id: str) -> datetime | None: - try: - res = await self._resource_client.prolong_request_lock( - request_id, lock_secs=int(self._request_lock_time.total_seconds()) - ) - except Exception as err: - # Most likely we do not own the lock anymore - logger.warning( - f'Failed to prolong lock for cached request {request_id}, either lost the lock ' - 'or the request was already handled\n', - exc_info=err, + if request_count: + logger.debug( + f'Added {request_count} requests to the queue. Processed requests: {response.processed_requests}' ) - return None - else: - return res.lock_expires_at - - async def _clear_possible_locks(self) -> None: - self._queue_paused_for_migration = True - request_id: str | None = None - - while True: - try: - request_id = self._queue_head.pop() - except LookupError: - break - - with suppress(Exception): - await self._resource_client.delete_request_lock(request_id) - # If this fails, we don't have the lock, or the request was never locked. Either way it's fine diff --git a/src/crawlee/storages/_storage_instance_manager.py b/src/crawlee/storages/_storage_instance_manager.py new file mode 100644 index 0000000000..9bd52f1219 --- /dev/null +++ b/src/crawlee/storages/_storage_instance_manager.py @@ -0,0 +1,135 @@ +from __future__ import annotations + +from collections.abc import Awaitable +from typing import TYPE_CHECKING, Callable, TypeVar, Union, cast + +from crawlee._utils.docs import docs_group +from crawlee.storage_clients._base import DatasetClient, KeyValueStoreClient, RequestQueueClient + +from ._base import Storage + +if TYPE_CHECKING: + from crawlee.configuration import Configuration + +T = TypeVar('T', bound='Storage') + +StorageClientType = Union[DatasetClient, KeyValueStoreClient, RequestQueueClient] +"""Type alias for the storage client types.""" + +ClientOpener = Callable[..., Awaitable[StorageClientType]] +"""Type alias for the client opener function.""" + + +@docs_group('Classes') +class StorageInstanceManager: + """Manager for caching and managing storage instances. + + This class centralizes the caching logic for all storage types (Dataset, KeyValueStore, RequestQueue) + and provides a unified interface for opening and managing storage instances. + """ + + def __init__(self) -> None: + self._cache_by_id = dict[type[Storage], dict[str, Storage]]() + """Cache for storage instances by ID, separated by storage type.""" + + self._cache_by_name = dict[type[Storage], dict[str, Storage]]() + """Cache for storage instances by name, separated by storage type.""" + + self._default_instances = dict[type[Storage], Storage]() + """Cache for default instances of each storage type.""" + + async def open_storage_instance( + self, + cls: type[T], + *, + id: str | None, + name: str | None, + configuration: Configuration, + client_opener: ClientOpener, + ) -> T: + """Open a storage instance with caching support. + + Args: + cls: The storage class to instantiate. + id: Storage ID. + name: Storage name. + configuration: Configuration object. + client_opener: Function to create the storage client. + + Returns: + The storage instance. + + Raises: + ValueError: If both id and name are specified. + """ + if id and name: + raise ValueError('Only one of "id" or "name" can be specified, not both.') + + # Check for default instance + if id is None and name is None and cls in self._default_instances: + return cast('T', self._default_instances[cls]) + + # Check cache + if id is not None: + type_cache_by_id = self._cache_by_id.get(cls, {}) + if id in type_cache_by_id: + cached_instance = type_cache_by_id[id] + if isinstance(cached_instance, cls): + return cached_instance + + if name is not None: + type_cache_by_name = self._cache_by_name.get(cls, {}) + if name in type_cache_by_name: + cached_instance = type_cache_by_name[name] + if isinstance(cached_instance, cls): + return cached_instance + + # Create new instance + client = await client_opener(id=id, name=name, configuration=configuration) + metadata = await client.get_metadata() + + instance = cls(client, metadata.id, metadata.name) # type: ignore[call-arg] + instance_name = getattr(instance, 'name', None) + + # Cache the instance + type_cache_by_id = self._cache_by_id.setdefault(cls, {}) + type_cache_by_name = self._cache_by_name.setdefault(cls, {}) + + type_cache_by_id[instance.id] = instance + if instance_name is not None: + type_cache_by_name[instance_name] = instance + + # Set as default if no id/name specified + if id is None and name is None: + self._default_instances[cls] = instance + + return instance + + def remove_from_cache(self, storage_instance: Storage) -> None: + """Remove a storage instance from the cache. + + Args: + storage_instance: The storage instance to remove. + """ + storage_type = type(storage_instance) + + # Remove from ID cache + type_cache_by_id = self._cache_by_id.get(storage_type, {}) + if storage_instance.id in type_cache_by_id: + del type_cache_by_id[storage_instance.id] + + # Remove from name cache + if storage_instance.name is not None: + type_cache_by_name = self._cache_by_name.get(storage_type, {}) + if storage_instance.name in type_cache_by_name: + del type_cache_by_name[storage_instance.name] + + # Remove from default instances + if storage_type in self._default_instances and self._default_instances[storage_type] is storage_instance: + del self._default_instances[storage_type] + + def clear_cache(self) -> None: + """Clear all cached storage instances.""" + self._cache_by_id.clear() + self._cache_by_name.clear() + self._default_instances.clear() diff --git a/tests/e2e/project_template/utils.py b/tests/e2e/project_template/utils.py index 3bc5be4ea6..685e8c45e8 100644 --- a/tests/e2e/project_template/utils.py +++ b/tests/e2e/project_template/utils.py @@ -20,23 +20,25 @@ def patch_crawlee_version_in_project( def _patch_crawlee_version_in_requirements_txt_based_project(project_path: Path, wheel_path: Path) -> None: # Get any extras - with open(project_path / 'requirements.txt') as f: + requirements_path = project_path / 'requirements.txt' + with requirements_path.open() as f: requirements = f.read() crawlee_extras = re.findall(r'crawlee(\[.*\])', requirements)[0] or '' # Modify requirements.txt to use crawlee from wheel file instead of from Pypi - with open(project_path / 'requirements.txt') as f: + with requirements_path.open() as f: modified_lines = [] for line in f: if 'crawlee' in line: modified_lines.append(f'./{wheel_path.name}{crawlee_extras}\n') else: modified_lines.append(line) - with open(project_path / 'requirements.txt', 'w') as f: + with requirements_path.open('w') as f: f.write(''.join(modified_lines)) # Patch the dockerfile to have wheel file available - with open(project_path / 'Dockerfile') as f: + dockerfile_path = project_path / 'Dockerfile' + with dockerfile_path.open() as f: modified_lines = [] for line in f: modified_lines.append(line) @@ -49,19 +51,21 @@ def _patch_crawlee_version_in_requirements_txt_based_project(project_path: Path, f'RUN pip install ./{wheel_path.name}{crawlee_extras} --force-reinstall\n', ] ) - with open(project_path / 'Dockerfile', 'w') as f: + with dockerfile_path.open('w') as f: f.write(''.join(modified_lines)) def _patch_crawlee_version_in_pyproject_toml_based_project(project_path: Path, wheel_path: Path) -> None: """Ensure that the test is using current version of the crawlee from the source and not from Pypi.""" # Get any extras - with open(project_path / 'pyproject.toml') as f: + pyproject_path = project_path / 'pyproject.toml' + with pyproject_path.open() as f: pyproject = f.read() crawlee_extras = re.findall(r'crawlee(\[.*\])', pyproject)[0] or '' # Inject crawlee wheel file to the docker image and update project to depend on it.""" - with open(project_path / 'Dockerfile') as f: + dockerfile_path = project_path / 'Dockerfile' + with dockerfile_path.open() as f: modified_lines = [] for line in f: modified_lines.append(line) @@ -94,5 +98,5 @@ def _patch_crawlee_version_in_pyproject_toml_based_project(project_path: Path, w f'RUN {package_manager} lock\n', ] ) - with open(project_path / 'Dockerfile', 'w') as f: + with dockerfile_path.open('w') as f: f.write(''.join(modified_lines)) diff --git a/tests/unit/_autoscaling/test_autoscaled_pool.py b/tests/unit/_autoscaling/test_autoscaled_pool.py index 717b178738..b4e82fee76 100644 --- a/tests/unit/_autoscaling/test_autoscaled_pool.py +++ b/tests/unit/_autoscaling/test_autoscaled_pool.py @@ -328,6 +328,8 @@ async def run() -> None: assert done_count == 4 done_count = 0 + await asyncio.sleep(0.2) # Allow any lingering callbacks to complete + done_count = 0 # Reset again to ensure clean state await pool.run() assert done_count == 4 diff --git a/tests/unit/_utils/test_data_processing.py b/tests/unit/_utils/test_data_processing.py deleted file mode 100644 index c67335517b..0000000000 --- a/tests/unit/_utils/test_data_processing.py +++ /dev/null @@ -1,51 +0,0 @@ -from __future__ import annotations - -from enum import Enum - -import pytest - -from crawlee._types import StorageTypes -from crawlee._utils.data_processing import ( - maybe_extract_enum_member_value, - maybe_parse_body, - raise_on_duplicate_storage, - raise_on_non_existing_storage, -) - - -def test_maybe_extract_enum_member_value() -> None: - class Color(Enum): - RED = 1 - GREEN = 2 - BLUE = 3 - - assert maybe_extract_enum_member_value(Color.RED) == 1 - assert maybe_extract_enum_member_value(Color.GREEN) == 2 - assert maybe_extract_enum_member_value(Color.BLUE) == 3 - assert maybe_extract_enum_member_value(10) == 10 - assert maybe_extract_enum_member_value('test') == 'test' - assert maybe_extract_enum_member_value(None) is None - - -def test_maybe_parse_body() -> None: - json_body = b'{"key": "value"}' - xml_body = b'ToveJani' - text_body = b'Plain text content' - binary_body = b'\x00\x01\x02' - - assert maybe_parse_body(json_body, 'application/json') == {'key': 'value'} - assert maybe_parse_body(xml_body, 'application/xml') == 'ToveJani' - assert maybe_parse_body(text_body, 'text/plain') == 'Plain text content' - assert maybe_parse_body(binary_body, 'application/octet-stream') == binary_body - assert maybe_parse_body(xml_body, 'text/xml') == 'ToveJani' - assert maybe_parse_body(text_body, 'text/plain; charset=utf-8') == 'Plain text content' - - -def test_raise_on_duplicate_storage() -> None: - with pytest.raises(ValueError, match='Dataset with name "test" already exists.'): - raise_on_duplicate_storage(StorageTypes.DATASET, 'name', 'test') - - -def test_raise_on_non_existing_storage() -> None: - with pytest.raises(ValueError, match='Dataset with id "kckxQw6j6AtrgyA09" does not exist.'): - raise_on_non_existing_storage(StorageTypes.DATASET, 'kckxQw6j6AtrgyA09') diff --git a/tests/unit/_utils/test_file.py b/tests/unit/_utils/test_file.py index a86291b43f..c00618b600 100644 --- a/tests/unit/_utils/test_file.py +++ b/tests/unit/_utils/test_file.py @@ -1,20 +1,8 @@ from __future__ import annotations -import io from datetime import datetime, timezone -from pathlib import Path -import pytest - -from crawlee._utils.file import ( - ContentType, - determine_file_extension, - force_remove, - force_rename, - is_content_type, - is_file_or_bytes, - json_dumps, -) +from crawlee._utils.file import json_dumps async def test_json_dumps() -> None: @@ -23,127 +11,3 @@ async def test_json_dumps() -> None: assert await json_dumps('string') == '"string"' assert await json_dumps(123) == '123' assert await json_dumps(datetime(2022, 1, 1, tzinfo=timezone.utc)) == '"2022-01-01 00:00:00+00:00"' - - -def test_is_file_or_bytes() -> None: - assert is_file_or_bytes(b'bytes') is True - assert is_file_or_bytes(bytearray(b'bytearray')) is True - assert is_file_or_bytes(io.BytesIO(b'some bytes')) is True - assert is_file_or_bytes(io.StringIO('string')) is True - assert is_file_or_bytes('just a regular string') is False - assert is_file_or_bytes(12345) is False - - -@pytest.mark.parametrize( - ('content_type_enum', 'content_type', 'expected_result'), - [ - (ContentType.JSON, 'application/json', True), - (ContentType.JSON, 'application/json; charset=utf-8', True), - (ContentType.JSON, 'text/plain', False), - (ContentType.JSON, 'application/xml', False), - (ContentType.XML, 'application/xml', True), - (ContentType.XML, 'application/xhtml+xml', True), - (ContentType.XML, 'text/xml; charset=utf-8', False), - (ContentType.XML, 'application/json', False), - (ContentType.TEXT, 'text/plain', True), - (ContentType.TEXT, 'text/html; charset=utf-8', True), - (ContentType.TEXT, 'application/json', False), - (ContentType.TEXT, 'application/xml', False), - ], - ids=[ - 'json_valid_simple', - 'json_valid_charset', - 'json_invalid_text', - 'json_invalid_xml', - 'xml_valid_simple', - 'xml_valid_xhtml', - 'xml_invalid_text_charset', - 'xml_invalid_json', - 'text_valid_plain', - 'text_valid_html_charset', - 'text_invalid_json', - 'text_invalid_xml', - ], -) -def test_is_content_type(content_type_enum: ContentType, content_type: str, *, expected_result: bool) -> None: - result = is_content_type(content_type_enum, content_type) - assert expected_result == result - - -def test_is_content_type_json() -> None: - assert is_content_type(ContentType.JSON, 'application/json') is True - assert is_content_type(ContentType.JSON, 'application/json; charset=utf-8') is True - assert is_content_type(ContentType.JSON, 'text/plain') is False - assert is_content_type(ContentType.JSON, 'application/xml') is False - - -def test_is_content_type_xml() -> None: - assert is_content_type(ContentType.XML, 'application/xml') is True - assert is_content_type(ContentType.XML, 'application/xhtml+xml') is True - assert is_content_type(ContentType.XML, 'text/xml; charset=utf-8') is False - assert is_content_type(ContentType.XML, 'application/json') is False - - -def test_is_content_type_text() -> None: - assert is_content_type(ContentType.TEXT, 'text/plain') is True - assert is_content_type(ContentType.TEXT, 'text/html; charset=utf-8') is True - assert is_content_type(ContentType.TEXT, 'application/json') is False - assert is_content_type(ContentType.TEXT, 'application/xml') is False - - -def test_determine_file_extension() -> None: - # Can determine common types properly - assert determine_file_extension('application/json') == 'json' - assert determine_file_extension('application/xml') == 'xml' - assert determine_file_extension('text/plain') == 'txt' - - # Can handle unusual formats - assert determine_file_extension(' application/json ') == 'json' - assert determine_file_extension('APPLICATION/JSON') == 'json' - assert determine_file_extension('application/json;charset=utf-8') == 'json' - - # Return None for non-existent content types - assert determine_file_extension('clearly not a content type') is None - assert determine_file_extension('') is None - - -async def test_force_remove(tmp_path: Path) -> None: - test_file_path = Path(tmp_path, 'test.txt') - # Does not crash/raise when the file does not exist - assert test_file_path.exists() is False - await force_remove(test_file_path) - assert test_file_path.exists() is False - - # Remove the file if it exists - with open(test_file_path, 'a', encoding='utf-8'): # noqa: ASYNC230 - pass - assert test_file_path.exists() is True - await force_remove(test_file_path) - assert test_file_path.exists() is False - - -async def test_force_rename(tmp_path: Path) -> None: - src_dir = Path(tmp_path, 'src') - dst_dir = Path(tmp_path, 'dst') - src_file = Path(src_dir, 'src_dir.txt') - dst_file = Path(dst_dir, 'dst_dir.txt') - # Won't crash if source directory does not exist - assert src_dir.exists() is False - await force_rename(src_dir, dst_dir) - - # Will remove dst_dir if it exists (also covers normal case) - # Create the src_dir with a file in it - src_dir.mkdir() - with open(src_file, 'a', encoding='utf-8'): # noqa: ASYNC230 - pass - # Create the dst_dir with a file in it - dst_dir.mkdir() - with open(dst_file, 'a', encoding='utf-8'): # noqa: ASYNC230 - pass - assert src_file.exists() is True - assert dst_file.exists() is True - await force_rename(src_dir, dst_dir) - assert src_dir.exists() is False - assert dst_file.exists() is False - # src_dir.txt should exist in dst_dir - assert (dst_dir / 'src_dir.txt').exists() is True diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index b7ac06d124..0d139c9372 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -13,12 +13,10 @@ from uvicorn.config import Config from crawlee import service_locator -from crawlee.configuration import Configuration from crawlee.fingerprint_suite._browserforge_adapter import get_available_header_network from crawlee.http_clients import CurlImpersonateHttpClient, HttpxHttpClient from crawlee.proxy_configuration import ProxyInfo -from crawlee.storage_clients import MemoryStorageClient -from crawlee.storages import KeyValueStore, _creation_management +from crawlee.storages import KeyValueStore from tests.unit.server import TestServer, app, serve_in_thread if TYPE_CHECKING: @@ -63,14 +61,12 @@ def _prepare_test_env() -> None: service_locator._configuration = None service_locator._event_manager = None service_locator._storage_client = None + service_locator._storage_instance_manager = None - # Clear creation-related caches to ensure no state is carried over between tests. - monkeypatch.setattr(_creation_management, '_cache_dataset_by_id', {}) - monkeypatch.setattr(_creation_management, '_cache_dataset_by_name', {}) - monkeypatch.setattr(_creation_management, '_cache_kvs_by_id', {}) - monkeypatch.setattr(_creation_management, '_cache_kvs_by_name', {}) - monkeypatch.setattr(_creation_management, '_cache_rq_by_id', {}) - monkeypatch.setattr(_creation_management, '_cache_rq_by_name', {}) + # Reset the retrieval flags + service_locator._configuration_was_retrieved = False + service_locator._event_manager_was_retrieved = False + service_locator._storage_client_was_retrieved = False # Verify that the test environment was set up correctly. assert os.environ.get('CRAWLEE_STORAGE_DIR') == str(tmp_path) @@ -149,18 +145,6 @@ async def disabled_proxy(proxy_info: ProxyInfo) -> AsyncGenerator[ProxyInfo, Non yield proxy_info -@pytest.fixture -def memory_storage_client(tmp_path: Path) -> MemoryStorageClient: - """A fixture for testing the memory storage client and its resource clients.""" - config = Configuration( - persist_storage=True, - write_metadata=True, - crawlee_storage_dir=str(tmp_path), # type: ignore[call-arg] - ) - - return MemoryStorageClient.from_config(config) - - @pytest.fixture(scope='session') def header_network() -> dict: return get_available_header_network() diff --git a/tests/unit/crawlers/_adaptive_playwright/test_adaptive_playwright_crawler.py b/tests/unit/crawlers/_adaptive_playwright/test_adaptive_playwright_crawler.py index 3ee386324a..19ebfa9cf7 100644 --- a/tests/unit/crawlers/_adaptive_playwright/test_adaptive_playwright_crawler.py +++ b/tests/unit/crawlers/_adaptive_playwright/test_adaptive_playwright_crawler.py @@ -356,9 +356,11 @@ async def request_handler(context: AdaptivePlaywrightCrawlingContext) -> None: ): await crawler.run(requests) - mocked_predict.assert_called_once_with(requests[0]) + assert mocked_predict.call_count == 1 + assert mocked_predict.call_args[0][0].url == requests[0].url + # If `static` and `client only` results are same, `store_result` should be called with `static`. - mocked_store_result.assert_called_once_with(requests[0], expected_result_rendering_type) + mocked_store_result.assert_called_once_with(mocked_predict.call_args[0][0], expected_result_rendering_type) async def test_adaptive_crawling_result_use_state_isolation( @@ -500,10 +502,10 @@ async def test_adaptive_playwright_crawler_timeout_in_sub_crawler(test_urls: lis """Tests that timeout in static sub crawler forces fall back to browser sub crawler. Create situation where static sub crawler blocks(should time out), such error should start browser sub - crawler.""" - + crawler. + """ static_only_predictor_no_detection = _SimpleRenderingTypePredictor(detection_probability_recommendation=cycle([0])) - request_handler_timeout = timedelta(seconds=0.1) + request_handler_timeout = timedelta(seconds=1) crawler = AdaptivePlaywrightCrawler.with_beautifulsoup_static_parser( max_request_retries=1, @@ -522,9 +524,9 @@ async def request_handler(context: AdaptivePlaywrightCrawlingContext) -> None: except AdaptiveContextError: mocked_static_handler() # Relax timeout for the fallback browser request to avoid flakiness in test - crawler._request_handler_timeout = timedelta(seconds=5) + crawler._request_handler_timeout = timedelta(seconds=10) # Sleep for time obviously larger than top crawler timeout. - await asyncio.sleep(request_handler_timeout.total_seconds() * 2) + await asyncio.sleep(request_handler_timeout.total_seconds() * 3) await crawler.run(test_urls[:1]) diff --git a/tests/unit/crawlers/_basic/test_basic_crawler.py b/tests/unit/crawlers/_basic/test_basic_crawler.py index 804a64fae2..9479b11143 100644 --- a/tests/unit/crawlers/_basic/test_basic_crawler.py +++ b/tests/unit/crawlers/_basic/test_basic_crawler.py @@ -10,7 +10,6 @@ from collections import Counter from dataclasses import dataclass from datetime import timedelta -from pathlib import Path from typing import TYPE_CHECKING, Any, Literal, cast from unittest.mock import AsyncMock, Mock, call, patch @@ -32,16 +31,16 @@ if TYPE_CHECKING: from collections.abc import Callable, Sequence + from pathlib import Path from yarl import URL from crawlee._types import JsonSerializable - from crawlee.storage_clients._memory import DatasetClient async def test_processes_requests_from_explicit_queue() -> None: queue = await RequestQueue.open() - await queue.add_requests_batched(['http://a.com/', 'http://b.com/', 'http://c.com/']) + await queue.add_requests(['http://a.com/', 'http://b.com/', 'http://c.com/']) crawler = BasicCrawler(request_manager=queue) calls = list[str]() @@ -57,7 +56,7 @@ async def handler(context: BasicCrawlingContext) -> None: async def test_processes_requests_from_request_source_tandem() -> None: request_queue = await RequestQueue.open() - await request_queue.add_requests_batched(['http://a.com/', 'http://b.com/', 'http://c.com/']) + await request_queue.add_requests(['http://a.com/', 'http://b.com/', 'http://c.com/']) request_list = RequestList(['http://a.com/', 'http://d.com', 'http://e.com']) @@ -537,8 +536,8 @@ async def handler(context: BasicCrawlingContext) -> None: assert visited == set(test_input.expected_urls) -async def test_session_rotation() -> None: - track_session_usage = Mock() +async def test_session_rotation(server_url: URL) -> None: + session_ids: list[str | None] = [] crawler = BasicCrawler( max_session_rotations=7, @@ -547,16 +546,20 @@ async def test_session_rotation() -> None: @crawler.router.default_handler async def handler(context: BasicCrawlingContext) -> None: - track_session_usage(context.session.id if context.session else None) + session_ids.append(context.session.id if context.session else None) raise SessionError('Test error') - await crawler.run([Request.from_url('https://someplace.com/', label='start')]) - assert track_session_usage.call_count == 7 + await crawler.run([str(server_url)]) - session_ids = {call[0][0] for call in track_session_usage.call_args_list} + # exactly 7 handler calls happened assert len(session_ids) == 7 + + # all session ids are not None assert None not in session_ids + # and each was a different session + assert len(set(session_ids)) == 7 + async def test_final_statistics() -> None: crawler = BasicCrawler(max_request_retries=3) @@ -639,14 +642,14 @@ async def test_context_push_and_get_data() -> None: crawler = BasicCrawler() dataset = await Dataset.open() - await dataset.push_data('{"a": 1}') + await dataset.push_data({'a': 1}) assert (await crawler.get_data()).items == [{'a': 1}] @crawler.router.default_handler async def handler(context: BasicCrawlingContext) -> None: - await context.push_data('{"b": 2}') + await context.push_data({'b': 2}) - await dataset.push_data('{"c": 3}') + await dataset.push_data({'c': 3}) assert (await crawler.get_data()).items == [{'a': 1}, {'c': 3}] stats = await crawler.run(['http://test.io/1']) @@ -661,7 +664,7 @@ async def test_context_push_and_get_data_handler_error() -> None: @crawler.router.default_handler async def handler(context: BasicCrawlingContext) -> None: - await context.push_data('{"b": 2}') + await context.push_data({'b': 2}) raise RuntimeError('Watch me crash') stats = await crawler.run(['https://a.com']) @@ -679,8 +682,8 @@ async def test_crawler_push_and_export_data(tmp_path: Path) -> None: await dataset.push_data([{'id': 0, 'test': 'test'}, {'id': 1, 'test': 'test'}]) await dataset.push_data({'id': 2, 'test': 'test'}) - await crawler.export_data_json(path=tmp_path / 'dataset.json') - await crawler.export_data_csv(path=tmp_path / 'dataset.csv') + await crawler.export_data(path=tmp_path / 'dataset.json') + await crawler.export_data(path=tmp_path / 'dataset.csv') assert json.load((tmp_path / 'dataset.json').open()) == [ {'id': 0, 'test': 'test'}, @@ -700,8 +703,8 @@ async def handler(context: BasicCrawlingContext) -> None: await crawler.run(['http://test.io/1']) - await crawler.export_data_json(path=tmp_path / 'dataset.json') - await crawler.export_data_csv(path=tmp_path / 'dataset.csv') + await crawler.export_data(path=tmp_path / 'dataset.json') + await crawler.export_data(path=tmp_path / 'dataset.csv') assert json.load((tmp_path / 'dataset.json').open()) == [ {'id': 0, 'test': 'test'}, @@ -712,45 +715,6 @@ async def handler(context: BasicCrawlingContext) -> None: assert (tmp_path / 'dataset.csv').read_bytes() == b'id,test\r\n0,test\r\n1,test\r\n2,test\r\n' -async def test_crawler_push_and_export_data_and_json_dump_parameter(tmp_path: Path) -> None: - crawler = BasicCrawler() - - @crawler.router.default_handler - async def handler(context: BasicCrawlingContext) -> None: - await context.push_data([{'id': 0, 'test': 'test'}, {'id': 1, 'test': 'test'}]) - await context.push_data({'id': 2, 'test': 'test'}) - - await crawler.run(['http://test.io/1']) - - await crawler.export_data_json(path=tmp_path / 'dataset.json', indent=3) - - with (tmp_path / 'dataset.json').open() as json_file: - exported_json_str = json_file.read() - - # Expected data in JSON format with 3 spaces indent - expected_data = [ - {'id': 0, 'test': 'test'}, - {'id': 1, 'test': 'test'}, - {'id': 2, 'test': 'test'}, - ] - expected_json_str = json.dumps(expected_data, indent=3) - - # Assert that the exported JSON string matches the expected JSON string - assert exported_json_str == expected_json_str - - -async def test_crawler_push_data_over_limit() -> None: - crawler = BasicCrawler() - - @crawler.router.default_handler - async def handler(context: BasicCrawlingContext) -> None: - # Push a roughly 15MB payload - this should be enough to break the 9MB limit - await context.push_data({'hello': 'world' * 3 * 1024 * 1024}) - - stats = await crawler.run(['http://example.tld/1']) - assert stats.requests_failed == 1 - - async def test_context_update_kv_store() -> None: crawler = BasicCrawler() @@ -765,7 +729,7 @@ async def handler(context: BasicCrawlingContext) -> None: assert (await store.get_value('foo')) == 'bar' -async def test_context_use_state(key_value_store: KeyValueStore) -> None: +async def test_context_use_state() -> None: crawler = BasicCrawler() @crawler.router.default_handler @@ -774,9 +738,10 @@ async def handler(context: BasicCrawlingContext) -> None: await crawler.run(['https://hello.world']) - store = await crawler.get_key_value_store() + kvs = await crawler.get_key_value_store() + value = await kvs.get_value(BasicCrawler._CRAWLEE_STATE_KEY) - assert (await store.get_value(BasicCrawler._CRAWLEE_STATE_KEY)) == {'hello': 'world'} + assert value == {'hello': 'world'} async def test_context_handlers_use_state(key_value_store: KeyValueStore) -> None: @@ -940,18 +905,6 @@ async def handler(context: BasicCrawlingContext) -> None: } -async def test_respects_no_persist_storage() -> None: - configuration = Configuration(persist_storage=False) - crawler = BasicCrawler(configuration=configuration) - - @crawler.router.default_handler - async def handler(context: BasicCrawlingContext) -> None: - await context.push_data({'something': 'something'}) - - datasets_path = Path(configuration.storage_dir) / 'datasets' / 'default' - assert not datasets_path.exists() or list(datasets_path.iterdir()) == [] - - @pytest.mark.skipif(os.name == 'nt' and 'CI' in os.environ, reason='Skipped in Windows CI') @pytest.mark.parametrize( ('statistics_log_format'), @@ -1091,9 +1044,9 @@ async def handler(context: BasicCrawlingContext) -> None: async def test_sets_services() -> None: custom_configuration = Configuration() custom_event_manager = LocalEventManager.from_config(custom_configuration) - custom_storage_client = MemoryStorageClient.from_config(custom_configuration) + custom_storage_client = MemoryStorageClient() - crawler = BasicCrawler( + _ = BasicCrawler( configuration=custom_configuration, event_manager=custom_event_manager, storage_client=custom_storage_client, @@ -1103,12 +1056,9 @@ async def test_sets_services() -> None: assert service_locator.get_event_manager() is custom_event_manager assert service_locator.get_storage_client() is custom_storage_client - dataset = await crawler.get_dataset(name='test') - assert cast('DatasetClient', dataset._resource_client)._memory_storage_client is custom_storage_client - async def test_allows_storage_client_overwrite_before_run(monkeypatch: pytest.MonkeyPatch) -> None: - custom_storage_client = MemoryStorageClient.from_config() + custom_storage_client = MemoryStorageClient() crawler = BasicCrawler( storage_client=custom_storage_client, @@ -1118,7 +1068,7 @@ async def test_allows_storage_client_overwrite_before_run(monkeypatch: pytest.Mo async def handler(context: BasicCrawlingContext) -> None: await context.push_data({'foo': 'bar'}) - other_storage_client = MemoryStorageClient.from_config() + other_storage_client = MemoryStorageClient() service_locator.set_storage_client(other_storage_client) with monkeypatch.context() as monkey: @@ -1128,8 +1078,6 @@ async def handler(context: BasicCrawlingContext) -> None: assert spy.call_count >= 1 dataset = await crawler.get_dataset() - assert cast('DatasetClient', dataset._resource_client)._memory_storage_client is other_storage_client - data = await dataset.get_data() assert data.items == [{'foo': 'bar'}] @@ -1397,23 +1345,30 @@ async def test_lock_with_get_robots_txt_file_for_url(server_url: URL) -> None: assert spy.call_count == 1 -async def test_reduced_logs_from_timed_out_request_handler( - monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture -) -> None: +async def test_reduced_logs_from_timed_out_request_handler(caplog: pytest.LogCaptureFixture) -> None: caplog.set_level(logging.INFO) - crawler = BasicCrawler(configure_logging=False, request_handler_timeout=timedelta(seconds=1)) + crawler = BasicCrawler( + configure_logging=False, + request_handler_timeout=timedelta(seconds=1), + ) @crawler.router.default_handler async def handler(context: BasicCrawlingContext) -> None: + # Intentionally add a delay longer than the timeout to trigger the timeout mechanism await asyncio.sleep(10) # INJECTED DELAY - await crawler.run([Request.from_url('http://a.com/')]) + # Capture all logs from the 'crawlee' logger at INFO level or higher + with caplog.at_level(logging.INFO, logger='crawlee'): + await crawler.run([Request.from_url('http://a.com/')]) + # Check for the timeout message in any of the logs + found_timeout_message = False for record in caplog.records: - if record.funcName == '_handle_failed_request': + if record.message and 'timed out after 1.0 seconds' in record.message: full_message = (record.message or '') + (record.exc_text or '') assert Counter(full_message)['\n'] < 10 assert '# INJECTED DELAY' in full_message + found_timeout_message = True break - else: - raise AssertionError('Expected log message about request handler error was not found.') + + assert found_timeout_message, 'Expected log message about request handler error was not found.' diff --git a/tests/unit/crawlers/_http/test_http_crawler.py b/tests/unit/crawlers/_http/test_http_crawler.py index 7f00ff2166..807e14e9cc 100644 --- a/tests/unit/crawlers/_http/test_http_crawler.py +++ b/tests/unit/crawlers/_http/test_http_crawler.py @@ -544,7 +544,8 @@ async def request_handler(context: HttpCrawlingContext) -> None: async def test_error_snapshot_through_statistics(server_url: URL) -> None: - crawler = HttpCrawler(statistics=Statistics.with_default_state(save_error_snapshots=True)) + statistics = Statistics.with_default_state(save_error_snapshots=True) + crawler = HttpCrawler(statistics=statistics) @crawler.router.default_handler async def request_handler(context: HttpCrawlingContext) -> None: diff --git a/tests/unit/crawlers/_playwright/test_playwright_crawler.py b/tests/unit/crawlers/_playwright/test_playwright_crawler.py index fc6fb282a9..a8d422c8e7 100644 --- a/tests/unit/crawlers/_playwright/test_playwright_crawler.py +++ b/tests/unit/crawlers/_playwright/test_playwright_crawler.py @@ -708,7 +708,7 @@ async def request_handler(context: PlaywrightCrawlingContext) -> None: async def test_overwrite_configuration() -> None: """Check that the configuration is allowed to be passed to the Playwrightcrawler.""" - configuration = Configuration(default_dataset_id='my_dataset_id') + configuration = Configuration(log_level='WARNING') PlaywrightCrawler(configuration=configuration) used_configuration = service_locator.get_configuration() assert used_configuration is configuration diff --git a/tests/unit/storage_clients/_file_system/test_fs_dataset_client.py b/tests/unit/storage_clients/_file_system/test_fs_dataset_client.py new file mode 100644 index 0000000000..c5f31f144e --- /dev/null +++ b/tests/unit/storage_clients/_file_system/test_fs_dataset_client.py @@ -0,0 +1,165 @@ +from __future__ import annotations + +import asyncio +import json +from pathlib import Path +from typing import TYPE_CHECKING + +import pytest + +from crawlee._consts import METADATA_FILENAME +from crawlee.configuration import Configuration +from crawlee.storage_clients import FileSystemStorageClient + +if TYPE_CHECKING: + from collections.abc import AsyncGenerator + + from crawlee.storage_clients._file_system import FileSystemDatasetClient + + +@pytest.fixture +def configuration(tmp_path: Path) -> Configuration: + return Configuration( + crawlee_storage_dir=str(tmp_path), # type: ignore[call-arg] + ) + + +@pytest.fixture +async def dataset_client(configuration: Configuration) -> AsyncGenerator[FileSystemDatasetClient, None]: + """A fixture for a file system dataset client.""" + client = await FileSystemStorageClient().create_dataset_client( + name='test_dataset', + configuration=configuration, + ) + yield client + await client.drop() + + +async def test_file_and_directory_creation(configuration: Configuration) -> None: + """Test that file system dataset creates proper files and directories.""" + client = await FileSystemStorageClient().create_dataset_client( + name='new_dataset', + configuration=configuration, + ) + + # Verify files were created + assert client.path_to_dataset.exists() + assert client.path_to_metadata.exists() + + # Verify metadata file structure + with client.path_to_metadata.open() as f: + metadata = json.load(f) + client_metadata = await client.get_metadata() + assert metadata['id'] == client_metadata.id + assert metadata['name'] == 'new_dataset' + assert metadata['item_count'] == 0 + + await client.drop() + + +async def test_file_persistence_and_content_verification(dataset_client: FileSystemDatasetClient) -> None: + """Test that data is properly persisted to files with correct content.""" + item = {'key': 'value', 'number': 42} + await dataset_client.push_data(item) + + # Verify files are created on disk + all_files = list(dataset_client.path_to_dataset.glob('*.json')) + assert len(all_files) == 2 # 1 data file + 1 metadata file + + # Verify actual file content + data_files = [item for item in all_files if item.name != METADATA_FILENAME] + assert len(data_files) == 1 + + with Path(data_files[0]).open() as f: + saved_item = json.load(f) + assert saved_item == item + + # Test multiple items file creation + items = [{'id': 1, 'name': 'Item 1'}, {'id': 2, 'name': 'Item 2'}, {'id': 3, 'name': 'Item 3'}] + await dataset_client.push_data(items) + + all_files = list(dataset_client.path_to_dataset.glob('*.json')) + assert len(all_files) == 5 # 4 data files + 1 metadata file + + data_files = [f for f in all_files if f.name != METADATA_FILENAME] + assert len(data_files) == 4 # Original item + 3 new items + + +async def test_drop_removes_files_from_disk(dataset_client: FileSystemDatasetClient) -> None: + """Test that dropping a dataset removes the entire dataset directory from disk.""" + await dataset_client.push_data({'test': 'data'}) + + assert dataset_client.path_to_dataset.exists() + + # Drop the dataset + await dataset_client.drop() + + assert not dataset_client.path_to_dataset.exists() + + +async def test_metadata_file_updates(dataset_client: FileSystemDatasetClient) -> None: + """Test that metadata file is updated correctly after operations.""" + # Record initial timestamps + metadata = await dataset_client.get_metadata() + initial_created = metadata.created_at + initial_accessed = metadata.accessed_at + initial_modified = metadata.modified_at + + # Wait a moment to ensure timestamps can change + await asyncio.sleep(0.01) + + # Perform an operation that updates accessed_at + await dataset_client.get_data() + + # Verify timestamps + metadata = await dataset_client.get_metadata() + assert metadata.created_at == initial_created + assert metadata.accessed_at > initial_accessed + assert metadata.modified_at == initial_modified + + accessed_after_get = metadata.accessed_at + + # Wait a moment to ensure timestamps can change + await asyncio.sleep(0.01) + + # Perform an operation that updates modified_at + await dataset_client.push_data({'new': 'item'}) + + # Verify timestamps again + metadata = await dataset_client.get_metadata() + assert metadata.created_at == initial_created + assert metadata.modified_at > initial_modified + assert metadata.accessed_at > accessed_after_get + + # Verify metadata file is updated on disk + with dataset_client.path_to_metadata.open() as f: + metadata_json = json.load(f) + assert metadata_json['item_count'] == 1 + + +async def test_data_persistence_across_reopens(configuration: Configuration) -> None: + """Test that data persists correctly when reopening the same dataset.""" + storage_client = FileSystemStorageClient() + + # Create dataset and add data + original_client = await storage_client.create_dataset_client( + name='persistence-test', + configuration=configuration, + ) + + test_data = {'test_item': 'test_value', 'id': 123} + await original_client.push_data(test_data) + + dataset_id = (await original_client.get_metadata()).id + + # Reopen by ID and verify data persists + reopened_client = await storage_client.create_dataset_client( + id=dataset_id, + configuration=configuration, + ) + + data = await reopened_client.get_data() + assert len(data.items) == 1 + assert data.items[0] == test_data + + await reopened_client.drop() diff --git a/tests/unit/storage_clients/_file_system/test_fs_kvs_client.py b/tests/unit/storage_clients/_file_system/test_fs_kvs_client.py new file mode 100644 index 0000000000..c5bfa96c47 --- /dev/null +++ b/tests/unit/storage_clients/_file_system/test_fs_kvs_client.py @@ -0,0 +1,211 @@ +from __future__ import annotations + +import asyncio +import json +from typing import TYPE_CHECKING + +import pytest + +from crawlee._consts import METADATA_FILENAME +from crawlee.configuration import Configuration +from crawlee.storage_clients import FileSystemStorageClient + +if TYPE_CHECKING: + from collections.abc import AsyncGenerator + from pathlib import Path + + from crawlee.storage_clients._file_system import FileSystemKeyValueStoreClient + + +@pytest.fixture +def configuration(tmp_path: Path) -> Configuration: + return Configuration( + crawlee_storage_dir=str(tmp_path), # type: ignore[call-arg] + ) + + +@pytest.fixture +async def kvs_client(configuration: Configuration) -> AsyncGenerator[FileSystemKeyValueStoreClient, None]: + """A fixture for a file system key-value store client.""" + client = await FileSystemStorageClient().create_kvs_client( + name='test_kvs', + configuration=configuration, + ) + yield client + await client.drop() + + +async def test_file_and_directory_creation(configuration: Configuration) -> None: + """Test that file system KVS creates proper files and directories.""" + client = await FileSystemStorageClient().create_kvs_client( + name='new_kvs', + configuration=configuration, + ) + + # Verify files were created + assert client.path_to_kvs.exists() + assert client.path_to_metadata.exists() + + # Verify metadata file structure + with client.path_to_metadata.open() as f: + metadata = json.load(f) + assert metadata['id'] == (await client.get_metadata()).id + assert metadata['name'] == 'new_kvs' + + await client.drop() + + +async def test_value_file_creation_and_content(kvs_client: FileSystemKeyValueStoreClient) -> None: + """Test that values are properly persisted to files with correct content and metadata.""" + test_key = 'test-key' + test_value = 'Hello, world!' + await kvs_client.set_value(key=test_key, value=test_value) + + # Check if the files were created + key_path = kvs_client.path_to_kvs / test_key + key_metadata_path = kvs_client.path_to_kvs / f'{test_key}.{METADATA_FILENAME}' + assert key_path.exists() + assert key_metadata_path.exists() + + # Check file content + content = key_path.read_text(encoding='utf-8') + assert content == test_value + + # Check record metadata file + with key_metadata_path.open() as f: + metadata = json.load(f) + assert metadata['key'] == test_key + assert metadata['content_type'] == 'text/plain; charset=utf-8' + assert metadata['size'] == len(test_value.encode('utf-8')) + + +async def test_binary_data_persistence(kvs_client: FileSystemKeyValueStoreClient) -> None: + """Test that binary data is stored correctly without corruption.""" + test_key = 'test-binary' + test_value = b'\x00\x01\x02\x03\x04' + await kvs_client.set_value(key=test_key, value=test_value) + + # Verify binary file exists + key_path = kvs_client.path_to_kvs / test_key + assert key_path.exists() + + # Verify binary content is preserved + content = key_path.read_bytes() + assert content == test_value + + # Verify retrieval works correctly + record = await kvs_client.get_value(key=test_key) + assert record is not None + assert record.value == test_value + assert record.content_type == 'application/octet-stream' + + +async def test_json_serialization_to_file(kvs_client: FileSystemKeyValueStoreClient) -> None: + """Test that JSON objects are properly serialized to files.""" + test_key = 'test-json' + test_value = {'name': 'John', 'age': 30, 'items': [1, 2, 3]} + await kvs_client.set_value(key=test_key, value=test_value) + + # Check if file content is valid JSON + key_path = kvs_client.path_to_kvs / test_key + with key_path.open() as f: + file_content = json.load(f) + assert file_content == test_value + + +async def test_file_deletion_on_value_delete(kvs_client: FileSystemKeyValueStoreClient) -> None: + """Test that deleting a value removes its files from disk.""" + test_key = 'test-delete' + test_value = 'Delete me' + + # Set a value + await kvs_client.set_value(key=test_key, value=test_value) + + # Verify files exist + key_path = kvs_client.path_to_kvs / test_key + metadata_path = kvs_client.path_to_kvs / f'{test_key}.{METADATA_FILENAME}' + assert key_path.exists() + assert metadata_path.exists() + + # Delete the value + await kvs_client.delete_value(key=test_key) + + # Verify files were deleted + assert not key_path.exists() + assert not metadata_path.exists() + + +async def test_drop_removes_directory(kvs_client: FileSystemKeyValueStoreClient) -> None: + """Test that drop removes the entire store directory from disk.""" + await kvs_client.set_value(key='test', value='test-value') + + assert kvs_client.path_to_kvs.exists() + + # Drop the store + await kvs_client.drop() + + assert not kvs_client.path_to_kvs.exists() + + +async def test_metadata_file_updates(kvs_client: FileSystemKeyValueStoreClient) -> None: + """Test that read/write operations properly update metadata file timestamps.""" + # Record initial timestamps + metadata = await kvs_client.get_metadata() + initial_created = metadata.created_at + initial_accessed = metadata.accessed_at + initial_modified = metadata.modified_at + + # Wait a moment to ensure timestamps can change + await asyncio.sleep(0.01) + + # Perform a read operation + await kvs_client.get_value(key='nonexistent') + + # Verify accessed timestamp was updated + metadata = await kvs_client.get_metadata() + assert metadata.created_at == initial_created + assert metadata.accessed_at > initial_accessed + assert metadata.modified_at == initial_modified + + accessed_after_read = metadata.accessed_at + + # Wait a moment to ensure timestamps can change + await asyncio.sleep(0.01) + + # Perform a write operation + await kvs_client.set_value(key='test', value='test-value') + + # Verify modified timestamp was updated + metadata = await kvs_client.get_metadata() + assert metadata.created_at == initial_created + assert metadata.modified_at > initial_modified + assert metadata.accessed_at > accessed_after_read + + +async def test_data_persistence_across_reopens(configuration: Configuration) -> None: + """Test that data persists correctly when reopening the same KVS.""" + storage_client = FileSystemStorageClient() + + # Create KVS and add data + original_client = await storage_client.create_kvs_client( + name='persistence-test', + configuration=configuration, + ) + + test_key = 'persistent-key' + test_value = 'persistent-value' + await original_client.set_value(key=test_key, value=test_value) + + kvs_id = (await original_client.get_metadata()).id + + # Reopen by ID and verify data persists + reopened_client = await storage_client.create_kvs_client( + id=kvs_id, + configuration=configuration, + ) + + record = await reopened_client.get_value(key=test_key) + assert record is not None + assert record.value == test_value + + await reopened_client.drop() diff --git a/tests/unit/storage_clients/_file_system/test_fs_rq_client.py b/tests/unit/storage_clients/_file_system/test_fs_rq_client.py new file mode 100644 index 0000000000..0be182fcd8 --- /dev/null +++ b/tests/unit/storage_clients/_file_system/test_fs_rq_client.py @@ -0,0 +1,173 @@ +from __future__ import annotations + +import asyncio +import json +from typing import TYPE_CHECKING + +import pytest + +from crawlee import Request +from crawlee.configuration import Configuration +from crawlee.storage_clients import FileSystemStorageClient + +if TYPE_CHECKING: + from collections.abc import AsyncGenerator + from pathlib import Path + + from crawlee.storage_clients._file_system import FileSystemRequestQueueClient + + +@pytest.fixture +def configuration(tmp_path: Path) -> Configuration: + return Configuration( + crawlee_storage_dir=str(tmp_path), # type: ignore[call-arg] + ) + + +@pytest.fixture +async def rq_client(configuration: Configuration) -> AsyncGenerator[FileSystemRequestQueueClient, None]: + """A fixture for a file system request queue client.""" + client = await FileSystemStorageClient().create_rq_client( + name='test_request_queue', + configuration=configuration, + ) + yield client + await client.drop() + + +async def test_file_and_directory_creation(configuration: Configuration) -> None: + """Test that file system RQ creates proper files and directories.""" + client = await FileSystemStorageClient().create_rq_client( + name='new_request_queue', + configuration=configuration, + ) + + # Verify files were created + assert client.path_to_rq.exists() + assert client.path_to_metadata.exists() + + # Verify metadata file structure + with client.path_to_metadata.open() as f: + metadata = json.load(f) + assert metadata['id'] == (await client.get_metadata()).id + assert metadata['name'] == 'new_request_queue' + + await client.drop() + + +async def test_request_file_persistence(rq_client: FileSystemRequestQueueClient) -> None: + """Test that requests are properly persisted to files.""" + requests = [ + Request.from_url('https://example.com/1'), + Request.from_url('https://example.com/2'), + Request.from_url('https://example.com/3'), + ] + + await rq_client.add_batch_of_requests(requests) + + # Verify request files are created + request_files = list(rq_client.path_to_rq.glob('*.json')) + # Should have 3 request files + 1 metadata file + assert len(request_files) == 4 + assert rq_client.path_to_metadata in request_files + + # Verify actual request file content + data_files = [f for f in request_files if f != rq_client.path_to_metadata] + assert len(data_files) == 3 + + for req_file in data_files: + with req_file.open() as f: + request_data = json.load(f) + assert 'url' in request_data + assert request_data['url'].startswith('https://example.com/') + + +async def test_drop_removes_directory(rq_client: FileSystemRequestQueueClient) -> None: + """Test that drop removes the entire RQ directory from disk.""" + await rq_client.add_batch_of_requests([Request.from_url('https://example.com')]) + + rq_path = rq_client.path_to_rq + assert rq_path.exists() + + # Drop the request queue + await rq_client.drop() + + assert not rq_path.exists() + + +async def test_metadata_file_updates(rq_client: FileSystemRequestQueueClient) -> None: + """Test that metadata file is updated correctly after operations.""" + # Record initial timestamps + metadata = await rq_client.get_metadata() + initial_created = metadata.created_at + initial_accessed = metadata.accessed_at + initial_modified = metadata.modified_at + + # Wait a moment to ensure timestamps can change + await asyncio.sleep(0.01) + + # Perform a read operation + await rq_client.is_empty() + + # Verify accessed timestamp was updated + metadata = await rq_client.get_metadata() + assert metadata.created_at == initial_created + assert metadata.accessed_at > initial_accessed + assert metadata.modified_at == initial_modified + + accessed_after_read = metadata.accessed_at + + # Wait a moment to ensure timestamps can change + await asyncio.sleep(0.01) + + # Perform a write operation + await rq_client.add_batch_of_requests([Request.from_url('https://example.com')]) + + # Verify modified timestamp was updated + metadata = await rq_client.get_metadata() + assert metadata.created_at == initial_created + assert metadata.modified_at > initial_modified + assert metadata.accessed_at > accessed_after_read + + # Verify metadata file is updated on disk + with rq_client.path_to_metadata.open() as f: + metadata_json = json.load(f) + assert metadata_json['total_request_count'] == 1 + + +async def test_data_persistence_across_reopens(configuration: Configuration) -> None: + """Test that requests persist correctly when reopening the same RQ.""" + storage_client = FileSystemStorageClient() + + # Create RQ and add requests + original_client = await storage_client.create_rq_client( + name='persistence-test', + configuration=configuration, + ) + + test_requests = [ + Request.from_url('https://example.com/1'), + Request.from_url('https://example.com/2'), + ] + await original_client.add_batch_of_requests(test_requests) + + rq_id = (await original_client.get_metadata()).id + + # Reopen by ID and verify requests persist + reopened_client = await storage_client.create_rq_client( + id=rq_id, + configuration=configuration, + ) + + metadata = await reopened_client.get_metadata() + assert metadata.total_request_count == 2 + + # Fetch requests to verify they're still there + request1 = await reopened_client.fetch_next_request() + request2 = await reopened_client.fetch_next_request() + + assert request1 is not None + assert request2 is not None + assert {request1.url, request2.url} == {'https://example.com/1', 'https://example.com/2'} + + await reopened_client.drop() diff --git a/tests/unit/storage_clients/_memory/test_creation_management.py b/tests/unit/storage_clients/_memory/test_creation_management.py deleted file mode 100644 index 88a5e9e283..0000000000 --- a/tests/unit/storage_clients/_memory/test_creation_management.py +++ /dev/null @@ -1,59 +0,0 @@ -from __future__ import annotations - -import json -from pathlib import Path -from unittest.mock import AsyncMock, patch - -import pytest - -from crawlee._consts import METADATA_FILENAME -from crawlee.storage_clients._memory._creation_management import persist_metadata_if_enabled - - -async def test_persist_metadata_skips_when_disabled(tmp_path: Path) -> None: - await persist_metadata_if_enabled(data={'key': 'value'}, entity_directory=str(tmp_path), write_metadata=False) - assert not list(tmp_path.iterdir()) # The directory should be empty since write_metadata is False - - -async def test_persist_metadata_creates_files_and_directories_when_enabled(tmp_path: Path) -> None: - data = {'key': 'value'} - entity_directory = Path(tmp_path, 'new_dir') - await persist_metadata_if_enabled(data=data, entity_directory=str(entity_directory), write_metadata=True) - assert entity_directory.exists() is True # Check if directory was created - assert (entity_directory / METADATA_FILENAME).is_file() # Check if file was created - - -async def test_persist_metadata_correctly_writes_data(tmp_path: Path) -> None: - data = {'key': 'value'} - entity_directory = Path(tmp_path, 'data_dir') - await persist_metadata_if_enabled(data=data, entity_directory=str(entity_directory), write_metadata=True) - metadata_path = entity_directory / METADATA_FILENAME - with open(metadata_path) as f: # noqa: ASYNC230 - content = f.read() - assert json.loads(content) == data # Check if correct data was written - - -async def test_persist_metadata_rewrites_data_with_error(tmp_path: Path) -> None: - init_data = {'key': 'very_long_value'} - update_data = {'key': 'short_value'} - error_data = {'key': 'error'} - - entity_directory = Path(tmp_path, 'data_dir') - metadata_path = entity_directory / METADATA_FILENAME - - # write metadata with init_data - await persist_metadata_if_enabled(data=init_data, entity_directory=str(entity_directory), write_metadata=True) - - # rewrite metadata with new_data - await persist_metadata_if_enabled(data=update_data, entity_directory=str(entity_directory), write_metadata=True) - with open(metadata_path) as f: # noqa: ASYNC230 - content = f.read() - assert json.loads(content) == update_data # Check if correct data was rewritten - - # raise interrupt between opening a file and writing - module_for_patch = 'crawlee.storage_clients._memory._creation_management.json_dumps' - with patch(module_for_patch, AsyncMock(side_effect=KeyboardInterrupt())), pytest.raises(KeyboardInterrupt): - await persist_metadata_if_enabled(data=error_data, entity_directory=str(entity_directory), write_metadata=True) - with open(metadata_path) as f: # noqa: ASYNC230 - content = f.read() - assert content == '' # The file is empty after an error diff --git a/tests/unit/storage_clients/_memory/test_dataset_client.py b/tests/unit/storage_clients/_memory/test_dataset_client.py deleted file mode 100644 index 472d11a8b3..0000000000 --- a/tests/unit/storage_clients/_memory/test_dataset_client.py +++ /dev/null @@ -1,148 +0,0 @@ -from __future__ import annotations - -import asyncio -from pathlib import Path -from typing import TYPE_CHECKING - -import pytest - -if TYPE_CHECKING: - from crawlee.storage_clients import MemoryStorageClient - from crawlee.storage_clients._memory import DatasetClient - - -@pytest.fixture -async def dataset_client(memory_storage_client: MemoryStorageClient) -> DatasetClient: - datasets_client = memory_storage_client.datasets() - dataset_info = await datasets_client.get_or_create(name='test') - return memory_storage_client.dataset(dataset_info.id) - - -async def test_nonexistent(memory_storage_client: MemoryStorageClient) -> None: - dataset_client = memory_storage_client.dataset(id='nonexistent-id') - assert await dataset_client.get() is None - with pytest.raises(ValueError, match='Dataset with id "nonexistent-id" does not exist.'): - await dataset_client.update(name='test-update') - - with pytest.raises(ValueError, match='Dataset with id "nonexistent-id" does not exist.'): - await dataset_client.list_items() - - with pytest.raises(ValueError, match='Dataset with id "nonexistent-id" does not exist.'): - await dataset_client.push_items([{'abc': 123}]) - await dataset_client.delete() - - -async def test_not_implemented(dataset_client: DatasetClient) -> None: - with pytest.raises(NotImplementedError, match='This method is not supported in memory storage.'): - await dataset_client.stream_items() - with pytest.raises(NotImplementedError, match='This method is not supported in memory storage.'): - await dataset_client.get_items_as_bytes() - - -async def test_get(dataset_client: DatasetClient) -> None: - await asyncio.sleep(0.1) - info = await dataset_client.get() - assert info is not None - assert info.id == dataset_client.id - assert info.accessed_at != info.created_at - - -async def test_update(dataset_client: DatasetClient) -> None: - new_dataset_name = 'test-update' - await dataset_client.push_items({'abc': 123}) - - old_dataset_info = await dataset_client.get() - assert old_dataset_info is not None - old_dataset_directory = Path(dataset_client._memory_storage_client.datasets_directory, old_dataset_info.name or '') - new_dataset_directory = Path(dataset_client._memory_storage_client.datasets_directory, new_dataset_name) - assert (old_dataset_directory / '000000001.json').exists() is True - assert (new_dataset_directory / '000000001.json').exists() is False - - await asyncio.sleep(0.1) - updated_dataset_info = await dataset_client.update(name=new_dataset_name) - assert (old_dataset_directory / '000000001.json').exists() is False - assert (new_dataset_directory / '000000001.json').exists() is True - # Only modified_at and accessed_at should be different - assert old_dataset_info.created_at == updated_dataset_info.created_at - assert old_dataset_info.modified_at != updated_dataset_info.modified_at - assert old_dataset_info.accessed_at != updated_dataset_info.accessed_at - - # Should fail with the same name - with pytest.raises(ValueError, match='Dataset with name "test-update" already exists.'): - await dataset_client.update(name=new_dataset_name) - - -async def test_delete(dataset_client: DatasetClient) -> None: - await dataset_client.push_items({'abc': 123}) - dataset_info = await dataset_client.get() - assert dataset_info is not None - dataset_directory = Path(dataset_client._memory_storage_client.datasets_directory, dataset_info.name or '') - assert (dataset_directory / '000000001.json').exists() is True - await dataset_client.delete() - assert (dataset_directory / '000000001.json').exists() is False - # Does not crash when called again - await dataset_client.delete() - - -async def test_push_items(dataset_client: DatasetClient) -> None: - await dataset_client.push_items('{"test": "JSON from a string"}') - await dataset_client.push_items({'abc': {'def': {'ghi': '123'}}}) - await dataset_client.push_items(['{"test-json-parse": "JSON from a string"}' for _ in range(10)]) - await dataset_client.push_items([{'test-dict': i} for i in range(10)]) - - list_page = await dataset_client.list_items() - assert list_page.items[0]['test'] == 'JSON from a string' - assert list_page.items[1]['abc']['def']['ghi'] == '123' - assert list_page.items[11]['test-json-parse'] == 'JSON from a string' - assert list_page.items[21]['test-dict'] == 9 - assert list_page.count == 22 - - -async def test_list_items(dataset_client: DatasetClient) -> None: - item_count = 100 - used_offset = 10 - used_limit = 50 - await dataset_client.push_items([{'id': i} for i in range(item_count)]) - # Test without any parameters - list_default = await dataset_client.list_items() - assert list_default.count == item_count - assert list_default.offset == 0 - assert list_default.items[0]['id'] == 0 - assert list_default.desc is False - # Test offset - list_offset_10 = await dataset_client.list_items(offset=used_offset) - assert list_offset_10.count == item_count - used_offset - assert list_offset_10.offset == used_offset - assert list_offset_10.total == item_count - assert list_offset_10.items[0]['id'] == used_offset - # Test limit - list_limit_50 = await dataset_client.list_items(limit=used_limit) - assert list_limit_50.count == used_limit - assert list_limit_50.limit == used_limit - assert list_limit_50.total == item_count - # Test desc - list_desc_true = await dataset_client.list_items(desc=True) - assert list_desc_true.items[0]['id'] == 99 - assert list_desc_true.desc is True - - -async def test_iterate_items(dataset_client: DatasetClient) -> None: - item_count = 100 - await dataset_client.push_items([{'id': i} for i in range(item_count)]) - actual_items = [] - async for item in dataset_client.iterate_items(): - assert 'id' in item - actual_items.append(item) - assert len(actual_items) == item_count - assert actual_items[0]['id'] == 0 - assert actual_items[99]['id'] == 99 - - -async def test_reuse_dataset(dataset_client: DatasetClient, memory_storage_client: MemoryStorageClient) -> None: - item_count = 10 - await dataset_client.push_items([{'id': i} for i in range(item_count)]) - - memory_storage_client.datasets_handled = [] # purge datasets loaded to test create_dataset_from_directory - datasets_client = memory_storage_client.datasets() - dataset_info = await datasets_client.get_or_create(name='test') - assert dataset_info.item_count == item_count diff --git a/tests/unit/storage_clients/_memory/test_dataset_collection_client.py b/tests/unit/storage_clients/_memory/test_dataset_collection_client.py deleted file mode 100644 index d71b7e8f68..0000000000 --- a/tests/unit/storage_clients/_memory/test_dataset_collection_client.py +++ /dev/null @@ -1,45 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -import pytest - -if TYPE_CHECKING: - from crawlee.storage_clients import MemoryStorageClient - from crawlee.storage_clients._memory import DatasetCollectionClient - - -@pytest.fixture -def datasets_client(memory_storage_client: MemoryStorageClient) -> DatasetCollectionClient: - return memory_storage_client.datasets() - - -async def test_get_or_create(datasets_client: DatasetCollectionClient) -> None: - dataset_name = 'test' - # A new dataset gets created - dataset_info = await datasets_client.get_or_create(name=dataset_name) - assert dataset_info.name == dataset_name - - # Another get_or_create call returns the same dataset - dataset_info_existing = await datasets_client.get_or_create(name=dataset_name) - assert dataset_info.id == dataset_info_existing.id - assert dataset_info.name == dataset_info_existing.name - assert dataset_info.created_at == dataset_info_existing.created_at - - -async def test_list(datasets_client: DatasetCollectionClient) -> None: - dataset_list_1 = await datasets_client.list() - assert dataset_list_1.count == 0 - - dataset_info = await datasets_client.get_or_create(name='dataset') - dataset_list_2 = await datasets_client.list() - - assert dataset_list_2.count == 1 - assert dataset_list_2.items[0].name == dataset_info.name - - # Test sorting behavior - newer_dataset_info = await datasets_client.get_or_create(name='newer-dataset') - dataset_list_sorting = await datasets_client.list() - assert dataset_list_sorting.count == 2 - assert dataset_list_sorting.items[0].name == dataset_info.name - assert dataset_list_sorting.items[1].name == newer_dataset_info.name diff --git a/tests/unit/storage_clients/_memory/test_key_value_store_client.py b/tests/unit/storage_clients/_memory/test_key_value_store_client.py deleted file mode 100644 index 26d1f8f974..0000000000 --- a/tests/unit/storage_clients/_memory/test_key_value_store_client.py +++ /dev/null @@ -1,443 +0,0 @@ -from __future__ import annotations - -import asyncio -import base64 -import json -from datetime import datetime, timezone -from pathlib import Path -from typing import TYPE_CHECKING - -import pytest - -from crawlee._consts import METADATA_FILENAME -from crawlee._utils.crypto import crypto_random_object_id -from crawlee._utils.data_processing import maybe_parse_body -from crawlee._utils.file import json_dumps -from crawlee.storage_clients.models import KeyValueStoreMetadata, KeyValueStoreRecordMetadata - -if TYPE_CHECKING: - from crawlee.storage_clients import MemoryStorageClient - from crawlee.storage_clients._memory import KeyValueStoreClient - -TINY_PNG = base64.b64decode( - s='iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVQYV2NgYAAAAAMAAWgmWQ0AAAAASUVORK5CYII=', -) -TINY_BYTES = b'\x12\x34\x56\x78\x90\xab\xcd\xef' -TINY_DATA = {'a': 'b'} -TINY_TEXT = 'abcd' - - -@pytest.fixture -async def key_value_store_client(memory_storage_client: MemoryStorageClient) -> KeyValueStoreClient: - key_value_stores_client = memory_storage_client.key_value_stores() - kvs_info = await key_value_stores_client.get_or_create(name='test') - return memory_storage_client.key_value_store(kvs_info.id) - - -async def test_nonexistent(memory_storage_client: MemoryStorageClient) -> None: - kvs_client = memory_storage_client.key_value_store(id='nonexistent-id') - assert await kvs_client.get() is None - - with pytest.raises(ValueError, match='Key-value store with id "nonexistent-id" does not exist.'): - await kvs_client.update(name='test-update') - - with pytest.raises(ValueError, match='Key-value store with id "nonexistent-id" does not exist.'): - await kvs_client.list_keys() - - with pytest.raises(ValueError, match='Key-value store with id "nonexistent-id" does not exist.'): - await kvs_client.set_record('test', {'abc': 123}) - - with pytest.raises(ValueError, match='Key-value store with id "nonexistent-id" does not exist.'): - await kvs_client.get_record('test') - - with pytest.raises(ValueError, match='Key-value store with id "nonexistent-id" does not exist.'): - await kvs_client.get_record_as_bytes('test') - - with pytest.raises(ValueError, match='Key-value store with id "nonexistent-id" does not exist.'): - await kvs_client.delete_record('test') - - await kvs_client.delete() - - -async def test_not_implemented(key_value_store_client: KeyValueStoreClient) -> None: - with pytest.raises(NotImplementedError, match='This method is not supported in memory storage.'): - await key_value_store_client.stream_record('test') - - -async def test_get(key_value_store_client: KeyValueStoreClient) -> None: - await asyncio.sleep(0.1) - info = await key_value_store_client.get() - assert info is not None - assert info.id == key_value_store_client.id - assert info.accessed_at != info.created_at - - -async def test_update(key_value_store_client: KeyValueStoreClient) -> None: - new_kvs_name = 'test-update' - await key_value_store_client.set_record('test', {'abc': 123}) - old_kvs_info = await key_value_store_client.get() - assert old_kvs_info is not None - old_kvs_directory = Path( - key_value_store_client._memory_storage_client.key_value_stores_directory, old_kvs_info.name or '' - ) - new_kvs_directory = Path(key_value_store_client._memory_storage_client.key_value_stores_directory, new_kvs_name) - assert (old_kvs_directory / 'test.json').exists() is True - assert (new_kvs_directory / 'test.json').exists() is False - - await asyncio.sleep(0.1) - updated_kvs_info = await key_value_store_client.update(name=new_kvs_name) - assert (old_kvs_directory / 'test.json').exists() is False - assert (new_kvs_directory / 'test.json').exists() is True - # Only modified_at and accessed_at should be different - assert old_kvs_info.created_at == updated_kvs_info.created_at - assert old_kvs_info.modified_at != updated_kvs_info.modified_at - assert old_kvs_info.accessed_at != updated_kvs_info.accessed_at - - # Should fail with the same name - with pytest.raises(ValueError, match='Key-value store with name "test-update" already exists.'): - await key_value_store_client.update(name=new_kvs_name) - - -async def test_delete(key_value_store_client: KeyValueStoreClient) -> None: - await key_value_store_client.set_record('test', {'abc': 123}) - kvs_info = await key_value_store_client.get() - assert kvs_info is not None - kvs_directory = Path(key_value_store_client._memory_storage_client.key_value_stores_directory, kvs_info.name or '') - assert (kvs_directory / 'test.json').exists() is True - await key_value_store_client.delete() - assert (kvs_directory / 'test.json').exists() is False - # Does not crash when called again - await key_value_store_client.delete() - - -async def test_list_keys_empty(key_value_store_client: KeyValueStoreClient) -> None: - keys = await key_value_store_client.list_keys() - assert len(keys.items) == 0 - assert keys.count == 0 - assert keys.is_truncated is False - - -async def test_list_keys(key_value_store_client: KeyValueStoreClient) -> None: - record_count = 4 - used_limit = 2 - used_exclusive_start_key = 'a' - await key_value_store_client.set_record('b', 'test') - await key_value_store_client.set_record('a', 'test') - await key_value_store_client.set_record('d', 'test') - await key_value_store_client.set_record('c', 'test') - - # Default settings - keys = await key_value_store_client.list_keys() - assert keys.items[0].key == 'a' - assert keys.items[3].key == 'd' - assert keys.count == record_count - assert keys.is_truncated is False - # Test limit - keys_limit_2 = await key_value_store_client.list_keys(limit=used_limit) - assert keys_limit_2.count == record_count - assert keys_limit_2.limit == used_limit - assert keys_limit_2.items[1].key == 'b' - # Test exclusive start key - keys_exclusive_start = await key_value_store_client.list_keys(exclusive_start_key=used_exclusive_start_key, limit=2) - assert keys_exclusive_start.exclusive_start_key == used_exclusive_start_key - assert keys_exclusive_start.is_truncated is True - assert keys_exclusive_start.next_exclusive_start_key == 'c' - assert keys_exclusive_start.items[0].key == 'b' - assert keys_exclusive_start.items[-1].key == keys_exclusive_start.next_exclusive_start_key - - -async def test_get_and_set_record(tmp_path: Path, key_value_store_client: KeyValueStoreClient) -> None: - # Test setting dict record - dict_record_key = 'test-dict' - await key_value_store_client.set_record(dict_record_key, {'test': 123}) - dict_record_info = await key_value_store_client.get_record(dict_record_key) - assert dict_record_info is not None - assert 'application/json' in str(dict_record_info.content_type) - assert dict_record_info.value['test'] == 123 - - # Test setting str record - str_record_key = 'test-str' - await key_value_store_client.set_record(str_record_key, 'test') - str_record_info = await key_value_store_client.get_record(str_record_key) - assert str_record_info is not None - assert 'text/plain' in str(str_record_info.content_type) - assert str_record_info.value == 'test' - - # Test setting explicit json record but use str as value, i.e. json dumps is skipped - explicit_json_key = 'test-json' - await key_value_store_client.set_record(explicit_json_key, '{"test": "explicit string"}', 'application/json') - bytes_record_info = await key_value_store_client.get_record(explicit_json_key) - assert bytes_record_info is not None - assert 'application/json' in str(bytes_record_info.content_type) - assert bytes_record_info.value['test'] == 'explicit string' - - # Test using bytes - bytes_key = 'test-json' - bytes_value = b'testing bytes set_record' - await key_value_store_client.set_record(bytes_key, bytes_value, 'unknown') - bytes_record_info = await key_value_store_client.get_record(bytes_key) - assert bytes_record_info is not None - assert 'unknown' in str(bytes_record_info.content_type) - assert bytes_record_info.value == bytes_value - assert bytes_record_info.value.decode('utf-8') == bytes_value.decode('utf-8') - - # Test using file descriptor - with open(tmp_path / 'test.json', 'w+', encoding='utf-8') as f: # noqa: ASYNC230 - f.write('Test') - with pytest.raises(NotImplementedError, match='File-like values are not supported in local memory storage'): - await key_value_store_client.set_record('file', f) - - -async def test_get_record_as_bytes(key_value_store_client: KeyValueStoreClient) -> None: - record_key = 'test' - record_value = 'testing' - await key_value_store_client.set_record(record_key, record_value) - record_info = await key_value_store_client.get_record_as_bytes(record_key) - assert record_info is not None - assert record_info.value == record_value.encode('utf-8') - - -async def test_delete_record(key_value_store_client: KeyValueStoreClient) -> None: - record_key = 'test' - await key_value_store_client.set_record(record_key, 'test') - await key_value_store_client.delete_record(record_key) - # Does not crash when called again - await key_value_store_client.delete_record(record_key) - - -@pytest.mark.parametrize( - ('input_data', 'expected_output'), - [ - ( - {'key': 'image', 'value': TINY_PNG, 'contentType': None}, - {'filename': 'image', 'key': 'image', 'contentType': 'application/octet-stream'}, - ), - ( - {'key': 'image', 'value': TINY_PNG, 'contentType': 'image/png'}, - {'filename': 'image.png', 'key': 'image', 'contentType': 'image/png'}, - ), - ( - {'key': 'image.png', 'value': TINY_PNG, 'contentType': None}, - {'filename': 'image.png', 'key': 'image.png', 'contentType': 'application/octet-stream'}, - ), - ( - {'key': 'image.png', 'value': TINY_PNG, 'contentType': 'image/png'}, - {'filename': 'image.png', 'key': 'image.png', 'contentType': 'image/png'}, - ), - ( - {'key': 'data', 'value': TINY_DATA, 'contentType': None}, - {'filename': 'data.json', 'key': 'data', 'contentType': 'application/json'}, - ), - ( - {'key': 'data', 'value': TINY_DATA, 'contentType': 'application/json'}, - {'filename': 'data.json', 'key': 'data', 'contentType': 'application/json'}, - ), - ( - {'key': 'data.json', 'value': TINY_DATA, 'contentType': None}, - {'filename': 'data.json', 'key': 'data.json', 'contentType': 'application/json'}, - ), - ( - {'key': 'data.json', 'value': TINY_DATA, 'contentType': 'application/json'}, - {'filename': 'data.json', 'key': 'data.json', 'contentType': 'application/json'}, - ), - ( - {'key': 'text', 'value': TINY_TEXT, 'contentType': None}, - {'filename': 'text.txt', 'key': 'text', 'contentType': 'text/plain'}, - ), - ( - {'key': 'text', 'value': TINY_TEXT, 'contentType': 'text/plain'}, - {'filename': 'text.txt', 'key': 'text', 'contentType': 'text/plain'}, - ), - ( - {'key': 'text.txt', 'value': TINY_TEXT, 'contentType': None}, - {'filename': 'text.txt', 'key': 'text.txt', 'contentType': 'text/plain'}, - ), - ( - {'key': 'text.txt', 'value': TINY_TEXT, 'contentType': 'text/plain'}, - {'filename': 'text.txt', 'key': 'text.txt', 'contentType': 'text/plain'}, - ), - ], -) -async def test_writes_correct_metadata( - memory_storage_client: MemoryStorageClient, - input_data: dict, - expected_output: dict, -) -> None: - key_value_store_name = crypto_random_object_id() - - # Get KVS client - kvs_info = await memory_storage_client.key_value_stores().get_or_create(name=key_value_store_name) - kvs_client = memory_storage_client.key_value_store(kvs_info.id) - - # Write the test input item to the store - await kvs_client.set_record( - key=input_data['key'], - value=input_data['value'], - content_type=input_data['contentType'], - ) - - # Check that everything was written correctly, both the data and metadata - storage_path = Path(memory_storage_client.key_value_stores_directory, key_value_store_name) - item_path = Path(storage_path, expected_output['filename']) - item_metadata_path = storage_path / f'{expected_output["filename"]}.__metadata__.json' - - assert item_path.exists() - assert item_metadata_path.exists() - - # Test the actual value of the item - with open(item_path, 'rb') as item_file: # noqa: ASYNC230 - actual_value = maybe_parse_body(item_file.read(), expected_output['contentType']) - assert actual_value == input_data['value'] - - # Test the actual metadata of the item - with open(item_metadata_path, encoding='utf-8') as metadata_file: # noqa: ASYNC230 - json_content = json.load(metadata_file) - metadata = KeyValueStoreRecordMetadata(**json_content) - assert metadata.key == expected_output['key'] - assert expected_output['contentType'] in metadata.content_type - - -@pytest.mark.parametrize( - ('input_data', 'expected_output'), - [ - ( - {'filename': 'image', 'value': TINY_PNG, 'metadata': None}, - {'key': 'image', 'filename': 'image', 'contentType': 'application/octet-stream'}, - ), - ( - {'filename': 'image.png', 'value': TINY_PNG, 'metadata': None}, - {'key': 'image', 'filename': 'image.png', 'contentType': 'image/png'}, - ), - ( - { - 'filename': 'image', - 'value': TINY_PNG, - 'metadata': {'key': 'image', 'contentType': 'application/octet-stream'}, - }, - {'key': 'image', 'contentType': 'application/octet-stream'}, - ), - ( - {'filename': 'image', 'value': TINY_PNG, 'metadata': {'key': 'image', 'contentType': 'image/png'}}, - {'key': 'image', 'filename': 'image', 'contentType': 'image/png'}, - ), - ( - { - 'filename': 'image.png', - 'value': TINY_PNG, - 'metadata': {'key': 'image.png', 'contentType': 'application/octet-stream'}, - }, - {'key': 'image.png', 'contentType': 'application/octet-stream'}, - ), - ( - {'filename': 'image.png', 'value': TINY_PNG, 'metadata': {'key': 'image.png', 'contentType': 'image/png'}}, - {'key': 'image.png', 'contentType': 'image/png'}, - ), - ( - {'filename': 'image.png', 'value': TINY_PNG, 'metadata': {'key': 'image', 'contentType': 'image/png'}}, - {'key': 'image', 'contentType': 'image/png'}, - ), - ( - {'filename': 'input', 'value': TINY_BYTES, 'metadata': None}, - {'key': 'input', 'contentType': 'application/octet-stream'}, - ), - ( - {'filename': 'input.json', 'value': TINY_DATA, 'metadata': None}, - {'key': 'input', 'contentType': 'application/json'}, - ), - ( - {'filename': 'input.txt', 'value': TINY_TEXT, 'metadata': None}, - {'key': 'input', 'contentType': 'text/plain'}, - ), - ( - {'filename': 'input.bin', 'value': TINY_BYTES, 'metadata': None}, - {'key': 'input', 'contentType': 'application/octet-stream'}, - ), - ( - { - 'filename': 'input', - 'value': TINY_BYTES, - 'metadata': {'key': 'input', 'contentType': 'application/octet-stream'}, - }, - {'key': 'input', 'contentType': 'application/octet-stream'}, - ), - ( - { - 'filename': 'input.json', - 'value': TINY_DATA, - 'metadata': {'key': 'input', 'contentType': 'application/json'}, - }, - {'key': 'input', 'contentType': 'application/json'}, - ), - ( - {'filename': 'input.txt', 'value': TINY_TEXT, 'metadata': {'key': 'input', 'contentType': 'text/plain'}}, - {'key': 'input', 'contentType': 'text/plain'}, - ), - ( - { - 'filename': 'input.bin', - 'value': TINY_BYTES, - 'metadata': {'key': 'input', 'contentType': 'application/octet-stream'}, - }, - {'key': 'input', 'contentType': 'application/octet-stream'}, - ), - ], -) -async def test_reads_correct_metadata( - memory_storage_client: MemoryStorageClient, - input_data: dict, - expected_output: dict, -) -> None: - key_value_store_name = crypto_random_object_id() - - # Ensure the directory for the store exists - storage_path = Path(memory_storage_client.key_value_stores_directory, key_value_store_name) - storage_path.mkdir(exist_ok=True, parents=True) - - store_metadata = KeyValueStoreMetadata( - id=crypto_random_object_id(), - name='', - accessed_at=datetime.now(timezone.utc), - created_at=datetime.now(timezone.utc), - modified_at=datetime.now(timezone.utc), - user_id='1', - ) - - # Write the store metadata to disk - storage_metadata_path = storage_path / METADATA_FILENAME - with open(storage_metadata_path, mode='wb') as f: # noqa: ASYNC230 - f.write(store_metadata.model_dump_json().encode('utf-8')) - - # Write the test input item to the disk - item_path = storage_path / input_data['filename'] - with open(item_path, 'wb') as item_file: # noqa: ASYNC230 - if isinstance(input_data['value'], bytes): - item_file.write(input_data['value']) - elif isinstance(input_data['value'], str): - item_file.write(input_data['value'].encode('utf-8')) - else: - s = await json_dumps(input_data['value']) - item_file.write(s.encode('utf-8')) - - # Optionally write the metadata to disk if there is some - if input_data['metadata'] is not None: - storage_metadata_path = storage_path / f'{input_data["filename"]}.__metadata__.json' - with open(storage_metadata_path, 'w', encoding='utf-8') as metadata_file: # noqa: ASYNC230 - s = await json_dumps( - { - 'key': input_data['metadata']['key'], - 'contentType': input_data['metadata']['contentType'], - } - ) - metadata_file.write(s) - - # Create the key-value store client to load the items from disk - store_details = await memory_storage_client.key_value_stores().get_or_create(name=key_value_store_name) - key_value_store_client = memory_storage_client.key_value_store(store_details.id) - - # Read the item from the store and check if it is as expected - actual_record = await key_value_store_client.get_record(expected_output['key']) - assert actual_record is not None - - assert actual_record.key == expected_output['key'] - assert actual_record.content_type == expected_output['contentType'] - assert actual_record.value == input_data['value'] diff --git a/tests/unit/storage_clients/_memory/test_key_value_store_collection_client.py b/tests/unit/storage_clients/_memory/test_key_value_store_collection_client.py deleted file mode 100644 index 41b289eb06..0000000000 --- a/tests/unit/storage_clients/_memory/test_key_value_store_collection_client.py +++ /dev/null @@ -1,42 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -import pytest - -if TYPE_CHECKING: - from crawlee.storage_clients import MemoryStorageClient - from crawlee.storage_clients._memory import KeyValueStoreCollectionClient - - -@pytest.fixture -def key_value_stores_client(memory_storage_client: MemoryStorageClient) -> KeyValueStoreCollectionClient: - return memory_storage_client.key_value_stores() - - -async def test_get_or_create(key_value_stores_client: KeyValueStoreCollectionClient) -> None: - kvs_name = 'test' - # A new kvs gets created - kvs_info = await key_value_stores_client.get_or_create(name=kvs_name) - assert kvs_info.name == kvs_name - - # Another get_or_create call returns the same kvs - kvs_info_existing = await key_value_stores_client.get_or_create(name=kvs_name) - assert kvs_info.id == kvs_info_existing.id - assert kvs_info.name == kvs_info_existing.name - assert kvs_info.created_at == kvs_info_existing.created_at - - -async def test_list(key_value_stores_client: KeyValueStoreCollectionClient) -> None: - assert (await key_value_stores_client.list()).count == 0 - kvs_info = await key_value_stores_client.get_or_create(name='kvs') - kvs_list = await key_value_stores_client.list() - assert kvs_list.count == 1 - assert kvs_list.items[0].name == kvs_info.name - - # Test sorting behavior - newer_kvs_info = await key_value_stores_client.get_or_create(name='newer-kvs') - kvs_list_sorting = await key_value_stores_client.list() - assert kvs_list_sorting.count == 2 - assert kvs_list_sorting.items[0].name == kvs_info.name - assert kvs_list_sorting.items[1].name == newer_kvs_info.name diff --git a/tests/unit/storage_clients/_memory/test_memory_dataset_client.py b/tests/unit/storage_clients/_memory/test_memory_dataset_client.py new file mode 100644 index 0000000000..8cc846b0f4 --- /dev/null +++ b/tests/unit/storage_clients/_memory/test_memory_dataset_client.py @@ -0,0 +1,83 @@ +from __future__ import annotations + +import asyncio +from typing import TYPE_CHECKING + +import pytest + +from crawlee.configuration import Configuration +from crawlee.storage_clients import MemoryStorageClient + +if TYPE_CHECKING: + from collections.abc import AsyncGenerator + + from crawlee.storage_clients._memory import MemoryDatasetClient + + +@pytest.fixture +async def dataset_client() -> AsyncGenerator[MemoryDatasetClient, None]: + """Fixture that provides a fresh memory dataset client for each test.""" + client = await MemoryStorageClient().create_dataset_client(name='test_dataset') + yield client + await client.drop() + + +async def test_memory_specific_purge_behavior() -> None: + """Test memory-specific purge behavior and in-memory storage characteristics.""" + configuration = Configuration(purge_on_start=True) + + # Create dataset and add data + dataset_client1 = await MemoryStorageClient().create_dataset_client( + name='test_purge_dataset', + configuration=configuration, + ) + await dataset_client1.push_data({'item': 'initial data'}) + + # Verify data was added + items = await dataset_client1.get_data() + assert len(items.items) == 1 + + # Reopen with same storage client instance + dataset_client2 = await MemoryStorageClient().create_dataset_client( + name='test_purge_dataset', + configuration=configuration, + ) + + # Verify data was purged (memory storage specific behavior) + items = await dataset_client2.get_data() + assert len(items.items) == 0 + + +async def test_memory_metadata_updates(dataset_client: MemoryDatasetClient) -> None: + """Test that metadata timestamps are updated correctly in memory storage.""" + # Record initial timestamps + metadata = await dataset_client.get_metadata() + initial_created = metadata.created_at + initial_accessed = metadata.accessed_at + initial_modified = metadata.modified_at + + # Wait a moment to ensure timestamps can change + await asyncio.sleep(0.01) + + # Perform a read operation + await dataset_client.get_data() + + # Verify timestamps (memory-specific behavior) + metadata = await dataset_client.get_metadata() + assert metadata.created_at == initial_created + assert metadata.accessed_at > initial_accessed + assert metadata.modified_at == initial_modified + + accessed_after_read = metadata.accessed_at + + # Wait a moment to ensure timestamps can change + await asyncio.sleep(0.01) + + # Perform a write operation + await dataset_client.push_data({'new': 'item'}) + + # Verify timestamps were updated + metadata = await dataset_client.get_metadata() + assert metadata.created_at == initial_created + assert metadata.modified_at > initial_modified + assert metadata.accessed_at > accessed_after_read diff --git a/tests/unit/storage_clients/_memory/test_memory_kvs_client.py b/tests/unit/storage_clients/_memory/test_memory_kvs_client.py new file mode 100644 index 0000000000..463fb2a14c --- /dev/null +++ b/tests/unit/storage_clients/_memory/test_memory_kvs_client.py @@ -0,0 +1,84 @@ +from __future__ import annotations + +import asyncio +from typing import TYPE_CHECKING + +import pytest + +from crawlee.configuration import Configuration +from crawlee.storage_clients import MemoryStorageClient + +if TYPE_CHECKING: + from collections.abc import AsyncGenerator + + from crawlee.storage_clients._memory import MemoryKeyValueStoreClient + + +@pytest.fixture +async def kvs_client() -> AsyncGenerator[MemoryKeyValueStoreClient, None]: + """Fixture that provides a fresh memory key-value store client for each test.""" + client = await MemoryStorageClient().create_kvs_client(name='test_kvs') + yield client + await client.drop() + + +async def test_memory_specific_purge_behavior() -> None: + """Test memory-specific purge behavior and in-memory storage characteristics.""" + configuration = Configuration(purge_on_start=True) + + # Create KVS and add data + kvs_client1 = await MemoryStorageClient().create_kvs_client( + name='test_purge_kvs', + configuration=configuration, + ) + await kvs_client1.set_value(key='test-key', value='initial value') + + # Verify value was set + record = await kvs_client1.get_value(key='test-key') + assert record is not None + assert record.value == 'initial value' + + # Reopen with same storage client instance + kvs_client2 = await MemoryStorageClient().create_kvs_client( + name='test_purge_kvs', + configuration=configuration, + ) + + # Verify value was purged (memory storage specific behavior) + record = await kvs_client2.get_value(key='test-key') + assert record is None + + +async def test_memory_metadata_updates(kvs_client: MemoryKeyValueStoreClient) -> None: + """Test that metadata timestamps are updated correctly in memory storage.""" + # Record initial timestamps + metadata = await kvs_client.get_metadata() + initial_created = metadata.created_at + initial_accessed = metadata.accessed_at + initial_modified = metadata.modified_at + + # Wait a moment to ensure timestamps can change + await asyncio.sleep(0.01) + + # Perform a read operation + await kvs_client.get_value(key='nonexistent') + + # Verify timestamps (memory-specific behavior) + metadata = await kvs_client.get_metadata() + assert metadata.created_at == initial_created + assert metadata.accessed_at > initial_accessed + assert metadata.modified_at == initial_modified + + accessed_after_read = metadata.accessed_at + + # Wait a moment to ensure timestamps can change + await asyncio.sleep(0.01) + + # Perform a write operation + await kvs_client.set_value(key='test', value='test-value') + + # Verify timestamps were updated + metadata = await kvs_client.get_metadata() + assert metadata.created_at == initial_created + assert metadata.modified_at > initial_modified + assert metadata.accessed_at > accessed_after_read diff --git a/tests/unit/storage_clients/_memory/test_memory_rq_client.py b/tests/unit/storage_clients/_memory/test_memory_rq_client.py new file mode 100644 index 0000000000..7877d8af79 --- /dev/null +++ b/tests/unit/storage_clients/_memory/test_memory_rq_client.py @@ -0,0 +1,83 @@ +from __future__ import annotations + +import asyncio +from typing import TYPE_CHECKING + +import pytest + +from crawlee import Request +from crawlee.configuration import Configuration +from crawlee.storage_clients import MemoryStorageClient + +if TYPE_CHECKING: + from collections.abc import AsyncGenerator + + from crawlee.storage_clients._memory import MemoryRequestQueueClient + + +@pytest.fixture +async def rq_client() -> AsyncGenerator[MemoryRequestQueueClient, None]: + """Fixture that provides a fresh memory request queue client for each test.""" + client = await MemoryStorageClient().create_rq_client(name='test_rq') + yield client + await client.drop() + + +async def test_memory_specific_purge_behavior() -> None: + """Test memory-specific purge behavior and in-memory storage characteristics.""" + configuration = Configuration(purge_on_start=True) + + # Create RQ and add data + rq_client1 = await MemoryStorageClient().create_rq_client( + name='test_purge_rq', + configuration=configuration, + ) + request = Request.from_url(url='https://example.com/initial') + await rq_client1.add_batch_of_requests([request]) + + # Verify request was added + assert await rq_client1.is_empty() is False + + # Reopen with same storage client instance + rq_client2 = await MemoryStorageClient().create_rq_client( + name='test_purge_rq', + configuration=configuration, + ) + + # Verify queue was purged (memory storage specific behavior) + assert await rq_client2.is_empty() is True + + +async def test_memory_metadata_updates(rq_client: MemoryRequestQueueClient) -> None: + """Test that metadata timestamps are updated correctly in memory storage.""" + # Record initial timestamps + metadata = await rq_client.get_metadata() + initial_created = metadata.created_at + initial_accessed = metadata.accessed_at + initial_modified = metadata.modified_at + + # Wait a moment to ensure timestamps can change + await asyncio.sleep(0.01) + + # Perform a read operation + await rq_client.is_empty() + + # Verify timestamps (memory-specific behavior) + metadata = await rq_client.get_metadata() + assert metadata.created_at == initial_created + assert metadata.accessed_at > initial_accessed + assert metadata.modified_at == initial_modified + + accessed_after_read = metadata.accessed_at + + # Wait a moment to ensure timestamps can change + await asyncio.sleep(0.01) + + # Perform a write operation + await rq_client.add_batch_of_requests([Request.from_url('https://example.com')]) + + # Verify timestamps were updated + metadata = await rq_client.get_metadata() + assert metadata.created_at == initial_created + assert metadata.modified_at > initial_modified + assert metadata.accessed_at > accessed_after_read diff --git a/tests/unit/storage_clients/_memory/test_memory_storage_client.py b/tests/unit/storage_clients/_memory/test_memory_storage_client.py deleted file mode 100644 index 0d043322ae..0000000000 --- a/tests/unit/storage_clients/_memory/test_memory_storage_client.py +++ /dev/null @@ -1,288 +0,0 @@ -# TODO: Update crawlee_storage_dir args once the Pydantic bug is fixed -# https://github.com/apify/crawlee-python/issues/146 - -from __future__ import annotations - -from pathlib import Path - -import pytest - -from crawlee import Request, service_locator -from crawlee._consts import METADATA_FILENAME -from crawlee.configuration import Configuration -from crawlee.storage_clients import MemoryStorageClient -from crawlee.storage_clients.models import BatchRequestsOperationResponse - - -async def test_write_metadata(tmp_path: Path) -> None: - dataset_name = 'test' - dataset_no_metadata_name = 'test-no-metadata' - ms = MemoryStorageClient.from_config( - Configuration( - crawlee_storage_dir=str(tmp_path), # type: ignore[call-arg] - write_metadata=True, - ), - ) - ms_no_metadata = MemoryStorageClient.from_config( - Configuration( - crawlee_storage_dir=str(tmp_path), # type: ignore[call-arg] - write_metadata=False, - ) - ) - datasets_client = ms.datasets() - datasets_no_metadata_client = ms_no_metadata.datasets() - await datasets_client.get_or_create(name=dataset_name) - await datasets_no_metadata_client.get_or_create(name=dataset_no_metadata_name) - assert Path(ms.datasets_directory, dataset_name, METADATA_FILENAME).exists() is True - assert Path(ms_no_metadata.datasets_directory, dataset_no_metadata_name, METADATA_FILENAME).exists() is False - - -@pytest.mark.parametrize( - 'persist_storage', - [ - True, - False, - ], -) -async def test_persist_storage(persist_storage: bool, tmp_path: Path) -> None: # noqa: FBT001 - ms = MemoryStorageClient.from_config( - Configuration( - crawlee_storage_dir=str(tmp_path), # type: ignore[call-arg] - persist_storage=persist_storage, - ) - ) - - # Key value stores - kvs_client = ms.key_value_stores() - kvs_info = await kvs_client.get_or_create(name='kvs') - await ms.key_value_store(kvs_info.id).set_record('test', {'x': 1}, 'application/json') - - path = Path(ms.key_value_stores_directory) / (kvs_info.name or '') / 'test.json' - assert path.exists() is persist_storage - - # Request queues - rq_client = ms.request_queues() - rq_info = await rq_client.get_or_create(name='rq') - - request = Request.from_url('http://lorem.com') - await ms.request_queue(rq_info.id).add_request(request) - - path = Path(ms.request_queues_directory) / (rq_info.name or '') / f'{request.id}.json' - assert path.exists() is persist_storage - - # Datasets - ds_client = ms.datasets() - ds_info = await ds_client.get_or_create(name='ds') - - await ms.dataset(ds_info.id).push_items([{'foo': 'bar'}]) - - -def test_persist_storage_set_to_false_via_string_env_var(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: - monkeypatch.setenv('CRAWLEE_PERSIST_STORAGE', 'false') - ms = MemoryStorageClient.from_config( - Configuration(crawlee_storage_dir=str(tmp_path)), # type: ignore[call-arg] - ) - assert ms.persist_storage is False - - -def test_persist_storage_set_to_false_via_numeric_env_var(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: - monkeypatch.setenv('CRAWLEE_PERSIST_STORAGE', '0') - ms = MemoryStorageClient.from_config(Configuration(crawlee_storage_dir=str(tmp_path))) # type: ignore[call-arg] - assert ms.persist_storage is False - - -def test_persist_storage_true_via_constructor_arg(tmp_path: Path) -> None: - ms = MemoryStorageClient.from_config( - Configuration( - crawlee_storage_dir=str(tmp_path), # type: ignore[call-arg] - persist_storage=True, - ) - ) - assert ms.persist_storage is True - - -def test_default_write_metadata_behavior(tmp_path: Path) -> None: - # Default behavior - ms = MemoryStorageClient.from_config( - Configuration(crawlee_storage_dir=str(tmp_path)), # type: ignore[call-arg] - ) - assert ms.write_metadata is True - - -def test_write_metadata_set_to_false_via_env_var(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: - # Test if env var changes write_metadata to False - monkeypatch.setenv('CRAWLEE_WRITE_METADATA', 'false') - ms = MemoryStorageClient.from_config( - Configuration(crawlee_storage_dir=str(tmp_path)), # type: ignore[call-arg] - ) - assert ms.write_metadata is False - - -def test_write_metadata_false_via_constructor_arg_overrides_env_var(tmp_path: Path) -> None: - # Test if constructor arg takes precedence over env var value - ms = MemoryStorageClient.from_config( - Configuration( - write_metadata=False, - crawlee_storage_dir=str(tmp_path), # type: ignore[call-arg] - ) - ) - assert ms.write_metadata is False - - -async def test_purge_datasets(tmp_path: Path) -> None: - ms = MemoryStorageClient.from_config( - Configuration( - write_metadata=True, - crawlee_storage_dir=str(tmp_path), # type: ignore[call-arg] - ) - ) - # Create default and non-default datasets - datasets_client = ms.datasets() - default_dataset_info = await datasets_client.get_or_create(name='default') - non_default_dataset_info = await datasets_client.get_or_create(name='non-default') - - # Check all folders inside datasets directory before and after purge - assert default_dataset_info.name is not None - assert non_default_dataset_info.name is not None - - default_path = Path(ms.datasets_directory, default_dataset_info.name) - non_default_path = Path(ms.datasets_directory, non_default_dataset_info.name) - - assert default_path.exists() is True - assert non_default_path.exists() is True - - await ms._purge_default_storages() - - assert default_path.exists() is False - assert non_default_path.exists() is True - - -async def test_purge_key_value_stores(tmp_path: Path) -> None: - ms = MemoryStorageClient.from_config( - Configuration( - write_metadata=True, - crawlee_storage_dir=str(tmp_path), # type: ignore[call-arg] - ) - ) - - # Create default and non-default key-value stores - kvs_client = ms.key_value_stores() - default_kvs_info = await kvs_client.get_or_create(name='default') - non_default_kvs_info = await kvs_client.get_or_create(name='non-default') - default_kvs_client = ms.key_value_store(default_kvs_info.id) - # INPUT.json should be kept - await default_kvs_client.set_record('INPUT', {'abc': 123}, 'application/json') - # test.json should not be kept - await default_kvs_client.set_record('test', {'abc': 123}, 'application/json') - - # Check all folders and files inside kvs directory before and after purge - assert default_kvs_info.name is not None - assert non_default_kvs_info.name is not None - - default_kvs_path = Path(ms.key_value_stores_directory, default_kvs_info.name) - non_default_kvs_path = Path(ms.key_value_stores_directory, non_default_kvs_info.name) - kvs_directory = Path(ms.key_value_stores_directory, 'default') - - assert default_kvs_path.exists() is True - assert non_default_kvs_path.exists() is True - - assert (kvs_directory / 'INPUT.json').exists() is True - assert (kvs_directory / 'test.json').exists() is True - - await ms._purge_default_storages() - - assert default_kvs_path.exists() is True - assert non_default_kvs_path.exists() is True - - assert (kvs_directory / 'INPUT.json').exists() is True - assert (kvs_directory / 'test.json').exists() is False - - -async def test_purge_request_queues(tmp_path: Path) -> None: - ms = MemoryStorageClient.from_config( - Configuration( - write_metadata=True, - crawlee_storage_dir=str(tmp_path), # type: ignore[call-arg] - ) - ) - # Create default and non-default request queues - rq_client = ms.request_queues() - default_rq_info = await rq_client.get_or_create(name='default') - non_default_rq_info = await rq_client.get_or_create(name='non-default') - - # Check all folders inside rq directory before and after purge - assert default_rq_info.name - assert non_default_rq_info.name - - default_rq_path = Path(ms.request_queues_directory, default_rq_info.name) - non_default_rq_path = Path(ms.request_queues_directory, non_default_rq_info.name) - - assert default_rq_path.exists() is True - assert non_default_rq_path.exists() is True - - await ms._purge_default_storages() - - assert default_rq_path.exists() is False - assert non_default_rq_path.exists() is True - - -async def test_not_implemented_method(tmp_path: Path) -> None: - ms = MemoryStorageClient.from_config( - Configuration( - write_metadata=True, - crawlee_storage_dir=str(tmp_path), # type: ignore[call-arg] - ) - ) - ddt = ms.dataset('test') - with pytest.raises(NotImplementedError, match='This method is not supported in memory storage.'): - await ddt.stream_items(item_format='json') - - with pytest.raises(NotImplementedError, match='This method is not supported in memory storage.'): - await ddt.stream_items(item_format='json') - - -async def test_default_storage_path_used(monkeypatch: pytest.MonkeyPatch) -> None: - # Reset the configuration in service locator - service_locator._configuration = None - service_locator._configuration_was_retrieved = False - - # Remove the env var for setting the storage directory - monkeypatch.delenv('CRAWLEE_STORAGE_DIR', raising=False) - - # Initialize the service locator with default configuration - msc = MemoryStorageClient.from_config() - assert msc.storage_dir == './storage' - - -async def test_storage_path_from_env_var_overrides_default(monkeypatch: pytest.MonkeyPatch) -> None: - # We expect the env var to override the default value - monkeypatch.setenv('CRAWLEE_STORAGE_DIR', './env_var_storage_dir') - service_locator.set_configuration(Configuration()) - ms = MemoryStorageClient.from_config() - assert ms.storage_dir == './env_var_storage_dir' - - -async def test_parametrized_storage_path_overrides_env_var() -> None: - # We expect the parametrized value to be used - ms = MemoryStorageClient.from_config( - Configuration(crawlee_storage_dir='./parametrized_storage_dir'), # type: ignore[call-arg] - ) - assert ms.storage_dir == './parametrized_storage_dir' - - -async def test_batch_requests_operation_response() -> None: - """Test that `BatchRequestsOperationResponse` creation from example responses.""" - process_request = { - 'requestId': 'EAaArVRs5qV39C9', - 'uniqueKey': 'https://example.com', - 'wasAlreadyHandled': False, - 'wasAlreadyPresent': True, - } - unprocess_request_full = {'uniqueKey': 'https://example2.com', 'method': 'GET', 'url': 'https://example2.com'} - unprocess_request_minimal = {'uniqueKey': 'https://example3.com', 'url': 'https://example3.com'} - BatchRequestsOperationResponse.model_validate( - { - 'processedRequests': [process_request], - 'unprocessedRequests': [unprocess_request_full, unprocess_request_minimal], - } - ) diff --git a/tests/unit/storage_clients/_memory/test_memory_storage_e2e.py b/tests/unit/storage_clients/_memory/test_memory_storage_e2e.py deleted file mode 100644 index c79fa66792..0000000000 --- a/tests/unit/storage_clients/_memory/test_memory_storage_e2e.py +++ /dev/null @@ -1,130 +0,0 @@ -from __future__ import annotations - -from datetime import datetime, timezone -from typing import Callable - -import pytest - -from crawlee import Request, service_locator -from crawlee.storages._key_value_store import KeyValueStore -from crawlee.storages._request_queue import RequestQueue - - -@pytest.mark.parametrize('purge_on_start', [True, False]) -async def test_actor_memory_storage_client_key_value_store_e2e( - monkeypatch: pytest.MonkeyPatch, - purge_on_start: bool, # noqa: FBT001 - prepare_test_env: Callable[[], None], -) -> None: - """This test simulates two clean runs using memory storage. - The second run attempts to access data created by the first one. - We run 2 configurations with different `purge_on_start`.""" - # Configure purging env var - monkeypatch.setenv('CRAWLEE_PURGE_ON_START', f'{int(purge_on_start)}') - # Store old storage client so we have the object reference for comparison - old_client = service_locator.get_storage_client() - - old_default_kvs = await KeyValueStore.open() - old_non_default_kvs = await KeyValueStore.open(name='non-default') - # Create data in default and non-default key-value store - await old_default_kvs.set_value('test', 'default value') - await old_non_default_kvs.set_value('test', 'non-default value') - - # We simulate another clean run, we expect the memory storage to read from the local data directory - # Default storages are purged based on purge_on_start parameter. - prepare_test_env() - - # Check if we're using a different memory storage instance - assert old_client is not service_locator.get_storage_client() - default_kvs = await KeyValueStore.open() - assert default_kvs is not old_default_kvs - non_default_kvs = await KeyValueStore.open(name='non-default') - assert non_default_kvs is not old_non_default_kvs - default_value = await default_kvs.get_value('test') - - if purge_on_start: - assert default_value is None - else: - assert default_value == 'default value' - - assert await non_default_kvs.get_value('test') == 'non-default value' - - -@pytest.mark.parametrize('purge_on_start', [True, False]) -async def test_actor_memory_storage_client_request_queue_e2e( - monkeypatch: pytest.MonkeyPatch, - purge_on_start: bool, # noqa: FBT001 - prepare_test_env: Callable[[], None], -) -> None: - """This test simulates two clean runs using memory storage. - The second run attempts to access data created by the first one. - We run 2 configurations with different `purge_on_start`.""" - # Configure purging env var - monkeypatch.setenv('CRAWLEE_PURGE_ON_START', f'{int(purge_on_start)}') - - # Add some requests to the default queue - default_queue = await RequestQueue.open() - for i in range(6): - # [0, 3] <- nothing special - # [1, 4] <- forefront=True - # [2, 5] <- handled=True - request_url = f'http://example.com/{i}' - forefront = i % 3 == 1 - was_handled = i % 3 == 2 - await default_queue.add_request( - Request.from_url( - unique_key=str(i), - url=request_url, - handled_at=datetime.now(timezone.utc) if was_handled else None, - payload=b'test', - ), - forefront=forefront, - ) - - # We simulate another clean run, we expect the memory storage to read from the local data directory - # Default storages are purged based on purge_on_start parameter. - prepare_test_env() - - # Add some more requests to the default queue - default_queue = await RequestQueue.open() - for i in range(6, 12): - # [6, 9] <- nothing special - # [7, 10] <- forefront=True - # [8, 11] <- handled=True - request_url = f'http://example.com/{i}' - forefront = i % 3 == 1 - was_handled = i % 3 == 2 - await default_queue.add_request( - Request.from_url( - unique_key=str(i), - url=request_url, - handled_at=datetime.now(timezone.utc) if was_handled else None, - payload=b'test', - ), - forefront=forefront, - ) - - queue_info = await default_queue.get_info() - assert queue_info is not None - - # If the queue was purged between the runs, only the requests from the second run should be present, - # in the right order - if purge_on_start: - assert queue_info.total_request_count == 6 - assert queue_info.handled_request_count == 2 - - expected_pending_request_order = [10, 7, 6, 9] - # If the queue was NOT purged between the runs, all the requests should be in the queue in the right order - else: - assert queue_info.total_request_count == 12 - assert queue_info.handled_request_count == 4 - - expected_pending_request_order = [10, 7, 4, 1, 0, 3, 6, 9] - - actual_requests = list[Request]() - while req := await default_queue.fetch_next_request(): - actual_requests.append(req) - - assert [int(req.unique_key) for req in actual_requests] == expected_pending_request_order - assert [req.url for req in actual_requests] == [f'http://example.com/{req.unique_key}' for req in actual_requests] - assert [req.payload for req in actual_requests] == [b'test' for _ in actual_requests] diff --git a/tests/unit/storage_clients/_memory/test_request_queue_client.py b/tests/unit/storage_clients/_memory/test_request_queue_client.py deleted file mode 100644 index feffacbbd8..0000000000 --- a/tests/unit/storage_clients/_memory/test_request_queue_client.py +++ /dev/null @@ -1,249 +0,0 @@ -from __future__ import annotations - -import asyncio -from datetime import datetime, timezone -from pathlib import Path -from typing import TYPE_CHECKING - -import pytest - -from crawlee import Request -from crawlee._request import RequestState - -if TYPE_CHECKING: - from crawlee.storage_clients import MemoryStorageClient - from crawlee.storage_clients._memory import RequestQueueClient - - -@pytest.fixture -async def request_queue_client(memory_storage_client: MemoryStorageClient) -> RequestQueueClient: - request_queues_client = memory_storage_client.request_queues() - rq_info = await request_queues_client.get_or_create(name='test') - return memory_storage_client.request_queue(rq_info.id) - - -async def test_nonexistent(memory_storage_client: MemoryStorageClient) -> None: - request_queue_client = memory_storage_client.request_queue(id='nonexistent-id') - assert await request_queue_client.get() is None - with pytest.raises(ValueError, match='Request queue with id "nonexistent-id" does not exist.'): - await request_queue_client.update(name='test-update') - await request_queue_client.delete() - - -async def test_get(request_queue_client: RequestQueueClient) -> None: - await asyncio.sleep(0.1) - info = await request_queue_client.get() - assert info is not None - assert info.id == request_queue_client.id - assert info.accessed_at != info.created_at - - -async def test_update(request_queue_client: RequestQueueClient) -> None: - new_rq_name = 'test-update' - request = Request.from_url('https://apify.com') - await request_queue_client.add_request(request) - old_rq_info = await request_queue_client.get() - assert old_rq_info is not None - assert old_rq_info.name is not None - old_rq_directory = Path( - request_queue_client._memory_storage_client.request_queues_directory, - old_rq_info.name, - ) - new_rq_directory = Path(request_queue_client._memory_storage_client.request_queues_directory, new_rq_name) - assert (old_rq_directory / 'fvwscO2UJLdr10B.json').exists() is True - assert (new_rq_directory / 'fvwscO2UJLdr10B.json').exists() is False - - await asyncio.sleep(0.1) - updated_rq_info = await request_queue_client.update(name=new_rq_name) - assert (old_rq_directory / 'fvwscO2UJLdr10B.json').exists() is False - assert (new_rq_directory / 'fvwscO2UJLdr10B.json').exists() is True - # Only modified_at and accessed_at should be different - assert old_rq_info.created_at == updated_rq_info.created_at - assert old_rq_info.modified_at != updated_rq_info.modified_at - assert old_rq_info.accessed_at != updated_rq_info.accessed_at - - # Should fail with the same name - with pytest.raises(ValueError, match='Request queue with name "test-update" already exists'): - await request_queue_client.update(name=new_rq_name) - - -async def test_delete(request_queue_client: RequestQueueClient) -> None: - await request_queue_client.add_request(Request.from_url('https://apify.com')) - rq_info = await request_queue_client.get() - assert rq_info is not None - - rq_directory = Path(request_queue_client._memory_storage_client.request_queues_directory, str(rq_info.name)) - assert (rq_directory / 'fvwscO2UJLdr10B.json').exists() is True - - await request_queue_client.delete() - assert (rq_directory / 'fvwscO2UJLdr10B.json').exists() is False - - # Does not crash when called again - await request_queue_client.delete() - - -async def test_list_head(request_queue_client: RequestQueueClient) -> None: - await request_queue_client.add_request(Request.from_url('https://apify.com')) - await request_queue_client.add_request(Request.from_url('https://example.com')) - list_head = await request_queue_client.list_head() - assert len(list_head.items) == 2 - - for item in list_head.items: - assert item.id is not None - - -async def test_request_state_serialization(request_queue_client: RequestQueueClient) -> None: - request = Request.from_url('https://crawlee.dev', payload=b'test') - request.state = RequestState.UNPROCESSED - - await request_queue_client.add_request(request) - - result = await request_queue_client.list_head() - assert len(result.items) == 1 - assert result.items[0] == request - - got_request = await request_queue_client.get_request(request.id) - - assert request == got_request - - -async def test_add_record(request_queue_client: RequestQueueClient) -> None: - processed_request_forefront = await request_queue_client.add_request( - Request.from_url('https://apify.com'), - forefront=True, - ) - processed_request_not_forefront = await request_queue_client.add_request( - Request.from_url('https://example.com'), - forefront=False, - ) - - assert processed_request_forefront.id is not None - assert processed_request_not_forefront.id is not None - assert processed_request_forefront.was_already_handled is False - assert processed_request_not_forefront.was_already_handled is False - - rq_info = await request_queue_client.get() - assert rq_info is not None - assert rq_info.pending_request_count == rq_info.total_request_count == 2 - assert rq_info.handled_request_count == 0 - - -async def test_get_record(request_queue_client: RequestQueueClient) -> None: - request_url = 'https://apify.com' - processed_request = await request_queue_client.add_request(Request.from_url(request_url)) - - request = await request_queue_client.get_request(processed_request.id) - assert request is not None - assert request.url == request_url - - # Non-existent id - assert (await request_queue_client.get_request('non-existent id')) is None - - -async def test_update_record(request_queue_client: RequestQueueClient) -> None: - processed_request = await request_queue_client.add_request(Request.from_url('https://apify.com')) - request = await request_queue_client.get_request(processed_request.id) - assert request is not None - - rq_info_before_update = await request_queue_client.get() - assert rq_info_before_update is not None - assert rq_info_before_update.pending_request_count == 1 - assert rq_info_before_update.handled_request_count == 0 - - request.handled_at = datetime.now(timezone.utc) - request_update_info = await request_queue_client.update_request(request) - - assert request_update_info.was_already_handled is False - - rq_info_after_update = await request_queue_client.get() - assert rq_info_after_update is not None - assert rq_info_after_update.pending_request_count == 0 - assert rq_info_after_update.handled_request_count == 1 - - -async def test_delete_record(request_queue_client: RequestQueueClient) -> None: - processed_request_pending = await request_queue_client.add_request( - Request.from_url( - url='https://apify.com', - unique_key='pending', - ), - ) - - processed_request_handled = await request_queue_client.add_request( - Request.from_url( - url='https://apify.com', - unique_key='handled', - handled_at=datetime.now(timezone.utc), - ), - ) - - rq_info_before_delete = await request_queue_client.get() - assert rq_info_before_delete is not None - assert rq_info_before_delete.pending_request_count == 1 - - await request_queue_client.delete_request(processed_request_pending.id) - rq_info_after_first_delete = await request_queue_client.get() - assert rq_info_after_first_delete is not None - assert rq_info_after_first_delete.pending_request_count == 0 - assert rq_info_after_first_delete.handled_request_count == 1 - - await request_queue_client.delete_request(processed_request_handled.id) - rq_info_after_second_delete = await request_queue_client.get() - assert rq_info_after_second_delete is not None - assert rq_info_after_second_delete.pending_request_count == 0 - assert rq_info_after_second_delete.handled_request_count == 0 - - # Does not crash when called again - await request_queue_client.delete_request(processed_request_pending.id) - - -async def test_forefront(request_queue_client: RequestQueueClient) -> None: - # this should create a queue with requests in this order: - # Handled: - # 2, 5, 8 - # Not handled: - # 7, 4, 1, 0, 3, 6 - for i in range(9): - request_url = f'http://example.com/{i}' - forefront = i % 3 == 1 - was_handled = i % 3 == 2 - await request_queue_client.add_request( - Request.from_url( - url=request_url, - unique_key=str(i), - handled_at=datetime.now(timezone.utc) if was_handled else None, - ), - forefront=forefront, - ) - - # Check that the queue head (unhandled items) is in the right order - queue_head = await request_queue_client.list_head() - req_unique_keys = [req.unique_key for req in queue_head.items] - assert req_unique_keys == ['7', '4', '1', '0', '3', '6'] - - # Mark request #1 as handled - await request_queue_client.update_request( - Request.from_url( - url='http://example.com/1', - unique_key='1', - handled_at=datetime.now(timezone.utc), - ), - ) - # Move request #3 to forefront - await request_queue_client.update_request( - Request.from_url(url='http://example.com/3', unique_key='3'), - forefront=True, - ) - - # Check that the queue head (unhandled items) is in the right order after the updates - queue_head = await request_queue_client.list_head() - req_unique_keys = [req.unique_key for req in queue_head.items] - assert req_unique_keys == ['3', '7', '4', '0', '6'] - - -async def test_add_duplicate_record(request_queue_client: RequestQueueClient) -> None: - processed_request = await request_queue_client.add_request(Request.from_url('https://apify.com')) - processed_request_duplicate = await request_queue_client.add_request(Request.from_url('https://apify.com')) - - assert processed_request.id == processed_request_duplicate.id - assert processed_request_duplicate.was_already_present is True diff --git a/tests/unit/storage_clients/_memory/test_request_queue_collection_client.py b/tests/unit/storage_clients/_memory/test_request_queue_collection_client.py deleted file mode 100644 index fa10889f83..0000000000 --- a/tests/unit/storage_clients/_memory/test_request_queue_collection_client.py +++ /dev/null @@ -1,42 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -import pytest - -if TYPE_CHECKING: - from crawlee.storage_clients import MemoryStorageClient - from crawlee.storage_clients._memory import RequestQueueCollectionClient - - -@pytest.fixture -def request_queues_client(memory_storage_client: MemoryStorageClient) -> RequestQueueCollectionClient: - return memory_storage_client.request_queues() - - -async def test_get_or_create(request_queues_client: RequestQueueCollectionClient) -> None: - rq_name = 'test' - # A new request queue gets created - rq_info = await request_queues_client.get_or_create(name=rq_name) - assert rq_info.name == rq_name - - # Another get_or_create call returns the same request queue - rq_existing = await request_queues_client.get_or_create(name=rq_name) - assert rq_info.id == rq_existing.id - assert rq_info.name == rq_existing.name - assert rq_info.created_at == rq_existing.created_at - - -async def test_list(request_queues_client: RequestQueueCollectionClient) -> None: - assert (await request_queues_client.list()).count == 0 - rq_info = await request_queues_client.get_or_create(name='dataset') - rq_list = await request_queues_client.list() - assert rq_list.count == 1 - assert rq_list.items[0].name == rq_info.name - - # Test sorting behavior - newer_rq_info = await request_queues_client.get_or_create(name='newer-dataset') - rq_list_sorting = await request_queues_client.list() - assert rq_list_sorting.count == 2 - assert rq_list_sorting.items[0].name == rq_info.name - assert rq_list_sorting.items[1].name == newer_rq_info.name diff --git a/tests/unit/storages/test_dataset.py b/tests/unit/storages/test_dataset.py index f299aee08d..b4f75bc6b4 100644 --- a/tests/unit/storages/test_dataset.py +++ b/tests/unit/storages/test_dataset.py @@ -1,156 +1,584 @@ +# TODO: Update crawlee_storage_dir args once the Pydantic bug is fixed +# https://github.com/apify/crawlee-python/issues/146 + from __future__ import annotations -from datetime import datetime, timezone from typing import TYPE_CHECKING import pytest -from crawlee import service_locator -from crawlee.storage_clients.models import StorageMetadata +from crawlee.configuration import Configuration +from crawlee.storage_clients import FileSystemStorageClient, MemoryStorageClient from crawlee.storages import Dataset, KeyValueStore if TYPE_CHECKING: from collections.abc import AsyncGenerator + from pathlib import Path + from typing import Any + + from crawlee.storage_clients import StorageClient + + +@pytest.fixture(params=['memory', 'file_system']) +def storage_client(request: pytest.FixtureRequest) -> StorageClient: + """Parameterized fixture to test with different storage clients.""" + if request.param == 'memory': + return MemoryStorageClient() + + return FileSystemStorageClient() + + +@pytest.fixture +def configuration(tmp_path: Path) -> Configuration: + """Provide a configuration with a temporary storage directory.""" + return Configuration( + crawlee_storage_dir=str(tmp_path), # type: ignore[call-arg] + purge_on_start=True, + ) @pytest.fixture -async def dataset() -> AsyncGenerator[Dataset, None]: - dataset = await Dataset.open() +async def dataset( + storage_client: StorageClient, + configuration: Configuration, +) -> AsyncGenerator[Dataset, None]: + """Fixture that provides a dataset instance for each test.""" + dataset = await Dataset.open( + storage_client=storage_client, + configuration=configuration, + ) + yield dataset await dataset.drop() -async def test_open() -> None: - default_dataset = await Dataset.open() - default_dataset_by_id = await Dataset.open(id=default_dataset.id) +async def test_open_creates_new_dataset( + storage_client: StorageClient, + configuration: Configuration, +) -> None: + """Test that open() creates a new dataset with proper metadata.""" + dataset = await Dataset.open( + name='new_dataset', + storage_client=storage_client, + configuration=configuration, + ) - assert default_dataset is default_dataset_by_id + # Verify dataset properties + assert dataset.id is not None + assert dataset.name == 'new_dataset' - dataset_name = 'dummy-name' - named_dataset = await Dataset.open(name=dataset_name) - assert default_dataset is not named_dataset + metadata = await dataset.get_metadata() + assert metadata.item_count == 0 + + await dataset.drop() - with pytest.raises(RuntimeError, match='Dataset with id "nonexistent-id" does not exist!'): - await Dataset.open(id='nonexistent-id') - # Test that when you try to open a dataset by ID and you use a name of an existing dataset, - # it doesn't work - with pytest.raises(RuntimeError, match='Dataset with id "dummy-name" does not exist!'): - await Dataset.open(id='dummy-name') +async def test_reopen_default( + storage_client: StorageClient, + configuration: Configuration, +) -> None: + """Test reopening a dataset with default parameters.""" + # Create a first dataset instance with default parameters + dataset_1 = await Dataset.open( + storage_client=storage_client, + configuration=configuration, + ) + # Verify default properties + assert dataset_1.id is not None + metadata_1 = await dataset_1.get_metadata() + assert metadata_1.item_count == 0 -async def test_consistency_accross_two_clients() -> None: - dataset = await Dataset.open(name='my-dataset') - await dataset.push_data({'key': 'value'}) + # Add an item + await dataset_1.push_data({'key': 'value'}) + metadata_1 = await dataset_1.get_metadata() + assert metadata_1.item_count == 1 - dataset_by_id = await Dataset.open(id=dataset.id) - await dataset_by_id.push_data({'key2': 'value2'}) + # Reopen the same dataset + dataset_2 = await Dataset.open( + storage_client=storage_client, + configuration=configuration, + ) - assert (await dataset.get_data()).items == [{'key': 'value'}, {'key2': 'value2'}] - assert (await dataset_by_id.get_data()).items == [{'key': 'value'}, {'key2': 'value2'}] + # Verify both instances reference the same dataset + assert dataset_2.id == dataset_1.id + assert dataset_2.name == dataset_1.name + metadata_1 = await dataset_1.get_metadata() + metadata_2 = await dataset_2.get_metadata() + assert metadata_2.item_count == metadata_1.item_count == 1 + + # Verify they are the same object (cached) + assert id(dataset_1) == id(dataset_2) + + # Clean up + await dataset_1.drop() + + +async def test_open_by_id( + storage_client: StorageClient, + configuration: Configuration, +) -> None: + """Test opening a dataset by its ID.""" + # First create a dataset by name + dataset1 = await Dataset.open( + name='dataset_by_id_test', + storage_client=storage_client, + configuration=configuration, + ) - await dataset.drop() - with pytest.raises(RuntimeError, match='Storage with provided ID was not found'): - await dataset_by_id.drop() - - -async def test_same_references() -> None: - dataset1 = await Dataset.open() - dataset2 = await Dataset.open() - assert dataset1 is dataset2 - - dataset_name = 'non-default' - dataset_named1 = await Dataset.open(name=dataset_name) - dataset_named2 = await Dataset.open(name=dataset_name) - assert dataset_named1 is dataset_named2 - - -async def test_drop() -> None: - dataset1 = await Dataset.open() - await dataset1.drop() - dataset2 = await Dataset.open() - assert dataset1 is not dataset2 - - -async def test_export(dataset: Dataset) -> None: - expected_csv = 'id,test\r\n0,test\r\n1,test\r\n2,test\r\n' - expected_json = [{'id': 0, 'test': 'test'}, {'id': 1, 'test': 'test'}, {'id': 2, 'test': 'test'}] - desired_item_count = 3 - await dataset.push_data([{'id': i, 'test': 'test'} for i in range(desired_item_count)]) - await dataset.export_to(key='dataset-csv', content_type='csv') - await dataset.export_to(key='dataset-json', content_type='json') - kvs = await KeyValueStore.open() - dataset_csv = await kvs.get_value(key='dataset-csv') - dataset_json = await kvs.get_value(key='dataset-json') - assert dataset_csv == expected_csv - assert dataset_json == expected_json - - -async def test_push_data(dataset: Dataset) -> None: - desired_item_count = 2000 - await dataset.push_data([{'id': i} for i in range(desired_item_count)]) - dataset_info = await dataset.get_info() - assert dataset_info is not None - assert dataset_info.item_count == desired_item_count - list_page = await dataset.get_data(limit=desired_item_count) - assert list_page.items[0]['id'] == 0 - assert list_page.items[-1]['id'] == desired_item_count - 1 - - -async def test_push_data_empty(dataset: Dataset) -> None: - await dataset.push_data([]) - dataset_info = await dataset.get_info() - assert dataset_info is not None - assert dataset_info.item_count == 0 - - -async def test_push_data_singular(dataset: Dataset) -> None: - await dataset.push_data({'id': 1}) - dataset_info = await dataset.get_info() - assert dataset_info is not None - assert dataset_info.item_count == 1 - list_page = await dataset.get_data() - assert list_page.items[0]['id'] == 1 - - -async def test_get_data(dataset: Dataset) -> None: # We don't test everything, that's done in memory storage tests - desired_item_count = 3 - await dataset.push_data([{'id': i} for i in range(desired_item_count)]) - list_page = await dataset.get_data() - assert list_page.count == desired_item_count - assert list_page.desc is False - assert list_page.offset == 0 - assert list_page.items[0]['id'] == 0 - assert list_page.items[-1]['id'] == desired_item_count - 1 + # Add some data to identify it + test_item = {'test': 'opening_by_id', 'timestamp': 12345} + await dataset1.push_data(test_item) + + # Open the dataset by ID + dataset2 = await Dataset.open( + id=dataset1.id, + storage_client=storage_client, + configuration=configuration, + ) + + # Verify it's the same dataset + assert dataset2.id == dataset1.id + assert dataset2.name == 'dataset_by_id_test' + + # Verify the data is still there + data = await dataset2.get_data() + assert data.count == 1 + assert data.items[0]['test'] == 'opening_by_id' + assert data.items[0]['timestamp'] == 12345 + + # Clean up + await dataset2.drop() + + +async def test_open_existing_dataset( + dataset: Dataset, + storage_client: StorageClient, +) -> None: + """Test that open() loads an existing dataset correctly.""" + # Open the same dataset again + reopened_dataset = await Dataset.open( + name=dataset.name, + storage_client=storage_client, + ) + + # Verify dataset properties + assert dataset.id == reopened_dataset.id + assert dataset.name == reopened_dataset.name + metadata = await dataset.get_metadata() + reopened_metadata = await reopened_dataset.get_metadata() + assert metadata.item_count == reopened_metadata.item_count + + # Verify they are the same object (from cache) + assert id(dataset) == id(reopened_dataset) + + +async def test_open_with_id_and_name( + storage_client: StorageClient, + configuration: Configuration, +) -> None: + """Test that open() raises an error when both id and name are provided.""" + with pytest.raises(ValueError, match='Only one of "id" or "name" can be specified'): + await Dataset.open( + id='some-id', + name='some-name', + storage_client=storage_client, + configuration=configuration, + ) + + +async def test_push_data_single_item(dataset: Dataset) -> None: + """Test pushing a single item to the dataset.""" + item = {'key': 'value', 'number': 42} + await dataset.push_data(item) + + # Verify item was stored + result = await dataset.get_data() + assert result.count == 1 + assert result.items[0] == item + + +async def test_push_data_multiple_items(dataset: Dataset) -> None: + """Test pushing multiple items to the dataset.""" + items = [ + {'id': 1, 'name': 'Item 1'}, + {'id': 2, 'name': 'Item 2'}, + {'id': 3, 'name': 'Item 3'}, + ] + await dataset.push_data(items) + + # Verify items were stored + result = await dataset.get_data() + assert result.count == 3 + assert result.items == items + + +async def test_get_data_empty_dataset(dataset: Dataset) -> None: + """Test getting data from an empty dataset returns empty results.""" + result = await dataset.get_data() + + assert result.count == 0 + assert result.total == 0 + assert result.items == [] + + +async def test_get_data_with_pagination(dataset: Dataset) -> None: + """Test getting data with offset and limit parameters for pagination.""" + # Add some items + items = [{'id': i} for i in range(1, 11)] # 10 items + await dataset.push_data(items) + + # Test offset + result = await dataset.get_data(offset=3) + assert result.count == 7 + assert result.offset == 3 + assert result.items[0]['id'] == 4 + + # Test limit + result = await dataset.get_data(limit=5) + assert result.count == 5 + assert result.limit == 5 + assert result.items[-1]['id'] == 5 + + # Test both offset and limit + result = await dataset.get_data(offset=2, limit=3) + assert result.count == 3 + assert result.offset == 2 + assert result.limit == 3 + assert result.items[0]['id'] == 3 + assert result.items[-1]['id'] == 5 + + +async def test_get_data_descending_order(dataset: Dataset) -> None: + """Test getting data in descending order reverses the item order.""" + # Add some items + items = [{'id': i} for i in range(1, 6)] # 5 items + await dataset.push_data(items) + + # Get items in descending order + result = await dataset.get_data(desc=True) + + assert result.desc is True + assert result.items[0]['id'] == 5 + assert result.items[-1]['id'] == 1 + + +async def test_get_data_skip_empty(dataset: Dataset) -> None: + """Test getting data with skip_empty option filters out empty items.""" + # Add some items including an empty one + items = [ + {'id': 1, 'name': 'Item 1'}, + {}, # Empty item + {'id': 3, 'name': 'Item 3'}, + ] + await dataset.push_data(items) + + # Get all items + result = await dataset.get_data() + assert result.count == 3 + + # Get non-empty items + result = await dataset.get_data(skip_empty=True) + assert result.count == 2 + assert all(item != {} for item in result.items) async def test_iterate_items(dataset: Dataset) -> None: - desired_item_count = 3 - idx = 0 - await dataset.push_data([{'id': i} for i in range(desired_item_count)]) + """Test iterating over dataset items yields each item in the correct order.""" + # Add some items + items = [{'id': i} for i in range(1, 6)] # 5 items + await dataset.push_data(items) - async for item in dataset.iterate_items(): - assert item['id'] == idx - idx += 1 + # Iterate over all items + collected_items = [item async for item in dataset.iterate_items()] - assert idx == desired_item_count + assert len(collected_items) == 5 + assert collected_items[0]['id'] == 1 + assert collected_items[-1]['id'] == 5 -async def test_from_storage_object() -> None: - storage_client = service_locator.get_storage_client() +async def test_iterate_items_with_options(dataset: Dataset) -> None: + """Test iterating with offset, limit and desc parameters.""" + # Add some items + items = [{'id': i} for i in range(1, 11)] # 10 items + await dataset.push_data(items) + + # Test with offset and limit + collected_items = [item async for item in dataset.iterate_items(offset=3, limit=3)] + + assert len(collected_items) == 3 + assert collected_items[0]['id'] == 4 + assert collected_items[-1]['id'] == 6 + + # Test with descending order + collected_items = [] + async for item in dataset.iterate_items(desc=True, limit=3): + collected_items.append(item) + + assert len(collected_items) == 3 + assert collected_items[0]['id'] == 10 + assert collected_items[-1]['id'] == 8 + + +async def test_list_items(dataset: Dataset) -> None: + """Test that list_items returns all dataset items as a list.""" + # Add some items + items = [{'id': i} for i in range(1, 6)] # 5 items + await dataset.push_data(items) + + # Get all items as a list + collected_items = await dataset.list_items() + + assert len(collected_items) == 5 + assert collected_items[0]['id'] == 1 + assert collected_items[-1]['id'] == 5 + + +async def test_list_items_with_options(dataset: Dataset) -> None: + """Test that list_items respects filtering options.""" + # Add some items + items: list[dict[str, Any]] = [ + {'id': 1, 'name': 'Item 1'}, + {'id': 2, 'name': 'Item 2'}, + {'id': 3}, # Item with missing 'name' field + {}, # Empty item + {'id': 5, 'name': 'Item 5'}, + ] + await dataset.push_data(items) + + # Test with offset and limit + collected_items = await dataset.list_items(offset=1, limit=2) + assert len(collected_items) == 2 + assert collected_items[0]['id'] == 2 + assert collected_items[1]['id'] == 3 + + # Test with descending order - skip empty items to avoid KeyError + collected_items = await dataset.list_items(desc=True, skip_empty=True) + + # Filter items that have an 'id' field + items_with_ids = [item for item in collected_items if 'id' in item] + id_values = [item['id'] for item in items_with_ids] + + # Verify the list is sorted in descending order + assert sorted(id_values, reverse=True) == id_values, f'IDs should be in descending order. Got {id_values}' + + # Verify key IDs are present and in the right order + if 5 in id_values and 3 in id_values: + assert id_values.index(5) < id_values.index(3), 'ID 5 should come before ID 3 in descending order' + + # Test with skip_empty + collected_items = await dataset.list_items(skip_empty=True) + assert len(collected_items) == 4 # Should skip the empty item + assert all(item != {} for item in collected_items) + + # Test with fields - manually filter since 'fields' parameter is not supported + # Get all items first + collected_items = await dataset.list_items() + assert len(collected_items) == 5 + + # Manually extract only the 'id' field from each item + filtered_items = [{key: item[key] for key in ['id'] if key in item} for item in collected_items] + + # Verify 'name' field is not present in any item + assert all('name' not in item for item in filtered_items) + + # Test clean functionality manually instead of using the clean parameter + # Get all items + collected_items = await dataset.list_items() + + # Manually filter out empty items as 'clean' would do + clean_items = [item for item in collected_items if item != {}] - storage_object = StorageMetadata( - id='dummy-id', - name='dummy-name', - accessed_at=datetime.now(timezone.utc), - created_at=datetime.now(timezone.utc), - modified_at=datetime.now(timezone.utc), - extra_attribute='extra', + assert len(clean_items) == 4 # Should have 4 non-empty items + assert all(item != {} for item in clean_items) + + +async def test_drop( + storage_client: StorageClient, + configuration: Configuration, +) -> None: + """Test dropping a dataset removes it from cache and clears its data.""" + dataset = await Dataset.open( + name='drop_test', + storage_client=storage_client, + configuration=configuration, + ) + + # Add some data + await dataset.push_data({'test': 'data'}) + + # Drop the dataset + await dataset.drop() + + # Verify dataset is empty (by creating a new one with the same name) + new_dataset = await Dataset.open( + name='drop_test', + storage_client=storage_client, + configuration=configuration, + ) + + result = await new_dataset.get_data() + assert result.count == 0 + await new_dataset.drop() + + +async def test_export_to_json( + dataset: Dataset, + storage_client: StorageClient, +) -> None: + """Test exporting dataset to JSON format.""" + # Create a key-value store for export + kvs = await KeyValueStore.open( + name='export_kvs', + storage_client=storage_client, + ) + + # Add some items to the dataset + items = [ + {'id': 1, 'name': 'Item 1'}, + {'id': 2, 'name': 'Item 2'}, + {'id': 3, 'name': 'Item 3'}, + ] + await dataset.push_data(items) + + # Export to JSON + await dataset.export_to( + key='dataset_export.json', + content_type='json', + to_kvs_name='export_kvs', + to_kvs_storage_client=storage_client, ) - dataset = Dataset.from_storage_object(storage_client, storage_object) + # Retrieve the exported file + record = await kvs.get_value(key='dataset_export.json') + assert record is not None + + # Verify content has all the items + assert '"id": 1' in record + assert '"id": 2' in record + assert '"id": 3' in record + + await kvs.drop() + + +async def test_export_to_csv( + dataset: Dataset, + storage_client: StorageClient, +) -> None: + """Test exporting dataset to CSV format.""" + # Create a key-value store for export + kvs = await KeyValueStore.open( + name='export_kvs', + storage_client=storage_client, + ) - assert dataset.id == storage_object.id - assert dataset.name == storage_object.name - assert dataset.storage_object == storage_object - assert storage_object.model_extra.get('extra_attribute') == 'extra' # type: ignore[union-attr] + # Add some items to the dataset + items = [ + {'id': 1, 'name': 'Item 1'}, + {'id': 2, 'name': 'Item 2'}, + {'id': 3, 'name': 'Item 3'}, + ] + await dataset.push_data(items) + + # Export to CSV + await dataset.export_to( + key='dataset_export.csv', + content_type='csv', + to_kvs_name='export_kvs', + to_kvs_storage_client=storage_client, + ) + + # Retrieve the exported file + record = await kvs.get_value(key='dataset_export.csv') + assert record is not None + + # Verify content has all the items + assert 'id,name' in record + assert '1,Item 1' in record + assert '2,Item 2' in record + assert '3,Item 3' in record + + await kvs.drop() + + +async def test_export_to_invalid_content_type(dataset: Dataset) -> None: + """Test exporting dataset with invalid content type raises error.""" + with pytest.raises(ValueError, match='Unsupported content type'): + await dataset.export_to( + key='invalid_export', + content_type='invalid', # type: ignore[call-overload] # Intentionally invalid content type + ) + + +async def test_large_dataset(dataset: Dataset) -> None: + """Test handling a large dataset with many items.""" + items = [{'id': i, 'value': f'value-{i}'} for i in range(100)] + await dataset.push_data(items) + + # Test that all items are retrieved + result = await dataset.get_data(limit=None) + assert result.count == 100 + assert result.total == 100 + + # Test pagination with large datasets + result = await dataset.get_data(offset=50, limit=25) + assert result.count == 25 + assert result.offset == 50 + assert result.items[0]['id'] == 50 + assert result.items[-1]['id'] == 74 + + +async def test_purge( + storage_client: StorageClient, + configuration: Configuration, +) -> None: + """Test purging a dataset removes all data but keeps the dataset itself.""" + # First create a dataset + dataset = await Dataset.open( + name='purge_test_dataset', + storage_client=storage_client, + configuration=configuration, + ) + + # Add some data + initial_items = [ + {'id': 1, 'name': 'Item 1'}, + {'id': 2, 'name': 'Item 2'}, + {'id': 3, 'name': 'Item 3'}, + ] + await dataset.push_data(initial_items) + + # Verify data was added + data = await dataset.get_data() + assert data.count == 3 + assert data.total == 3 + metadata = await dataset.get_metadata() + assert metadata.item_count == 3 + + # Record the dataset ID + dataset_id = dataset.id + + # Purge the dataset + await dataset.purge() + + # Verify the dataset still exists but is empty + assert dataset.id == dataset_id # Same ID preserved + assert dataset.name == 'purge_test_dataset' # Same name preserved + + # Dataset should be empty now + data = await dataset.get_data() + assert data.count == 0 + assert data.total == 0 + metadata = await dataset.get_metadata() + assert metadata.item_count == 0 + + # Verify we can add new data after purging + new_item = {'id': 4, 'name': 'New Item After Purge'} + await dataset.push_data(new_item) + + data = await dataset.get_data() + assert data.count == 1 + assert data.items[0]['name'] == 'New Item After Purge' + + # Clean up + await dataset.drop() diff --git a/tests/unit/storages/test_key_value_store.py b/tests/unit/storages/test_key_value_store.py index ea3f4e5f7d..25bbcb4fc0 100644 --- a/tests/unit/storages/test_key_value_store.py +++ b/tests/unit/storages/test_key_value_store.py @@ -1,229 +1,600 @@ +# TODO: Update crawlee_storage_dir args once the Pydantic bug is fixed +# https://github.com/apify/crawlee-python/issues/146 + from __future__ import annotations -import asyncio -from datetime import datetime, timedelta, timezone -from itertools import chain, repeat -from typing import TYPE_CHECKING, cast -from unittest.mock import patch -from urllib.parse import urlparse +import json +from typing import TYPE_CHECKING import pytest -from crawlee import service_locator -from crawlee.events import EventManager -from crawlee.storage_clients.models import StorageMetadata +from crawlee.configuration import Configuration +from crawlee.storage_clients import FileSystemStorageClient, MemoryStorageClient from crawlee.storages import KeyValueStore if TYPE_CHECKING: from collections.abc import AsyncGenerator + from pathlib import Path + + from crawlee.storage_clients import StorageClient + + +@pytest.fixture(params=['memory', 'file_system']) +def storage_client(request: pytest.FixtureRequest) -> StorageClient: + """Parameterized fixture to test with different storage clients.""" + if request.param == 'memory': + return MemoryStorageClient() - from crawlee._types import JsonSerializable + return FileSystemStorageClient() @pytest.fixture -async def mock_event_manager() -> AsyncGenerator[EventManager, None]: - async with EventManager(persist_state_interval=timedelta(milliseconds=50)) as event_manager: - with patch('crawlee.service_locator.get_event_manager', return_value=event_manager): - yield event_manager +def configuration(tmp_path: Path) -> Configuration: + """Provide a configuration with a temporary storage directory.""" + return Configuration( + crawlee_storage_dir=str(tmp_path), # type: ignore[call-arg] + purge_on_start=True, + ) + + +@pytest.fixture +async def kvs( + storage_client: StorageClient, + configuration: Configuration, +) -> AsyncGenerator[KeyValueStore, None]: + """Fixture that provides a key-value store instance for each test.""" + kvs = await KeyValueStore.open( + storage_client=storage_client, + configuration=configuration, + ) + + yield kvs + await kvs.drop() + + +async def test_open_creates_new_kvs( + storage_client: StorageClient, + configuration: Configuration, +) -> None: + """Test that open() creates a new key-value store with proper metadata.""" + kvs = await KeyValueStore.open( + name='new_kvs', + storage_client=storage_client, + configuration=configuration, + ) + + # Verify key-value store properties + assert kvs.id is not None + assert kvs.name == 'new_kvs' + + await kvs.drop() + + +async def test_open_existing_kvs( + kvs: KeyValueStore, + storage_client: StorageClient, +) -> None: + """Test that open() loads an existing key-value store correctly.""" + # Open the same key-value store again + reopened_kvs = await KeyValueStore.open( + name=kvs.name, + storage_client=storage_client, + ) + + # Verify key-value store properties + assert kvs.id == reopened_kvs.id + assert kvs.name == reopened_kvs.name + + # Verify they are the same object (from cache) + assert id(kvs) == id(reopened_kvs) + + +async def test_open_with_id_and_name( + storage_client: StorageClient, + configuration: Configuration, +) -> None: + """Test that open() raises an error when both id and name are provided.""" + with pytest.raises(ValueError, match='Only one of "id" or "name" can be specified'): + await KeyValueStore.open( + id='some-id', + name='some-name', + storage_client=storage_client, + configuration=configuration, + ) + + +async def test_open_by_id( + storage_client: StorageClient, + configuration: Configuration, +) -> None: + """Test opening a key-value store by its ID.""" + # First create a key-value store by name + kvs1 = await KeyValueStore.open( + name='kvs_by_id_test', + storage_client=storage_client, + configuration=configuration, + ) + + # Add some data to identify it + await kvs1.set_value('test_key', {'test': 'opening_by_id', 'timestamp': 12345}) + + # Open the key-value store by ID + kvs2 = await KeyValueStore.open( + id=kvs1.id, + storage_client=storage_client, + configuration=configuration, + ) + # Verify it's the same key-value store + assert kvs2.id == kvs1.id + assert kvs2.name == 'kvs_by_id_test' -async def test_open() -> None: - default_key_value_store = await KeyValueStore.open() - default_key_value_store_by_id = await KeyValueStore.open(id=default_key_value_store.id) + # Verify the data is still there + value = await kvs2.get_value('test_key') + assert value is not None + assert value['test'] == 'opening_by_id' + assert value['timestamp'] == 12345 - assert default_key_value_store is default_key_value_store_by_id + # Clean up + await kvs2.drop() - key_value_store_name = 'dummy-name' - named_key_value_store = await KeyValueStore.open(name=key_value_store_name) - assert default_key_value_store is not named_key_value_store - with pytest.raises(RuntimeError, match='KeyValueStore with id "nonexistent-id" does not exist!'): - await KeyValueStore.open(id='nonexistent-id') +async def test_set_get_value(kvs: KeyValueStore) -> None: + """Test setting and getting a value from the key-value store.""" + # Set a value + test_key = 'test-key' + test_value = {'data': 'value', 'number': 42} + await kvs.set_value(test_key, test_value) - # Test that when you try to open a key-value store by ID and you use a name of an existing key-value store, - # it doesn't work - with pytest.raises(RuntimeError, match='KeyValueStore with id "dummy-name" does not exist!'): - await KeyValueStore.open(id='dummy-name') + # Get the value + result = await kvs.get_value(test_key) + assert result == test_value -async def test_open_save_storage_object() -> None: - default_key_value_store = await KeyValueStore.open() +async def test_set_get_none(kvs: KeyValueStore) -> None: + """Test setting and getting None as a value.""" + test_key = 'none-key' + await kvs.set_value(test_key, None) + result = await kvs.get_value(test_key) + assert result is None - assert default_key_value_store.storage_object is not None - assert default_key_value_store.storage_object.id == default_key_value_store.id +async def test_get_value_nonexistent(kvs: KeyValueStore) -> None: + """Test getting a nonexistent value returns None.""" + result = await kvs.get_value('nonexistent-key') + assert result is None -async def test_consistency_accross_two_clients() -> None: - kvs = await KeyValueStore.open(name='my-kvs') - await kvs.set_value('key', 'value') - kvs_by_id = await KeyValueStore.open(id=kvs.id) - await kvs_by_id.set_value('key2', 'value2') +async def test_get_value_with_default(kvs: KeyValueStore) -> None: + """Test getting a nonexistent value with a default value.""" + default_value = {'default': True} + result = await kvs.get_value('nonexistent-key', default_value=default_value) + assert result == default_value - assert (await kvs.get_value('key')) == 'value' - assert (await kvs.get_value('key2')) == 'value2' - assert (await kvs_by_id.get_value('key')) == 'value' - assert (await kvs_by_id.get_value('key2')) == 'value2' +async def test_set_value_with_content_type(kvs: KeyValueStore) -> None: + """Test setting a value with a specific content type.""" + test_key = 'test-json' + test_value = {'data': 'value', 'items': [1, 2, 3]} + await kvs.set_value(test_key, test_value, content_type='application/json') + # Verify the value is retrievable + result = await kvs.get_value(test_key) + assert result == test_value + + +async def test_delete_value(kvs: KeyValueStore) -> None: + """Test deleting a value from the key-value store.""" + # Set a value first + test_key = 'delete-me' + test_value = 'value to delete' + await kvs.set_value(test_key, test_value) + + # Verify value exists + assert await kvs.get_value(test_key) == test_value + + # Delete the value + await kvs.delete_value(test_key) + + # Verify value is gone + assert await kvs.get_value(test_key) is None + + +async def test_list_keys_empty_kvs(kvs: KeyValueStore) -> None: + """Test listing keys from an empty key-value store.""" + keys = await kvs.list_keys() + assert len(keys) == 0 + + +async def test_list_keys(kvs: KeyValueStore) -> None: + """Test listing keys from a key-value store with items.""" + # Add some items + await kvs.set_value('key1', 'value1') + await kvs.set_value('key2', 'value2') + await kvs.set_value('key3', 'value3') + + # List keys + keys = await kvs.list_keys() + + # Verify keys + assert len(keys) == 3 + key_names = [k.key for k in keys] + assert 'key1' in key_names + assert 'key2' in key_names + assert 'key3' in key_names + + +async def test_list_keys_with_limit(kvs: KeyValueStore) -> None: + """Test listing keys with a limit parameter.""" + # Add some items + for i in range(10): + await kvs.set_value(f'key{i}', f'value{i}') + + # List with limit + keys = await kvs.list_keys(limit=5) + assert len(keys) == 5 + + +async def test_list_keys_with_exclusive_start_key(kvs: KeyValueStore) -> None: + """Test listing keys with an exclusive start key.""" + # Add some items in a known order + await kvs.set_value('key1', 'value1') + await kvs.set_value('key2', 'value2') + await kvs.set_value('key3', 'value3') + await kvs.set_value('key4', 'value4') + await kvs.set_value('key5', 'value5') + + # Get all keys first to determine their order + all_keys = await kvs.list_keys() + all_key_names = [k.key for k in all_keys] + + if len(all_key_names) >= 3: + # Start from the second key + start_key = all_key_names[1] + keys = await kvs.list_keys(exclusive_start_key=start_key) + + # We should get all keys after the start key + expected_count = len(all_key_names) - all_key_names.index(start_key) - 1 + assert len(keys) == expected_count + + # First key should be the one after start_key + first_returned_key = keys[0].key + assert first_returned_key != start_key + assert all_key_names.index(first_returned_key) > all_key_names.index(start_key) + + +async def test_iterate_keys(kvs: KeyValueStore) -> None: + """Test iterating over keys in the key-value store.""" + # Add some items + await kvs.set_value('key1', 'value1') + await kvs.set_value('key2', 'value2') + await kvs.set_value('key3', 'value3') + + collected_keys = [key async for key in kvs.iterate_keys()] + + # Verify iteration result + assert len(collected_keys) == 3 + key_names = [k.key for k in collected_keys] + assert 'key1' in key_names + assert 'key2' in key_names + assert 'key3' in key_names + + +async def test_iterate_keys_with_limit(kvs: KeyValueStore) -> None: + """Test iterating over keys with a limit parameter.""" + # Add some items + for i in range(10): + await kvs.set_value(f'key{i}', f'value{i}') + + collected_keys = [key async for key in kvs.iterate_keys(limit=5)] + + # Verify iteration result + assert len(collected_keys) == 5 + + +async def test_drop( + storage_client: StorageClient, + configuration: Configuration, +) -> None: + """Test dropping a key-value store removes it from cache and clears its data.""" + kvs = await KeyValueStore.open( + name='drop_test', + storage_client=storage_client, + configuration=configuration, + ) + + # Add some data + await kvs.set_value('test', 'data') + + # Drop the key-value store await kvs.drop() - with pytest.raises(RuntimeError, match='Storage with provided ID was not found'): - await kvs_by_id.drop() + # Verify key-value store is empty (by creating a new one with the same name) + new_kvs = await KeyValueStore.open( + name='drop_test', + storage_client=storage_client, + configuration=configuration, + ) + + # Attempt to get a previously stored value + result = await new_kvs.get_value('test') + assert result is None + await new_kvs.drop() + + +async def test_reopen_default( + storage_client: StorageClient, + configuration: Configuration, +) -> None: + """Test reopening the default key-value store.""" + # Open the default key-value store + kvs1 = await KeyValueStore.open( + storage_client=storage_client, + configuration=configuration, + ) + + # Set a value + await kvs1.set_value('test_key', 'test_value') -async def test_same_references() -> None: - kvs1 = await KeyValueStore.open() - kvs2 = await KeyValueStore.open() - assert kvs1 is kvs2 + # Open the default key-value store again + kvs2 = await KeyValueStore.open( + storage_client=storage_client, + configuration=configuration, + ) - kvs_name = 'non-default' - kvs_named1 = await KeyValueStore.open(name=kvs_name) - kvs_named2 = await KeyValueStore.open(name=kvs_name) - assert kvs_named1 is kvs_named2 + # Verify they are the same store + assert kvs1.id == kvs2.id + assert kvs1.name == kvs2.name + + # Verify the value is accessible + value1 = await kvs1.get_value('test_key') + value2 = await kvs2.get_value('test_key') + assert value1 == value2 == 'test_value' + + # Verify they are the same object + assert id(kvs1) == id(kvs2) + + +async def test_complex_data_types(kvs: KeyValueStore) -> None: + """Test storing and retrieving complex data types.""" + # Test nested dictionaries + nested_dict = { + 'level1': { + 'level2': { + 'level3': 'deep value', + 'numbers': [1, 2, 3], + }, + }, + 'array': [{'a': 1}, {'b': 2}], + } + await kvs.set_value('nested', nested_dict) + result = await kvs.get_value('nested') + assert result == nested_dict + + # Test lists + test_list = [1, 'string', True, None, {'key': 'value'}] + await kvs.set_value('list', test_list) + result = await kvs.get_value('list') + assert result == test_list + + +async def test_string_data(kvs: KeyValueStore) -> None: + """Test storing and retrieving string data.""" + # Plain string + await kvs.set_value('string', 'simple string') + result = await kvs.get_value('string') + assert result == 'simple string' + + # JSON string + json_string = json.dumps({'key': 'value'}) + await kvs.set_value('json_string', json_string) + result = await kvs.get_value('json_string') + assert result == json_string + + +async def test_key_with_special_characters(kvs: KeyValueStore) -> None: + """Test storing and retrieving values with keys containing special characters.""" + # Key with spaces, slashes, and special characters + special_key = 'key with spaces/and/slashes!@#$%^&*()' + test_value = 'Special key value' + + # Store the value with the special key + await kvs.set_value(key=special_key, value=test_value) + + # Retrieve the value and verify it matches + result = await kvs.get_value(key=special_key) + assert result is not None + assert result == test_value + + # Make sure the key is properly listed + keys = await kvs.list_keys() + key_names = [k.key for k in keys] + assert special_key in key_names + + # Test key deletion + await kvs.delete_value(key=special_key) + assert await kvs.get_value(key=special_key) is None + + +async def test_data_persistence_on_reopen(configuration: Configuration) -> None: + """Test that data persists when reopening a KeyValueStore.""" + kvs1 = await KeyValueStore.open(configuration=configuration) + + await kvs1.set_value('key_123', 'value_123') + + result1 = await kvs1.get_value('key_123') + assert result1 == 'value_123' + + kvs2 = await KeyValueStore.open(configuration=configuration) + + result2 = await kvs2.get_value('key_123') + assert result2 == 'value_123' + assert await kvs1.list_keys() == await kvs2.list_keys() + + await kvs2.set_value('key_456', 'value_456') + + result1 = await kvs1.get_value('key_456') + assert result1 == 'value_456' + + +async def test_purge( + storage_client: StorageClient, + configuration: Configuration, +) -> None: + """Test purging a key-value store removes all values but keeps the store itself.""" + # First create a key-value store + kvs = await KeyValueStore.open( + name='purge_test_kvs', + storage_client=storage_client, + configuration=configuration, + ) + # Add some values + await kvs.set_value('key1', 'value1') + await kvs.set_value('key2', 'value2') + await kvs.set_value('key3', {'complex': 'value', 'number': 42}) -async def test_drop() -> None: - kvs1 = await KeyValueStore.open() - await kvs1.drop() - kvs2 = await KeyValueStore.open() - assert kvs1 is not kvs2 + # Verify values were added + keys = await kvs.list_keys() + assert len(keys) == 3 + # Record the store ID + kvs_id = kvs.id -async def test_get_set_value(key_value_store: KeyValueStore) -> None: - await key_value_store.set_value('test-str', 'string') - await key_value_store.set_value('test-int', 123) - await key_value_store.set_value('test-dict', {'abc': '123'}) - str_value = await key_value_store.get_value('test-str') - int_value = await key_value_store.get_value('test-int') - dict_value = await key_value_store.get_value('test-dict') - non_existent_value = await key_value_store.get_value('test-non-existent') - assert str_value == 'string' - assert int_value == 123 - assert dict_value['abc'] == '123' - assert non_existent_value is None + # Purge the key-value store + await kvs.purge() + # Verify the store still exists but is empty + assert kvs.id == kvs_id # Same ID preserved + assert kvs.name == 'purge_test_kvs' # Same name preserved -async def test_for_each_key(key_value_store: KeyValueStore) -> None: - keys = [item.key async for item in key_value_store.iterate_keys()] + # Store should be empty now + keys = await kvs.list_keys() assert len(keys) == 0 - for i in range(2001): - await key_value_store.set_value(str(i).zfill(4), i) - index = 0 - async for item in key_value_store.iterate_keys(): - assert item.key == str(index).zfill(4) - index += 1 - assert index == 2001 + # Values should no longer be accessible + assert await kvs.get_value('key1') is None + assert await kvs.get_value('key2') is None + assert await kvs.get_value('key3') is None + # Verify we can add new values after purging + await kvs.set_value('new_key', 'new value after purge') -async def test_static_get_set_value(key_value_store: KeyValueStore) -> None: - await key_value_store.set_value('test-static', 'static') - value = await key_value_store.get_value('test-static') - assert value == 'static' + value = await kvs.get_value('new_key') + assert value == 'new value after purge' + # Clean up + await kvs.drop() -async def test_get_public_url_raises_for_non_existing_key(key_value_store: KeyValueStore) -> None: - with pytest.raises(ValueError, match='was not found'): - await key_value_store.get_public_url('i-do-not-exist') +async def test_record_exists_nonexistent(kvs: KeyValueStore) -> None: + """Test that record_exists returns False for a nonexistent key.""" + result = await kvs.record_exists('nonexistent-key') + assert result is False -async def test_get_public_url(key_value_store: KeyValueStore) -> None: - await key_value_store.set_value('test-static', 'static') - public_url = await key_value_store.get_public_url('test-static') - url = urlparse(public_url) - path = url.netloc if url.netloc else url.path +async def test_record_exists_after_set(kvs: KeyValueStore) -> None: + """Test that record_exists returns True after setting a value.""" + test_key = 'exists-key' + test_value = {'data': 'test'} - with open(path) as f: # noqa: ASYNC230 - content = await asyncio.to_thread(f.read) - assert content == 'static' + # Initially should not exist + assert await kvs.record_exists(test_key) is False + # Set the value + await kvs.set_value(test_key, test_value) -async def test_get_auto_saved_value_default_value(key_value_store: KeyValueStore) -> None: - default_value: dict[str, JsonSerializable] = {'hello': 'world'} - value = await key_value_store.get_auto_saved_value('state', default_value) - assert value == default_value + # Now should exist + assert await kvs.record_exists(test_key) is True -async def test_get_auto_saved_value_cache_value(key_value_store: KeyValueStore) -> None: - default_value: dict[str, JsonSerializable] = {'hello': 'world'} - key_name = 'state' +async def test_record_exists_after_delete(kvs: KeyValueStore) -> None: + """Test that record_exists returns False after deleting a value.""" + test_key = 'exists-then-delete-key' + test_value = 'will be deleted' - value = await key_value_store.get_auto_saved_value(key_name, default_value) - value['hello'] = 'new_world' - value_one = await key_value_store.get_auto_saved_value(key_name) - assert value_one == {'hello': 'new_world'} + # Set a value + await kvs.set_value(test_key, test_value) + assert await kvs.record_exists(test_key) is True - value_one['hello'] = ['new_world'] - value_two = await key_value_store.get_auto_saved_value(key_name) - assert value_two == {'hello': ['new_world']} + # Delete the value + await kvs.delete_value(test_key) + # Should no longer exist + assert await kvs.record_exists(test_key) is False -async def test_get_auto_saved_value_auto_save(key_value_store: KeyValueStore, mock_event_manager: EventManager) -> None: # noqa: ARG001 - # This is not a realtime system and timing constrains can be hard to enforce. - # For the test to avoid flakiness it needs some time tolerance. - autosave_deadline_time = 1 - autosave_check_period = 0.01 - async def autosaved_within_deadline(key: str, expected_value: dict[str, str]) -> bool: - """Check if the `key_value_store` of `key` has expected value within `autosave_deadline_time` seconds.""" - deadline = datetime.now(tz=timezone.utc) + timedelta(seconds=autosave_deadline_time) - while datetime.now(tz=timezone.utc) < deadline: - await asyncio.sleep(autosave_check_period) - if await key_value_store.get_value(key) == expected_value: - return True - return False +async def test_record_exists_with_none_value(kvs: KeyValueStore) -> None: + """Test that record_exists returns True even when value is None.""" + test_key = 'none-value-key' - default_value: dict[str, JsonSerializable] = {'hello': 'world'} - key_name = 'state' - value = await key_value_store.get_auto_saved_value(key_name, default_value) - assert await autosaved_within_deadline(key=key_name, expected_value={'hello': 'world'}) + # Set None as value + await kvs.set_value(test_key, None) - value['hello'] = 'new_world' - assert await autosaved_within_deadline(key=key_name, expected_value={'hello': 'new_world'}) + # Should still exist even though value is None + assert await kvs.record_exists(test_key) is True + # Verify we can distinguish between None value and nonexistent key + assert await kvs.get_value(test_key) is None + assert await kvs.record_exists(test_key) is True + assert await kvs.record_exists('truly-nonexistent') is False -async def test_get_auto_saved_value_auto_save_race_conditions(key_value_store: KeyValueStore) -> None: - """Two parallel functions increment global variable obtained by `get_auto_saved_value`. - Result should be incremented by 2. - Method `get_auto_saved_value` must be implemented in a way that prevents race conditions in such scenario. - Test creates situation where first `get_auto_saved_value` call to kvs gets delayed. Such situation can happen - and unless handled, it can cause race condition in getting the state value.""" - await key_value_store.set_value('state', {'counter': 0}) +async def test_record_exists_different_content_types(kvs: KeyValueStore) -> None: + """Test record_exists with different content types.""" + test_cases = [ + ('json-key', {'data': 'json'}, 'application/json'), + ('text-key', 'plain text', 'text/plain'), + ('binary-key', b'binary data', 'application/octet-stream'), + ] - sleep_time_iterator = chain(iter([0.5]), repeat(0)) + for key, value, content_type in test_cases: + # Set value with specific content type + await kvs.set_value(key, value, content_type=content_type) - async def delayed_get_value(key: str, default_value: None = None) -> None: - await asyncio.sleep(next(sleep_time_iterator)) - return await KeyValueStore.get_value(key_value_store, key=key, default_value=default_value) + # Should exist regardless of content type + assert await kvs.record_exists(key) is True - async def increment_counter() -> None: - state = cast('dict[str, int]', await key_value_store.get_auto_saved_value('state')) - state['counter'] += 1 - with patch.object(key_value_store, 'get_value', delayed_get_value): - tasks = [asyncio.create_task(increment_counter()), asyncio.create_task(increment_counter())] - await asyncio.gather(*tasks) +async def test_record_exists_multiple_keys(kvs: KeyValueStore) -> None: + """Test record_exists with multiple keys and batch operations.""" + keys_and_values = [ + ('key1', 'value1'), + ('key2', {'nested': 'object'}), + ('key3', [1, 2, 3]), + ('key4', None), + ] - assert (await key_value_store.get_auto_saved_value('state'))['counter'] == 2 + # Initially, none should exist + for key, _ in keys_and_values: + assert await kvs.record_exists(key) is False + # Set all values + for key, value in keys_and_values: + await kvs.set_value(key, value) -async def test_from_storage_object() -> None: - storage_client = service_locator.get_storage_client() + # All should exist now + for key, _ in keys_and_values: + assert await kvs.record_exists(key) is True - storage_object = StorageMetadata( - id='dummy-id', - name='dummy-name', - accessed_at=datetime.now(timezone.utc), - created_at=datetime.now(timezone.utc), - modified_at=datetime.now(timezone.utc), - extra_attribute='extra', - ) + # Test some non-existent keys + assert await kvs.record_exists('nonexistent1') is False + assert await kvs.record_exists('nonexistent2') is False + + +async def test_record_exists_after_purge(kvs: KeyValueStore) -> None: + """Test that record_exists returns False after purging the store.""" + # Set some values + await kvs.set_value('key1', 'value1') + await kvs.set_value('key2', 'value2') + + # Verify they exist + assert await kvs.record_exists('key1') is True + assert await kvs.record_exists('key2') is True - key_value_store = KeyValueStore.from_storage_object(storage_client, storage_object) + # Purge the store + await kvs.purge() - assert key_value_store.id == storage_object.id - assert key_value_store.name == storage_object.name - assert key_value_store.storage_object == storage_object - assert storage_object.model_extra.get('extra_attribute') == 'extra' # type: ignore[union-attr] + # Should no longer exist + assert await kvs.record_exists('key1') is False + assert await kvs.record_exists('key2') is False diff --git a/tests/unit/storages/test_request_manager_tandem.py b/tests/unit/storages/test_request_manager_tandem.py index e38ef3d0e8..70240914ec 100644 --- a/tests/unit/storages/test_request_manager_tandem.py +++ b/tests/unit/storages/test_request_manager_tandem.py @@ -56,7 +56,7 @@ async def test_basic_functionality(test_input: TestInput) -> None: request_queue = await RequestQueue.open() if test_input.request_manager_items: - await request_queue.add_requests_batched(test_input.request_manager_items) + await request_queue.add_requests(test_input.request_manager_items) mock_request_loader = create_autospec(RequestLoader, instance=True, spec_set=True) mock_request_loader.fetch_next_request.side_effect = lambda: test_input.request_loader_items.pop(0) diff --git a/tests/unit/storages/test_request_queue.py b/tests/unit/storages/test_request_queue.py index cddba8ef99..8df759a27f 100644 --- a/tests/unit/storages/test_request_queue.py +++ b/tests/unit/storages/test_request_queue.py @@ -1,367 +1,646 @@ +# TODO: Update crawlee_storage_dir args once the Pydantic bug is fixed +# https://github.com/apify/crawlee-python/issues/146 + from __future__ import annotations import asyncio -from datetime import datetime, timedelta, timezone -from itertools import count from typing import TYPE_CHECKING -from unittest.mock import AsyncMock, MagicMock import pytest -from pydantic import ValidationError from crawlee import Request, service_locator -from crawlee._request import RequestState -from crawlee.storage_clients import MemoryStorageClient, StorageClient -from crawlee.storage_clients._memory import RequestQueueClient -from crawlee.storage_clients.models import ( - BatchRequestsOperationResponse, - StorageMetadata, - UnprocessedRequest, -) +from crawlee.configuration import Configuration +from crawlee.storage_clients import FileSystemStorageClient, MemoryStorageClient, StorageClient from crawlee.storages import RequestQueue if TYPE_CHECKING: - from collections.abc import AsyncGenerator, Sequence + from collections.abc import AsyncGenerator + from pathlib import Path + + from crawlee.storage_clients import StorageClient + + +@pytest.fixture(params=['memory', 'file_system']) +def storage_client(request: pytest.FixtureRequest) -> StorageClient: + """Parameterized fixture to test with different storage clients.""" + if request.param == 'memory': + return MemoryStorageClient() + + return FileSystemStorageClient() @pytest.fixture -async def request_queue() -> AsyncGenerator[RequestQueue, None]: - rq = await RequestQueue.open() +def configuration(tmp_path: Path) -> Configuration: + """Provide a configuration with a temporary storage directory.""" + return Configuration( + crawlee_storage_dir=str(tmp_path), # type: ignore[call-arg] + purge_on_start=True, + ) + + +@pytest.fixture +async def rq( + storage_client: StorageClient, + configuration: Configuration, +) -> AsyncGenerator[RequestQueue, None]: + """Fixture that provides a request queue instance for each test.""" + rq = await RequestQueue.open( + storage_client=storage_client, + configuration=configuration, + ) + yield rq await rq.drop() -async def test_open() -> None: - default_request_queue = await RequestQueue.open() - default_request_queue_by_id = await RequestQueue.open(id=default_request_queue.id) +async def test_open_creates_new_rq( + storage_client: StorageClient, + configuration: Configuration, +) -> None: + """Test that open() creates a new request queue with proper metadata.""" + rq = await RequestQueue.open( + name='new_request_queue', + storage_client=storage_client, + configuration=configuration, + ) - assert default_request_queue is default_request_queue_by_id + # Verify request queue properties + assert rq.id is not None + assert rq.name == 'new_request_queue' + metadata = await rq.get_metadata() + assert metadata.pending_request_count == 0 + assert metadata.handled_request_count == 0 + assert metadata.total_request_count == 0 - request_queue_name = 'dummy-name' - named_request_queue = await RequestQueue.open(name=request_queue_name) - assert default_request_queue is not named_request_queue + await rq.drop() - with pytest.raises(RuntimeError, match='RequestQueue with id "nonexistent-id" does not exist!'): - await RequestQueue.open(id='nonexistent-id') - # Test that when you try to open a request queue by ID and you use a name of an existing request queue, - # it doesn't work - with pytest.raises(RuntimeError, match='RequestQueue with id "dummy-name" does not exist!'): - await RequestQueue.open(id='dummy-name') +async def test_open_existing_rq( + rq: RequestQueue, + storage_client: StorageClient, +) -> None: + """Test that open() loads an existing request queue correctly.""" + # Open the same request queue again + reopened_rq = await RequestQueue.open( + name=rq.name, + storage_client=storage_client, + ) + + # Verify request queue properties + assert rq.id == reopened_rq.id + assert rq.name == reopened_rq.name + # Verify they are the same object (from cache) + assert id(rq) == id(reopened_rq) -async def test_consistency_accross_two_clients() -> None: - request_apify = Request.from_url('https://apify.com') - request_crawlee = Request.from_url('https://crawlee.dev') - rq = await RequestQueue.open(name='my-rq') - await rq.add_request(request_apify) +async def test_open_with_id_and_name( + storage_client: StorageClient, + configuration: Configuration, +) -> None: + """Test that open() raises an error when both id and name are provided.""" + with pytest.raises(ValueError, match='Only one of "id" or "name" can be specified'): + await RequestQueue.open( + id='some-id', + name='some-name', + storage_client=storage_client, + configuration=configuration, + ) + - rq_by_id = await RequestQueue.open(id=rq.id) - await rq_by_id.add_request(request_crawlee) +async def test_open_by_id( + storage_client: StorageClient, + configuration: Configuration, +) -> None: + """Test opening a request queue by its ID.""" + # First create a request queue by name + rq1 = await RequestQueue.open( + name='rq_by_id_test', + storage_client=storage_client, + configuration=configuration, + ) - assert await rq.get_total_count() == 2 - assert await rq_by_id.get_total_count() == 2 + # Add a request to identify it + await rq1.add_request('https://example.com/open-by-id-test') - assert await rq.fetch_next_request() == request_apify - assert await rq_by_id.fetch_next_request() == request_crawlee + # Open the request queue by ID + rq2 = await RequestQueue.open( + id=rq1.id, + storage_client=storage_client, + configuration=configuration, + ) - await rq.drop() - with pytest.raises(RuntimeError, match='Storage with provided ID was not found'): - await rq_by_id.drop() + # Verify it's the same request queue + assert rq2.id == rq1.id + assert rq2.name == 'rq_by_id_test' + + # Verify the request is still there + request = await rq2.fetch_next_request() + assert request is not None + assert request.url == 'https://example.com/open-by-id-test' + + # Clean up + await rq2.drop() + + +async def test_add_request_string_url(rq: RequestQueue) -> None: + """Test adding a request with a string URL.""" + # Add a request with a string URL + url = 'https://example.com' + result = await rq.add_request(url) + + # Verify request was added + assert result.id is not None + assert result.unique_key is not None + assert result.was_already_present is False + assert result.was_already_handled is False + + # Verify the queue stats were updated + metadata = await rq.get_metadata() + assert metadata.total_request_count == 1 + assert metadata.pending_request_count == 1 + + +async def test_add_request_object(rq: RequestQueue) -> None: + """Test adding a request object.""" + # Create and add a request object + request = Request.from_url(url='https://example.com', user_data={'key': 'value'}) + result = await rq.add_request(request) + + # Verify request was added + assert result.id is not None + assert result.unique_key is not None + assert result.was_already_present is False + assert result.was_already_handled is False + + # Verify the queue stats were updated + metadata = await rq.get_metadata() + assert metadata.total_request_count == 1 + assert metadata.pending_request_count == 1 + + +async def test_add_duplicate_request(rq: RequestQueue) -> None: + """Test adding a duplicate request to the queue.""" + # Add a request + url = 'https://example.com' + first_result = await rq.add_request(url) + + # Add the same request again + second_result = await rq.add_request(url) + + # Verify the second request was detected as duplicate + assert second_result.was_already_present is True + assert second_result.unique_key == first_result.unique_key + + # Verify the queue stats weren't incremented twice + metadata = await rq.get_metadata() + assert metadata.total_request_count == 1 + assert metadata.pending_request_count == 1 + + +async def test_add_requests_batch(rq: RequestQueue) -> None: + """Test adding multiple requests in a batch.""" + # Create a batch of requests + urls = [ + 'https://example.com/page1', + 'https://example.com/page2', + 'https://example.com/page3', + ] + + # Add the requests + await rq.add_requests(urls) + + # Wait for all background tasks to complete + await asyncio.sleep(0.1) + + # Verify the queue stats + metadata = await rq.get_metadata() + assert metadata.total_request_count == 3 + assert metadata.pending_request_count == 3 + + +async def test_add_requests_batch_with_forefront(rq: RequestQueue) -> None: + """Test adding multiple requests in a batch with forefront option.""" + # Add some initial requests + await rq.add_request('https://example.com/page1') + await rq.add_request('https://example.com/page2') + + # Add a batch of priority requests at the forefront + + await rq.add_requests( + [ + 'https://example.com/priority1', + 'https://example.com/priority2', + 'https://example.com/priority3', + ], + forefront=True, + ) + # Wait for all background tasks to complete + await asyncio.sleep(0.1) -async def test_same_references() -> None: - rq1 = await RequestQueue.open() - rq2 = await RequestQueue.open() - assert rq1 is rq2 + # Fetch requests - they should come out in priority order first + next_request1 = await rq.fetch_next_request() + assert next_request1 is not None + assert next_request1.url.startswith('https://example.com/priority') - rq_name = 'non-default' - rq_named1 = await RequestQueue.open(name=rq_name) - rq_named2 = await RequestQueue.open(name=rq_name) - assert rq_named1 is rq_named2 + next_request2 = await rq.fetch_next_request() + assert next_request2 is not None + assert next_request2.url.startswith('https://example.com/priority') + next_request3 = await rq.fetch_next_request() + assert next_request3 is not None + assert next_request3.url.startswith('https://example.com/priority') -async def test_drop() -> None: - rq1 = await RequestQueue.open() - await rq1.drop() - rq2 = await RequestQueue.open() - assert rq1 is not rq2 + # Now we should get the original requests + next_request4 = await rq.fetch_next_request() + assert next_request4 is not None + assert next_request4.url == 'https://example.com/page1' + next_request5 = await rq.fetch_next_request() + assert next_request5 is not None + assert next_request5.url == 'https://example.com/page2' -async def test_get_request(request_queue: RequestQueue) -> None: - request = Request.from_url('https://example.com') - processed_request = await request_queue.add_request(request) - assert request.id == processed_request.id - request_2 = await request_queue.get_request(request.id) - assert request_2 is not None - assert request == request_2 + # Queue should be empty now + next_request6 = await rq.fetch_next_request() + assert next_request6 is None -async def test_add_fetch_handle_request(request_queue: RequestQueue) -> None: - request = Request.from_url('https://example.com') - assert await request_queue.is_empty() is True - add_request_info = await request_queue.add_request(request) +async def test_add_requests_with_forefront(rq: RequestQueue) -> None: + """Test adding requests to the front of the queue.""" + # Add some initial requests + await rq.add_request('https://example.com/page1') + await rq.add_request('https://example.com/page2') - assert add_request_info.was_already_present is False - assert add_request_info.was_already_handled is False - assert await request_queue.is_empty() is False + # Add a priority request at the forefront + await rq.add_request('https://example.com/priority', forefront=True) - # Fetch the request - next_request = await request_queue.fetch_next_request() + # Fetch the next request - should be the priority one + next_request = await rq.fetch_next_request() assert next_request is not None + assert next_request.url == 'https://example.com/priority' - # Mark it as handled - next_request.handled_at = datetime.now(timezone.utc) - processed_request = await request_queue.mark_request_as_handled(next_request) - assert processed_request is not None - assert processed_request.id == request.id - assert processed_request.unique_key == request.unique_key - assert await request_queue.is_finished() is True +async def test_add_requests_mixed_forefront(rq: RequestQueue) -> None: + """Test the ordering when adding requests with mixed forefront values.""" + # Add normal requests + await rq.add_request('https://example.com/normal1') + await rq.add_request('https://example.com/normal2') + # Add a batch with forefront=True + await rq.add_requests( + ['https://example.com/priority1', 'https://example.com/priority2'], + forefront=True, + ) -async def test_reclaim_request(request_queue: RequestQueue) -> None: - request = Request.from_url('https://example.com') - await request_queue.add_request(request) + # Add another normal request + await rq.add_request('https://example.com/normal3') - # Fetch the request - next_request = await request_queue.fetch_next_request() - assert next_request is not None - assert next_request.unique_key == request.url - - # Reclaim - await request_queue.reclaim_request(next_request) - # Try to fetch again after a few secs - await asyncio.sleep(4) # 3 seconds is the consistency delay in request queue - next_again = await request_queue.fetch_next_request() - - assert next_again is not None - assert next_again.id == request.id - assert next_again.unique_key == request.unique_key - - -@pytest.mark.parametrize( - 'requests', - [ - [Request.from_url('https://apify.com')], - ['https://crawlee.dev'], - [Request.from_url(f'https://example.com/{i}') for i in range(10)], - [f'https://example.com/{i}' for i in range(15)], - ], - ids=['single-request', 'single-url', 'multiple-requests', 'multiple-urls'], -) -async def test_add_batched_requests( - request_queue: RequestQueue, - requests: Sequence[str | Request], -) -> None: - request_count = len(requests) + # Add another priority request + await rq.add_request('https://example.com/priority3', forefront=True) - # Add the requests to the RQ in batches - await request_queue.add_requests_batched(requests, wait_for_all_requests_to_be_added=True) + # Wait for background tasks + await asyncio.sleep(0.1) - # Ensure the batch was processed correctly - assert await request_queue.get_total_count() == request_count + # The expected order should be: + # 1. priority3 (most recent forefront) + # 2. priority1 (from batch, forefront) + # 3. priority2 (from batch, forefront) + # 4. normal1 (oldest normal) + # 5. normal2 + # 6. normal3 (newest normal) - # Fetch and validate each request in the queue - for original_request in requests: - next_request = await request_queue.fetch_next_request() - assert next_request is not None + requests = [] + while True: + req = await rq.fetch_next_request() + if req is None: + break + requests.append(req) + await rq.mark_request_as_handled(req) - expected_url = original_request if isinstance(original_request, str) else original_request.url - assert next_request.url == expected_url + assert len(requests) == 6 + assert requests[0].url == 'https://example.com/priority3' - # Confirm the queue is empty after processing all requests - assert await request_queue.is_empty() is True + # The next two should be from the forefront batch (exact order within batch may vary) + batch_urls = {requests[1].url, requests[2].url} + assert 'https://example.com/priority1' in batch_urls + assert 'https://example.com/priority2' in batch_urls + # Then the normal requests in order + assert requests[3].url == 'https://example.com/normal1' + assert requests[4].url == 'https://example.com/normal2' + assert requests[5].url == 'https://example.com/normal3' -async def test_invalid_user_data_serialization() -> None: - with pytest.raises(ValidationError): - Request.from_url( - 'https://crawlee.dev', - user_data={ - 'foo': datetime(year=2020, month=7, day=4, tzinfo=timezone.utc), - 'bar': {datetime(year=2020, month=4, day=7, tzinfo=timezone.utc)}, - }, - ) +async def test_fetch_next_request_and_mark_handled(rq: RequestQueue) -> None: + """Test fetching and marking requests as handled.""" + # Add some requests + await rq.add_request('https://example.com/page1') + await rq.add_request('https://example.com/page2') -async def test_user_data_serialization(request_queue: RequestQueue) -> None: - request = Request.from_url( - 'https://crawlee.dev', - user_data={ - 'hello': 'world', - 'foo': 42, - }, - ) + # Fetch first request + request1 = await rq.fetch_next_request() + assert request1 is not None + assert request1.url == 'https://example.com/page1' - await request_queue.add_request(request) + # Mark the request as handled + result = await rq.mark_request_as_handled(request1) + assert result is not None + assert result.was_already_handled is True - dequeued_request = await request_queue.fetch_next_request() - assert dequeued_request is not None + # Fetch next request + request2 = await rq.fetch_next_request() + assert request2 is not None + assert request2.url == 'https://example.com/page2' - assert dequeued_request.user_data['hello'] == 'world' - assert dequeued_request.user_data['foo'] == 42 + # Mark the second request as handled + await rq.mark_request_as_handled(request2) + # Verify counts + metadata = await rq.get_metadata() + assert metadata.total_request_count == 2 + assert metadata.handled_request_count == 2 + assert metadata.pending_request_count == 0 -async def test_complex_user_data_serialization(request_queue: RequestQueue) -> None: - request = Request.from_url('https://crawlee.dev') - request.user_data['hello'] = 'world' - request.user_data['foo'] = 42 - request.crawlee_data.max_retries = 1 - request.crawlee_data.state = RequestState.ERROR_HANDLER + # Verify queue is empty + empty_request = await rq.fetch_next_request() + assert empty_request is None - await request_queue.add_request(request) - dequeued_request = await request_queue.fetch_next_request() - assert dequeued_request is not None +async def test_get_request_by_id(rq: RequestQueue) -> None: + """Test retrieving a request by its ID.""" + # Add a request + added_result = await rq.add_request('https://example.com') + request_id = added_result.id - data = dequeued_request.model_dump(by_alias=True) - assert data['userData']['hello'] == 'world' - assert data['userData']['foo'] == 42 - assert data['userData']['__crawlee'] == { - 'maxRetries': 1, - 'state': RequestState.ERROR_HANDLER, - } + # Retrieve the request by ID + retrieved_request = await rq.get_request(request_id) + assert retrieved_request is not None + assert retrieved_request.id == request_id + assert retrieved_request.url == 'https://example.com' -async def test_deduplication_of_requests_with_custom_unique_key() -> None: - with pytest.raises(ValueError, match='`always_enqueue` cannot be used with a custom `unique_key`'): - Request.from_url('https://apify.com', unique_key='apify', always_enqueue=True) +async def test_get_non_existent_request(rq: RequestQueue) -> None: + """Test retrieving a request that doesn't exist.""" + non_existent_request = await rq.get_request('non-existent-id') + assert non_existent_request is None -async def test_deduplication_of_requests_with_invalid_custom_unique_key() -> None: - request_1 = Request.from_url('https://apify.com', always_enqueue=True) - request_2 = Request.from_url('https://apify.com', always_enqueue=True) +async def test_reclaim_request(rq: RequestQueue) -> None: + """Test reclaiming a request that failed processing.""" + # Add a request + await rq.add_request('https://example.com') - rq = await RequestQueue.open(name='my-rq') - await rq.add_request(request_1) - await rq.add_request(request_2) + # Fetch the request + request = await rq.fetch_next_request() + assert request is not None - assert await rq.get_total_count() == 2 + # Reclaim the request + result = await rq.reclaim_request(request) + assert result is not None + assert result.was_already_handled is False - assert await rq.fetch_next_request() == request_1 - assert await rq.fetch_next_request() == request_2 + # Verify we can fetch it again + reclaimed_request = await rq.fetch_next_request() + assert reclaimed_request is not None + assert reclaimed_request.id == request.id + assert reclaimed_request.url == 'https://example.com' -async def test_deduplication_of_requests_with_valid_custom_unique_key() -> None: - request_1 = Request.from_url('https://apify.com') - request_2 = Request.from_url('https://apify.com') +async def test_reclaim_request_with_forefront(rq: RequestQueue) -> None: + """Test reclaiming a request to the front of the queue.""" + # Add requests + await rq.add_request('https://example.com/first') + await rq.add_request('https://example.com/second') - rq = await RequestQueue.open(name='my-rq') - await rq.add_request(request_1) - await rq.add_request(request_2) + # Fetch the first request + first_request = await rq.fetch_next_request() + assert first_request is not None + assert first_request.url == 'https://example.com/first' - assert await rq.get_total_count() == 1 + # Reclaim it to the forefront + await rq.reclaim_request(first_request, forefront=True) - assert await rq.fetch_next_request() == request_1 + # The reclaimed request should be returned first (before the second request) + next_request = await rq.fetch_next_request() + assert next_request is not None + assert next_request.url == 'https://example.com/first' -async def test_cache_requests(request_queue: RequestQueue) -> None: - request_1 = Request.from_url('https://apify.com') - request_2 = Request.from_url('https://crawlee.dev') +async def test_is_empty(rq: RequestQueue) -> None: + """Test checking if a request queue is empty.""" + # Initially the queue should be empty + assert await rq.is_empty() is True - await request_queue.add_request(request_1) - await request_queue.add_request(request_2) + # Add a request + await rq.add_request('https://example.com') + assert await rq.is_empty() is False - assert request_queue._requests_cache.currsize == 2 + # Fetch and handle the request + request = await rq.fetch_next_request() - fetched_request = await request_queue.fetch_next_request() + assert request is not None + await rq.mark_request_as_handled(request) - assert fetched_request is not None - assert fetched_request.id == request_1.id + # Queue should be empty again + assert await rq.is_empty() is True - # After calling fetch_next_request request_1 moved to the end of the cache store. - cached_items = [request_queue._requests_cache.popitem()[0] for _ in range(2)] - assert cached_items == [request_2.id, request_1.id] +async def test_is_finished(rq: RequestQueue) -> None: + """Test checking if a request queue is finished.""" + # Initially the queue should be finished (empty and no background tasks) + assert await rq.is_finished() is True -async def test_from_storage_object() -> None: - storage_client = service_locator.get_storage_client() + # Add a request + await rq.add_request('https://example.com') + assert await rq.is_finished() is False - storage_object = StorageMetadata( - id='dummy-id', - name='dummy-name', - accessed_at=datetime.now(timezone.utc), - created_at=datetime.now(timezone.utc), - modified_at=datetime.now(timezone.utc), - extra_attribute='extra', + # Add requests in the background + await rq.add_requests( + ['https://example.com/1', 'https://example.com/2'], + wait_for_all_requests_to_be_added=False, ) - request_queue = RequestQueue.from_storage_object(storage_client, storage_object) - - assert request_queue.id == storage_object.id - assert request_queue.name == storage_object.name - assert request_queue.storage_object == storage_object - assert storage_object.model_extra.get('extra_attribute') == 'extra' # type: ignore[union-attr] - - -async def test_add_batched_requests_with_retry(request_queue: RequestQueue) -> None: - """Test that unprocessed requests are retried. - - Unprocessed requests should not count in `get_total_count` - Test creates situation where in `batch_add_requests` call in first batch 3 requests are unprocessed. - On each following `batch_add_requests` call the last request in batch remains unprocessed. - In this test `batch_add_requests` is called once with batch of 10 requests. With retries only 1 request should - remain unprocessed.""" - - batch_add_requests_call_counter = count(start=1) - service_locator.get_storage_client() - initial_request_count = 10 - expected_added_requests = 9 - requests = [f'https://example.com/{i}' for i in range(initial_request_count)] - - class MockedRequestQueueClient(RequestQueueClient): - """Patched memory storage client that simulates unprocessed requests.""" - - async def _batch_add_requests_without_last_n( - self, batch: Sequence[Request], n: int = 0 - ) -> BatchRequestsOperationResponse: - response = await super().batch_add_requests(batch[:-n]) - response.unprocessed_requests = [ - UnprocessedRequest(url=r.url, unique_key=r.unique_key, method=r.method) for r in batch[-n:] - ] - return response - - async def batch_add_requests( - self, - requests: Sequence[Request], - *, - forefront: bool = False, # noqa: ARG002 - ) -> BatchRequestsOperationResponse: - """Mocked client behavior that simulates unprocessed requests. - - It processes all except last three at first run, then all except last none. - Overall if tried with the same batch it will process all except the last one. - """ - call_count = next(batch_add_requests_call_counter) - if call_count == 1: - # Process all but last three - return await self._batch_add_requests_without_last_n(requests, n=3) - # Process all but last - return await self._batch_add_requests_without_last_n(requests, n=1) - - mocked_storage_client = AsyncMock(spec=StorageClient) - mocked_storage_client.request_queue = MagicMock( - return_value=MockedRequestQueueClient(id='default', memory_storage_client=MemoryStorageClient.from_config()) + # Queue shouldn't be finished while background tasks are running + assert await rq.is_finished() is False + + # Wait for background tasks to finish + await asyncio.sleep(0.2) + + # Process all requests + while True: + request = await rq.fetch_next_request() + if request is None: + break + await rq.mark_request_as_handled(request) + + # Now queue should be finished + assert await rq.is_finished() is True + + +async def test_mark_non_existent_request_as_handled(rq: RequestQueue) -> None: + """Test marking a non-existent request as handled.""" + # Create a request that hasn't been added to the queue + request = Request.from_url(url='https://example.com', id='non-existent-id') + + # Attempt to mark it as handled + result = await rq.mark_request_as_handled(request) + assert result is None + + +async def test_reclaim_non_existent_request(rq: RequestQueue) -> None: + """Test reclaiming a non-existent request.""" + # Create a request that hasn't been added to the queue + request = Request.from_url(url='https://example.com', id='non-existent-id') + + # Attempt to reclaim it + result = await rq.reclaim_request(request) + assert result is None + + +async def test_drop( + storage_client: StorageClient, + configuration: Configuration, +) -> None: + """Test dropping a request queue removes it from cache and clears its data.""" + rq = await RequestQueue.open( + name='drop_test', + storage_client=storage_client, + configuration=configuration, ) - request_queue = RequestQueue(id='default', name='some_name', storage_client=mocked_storage_client) + # Add a request + await rq.add_request('https://example.com') - # Add the requests to the RQ in batches - await request_queue.add_requests_batched( - requests, wait_for_all_requests_to_be_added=True, wait_time_between_batches=timedelta(0) + # Drop the request queue + await rq.drop() + + # Verify request queue is empty (by creating a new one with the same name) + new_rq = await RequestQueue.open( + name='drop_test', + storage_client=storage_client, + configuration=configuration, ) - # Ensure the batch was processed correctly - assert await request_queue.get_total_count() == expected_added_requests - # Fetch and validate each request in the queue - for original_request in requests[:expected_added_requests]: - next_request = await request_queue.fetch_next_request() - assert next_request is not None + # Verify the queue is empty + assert await new_rq.is_empty() is True + metadata = await new_rq.get_metadata() + assert metadata.total_request_count == 0 + assert metadata.pending_request_count == 0 + await new_rq.drop() + + +async def test_reopen_default( + storage_client: StorageClient, + configuration: Configuration, +) -> None: + """Test reopening the default request queue.""" + # First clean up any storage instance caches + storage_instance_manager = service_locator.storage_instance_manager + storage_instance_manager.clear_cache() + + # Open the default request queue + rq1 = await RequestQueue.open( + storage_client=storage_client, + configuration=configuration, + ) + + # If a request queue already exists (due to previous test run), purge it to start fresh + try: + await rq1.purge() + except Exception: + # If purge fails, try dropping and recreating + await rq1.drop() + rq1 = await RequestQueue.open( + storage_client=storage_client, + configuration=configuration, + ) + + # Verify we're starting fresh + metadata1 = await rq1.get_metadata() + assert metadata1.pending_request_count == 0 + + # Add a request + await rq1.add_request('https://example.com/') - expected_url = original_request if isinstance(original_request, str) else original_request.url - assert next_request.url == expected_url + # Verify the request was added + metadata1 = await rq1.get_metadata() + assert metadata1.pending_request_count == 1 - # Confirm the queue is empty after processing all requests - assert await request_queue.is_empty() is True + # Open the default request queue again + rq2 = await RequestQueue.open( + storage_client=storage_client, + configuration=configuration, + ) + + # Verify they are the same queue + assert rq1.id == rq2.id + assert rq1.name == rq2.name + metadata1 = await rq1.get_metadata() + metadata2 = await rq2.get_metadata() + assert metadata1.total_request_count == metadata2.total_request_count + assert metadata1.pending_request_count == metadata2.pending_request_count + assert metadata1.handled_request_count == metadata2.handled_request_count + + # Verify the request is accessible + request = await rq2.fetch_next_request() + assert request is not None + assert request.url == 'https://example.com/' + + # Clean up after the test + await rq1.drop() + + +async def test_purge( + storage_client: StorageClient, + configuration: Configuration, +) -> None: + """Test purging a request queue removes all requests but keeps the queue itself.""" + # First create a request queue + rq = await RequestQueue.open( + name='purge_test_queue', + storage_client=storage_client, + configuration=configuration, + ) + + # Add some requests + await rq.add_requests( + [ + 'https://example.com/page1', + 'https://example.com/page2', + 'https://example.com/page3', + ] + ) + + # Verify requests were added + metadata = await rq.get_metadata() + assert metadata.total_request_count == 3 + assert metadata.pending_request_count == 3 + assert metadata.handled_request_count == 0 + + # Record the queue ID + queue_id = rq.id + + # Purge the queue + await rq.purge() + + # Verify the queue still exists but is empty + assert rq.id == queue_id # Same ID preserved + assert rq.name == 'purge_test_queue' # Same name preserved + + # Queue should be empty now + metadata = await rq.get_metadata() + assert metadata.total_request_count == 3 + assert metadata.pending_request_count == 0 + assert metadata.handled_request_count == 0 + assert await rq.is_empty() is True + + # Verify we can add new requests after purging + await rq.add_request('https://example.com/new-after-purge') + + request = await rq.fetch_next_request() + assert request is not None + assert request.url == 'https://example.com/new-after-purge' + + # Clean up + await rq.drop() diff --git a/tests/unit/test_configuration.py b/tests/unit/test_configuration.py index 73e17d50d9..f89401e5be 100644 --- a/tests/unit/test_configuration.py +++ b/tests/unit/test_configuration.py @@ -9,6 +9,7 @@ from crawlee.configuration import Configuration from crawlee.crawlers import HttpCrawler, HttpCrawlingContext from crawlee.storage_clients import MemoryStorageClient +from crawlee.storage_clients._file_system._storage_client import FileSystemStorageClient if TYPE_CHECKING: from pathlib import Path @@ -35,14 +36,15 @@ def test_global_configuration_works_reversed() -> None: async def test_storage_not_persisted_when_disabled(tmp_path: Path, server_url: URL) -> None: - config = Configuration( - persist_storage=False, - write_metadata=False, + configuration = Configuration( crawlee_storage_dir=str(tmp_path), # type: ignore[call-arg] ) - storage_client = MemoryStorageClient.from_config(config) + storage_client = MemoryStorageClient() - crawler = HttpCrawler(storage_client=storage_client) + crawler = HttpCrawler( + configuration=configuration, + storage_client=storage_client, + ) @crawler.router.default_handler async def default_handler(context: HttpCrawlingContext) -> None: @@ -56,14 +58,16 @@ async def default_handler(context: HttpCrawlingContext) -> None: async def test_storage_persisted_when_enabled(tmp_path: Path, server_url: URL) -> None: - config = Configuration( - persist_storage=True, - write_metadata=True, + configuration = Configuration( crawlee_storage_dir=str(tmp_path), # type: ignore[call-arg] ) - storage_client = MemoryStorageClient.from_config(config) - crawler = HttpCrawler(storage_client=storage_client) + storage_client = FileSystemStorageClient() + + crawler = HttpCrawler( + configuration=configuration, + storage_client=storage_client, + ) @crawler.router.default_handler async def default_handler(context: HttpCrawlingContext) -> None: diff --git a/tests/unit/test_service_locator.py b/tests/unit/test_service_locator.py index 50da5ddb86..a4ed0620dd 100644 --- a/tests/unit/test_service_locator.py +++ b/tests/unit/test_service_locator.py @@ -6,7 +6,7 @@ from crawlee.configuration import Configuration from crawlee.errors import ServiceConflictError from crawlee.events import LocalEventManager -from crawlee.storage_clients import MemoryStorageClient +from crawlee.storage_clients import FileSystemStorageClient, MemoryStorageClient def test_default_configuration() -> None: @@ -72,21 +72,21 @@ def test_event_manager_conflict() -> None: def test_default_storage_client() -> None: default_storage_client = service_locator.get_storage_client() - assert isinstance(default_storage_client, MemoryStorageClient) + assert isinstance(default_storage_client, FileSystemStorageClient) def test_custom_storage_client() -> None: - custom_storage_client = MemoryStorageClient.from_config() + custom_storage_client = MemoryStorageClient() service_locator.set_storage_client(custom_storage_client) storage_client = service_locator.get_storage_client() assert storage_client is custom_storage_client def test_storage_client_overwrite() -> None: - custom_storage_client = MemoryStorageClient.from_config() + custom_storage_client = MemoryStorageClient() service_locator.set_storage_client(custom_storage_client) - another_custom_storage_client = MemoryStorageClient.from_config() + another_custom_storage_client = MemoryStorageClient() service_locator.set_storage_client(another_custom_storage_client) assert custom_storage_client != another_custom_storage_client @@ -95,7 +95,7 @@ def test_storage_client_overwrite() -> None: def test_storage_client_conflict() -> None: service_locator.get_storage_client() - custom_storage_client = MemoryStorageClient.from_config() + custom_storage_client = MemoryStorageClient() with pytest.raises(ServiceConflictError, match='StorageClient is already in use.'): service_locator.set_storage_client(custom_storage_client) diff --git a/website/generate_module_shortcuts.py b/website/generate_module_shortcuts.py index 5a18e8d3f3..61acc68ade 100755 --- a/website/generate_module_shortcuts.py +++ b/website/generate_module_shortcuts.py @@ -5,6 +5,7 @@ import importlib import inspect import json +from pathlib import Path from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -55,5 +56,5 @@ def resolve_shortcuts(shortcuts: dict) -> None: resolve_shortcuts(shortcuts) -with open('module_shortcuts.json', 'w', encoding='utf-8') as shortcuts_file: +with Path('module_shortcuts.json').open('w', encoding='utf-8') as shortcuts_file: json.dump(shortcuts, shortcuts_file, indent=4, sort_keys=True)