diff --git a/.github/workflows/coverage.yml b/.github/workflows/coverage.yml index 07358572..c463d658 100644 --- a/.github/workflows/coverage.yml +++ b/.github/workflows/coverage.yml @@ -18,10 +18,10 @@ jobs: with: python-version: '3.14.0-beta.1' - - name: unittest + - name: pytest run: | python -m pip install -r requirements-dev.txt - coverage run -m unittest + coverage run -m pytest -s - name: Generate coverage report run: coverage report -m --fail-under=100 diff --git a/.github/workflows/testing.yml b/.github/workflows/testing.yml index 725e09e9..c073f9fc 100644 --- a/.github/workflows/testing.yml +++ b/.github/workflows/testing.yml @@ -13,7 +13,7 @@ jobs: strategy: matrix: - python-version: ['3.7.17', '3.8.18', '3.9.18', '3.10.13', '3.11.7', '3.12.1', '3.13.1', '3.14.0-beta.1'] + python-version: ['3.7.17', '3.8.18', '3.9.18', '3.10.13', '3.11.7', '3.12.1', '3.13.1', '3.14.0', '3.15.0-alpha.1'] steps: - uses: actions/checkout@v4 @@ -22,7 +22,7 @@ jobs: with: python-version: ${{ matrix.python-version }} - - name: unittest + - name: pytest run: | python -m pip install -r tests/requirements.txt - python -m unittest + python -m pytest -s diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index a23e252b..2182fd28 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -11,7 +11,7 @@ make lint ## run a specific test ```bash -python -m unittest tests.test_stream.TestStream.test_distinct +pytest -s -k test_etl_example ``` ## changelog diff --git a/Makefile b/Makefile index e0bde1b8..0216698d 100644 --- a/Makefile +++ b/Makefile @@ -8,7 +8,7 @@ help: @echo "Available commands:" @echo " make all - Run all tasks: venv, test, type-check, format" @echo " make venv - Create a virtual environment and install dependencies" - @echo " make test - Run unittests and check coverage" + @echo " make test - Run unit tests and check coverage" @echo " make type-check - Check typing via mypy" @echo " make format - Format via ruff" @echo " make format-check - Check the formatting via ruff" @@ -18,7 +18,7 @@ venv: $(VENV_DIR)/bin/pip install -r requirements-dev.txt test: - $(VENV_DIR)/bin/python -m coverage run -m unittest -v --failfast + $(VENV_DIR)/bin/python -m coverage run -m pytest -sx $(VENV_DIR)/bin/coverage report -m $(VENV_DIR)/bin/coverage html diff --git a/README.md b/README.md index b7618f10..b33e103f 100644 --- a/README.md +++ b/README.md @@ -64,55 +64,55 @@ This toy script gets Pokémons concurrently from [PokéAPI](https://pokeapi.co/) ```python import csv from datetime import timedelta -import itertools -import requests +from itertools import count +import httpx from streamable import Stream with open("./quadruped_pokemons.csv", mode="w") as file: fields = ["id", "name", "is_legendary", "base_happiness", "capture_rate"] writer = csv.DictWriter(file, fields, extrasaction='ignore') writer.writeheader() + with httpx.Client() as http_client: + pipeline = ( + # Infinite Stream[int] of Pokemon ids starting from Pokémon #1: Bulbasaur + Stream(count(1)) + # Limit to 16 requests per second to be friendly to our fellow PokéAPI devs + .throttle(16, per=timedelta(seconds=1)) + # GET pokemons concurrently using a pool of 8 threads + .map(lambda poke_id: f"https://pokeapi.co/api/v2/pokemon-species/{poke_id}") + .map(http_client.get, concurrency=8) + .foreach(httpx.Response.raise_for_status) + .map(httpx.Response.json) + # Stop the iteration when reaching the 1st pokemon of the 4th generation + .truncate(when=lambda poke: poke["generation"]["name"] == "generation-iv") + .observe("pokemons") + # Keep only quadruped Pokemons + .filter(lambda poke: poke["shape"]["name"] == "quadruped") + # Write a batch of pokemons every 5 seconds to the CSV file + .group(interval=timedelta(seconds=5)) + .foreach(writer.writerows) + .flatten() + .observe("written pokemons") + # Catch exceptions and raises the 1st one at the end of the iteration + .catch(Exception, finally_raise=True) + ) - pipeline: Stream = ( - # Infinite Stream[int] of Pokemon ids starting from Pokémon #1: Bulbasaur - Stream(itertools.count(1)) - # Limit to 16 requests per second to be friendly to our fellow PokéAPI devs - .throttle(16, per=timedelta(seconds=1)) - # GET pokemons concurrently using a pool of 8 threads - .map(lambda poke_id: f"https://pokeapi.co/api/v2/pokemon-species/{poke_id}") - .map(requests.get, concurrency=8) - .foreach(requests.Response.raise_for_status) - .map(requests.Response.json) - # Stop the iteration when reaching the 1st pokemon of the 4th generation - .truncate(when=lambda poke: poke["generation"]["name"] == "generation-iv") - .observe("pokemons") - # Keep only quadruped Pokemons - .filter(lambda poke: poke["shape"]["name"] == "quadruped") - .observe("quadruped pokemons") - # Write a batch of pokemons every 5 seconds to the CSV file - .group(interval=timedelta(seconds=5)) - .foreach(writer.writerows) - .flatten() - .observe("written pokemons") - # Catch exceptions and raises the 1st one at the end of the iteration - .catch(Exception, finally_raise=True) - ) - - # Start a full iteration - pipeline() + # Start a full iteration + pipeline() ``` ## ... or the `async` way! Let's write an `async` version of this script: -- using `httpx.AsyncCLient` together with the `.amap` operation (the `async` counterpart of `.map`). -- instead of calling `pipeline()` to iterate over it as an `Iterable`, let's `await pipeline` to iterate over it as an `AsyncIterable`. +- `httpx.CLient` becomes `httpx.AsyncCLient` +- `.map` becomes `.amap` +- `pipeline()` becomes `await pipeline` ```python import asyncio import csv from datetime import timedelta -import itertools +from itertools import count import httpx from streamable import Stream @@ -122,15 +122,15 @@ async def main() -> None: writer = csv.DictWriter(file, fields, extrasaction='ignore') writer.writeheader() - async with httpx.AsyncClient() as http: - pipeline: Stream = ( + async with httpx.AsyncClient() as http_client: + pipeline = ( # Infinite Stream[int] of Pokemon ids starting from Pokémon #1: Bulbasaur - Stream(itertools.count(1)) + Stream(count(1)) # Limit to 16 requests per second to be friendly to our fellow PokéAPI devs .throttle(16, per=timedelta(seconds=1)) # GET pokemons via 8 concurrent coroutines .map(lambda poke_id: f"https://pokeapi.co/api/v2/pokemon-species/{poke_id}") - .amap(http.get, concurrency=8) + .amap(http_client.get, concurrency=8) .foreach(httpx.Response.raise_for_status) .map(httpx.Response.json) # Stop the iteration when reaching the 1st pokemon of the 4th generation @@ -138,7 +138,6 @@ async def main() -> None: .observe("pokemons") # Keep only quadruped Pokemons .filter(lambda poke: poke["shape"]["name"] == "quadruped") - .observe("quadruped pokemons") # Write a batch of pokemons every 5 seconds to the CSV file .group(interval=timedelta(seconds=5)) .foreach(writer.writerows) diff --git a/requirements-dev.txt b/requirements-dev.txt index ca467e0a..28072aca 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,8 +1,10 @@ -ruff==0.9.10 -mypy==1.7.1 -mypy-extensions==1.0.0 -parameterized==0.9.0 coverage==7.10.6 httpx==0.28.1 +parameterized==0.9.0 +mypy==1.7.1 +mypy-extensions==1.0.0 +pytest==7.4.4 +pytest-asyncio==0.21.2 requests==2.32.5 +ruff==0.9.10 typing-extensions==4.12.2 diff --git a/streamable/__init__.py b/streamable/__init__.py index 17cf16ae..50eb0c29 100644 --- a/streamable/__init__.py +++ b/streamable/__init__.py @@ -1,4 +1,4 @@ from streamable.stream import Stream -from streamable.util.functiontools import star +from streamable._util._functiontools import star __all__ = ["Stream", "star"] diff --git a/streamable/aiterators.py b/streamable/_aiterators.py similarity index 84% rename from streamable/aiterators.py rename to streamable/_aiterators.py index e53a7985..86df3938 100644 --- a/streamable/aiterators.py +++ b/streamable/_aiterators.py @@ -1,7 +1,6 @@ import asyncio import datetime import multiprocessing -import queue import time from abc import ABC, abstractmethod from collections import defaultdict, deque @@ -22,7 +21,6 @@ Iterable, Iterator, List, - NamedTuple, Optional, Set, Tuple, @@ -32,16 +30,15 @@ cast, ) -from streamable.util.asynctools import awaitable_to_coroutine, empty_aiter -from streamable.util.contextmanagertools import noop_context_manager -from streamable.util.loggertools import get_logger +from streamable._util._asynctools import awaitable_to_coroutine, empty_aiter +from streamable._util._contextmanagertools import noop_context_manager +from streamable._util._errortools import ExceptionContainer +from streamable._util._loggertools import get_logger -from streamable.util.constants import NO_REPLACEMENT -from streamable.util.futuretools import ( - FDFOAsyncFutureResultCollection, - FDFOOSFutureResultCollection, - FIFOAsyncFutureResultCollection, - FIFOOSFutureResultCollection, +from streamable._util._constants import NO_REPLACEMENT +from streamable._util._futuretools import ( + AsyncFDFOFutureResultCollection, + AsyncFIFOFutureResultCollection, FutureResult, FutureResultCollection, ) @@ -559,9 +556,6 @@ async def __anext__(self) -> T: class _RaisingAsyncIterator(AsyncIterator[T]): - class ExceptionContainer(NamedTuple): - exception: Exception - def __init__( self, iterator: AsyncIterator[Union[T, ExceptionContainer]], @@ -570,7 +564,7 @@ def __init__( async def __anext__(self) -> T: elem = await self.iterator.__anext__() - if isinstance(elem, self.ExceptionContainer): + if isinstance(elem, ExceptionContainer): try: raise elem.exception finally: @@ -586,7 +580,7 @@ async def __anext__(self) -> T: class _BaseConcurrentMapAsyncIterable( Generic[T, U], ABC, - AsyncIterable[Union[U, _RaisingAsyncIterator.ExceptionContainer]], + AsyncIterable[Union[U, ExceptionContainer]], ): def __init__( self, @@ -602,29 +596,29 @@ def _context_manager(self) -> ContextManager: return noop_context_manager() @abstractmethod - def _launch_task( - self, elem: T - ) -> "Future[Union[U, _RaisingAsyncIterator.ExceptionContainer]]": ... + def _launch_task(self, elem: T) -> "Future[Union[U, ExceptionContainer]]": ... - # factory method - @abstractmethod def _future_result_collection( self, - ) -> FutureResultCollection[Union[U, _RaisingAsyncIterator.ExceptionContainer]]: ... + ) -> FutureResultCollection[Union[U, ExceptionContainer]]: + if self.ordered: + return AsyncFIFOFutureResultCollection(asyncio.get_running_loop()) + return AsyncFDFOFutureResultCollection(asyncio.get_running_loop()) async def _next_future( self, - ) -> Optional["Future[Union[U, _RaisingAsyncIterator.ExceptionContainer]]"]: + ) -> Optional["Future[Union[U, ExceptionContainer]]"]: try: - return self._launch_task(await self.iterator.__anext__()) + elem = await self.iterator.__anext__() except StopAsyncIteration: return None except Exception as e: - return FutureResult(_RaisingAsyncIterator.ExceptionContainer(e)) + return FutureResult(ExceptionContainer(e)) + return self._launch_task(elem) async def __aiter__( self, - ) -> AsyncIterator[Union[U, _RaisingAsyncIterator.ExceptionContainer]]: + ) -> AsyncIterator[Union[U, ExceptionContainer]]: with self._context_manager(): future_results = self._future_result_collection() @@ -634,13 +628,13 @@ async def __aiter__( if not future: # no more tasks to queue break - future_results.add_future(future) + future_results.add(future) # queue, wait, yield while future_results: future = await self._next_future() if future: - future_results.add_future(future) + future_results.add(future) yield await future_results.__anext__() @@ -674,26 +668,18 @@ def _context_manager(self) -> ContextManager: @staticmethod def _safe_transformation( transformation: Callable[[T], U], elem: T - ) -> Union[U, _RaisingAsyncIterator.ExceptionContainer]: + ) -> Union[U, ExceptionContainer]: try: return transformation(elem) except Exception as e: - return _RaisingAsyncIterator.ExceptionContainer(e) + return ExceptionContainer(e) - def _launch_task( - self, elem: T - ) -> "Future[Union[U, _RaisingAsyncIterator.ExceptionContainer]]": - return self.executor.submit( - self._safe_transformation, self.transformation, elem - ) - - def _future_result_collection( - self, - ) -> FutureResultCollection[Union[U, _RaisingAsyncIterator.ExceptionContainer]]: - if self.ordered: - return FIFOOSFutureResultCollection() - return FDFOOSFutureResultCollection( - multiprocessing.Queue if self.via == "process" else queue.Queue + def _launch_task(self, elem: T) -> "Future[Union[U, ExceptionContainer]]": + return cast( + "Future[Union[U, ExceptionContainer]]", + asyncio.get_running_loop().run_in_executor( + self.executor, self._safe_transformation, self.transformation, elem + ), ) @@ -724,42 +710,41 @@ def __init__( self, iterator: AsyncIterator[T], transformation: Callable[[T], Coroutine[Any, Any, U]], + concurrency: int, buffersize: int, ordered: bool, ) -> None: super().__init__(iterator, buffersize, ordered) self.transformation = transformation + self.concurrency = concurrency + self._semaphore: Optional[asyncio.Semaphore] = None + + @property + def semaphore(self) -> asyncio.Semaphore: + if not self._semaphore: + self._semaphore = asyncio.Semaphore(self.concurrency) + return self._semaphore - async def _safe_transformation( - self, elem: T - ) -> Union[U, _RaisingAsyncIterator.ExceptionContainer]: + async def _safe_transformation(self, elem: T) -> Union[U, ExceptionContainer]: try: - return await self.transformation(elem) + async with self.semaphore: + return await self.transformation(elem) except Exception as e: - return _RaisingAsyncIterator.ExceptionContainer(e) + return ExceptionContainer(e) - def _launch_task( - self, elem: T - ) -> "Future[Union[U, _RaisingAsyncIterator.ExceptionContainer]]": + def _launch_task(self, elem: T) -> "Future[Union[U, ExceptionContainer]]": return cast( - "Future[Union[U, _RaisingAsyncIterator.ExceptionContainer]]", + "Future[Union[U, ExceptionContainer]]", asyncio.get_running_loop().create_task(self._safe_transformation(elem)), ) - def _future_result_collection( - self, - ) -> FutureResultCollection[Union[U, _RaisingAsyncIterator.ExceptionContainer]]: - if self.ordered: - return FIFOAsyncFutureResultCollection(asyncio.get_running_loop()) - else: - return FDFOAsyncFutureResultCollection(asyncio.get_running_loop()) - class ConcurrentAMapAsyncIterator(_RaisingAsyncIterator[U]): def __init__( self, iterator: AsyncIterator[T], transformation: Callable[[T], Coroutine[Any, Any, U]], + concurrency: int, buffersize: int, ordered: bool, ) -> None: @@ -767,6 +752,7 @@ def __init__( _ConcurrentAMapAsyncIterable( iterator, transformation, + concurrency, buffersize, ordered, ).__aiter__() @@ -778,9 +764,7 @@ def __init__( ###################### -class _ConcurrentFlattenAsyncIterable( - AsyncIterable[Union[T, _RaisingAsyncIterator.ExceptionContainer]] -): +class _ConcurrentFlattenAsyncIterable(AsyncIterable[Union[T, ExceptionContainer]]): def __init__( self, iterables_iterator: AsyncIterator[Iterable[T]], @@ -790,34 +774,29 @@ def __init__( self.iterables_iterator = iterables_iterator self.concurrency = concurrency self.buffersize = buffersize + self._next = ExceptionContainer.wrap(next) async def __aiter__( self, - ) -> AsyncIterator[Union[T, _RaisingAsyncIterator.ExceptionContainer]]: + ) -> AsyncIterator[Union[T, ExceptionContainer]]: with ThreadPoolExecutor(max_workers=self.concurrency) as executor: iterator_and_future_pairs: Deque[ Tuple[ Optional[Iterator[T]], - "Future[Union[T, _RaisingAsyncIterator.ExceptionContainer]]", + "Awaitable[Union[T, ExceptionContainer]]", ] ] = deque() - element_to_yield: Deque[ - Union[T, _RaisingAsyncIterator.ExceptionContainer] - ] = deque(maxlen=1) + to_yield: Deque[Union[T, ExceptionContainer]] = deque(maxlen=1) iterator_to_queue: Optional[Iterator[T]] = None # wait, queue, yield (FIFO) while True: if iterator_and_future_pairs: iterator, future = iterator_and_future_pairs.popleft() - try: - element_to_yield.append(future.result()) - iterator_to_queue = iterator - except StopIteration: - pass - except Exception as e: - element_to_yield.append( - _RaisingAsyncIterator.ExceptionContainer(e) - ) + elem = await future + if not isinstance(elem, ExceptionContainer) or not isinstance( + elem.exception, StopIteration + ): + to_yield.append(elem) iterator_to_queue = iterator # queue tasks up to buffersize @@ -831,18 +810,18 @@ async def __aiter__( iterator_to_queue = iterable.__iter__() except Exception as e: iterator_to_queue = None - future = FutureResult( - _RaisingAsyncIterator.ExceptionContainer(e) - ) + future = FutureResult(ExceptionContainer(e)) iterator_and_future_pairs.append( (iterator_to_queue, future) ) continue - future = executor.submit(next, iterator_to_queue) + future = asyncio.get_running_loop().run_in_executor( + executor, self._next, iterator_to_queue + ) iterator_and_future_pairs.append((iterator_to_queue, future)) iterator_to_queue = None - if element_to_yield: - yield element_to_yield.pop() + if to_yield: + yield to_yield.pop() if not iterator_and_future_pairs: break @@ -863,9 +842,7 @@ def __init__( ) -class _ConcurrentAFlattenAsyncIterable( - AsyncIterable[Union[T, _RaisingAsyncIterator.ExceptionContainer]] -): +class _ConcurrentAFlattenAsyncIterable(AsyncIterable[Union[T, ExceptionContainer]]): def __init__( self, iterables_iterator: AsyncIterator[AsyncIterable[T]], @@ -878,28 +855,26 @@ def __init__( async def __aiter__( self, - ) -> AsyncIterator[Union[T, _RaisingAsyncIterator.ExceptionContainer]]: + ) -> AsyncIterator[Union[T, ExceptionContainer]]: iterator_and_future_pairs: Deque[ Tuple[ Optional[AsyncIterator[T]], - Awaitable[Union[T, _RaisingAsyncIterator.ExceptionContainer]], + Awaitable[Union[T, ExceptionContainer]], ] ] = deque() - element_to_yield: Deque[Union[T, _RaisingAsyncIterator.ExceptionContainer]] = ( - deque(maxlen=1) - ) + to_yield: Deque[Union[T, ExceptionContainer]] = deque(maxlen=1) iterator_to_queue: Optional[AsyncIterator[T]] = None # wait, queue, yield (FIFO) while True: if iterator_and_future_pairs: iterator, future = iterator_and_future_pairs.popleft() try: - element_to_yield.append(await future) + to_yield.append(await future) iterator_to_queue = iterator except StopAsyncIteration: pass except Exception as e: - element_to_yield.append(_RaisingAsyncIterator.ExceptionContainer(e)) + to_yield.append(ExceptionContainer(e)) iterator_to_queue = iterator # queue tasks up to buffersize @@ -913,9 +888,7 @@ async def __aiter__( iterator_to_queue = iterable.__aiter__() except Exception as e: iterator_to_queue = None - future = FutureResult( - _RaisingAsyncIterator.ExceptionContainer(e) - ) + future = FutureResult(ExceptionContainer(e)) iterator_and_future_pairs.append((iterator_to_queue, future)) continue future = asyncio.get_running_loop().create_task( @@ -923,8 +896,8 @@ async def __aiter__( ) iterator_and_future_pairs.append((iterator_to_queue, future)) iterator_to_queue = None - if element_to_yield: - yield element_to_yield.pop() + if to_yield: + yield to_yield.pop() if not iterator_and_future_pairs: break diff --git a/streamable/iterators.py b/streamable/_iterators.py similarity index 85% rename from streamable/iterators.py rename to streamable/_iterators.py index c816bcac..fe2848cf 100644 --- a/streamable/iterators.py +++ b/streamable/_iterators.py @@ -22,7 +22,6 @@ Iterable, Iterator, List, - NamedTuple, Optional, Set, Tuple, @@ -31,20 +30,21 @@ Union, cast, ) -from streamable.util.asynctools import ( +from streamable._util._asynctools import ( CloseEventLoopMixin, awaitable_to_coroutine, empty_aiter, ) -from streamable.util.contextmanagertools import noop_context_manager -from streamable.util.loggertools import get_logger - -from streamable.util.constants import NO_REPLACEMENT -from streamable.util.futuretools import ( - FDFOAsyncFutureResultCollection, - FDFOOSFutureResultCollection, - FIFOAsyncFutureResultCollection, - FIFOOSFutureResultCollection, +from streamable._util._contextmanagertools import noop_context_manager +from streamable._util._errortools import ExceptionContainer +from streamable._util._loggertools import get_logger + +from streamable._util._constants import NO_REPLACEMENT +from streamable._util._futuretools import ( + AsyncFDFOFutureResultCollection, + ExecutorFDFOFutureResultCollection, + AsyncFIFOFutureResultCollection, + ExecutorFIFOFutureResultCollection, FutureResult, FutureResultCollection, ) @@ -512,9 +512,6 @@ def __next__(self) -> T: class _RaisingIterator(Iterator[T]): - class ExceptionContainer(NamedTuple): - exception: Exception - def __init__( self, iterator: Iterator[Union[T, ExceptionContainer]], @@ -523,7 +520,7 @@ def __init__( def __next__(self) -> T: elem = self.iterator.__next__() - if isinstance(elem, self.ExceptionContainer): + if isinstance(elem, ExceptionContainer): try: raise elem.exception finally: @@ -537,7 +534,7 @@ def __next__(self) -> T: class _BaseConcurrentMapIterable( - Generic[T, U], ABC, Iterable[Union[U, _RaisingIterator.ExceptionContainer]] + Generic[T, U], ABC, Iterable[Union[U, ExceptionContainer]] ): def __init__( self, @@ -553,27 +550,26 @@ def _context_manager(self) -> ContextManager: return noop_context_manager() @abstractmethod - def _launch_task( - self, elem: T - ) -> "Future[Union[U, _RaisingIterator.ExceptionContainer]]": ... + def _launch_task(self, elem: T) -> "Future[Union[U, ExceptionContainer]]": ... # factory method @abstractmethod def _future_result_collection( self, - ) -> FutureResultCollection[Union[U, _RaisingIterator.ExceptionContainer]]: ... + ) -> FutureResultCollection[Union[U, ExceptionContainer]]: ... def _next_future( self, - ) -> Optional["Future[Union[U, _RaisingIterator.ExceptionContainer]]"]: + ) -> Optional["Future[Union[U, ExceptionContainer]]"]: try: - return self._launch_task(self.iterator.__next__()) + elem = self.iterator.__next__() except StopIteration: return None except Exception as e: - return FutureResult(_RaisingIterator.ExceptionContainer(e)) + return FutureResult(ExceptionContainer(e)) + return self._launch_task(elem) - def __iter__(self) -> Iterator[Union[U, _RaisingIterator.ExceptionContainer]]: + def __iter__(self) -> Iterator[Union[U, ExceptionContainer]]: with self._context_manager(): future_results = self._future_result_collection() @@ -583,13 +579,13 @@ def __iter__(self) -> Iterator[Union[U, _RaisingIterator.ExceptionContainer]]: if not future: # no more tasks to queue break - future_results.add_future(future) + future_results.add(future) # queue, wait, yield while future_results: future = self._next_future() if future: - future_results.add_future(future) + future_results.add(future) yield future_results.__next__() @@ -623,25 +619,23 @@ def _context_manager(self) -> ContextManager: @staticmethod def _safe_transformation( transformation: Callable[[T], U], elem: T - ) -> Union[U, _RaisingIterator.ExceptionContainer]: + ) -> Union[U, ExceptionContainer]: try: return transformation(elem) except Exception as e: - return _RaisingIterator.ExceptionContainer(e) + return ExceptionContainer(e) - def _launch_task( - self, elem: T - ) -> "Future[Union[U, _RaisingIterator.ExceptionContainer]]": + def _launch_task(self, elem: T) -> "Future[Union[U, ExceptionContainer]]": return self.executor.submit( self._safe_transformation, self.transformation, elem ) def _future_result_collection( self, - ) -> FutureResultCollection[Union[U, _RaisingIterator.ExceptionContainer]]: + ) -> FutureResultCollection[Union[U, ExceptionContainer]]: if self.ordered: - return FIFOOSFutureResultCollection() - return FDFOOSFutureResultCollection( + return ExecutorFIFOFutureResultCollection() + return ExecutorFDFOFutureResultCollection( multiprocessing.Queue if self.via == "process" else queue.Queue ) @@ -674,36 +668,41 @@ def __init__( event_loop: asyncio.AbstractEventLoop, iterator: Iterator[T], transformation: Callable[[T], Coroutine[Any, Any, U]], + concurrency: int, buffersize: int, ordered: bool, ) -> None: super().__init__(iterator, buffersize, ordered) self.transformation = transformation self.event_loop = event_loop + self.concurrency = concurrency + self._semaphore: Optional[asyncio.Semaphore] = None - async def _safe_transformation( - self, elem: T - ) -> Union[U, _RaisingIterator.ExceptionContainer]: + @property + def semaphore(self) -> asyncio.Semaphore: + if not self._semaphore: + self._semaphore = asyncio.Semaphore(self.concurrency) + return self._semaphore + + async def _safe_transformation(self, elem: T) -> Union[U, ExceptionContainer]: try: - return await self.transformation(elem) + async with self.semaphore: + return await self.transformation(elem) except Exception as e: - return _RaisingIterator.ExceptionContainer(e) + return ExceptionContainer(e) - def _launch_task( - self, elem: T - ) -> "Future[Union[U, _RaisingIterator.ExceptionContainer]]": + def _launch_task(self, elem: T) -> "Future[Union[U, ExceptionContainer]]": return cast( - "Future[Union[U, _RaisingIterator.ExceptionContainer]]", + "Future[Union[U, ExceptionContainer]]", self.event_loop.create_task(self._safe_transformation(elem)), ) def _future_result_collection( self, - ) -> FutureResultCollection[Union[U, _RaisingIterator.ExceptionContainer]]: + ) -> FutureResultCollection[Union[U, ExceptionContainer]]: if self.ordered: - return FIFOAsyncFutureResultCollection(self.event_loop) - else: - return FDFOAsyncFutureResultCollection(self.event_loop) + return AsyncFIFOFutureResultCollection(self.event_loop) + return AsyncFDFOFutureResultCollection(self.event_loop) class ConcurrentAMapIterator(_RaisingIterator[U]): @@ -712,6 +711,7 @@ def __init__( event_loop: asyncio.AbstractEventLoop, iterator: Iterator[T], transformation: Callable[[T], Coroutine[Any, Any, U]], + concurrency: int, buffersize: int, ordered: bool, ) -> None: @@ -720,6 +720,7 @@ def __init__( event_loop, iterator, transformation, + concurrency, buffersize, ordered, ).__iter__() @@ -731,9 +732,7 @@ def __init__( ###################### -class _ConcurrentFlattenIterable( - Iterable[Union[T, _RaisingIterator.ExceptionContainer]] -): +class _ConcurrentFlattenIterable(Iterable[Union[T, ExceptionContainer]]): def __init__( self, iterables_iterator: Iterator[Iterable[T]], @@ -743,30 +742,27 @@ def __init__( self.iterables_iterator = iterables_iterator self.concurrency = concurrency self.buffersize = buffersize + self._next = ExceptionContainer.wrap(next) - def __iter__(self) -> Iterator[Union[T, _RaisingIterator.ExceptionContainer]]: + def __iter__(self) -> Iterator[Union[T, ExceptionContainer]]: with ThreadPoolExecutor(max_workers=self.concurrency) as executor: iterator_and_future_pairs: Deque[ Tuple[ Optional[Iterator[T]], - "Future[Union[T, _RaisingIterator.ExceptionContainer]]", + "Future[Union[T, ExceptionContainer]]", ] ] = deque() - element_to_yield: Deque[Union[T, _RaisingIterator.ExceptionContainer]] = ( - deque(maxlen=1) - ) + to_yield: Deque[Union[T, ExceptionContainer]] = deque(maxlen=1) iterator_to_queue: Optional[Iterator[T]] = None # wait, queue, yield (FIFO) while True: if iterator_and_future_pairs: iterator, future = iterator_and_future_pairs.popleft() - try: - element_to_yield.append(future.result()) - iterator_to_queue = iterator - except StopIteration: - pass - except Exception as e: - element_to_yield.append(_RaisingIterator.ExceptionContainer(e)) + elem = future.result() + if not isinstance(elem, ExceptionContainer) or not isinstance( + elem.exception, StopIteration + ): + to_yield.append(elem) iterator_to_queue = iterator # queue tasks up to buffersize @@ -780,18 +776,16 @@ def __iter__(self) -> Iterator[Union[T, _RaisingIterator.ExceptionContainer]]: iterator_to_queue = iterable.__iter__() except Exception as e: iterator_to_queue = None - future = FutureResult( - _RaisingIterator.ExceptionContainer(e) - ) + future = FutureResult(ExceptionContainer(e)) iterator_and_future_pairs.append( (iterator_to_queue, future) ) continue - future = executor.submit(next, iterator_to_queue) + future = executor.submit(self._next, iterator_to_queue) iterator_and_future_pairs.append((iterator_to_queue, future)) iterator_to_queue = None - if element_to_yield: - yield element_to_yield.pop() + if to_yield: + yield to_yield.pop() if not iterator_and_future_pairs: break @@ -813,7 +807,7 @@ def __init__( class _ConcurrentAFlattenIterable( - Iterable[Union[T, _RaisingIterator.ExceptionContainer]], CloseEventLoopMixin + Iterable[Union[T, ExceptionContainer]], CloseEventLoopMixin ): def __init__( self, @@ -827,28 +821,26 @@ def __init__( self.buffersize = buffersize self.event_loop = event_loop - def __iter__(self) -> Iterator[Union[T, _RaisingIterator.ExceptionContainer]]: + def __iter__(self) -> Iterator[Union[T, ExceptionContainer]]: iterator_and_future_pairs: Deque[ Tuple[ Optional[AsyncIterator[T]], - Awaitable[Union[T, _RaisingIterator.ExceptionContainer]], + Awaitable[Union[T, ExceptionContainer]], ] ] = deque() - element_to_yield: Deque[Union[T, _RaisingIterator.ExceptionContainer]] = deque( - maxlen=1 - ) + to_yield: Deque[Union[T, ExceptionContainer]] = deque(maxlen=1) iterator_to_queue: Optional[AsyncIterator[T]] = None # wait, queue, yield (FIFO) while True: if iterator_and_future_pairs: iterator, future = iterator_and_future_pairs.popleft() try: - element_to_yield.append(self.event_loop.run_until_complete(future)) + to_yield.append(self.event_loop.run_until_complete(future)) iterator_to_queue = iterator except StopAsyncIteration: pass except Exception as e: - element_to_yield.append(_RaisingIterator.ExceptionContainer(e)) + to_yield.append(ExceptionContainer(e)) iterator_to_queue = iterator # queue tasks up to buffersize @@ -862,7 +854,7 @@ def __iter__(self) -> Iterator[Union[T, _RaisingIterator.ExceptionContainer]]: iterator_to_queue = iterable.__aiter__() except Exception as e: iterator_to_queue = None - future = FutureResult(_RaisingIterator.ExceptionContainer(e)) + future = FutureResult(ExceptionContainer(e)) iterator_and_future_pairs.append((iterator_to_queue, future)) continue future = self.event_loop.create_task( @@ -870,8 +862,8 @@ def __iter__(self) -> Iterator[Union[T, _RaisingIterator.ExceptionContainer]]: ) iterator_and_future_pairs.append((iterator_to_queue, future)) iterator_to_queue = None - if element_to_yield: - yield element_to_yield.pop() + if to_yield: + yield to_yield.pop() if not iterator_and_future_pairs: break diff --git a/streamable/util/__init__.py b/streamable/_util/__init__.py similarity index 100% rename from streamable/util/__init__.py rename to streamable/_util/__init__.py diff --git a/streamable/util/asynctools.py b/streamable/_util/_asynctools.py similarity index 100% rename from streamable/util/asynctools.py rename to streamable/_util/_asynctools.py diff --git a/streamable/util/constants.py b/streamable/_util/_constants.py similarity index 100% rename from streamable/util/constants.py rename to streamable/_util/_constants.py diff --git a/streamable/util/contextmanagertools.py b/streamable/_util/_contextmanagertools.py similarity index 100% rename from streamable/util/contextmanagertools.py rename to streamable/_util/_contextmanagertools.py diff --git a/streamable/_util/_errortools.py b/streamable/_util/_errortools.py new file mode 100644 index 00000000..4362a9d2 --- /dev/null +++ b/streamable/_util/_errortools.py @@ -0,0 +1,18 @@ +from typing import Callable, NamedTuple, TypeVar, Union + +T = TypeVar("T") +U = TypeVar("U") + + +class ExceptionContainer(NamedTuple): + exception: Exception + + @staticmethod + def wrap(func: Callable[[T], U]) -> Callable[[T], Union[U, "ExceptionContainer"]]: + def error_wrapping(_: T) -> Union[U, "ExceptionContainer"]: + try: + return func(_) + except Exception as e: + return ExceptionContainer(e) + + return error_wrapping diff --git a/streamable/util/functiontools.py b/streamable/_util/_functiontools.py similarity index 98% rename from streamable/util/functiontools.py rename to streamable/_util/_functiontools.py index 72bbb605..019b1adb 100644 --- a/streamable/util/functiontools.py +++ b/streamable/_util/_functiontools.py @@ -10,7 +10,7 @@ overload, ) -from streamable.util.asynctools import CloseEventLoopMixin +from streamable._util._asynctools import CloseEventLoopMixin T = TypeVar("T") R = TypeVar("R") diff --git a/streamable/util/futuretools.py b/streamable/_util/_futuretools.py similarity index 72% rename from streamable/util/futuretools.py rename to streamable/_util/_futuretools.py index 9bcd92ca..535317ff 100644 --- a/streamable/util/futuretools.py +++ b/streamable/_util/_futuretools.py @@ -3,6 +3,7 @@ from collections import deque from concurrent.futures import Future from contextlib import suppress +import sys from typing import ( AsyncIterator, Awaitable, @@ -12,12 +13,13 @@ Sized, Type, TypeVar, + Union, cast, ) with suppress(ImportError): - from streamable.util.protocols import Queue + from streamable._util._protocols import Queue T = TypeVar("T") @@ -38,39 +40,40 @@ class FutureResultCollection(Iterator[T], AsyncIterator[T], Sized, ABC): """ @abstractmethod - def add_future(self, future: "Future[T]") -> None: ... + def add(self, future: "Future[T]") -> None: ... async def __anext__(self) -> T: return self.__next__() -class DequeFutureResultCollection(FutureResultCollection[T]): +class FIFOFutureResultCollection(FutureResultCollection[T]): def __init__(self) -> None: self._futures: Deque["Future[T]"] = deque() def __len__(self) -> int: return len(self._futures) - def add_future(self, future: "Future[T]") -> None: + def add(self, future: "Future[T]") -> None: return self._futures.append(future) -class CallbackFutureResultCollection(FutureResultCollection[T]): +class FDFObackFutureResultCollection(FutureResultCollection[T]): def __init__(self) -> None: self._n_futures = 0 + self._results: "Union[Queue[T], asyncio.Queue[T]]" def __len__(self) -> int: return self._n_futures - @abstractmethod - def _done_callback(self, future: "Future[T]") -> None: ... + def _done_callback(self, future: "Future[T]") -> None: + self._results.put_nowait(future.result()) - def add_future(self, future: "Future[T]") -> None: + def add(self, future: "Future[T]") -> None: future.add_done_callback(self._done_callback) self._n_futures += 1 -class FIFOOSFutureResultCollection(DequeFutureResultCollection[T]): +class ExecutorFIFOFutureResultCollection(FIFOFutureResultCollection[T]): """ First In First Out """ @@ -79,7 +82,7 @@ def __next__(self) -> T: return self._futures.popleft().result() -class FDFOOSFutureResultCollection(CallbackFutureResultCollection[T]): +class ExecutorFDFOFutureResultCollection(FDFObackFutureResultCollection[T]): """ First Done First Out """ @@ -88,16 +91,13 @@ def __init__(self, queue_type: Type["Queue"]) -> None: super().__init__() self._results: "Queue[T]" = queue_type() - def _done_callback(self, future: "Future[T]") -> None: - self._results.put_nowait(future.result()) - def __next__(self) -> T: result = self._results.get() self._n_futures -= 1 return result -class FIFOAsyncFutureResultCollection(DequeFutureResultCollection[T]): +class AsyncFIFOFutureResultCollection(FIFOFutureResultCollection[T]): """ First In First Out """ @@ -115,7 +115,7 @@ async def __anext__(self) -> T: return await cast(Awaitable[T], self._futures.popleft()) -class FDFOAsyncFutureResultCollection(CallbackFutureResultCollection[T]): +class AsyncFDFOFutureResultCollection(FDFObackFutureResultCollection[T]): """ First Done First Out """ @@ -123,20 +123,18 @@ class FDFOAsyncFutureResultCollection(CallbackFutureResultCollection[T]): def __init__(self, event_loop: asyncio.AbstractEventLoop) -> None: super().__init__() self.event_loop = event_loop - asyncio.set_event_loop(event_loop) - self._results: "asyncio.Queue[T]" = asyncio.Queue() - - def _done_callback(self, future: "Future[T]") -> None: - self._results.put_nowait(future.result()) + self._results: "asyncio.Queue[T]" + if sys.version_info >= (3, 10): + self._results = asyncio.Queue() + else: # pragma: no cover + self._results = asyncio.Queue(loop=event_loop) # type: ignore def __next__(self) -> T: result = self.event_loop.run_until_complete(self._results.get()) self._n_futures -= 1 - self._waiter = self.event_loop.create_future() return result async def __anext__(self) -> T: result = await self._results.get() self._n_futures -= 1 - self._waiter = self.event_loop.create_future() return result diff --git a/streamable/util/iterabletools.py b/streamable/_util/_iterabletools.py similarity index 96% rename from streamable/util/iterabletools.py rename to streamable/_util/_iterabletools.py index ab4ea0c0..6f21b722 100644 --- a/streamable/util/iterabletools.py +++ b/streamable/_util/_iterabletools.py @@ -8,7 +8,7 @@ TypeVar, ) -from streamable.util.asynctools import CloseEventLoopMixin +from streamable._util._asynctools import CloseEventLoopMixin T = TypeVar("T") diff --git a/streamable/util/loggertools.py b/streamable/_util/_loggertools.py similarity index 100% rename from streamable/util/loggertools.py rename to streamable/_util/_loggertools.py diff --git a/streamable/util/protocols.py b/streamable/_util/_protocols.py similarity index 100% rename from streamable/util/protocols.py rename to streamable/_util/_protocols.py diff --git a/streamable/util/validationtools.py b/streamable/_util/_validationtools.py similarity index 100% rename from streamable/util/validationtools.py rename to streamable/_util/_validationtools.py diff --git a/streamable/afunctions.py b/streamable/afunctions.py index 43838701..ce652b23 100644 --- a/streamable/afunctions.py +++ b/streamable/afunctions.py @@ -17,7 +17,7 @@ Union, ) -from streamable.aiterators import ( +from streamable._aiterators import ( ACatchAsyncIterator, ADistinctAsyncIterator, AFilterAsyncIterator, @@ -39,8 +39,8 @@ PredicateATruncateAsyncIterator, YieldsPerPeriodThrottleAsyncIterator, ) -from streamable.util.constants import NO_REPLACEMENT -from streamable.util.functiontools import asyncify +from streamable._util._constants import NO_REPLACEMENT +from streamable._util._functiontools import asyncify with suppress(ImportError): from typing import Literal @@ -134,12 +134,11 @@ def flatten( ) -> AsyncIterator[T]: if concurrency == 1: return FlattenAsyncIterator(aiterator) - else: - return ConcurrentFlattenAsyncIterator( - aiterator, - concurrency=concurrency, - buffersize=concurrency, - ) + return ConcurrentFlattenAsyncIterator( + aiterator, + concurrency=concurrency, + buffersize=concurrency, + ) def aflatten( @@ -147,12 +146,11 @@ def aflatten( ) -> AsyncIterator[T]: if concurrency == 1: return AFlattenAsyncIterator(aiterator) - else: - return ConcurrentAFlattenAsyncIterator( - aiterator, - concurrency=concurrency, - buffersize=concurrency, - ) + return ConcurrentAFlattenAsyncIterator( + aiterator, + concurrency=concurrency, + buffersize=concurrency, + ) def group( @@ -212,15 +210,14 @@ def map( ) -> AsyncIterator[U]: if concurrency == 1: return amap(asyncify(transformation), aiterator) - else: - return ConcurrentMapAsyncIterator( - aiterator, - transformation, - concurrency=concurrency, - buffersize=concurrency, - ordered=ordered, - via=via, - ) + return ConcurrentMapAsyncIterator( + aiterator, + transformation, + concurrency=concurrency, + buffersize=concurrency, + ordered=ordered, + via=via, + ) def amap( @@ -235,6 +232,7 @@ def amap( return ConcurrentAMapAsyncIterator( aiterator, transformation, + concurrency=concurrency, buffersize=concurrency, ordered=ordered, ) diff --git a/streamable/functions.py b/streamable/functions.py index 8d971b87..eb562919 100644 --- a/streamable/functions.py +++ b/streamable/functions.py @@ -18,7 +18,7 @@ Union, ) -from streamable.iterators import ( +from streamable._iterators import ( AFlattenIterator, CatchIterator, ConcurrentAFlattenIterator, @@ -38,8 +38,8 @@ PredicateTruncateIterator, YieldsPerPeriodThrottleIterator, ) -from streamable.util.constants import NO_REPLACEMENT -from streamable.util.functiontools import syncify +from streamable._util._constants import NO_REPLACEMENT +from streamable._util._functiontools import syncify with suppress(ImportError): from typing import Literal @@ -117,12 +117,11 @@ def adistinct( def flatten(iterator: Iterator[Iterable[T]], *, concurrency: int = 1) -> Iterator[T]: if concurrency == 1: return FlattenIterator(iterator) - else: - return ConcurrentFlattenIterator( - iterator, - concurrency=concurrency, - buffersize=concurrency, - ) + return ConcurrentFlattenIterator( + iterator, + concurrency=concurrency, + buffersize=concurrency, + ) def aflatten( @@ -133,13 +132,12 @@ def aflatten( ) -> Iterator[T]: if concurrency == 1: return AFlattenIterator(event_loop, iterator) - else: - return ConcurrentAFlattenIterator( - event_loop, - iterator, - concurrency=concurrency, - buffersize=concurrency, - ) + return ConcurrentAFlattenIterator( + event_loop, + iterator, + concurrency=concurrency, + buffersize=concurrency, + ) def group( @@ -206,15 +204,14 @@ def map( ) -> Iterator[U]: if concurrency == 1: return builtins.map(transformation, iterator) - else: - return ConcurrentMapIterator( - iterator, - transformation, - concurrency=concurrency, - buffersize=concurrency, - ordered=ordered, - via=via, - ) + return ConcurrentMapIterator( + iterator, + transformation, + concurrency=concurrency, + buffersize=concurrency, + ordered=ordered, + via=via, + ) def amap( @@ -231,6 +228,7 @@ def amap( event_loop, iterator, transformation, + concurrency=concurrency, buffersize=concurrency, ordered=ordered, ) diff --git a/streamable/stream.py b/streamable/stream.py index a97e65a8..34c8d11f 100644 --- a/streamable/stream.py +++ b/streamable/stream.py @@ -27,10 +27,10 @@ overload, ) -from streamable.util.constants import NO_REPLACEMENT -from streamable.util.functiontools import asyncify -from streamable.util.loggertools import get_logger -from streamable.util.validationtools import ( +from streamable._util._constants import NO_REPLACEMENT +from streamable._util._functiontools import asyncify +from streamable._util._loggertools import get_logger +from streamable._util._validationtools import ( validate_concurrency, validate_errors, validate_group_size, @@ -110,12 +110,12 @@ def source( return self._source def __iter__(self) -> Iterator[T]: - from streamable.visitors.iterator import IteratorVisitor + from streamable.visitors._iterator import IteratorVisitor return self.accept(IteratorVisitor[T]()) def __aiter__(self) -> AsyncIterator[T]: - from streamable.visitors.aiterator import AsyncIteratorVisitor + from streamable.visitors._aiterator import AsyncIteratorVisitor return self.accept(AsyncIteratorVisitor[T]()) @@ -126,17 +126,17 @@ def __eq__(self, other: Any) -> bool: Returns: bool: True if this stream is equal to `other`. """ - from streamable.visitors.equality import EqualityVisitor + from streamable.visitors._equality import EqualityVisitor return self.accept(EqualityVisitor(other)) def __repr__(self) -> str: - from streamable.visitors.representation import ReprVisitor + from streamable.visitors._representation import ReprVisitor return self.accept(ReprVisitor()) def __str__(self) -> str: - from streamable.visitors.representation import StrVisitor + from streamable.visitors._representation import StrVisitor return self.accept(StrVisitor()) @@ -240,7 +240,7 @@ def catch( """ Catches the upstream exceptions if they are instances of `errors` type and they satisfy the `when` predicate. Optionally yields a `replacement` value. - If any exception was caught during the iteration and `finally_raise=True`, the first caught exception will be raised when the iteration finishes. + If any exception was caught during the iteration and `finally_raise=True`, the first exception caught will be raised when the iteration finishes. Args: errors (Optional[Type[Exception]], Iterable[Optional[Type[Exception]]], optional): The exception type to catch, or an iterable of exception types to catch (default: catches all `Exception`s) @@ -273,7 +273,7 @@ def acatch( """ Catches the upstream exceptions if they are instances of `errors` type and they satisfy the `when` predicate. Optionally yields a `replacement` value. - If any exception was caught during the iteration and `finally_raise=True`, the first caught exception will be raised when the iteration finishes. + If any exception was caught during the iteration and `finally_raise=True`, the first exception caught will be raised when the iteration finishes. Args: errors (Optional[Type[Exception]], Iterable[Optional[Type[Exception]]], optional): The exception type to catch, or an iterable of exception types to catch (default: catches all `Exception`s) diff --git a/streamable/visitors/__init__.py b/streamable/visitors/__init__.py index ba67ca01..d4028ec6 100644 --- a/streamable/visitors/__init__.py +++ b/streamable/visitors/__init__.py @@ -1,3 +1,3 @@ -from streamable.visitors.base import Visitor +from streamable.visitors._base import Visitor __all__ = ["Visitor"] diff --git a/streamable/visitors/aiterator.py b/streamable/visitors/_aiterator.py similarity index 98% rename from streamable/visitors/aiterator.py rename to streamable/visitors/_aiterator.py index 9352c72b..c8bcacc0 100644 --- a/streamable/visitors/aiterator.py +++ b/streamable/visitors/_aiterator.py @@ -26,8 +26,8 @@ ThrottleStream, TruncateStream, ) -from streamable.util.functiontools import async_sidify, sidify -from streamable.util.iterabletools import sync_to_async_iter +from streamable._util._functiontools import async_sidify, sidify +from streamable._util._iterabletools import sync_to_async_iter from streamable.visitors import Visitor T = TypeVar("T") diff --git a/streamable/visitors/base.py b/streamable/visitors/_base.py similarity index 100% rename from streamable/visitors/base.py rename to streamable/visitors/_base.py diff --git a/streamable/visitors/equality.py b/streamable/visitors/_equality.py similarity index 100% rename from streamable/visitors/equality.py rename to streamable/visitors/_equality.py diff --git a/streamable/visitors/iterator.py b/streamable/visitors/_iterator.py similarity index 98% rename from streamable/visitors/iterator.py rename to streamable/visitors/_iterator.py index 2a08d5f9..c259c4f0 100644 --- a/streamable/visitors/iterator.py +++ b/streamable/visitors/_iterator.py @@ -27,12 +27,12 @@ ThrottleStream, TruncateStream, ) -from streamable.util.functiontools import ( +from streamable._util._functiontools import ( async_sidify, sidify, syncify, ) -from streamable.util.iterabletools import async_to_sync_iter +from streamable._util._iterabletools import async_to_sync_iter from streamable.visitors import Visitor T = TypeVar("T") diff --git a/streamable/visitors/representation.py b/streamable/visitors/_representation.py similarity index 98% rename from streamable/visitors/representation.py rename to streamable/visitors/_representation.py index fbe7bd61..bbe6fbda 100644 --- a/streamable/visitors/representation.py +++ b/streamable/visitors/_representation.py @@ -26,8 +26,8 @@ ThrottleStream, TruncateStream, ) -from streamable.util.constants import NO_REPLACEMENT -from streamable.util.functiontools import _Star +from streamable._util._constants import NO_REPLACEMENT +from streamable._util._functiontools import _Star from streamable.visitors import Visitor diff --git a/tests/mocks.py b/tests/mocks.py index 34dbe373..f7d1a40f 100644 --- a/tests/mocks.py +++ b/tests/mocks.py @@ -1,14 +1,13 @@ import json -from typing import Any, Dict, List, Union +from typing import Any, Dict, List from unittest.mock import Mock import httpx -import requests with open("tests/pokemons.json") as pokemon_sample: POKEMONS: List[Dict[str, Any]] = json.loads(pokemon_sample.read()) -def get_poke(url: str) -> Union[requests.Response, httpx.Response]: +def get_poke(url: str) -> httpx.Response: poke_id = int(url.split("/")[-1]) response = Mock() response.text = json.dumps(POKEMONS[poke_id - 1]) @@ -17,5 +16,5 @@ def get_poke(url: str) -> Union[requests.Response, httpx.Response]: return response -async def async_get_poke(url: str) -> Union[requests.Response, httpx.Response]: +async def async_get_poke(url: str) -> httpx.Response: return get_poke(url) diff --git a/tests/requirements.txt b/tests/requirements.txt index fa076abf..92ec2b92 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -1,3 +1,5 @@ parameterized==0.9.0 -requests +pytest==7.4.4 +pytest-asyncio==0.21.2 httpx +requests diff --git a/tests/test_functions.py b/tests/test_functions.py index cf70a872..659d4c23 100644 --- a/tests/test_functions.py +++ b/tests/test_functions.py @@ -1,42 +1,31 @@ import datetime -import unittest -from typing import Callable, Iterator, List, TypeVar, cast +from typing import Callable, Iterator, List, cast from streamable.functions import catch, flatten, group, map, observe, throttle, truncate -T = TypeVar("T") - -# size of the test collections -N = 256 - - -src = range(N) - - -class TestFunctions(unittest.TestCase): - def test_signatures(self) -> None: - iterator = iter(src) - transformation = cast(Callable[[int], int], ...) - mapped_it_1: Iterator[int] = map(transformation, iterator) # noqa: F841 - mapped_it_2: Iterator[int] = map(transformation, iterator, concurrency=1) # noqa: F841 - mapped_it_3: Iterator[int] = map(transformation, iterator, concurrency=2) # noqa: F841 - grouped_it_1: Iterator[List[int]] = group(iterator, size=1) - grouped_it_2: Iterator[List[int]] = group( # noqa: F841 - iterator, size=1, interval=datetime.timedelta(seconds=0.1) - ) - grouped_it_3: Iterator[List[int]] = group( # noqa: F841 - iterator, size=1, interval=datetime.timedelta(seconds=2) - ) - flattened_grouped_it_1: Iterator[int] = flatten(grouped_it_1) # noqa: F841 - flattened_grouped_it_2: Iterator[int] = flatten(grouped_it_1, concurrency=1) # noqa: F841 - flattened_grouped_it_3: Iterator[int] = flatten(grouped_it_1, concurrency=2) # noqa: F841 - caught_it_1: Iterator[int] = catch(iterator, Exception) # noqa: F841 - caught_it_2: Iterator[int] = catch(iterator, Exception, finally_raise=True) # noqa: F841 - observed_it_1: Iterator[int] = observe(iterator, what="objects") # noqa: F841 - throttleed_it_1: Iterator[int] = throttle( # noqa: F841 - iterator, - 1, - per=datetime.timedelta(seconds=1), - ) - truncated_it_1: Iterator[int] = truncate(iterator, count=1) # noqa: F841 +def test_signatures() -> None: + iterator = iter((0,)) + transformation = cast(Callable[[int], int], ...) + mapped_it_1: Iterator[int] = map(transformation, iterator) # noqa: F841 + mapped_it_2: Iterator[int] = map(transformation, iterator, concurrency=1) # noqa: F841 + mapped_it_3: Iterator[int] = map(transformation, iterator, concurrency=2) # noqa: F841 + grouped_it_1: Iterator[List[int]] = group(iterator, size=1) + grouped_it_2: Iterator[List[int]] = group( # noqa: F841 + iterator, size=1, interval=datetime.timedelta(seconds=0.1) + ) + grouped_it_3: Iterator[List[int]] = group( # noqa: F841 + iterator, size=1, interval=datetime.timedelta(seconds=2) + ) + flattened_grouped_it_1: Iterator[int] = flatten(grouped_it_1) # noqa: F841 + flattened_grouped_it_2: Iterator[int] = flatten(grouped_it_1, concurrency=1) # noqa: F841 + flattened_grouped_it_3: Iterator[int] = flatten(grouped_it_1, concurrency=2) # noqa: F841 + caught_it_1: Iterator[int] = catch(iterator, Exception) # noqa: F841 + caught_it_2: Iterator[int] = catch(iterator, Exception, finally_raise=True) # noqa: F841 + observed_it_1: Iterator[int] = observe(iterator, what="objects") # noqa: F841 + throttleed_it_1: Iterator[int] = throttle( # noqa: F841 + iterator, + 1, + per=datetime.timedelta(seconds=1), + ) + truncated_it_1: Iterator[int] = truncate(iterator, count=1) # noqa: F841 diff --git a/tests/test_iterators.py b/tests/test_iterators.py index 1a8fdd6e..82025e52 100644 --- a/tests/test_iterators.py +++ b/tests/test_iterators.py @@ -1,36 +1,37 @@ import asyncio -import unittest from typing import AsyncIterator -from streamable.aiterators import ( +import pytest + +from streamable._aiterators import ( _ConcurrentAMapAsyncIterable, _RaisingAsyncIterator, ) -from streamable.util.asynctools import awaitable_to_coroutine -from streamable.util.iterabletools import sync_to_async_iter +from streamable._util._asynctools import awaitable_to_coroutine +from streamable._util._iterabletools import sync_to_async_iter from tests.utils import async_identity, identity, src -class TestIterators(unittest.TestCase): - def test_ConcurrentAMapAsyncIterable(self) -> None: - with self.assertRaisesRegex( - TypeError, - r"(object int can't be used in 'await' expression)|('int' object can't be awaited)", - msg="`amap` should raise a TypeError if a non async function is passed to it.", - ): - concurrent_amap_async_iterable: _ConcurrentAMapAsyncIterable[int, int] = ( - _ConcurrentAMapAsyncIterable( - sync_to_async_iter(iter(src)), - async_identity, - buffersize=2, - ordered=True, - ) +def test_ConcurrentAMapAsyncIterable() -> None: + # `amap` should raise a TypeError if a non async function is passed to it. + with pytest.raises( + TypeError, + match=r"(object int can't be used in 'await' expression)|('int' object can't be awaited)", + ): + concurrent_amap_async_iterable: _ConcurrentAMapAsyncIterable[int, int] = ( + _ConcurrentAMapAsyncIterable( + sync_to_async_iter(iter(src)), + async_identity, + concurrency=2, + buffersize=2, + ordered=True, ) + ) - # remove error wrapping - concurrent_amap_async_iterable.transformation = identity # type: ignore + # remove error wrapping + concurrent_amap_async_iterable.transformation = identity # type: ignore - aiterator: AsyncIterator[int] = _RaisingAsyncIterator( - concurrent_amap_async_iterable.__aiter__() - ) - asyncio.run(awaitable_to_coroutine(aiterator.__aiter__().__anext__())) + aiterator: AsyncIterator[int] = _RaisingAsyncIterator( + concurrent_amap_async_iterable.__aiter__() + ) + asyncio.run(awaitable_to_coroutine(aiterator.__aiter__().__anext__())) diff --git a/tests/test_readme.py b/tests/test_readme.py index 6c1b9ba4..fcd99c34 100644 --- a/tests/test_readme.py +++ b/tests/test_readme.py @@ -1,10 +1,12 @@ import asyncio +from pathlib import Path import time -import unittest from datetime import timedelta from typing import Iterator, List, Tuple, TypeVar from unittest.mock import patch +import pytest + from streamable.stream import Stream from tests import mocks @@ -22,379 +24,367 @@ # fmt: off -class TestReadme(unittest.TestCase): - def test_iterate(self) -> None: - self.assertListEqual( - list(inverses), - [1.0, 0.5, 0.33, 0.25, 0.2, 0.17, 0.14, 0.12, 0.11], - ) - self.assertSetEqual( - set(inverses), - {0.5, 1.0, 0.2, 0.33, 0.25, 0.17, 0.14, 0.12, 0.11}, - ) - self.assertAlmostEqual(sum(inverses), 2.82) - self.assertEqual(max(inverses), 1.0) - self.assertEqual(max(inverses), 1.0) - inverses_iter = iter(inverses) - self.assertEqual(next(inverses_iter), 1.0) - self.assertEqual(next(inverses_iter), 0.5) - - async def main() -> List[float]: - return [inverse async for inverse in inverses] - - assert asyncio.run(main()) == [1.0, 0.5, 0.33, 0.25, 0.2, 0.17, 0.14, 0.12, 0.11] - - - def test_map_example(self) -> None: - integer_strings: Stream[str] = integers.map(str) - - assert list(integer_strings) == ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'] - - @patch("requests.get", mocks.get_poke) - def test_thread_concurrent_map_example(self) -> None: - import requests - - pokemon_names: Stream[str] = ( - Stream(range(1, 4)) - .map(lambda i: f"https://pokeapi.co/api/v2/pokemon-species/{i}") - .map(requests.get, concurrency=3) - .map(requests.Response.json) - .map(lambda poke: poke["name"]) - ) - assert list(pokemon_names) == ['bulbasaur', 'ivysaur', 'venusaur'] - - def test_process_concurrent_map_example(self) -> None: - state: List[int] = [] - # integers are mapped - assert integers.map(state.append, concurrency=4, via="process").count() == 10 - # but the `state` of the main process is not mutated - assert state == [] - - @patch("httpx.AsyncClient.get", lambda self, url: mocks.async_get_poke(url)) - def test_async_amap_example(self) -> None: - import asyncio - - import httpx - - async def main() -> None: - async with httpx.AsyncClient() as http: - pokemon_names: Stream[str] = ( - Stream(range(1, 4)) - .map(lambda i: f"https://pokeapi.co/api/v2/pokemon-species/{i}") - .amap(http.get, concurrency=3) - .map(httpx.Response.json) - .map(lambda poke: poke["name"]) - ) - # consume as an AsyncIterable[str] - assert [name async for name in pokemon_names] == ['bulbasaur', 'ivysaur', 'venusaur'] +def test_iterate() -> None: + assert list(inverses) == [1.0, 0.5, 0.33, 0.25, 0.2, 0.17, 0.14, 0.12, 0.11] + assert set(inverses) == {0.5, 1.0, 0.2, 0.33, 0.25, 0.17, 0.14, 0.12, 0.11} + assert sum(inverses) == pytest.approx(2.82) + assert max(inverses) == 1.0 + assert max(inverses) == 1.0 + inverses_iter = iter(inverses) + assert next(inverses_iter) == 1.0 + assert next(inverses_iter) == 0.5 + +@pytest.mark.asyncio +async def test_aiterate() -> None: + assert [inverse async for inverse in inverses] == [1.0, 0.5, 0.33, 0.25, 0.2, 0.17, 0.14, 0.12, 0.11] + +def test_map_example() -> None: + integer_strings: Stream[str] = integers.map(str) + + assert list(integer_strings) == ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'] + +@patch("requests.get", mocks.get_poke) +def test_thread_concurrent_map_example() -> None: + import requests + + pokemon_names: Stream[str] = ( + Stream(range(1, 4)) + .map(lambda i: f"https://pokeapi.co/api/v2/pokemon-species/{i}") + .map(requests.get, concurrency=3) + .map(requests.Response.json) + .map(lambda poke: poke["name"]) + ) + assert list(pokemon_names) == ['bulbasaur', 'ivysaur', 'venusaur'] + +def test_process_concurrent_map_example() -> None: + state: List[int] = [] + # integers are mapped + assert integers.map(state.append, concurrency=4, via="process").count() == 10 + # but the `state` of the main process is not mutated + assert state == [] + +@patch("httpx.AsyncClient.get", lambda self, url: mocks.async_get_poke(url)) +def test_async_amap_example() -> None: + import asyncio + + import httpx + + async def main() -> None: + async with httpx.AsyncClient() as http: + pokemon_names: Stream[str] = ( + Stream(range(1, 4)) + .map(lambda i: f"https://pokeapi.co/api/v2/pokemon-species/{i}") + .amap(http.get, concurrency=3) + .map(httpx.Response.json) + .map(lambda poke: poke["name"]) + ) + # consume as an AsyncIterable[str] + assert [name async for name in pokemon_names] == ['bulbasaur', 'ivysaur', 'venusaur'] - asyncio.run(main()) + asyncio.run(main()) - def test_starmap_example(self) -> None: - from streamable import star +def test_starmap_example() -> None: + from streamable import star - zeros: Stream[int] = ( - Stream(enumerate(integers)) - .map(star(lambda index, integer: index - integer)) - ) + zeros: Stream[int] = ( + Stream(enumerate(integers)) + .map(star(lambda index, integer: index - integer)) + ) - assert list(zeros) == [0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + assert list(zeros) == [0, 0, 0, 0, 0, 0, 0, 0, 0, 0] - def test_foreach_example(self) -> None: - state: List[int] = [] - appending_integers: Stream[int] = integers.foreach(state.append) +def test_foreach_example() -> None: + state: List[int] = [] + appending_integers: Stream[int] = integers.foreach(state.append) - assert list(appending_integers) == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] - assert state == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] + assert list(appending_integers) == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] + assert state == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] + +def test_filter_example() -> None: + even_integers: Stream[int] = integers.filter(lambda n: n % 2 == 0) - def test_filter_example(self) -> None: - even_integers: Stream[int] = integers.filter(lambda n: n % 2 == 0) + assert list(even_integers) == [0, 2, 4, 6, 8] - assert list(even_integers) == [0, 2, 4, 6, 8] +def test_throttle_example() -> None: + from datetime import timedelta - def test_throttle_example(self) -> None: - from datetime import timedelta + three_integers_per_second: Stream[int] = integers.throttle(3, per=timedelta(seconds=1)) - three_integers_per_second: Stream[int] = integers.throttle(3, per=timedelta(seconds=1)) + start = time.perf_counter() + # takes 3s: ceil(10 integers / 3 per_second) - 1 + assert list(three_integers_per_second) == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] + assert 2.99 < time.perf_counter() - start < 3.25 - start = time.perf_counter() - # takes 3s: ceil(10 integers / 3 per_second) - 1 - assert list(three_integers_per_second) == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] - assert 2.99 < time.perf_counter() - start < 3.25 + integers_every_100_millis = ( + integers + .throttle(1, per=timedelta(milliseconds=100)) + ) - integers_every_100_millis = ( - integers - .throttle(1, per=timedelta(milliseconds=100)) - ) + start = time.perf_counter() + # takes 900 millis: (10 integers - 1) * 100 millis + assert list(integers_every_100_millis) == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] + assert 0.89 < time.perf_counter() - start < 0.95 - start = time.perf_counter() - # takes 900 millis: (10 integers - 1) * 100 millis - assert list(integers_every_100_millis) == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] - assert 0.89 < time.perf_counter() - start < 0.95 +def test_group_example() -> None: + global integers_by_parity + integers_by_5: Stream[List[int]] = integers.group(size=5) - def test_group_example(self) -> None: - global integers_by_parity - integers_by_5: Stream[List[int]] = integers.group(size=5) + assert list(integers_by_5) == [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]] - assert list(integers_by_5) == [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]] + integers_by_parity = integers.group(by=lambda n: n % 2) - integers_by_parity = integers.group(by=lambda n: n % 2) + assert list(integers_by_parity) == [[0, 2, 4, 6, 8], [1, 3, 5, 7, 9]] - assert list(integers_by_parity) == [[0, 2, 4, 6, 8], [1, 3, 5, 7, 9]] + from datetime import timedelta - from datetime import timedelta + integers_within_1_sec: Stream[List[int]] = ( + integers + .throttle(2, per=timedelta(seconds=1)) + .group(interval=timedelta(seconds=0.99)) + ) - integers_within_1_sec: Stream[List[int]] = ( - integers - .throttle(2, per=timedelta(seconds=1)) - .group(interval=timedelta(seconds=0.99)) - ) + assert list(integers_within_1_sec) == [[0, 1, 2], [3, 4], [5, 6], [7, 8], [9]] - assert list(integers_within_1_sec) == [[0, 1, 2], [3, 4], [5, 6], [7, 8], [9]] + integers_by_parity_by_2: Stream[List[int]] = ( + integers + .group(by=lambda n: n % 2, size=2) + ) - integers_by_parity_by_2: Stream[List[int]] = ( - integers - .group(by=lambda n: n % 2, size=2) - ) + assert list(integers_by_parity_by_2) == [[0, 2], [1, 3], [4, 6], [5, 7], [8], [9]] - assert list(integers_by_parity_by_2) == [[0, 2], [1, 3], [4, 6], [5, 7], [8], [9]] +def test_groupby_example() -> None: + integers_by_parity: Stream[Tuple[str, List[int]]] = ( + integers + .groupby(lambda n: "odd" if n % 2 else "even") + ) - def test_groupby_example(self) -> None: - integers_by_parity: Stream[Tuple[str, List[int]]] = ( - integers - .groupby(lambda n: "odd" if n % 2 else "even") - ) + assert list(integers_by_parity) == [("even", [0, 2, 4, 6, 8]), ("odd", [1, 3, 5, 7, 9])] - assert list(integers_by_parity) == [("even", [0, 2, 4, 6, 8]), ("odd", [1, 3, 5, 7, 9])] + from streamable import star - from streamable import star + counts_by_parity: Stream[Tuple[str, int]] = ( + integers_by_parity + .map(star(lambda parity, ints: (parity, len(ints)))) + ) - counts_by_parity: Stream[Tuple[str, int]] = ( - integers_by_parity - .map(star(lambda parity, ints: (parity, len(ints)))) - ) + assert list(counts_by_parity) == [("even", 5), ("odd", 5)] - assert list(counts_by_parity) == [("even", 5), ("odd", 5)] +def test_flatten_example() -> None: + global integers_by_parity + even_then_odd_integers: Stream[int] = integers_by_parity.flatten() - def test_flatten_example(self) -> None: - global integers_by_parity - even_then_odd_integers: Stream[int] = integers_by_parity.flatten() + assert list(even_then_odd_integers) == [0, 2, 4, 6, 8, 1, 3, 5, 7, 9] - assert list(even_then_odd_integers) == [0, 2, 4, 6, 8, 1, 3, 5, 7, 9] + mixed_ones_and_zeros: Stream[int] = ( + Stream([[0] * 4, [1] * 4]) + .flatten(concurrency=2) + ) + assert list(mixed_ones_and_zeros) == [0, 1, 0, 1, 0, 1, 0, 1] - mixed_ones_and_zeros: Stream[int] = ( - Stream([[0] * 4, [1] * 4]) - .flatten(concurrency=2) - ) - assert list(mixed_ones_and_zeros) == [0, 1, 0, 1, 0, 1, 0, 1] +def test_catch_example() -> None: + inverses: Stream[float] = ( + integers + .map(lambda n: round(1 / n, 2)) + .catch(ZeroDivisionError, replacement=float("inf")) + ) - def test_catch_example(self) -> None: - inverses: Stream[float] = ( - integers - .map(lambda n: round(1 / n, 2)) - .catch(ZeroDivisionError, replacement=float("inf")) - ) + assert list(inverses) == [float("inf"), 1.0, 0.5, 0.33, 0.25, 0.2, 0.17, 0.14, 0.12, 0.11] - assert list(inverses) == [float("inf"), 1.0, 0.5, 0.33, 0.25, 0.2, 0.17, 0.14, 0.12, 0.11] + import requests + from requests.exceptions import ConnectionError - import requests - from requests.exceptions import ConnectionError + status_codes_ignoring_resolution_errors: Stream[int] = ( + Stream(["https://github.com", "https://foo.bar", "https://github.com/foo/bar"]) + .map(requests.get, concurrency=2) + .catch(ConnectionError, when=lambda exception: "Max retries exceeded with url" in str(exception)) + .map(lambda response: response.status_code) + ) - status_codes_ignoring_resolution_errors: Stream[int] = ( - Stream(["https://github.com", "https://foo.bar", "https://github.com/foo/bar"]) - .map(requests.get, concurrency=2) - .catch(ConnectionError, when=lambda exception: "Max retries exceeded with url" in str(exception)) - .map(lambda response: response.status_code) - ) + assert list(status_codes_ignoring_resolution_errors) == [200, 404] - assert list(status_codes_ignoring_resolution_errors) == [200, 404] + errors: List[Exception] = [] - errors: List[Exception] = [] + def store_error(error: Exception) -> bool: + errors.append(error) + return True - def store_error(error: Exception) -> bool: - errors.append(error) - return True + integers_in_string: Stream[int] = ( + Stream("012345foo6789") + .map(int) + .catch(ValueError, when=store_error) + ) - integers_in_string: Stream[int] = ( - Stream("012345foo6789") - .map(int) - .catch(ValueError, when=store_error) - ) + assert list(integers_in_string) == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] + assert len(errors) == len("foo") - assert list(integers_in_string) == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] - assert len(errors) == len("foo") +def test_truncate_example() -> None: + five_first_integers: Stream[int] = integers.truncate(5) - def test_truncate_example(self) -> None: - five_first_integers: Stream[int] = integers.truncate(5) + assert list(five_first_integers) == [0, 1, 2, 3, 4] - assert list(five_first_integers) == [0, 1, 2, 3, 4] + five_first_integers = integers.truncate(when=lambda n: n == 5) - five_first_integers = integers.truncate(when=lambda n: n == 5) + assert list(five_first_integers) == [0, 1, 2, 3, 4] - assert list(five_first_integers) == [0, 1, 2, 3, 4] +def test_skip_example() -> None: + integers_after_five: Stream[int] = integers.skip(5) - def test_skip_example(self) -> None: - integers_after_five: Stream[int] = integers.skip(5) + assert list(integers_after_five) == [5, 6, 7, 8, 9] - assert list(integers_after_five) == [5, 6, 7, 8, 9] + integers_after_five = integers.skip(until=lambda n: n >= 5) - integers_after_five = integers.skip(until=lambda n: n >= 5) + assert list(integers_after_five) == [5, 6, 7, 8, 9] - assert list(integers_after_five) == [5, 6, 7, 8, 9] +def test_distinct_example() -> None: + distinct_chars: Stream[str] = Stream("foobarfooo").distinct() - def test_distinct_example(self) -> None: - distinct_chars: Stream[str] = Stream("foobarfooo").distinct() + assert list(distinct_chars) == ["f", "o", "b", "a", "r"] - assert list(distinct_chars) == ["f", "o", "b", "a", "r"] + strings_of_distinct_lengths: Stream[str] = ( + Stream(["a", "foo", "bar", "z"]) + .distinct(len) + ) - strings_of_distinct_lengths: Stream[str] = ( - Stream(["a", "foo", "bar", "z"]) - .distinct(len) - ) + assert list(strings_of_distinct_lengths) == ["a", "foo"] - assert list(strings_of_distinct_lengths) == ["a", "foo"] + consecutively_distinct_chars: Stream[str] = ( + Stream("foobarfooo") + .distinct(consecutive_only=True) + ) - consecutively_distinct_chars: Stream[str] = ( - Stream("foobarfooo") - .distinct(consecutive_only=True) - ) + assert list(consecutively_distinct_chars) == ["f", "o", "b", "a", "r", "f", "o"] - assert list(consecutively_distinct_chars) == ["f", "o", "b", "a", "r", "f", "o"] +def test_observe_example() -> None: + assert list(integers.throttle(2, per=timedelta(seconds=1)).observe("integers")) == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] - def test_observe_example(self) -> None: - assert list(integers.throttle(2, per=timedelta(seconds=1)).observe("integers")) == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] +def test_plus_example() -> None: + assert list(integers + integers) == [0, 1, 2, 3 ,4, 5, 6, 7, 8, 9, 0, 1, 2, 3 ,4, 5, 6, 7, 8, 9] - def test_plus_example(self) -> None: - assert list(integers + integers) == [0, 1, 2, 3 ,4, 5, 6, 7, 8, 9, 0, 1, 2, 3 ,4, 5, 6, 7, 8, 9] +def test_zip_example() -> None: + from streamable import star - def test_zip_example(self) -> None: - from streamable import star + cubes: Stream[int] = ( + Stream(zip(integers, integers, integers)) # Stream[Tuple[int, int, int]] + .map(star(lambda a, b, c: a * b * c)) # Stream[int] + ) - cubes: Stream[int] = ( - Stream(zip(integers, integers, integers)) # Stream[Tuple[int, int, int]] - .map(star(lambda a, b, c: a * b * c)) # Stream[int] - ) + assert list(cubes) == [0, 1, 8, 27, 64, 125, 216, 343, 512, 729] - assert list(cubes) == [0, 1, 8, 27, 64, 125, 216, 343, 512, 729] +def test_count_example() -> None: + assert integers.count() == 10 - def test_count_example(self) -> None: - assert integers.count() == 10 +def test_acount_example() -> None: + assert asyncio.run(integers.acount()) == 10 - def test_acount_example(self) -> None: - assert asyncio.run(integers.acount()) == 10 +def test_call_example() -> None: + state: List[int] = [] + appending_integers: Stream[int] = integers.foreach(state.append) + assert appending_integers() is appending_integers + assert state == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] - def test_call_example(self) -> None: +def test_await_example() -> None: + async def test() -> None: state: List[int] = [] appending_integers: Stream[int] = integers.foreach(state.append) - assert appending_integers() is appending_integers + appending_integers is await appending_integers assert state == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] + asyncio.run(test()) - def test_await_example(self) -> None: - async def test() -> None: - state: List[int] = [] - appending_integers: Stream[int] = integers.foreach(state.append) - appending_integers is await appending_integers - assert state == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] - asyncio.run(test()) +def test_non_stopping_exceptions_example() -> None: + from contextlib import suppress - def test_non_stopping_exceptions_example(self) -> None: - from contextlib import suppress - - casted_ints: Iterator[int] = iter(Stream("0123_56789").map(int).group(3).flatten()) - collected_casted_ints: List[int] = [] - - with suppress(ValueError): - collected_casted_ints.extend(casted_ints) - assert collected_casted_ints == [0, 1, 2, 3] + casted_ints: Iterator[int] = iter(Stream("0123_56789").map(int).group(3).flatten()) + collected_casted_ints: List[int] = [] + with suppress(ValueError): collected_casted_ints.extend(casted_ints) - assert collected_casted_ints == [0, 1, 2, 3, 5, 6, 7, 8, 9] - - @patch("httpx.AsyncClient.get", lambda self, url: mocks.async_get_poke(url)) - def test_async_etl_example(self) -> None: # pragma: no cover - import asyncio - import csv - import itertools - from datetime import timedelta - - import httpx - - from streamable import Stream - - async def main() -> None: - with open("./quadruped_pokemons.csv", mode="w") as file: - fields = ["id", "name", "is_legendary", "base_happiness", "capture_rate"] - writer = csv.DictWriter(file, fields, extrasaction='ignore') - writer.writeheader() - - async with httpx.AsyncClient() as http: - pipeline: Stream = ( - # Infinite Stream[int] of Pokemon ids starting from Pokémon #1: Bulbasaur - Stream(itertools.count(1)) - # Limit to 16 requests per second to be friendly to our fellow PokéAPI devs - .throttle(16, per=timedelta(microseconds=1)) - # GET pokemons via 8 concurrent coroutines - .map(lambda poke_id: f"https://pokeapi.co/api/v2/pokemon-species/{poke_id}") - .amap(http.get, concurrency=8) - .foreach(httpx.Response.raise_for_status) - .map(httpx.Response.json) - # Stop the iteration when reaching the 1st pokemon of the 4th generation - .truncate(when=lambda poke: poke["generation"]["name"] == "generation-iv") - .observe("pokemons") - # Keep only quadruped Pokemons - .filter(lambda poke: poke["shape"]["name"] == "quadruped") - .observe("quadruped pokemons") - # Write a batch of pokemons every 5 seconds to the CSV file - .group(interval=timedelta(seconds=5)) - .foreach(writer.writerows) - .flatten() - .observe("written pokemons") - # Catch exceptions and raises the 1st one at the end of the iteration - # .catch(Exception, finally_raise=True) - ) - - await pipeline - - asyncio.run(main()) - - @patch("requests.get", mocks.get_poke) - def test_etl_example(self) -> None: # pragma: no cover - import csv - import itertools - from datetime import timedelta - - import requests - - from streamable import Stream - - with open("./quadruped_pokemons.csv", mode="w") as file: + assert collected_casted_ints == [0, 1, 2, 3] + + collected_casted_ints.extend(casted_ints) + assert collected_casted_ints == [0, 1, 2, 3, 5, 6, 7, 8, 9] + +@patch("httpx.AsyncClient.get", lambda self, url: mocks.async_get_poke(url)) +def test_async_etl_example(tmp_path: Path) -> None: # pragma: no cover + import asyncio + import csv + from datetime import timedelta + from itertools import count + import httpx + from streamable import Stream + + async def main() -> None: + with (tmp_path / "quadruped_pokemons.csv").open("w") as file: fields = ["id", "name", "is_legendary", "base_happiness", "capture_rate"] writer = csv.DictWriter(file, fields, extrasaction='ignore') writer.writeheader() + + async with httpx.AsyncClient() as http_client: + pipeline = ( + # Infinite Stream[int] of Pokemon ids starting from Pokémon #1: Bulbasaur + Stream(count(1)) + # Limit to 16 requests per second to be friendly to our fellow PokéAPI devs + .throttle(16, per=timedelta(microseconds=1)) + # GET pokemons via 8 concurrent coroutines + .map(lambda poke_id: f"https://pokeapi.co/api/v2/pokemon-species/{poke_id}") + .amap(http_client.get, concurrency=8) + .foreach(httpx.Response.raise_for_status) + .map(httpx.Response.json) + # Stop the iteration when reaching the 1st pokemon of the 4th generation + .truncate(when=lambda poke: poke["generation"]["name"] == "generation-iv") + .observe("pokemons") + # Keep only quadruped Pokemons + .filter(lambda poke: poke["shape"]["name"] == "quadruped") + # Write a batch of pokemons every 5 seconds to the CSV file + .group(interval=timedelta(seconds=5)) + .foreach(writer.writerows) + .flatten() + .observe("written pokemons") + # Catch exceptions and raises the 1st one at the end of the iteration + .catch(Exception, finally_raise=True) + ) + + # Start a full async iteration + await pipeline + + asyncio.run(main()) + +@patch("httpx.Client.get", lambda self, url: mocks.get_poke(url)) +def test_etl_example(tmp_path: Path) -> None: # pragma: no cover + import csv + from datetime import timedelta + from itertools import count + import httpx + from streamable import Stream + + with (tmp_path / "quadruped_pokemons.csv").open("w") as file: + fields = ["id", "name", "is_legendary", "base_happiness", "capture_rate"] + writer = csv.DictWriter(file, fields, extrasaction='ignore') + writer.writeheader() + with httpx.Client() as http_client: pipeline = ( # Infinite Stream[int] of Pokemon ids starting from Pokémon #1: Bulbasaur - Stream(itertools.count(1)) + Stream(count(1)) # Limit to 16 requests per second to be friendly to our fellow PokéAPI devs .throttle(16, per=timedelta(microseconds=1)) # GET pokemons concurrently using a pool of 8 threads .map(lambda poke_id: f"https://pokeapi.co/api/v2/pokemon-species/{poke_id}") - .map(requests.get, concurrency=8) - .foreach(requests.Response.raise_for_status) - .map(requests.Response.json) + .map(http_client.get, concurrency=8) + .foreach(httpx.Response.raise_for_status) + .map(httpx.Response.json) # Stop the iteration when reaching the 1st pokemon of the 4th generation .truncate(when=lambda poke: poke["generation"]["name"] == "generation-iv") .observe("pokemons") # Keep only quadruped Pokemons .filter(lambda poke: poke["shape"]["name"] == "quadruped") - .observe("quadruped pokemons") # Write a batch of pokemons every 5 seconds to the CSV file .group(interval=timedelta(seconds=5)) .foreach(writer.writerows) .flatten() .observe("written pokemons") # Catch exceptions and raises the 1st one at the end of the iteration - # .catch(Exception, finally_raise=True) + .catch(Exception, finally_raise=True) ) + # Start a full iteration pipeline() # fmt: on diff --git a/tests/test_stream.py b/tests/test_stream.py index e40a3da4..7dfab27e 100644 --- a/tests/test_stream.py +++ b/tests/test_stream.py @@ -33,16 +33,17 @@ from unittest.mock import patch from parameterized import parameterized # type: ignore +import pytest from streamable import Stream -from streamable.util.asynctools import awaitable_to_coroutine -from streamable.util.functiontools import anostop, asyncify, nostop, star -from streamable.util.iterabletools import ( +from streamable._util._asynctools import awaitable_to_coroutine +from streamable._util._functiontools import anostop, asyncify, nostop, star +from streamable._util._iterabletools import ( sync_to_async_iter, sync_to_bi_iterable, ) from tests.utils import ( - DELTA_RATE, + DELTA, ITERABLE_TYPES, IterableType, N, @@ -70,6 +71,7 @@ throw, throw_for_odd_func, throw_func, + timecoro, timestream, to_list, ) @@ -647,7 +649,7 @@ def test_map_or_foreach_concurrency(self, method, func, concurrency, itype) -> N self.assertAlmostEqual( duration, expected_iteration_duration, - delta=expected_iteration_duration * DELTA_RATE, + delta=expected_iteration_duration * DELTA, msg="Increasing the concurrency of mapping should decrease proportionnally the iteration's duration.", ) @@ -866,7 +868,7 @@ def test_flatten_concurrency(self, flatten, itype, slow) -> None: self.assertAlmostEqual( runtime, expected_runtime, - delta=DELTA_RATE * expected_runtime, + delta=DELTA * expected_runtime, msg="`flatten` should process 'a's and 'b's concurrently and then 'c's without concurrency", ) @@ -2086,7 +2088,7 @@ def f(i): with self.assertRaises( ZeroDivisionError, - msg="If a non caught exception type occurs, then it should be raised.", + msg="If a non-caught exception type occurs, then it should be raised.", ): to_list(stream.catch(TestError), itype=itype) @@ -2905,3 +2907,29 @@ def tracking_new_event_loop(): ) asyncio.new_event_loop = original_new_event_loop + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "stream", + ( + Stream(range(N)).map(slow_identity, concurrency=N // 8), + ( + Stream(range(N)) + .map(lambda i: map(slow_identity, (i,))) + .flatten(concurrency=N // 8) + ), + ), +) +async def test_run_in_executor(stream: Stream) -> None: + """ + Tests that executor-based concurrent mapping/flattening are wrapped + in non-loop-blocking run_in_executor-based async tasks. + """ + concurrency = N // 8 + res: tuple[int, int] + duration, res = await timecoro( + lambda: asyncio.gather(stream.acount(), stream.acount()), times=10 + ) + assert tuple(res) == (N, N) + assert duration == pytest.approx(N * slow_identity_duration / concurrency, rel=0.25) diff --git a/tests/test_util.py b/tests/test_util.py index 32e7d774..00c03dfe 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -1,34 +1,41 @@ -import unittest +import asyncio -from streamable.util.functiontools import sidify, star +from streamable._util._functiontools import sidify, star +from streamable._util._futuretools import ( + ExecutorFIFOFutureResultCollection, + FutureResult, +) -class TestUtil(unittest.TestCase): - def test_sidify(self) -> None: - def f(x: int) -> int: - return x**2 +def test_sidify() -> None: + def f(x: int) -> int: + return x**2 - self.assertEqual(f(2), 4) - self.assertEqual(sidify(f)(2), 2) + assert f(2) == 4 + assert sidify(f)(2) == 2 - # test decoration - @sidify - def g(x): - return x**2 + # test decoration + @sidify + def g(x): + return x**2 - self.assertEqual(g(2), 2) + assert g(2) == 2 - def test_star(self) -> None: - self.assertListEqual( - list(map(star(lambda i, n: i * n), enumerate(range(10)))), - list(map(lambda x: x**2, range(10))), - ) - @star - def mul(a: int, b: int) -> int: - return a * b +def test_star() -> None: + assert list(map(star(lambda i, n: i * n), enumerate(range(10)))) == list( + map(lambda x: x**2, range(10)) + ) - self.assertListEqual( - list(map(mul, enumerate(range(10)))), - list(map(lambda x: x**2, range(10))), - ) + @star + def mul(a: int, b: int) -> int: + return a * b + + assert list(map(mul, enumerate(range(10)))) == list(map(lambda x: x**2, range(10))) + + +def test_os_future_result_collection_anext(): + result = object() + future_results = ExecutorFIFOFutureResultCollection() + future_results.add(FutureResult(result)) + assert asyncio.run(future_results.__anext__()) == result diff --git a/tests/test_visitor.py b/tests/test_visitor.py index ac5cce7c..01b2cd17 100644 --- a/tests/test_visitor.py +++ b/tests/test_visitor.py @@ -1,4 +1,3 @@ -import unittest from typing import cast from streamable.stream import ( @@ -29,51 +28,47 @@ from streamable.visitors import Visitor -class TestVisitor(unittest.TestCase): - def test_visitor(self) -> None: - class ConcreteVisitor(Visitor[None]): - def visit_stream(self, stream: Stream) -> None: - return None +def test_visitor() -> None: + class ConcreteVisitor(Visitor[None]): + def visit_stream(self, stream: Stream) -> None: + return None - visitor = ConcreteVisitor() - visitor.visit_catch_stream(cast(CatchStream, ...)) - visitor.visit_acatch_stream(cast(ACatchStream, ...)) - visitor.visit_distinct_stream(cast(DistinctStream, ...)) - visitor.visit_adistinct_stream(cast(ADistinctStream, ...)) - visitor.visit_filter_stream(cast(FilterStream, ...)) - visitor.visit_afilter_stream(cast(AFilterStream, ...)) - visitor.visit_flatten_stream(cast(FlattenStream, ...)) - visitor.visit_aflatten_stream(cast(AFlattenStream, ...)) - visitor.visit_foreach_stream(cast(ForeachStream, ...)) - visitor.visit_aforeach_stream(cast(AForeachStream, ...)) - visitor.visit_group_stream(cast(GroupStream, ...)) - visitor.visit_agroup_stream(cast(AGroupStream, ...)) - visitor.visit_groupby_stream(cast(GroupbyStream, ...)) - visitor.visit_agroupby_stream(cast(AGroupbyStream, ...)) - visitor.visit_map_stream(cast(MapStream, ...)) - visitor.visit_amap_stream(cast(AMapStream, ...)) - visitor.visit_observe_stream(cast(ObserveStream, ...)) - visitor.visit_skip_stream(cast(SkipStream, ...)) - visitor.visit_askip_stream(cast(ASkipStream, ...)) - visitor.visit_throttle_stream(cast(ThrottleStream, ...)) - visitor.visit_truncate_stream(cast(TruncateStream, ...)) - visitor.visit_atruncate_stream(cast(ATruncateStream, ...)) - visitor.visit_stream(cast(Stream, ...)) + visitor = ConcreteVisitor() + visitor.visit_catch_stream(cast(CatchStream, ...)) + visitor.visit_acatch_stream(cast(ACatchStream, ...)) + visitor.visit_distinct_stream(cast(DistinctStream, ...)) + visitor.visit_adistinct_stream(cast(ADistinctStream, ...)) + visitor.visit_filter_stream(cast(FilterStream, ...)) + visitor.visit_afilter_stream(cast(AFilterStream, ...)) + visitor.visit_flatten_stream(cast(FlattenStream, ...)) + visitor.visit_aflatten_stream(cast(AFlattenStream, ...)) + visitor.visit_foreach_stream(cast(ForeachStream, ...)) + visitor.visit_aforeach_stream(cast(AForeachStream, ...)) + visitor.visit_group_stream(cast(GroupStream, ...)) + visitor.visit_agroup_stream(cast(AGroupStream, ...)) + visitor.visit_groupby_stream(cast(GroupbyStream, ...)) + visitor.visit_agroupby_stream(cast(AGroupbyStream, ...)) + visitor.visit_map_stream(cast(MapStream, ...)) + visitor.visit_amap_stream(cast(AMapStream, ...)) + visitor.visit_observe_stream(cast(ObserveStream, ...)) + visitor.visit_skip_stream(cast(SkipStream, ...)) + visitor.visit_askip_stream(cast(ASkipStream, ...)) + visitor.visit_throttle_stream(cast(ThrottleStream, ...)) + visitor.visit_truncate_stream(cast(TruncateStream, ...)) + visitor.visit_atruncate_stream(cast(ATruncateStream, ...)) + visitor.visit_stream(cast(Stream, ...)) - def test_depth_visitor_example(self): - from streamable.visitors import Visitor - class DepthVisitor(Visitor[int]): - def visit_stream(self, stream: Stream) -> int: - if not stream.upstream: - return 1 - return 1 + stream.upstream.accept(self) +def test_depth_visitor_example(): + from streamable.visitors import Visitor - def depth(stream: Stream) -> int: - return stream.accept(DepthVisitor()) + class DepthVisitor(Visitor[int]): + def visit_stream(self, stream: Stream) -> int: + if not stream.upstream: + return 1 + return 1 + stream.upstream.accept(self) - self.assertEqual( - depth(Stream(range(10)).map(str).foreach(print)), - 3, - msg="DepthVisitor example should work", - ) + def depth(stream: Stream) -> int: + return stream.accept(DepthVisitor()) + + assert depth(Stream(range(10)).map(str).foreach(print)) == 3 diff --git a/tests/utils.py b/tests/utils.py index 145ab8f8..3a3d162f 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -18,8 +18,8 @@ ) from streamable.stream import Stream -from streamable.util.asynctools import awaitable_to_coroutine -from streamable.util.iterabletools import BiIterable +from streamable._util._asynctools import awaitable_to_coroutine +from streamable._util._iterabletools import BiIterable T = TypeVar("T") R = TypeVar("R") @@ -28,6 +28,20 @@ ITERABLE_TYPES: Tuple[IterableType, ...] = (Iterable, AsyncIterable) +class TestError(Exception): + pass + + +DELTA = 0.1 + +# size of the test collections +N = 256 + +src = range(N) + +even_src = range(0, N, 2) + + async def _aiter_to_list(aiterable: AsyncIterable[T]) -> List[T]: return [elem async for elem in aiterable] @@ -46,8 +60,7 @@ def to_list(stream: Stream[T], itype: IterableType) -> List[T]: assert isinstance(stream, Stream) if itype is AsyncIterable: return aiterable_to_list(stream) - else: - return list(stream) + return list(stream) def bi_iterable_to_iter( @@ -55,22 +68,19 @@ def bi_iterable_to_iter( ) -> Union[Iterator[T], AsyncIterator[T]]: if itype is AsyncIterable: return iterable.__aiter__() - else: - return iter(iterable) + return iter(iterable) def anext_or_next(it: Union[Iterator[T], AsyncIterator[T]]) -> T: if isinstance(it, AsyncIterator): return asyncio.run(awaitable_to_coroutine(it.__anext__())) - else: - return next(it) + return next(it) def alist_or_list(iterable: Union[Iterable[T], AsyncIterable[T]]) -> List[T]: if isinstance(iterable, AsyncIterable): return aiterable_to_list(iterable) - else: - return list(iterable) + return list(iterable) def timestream( @@ -85,6 +95,16 @@ def iterate(): return timeit.timeit(iterate, number=times) / times, res +async def timecoro( + afn: Callable[[], Union[Coroutine[None, None, T], "asyncio.Future[T]"]], + times: int = 1, +) -> Tuple[float, T]: + start = time.perf_counter() + for _ in range(times): + res = await afn() + return (time.perf_counter() - start) / times, res + + def identity_sleep(seconds: float) -> float: time.sleep(seconds) return seconds @@ -95,7 +115,6 @@ async def async_identity_sleep(seconds: float) -> float: return seconds -# simulates an I/0 bound function slow_identity_duration = 0.05 @@ -152,19 +171,6 @@ async def f(i): return f -class TestError(Exception): - pass - - -DELTA_RATE = 0.4 -# size of the test collections -N = 256 - -src = range(N) - -even_src = range(0, N, 2) - - def randomly_slowed( func: Callable[[T], R], min_sleep: float = 0.001, max_sleep: float = 0.05 ) -> Callable[[T], R]: diff --git a/version.py b/version.py index 81162ee4..b88f2866 100644 --- a/version.py +++ b/version.py @@ -1 +1 @@ -__version__ = "1.6.5" +__version__ = "1.6.6a0"