Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@ classifiers = [
]
dependencies = [
"nnterp",
"ipywidgets",
"ipywidgets",
"fastapi",
"uvicorn[standard]",
"python-multipart",
]


Expand Down
85 changes: 85 additions & 0 deletions test_serve.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
"""
Test script for the .serve() functionality.
This creates a minimal dashboard to verify the serving functionality works.
"""

import torch as th
from transformers import AutoTokenizer
from tiny_dashboard import OfflineFeatureCentricDashboard, OnlineFeatureCentricDashboard

def test_offline_serve():
"""Test the offline dashboard serve functionality"""
print("Testing Offline Dashboard serve()...")

# Create minimal test data
tokenizer = AutoTokenizer.from_pretrained("gpt2")

# Create some fake activation examples
max_activation_examples = {
0: [
(5.2, ["Hello", " world", "!"], [1.0, 5.2, 0.5]),
(4.8, ["Test", " example"], [4.8, 2.1]),
],
1: [
(3.5, ["Another", " test"], [3.5, 1.2]),
]
}

dashboard = OfflineFeatureCentricDashboard(
max_activation_examples=max_activation_examples,
tokenizer=tokenizer,
window_size=50,
max_examples=10
)

# Test serving (non-blocking)
print("Starting server on http://localhost:8000")
print("Press Ctrl+C to stop")
server = dashboard.serve(port=8000, open_browser=False, block=True)


def test_online_serve():
"""Test the online dashboard serve functionality"""
print("Testing Online Dashboard serve()...")

tokenizer = AutoTokenizer.from_pretrained("gpt2")

# Create a simple activation function for testing
def get_feature_activation(text: str, feature_indices: tuple[int, ...]) -> th.Tensor:
"""Dummy activation function that returns random values"""
tokens = tokenizer.tokenize(text, add_special_tokens=True)
num_tokens = len(tokens)
num_features = len(feature_indices)

# Return random activations
return th.rand(num_tokens, num_features) * 5.0

dashboard = OnlineFeatureCentricDashboard(
get_feature_activation=get_feature_activation,
tokenizer=tokenizer,
window_size=50
)

# Test serving (non-blocking)
print("Starting server on http://localhost:8001")
print("Press Ctrl+C to stop")
server = dashboard.serve(port=8001, open_browser=False, block=True)


if __name__ == "__main__":
import sys

if len(sys.argv) < 2:
print("Usage: python test_serve.py [offline|online]")
sys.exit(1)

mode = sys.argv[1].lower()

if mode == "offline":
test_offline_serve()
elif mode == "online":
test_online_serve()
else:
print(f"Unknown mode: {mode}")
print("Use 'offline' or 'online'")
sys.exit(1)
109 changes: 109 additions & 0 deletions test_serve_simple.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
"""
Simple test script for the .serve() functionality without requiring model downloads.
"""

import torch as th
from tiny_dashboard import OfflineFeatureCentricDashboard, OnlineFeatureCentricDashboard


class MockTokenizer:
"""Mock tokenizer for testing"""
bos_token = "<s>"

def tokenize(self, text, add_special_tokens=False):
# Simple whitespace tokenization
tokens = text.split()
if add_special_tokens:
tokens = [self.bos_token] + tokens
return tokens

def convert_ids_to_tokens(self, ids):
return [f"token_{i}" for i in ids]

def convert_tokens_to_ids(self, tokens):
# Simple hash-based ID generation
return [hash(token) % 50000 for token in tokens]


def test_offline_serve():
"""Test the offline dashboard serve functionality"""
print("Testing Offline Dashboard serve()...")

tokenizer = MockTokenizer()

# Create some fake activation examples
max_activation_examples = {
0: [
(5.2, ["Hello", " world", "!"], [1.0, 5.2, 0.5]),
(4.8, ["Test", " example"], [4.8, 2.1]),
],
1: [
(3.5, ["Another", " test"], [3.5, 1.2]),
],
42: [
(6.1, ["Feature", " forty", " two"], [2.0, 6.1, 3.0]),
]
}

dashboard = OfflineFeatureCentricDashboard(
max_activation_examples=max_activation_examples,
tokenizer=tokenizer,
window_size=50,
max_examples=10
)

# Test serving (blocking)
print("Starting server on http://localhost:8002")
print("Press Ctrl+C to stop")
server = dashboard.serve(port=8002, open_browser=False, block=True)


def test_online_serve():
"""Test the online dashboard serve functionality"""
print("Testing Online Dashboard serve()...")

tokenizer = MockTokenizer()

