diff --git a/README.md b/README.md index 003c4d3aa..b3fe28bcd 100644 --- a/README.md +++ b/README.md @@ -25,6 +25,11 @@ Developed at =0.3, <1", "dict2xml>=1.7.6, <2", "xmltodict>=0.13.0, <1", + "bidict>=0.23, <1" ] [project.urls] diff --git a/src/fundus/publishers/base_objects.py b/src/fundus/publishers/base_objects.py index 8e8f1286e..404ca6927 100644 --- a/src/fundus/publishers/base_objects.py +++ b/src/fundus/publishers/base_objects.py @@ -1,6 +1,6 @@ from collections import defaultdict from textwrap import indent -from typing import Dict, Iterator, List, Optional, Set, Type, Union +from typing import Dict, Iterable, Iterator, List, Optional, Set, Type, Union from urllib.robotparser import RobotFileParser from warnings import warn @@ -11,7 +11,7 @@ from fundus.logging import create_logger from fundus.parser.base_parser import ParserProxy from fundus.scraping.filter import URLFilter -from fundus.scraping.session import session_handler +from fundus.scraping.session import _default_header, session_handler from fundus.scraping.url import NewsMap, RSSFeed, Sitemap, URLSource from fundus.utils.iteration import iterate_all_subclasses @@ -27,43 +27,90 @@ class CustomRobotFileParser(RobotFileParser): This is in order to avoid 403 errors when fetching the robots.txt file. """ + _disallow_training_keywords: Set[str] = { + "machine", + "learning", + "training", + "train", + "model", + "models", + "artificial", + "intelligence", + "large", + "language", + "llm", + "llms", + } + + def __init__(self, url: str, headers: Optional[Dict[str, str]] = None): + self.headers = headers + self.disallows_training: bool = False + self.url = url + super().__init__(url) + # noinspection PyAttributeOutsideInit - def read(self, headers: Optional[Dict[str, str]] = None) -> None: + def read(self) -> None: """Reads the robots.txt URL and feeds it to the parser.""" try: # noinspection PyUnresolvedReferences session = session_handler.get_session() - response = session.get_with_interrupt(self.url, headers=headers) # type: ignore[attr-defined] + response = session.get_with_interrupt(self.url, headers=self.headers) except HTTPError as err: if err.response.status_code in (401, 403): + logger.warning( + f"Robots {self.url!r} disallowed access with status code {err.response.status_code}." + " Defaulting to disallow all." + ) self.disallow_all = True elif 400 <= err.response.status_code < 500: self.allow_all = True else: self.parse(response.text.splitlines()) + def parse(self, lines: Iterable[str]) -> None: + for line in lines: + if line.strip().startswith("#") and set(line.split(" ")) & self._disallow_training_keywords: + self.disallows_training = True + break + super().parse(lines) + class Robots: - def __init__(self, url: str): + def __init__(self, url: str, headers: Optional[Dict[str, str]] = None): self.url = url - self.robots_file_parser = CustomRobotFileParser(url) + self.robots_file_parser = CustomRobotFileParser(url, headers=headers) self.ready: bool = False - def read(self, headers: Optional[Dict[str, str]] = None) -> None: + def _read(self) -> None: try: - self.robots_file_parser.read(headers=headers) + self.robots_file_parser.read() except (ConnectionError, ReadTimeout): logger.warning(f"Could not load robots {self.url!r}. Ignoring robots and continuing.") self.robots_file_parser.allow_all = True self.ready = True + def ensure_ready(self) -> None: + """Ensure that the robots.txt file is read and parsed.""" + if not self.ready: + self._read() + def can_fetch(self, useragent: str, url: str) -> bool: + self.ensure_ready() return self.robots_file_parser.can_fetch(useragent, url) def crawl_delay(self, useragent: str) -> Optional[float]: + self.ensure_ready() delay = self.robots_file_parser.crawl_delay(useragent) return delay if delay is None else float(delay) + def disallows_training(self) -> bool: + self.ensure_ready() + return self.robots_file_parser.disallows_training + + def disallow_all(self) -> bool: + self.ensure_ready() + return self.robots_file_parser.disallow_all + class Publisher: __name__: str @@ -83,8 +130,9 @@ def __init__( sources: List[URLSource], query_parameter: Optional[Dict[str, str]] = None, url_filter: Optional[URLFilter] = None, - request_header: Optional[Dict[str, str]] = None, + request_header: Optional[Dict[str, str]] = _default_header, deprecated: bool = False, + disallows_training: bool = False, suppress_robots: bool = False, ): """Initialization of a new Publisher object @@ -98,6 +146,10 @@ def __init__( appended to crawled URLs url_filter (Optional[URLFilter]): Regex filter to apply determining URLs to be skipped request_header (Optional[Dict[str, str]]): Request header to be used for the GET-request + deprecated (bool): If True, the publisher is deprecated and skipped by default + disallows_training (bool): If True, the publisher disallows training on its articles in it's robots.txt file. + Note that this is only an indicator and users should verify the terms of use of the publisher before + using the articles for training purposes. """ if not (name and domain and parser and sources): @@ -109,7 +161,11 @@ def __init__( self.url_filter = url_filter self.request_header = request_header self.deprecated = deprecated - self.robots = Robots(self.domain + "robots.txt" if self.domain.endswith("/") else self.domain + "/robots.txt") + self.robots = Robots( + url=self.domain + "robots.txt" if self.domain.endswith("/") else self.domain + "/robots.txt", + headers=self.request_header, + ) + self._disallows_training = disallows_training # Temporary fix to compensate for a bug in the RobotsFileParser treating rule lines # like /? as / disallowing the entire site. we could think about replacing the urllib @@ -117,6 +173,8 @@ def __init__( if suppress_robots: self.robots.robots_file_parser.allow_all = True + # we define the dict here manually instead of using default dict so that we can control + # the order in which sources are proceeded. source_mapping: Dict[Type[URLSource], List[URLSource]] = defaultdict(list) for url_source in sources: @@ -129,6 +187,10 @@ def __init__( self._source_mapping = dict(sorted(source_mapping.items(), key=lambda item: self.__SOURCE_ORDER__[item[0]])) + @property + def disallows_training(self) -> bool: + return self._disallows_training or self.robots.disallows_training() + @property def source_mapping(self) -> Dict[Type[URLSource], List[URLSource]]: return self._source_mapping diff --git a/src/fundus/scraping/crawler.py b/src/fundus/scraping/crawler.py index 03783940b..fa0d6c739 100644 --- a/src/fundus/scraping/crawler.py +++ b/src/fundus/scraping/crawler.py @@ -12,10 +12,13 @@ import traceback from abc import ABC, abstractmethod from collections import defaultdict +from concurrent.futures import ThreadPoolExecutor +from concurrent.futures import TimeoutError as FuturesTimeoutError +from concurrent.futures import as_completed from datetime import datetime from functools import lru_cache, partial, wraps from multiprocessing import Manager -from multiprocessing.context import TimeoutError +from multiprocessing.context import TimeoutError as MPTimeoutError from multiprocessing.managers import BaseManager from multiprocessing.pool import MapResult, Pool, ThreadPool from pathlib import Path @@ -57,7 +60,7 @@ from fundus.scraping.filter import ExtractionFilter, Requires, RequiresAll, URLFilter from fundus.scraping.html import CCNewsSource from fundus.scraping.scraper import CCNewsScraper, WebScraper -from fundus.scraping.session import session_handler +from fundus.scraping.session import CrashThread, session_handler from fundus.scraping.url import URLSource from fundus.utils.events import __EVENTS__ from fundus.utils.timeout import Timeout @@ -134,12 +137,17 @@ def get_execution_context(): return thread.name, thread.ident -def queue_wrapper(queue: Queue[Union[_T, Exception]], target: Callable[_P, Iterator[_T]]) -> Callable[_P, None]: +def queue_wrapper( + queue: Queue[Union[_T, Exception]], + target: Callable[_P, Iterator[_T]], + silenced_exceptions: Tuple[Type[BaseException], ...] = (), +) -> Callable[_P, None]: """Wraps the target callable to add its results to the queue instead of returning them directly. Args: queue: The buffer queue. target: A target callable. + silenced_exceptions: Exception types that should be silenced Returns: (Callable[_P, None]) The wrapped target. @@ -150,16 +158,19 @@ def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None: try: for obj in target(*args, **kwargs): queue.put(obj) + except silenced_exceptions: + pass except Exception as err: tb_str = "".join(traceback.TracebackException.from_exception(err).format()) context, ident = get_execution_context() queue.put( RemoteException( - f"There was a(n) {type(err).__name__!r} occurring in {context} with ident {ident}\n{tb_str}" + f"There was a(n) {type(err).__name__!r} occurring in {context} " + f"with ident {ident} ({__EVENTS__.get_alias(ident)})\n{tb_str}" ) ) - logger.debug(f"Encountered remote exception: {err!r}") + logger.debug(f"Encountered remote exception in thread {ident} ({__EVENTS__.get_alias(ident)}): {err!r}") # prevents a race condition where the ThreadPool shuts down before the exception is pulled from the queue time.sleep(0.2) @@ -188,7 +199,7 @@ def pool_queue_iter(handle: MapResult[Any], queue: Queue[Union[_T, Exception]]) except Empty: try: handle.get(timeout=0.1) - except TimeoutError: + except MPTimeoutError: if __EVENTS__.is_event_set("stop"): __EVENTS__.clear_event("stop") break @@ -212,7 +223,6 @@ def __init__(self, *publishers: PublisherType): raise ValueError("param of must include at least one publisher.") __EVENTS__.alias("main-thread") - __EVENTS__.register_event("stop") @abstractmethod def _build_article_iterator( @@ -222,6 +232,7 @@ def _build_article_iterator( extraction_filter: Optional[ExtractionFilter], url_filter: Optional[URLFilter], language_filter: Optional[List[str]], + skip_publishers_disallowing_training: bool = False, ) -> Iterator[Article]: raise NotImplementedError @@ -236,6 +247,7 @@ def crawl( language_filter: Optional[List[str]] = None, only_unique: bool = True, save_to_file: Union[None, str, Path] = None, + skip_publishers_disallowing_training: bool = False, ) -> Iterator[Article]: """Yields articles from initialized scrapers @@ -267,6 +279,9 @@ def crawl( Always returns the first encountered article. Defaults to True. save_to_file (Union[None, str, Path]): If set, the crawled articles will be collected saved to the specified file as a JSON list. + skip_publishers_disallowing_training (bool): If set to True, publishers that disallow training + are skipped. Note that this is an indicator only and users with the intention of using Fundus to gather + training data should always check the publisher's terms of use beforehand. Returns: Iterator[Article]: An iterator yielding objects of type Article. @@ -364,10 +379,17 @@ def callback() -> None: try: with Timeout(seconds=timeout, silent=True, callback=callback, disable=timeout <= 0) as timer: for article in self._build_article_iterator( - tuple(fitting_publishers), error_handling, build_extraction_filter(), url_filter, language_filter + tuple(fitting_publishers), + error_handling, + build_extraction_filter(), + url_filter, + language_filter, + skip_publishers_disallowing_training, ): if max_articles_per_publisher and article_count[article.publisher] == max_articles_per_publisher: - if isinstance(self, Crawler) and not __EVENTS__.is_event_set("stop", article.publisher): + if (isinstance(self, Crawler) and self.threading) and not __EVENTS__.is_event_set( + "stop", article.publisher + ): __EVENTS__.set_event("stop", article.publisher) if sum(article_count.values()) == len(self.publishers) * max_articles_per_publisher: break @@ -465,7 +487,14 @@ def _fetch_articles( extraction_filter: Optional[ExtractionFilter] = None, url_filter: Optional[URLFilter] = None, language_filter: Optional[List[str]] = None, + skip_publishers_disallowing_training: bool = False, ) -> Iterator[Article]: + if skip_publishers_disallowing_training and ( + publisher.disallows_training or publisher.robots.disallows_training() + ): + logger.info(f"Skipping publisher {publisher.name} because it disallows training.") + return + def build_delay() -> Optional[Delay]: if isinstance(self.delay, float): delay = self.delay @@ -481,6 +510,11 @@ def constant_delay() -> float: else: raise TypeError("param of ") + # we "register" the thread in the event dict as soon as possible to avoid that a + # thread crashes before + if self.threading: + __EVENTS__.alias(publisher.name) + scraper = WebScraper( publisher, self.restrict_sources_to, @@ -507,7 +541,7 @@ def _threaded_crawl( self, publishers: Tuple[Publisher, ...], article_task: Callable[[Publisher], Iterator[Article]] ) -> Iterator[Article]: result_queue: Queue[Union[Article, Exception]] = Queue(len(publishers)) - wrapped_article_task = queue_wrapper(result_queue, article_task) + wrapped_article_task = queue_wrapper(result_queue, article_task, silenced_exceptions=(CrashThread,)) pool = ThreadPool(processes=len(publishers) or None) try: with session_handler.context( @@ -529,6 +563,7 @@ def _build_article_iterator( extraction_filter: Optional[ExtractionFilter], url_filter: Optional[URLFilter], language_filter: Optional[List[str]], + skip_publishers_disallowing_training: bool = False, ) -> Iterator[Article]: article_task = partial( self._fetch_articles, @@ -536,6 +571,7 @@ def _build_article_iterator( extraction_filter=extraction_filter, url_filter=url_filter, language_filter=language_filter, + skip_publishers_disallowing_training=skip_publishers_disallowing_training, ) if self.threading: @@ -737,9 +773,43 @@ def _build_article_iterator( extraction_filter: Optional[ExtractionFilter], url_filter: Optional[URLFilter], language_filter: Optional[List[str]], + skip_publishers_disallowing_training: bool = False, **kwargs, ) -> Iterator[Article]: - warc_paths = tuple(self._get_warc_paths()) + if skip_publishers_disallowing_training: + max_workers = self.processes if self.processes > 0 else min(len(publishers), 5) + verified_publishers: List["Publisher"] = [] + + def run_disallow_training(publisher: Publisher) -> bool: + return publisher.disallows_training + + with ThreadPoolExecutor(max_workers=max_workers) as executor: + future_to_publisher = { + executor.submit(run_disallow_training, publisher=publisher): publisher for publisher in publishers + } + + warc_paths = tuple(self._get_warc_paths()) + + try: + for future in as_completed(future_to_publisher.keys(), timeout=30): + publisher = future_to_publisher[future] + try: + if not future.result(): + verified_publishers.append(publisher) + else: + logger.warning(f"Skipping publisher {publisher.name!r} because it disallows training.") + except FuturesTimeoutError: + logger.warning(f"Robots.txt check timed out for {publisher.name!r}", exc_info=False) + except Exception as exc: + logger.warning( + f"Could not verify training policy for {publisher.name!r}: {exc}", exc_info=True + ) + publishers = tuple(verified_publishers) + except FuturesTimeoutError: + logger.warning("Publisher verification timed out, proceeding with all publishers") + + else: + warc_paths = tuple(self._get_warc_paths()) with get_proxy_tqdm(total=len(warc_paths), desc="Process WARC files", disable=self.disable_tqdm) as bar: article_task = partial( diff --git a/src/fundus/scraping/html.py b/src/fundus/scraping/html.py index ba0c91709..6a17fdccb 100644 --- a/src/fundus/scraping/html.py +++ b/src/fundus/scraping/html.py @@ -116,15 +116,10 @@ def __init__( self.delay = delay - # register default events - __EVENTS__.register_event("stop") - # parse robots: self.robots: Optional[Robots] = None if not ignore_robots: self.robots = self.publisher.robots - if not self.robots.ready: - self.publisher.robots.read(headers=self.request_header) if not ignore_crawl_delay: if robots_delay := self.robots.crawl_delay(self.request_header.get("user-agent") or "*"): @@ -191,7 +186,6 @@ def filter_url(u: str) -> bool: if isinstance(error, HTTPError) and error.response.status_code >= 500: logger.warning(f"Skipped {self.publisher.name!r} due to server errors: {error!r}") continue - except Exception as error: logger.error(f"Warning! Skipped requested URL {url!r} because of an unexpected error {error!r}") continue diff --git a/src/fundus/scraping/scraper.py b/src/fundus/scraping/scraper.py index e903b0c29..6dfb973fb 100644 --- a/src/fundus/scraping/scraper.py +++ b/src/fundus/scraping/scraper.py @@ -107,8 +107,6 @@ def __init__( parser_mapping: Dict[str, ParserProxy] = {publisher.name: publisher.parser} super().__init__(*html_sources, parser_mapping=parser_mapping) - __EVENTS__.alias(publisher.name) - class CCNewsScraper(BaseScraper): def __init__(self, source: CCNewsSource): diff --git a/src/fundus/scraping/session.py b/src/fundus/scraping/session.py index e4550c36c..c889da5cf 100644 --- a/src/fundus/scraping/session.py +++ b/src/fundus/scraping/session.py @@ -15,6 +15,12 @@ _default_header = {"user-agent": "Fundus/2.0 (contact: github.com/flairnlp/fundus)"} +class CrashThread(BaseException): + """Is raised to end a thread without relying on the thread ending naturally""" + + pass + + class InterruptableSession(requests.Session): def __init__(self, timeout: Optional[int] = None): super().__init__() @@ -25,7 +31,7 @@ def get_with_interrupt(self, *args, **kwargs) -> requests.Response: This function hands over the request to another thread and checks every second for an interrupt event. If there was an interrupt event, this function raises - a requests.exceptions.Timeout error. + a CrashThread exception. Args: *args: requests.Session.get(*) arguments. @@ -33,6 +39,9 @@ def get_with_interrupt(self, *args, **kwargs) -> requests.Response: Returns: The response. + + Raises: + CrashThread: If the request is interrupted by a stop event. """ def _req(): @@ -56,8 +65,7 @@ def _req(): except Empty: if __EVENTS__.is_event_set("stop"): logger.debug(f"Interrupt request for {url!r}") - response_queue.task_done() - exit(1) + raise CrashThread(f"Request to {url} was interrupted by stop event") else: if isinstance(response, Exception): raise response diff --git a/src/fundus/scraping/url.py b/src/fundus/scraping/url.py index d005412a0..9989bf01b 100644 --- a/src/fundus/scraping/url.py +++ b/src/fundus/scraping/url.py @@ -139,7 +139,6 @@ def __iter__(self) -> Iterator[str]: except (HTTPError, ConnectionError, ReadTimeout) as err: logger.warning(f"Warning! Couldn't parse rss feed {self.url!r} because of {err}") return - except Exception as error: logger.error(f"Warning! Couldn't parse rss feed {self.url!r} because of an unexpected error {error!r}") return @@ -177,7 +176,6 @@ def yield_recursive(sitemap_url: str) -> Iterator[str]: except (HTTPError, ConnectionError, ReadTimeout) as error: logger.warning(f"Warning! Couldn't reach sitemap {sitemap_url!r} because of {error!r}") return - except Exception as error: logger.error( f"Warning! Couldn't reach sitemap {sitemap_url!r} because of an unexpected error {error!r}" diff --git a/src/fundus/utils/events.py b/src/fundus/utils/events.py index cfff2867c..f82993c2d 100644 --- a/src/fundus/utils/events.py +++ b/src/fundus/utils/events.py @@ -1,44 +1,113 @@ import threading from collections import defaultdict -from typing import Any, Dict, Optional, Union +from typing import Dict, List, Optional, Union + +from bidict import bidict from fundus.logging import create_logger logger = create_logger(__name__) +__DEFAULT_EVENTS__: List[str] = ["stop"] + + +class ThreadEventDict(Dict[str, threading.Event]): + """A dictionary that creates threading.Event() objects on demand for certain keys. + + This class behaves like a standard dictionary but automatically creates + `threading.Event` objects when specific keys (provided via `default_events`) + are accessed. This is similar to `defaultdict`, but the auto-creation only + applies to those specific keys. + + Attributes: + _default_events (List[str]): List of event names for which Events will be auto-created. + """ + + def __init__(self, default_events: Optional[List[str]] = None): + """ + Initialize a new ThreadEventDict. + + Args: + default_events: A list of event names for which Event objects + should be automatically created when accessed. + """ + super().__init__() + self._default_events = default_events or [] + + def __getitem__(self, item: str) -> threading.Event: + """ + Get the Event associated with the given item. + + If the key does not exist and is in `_default_events`, a new + `threading.Event` is created, stored, and returned. + + Args: + item: The event name to retrieve. + + Returns: + threading.Event: The event associated with the key. + + Raises: + KeyError: If the key is not present and not in `_default_events`. + """ + try: + return super().__getitem__(item) + except KeyError as e: + if item in self._default_events: + event = threading.Event() + self[item] = event + return event + raise e + class EventDict: - """A thread-safe event dictionary. + """A thread-safe event registry for managing thread-local events with optional aliases. - Events are registered by name and stored per thread in a dictionary, using the - thread's identifier as the key. For example, calling `register_event("stop")` - registers a "stop" event for the current thread's identifier. + This class maintains per-thread event dictionaries, allowing threads to + register, set, and clear named `threading.Event` objects in an isolated + and synchronized manner. - To enhance usability, threads can be assigned aliases. Calling - `register_event("stop", "BR")` registers the "stop" event (if it is not already - registered) for the current thread and automatically creates an alias mapping - "BR" to the thread's identifier. + Aliases can be assigned to thread identifiers for convenience. Each alias + maps uniquely to a thread ID, allowing event access via human-readable names. + + Attributes: + _events (Dict[int, ThreadEventDict]): Mapping of thread IDs to their events. + _aliases (bidict[str, int]): Bidirectional mapping of aliases to thread IDs. + _lock (threading.RLock): A re-entrant lock to ensure thread safety. """ - def __init__(self): - self._events: Dict[int, Dict[str, threading.Event]] = defaultdict(dict) - self._aliases: Dict[Any, int] = {} - self._lock = threading.Lock() + def __init__(self, default_events: Optional[List[str]] = None): + """ + Initialize a new EventDict. + + Args: + default_events: A list of event names that are automatically available + for all threads (e.g., ["stop"]). + """ + self._events: Dict[int, ThreadEventDict] = defaultdict(lambda: ThreadEventDict(default_events)) + self._aliases: bidict[str, int] = bidict() + self._lock = threading.RLock() @staticmethod def _get_identifier() -> int: + """ + Get the current thread's unique identifier. + + Returns: + int: The current thread's identifier. + """ return threading.get_ident() def _resolve(self, key: Union[int, str, None]) -> int: - """Resolves a given key to a thread identifier + """Resolve a key (thread ID, alias, or None) to a thread identifier. - Should only be used within a Lock! + Should only be called while holding the internal lock. Args: - key: Key to resolve + key: The key to resolve. May be a thread ID, alias, or None. Returns: - Resolved thread identifier + int: The resolved thread identifier. """ if key is None: return self._get_identifier() @@ -46,79 +115,177 @@ def _resolve(self, key: Union[int, str, None]) -> int: return key return self._aliases[key] + def _pretty_resolve(self, key: Union[int, str, None]) -> str: + """ + Resolve a key to a human-readable identifier string, including alias if available. + + Should only be called while holding the internal lock. + + Args: + key: Thread ID, alias, or None. + + Returns: + str: A formatted string of the form " (alias)". + """ + resolved = self._resolve(key) + alias = f" ({self._aliases.inv[resolved]})" if resolved in self._aliases.values() else "" + return f"{resolved:<6}{alias}" + def _alias(self, alias: str, key: Optional[int] = None): + """ + Register an alias for a given thread identifier. + + Should only be called while holding the internal lock. + + Args: + alias: The alias to assign. + key: The thread identifier to associate with this alias. + If None, the current thread's identifier is used. + """ self._aliases[alias] = key if key else self._get_identifier() + if (ident := self._resolve(alias)) not in self._events: + # noinspection PyStatementEffect + # Since defaultdict doesn't provide a direct way to create defaults, + # we simulate it by accessing the key. + self._events[ident] logger.debug(f"Registered alias {alias} -> {self._aliases[alias]}") def register_event(self, event: str, key: Union[int, str, None] = None): + """ + Register a new event for the specified thread or alias. + + If the alias does not exist, it is automatically created. + + Args: + event: The name of the event to register. + key: Thread ID, alias, or None (defaults to the current thread). + """ with self._lock: if isinstance(key, str) and key not in self._aliases: self._alias(key) if (resolved := self._resolve(key)) not in self._events: self._events[resolved][event] = threading.Event() - logger.debug(f"Registered event {event!r} for {resolved}") + logger.debug(f"Registered event {event!r} for {self._pretty_resolve(key)}") def set_event(self, event: str, key: Union[int, str, None] = None): + """ + Set (trigger) an event for the specified thread. + + Args: + event: The name of the event to set. + key: Thread ID, alias, or None (defaults to the current thread). + """ with self._lock: self._events[self._resolve(key)][event].set() - logger.debug(f"Set event {event!r} for {self._resolve(key)}") + logger.debug(f"Set event {event!r} for {self._pretty_resolve(key)}") def clear_event(self, event: str, key: Union[int, str, None] = None): + """ + Clear (reset) an event for the specified thread. + + Args: + event: The name of the event to clear. + key: Thread ID, alias, or None (defaults to the current thread). + """ with self._lock: self._events[self._resolve(key)][event].clear() - logger.debug(f"Cleared event {event!r} for {self._resolve(key)}") + logger.debug(f"Cleared event {event!r} for {self._pretty_resolve(key)}") def set_for_all(self, event: Optional[str] = None): - """Set for all registered keys + """Set an event for all registered threads. - If is None, all events for every registered key will be set. - Args: - event: The event to set. Defaults to None. + If `event` is None, all events for every registered thread are set. - Returns: - None + Args: + event: The event name to set. If None, all events are set. """ with self._lock: - for events in self._events.values(): - if event is not None and event in events: - events[event].set() - else: - for flag in events.values(): - flag.set() + if event is None: + for ident, events in self._events.items(): + for name in events: + self.set_event(name, ident) + else: + for ident in self._events: + self.set_event(event, ident) def clear_for_all(self, event: Optional[str] = None): - """Clear for all registered keys + """Clear an event for all registered threads. - If is None, all events for every registered key will be cleared. - Args: - event: The event to clear. Defaults to None. + If `event` is None, all events for every registered thread are cleared. - Returns: - None + Args: + event: The event name to clear. If None, all events are cleared. """ with self._lock: - for events in self._events.values(): - if event is not None and event in events: - events[event].clear() - else: - for flag in events.values(): - flag.clear() + if event is None: + for ident, events in self._events.items(): + for name in events: + self.clear_event(name, ident) + else: + for ident in self._events: + self.clear_event(event, ident) def is_event_set(self, event: str, key: Union[int, str, None] = None) -> bool: + """ + Check if a specific event is set for a given thread. + + Args: + event: The name of the event to check. + key: Thread ID, alias, or None (defaults to the current thread). + + Returns: + bool: True if the event is set, False otherwise. + """ with self._lock: return self._events[self._resolve(key)][event].is_set() def alias(self, alias: str, key: Optional[int] = None): + """ + Public wrapper to register an alias for a thread. + + Args: + alias: The alias name to register. + key: Optional thread identifier to associate with the alias. + Defaults to the current thread if not provided. + """ with self._lock: self._alias(alias, key) + def get_alias(self, ident: int) -> str: + """ + Get the alias associated with a thread identifier. + + Args: + ident: The thread identifier. + + Returns: + str: The alias associated with the identifier. + """ + return self._aliases.inv[ident] + def remove_alias(self, alias: str): + """ + Remove an alias from the alias mapping. + + Args: + alias: The alias to remove. + """ with self._lock: self._aliases.pop(alias, None) def get(self, event: str, key: Optional[Union[int, str, None]] = None) -> threading.Event: + """ + Get the event object associated with the given event name and thread. + + Args: + event: The name of the event to retrieve. + key: Thread ID, alias, or None (defaults to the current thread). + + Returns: + threading.Event: The event object. + """ with self._lock: return self._events[self._resolve(key)][event] -__EVENTS__: EventDict = EventDict() +__EVENTS__: EventDict = EventDict(default_events=__DEFAULT_EVENTS__)