Skip to content

feat: add AI module for LLM interaction and a heuristic for checking code–docstring consistency #1121

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ dependencies = [
"problog >= 2.2.6,<3.0.0",
"cryptography >=44.0.0,<45.0.0",
"semgrep == 1.113.0",
"pydantic >= 2.11.5,<2.12.0",
"gradio_client == 1.4.3",
]
keywords = []
# https://pypi.org/classifiers/
Expand Down
50 changes: 50 additions & 0 deletions src/macaron/ai/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Macaron AI Module

This module provides the foundation for interacting with Large Language Models (LLMs) in a provider-agnostic way. It includes an abstract client definition, provider-specific client implementations, a client factory, and utility functions for processing responses.

## Module Components

- **ai_client.py**
Defines the abstract [`AIClient`](./ai_client.py) class. This class handles the initialization of LLM configuration from the defaults and serves as the base for all specific AI client implementations.

- **openai_client.py**
Implements the [`OpenAiClient`](./openai_client.py) class, a concrete subclass of [`AIClient`](./ai_client.py). This client interacts with OpenAI-like APIs by sending requests using HTTP and processing the responses. It also validates and structures responses using the tools provided.

- **ai_factory.py**
Contains the [`AIClientFactory`](./ai_factory.py) class, which is responsible for reading provider configuration from the defaults and creating the correct AI client instance.

- **ai_tools.py**
Offers utility functions such as `structure_response` to assist with parsing and validating the JSON response returned by an LLM. These functions ensure that responses conform to a given Pydantic model for easier downstream processing.

## Usage

1. **Configuration:**
The module reads the LLM configuration from the application defaults (using the `defaults` module). Make sure that the `llm` section in your configuration includes valid settings such as `enabled`, `api_key`, `api_endpoint`, `model`, and `context_window`.

2. **Creating a Client:**
Use the [`AIClientFactory`](./ai_factory.py) to create an AI client instance. The factory checks the configured provider and returns a client (e.g., an instance of [`OpenAiClient`](./openai_client.py)) that can be used to invoke the LLM.

Example:
```py
from macaron.ai.ai_factory import AIClientFactory

factory = AIClientFactory()
client = factory.create_client(system_prompt="You are a helpful assistant.")
response = client.invoke("Hello, how can you assist me?")
print(response)
```

3. **Response Processing:**
When a structured response is required, pass a Pydantic model class to the `invoke` method. The [`ai_tools.py`](./ai_tools.py) module takes care of parsing and validating the response to ensure it meets the expected structure.

## Logging and Error Handling

- The module uses Python's logging framework to report important events, such as token usage and warnings when prompts exceed the allowed context window.
- Configuration errors (e.g., missing API key or endpoint) are handled by raising descriptive exceptions, such as those defined in the [`ConfigurationError`](../errors.py).

## Extensibility

The design of the AI module is provider-agnostic. To add support for additional LLM providers:
- Implement a new client by subclassing [`AIClient`](./ai_client.py).
- Add the new client to the [`PROVIDER_MAPPING`](./ai_factory.py).
- Update the configuration defaults accordingly.
2 changes: 2 additions & 0 deletions src/macaron/ai/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Copyright (c) 2025 - 2025, Oracle and/or its affiliates. All rights reserved.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/.
53 changes: 53 additions & 0 deletions src/macaron/ai/ai_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright (c) 2025 - 2025, Oracle and/or its affiliates. All rights reserved.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/.

"""This module defines the abstract AIClient class for implementing AI clients."""

import logging
from abc import ABC, abstractmethod
from typing import Any, TypeVar

from pydantic import BaseModel

T = TypeVar("T", bound=BaseModel)

logger: logging.Logger = logging.getLogger(__name__)


class AIClient(ABC):
"""This abstract class is used to implement ai clients."""

def __init__(self, system_prompt: str, defaults: dict) -> None:
"""
Initialize the AI client.

The LLM configuration is read from defaults.
"""
self.system_prompt = system_prompt
self.defaults = defaults

