diff --git a/backend/onyx/configs/app_configs.py b/backend/onyx/configs/app_configs.py index 92a15bde6fc..61b7530e07a 100644 --- a/backend/onyx/configs/app_configs.py +++ b/backend/onyx/configs/app_configs.py @@ -121,6 +121,8 @@ os.environ.get("OAUTH_CLIENT_SECRET", os.environ.get("GOOGLE_OAUTH_CLIENT_SECRET")) or "" ) +# OpenID Connect configuration URL for Okta Profile Tool and other OIDC integrations +OPENID_CONFIG_URL = os.environ.get("OPENID_CONFIG_URL") or "" USER_AUTH_SECRET = os.environ.get("USER_AUTH_SECRET", "") @@ -617,6 +619,17 @@ def get_current_tz_offset() -> int: MAX_TOKENS_FOR_FULL_INCLUSION = 4096 + +##### +# Tool Configs +##### +OKTA_PROFILE_TOOL_ENABLED = ( + os.environ.get("OKTA_PROFILE_TOOL_ENABLED", "").lower() == "true" +) +# API token for SSWS auth to Okta Admin API. If set, Users API will be used to enrich profile. +OKTA_API_TOKEN = os.environ.get("OKTA_API_TOKEN") or "" + + ##### # Miscellaneous ##### diff --git a/backend/onyx/llm/models.py b/backend/onyx/llm/models.py index d85d89c2a79..5459955987b 100644 --- a/backend/onyx/llm/models.py +++ b/backend/onyx/llm/models.py @@ -72,6 +72,7 @@ def from_langchain_msg( message_type = MessageType.USER elif isinstance(msg, AIMessage): message_type = MessageType.ASSISTANT + message = message_to_string(msg) return cls( message=message, diff --git a/backend/onyx/tools/built_in_tools.py b/backend/onyx/tools/built_in_tools.py index c8da72a9d66..958f3b49008 100644 --- a/backend/onyx/tools/built_in_tools.py +++ b/backend/onyx/tools/built_in_tools.py @@ -6,6 +6,7 @@ from sqlalchemy import select from sqlalchemy.orm import Session +from onyx.configs.app_configs import OKTA_PROFILE_TOOL_ENABLED from onyx.db.models import Persona from onyx.db.models import Tool as ToolDBModel from onyx.tools.tool_implementations.images.image_generation_tool import ( @@ -17,6 +18,9 @@ from onyx.tools.tool_implementations.internet_search.providers import ( get_available_providers, ) +from onyx.tools.tool_implementations.okta_profile.okta_profile_tool import ( + OktaProfileTool, +) from onyx.tools.tool_implementations.search.search_tool import SearchTool from onyx.tools.tool import Tool from onyx.utils.logger import setup_logger @@ -63,6 +67,19 @@ class InCodeToolInfo(TypedDict): if (bool(get_available_providers())) else [] ), + # Show Okta Profile tool if the environment variables are set + *( + [ + InCodeToolInfo( + cls=OktaProfileTool, + description="The Okta Profile Action allows the assistant to fetch user information from Okta.", + in_code_tool_id=OktaProfileTool.__name__, + display_name=OktaProfileTool._DISPLAY_NAME, + ) + ] + if OKTA_PROFILE_TOOL_ENABLED + else [] + ), ] diff --git a/backend/onyx/tools/tool.py b/backend/onyx/tools/tool.py index 2c7f53647f0..65f6c91c2a3 100644 --- a/backend/onyx/tools/tool.py +++ b/backend/onyx/tools/tool.py @@ -5,12 +5,12 @@ from typing import TYPE_CHECKING from typing import TypeVar -from onyx.llm.interfaces import LLM -from onyx.llm.models import PreviousMessage from onyx.utils.special_types import JSON_ro if TYPE_CHECKING: + from onyx.llm.interfaces import LLM + from onyx.llm.models import PreviousMessage from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder from onyx.tools.message import ToolCallSummary from onyx.tools.models import ToolResponse @@ -53,8 +53,8 @@ def build_tool_message_content( def get_args_for_non_tool_calling_llm( self, query: str, - history: list[PreviousMessage], - llm: LLM, + history: list["PreviousMessage"], + llm: "LLM", force_run: bool = False, ) -> dict[str, Any] | None: raise NotImplementedError diff --git a/backend/onyx/tools/tool_constructor.py b/backend/onyx/tools/tool_constructor.py index dd3d080a0e3..8ba0a1c6c24 100644 --- a/backend/onyx/tools/tool_constructor.py +++ b/backend/onyx/tools/tool_constructor.py @@ -14,6 +14,10 @@ from onyx.configs.app_configs import AZURE_DALLE_API_VERSION from onyx.configs.app_configs import AZURE_DALLE_DEPLOYMENT_NAME from onyx.configs.app_configs import IMAGE_MODEL_NAME +from onyx.configs.app_configs import OAUTH_CLIENT_ID +from onyx.configs.app_configs import OAUTH_CLIENT_SECRET +from onyx.configs.app_configs import OKTA_API_TOKEN +from onyx.configs.app_configs import OPENID_CONFIG_URL from onyx.configs.chat_configs import NUM_INTERNET_SEARCH_CHUNKS from onyx.configs.chat_configs import NUM_INTERNET_SEARCH_RESULTS from onyx.configs.model_configs import GEN_AI_TEMPERATURE @@ -41,6 +45,9 @@ from onyx.tools.tool_implementations.internet_search.internet_search_tool import ( InternetSearchTool, ) +from onyx.tools.tool_implementations.okta_profile.okta_profile_tool import ( + OktaProfileTool, +) from onyx.tools.tool_implementations.search.search_tool import SearchTool from onyx.tools.utils import compute_all_tool_tokens from onyx.tools.utils import explicit_tool_calling_supported @@ -265,6 +272,33 @@ def construct_tools( "Internet search tool requires a Bing or Exa API key, please contact your Onyx admin to get it added!" ) + # Handle Okta Profile Tool + elif tool_cls.__name__ == OktaProfileTool.__name__: + if not user_oauth_token: + raise ValueError( + "Okta Profile Tool requires user OAuth token but none found" + ) + + if not all([OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET, OPENID_CONFIG_URL]): + raise ValueError( + "Okta Profile Tool requires OAuth configuration to be set" + ) + + if not OKTA_API_TOKEN: + raise ValueError( + "Okta Profile Tool requires OKTA_API_TOKEN to be set" + ) + + tool_dict[db_tool_model.id] = [ + OktaProfileTool( + access_token=user_oauth_token, + client_id=OAUTH_CLIENT_ID, + client_secret=OAUTH_CLIENT_SECRET, + openid_config_url=OPENID_CONFIG_URL, + okta_api_token=OKTA_API_TOKEN, + ) + ] + # Handle custom tools elif db_tool_model.openapi_schema: if not custom_tool_config: diff --git a/backend/onyx/tools/tool_implementations/okta_profile/okta_profile_tool.py b/backend/onyx/tools/tool_implementations/okta_profile/okta_profile_tool.py new file mode 100644 index 00000000000..11e841011ce --- /dev/null +++ b/backend/onyx/tools/tool_implementations/okta_profile/okta_profile_tool.py @@ -0,0 +1,243 @@ +import json +from collections.abc import Generator +from typing import Any +from urllib.parse import urlparse + +import requests +from pydantic import BaseModel + +from onyx.llm.interfaces import LLM +from onyx.llm.models import PreviousMessage +from onyx.llm.utils import message_to_string +from onyx.prompts.constants import GENERAL_SEP_PAT +from onyx.tools.base_tool import BaseTool +from onyx.tools.models import ToolResponse +from onyx.utils.logger import setup_logger +from onyx.utils.special_types import JSON_ro + + +logger = setup_logger() + + +OKTA_PROFILE_RESPONSE_ID = "okta_profile" + +OKTA_TOOL_DESCRIPTION = """ +The Okta profile tool can retrieve user profile information from Okta including: +- User ID, status, creation date +- Profile details like name, email, department, location, title, manager, and more +- Account status and activity +""" + + +class OIDCConfig(BaseModel): + issuer: str + jwks_uri: str | None = None + userinfo_endpoint: str | None = None + introspection_endpoint: str | None = None + token_endpoint: str | None = None + + +class OktaProfileTool(BaseTool): + _NAME = "get_okta_profile" + _DESCRIPTION = "This tool is used to get the user's profile information." + _DISPLAY_NAME = "Okta Profile" + + def __init__( + self, + access_token: str, + client_id: str, + client_secret: str, + openid_config_url: str, + okta_api_token: str, + request_timeout_sec: int = 15, + ) -> None: + self.access_token = access_token + self.client_id = client_id + self.client_secret = client_secret + self.openid_config_url = openid_config_url + self.request_timeout_sec = request_timeout_sec + + # Extract Okta org URL from OpenID config URL using URL parsing + # OpenID config URL format: https://{org}.okta.com/.well-known/openid_configuration + parsed_url = urlparse(self.openid_config_url) + self.okta_org_url = f"{parsed_url.scheme}://{parsed_url.netloc}" + self.okta_api_token = okta_api_token + + self._oidc_config: OIDCConfig | None = None + + @property + def name(self) -> str: + return self._NAME + + @property + def description(self) -> str: + return self._DESCRIPTION + + @property + def display_name(self) -> str: + return self._DISPLAY_NAME + + def tool_definition(self) -> dict: + return { + "type": "function", + "function": { + "name": self.name, + "description": self.description, + "parameters": {"type": "object", "properties": {}, "required": []}, + }, + } + + def _load_oidc_config(self) -> OIDCConfig: + if self._oidc_config is not None: + return self._oidc_config + + resp = requests.get(self.openid_config_url, timeout=self.request_timeout_sec) + resp.raise_for_status() + data = resp.json() + self._oidc_config = OIDCConfig(**data) + logger.debug(f"Loaded OIDC config from {self.openid_config_url}") + return self._oidc_config + + def _call_userinfo(self, access_token: str) -> dict[str, Any] | None: + try: + cfg = self._load_oidc_config() + if not cfg.userinfo_endpoint: + logger.info("OIDC config missing userinfo_endpoint") + return None + headers = {"Authorization": f"Bearer {access_token}"} + r = requests.get( + cfg.userinfo_endpoint, headers=headers, timeout=self.request_timeout_sec + ) + if r.status_code == 200: + return r.json() + logger.info( + f"userinfo call returned status {r.status_code}: {r.text[:200]}" + ) + return None + except requests.RequestException as e: + logger.debug(f"userinfo request failed: {e}") + return None + + def _call_introspection(self, access_token: str) -> dict[str, Any] | None: + try: + cfg = self._load_oidc_config() + if not cfg.introspection_endpoint: + logger.info("OIDC config missing introspection_endpoint") + return None + data = { + "token": access_token, + "token_type_hint": "access_token", + } + auth: tuple[str, str] | None = (self.client_id, self.client_secret) + r = requests.post( + cfg.introspection_endpoint, + data=data, + auth=auth, + headers={"Accept": "application/json"}, + timeout=self.request_timeout_sec, + ) + if r.status_code == 200: + return r.json() + logger.info( + f"introspection call returned status {r.status_code}: {r.text[:200]}" + ) + return None + except requests.RequestException as e: + logger.debug(f"introspection request failed: {e}") + return None + + def _call_users_api(self, uid: str) -> dict[str, Any]: + """Call Okta Users API to fetch full user profile. + + Requires okta_org_url and okta_api_token to be set. Raises exception on any error. + """ + if not self.okta_org_url or not self.okta_api_token: + raise ValueError( + "Okta org URL and API token are required for user profile lookup" + ) + + try: + url = f"{self.okta_org_url.rstrip('/')}/api/v1/users/{uid}" + headers = {"Authorization": f"SSWS {self.okta_api_token}"} + r = requests.get(url, headers=headers, timeout=self.request_timeout_sec) + if r.status_code == 200: + return r.json() + raise ValueError( + f"Okta Users API call failed with status {r.status_code}: {r.text[:200]}" + ) + except requests.RequestException as e: + raise ValueError(f"Okta Users API request failed: {e}") from e + + def build_tool_message_content( + self, *args: ToolResponse + ) -> str | list[str | dict[str, Any]]: + # The tool emits a single aggregated packet; pass it through as compact JSON + profile = args[-1].response if args else {} + return json.dumps(profile) + + def get_args_for_non_tool_calling_llm( + self, + query: str, + history: list[PreviousMessage], + llm: LLM, + force_run: bool = False, + ) -> dict[str, Any] | None: + if force_run: + return {} + + # Use LLM to determine if this tool should be called based on the query + prompt = f""" +You are helping to determine if an Okta profile lookup tool should be called based on a user's query. + +{OKTA_TOOL_DESCRIPTION} + +Query: {query} + +Conversation history: +{GENERAL_SEP_PAT} +{history} +{GENERAL_SEP_PAT} + +Should the Okta profile tool be called for this query? Respond with only "YES" or "NO". +""".strip() + response = llm.invoke(prompt) + if response and "YES" in message_to_string(response).upper(): + return {} + + return None + + def run( + self, override_kwargs: None = None, **llm_kwargs: Any + ) -> Generator[ToolResponse, None, None]: + # Try to get UID from userinfo first, then fallback to introspection + uid_candidate = None + + # Try userinfo endpoint first + userinfo_data = self._call_userinfo(self.access_token) + if userinfo_data and isinstance(userinfo_data, dict): + uid_candidate = userinfo_data.get("uid") + + # Only try introspection if userinfo didn't provide a UID + if not uid_candidate: + introspection_data = self._call_introspection(self.access_token) + if introspection_data and isinstance(introspection_data, dict): + uid_candidate = introspection_data.get("uid") + + if not uid_candidate: + raise ValueError( + "Unable to fetch user profile from Okta. This likely means your Okta " + "token has expired. Please logout, log back in, and try again." + ) + + # Call Users API to get full profile - this is now required + users_api_data = self._call_users_api(uid_candidate) + + yield ToolResponse( + id=OKTA_PROFILE_RESPONSE_ID, response=users_api_data["profile"] + ) + + def final_result(self, *args: ToolResponse) -> JSON_ro: + # Return the single aggregated profile packet + if not args: + return {} + return args[-1].response diff --git a/web/src/app/chat/message/Messages.tsx b/web/src/app/chat/message/Messages.tsx index 4762ef98ae0..e64342ae0ae 100644 --- a/web/src/app/chat/message/Messages.tsx +++ b/web/src/app/chat/message/Messages.tsx @@ -554,29 +554,32 @@ export const AIMessage = ({ )} ) : null)} - {userKnowledgeFiles && ( + {userKnowledgeFiles.length > 0 && ( )} - {!userKnowledgeFiles && + {userKnowledgeFiles.length === 0 && toolCall && !TOOLS_WITH_CUSTOM_HANDLING.includes( toolCall.tool_name ) && ( - - } - isRunning={!toolCall.tool_result || !content} - /> +
+ + } + isRunning={!toolCall.tool_result || !content} + /> +
)} + {toolCall && (!files || files.length == 0) && toolCall.tool_name === IMAGE_GENERATION_TOOL_NAME &&