From 1c13975ec5dc7f6641511f7d26c9b802cd8ed628 Mon Sep 17 00:00:00 2001 From: Wenxi Onyx Date: Thu, 17 Jul 2025 14:53:31 -0700 Subject: [PATCH 01/78] minor internet search env vars --- deployment/docker_compose/docker-compose.gpu-dev.yml | 1 + deployment/docker_compose/env.prod.template | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/deployment/docker_compose/docker-compose.gpu-dev.yml b/deployment/docker_compose/docker-compose.gpu-dev.yml index 554ad83e3f8..1b0c1078e5d 100644 --- a/deployment/docker_compose/docker-compose.gpu-dev.yml +++ b/deployment/docker_compose/docker-compose.gpu-dev.yml @@ -145,6 +145,7 @@ services: - GENERATIVE_MODEL_ACCESS_CHECK_FREQ=${GENERATIVE_MODEL_ACCESS_CHECK_FREQ:-} - DISABLE_LITELLM_STREAMING=${DISABLE_LITELLM_STREAMING:-} - LITELLM_EXTRA_HEADERS=${LITELLM_EXTRA_HEADERS:-} + - EXA_API_KEY=${EXA_API_KEY:-} # Query Options - DOC_TIME_DECAY=${DOC_TIME_DECAY:-} # Recency Bias for search results, decay at 1 / (1 + DOC_TIME_DECAY * x years) - HYBRID_ALPHA=${HYBRID_ALPHA:-} # Hybrid Search Alpha (0 for entirely keyword, 1 for entirely vector) diff --git a/deployment/docker_compose/env.prod.template b/deployment/docker_compose/env.prod.template index cf36379468d..a4e8856dfe1 100644 --- a/deployment/docker_compose/env.prod.template +++ b/deployment/docker_compose/env.prod.template @@ -8,7 +8,7 @@ WEB_DOMAIN=http://localhost:3000 # NOTE: Generative AI configurations are done via the UI now - +EXA_API_KEY= # The following are for configuring User Authentication, supported flows are: # disabled From 73258f26ea769c18f7993618ec84c0a863dc29bb Mon Sep 17 00:00:00 2001 From: Wenxi Onyx Date: Thu, 17 Jul 2025 16:20:50 -0700 Subject: [PATCH 02/78] clean up connector page and add new option for uncommon connectors --- backend/onyx/configs/app_configs.py | 3 + backend/onyx/configs/constants.py | 3 - backend/onyx/server/settings/models.py | 3 + backend/onyx/server/settings/store.py | 4 ++ .../docker_compose/docker-compose.dev.yml | 1 + .../docker_compose/docker-compose.gpu-dev.yml | 1 + .../docker-compose.multitenant-dev.yml | 1 + deployment/docker_compose/env.prod.template | 4 ++ web/src/app/admin/add-connector/page.tsx | 57 ++++++++++--------- web/src/app/admin/settings/interfaces.ts | 3 + web/src/lib/search/interfaces.ts | 6 +- web/src/lib/sources.ts | 17 +++--- 12 files changed, 61 insertions(+), 42 deletions(-) diff --git a/backend/onyx/configs/app_configs.py b/backend/onyx/configs/app_configs.py index 2c538950186..047e49d7d8b 100644 --- a/backend/onyx/configs/app_configs.py +++ b/backend/onyx/configs/app_configs.py @@ -38,6 +38,9 @@ # Controls whether users can use User Knowledge (personal documents) in assistants DISABLE_USER_KNOWLEDGE = os.environ.get("DISABLE_USER_KNOWLEDGE", "").lower() == "true" +# If set to true, will show extra/uncommon connectors in the "Other" category +SHOW_EXTRA_CONNECTORS = os.environ.get("SHOW_EXTRA_CONNECTORS", "").lower() == "true" + # Controls whether to allow admin query history reports with: # 1. associated user emails # 2. anonymized user emails diff --git a/backend/onyx/configs/constants.py b/backend/onyx/configs/constants.py index 18bfb61496e..6f1a8f8157c 100644 --- a/backend/onyx/configs/constants.py +++ b/backend/onyx/configs/constants.py @@ -216,9 +216,6 @@ class BlobType(str, Enum): GOOGLE_CLOUD_STORAGE = "google_cloud_storage" OCI_STORAGE = "oci_storage" - # Special case, for internet search - NOT_APPLICABLE = "not_applicable" - class DocumentIndexType(str, Enum): COMBINED = "combined" # Vespa diff --git a/backend/onyx/server/settings/models.py b/backend/onyx/server/settings/models.py index 9368ed91e50..90a1f7a7143 100644 --- a/backend/onyx/server/settings/models.py +++ b/backend/onyx/server/settings/models.py @@ -62,6 +62,9 @@ class Settings(BaseModel): # User Knowledge settings user_knowledge_enabled: bool | None = True + # Connector settings + show_extra_connectors: bool | None = True + class UserSettings(Settings): notifications: list[Notification] diff --git a/backend/onyx/server/settings/store.py b/backend/onyx/server/settings/store.py index a1dc319ed35..6e32e22c16c 100644 --- a/backend/onyx/server/settings/store.py +++ b/backend/onyx/server/settings/store.py @@ -1,5 +1,6 @@ from onyx.configs.app_configs import DISABLE_USER_KNOWLEDGE from onyx.configs.app_configs import ONYX_QUERY_HISTORY_TYPE +from onyx.configs.app_configs import SHOW_EXTRA_CONNECTORS from onyx.configs.constants import KV_SETTINGS_KEY from onyx.configs.constants import OnyxRedisLocks from onyx.key_value_store.factory import get_kv_store @@ -53,6 +54,9 @@ def load_settings() -> Settings: if DISABLE_USER_KNOWLEDGE: settings.user_knowledge_enabled = False + # Override show extra connectors setting based on environment variable + settings.show_extra_connectors = SHOW_EXTRA_CONNECTORS + return settings diff --git a/deployment/docker_compose/docker-compose.dev.yml b/deployment/docker_compose/docker-compose.dev.yml index 643b0535a20..908479af8da 100644 --- a/deployment/docker_compose/docker-compose.dev.yml +++ b/deployment/docker_compose/docker-compose.dev.yml @@ -233,6 +233,7 @@ services: - GITHUB_CONNECTOR_BASE_URL=${GITHUB_CONNECTOR_BASE_URL:-} - MAX_DOCUMENT_CHARS=${MAX_DOCUMENT_CHARS:-} - MAX_FILE_SIZE_BYTES=${MAX_FILE_SIZE_BYTES:-} + - SHOW_EXTRA_CONNECTORS=${SHOW_EXTRA_CONNECTORS:-} # Egnyte OAuth Configs - EGNYTE_CLIENT_ID=${EGNYTE_CLIENT_ID:-} - EGNYTE_CLIENT_SECRET=${EGNYTE_CLIENT_SECRET:-} diff --git a/deployment/docker_compose/docker-compose.gpu-dev.yml b/deployment/docker_compose/docker-compose.gpu-dev.yml index 1b0c1078e5d..46f6b8de001 100644 --- a/deployment/docker_compose/docker-compose.gpu-dev.yml +++ b/deployment/docker_compose/docker-compose.gpu-dev.yml @@ -192,6 +192,7 @@ services: - GONG_CONNECTOR_START_TIME=${GONG_CONNECTOR_START_TIME:-} - NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP=${NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP:-} - GITHUB_CONNECTOR_BASE_URL=${GITHUB_CONNECTOR_BASE_URL:-} + - SHOW_EXTRA_CONNECTORS=${SHOW_EXTRA_CONNECTORS:-} # Onyx SlackBot Configs - DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER=${DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER:-} - DANSWER_BOT_FEEDBACK_VISIBILITY=${DANSWER_BOT_FEEDBACK_VISIBILITY:-} diff --git a/deployment/docker_compose/docker-compose.multitenant-dev.yml b/deployment/docker_compose/docker-compose.multitenant-dev.yml index 98db1c5da6f..6dd60186e59 100644 --- a/deployment/docker_compose/docker-compose.multitenant-dev.yml +++ b/deployment/docker_compose/docker-compose.multitenant-dev.yml @@ -214,6 +214,7 @@ services: - GITHUB_CONNECTOR_BASE_URL=${GITHUB_CONNECTOR_BASE_URL:-} - MAX_DOCUMENT_CHARS=${MAX_DOCUMENT_CHARS:-} - MAX_FILE_SIZE_BYTES=${MAX_FILE_SIZE_BYTES:-} + - SHOW_EXTRA_CONNECTORS=${SHOW_EXTRA_CONNECTORS:-} # Egnyte OAuth Configs - EGNYTE_CLIENT_ID=${EGNYTE_CLIENT_ID:-} - EGNYTE_CLIENT_SECRET=${EGNYTE_CLIENT_SECRET:-} diff --git a/deployment/docker_compose/env.prod.template b/deployment/docker_compose/env.prod.template index a4e8856dfe1..f8800173aec 100644 --- a/deployment/docker_compose/env.prod.template +++ b/deployment/docker_compose/env.prod.template @@ -65,3 +65,7 @@ DB_READONLY_PASSWORD=password # If setting the vespa language is required, set this ('en', 'de', etc.). # See: https://docs.vespa.ai/en/linguistics.html #VESPA_LANGUAGE_OVERRIDE= + +# Uncommon connectors supported by the community +# See https://docs.onyx.app for list of these connectors +SHOW_EXTRA_CONNECTORS=False \ No newline at end of file diff --git a/web/src/app/admin/add-connector/page.tsx b/web/src/app/admin/add-connector/page.tsx index 91dccea0bab..1c75372b273 100644 --- a/web/src/app/admin/add-connector/page.tsx +++ b/web/src/app/admin/add-connector/page.tsx @@ -7,7 +7,14 @@ import { listSourceMetadata } from "@/lib/sources"; import Title from "@/components/ui/title"; import { Button } from "@/components/ui/button"; import Link from "next/link"; -import { useCallback, useEffect, useMemo, useRef, useState } from "react"; +import { + useCallback, + useContext, + useEffect, + useMemo, + useRef, + useState, +} from "react"; import { Tooltip, TooltipContent, @@ -24,6 +31,7 @@ import useSWR from "swr"; import { errorHandlingFetcher } from "@/lib/fetcher"; import { buildSimilarCredentialInfoURL } from "@/app/admin/connector/[ccPairId]/lib"; import { Credential } from "@/lib/connectors/credentials"; +import { SettingsContext } from "@/components/settings/SettingsProvider"; function SourceTile({ sourceMetadata, @@ -148,6 +156,7 @@ export default function Page() { const sources = useMemo(() => listSourceMetadata(), []); const [searchTerm, setSearchTerm] = useState(""); const { data: federatedConnectors } = useFederatedConnectors(); + const settings = useContext(SettingsContext); // Fetch Slack credentials to determine navigation behavior const { data: slackCredentials } = useSWR[]>( @@ -177,7 +186,7 @@ export default function Page() { const categorizedSources = useMemo(() => { const filtered = filterSources(sources); - return Object.values(SourceCategory).reduce( + const categories = Object.values(SourceCategory).reduce( (acc, category) => { acc[category] = sources.filter( (source) => @@ -189,7 +198,25 @@ export default function Page() { }, {} as Record ); - }, [sources, filterSources, searchTerm]); + + // Filter out the "Other" category if show_extra_connectors is false + if (settings?.settings?.show_extra_connectors === false) { + const filteredCategories = Object.entries(categories).filter( + ([category]) => category !== SourceCategory.Other + ); + return Object.fromEntries(filteredCategories) as Record< + SourceCategory, + SourceMetadata[] + >; + } + + return categories; + }, [ + sources, + filterSources, + searchTerm, + settings?.settings?.show_extra_connectors, + ]); const handleKeyPress = (e: React.KeyboardEvent) => { if (e.key === "Enter") { @@ -251,7 +278,6 @@ export default function Page() {
{category}
-

{getCategoryDescription(category as SourceCategory)}