@abstractmethod
def invoke(
self,
user_prompt: str,
temperature: float = 0.2,
structured_output: type[T] | None = None,
) -> Any:
"""
Invoke the LLM and optionally validate its response.

Parameters
----------
user_prompt: str
The user prompt to send to the LLM.
temperature: float
The temperature for the LLM response.
structured_output: Optional[Type[T]]
The Pydantic model to validate the response against. If provided, the response will be parsed and validated.

Returns
-------
Optional[T | str]
The validated Pydantic model instance if `structured_output` is provided,
or the raw string response if not.
"""
70 changes: 70 additions & 0 deletions src/macaron/ai/ai_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright (c) 2025 - 2025, Oracle and/or its affiliates. All rights reserved.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/.

"""This module defines the AIClientFactory class for creating AI clients based on provider configuration."""

import logging

from macaron.ai.ai_client import AIClient
from macaron.ai.openai_client import OpenAiClient
from macaron.config.defaults import defaults
from macaron.errors import ConfigurationError

logger: logging.Logger = logging.getLogger(__name__)


class AIClientFactory:
"""Factory to create AI clients based on provider configuration."""

PROVIDER_MAPPING: dict[str, type[AIClient]] = {"openai": OpenAiClient}

def __init__(self) -> None:
"""
Initialize the AI client.

The LLM configuration is read from defaults.
"""
self.defaults = self._load_defaults()

def _load_defaults(self) -> dict:
section_name = "llm"
default_values = {
"enabled": False,
"provider": "",
"api_key": "",
"api_endpoint": "",
"model": "",
"context_window": 10000,
}

if defaults.has_section(section_name):
section = defaults[section_name]
default_values["enabled"] = section.getboolean("enabled", default_values["enabled"])
default_values["api_key"] = str(section.get("api_key", default_values["api_key"])).strip().lower()
default_values["api_endpoint"] = (
str(section.get("api_endpoint", default_values["api_endpoint"])).strip().lower()
)
default_values["model"] = str(section.get("model", default_values["model"])).strip().lower()
default_values["provider"] = str(section.get("provider", default_values["provider"])).strip().lower()
default_values["context_window"] = section.getint("context_window", 10000)

if default_values["enabled"]:
for key, value in default_values.items():
if not value:
raise ConfigurationError(
f"AI client configuration '{key}' is required but not set in the defaults."
)

return default_values

def create_client(self, system_prompt: str) -> AIClient | None:
"""Create an AI client based on the configured provider."""
client_class = self.PROVIDER_MAPPING.get(self.defaults["provider"])
if client_class is None:
logger.error("Provider '%s' is not supported.", self.defaults["provider"])
return None
return client_class(system_prompt, self.defaults)

def list_available_providers(self) -> list[str]:
"""List all registered providers."""
return list(self.PROVIDER_MAPPING.keys())
53 changes: 53 additions & 0 deletions src/macaron/ai/ai_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright (c) 2025 - 2025, Oracle and/or its affiliates. All rights reserved.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/.

"""This module provides utility functions for Large Language Model (LLM)."""
import json
import logging
import re
from typing import TypeVar

from pydantic import BaseModel, ValidationError

T = TypeVar("T", bound=BaseModel)

logger: logging.Logger = logging.getLogger(__name__)


def structure_response(response_text: str, response_model: type[T]) -> T | None:
"""
Structure and parse the response from the LLM.

If raw JSON parsing fails, attempts to extract a JSON object from text.

Parameters
----------
response_text: str
The response text from the LLM.
response_model: Type[T]
The Pydantic model to structure the response against.

Returns
-------
T | None
The structured Pydantic model instance.
"""
try:
data = json.loads(response_text)
except json.JSONDecodeError:
logger.debug("Full JSON parse failed; trying to extract JSON from text.")
# If the response is not a valid JSON, try to extract a JSON object from the text.
match = re.search(r"\{.*\}", response_text, re.DOTALL)
if not match:
return None
try:
data = json.loads(match.group(0))
except json.JSONDecodeError as e:
logger.debug("Failed to parse extracted JSON: %s", e)
return None

try:
return response_model.model_validate(data)
except ValidationError as e:
logger.debug("Validation failed against response model: %s", e)
return None
100 changes: 100 additions & 0 deletions src/macaron/ai/openai_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Copyright (c) 2024 - 2025, Oracle and/or its affiliates. All rights reserved.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/.

"""This module provides a client for interacting with a Large Language Model (LLM) that is Openai like."""

