Skip to content

Improve performance of qos status #87

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 15 commits into from
44 changes: 44 additions & 0 deletions alembic/versions/67957f85d934_create_qos_rules_table.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
"""create qos_rules table

Revision ID: 67957f85d934
Revises: 01374a5b3c41
Create Date: 2023-10-27 12:04:01.741917

"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects.postgresql import JSONB


# revision identifiers, used by Alembic.
revision = '67957f85d934'
down_revision = '01374a5b3c41'
branch_labels = None
depends_on = None


def upgrade() -> None:
op.create_table(
"qos_rules",
sa.Column("id", sa.Text, primary_key=True),
sa.Column("rule", sa.Text),
sa.Column("condition", sa.Text),
sa.Column("conclusion", sa.Text),
sa.Column("info", sa.Text),
sa.Column("timestamp", sa.TIMESTAMP, default=sa.func.now()),
)
op.drop_column("system_requests", "qos_status")
op.add_column(
"system_requests",
sa.Column(
"qos_status_ids", sa.dialects.postgresql.ARRAY(sa.Text), default=[]
),
)
op.execute("update system_requests set qos_status_ids='{}'")


def downgrade() -> None:
op.drop_column("system_requests", "qos_status_ids")
op.add_column("system_requests", sa.Column("qos_status", JSONB, default={}))
op.execute("UPDATE system_requests SET qos_status='{}'")
op.drop_table("qos_rules")
4 changes: 2 additions & 2 deletions cads_broker/Environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ def wrapped(self, *args, **kwargs):


class Environment:
def __init__(self):
self.number_of_workers = None
def __init__(self, number_of_workers=None):
self.number_of_workers = number_of_workers
self.lock = threading.RLock()
self._enabled = {}
self._values = {}
Expand Down
82 changes: 55 additions & 27 deletions cads_broker/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,19 @@ class AdaptorProperties(BaseModel):
form = sa.Column(JSONB)


class QosRule(BaseModel):
"""QoS Rule ORM model."""

__tablename__ = "qos_rules"

id = sa.Column(sa.Text, primary_key=True)
rule = sa.Column(sa.Text)
condition = sa.Column(sa.Text)
conclusion = sa.Column(sa.Text)
info = sa.Column(sa.Text)
timestamp = sa.Column(sa.TIMESTAMP, default=sa.func.now())


class SystemRequest(BaseModel):
"""System Request ORM model."""

Expand Down Expand Up @@ -72,7 +85,7 @@ class SystemRequest(BaseModel):
sa.Text, sa.ForeignKey("adaptor_properties.hash"), nullable=False
)
entry_point = sa.Column(sa.Text)
qos_status = sa.Column(JSONB, default=dict)
qos_status_ids = sa.Column(sa.dialects.postgresql.ARRAY(sa.Text), default="{}")

__table_args__: tuple[sa.ForeignKeyConstraint, dict[None, None]] = (
sa.ForeignKeyConstraint(
Expand All @@ -81,7 +94,6 @@ class SystemRequest(BaseModel):
{},
)

# joined is temporary
cache_entry = sa.orm.relationship(cacholote.database.CacheEntry, lazy="joined")
adaptor_properties = sa.orm.relationship(AdaptorProperties, lazy="select")

Expand Down Expand Up @@ -307,40 +319,56 @@ def count_users(status: str, entry_point: str, session: sa.orm.Session) -> int:
)


def drop_qos_rules_table(session: sa.orm.Session):
session.query(QosRule).delete()
session.commit()


def add_qos_rule(
rule,
session: sa.orm.Session,
) -> None:
session.add(QosRule(
id=rule.get_uid(),
# this works only if the "request" is not needed in the computation of "conclusion"
conclusion=str(rule.evaluate(request=None)),
info=str(rule.info).replace('"', ""),
condition=str(rule.condition),
))
session.commit()


def get_qos_rule_from_id(qos_rule_id: str, session: sa.orm.Session) -> QosRule | None:
statement = sa.select(QosRule).where(
QosRule.id == qos_rule_id
)
try:
return session.scalars(statement).one()
except sa.exc.NoResultFound:
return None


def get_qos_status_from_request(
request: SystemRequest,
session: sa.orm.Session,
) -> dict[str, list[tuple[str, str]]]:
ret_value: dict[str, list[str]] = {}
for rule_name, rules in request.qos_status.items():
ret_value[rule_name] = []
for rule in rules.values():
ret_value[rule_name].append(
(rule.get("info", ""), rule.get("conclusion", ""))
)
ret_value: dict[str, list[tuple[str, str]]] = {}
for qos_rule_id in request.qos_status_ids:
qos_rule = get_qos_rule_from_id(qos_rule_id=qos_rule_id, session=session)
if qos_rule is not None:
ret_value.setdefault(qos_rule.rule, [])
ret_value[qos_rule.rule].append((qos_rule.info, qos_rule.conclusion))
return ret_value


def set_request_qos_rule(
def add_qos_rule_to_request(
request: SystemRequest,
rule,
session: sa.orm.Session,
):
qos_status = request.qos_status
old_rules = qos_status.get(rule.name, {})
rule_uid = rule.get_uid(request)
if rule_uid in old_rules:
return
old_rules[rule_uid] = {
"conclusion": str(rule.evaluate(request)),
"info": str(rule.info).replace('"', ""),
"condition": str(rule.condition),
}
qos_status[rule.name] = old_rules
session.execute(
sa.update(SystemRequest)
.filter_by(request_uid=request.request_uid)
.values(qos_status=qos_status)
)
request.qos_status_ids = request.qos_status_ids + [rule.get_uid(None)]
session.add(request)
session.commit()


def requeue_request(
Expand Down Expand Up @@ -395,7 +423,7 @@ def set_request_status(
request.response_error = {"message": error_message, "reason": error_reason}
elif status == "running":
request.started_at = sa.func.now()
request.qos_status = {}
request.qos_status_ids = []
# FIXME: logs can't be live updated
request.response_log = json.dumps(log)
request.response_user_visible_log = json.dumps(user_visible_log)
Expand Down
7 changes: 4 additions & 3 deletions cads_broker/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,8 @@ def get_tasks_on_scheduler(dask_scheduler: distributed.Scheduler) -> dict[str, s


class QoSRules:
def __init__(self) -> None:
self.environment = Environment.Environment()
def __init__(self, number_of_workers) -> None:
self.environment = Environment.Environment(number_of_workers=number_of_workers)
self.rules_path = os.getenv("RULES_PATH", "/src/rules.qos")
if os.path.exists(self.rules_path):
self.rules = self.rules_path
Expand Down Expand Up @@ -125,7 +125,7 @@ def from_address(
session_maker: sa.orm.sessionmaker = None,
):
client = distributed.Client(address)
qos_config = QoSRules()
qos_config = QoSRules(get_number_of_workers(client=client))
factory.register_functions()
session_maker = db.ensure_session_obj(session_maker)
rules_hash = get_rules_hash(qos_config.rules_path)
Expand All @@ -137,6 +137,7 @@ def from_address(
qos_config.rules,
qos_config.environment,
rules_hash=rules_hash,
session_maker=session_maker,
),
address=address,
)
Expand Down
5 changes: 4 additions & 1 deletion cads_broker/expressions/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
import operator
import re

import structlog

logger: structlog.stdlib.BoundLogger = structlog.get_logger(__name__)

class FunctionExpression:
def __init__(self, name, args):
Expand All @@ -26,7 +29,7 @@ def evaluate(self, context):
return self.execute(context, *args)
except Exception as e:
args = ",".join(repr(a) for a in args)
print(f"{self.name}({args}): {e}")
logger.warning(f"{self.name}({args}): {e}")
raise


Expand Down
4 changes: 4 additions & 0 deletions cads_broker/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ def register_functions():
"adaptor",
lambda context, *args: context.request.entry_point,
)
expressions.FunctionFactory.FunctionFactory.register_function(
"number_of_workers",
lambda context, *args: context.environment.number_of_workers,
)
expressions.FunctionFactory.FunctionFactory.register_function(
"user_request_count",
lambda context, seconds: database.count_finished_requests_per_user(
Expand Down
24 changes: 17 additions & 7 deletions cads_broker/qos/QoS.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def wrapped(self, *args, **kwargs):


class QoS:
def __init__(self, rules, environment, rules_hash):
def __init__(self, rules, environment, rules_hash, session_maker):
self.lock = threading.RLock()

self.rules_hash = rules_hash
Expand All @@ -47,10 +47,11 @@ def __init__(self, rules, environment, rules_hash):
self.path = rules
self.rules = None
# Read the files from the rules file
self.read_rules()
with session_maker() as session:
self.read_rules(session=session)

@locked
def read_rules(self):
def read_rules(self, session):
"""Read the rule files and populate the rule_set."""
# Create a parser to parse the rules file
parser = RulesParser(self.path)
Expand All @@ -60,18 +61,27 @@ def read_rules(self):

# Parse the rules
parser.parse_rules(self.rules, self.environment)
self.fill_qos_rules_db(session=session)

# Print the rules
self.rules.dump()

@locked
def fill_qos_rules_db(self, session):
database.drop_qos_rules_table(session=session)
for limit in self.rules.global_limits:
database.add_qos_rule(limit, session=session)
for limit in self.rules.user_limits:
database.add_qos_rule(limit, session=session)

@locked
def reload_rules(self, session):
"""Allow a 'hot' reloading of the rules.

For example, a thread could be monitoring the time stamp of the rules
file and call this method.
"""
self.read_rules()
self.read_rules(session=session)
self.reconfigure(session=session)

@locked
Expand All @@ -97,11 +107,11 @@ def can_run(self, request, session):
"""Check if a request can run."""
properties = self._properties(request=request, session=session)
limits = []
for i, limit in enumerate(properties.limits):
for limit in properties.limits:
if limit.full(request):
# performance. avoid interacting with db if limit is already there
if limit.get_uid(request) not in request.qos_status.get(limit.name, []):
database.set_request_qos_rule(request, limit, session)
if limit.get_uid(request) not in request.qos_status_ids:
database.add_qos_rule_to_request(request, limit, session)
limits.append(limit)
session.commit()
permissions = []
Expand Down
4 changes: 2 additions & 2 deletions cads_broker/qos/Rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@ def match(self, request):
def dump(self, out):
out(self)

def get_uid(self, request):
def get_uid(self, request=None):
return str(
hash(f"{self.name} {self.info} {self.condition} : {self.evaluate(request)}")
hash(f"{self.name} {self.info} {self.condition} : {self.evaluate(request=request)}")
)

def __repr__(self):
Expand Down