Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 41 additions & 3 deletions src/pybox/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,22 @@
import logging
import platform
import queue
from typing import TYPE_CHECKING
from uuid import uuid4

try:
from typing import Self
except ImportError:
from typing_extensions import Self

from jupyter_client import AsyncMultiKernelManager, KernelManager, MultiKernelManager
from jupyter_client.multikernelmanager import DuplicateKernelError
from jupyter_core.utils import run_sync

if TYPE_CHECKING:
from types import TracebackType


from pybox.base import BasePyBox, BasePyBoxManager
from pybox.schema import (
ExecutionResponse,
Expand All @@ -21,14 +31,42 @@


class LocalPyBox(BasePyBox):
def __init__(self, km: KernelManager):
def __init__(self, km: KernelManager, mkm: MultiKernelManager | None = None):
self.km = km
self.mkm = mkm
self.client = self.km.client()

@property
def kernel_id(self) -> str | None:
return self.km.kernel_id

def __enter__(self) -> Self:
return self

def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: TracebackType | None,
) -> bool:
# If we use `self.km.shutdown_kernel(now=True)`, the kernel_id will last in the multi_kernel_manager.
if self.mkm is not None:
self.mkm.shutdown_kernel(kernel_id=self.kernel_id, now=True)
# 返回 False 让异常继续传播, 返回 True 会抑制异常
return False

async def __aenter__(self) -> Self:
return self

async def __aexit__(
self, exc_type: type[BaseException] | None, exc_val: BaseException | None, traceback: TracebackType | None
) -> bool:
# If we use `await self.km.shutdown_kernel(now=True)`, the kernel_id will last in the multi_kernel_manager.
if self.mkm is not None:
await self.mkm.shutdown_kernel(kernel_id=self.kernel_id, now=True)
# 返回 False 让异常继续传播, 返回 True 会抑制异常
return False

