diff --git a/alembic/versions/67957f85d934_create_qos_rules_table.py b/alembic/versions/67957f85d934_create_qos_rules_table.py new file mode 100644 index 00000000..92f73593 --- /dev/null +++ b/alembic/versions/67957f85d934_create_qos_rules_table.py @@ -0,0 +1,44 @@ +"""create qos_rules table + +Revision ID: 67957f85d934 +Revises: 01374a5b3c41 +Create Date: 2023-10-27 12:04:01.741917 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects.postgresql import JSONB + + +# revision identifiers, used by Alembic. +revision = '67957f85d934' +down_revision = '01374a5b3c41' +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.create_table( + "qos_rules", + sa.Column("id", sa.Text, primary_key=True), + sa.Column("rule", sa.Text), + sa.Column("condition", sa.Text), + sa.Column("conclusion", sa.Text), + sa.Column("info", sa.Text), + sa.Column("timestamp", sa.TIMESTAMP, default=sa.func.now()), + ) + op.drop_column("system_requests", "qos_status") + op.add_column( + "system_requests", + sa.Column( + "qos_status_ids", sa.dialects.postgresql.ARRAY(sa.Text), default=[] + ), + ) + op.execute("update system_requests set qos_status_ids='{}'") + + +def downgrade() -> None: + op.drop_column("system_requests", "qos_status_ids") + op.add_column("system_requests", sa.Column("qos_status", JSONB, default={})) + op.execute("UPDATE system_requests SET qos_status='{}'") + op.drop_table("qos_rules") diff --git a/cads_broker/Environment.py b/cads_broker/Environment.py index ea7075eb..60a07ac7 100644 --- a/cads_broker/Environment.py +++ b/cads_broker/Environment.py @@ -23,8 +23,8 @@ def wrapped(self, *args, **kwargs): class Environment: - def __init__(self): - self.number_of_workers = None + def __init__(self, number_of_workers=None): + self.number_of_workers = number_of_workers self.lock = threading.RLock() self._enabled = {} self._values = {} diff --git a/cads_broker/database.py b/cads_broker/database.py index 39c6a9dd..f5c7cfd0 100644 --- a/cads_broker/database.py +++ b/cads_broker/database.py @@ -41,6 +41,19 @@ class AdaptorProperties(BaseModel): form = sa.Column(JSONB) +class QosRule(BaseModel): + """QoS Rule ORM model.""" + + __tablename__ = "qos_rules" + + id = sa.Column(sa.Text, primary_key=True) + rule = sa.Column(sa.Text) + condition = sa.Column(sa.Text) + conclusion = sa.Column(sa.Text) + info = sa.Column(sa.Text) + timestamp = sa.Column(sa.TIMESTAMP, default=sa.func.now()) + + class SystemRequest(BaseModel): """System Request ORM model.""" @@ -72,7 +85,7 @@ class SystemRequest(BaseModel): sa.Text, sa.ForeignKey("adaptor_properties.hash"), nullable=False ) entry_point = sa.Column(sa.Text) - qos_status = sa.Column(JSONB, default=dict) + qos_status_ids = sa.Column(sa.dialects.postgresql.ARRAY(sa.Text), default="{}") __table_args__: tuple[sa.ForeignKeyConstraint, dict[None, None]] = ( sa.ForeignKeyConstraint( @@ -81,7 +94,6 @@ class SystemRequest(BaseModel): {}, ) - # joined is temporary cache_entry = sa.orm.relationship(cacholote.database.CacheEntry, lazy="joined") adaptor_properties = sa.orm.relationship(AdaptorProperties, lazy="select") @@ -307,40 +319,56 @@ def count_users(status: str, entry_point: str, session: sa.orm.Session) -> int: ) +def drop_qos_rules_table(session: sa.orm.Session): + session.query(QosRule).delete() + session.commit() + + +def add_qos_rule( + rule, + session: sa.orm.Session, +) -> None: + session.add(QosRule( + id=rule.get_uid(), + # this works only if the "request" is not needed in the computation of "conclusion" + conclusion=str(rule.evaluate(request=None)), + info=str(rule.info).replace('"', ""), + condition=str(rule.condition), + )) + session.commit() + + +def get_qos_rule_from_id(qos_rule_id: str, session: sa.orm.Session) -> QosRule | None: + statement = sa.select(QosRule).where( + QosRule.id == qos_rule_id + ) + try: + return session.scalars(statement).one() + except sa.exc.NoResultFound: + return None + + def get_qos_status_from_request( request: SystemRequest, + session: sa.orm.Session, ) -> dict[str, list[tuple[str, str]]]: - ret_value: dict[str, list[str]] = {} - for rule_name, rules in request.qos_status.items(): - ret_value[rule_name] = [] - for rule in rules.values(): - ret_value[rule_name].append( - (rule.get("info", ""), rule.get("conclusion", "")) - ) + ret_value: dict[str, list[tuple[str, str]]] = {} + for qos_rule_id in request.qos_status_ids: + qos_rule = get_qos_rule_from_id(qos_rule_id=qos_rule_id, session=session) + if qos_rule is not None: + ret_value.setdefault(qos_rule.rule, []) + ret_value[qos_rule.rule].append((qos_rule.info, qos_rule.conclusion)) return ret_value -def set_request_qos_rule( +def add_qos_rule_to_request( request: SystemRequest, rule, session: sa.orm.Session, ): - qos_status = request.qos_status - old_rules = qos_status.get(rule.name, {}) - rule_uid = rule.get_uid(request) - if rule_uid in old_rules: - return - old_rules[rule_uid] = { - "conclusion": str(rule.evaluate(request)), - "info": str(rule.info).replace('"', ""), - "condition": str(rule.condition), - } - qos_status[rule.name] = old_rules - session.execute( - sa.update(SystemRequest) - .filter_by(request_uid=request.request_uid) - .values(qos_status=qos_status) - ) + request.qos_status_ids = request.qos_status_ids + [rule.get_uid(None)] + session.add(request) + session.commit() def requeue_request( @@ -395,7 +423,7 @@ def set_request_status( request.response_error = {"message": error_message, "reason": error_reason} elif status == "running": request.started_at = sa.func.now() - request.qos_status = {} + request.qos_status_ids = [] # FIXME: logs can't be live updated request.response_log = json.dumps(log) request.response_user_visible_log = json.dumps(user_visible_log) diff --git a/cads_broker/dispatcher.py b/cads_broker/dispatcher.py index f47fa0bd..3d73aa8c 100644 --- a/cads_broker/dispatcher.py +++ b/cads_broker/dispatcher.py @@ -89,8 +89,8 @@ def get_tasks_on_scheduler(dask_scheduler: distributed.Scheduler) -> dict[str, s class QoSRules: - def __init__(self) -> None: - self.environment = Environment.Environment() + def __init__(self, number_of_workers) -> None: + self.environment = Environment.Environment(number_of_workers=number_of_workers) self.rules_path = os.getenv("RULES_PATH", "/src/rules.qos") if os.path.exists(self.rules_path): self.rules = self.rules_path @@ -125,7 +125,7 @@ def from_address( session_maker: sa.orm.sessionmaker = None, ): client = distributed.Client(address) - qos_config = QoSRules() + qos_config = QoSRules(get_number_of_workers(client=client)) factory.register_functions() session_maker = db.ensure_session_obj(session_maker) rules_hash = get_rules_hash(qos_config.rules_path) @@ -137,6 +137,7 @@ def from_address( qos_config.rules, qos_config.environment, rules_hash=rules_hash, + session_maker=session_maker, ), address=address, ) diff --git a/cads_broker/expressions/functions.py b/cads_broker/expressions/functions.py index 84b8d892..92278e2a 100644 --- a/cads_broker/expressions/functions.py +++ b/cads_broker/expressions/functions.py @@ -10,6 +10,9 @@ import operator import re +import structlog + +logger: structlog.stdlib.BoundLogger = structlog.get_logger(__name__) class FunctionExpression: def __init__(self, name, args): @@ -26,7 +29,7 @@ def evaluate(self, context): return self.execute(context, *args) except Exception as e: args = ",".join(repr(a) for a in args) - print(f"{self.name}({args}): {e}") + logger.warning(f"{self.name}({args}): {e}") raise diff --git a/cads_broker/factory.py b/cads_broker/factory.py index 330e7c86..7a264e20 100644 --- a/cads_broker/factory.py +++ b/cads_broker/factory.py @@ -33,6 +33,10 @@ def register_functions(): "adaptor", lambda context, *args: context.request.entry_point, ) + expressions.FunctionFactory.FunctionFactory.register_function( + "number_of_workers", + lambda context, *args: context.environment.number_of_workers, + ) expressions.FunctionFactory.FunctionFactory.register_function( "user_request_count", lambda context, seconds: database.count_finished_requests_per_user( diff --git a/cads_broker/qos/QoS.py b/cads_broker/qos/QoS.py index b955d7b0..4fb929f9 100644 --- a/cads_broker/qos/QoS.py +++ b/cads_broker/qos/QoS.py @@ -26,7 +26,7 @@ def wrapped(self, *args, **kwargs): class QoS: - def __init__(self, rules, environment, rules_hash): + def __init__(self, rules, environment, rules_hash, session_maker): self.lock = threading.RLock() self.rules_hash = rules_hash @@ -47,10 +47,11 @@ def __init__(self, rules, environment, rules_hash): self.path = rules self.rules = None # Read the files from the rules file - self.read_rules() + with session_maker() as session: + self.read_rules(session=session) @locked - def read_rules(self): + def read_rules(self, session): """Read the rule files and populate the rule_set.""" # Create a parser to parse the rules file parser = RulesParser(self.path) @@ -60,10 +61,19 @@ def read_rules(self): # Parse the rules parser.parse_rules(self.rules, self.environment) + self.fill_qos_rules_db(session=session) # Print the rules self.rules.dump() + @locked + def fill_qos_rules_db(self, session): + database.drop_qos_rules_table(session=session) + for limit in self.rules.global_limits: + database.add_qos_rule(limit, session=session) + for limit in self.rules.user_limits: + database.add_qos_rule(limit, session=session) + @locked def reload_rules(self, session): """Allow a 'hot' reloading of the rules. @@ -71,7 +81,7 @@ def reload_rules(self, session): For example, a thread could be monitoring the time stamp of the rules file and call this method. """ - self.read_rules() + self.read_rules(session=session) self.reconfigure(session=session) @locked @@ -97,11 +107,11 @@ def can_run(self, request, session): """Check if a request can run.""" properties = self._properties(request=request, session=session) limits = [] - for i, limit in enumerate(properties.limits): + for limit in properties.limits: if limit.full(request): # performance. avoid interacting with db if limit is already there - if limit.get_uid(request) not in request.qos_status.get(limit.name, []): - database.set_request_qos_rule(request, limit, session) + if limit.get_uid(request) not in request.qos_status_ids: + database.add_qos_rule_to_request(request, limit, session) limits.append(limit) session.commit() permissions = [] diff --git a/cads_broker/qos/Rule.py b/cads_broker/qos/Rule.py index c0dc5150..6661cf76 100644 --- a/cads_broker/qos/Rule.py +++ b/cads_broker/qos/Rule.py @@ -45,9 +45,9 @@ def match(self, request): def dump(self, out): out(self) - def get_uid(self, request): + def get_uid(self, request=None): return str( - hash(f"{self.name} {self.info} {self.condition} : {self.evaluate(request)}") + hash(f"{self.name} {self.info} {self.condition} : {self.evaluate(request=request)}") ) def __repr__(self):