Skip to content
Draft
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,12 @@
license="MIT -or- Apache License 2.0",
packages=find_packages(),
package_data={"sniffio": ["py.typed"]},
install_requires=["contextvars >= 2.1; python_version < '3.7'"],
keywords=[
"async",
"trio",
"asyncio",
],
python_requires=">=3.5",
python_requires=">=3.7",
tests_require=['curio'],
classifiers=[
"License :: OSI Approved :: MIT License",
Expand Down
3 changes: 2 additions & 1 deletion sniffio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

__all__ = [
"current_async_library", "AsyncLibraryNotFoundError",
"current_async_library_cvar"
"current_async_library_cvar", "hooks"
]

from ._version import __version__
Expand All @@ -12,4 +12,5 @@
AsyncLibraryNotFoundError,
current_async_library_cvar,
thread_local,
hooks,
)
118 changes: 91 additions & 27 deletions sniffio/_impl.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from functools import partial
from contextvars import ContextVar
from typing import Optional
import sys
Expand All @@ -22,10 +23,68 @@ class AsyncLibraryNotFoundError(RuntimeError):
pass


def _guessed_mode() -> str:
# special support for trio-asyncio
value = thread_local.name
if value is not None:
return value

value = current_async_library_cvar.get()
if value is not None:
return value

# Need to sniff for asyncio
if "asyncio" in sys.modules:
import asyncio
try:
current_task = asyncio.current_task # type: ignore[attr-defined]
except AttributeError:
current_task = asyncio.Task.current_task # type: ignore[attr-defined]
try:
if current_task() is not None:
return "asyncio"
except RuntimeError:
pass

# Sniff for curio (for now)
if 'curio' in sys.modules:
from curio.meta import curio_running
if curio_running():
return 'curio'

raise AsyncLibraryNotFoundError(
"unknown async library, or not in async context"
)


def _noop_hook(v: str) -> str:
return v


_NO_HOOK = object()

# this is publicly mutable, if an async framework wants to implement complex
# async gen hook behaviour it can set
# sniffio.hooks[__package__] = detect_me. As long as it does so before
# defining its async gen finalizer function it is free from race conditions
hooks = {
# could be trio-asyncio or trio-guest mode
# once trio and trio-asyncio and sniffio align trio should set
# sniffio.hooks['trio'] = detect_trio()
"trio": _guessed_mode,
# pre-cache some well-known well behaved asyncgen_finalizer modules
# and so it saves a trip around _is_asyncio(finalizer) when we
# know asyncio is asyncio and curio is curio
"asyncio.base_events": partial(_noop_hook, "asyncio"),
"curio.meta": partial(_noop_hook, "curio"),
_NO_HOOK: _guessed_mode, # no hooks installed, fallback
}


def current_async_library() -> str:
"""Detect which async library is currently running.

The following libraries are currently supported:
The following libraries are currently special-cased:

================ =========== ============================
Library Requires Magic string
Expand Down Expand Up @@ -63,33 +122,38 @@ async def generic_sleep(seconds):
raise RuntimeError(f"Unsupported library {library!r}")

"""
value = thread_local.name
if value is not None:
return value

value = current_async_library_cvar.get()
if value is not None:
return value

# Need to sniff for asyncio
finalizer = sys.get_asyncgen_hooks().finalizer
finalizer_module = getattr(finalizer, "__module__", _NO_HOOK)
if finalizer_module is None: # finalizer is old cython function
if "uvloop" in sys.modules and _is_asyncio(finalizer):
return "asyncio"

try:
hook = hooks[finalizer_module]
except KeyError:
pass
else:
return hook()

# special case asyncio - when implementing an asyncio event loop
# you have to implement _asyncgen_finalizer_hook in your own module
if _is_asyncio(finalizer): # eg qasync _SelectorEventLoop
hooks[finalizer_module] = partial(_noop_hook, "asyncio")
return "asyncio"

# when implementing a twisted reactor you'd need to rely on hooks defined in
# twisted.internet.defer
assert type(finalizer_module) is str
sniffio_name = finalizer_module.rpartition(".")[0]
hooks[finalizer_module] = partial(_noop_hook, sniffio_name)
return sniffio_name


def _is_asyncio(finalizer):
if "asyncio" in sys.modules:
import asyncio
try:
current_task = asyncio.current_task # type: ignore[attr-defined]
except AttributeError:
current_task = asyncio.Task.current_task # type: ignore[attr-defined]
try:
if current_task() is not None:
return "asyncio"
return finalizer == asyncio.get_running_loop()._asyncgen_finalizer_hook
except RuntimeError:
pass

# Sniff for curio (for now)
if 'curio' in sys.modules:
from curio.meta import curio_running
if curio_running():
return 'curio'

raise AsyncLibraryNotFoundError(
"unknown async library, or not in async context"
)
return False
return False
70 changes: 69 additions & 1 deletion sniffio/_tests/test_sniffio.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,30 @@ async def this_is_asyncio():
assert current_async_library() == "asyncio"
ran.append(True)

loop = asyncio.get_event_loop()
loop = asyncio.new_event_loop()
loop.run_until_complete(this_is_asyncio())
assert ran == [True]
loop.close()

with pytest.raises(AsyncLibraryNotFoundError):
current_async_library()


def test_uvloop():
import uvloop

with pytest.raises(AsyncLibraryNotFoundError):
current_async_library()

ran = []

async def this_is_asyncio():
assert current_async_library() == "asyncio"
# Call it a second time to exercise the caching logic
assert current_async_library() == "asyncio"
ran.append(True)

loop = uvloop.new_event_loop()
loop.run_until_complete(this_is_asyncio())
assert ran == [True]
loop.close()
Expand Down Expand Up @@ -79,3 +102,48 @@ async def this_is_curio():

with pytest.raises(AsyncLibraryNotFoundError):
current_async_library()


def test_asyncio_in_curio():
import curio
import asyncio

async def this_is_asyncio():
return current_async_library()

async def this_is_curio():
return current_async_library(), asyncio.run(this_is_asyncio())

assert curio.run(this_is_curio) == ("curio", "asyncio")


def test_curio_in_asyncio():
import asyncio
import curio

async def this_is_curio():
return current_async_library()

async def this_is_asyncio():
return current_async_library(), curio.run(this_is_curio)

assert asyncio.run(this_is_asyncio()) == ("asyncio", "curio")



@pytest.mark.skipif(sys.version_info < (3, 9), reason='to_thread requires 3.9')
def test_curio_in_asyncio_to_thread():
import curio
import sniffio
import asyncio

async def current_framework():
return sniffio.current_async_library()


async def amain():
sniffio.current_async_library()
return await asyncio.to_thread(curio.run, current_framework)


assert asyncio.run(amain()) == "curio"
1 change: 1 addition & 0 deletions test-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
pytest
pytest-cov
curio
uvloop