Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions sky/global_user_state.py

Large diffs are not rendered by default.

14 changes: 8 additions & 6 deletions sky/server/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,14 +648,16 @@ def _set_metrics_env_var(env: Union[Dict[str, str], os._Environ], metrics: bool,
deploy: Whether the server is running in deploy mode, which means
multiple processes might be running.
"""
del deploy
if metrics or os.getenv(constants.ENV_VAR_SERVER_METRICS_ENABLED) == 'true':
env[constants.ENV_VAR_SERVER_METRICS_ENABLED] = 'true'
if deploy:
metrics_dir = os.path.join(tempfile.gettempdir(), 'metrics')
shutil.rmtree(metrics_dir, ignore_errors=True)
os.makedirs(metrics_dir, exist_ok=True)
# Refer to https://prometheus.github.io/client_python/multiprocess/
env['PROMETHEUS_MULTIPROC_DIR'] = metrics_dir
# Always set the metrics dir since we need to collect metrics from
# subprocesses like the executor.
metrics_dir = os.path.join(tempfile.gettempdir(), 'metrics')
shutil.rmtree(metrics_dir, ignore_errors=True)
os.makedirs(metrics_dir, exist_ok=True)
# Refer to https://prometheus.github.io/client_python/multiprocess/
env['PROMETHEUS_MULTIPROC_DIR'] = metrics_dir


def check_server_healthy(
Expand Down
88 changes: 82 additions & 6 deletions sky/server/metrics.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Instrumentation for the API server."""

import contextlib
import functools
import os
import time

Expand All @@ -11,26 +13,57 @@
import uvicorn

from sky import sky_logging
from sky.skylet import constants

# Whether the metrics are enabled, cannot be changed at runtime.
METRICS_ENABLED = os.environ.get(constants.ENV_VAR_SERVER_METRICS_ENABLED,
'false').lower() == 'true'

logger = sky_logging.init_logger(__name__)

# Total number of API server requests, grouped by path, method, and status.
sky_apiserver_requests_total = prom.Counter(
SKY_APISERVER_REQUESTS_TOTAL = prom.Counter(
'sky_apiserver_requests_total',
'Total number of API server requests',
['path', 'method', 'status'],
)

# Time spent processing API server requests, grouped by path, method, and
# status.
sky_apiserver_request_duration_seconds = prom.Histogram(
SKY_APISERVER_REQUEST_DURATION_SECONDS = prom.Histogram(
'sky_apiserver_request_duration_seconds',
'Time spent processing API server requests',
['path', 'method', 'status'],
buckets=(0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0,
buckets=(0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0, 20.0, 30.0,
60.0, 120.0, float('inf')),
)

# Time spent processing requests in executor.
SKY_APISERVER_REQUEST_EXECUTION_DURATION_SECONDS = prom.Histogram(
'sky_apiserver_request_execution_duration_seconds',
'Time spent executing requests in executor',
['request', 'worker'],
buckets=(0.5, 1, 2.5, 5.0, 10.0, 15.0, 25.0, 40.0, 60.0, 90.0, 120.0, 180.0,
float('inf')),
)

# Time spent processing a piece of code, refer to time_it().
SKY_APISERVER_CODE_DURATION_SECONDS = prom.Histogram(
'sky_apiserver_code_duration_seconds',
'Time spent processing code',
['name', 'group'],
buckets=(0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0, 20.0, 30.0,
60.0, 120.0, float('inf')),
)

SKY_APISERVER_EVENT_LOOP_LAG_SECONDS = prom.Histogram(
'sky_apiserver_event_loop_lag_seconds',
'Scheduling delay of the server event loop',
['pid'],
buckets=(0.001, 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1, 2, 5, 20.0,
60.0, float('inf')),
)

metrics_app = fastapi.FastAPI()


Expand Down Expand Up @@ -76,7 +109,7 @@ class PrometheusMiddleware(starlette.middleware.base.BaseHTTPMiddleware):

async def dispatch(self, request: fastapi.Request, call_next):
path = request.url.path
logger.info(f'PROM Middleware Request: {request}, {request.url.path}')
logger.debug(f'PROM Middleware Request: {request}, {request.url.path}')
streaming = _is_streaming_api(path)
if not streaming:
# Exclude streaming APIs, the duration is not meaningful.
Expand All @@ -92,13 +125,56 @@ async def dispatch(self, request: fastapi.Request, call_next):
status_code_group = '5xx'
raise
finally:
sky_apiserver_requests_total.labels(path=path,
SKY_APISERVER_REQUESTS_TOTAL.labels(path=path,
method=method,
status=status_code_group).inc()
if not streaming:
duration = time.time() - start_time
sky_apiserver_request_duration_seconds.labels(
SKY_APISERVER_REQUEST_DURATION_SECONDS.labels(
path=path, method=method,
status=status_code_group).observe(duration)

return response


@contextlib.contextmanager
def time_it(name: str, group: str = 'default'):
"""Context manager to measure and record code execution duration."""
if not METRICS_ENABLED:
yield
else:
start_time = time.time()
try:
yield
finally:
duration = time.time() - start_time
SKY_APISERVER_CODE_DURATION_SECONDS.labels(
name=name, group=group).observe(duration)


def time_me(func):
"""Measure the duration of decorated function."""

@functools.wraps(func)
def wrapper(*args, **kwargs):
if not METRICS_ENABLED:
return func(*args, **kwargs)
name = f'{func.__module__}/{func.__name__}'
with time_it(name, group='function'):
return func(*args, **kwargs)

return wrapper


def time_me_async(func):
"""Measure the duration of decorated async function."""

@functools.wraps(func)
async def async_wrapper(*args, **kwargs):
if not METRICS_ENABLED:
return await func(*args, **kwargs)
name = f'{func.__module__}/{func.__name__}'
with time_it(name, group='function'):
return await func(*args, **kwargs)

return async_wrapper
6 changes: 5 additions & 1 deletion sky/server/requests/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from sky.server import common as server_common
from sky.server import config as server_config
from sky.server import constants as server_constants
from sky.server import metrics as metrics_lib
from sky.server.requests import payloads
from sky.server.requests import preconditions
from sky.server.requests import process
Expand Down Expand Up @@ -373,6 +374,7 @@ def _request_execution_wrapper(request_id: str,
request_task.status = api_requests.RequestStatus.RUNNING
func = request_task.entrypoint
request_body = request_task.request_body
request_name = request_task.name

# Append to the log file instead of overwriting it since there might be
# logs from previous retries.
Expand All @@ -390,7 +392,9 @@ def _request_execution_wrapper(request_id: str,
config = skypilot_config.to_dict()
logger.debug(f'request config: \n'
f'{yaml_utils.dump_yaml_str(dict(config))}')
return_value = func(**request_body.to_kwargs())
with metrics_lib.time_it(name=request_name,
group='request_execution'):
return_value = func(**request_body.to_kwargs())
f.flush()
except KeyboardInterrupt:
logger.info(f'Request {request_id} cancelled by user')
Expand Down
11 changes: 11 additions & 0 deletions sky/server/requests/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from sky.server import common as server_common
from sky.server import constants as server_constants
from sky.server import daemons
from sky.server import metrics as metrics_lib
from sky.server.requests import payloads
from sky.server.requests.serializers import decoders
from sky.server.requests.serializers import encoders
Expand Down Expand Up @@ -460,6 +461,7 @@ def request_lock_path(request_id: str) -> str:

@contextlib.contextmanager
@init_db
@metrics_lib.time_me
def update_request(request_id: str) -> Generator[Optional[Request], None, None]:
"""Get and update a SkyPilot API request."""
request = _get_request_no_lock(request_id)
Expand All @@ -469,6 +471,7 @@ def update_request(request_id: str) -> Generator[Optional[Request], None, None]:


@init_db
@metrics_lib.time_me
def update_request_async(
request_id: str) -> AsyncContextManager[Optional[Request]]:
"""Async version of update_request.
Expand Down Expand Up @@ -517,6 +520,7 @@ async def _get_request_no_lock_async(request_id: str) -> Optional[Request]:


@init_db
@metrics_lib.time_me
def get_latest_request_id() -> Optional[str]:
"""Get the latest request ID."""
assert _DB is not None
Expand All @@ -529,20 +533,23 @@ def get_latest_request_id() -> Optional[str]:


@init_db
@metrics_lib.time_me
def get_request(request_id: str) -> Optional[Request]:
"""Get a SkyPilot API request."""
with filelock.FileLock(request_lock_path(request_id)):
return _get_request_no_lock(request_id)


@init_db_async
@metrics_lib.time_me_async
async def get_request_async(request_id: str) -> Optional[Request]:
"""Async version of get_request."""
async with filelock.AsyncFileLock(request_lock_path(request_id)):
return await _get_request_no_lock_async(request_id)


@init_db
@metrics_lib.time_me
def create_if_not_exists(request: Request) -> bool:
"""Create a SkyPilot API request if it does not exist."""
with filelock.FileLock(request_lock_path(request.request_id)):
Expand All @@ -553,6 +560,7 @@ def create_if_not_exists(request: Request) -> bool:


@init_db_async
@metrics_lib.time_me_async
async def create_if_not_exists_async(request: Request) -> bool:
"""Async version of create_if_not_exists."""
async with filelock.AsyncFileLock(request_lock_path(request.request_id)):
Expand All @@ -563,6 +571,7 @@ async def create_if_not_exists_async(request: Request) -> bool:


@init_db
@metrics_lib.time_me
def get_request_tasks(
status: Optional[List[RequestStatus]] = None,
cluster_names: Optional[List[str]] = None,
Expand Down Expand Up @@ -637,6 +646,7 @@ def get_request_tasks(


@init_db_async
@metrics_lib.time_me_async
async def get_api_request_ids_start_with(incomplete: str) -> List[str]:
"""Get a list of API request ids for shell completion."""
assert _DB is not None
Expand Down Expand Up @@ -711,6 +721,7 @@ def set_request_cancelled(request_id: str) -> None:


@init_db
@metrics_lib.time_me
def _delete_requests(requests: List[Request]):
"""Clean up requests by their IDs."""
id_list_str = ','.join(repr(req.request_id) for req in requests)
Expand Down
27 changes: 27 additions & 0 deletions sky/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
from sky.utils import context
from sky.utils import context_utils
from sky.utils import dag_utils
from sky.utils import perf_utils
from sky.utils import status_lib
from sky.utils import subprocess_utils
from sky.volumes.server import server as volumes_rest
Expand Down Expand Up @@ -421,6 +422,28 @@ async def cleanup_upload_ids():
upload_ids_to_cleanup.pop((upload_id, user_hash))


async def loop_lag_monitor(loop: asyncio.AbstractEventLoop,
interval: float = 0.1) -> None:
target = loop.time() + interval

pid = str(os.getpid())
lag_threshold = perf_utils.get_loop_lag_threshold()

def tick():
nonlocal target
now = loop.time()
lag = max(0.0, now - target)
if lag_threshold is not None and lag > lag_threshold:
logger.warning(f'Event loop lag {lag} seconds exceeds threshold '
f'{lag_threshold} seconds.')
metrics.SKY_APISERVER_EVENT_LOOP_LAG_SECONDS.labels(
pid=pid).observe(lag)
target = now + interval
loop.call_at(target, tick)

loop.call_at(target, tick)


@contextlib.asynccontextmanager
async def lifespan(app: fastapi.FastAPI): # pylint: disable=redefined-outer-name
"""FastAPI lifespan context manager."""
Expand All @@ -446,6 +469,10 @@ async def lifespan(app: fastapi.FastAPI): # pylint: disable=redefined-outer-nam
# can safely ignore the error if the task is already scheduled.
logger.debug(f'Request {event.id} already exists.')
asyncio.create_task(cleanup_upload_ids())
if metrics.METRICS_ENABLED:
# Start monitoring the event loop lag in each server worker
# event loop (process).
asyncio.create_task(loop_lag_monitor(asyncio.get_event_loop()))
yield
# Shutdown: Add any cleanup code here if needed

Expand Down
7 changes: 7 additions & 0 deletions sky/server/uvicorn.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from sky.skylet import constants
from sky.utils import context_utils
from sky.utils import env_options
from sky.utils import perf_utils
from sky.utils import subprocess_utils

logger = sky_logging.init_logger(__name__)
Expand Down Expand Up @@ -198,6 +199,12 @@ def run(self, *args, **kwargs):
context_utils.hijack_sys_attrs()
# Use default loop policy of uvicorn (use uvloop if available).
self.config.setup_event_loop()
lag_threshold = perf_utils.get_loop_lag_threshold()
if lag_threshold is not None:
event_loop = asyncio.get_event_loop()
# Same as set PYTHONASYNCIODEBUG=1, but with custom threshold.
event_loop.set_debug(True)
event_loop.slow_callback_duration = lag_threshold
with self.capture_signals():
asyncio.run(self.serve(*args, **kwargs))

Expand Down
3 changes: 3 additions & 0 deletions sky/skylet/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,3 +505,6 @@

# The directory for file locks.
SKY_LOCKS_DIR = os.path.expanduser('~/.sky/locks')

ENV_VAR_LOOP_LAG_THRESHOLD_MS = (SKYPILOT_ENV_VAR_PREFIX +
'DEBUG_LOOP_LAG_THRESHOLD_MS')
22 changes: 22 additions & 0 deletions sky/utils/perf_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
"""Utility functions for performance monitoring."""
import os
from typing import Optional

from sky import sky_logging
from sky.skylet import constants

logger = sky_logging.init_logger(__name__)


def get_loop_lag_threshold() -> Optional[float]:
"""Get the loop lag threshold from the environment variable."""
lag_threshold = os.getenv(constants.ENV_VAR_LOOP_LAG_THRESHOLD_MS, None)
if lag_threshold is not None:
try:
return float(lag_threshold) / 1000.0
except ValueError:
logger.warning(
f'Invalid value for {constants.ENV_VAR_LOOP_LAG_THRESHOLD_MS}:'
f' {lag_threshold}')
return None
return None
Loading