import logging
from typing import Any, TypeVar

from pydantic import BaseModel

from macaron.ai.ai_client import AIClient
from macaron.ai.ai_tools import structure_response
from macaron.errors import ConfigurationError, HeuristicAnalyzerValueError
from macaron.util import send_post_http_raw

logger: logging.Logger = logging.getLogger(__name__)

T = TypeVar("T", bound=BaseModel)


class OpenAiClient(AIClient):
"""A client for interacting with a Large Language Model that is OpenAI API like."""

def invoke(
self,
user_prompt: str,
temperature: float = 0.2,
structured_output: type[T] | None = None,
max_tokens: int = 4000,
timeout: int = 30,
) -> Any:
"""
Invoke the LLM and optionally validate its response.

Parameters
----------
user_prompt: str
The user prompt to send to the LLM.
temperature: float
The temperature for the LLM response.
structured_output: Optional[Type[T]]
The Pydantic model to validate the response against. If provided, the response will be parsed and validated.
max_tokens: int
The maximum number of tokens for the LLM response.
timeout: int
The timeout for the HTTP request in seconds.

Returns
-------
Optional[T | str]
The validated Pydantic model instance if `structured_output` is provided,
or the raw string response if not.

Raises
------
HeuristicAnalyzerValueError
If there is an error in parsing or validating the response.
"""
if not self.defaults["enabled"]:
raise ConfigurationError("AI client is not enabled. Please check your configuration.")

if len(user_prompt.split()) > self.defaults["context_window"]:
logger.warning(
"User prompt exceeds context window (%s words). "
"Truncating the prompt to fit within the context window.",
self.defaults["context_window"],
)
user_prompt = " ".join(user_prompt.split()[: self.defaults["context_window"]])

headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.defaults["api_key"]}"}
payload = {
"model": self.defaults["model"],
"messages": [{"role": "system", "content": self.system_prompt}, {"role": "user", "content": user_prompt}],
"temperature": temperature,
"max_tokens": max_tokens,
}

try:
response = send_post_http_raw(
url=self.defaults["api_endpoint"], json_data=payload, headers=headers, timeout=timeout
)
if not response:
raise HeuristicAnalyzerValueError("No response received from the LLM.")
response_json = response.json()
usage = response_json.get("usage", {})

if usage:
usage_str = ", ".join(f"{key} = {value}" for key, value in usage.items())
logger.info("LLM call token usage: %s", usage_str)

message_content = response_json["choices"][0]["message"]["content"]

if not structured_output:
logger.debug("Returning raw message content (no structured output requested).")
return message_content
return structure_response(message_content, structured_output)

except Exception as e:
logger.error("Error during LLM invocation: %s", e)
raise HeuristicAnalyzerValueError(f"Failed to get or validate LLM response: {e}") from e
18 changes: 18 additions & 0 deletions src/macaron/config/defaults.ini
Original file line number Diff line number Diff line change
Expand Up @@ -635,3 +635,21 @@ custom_semgrep_rules_path =
# .yaml prefix. Note, this will be ignored if a path to custom semgrep rules is not provided. This list may not contain
# duplicated elements, meaning that ruleset names must be unique.
disabled_custom_rulesets =

[llm]
# The LLM configuration for Macaron.
# If enabled, the LLM will be used to analyze the results and provide insights.
enabled = False
# The provider for the LLM service.
# Supported providers :
# - openai: OpenAI's GPT models.
provider =
# The API key for the LLM service.
api_key =
# The API endpoint for the LLM service.
api_endpoint =
# The model to use for the LLM service.
model =
# The context window size for the LLM service.
# This is the maximum number of tokens that the LLM can process in a single request.
context_window = 10000
3 changes: 3 additions & 0 deletions src/macaron/malware_analyzer/pypi_heuristics/heuristics.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ class Heuristics(str, Enum):
#: Indicates that the package source code contains suspicious code patterns.
SUSPICIOUS_PATTERNS = "suspicious_patterns"

#: Indicates that the package contains some code that doesn't match the docstrings.
MATCHING_DOCSTRINGS = "matching_docstrings"


class HeuristicResult(str, Enum):
"""Result type indicating the outcome of a heuristic."""
Expand Down
Loading
Loading