def run(self, code: str, timeout: int = 60) -> PyBoxOut:
if not self.client.channels_running:
# `wait_for_ready` raises a RuntimeError if the kernel is not ready
Expand Down Expand Up @@ -317,7 +355,7 @@ def start(
# it's OK if the kernel already exists
kid = kernel_id
km = self.kernel_manager.get_kernel(kernel_id=kid)
return LocalPyBox(km=km)
return LocalPyBox(km=km, mkm=self.kernel_manager)

async def astart(
self,
Expand All @@ -337,7 +375,7 @@ async def astart(
# it's OK if the kernel already exists
kid = kernel_id
km = self.async_kernel_manager.get_kernel(kernel_id=kid)
return LocalPyBox(km=km)
return LocalPyBox(km=km, mkm=self.async_kernel_manager)

def shutdown(
self,
Expand Down
71 changes: 44 additions & 27 deletions tests/test_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,29 +26,37 @@ def test_box_lifecycle(self, local_manager: LocalPyBoxManager):
local_manager.shutdown(kernel_id, now=True)
assert kernel_id not in local_manager.kernel_manager

def test_box_lifecycle_w_context_manager(self, local_manager: LocalPyBoxManager):
with local_manager.start() as box:
kernel_id = box.kernel_id

assert kernel_id in local_manager.kernel_manager
assert local_manager.kernel_manager.is_alive(kernel_id)

assert kernel_id not in local_manager.kernel_manager
with pytest.raises(KeyError):
assert not local_manager.kernel_manager.is_alive(kernel_id)

def test_start_w_id(self, local_manager: LocalPyBoxManager):
kernel_id = str(uuid4())
box = local_manager.start(kernel_id)
assert box.kernel_id == kernel_id
assert kernel_id in local_manager.kernel_manager
assert local_manager.kernel_manager.is_alive(box.kernel_id)
local_manager.shutdown(box.kernel_id)
with local_manager.start(kernel_id) as box:
assert box.kernel_id == kernel_id
assert kernel_id in local_manager.kernel_manager
assert local_manager.kernel_manager.is_alive(box.kernel_id)

def test_set_cwd(self, local_manager: LocalPyBoxManager):
# even we don't set the cwd, it defaults to os.getcwd()
# in order to test this is working, we need to change the cwd to a cross platform path
box = local_manager.start(cwd=os.path.expanduser("~"))
test_code = "import os\nprint(os.getcwd())"
out: PyBoxOut = box.run(code=test_code)
assert len(out.data) == 1
assert os.path.expanduser("~") + "\n" == out.data[0]["text/plain"]
local_manager.shutdown(box.kernel_id, now=True)
with local_manager.start(cwd=os.path.expanduser("~")) as box:
test_code = "import os\nprint(os.getcwd())"
out: PyBoxOut = box.run(code=test_code)
assert len(out.data) == 1
assert os.path.expanduser("~") + "\n" == out.data[0]["text/plain"]

@pytest.fixture
def local_box(self, local_manager: LocalPyBoxManager) -> Iterator[LocalPyBox]:
_box = local_manager.start()
yield _box
local_manager.shutdown(_box.kernel_id, now=True)
with local_manager.start() as _box:
yield _box

def test_code_execute(self, local_box: LocalPyBox):
test_code = "print('test')"
Expand Down Expand Up @@ -132,31 +140,40 @@ async def test_box_lifecycle_async(self, async_local_manager: LocalPyBoxManager)
await async_local_manager.ashutdown(kernel_id)
assert kernel_id not in async_local_manager.async_kernel_manager

async def test_box_lifecycle_w_async_context_manager(self, async_local_manager: LocalPyBoxManager):
async with await async_local_manager.astart() as box:
kernel_id = box.kernel_id

assert kernel_id in async_local_manager.async_kernel_manager
assert await async_local_manager.async_kernel_manager.is_alive(kernel_id)

assert kernel_id not in async_local_manager.async_kernel_manager
with pytest.raises(KeyError):
assert not await async_local_manager.async_kernel_manager.is_alive(kernel_id)

async def test_start_async_w_id(self, async_local_manager: LocalPyBoxManager):
kernel_id = str(uuid4())
box = await async_local_manager.astart(kernel_id)
assert box.kernel_id == kernel_id
assert kernel_id in async_local_manager.async_kernel_manager
assert await async_local_manager.async_kernel_manager.is_alive(kernel_id)
await async_local_manager.ashutdown(kernel_id)
async with await async_local_manager.astart(kernel_id) as box:
assert box.kernel_id == kernel_id
assert kernel_id in async_local_manager.async_kernel_manager
assert await async_local_manager.async_kernel_manager.is_alive(kernel_id)

async def test_set_cwd_async(self, async_local_manager: LocalPyBoxManager):
# even we don't set the cwd, it defaults to os.getcwd()
# in order to test this is working, we need to change the cwd to a cross platform path
box = await async_local_manager.astart(cwd=os.path.expanduser("~"))
test_code = "import os\nprint(os.getcwd())"
out: PyBoxOut = await box.arun(code=test_code)
assert len(out.data) == 1
assert os.path.expanduser("~") + "\n" == out.data[0]["text/plain"]
async with await async_local_manager.astart(cwd=os.path.expanduser("~")) as box:
test_code = "import os\nprint(os.getcwd())"
out: PyBoxOut = await box.arun(code=test_code)
assert len(out.data) == 1
assert os.path.expanduser("~") + "\n" == out.data[0]["text/plain"]

@pytest_asyncio.fixture(loop_scope="class")
async def async_local_box(
self,
async_local_manager: LocalPyBoxManager,
) -> AsyncIterator[LocalPyBox]:
_box = await async_local_manager.astart()
yield _box
await async_local_manager.ashutdown(_box.kernel_id, now=True)
async with await async_local_manager.astart() as _box:
yield _box

async def test_code_execute_async(self, async_local_box: LocalPyBox):
test_code = "print('test')"
Expand Down
Loading