{sources.map((source, sourceInd) => ( ); } - -function getCategoryDescription(category: SourceCategory): string { - switch (category) { - case SourceCategory.Messaging: - return "Integrate with messaging and communication platforms."; - case SourceCategory.ProjectManagement: - return "Link to project management and task tracking tools."; - case SourceCategory.CustomerSupport: - return "Connect to customer support and helpdesk systems."; - case SourceCategory.CustomerRelationshipManagement: - return "Integrate with customer relationship management platforms."; - case SourceCategory.CodeRepository: - return "Integrate with code repositories and version control systems."; - case SourceCategory.Storage: - return "Connect to cloud storage and file hosting services."; - case SourceCategory.Wiki: - return "Link to wiki and knowledge base platforms."; - case SourceCategory.Other: - return "Connect to other miscellaneous knowledge sources."; - default: - return "Connect to various knowledge sources."; - } -} diff --git a/web/src/app/admin/settings/interfaces.ts b/web/src/app/admin/settings/interfaces.ts index 40dc588340c..6adfd2cdc79 100644 --- a/web/src/app/admin/settings/interfaces.ts +++ b/web/src/app/admin/settings/interfaces.ts @@ -30,6 +30,9 @@ export interface Settings { // User Knowledge settings user_knowledge_enabled?: boolean; + + // Connector settings + show_extra_connectors?: boolean; } export enum NotificationType { diff --git a/web/src/lib/search/interfaces.ts b/web/src/lib/search/interfaces.ts index 1fde64f8d9e..8f183e0ef4a 100644 --- a/web/src/lib/search/interfaces.ts +++ b/web/src/lib/search/interfaces.ts @@ -162,10 +162,10 @@ export interface SearchResponse { } export enum SourceCategory { - Storage = "Storage", - Wiki = "Wiki", + Storage = "Web Crawl & File Storage", + Wiki = "Knowledge Base & Wiki", CustomerSupport = "Customer Support", - CustomerRelationshipManagement = "Customer Relationship Management", + SalesAndMarketing = "Sales & Marketing", Messaging = "Messaging", ProjectManagement = "Project Management", CodeRepository = "Code Repository", diff --git a/web/src/lib/sources.ts b/web/src/lib/sources.ts index 61c622446bb..59072ec77d2 100644 --- a/web/src/lib/sources.ts +++ b/web/src/lib/sources.ts @@ -84,7 +84,7 @@ export const SOURCE_METADATA_MAP: SourceMap = { web: { icon: GlobeIcon2, displayName: "Web", - category: SourceCategory.Other, + category: SourceCategory.Storage, docs: "https://docs.onyx.app/connectors/web", }, file: { @@ -154,7 +154,7 @@ export const SOURCE_METADATA_MAP: SourceMap = { gong: { icon: GongIcon, displayName: "Gong", - category: SourceCategory.Other, + category: SourceCategory.SalesAndMarketing, docs: "https://docs.onyx.app/connectors/gong", }, linear: { @@ -190,7 +190,7 @@ export const SOURCE_METADATA_MAP: SourceMap = { hubspot: { icon: HubSpotIcon, displayName: "HubSpot", - category: SourceCategory.CustomerRelationshipManagement, + category: SourceCategory.SalesAndMarketing, docs: "https://docs.onyx.app/connectors/hubspot", }, document360: { @@ -214,7 +214,7 @@ export const SOURCE_METADATA_MAP: SourceMap = { loopio: { icon: LoopioIcon, displayName: "Loopio", - category: SourceCategory.Other, + category: SourceCategory.SalesAndMarketing, }, dropbox: { icon: DropboxIcon, @@ -225,7 +225,7 @@ export const SOURCE_METADATA_MAP: SourceMap = { salesforce: { icon: SalesforceIcon, displayName: "Salesforce", - category: SourceCategory.CustomerRelationshipManagement, + category: SourceCategory.SalesAndMarketing, docs: "https://docs.onyx.app/connectors/salesforce", }, sharepoint: { @@ -319,7 +319,7 @@ export const SOURCE_METADATA_MAP: SourceMap = { fireflies: { icon: FirefliesIcon, displayName: "Fireflies", - category: SourceCategory.Other, + category: SourceCategory.SalesAndMarketing, docs: "https://docs.onyx.app/connectors/fireflies", }, egnyte: { @@ -331,7 +331,7 @@ export const SOURCE_METADATA_MAP: SourceMap = { airtable: { icon: AirtableIcon, displayName: "Airtable", - category: SourceCategory.Other, + category: SourceCategory.ProjectManagement, docs: "https://docs.onyx.app/connectors/airtable", }, gitbook: { @@ -351,8 +351,7 @@ export const SOURCE_METADATA_MAP: SourceMap = { displayName: "Email", category: SourceCategory.Messaging, }, - // currently used for the Internet Search tool docs, which is why - // a globe is used + // Placeholder used as non-null default not_applicable: { icon: GlobeIcon, displayName: "Not Applicable", From 10900c517b87b283746e3f0e6df36f15aaf95f56 Mon Sep 17 00:00:00 2001 From: Wenxi Onyx Date: Thu, 17 Jul 2025 16:21:17 -0700 Subject: [PATCH 03/78] vscode env template --- .vscode/env_template.txt | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.vscode/env_template.txt b/.vscode/env_template.txt index dd5417bbe4d..8e641f8b5cd 100644 --- a/.vscode/env_template.txt +++ b/.vscode/env_template.txt @@ -65,3 +65,6 @@ S3_ENDPOINT_URL=http://localhost:9004 S3_FILE_STORE_BUCKET_NAME=onyx-file-store-bucket S3_AWS_ACCESS_KEY_ID=minioadmin S3_AWS_SECRET_ACCESS_KEY=minioadmin + +# Show extra/uncommon connectors +SHOW_EXTRA_CONNECTORS=False From 8cde39a01bd0e332208ce516173bc2a8afbc1f19 Mon Sep 17 00:00:00 2001 From: Wenxi Onyx Date: Thu, 17 Jul 2025 16:56:52 -0700 Subject: [PATCH 04/78] deployment fix and change default to false --- backend/onyx/server/settings/models.py | 2 +- deployment/docker_compose/docker-compose.dev.yml | 4 +++- deployment/docker_compose/docker-compose.gpu-dev.yml | 5 ++++- deployment/docker_compose/docker-compose.multitenant-dev.yml | 5 ++++- 4 files changed, 12 insertions(+), 4 deletions(-) diff --git a/backend/onyx/server/settings/models.py b/backend/onyx/server/settings/models.py index 90a1f7a7143..450e321afc5 100644 --- a/backend/onyx/server/settings/models.py +++ b/backend/onyx/server/settings/models.py @@ -63,7 +63,7 @@ class Settings(BaseModel): user_knowledge_enabled: bool | None = True # Connector settings - show_extra_connectors: bool | None = True + show_extra_connectors: bool | None = False class UserSettings(Settings): diff --git a/deployment/docker_compose/docker-compose.dev.yml b/deployment/docker_compose/docker-compose.dev.yml index 908479af8da..88c131984a8 100644 --- a/deployment/docker_compose/docker-compose.dev.yml +++ b/deployment/docker_compose/docker-compose.dev.yml @@ -119,6 +119,9 @@ services: # Chat Configs - HARD_DELETE_CHATS=${HARD_DELETE_CHATS:-} + # Enables extra/community-supported connectors + - SHOW_EXTRA_CONNECTORS=${SHOW_EXTRA_CONNECTORS:-} + # Enables the use of bedrock models or IAM Auth - AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID:-} - AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY:-} @@ -233,7 +236,6 @@ services: - GITHUB_CONNECTOR_BASE_URL=${GITHUB_CONNECTOR_BASE_URL:-} - MAX_DOCUMENT_CHARS=${MAX_DOCUMENT_CHARS:-} - MAX_FILE_SIZE_BYTES=${MAX_FILE_SIZE_BYTES:-} - - SHOW_EXTRA_CONNECTORS=${SHOW_EXTRA_CONNECTORS:-} # Egnyte OAuth Configs - EGNYTE_CLIENT_ID=${EGNYTE_CLIENT_ID:-} - EGNYTE_CLIENT_SECRET=${EGNYTE_CLIENT_SECRET:-} diff --git a/deployment/docker_compose/docker-compose.gpu-dev.yml b/deployment/docker_compose/docker-compose.gpu-dev.yml index 46f6b8de001..a2b76517370 100644 --- a/deployment/docker_compose/docker-compose.gpu-dev.yml +++ b/deployment/docker_compose/docker-compose.gpu-dev.yml @@ -42,6 +42,7 @@ services: - DISABLE_GENERATIVE_AI=${DISABLE_GENERATIVE_AI:-} - DISABLE_LITELLM_STREAMING=${DISABLE_LITELLM_STREAMING:-} - LITELLM_EXTRA_HEADERS=${LITELLM_EXTRA_HEADERS:-} + - EXA_API_KEY=${EXA_API_KEY:-} # if set, allows for the use of the token budget system - TOKEN_BUDGET_GLOBALLY_ENABLED=${TOKEN_BUDGET_GLOBALLY_ENABLED:-} @@ -97,6 +98,9 @@ services: # Chat Configs - HARD_DELETE_CHATS=${HARD_DELETE_CHATS:-} + # Enables extra/community-supported connectors + - SHOW_EXTRA_CONNECTORS=${SHOW_EXTRA_CONNECTORS:-} + # Vespa Language Forcing # See: https://docs.vespa.ai/en/linguistics.html - VESPA_LANGUAGE_OVERRIDE=${VESPA_LANGUAGE_OVERRIDE:-} @@ -192,7 +196,6 @@ services: - GONG_CONNECTOR_START_TIME=${GONG_CONNECTOR_START_TIME:-} - NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP=${NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP:-} - GITHUB_CONNECTOR_BASE_URL=${GITHUB_CONNECTOR_BASE_URL:-} - - SHOW_EXTRA_CONNECTORS=${SHOW_EXTRA_CONNECTORS:-} # Onyx SlackBot Configs - DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER=${DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER:-} - DANSWER_BOT_FEEDBACK_VISIBILITY=${DANSWER_BOT_FEEDBACK_VISIBILITY:-} diff --git a/deployment/docker_compose/docker-compose.multitenant-dev.yml b/deployment/docker_compose/docker-compose.multitenant-dev.yml index 6dd60186e59..a3b7f4ee4c0 100644 --- a/deployment/docker_compose/docker-compose.multitenant-dev.yml +++ b/deployment/docker_compose/docker-compose.multitenant-dev.yml @@ -116,6 +116,10 @@ services: # Vespa Language Forcing # See: https://docs.vespa.ai/en/linguistics.html - VESPA_LANGUAGE_OVERRIDE=${VESPA_LANGUAGE_OVERRIDE:-} + + # Enables extra/community-supported connectors + - SHOW_EXTRA_CONNECTORS=${SHOW_EXTRA_CONNECTORS:-} + extra_hosts: - "host.docker.internal:host-gateway" logging: @@ -214,7 +218,6 @@ services: - GITHUB_CONNECTOR_BASE_URL=${GITHUB_CONNECTOR_BASE_URL:-} - MAX_DOCUMENT_CHARS=${MAX_DOCUMENT_CHARS:-} - MAX_FILE_SIZE_BYTES=${MAX_FILE_SIZE_BYTES:-} - - SHOW_EXTRA_CONNECTORS=${SHOW_EXTRA_CONNECTORS:-} # Egnyte OAuth Configs - EGNYTE_CLIENT_ID=${EGNYTE_CLIENT_ID:-} - EGNYTE_CLIENT_SECRET=${EGNYTE_CLIENT_SECRET:-} From 9b92c2f353e3458635d88e12c44328df36652f89 Mon Sep 17 00:00:00 2001 From: Wenxi Onyx Date: Thu, 17 Jul 2025 16:58:46 -0700 Subject: [PATCH 05/78] greptile nit --- web/src/app/admin/add-connector/page.tsx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/web/src/app/admin/add-connector/page.tsx b/web/src/app/admin/add-connector/page.tsx index 1c75372b273..7abb8b29702 100644 --- a/web/src/app/admin/add-connector/page.tsx +++ b/web/src/app/admin/add-connector/page.tsx @@ -200,7 +200,7 @@ export default function Page() { ); // Filter out the "Other" category if show_extra_connectors is false - if (settings?.settings?.show_extra_connectors === false) { + if (settings?.settings?.show_extra_connectors !== true) { const filteredCategories = Object.entries(categories).filter( ([category]) => category !== SourceCategory.Other ); From c1b706e6028d9bfa193575a7bcd6b22b8ae5351b Mon Sep 17 00:00:00 2001 From: Raunak Bhagat Date: Thu, 17 Jul 2025 15:41:31 -0700 Subject: [PATCH 06/78] fix: Move around group-sync tests (since they require docker services to be running) (#5041) * Move around tests * Add missing fixtures + change directory structure up some more * Add env variables --- .../pr-external-dependency-unit-tests.yml | 8 ++ .../workflows/pr-python-connector-tests.yml | 2 +- backend/tests/daily/conftest.py | 10 --- .../connectors/confluence/conftest.py | 0 .../confluence/test_confluence_group_sync.py | 84 +++++++++---------- .../test_google_drive_group_sync.py} | 0 6 files changed, 51 insertions(+), 53 deletions(-) rename backend/tests/{daily => external_dependency_unit}/connectors/confluence/conftest.py (100%) rename backend/tests/{daily => external_dependency_unit}/connectors/confluence/test_confluence_group_sync.py (63%) rename backend/tests/external_dependency_unit/{external_group_sync/test_external_group_sync.py => connectors/google_drive/test_google_drive_group_sync.py} (100%) diff --git a/.github/workflows/pr-external-dependency-unit-tests.yml b/.github/workflows/pr-external-dependency-unit-tests.yml index dd63e2e5a1a..a37f99ec307 100644 --- a/.github/workflows/pr-external-dependency-unit-tests.yml +++ b/.github/workflows/pr-external-dependency-unit-tests.yml @@ -13,6 +13,14 @@ env: # MinIO S3_ENDPOINT_URL: "http://localhost:9004" + # Confluence + CONFLUENCE_TEST_SPACE_URL: ${{ secrets.CONFLUENCE_TEST_SPACE_URL }} + CONFLUENCE_TEST_SPACE: ${{ secrets.CONFLUENCE_TEST_SPACE }} + CONFLUENCE_TEST_PAGE_ID: ${{ secrets.CONFLUENCE_TEST_PAGE_ID }} + CONFLUENCE_IS_CLOUD: ${{ secrets.CONFLUENCE_IS_CLOUD }} + CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }} + CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }} + jobs: discover-test-dirs: runs-on: ubuntu-latest diff --git a/.github/workflows/pr-python-connector-tests.yml b/.github/workflows/pr-python-connector-tests.yml index aea2f89c473..c4c70ef3914 100644 --- a/.github/workflows/pr-python-connector-tests.yml +++ b/.github/workflows/pr-python-connector-tests.yml @@ -16,8 +16,8 @@ env: # Confluence CONFLUENCE_TEST_SPACE_URL: ${{ secrets.CONFLUENCE_TEST_SPACE_URL }} CONFLUENCE_TEST_SPACE: ${{ secrets.CONFLUENCE_TEST_SPACE }} - CONFLUENCE_IS_CLOUD: ${{ secrets.CONFLUENCE_IS_CLOUD }} CONFLUENCE_TEST_PAGE_ID: ${{ secrets.CONFLUENCE_TEST_PAGE_ID }} + CONFLUENCE_IS_CLOUD: ${{ secrets.CONFLUENCE_IS_CLOUD }} CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }} CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }} diff --git a/backend/tests/daily/conftest.py b/backend/tests/daily/conftest.py index 000d3b53a7f..4002b6c1180 100644 --- a/backend/tests/daily/conftest.py +++ b/backend/tests/daily/conftest.py @@ -6,7 +6,6 @@ from fastapi import FastAPI from fastapi.testclient import TestClient -from onyx.db.engine.sql_engine import SqlEngine from onyx.main import fetch_versioned_implementation from onyx.utils.logger import setup_logger @@ -24,12 +23,3 @@ def client() -> Generator[TestClient, Any, None]: )() client = TestClient(app) yield client - - -@pytest.fixture(scope="session", autouse=True) -def initialize_db() -> None: - # Make sure that the db engine is initialized before any tests are run - SqlEngine.init_engine( - pool_size=10, - max_overflow=5, - ) diff --git a/backend/tests/daily/connectors/confluence/conftest.py b/backend/tests/external_dependency_unit/connectors/confluence/conftest.py similarity index 100% rename from backend/tests/daily/connectors/confluence/conftest.py rename to backend/tests/external_dependency_unit/connectors/confluence/conftest.py diff --git a/backend/tests/daily/connectors/confluence/test_confluence_group_sync.py b/backend/tests/external_dependency_unit/connectors/confluence/test_confluence_group_sync.py similarity index 63% rename from backend/tests/daily/connectors/confluence/test_confluence_group_sync.py rename to backend/tests/external_dependency_unit/connectors/confluence/test_confluence_group_sync.py index 32ba600aad7..8017fba545b 100644 --- a/backend/tests/daily/connectors/confluence/test_confluence_group_sync.py +++ b/backend/tests/external_dependency_unit/connectors/confluence/test_confluence_group_sync.py @@ -1,9 +1,10 @@ from typing import Any +from sqlalchemy.orm import Session + from ee.onyx.external_permissions.confluence.group_sync import confluence_group_sync from onyx.configs.constants import DocumentSource from onyx.connectors.models import InputType -from onyx.db.engine.sql_engine import get_session_with_current_tenant from onyx.db.enums import AccessType from onyx.db.enums import ConnectorCredentialPairStatus from onyx.db.models import Connector @@ -83,51 +84,50 @@ def test_confluence_group_sync( - initialize_db: None, + db_session: Session, confluence_connector_config: dict[str, Any], confluence_credential_json: dict[str, Any], ) -> None: - with get_session_with_current_tenant() as db_session: - connector = Connector( - name="Test Connector", - source=DocumentSource.CONFLUENCE, - input_type=InputType.POLL, - connector_specific_config=confluence_connector_config, - refresh_freq=None, - prune_freq=None, - indexing_start=None, - ) - db_session.add(connector) - db_session.flush() + connector = Connector( + name="Test Connector", + source=DocumentSource.CONFLUENCE, + input_type=InputType.POLL, + connector_specific_config=confluence_connector_config, + refresh_freq=None, + prune_freq=None, + indexing_start=None, + ) + db_session.add(connector) + db_session.flush() - credential = Credential( - source=DocumentSource.CONFLUENCE, - credential_json=confluence_credential_json, - ) - db_session.add(credential) - db_session.flush() + credential = Credential( + source=DocumentSource.CONFLUENCE, + credential_json=confluence_credential_json, + ) + db_session.add(credential) + db_session.flush() - cc_pair = ConnectorCredentialPair( - connector_id=connector.id, - credential_id=credential.id, - name="Test CC Pair", - status=ConnectorCredentialPairStatus.ACTIVE, - access_type=AccessType.SYNC, - auto_sync_options=None, - ) - db_session.add(cc_pair) - db_session.commit() - db_session.refresh(cc_pair) + cc_pair = ConnectorCredentialPair( + connector_id=connector.id, + credential_id=credential.id, + name="Test CC Pair", + status=ConnectorCredentialPairStatus.ACTIVE, + access_type=AccessType.SYNC, + auto_sync_options=None, + ) + db_session.add(cc_pair) + db_session.commit() + db_session.refresh(cc_pair) - tenant_id = get_current_tenant_id() - group_sync_iter = confluence_group_sync( - tenant_id=tenant_id, - cc_pair=cc_pair, - ) + tenant_id = get_current_tenant_id() + group_sync_iter = confluence_group_sync( + tenant_id=tenant_id, + cc_pair=cc_pair, + ) - expected_groups = {group.id: group for group in _EXPECTED_CONFLUENCE_GROUPS} - actual_groups = { - group.id: ExternalUserGroupSet.from_model(external_user_group=group) - for group in group_sync_iter - } - assert expected_groups == actual_groups + expected_groups = {group.id: group for group in _EXPECTED_CONFLUENCE_GROUPS} + actual_groups = { + group.id: ExternalUserGroupSet.from_model(external_user_group=group) + for group in group_sync_iter + } + assert expected_groups == actual_groups diff --git a/backend/tests/external_dependency_unit/external_group_sync/test_external_group_sync.py b/backend/tests/external_dependency_unit/connectors/google_drive/test_google_drive_group_sync.py similarity index 100% rename from backend/tests/external_dependency_unit/external_group_sync/test_external_group_sync.py rename to backend/tests/external_dependency_unit/connectors/google_drive/test_google_drive_group_sync.py From 8edcb69ad234a5e0b9c1c368ef487660f52b1c46 Mon Sep 17 00:00:00 2001 From: Wenxi Date: Thu, 17 Jul 2025 16:23:46 -0700 Subject: [PATCH 07/78] remove chat session necessity from send message simple api (#5040) --- .../server/query_and_chat/chat_backend.py | 37 ++++++++++++++++--- .../ee/onyx/server/query_and_chat/models.py | 15 +++++++- 2 files changed, 45 insertions(+), 7 deletions(-) diff --git a/backend/ee/onyx/server/query_and_chat/chat_backend.py b/backend/ee/onyx/server/query_and_chat/chat_backend.py index 915564d69c3..2e30cf0be37 100644 --- a/backend/ee/onyx/server/query_and_chat/chat_backend.py +++ b/backend/ee/onyx/server/query_and_chat/chat_backend.py @@ -1,5 +1,6 @@ import re from typing import cast +from uuid import UUID from fastapi import APIRouter from fastapi import Depends @@ -73,6 +74,7 @@ def _get_final_context_doc_indices( def _convert_packet_stream_to_response( packets: ChatPacketStream, + chat_session_id: UUID, ) -> ChatBasicResponse: response = ChatBasicResponse() final_context_docs: list[LlmDoc] = [] @@ -216,6 +218,8 @@ def _convert_packet_stream_to_response( if answer: response.answer_citationless = remove_answer_citations(answer) + response.chat_session_id = chat_session_id + return response @@ -237,13 +241,36 @@ def handle_simplified_chat_message( if not chat_message_req.message: raise HTTPException(status_code=400, detail="Empty chat message is invalid") + # Handle chat session creation if chat_session_id is not provided + if chat_message_req.chat_session_id is None: + if chat_message_req.persona_id is None: + raise HTTPException( + status_code=400, + detail="Either chat_session_id or persona_id must be provided", + ) + + # Create a new chat session with the provided persona_id + try: + new_chat_session = create_chat_session( + db_session=db_session, + description="", # Leave empty for simple API + user_id=user.id if user else None, + persona_id=chat_message_req.persona_id, + ) + chat_session_id = new_chat_session.id + except Exception as e: + logger.exception(e) + raise HTTPException(status_code=400, detail="Invalid Persona provided.") + else: + chat_session_id = chat_message_req.chat_session_id + try: parent_message, _ = create_chat_chain( - chat_session_id=chat_message_req.chat_session_id, db_session=db_session + chat_session_id=chat_session_id, db_session=db_session ) except Exception: parent_message = get_or_create_root_message( - chat_session_id=chat_message_req.chat_session_id, db_session=db_session + chat_session_id=chat_session_id, db_session=db_session ) if ( @@ -258,7 +285,7 @@ def handle_simplified_chat_message( retrieval_options = chat_message_req.retrieval_options full_chat_msg_info = CreateChatMessageRequest( - chat_session_id=chat_message_req.chat_session_id, + chat_session_id=chat_session_id, parent_message_id=parent_message.id, message=chat_message_req.message, file_descriptors=[], @@ -283,7 +310,7 @@ def handle_simplified_chat_message( enforce_chat_session_id_for_search_docs=False, ) - return _convert_packet_stream_to_response(packets) + return _convert_packet_stream_to_response(packets, chat_session_id) @router.post("/send-message-simple-with-history") @@ -403,4 +430,4 @@ def handle_send_message_simple_with_history( enforce_chat_session_id_for_search_docs=False, ) - return _convert_packet_stream_to_response(packets) + return _convert_packet_stream_to_response(packets, chat_session.id) diff --git a/backend/ee/onyx/server/query_and_chat/models.py b/backend/ee/onyx/server/query_and_chat/models.py index d674e9ecb51..9a97c729f35 100644 --- a/backend/ee/onyx/server/query_and_chat/models.py +++ b/backend/ee/onyx/server/query_and_chat/models.py @@ -41,11 +41,13 @@ class DocumentSearchRequest(ChunkContext): class BasicCreateChatMessageRequest(ChunkContext): - """Before creating messages, be sure to create a chat_session and get an id + """If a chat_session_id is not provided, a persona_id must be provided to automatically create a new chat session Note, for simplicity this option only allows for a single linear chain of messages """ - chat_session_id: UUID + chat_session_id: UUID | None = None + # Optional persona_id to create a new chat session if chat_session_id is not provided + persona_id: int | None = None # New message contents message: str # Defaults to using retrieval with no additional filters @@ -62,6 +64,12 @@ class BasicCreateChatMessageRequest(ChunkContext): # If True, uses agentic search instead of basic search use_agentic_search: bool = False + @model_validator(mode="after") + def validate_chat_session_or_persona(self) -> "BasicCreateChatMessageRequest": + if self.chat_session_id is None and self.persona_id is None: + raise ValueError("Either chat_session_id or persona_id must be provided") + return self + class BasicCreateChatMessageWithHistoryRequest(ChunkContext): # Last element is the new query. All previous elements are historical context @@ -171,6 +179,9 @@ class ChatBasicResponse(BaseModel): agent_sub_queries: dict[int, dict[int, list[AgentSubQuery]]] | None = None agent_refined_answer_improvement: bool | None = None + # Chat session ID for tracking conversation continuity + chat_session_id: UUID | None = None + class OneShotQARequest(ChunkContext): # Supports simplier APIs that don't deal with chat histories or message edits From 28d5a63a1c84115f612b327cba5bf5c520b5b1ce Mon Sep 17 00:00:00 2001 From: Chris Weaver Date: Thu, 17 Jul 2025 23:51:39 -0700 Subject: [PATCH 08/78] Improve support for non-default postgres schemas (#5046) --- backend/alembic/env.py | 4 ++-- .../versions/36e9220ab794_update_kg_trigger_functions.py | 6 +++--- .../versions/495cb26ce93e_create_knowlege_graph_tables.py | 8 ++++---- backend/onyx/db/engine/async_sql_engine.py | 3 ++- backend/onyx/db/engine/sql_engine.py | 6 +++--- backend/onyx/kg/clustering/clustering.py | 4 ++-- backend/onyx/kg/clustering/normalizations.py | 4 ++-- backend/shared_configs/configs.py | 2 ++ 8 files changed, 20 insertions(+), 17 deletions(-) diff --git a/backend/alembic/env.py b/backend/alembic/env.py index d33d5c37a37..24ca51a89a5 100644 --- a/backend/alembic/env.py +++ b/backend/alembic/env.py @@ -23,7 +23,7 @@ from onyx.configs.constants import SSL_CERT_FILE from shared_configs.configs import ( MULTI_TENANT, - POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE, + POSTGRES_DEFAULT_SCHEMA, TENANT_ID_PREFIX, ) from onyx.db.models import Base @@ -271,7 +271,7 @@ async def run_async_migrations() -> None: ) = get_schema_options() if not schemas and not MULTI_TENANT: - schemas = [POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE] + schemas = [POSTGRES_DEFAULT_SCHEMA] # without init_engine, subsequent engine calls fail hard intentionally SqlEngine.init_engine(pool_size=20, max_overflow=5) diff --git a/backend/alembic/versions/36e9220ab794_update_kg_trigger_functions.py b/backend/alembic/versions/36e9220ab794_update_kg_trigger_functions.py index bde421fcb73..7c9e25fb179 100644 --- a/backend/alembic/versions/36e9220ab794_update_kg_trigger_functions.py +++ b/backend/alembic/versions/36e9220ab794_update_kg_trigger_functions.py @@ -9,7 +9,7 @@ from alembic import op from sqlalchemy.orm import Session from sqlalchemy import text -from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE +from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA # revision identifiers, used by Alembic. revision = "36e9220ab794" @@ -66,7 +66,7 @@ def upgrade() -> None: -- Set name and name trigrams NEW.name = name; - NEW.name_trigrams = {POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE}.show_trgm(cleaned_name); + NEW.name_trigrams = {POSTGRES_DEFAULT_SCHEMA}.show_trgm(cleaned_name); RETURN NEW; END; $$ LANGUAGE plpgsql; @@ -111,7 +111,7 @@ def upgrade() -> None: UPDATE "{tenant_id}".kg_entity SET name = doc_name, - name_trigrams = {POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE}.show_trgm(cleaned_name) + name_trigrams = {POSTGRES_DEFAULT_SCHEMA}.show_trgm(cleaned_name) WHERE document_id = NEW.id; RETURN NEW; END; diff --git a/backend/alembic/versions/495cb26ce93e_create_knowlege_graph_tables.py b/backend/alembic/versions/495cb26ce93e_create_knowlege_graph_tables.py index 65cf759d6f3..f1cbf003359 100644 --- a/backend/alembic/versions/495cb26ce93e_create_knowlege_graph_tables.py +++ b/backend/alembic/versions/495cb26ce93e_create_knowlege_graph_tables.py @@ -15,7 +15,7 @@ from onyx.configs.app_configs import DB_READONLY_USER from onyx.configs.app_configs import DB_READONLY_PASSWORD from shared_configs.configs import MULTI_TENANT -from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE +from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA # revision identifiers, used by Alembic. @@ -478,7 +478,7 @@ def upgrade() -> None: # Create GIN index for clustering and normalization op.execute( "CREATE INDEX IF NOT EXISTS idx_kg_entity_clustering_trigrams " - f"ON kg_entity USING GIN (name {POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE}.gin_trgm_ops)" + f"ON kg_entity USING GIN (name {POSTGRES_DEFAULT_SCHEMA}.gin_trgm_ops)" ) op.execute( "CREATE INDEX IF NOT EXISTS idx_kg_entity_normalization_trigrams " @@ -518,7 +518,7 @@ def upgrade() -> None: -- Set name and name trigrams NEW.name = name; - NEW.name_trigrams = {POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE}.show_trgm(cleaned_name); + NEW.name_trigrams = {POSTGRES_DEFAULT_SCHEMA}.show_trgm(cleaned_name); RETURN NEW; END; $$ LANGUAGE plpgsql; @@ -563,7 +563,7 @@ def upgrade() -> None: UPDATE kg_entity SET name = doc_name, - name_trigrams = {POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE}.show_trgm(cleaned_name) + name_trigrams = {POSTGRES_DEFAULT_SCHEMA}.show_trgm(cleaned_name) WHERE document_id = NEW.id; RETURN NEW; END; diff --git a/backend/onyx/db/engine/async_sql_engine.py b/backend/onyx/db/engine/async_sql_engine.py index a871dbad1ed..0bce3561899 100644 --- a/backend/onyx/db/engine/async_sql_engine.py +++ b/backend/onyx/db/engine/async_sql_engine.py @@ -29,6 +29,7 @@ from onyx.db.engine.sql_engine import SqlEngine from onyx.db.engine.sql_engine import USE_IAM_AUTH from shared_configs.configs import MULTI_TENANT +from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE from shared_configs.contextvars import get_current_tenant_id @@ -118,7 +119,7 @@ async def get_async_session( engine = get_sqlalchemy_async_engine() # no need to use the schema translation map for self-hosted + default schema - if not MULTI_TENANT: + if not MULTI_TENANT and tenant_id == POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE: async with AsyncSession(bind=engine, expire_on_commit=False) as session: yield session return diff --git a/backend/onyx/db/engine/sql_engine.py b/backend/onyx/db/engine/sql_engine.py index beac099265e..459afb9d849 100644 --- a/backend/onyx/db/engine/sql_engine.py +++ b/backend/onyx/db/engine/sql_engine.py @@ -31,6 +31,7 @@ from onyx.utils.logger import setup_logger from shared_configs.configs import MULTI_TENANT from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA +from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR from shared_configs.contextvars import get_current_tenant_id @@ -324,7 +325,7 @@ def get_session_with_tenant(*, tenant_id: str) -> Generator[Session, None, None] raise HTTPException(status_code=400, detail="Invalid tenant ID") # no need to use the schema translation map for self-hosted + default schema - if not MULTI_TENANT: + if not MULTI_TENANT and tenant_id == POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE: with Session(bind=engine, expire_on_commit=False) as session: yield session return @@ -370,12 +371,11 @@ def get_db_readonly_user_session_with_current_tenant() -> ( raise HTTPException(status_code=400, detail="Invalid tenant ID") # no need to use the schema translation map for self-hosted + default schema - if not MULTI_TENANT: + if not MULTI_TENANT and tenant_id == POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE: with Session(readonly_engine, expire_on_commit=False) as session: yield session return - # no need to use the schema translation map for self-hosted + default schema schema_translate_map = {None: tenant_id} with readonly_engine.connect().execution_options( schema_translate_map=schema_translate_map diff --git a/backend/onyx/kg/clustering/clustering.py b/backend/onyx/kg/clustering/clustering.py index 12ca5ffa023..7012e01d23b 100644 --- a/backend/onyx/kg/clustering/clustering.py +++ b/backend/onyx/kg/clustering/clustering.py @@ -34,7 +34,7 @@ from onyx.kg.utils.formatting_utils import make_relationship_id from onyx.utils.logger import setup_logger from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel -from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE +from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA logger = setup_logger() @@ -180,7 +180,7 @@ def _cluster_one_grounded_entity( # find entities of the same type with a similar name *filtering, KGEntity.entity_type_id_name == entity.entity_type_id_name, - getattr(func, POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE).similarity_op( + getattr(func, POSTGRES_DEFAULT_SCHEMA).similarity_op( KGEntity.name, entity_name ), ) diff --git a/backend/onyx/kg/clustering/normalizations.py b/backend/onyx/kg/clustering/normalizations.py index 3f272021145..3a611cec7ad 100644 --- a/backend/onyx/kg/clustering/normalizations.py +++ b/backend/onyx/kg/clustering/normalizations.py @@ -33,7 +33,7 @@ from onyx.kg.utils.formatting_utils import split_relationship_id from onyx.utils.logger import setup_logger from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel -from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE +from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA logger = setup_logger() @@ -95,7 +95,7 @@ def _normalize_one_entity( # generate trigrams of the queried entity Q query_trigrams = db_session.query( - getattr(func, POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE) + getattr(func, POSTGRES_DEFAULT_SCHEMA) .show_trgm(cleaned_entity) .cast(ARRAY(String(3))) .label("trigrams") diff --git a/backend/shared_configs/configs.py b/backend/shared_configs/configs.py index 12d12a49697..a21e890e360 100644 --- a/backend/shared_configs/configs.py +++ b/backend/shared_configs/configs.py @@ -140,6 +140,8 @@ def validate_cors_origin(origin: str) -> None: # Multi-tenancy configuration MULTI_TENANT = os.environ.get("MULTI_TENANT", "").lower() == "true" +# Outside this file, should almost always use `POSTGRES_DEFAULT_SCHEMA` unless you +# have a very good reason POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE = "public" POSTGRES_DEFAULT_SCHEMA = ( os.environ.get("POSTGRES_DEFAULT_SCHEMA") or POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE From 524eb1e8b7e92e8d942cb73ac3f8ed9ac47871d3 Mon Sep 17 00:00:00 2001 From: Chris Weaver Date: Thu, 17 Jul 2025 23:52:51 -0700 Subject: [PATCH 09/78] fix: improve check for indexing status (#5042) * Improve check_for_indexing + check_for_vespa_sync_task * Remove unused * Fix * Simplify query * Add more logging * Address bot comments * Increase # of tasks generated since we're not going cc-pair by cc-pair * Only index 50 user files at a time --- .../onyx/background/celery/apps/app_base.py | 10 +- .../onyx/background/celery/apps/primary.py | 9 +- .../background/celery/tasks/indexing/tasks.py | 39 +++- .../celery/tasks/vespa/document_sync.py | 178 +++++++++++++++ .../background/celery/tasks/vespa/tasks.py | 109 ++------- backend/onyx/configs/app_configs.py | 2 +- backend/onyx/db/connector_credential_pair.py | 30 ++- backend/onyx/db/document.py | 71 +----- backend/onyx/db/enums.py | 14 +- .../redis/redis_connector_credential_pair.py | 207 ------------------ backend/onyx/redis/redis_utils.py | 5 - 11 files changed, 282 insertions(+), 392 deletions(-) create mode 100644 backend/onyx/background/celery/tasks/vespa/document_sync.py delete mode 100644 backend/onyx/redis/redis_connector_credential_pair.py diff --git a/backend/onyx/background/celery/apps/app_base.py b/backend/onyx/background/celery/apps/app_base.py index e4619ce3c11..59ecf1cba59 100644 --- a/backend/onyx/background/celery/apps/app_base.py +++ b/backend/onyx/background/celery/apps/app_base.py @@ -24,13 +24,14 @@ from onyx.background.celery.apps.task_formatters import CeleryTaskPlainFormatter from onyx.background.celery.celery_utils import celery_is_worker_primary from onyx.background.celery.celery_utils import make_probe_path +from onyx.background.celery.tasks.vespa.document_sync import DOCUMENT_SYNC_PREFIX +from onyx.background.celery.tasks.vespa.document_sync import DOCUMENT_SYNC_TASKSET_KEY from onyx.configs.constants import ONYX_CLOUD_CELERY_TASK_PREFIX from onyx.configs.constants import OnyxRedisLocks from onyx.db.engine.sql_engine import get_sqlalchemy_engine from onyx.document_index.vespa.shared_utils.utils import wait_for_vespa_with_timeout from onyx.httpx.httpx_pool import HttpxPool from onyx.redis.redis_connector import RedisConnector -from onyx.redis.redis_connector_credential_pair import RedisConnectorCredentialPair from onyx.redis.redis_connector_delete import RedisConnectorDelete from onyx.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync from onyx.redis.redis_connector_ext_group_sync import RedisConnectorExternalGroupSync @@ -145,8 +146,11 @@ def on_task_postrun( r = get_redis_client(tenant_id=tenant_id) - if task_id.startswith(RedisConnectorCredentialPair.PREFIX): - r.srem(RedisConnectorCredentialPair.get_taskset_key(), task_id) + # NOTE: we want to remove the `Redis*` classes, prefer to just have functions to + # do these things going forward. In short, things should generally be like the doc + # sync task rather than the others below + if task_id.startswith(DOCUMENT_SYNC_PREFIX): + r.srem(DOCUMENT_SYNC_TASKSET_KEY, task_id) return if task_id.startswith(RedisDocumentSet.PREFIX): diff --git a/backend/onyx/background/celery/apps/primary.py b/backend/onyx/background/celery/apps/primary.py index 298918f0ef2..e63546a7488 100644 --- a/backend/onyx/background/celery/apps/primary.py +++ b/backend/onyx/background/celery/apps/primary.py @@ -21,6 +21,7 @@ from onyx.background.celery.tasks.indexing.utils import ( get_unfenced_index_attempt_ids, ) +from onyx.background.celery.tasks.vespa.document_sync import reset_document_sync from onyx.configs.constants import CELERY_PRIMARY_WORKER_LOCK_TIMEOUT from onyx.configs.constants import OnyxRedisConstants from onyx.configs.constants import OnyxRedisLocks @@ -29,9 +30,6 @@ from onyx.db.engine.sql_engine import SqlEngine from onyx.db.index_attempt import get_index_attempt from onyx.db.index_attempt import mark_attempt_canceled -from onyx.redis.redis_connector_credential_pair import ( - RedisGlobalConnectorCredentialPair, -) from onyx.redis.redis_connector_delete import RedisConnectorDelete from onyx.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync from onyx.redis.redis_connector_ext_group_sync import RedisConnectorExternalGroupSync @@ -156,7 +154,10 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None: r.delete(OnyxRedisConstants.ACTIVE_FENCES) - RedisGlobalConnectorCredentialPair.reset_all(r) + # NOTE: we want to remove the `Redis*` classes, prefer to just have functions + # This is the preferred way to do this going forward + reset_document_sync(r) + RedisDocumentSet.reset_all(r) RedisUserGroup.reset_all(r) RedisConnectorDelete.reset_all(r) diff --git a/backend/onyx/background/celery/tasks/indexing/tasks.py b/backend/onyx/background/celery/tasks/indexing/tasks.py index b07a0a2133e..c911271f154 100644 --- a/backend/onyx/background/celery/tasks/indexing/tasks.py +++ b/backend/onyx/background/celery/tasks/indexing/tasks.py @@ -54,7 +54,10 @@ from onyx.configs.constants import OnyxRedisSignals from onyx.connectors.exceptions import ConnectorValidationError from onyx.db.connector import mark_ccpair_with_indexing_trigger -from onyx.db.connector_credential_pair import fetch_connector_credential_pairs +from onyx.db.connector_credential_pair import ConnectorType +from onyx.db.connector_credential_pair import ( + fetch_indexable_connector_credential_pair_ids, +) from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id from onyx.db.connector_credential_pair import set_cc_pair_repeated_error_state from onyx.db.engine.sql_engine import get_session_with_current_tenant @@ -86,6 +89,8 @@ logger = setup_logger() +USER_FILE_INDEXING_LIMIT = 100 + def _get_fence_validation_block_expiration() -> int: """ @@ -480,20 +485,37 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None: embedding_model=embedding_model, ) - # gather cc_pair_ids + # gather cc_pair_ids + current search settings lock_beat.reacquire() - cc_pair_ids: list[int] = [] with get_session_with_current_tenant() as db_session: - cc_pairs = fetch_connector_credential_pairs( - db_session, include_user_files=True + standard_cc_pair_ids = fetch_indexable_connector_credential_pair_ids( + db_session, connector_type=ConnectorType.STANDARD + ) + # only index 50 user files at a time. This makes sense since user files are + # indexed only once, and then they are done. In practice, we would rarely + # have more than `USER_FILE_INDEXING_LIMIT` user files to index. + user_file_cc_pair_ids = fetch_indexable_connector_credential_pair_ids( + db_session, + connector_type=ConnectorType.USER_FILE, + limit=USER_FILE_INDEXING_LIMIT, ) - for cc_pair_entry in cc_pairs: - cc_pair_ids.append(cc_pair_entry.id) + cc_pair_ids = standard_cc_pair_ids + user_file_cc_pair_ids + + # NOTE: some potential race conditions here, but the worse case is + # kicking off some "invalid" indexing tasks which will just fail + search_settings_list = get_active_search_settings_list(db_session) + + current_search_settings = next( + search_settings_instance + for search_settings_instance in search_settings_list + if search_settings_instance.status.is_current() + ) # mark CC Pairs that are repeatedly failing as in repeated error state with get_session_with_current_tenant() as db_session: - current_search_settings = get_current_search_settings(db_session) for cc_pair_id in cc_pair_ids: + lock_beat.reacquire() + if is_in_repeated_error_state( cc_pair_id=cc_pair_id, search_settings_id=current_search_settings.id, @@ -511,7 +533,6 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None: redis_connector = RedisConnector(tenant_id, cc_pair_id) with get_session_with_current_tenant() as db_session: - search_settings_list = get_active_search_settings_list(db_session) for search_settings_instance in search_settings_list: # skip non-live search settings that don't have background reindex enabled # those should just auto-change to live shortly after creation without diff --git a/backend/onyx/background/celery/tasks/vespa/document_sync.py b/backend/onyx/background/celery/tasks/vespa/document_sync.py new file mode 100644 index 00000000000..489d127a830 --- /dev/null +++ b/backend/onyx/background/celery/tasks/vespa/document_sync.py @@ -0,0 +1,178 @@ +import time +from typing import cast +from uuid import uuid4 + +from celery import Celery +from redis import Redis +from redis.lock import Lock as RedisLock +from sqlalchemy.orm import Session + +from onyx.configs.app_configs import DB_YIELD_PER_DEFAULT +from onyx.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT +from onyx.configs.constants import OnyxCeleryPriority +from onyx.configs.constants import OnyxCeleryQueues +from onyx.configs.constants import OnyxCeleryTask +from onyx.configs.constants import OnyxRedisConstants +from onyx.db.document import construct_document_id_select_by_needs_sync +from onyx.db.document import count_documents_by_needs_sync +from onyx.utils.logger import setup_logger + +# Redis keys for document sync tracking +DOCUMENT_SYNC_PREFIX = "documentsync" +DOCUMENT_SYNC_FENCE_KEY = f"{DOCUMENT_SYNC_PREFIX}_fence" +DOCUMENT_SYNC_TASKSET_KEY = f"{DOCUMENT_SYNC_PREFIX}_taskset" + +logger = setup_logger() + + +def is_document_sync_fenced(r: Redis) -> bool: + """Check if document sync tasks are currently in progress.""" + return bool(r.exists(DOCUMENT_SYNC_FENCE_KEY)) + + +def get_document_sync_payload(r: Redis) -> int | None: + """Get the initial number of tasks that were created.""" + bytes_result = r.get(DOCUMENT_SYNC_FENCE_KEY) + if bytes_result is None: + return None + return int(cast(int, bytes_result)) + + +def get_document_sync_remaining(r: Redis) -> int: + """Get the number of tasks still pending completion.""" + return cast(int, r.scard(DOCUMENT_SYNC_TASKSET_KEY)) + + +def set_document_sync_fence(r: Redis, payload: int | None) -> None: + """Set up the fence and register with active fences.""" + if payload is None: + r.srem(OnyxRedisConstants.ACTIVE_FENCES, DOCUMENT_SYNC_FENCE_KEY) + r.delete(DOCUMENT_SYNC_FENCE_KEY) + return + + r.set(DOCUMENT_SYNC_FENCE_KEY, payload) + r.sadd(OnyxRedisConstants.ACTIVE_FENCES, DOCUMENT_SYNC_FENCE_KEY) + + +def delete_document_sync_taskset(r: Redis) -> None: + """Clear the document sync taskset.""" + r.delete(DOCUMENT_SYNC_TASKSET_KEY) + + +def reset_document_sync(r: Redis) -> None: + """Reset all document sync tracking data.""" + r.srem(OnyxRedisConstants.ACTIVE_FENCES, DOCUMENT_SYNC_FENCE_KEY) + r.delete(DOCUMENT_SYNC_TASKSET_KEY) + r.delete(DOCUMENT_SYNC_FENCE_KEY) + + +def generate_document_sync_tasks( + r: Redis, + max_tasks: int, + celery_app: Celery, + db_session: Session, + lock: RedisLock, + tenant_id: str, +) -> tuple[int, int]: + """Generate sync tasks for all documents that need syncing. + + Args: + r: Redis client + max_tasks: Maximum number of tasks to generate + celery_app: Celery application instance + db_session: Database session + lock: Redis lock for coordination + tenant_id: Tenant identifier + + Returns: + tuple[int, int]: (tasks_generated, total_docs_found) + """ + last_lock_time = time.monotonic() + num_tasks_sent = 0 + num_docs = 0 + + # Get all documents that need syncing + stmt = construct_document_id_select_by_needs_sync() + + for doc_id in db_session.scalars(stmt).yield_per(DB_YIELD_PER_DEFAULT): + doc_id = cast(str, doc_id) + current_time = time.monotonic() + + # Reacquire lock periodically to prevent timeout + if current_time - last_lock_time >= (CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4): + lock.reacquire() + last_lock_time = current_time + + num_docs += 1 + + # Create a unique task ID + custom_task_id = f"{DOCUMENT_SYNC_PREFIX}_{uuid4()}" + + # Add to the tracking taskset in Redis BEFORE creating the celery task + r.sadd(DOCUMENT_SYNC_TASKSET_KEY, custom_task_id) + + # Create the Celery task + celery_app.send_task( + OnyxCeleryTask.VESPA_METADATA_SYNC_TASK, + kwargs=dict(document_id=doc_id, tenant_id=tenant_id), + queue=OnyxCeleryQueues.VESPA_METADATA_SYNC, + task_id=custom_task_id, + priority=OnyxCeleryPriority.MEDIUM, + ignore_result=True, + ) + + num_tasks_sent += 1 + + if num_tasks_sent >= max_tasks: + break + + return num_tasks_sent, num_docs + + +def try_generate_stale_document_sync_tasks( + celery_app: Celery, + max_tasks: int, + db_session: Session, + r: Redis, + lock_beat: RedisLock, + tenant_id: str, +) -> int | None: + # the fence is up, do nothing + if is_document_sync_fenced(r): + return None + + # add tasks to celery and build up the task set to monitor in redis + stale_doc_count = count_documents_by_needs_sync(db_session) + if stale_doc_count == 0: + logger.info("No stale documents found. Skipping sync tasks generation.") + return None + + logger.info( + f"Stale documents found (at least {stale_doc_count}). Generating sync tasks in one batch." + ) + + logger.info("generate_document_sync_tasks starting for all documents.") + + # Generate all tasks in one pass + result = generate_document_sync_tasks( + r, max_tasks, celery_app, db_session, lock_beat, tenant_id + ) + + if result is None: + return None + + tasks_generated, total_docs = result + + if tasks_generated >= max_tasks: + logger.info( + f"generate_document_sync_tasks reached the task generation limit: " + f"tasks_generated={tasks_generated} max_tasks={max_tasks}" + ) + else: + logger.info( + f"generate_document_sync_tasks finished for all documents. " + f"tasks_generated={tasks_generated} total_docs_found={total_docs}" + ) + + set_document_sync_fence(r, tasks_generated) + return tasks_generated diff --git a/backend/onyx/background/celery/tasks/vespa/tasks.py b/backend/onyx/background/celery/tasks/vespa/tasks.py index 9966d8e5934..cee8a6b0e53 100644 --- a/backend/onyx/background/celery/tasks/vespa/tasks.py +++ b/backend/onyx/background/celery/tasks/vespa/tasks.py @@ -20,14 +20,19 @@ from onyx.background.celery.tasks.shared.tasks import LIGHT_SOFT_TIME_LIMIT from onyx.background.celery.tasks.shared.tasks import LIGHT_TIME_LIMIT from onyx.background.celery.tasks.shared.tasks import OnyxCeleryTaskCompletionStatus +from onyx.background.celery.tasks.vespa.document_sync import DOCUMENT_SYNC_FENCE_KEY +from onyx.background.celery.tasks.vespa.document_sync import get_document_sync_payload +from onyx.background.celery.tasks.vespa.document_sync import get_document_sync_remaining +from onyx.background.celery.tasks.vespa.document_sync import reset_document_sync +from onyx.background.celery.tasks.vespa.document_sync import ( + try_generate_stale_document_sync_tasks, +) from onyx.configs.app_configs import JOB_TIMEOUT from onyx.configs.app_configs import VESPA_SYNC_MAX_TASKS from onyx.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT from onyx.configs.constants import OnyxCeleryTask from onyx.configs.constants import OnyxRedisConstants from onyx.configs.constants import OnyxRedisLocks -from onyx.db.connector_credential_pair import get_connector_credential_pairs -from onyx.db.document import count_documents_by_needs_sync from onyx.db.document import get_document from onyx.db.document import mark_document_as_synced from onyx.db.document_set import delete_document_set @@ -47,10 +52,6 @@ from onyx.document_index.factory import get_default_document_index from onyx.document_index.interfaces import VespaDocumentFields from onyx.httpx.httpx_pool import HttpxPool -from onyx.redis.redis_connector_credential_pair import RedisConnectorCredentialPair -from onyx.redis.redis_connector_credential_pair import ( - RedisGlobalConnectorCredentialPair, -) from onyx.redis.redis_document_set import RedisDocumentSet from onyx.redis.redis_pool import get_redis_client from onyx.redis.redis_pool import get_redis_replica_client @@ -166,8 +167,11 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str) -> bool | None: continue key_str = key_bytes.decode("utf-8") - if key_str == RedisGlobalConnectorCredentialPair.FENCE_KEY: - monitor_connector_taskset(r) + # NOTE: removing the "Redis*" classes, prefer to just have functions to + # do these things going forward. In short, things should generally be like the doc + # sync task rather than the others + if key_str == DOCUMENT_SYNC_FENCE_KEY: + monitor_document_sync_taskset(r) elif key_str.startswith(RedisDocumentSet.FENCE_PREFIX): with get_session_with_current_tenant() as db_session: monitor_document_set_taskset(tenant_id, key_bytes, r, db_session) @@ -203,82 +207,6 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str) -> bool | None: return True -def try_generate_stale_document_sync_tasks( - celery_app: Celery, - max_tasks: int, - db_session: Session, - r: Redis, - lock_beat: RedisLock, - tenant_id: str, -) -> int | None: - # the fence is up, do nothing - - redis_global_ccpair = RedisGlobalConnectorCredentialPair(r) - if redis_global_ccpair.fenced: - return None - - redis_global_ccpair.delete_taskset() - - # add tasks to celery and build up the task set to monitor in redis - stale_doc_count = count_documents_by_needs_sync(db_session) - if stale_doc_count == 0: - return None - - task_logger.info( - f"Stale documents found (at least {stale_doc_count}). Generating sync tasks by cc pair." - ) - - task_logger.info( - "RedisConnector.generate_tasks starting by cc_pair. " - "Documents spanning multiple cc_pairs will only be synced once." - ) - - docs_to_skip: set[str] = set() - - # rkuo: we could technically sync all stale docs in one big pass. - # but I feel it's more understandable to group the docs by cc_pair - total_tasks_generated = 0 - tasks_remaining = max_tasks - cc_pairs = get_connector_credential_pairs(db_session) - for cc_pair in cc_pairs: - lock_beat.reacquire() - - rc = RedisConnectorCredentialPair(tenant_id, cc_pair.id) - rc.set_skip_docs(docs_to_skip) - result = rc.generate_tasks( - tasks_remaining, celery_app, db_session, r, lock_beat, tenant_id - ) - - if result is None: - continue - - if result[1] == 0: - continue - - task_logger.info( - f"RedisConnector.generate_tasks finished for single cc_pair. " - f"cc_pair={cc_pair.id} tasks_generated={result[0]} tasks_possible={result[1]}" - ) - - total_tasks_generated += result[0] - tasks_remaining -= result[0] - if tasks_remaining <= 0: - break - - if tasks_remaining <= 0: - task_logger.info( - f"RedisConnector.generate_tasks reached the task generation limit: " - f"total_tasks_generated={total_tasks_generated} max_tasks={max_tasks}" - ) - else: - task_logger.info( - f"RedisConnector.generate_tasks finished for all cc_pairs. total_tasks_generated={total_tasks_generated}" - ) - - redis_global_ccpair.set_fence(total_tasks_generated) - return total_tasks_generated - - def try_generate_document_set_sync_tasks( celery_app: Celery, document_set_id: int, @@ -433,19 +361,18 @@ def try_generate_user_group_sync_tasks( return tasks_generated -def monitor_connector_taskset(r: Redis) -> None: - redis_global_ccpair = RedisGlobalConnectorCredentialPair(r) - initial_count = redis_global_ccpair.payload +def monitor_document_sync_taskset(r: Redis) -> None: + initial_count = get_document_sync_payload(r) if initial_count is None: return - remaining = redis_global_ccpair.get_remaining() + remaining = get_document_sync_remaining(r) task_logger.info( - f"Stale document sync progress: remaining={remaining} initial={initial_count}" + f"Document sync progress: remaining={remaining} initial={initial_count}" ) if remaining == 0: - redis_global_ccpair.reset() - task_logger.info(f"Successfully synced stale documents. count={initial_count}") + reset_document_sync(r) + task_logger.info(f"Successfully synced all documents. count={initial_count}") def monitor_document_set_taskset( diff --git a/backend/onyx/configs/app_configs.py b/backend/onyx/configs/app_configs.py index 047e49d7d8b..650ae61fb87 100644 --- a/backend/onyx/configs/app_configs.py +++ b/backend/onyx/configs/app_configs.py @@ -332,7 +332,7 @@ ) # The maximum number of tasks that can be queued up to sync to Vespa in a single pass -VESPA_SYNC_MAX_TASKS = 1024 +VESPA_SYNC_MAX_TASKS = 8192 DB_YIELD_PER_DEFAULT = 64 diff --git a/backend/onyx/db/connector_credential_pair.py b/backend/onyx/db/connector_credential_pair.py index 7de64dfa56b..040a43ec909 100644 --- a/backend/onyx/db/connector_credential_pair.py +++ b/backend/onyx/db/connector_credential_pair.py @@ -1,4 +1,5 @@ from datetime import datetime +from enum import Enum from typing import TypeVarTuple from fastapi import HTTPException @@ -41,6 +42,11 @@ R = TypeVarTuple("R") +class ConnectorType(str, Enum): + STANDARD = "standard" + USER_FILE = "user_file" + + def _add_user_filters( stmt: Select[tuple[*R]], user: User | None, get_editable: bool = True ) -> Select[tuple[*R]]: @@ -619,14 +625,24 @@ def remove_credential_from_connector( ) -def fetch_connector_credential_pairs( +def fetch_indexable_connector_credential_pair_ids( db_session: Session, - include_user_files: bool = False, -) -> list[ConnectorCredentialPair]: - stmt = select(ConnectorCredentialPair) - if not include_user_files: - stmt = stmt.where(ConnectorCredentialPair.is_user_file != True) # noqa: E712 - return list(db_session.scalars(stmt).unique().all()) + connector_type: ConnectorType | None = None, + limit: int | None = None, +) -> list[int]: + stmt = select(ConnectorCredentialPair.id) + stmt = stmt.where( + ConnectorCredentialPair.status.in_( + ConnectorCredentialPairStatus.active_statuses() + ) + ) + if connector_type == ConnectorType.USER_FILE: + stmt = stmt.where(ConnectorCredentialPair.is_user_file.is_(True)) + elif connector_type == ConnectorType.STANDARD: + stmt = stmt.where(ConnectorCredentialPair.is_user_file.is_(False)) + if limit: + stmt = stmt.limit(limit) + return list(db_session.scalars(stmt).all()) def fetch_connector_credential_pair_for_connector( diff --git a/backend/onyx/db/document.py b/backend/onyx/db/document.py index f7d0afecf7f..729cbd4f51e 100644 --- a/backend/onyx/db/document.py +++ b/backend/onyx/db/document.py @@ -79,10 +79,6 @@ def count_documents_by_needs_sync(session: Session) -> int: return ( session.query(DbDocument.id) - .join( - DocumentByConnectorCredentialPair, - DbDocument.id == DocumentByConnectorCredentialPair.id, - ) .filter( or_( DbDocument.last_modified > DbDocument.last_synced, @@ -93,67 +89,22 @@ def count_documents_by_needs_sync(session: Session) -> int: ) -def construct_document_select_for_connector_credential_pair_by_needs_sync( - connector_id: int, credential_id: int -) -> Select: - return ( - select(DbDocument) - .join( - DocumentByConnectorCredentialPair, - DbDocument.id == DocumentByConnectorCredentialPair.id, - ) - .where( - and_( - DocumentByConnectorCredentialPair.connector_id == connector_id, - DocumentByConnectorCredentialPair.credential_id == credential_id, - or_( - DbDocument.last_modified > DbDocument.last_synced, - DbDocument.last_synced.is_(None), - ), - ) - ) - ) - +def construct_document_id_select_by_needs_sync() -> Select: + """Get all document IDs that need syncing across all connector credential pairs. -def construct_document_id_select_for_connector_credential_pair_by_needs_sync( - connector_id: int, credential_id: int -) -> Select: - return ( - select(DbDocument.id) - .join( - DocumentByConnectorCredentialPair, - DbDocument.id == DocumentByConnectorCredentialPair.id, - ) - .where( - and_( - DocumentByConnectorCredentialPair.connector_id == connector_id, - DocumentByConnectorCredentialPair.credential_id == credential_id, - or_( - DbDocument.last_modified > DbDocument.last_synced, - DbDocument.last_synced.is_(None), - ), - ) + Returns a Select statement for documents where: + 1. last_modified is newer than last_synced + 2. last_synced is null (meaning we've never synced) + AND the document has a relationship with a connector/credential pair + """ + return select(DbDocument.id).where( + or_( + DbDocument.last_modified > DbDocument.last_synced, + DbDocument.last_synced.is_(None), ) ) -def get_all_documents_needing_vespa_sync_for_cc_pair( - db_session: Session, cc_pair_id: int -) -> list[DbDocument]: - cc_pair = get_connector_credential_pair_from_id( - db_session=db_session, - cc_pair_id=cc_pair_id, - ) - if not cc_pair: - raise ValueError(f"No CC pair found with ID: {cc_pair_id}") - - stmt = construct_document_select_for_connector_credential_pair_by_needs_sync( - cc_pair.connector_id, cc_pair.credential_id - ) - - return list(db_session.scalars(stmt).all()) - - def construct_document_id_select_for_connector_credential_pair( connector_id: int, credential_id: int | None = None ) -> Select: diff --git a/backend/onyx/db/enums.py b/backend/onyx/db/enums.py index 0730096990a..39ef8574233 100644 --- a/backend/onyx/db/enums.py +++ b/backend/onyx/db/enums.py @@ -86,12 +86,16 @@ class ConnectorCredentialPairStatus(str, PyEnum): DELETING = "DELETING" INVALID = "INVALID" + @classmethod + def active_statuses(cls) -> list["ConnectorCredentialPairStatus"]: + return [ + ConnectorCredentialPairStatus.ACTIVE, + ConnectorCredentialPairStatus.SCHEDULED, + ConnectorCredentialPairStatus.INITIAL_INDEXING, + ] + def is_active(self) -> bool: - return ( - self == ConnectorCredentialPairStatus.ACTIVE - or self == ConnectorCredentialPairStatus.SCHEDULED - or self == ConnectorCredentialPairStatus.INITIAL_INDEXING - ) + return self in self.active_statuses() class AccessType(str, PyEnum): diff --git a/backend/onyx/redis/redis_connector_credential_pair.py b/backend/onyx/redis/redis_connector_credential_pair.py deleted file mode 100644 index 5bbbd2e08f2..00000000000 --- a/backend/onyx/redis/redis_connector_credential_pair.py +++ /dev/null @@ -1,207 +0,0 @@ -import time -from typing import cast -from uuid import uuid4 - -import redis -from celery import Celery -from redis import Redis -from redis.lock import Lock as RedisLock -from sqlalchemy.orm import Session - -from onyx.configs.app_configs import DB_YIELD_PER_DEFAULT -from onyx.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT -from onyx.configs.constants import OnyxCeleryPriority -from onyx.configs.constants import OnyxCeleryQueues -from onyx.configs.constants import OnyxCeleryTask -from onyx.configs.constants import OnyxRedisConstants -from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id -from onyx.db.document import ( - construct_document_id_select_for_connector_credential_pair_by_needs_sync, -) -from onyx.redis.redis_object_helper import RedisObjectHelper - - -class RedisConnectorCredentialPair(RedisObjectHelper): - """This class is used to scan documents by cc_pair in the db and collect them into - a unified set for syncing. - - It differs from the other redis helpers in that the taskset used spans - all connectors and is not per connector.""" - - PREFIX = "connectorsync" - TASKSET_PREFIX = PREFIX + "_taskset" - - def __init__(self, tenant_id: str, id: int) -> None: - super().__init__(tenant_id, str(id)) - - # documents that should be skipped - self.skip_docs: set[str] = set() - - @classmethod - def get_taskset_key(cls) -> str: - return RedisConnectorCredentialPair.TASKSET_PREFIX - - @property - def taskset_key(self) -> str: - """Notice that this is intentionally reusing the same taskset for all - connector syncs""" - # example: connectorsync_taskset - return f"{self.TASKSET_PREFIX}" - - def set_skip_docs(self, skip_docs: set[str]) -> None: - # documents that should be skipped. Note that this class updates - # the list on the fly - self.skip_docs = skip_docs - - def generate_tasks( - self, - max_tasks: int, - celery_app: Celery, - db_session: Session, - redis_client: Redis, - lock: RedisLock, - tenant_id: str, - ) -> tuple[int, int] | None: - """We can limit the number of tasks generated here, which is useful to prevent - one tenant from overwhelming the sync queue. - - This works because the dirty state of a document is in the DB, so more docs - get picked up after the limited set of tasks is complete. - """ - - last_lock_time = time.monotonic() - - num_tasks_sent = 0 - - cc_pair = get_connector_credential_pair_from_id( - db_session=db_session, - cc_pair_id=int(self._id), - ) - if not cc_pair: - return None - - stmt = construct_document_id_select_for_connector_credential_pair_by_needs_sync( - cc_pair.connector_id, cc_pair.credential_id - ) - - num_docs = 0 - - for doc_id in db_session.scalars(stmt).yield_per(DB_YIELD_PER_DEFAULT): - doc_id = cast(str, doc_id) - current_time = time.monotonic() - if current_time - last_lock_time >= ( - CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4 - ): - lock.reacquire() - last_lock_time = current_time - - num_docs += 1 - - # check if we should skip the document (typically because it's already syncing) - if doc_id in self.skip_docs: - continue - - # celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac" - # the key for the result is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac" - # we prefix the task id so it's easier to keep track of who created the task - # aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac" - custom_task_id = f"{self.task_id_prefix}_{uuid4()}" - - # add to the tracking taskset in redis BEFORE creating the celery task. - # note that for the moment we are using a single taskset key, not differentiated by cc_pair id - redis_client.sadd( - RedisConnectorCredentialPair.get_taskset_key(), custom_task_id - ) - - # Priority on sync's triggered by new indexing should be medium - celery_app.send_task( - OnyxCeleryTask.VESPA_METADATA_SYNC_TASK, - kwargs=dict(document_id=doc_id, tenant_id=tenant_id), - queue=OnyxCeleryQueues.VESPA_METADATA_SYNC, - task_id=custom_task_id, - priority=OnyxCeleryPriority.MEDIUM, - ignore_result=True, - ) - - num_tasks_sent += 1 - self.skip_docs.add(doc_id) - - if num_tasks_sent >= max_tasks: - break - - return num_tasks_sent, num_docs - - -class RedisGlobalConnectorCredentialPair: - """This class is used to scan documents by cc_pair in the db and collect them into - a unified set for syncing. - - It differs from the other redis helpers in that the taskset used spans - all connectors and is not per connector.""" - - PREFIX = "connectorsync" - FENCE_KEY = PREFIX + "_fence" - TASKSET_KEY = PREFIX + "_taskset" - - def __init__(self, redis: redis.Redis) -> None: - self.redis = redis - - @property - def fenced(self) -> bool: - if self.redis.exists(self.fence_key): - return True - - return False - - @property - def payload(self) -> int | None: - bytes = self.redis.get(self.fence_key) - if bytes is None: - return None - - progress = int(cast(int, bytes)) - return progress - - def get_remaining(self) -> int: - remaining = cast(int, self.redis.scard(self.taskset_key)) - return remaining - - @property - def fence_key(self) -> str: - """Notice that this is intentionally reusing the same fence for all - connector syncs""" - # example: connectorsync_fence - return f"{self.FENCE_KEY}" - - @property - def taskset_key(self) -> str: - """Notice that this is intentionally reusing the same taskset for all - connector syncs""" - # example: connectorsync_taskset - return f"{self.TASKSET_KEY}" - - def set_fence(self, payload: int | None) -> None: - if payload is None: - self.redis.srem(OnyxRedisConstants.ACTIVE_FENCES, self.fence_key) - self.redis.delete(self.fence_key) - return - - self.redis.set(self.fence_key, payload) - self.redis.sadd(OnyxRedisConstants.ACTIVE_FENCES, self.fence_key) - - def delete_taskset(self) -> None: - self.redis.delete(self.taskset_key) - - def reset(self) -> None: - self.redis.srem(OnyxRedisConstants.ACTIVE_FENCES, self.fence_key) - self.redis.delete(self.taskset_key) - self.redis.delete(self.fence_key) - - @staticmethod - def reset_all(r: redis.Redis) -> None: - r.srem( - OnyxRedisConstants.ACTIVE_FENCES, - RedisGlobalConnectorCredentialPair.FENCE_KEY, - ) - r.delete(RedisGlobalConnectorCredentialPair.TASKSET_KEY) - r.delete(RedisGlobalConnectorCredentialPair.FENCE_KEY) diff --git a/backend/onyx/redis/redis_utils.py b/backend/onyx/redis/redis_utils.py index d311ca84eea..1403238513a 100644 --- a/backend/onyx/redis/redis_utils.py +++ b/backend/onyx/redis/redis_utils.py @@ -1,6 +1,3 @@ -from onyx.redis.redis_connector_credential_pair import ( - RedisGlobalConnectorCredentialPair, -) from onyx.redis.redis_connector_delete import RedisConnectorDelete from onyx.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync from onyx.redis.redis_connector_index import RedisConnectorIndex @@ -11,8 +8,6 @@ def is_fence(key_bytes: bytes) -> bool: key_str = key_bytes.decode("utf-8") - if key_str == RedisGlobalConnectorCredentialPair.FENCE_KEY: - return True if key_str.startswith(RedisDocumentSet.FENCE_PREFIX): return True if key_str.startswith(RedisUserGroup.FENCE_PREFIX): From 7dbe4ed50a59f6f7bde680a108a1c0945fb64b3b Mon Sep 17 00:00:00 2001 From: Chris Weaver Date: Fri, 18 Jul 2025 14:16:10 -0700 Subject: [PATCH 10/78] fix: improve assistant fetching efficiency (#5047) * Improve assistant fetching efficiency * More fix * Fix weird build stuff * Improve --- .../usage/PersonaMessagesChart.tsx | 13 ++-- .../app/ee/admin/performance/usage/page.tsx | 7 ++- web/src/app/layout.tsx | 9 +-- web/src/components/context/AppProvider.tsx | 10 +--- .../components/context/AssistantsContext.tsx | 59 +------------------ web/src/lib/chat/fetchAssistantdata.ts | 55 ++--------------- 6 files changed, 24 insertions(+), 129 deletions(-) diff --git a/web/src/app/ee/admin/performance/usage/PersonaMessagesChart.tsx b/web/src/app/ee/admin/performance/usage/PersonaMessagesChart.tsx index f084657e7a7..a90fa8f6985 100644 --- a/web/src/app/ee/admin/performance/usage/PersonaMessagesChart.tsx +++ b/web/src/app/ee/admin/performance/usage/PersonaMessagesChart.tsx @@ -5,7 +5,6 @@ import { usePersonaMessages, usePersonaUniqueUsers, } from "../lib"; -import { useAssistants } from "@/components/context/AssistantsContext"; import { DateRangePickerValue } from "@/components/dateRangeSelectors/AdminDateRangeSelector"; import Text from "@/components/ui/text"; import Title from "@/components/ui/title"; @@ -19,10 +18,13 @@ import { SelectValue, } from "@/components/ui/select"; import { useState, useMemo, useEffect } from "react"; +import { Persona } from "@/app/admin/assistants/interfaces"; export function PersonaMessagesChart({ + availablePersonas, timeRange, }: { + availablePersonas: Persona[]; timeRange: DateRangePickerValue; }) { const [selectedPersonaId, setSelectedPersonaId] = useState< @@ -30,7 +32,6 @@ export function PersonaMessagesChart({ >(undefined); const [searchQuery, setSearchQuery] = useState(""); const [highlightedIndex, setHighlightedIndex] = useState(-1); - const { allAssistants: personaList } = useAssistants(); const { data: personaMessagesData, @@ -48,11 +49,11 @@ export function PersonaMessagesChart({ const hasError = personaMessagesError || personaUniqueUsersError; const filteredPersonaList = useMemo(() => { - if (!personaList) return []; - return personaList.filter((persona) => + if (!availablePersonas) return []; + return availablePersonas.filter((persona) => persona.name.toLowerCase().includes(searchQuery.toLowerCase()) ); - }, [personaList, searchQuery]); + }, [availablePersonas, searchQuery]); const handleKeyDown = (e: React.KeyboardEvent) => { e.stopPropagation(); @@ -142,7 +143,7 @@ export function PersonaMessagesChart({
); - } else if (!personaList || hasError) { + } else if (!availablePersonas || hasError) { content = (

Failed to fetch data...

diff --git a/web/src/app/ee/admin/performance/usage/page.tsx b/web/src/app/ee/admin/performance/usage/page.tsx index ec5604298a2..937983a5a4d 100644 --- a/web/src/app/ee/admin/performance/usage/page.tsx +++ b/web/src/app/ee/admin/performance/usage/page.tsx @@ -10,9 +10,11 @@ import { AdminPageTitle } from "@/components/admin/Title"; import { FiActivity } from "react-icons/fi"; import UsageReports from "./UsageReports"; import { Separator } from "@/components/ui/separator"; +import { useAdminPersonas } from "@/app/admin/assistants/hooks"; export default function AnalyticsPage() { const [timeRange, setTimeRange] = useTimeRange(); + const { personas } = useAdminPersonas(); return (
@@ -27,7 +29,10 @@ export default function AnalyticsPage() { - +
diff --git a/web/src/app/layout.tsx b/web/src/app/layout.tsx index 707367a459d..d68a3b756c6 100644 --- a/web/src/app/layout.tsx +++ b/web/src/app/layout.tsx @@ -17,7 +17,6 @@ import { EnterpriseSettings, ApplicationStatus, } from "./admin/settings/interfaces"; -import { fetchAssistantData } from "@/lib/chat/fetchAssistantdata"; import { AppProvider } from "@/components/context/AppProvider"; import { PHProvider } from "./providers"; import { getAuthTypeMetadataSS, getCurrentUserSS } from "@/lib/userSS"; @@ -31,6 +30,7 @@ import { DocumentsProvider } from "./chat/my-documents/DocumentsContext"; import CloudError from "@/components/errorPages/CloudErrorPage"; import Error from "@/components/errorPages/ErrorPage"; import AccessRestrictedPage from "@/components/errorPages/AccessRestrictedPage"; +import { fetchAssistantData } from "@/lib/chat/fetchAssistantdata"; const inter = Inter({ subsets: ["latin"], @@ -71,7 +71,7 @@ export default async function RootLayout({ }: { children: React.ReactNode; }) { - const [combinedSettings, assistantsData, user, authTypeMetadata] = + const [combinedSettings, assistants, user, authTypeMetadata] = await Promise.all([ fetchSettingsSS(), fetchAssistantData(), @@ -145,17 +145,12 @@ export default async function RootLayout({ ); } - const { assistants, hasAnyConnectors, hasImageCompatibleModel } = - assistantsData; - return getPageContent( diff --git a/web/src/components/context/AppProvider.tsx b/web/src/components/context/AppProvider.tsx index 8e693336a9d..53b9c957097 100644 --- a/web/src/components/context/AppProvider.tsx +++ b/web/src/components/context/AppProvider.tsx @@ -14,8 +14,6 @@ interface AppProviderProps { user: User | null; settings: CombinedSettings; assistants: MinimalPersonaSnapshot[]; - hasAnyConnectors: boolean; - hasImageCompatibleModel: boolean; authTypeMetadata: AuthTypeMetadata; } @@ -24,8 +22,6 @@ export const AppProvider = ({ user, settings, assistants, - hasAnyConnectors, - hasImageCompatibleModel, authTypeMetadata, }: AppProviderProps) => { return ( @@ -36,11 +32,7 @@ export const AppProvider = ({ authTypeMetadata={authTypeMetadata} > - + {children} diff --git a/web/src/components/context/AssistantsContext.tsx b/web/src/components/context/AssistantsContext.tsx index 950f6e7ba75..7bb48c73194 100644 --- a/web/src/components/context/AssistantsContext.tsx +++ b/web/src/components/context/AssistantsContext.tsx @@ -25,9 +25,6 @@ interface AssistantsContextProps { ownedButHiddenAssistants: MinimalPersonaSnapshot[]; refreshAssistants: () => Promise; isImageGenerationAvailable: boolean; - // Admin only - editablePersonas: MinimalPersonaSnapshot[]; - allAssistants: MinimalPersonaSnapshot[]; pinnedAssistants: MinimalPersonaSnapshot[]; setPinnedAssistants: Dispatch>; } @@ -41,22 +38,11 @@ export const AssistantsProvider: React.FC<{ initialAssistants: MinimalPersonaSnapshot[]; hasAnyConnectors?: boolean; hasImageCompatibleModel?: boolean; -}> = ({ - children, - initialAssistants, - hasAnyConnectors, - hasImageCompatibleModel, -}) => { +}> = ({ children, initialAssistants }) => { const [assistants, setAssistants] = useState( initialAssistants || [] ); - const { user, isAdmin, isCurator } = useUser(); - const [editablePersonas, setEditablePersonas] = useState< - MinimalPersonaSnapshot[] - >([]); - const [allAssistants, setAllAssistants] = useState( - [] - ); + const { user } = useUser(); const [pinnedAssistants, setPinnedAssistants] = useState< MinimalPersonaSnapshot[] @@ -107,37 +93,6 @@ export const AssistantsProvider: React.FC<{ checkImageGenerationAvailability(); }, []); - const fetchPersonas = async () => { - if (!isAdmin && !isCurator) { - return; - } - - try { - const [editableResponse, allResponse] = await Promise.all([ - fetch("/api/admin/persona?get_editable=true"), - fetch("/api/admin/persona"), - ]); - - if (editableResponse.ok) { - const editablePersonas = await editableResponse.json(); - setEditablePersonas(editablePersonas); - } - - if (allResponse.ok) { - const allPersonas = await allResponse.json(); - setAllAssistants(allPersonas); - } else { - console.error("Error fetching personas:", allResponse); - } - } catch (error) { - console.error("Error fetching personas:", error); - } - }; - - useEffect(() => { - fetchPersonas(); - }, [isAdmin, isCurator]); - const refreshAssistants = async () => { try { const response = await fetch("/api/persona", { @@ -148,13 +103,7 @@ export const AssistantsProvider: React.FC<{ }); if (!response.ok) throw new Error("Failed to fetch assistants"); let assistants: MinimalPersonaSnapshot[] = await response.json(); - - let filteredAssistants = filterAssistants(assistants); - - setAssistants(filteredAssistants); - - // Fetch and update allAssistants for admins and curators - await fetchPersonas(); + setAssistants(filterAssistants(assistants)); } catch (error) { console.error("Error refreshing assistants:", error); } @@ -197,8 +146,6 @@ export const AssistantsProvider: React.FC<{ finalAssistants, ownedButHiddenAssistants, refreshAssistants, - editablePersonas, - allAssistants, isImageGenerationAvailable, setPinnedAssistants, pinnedAssistants, diff --git a/web/src/lib/chat/fetchAssistantdata.ts b/web/src/lib/chat/fetchAssistantdata.ts index f17b70de593..7c76fa932f5 100644 --- a/web/src/lib/chat/fetchAssistantdata.ts +++ b/web/src/lib/chat/fetchAssistantdata.ts @@ -1,65 +1,20 @@ -import { fetchSS } from "@/lib/utilsSS"; import { MinimalPersonaSnapshot } from "@/app/admin/assistants/interfaces"; -import { fetchLLMProvidersSS } from "@/lib/llm/fetchLLMs"; import { fetchAssistantsSS } from "../assistants/fetchAssistantsSS"; -import { modelSupportsImageInput } from "../llm/utils"; import { filterAssistants } from "../assistants/utils"; -interface AssistantData { - assistants: MinimalPersonaSnapshot[]; - hasAnyConnectors: boolean; - hasImageCompatibleModel: boolean; -} -export async function fetchAssistantData(): Promise { - // Default state if anything fails - const defaultState: AssistantData = { - assistants: [], - hasAnyConnectors: false, - hasImageCompatibleModel: false, - }; - +export async function fetchAssistantData(): Promise { try { - // Fetch core assistants data first + // Fetch core assistants data const [assistants, assistantsFetchError] = await fetchAssistantsSS(); if (assistantsFetchError) { // This is not a critical error and occurs when the user is not logged in console.warn(`Failed to fetch assistants - ${assistantsFetchError}`); - return defaultState; + return []; } - // Parallel fetch of additional data - const [ccPairsResponse, llmProviders] = await Promise.all([ - fetchSS("/manage/connector-status").catch((error) => { - console.error("Failed to fetch connectors:", error); - return null; - }), - fetchLLMProvidersSS().catch((error) => { - console.error("Failed to fetch LLM providers:", error); - return []; - }), - ]); - - const hasAnyConnectors = ccPairsResponse?.ok - ? (await ccPairsResponse.json()).length > 0 - : false; - - const hasImageCompatibleModel = llmProviders.some( - (provider) => - provider.provider === "openai" || - provider.model_configurations.some((modelConfiguration) => - modelSupportsImageInput(llmProviders, modelConfiguration.name) - ) - ); - - let filteredAssistants = filterAssistants(assistants); - - return { - assistants: filteredAssistants, - hasAnyConnectors, - hasImageCompatibleModel, - }; + return filterAssistants(assistants); } catch (error) { console.error("Unexpected error in fetchAssistantData:", error); - return defaultState; + return []; } } From d1d8626b405d49527a0beab6919344c8de852e36 Mon Sep 17 00:00:00 2001 From: joachim-danswer Date: Fri, 18 Jul 2025 16:15:11 -0700 Subject: [PATCH 11/78] feat: KG improvements (#5048) * improvements * drop views if SQL fails * mypy fix --- .../kb_search/nodes/a3_generate_simple_sql.py | 48 +++++++++++++++++-- backend/onyx/db/kg_temp_view.py | 9 ++-- 2 files changed, 49 insertions(+), 8 deletions(-) diff --git a/backend/onyx/agents/agent_search/kb_search/nodes/a3_generate_simple_sql.py b/backend/onyx/agents/agent_search/kb_search/nodes/a3_generate_simple_sql.py index 53f823017bc..181a15fcee2 100644 --- a/backend/onyx/agents/agent_search/kb_search/nodes/a3_generate_simple_sql.py +++ b/backend/onyx/agents/agent_search/kb_search/nodes/a3_generate_simple_sql.py @@ -203,6 +203,8 @@ def generate_simple_sql( if state.kg_entity_temp_view_name is None: raise ValueError("kg_entity_temp_view_name is not set") + sql_statement_display: str | None = None + ## STEP 3 - articulate goals stream_write_step_activities(writer, _KG_STEP_NR) @@ -381,7 +383,18 @@ def generate_simple_sql( raise e - logger.debug(f"A3 - sql_statement after correction: {sql_statement}") + # display sql statement with view names replaced by general view names + sql_statement_display = sql_statement.replace( + state.kg_doc_temp_view_name, "" + ) + sql_statement_display = sql_statement_display.replace( + state.kg_rel_temp_view_name, "" + ) + sql_statement_display = sql_statement_display.replace( + state.kg_entity_temp_view_name, "" + ) + + logger.debug(f"A3 - sql_statement after correction: {sql_statement_display}") # Get SQL for source documents @@ -409,7 +422,20 @@ def generate_simple_sql( "relationship_table", rel_temp_view ) - logger.debug(f"A3 source_documents_sql: {source_documents_sql}") + if source_documents_sql: + source_documents_sql_display = source_documents_sql.replace( + state.kg_doc_temp_view_name, "" + ) + source_documents_sql_display = source_documents_sql_display.replace( + state.kg_rel_temp_view_name, "" + ) + source_documents_sql_display = source_documents_sql_display.replace( + state.kg_entity_temp_view_name, "" + ) + else: + source_documents_sql_display = "(No source documents SQL generated)" + + logger.debug(f"A3 source_documents_sql: {source_documents_sql_display}") scalar_result = None query_results = None @@ -435,7 +461,13 @@ def generate_simple_sql( rows = result.fetchall() query_results = [dict(row._mapping) for row in rows] except Exception as e: + # TODO: raise error on frontend logger.error(f"Error executing SQL query: {e}") + drop_views( + allowed_docs_view_name=doc_temp_view, + kg_relationships_view_name=rel_temp_view, + kg_entity_view_name=ent_temp_view, + ) raise e @@ -459,8 +491,14 @@ def generate_simple_sql( for source_document_result in query_source_document_results ] except Exception as e: - # No stopping here, the individualized SQL query is not mandatory # TODO: raise error on frontend + + drop_views( + allowed_docs_view_name=doc_temp_view, + kg_relationships_view_name=rel_temp_view, + kg_entity_view_name=ent_temp_view, + ) + logger.error(f"Error executing Individualized SQL query: {e}") else: @@ -493,11 +531,11 @@ def generate_simple_sql( if reasoning: stream_write_step_answer_explicit(writer, step_nr=_KG_STEP_NR, answer=reasoning) - if main_sql_statement: + if sql_statement_display: stream_write_step_answer_explicit( writer, step_nr=_KG_STEP_NR, - answer=f" \n Generated SQL: {main_sql_statement}", + answer=f" \n Generated SQL: {sql_statement_display}", ) stream_close_step_answer(writer, _KG_STEP_NR) diff --git a/backend/onyx/db/kg_temp_view.py b/backend/onyx/db/kg_temp_view.py index 5a956ea53e2..baa4b012141 100644 --- a/backend/onyx/db/kg_temp_view.py +++ b/backend/onyx/db/kg_temp_view.py @@ -1,3 +1,5 @@ +import random + from sqlalchemy import text from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import Session @@ -17,10 +19,11 @@ def get_user_view_names(user_email: str, tenant_id: str) -> KGViewNames: user_email_cleaned = ( user_email.replace("@", "__").replace(".", "_").replace("+", "_") ) + random_suffix_str = str(random.randint(1000000, 9999999)) return KGViewNames( - allowed_docs_view_name=f'"{tenant_id}".{KG_TEMP_ALLOWED_DOCS_VIEW_NAME_PREFIX}_{user_email_cleaned}', - kg_relationships_view_name=f'"{tenant_id}".{KG_TEMP_KG_RELATIONSHIPS_VIEW_NAME_PREFIX}_{user_email_cleaned}', - kg_entity_view_name=f'"{tenant_id}".{KG_TEMP_KG_ENTITIES_VIEW_NAME_PREFIX}_{user_email_cleaned}', + allowed_docs_view_name=f'"{tenant_id}".{KG_TEMP_ALLOWED_DOCS_VIEW_NAME_PREFIX}_{user_email_cleaned}_{random_suffix_str}', + kg_relationships_view_name=f'"{tenant_id}".{KG_TEMP_KG_RELATIONSHIPS_VIEW_NAME_PREFIX}_{user_email_cleaned}_{random_suffix_str}', + kg_entity_view_name=f'"{tenant_id}".{KG_TEMP_KG_ENTITIES_VIEW_NAME_PREFIX}_{user_email_cleaned}_{random_suffix_str}', ) From 2b856d40d4574d38cff079b95fffb5039c561cc6 Mon Sep 17 00:00:00 2001 From: Rei Meguro <36625832+Orbital-Web@users.noreply.github.com> Date: Sat, 19 Jul 2025 10:51:51 +0900 Subject: [PATCH 12/78] feat: Search and Answer Quality Test Script (#4974) * aefads * search quality tests improvement Co-authored-by: wenxi-onyx * nits * refactor: config refactor * document context + skip genai fix * feat: answer eval * more error messages * mypy ragas * mypy * small fixes * feat: more metrics * fix * feat: grab content * typing * feat: lazy updates * mypy * all at front * feat: answer correctness * use api key so it works with auth enabled * update readme * feat: auto add path * feat: rate limit * fix: readme + remove rerank all * fix: raise exception immediately * docs: improved clarity * feat: federated handling * fix: mypy * nits --------- Co-authored-by: wenxi-onyx --- backend/onyx/chat/process_message.py | 1 + .../tests/regression/search_quality/README.md | 60 +- .../tests/regression/search_quality/models.py | 82 ++ .../search_quality/run_search_eval.py | 822 +++++++++++++++--- .../search_eval_config.yaml.template | 16 - .../search_quality/test_queries.json.template | 10 +- .../regression/search_quality/util_config.py | 75 -- .../regression/search_quality/util_data.py | 166 ---- .../regression/search_quality/util_eval.py | 94 -- .../search_quality/util_retrieve.py | 88 -- .../tests/regression/search_quality/utils.py | 208 +++++ 11 files changed, 1023 insertions(+), 599 deletions(-) create mode 100644 backend/tests/regression/search_quality/models.py delete mode 100644 backend/tests/regression/search_quality/search_eval_config.yaml.template delete mode 100644 backend/tests/regression/search_quality/util_config.py delete mode 100644 backend/tests/regression/search_quality/util_data.py delete mode 100644 backend/tests/regression/search_quality/util_eval.py delete mode 100644 backend/tests/regression/search_quality/util_retrieve.py create mode 100644 backend/tests/regression/search_quality/utils.py diff --git a/backend/onyx/chat/process_message.py b/backend/onyx/chat/process_message.py index 1a7f64b5003..dbb04d3a963 100644 --- a/backend/onyx/chat/process_message.py +++ b/backend/onyx/chat/process_message.py @@ -1012,6 +1012,7 @@ def create_response( tools=tools, db_session=db_session, use_agentic_search=new_msg_req.use_agentic_search, + skip_gen_ai_answer_generation=new_msg_req.skip_gen_ai_answer_generation, ) info_by_subq: dict[SubQuestionKey, AnswerPostInfo] = defaultdict( diff --git a/backend/tests/regression/search_quality/README.md b/backend/tests/regression/search_quality/README.md index db35cf972fd..c4eb87d5ac0 100644 --- a/backend/tests/regression/search_quality/README.md +++ b/backend/tests/regression/search_quality/README.md @@ -1,62 +1,50 @@ # Search Quality Test Script -This Python script evaluates the search results for a list of queries. - -This script will likely get refactored in the future as an API endpoint. -In the meanwhile, it is used to evaluate the search quality using locally ingested documents. -The key differentiating factor with `answer_quality` is that it can evaluate results without explicit "ground truth" using the reranker as a reference. +This Python script evaluates the search and answer quality for a list of queries, against a ground truth. It will use the currently ingested documents for the search, answer generation, and ground truth comparisons. ## Usage 1. Ensure you have the required dependencies installed and onyx running. -2. Ensure a reranker model is configured in the search settings. -This can be checked/modified by opening the admin panel, going to search settings, and ensuring a reranking model is set. - -3. Set up the PYTHONPATH permanently: - Add the following line to your shell configuration file (e.g., `~/.bashrc`, `~/.zshrc`, or `~/.bash_profile`): - ``` - export PYTHONPATH=$PYTHONPATH:/path/to/onyx/backend - ``` - Replace `/path/to/onyx` with the actual path to your Onyx repository. - After adding this line, restart your terminal or run `source ~/.bashrc` (or the appropriate config file) to apply the changes. +2. Ensure you have `OPENAI_API_KEY` set if you intend to do answer evaluation (enabled by default, unless you run the script with the `-s` flag). Also, if you're not using `AUTH_TYPE=disabled`, go to the API Keys page in the admin panel, generate a basic api token, and add it to the env file as `ONYX_API_KEY=on_...`. -4. Navigate to Onyx repo, search_quality folder: +3. Navigate to Onyx repo, **search_quality** folder: ``` cd path/to/onyx/backend/tests/regression/search_quality ``` -5. Copy `test_queries.json.template` to `test_queries.json` and add/remove test queries in it. The possible fields are: +4. Copy `test_queries.json.template` to `test_queries.json` and add/remove test queries in it. The fields for each query are: - `question: str` the query - - `question_search: Optional[str]` modified query specifically for the search step - - `ground_truth: Optional[list[GroundTruth]]` a ranked list of expected search results with fields: - - `doc_source: str` document source (e.g., Web, Drive, Linear), currently unused + - `ground_truth: list[GroundTruth]` an un-ranked list of expected search results with fields: + - `doc_source: str` document source (e.g., web, google_drive, linear), used to normalize the links in some cases - `doc_link: str` link associated with document, used to find corresponding document in local index + - `ground_truth_response: Optional[str]` a response with clauses the ideal answer should include - `categories: Optional[list[str]]` list of categories, used to aggregate evaluation results -6. Copy `search_eval_config.yaml.template` to `search_eval_config.yaml` and specify the search and eval parameters - -7. Run `run_search_eval.py` to run the search and evaluate the search results +5. Run `run_search_eval.py` to evaluate the queries. All parameters are optional and have sensible defaults: ``` python run_search_eval.py + -d --dataset # Path to the test-set JSON file (default: ./test_queries.json) + -n --num_search # Maximum number of documents to retrieve per search (default: 50) + -a --num_answer # Maximum number of documents to use for answer evaluation (default: 25) + -w --max_workers # Maximum number of concurrent search requests (0 = unlimited, default: 10). + -r --max_req_rate # Maximum number of search requests per minute (0 = unlimited, default: 0). + -q --timeout # Request timeout in seconds (default: 120) + -e --api_endpoint # Base URL of the Onyx API server (default: http://127.0.0.1:8080) + -s --search_only # Only perform search and not answer evaluation (default: false) + -t --tenant_id # Tenant ID to use for the evaluation (default: None) ``` -8. Optionally, save the generated `test_queries.json` in the export folder to reuse the generated `question_search`, and rerun the search evaluation with alternative search parameters. - -## Metrics -There are two main metrics currently implemented: -- ratio_topk: the ratio of documents in the comparison set that are in the topk search results (higher is better, 0-1) -- avg_rank_delta: the average rank difference between the comparison set and search results (lower is better, 0-inf) - -Ratio topk gives a general idea on whether the most relevant documents are appearing first in the search results. Decreasing `eval_topk` will make this metric stricter, requiring relevant documents to appear in a narrow window. - -Avg rank delta is another metric which can give insight on the performance of documents not in the topk search results. If none of the comparison documents are in the topk, `ratio_topk` will only show a 0, whereas `avg_rank_delta` will show a higher value the worse the search results gets. +Note: If you only care about search quality, you should run with the `-s` flag for a significantly faster evaluation. Furthermore, you should set `-r` to 1 if running with federated search enabled to avoid hitting rate limits. -Furthermore, there are two versions of the metrics: ground truth, and soft truth. +6. After the run, an `eval-YYYY-MM-DD-HH-MM-SS` folder is created containing: -The ground truth includes documents explicitly listed as relevant in the test dataset. The ground truth metrics will only be computed if a ground truth set is provided for the question and exists in the index. + * `test_queries.json` – the dataset used with the list of valid queries and corresponding indexed ground truth. + * `search_results.json` – per-query search and answer details. + * `results_by_category.csv` – aggregated metrics per category and for "all". + * `search_position_chart.png` – bar-chart of ground-truth ranks. -The soft truth is built on top of the ground truth (if provided), filling the remaining entries with results from the reranker. The soft truth metrics will only be computed if `skip_rerank` is false. Computing the soft truth metric can be extremely slow, especially for large `num_returned_hits`. However, it can provide a good basis when there are many relevant documents in no particular order, or for running quick tests without explicitly having to mention which documents are relevant. \ No newline at end of file +You can replace `test_queries.json` with the generated one for a slightly faster loading of the queries the next time around. \ No newline at end of file diff --git a/backend/tests/regression/search_quality/models.py b/backend/tests/regression/search_quality/models.py new file mode 100644 index 00000000000..b8c00e003fa --- /dev/null +++ b/backend/tests/regression/search_quality/models.py @@ -0,0 +1,82 @@ +from pydantic import BaseModel + +from onyx.configs.constants import DocumentSource +from onyx.context.search.models import SavedSearchDoc + + +class GroundTruth(BaseModel): + doc_source: DocumentSource + doc_link: str + + +class TestQuery(BaseModel): + question: str + ground_truth: list[GroundTruth] = [] + ground_truth_response: str | None = None + categories: list[str] = [] + + # autogenerated + ground_truth_docids: list[str] = [] + + +class EvalConfig(BaseModel): + max_search_results: int + max_answer_context: int + num_workers: int # 0 = unlimited + max_request_rate: int # 0 = unlimited + request_timeout: int + api_url: str + search_only: bool + + +class OneshotQAResult(BaseModel): + time_taken: float + top_documents: list[SavedSearchDoc] + answer: str | None + + +class RetrievedDocument(BaseModel): + document_id: str + chunk_id: int + content: str + + +class AnalysisSummary(BaseModel): + question: str + categories: list[str] + found: bool + rank: int | None + total_results: int + ground_truth_count: int + response_relevancy: float | None = None + faithfulness: float | None = None + factual_correctness: float | None = None + answer: str | None = None + retrieved: list[RetrievedDocument] = [] + time_taken: float + + +class SearchMetrics(BaseModel): + total_queries: int + found_count: int + + # for found results + best_rank: int + worst_rank: int + average_rank: float + top_k_accuracy: dict[int, float] + + +class AnswerMetrics(BaseModel): + response_relevancy: float + faithfulness: float + factual_correctness: float + + # only for metric computation + n_response_relevancy: int + n_faithfulness: int + n_factual_correctness: int + + +class CombinedMetrics(SearchMetrics, AnswerMetrics): + average_time_taken: float diff --git a/backend/tests/regression/search_quality/run_search_eval.py b/backend/tests/regression/search_quality/run_search_eval.py index 1c3d03744d5..43dcd55474b 100644 --- a/backend/tests/regression/search_quality/run_search_eval.py +++ b/backend/tests/regression/search_quality/run_search_eval.py @@ -1,151 +1,725 @@ import csv +import json +import os +import sys +import time from collections import defaultdict +from concurrent.futures import as_completed +from concurrent.futures import ThreadPoolExecutor +from datetime import datetime from pathlib import Path +from threading import Event +from threading import Lock +from threading import Semaphore +from typing import cast +import matplotlib.pyplot as plt # type: ignore +import requests +from dotenv import load_dotenv +from matplotlib.patches import Patch # type: ignore +from pydantic import ValidationError +from requests.exceptions import RequestException +from retry import retry + +# add onyx/backend to path (since this isn't done automatically when running as a script) +current_dir = Path(__file__).parent +onyx_dir = current_dir.parent.parent.parent.parent +sys.path.append(str(onyx_dir / "backend")) + +# load env before app_config loads (since env doesn't get loaded when running as a script) +env_path = onyx_dir / ".vscode" / ".env" +if not env_path.exists(): + raise RuntimeError( + "Could not find .env file. Please create one in the root .vscode directory." + ) +load_dotenv(env_path) + +# pylint: disable=E402 +# flake8: noqa: E402 + +from ee.onyx.server.query_and_chat.models import OneShotQARequest +from ee.onyx.server.query_and_chat.models import OneShotQAResponse +from onyx.chat.models import ThreadMessage from onyx.configs.app_configs import POSTGRES_API_SERVER_POOL_OVERFLOW from onyx.configs.app_configs import POSTGRES_API_SERVER_POOL_SIZE -from onyx.context.search.models import RerankingDetails -from onyx.db.engine.sql_engine import get_session_with_current_tenant +from onyx.configs.app_configs import AUTH_TYPE +from onyx.configs.constants import AuthType +from onyx.configs.constants import MessageType +from onyx.context.search.enums import OptionalSearchSetting +from onyx.context.search.models import IndexFilters +from onyx.context.search.models import RetrievalDetails +from onyx.db.engine.sql_engine import get_session_with_tenant from onyx.db.engine.sql_engine import SqlEngine -from onyx.db.search_settings import get_current_search_settings -from onyx.db.search_settings import get_multilingual_expansion -from onyx.document_index.factory import get_default_document_index from onyx.utils.logger import setup_logger from shared_configs.configs import MULTI_TENANT -from tests.regression.search_quality.util_config import load_config -from tests.regression.search_quality.util_data import export_test_queries -from tests.regression.search_quality.util_data import load_test_queries -from tests.regression.search_quality.util_eval import evaluate_one_query -from tests.regression.search_quality.util_eval import get_corresponding_document -from tests.regression.search_quality.util_eval import metric_names -from tests.regression.search_quality.util_retrieve import rerank_one_query -from tests.regression.search_quality.util_retrieve import search_one_query +from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE +from tests.regression.search_quality.models import AnalysisSummary +from tests.regression.search_quality.models import CombinedMetrics +from tests.regression.search_quality.models import EvalConfig +from tests.regression.search_quality.models import OneshotQAResult +from tests.regression.search_quality.models import TestQuery +from tests.regression.search_quality.utils import compute_overall_scores +from tests.regression.search_quality.utils import find_document_id +from tests.regression.search_quality.utils import get_federated_sources +from tests.regression.search_quality.utils import LazyJsonWriter +from tests.regression.search_quality.utils import ragas_evaluate +from tests.regression.search_quality.utils import search_docs_to_doc_contexts logger = setup_logger(__name__) +GENERAL_HEADERS = {"Content-Type": "application/json"} +TOP_K_LIST = [1, 3, 5, 10] -def run_search_eval() -> None: - config = load_config() - test_queries = load_test_queries() - # export related - export_path = Path(config.export_folder) - export_test_queries(test_queries, export_path / "test_queries.json") - search_result_path = export_path / "search_results.csv" - eval_path = export_path / "eval_results.csv" - aggregate_eval_path = export_path / "aggregate_eval.csv" - aggregate_results: dict[str, list[list[float]]] = defaultdict( - lambda: [[] for _ in metric_names] - ) +class SearchAnswerAnalyzer: + def __init__( + self, + config: EvalConfig, + tenant_id: str | None = None, + ): + if not MULTI_TENANT: + logger.info("Running in single-tenant mode") + tenant_id = POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE + elif tenant_id is None: + raise ValueError("Tenant ID is required for multi-tenant") + + self.config = config + self.tenant_id = tenant_id - with get_session_with_current_tenant() as db_session: - multilingual_expansion = get_multilingual_expansion(db_session) - search_settings = get_current_search_settings(db_session) - document_index = get_default_document_index(search_settings, None) - rerank_settings = RerankingDetails.from_db_model(search_settings) - - if config.skip_rerank: - logger.warning("Reranking is disabled, evaluation will not run") - elif rerank_settings.rerank_model_name is None: - raise ValueError( - "Reranking is enabled but no reranker is configured. " - "Please set the reranker in the admin panel search settings." + # shared analysis results + self._lock = Lock() + self._progress_counter = 0 + self._result_writer: LazyJsonWriter | None = None + self.ranks: list[int | None] = [] + self.metrics: dict[str, CombinedMetrics] = defaultdict( + lambda: CombinedMetrics( + total_queries=0, + found_count=0, + best_rank=config.max_search_results, + worst_rank=1, + average_rank=0.0, + top_k_accuracy={k: 0.0 for k in TOP_K_LIST}, + response_relevancy=0.0, + faithfulness=0.0, + factual_correctness=0.0, + n_response_relevancy=0, + n_faithfulness=0, + n_factual_correctness=0, + average_time_taken=0.0, ) + ) + + def run_analysis(self, dataset_path: Path, export_path: Path) -> None: + # load and save the dataset + dataset = self._load_dataset(dataset_path) + dataset_size = len(dataset) + dataset_export_path = export_path / "test_queries.json" + with dataset_export_path.open("w") as f: + dataset_serializable = [q.model_dump(mode="json") for q in dataset] + json.dump(dataset_serializable, f, indent=4) - # run search and evaluate - logger.info( - "Running search and evaluation... " - f"Individual search and evaluation results will be saved to {search_result_path} and {eval_path}" + result_export_path = export_path / "search_results.json" + self._result_writer = LazyJsonWriter(result_export_path) + + # set up rate limiting and threading primitives + interval = ( + 60.0 / self.config.max_request_rate + if self.config.max_request_rate > 0 + else 0.0 ) - with ( - search_result_path.open("w") as search_file, - eval_path.open("w") as eval_file, - ): - search_csv_writer = csv.writer(search_file) - eval_csv_writer = csv.writer(eval_file) - search_csv_writer.writerow( - ["source", "query", "rank", "score", "doc_id", "chunk_id"] - ) - eval_csv_writer.writerow(["query", *metric_names]) - - for query in test_queries: - # search and write results - assert query.question_search is not None - search_chunks = search_one_query( - query.question_search, - multilingual_expansion, - document_index, - db_session, - config, - ) - for rank, result in enumerate(search_chunks): - search_csv_writer.writerow( + available_workers = Semaphore(self.config.num_workers) + stop_event = Event() + + def _submit_wrapper(tc: TestQuery) -> AnalysisSummary: + try: + return self._run_and_analyze_one(tc, dataset_size) + except Exception as e: + logger.error("Error during analysis: %s", e) + stop_event.set() + raise + finally: + available_workers.release() + + # run the analysis + logger.info("Starting analysis of %d queries", dataset_size) + logger.info("Using %d parallel workers", self.config.num_workers) + logger.info("Exporting search results to %s", result_export_path) + + with ThreadPoolExecutor( + max_workers=self.config.num_workers or None + ) as executor: + # submit requests at configured rate, break early if any error occurs + futures = [] + for tc in dataset: + if stop_event.is_set(): + break + + available_workers.acquire() + fut = executor.submit(_submit_wrapper, tc) + futures.append(fut) + + if ( + len(futures) != dataset_size + and interval > 0 + and not stop_event.is_set() + ): + time.sleep(interval) + + # ensure all tasks finish and surface any exceptions + for fut in as_completed(futures): + fut.result() + + if self._result_writer: + self._result_writer.close() + self._aggregate_metrics() + + def generate_detailed_report(self, export_path: Path) -> None: + logger.info("Generating detailed report...") + + csv_path = export_path / "results_by_category.csv" + with csv_path.open("w", newline="") as csv_file: + csv_writer = csv.writer(csv_file) + csv_writer.writerow( + [ + "category", + "total_queries", + "found", + "percent_found", + "best_rank", + "worst_rank", + "avg_rank", + *[f"top_{k}_accuracy" for k in TOP_K_LIST], + *( [ - "search", - query.question_search, - rank, - result.score, - result.document_id, - result.chunk_id, + "avg_response_relevancy", + "avg_faithfulness", + "avg_factual_correctness", ] - ) + if not self.config.search_only + else [] + ), + "search_score", + *(["answer_score"] if not self.config.search_only else []), + "avg_time_taken", + ] + ) + + for category, metrics in sorted( + self.metrics.items(), key=lambda c: (0 if c[0] == "all" else 1, c[0]) + ): + found_count = metrics.found_count + total_count = metrics.total_queries + accuracy = found_count / total_count * 100 if total_count > 0 else 0 - rerank_chunks = [] - if not config.skip_rerank: - # rerank and write results - rerank_chunks = rerank_one_query( - query.question, search_chunks, rerank_settings + print( + f"\n{category.upper()}:" + f" total queries: {total_count}\n" + f" found: {found_count} ({accuracy:.1f}%)" + ) + best_rank = metrics.best_rank if metrics.found_count > 0 else None + worst_rank = metrics.worst_rank if metrics.found_count > 0 else None + avg_rank = metrics.average_rank if metrics.found_count > 0 else None + if metrics.found_count > 0: + print( + f" average rank (for found results): {avg_rank:.2f}\n" + f" best rank (for found results): {best_rank:.2f}\n" + f" worst rank (for found results): {worst_rank:.2f}" ) - for rank, result in enumerate(rerank_chunks): - search_csv_writer.writerow( - [ - "rerank", - query.question, - rank, - result.score, - result.document_id, - result.chunk_id, - ] + for k, acc in metrics.top_k_accuracy.items(): + print(f" top-{k} accuracy: {acc:.1f}%") + if not self.config.search_only: + if metrics.n_response_relevancy > 0: + print( + f" average response relevancy: {metrics.response_relevancy:.2f}" ) + if metrics.n_faithfulness > 0: + print(f" average faithfulness: {metrics.faithfulness:.2f}") + if metrics.n_factual_correctness > 0: + print( + f" average factual correctness: {metrics.factual_correctness:.2f}" + ) + search_score, answer_score = compute_overall_scores(metrics) + print(f" search score: {search_score:.1f}") + if not self.config.search_only: + print(f" answer score: {answer_score:.1f}") + print(f" average time taken: {metrics.average_time_taken:.2f}s") - # evaluate and write results - truth_documents = [ - doc - for truth in query.ground_truth - if (doc := get_corresponding_document(truth.doc_link, db_session)) - ] - metrics = evaluate_one_query( - search_chunks, rerank_chunks, truth_documents, config.eval_topk - ) - metric_vals = [getattr(metrics, field) for field in metric_names] - eval_csv_writer.writerow([query.question, *metric_vals]) - - # add to aggregation - for category in ["all"] + query.categories: - for i, val in enumerate(metric_vals): - if val is not None: - aggregate_results[category][i].append(val) - - # aggregate and write results - with aggregate_eval_path.open("w") as file: - aggregate_csv_writer = csv.writer(file) - aggregate_csv_writer.writerow(["category", *metric_names]) - - for category, agg_metrics in aggregate_results.items(): - aggregate_csv_writer.writerow( + csv_writer.writerow( [ category, + total_count, + found_count, + f"{accuracy:.1f}", + best_rank or "", + worst_rank or "", + f"{avg_rank:.2f}" if avg_rank is not None else "", + *[f"{acc:.1f}" for acc in metrics.top_k_accuracy.values()], + *( + [ + ( + f"{metrics.response_relevancy:.2f}" + if metrics.n_response_relevancy > 0 + else "" + ), + ( + f"{metrics.faithfulness:.2f}" + if metrics.n_faithfulness > 0 + else "" + ), + ( + f"{metrics.factual_correctness:.2f}" + if metrics.n_factual_correctness > 0 + else "" + ), + ] + if not self.config.search_only + else [] + ), + f"{search_score:.1f}", *( - sum(metric) / len(metric) if metric else None - for metric in agg_metrics + [f"{answer_score:.1f}"] + if not self.config.search_only + else [] ), + f"{metrics.average_time_taken:.2f}", ] ) + logger.info("Saved category breakdown csv to %s", csv_path) + + def generate_chart(self, export_path: Path) -> None: + logger.info("Generating search position chart...") + + if len(self.ranks) == 0: + logger.warning("No results to chart") + return + + found_count = 0 + not_found_count = 0 + rank_counts: dict[int, int] = defaultdict(int) + for rank in self.ranks: + if rank is None: + not_found_count += 1 + else: + found_count += 1 + rank_counts[rank] += 1 + + # create the data for plotting + if found_count: + max_rank = max(rank_counts.keys()) + positions = list(range(1, max_rank + 1)) + counts = [rank_counts.get(pos, 0) for pos in positions] + else: + positions = [] + counts = [] + + # add the "not found" bar on the far right + if not_found_count: + # add some spacing between found positions and "not found" + not_found_position = (max(positions) + 2) if positions else 1 + positions.append(not_found_position) + counts.append(not_found_count) + + # create labels for x-axis + x_labels = [str(pos) for pos in positions[:-1]] + [ + f"not found\n(>{self.config.max_search_results})" + ] + else: + x_labels = [str(pos) for pos in positions] + + # create the figure and bar chart + plt.figure(figsize=(14, 6)) + + # use different colors for found vs not found + colors = ( + ["#3498db"] * (len(positions) - 1) + ["#e74c3c"] + if not_found_count > 0 + else ["#3498db"] * len(positions) + ) + bars = plt.bar( + positions, counts, color=colors, alpha=0.7, edgecolor="black", linewidth=0.5 + ) + + # customize the chart + plt.xlabel("Position in Search Results", fontsize=12) + plt.ylabel("Number of Ground Truth Documents", fontsize=12) + plt.title( + "Ground Truth Document Positions in Search Results", + fontsize=14, + fontweight="bold", + ) + plt.grid(axis="y", alpha=0.3) + + # add value labels on top of each bar + for bar, count in zip(bars, counts): + if count > 0: + plt.text( + bar.get_x() + bar.get_width() / 2, + bar.get_height() + 0.1, + str(count), + ha="center", + va="bottom", + fontweight="bold", + ) + + # set x-axis labels + plt.xticks(positions, x_labels, rotation=45 if not_found_count > 0 else 0) + + # add legend if we have both found and not found + if not_found_count and found_count: + legend_elements = [ + Patch(facecolor="#3498db", alpha=0.7, label="Found in Results"), + Patch(facecolor="#e74c3c", alpha=0.7, label="Not Found"), + ] + plt.legend(handles=legend_elements, loc="upper right") + + # make layout tight and save + plt.tight_layout() + chart_file = export_path / "search_position_chart.png" + plt.savefig(chart_file, dpi=300, bbox_inches="tight") + logger.info("Search position chart saved to: %s", chart_file) + plt.show() + + def _load_dataset(self, dataset_path: Path) -> list[TestQuery]: + """Load the test dataset from a JSON file and validate the ground truth documents.""" + with dataset_path.open("r") as f: + dataset_raw: list[dict] = json.load(f) + + with get_session_with_tenant(tenant_id=self.tenant_id) as db_session: + federated_sources = get_federated_sources(db_session) + + dataset: list[TestQuery] = [] + for datum in dataset_raw: + # validate the raw datum + try: + test_query = TestQuery(**datum) + except ValidationError as e: + logger.error("Incorrectly formatted query %s: %s", datum, e) + continue + + # in case the dataset was copied from the previous run export + if test_query.ground_truth_docids: + dataset.append(test_query) + continue + + # validate and get the ground truth documents + with get_session_with_tenant(tenant_id=self.tenant_id) as db_session: + for ground_truth in test_query.ground_truth: + if ( + doc_id := find_document_id( + ground_truth, federated_sources, db_session + ) + ) is not None: + test_query.ground_truth_docids.append(doc_id) + + if len(test_query.ground_truth_docids) == 0: + logger.warning( + "No ground truth documents found for query: %s, skipping...", + test_query.question, + ) + continue + + dataset.append(test_query) + + return dataset + + @retry(tries=3, delay=1, backoff=2) + def _perform_oneshot_qa(self, query: str) -> OneshotQAResult: + """Perform a OneShot QA query against the Onyx API and time it.""" + # create the OneShot QA request + messages = [ThreadMessage(message=query, sender=None, role=MessageType.USER)] + filters = IndexFilters(access_control_list=None, tenant_id=self.tenant_id) + qa_request = OneShotQARequest( + messages=messages, + persona_id=0, # default persona + retrieval_options=RetrievalDetails( + run_search=OptionalSearchSetting.ALWAYS, + real_time=True, + filters=filters, + enable_auto_detect_filters=False, + limit=self.config.max_search_results, + ), + return_contexts=True, + skip_gen_ai_answer_generation=self.config.search_only, + ) + + # send the request + response = None + try: + request_data = qa_request.model_dump() + headers = GENERAL_HEADERS.copy() + if AUTH_TYPE != AuthType.DISABLED: + headers["Authorization"] = f"Bearer {os.environ.get('ONYX_API_KEY')}" + + start_time = time.monotonic() + response = requests.post( + url=f"{self.config.api_url}/query/answer-with-citation", + json=request_data, + headers=headers, + timeout=self.config.request_timeout, + ) + time_taken = time.monotonic() - start_time + response.raise_for_status() + result = OneShotQAResponse.model_validate(response.json()) + + # extract documents from the QA response + if result.docs: + top_documents = result.docs.top_documents + return OneshotQAResult( + time_taken=time_taken, + top_documents=top_documents, + answer=result.answer, + ) + except RequestException as e: + raise RuntimeError( + f"OneShot QA failed for query '{query}': {e}." + f" Response: {response.json()}" + if response + else "" + ) + raise RuntimeError(f"OneShot QA returned no documents for query {query}") + + def _run_and_analyze_one(self, test_case: TestQuery, total: int) -> AnalysisSummary: + result = self._perform_oneshot_qa(test_case.question) + + # compute rank + rank = None + found = False + ground_truths = set(test_case.ground_truth_docids) + for i, doc in enumerate(result.top_documents, 1): + if doc.document_id in ground_truths: + rank = i + found = True + break + + # print search progress and result + with self._lock: + self._progress_counter += 1 + completed = self._progress_counter + status = "✓ Found" if found else "✗ Not found" + rank_info = f" (rank {rank})" if found else "" + question_snippet = ( + test_case.question[:50] + "..." + if len(test_case.question) > 50 + else test_case.question + ) + print(f"[{completed}/{total}] {status}{rank_info}: {question_snippet}") + + # get the search contents + retrieved = search_docs_to_doc_contexts(result.top_documents, self.tenant_id) + + # do answer evaluation + response_relevancy: float | None = None + faithfulness: float | None = None + factual_correctness: float | None = None + contexts = [c.content for c in retrieved[: self.config.max_answer_context]] + if not self.config.search_only: + if result.answer is None: + logger.error( + "No answer found for query: %s, skipping answer evaluation", + test_case.question, + ) + else: + try: + ragas_result = ragas_evaluate( + question=test_case.question, + answer=result.answer, + contexts=contexts, + reference_answer=test_case.ground_truth_response, + ).scores[0] + response_relevancy = ragas_result["answer_relevancy"] + faithfulness = ragas_result["faithfulness"] + factual_correctness = ragas_result.get( + "factual_correctness(mode=recall)" + ) + except Exception as e: + logger.error( + "Error evaluating answer for query %s: %s", + test_case.question, + e, + ) + + # save results + analysis = AnalysisSummary( + question=test_case.question, + categories=test_case.categories, + found=found, + rank=rank, + total_results=len(result.top_documents), + ground_truth_count=len(test_case.ground_truth_docids), + answer=result.answer, + response_relevancy=response_relevancy, + faithfulness=faithfulness, + factual_correctness=factual_correctness, + retrieved=retrieved, + time_taken=result.time_taken, + ) + with self._lock: + self.ranks.append(analysis.rank) + if self._result_writer: + self._result_writer.append(analysis.model_dump(mode="json")) + self._update_metrics(analysis) + + return analysis + + def _update_metrics(self, result: AnalysisSummary) -> None: + for cat in result.categories + ["all"]: + self.metrics[cat].total_queries += 1 + self.metrics[cat].average_time_taken += result.time_taken + + if result.found: + self.metrics[cat].found_count += 1 + + rank = cast(int, result.rank) + self.metrics[cat].best_rank = min(self.metrics[cat].best_rank, rank) + self.metrics[cat].worst_rank = max(self.metrics[cat].worst_rank, rank) + self.metrics[cat].average_rank += rank + for k in TOP_K_LIST: + self.metrics[cat].top_k_accuracy[k] += int(rank <= k) + + if self.config.search_only: + continue + if result.response_relevancy is not None: + self.metrics[cat].response_relevancy += result.response_relevancy + self.metrics[cat].n_response_relevancy += 1 + if result.faithfulness is not None: + self.metrics[cat].faithfulness += result.faithfulness + self.metrics[cat].n_faithfulness += 1 + if result.factual_correctness is not None: + self.metrics[cat].factual_correctness += result.factual_correctness + self.metrics[cat].n_factual_correctness += 1 + + def _aggregate_metrics(self) -> None: + for cat in self.metrics: + total = self.metrics[cat].total_queries + self.metrics[cat].average_time_taken /= total + + if self.metrics[cat].found_count > 0: + self.metrics[cat].average_rank /= self.metrics[cat].found_count + for k in TOP_K_LIST: + self.metrics[cat].top_k_accuracy[k] /= total + self.metrics[cat].top_k_accuracy[k] *= 100 + + if self.config.search_only: + continue + if (n := self.metrics[cat].n_response_relevancy) > 0: + self.metrics[cat].response_relevancy /= n + if (n := self.metrics[cat].n_faithfulness) > 0: + self.metrics[cat].faithfulness /= n + if (n := self.metrics[cat].n_factual_correctness) > 0: + self.metrics[cat].factual_correctness /= n + + +def run_search_eval( + dataset_path: Path, + config: EvalConfig, + tenant_id: str | None, +) -> None: + # check openai api key is set if doing answer eval (must be called that for ragas to recognize) + if not config.search_only and not os.environ.get("OPENAI_API_KEY"): + raise RuntimeError( + "OPENAI_API_KEY is required for answer evaluation. " + "Please add it to the root .vscode/.env file." + ) + + # check onyx api key is set if auth is enabled + if AUTH_TYPE != AuthType.DISABLED and not os.environ.get("ONYX_API_KEY"): + raise RuntimeError( + "ONYX_API_KEY is required if auth is enabled. " + "Please create one in the admin panel and add it to the root .vscode/.env file." + ) + + # check onyx is running + try: + response = requests.get( + f"{config.api_url}/health", timeout=config.request_timeout + ) + response.raise_for_status() + except RequestException as e: + raise RuntimeError(f"Could not connect to Onyx API: {e}") + + # create the export folder + export_folder = current_dir / datetime.now().strftime("eval-%Y-%m-%d-%H-%M-%S") + export_path = Path(export_folder) + export_path.mkdir(parents=True, exist_ok=True) + logger.info("Created export folder: %s", export_path) + + # run the search eval + analyzer = SearchAnswerAnalyzer(config=config, tenant_id=tenant_id) + analyzer.run_analysis(dataset_path, export_path) + analyzer.generate_detailed_report(export_path) + analyzer.generate_chart(export_path) if __name__ == "__main__": - if MULTI_TENANT: - raise ValueError("Multi-tenant is not supported currently") + import argparse + + current_dir = Path(__file__).parent + parser = argparse.ArgumentParser(description="Run search quality evaluation.") + parser.add_argument( + "-d", + "--dataset", + type=Path, + default=current_dir / "test_queries.json", + help="Path to the test-set JSON file (default: %(default)s).", + ) + parser.add_argument( + "-n", + "--num_search", + type=int, + default=50, + help="Maximum number of documents to retrieve per search (default: %(default)s).", + ) + parser.add_argument( + "-a", + "--num_answer", + type=int, + default=25, + help="Maximum number of documents to use for answer evaluation (default: %(default)s).", + ) + parser.add_argument( + "-w", + "--max_workers", + type=int, + default=10, + help="Maximum number of concurrent search requests (0 = unlimited, default: %(default)s).", + ) + parser.add_argument( + "-r", + "--max_req_rate", + type=int, + default=0, + help="Maximum number of search requests per minute (0 = unlimited, default: %(default)s).", + ) + parser.add_argument( + "-q", + "--timeout", + type=int, + default=120, + help="Request timeout in seconds (default: %(default)s).", + ) + parser.add_argument( + "-e", + "--api_endpoint", + type=str, + default="http://127.0.0.1:8080", + help="Base URL of the Onyx API server (default: %(default)s).", + ) + parser.add_argument( + "-s", + "--search_only", + action="store_true", + default=False, + help="Only perform search and not answer evaluation (default: %(default)s).", + ) + parser.add_argument( + "-t", + "--tenant_id", + type=str, + default=None, + help="Tenant ID to use for the evaluation (default: %(default)s).", + ) + + args = parser.parse_args() SqlEngine.init_engine( pool_size=POSTGRES_API_SERVER_POOL_SIZE, @@ -153,9 +727,21 @@ def run_search_eval() -> None: ) try: - run_search_eval() + run_search_eval( + args.dataset, + EvalConfig( + max_search_results=args.num_search, + max_answer_context=args.num_answer, + num_workers=args.max_workers, + max_request_rate=args.max_req_rate, + request_timeout=args.timeout, + api_url=args.api_endpoint, + search_only=args.search_only, + ), + args.tenant_id, + ) except Exception as e: - logger.error(f"Error running search evaluation: {e}") - raise e + logger.error("Unexpected error during search evaluation: %s", e) + raise finally: SqlEngine.reset_engine() diff --git a/backend/tests/regression/search_quality/search_eval_config.yaml.template b/backend/tests/regression/search_quality/search_eval_config.yaml.template deleted file mode 100644 index 68405ebb116..00000000000 --- a/backend/tests/regression/search_quality/search_eval_config.yaml.template +++ /dev/null @@ -1,16 +0,0 @@ -# Search Parameters -HYBRID_ALPHA: 0.5 -HYBRID_ALPHA_KEYWORD: 0.4 -DOC_TIME_DECAY: 0.5 -NUM_RETURNED_HITS: 50 # Setting to a higher value will improve evaluation quality but increase reranking time -RANK_PROFILE: 'semantic' -OFFSET: 0 -TITLE_CONTENT_RATIO: 0.1 -USER_EMAIL: null # User email to use for testing, modifies access control list, null means only public files - -# Evaluation parameters -SKIP_RERANK: false # Whether to skip reranking, reranking must be enabled to evaluate the search results -EVAL_TOPK: 5 # Number of top results from the searcher and reranker to evaluate, lower means stricter evaluation - -# Export file, will export a csv file with the results and a json file with the parameters -EXPORT_FOLDER: "eval-%Y-%m-%d-%H-%M-%S" diff --git a/backend/tests/regression/search_quality/test_queries.json.template b/backend/tests/regression/search_quality/test_queries.json.template index 93e855472b7..e5646d3f03b 100644 --- a/backend/tests/regression/search_quality/test_queries.json.template +++ b/backend/tests/regression/search_quality/test_queries.json.template @@ -3,20 +3,18 @@ "question": "What is Onyx?", "ground_truth": [ { - "doc_source": "Web", + "doc_source": "web", "doc_link": "https://docs.onyx.app/more/use_cases/overview" }, { - "doc_source": "Web", + "doc_source": "web", "doc_link": "https://docs.onyx.app/more/use_cases/ai_platform" } ], "categories": [ "keyword", - "broad" + "broad", + "easy" ] - }, - { - "question": "What is the meaning of life?" } ] \ No newline at end of file diff --git a/backend/tests/regression/search_quality/util_config.py b/backend/tests/regression/search_quality/util_config.py deleted file mode 100644 index 4a06b7b9ec5..00000000000 --- a/backend/tests/regression/search_quality/util_config.py +++ /dev/null @@ -1,75 +0,0 @@ -from datetime import datetime -from pathlib import Path - -import yaml -from pydantic import BaseModel - -from onyx.agents.agent_search.shared_graph_utils.models import QueryExpansionType -from onyx.configs.chat_configs import DOC_TIME_DECAY -from onyx.configs.chat_configs import HYBRID_ALPHA -from onyx.configs.chat_configs import HYBRID_ALPHA_KEYWORD -from onyx.configs.chat_configs import NUM_RETURNED_HITS -from onyx.configs.chat_configs import TITLE_CONTENT_RATIO -from onyx.utils.logger import setup_logger - -logger = setup_logger(__name__) - - -class SearchEvalConfig(BaseModel): - hybrid_alpha: float - hybrid_alpha_keyword: float - doc_time_decay: float - num_returned_hits: int - rank_profile: QueryExpansionType - offset: int - title_content_ratio: float - user_email: str | None - skip_rerank: bool - eval_topk: int - export_folder: str - - -def load_config() -> SearchEvalConfig: - """Loads the search evaluation configs from the config file.""" - # open the config file - current_dir = Path(__file__).parent - config_path = current_dir / "search_eval_config.yaml" - if not config_path.exists(): - raise FileNotFoundError(f"Search eval config file not found at {config_path}") - with config_path.open("r") as file: - config_raw = yaml.safe_load(file) - - # create the export folder - export_folder = config_raw.get("EXPORT_FOLDER", "eval-%Y-%m-%d-%H-%M-%S") - export_folder = datetime.now().strftime(export_folder) - export_path = Path(export_folder) - export_path.mkdir(parents=True, exist_ok=True) - logger.info(f"Created export folder: {export_path}") - - # create the config - config = SearchEvalConfig( - hybrid_alpha=config_raw.get("HYBRID_ALPHA", HYBRID_ALPHA), - hybrid_alpha_keyword=config_raw.get( - "HYBRID_ALPHA_KEYWORD", HYBRID_ALPHA_KEYWORD - ), - doc_time_decay=config_raw.get("DOC_TIME_DECAY", DOC_TIME_DECAY), - num_returned_hits=config_raw.get("NUM_RETURNED_HITS", NUM_RETURNED_HITS), - rank_profile=config_raw.get("RANK_PROFILE", QueryExpansionType.SEMANTIC), - offset=config_raw.get("OFFSET", 0), - title_content_ratio=config_raw.get("TITLE_CONTENT_RATIO", TITLE_CONTENT_RATIO), - user_email=config_raw.get("USER_EMAIL"), - skip_rerank=config_raw.get("SKIP_RERANK", False), - eval_topk=config_raw.get("EVAL_TOPK", 5), - export_folder=export_folder, - ) - logger.info(f"Using search parameters: {config}") - - # export the config - config_file = export_path / "search_eval_config.yaml" - with config_file.open("w") as file: - config_dict = config.model_dump(mode="python") - config_dict["rank_profile"] = config.rank_profile.value - yaml.dump(config_dict, file, sort_keys=False) - logger.info(f"Exported config to {config_file}") - - return config diff --git a/backend/tests/regression/search_quality/util_data.py b/backend/tests/regression/search_quality/util_data.py deleted file mode 100644 index 34f0c5515eb..00000000000 --- a/backend/tests/regression/search_quality/util_data.py +++ /dev/null @@ -1,166 +0,0 @@ -import json -from pathlib import Path -from typing import cast -from typing import Optional - -from langgraph.types import StreamWriter -from pydantic import BaseModel -from pydantic import ValidationError - -from onyx.agents.agent_search.basic.utils import process_llm_stream -from onyx.chat.models import PromptConfig -from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder -from onyx.chat.prompt_builder.answer_prompt_builder import default_build_system_message -from onyx.chat.prompt_builder.answer_prompt_builder import default_build_user_message -from onyx.configs.constants import DEFAULT_PERSONA_ID -from onyx.db.engine.sql_engine import get_session_with_current_tenant -from onyx.db.persona import get_persona_by_id -from onyx.llm.factory import get_llms_for_persona -from onyx.llm.interfaces import LLM -from onyx.tools.tool_implementations.search.search_tool import SearchTool -from onyx.tools.utils import explicit_tool_calling_supported -from onyx.utils.logger import setup_logger - -logger = setup_logger() - - -class GroundTruth(BaseModel): - doc_source: str - doc_link: str - - -class TestQuery(BaseModel): - question: str - question_search: Optional[str] = None - ground_truth: list[GroundTruth] = [] - categories: list[str] = [] - - -def load_test_queries() -> list[TestQuery]: - """ - Loads the test queries from the test_queries.json file. - If `question_search` is missing, it will use the tool-calling LLM to generate it. - """ - # open test queries file - current_dir = Path(__file__).parent - test_queries_path = current_dir / "test_queries.json" - logger.info(f"Loading test queries from {test_queries_path}") - if not test_queries_path.exists(): - raise FileNotFoundError(f"Test queries file not found at {test_queries_path}") - with test_queries_path.open("r") as f: - test_queries_raw: list[dict] = json.load(f) - - # setup llm for question_search generation - with get_session_with_current_tenant() as db_session: - persona = get_persona_by_id(DEFAULT_PERSONA_ID, None, db_session) - llm, _ = get_llms_for_persona(persona) - prompt_config = PromptConfig.from_model(persona.prompts[0]) - search_tool = SearchToolOverride() - - tool_call_supported = explicit_tool_calling_supported( - llm.config.model_provider, llm.config.model_name - ) - - # validate keys and generate question_search if missing - test_queries: list[TestQuery] = [] - for query_raw in test_queries_raw: - try: - test_query = TestQuery(**query_raw) - except ValidationError as e: - logger.error(f"Incorrectly formatted query: {e}") - continue - - if test_query.question_search is None: - test_query.question_search = _modify_one_query( - query=test_query.question, - llm=llm, - prompt_config=prompt_config, - tool=search_tool, - tool_call_supported=tool_call_supported, - ) - test_queries.append(test_query) - - return test_queries - - -def export_test_queries(test_queries: list[TestQuery], export_path: Path) -> None: - """Exports the test queries to a JSON file.""" - logger.info(f"Exporting test queries to {export_path}") - with export_path.open("w") as f: - json.dump( - [query.model_dump() for query in test_queries], - f, - indent=4, - ) - - -class SearchToolOverride(SearchTool): - def __init__(self) -> None: - # do nothing, only class variables are required for the functions we call - pass - - -warned = False - - -def _modify_one_query( - query: str, - llm: LLM, - prompt_config: PromptConfig, - tool: SearchTool, - tool_call_supported: bool, - writer: StreamWriter = lambda _: None, -) -> str: - global warned - if not warned: - logger.warning( - "Generating question_search. If you do not save the question_search, " - "it will be generated again on the next run, potentially altering the search results." - ) - warned = True - - prompt_builder = AnswerPromptBuilder( - user_message=default_build_user_message( - user_query=query, - prompt_config=prompt_config, - files=[], - single_message_history=None, - ), - system_message=default_build_system_message(prompt_config, llm.config), - message_history=[], - llm_config=llm.config, - raw_user_query=query, - raw_user_uploaded_files=[], - single_message_history=None, - ) - - if tool_call_supported: - prompt = prompt_builder.build() - tool_definition = tool.tool_definition() - stream = llm.stream( - prompt=prompt, - tools=[tool_definition], - tool_choice="required", - structured_response_format=None, - ) - tool_message = process_llm_stream( - messages=stream, - should_stream_answer=False, - writer=writer, - ) - return ( - tool_message.tool_calls[0]["args"]["query"] - if tool_message.tool_calls - else query - ) - - history = prompt_builder.get_message_history() - return cast( - dict[str, str], - tool.get_args_for_non_tool_calling_llm( - query=query, - history=history, - llm=llm, - force_run=True, - ), - )["query"] diff --git a/backend/tests/regression/search_quality/util_eval.py b/backend/tests/regression/search_quality/util_eval.py deleted file mode 100644 index 47fb86d7ee3..00000000000 --- a/backend/tests/regression/search_quality/util_eval.py +++ /dev/null @@ -1,94 +0,0 @@ -from typing import Optional - -from pydantic import BaseModel -from sqlalchemy.orm import Session - -from onyx.context.search.models import InferenceChunk -from onyx.db.models import Document -from onyx.utils.logger import setup_logger -from tests.regression.search_quality.util_retrieve import group_by_documents - -logger = setup_logger(__name__) - - -class Metrics(BaseModel): - # computed if ground truth is provided - ground_truth_ratio_topk: Optional[float] = None - ground_truth_avg_rank_delta: Optional[float] = None - - # computed if reranked results are provided - soft_truth_ratio_topk: Optional[float] = None - soft_truth_avg_rank_delta: Optional[float] = None - - -metric_names = list(Metrics.model_fields.keys()) - - -def get_corresponding_document( - doc_link: str, db_session: Session -) -> Optional[Document]: - """Get the corresponding document from the database.""" - doc_filter = db_session.query(Document).filter(Document.link == doc_link) - count = doc_filter.count() - if count == 0: - logger.warning(f"Could not find document with link {doc_link}, ignoring") - return None - if count > 1: - logger.warning(f"Found multiple documents with link {doc_link}, using first") - return doc_filter.first() - - -def evaluate_one_query( - search_chunks: list[InferenceChunk], - rerank_chunks: list[InferenceChunk], - true_documents: list[Document], - topk: int, -) -> Metrics: - """Computes metrics for the search results, relative to the ground truth and reranked results.""" - metrics_dict: dict[str, float] = {} - - search_documents = group_by_documents(search_chunks) - search_ranks = {docid: rank for rank, docid in enumerate(search_documents)} - search_ranks_topk = { - docid: rank for rank, docid in enumerate(search_documents[:topk]) - } - true_ranks = {doc.id: rank for rank, doc in enumerate(true_documents)} - - if true_documents: - metrics_dict["ground_truth_ratio_topk"] = _compute_ratio( - search_ranks_topk, true_ranks - ) - metrics_dict["ground_truth_avg_rank_delta"] = _compute_avg_rank_delta( - search_ranks, true_ranks - ) - - if rerank_chunks: - # build soft truth out of ground truth + reranked results, up to topk - soft_ranks = true_ranks - for docid in group_by_documents(rerank_chunks): - if len(soft_ranks) >= topk: - break - if docid not in soft_ranks: - soft_ranks[docid] = len(soft_ranks) - - metrics_dict["soft_truth_ratio_topk"] = _compute_ratio( - search_ranks_topk, soft_ranks - ) - metrics_dict["soft_truth_avg_rank_delta"] = _compute_avg_rank_delta( - search_ranks, soft_ranks - ) - - return Metrics(**metrics_dict) - - -def _compute_ratio(search_ranks: dict[str, int], true_ranks: dict[str, int]) -> float: - return len(set(search_ranks) & set(true_ranks)) / len(true_ranks) - - -def _compute_avg_rank_delta( - search_ranks: dict[str, int], true_ranks: dict[str, int] -) -> float: - out = len(search_ranks) - return sum( - abs(search_ranks.get(docid, out) - rank) for docid, rank in true_ranks.items() - ) / len(true_ranks) diff --git a/backend/tests/regression/search_quality/util_retrieve.py b/backend/tests/regression/search_quality/util_retrieve.py deleted file mode 100644 index 5ddfa29471f..00000000000 --- a/backend/tests/regression/search_quality/util_retrieve.py +++ /dev/null @@ -1,88 +0,0 @@ -from sqlalchemy.orm import Session - -from onyx.context.search.models import IndexFilters -from onyx.context.search.models import InferenceChunk -from onyx.context.search.models import RerankingDetails -from onyx.context.search.postprocessing.postprocessing import semantic_reranking -from onyx.context.search.preprocessing.preprocessing import query_analysis -from onyx.context.search.retrieval.search_runner import get_query_embedding -from onyx.context.search.utils import remove_stop_words_and_punctuation -from onyx.document_index.interfaces import DocumentIndex -from onyx.utils.logger import setup_logger -from tests.regression.search_quality.util_config import SearchEvalConfig - -logger = setup_logger(__name__) - - -def search_one_query( - question_search: str, - multilingual_expansion: list[str], - document_index: DocumentIndex, - db_session: Session, - config: SearchEvalConfig, -) -> list[InferenceChunk]: - """Uses the search pipeline to retrieve relevant chunks for the given query.""" - # the retrieval preprocessing is fairly stripped down so the query doesn't unexpectedly change - query_embedding = get_query_embedding(question_search, db_session) - - all_query_terms = question_search.split() - processed_keywords = ( - remove_stop_words_and_punctuation(all_query_terms) - if not multilingual_expansion - else all_query_terms - ) - - is_keyword = query_analysis(question_search)[0] - hybrid_alpha = config.hybrid_alpha_keyword if is_keyword else config.hybrid_alpha - - access_control_list = ["PUBLIC"] - if config.user_email: - access_control_list.append(f"user_email:{config.user_email}") - filters = IndexFilters( - tags=[], - user_file_ids=[], - user_folder_ids=[], - access_control_list=access_control_list, - tenant_id=None, - ) - - results = document_index.hybrid_retrieval( - query=question_search, - query_embedding=query_embedding, - final_keywords=processed_keywords, - filters=filters, - hybrid_alpha=hybrid_alpha, - time_decay_multiplier=config.doc_time_decay, - num_to_retrieve=config.num_returned_hits, - ranking_profile_type=config.rank_profile, - offset=config.offset, - title_content_ratio=config.title_content_ratio, - ) - - return [result.to_inference_chunk() for result in results] - - -def rerank_one_query( - question: str, - retrieved_chunks: list[InferenceChunk], - rerank_settings: RerankingDetails, -) -> list[InferenceChunk]: - """Uses the reranker to rerank the retrieved chunks for the given query.""" - rerank_settings.num_rerank = len(retrieved_chunks) - return semantic_reranking( - query_str=question, - rerank_settings=rerank_settings, - chunks=retrieved_chunks, - rerank_metrics_callback=None, - )[0] - - -def group_by_documents(chunks: list[InferenceChunk]) -> list[str]: - """Groups a sorted list of chunks into a sorted list of document ids.""" - seen_docids: set[str] = set() - retrieved_docids: list[str] = [] - for chunk in chunks: - if chunk.document_id not in seen_docids: - seen_docids.add(chunk.document_id) - retrieved_docids.append(chunk.document_id) - return retrieved_docids diff --git a/backend/tests/regression/search_quality/utils.py b/backend/tests/regression/search_quality/utils.py new file mode 100644 index 00000000000..dc5b6e53352 --- /dev/null +++ b/backend/tests/regression/search_quality/utils.py @@ -0,0 +1,208 @@ +import json +import re +from pathlib import Path +from textwrap import indent +from typing import Any +from typing import TextIO + +from ragas import evaluate # type: ignore +from ragas import EvaluationDataset # type: ignore +from ragas import SingleTurnSample # type: ignore +from ragas.dataset_schema import EvaluationResult # type: ignore +from ragas.metrics import FactualCorrectness # type: ignore +from ragas.metrics import Faithfulness # type: ignore +from ragas.metrics import ResponseRelevancy # type: ignore +from sqlalchemy.orm import Session + +from onyx.configs.constants import DocumentSource +from onyx.context.search.models import IndexFilters +from onyx.context.search.models import SavedSearchDoc +from onyx.db.engine.sql_engine import get_session_with_tenant +from onyx.db.models import Document +from onyx.db.models import FederatedConnector +from onyx.db.search_settings import get_current_search_settings +from onyx.document_index.factory import get_default_document_index +from onyx.document_index.interfaces import VespaChunkRequest +from onyx.prompts.prompt_utils import build_doc_context_str +from onyx.utils.logger import setup_logger +from tests.regression.search_quality.models import CombinedMetrics +from tests.regression.search_quality.models import GroundTruth +from tests.regression.search_quality.models import RetrievedDocument + +logger = setup_logger(__name__) + + +def get_federated_sources(db_session: Session) -> set[DocumentSource]: + """Get all federated sources from the database.""" + return { + source + for connector in db_session.query(FederatedConnector).all() + if (source := connector.source.to_non_federated_source()) is not None + } + + +def find_document_id( + ground_truth: GroundTruth, + federated_sources: set[DocumentSource], + db_session: Session, +) -> str | None: + """Find a document by its link and return its id if found.""" + # handle federated sources TODO: maybe make handler dictionary by source if this gets complex + if ground_truth.doc_source in federated_sources: + if ground_truth.doc_source == DocumentSource.SLACK: + groups = re.search( + r"archives\/([A-Z0-9]+)\/p([0-9]+)", ground_truth.doc_link + ) + if groups: + channel_id = groups.group(1) + message_id = groups.group(2) + return f"{channel_id}__{message_id[:-6]}.{message_id[-6:]}" + + # preprocess links + doc_link = ground_truth.doc_link + if ground_truth.doc_source == DocumentSource.GOOGLE_DRIVE: + if "/edit" in doc_link: + doc_link = doc_link.split("/edit", 1)[0] + elif "/view" in doc_link: + doc_link = doc_link.split("/view", 1)[0] + elif ground_truth.doc_source == DocumentSource.FIREFLIES: + doc_link = doc_link.split("?", 1)[0] + + docs = db_session.query(Document).filter(Document.link.ilike(f"{doc_link}%")).all() + if len(docs) == 0: + logger.warning("Could not find ground truth document: %s", doc_link) + return None + elif len(docs) > 1: + logger.warning( + "Found multiple ground truth documents: %s, using the first one: %s", + doc_link, + docs[0].id, + ) + return docs[0].id + + +def get_doc_contents( + docs: list[SavedSearchDoc], tenant_id: str +) -> dict[tuple[str, int], str]: + with get_session_with_tenant(tenant_id=tenant_id) as db_session: + search_settings = get_current_search_settings(db_session) + document_index = get_default_document_index(search_settings, None) + + filters = IndexFilters(access_control_list=None, tenant_id=tenant_id) + + reqs: list[VespaChunkRequest] = [ + VespaChunkRequest( + document_id=doc.document_id, + min_chunk_ind=doc.chunk_ind, + max_chunk_ind=doc.chunk_ind, + ) + for doc in docs + ] + + results = document_index.id_based_retrieval(chunk_requests=reqs, filters=filters) + return {(doc.document_id, doc.chunk_id): doc.content for doc in results} + + +def search_docs_to_doc_contexts( + docs: list[SavedSearchDoc], tenant_id: str +) -> list[RetrievedDocument]: + try: + doc_contents = get_doc_contents(docs, tenant_id) + except Exception as e: + logger.error("Error getting doc contents: %s", e) + doc_contents = {} + + return [ + RetrievedDocument( + document_id=doc.document_id, + chunk_id=doc.chunk_ind, + content=build_doc_context_str( + semantic_identifier=doc.semantic_identifier, + source_type=doc.source_type, + content=doc_contents.get( + (doc.document_id, doc.chunk_ind), f"Blurb: {doc.blurb}" + ), + metadata_dict=doc.metadata, + updated_at=doc.updated_at, + ind=ind, + include_metadata=True, + ), + ) + for ind, doc in enumerate(docs) + ] + + +def ragas_evaluate( + question: str, answer: str, contexts: list[str], reference_answer: str | None = None +) -> EvaluationResult: + sample = SingleTurnSample( + user_input=question, + retrieved_contexts=contexts, + response=answer, + reference=reference_answer, + ) + dataset = EvaluationDataset([sample]) + return evaluate( + dataset, + metrics=[ + ResponseRelevancy(), + Faithfulness(), + *( + [FactualCorrectness(mode="recall")] + if reference_answer is not None + else [] + ), + ], + ) + + +def compute_overall_scores(metrics: CombinedMetrics) -> tuple[float, float]: + """Compute the overall search and answer quality scores. + The scores are subjective and may require tuning.""" + # search score + FOUND_RATIO_WEIGHT = 0.4 + TOP_IMPORTANCE = 0.7 # 0-inf, how important is it to be no. 1 over other ranks + + found_ratio = metrics.found_count / metrics.total_queries + sum_k = sum(1.0 / pow(k, TOP_IMPORTANCE) for k in metrics.top_k_accuracy) + weighted_topk = sum( + acc / (pow(k, TOP_IMPORTANCE) * sum_k * 100) + for k, acc in metrics.top_k_accuracy.items() + ) + search_score = 100 * ( + FOUND_RATIO_WEIGHT * found_ratio + (1.0 - FOUND_RATIO_WEIGHT) * weighted_topk + ) + + # answer score + mets = [ + *([metrics.response_relevancy] if metrics.n_response_relevancy > 0 else []), + *([metrics.faithfulness] if metrics.n_faithfulness > 0 else []), + *([metrics.factual_correctness] if metrics.n_factual_correctness > 0 else []), + ] + answer_score = 100 * sum(mets) / len(mets) if mets else 0.0 + + return search_score, answer_score + + +class LazyJsonWriter: + def __init__(self, filepath: Path, indent: int = 4) -> None: + self.filepath = filepath + self.file: TextIO | None = None + self.indent = indent + + def append(self, serializable_item: dict[str, Any]) -> None: + if not self.file: + self.file = open(self.filepath, "a") + self.file.write("[\n") + else: + self.file.write(",\n") + + data = json.dumps(serializable_item, indent=self.indent) + self.file.write(indent(data, " " * self.indent)) + + def close(self) -> None: + if not self.file: + return + self.file.write("\n]") + self.file.close() + self.file = None From 44bee6fc4bf6dc5483018c2c15a962f0a4f995d3 Mon Sep 17 00:00:00 2001 From: Raunak Bhagat Date: Mon, 21 Jul 2025 12:45:48 -0700 Subject: [PATCH 13/78] Remove empty tooltip (#5050) --- web/src/app/admin/add-connector/page.tsx | 68 ++++++++++-------------- web/src/components/SourceTile.tsx | 50 +++++++++++++++++ 2 files changed, 79 insertions(+), 39 deletions(-) create mode 100644 web/src/components/SourceTile.tsx diff --git a/web/src/app/admin/add-connector/page.tsx b/web/src/app/admin/add-connector/page.tsx index 7abb8b29702..588d23cd42d 100644 --- a/web/src/app/admin/add-connector/page.tsx +++ b/web/src/app/admin/add-connector/page.tsx @@ -1,7 +1,6 @@ "use client"; -import { SourceIcon } from "@/components/SourceIcon"; import { AdminPageTitle } from "@/components/admin/Title"; -import { AlertIcon, ConnectorIcon, InfoIcon } from "@/components/icons/icons"; +import { ConnectorIcon } from "@/components/icons/icons"; import { SourceCategory, SourceMetadata } from "@/lib/search/interfaces"; import { listSourceMetadata } from "@/lib/sources"; import Title from "@/components/ui/title"; @@ -31,9 +30,10 @@ import useSWR from "swr"; import { errorHandlingFetcher } from "@/lib/fetcher"; import { buildSimilarCredentialInfoURL } from "@/app/admin/connector/[ccPairId]/lib"; import { Credential } from "@/lib/connectors/credentials"; +import SourceTile from "@/components/SourceTile"; import { SettingsContext } from "@/components/settings/SettingsProvider"; -function SourceTile({ +function SourceTileTooltipWrapper({ sourceMetadata, preSelect, federatedConnectors, @@ -82,46 +82,36 @@ function SourceTile({ sourceMetadata.adminUrl, ]); + // Compute whether to hide the tooltip based on the provided condition + const shouldHideTooltip = + !(existingFederatedConnector && !hasExistingSlackCredentials) && + !hasExistingSlackCredentials && + !sourceMetadata.federated; + + // If tooltip should be hidden, just render the tile as a component + if (shouldHideTooltip) { + return ( + + ); + } + return ( - - {sourceMetadata.federated && !hasExistingSlackCredentials && ( -
- -
- )} - + -

- {sourceMetadata.displayName} -

- +
{existingFederatedConnector && !hasExistingSlackCredentials ? ( @@ -280,7 +270,7 @@ export default function Page() {
{sources.map((source, sourceInd) => ( - 0 && categoryInd == 0 && sourceInd == 0 } diff --git a/web/src/components/SourceTile.tsx b/web/src/components/SourceTile.tsx new file mode 100644 index 00000000000..a6d47a199e4 --- /dev/null +++ b/web/src/components/SourceTile.tsx @@ -0,0 +1,50 @@ +import { SourceIcon } from "@/components/SourceIcon"; +import { AlertIcon } from "@/components/icons/icons"; +import Link from "next/link"; +import { SourceMetadata } from "@/lib/search/interfaces"; +import React from "react"; + +interface SourceTileProps { + sourceMetadata: SourceMetadata; + preSelect?: boolean; + navigationUrl: string; + hasExistingSlackCredentials: boolean; +} + +export default function SourceTile({ + sourceMetadata, + preSelect, + navigationUrl, + hasExistingSlackCredentials, +}: SourceTileProps) { + return ( + + {sourceMetadata.federated && !hasExistingSlackCredentials && ( +
+ +
+ )} + +

{sourceMetadata.displayName}

+ + ); +} From 48f8a68c78ee0a269e613c82624c63fe705d7b53 Mon Sep 17 00:00:00 2001 From: Raunak Bhagat Date: Mon, 21 Jul 2025 15:37:27 -0700 Subject: [PATCH 14/78] feat: Updated KG admin page (#5044) * Update KG admin UI * Styling changes * More changes * Make edits auto-save * Add more stylings / transitions * Fix opacity * Separate out modal into new component * Revert backend changes * Update styling * Add convenience / styling changes to date-picker * More styling / functional updates to kg admin-page * Avoid reducing opacity of active-toggle * Update backend APIs for new KG admin page * More updates of styling for kg-admin page * Remove nullability * Remove console log * Remove unused imports * Change type of `children` variable * Update web/src/app/admin/kg/interfaces.ts Co-authored-by: cubic-dev-ai[bot] <191113872+cubic-dev-ai[bot]@users.noreply.github.com> * Update web/src/components/CollapsibleCard.tsx Co-authored-by: cubic-dev-ai[bot] <191113872+cubic-dev-ai[bot]@users.noreply.github.com> * Remove null * Update web/src/components/CollapsibleCard.tsx Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> * Force non-null * Fix failing test --------- Co-authored-by: cubic-dev-ai[bot] <191113872+cubic-dev-ai[bot]@users.noreply.github.com> Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> --- backend/onyx/db/entities.py | 24 ++ backend/onyx/db/entity_type.py | 22 +- backend/onyx/server/kg/api.py | 28 +- backend/onyx/server/kg/models.py | 11 + .../tests/integration/tests/kg/test_kg_api.py | 22 +- .../app/admin/connector/[ccPairId]/unused.txt | 0 web/src/app/admin/kg/KGEntityTypes.tsx | 323 ++++++++++++++++++ web/src/app/admin/kg/interfaces.ts | 35 +- web/src/app/admin/kg/page.tsx | 228 +------------ web/src/app/admin/kg/utils.ts | 10 + web/src/components/CollapsibleCard.tsx | 80 +++++ web/src/components/ui/dataTable.tsx | 81 ----- web/src/components/ui/datePicker.tsx | 6 +- web/tailwind-themes/tailwind.config.js | 7 +- 14 files changed, 538 insertions(+), 339 deletions(-) delete mode 100644 web/src/app/admin/connector/[ccPairId]/unused.txt create mode 100644 web/src/app/admin/kg/KGEntityTypes.tsx create mode 100644 web/src/components/CollapsibleCard.tsx delete mode 100644 web/src/components/ui/dataTable.tsx diff --git a/backend/onyx/db/entities.py b/backend/onyx/db/entities.py index d3e92d43116..9696ce92673 100644 --- a/backend/onyx/db/entities.py +++ b/backend/onyx/db/entities.py @@ -308,3 +308,27 @@ def get_entity_name(db_session: Session, entity_id_name: str) -> str | None: db_session.query(KGEntity).filter(KGEntity.id_name == entity_id_name).first() ) return entity.name if entity else None + + +def get_entity_stats_by_grounded_source_name( + db_session: Session, +) -> dict[str, tuple[datetime, int]]: + """ + Returns a dict mapping each grounded_source_name to a tuple in which: + - the first element is the latest update time across all entities with the same entity-type + - the second element is the count of `KGEntity`s + """ + results = ( + db_session.query( + KGEntityType.grounded_source_name, + func.count(KGEntity.id_name).label("entities_count"), + func.max(KGEntity.time_updated).label("last_updated"), + ) + .join(KGEntityType, KGEntity.entity_type_id_name == KGEntityType.id_name) + .group_by(KGEntityType.grounded_source_name) + .all() + ) + return { + row.grounded_source_name: (row.last_updated, row.entities_count) + for row in results + } diff --git a/backend/onyx/db/entity_type.py b/backend/onyx/db/entity_type.py index 56c8d367496..54b7bfaff35 100644 --- a/backend/onyx/db/entity_type.py +++ b/backend/onyx/db/entity_type.py @@ -1,3 +1,5 @@ +from collections import defaultdict + from sqlalchemy import update from sqlalchemy.orm import Session @@ -9,6 +11,9 @@ from onyx.server.kg.models import EntityType +_UNGROUNDED_SOURCE_NAME = "Ungrounded" + + def get_entity_types_with_grounded_source_name( db_session: Session, ) -> list[KGEntityType]: @@ -45,7 +50,7 @@ def get_entity_types( ) -def get_configured_entity_types(db_session: Session) -> list[KGEntityType]: +def get_configured_entity_types(db_session: Session) -> dict[str, list[KGEntityType]]: # get entity types from configured sources configured_connector_sources = { source.value.lower() @@ -73,12 +78,20 @@ def get_configured_entity_types(db_session: Session) -> list[KGEntityType]: elif isinstance(implied_et, str): if implied_et not in entity_type_set: entity_type_set.add(implied_et) - return ( + + ets = ( db_session.query(KGEntityType) .filter(KGEntityType.id_name.in_(entity_type_set)) .all() ) + et_map = defaultdict(list) + for et in ets: + key = et.grounded_source_name or _UNGROUNDED_SOURCE_NAME + et_map[key].append(et) + + return et_map + def update_entity_types_and_related_connectors__commit( db_session: Session, updates: list[EntityType] @@ -99,7 +112,10 @@ def update_entity_types_and_related_connectors__commit( configured_entity_types = get_configured_entity_types(db_session=db_session) active_entity_type_sources = { - et.grounded_source_name for et in configured_entity_types if et.active + et.grounded_source_name + for ets in configured_entity_types.values() + for et in ets + if et.active } # Update connectors that should be enabled diff --git a/backend/onyx/server/kg/api.py b/backend/onyx/server/kg/api.py index c59434efa7b..56081997424 100644 --- a/backend/onyx/server/kg/api.py +++ b/backend/onyx/server/kg/api.py @@ -5,6 +5,7 @@ from onyx.auth.users import current_admin_user from onyx.context.search.enums import RecencyBiasSetting from onyx.db.engine.sql_engine import get_session +from onyx.db.entities import get_entity_stats_by_grounded_source_name from onyx.db.entity_type import get_configured_entity_types from onyx.db.entity_type import update_entity_types_and_related_connectors__commit from onyx.db.kg_config import disable_kg @@ -28,6 +29,8 @@ from onyx.server.kg.models import EntityType from onyx.server.kg.models import KGConfig from onyx.server.kg.models import KGConfig as KGConfigAPIModel +from onyx.server.kg.models import SourceAndEntityTypeView +from onyx.server.kg.models import SourceStatistics from onyx.tools.built_in_tools import get_search_tool @@ -54,7 +57,7 @@ def get_kg_exposed(_: User | None = Depends(current_admin_user)) -> bool: def reset_kg( _: User | None = Depends(current_admin_user), db_session: Session = Depends(get_session), -) -> list[EntityType]: +) -> SourceAndEntityTypeView: reset_full_kg_index__commit(db_session) populate_missing_default_entity_types__commit(db_session=db_session) return get_kg_entity_types(db_session=db_session) @@ -173,11 +176,26 @@ def enable_or_disable_kg( def get_kg_entity_types( _: User | None = Depends(current_admin_user), db_session: Session = Depends(get_session), -) -> list[EntityType]: +) -> SourceAndEntityTypeView: # when using for the first time, populate with default entity types - kg_entity_types = get_configured_entity_types(db_session=db_session) - - return [EntityType.from_model(kg_entity_type) for kg_entity_type in kg_entity_types] + entity_types = { + key: [EntityType.from_model(et) for et in ets] + for key, ets in get_configured_entity_types(db_session=db_session).items() + } + + source_statistics = { + key: SourceStatistics( + source_name=key, last_updated=last_updated, entities_count=entities_count + ) + for key, ( + last_updated, + entities_count, + ) in get_entity_stats_by_grounded_source_name(db_session=db_session).items() + } + + return SourceAndEntityTypeView( + source_statistics=source_statistics, entity_types=entity_types + ) @admin_router.put("/entity-types") diff --git a/backend/onyx/server/kg/models.py b/backend/onyx/server/kg/models.py index 4bffa673be6..e4527f8e947 100644 --- a/backend/onyx/server/kg/models.py +++ b/backend/onyx/server/kg/models.py @@ -62,3 +62,14 @@ def from_model( active=model.active, grounded_source_name=model.grounded_source_name, ) + + +class SourceStatistics(BaseModel): + source_name: str + last_updated: datetime + entities_count: int + + +class SourceAndEntityTypeView(BaseModel): + source_statistics: dict[str, SourceStatistics] + entity_types: dict[str, list[EntityType]] diff --git a/backend/tests/integration/tests/kg/test_kg_api.py b/backend/tests/integration/tests/kg/test_kg_api.py index 46addb9af31..82cdca15fb9 100644 --- a/backend/tests/integration/tests/kg/test_kg_api.py +++ b/backend/tests/integration/tests/kg/test_kg_api.py @@ -17,6 +17,7 @@ from onyx.server.kg.models import EnableKGConfigRequest from onyx.server.kg.models import EntityType from onyx.server.kg.models import KGConfig as KGConfigAPIModel +from onyx.server.kg.models import SourceAndEntityTypeView from tests.integration.common_utils.constants import API_SERVER_URL from tests.integration.common_utils.managers.user import UserManager from tests.integration.common_utils.reset import reset_all @@ -169,6 +170,7 @@ def test_update_kg_entity_types(connectors: None) -> None: assert ( res2.status_code == HTTPStatus.OK ), f"Error response: {res2.status_code} - {res2.text}" + res2_parsed = SourceAndEntityTypeView.model_validate(res2.json()) # Update entity types req3 = [ @@ -210,16 +212,20 @@ def test_update_kg_entity_types(connectors: None) -> None: assert ( res4.status_code == HTTPStatus.OK ), f"Error response: {res4.status_code} - {res4.text}" + res4_parsed = SourceAndEntityTypeView.model_validate(res4.json()) - new_entity_types = { - entity_type["name"]: EntityType.model_validate(entity_type) - for entity_type in res4.json() - } + def to_entity_type_map(map: dict[str, list[EntityType]]) -> dict[str, EntityType]: + return { + entity_type.name: entity_type + for entity_types in map.values() + for entity_type in entity_types + } - expected_entity_types = { - entity_type["name"]: EntityType.model_validate(entity_type) - for entity_type in res2.json() - } + expected_entity_types = to_entity_type_map(map=res2_parsed.entity_types) + new_entity_types = to_entity_type_map(map=res4_parsed.entity_types) + + # These are the updates. + # We're just manually updating them. expected_entity_types["ACCOUNT"].active = True expected_entity_types["ACCOUNT"].description = "Test." expected_entity_types["OPPORTUNITY"].active = False diff --git a/web/src/app/admin/connector/[ccPairId]/unused.txt b/web/src/app/admin/connector/[ccPairId]/unused.txt deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/web/src/app/admin/kg/KGEntityTypes.tsx b/web/src/app/admin/kg/KGEntityTypes.tsx new file mode 100644 index 00000000000..b674dd91303 --- /dev/null +++ b/web/src/app/admin/kg/KGEntityTypes.tsx @@ -0,0 +1,323 @@ +import { SourceIcon } from "@/components/SourceIcon"; +import React, { useEffect, useState } from "react"; +import { Switch } from "@/components/ui/switch"; +import Link from "next/link"; +import { EntityType, SourceAndEntityTypeView } from "./interfaces"; +import CollapsibleCard from "@/components/CollapsibleCard"; +import { ValidSources } from "@/lib/types"; +import { FaCircleQuestion } from "react-icons/fa6"; +import { Input } from "@/components/ui/input"; +import { CheckmarkIcon } from "@/components/icons/icons"; +import { Button } from "@/components/ui/button"; + +// Utility: Convert capitalized snake case to human readable case +function snakeToHumanReadable(str: string): string { + return ( + str + .toLowerCase() + .replace(/_/g, " ") + .replace(/\b\w/g, (match) => match.toUpperCase()) + // # TODO (@raunakab) + // Special case to replace all instances of "Pr" with "PR". + // This is a *dumb* implementation. If there exists a string that starts with "Pr" (e.g., "Prompt"), + // then this line will stupidly convert it to "PRompt". + // Fix this later (or if this becomes a problem lol). + .replace("Pr", "PR") + ); +} + +// Custom Header Component +function TableHeader() { + return ( +
+
Entity Name
+
Description
+
Active
+
+ ); +} + +// Custom Row Component +function TableRow({ entityType }: { entityType: EntityType }) { + const [entityTypeState, setEntityTypeState] = useState(entityType); + const [descriptionSavingState, setDescriptionSavingState] = useState< + "saving" | "saved" | "failed" | undefined + >(undefined); + + const [timer, setTimer] = useState(null); + const [checkmarkVisible, setCheckmarkVisible] = useState(false); + const [hasMounted, setHasMounted] = useState(false); + + const handleToggle = async (checked: boolean) => { + const response = await fetch("/api/admin/kg/entity-types", { + method: "PUT", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify([{ ...entityType, active: checked }]), + }); + + if (!response.ok) return; + + setEntityTypeState({ ...entityTypeState, active: checked }); + }; + + const handleDescriptionChange = async (description: string) => { + try { + const response = await fetch("/api/admin/kg/entity-types", { + method: "PUT", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify([{ ...entityType, description }]), + }); + if (response.ok) { + setDescriptionSavingState("saved"); + setCheckmarkVisible(true); + setTimeout(() => setCheckmarkVisible(false), 1000); + } else { + setDescriptionSavingState("failed"); + setCheckmarkVisible(false); + } + } catch { + setDescriptionSavingState("failed"); + setCheckmarkVisible(false); + } finally { + setTimeout(() => setDescriptionSavingState(undefined), 1000); + } + }; + + useEffect(() => { + if (!hasMounted) { + setHasMounted(true); + return; + } + if (timer) clearTimeout(timer); + setTimer( + setTimeout(() => { + setDescriptionSavingState("saving"); + setCheckmarkVisible(false); + setTimer( + setTimeout( + () => handleDescriptionChange(entityTypeState.description), + 500 + ) + ); + }, 1000) + ); + // eslint-disable-next-line react-hooks/exhaustive-deps + }, [entityTypeState.description]); + + return ( +
+
+
+
+ + {snakeToHumanReadable(entityType.name)} + +
+
+ + setEntityTypeState({ + ...entityTypeState, + description: e.target.value, + }) + } + onKeyDown={async (e) => { + if (e.key === "Enter") { + e.preventDefault(); + if (timer) { + clearTimeout(timer); + setTimer(null); + } + setDescriptionSavingState("saving"); + setCheckmarkVisible(false); + await handleDescriptionChange( + (e.target as HTMLInputElement).value + ); + } + }} + /> + + + + + + + + +
+
+
+ +
+
+
+ ); +} + +interface KGEntityTypesProps { + sourceAndEntityTypes: SourceAndEntityTypeView; +} + +export default function KGEntityTypes({ + sourceAndEntityTypes, +}: KGEntityTypesProps) { + // State to control open/close of all CollapsibleCards + const [openCards, setOpenCards] = useState<{ [key: string]: boolean }>({}); + // State for search query + const [search, setSearch] = useState(""); + + // Initialize openCards state when data changes + useEffect(() => { + const initialState: { [key: string]: boolean } = {}; + Object.keys(sourceAndEntityTypes.entity_types).forEach((key) => { + initialState[key] = true; + }); + setOpenCards(initialState); + }, [sourceAndEntityTypes]); + + // Handlers for expand/collapse all + const handleExpandAll = () => { + const newState: { [key: string]: boolean } = {}; + Object.keys(sourceAndEntityTypes.entity_types).forEach((key) => { + newState[key] = true; + }); + setOpenCards(newState); + }; + const handleCollapseAll = () => { + const newState: { [key: string]: boolean } = {}; + Object.keys(sourceAndEntityTypes.entity_types).forEach((key) => { + newState[key] = false; + }); + setOpenCards(newState); + }; + + // Determine if all cards are closed + const allClosed = Object.values(openCards).every((v) => v === false); + + return ( +
+
+ setSearch(e.target.value)} + /> + +
+
+ {Object.entries(sourceAndEntityTypes.entity_types).length === 0 ? ( +
+

No results available.

+

+ To configure Knowledge Graph, first connect some{" "} + + Connectors. + +

+
+ ) : ( + Object.entries(sourceAndEntityTypes.entity_types) + .filter(([key]) => + snakeToHumanReadable(key) + .toLowerCase() + .includes(search.toLowerCase()) + ) + .sort(([keyA], [keyB]) => keyA.localeCompare(keyB)) + .map(([key, entityTypesArr]) => { + const stats = sourceAndEntityTypes.source_statistics[key] ?? { + source_name: key, + last_updated: undefined, + entities_count: 0, + }; + return ( +
+ + {Object.values(ValidSources).includes( + key as ValidSources + ) ? ( + + ) : ( + + )} + {snakeToHumanReadable(key)} + + + + Entities Count + + + {stats.entities_count} + + + + + Last Updated + + + {stats.last_updated + ? new Date(stats.last_updated).toLocaleString() + : "N/A"} + + + + + } + // Use a key that changes with openCards[key] to force remount and update defaultOpen + key={`${key}-${openCards[key]}`} + defaultOpen={ + openCards[key] !== undefined ? openCards[key] : true + } + > +
+ + {entityTypesArr.map( + (entityType: EntityType, index: number) => ( + + ) + )} +
+
+
+ ); + }) + )} +
+
+ ); +} diff --git a/web/src/app/admin/kg/interfaces.ts b/web/src/app/admin/kg/interfaces.ts index fa2985a2968..aca740d96ca 100644 --- a/web/src/app/admin/kg/interfaces.ts +++ b/web/src/app/admin/kg/interfaces.ts @@ -16,31 +16,20 @@ export type KGConfigRaw = { export type EntityTypeValues = { [key: string]: EntityType }; +export type SourceAndEntityTypeView = { + source_statistics: Record; + entity_types: Record; +}; + +export type SourceStatistics = { + source_name: string; + last_updated: string; + entities_count: number; +}; + export type EntityType = { name: string; description: string; active: boolean; + grounded_source_name: string; }; - -export function sanitizeKGConfig(raw: KGConfigRaw): KGConfig { - const coverage_start = new Date(raw.coverage_start); - - return { - ...raw, - coverage_start, - }; -} - -export function sanitizeKGEntityTypes( - entityTypes: EntityType[] -): [EntityTypeValues, EntityType[]] { - const entityTypeMap: EntityTypeValues = {}; - for (const entityType of entityTypes) { - entityTypeMap[entityType.name.toLowerCase()] = entityType; - } - - const sortedData = Object.values(entityTypeMap); - sortedData.sort((a, b) => a.name.localeCompare(b.name)); - - return [entityTypeMap, sortedData]; -} diff --git a/web/src/app/admin/kg/page.tsx b/web/src/app/admin/kg/page.tsx index 53f51329889..e83c11be6a9 100644 --- a/web/src/app/admin/kg/page.tsx +++ b/web/src/app/admin/kg/page.tsx @@ -5,7 +5,6 @@ import { AdminPageTitle } from "@/components/admin/Title"; import { DatePickerField, FieldLabel, - TextAreaField, TextArrayField, TextFormField, } from "@/components/Field"; @@ -13,33 +12,19 @@ import { BrainIcon } from "@/components/icons/icons"; import { Modal } from "@/components/Modal"; import { Button } from "@/components/ui/button"; import { SwitchField } from "@/components/ui/switch"; -import { - Form, - Formik, - FormikProps, - FormikState, - useFormikContext, -} from "formik"; +import { Form, Formik, FormikState, useFormikContext } from "formik"; import { useState } from "react"; import { FiSettings } from "react-icons/fi"; import * as Yup from "yup"; -import { - EntityType, - KGConfig, - EntityTypeValues, - sanitizeKGConfig, - KGConfigRaw, - sanitizeKGEntityTypes, -} from "./interfaces"; -import { ColumnDef } from "@tanstack/react-table"; -import { DataTable } from "@/components/ui/dataTable"; +import { KGConfig, KGConfigRaw, SourceAndEntityTypeView } from "./interfaces"; +import { sanitizeKGConfig } from "./utils"; import useSWR from "swr"; import { errorHandlingFetcher } from "@/lib/fetcher"; import { PopupSpec, usePopup } from "@/components/admin/connectors/Popup"; import Title from "@/components/ui/title"; import { redirect } from "next/navigation"; -import Link from "next/link"; import { useIsKGExposed } from "./utils"; +import KGEntityTypes from "./KGEntityTypes"; function createDomainField( name: string, @@ -225,191 +210,6 @@ function KGConfiguration({ ); } -function KGEntityTypes({ - kgEntityTypes, - sortedKGEntityTypes: sorted, - setPopup, - refreshKGEntityTypes, -}: { - kgEntityTypes: EntityTypeValues; - sortedKGEntityTypes: EntityType[]; - setPopup?: (spec: PopupSpec | null) => void; - refreshKGEntityTypes?: () => void; -}) { - const [sortedKGEntityTypes, setSortedKGEntityTypes] = useState(sorted); - console.log({ sortedKGEntityTypes }); - - const columns: ColumnDef[] = [ - { - accessorKey: "name", - header: "Name", - }, - { - accessorKey: "description", - header: "Description", - cell: ({ row }) => ( -
- -
- ), - }, - { - accessorKey: "active", - header: "Active", - cell: ({ row }) => ( - - ), - }, - ]; - - const validationSchema = Yup.array( - Yup.object({ - active: Yup.boolean().required(), - }) - ); - - const onSubmit = async ( - values: EntityTypeValues, - { - resetForm, - }: { - resetForm: (nextState?: Partial>) => void; - } - ) => { - const diffs: EntityType[] = []; - - for (const key in kgEntityTypes) { - const initialValue = kgEntityTypes[key]!; - const currentValue = values[key]!; - const equals = - initialValue.description === currentValue.description && - initialValue.active === currentValue.active; - if (!equals) { - diffs.push(currentValue); - } - } - - if (diffs.length === 0) return; - - const response = await fetch("/api/admin/kg/entity-types", { - method: "PUT", - headers: { - "Content-Type": "application/json", - }, - body: JSON.stringify(diffs), - }); - - if (!response.ok) { - const errorMsg = (await response.json()).detail; - console.warn({ errorMsg }); - setPopup?.({ - message: "Failed to configure Entity Types.", - type: "error", - }); - return; - } - - setPopup?.({ - message: "Successfully updated Entity Types.", - type: "success", - }); - - refreshKGEntityTypes?.(); - - resetForm({ values }); - }; - - const reset = async (props: FormikProps) => { - const result = await fetch("/api/admin/kg/reset", { method: "PUT" }); - - if (!result.ok) { - setPopup?.({ - message: "Failed to reset Knowledge Graph.", - type: "error", - }); - return; - } - - const rawData = (await result.json()) as EntityType[]; - const [newEntityTypes, newSortedEntityTypes] = - sanitizeKGEntityTypes(rawData); - props.resetForm({ values: newEntityTypes }); - setSortedKGEntityTypes(newSortedEntityTypes); - - setPopup?.({ - message: "Successfully reset Knowledge Graph.", - type: "success", - }); - - refreshKGEntityTypes?.(); - }; - - return ( - - {(props) => ( -
- - -

No results available.

-

- To configure Knowledge Graph, first connect some {` `} - - Connectors. - -

-
- } - /> -
- - -
- -
-

- Danger -

-
-

- Resetting will delete all extracted entities and relationships - and deactivate all entity types. After reset, you can reactivate - entity types to begin populating the Knowledge Graph again. -

- -
-
- - )} - - ); -} - function Main() { // Data: const { @@ -418,10 +218,13 @@ function Main() { mutate: configMutate, } = useSWR("/api/admin/kg/config", errorHandlingFetcher); const { - data: entityTypesData, + data: sourceAndEntityTypesData, isLoading: entityTypesIsLoading, mutate: entityTypesMutate, - } = useSWR("/api/admin/kg/entity-types", errorHandlingFetcher); + } = useSWR( + "/api/admin/kg/entity-types", + errorHandlingFetcher + ); // Local State: const { popup, setPopup } = usePopup(); @@ -431,14 +234,12 @@ function Main() { configIsLoading || entityTypesIsLoading || !configData || - !entityTypesData + !sourceAndEntityTypesData ) { return <>; } const kgConfig = sanitizeKGConfig(configData); - const [kgEntityTypes, sortedKGEntityTypes] = - sanitizeKGEntityTypes(entityTypesData); return (
@@ -484,15 +285,10 @@ function Main() { {kgConfig.enabled && ( <> -

+

Entity Types

- + )} {configureModalShown && ( diff --git a/web/src/app/admin/kg/utils.ts b/web/src/app/admin/kg/utils.ts index 8142db581ce..855aa0f0c98 100644 --- a/web/src/app/admin/kg/utils.ts +++ b/web/src/app/admin/kg/utils.ts @@ -1,6 +1,7 @@ import { useUser } from "@/components/user/UserProvider"; import { errorHandlingFetcher } from "@/lib/fetcher"; import useSWR from "swr"; +import { KGConfig, KGConfigRaw } from "./interfaces"; export type KgExposedStatus = { kgExposed: boolean; isLoading: boolean }; @@ -17,3 +18,12 @@ export function useIsKGExposed(): KgExposedStatus { ); return { kgExposed: kgExposedRaw ?? false, isLoading }; } + +export function sanitizeKGConfig(raw: KGConfigRaw): KGConfig { + const coverage_start = new Date(raw.coverage_start); + + return { + ...raw, + coverage_start, + }; +} diff --git a/web/src/components/CollapsibleCard.tsx b/web/src/components/CollapsibleCard.tsx new file mode 100644 index 00000000000..0fe029616e3 --- /dev/null +++ b/web/src/components/CollapsibleCard.tsx @@ -0,0 +1,80 @@ +import { ChevronDown } from "lucide-react"; +import React, { useState, ReactNode, useRef, useLayoutEffect } from "react"; + +interface CollapsibleCardProps { + header: JSX.Element; + children: ReactNode; + defaultOpen?: boolean; + className?: string; +} + +/** + * Renders a "collapsible" card which, when collapsed, is meant to showcase very "high-level" information (e.g., the name), but when expanded, can show a list of sub-items which are all related to one another. + */ +export default function CollapsibleCard({ + header, + children, + defaultOpen = false, + className = "", +}: CollapsibleCardProps) { + const [open, setOpen] = useState(defaultOpen); + const [maxHeight, setMaxHeight] = useState(undefined); + const contentRef = useRef(null); + + // Update maxHeight for animation when open/close + useLayoutEffect(() => { + if (open && contentRef.current) { + setMaxHeight(contentRef.current.scrollHeight + "px"); + } else { + setMaxHeight("0px"); + } + }, [open, children]); + + // If content changes size while open, update maxHeight + useLayoutEffect(() => { + if (open && contentRef.current) { + const handleResize = () => { + setMaxHeight(contentRef.current!.scrollHeight + "px"); + }; + handleResize(); + window.addEventListener("resize", handleResize); + return () => window.removeEventListener("resize", handleResize); + } + }, [open, children]); + + return ( +
+ +
+
+ {children} +
+
+
+ ); +} diff --git a/web/src/components/ui/dataTable.tsx b/web/src/components/ui/dataTable.tsx deleted file mode 100644 index ccc786981d5..00000000000 --- a/web/src/components/ui/dataTable.tsx +++ /dev/null @@ -1,81 +0,0 @@ -"use client"; - -import { - ColumnDef, - flexRender, - getCoreRowModel, - useReactTable, -} from "@tanstack/react-table"; - -import { - Table, - TableBody, - TableCell, - TableHead, - TableHeader, - TableRow, -} from "@/components/ui/table"; - -interface DataTableProps { - columns: ColumnDef[]; - data: TData[]; - emptyMessage?: string | JSX.Element; -} - -export function DataTable({ - columns, - data, - emptyMessage = "No results.", -}: DataTableProps) { - const table = useReactTable({ - data, - columns, - getCoreRowModel: getCoreRowModel(), - }); - - return ( -
- - - {table.getHeaderGroups().map((headerGroup) => ( - - {headerGroup.headers.map((header) => ( - - {header.isPlaceholder - ? null - : flexRender( - header.column.columnDef.header, - header.getContext() - )} - - ))} - - ))} - - - {table.getRowModel().rows?.length ? ( - table.getRowModel().rows.map((row) => ( - - {row.getVisibleCells().map((cell) => ( - - {flexRender(cell.column.columnDef.cell, cell.getContext())} - - ))} - - )) - ) : ( - - - {emptyMessage} - - - )} - -
-
- ); -} diff --git a/web/src/components/ui/datePicker.tsx b/web/src/components/ui/datePicker.tsx index e169d12d882..370ac9fca5c 100644 --- a/web/src/components/ui/datePicker.tsx +++ b/web/src/components/ui/datePicker.tsx @@ -37,9 +37,10 @@ export function DatePicker({ .fill(currYear) .map((currYear, index) => currYear - index); const [shownDate, setShownDate] = useState(selectedDate ?? new Date()); + const [open, setOpen] = useState(false); return ( - +