Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ classifiers = [
]
dependencies = [
"aiohttp[speedups]>=3.9.3",
"jkclient>=0.0.5",
"jupyter-client>=8.6.0",
"python-dotenv>=1",
"pydantic>=2",
"requests>=2",
Expand All @@ -37,7 +39,6 @@ local = [
# kernel dependencies
"ipython>=8.18.1",
"ipykernel>=6.26.0",
"jupyter-client>=8.6.0",
]

[project.urls]
Expand Down
105 changes: 105 additions & 0 deletions src/pybox/kube.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
from __future__ import annotations

import logging
from typing import Any

from dotenv import dotenv_values
from jkclient import CreateKernelRequest, JupyterKernelClient, Kernel
from jupyter_client import AsyncKernelClient, BlockingKernelClient

from pybox import LocalPyBox
from pybox.base import BasePyBoxManager

logger = logging.getLogger(__name__)


class KubePyBoxManager(BasePyBoxManager):
"""Kubernetes kernel pybox, used to create a custom kernel and connect to it to execute code"""

def __init__(
self,
*,
incluster: bool,
env_file: str | None = None,
kernel_env: dict[str, Any] | None = None,
):
self.env_file = env_file
self.kernel_env = dotenv_values(env_file)
if kernel_env:
self.kernel_env.update(kernel_env)

self.client = JupyterKernelClient(incluster=incluster)

def start(self, kernel_id: str, cwd: str, **kwargs) -> LocalPyBox:
"""Retrieve an existing kernel or create a new one in kubernetes

Args:
kernel_id (str): kernel_id
cwd (str): kernel_working_dir

Returns:
LocalPyBox: kubernetes kernel box
"""
env = self.kernel_env.copy()

if kernel_id:
env["KERNEL_ID"] = kernel_id
if cwd:
env["KERNEL_WORKING_DIR"] = cwd
if username := kwargs.pop("username", None):
env["KERNEL_USERNAME"] = username

# Create kernel custom resource
kernel_request = CreateKernelRequest(env=env)
kernel: Kernel = self.client.create(kernel_request, **kwargs)

# New kernel clinet
kernel_client = BlockingKernelClient()
kernel_client.load_connection_info(kernel.conn_info)

return LocalPyBox(kernel_id=kernel_id, client=kernel_client)

async def astart(self, kernel_id: str, cwd: str, **kwargs) -> LocalPyBox:
"""Retrieve an existing kernel or create a new one in kubernetes

Args:
kernel_id (str): kubernetes kernel id
cwd (str): kernel workdir

Returns:
LocalPyBox: An iPython kernel that executes code.
"""
env = self.kernel_env.copy()

if kernel_id:
env["KERNEL_ID"] = kernel_id
if cwd:
env["KERNEL_WORKING_DIR"] = cwd
if username := kwargs.pop("username", None):
env["KERNEL_USERNAME"] = username

# Create kernel custom resource
kernel_request = CreateKernelRequest(env=env)
kernel: Kernel = await self.client.acreate(kernel_request, **kwargs)

# New kernel clinet
kernel_client = AsyncKernelClient()
kernel_client.load_connection_info(kernel.conn_info)

return LocalPyBox(kernel_id=kernel_id, client=kernel_client)

def shutdown(self, kernel_id: str, **kwargs) -> None:
"""Shutdown the kernel in kubernetes.

Args:
kernel_id (str): kernel_id
"""
self.client.delete_by_kernel_id(kernel_id, **kwargs)

async def ashutdown(self, kernel_id: str, **kwargs) -> None:
"""Shutdown the kubernetes kernel by kernel id.

Args:
kernel_id (str): kubernetes kernel id
"""
return await self.client.adelete_by_kernel_id(kernel_id, **kwargs)
100 changes: 100 additions & 0 deletions tests/test_kube.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
from uuid import uuid4

import pytest
from pybox.kube import KubePyBoxManager


@pytest.fixture(scope="module")
def kube_manager() -> KubePyBoxManager:
return KubePyBoxManager(
incluster=False,
kernel_env={
"KERNEL_USERNAME": "tablegpt",
"KERNEL_NAMESPACE": "default",
"KERNEL_IMAGE": "zjuici/tablegpt-kernel:0.1.1",
"KERNEL_WORKING_DIR": "/mnt/data",
"KERNEL_VOLUME_MOUNTS": [
{"name": "shared-vol", "mountPath": "/mnt/data"},
{"name": "ipython-profile-vol", "mountPath": "/opt/startup"},
{
"name": "kernel-launch-vol",
"mountPath": "/usr/local/bin/bootstrap-kernel.sh",
"subPath": "bootstrap-kernel.sh",
},
{
"name": "kernel-launch-vol",
"mountPath": "/usr/local/bin/kernel-launchers/python/scripts/launch_ipykernel.py",
"subPath": "launch_ipykernel.py",
},
],
"KERNEL_VOLUMES": [
{
"name": "shared-vol",
"nfs": {
"server": "10.0.0.29",
"path": "/data/tablegpt-slim-py/data",
},
},
{
"name": "ipython-profile-vol",
"configMap": {"name": "ipython-startup-scripts"},
},
{
"name": "kernel-launch-vol",
"configMap": {
"defaultMode": 0o755,
"name": "kernel-launch-scripts",
},
},
],
"KERNEL_STARTUP_SCRIPTS_PATH": "/opt/startup",
"KERNEL_IDLE_TIMEOUT": "1800",
},
)


@pytest.mark.skip(reason="Start kernel cr need kubernetes environment")
def test_start_with_user(kube_manager: KubePyBoxManager) -> None:
kernel_id = str(uuid4())
box = kube_manager.start(
kernel_id=kernel_id,
cwd="/mnt/data",
username="dev",
)
assert box.kernel_id == kernel_id


@pytest.mark.skip(reason="Start kernel cr need kubernetes environment")
def test_start_without_user(kube_manager: KubePyBoxManager) -> None:
kernel_id = str(uuid4())
box = kube_manager.start(
kernel_id=kernel_id,
cwd="/mnt/data",
)
assert box.kernel_id == kernel_id


@pytest.mark.skip(reason="Start kernel cr need kubernetes environment")
@pytest.mark.asyncio
async def test_start_async(kube_manager: KubePyBoxManager) -> None:
kernel_id = str(uuid4())
box = await kube_manager.astart(
kernel_id=kernel_id,
cwd="/mnt/data",
)
assert box.kernel_id == kernel_id


@pytest.mark.skip(reason="Shutting down kernel cr need kubernetes environment")
def test_shutdown_w_id(kube_manager: KubePyBoxManager) -> None:
kube_manager.shutdown(kernel_id="1918a836-e941-4332-9e6f-dbfe91e5771a")


@pytest.mark.skip(reason="Shutting down kernel cr need kubernetes environment")
@pytest.mark.asyncio
async def test_shutdown_async(kube_manager: KubePyBoxManager) -> None:
await kube_manager.ashutdown(kernel_id="1918a836-e941-4332-9e6f-dbfe91e5771a")


if __name__ == "__main__":
pytest.main()