# Create a simple activation function for testing
def get_feature_activation(text: str, feature_indices: tuple[int, ...]) -> th.Tensor:
"""Dummy activation function that returns random values"""
tokens = tokenizer.tokenize(text, add_special_tokens=True)
num_tokens = len(tokens)
num_features = len(feature_indices)

# Return random activations scaled by feature index
activations = th.rand(num_tokens, num_features) * 5.0
# Make different features have different patterns
for i, feat_idx in enumerate(feature_indices):
activations[:, i] *= (1 + feat_idx * 0.1)
return activations

dashboard = OnlineFeatureCentricDashboard(
get_feature_activation=get_feature_activation,
tokenizer=tokenizer,
window_size=50
)

# Test serving (blocking)
print("Starting server on http://localhost:8001")
print("Press Ctrl+C to stop")
server = dashboard.serve(port=8001, open_browser=False, block=True)


if __name__ == "__main__":
import sys

if len(sys.argv) < 2:
print("Usage: python test_serve_simple.py [offline|online]")
sys.exit(1)

mode = sys.argv[1].lower()

if mode == "offline":
test_offline_serve()
elif mode == "online":
test_online_serve()
else:
print(f"Unknown mode: {mode}")
print("Use 'offline' or 'online'")
sys.exit(1)
2 changes: 2 additions & 0 deletions tiny_dashboard/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
OnlineFeatureCentricDashboard,
AbstractOnlineFeatureCentricDashboard,
)
from .base_servable_dashboard import BaseServableDashboard
from .visualization_utils import activation_visualization

__all__ = [
"OfflineFeatureCentricDashboard",
"OnlineFeatureCentricDashboard",
"AbstractOnlineFeatureCentricDashboard",
"BaseServableDashboard",
"activation_visualization",
]
115 changes: 115 additions & 0 deletions tiny_dashboard/base_servable_dashboard.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import threading
import webbrowser
from abc import ABC, abstractmethod
from typing import Any
import uvicorn
from fastapi import FastAPI, Request, Form
from fastapi.responses import HTMLResponse


class BaseServableDashboard(ABC):
"""
Abstract base class that provides web serving capability to dashboards.
Subclasses must implement get_initial_html() and handle_request() to define
their specific UI and behavior.
"""

@abstractmethod
def get_initial_html(self) -> str:
"""
Generate the initial HTML page with input form.

Returns:
HTML string containing the form for user input
"""
pass

@abstractmethod
def handle_request(self, form_data: dict[str, Any]) -> str:
"""
Process form submission and generate result HTML.

Args:
form_data: Dictionary containing form field values

Returns:
HTML string containing the analysis results
"""
pass

def serve(self, port: int = 8000, open_browser: bool = True, block: bool = False):
"""
Start a web server to serve the dashboard in a browser.

Args:
port: Port number to run the server on (default: 8000)
open_browser: Whether to automatically open browser (default: True)
block: Whether to block execution (default: False for non-blocking)

Returns:
Server instance if non-blocking, None if blocking
"""
app = FastAPI()

@app.get("/", response_class=HTMLResponse)
async def root():
return self.get_initial_html()

@app.post("/analyze", response_class=HTMLResponse)
async def analyze(request: Request):
form_data = await request.form()
form_dict = dict(form_data)
try:
return self.handle_request(form_dict)
except Exception as e:
# Return error HTML
import traceback
error_html = f"""
<!DOCTYPE html>
<html>
<head>
<title>Error</title>
<style>
body {{ font-family: sans-serif; padding: 20px; }}
.error {{ color: red; background: #fee; padding: 15px; border-radius: 5px; }}
pre {{ background: #f5f5f5; padding: 10px; overflow-x: auto; }}
</style>
</head>
<body>
<h1>Error Processing Request</h1>
<div class="error">
<strong>Error:</strong> {str(e)}
</div>
<h2>Traceback:</h2>
<pre>{traceback.format_exc()}</pre>
<a href="/">Go Back</a>
</body>
</html>
"""
return error_html

# Configure and create server
config = uvicorn.Config(
app,
host="127.0.0.1",
port=port,
log_level="warning",
access_log=False
)
server = uvicorn.Server(config)

url = f"http://127.0.0.1:{port}"
print(f"Dashboard serving at {url}")

if open_browser:
# Open browser after a short delay to ensure server is ready
threading.Timer(0.5, lambda: webbrowser.open(url)).start()

if block:
server.run()
return None
else:
# Run server in background thread
thread = threading.Thread(target=server.run, daemon=True)
thread.start()
return server
Loading