diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index aad6b6fc99f..b5ea772d5e2 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -88,7 +88,12 @@ jobs: echo "Waiting for service to be available..." sleep 5 done - cd sdk/python && uv sync --python 3.10 --group test --frozen && uv pip install . && source .venv/bin/activate && cd test/test_sdk_api && pytest -s --tb=short get_email.py t_dataset.py t_chat.py t_session.py t_document.py t_chunk.py + if [[ $GITHUB_EVENT_NAME == 'schedule' ]]; then + export HTTP_API_TEST_LEVEL=p3 + else + export HTTP_API_TEST_LEVEL=p2 + fi + UV_LINK_MODE=copy uv sync --python 3.10 --only-group test --no-default-groups --frozen && uv pip install sdk/python && uv run --only-group test --no-default-groups pytest -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/test_sdk_api - name: Run frontend api tests against Elasticsearch run: | @@ -98,7 +103,7 @@ jobs: echo "Waiting for service to be available..." sleep 5 done - cd sdk/python && uv sync --python 3.10 --group test --frozen && source .venv/bin/activate && cd test/test_frontend_api && pytest -s --tb=short get_email.py test_dataset.py + cd sdk/python && UV_LINK_MODE=copy uv sync --python 3.10 --group test --frozen && source .venv/bin/activate && cd test/test_frontend_api && pytest -s --tb=short get_email.py test_dataset.py - name: Run http api tests against Elasticsearch run: | @@ -113,7 +118,7 @@ jobs: else export HTTP_API_TEST_LEVEL=p2 fi - cd sdk/python && uv sync --python 3.10 --group test --frozen && source .venv/bin/activate && cd test/test_http_api && pytest -s --tb=short --level=${HTTP_API_TEST_LEVEL} + UV_LINK_MODE=copy uv sync --python 3.10 --only-group test --no-default-groups --frozen && uv run --only-group test --no-default-groups pytest -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/test_http_api - name: Stop ragflow:nightly if: always() # always run this step even if previous steps failed @@ -132,7 +137,12 @@ jobs: echo "Waiting for service to be available..." sleep 5 done - cd sdk/python && uv sync --python 3.10 --group test --frozen && uv pip install . && source .venv/bin/activate && cd test/test_sdk_api && pytest -s --tb=short get_email.py t_dataset.py t_chat.py t_session.py t_document.py t_chunk.py + if [[ $GITHUB_EVENT_NAME == 'schedule' ]]; then + export HTTP_API_TEST_LEVEL=p3 + else + export HTTP_API_TEST_LEVEL=p2 + fi + UV_LINK_MODE=copy uv sync --python 3.10 --only-group test --no-default-groups --frozen && uv pip install sdk/python && DOC_ENGINE=infinity uv run --only-group test --no-default-groups pytest -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/test_sdk_api - name: Run frontend api tests against Infinity run: | @@ -142,7 +152,7 @@ jobs: echo "Waiting for service to be available..." sleep 5 done - cd sdk/python && uv sync --python 3.10 --group test --frozen && source .venv/bin/activate && cd test/test_frontend_api && pytest -s --tb=short get_email.py test_dataset.py + cd sdk/python && UV_LINK_MODE=copy uv sync --python 3.10 --group test --frozen && source .venv/bin/activate && cd test/test_frontend_api && pytest -s --tb=short get_email.py test_dataset.py - name: Run http api tests against Infinity run: | @@ -157,7 +167,7 @@ jobs: else export HTTP_API_TEST_LEVEL=p2 fi - cd sdk/python && uv sync --python 3.10 --group test --frozen && source .venv/bin/activate && cd test/test_http_api && DOC_ENGINE=infinity pytest -s --tb=short --level=${HTTP_API_TEST_LEVEL} + UV_LINK_MODE=copy uv sync --python 3.10 --only-group test --no-default-groups --frozen && DOC_ENGINE=infinity uv run --only-group test --no-default-groups pytest -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/test_http_api - name: Stop ragflow:nightly if: always() # always run this step even if previous steps failed diff --git a/.gitignore b/.gitignore index b9c688dcc19..52c53277043 100644 --- a/.gitignore +++ b/.gitignore @@ -36,6 +36,12 @@ sdk/python/ragflow.egg-info/ sdk/python/build/ sdk/python/dist/ sdk/python/ragflow_sdk.egg-info/ + +# Exclude dep files +libssl*.deb +tika-server*.jar* +cl100k_base.tiktoken +chrome* huggingface.co/ nltk_data/ @@ -44,3 +50,146 @@ nltk_data/ .lh/ .venv docker/data + + +#--------------------------------------------------# +# The following was generated with gitignore.nvim: # +#--------------------------------------------------# +# Gitignore for the following technologies: Node + +# Logs +logs +*.log +npm-debug.log* +yarn-debug.log* +yarn-error.log* +lerna-debug.log* +.pnpm-debug.log* + +# Diagnostic reports (https://nodejs.org/api/report.html) +report.[0-9]*.[0-9]*.[0-9]*.[0-9]*.json + +# Runtime data +pids +*.pid +*.seed +*.pid.lock + +# Directory for instrumented libs generated by jscoverage/JSCover +lib-cov + +# Coverage directory used by tools like istanbul +coverage +*.lcov + +# nyc test coverage +.nyc_output + +# Grunt intermediate storage (https://gruntjs.com/creating-plugins#storing-task-files) +.grunt + +# Bower dependency directory (https://bower.io/) +bower_components + +# node-waf configuration +.lock-wscript + +# Compiled binary addons (https://nodejs.org/api/addons.html) +build/Release + +# Dependency directories +node_modules/ +jspm_packages/ + +# Snowpack dependency directory (https://snowpack.dev/) +web_modules/ + +# TypeScript cache +*.tsbuildinfo + +# Optional npm cache directory +.npm + +# Optional eslint cache +.eslintcache + +# Optional stylelint cache +.stylelintcache + +# Microbundle cache +.rpt2_cache/ +.rts2_cache_cjs/ +.rts2_cache_es/ +.rts2_cache_umd/ + +# Optional REPL history +.node_repl_history + +# Output of 'npm pack' +*.tgz + +# Yarn Integrity file +.yarn-integrity + +# dotenv environment variable files +.env +.env.development.local +.env.test.local +.env.production.local +.env.local + +# parcel-bundler cache (https://parceljs.org/) +.cache +.parcel-cache + +# Next.js build output +.next +out + +# Nuxt.js build / generate output +.nuxt +dist + +# Gatsby files +.cache/ +# Comment in the public line in if your project uses Gatsby and not Next.js +# https://nextjs.org/blog/next-9-1#public-directory-support +# public + +# vuepress build output +.vuepress/dist + +# vuepress v2.x temp and cache directory +.temp + +# Docusaurus cache and generated files +.docusaurus + +# Serverless directories +.serverless/ + +# FuseBox cache +.fusebox/ + +# DynamoDB Local files +.dynamodb/ + +# TernJS port file +.tern-port + +# Stores VSCode versions used for testing VSCode extensions +.vscode-test + +# yarn v2 +.yarn/cache +.yarn/unplugged +.yarn/build-state.yml +.yarn/install-state.gz +.pnp.* + +# Serverless Webpack directories +.webpack/ + +# SvelteKit build / generate output +.svelte-kit + diff --git a/README.md b/README.md index ee9bc979fcb..c8e47cdbc92 100644 --- a/README.md +++ b/README.md @@ -5,13 +5,13 @@

- English | - 简体中文 | - 繁体中文 | - 日本語 | - 한국어 | - Bahasa Indonesia | - Português (Brasil) + README in English + 简体中文版自述文件 + 繁體版中文自述文件 + 日本語のREADME + 한국어 + Bahasa Indonesia + Português(Brasil)

@@ -22,7 +22,7 @@ Static Badge - docker pull infiniflow/ragflow:v0.19.0 + docker pull infiniflow/ragflow:v0.19.1 Latest Release @@ -30,6 +30,9 @@ license + + Ask DeepWiki +

@@ -40,6 +43,12 @@ Demo

+# + +
+infiniflow%2Fragflow | Trendshift +
+
📕 Table of Contents @@ -78,11 +87,11 @@ Try our demo at [https://demo.ragflow.io](https://demo.ragflow.io). ## 🔥 Latest Updates +- 2025-05-23 Adds a Python/JavaScript code executor component to Agent. +- 2025-05-05 Supports cross-language query. - 2025-03-19 Supports using a multi-modal model to make sense of images within PDF or DOCX files. - 2025-02-28 Combined with Internet search (Tavily), supports reasoning like Deep Research for any LLMs. -- 2025-01-26 Optimizes knowledge graph extraction and application, offering various configuration options. - 2024-12-18 Upgrades Document Layout Analysis model in DeepDoc. -- 2024-11-01 Adds keyword extraction and related question generation to the parsed chunks to improve the accuracy of retrieval. - 2024-08-22 Support text to SQL statements through RAG. ## 🎉 Stay Tuned @@ -178,7 +187,7 @@ releases! 🌟 > All Docker images are built for x86 platforms. We don't currently offer Docker images for ARM64. > If you are on an ARM64 platform, follow [this guide](https://ragflow.io/docs/dev/build_docker_image) to build a Docker image compatible with your system. - > The command below downloads the `v0.19.0-slim` edition of the RAGFlow Docker image. See the following table for descriptions of different RAGFlow editions. To download a RAGFlow edition different from `v0.19.0-slim`, update the `RAGFLOW_IMAGE` variable accordingly in **docker/.env** before using `docker compose` to start the server. For example: set `RAGFLOW_IMAGE=infiniflow/ragflow:v0.19.0` for the full edition `v0.19.0`. + > The command below downloads the `v0.19.1-slim` edition of the RAGFlow Docker image. See the following table for descriptions of different RAGFlow editions. To download a RAGFlow edition different from `v0.19.1-slim`, update the `RAGFLOW_IMAGE` variable accordingly in **docker/.env** before using `docker compose` to start the server. For example: set `RAGFLOW_IMAGE=infiniflow/ragflow:v0.19.1` for the full edition `v0.19.1`. ```bash $ cd ragflow/docker @@ -191,8 +200,8 @@ releases! 🌟 | RAGFlow image tag | Image size (GB) | Has embedding models? | Stable? | |-------------------|-----------------|-----------------------|--------------------------| - | v0.19.0 | ≈9 | :heavy_check_mark: | Stable release | - | v0.19.0-slim | ≈2 | ❌ | Stable release | + | v0.19.1 | ≈9 | :heavy_check_mark: | Stable release | + | v0.19.1-slim | ≈2 | ❌ | Stable release | | nightly | ≈9 | :heavy_check_mark: | _Unstable_ nightly build | | nightly-slim | ≈2 | ❌ | _Unstable_ nightly build | diff --git a/README_id.md b/README_id.md index d8b0292d1f1..50e00a9d5ed 100644 --- a/README_id.md +++ b/README_id.md @@ -5,13 +5,13 @@

- English | - 简体中文 | - 繁体中文 | - 日本語 | - 한국어 | - Bahasa Indonesia | - Português (Brasil) + README in English + 简体中文版自述文件 + 繁體中文版自述文件 + 日本語のREADME + 한국어 + Bahasa Indonesia + Português(Brasil)

@@ -22,7 +22,7 @@ Lencana Daring - docker pull infiniflow/ragflow:v0.19.0 + docker pull infiniflow/ragflow:v0.19.1 Rilis Terbaru @@ -30,6 +30,9 @@ Lisensi + + Ask DeepWiki +

@@ -40,6 +43,8 @@ Demo

+# +
📕 Daftar Isi @@ -75,11 +80,11 @@ Coba demo kami di [https://demo.ragflow.io](https://demo.ragflow.io). ## 🔥 Pembaruan Terbaru +- 2025-05-23 Menambahkan komponen pelaksana kode Python/JS ke Agen. +- 2025-05-05 Mendukung kueri lintas bahasa. - 2025-03-19 Mendukung penggunaan model multi-modal untuk memahami gambar di dalam file PDF atau DOCX. - 2025-02-28 dikombinasikan dengan pencarian Internet (TAVILY), mendukung penelitian mendalam untuk LLM apa pun. -- 2025-01-26 Optimalkan ekstraksi dan penerapan grafik pengetahuan dan sediakan berbagai opsi konfigurasi. - 2024-12-18 Meningkatkan model Analisis Tata Letak Dokumen di DeepDoc. -- 2024-11-01 Penambahan ekstraksi kata kunci dan pembuatan pertanyaan terkait untuk meningkatkan akurasi pengambilan. - 2024-08-22 Dukungan untuk teks ke pernyataan SQL melalui RAG. ## 🎉 Tetap Terkini @@ -173,7 +178,7 @@ Coba demo kami di [https://demo.ragflow.io](https://demo.ragflow.io). > Semua gambar Docker dibangun untuk platform x86. Saat ini, kami tidak menawarkan gambar Docker untuk ARM64. > Jika Anda menggunakan platform ARM64, [silakan gunakan panduan ini untuk membangun gambar Docker yang kompatibel dengan sistem Anda](https://ragflow.io/docs/dev/build_docker_image). -> Perintah di bawah ini mengunduh edisi v0.19.0-slim dari gambar Docker RAGFlow. Silakan merujuk ke tabel berikut untuk deskripsi berbagai edisi RAGFlow. Untuk mengunduh edisi RAGFlow yang berbeda dari v0.19.0-slim, perbarui variabel RAGFLOW_IMAGE di docker/.env sebelum menggunakan docker compose untuk memulai server. Misalnya, atur RAGFLOW_IMAGE=infiniflow/ragflow:v0.19.0 untuk edisi lengkap v0.19.0. +> Perintah di bawah ini mengunduh edisi v0.19.1-slim dari gambar Docker RAGFlow. Silakan merujuk ke tabel berikut untuk deskripsi berbagai edisi RAGFlow. Untuk mengunduh edisi RAGFlow yang berbeda dari v0.19.1-slim, perbarui variabel RAGFLOW_IMAGE di docker/.env sebelum menggunakan docker compose untuk memulai server. Misalnya, atur RAGFLOW_IMAGE=infiniflow/ragflow:v0.19.1 untuk edisi lengkap v0.19.1. ```bash $ cd ragflow/docker @@ -186,8 +191,8 @@ $ docker compose -f docker-compose.yml up -d | RAGFlow image tag | Image size (GB) | Has embedding models? | Stable? | | ----------------- | --------------- | --------------------- | ------------------------ | -| v0.19.0 | ≈9 | :heavy_check_mark: | Stable release | -| v0.19.0-slim | ≈2 | ❌ | Stable release | +| v0.19.1 | ≈9 | :heavy_check_mark: | Stable release | +| v0.19.1-slim | ≈2 | ❌ | Stable release | | nightly | ≈9 | :heavy_check_mark: | _Unstable_ nightly build | | nightly-slim | ≈2 | ❌ | _Unstable_ nightly build | diff --git a/README_ja.md b/README_ja.md index e9ad224860a..f0f153de060 100644 --- a/README_ja.md +++ b/README_ja.md @@ -5,13 +5,13 @@

- English | - 简体中文 | - 繁体中文 | - 日本語 | - 한국어 | - Bahasa Indonesia | - Português (Brasil) + README in English + 简体中文版自述文件 + 繁體中文版自述文件 + 日本語のREADME + 한국어 + Bahasa Indonesia + Português(Brasil)

@@ -22,7 +22,7 @@ Static Badge - docker pull infiniflow/ragflow:v0.19.0 + docker pull infiniflow/ragflow:v0.19.1 Latest Release @@ -30,6 +30,9 @@ license + + Ask DeepWiki +

@@ -40,6 +43,8 @@ Demo

+# + ## 💡 RAGFlow とは? [RAGFlow](https://ragflow.io/) は、深い文書理解に基づいたオープンソースの RAG (Retrieval-Augmented Generation) エンジンである。LLM(大規模言語モデル)を組み合わせることで、様々な複雑なフォーマットのデータから根拠のある引用に裏打ちされた、信頼できる質問応答機能を実現し、あらゆる規模のビジネスに適した RAG ワークフローを提供します。 @@ -55,11 +60,11 @@ ## 🔥 最新情報 +- 2025-05-23 エージェントに Python/JS コードエグゼキュータコンポーネントを追加しました。 +- 2025-05-05 言語間クエリをサポートしました。 - 2025-03-19 PDFまたはDOCXファイル内の画像を理解するために、多モーダルモデルを使用することをサポートします。 - 2025-02-28 インターネット検索 (TAVILY) と組み合わせて、あらゆる LLM の詳細な調査をサポートします。 -- 2025-01-26 ナレッジ グラフの抽出と適用を最適化し、さまざまな構成オプションを提供します。 - 2024-12-18 DeepDoc のドキュメント レイアウト分析モデルをアップグレードします。 -- 2024-11-01 再現の精度を向上させるために、解析されたチャンクにキーワード抽出と関連質問の生成を追加しました。 - 2024-08-22 RAG を介して SQL ステートメントへのテキストをサポートします。 ## 🎉 続きを楽しみに @@ -152,7 +157,7 @@ > 現在、公式に提供されているすべての Docker イメージは x86 アーキテクチャ向けにビルドされており、ARM64 用の Docker イメージは提供されていません。 > ARM64 アーキテクチャのオペレーティングシステムを使用している場合は、[このドキュメント](https://ragflow.io/docs/dev/build_docker_image)を参照して Docker イメージを自分でビルドしてください。 - > 以下のコマンドは、RAGFlow Docker イメージの v0.19.0-slim エディションをダウンロードします。異なる RAGFlow エディションの説明については、以下の表を参照してください。v0.19.0-slim とは異なるエディションをダウンロードするには、docker/.env ファイルの RAGFLOW_IMAGE 変数を適宜更新し、docker compose を使用してサーバーを起動してください。例えば、完全版 v0.19.0 をダウンロードするには、RAGFLOW_IMAGE=infiniflow/ragflow:v0.19.0 と設定します。 + > 以下のコマンドは、RAGFlow Docker イメージの v0.19.1-slim エディションをダウンロードします。異なる RAGFlow エディションの説明については、以下の表を参照してください。v0.19.1-slim とは異なるエディションをダウンロードするには、docker/.env ファイルの RAGFLOW_IMAGE 変数を適宜更新し、docker compose を使用してサーバーを起動してください。例えば、完全版 v0.19.1 をダウンロードするには、RAGFLOW_IMAGE=infiniflow/ragflow:v0.19.1 と設定します。 ```bash $ cd ragflow/docker @@ -165,8 +170,8 @@ | RAGFlow image tag | Image size (GB) | Has embedding models? | Stable? | | ----------------- | --------------- | --------------------- | ------------------------ | - | v0.19.0 | ≈9 | :heavy_check_mark: | Stable release | - | v0.19.0-slim | ≈2 | ❌ | Stable release | + | v0.19.1 | ≈9 | :heavy_check_mark: | Stable release | + | v0.19.1-slim | ≈2 | ❌ | Stable release | | nightly | ≈9 | :heavy_check_mark: | _Unstable_ nightly build | | nightly-slim | ≈2 | ❌ | _Unstable_ nightly build | diff --git a/README_ko.md b/README_ko.md index 8489159521f..322a32d20ed 100644 --- a/README_ko.md +++ b/README_ko.md @@ -5,13 +5,13 @@

- English | - 简体中文 | - 繁体中文 | - 日本語 | - 한국어 | - Bahasa Indonesia | - Português (Brasil) + README in English + 简体中文版自述文件 + 繁體版中文自述文件 + 日本語のREADME + 한국어 + Bahasa Indonesia + Português(Brasil)

@@ -22,7 +22,7 @@ Static Badge - docker pull infiniflow/ragflow:v0.19.0 + docker pull infiniflow/ragflow:v0.19.1 Latest Release @@ -30,6 +30,9 @@ license + + Ask DeepWiki +

@@ -40,6 +43,8 @@ Demo

+# + ## 💡 RAGFlow란? [RAGFlow](https://ragflow.io/)는 심층 문서 이해에 기반한 오픈소스 RAG (Retrieval-Augmented Generation) 엔진입니다. 이 엔진은 대규모 언어 모델(LLM)과 결합하여 정확한 질문 응답 기능을 제공하며, 다양한 복잡한 형식의 데이터에서 신뢰할 수 있는 출처를 바탕으로 한 인용을 통해 이를 뒷받침합니다. RAGFlow는 규모에 상관없이 모든 기업에 최적화된 RAG 워크플로우를 제공합니다. @@ -55,11 +60,11 @@ ## 🔥 업데이트 +- 2025-05-23 Agent에 Python/JS 코드 실행기 구성 요소를 추가합니다. +- 2025-05-05 언어 간 쿼리를 지원합니다. - 2025-03-19 PDF 또는 DOCX 파일 내의 이미지를 이해하기 위해 다중 모드 모델을 사용하는 것을 지원합니다. - 2025-02-28 인터넷 검색(TAVILY)과 결합되어 모든 LLM에 대한 심층 연구를 지원합니다. -- 2025-01-26 지식 그래프 추출 및 적용을 최적화하고 다양한 구성 옵션을 제공합니다. - 2024-12-18 DeepDoc의 문서 레이아웃 분석 모델 업그레이드. -- 2024-11-01 파싱된 청크에 키워드 추출 및 관련 질문 생성을 추가하여 재현율을 향상시킵니다. - 2024-08-22 RAG를 통해 SQL 문에 텍스트를 지원합니다. ## 🎉 계속 지켜봐 주세요 @@ -152,7 +157,7 @@ > 모든 Docker 이미지는 x86 플랫폼을 위해 빌드되었습니다. 우리는 현재 ARM64 플랫폼을 위한 Docker 이미지를 제공하지 않습니다. > ARM64 플랫폼을 사용 중이라면, [시스템과 호환되는 Docker 이미지를 빌드하려면 이 가이드를 사용해 주세요](https://ragflow.io/docs/dev/build_docker_image). - > 아래 명령어는 RAGFlow Docker 이미지의 v0.19.0-slim 버전을 다운로드합니다. 다양한 RAGFlow 버전에 대한 설명은 다음 표를 참조하십시오. v0.19.0-slim과 다른 RAGFlow 버전을 다운로드하려면, docker/.env 파일에서 RAGFLOW_IMAGE 변수를 적절히 업데이트한 후 docker compose를 사용하여 서버를 시작하십시오. 예를 들어, 전체 버전인 v0.19.0을 다운로드하려면 RAGFLOW_IMAGE=infiniflow/ragflow:v0.19.0로 설정합니다. + > 아래 명령어는 RAGFlow Docker 이미지의 v0.19.1-slim 버전을 다운로드합니다. 다양한 RAGFlow 버전에 대한 설명은 다음 표를 참조하십시오. v0.19.1-slim과 다른 RAGFlow 버전을 다운로드하려면, docker/.env 파일에서 RAGFLOW_IMAGE 변수를 적절히 업데이트한 후 docker compose를 사용하여 서버를 시작하십시오. 예를 들어, 전체 버전인 v0.19.1을 다운로드하려면 RAGFLOW_IMAGE=infiniflow/ragflow:v0.19.1로 설정합니다. ```bash $ cd ragflow/docker @@ -165,8 +170,8 @@ | RAGFlow image tag | Image size (GB) | Has embedding models? | Stable? | | ----------------- | --------------- | --------------------- | ------------------------ | - | v0.19.0 | ≈9 | :heavy_check_mark: | Stable release | - | v0.19.0-slim | ≈2 | ❌ | Stable release | + | v0.19.1 | ≈9 | :heavy_check_mark: | Stable release | + | v0.19.1-slim | ≈2 | ❌ | Stable release | | nightly | ≈9 | :heavy_check_mark: | _Unstable_ nightly build | | nightly-slim | ≈2 | ❌ | _Unstable_ nightly build | diff --git a/README_pt_br.md b/README_pt_br.md index 29bd87f5209..8154751ff31 100644 --- a/README_pt_br.md +++ b/README_pt_br.md @@ -5,13 +5,13 @@

- English | - 简体中文 | - 繁体中文 | - 日本語 | - 한국어 | - Bahasa Indonesia | - Português (Brasil) + README in English + 简体中文版自述文件 + 繁體版中文自述文件 + 日本語のREADME + 한국어 + Bahasa Indonesia + Português(Brasil)

@@ -22,7 +22,7 @@ Badge Estático - docker pull infiniflow/ragflow:v0.19.0 + docker pull infiniflow/ragflow:v0.19.1 Última Versão @@ -30,6 +30,9 @@ licença + + Ask DeepWiki +

@@ -40,6 +43,8 @@ Demo

+# +
📕 Índice @@ -75,11 +80,11 @@ Experimente nossa demo em [https://demo.ragflow.io](https://demo.ragflow.io). ## 🔥 Últimas Atualizações +- 23-05-2025 Adicione o componente executor de código Python/JS ao Agente. +- 05-05-2025 Suporte a consultas entre idiomas. - 19-03-2025 Suporta o uso de um modelo multi-modal para entender imagens dentro de arquivos PDF ou DOCX. - 28-02-2025 combinado com a pesquisa na Internet (T AVI LY), suporta pesquisas profundas para qualquer LLM. -- 26-01-2025 Otimize a extração e aplicação de gráficos de conhecimento e forneça uma variedade de opções de configuração. - 18-12-2024 Atualiza o modelo de Análise de Layout de Documentos no DeepDoc. -- 01-11-2024 Adiciona extração de palavras-chave e geração de perguntas relacionadas aos blocos analisados para melhorar a precisão da recuperação. - 22-08-2024 Suporta conversão de texto para comandos SQL via RAG. ## 🎉 Fique Ligado @@ -172,7 +177,7 @@ Experimente nossa demo em [https://demo.ragflow.io](https://demo.ragflow.io). > Todas as imagens Docker são construídas para plataformas x86. Atualmente, não oferecemos imagens Docker para ARM64. > Se você estiver usando uma plataforma ARM64, por favor, utilize [este guia](https://ragflow.io/docs/dev/build_docker_image) para construir uma imagem Docker compatível com o seu sistema. - > O comando abaixo baixa a edição `v0.19.0-slim` da imagem Docker do RAGFlow. Consulte a tabela a seguir para descrições de diferentes edições do RAGFlow. Para baixar uma edição do RAGFlow diferente da `v0.19.0-slim`, atualize a variável `RAGFLOW_IMAGE` conforme necessário no **docker/.env** antes de usar `docker compose` para iniciar o servidor. Por exemplo: defina `RAGFLOW_IMAGE=infiniflow/ragflow:v0.19.0` para a edição completa `v0.19.0`. + > O comando abaixo baixa a edição `v0.19.1-slim` da imagem Docker do RAGFlow. Consulte a tabela a seguir para descrições de diferentes edições do RAGFlow. Para baixar uma edição do RAGFlow diferente da `v0.19.1-slim`, atualize a variável `RAGFLOW_IMAGE` conforme necessário no **docker/.env** antes de usar `docker compose` para iniciar o servidor. Por exemplo: defina `RAGFLOW_IMAGE=infiniflow/ragflow:v0.19.1` para a edição completa `v0.19.1`. ```bash $ cd ragflow/docker @@ -185,8 +190,8 @@ Experimente nossa demo em [https://demo.ragflow.io](https://demo.ragflow.io). | Tag da imagem RAGFlow | Tamanho da imagem (GB) | Possui modelos de incorporação? | Estável? | | --------------------- | ---------------------- | ------------------------------- | ------------------------ | - | v0.19.0 | ~9 | :heavy_check_mark: | Lançamento estável | - | v0.19.0-slim | ~2 | ❌ | Lançamento estável | + | v0.19.1 | ~9 | :heavy_check_mark: | Lançamento estável | + | v0.19.1-slim | ~2 | ❌ | Lançamento estável | | nightly | ~9 | :heavy_check_mark: | _Instável_ build noturno | | nightly-slim | ~2 | ❌ | _Instável_ build noturno | diff --git a/README_tzh.md b/README_tzh.md index f3d82c1088a..e6010b0b640 100644 --- a/README_tzh.md +++ b/README_tzh.md @@ -5,12 +5,13 @@

- English | - 简体中文 | - 日本語 | - 한국어 | - Bahasa Indonesia | - Português (Brasil) + README in English + 简体中文版自述文件 + 繁體版中文自述文件 + 日本語のREADME + 한국어 + Bahasa Indonesia + Português(Brasil)

@@ -21,7 +22,7 @@ Static Badge - docker pull infiniflow/ragflow:v0.19.0 + docker pull infiniflow/ragflow:v0.19.1 Latest Release @@ -29,6 +30,9 @@ license + + Ask DeepWiki +

@@ -39,6 +43,31 @@ Demo

+# + +
+infiniflow%2Fragflow | Trendshift +
+ +
+📕 目錄 + +- 💡 [RAGFlow 是什麼?](#-RAGFlow-是什麼) +- 🎮 [Demo-試用](#-demo-試用) +- 📌 [近期更新](#-近期更新) +- 🌟 [主要功能](#-主要功能) +- 🔎 [系統架構](#-系統架構) +- 🎬 [快速開始](#-快速開始) +- 🔧 [系統配置](#-系統配置) +- 🔨 [以原始碼啟動服務](#-以原始碼啟動服務) +- 📚 [技術文檔](#-技術文檔) +- 📜 [路線圖](#-路線圖) +- 🏄 [貢獻指南](#-貢獻指南) +- 🙌 [加入社區](#-加入社區) +- 🤝 [商務合作](#-商務合作) + +
+ ## 💡 RAGFlow 是什麼? [RAGFlow](https://ragflow.io/) 是一款基於深度文件理解所建構的開源 RAG(Retrieval-Augmented Generation)引擎。 RAGFlow 可以為各種規模的企業及個人提供一套精簡的 RAG 工作流程,結合大語言模型(LLM)針對用戶各類不同的複雜格式數據提供可靠的問答以及有理有據的引用。 @@ -54,11 +83,11 @@ ## 🔥 近期更新 +- 2025-05-23 為 Agent 新增 Python/JS 程式碼執行器元件。 +- 2025-05-05 支援跨語言查詢。 - 2025-03-19 PDF和DOCX中的圖支持用多模態大模型去解析得到描述. - 2025-02-28 結合網路搜尋(Tavily),對於任意大模型實現類似 Deep Research 的推理功能. -- 2025-01-26 最佳化知識圖譜的擷取與應用,提供了多種配置選擇。 - 2024-12-18 升級了 DeepDoc 的文檔佈局分析模型。 -- 2024-11-01 對解析後的 chunk 加入關鍵字抽取和相關問題產生以提高回想的準確度。 - 2024-08-22 支援用 RAG 技術實現從自然語言到 SQL 語句的轉換。 ## 🎉 關注項目 @@ -151,7 +180,7 @@ > 所有 Docker 映像檔都是為 x86 平台建置的。目前,我們不提供 ARM64 平台的 Docker 映像檔。 > 如果您使用的是 ARM64 平台,請使用 [這份指南](https://ragflow.io/docs/dev/build_docker_image) 來建置適合您系統的 Docker 映像檔。 - > 執行以下指令會自動下載 RAGFlow slim Docker 映像 `v0.19.0-slim`。請參考下表查看不同 Docker 發行版的說明。如需下載不同於 `v0.19.0-slim` 的 Docker 映像,請在執行 `docker compose` 啟動服務之前先更新 **docker/.env** 檔案內的 `RAGFLOW_IMAGE` 變數。例如,你可以透過設定 `RAGFLOW_IMAGE=infiniflow/ragflow:v0.19.0` 來下載 RAGFlow 鏡像的 `v0.19.0` 完整發行版。 + > 執行以下指令會自動下載 RAGFlow slim Docker 映像 `v0.19.1-slim`。請參考下表查看不同 Docker 發行版的說明。如需下載不同於 `v0.19.1-slim` 的 Docker 映像,請在執行 `docker compose` 啟動服務之前先更新 **docker/.env** 檔案內的 `RAGFLOW_IMAGE` 變數。例如,你可以透過設定 `RAGFLOW_IMAGE=infiniflow/ragflow:v0.19.1` 來下載 RAGFlow 鏡像的 `v0.19.1` 完整發行版。 ```bash $ cd ragflow/docker @@ -164,8 +193,8 @@ | RAGFlow image tag | Image size (GB) | Has embedding models? | Stable? | | ----------------- | --------------- | --------------------- | ------------------------ | - | v0.19.0 | ≈9 | :heavy_check_mark: | Stable release | - | v0.19.0-slim | ≈2 | ❌ | Stable release | + | v0.19.1 | ≈9 | :heavy_check_mark: | Stable release | + | v0.19.1-slim | ≈2 | ❌ | Stable release | | nightly | ≈9 | :heavy_check_mark: | _Unstable_ nightly build | | nightly-slim | ≈2 | ❌ | _Unstable_ nightly build | diff --git a/README_zh.md b/README_zh.md index 2758f20ab3b..1669c68446c 100644 --- a/README_zh.md +++ b/README_zh.md @@ -5,13 +5,13 @@

- English | - 简体中文 | - 繁体中文 | - 日本語 | - 한국어 | - Bahasa Indonesia | - Português (Brasil) + README in English + 简体中文版自述文件 + 繁體版中文自述文件 + 日本語のREADME + 한국어 + Bahasa Indonesia + Português(Brasil)

@@ -22,7 +22,7 @@ Static Badge - docker pull infiniflow/ragflow:v0.19.0 + docker pull infiniflow/ragflow:v0.19.1 Latest Release @@ -30,6 +30,9 @@ license + + Ask DeepWiki +

@@ -40,6 +43,31 @@ Demo

+# + +
+infiniflow%2Fragflow | Trendshift +
+ +
+📕 目录 + +- 💡 [RAGFlow 是什么?](#-RAGFlow-是什么) +- 🎮 [Demo](#-demo) +- 📌 [近期更新](#-近期更新) +- 🌟 [主要功能](#-主要功能) +- 🔎 [系统架构](#-系统架构) +- 🎬 [快速开始](#-快速开始) +- 🔧 [系统配置](#-系统配置) +- 🔨 [以源代码启动服务](#-以源代码启动服务) +- 📚 [技术文档](#-技术文档) +- 📜 [路线图](#-路线图) +- 🏄 [贡献指南](#-贡献指南) +- 🙌 [加入社区](#-加入社区) +- 🤝 [商务合作](#-商务合作) + +
+ ## 💡 RAGFlow 是什么? [RAGFlow](https://ragflow.io/) 是一款基于深度文档理解构建的开源 RAG(Retrieval-Augmented Generation)引擎。RAGFlow 可以为各种规模的企业及个人提供一套精简的 RAG 工作流程,结合大语言模型(LLM)针对用户各类不同的复杂格式数据提供可靠的问答以及有理有据的引用。 @@ -55,11 +83,11 @@ ## 🔥 近期更新 -- 2025-03-19 PDF和DOCX中的图支持用多模态大模型去解析得到描述. +- 2025-05-23 Agent 新增 Python/JS 代码执行器组件。 +- 2025-05-05 支持跨语言查询。 +- 2025-03-19 PDF 和 DOCX 中的图支持用多模态大模型去解析得到描述. - 2025-02-28 结合互联网搜索(Tavily),对于任意大模型实现类似 Deep Research 的推理功能. -- 2025-01-26 优化知识图谱的提取和应用,提供了多种配置选择。 - 2024-12-18 升级了 DeepDoc 的文档布局分析模型。 -- 2024-11-01 对解析后的 chunk 加入关键词抽取和相关问题生成以提高召回的准确度。 - 2024-08-22 支持用 RAG 技术实现从自然语言到 SQL 语句的转换。 ## 🎉 关注项目 @@ -152,7 +180,7 @@ > 请注意,目前官方提供的所有 Docker 镜像均基于 x86 架构构建,并不提供基于 ARM64 的 Docker 镜像。 > 如果你的操作系统是 ARM64 架构,请参考[这篇文档](https://ragflow.io/docs/dev/build_docker_image)自行构建 Docker 镜像。 - > 运行以下命令会自动下载 RAGFlow slim Docker 镜像 `v0.19.0-slim`。请参考下表查看不同 Docker 发行版的描述。如需下载不同于 `v0.19.0-slim` 的 Docker 镜像,请在运行 `docker compose` 启动服务之前先更新 **docker/.env** 文件内的 `RAGFLOW_IMAGE` 变量。比如,你可以通过设置 `RAGFLOW_IMAGE=infiniflow/ragflow:v0.19.0` 来下载 RAGFlow 镜像的 `v0.19.0` 完整发行版。 + > 运行以下命令会自动下载 RAGFlow slim Docker 镜像 `v0.19.1-slim`。请参考下表查看不同 Docker 发行版的描述。如需下载不同于 `v0.19.1-slim` 的 Docker 镜像,请在运行 `docker compose` 启动服务之前先更新 **docker/.env** 文件内的 `RAGFLOW_IMAGE` 变量。比如,你可以通过设置 `RAGFLOW_IMAGE=infiniflow/ragflow:v0.19.1` 来下载 RAGFlow 镜像的 `v0.19.1` 完整发行版。 ```bash $ cd ragflow/docker @@ -165,8 +193,8 @@ | RAGFlow image tag | Image size (GB) | Has embedding models? | Stable? | | ----------------- | --------------- | --------------------- | ------------------------ | - | v0.19.0 | ≈9 | :heavy_check_mark: | Stable release | - | v0.19.0-slim | ≈2 | ❌ | Stable release | + | v0.19.1 | ≈9 | :heavy_check_mark: | Stable release | + | v0.19.1-slim | ≈2 | ❌ | Stable release | | nightly | ≈9 | :heavy_check_mark: | _Unstable_ nightly build | | nightly-slim | ≈2 | ❌ | _Unstable_ nightly build | diff --git a/agent/canvas.py b/agent/canvas.py index 7fc64245669..4b859fbf356 100644 --- a/agent/canvas.py +++ b/agent/canvas.py @@ -169,6 +169,7 @@ def get_component_name(self, cid): def run(self, running_hint_text = "is running...🕞", **kwargs): if not running_hint_text or not isinstance(running_hint_text, str): running_hint_text = "is running...🕞" + bypass_begin = bool(kwargs.get("bypass_begin", False)) if self.answer: cpn_id = self.answer[0] @@ -188,6 +189,12 @@ def run(self, running_hint_text = "is running...🕞", **kwargs): if not self.path: self.components["begin"]["obj"].run(self.history, **kwargs) self.path.append(["begin"]) + if bypass_begin: + cpn = self.get_component("begin") + downstream = cpn["downstream"] + self.path.append(downstream) + + self.path.append([]) @@ -304,6 +311,8 @@ def get_tenant_id(self): def get_history(self, window_size): convs = [] + if window_size <= 0: + return convs for role, obj in self.history[window_size * -1:]: if isinstance(obj, list) and obj and all([isinstance(o, dict) for o in obj]): convs.append({"role": role, "content": '\n'.join([str(s.get("content", "")) for s in obj])}) diff --git a/agent/component/answer.py b/agent/component/answer.py index 67dcbc63f7c..c8c3439c00b 100644 --- a/agent/component/answer.py +++ b/agent/component/answer.py @@ -64,14 +64,17 @@ def stream_output(self): for ii, row in stream.iterrows(): answer += row.to_dict()["content"] yield {"content": answer} - else: + elif stream is not None: for st in stream(): res = st yield st - if self._param.post_answers: + if self._param.post_answers and res: res["content"] += random.choice(self._param.post_answers) yield res + if res is None: + res = {"content": ""} + self.set_output(res) def set_exception(self, e): diff --git a/agent/component/baidu.py b/agent/component/baidu.py index daec9f058b6..b75faa4a031 100644 --- a/agent/component/baidu.py +++ b/agent/component/baidu.py @@ -17,6 +17,7 @@ from abc import ABC import pandas as pd import requests +from bs4 import BeautifulSoup import re from agent.component.base import ComponentBase, ComponentParamBase @@ -44,17 +45,28 @@ def _run(self, history, **kwargs): return Baidu.be_output("") try: - url = 'http://www.baidu.com/s?wd=' + ans + '&rn=' + str(self._param.top_n) + url = 'https://www.baidu.com/s?wd=' + ans + '&rn=' + str(self._param.top_n) headers = { - 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/88.0.4324.104 Safari/537.36'} + 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36', + 'Accept': 'text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,*/*;q=0.8', + 'Accept-Language': 'zh-CN,zh;q=0.9,en;q=0.8', + 'Connection': 'keep-alive', + } response = requests.get(url=url, headers=headers) - - url_res = re.findall(r"'url': \\\"(.*?)\\\"}", response.text) - title_res = re.findall(r"'title': \\\"(.*?)\\\",\\n", response.text) - body_res = re.findall(r"\"contentText\":\"(.*?)\"", response.text) - baidu_res = [{"content": re.sub('|', '', '' + title + ' ' + body)} for - url, title, body in zip(url_res, title_res, body_res)] - del body_res, url_res, title_res + # check if request success + if response.status_code == 200: + soup = BeautifulSoup(response.text, 'html.parser') + url_res = [] + title_res = [] + body_res = [] + for item in soup.select('.result.c-container'): + # extract title + title_res.append(item.select_one('h3 a').get_text(strip=True)) + url_res.append(item.select_one('h3 a')['href']) + body_res.append(item.select_one('.c-abstract').get_text(strip=True) if item.select_one('.c-abstract') else '') + baidu_res = [{"content": re.sub('|', '', '' + title + ' ' + body)} for + url, title, body in zip(url_res, title_res, body_res)] + del body_res, url_res, title_res except Exception as e: return Baidu.be_output("**ERROR**: " + str(e)) diff --git a/agent/component/base.py b/agent/component/base.py index 15b3c345e1b..e35a84e64e6 100644 --- a/agent/component/base.py +++ b/agent/component/base.py @@ -13,11 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from abc import ABC import builtins import json -import os import logging +import os +from abc import ABC from functools import partial from typing import Any, Tuple, Union @@ -110,15 +110,11 @@ def update(self, conf, allow_redundant=False): update_from_raw_conf = conf.get(_IS_RAW_CONF, True) if update_from_raw_conf: deprecated_params_set = self._get_or_init_deprecated_params_set() - feeded_deprecated_params_set = ( - self._get_or_init_feeded_deprecated_params_set() - ) + feeded_deprecated_params_set = self._get_or_init_feeded_deprecated_params_set() user_feeded_params_set = self._get_or_init_user_feeded_params_set() setattr(self, _IS_RAW_CONF, False) else: - feeded_deprecated_params_set = ( - self._get_or_init_feeded_deprecated_params_set(conf) - ) + feeded_deprecated_params_set = self._get_or_init_feeded_deprecated_params_set(conf) user_feeded_params_set = self._get_or_init_user_feeded_params_set(conf) def _recursive_update_param(param, config, depth, prefix): @@ -154,15 +150,11 @@ def _recursive_update_param(param, config, depth, prefix): else: # recursive set obj attr - sub_params = _recursive_update_param( - attr, config_value, depth + 1, prefix=f"{prefix}{config_key}." - ) + sub_params = _recursive_update_param(attr, config_value, depth + 1, prefix=f"{prefix}{config_key}.") setattr(param, config_key, sub_params) if not allow_redundant and redundant_attrs: - raise ValueError( - f"cpn `{getattr(self, '_name', type(self))}` has redundant parameters: `{[redundant_attrs]}`" - ) + raise ValueError(f"cpn `{getattr(self, '_name', type(self))}` has redundant parameters: `{[redundant_attrs]}`") return param @@ -193,9 +185,7 @@ def validate(self): param_validation_path_prefix = home_dir + "/param_validation/" param_name = type(self).__name__ - param_validation_path = "/".join( - [param_validation_path_prefix, param_name + ".json"] - ) + param_validation_path = "/".join([param_validation_path_prefix, param_name + ".json"]) validation_json = None @@ -228,11 +218,7 @@ def _validate_param(self, param_obj, validation_json): break if not value_legal: - raise ValueError( - "Plase check runtime conf, {} = {} does not match user-parameter restriction".format( - variable, value - ) - ) + raise ValueError("Plase check runtime conf, {} = {} does not match user-parameter restriction".format(variable, value)) elif variable in validation_json: self._validate_param(attr, validation_json) @@ -240,94 +226,63 @@ def _validate_param(self, param_obj, validation_json): @staticmethod def check_string(param, descr): if type(param).__name__ not in ["str"]: - raise ValueError( - descr + " {} not supported, should be string type".format(param) - ) + raise ValueError(descr + " {} not supported, should be string type".format(param)) @staticmethod def check_empty(param, descr): if not param: - raise ValueError( - descr + " does not support empty value." - ) + raise ValueError(descr + " does not support empty value.") @staticmethod def check_positive_integer(param, descr): if type(param).__name__ not in ["int", "long"] or param <= 0: - raise ValueError( - descr + " {} not supported, should be positive integer".format(param) - ) + raise ValueError(descr + " {} not supported, should be positive integer".format(param)) @staticmethod def check_positive_number(param, descr): if type(param).__name__ not in ["float", "int", "long"] or param <= 0: - raise ValueError( - descr + " {} not supported, should be positive numeric".format(param) - ) + raise ValueError(descr + " {} not supported, should be positive numeric".format(param)) @staticmethod def check_nonnegative_number(param, descr): if type(param).__name__ not in ["float", "int", "long"] or param < 0: - raise ValueError( - descr - + " {} not supported, should be non-negative numeric".format(param) - ) + raise ValueError(descr + " {} not supported, should be non-negative numeric".format(param)) @staticmethod def check_decimal_float(param, descr): if type(param).__name__ not in ["float", "int"] or param < 0 or param > 1: - raise ValueError( - descr - + " {} not supported, should be a float number in range [0, 1]".format( - param - ) - ) + raise ValueError(descr + " {} not supported, should be a float number in range [0, 1]".format(param)) @staticmethod def check_boolean(param, descr): if type(param).__name__ != "bool": - raise ValueError( - descr + " {} not supported, should be bool type".format(param) - ) + raise ValueError(descr + " {} not supported, should be bool type".format(param)) @staticmethod def check_open_unit_interval(param, descr): if type(param).__name__ not in ["float"] or param <= 0 or param >= 1: - raise ValueError( - descr + " should be a numeric number between 0 and 1 exclusively" - ) + raise ValueError(descr + " should be a numeric number between 0 and 1 exclusively") @staticmethod def check_valid_value(param, descr, valid_values): if param not in valid_values: - raise ValueError( - descr - + " {} is not supported, it should be in {}".format(param, valid_values) - ) + raise ValueError(descr + " {} is not supported, it should be in {}".format(param, valid_values)) @staticmethod def check_defined_type(param, descr, types): if type(param).__name__ not in types: - raise ValueError( - descr + " {} not supported, should be one of {}".format(param, types) - ) + raise ValueError(descr + " {} not supported, should be one of {}".format(param, types)) @staticmethod def check_and_change_lower(param, valid_list, descr=""): if type(param).__name__ != "str": - raise ValueError( - descr - + " {} not supported, should be one of {}".format(param, valid_list) - ) + raise ValueError(descr + " {} not supported, should be one of {}".format(param, valid_list)) lower_param = param.lower() if lower_param in valid_list: return lower_param else: - raise ValueError( - descr - + " {} not supported, should be one of {}".format(param, valid_list) - ) + raise ValueError(descr + " {} not supported, should be one of {}".format(param, valid_list)) @staticmethod def _greater_equal_than(value, limit): @@ -341,11 +296,7 @@ def _less_equal_than(value, limit): def _range(value, ranges): in_range = False for left_limit, right_limit in ranges: - if ( - left_limit - settings.FLOAT_ZERO - <= value - <= right_limit + settings.FLOAT_ZERO - ): + if left_limit - settings.FLOAT_ZERO <= value <= right_limit + settings.FLOAT_ZERO: in_range = True break @@ -361,16 +312,11 @@ def _not_in(value, wrong_value_list): def _warn_deprecated_param(self, param_name, descr): if self._deprecated_params_set.get(param_name): - logging.warning( - f"{descr} {param_name} is deprecated and ignored in this version." - ) + logging.warning(f"{descr} {param_name} is deprecated and ignored in this version.") def _warn_to_deprecate_param(self, param_name, descr, new_param): if self._deprecated_params_set.get(param_name): - logging.warning( - f"{descr} {param_name} will be deprecated in future release; " - f"please use {new_param} instead." - ) + logging.warning(f"{descr} {param_name} will be deprecated in future release; please use {new_param} instead.") return True return False @@ -395,14 +341,16 @@ def __str__(self): "params": {}, "output": {}, "inputs": {} - }}""".format(self.component_name, - self._param, - json.dumps(json.loads(str(self._param)).get("output", {}), ensure_ascii=False), - json.dumps(json.loads(str(self._param)).get("inputs", []), ensure_ascii=False) + }}""".format( + self.component_name, + self._param, + json.dumps(json.loads(str(self._param)).get("output", {}), ensure_ascii=False), + json.dumps(json.loads(str(self._param)).get("inputs", []), ensure_ascii=False), ) def __init__(self, canvas, id, param: ComponentParamBase): from agent.canvas import Canvas # Local import to avoid cyclic dependency + assert isinstance(canvas, Canvas), "canvas must be an instance of Canvas" self._canvas = canvas self._id = id @@ -410,15 +358,17 @@ def __init__(self, canvas, id, param: ComponentParamBase): self._param.check() def get_dependent_components(self): - cpnts = set([para["component_id"].split("@")[0] for para in self._param.query \ - if para.get("component_id") \ - and para["component_id"].lower().find("answer") < 0 \ - and para["component_id"].lower().find("begin") < 0]) + cpnts = set( + [ + para["component_id"].split("@")[0] + for para in self._param.query + if para.get("component_id") and para["component_id"].lower().find("answer") < 0 and para["component_id"].lower().find("begin") < 0 + ] + ) return list(cpnts) def run(self, history, **kwargs): - logging.debug("{}, history: {}, kwargs: {}".format(self, json.dumps(history, ensure_ascii=False), - json.dumps(kwargs, ensure_ascii=False))) + logging.debug("{}, history: {}, kwargs: {}".format(self, json.dumps(history, ensure_ascii=False), json.dumps(kwargs, ensure_ascii=False))) self._param.debug_inputs = [] try: res = self._run(history, **kwargs) @@ -465,7 +415,7 @@ def set_output(self, v): def set_infor(self, v): setattr(self._param, self._param.infor_var_name, v) - + def _fetch_outputs_from(self, sources: list[dict[str, Any]]) -> list[pd.DataFrame]: outs = [] for q in sources: @@ -482,7 +432,7 @@ def _fetch_outputs_from(self, sources: list[dict[str, Any]]) -> list[pd.DataFram if q["component_id"].lower().find("answer") == 0: txt = [] - for r, c in self._canvas.history[::-1][:self._param.message_history_window_size][::-1]: + for r, c in self._canvas.history[::-1][: self._param.message_history_window_size][::-1]: txt.append(f"{r.upper()}:{c}") txt = "\n".join(txt) outs.append(pd.DataFrame([{"content": txt}])) @@ -512,21 +462,16 @@ def get_input(self): content: str if len(records) > 1: - content = "\n".join( - [str(d["content"]) for d in records] - ) + content = "\n".join([str(d["content"]) for d in records]) else: content = records[0]["content"] - self._param.inputs.append({ - "component_id": records[0].get("component_id"), - "content": content - }) + self._param.inputs.append({"component_id": records[0].get("component_id"), "content": content}) if outs: df = pd.concat(outs, ignore_index=True) if "content" in df: - df = df.drop_duplicates(subset=['content']).reset_index(drop=True) + df = df.drop_duplicates(subset=["content"]).reset_index(drop=True) return df upstream_outs = [] @@ -540,9 +485,8 @@ def get_input(self): o["component_id"] = u upstream_outs.append(o) continue - #if self.component_name.lower()!="answer" and u not in self._canvas.get_component(self._id)["upstream"]: continue - if self.component_name.lower().find("switch") < 0 \ - and self.get_component_name(u) in ["relevant", "categorize"]: + # if self.component_name.lower()!="answer" and u not in self._canvas.get_component(self._id)["upstream"]: continue + if self.component_name.lower().find("switch") < 0 and self.get_component_name(u) in ["relevant", "categorize"]: continue if u.lower().find("answer") >= 0: for r, c in self._canvas.history[::-1]: @@ -562,7 +506,7 @@ def get_input(self): df = pd.concat(upstream_outs, ignore_index=True) if "content" in df: - df = df.drop_duplicates(subset=['content']).reset_index(drop=True) + df = df.drop_duplicates(subset=["content"]).reset_index(drop=True) self._param.inputs = [] for _, r in df.iterrows(): @@ -614,5 +558,5 @@ def get_parent(self): return self._canvas.get_component(pid)["obj"] def get_upstream(self): - cpn_nms = self._canvas.get_component(self._id)['upstream'] + cpn_nms = self._canvas.get_component(self._id)["upstream"] return cpn_nms diff --git a/agent/component/categorize.py b/agent/component/categorize.py index 9171f809aa5..34bd2cdeacf 100644 --- a/agent/component/categorize.py +++ b/agent/component/categorize.py @@ -99,9 +99,13 @@ def _run(self, history, **kwargs): # If a category is found, return the category with the highest count. if any(category_counts.values()): max_category = max(category_counts.items(), key=lambda x: x[1]) - return Categorize.be_output(self._param.category_description[max_category[0]]["to"]) + res = Categorize.be_output(self._param.category_description[max_category[0]]["to"]) + self.set_output(res) + return res - return Categorize.be_output(list(self._param.category_description.items())[-1][1]["to"]) + res = Categorize.be_output(list(self._param.category_description.items())[-1][1]["to"]) + self.set_output(res) + return res def debug(self, **kwargs): df = self._run([], **kwargs) diff --git a/agent/component/code.py b/agent/component/code.py index 0abf0b47247..215ffcfe574 100644 --- a/agent/component/code.py +++ b/agent/component/code.py @@ -81,21 +81,32 @@ def _run(self, history, **kwargs): for input in self._param.arguments: if "@" in input["component_id"]: component_id = input["component_id"].split("@")[0] - refered_component_key = input["component_id"].split("@")[1] - refered_component = self._canvas.get_component(component_id)["obj"] + referred_component_key = input["component_id"].split("@")[1] + referred_component = self._canvas.get_component(component_id)["obj"] - for param in refered_component._param.query: - if param["key"] == refered_component_key: + for param in referred_component._param.query: + if param["key"] == referred_component_key: if "value" in param: arguments[input["name"]] = param["value"] else: - cpn = self._canvas.get_component(input["component_id"])["obj"] - if cpn.component_name.lower() == "answer": - arguments[input["name"]] = self._canvas.get_history(1)[0]["content"] - continue - _, out = cpn.output(allow_partial=False) - if not out.empty: - arguments[input["name"]] = "\n".join(out["content"]) + referred_component = self._canvas.get_component(input["component_id"])["obj"] + referred_component_name = referred_component.component_name + referred_component_id = referred_component._id + + debug_inputs = self._param.debug_inputs + if debug_inputs: + for param in debug_inputs: + if param["key"] == referred_component_id: + if "value" in param and param["name"] == input["name"]: + arguments[input["name"]] = param["value"] + else: + if referred_component_name.lower() == "answer": + arguments[input["name"]] = self._canvas.get_history(1)[0]["content"] + continue + + _, out = referred_component.output(allow_partial=False) + if not out.empty: + arguments[input["name"]] = "\n".join(out["content"]) return self._execute_code( language=self._param.lang, @@ -136,3 +147,6 @@ def get_input_elements(self): cpn_id = input["component_id"] elements.append({"key": cpn_id, "name": input["name"]}) return elements + + def debug(self, **kwargs): + return self._run([], **kwargs) diff --git a/agent/component/exesql.py b/agent/component/exesql.py index 2c414ddb71f..6f8eae02598 100644 --- a/agent/component/exesql.py +++ b/agent/component/exesql.py @@ -105,6 +105,7 @@ def _run(self, history, **kwargs): sql_res = [] for i in range(len(input_list)): single_sql = input_list[i] + single_sql = single_sql.replace('```','') while self._loop <= self._param.loop: self._loop += 1 if not single_sql: diff --git a/agent/component/message.py b/agent/component/message.py index a193dd122ba..c60d4d307c2 100644 --- a/agent/component/message.py +++ b/agent/component/message.py @@ -40,7 +40,9 @@ def _run(self, history, **kwargs): if kwargs.get("stream"): return partial(self.stream_output) - return Message.be_output(random.choice(self._param.messages)) + res = Message.be_output(random.choice(self._param.messages)) + self.set_output(res) + return res def stream_output(self): res = None diff --git a/agent/component/retrieval.py b/agent/component/retrieval.py index 859d65478b1..218dae96999 100644 --- a/agent/component/retrieval.py +++ b/agent/component/retrieval.py @@ -96,6 +96,7 @@ def _run(self, history, **kwargs): rerank_mdl = LLMBundle(kbs[0].tenant_id, LLMType.RERANK, self._param.rerank_id) if kbs: + query = re.sub(r"^user[::\s]*", "", query, flags=re.IGNORECASE) kbinfos = settings.retrievaler.retrieval( query, embd_mdl, diff --git a/agent/component/template.py b/agent/component/template.py index b54b93d56af..ab02a111b20 100644 --- a/agent/component/template.py +++ b/agent/component/template.py @@ -15,8 +15,11 @@ # import json import re + +from jinja2 import StrictUndefined +from jinja2.sandbox import SandboxedEnvironment + from agent.component.base import ComponentBase, ComponentParamBase -from jinja2 import Template as Jinja2Template class TemplateParam(ComponentParamBase): @@ -75,6 +78,11 @@ def _run(self, history, **kwargs): if p["key"] == key: value = p.get("value", "") self.make_kwargs(para, kwargs, value) + + origin_pattern = "{begin@" + key + "}" + new_pattern = "begin_" + key + content = content.replace(origin_pattern, new_pattern) + kwargs[new_pattern] = kwargs.pop(origin_pattern, "") break else: assert False, f"Can't find parameter '{key}' for {cpn_id}" @@ -89,19 +97,27 @@ def _run(self, history, **kwargs): else: hist = "" self.make_kwargs(para, kwargs, hist) + + if ":" in component_id: + origin_pattern = "{" + component_id + "}" + new_pattern = component_id.replace(":", "_") + content = content.replace(origin_pattern, new_pattern) + kwargs[new_pattern] = kwargs.pop(component_id, "") continue _, out = cpn.output(allow_partial=False) result = "" if "content" in out.columns: - result = "\n".join( - [o if isinstance(o, str) else str(o) for o in out["content"]] - ) + result = "\n".join([o if isinstance(o, str) else str(o) for o in out["content"]]) self.make_kwargs(para, kwargs, result) - template = Jinja2Template(content) + env = SandboxedEnvironment( + autoescape=True, + undefined=StrictUndefined, + ) + template = env.from_string(content) try: content = template.render(kwargs) @@ -114,19 +130,16 @@ def _run(self, history, **kwargs): v = json.dumps(v, ensure_ascii=False) except Exception: pass - content = re.sub( - r"\{%s\}" % re.escape(n), v, content - ) - content = re.sub( - r"(#+)", r" \1 ", content - ) + # Process backslashes in strings, Use Lambda function to avoid escape issues + if isinstance(v, str): + v = v.replace("\\", "\\\\") + content = re.sub(r"\{%s\}" % re.escape(n), lambda match: v, content) + content = re.sub(r"(#+)", r" \1 ", content) return Template.be_output(content) def make_kwargs(self, para, kwargs, value): - self._param.inputs.append( - {"component_id": para["key"], "content": value} - ) + self._param.inputs.append({"component_id": para["key"], "content": value}) try: value = json.loads(value) except Exception: diff --git a/agent/templates/customer_service.json b/agent/templates/customer_service.json index 0723ed5d887..d3dcc70fe1a 100644 --- a/agent/templates/customer_service.json +++ b/agent/templates/customer_service.json @@ -52,7 +52,10 @@ "parameters": [], "presence_penalty": 0.4, "prompt": "", - "query": [], + "query": [{ + "type": "reference", + "component_id": "RewriteQuestion:AllNightsSniff" + }], "temperature": 0.1, "top_p": 0.3 } @@ -195,11 +198,15 @@ "message_history_window_size": 22, "output": null, "output_var_name": "output", - "query": [], "rerank_id": "", "similarity_threshold": 0.2, "top_k": 1024, - "top_n": 6 + "top_n": 6, + "query": [{ + "type": "reference", + "component_id": "RewriteQuestion:AllNightsSniff" + }], + "use_kg": false } }, "upstream": [ @@ -548,7 +555,11 @@ "temperature": 0.1, "temperatureEnabled": true, "topPEnabled": true, - "top_p": 0.3 + "top_p": 0.3, + "query": [{ + "type": "reference", + "component_id": "RewriteQuestion:AllNightsSniff" + }] }, "label": "Categorize", "name": "Question Categorize" @@ -625,7 +636,11 @@ "keywords_similarity_weight": 0.3, "similarity_threshold": 0.2, "top_k": 1024, - "top_n": 6 + "top_n": 6, + "query": [{ + "type": "reference", + "component_id": "RewriteQuestion:AllNightsSniff" + }] }, "label": "Retrieval", "name": "Search product info" @@ -932,7 +947,7 @@ "y": 962.5655101584402 }, "resizing": false, - "selected": true, + "selected": false, "sourcePosition": "right", "style": { "height": 163, diff --git a/agentic_reasoning/deep_research.py b/agentic_reasoning/deep_research.py index 3f2a2d8a9d4..d7121245f0f 100644 --- a/agentic_reasoning/deep_research.py +++ b/agentic_reasoning/deep_research.py @@ -36,17 +36,20 @@ def __init__(self, self._kb_retrieve = kb_retrieve self._kg_retrieve = kg_retrieve - @staticmethod - def _remove_query_tags(text): - """Remove query tags from text""" - pattern = re.escape(BEGIN_SEARCH_QUERY) + r"(.*?)" + re.escape(END_SEARCH_QUERY) + def _remove_tags(text: str, start_tag: str, end_tag: str) -> str: + """General Tag Removal Method""" + pattern = re.escape(start_tag) + r"(.*?)" + re.escape(end_tag) return re.sub(pattern, "", text) @staticmethod - def _remove_result_tags(text): - """Remove result tags from text""" - pattern = re.escape(BEGIN_SEARCH_RESULT) + r"(.*?)" + re.escape(END_SEARCH_RESULT) - return re.sub(pattern, "", text) + def _remove_query_tags(text: str) -> str: + """Remove Query Tags""" + return DeepResearcher._remove_tags(text, BEGIN_SEARCH_QUERY, END_SEARCH_QUERY) + + @staticmethod + def _remove_result_tags(text: str) -> str: + """Remove Result Tags""" + return DeepResearcher._remove_tags(text, BEGIN_SEARCH_RESULT, END_SEARCH_RESULT) def _generate_reasoning(self, msg_history): """Generate reasoning steps""" @@ -95,21 +98,31 @@ def _truncate_previous_reasoning(self, all_reasoning_steps): def _retrieve_information(self, search_query): """Retrieve information from different sources""" # 1. Knowledge base retrieval - kbinfos = self._kb_retrieve(question=search_query) if self._kb_retrieve else {"chunks": [], "doc_aggs": []} - + kbinfos = [] + try: + kbinfos = self._kb_retrieve(question=search_query) if self._kb_retrieve else {"chunks": [], "doc_aggs": []} + except Exception as e: + logging.error(f"Knowledge base retrieval error: {e}") + # 2. Web retrieval (if Tavily API is configured) - if self.prompt_config.get("tavily_api_key"): - tav = Tavily(self.prompt_config["tavily_api_key"]) - tav_res = tav.retrieve_chunks(search_query) - kbinfos["chunks"].extend(tav_res["chunks"]) - kbinfos["doc_aggs"].extend(tav_res["doc_aggs"]) - + try: + if self.prompt_config.get("tavily_api_key"): + tav = Tavily(self.prompt_config["tavily_api_key"]) + tav_res = tav.retrieve_chunks(search_query) + kbinfos["chunks"].extend(tav_res["chunks"]) + kbinfos["doc_aggs"].extend(tav_res["doc_aggs"]) + except Exception as e: + logging.error(f"Web retrieval error: {e}") + # 3. Knowledge graph retrieval (if configured) - if self.prompt_config.get("use_kg") and self._kg_retrieve: - ck = self._kg_retrieve(question=search_query) - if ck["content_with_weight"]: - kbinfos["chunks"].insert(0, ck) - + try: + if self.prompt_config.get("use_kg") and self._kg_retrieve: + ck = self._kg_retrieve(question=search_query) + if ck["content_with_weight"]: + kbinfos["chunks"].insert(0, ck) + except Exception as e: + logging.error(f"Knowledge graph retrieval error: {e}") + return kbinfos def _update_chunk_info(self, chunk_info, kbinfos): diff --git a/api/apps/__init__.py b/api/apps/__init__.py index 7acef9be5bc..007e37430e4 100644 --- a/api/apps/__init__.py +++ b/api/apps/__init__.py @@ -146,10 +146,23 @@ def load_user(web_request): if authorization: try: access_token = str(jwt.loads(authorization)) + + if not access_token or not access_token.strip(): + logging.warning("Authentication attempt with empty access token") + return None + + # Access tokens should be UUIDs (32 hex characters) + if len(access_token.strip()) < 32: + logging.warning(f"Authentication attempt with invalid token format: {len(access_token)} chars") + return None + user = UserService.query( access_token=access_token, status=StatusEnum.VALID.value ) if user: + if not user[0].access_token or not user[0].access_token.strip(): + logging.warning(f"User {user[0].email} has empty access_token in database") + return None return user[0] else: return None diff --git a/api/apps/api_app.py b/api/apps/api_app.py index cc417c9a22d..f66eb8067b5 100644 --- a/api/apps/api_app.py +++ b/api/apps/api_app.py @@ -18,7 +18,7 @@ import re from datetime import datetime, timedelta from flask import request, Response -from api.db.services.llm_service import TenantLLMService +from api.db.services.llm_service import LLMBundle from flask_login import login_required, current_user from api.db import VALID_FILE_TYPES, VALID_TASK_STATUS, FileType, LLMType, ParserType, FileSource @@ -875,14 +875,12 @@ def retrieval(): data=False, message='Knowledge bases use different embedding models or does not exist."', code=settings.RetCode.AUTHENTICATION_ERROR) - embd_mdl = TenantLLMService.model_instance( - kbs[0].tenant_id, LLMType.EMBEDDING.value, llm_name=kbs[0].embd_id) + embd_mdl = LLMBundle(kbs[0].tenant_id, LLMType.EMBEDDING, llm_name=kbs[0].embd_id) rerank_mdl = None if req.get("rerank_id"): - rerank_mdl = TenantLLMService.model_instance( - kbs[0].tenant_id, LLMType.RERANK.value, llm_name=req["rerank_id"]) + rerank_mdl = LLMBundle(kbs[0].tenant_id, LLMType.RERANK, llm_name=req["rerank_id"]) if req.get("keyword", False): - chat_mdl = TenantLLMService.model_instance(kbs[0].tenant_id, LLMType.CHAT) + chat_mdl = LLMBundle(kbs[0].tenant_id, LLMType.CHAT) question += keyword_extraction(chat_mdl, question) ranks = settings.retrievaler.retrieval(question, embd_mdl, kbs[0].tenant_id, kb_ids, page, size, similarity_threshold, vector_similarity_weight, top, diff --git a/api/apps/auth/oidc.py b/api/apps/auth/oidc.py index 2fcdb6f5d02..9c59ffaebaa 100644 --- a/api/apps/auth/oidc.py +++ b/api/apps/auth/oidc.py @@ -68,8 +68,7 @@ def parse_id_token(self, id_token): alg = headers.get("alg", "RS256") # Use PyJWT's PyJWKClient to fetch JWKS and find signing key - jwks_url = f"{self.issuer}/.well-known/jwks.json" - jwks_cli = jwt.PyJWKClient(jwks_url) + jwks_cli = jwt.PyJWKClient(self.jwks_uri) signing_key = jwks_cli.get_signing_key_from_jwt(id_token).key # Decode and verify signature diff --git a/api/apps/canvas_app.py b/api/apps/canvas_app.py index b55ab0d5df7..d80eb093c94 100644 --- a/api/apps/canvas_app.py +++ b/api/apps/canvas_app.py @@ -249,7 +249,9 @@ def debug(): code=RetCode.OPERATING_ERROR) canvas = Canvas(json.dumps(user_canvas.dsl), current_user.id) - canvas.get_component(req["component_id"])["obj"]._param.debug_inputs = req["params"] + componant = canvas.get_component(req["component_id"])["obj"] + componant.reset() + componant._param.debug_inputs = req["params"] df = canvas.get_component(req["component_id"])["obj"].debug() return get_json_result(data=df.to_dict(orient="records")) except Exception as e: diff --git a/api/apps/conversation_app.py b/api/apps/conversation_app.py index 23fb3b704e2..5ba39716f8b 100644 --- a/api/apps/conversation_app.py +++ b/api/apps/conversation_app.py @@ -42,6 +42,7 @@ def set_conversation(): conv_id = req.get("conversation_id") is_new = req.get("is_new") name = req.get("name", "New conversation") + req["user_id"] = current_user.id if len(name) > 255: name = name[0:255] @@ -64,7 +65,7 @@ def set_conversation(): e, dia = DialogService.get_by_id(req["dialog_id"]) if not e: return get_data_error_result(message="Dialog not found") - conv = {"id": conv_id, "dialog_id": req["dialog_id"], "name": name, "message": [{"role": "assistant", "content": dia.prompt_config["prologue"]}]} + conv = {"id": conv_id, "dialog_id": req["dialog_id"], "name": name, "message": [{"role": "assistant", "content": dia.prompt_config["prologue"]}],"user_id": current_user.id} ConversationService.save(**conv) return get_json_result(data=conv) except Exception as e: @@ -248,7 +249,7 @@ def stream(): else: answer = None for ans in chat(dia, msg, **req): - answer = structure_answer(conv, ans, message_id, req["conversation_id"]) + answer = structure_answer(conv, ans, message_id, conv.id) ConversationService.update_by_id(conv.id, conv.to_dict()) break return get_json_result(data=answer) diff --git a/api/apps/dialog_app.py b/api/apps/dialog_app.py index 4878e0732df..2c4bd725c47 100644 --- a/api/apps/dialog_app.py +++ b/api/apps/dialog_app.py @@ -28,6 +28,7 @@ @manager.route('/set', methods=['POST']) # noqa: F821 +@validate_request("prompt_config") @login_required def set_dialog(): req = request.json @@ -43,33 +44,10 @@ def set_dialog(): similarity_threshold = req.get("similarity_threshold", 0.1) vector_similarity_weight = req.get("vector_similarity_weight", 0.3) llm_setting = req.get("llm_setting", {}) - default_prompt_with_dataset = { - "system": """你是一个智能助手,请总结知识库的内容来回答问题,请列举知识库中的数据详细回答。当所有知识库内容都与问题无关时,你的回答必须包括“知识库中未找到您要的答案!”这句话。回答需要考虑聊天历史。 -以下是知识库: -{knowledge} -以上是知识库。""", - "prologue": "您好,我是您的助手小樱,长得可爱又善良,can I help you?", - "parameters": [ - {"key": "knowledge", "optional": False} - ], - "empty_response": "Sorry! 知识库中未找到相关内容!" - } - default_prompt_no_dataset = { - "system": """You are a helpful assistant.""", - "prologue": "您好,我是您的助手小樱,长得可爱又善良,can I help you?", - "parameters": [ - - ], - "empty_response": "" - } - prompt_config = req.get("prompt_config", default_prompt_with_dataset) - - if not prompt_config["system"]: - prompt_config["system"] = default_prompt_with_dataset["system"] + prompt_config = req["prompt_config"] - if not req.get("kb_ids", []): - if prompt_config['system'] == default_prompt_with_dataset['system'] or "{knowledge}" in prompt_config['system']: - prompt_config = default_prompt_no_dataset + if not req.get("kb_ids", []) and not prompt_config.get("tavily_api_key") and "{knowledge}" in prompt_config['system']: + return get_data_error_result(message="Please remove `{knowledge}` in system prompt since no knowledge base/Tavily used here.") for p in prompt_config["parameters"]: if p["optional"]: diff --git a/api/apps/document_app.py b/api/apps/document_app.py index 43c7f813b7e..68a76394f54 100644 --- a/api/apps/document_app.py +++ b/api/apps/document_app.py @@ -23,7 +23,7 @@ from flask_login import current_user, login_required from api import settings -from api.constants import IMG_BASE64_PREFIX +from api.constants import FILE_NAME_LEN_LIMIT, IMG_BASE64_PREFIX from api.db import VALID_FILE_TYPES, VALID_TASK_STATUS, FileSource, FileType, ParserType, TaskStatus from api.db.db_models import File, Task from api.db.services import duplicate_name @@ -61,18 +61,21 @@ def upload(): for file_obj in file_objs: if file_obj.filename == "": return get_json_result(data=False, message="No file selected!", code=settings.RetCode.ARGUMENT_ERROR) + if len(file_obj.filename.encode("utf-8")) > FILE_NAME_LEN_LIMIT: + return get_json_result(data=False, message=f"File name must be {FILE_NAME_LEN_LIMIT} bytes or less.", code=settings.RetCode.ARGUMENT_ERROR) e, kb = KnowledgebaseService.get_by_id(kb_id) if not e: raise LookupError("Can't find this knowledgebase!") err, files = FileService.upload_document(kb, file_objs, current_user.id) + if err: + return get_json_result(data=files, message="\n".join(err), code=settings.RetCode.SERVER_ERROR) + if not files: return get_json_result(data=files, message="There seems to be an issue with your file format. Please verify it is correct and not corrupted.", code=settings.RetCode.DATA_ERROR) files = [f[0] for f in files] # remove the blob - if err: - return get_json_result(data=files, message="\n".join(err), code=settings.RetCode.SERVER_ERROR) return get_json_result(data=files) @@ -146,6 +149,12 @@ def create(): kb_id = req["kb_id"] if not kb_id: return get_json_result(data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR) + if len(req["name"].encode("utf-8")) > FILE_NAME_LEN_LIMIT: + return get_json_result(data=False, message=f"File name must be {FILE_NAME_LEN_LIMIT} bytes or less.", code=settings.RetCode.ARGUMENT_ERROR) + + if req["name"].strip() == "": + return get_json_result(data=False, message="File name can't be empty.", code=settings.RetCode.ARGUMENT_ERROR) + req["name"] = req["name"].strip() try: e, kb = KnowledgebaseService.get_by_id(kb_id) @@ -190,7 +199,10 @@ def list_docs(): page_number = int(request.args.get("page", 0)) items_per_page = int(request.args.get("page_size", 0)) orderby = request.args.get("orderby", "create_time") - desc = request.args.get("desc", True) + if request.args.get("desc", "true").lower() == "false": + desc = False + else: + desc = True req = request.get_json() @@ -401,6 +413,9 @@ def rename(): return get_data_error_result(message="Document not found!") if pathlib.Path(req["name"].lower()).suffix != pathlib.Path(doc.name.lower()).suffix: return get_json_result(data=False, message="The extension of file can't be changed", code=settings.RetCode.ARGUMENT_ERROR) + if len(req["name"].encode("utf-8")) > FILE_NAME_LEN_LIMIT: + return get_json_result(data=False, message=f"File name must be {FILE_NAME_LEN_LIMIT} bytes or less.", code=settings.RetCode.ARGUMENT_ERROR) + for d in DocumentService.query(name=req["name"], kb_id=doc.kb_id): if d.name == req["name"]: return get_data_error_result(message="Duplicated document name in the same knowledgebase.") diff --git a/api/apps/kb_app.py b/api/apps/kb_app.py index 3267c806016..43e6e4eac3e 100644 --- a/api/apps/kb_app.py +++ b/api/apps/kb_app.py @@ -34,6 +34,7 @@ from rag.nlp import search from api.constants import DATASET_NAME_LIMIT from rag.settings import PAGERANK_FLD +from rag.utils.storage_factory import STORAGE_IMPL @manager.route('/create', methods=['post']) # noqa: F821 @@ -44,11 +45,11 @@ def create(): dataset_name = req["name"] if not isinstance(dataset_name, str): return get_data_error_result(message="Dataset name must be string.") - if dataset_name == "": + if dataset_name.strip() == "": return get_data_error_result(message="Dataset name can't be empty.") - if len(dataset_name) >= DATASET_NAME_LIMIT: + if len(dataset_name.encode("utf-8")) > DATASET_NAME_LIMIT: return get_data_error_result( - message=f"Dataset name length is {len(dataset_name)} which is large than {DATASET_NAME_LIMIT}") + message=f"Dataset name length is {len(dataset_name)} which is larger than {DATASET_NAME_LIMIT}") dataset_name = dataset_name.strip() dataset_name = duplicate_name( @@ -78,7 +79,15 @@ def create(): @not_allowed_parameters("id", "tenant_id", "created_by", "create_time", "update_time", "create_date", "update_date", "created_by") def update(): req = request.json + if not isinstance(req["name"], str): + return get_data_error_result(message="Dataset name must be string.") + if req["name"].strip() == "": + return get_data_error_result(message="Dataset name can't be empty.") + if len(req["name"].encode("utf-8")) > DATASET_NAME_LIMIT: + return get_data_error_result( + message=f"Dataset name length is {len(req['name'])} which is large than {DATASET_NAME_LIMIT}") req["name"] = req["name"].strip() + if not KnowledgebaseService.accessible4deletion(req["kb_id"], current_user.id): return get_json_result( data=False, @@ -106,7 +115,7 @@ def update(): if req["name"].lower() != kb.name.lower() \ and len( - KnowledgebaseService.query(name=req["name"], tenant_id=current_user.id, status=StatusEnum.VALID.value)) > 1: + KnowledgebaseService.query(name=req["name"], tenant_id=current_user.id, status=StatusEnum.VALID.value)) >= 1: return get_data_error_result( message="Duplicated knowledgebase name.") @@ -115,6 +124,9 @@ def update(): return get_data_error_result() if kb.pagerank != req.get("pagerank", 0): + if os.environ.get("DOC_ENGINE", "elasticsearch") != "elasticsearch": + return get_data_error_result(message="'pagerank' can only be set when doc_engine is elasticsearch") + if req.get("pagerank", 0) > 0: settings.docStoreConn.update({"kb_id": kb.id}, {PAGERANK_FLD: req["pagerank"]}, search.index_name(kb.tenant_id), kb.id) @@ -167,7 +179,10 @@ def list_kbs(): items_per_page = int(request.args.get("page_size", 0)) parser_id = request.args.get("parser_id") orderby = request.args.get("orderby", "create_time") - desc = request.args.get("desc", True) + if request.args.get("desc", "true").lower() == "false": + desc = False + else: + desc = True req = request.get_json() owner_ids = req.get("owner_ids", []) @@ -184,9 +199,9 @@ def list_kbs(): tenants, current_user.id, 0, 0, orderby, desc, keywords, parser_id) kbs = [kb for kb in kbs if kb["tenant_id"] in tenants] + total = len(kbs) if page_number and items_per_page: kbs = kbs[(page_number-1)*items_per_page:page_number*items_per_page] - total = len(kbs) return get_json_result(data={"kbs": kbs, "total": total}) except Exception as e: return server_error_response(e) @@ -226,6 +241,8 @@ def rm(): for kb in kbs: settings.docStoreConn.delete({"kb_id": kb.id}, search.index_name(kb.tenant_id), kb.id) settings.docStoreConn.deleteIdx(search.index_name(kb.tenant_id), kb.id) + if hasattr(STORAGE_IMPL, 'remove_bucket'): + STORAGE_IMPL.remove_bucket(kb.id) return get_json_result(data=True) except Exception as e: return server_error_response(e) diff --git a/api/apps/sdk/chat.py b/api/apps/sdk/chat.py index 24d89db85c2..3667dab548f 100644 --- a/api/apps/sdk/chat.py +++ b/api/apps/sdk/chat.py @@ -173,8 +173,10 @@ def update(tenant_id, chat_id): if llm: if "model_name" in llm: req["llm_id"] = llm.pop("model_name") - if not TenantLLMService.query(tenant_id=tenant_id, llm_name=req["llm_id"], model_type="chat"): - return get_error_data_result(f"`model_name` {req.get('llm_id')} doesn't exist") + if req.get("llm_id") is not None: + llm_name, llm_factory = TenantLLMService.split_model_name_and_factory(req["llm_id"]) + if not TenantLLMService.query(tenant_id=tenant_id, llm_name=llm_name, llm_factory=llm_factory, model_type="chat"): + return get_error_data_result(f"`model_name` {req.get('llm_id')} doesn't exist") req["llm_setting"] = req.pop("llm") e, tenant = TenantService.get_by_id(tenant_id) if not e: diff --git a/api/apps/sdk/dataset.py b/api/apps/sdk/dataset.py index f76cf2f9d93..e3675b8cd00 100644 --- a/api/apps/sdk/dataset.py +++ b/api/apps/sdk/dataset.py @@ -16,10 +16,12 @@ import logging +import os from flask import request from peewee import OperationalError +from api import settings from api.db import FileSource, StatusEnum from api.db.db_models import File from api.db.services.document_service import DocumentService @@ -48,6 +50,8 @@ validate_and_parse_json_request, validate_and_parse_request_args, ) +from rag.nlp import search +from rag.settings import PAGERANK_FLD @manager.route("/datasets", methods=["POST"]) # noqa: F821 @@ -97,9 +101,6 @@ def create(tenant_id): "picture", "presentation", "qa", "table", "tag" ] description: Chunking method. - pagerank: - type: integer - description: Set page rank. parser_config: type: object description: Parser configuration. @@ -124,48 +125,36 @@ def create(tenant_id): try: if KnowledgebaseService.get_or_none(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value): return get_error_operating_result(message=f"Dataset name '{req['name']}' already exists") - except OperationalError as e: - logging.exception(e) - return get_error_data_result(message="Database operation failed") - req["parser_config"] = get_parser_config(req["parser_id"], req["parser_config"]) - req["id"] = get_uuid() - req["tenant_id"] = tenant_id - req["created_by"] = tenant_id + req["parser_config"] = get_parser_config(req["parser_id"], req["parser_config"]) + req["id"] = get_uuid() + req["tenant_id"] = tenant_id + req["created_by"] = tenant_id - try: ok, t = TenantService.get_by_id(tenant_id) if not ok: return get_error_permission_result(message="Tenant not found") - except OperationalError as e: - logging.exception(e) - return get_error_data_result(message="Database operation failed") - if not req.get("embd_id"): - req["embd_id"] = t.embd_id - else: - ok, err = verify_embedding_availability(req["embd_id"], tenant_id) - if not ok: - return err + if not req.get("embd_id"): + req["embd_id"] = t.embd_id + else: + ok, err = verify_embedding_availability(req["embd_id"], tenant_id) + if not ok: + return err - try: if not KnowledgebaseService.save(**req): return get_error_data_result(message="Create dataset error.(Database error)") - except OperationalError as e: - logging.exception(e) - return get_error_data_result(message="Database operation failed") - try: ok, k = KnowledgebaseService.get_by_id(req["id"]) if not ok: return get_error_data_result(message="Dataset created failed") + + response_data = remap_dictionary_keys(k.to_dict()) + return get_result(data=response_data) except OperationalError as e: logging.exception(e) return get_error_data_result(message="Database operation failed") - response_data = remap_dictionary_keys(k.to_dict()) - return get_result(data=response_data) - @manager.route("/datasets", methods=["DELETE"]) # noqa: F821 @token_required @@ -211,34 +200,27 @@ def delete(tenant_id): if err is not None: return get_error_argument_result(err) - kb_id_instance_pairs = [] - if req["ids"] is None: - try: + try: + kb_id_instance_pairs = [] + if req["ids"] is None: kbs = KnowledgebaseService.query(tenant_id=tenant_id) for kb in kbs: kb_id_instance_pairs.append((kb.id, kb)) - except OperationalError as e: - logging.exception(e) - return get_error_data_result(message="Database operation failed") - else: - error_kb_ids = [] - for kb_id in req["ids"]: - try: + + else: + error_kb_ids = [] + for kb_id in req["ids"]: kb = KnowledgebaseService.get_or_none(id=kb_id, tenant_id=tenant_id) if kb is None: error_kb_ids.append(kb_id) continue kb_id_instance_pairs.append((kb_id, kb)) - except OperationalError as e: - logging.exception(e) - return get_error_data_result(message="Database operation failed") - if len(error_kb_ids) > 0: - return get_error_permission_result(message=f"""User '{tenant_id}' lacks permission for datasets: '{", ".join(error_kb_ids)}'""") - - errors = [] - success_count = 0 - for kb_id, kb in kb_id_instance_pairs: - try: + if len(error_kb_ids) > 0: + return get_error_permission_result(message=f"""User '{tenant_id}' lacks permission for datasets: '{", ".join(error_kb_ids)}'""") + + errors = [] + success_count = 0 + for kb_id, kb in kb_id_instance_pairs: for doc in DocumentService.query(kb_id=kb_id): if not DocumentService.remove_document(doc, tenant_id): errors.append(f"Remove document '{doc.id}' error for dataset '{kb_id}'") @@ -256,18 +238,18 @@ def delete(tenant_id): errors.append(f"Delete dataset error for {kb_id}") continue success_count += 1 - except OperationalError as e: - logging.exception(e) - return get_error_data_result(message="Database operation failed") - if not errors: - return get_result() + if not errors: + return get_result() - error_message = f"Successfully deleted {success_count} datasets, {len(errors)} failed. Details: {'; '.join(errors)[:128]}..." - if success_count == 0: - return get_error_data_result(message=error_message) + error_message = f"Successfully deleted {success_count} datasets, {len(errors)} failed. Details: {'; '.join(errors)[:128]}..." + if success_count == 0: + return get_error_data_result(message=error_message) - return get_result(data={"success_count": success_count, "errors": errors[:5]}, message=error_message) + return get_result(data={"success_count": success_count, "errors": errors[:5]}, message=error_message) + except OperationalError as e: + logging.exception(e) + return get_error_data_result(message="Database operation failed") @manager.route("/datasets/", methods=["PUT"]) # noqa: F821 @@ -349,44 +331,51 @@ def update(tenant_id, dataset_id): kb = KnowledgebaseService.get_or_none(id=dataset_id, tenant_id=tenant_id) if kb is None: return get_error_permission_result(message=f"User '{tenant_id}' lacks permission for dataset '{dataset_id}'") - except OperationalError as e: - logging.exception(e) - return get_error_data_result(message="Database operation failed") - if req.get("parser_config"): - req["parser_config"] = deep_merge(kb.parser_config, req["parser_config"]) + if req.get("parser_config"): + req["parser_config"] = deep_merge(kb.parser_config, req["parser_config"]) - if (chunk_method := req.get("parser_id")) and chunk_method != kb.parser_id: - if not req.get("parser_config"): - req["parser_config"] = get_parser_config(chunk_method, None) - elif "parser_config" in req and not req["parser_config"]: - del req["parser_config"] + if (chunk_method := req.get("parser_id")) and chunk_method != kb.parser_id: + if not req.get("parser_config"): + req["parser_config"] = get_parser_config(chunk_method, None) + elif "parser_config" in req and not req["parser_config"]: + del req["parser_config"] - if "name" in req and req["name"].lower() != kb.name.lower(): - try: + if "name" in req and req["name"].lower() != kb.name.lower(): exists = KnowledgebaseService.get_or_none(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value) if exists: return get_error_data_result(message=f"Dataset name '{req['name']}' already exists") - except OperationalError as e: - logging.exception(e) - return get_error_data_result(message="Database operation failed") - - if "embd_id" in req: - if kb.chunk_num != 0 and req["embd_id"] != kb.embd_id: - return get_error_data_result(message=f"When chunk_num ({kb.chunk_num}) > 0, embedding_model must remain {kb.embd_id}") - ok, err = verify_embedding_availability(req["embd_id"], tenant_id) - if not ok: - return err - try: + if "embd_id" in req: + if kb.chunk_num != 0 and req["embd_id"] != kb.embd_id: + return get_error_data_result(message=f"When chunk_num ({kb.chunk_num}) > 0, embedding_model must remain {kb.embd_id}") + ok, err = verify_embedding_availability(req["embd_id"], tenant_id) + if not ok: + return err + + if "pagerank" in req and req["pagerank"] != kb.pagerank: + if os.environ.get("DOC_ENGINE", "elasticsearch") == "infinity": + return get_error_argument_result(message="'pagerank' can only be set when doc_engine is elasticsearch") + + if req["pagerank"] > 0: + settings.docStoreConn.update({"kb_id": kb.id}, {PAGERANK_FLD: req["pagerank"]}, search.index_name(kb.tenant_id), kb.id) + else: + # Elasticsearch requires PAGERANK_FLD be non-zero! + settings.docStoreConn.update({"exists": PAGERANK_FLD}, {"remove": PAGERANK_FLD}, search.index_name(kb.tenant_id), kb.id) + if not KnowledgebaseService.update_by_id(kb.id, req): return get_error_data_result(message="Update dataset error.(Database error)") + + ok, k = KnowledgebaseService.get_by_id(kb.id) + if not ok: + return get_error_data_result(message="Dataset created failed") + + response_data = remap_dictionary_keys(k.to_dict()) + return get_result(data=response_data) except OperationalError as e: logging.exception(e) return get_error_data_result(message="Database operation failed") - return get_result() - @manager.route("/datasets", methods=["GET"]) # noqa: F821 @token_required @@ -450,26 +439,19 @@ def list_datasets(tenant_id): if err is not None: return get_error_argument_result(err) - kb_id = request.args.get("id") - name = args.get("name") - if kb_id: - try: + try: + kb_id = request.args.get("id") + name = args.get("name") + if kb_id: kbs = KnowledgebaseService.get_kb_by_id(kb_id, tenant_id) - except OperationalError as e: - logging.exception(e) - return get_error_data_result(message="Database operation failed") - if not kbs: - return get_error_permission_result(message=f"User '{tenant_id}' lacks permission for dataset '{kb_id}'") - if name: - try: + + if not kbs: + return get_error_permission_result(message=f"User '{tenant_id}' lacks permission for dataset '{kb_id}'") + if name: kbs = KnowledgebaseService.get_kb_by_name(name, tenant_id) - except OperationalError as e: - logging.exception(e) - return get_error_data_result(message="Database operation failed") - if not kbs: - return get_error_permission_result(message=f"User '{tenant_id}' lacks permission for dataset '{name}'") + if not kbs: + return get_error_permission_result(message=f"User '{tenant_id}' lacks permission for dataset '{name}'") - try: tenants = TenantService.get_joined_tenants_by_user_id(tenant_id) kbs = KnowledgebaseService.get_list( [m["tenant_id"] for m in tenants], @@ -481,11 +463,11 @@ def list_datasets(tenant_id): kb_id, name, ) + + response_data_list = [] + for kb in kbs: + response_data_list.append(remap_dictionary_keys(kb)) + return get_result(data=response_data_list) except OperationalError as e: logging.exception(e) return get_error_data_result(message="Database operation failed") - - response_data_list = [] - for kb in kbs: - response_data_list.append(remap_dictionary_keys(kb)) - return get_result(data=response_data_list) diff --git a/api/apps/sdk/dify_retrieval.py b/api/apps/sdk/dify_retrieval.py index 5d6a8c896c5..f15eb2396d0 100644 --- a/api/apps/sdk/dify_retrieval.py +++ b/api/apps/sdk/dify_retrieval.py @@ -16,6 +16,7 @@ from flask import request, jsonify from api.db import LLMType +from api.db.services.document_service import DocumentService from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.llm_service import LLMBundle from api import settings @@ -70,12 +71,13 @@ def retrieval(tenant_id): records = [] for c in ranks["chunks"]: + e, doc = DocumentService.get_by_id( c["doc_id"]) c.pop("vector", None) records.append({ "content": c["content_with_weight"], "score": c["similarity"], "title": c["docnm_kwd"], - "metadata": {} + "metadata": doc.meta_fields }) return jsonify({"records": records}) diff --git a/api/apps/sdk/doc.py b/api/apps/sdk/doc.py index f76065c1cd9..e0f77c985e9 100644 --- a/api/apps/sdk/doc.py +++ b/api/apps/sdk/doc.py @@ -13,38 +13,35 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import pathlib import datetime - -from rag.app.qa import rmPrefix, beAdoc -from rag.nlp import rag_tokenizer -from api.db import LLMType, ParserType -from api.db.services.llm_service import TenantLLMService, LLMBundle -from api import settings -import xxhash +import logging +import pathlib import re -from api.utils.api_utils import token_required -from api.db.db_models import Task -from api.db.services.task_service import TaskService, queue_tasks -from api.utils.api_utils import server_error_response -from api.utils.api_utils import get_result, get_error_data_result from io import BytesIO + +import xxhash from flask import request, send_file -from api.db import FileSource, TaskStatus, FileType -from api.db.db_models import File +from peewee import OperationalError +from pydantic import BaseModel, Field, validator + +from api import settings +from api.constants import FILE_NAME_LEN_LIMIT +from api.db import FileSource, FileType, LLMType, ParserType, TaskStatus +from api.db.db_models import File, Task from api.db.services.document_service import DocumentService from api.db.services.file2document_service import File2DocumentService from api.db.services.file_service import FileService from api.db.services.knowledgebase_service import KnowledgebaseService -from api.utils.api_utils import construct_json_result, get_parser_config, check_duplicate_ids -from rag.nlp import search -from rag.prompts import keyword_extraction +from api.db.services.llm_service import LLMBundle, TenantLLMService +from api.db.services.task_service import TaskService, queue_tasks +from api.utils.api_utils import check_duplicate_ids, construct_json_result, get_error_data_result, get_parser_config, get_result, server_error_response, token_required +from rag.app.qa import beAdoc, rmPrefix from rag.app.tag import label_question +from rag.nlp import rag_tokenizer, search +from rag.prompts import keyword_extraction from rag.utils import rmSpace from rag.utils.storage_factory import STORAGE_IMPL -from pydantic import BaseModel, Field, validator - MAXIMUM_OF_UPLOADING_FILES = 256 @@ -60,7 +57,7 @@ class Chunk(BaseModel): available: bool = True positions: list[list[int]] = Field(default_factory=list) - @validator('positions') + @validator("positions") def validate_positions(cls, value): for sublist in value: if len(sublist) != 5: @@ -128,20 +125,14 @@ def upload(dataset_id, tenant_id): description: Processing status. """ if "file" not in request.files: - return get_error_data_result( - message="No file part!", code=settings.RetCode.ARGUMENT_ERROR - ) + return get_error_data_result(message="No file part!", code=settings.RetCode.ARGUMENT_ERROR) file_objs = request.files.getlist("file") for file_obj in file_objs: if file_obj.filename == "": - return get_result( - message="No file selected!", code=settings.RetCode.ARGUMENT_ERROR - ) - if len(file_obj.filename.encode("utf-8")) >= 128: - return get_result( - message="File name should be less than 128 bytes.", code=settings.RetCode.ARGUMENT_ERROR - ) - ''' + return get_result(message="No file selected!", code=settings.RetCode.ARGUMENT_ERROR) + if len(file_obj.filename.encode("utf-8")) > FILE_NAME_LEN_LIMIT: + return get_result(message=f"File name must be {FILE_NAME_LEN_LIMIT} bytes or less.", code=settings.RetCode.ARGUMENT_ERROR) + """ # total size total_size = 0 for file_obj in file_objs: @@ -154,7 +145,7 @@ def upload(dataset_id, tenant_id): message=f"Total file size exceeds 10MB limit! ({total_size / (1024 * 1024):.2f} MB)", code=settings.RetCode.ARGUMENT_ERROR, ) - ''' + """ e, kb = KnowledgebaseService.get_by_id(dataset_id) if not e: raise LookupError(f"Can't find the dataset with ID {dataset_id}!") @@ -236,8 +227,7 @@ def update_doc(tenant_id, dataset_id, document_id): return get_error_data_result(message="You don't own the dataset.") e, kb = KnowledgebaseService.get_by_id(dataset_id) if not e: - return get_error_data_result( - message="Can't find this knowledgebase!") + return get_error_data_result(message="Can't find this knowledgebase!") doc = DocumentService.query(kb_id=dataset_id, id=document_id) if not doc: return get_error_data_result(message="The dataset doesn't own the document.") @@ -258,24 +248,19 @@ def update_doc(tenant_id, dataset_id, document_id): DocumentService.update_meta_fields(document_id, req["meta_fields"]) if "name" in req and req["name"] != doc.name: - if len(req["name"].encode("utf-8")) >= 128: + if len(req["name"].encode("utf-8")) > FILE_NAME_LEN_LIMIT: return get_result( - message="The name should be less than 128 bytes.", + message=f"File name must be {FILE_NAME_LEN_LIMIT} bytes or less.", code=settings.RetCode.ARGUMENT_ERROR, ) - if ( - pathlib.Path(req["name"].lower()).suffix - != pathlib.Path(doc.name.lower()).suffix - ): + if pathlib.Path(req["name"].lower()).suffix != pathlib.Path(doc.name.lower()).suffix: return get_result( message="The extension of file can't be changed", code=settings.RetCode.ARGUMENT_ERROR, ) for d in DocumentService.query(name=req["name"], kb_id=doc.kb_id): if d.name == req["name"]: - return get_error_data_result( - message="Duplicated document name in the same dataset." - ) + return get_error_data_result(message="Duplicated document name in the same dataset.") if not DocumentService.update_by_id(document_id, {"name": req["name"]}): return get_error_data_result(message="Database error (Document rename)!") @@ -287,46 +272,28 @@ def update_doc(tenant_id, dataset_id, document_id): if "parser_config" in req: DocumentService.update_parser_config(doc.id, req["parser_config"]) if "chunk_method" in req: - valid_chunk_method = { - "naive", - "manual", - "qa", - "table", - "paper", - "book", - "laws", - "presentation", - "picture", - "one", - "knowledge_graph", - "email", - "tag" - } + valid_chunk_method = {"naive", "manual", "qa", "table", "paper", "book", "laws", "presentation", "picture", "one", "knowledge_graph", "email", "tag"} if req.get("chunk_method") not in valid_chunk_method: - return get_error_data_result( - f"`chunk_method` {req['chunk_method']} doesn't exist" - ) - if doc.parser_id.lower() == req["chunk_method"].lower(): - return get_result() + return get_error_data_result(f"`chunk_method` {req['chunk_method']} doesn't exist") if doc.type == FileType.VISUAL or re.search(r"\.(ppt|pptx|pages)$", doc.name): return get_error_data_result(message="Not supported yet!") - e = DocumentService.update_by_id( - doc.id, - { - "parser_id": req["chunk_method"], - "progress": 0, - "progress_msg": "", - "run": TaskStatus.UNSTART.value, - }, - ) - if not e: - return get_error_data_result(message="Document not found!") - req["parser_config"] = get_parser_config( - req["chunk_method"], req.get("parser_config") - ) - DocumentService.update_parser_config(doc.id, req["parser_config"]) + if doc.parser_id.lower() != req["chunk_method"].lower(): + e = DocumentService.update_by_id( + doc.id, + { + "parser_id": req["chunk_method"], + "progress": 0, + "progress_msg": "", + "run": TaskStatus.UNSTART.value, + }, + ) + if not e: + return get_error_data_result(message="Document not found!") + if not req.get("parser_config"): + req["parser_config"] = get_parser_config(req["chunk_method"], req.get("parser_config")) + DocumentService.update_parser_config(doc.id, req["parser_config"]) if doc.token_num > 0: e = DocumentService.increment_chunk_num( doc.id, @@ -343,19 +310,45 @@ def update_doc(tenant_id, dataset_id, document_id): status = int(req["enabled"]) if doc.status != req["enabled"]: try: - if not DocumentService.update_by_id( - doc.id, {"status": str(status)}): - return get_error_data_result( - message="Database error (Document update)!") + if not DocumentService.update_by_id(doc.id, {"status": str(status)}): + return get_error_data_result(message="Database error (Document update)!") - settings.docStoreConn.update({"doc_id": doc.id}, {"available_int": status}, - search.index_name(kb.tenant_id), doc.kb_id) + settings.docStoreConn.update({"doc_id": doc.id}, {"available_int": status}, search.index_name(kb.tenant_id), doc.kb_id) return get_result(data=True) except Exception as e: return server_error_response(e) - return get_result() + try: + ok, doc = DocumentService.get_by_id(doc.id) + if not ok: + return get_error_data_result(message="Dataset created failed") + except OperationalError as e: + logging.exception(e) + return get_error_data_result(message="Database operation failed") + key_mapping = { + "chunk_num": "chunk_count", + "kb_id": "dataset_id", + "token_num": "token_count", + "parser_id": "chunk_method", + } + run_mapping = { + "0": "UNSTART", + "1": "RUNNING", + "2": "CANCEL", + "3": "DONE", + "4": "FAIL", + } + renamed_doc = {} + for key, value in doc.to_dict().items(): + if key == "run": + renamed_doc["run"] = run_mapping.get(str(value)) + new_key = key_mapping.get(key, key) + renamed_doc[new_key] = value + if key == "run": + renamed_doc["run"] = run_mapping.get(value) + + return get_result(data=renamed_doc) @manager.route("/datasets//documents/", methods=["GET"]) # noqa: F821 @@ -397,25 +390,17 @@ def download(tenant_id, dataset_id, document_id): type: object """ if not document_id: - return get_error_data_result( - message="Specify document_id please." - ) + return get_error_data_result(message="Specify document_id please.") if not KnowledgebaseService.query(id=dataset_id, tenant_id=tenant_id): return get_error_data_result(message=f"You do not own the dataset {dataset_id}.") doc = DocumentService.query(kb_id=dataset_id, id=document_id) if not doc: - return get_error_data_result( - message=f"The dataset not own the document {document_id}." - ) + return get_error_data_result(message=f"The dataset not own the document {document_id}.") # The process of downloading - doc_id, doc_location = File2DocumentService.get_storage_address( - doc_id=document_id - ) # minio address + doc_id, doc_location = File2DocumentService.get_storage_address(doc_id=document_id) # minio address file_stream = STORAGE_IMPL.get(doc_id, doc_location) if not file_stream: - return construct_json_result( - message="This file is empty.", code=settings.RetCode.DATA_ERROR - ) + return construct_json_result(message="This file is empty.", code=settings.RetCode.DATA_ERROR) file = BytesIO(file_stream) # Use send_file with a proper filename and MIME type return send_file( @@ -530,9 +515,7 @@ def list_docs(dataset_id, tenant_id): desc = False else: desc = True - docs, tol = DocumentService.get_list( - dataset_id, page, page_size, orderby, desc, keywords, id, name - ) + docs, tol = DocumentService.get_list(dataset_id, page, page_size, orderby, desc, keywords, id, name) # rename key's name renamed_doc_list = [] @@ -638,9 +621,7 @@ def delete(tenant_id, dataset_id): b, n = File2DocumentService.get_storage_address(doc_id=doc_id) if not DocumentService.remove_document(doc, tenant_id): - return get_error_data_result( - message="Database error (Document removal)!" - ) + return get_error_data_result(message="Database error (Document removal)!") f2d = File2DocumentService.get_by_document_id(doc_id) FileService.filter_delete( @@ -664,7 +645,10 @@ def delete(tenant_id, dataset_id): if duplicate_messages: if success_count > 0: - return get_result(message=f"Partially deleted {success_count} datasets with {len(duplicate_messages)} errors", data={"success_count": success_count, "errors": duplicate_messages},) + return get_result( + message=f"Partially deleted {success_count} datasets with {len(duplicate_messages)} errors", + data={"success_count": success_count, "errors": duplicate_messages}, + ) else: return get_error_data_result(message=";".join(duplicate_messages)) @@ -729,9 +713,7 @@ def parse(tenant_id, dataset_id): if not doc: return get_error_data_result(message=f"You don't own the document {id}.") if 0.0 < doc[0].progress < 1.0: - return get_error_data_result( - "Can't parse document that is currently being processed" - ) + return get_error_data_result("Can't parse document that is currently being processed") info = {"run": "1", "progress": 0, "progress_msg": "", "chunk_num": 0, "token_num": 0} DocumentService.update_by_id(id, info) settings.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), dataset_id) @@ -746,7 +728,10 @@ def parse(tenant_id, dataset_id): return get_result(message=f"Documents not found: {not_found}", code=settings.RetCode.DATA_ERROR) if duplicate_messages: if success_count > 0: - return get_result(message=f"Partially parsed {success_count} documents with {len(duplicate_messages)} errors", data={"success_count": success_count, "errors": duplicate_messages},) + return get_result( + message=f"Partially parsed {success_count} documents with {len(duplicate_messages)} errors", + data={"success_count": success_count, "errors": duplicate_messages}, + ) else: return get_error_data_result(message=";".join(duplicate_messages)) @@ -808,16 +793,17 @@ def stop_parsing(tenant_id, dataset_id): if not doc: return get_error_data_result(message=f"You don't own the document {id}.") if int(doc[0].progress) == 1 or doc[0].progress == 0: - return get_error_data_result( - "Can't stop parsing document with progress at 0 or 1" - ) + return get_error_data_result("Can't stop parsing document with progress at 0 or 1") info = {"run": "2", "progress": 0, "chunk_num": 0} DocumentService.update_by_id(id, info) settings.docStoreConn.delete({"doc_id": doc[0].id}, search.index_name(tenant_id), dataset_id) success_count += 1 if duplicate_messages: if success_count > 0: - return get_result(message=f"Partially stopped {success_count} documents with {len(duplicate_messages)} errors", data={"success_count": success_count, "errors": duplicate_messages},) + return get_result( + message=f"Partially stopped {success_count} documents with {len(duplicate_messages)} errors", + data={"success_count": success_count, "errors": duplicate_messages}, + ) else: return get_error_data_result(message=";".join(duplicate_messages)) return get_result() @@ -906,9 +892,7 @@ def list_chunks(tenant_id, dataset_id, document_id): return get_error_data_result(message=f"You don't own the dataset {dataset_id}.") doc = DocumentService.query(id=document_id, kb_id=dataset_id) if not doc: - return get_error_data_result( - message=f"You don't own the document {document_id}." - ) + return get_error_data_result(message=f"You don't own the document {document_id}.") doc = doc[0] req = request.args doc_id = document_id @@ -956,34 +940,29 @@ def list_chunks(tenant_id, dataset_id, document_id): del chunk[n] if not chunk: return get_error_data_result(f"Chunk `{req.get('id')}` not found.") - res['total'] = 1 + res["total"] = 1 final_chunk = { - "id":chunk.get("id",chunk.get("chunk_id")), - "content":chunk["content_with_weight"], - "document_id":chunk.get("doc_id",chunk.get("document_id")), - "docnm_kwd":chunk["docnm_kwd"], - "important_keywords":chunk.get("important_kwd",[]), - "questions":chunk.get("question_kwd",[]), - "dataset_id":chunk.get("kb_id",chunk.get("dataset_id")), - "image_id":chunk.get("img_id", ""), - "available":bool(chunk.get("available_int",1)), - "positions":chunk.get("position_int",[]), + "id": chunk.get("id", chunk.get("chunk_id")), + "content": chunk["content_with_weight"], + "document_id": chunk.get("doc_id", chunk.get("document_id")), + "docnm_kwd": chunk["docnm_kwd"], + "important_keywords": chunk.get("important_kwd", []), + "questions": chunk.get("question_kwd", []), + "dataset_id": chunk.get("kb_id", chunk.get("dataset_id")), + "image_id": chunk.get("img_id", ""), + "available": bool(chunk.get("available_int", 1)), + "positions": chunk.get("position_int", []), } res["chunks"].append(final_chunk) _ = Chunk(**final_chunk) elif settings.docStoreConn.indexExist(search.index_name(tenant_id), dataset_id): - sres = settings.retrievaler.search(query, search.index_name(tenant_id), [dataset_id], emb_mdl=None, - highlight=True) + sres = settings.retrievaler.search(query, search.index_name(tenant_id), [dataset_id], emb_mdl=None, highlight=True) res["total"] = sres.total for id in sres.ids: d = { "id": id, - "content": ( - rmSpace(sres.highlight[id]) - if question and id in sres.highlight - else sres.field[id].get("content_with_weight", "") - ), + "content": (rmSpace(sres.highlight[id]) if question and id in sres.highlight else sres.field[id].get("content_with_weight", "")), "document_id": sres.field[id]["doc_id"], "docnm_kwd": sres.field[id]["docnm_kwd"], "important_keywords": sres.field[id].get("important_kwd", []), @@ -991,10 +970,10 @@ def list_chunks(tenant_id, dataset_id, document_id): "dataset_id": sres.field[id].get("kb_id", sres.field[id].get("dataset_id")), "image_id": sres.field[id].get("img_id", ""), "available": bool(int(sres.field[id].get("available_int", "1"))), - "positions": sres.field[id].get("position_int",[]), + "positions": sres.field[id].get("position_int", []), } res["chunks"].append(d) - _ = Chunk(**d) # validate the chunk + _ = Chunk(**d) # validate the chunk return get_result(data=res) @@ -1070,23 +1049,17 @@ def add_chunk(tenant_id, dataset_id, document_id): return get_error_data_result(message=f"You don't own the dataset {dataset_id}.") doc = DocumentService.query(id=document_id, kb_id=dataset_id) if not doc: - return get_error_data_result( - message=f"You don't own the document {document_id}." - ) + return get_error_data_result(message=f"You don't own the document {document_id}.") doc = doc[0] req = request.json if not str(req.get("content", "")).strip(): return get_error_data_result(message="`content` is required") if "important_keywords" in req: if not isinstance(req["important_keywords"], list): - return get_error_data_result( - "`important_keywords` is required to be a list" - ) + return get_error_data_result("`important_keywords` is required to be a list") if "questions" in req: if not isinstance(req["questions"], list): - return get_error_data_result( - "`questions` is required to be a list" - ) + return get_error_data_result("`questions` is required to be a list") chunk_id = xxhash.xxh64((req["content"] + document_id).encode("utf-8")).hexdigest() d = { "id": chunk_id, @@ -1095,22 +1068,16 @@ def add_chunk(tenant_id, dataset_id, document_id): } d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"]) d["important_kwd"] = req.get("important_keywords", []) - d["important_tks"] = rag_tokenizer.tokenize( - " ".join(req.get("important_keywords", [])) - ) + d["important_tks"] = rag_tokenizer.tokenize(" ".join(req.get("important_keywords", []))) d["question_kwd"] = [str(q).strip() for q in req.get("questions", []) if str(q).strip()] - d["question_tks"] = rag_tokenizer.tokenize( - "\n".join(req.get("questions", [])) - ) + d["question_tks"] = rag_tokenizer.tokenize("\n".join(req.get("questions", []))) d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19] d["create_timestamp_flt"] = datetime.datetime.now().timestamp() d["kb_id"] = dataset_id d["docnm_kwd"] = doc.name d["doc_id"] = document_id embd_id = DocumentService.get_embd_id(document_id) - embd_mdl = TenantLLMService.model_instance( - tenant_id, LLMType.EMBEDDING.value, embd_id - ) + embd_mdl = TenantLLMService.model_instance(tenant_id, LLMType.EMBEDDING.value, embd_id) v, c = embd_mdl.encode([doc.name, req["content"] if not d["question_kwd"] else "\n".join(d["question_kwd"])]) v = 0.1 * v[0] + 0.9 * v[1] d["q_%d_vec" % len(v)] = v.tolist() @@ -1203,7 +1170,10 @@ def rm_chunk(tenant_id, dataset_id, document_id): return get_result(message=f"deleted {chunk_number} chunks") return get_error_data_result(message=f"rm_chunk deleted chunks {chunk_number}, expect {len(unique_chunk_ids)}") if duplicate_messages: - return get_result(message=f"Partially deleted {chunk_number} chunks with {len(duplicate_messages)} errors", data={"success_count": chunk_number, "errors": duplicate_messages},) + return get_result( + message=f"Partially deleted {chunk_number} chunks with {len(duplicate_messages)} errors", + data={"success_count": chunk_number, "errors": duplicate_messages}, + ) return get_result(message=f"deleted {chunk_number} chunks") @@ -1271,9 +1241,7 @@ def update_chunk(tenant_id, dataset_id, document_id, chunk_id): return get_error_data_result(message=f"You don't own the dataset {dataset_id}.") doc = DocumentService.query(id=document_id, kb_id=dataset_id) if not doc: - return get_error_data_result( - message=f"You don't own the document {document_id}." - ) + return get_error_data_result(message=f"You don't own the document {document_id}.") doc = doc[0] req = request.json if "content" in req: @@ -1296,19 +1264,13 @@ def update_chunk(tenant_id, dataset_id, document_id, chunk_id): if "available" in req: d["available_int"] = int(req["available"]) embd_id = DocumentService.get_embd_id(document_id) - embd_mdl = TenantLLMService.model_instance( - tenant_id, LLMType.EMBEDDING.value, embd_id - ) + embd_mdl = TenantLLMService.model_instance(tenant_id, LLMType.EMBEDDING.value, embd_id) if doc.parser_id == ParserType.QA: arr = [t for t in re.split(r"[\n\t]", d["content_with_weight"]) if len(t) > 1] if len(arr) != 2: - return get_error_data_result( - message="Q&A must be separated by TAB/ENTER key." - ) + return get_error_data_result(message="Q&A must be separated by TAB/ENTER key.") q, a = rmPrefix(arr[0]), rmPrefix(arr[1]) - d = beAdoc( - d, arr[0], arr[1], not any([rag_tokenizer.is_chinese(t) for t in q + a]) - ) + d = beAdoc(d, arr[0], arr[1], not any([rag_tokenizer.is_chinese(t) for t in q + a])) v, c = embd_mdl.encode([doc.name, d["content_with_weight"] if not d.get("question_kwd") else "\n".join(d["question_kwd"])]) v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1] @@ -1425,9 +1387,7 @@ def retrieval_test(tenant_id): doc_ids_list = KnowledgebaseService.list_documents_by_ids(kb_ids) for doc_id in doc_ids: if doc_id not in doc_ids_list: - return get_error_data_result( - f"The datasets don't own the document {doc_id}" - ) + return get_error_data_result(f"The datasets don't own the document {doc_id}") similarity_threshold = float(req.get("similarity_threshold", 0.2)) vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3)) top = int(req.get("top_k", 1024)) @@ -1463,14 +1423,10 @@ def retrieval_test(tenant_id): doc_ids, rerank_mdl=rerank_mdl, highlight=highlight, - rank_feature=label_question(question, kbs) + rank_feature=label_question(question, kbs), ) if use_kg: - ck = settings.kg_retrievaler.retrieval(question, - [k.tenant_id for k in kbs], - kb_ids, - embd_mdl, - LLMBundle(kb.tenant_id, LLMType.CHAT)) + ck = settings.kg_retrievaler.retrieval(question, [k.tenant_id for k in kbs], kb_ids, embd_mdl, LLMBundle(kb.tenant_id, LLMType.CHAT)) if ck["content_with_weight"]: ranks["chunks"].insert(0, ck) @@ -1487,7 +1443,7 @@ def retrieval_test(tenant_id): "important_kwd": "important_keywords", "question_kwd": "questions", "docnm_kwd": "document_keyword", - "kb_id":"dataset_id" + "kb_id": "dataset_id", } rename_chunk = {} for key, value in chunk.items(): diff --git a/api/apps/sdk/session.py b/api/apps/sdk/session.py index e9edab8b259..1716db9eee5 100644 --- a/api/apps/sdk/session.py +++ b/api/apps/sdk/session.py @@ -388,10 +388,10 @@ def agents_completion_openai_compatibility (tenant_id, agent_id): question = next((m["content"] for m in reversed(messages) if m["role"] == "user"), "") if req.get("stream", True): - return Response(completionOpenAI(tenant_id, agent_id, question, session_id=req.get("id", ""), stream=True), mimetype="text/event-stream") + return Response(completionOpenAI(tenant_id, agent_id, question, session_id=req.get("id", req.get("metadata", {}).get("id","")), stream=True), mimetype="text/event-stream") else: # For non-streaming, just return the response directly - response = next(completionOpenAI(tenant_id, agent_id, question, session_id=req.get("id", ""), stream=False)) + response = next(completionOpenAI(tenant_id, agent_id, question, session_id=req.get("id", req.get("metadata", {}).get("id","")), stream=False)) return jsonify(response) @@ -464,7 +464,7 @@ def list_session(tenant_id, chat_id): if conv["reference"]: messages = conv["messages"] message_num = 0 - while message_num < len(messages): + while message_num < len(messages) and message_num < len(conv["reference"]): if message_num != 0 and messages[message_num]["role"] != "user": chunk_list = [] if "chunks" in conv["reference"][message_num]: diff --git a/api/apps/search_app.py b/api/apps/search_app.py new file mode 100644 index 00000000000..083e6308331 --- /dev/null +++ b/api/apps/search_app.py @@ -0,0 +1,188 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from flask import request +from flask_login import current_user, login_required + +from api import settings +from api.constants import DATASET_NAME_LIMIT +from api.db import StatusEnum +from api.db.db_models import DB +from api.db.services import duplicate_name +from api.db.services.knowledgebase_service import KnowledgebaseService +from api.db.services.search_service import SearchService +from api.db.services.user_service import TenantService, UserTenantService +from api.utils import get_uuid +from api.utils.api_utils import get_data_error_result, get_json_result, not_allowed_parameters, server_error_response, validate_request + + +@manager.route("/create", methods=["post"]) # noqa: F821 +@login_required +@validate_request("name") +def create(): + req = request.get_json() + search_name = req["name"] + description = req.get("description", "") + if not isinstance(search_name, str): + return get_data_error_result(message="Search name must be string.") + if search_name.strip() == "": + return get_data_error_result(message="Search name can't be empty.") + if len(search_name.encode("utf-8")) > DATASET_NAME_LIMIT: + return get_data_error_result(message=f"Search name length is {len(search_name)} which is large than {DATASET_NAME_LIMIT}") + e, _ = TenantService.get_by_id(current_user.id) + if not e: + return get_data_error_result(message="Authorizationd identity.") + + search_name = search_name.strip() + search_name = duplicate_name(KnowledgebaseService.query, name=search_name, tenant_id=current_user.id, status=StatusEnum.VALID.value) + + req["id"] = get_uuid() + req["name"] = search_name + req["description"] = description + req["tenant_id"] = current_user.id + req["created_by"] = current_user.id + with DB.atomic(): + try: + if not SearchService.save(**req): + return get_data_error_result() + return get_json_result(data={"search_id": req["id"]}) + except Exception as e: + return server_error_response(e) + + +@manager.route("/update", methods=["post"]) # noqa: F821 +@login_required +@validate_request("search_id", "name", "search_config", "tenant_id") +@not_allowed_parameters("id", "created_by", "create_time", "update_time", "create_date", "update_date", "created_by") +def update(): + req = request.get_json() + if not isinstance(req["name"], str): + return get_data_error_result(message="Search name must be string.") + if req["name"].strip() == "": + return get_data_error_result(message="Search name can't be empty.") + if len(req["name"].encode("utf-8")) > DATASET_NAME_LIMIT: + return get_data_error_result(message=f"Search name length is {len(req['name'])} which is large than {DATASET_NAME_LIMIT}") + req["name"] = req["name"].strip() + tenant_id = req["tenant_id"] + e, _ = TenantService.get_by_id(tenant_id) + if not e: + return get_data_error_result(message="Authorizationd identity.") + + search_id = req["search_id"] + if not SearchService.accessible4deletion(search_id, current_user.id): + return get_json_result(data=False, message="No authorization.", code=settings.RetCode.AUTHENTICATION_ERROR) + + try: + search_app = SearchService.query(tenant_id=tenant_id, id=search_id)[0] + if not search_app: + return get_json_result(data=False, message=f"Cannot find search {search_id}", code=settings.RetCode.DATA_ERROR) + + if req["name"].lower() != search_app.name.lower() and len(SearchService.query(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value)) >= 1: + return get_data_error_result(message="Duplicated search name.") + + if "search_config" in req: + current_config = search_app.search_config or {} + new_config = req["search_config"] + + if not isinstance(new_config, dict): + return get_data_error_result(message="search_config must be a JSON object") + + updated_config = {**current_config, **new_config} + req["search_config"] = updated_config + + req.pop("search_id", None) + req.pop("tenant_id", None) + + updated = SearchService.update_by_id(search_id, req) + if not updated: + return get_data_error_result(message="Failed to update search") + + e, updated_search = SearchService.get_by_id(search_id) + if not e: + return get_data_error_result(message="Failed to fetch updated search") + + return get_json_result(data=updated_search.to_dict()) + + except Exception as e: + return server_error_response(e) + + +@manager.route("/detail", methods=["GET"]) # noqa: F821 +@login_required +def detail(): + search_id = request.args["search_id"] + try: + tenants = UserTenantService.query(user_id=current_user.id) + for tenant in tenants: + if SearchService.query(tenant_id=tenant.tenant_id, id=search_id): + break + else: + return get_json_result(data=False, message="Has no permission for this operation.", code=settings.RetCode.OPERATING_ERROR) + + search = SearchService.get_detail(search_id) + if not search: + return get_data_error_result(message="Can't find this Search App!") + return get_json_result(data=search) + except Exception as e: + return server_error_response(e) + + +@manager.route("/list", methods=["POST"]) # noqa: F821 +@login_required +def list_search_app(): + keywords = request.args.get("keywords", "") + page_number = int(request.args.get("page", 0)) + items_per_page = int(request.args.get("page_size", 0)) + orderby = request.args.get("orderby", "create_time") + if request.args.get("desc", "true").lower() == "false": + desc = False + else: + desc = True + + req = request.get_json() + owner_ids = req.get("owner_ids", []) + try: + if not owner_ids: + tenants = TenantService.get_joined_tenants_by_user_id(current_user.id) + tenants = [m["tenant_id"] for m in tenants] + search_apps, total = SearchService.get_by_tenant_ids(tenants, current_user.id, page_number, items_per_page, orderby, desc, keywords) + else: + tenants = owner_ids + search_apps, total = SearchService.get_by_tenant_ids(tenants, current_user.id, 0, 0, orderby, desc, keywords) + search_apps = [search_app for search_app in search_apps if search_app["tenant_id"] in tenants] + total = len(search_apps) + if page_number and items_per_page: + search_apps = search_apps[(page_number - 1) * items_per_page : page_number * items_per_page] + return get_json_result(data={"search_apps": search_apps, "total": total}) + except Exception as e: + return server_error_response(e) + + +@manager.route("/rm", methods=["post"]) # noqa: F821 +@login_required +@validate_request("search_id") +def rm(): + req = request.get_json() + search_id = req["search_id"] + if not SearchService.accessible4deletion(search_id, current_user.id): + return get_json_result(data=False, message="No authorization.", code=settings.RetCode.AUTHENTICATION_ERROR) + + try: + if not SearchService.delete_by_id(search_id): + return get_data_error_result(message=f"Failed to delete search App {search_id}") + return get_json_result(data=True) + except Exception as e: + return server_error_response(e) diff --git a/api/apps/user_app.py b/api/apps/user_app.py index 597f50971b9..b8d66ecba81 100644 --- a/api/apps/user_app.py +++ b/api/apps/user_app.py @@ -16,6 +16,7 @@ import json import logging import re +import secrets from datetime import datetime from flask import redirect, request, session @@ -465,7 +466,7 @@ def log_out(): schema: type: object """ - current_user.access_token = "" + current_user.access_token = f"INVALID_{secrets.token_hex(16)}" current_user.save() logout_user() return get_json_result(data=True) diff --git a/api/constants.py b/api/constants.py index e6a97e2c1b1..ce5cdeb3a8d 100644 --- a/api/constants.py +++ b/api/constants.py @@ -13,9 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -NAME_LENGTH_LIMIT = 2 ** 10 +NAME_LENGTH_LIMIT = 2**10 -IMG_BASE64_PREFIX = 'data:image/png;base64,' +IMG_BASE64_PREFIX = "data:image/png;base64," SERVICE_CONF = "service_conf.yaml" @@ -25,3 +25,4 @@ REQUEST_MAX_WAIT_SEC = 300 DATASET_NAME_LIMIT = 128 +FILE_NAME_LEN_LIMIT = 255 diff --git a/api/db/db_models.py b/api/db/db_models.py index ce71f7b6ff1..3ccfbdba392 100644 --- a/api/db/db_models.py +++ b/api/db/db_models.py @@ -13,16 +13,16 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import hashlib import inspect import logging import operator import os import sys -import typing import time +import typing from enum import Enum from functools import wraps -import hashlib from flask_login import UserMixin from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer @@ -264,14 +264,15 @@ def __init__(self): def with_retry(max_retries=3, retry_delay=1.0): """Decorator: Add retry mechanism to database operations - + Args: max_retries (int): maximum number of retries retry_delay (float): initial retry delay (seconds), will increase exponentially - + Returns: decorated function """ + def decorator(func): @wraps(func) def wrapper(*args, **kwargs): @@ -284,26 +285,28 @@ def wrapper(*args, **kwargs): # get self and method name for logging self_obj = args[0] if args else None func_name = func.__name__ - lock_name = getattr(self_obj, 'lock_name', 'unknown') if self_obj else 'unknown' - + lock_name = getattr(self_obj, "lock_name", "unknown") if self_obj else "unknown" + if retry < max_retries - 1: - current_delay = retry_delay * (2 ** retry) - logging.warning(f"{func_name} {lock_name} failed: {str(e)}, retrying ({retry+1}/{max_retries})") + current_delay = retry_delay * (2**retry) + logging.warning(f"{func_name} {lock_name} failed: {str(e)}, retrying ({retry + 1}/{max_retries})") time.sleep(current_delay) else: logging.error(f"{func_name} {lock_name} failed after all attempts: {str(e)}") - + if last_exception: raise last_exception return False + return wrapper + return decorator class PostgresDatabaseLock: def __init__(self, lock_name, timeout=10, db=None): self.lock_name = lock_name - self.lock_id = int(hashlib.md5(lock_name.encode()).hexdigest(), 16) % (2**31-1) + self.lock_id = int(hashlib.md5(lock_name.encode()).hexdigest(), 16) % (2**31 - 1) self.timeout = int(timeout) self.db = db if db else DB @@ -542,7 +545,7 @@ class LLM(DataBaseModel): max_tokens = IntegerField(default=0) tags = CharField(max_length=255, null=False, help_text="LLM, Text Embedding, Image2Text, Chat, 32k...", index=True) - is_tools = BooleanField(null=False, help_text="support tools", default=False) + is_tools = BooleanField(null=False, help_text="support tools", default=False) status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted, 1: validate)", default="1", index=True) def __str__(self): @@ -796,6 +799,50 @@ class Meta: db_table = "user_canvas_version" +class Search(DataBaseModel): + id = CharField(max_length=32, primary_key=True) + avatar = TextField(null=True, help_text="avatar base64 string") + tenant_id = CharField(max_length=32, null=False, index=True) + name = CharField(max_length=128, null=False, help_text="Search name", index=True) + description = TextField(null=True, help_text="KB description") + created_by = CharField(max_length=32, null=False, index=True) + search_config = JSONField( + null=False, + default={ + "kb_ids": [], + "doc_ids": [], + "similarity_threshold": 0.0, + "vector_similarity_weight": 0.3, + "use_kg": False, + # rerank settings + "rerank_id": "", + "top_k": 1024, + # chat settings + "summary": False, + "chat_id": "", + "llm_setting": { + "temperature": 0.1, + "top_p": 0.3, + "frequency_penalty": 0.7, + "presence_penalty": 0.4, + }, + "chat_settingcross_languages": [], + "highlight": False, + "keyword": False, + "web_search": False, + "related_search": False, + "query_mindmap": False, + }, + ) + status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted, 1: validate)", default="1", index=True) + + def __str__(self): + return self.name + + class Meta: + db_table = "search" + + def migrate_db(): migrator = DatabaseMigrator[settings.DATABASE_TYPE.upper()].value(DB) try: diff --git a/api/db/init_data.py b/api/db/init_data.py index f10aa64617a..b46b27ce6b0 100644 --- a/api/db/init_data.py +++ b/api/db/init_data.py @@ -84,14 +84,14 @@ def init_superuser(): {"role": "user", "content": "Hello!"}], gen_conf={}) if msg.find("ERROR: ") == 0: logging.error( - "'{}' dosen't work. {}".format( + "'{}' doesn't work. {}".format( tenant["llm_id"], msg)) embd_mdl = LLMBundle(tenant["id"], LLMType.EMBEDDING, tenant["embd_id"]) v, c = embd_mdl.encode(["Hello!"]) if c == 0: logging.error( - "'{}' dosen't work!".format( + "'{}' doesn't work!".format( tenant["embd_id"])) diff --git a/api/db/services/__init__.py b/api/db/services/__init__.py index 964a7a17b28..4b3af3ecfb4 100644 --- a/api/db/services/__init__.py +++ b/api/db/services/__init__.py @@ -13,27 +13,87 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import pathlib import re +from pathlib import PurePath + from .user_service import UserService as UserService -def duplicate_name(query_func, **kwargs): - fnm = kwargs["name"] - objs = query_func(**kwargs) - if not objs: - return fnm - ext = pathlib.Path(fnm).suffix #.jpg - nm = re.sub(r"%s$"%ext, "", fnm) - r = re.search(r"\(([0-9]+)\)$", nm) - c = 0 - if r: - c = int(r.group(1)) - nm = re.sub(r"\([0-9]+\)$", "", nm) - c += 1 - nm = f"{nm}({c})" - if ext: - nm += f"{ext}" - - kwargs["name"] = nm - return duplicate_name(query_func, **kwargs) +def split_name_counter(filename: str) -> tuple[str, int | None]: + """ + Splits a filename into main part and counter (if present in parentheses). + + Args: + filename: Input filename string to be parsed + + Returns: + A tuple containing: + - The main filename part (string) + - The counter from parentheses (integer) or None if no counter exists + """ + pattern = re.compile(r"^(.*?)\((\d+)\)$") + + match = pattern.search(filename) + if match: + main_part = match.group(1).rstrip() + bracket_part = match.group(2) + return main_part, int(bracket_part) + + return filename, None + + +def duplicate_name(query_func, **kwargs) -> str: + """ + Generates a unique filename by appending/incrementing a counter when duplicates exist. + + Continuously checks for name availability using the provided query function, + automatically appending (1), (2), etc. until finding an available name or + reaching maximum retries. + + Args: + query_func: Callable that accepts keyword arguments and returns: + - True if name exists (should be modified) + - False if name is available + **kwargs: Must contain 'name' key with original filename to check + + Returns: + str: Available filename, either: + - Original name (if available) + - Modified name with counter (e.g., "file(1).txt") + + Raises: + KeyError: If 'name' key not provided in kwargs + RuntimeError: If unable to generate unique name after maximum retries + + Example: + >>> def name_exists(name): return name in existing_files + >>> duplicate_name(name_exists, name="document.pdf") + 'document(1).pdf' # If original exists + """ + MAX_RETRIES = 1000 + + if "name" not in kwargs: + raise KeyError("Arguments must contain 'name' key") + + original_name = kwargs["name"] + current_name = original_name + retries = 0 + + while retries < MAX_RETRIES: + if not query_func(**kwargs): + return current_name + + path = PurePath(current_name) + stem = path.stem + suffix = path.suffix + + main_part, counter = split_name_counter(stem) + counter = counter + 1 if counter else 1 + + new_name = f"{main_part}({counter}){suffix}" + + kwargs["name"] = new_name + current_name = new_name + retries += 1 + + raise RuntimeError(f"Failed to generate unique name within {MAX_RETRIES} attempts. Original: {original_name}") diff --git a/api/db/services/canvas_service.py b/api/db/services/canvas_service.py index bca6a6be91c..8bcb7b1bc47 100644 --- a/api/db/services/canvas_service.py +++ b/api/db/services/canvas_service.py @@ -73,11 +73,11 @@ def get_by_tenant_id(cls, pid): User.nickname, User.avatar.alias('tenant_avatar'), ] - angents = cls.model.select(*fields) \ + agents = cls.model.select(*fields) \ .join(User, on=(cls.model.user_id == User.id)) \ .where(cls.model.id == pid) # obj = cls.model.query(id=pid)[0] - return True, angents.dicts()[0] + return True, agents.dicts()[0] except Exception as e: print(e) return False, None @@ -100,25 +100,25 @@ def get_by_tenant_ids(cls, joined_tenant_ids, user_id, cls.model.update_time ] if keywords: - angents = cls.model.select(*fields).join(User, on=(cls.model.user_id == User.id)).where( + agents = cls.model.select(*fields).join(User, on=(cls.model.user_id == User.id)).where( ((cls.model.user_id.in_(joined_tenant_ids) & (cls.model.permission == TenantPermission.TEAM.value)) | ( cls.model.user_id == user_id)), (fn.LOWER(cls.model.title).contains(keywords.lower())) ) else: - angents = cls.model.select(*fields).join(User, on=(cls.model.user_id == User.id)).where( + agents = cls.model.select(*fields).join(User, on=(cls.model.user_id == User.id)).where( ((cls.model.user_id.in_(joined_tenant_ids) & (cls.model.permission == TenantPermission.TEAM.value)) | ( cls.model.user_id == user_id)) ) if desc: - angents = angents.order_by(cls.model.getter_by(orderby).desc()) + agents = agents.order_by(cls.model.getter_by(orderby).desc()) else: - angents = angents.order_by(cls.model.getter_by(orderby).asc()) - count = angents.count() - angents = angents.paginate(page_number, items_per_page) - return list(angents.dicts()), count + agents = agents.order_by(cls.model.getter_by(orderby).asc()) + count = agents.count() + agents = agents.paginate(page_number, items_per_page) + return list(agents.dicts()), count def completion(tenant_id, agent_id, question, session_id=None, stream=True, **kwargs): @@ -173,6 +173,19 @@ def completion(tenant_id, agent_id, question, session_id=None, stream=True, **kw conv.reference = [] conv.reference.append({"chunks": [], "doc_aggs": []}) + kwargs_changed = False + if kwargs: + query = canvas.get_preset_param() + if query: + for ele in query: + if ele["key"] in kwargs: + if ele["value"] != kwargs[ele["key"]]: + ele["value"] = kwargs[ele["key"]] + kwargs_changed = True + if kwargs_changed: + conv.dsl = json.loads(str(canvas)) + API4ConversationService.update_by_id(session_id, {"dsl": conv.dsl}) + final_ans = {"reference": [], "content": ""} if stream: try: @@ -281,8 +294,22 @@ def completionOpenAI(tenant_id, agent_id, question, session_id=None, stream=True "source": "agent", "dsl": cvs.dsl } + canvas.messages.append({"role": "user", "content": question, "id": message_id}) + canvas.add_user_input(question) + API4ConversationService.save(**conv) conv = API4Conversation(**conv) + if not conv.message: + conv.message = [] + conv.message.append({ + "role": "user", + "content": question, + "id": message_id + }) + + if not conv.reference: + conv.reference = [] + conv.reference.append({"chunks": [], "doc_aggs": []}) # Handle existing session else: @@ -318,7 +345,7 @@ def completionOpenAI(tenant_id, agent_id, question, session_id=None, stream=True if stream: try: completion_tokens = 0 - for ans in canvas.run(stream=True): + for ans in canvas.run(stream=True, bypass_begin=True): if ans.get("running_status"): completion_tokens += len(tiktokenenc.encode(ans.get("content", ""))) yield "data: " + json.dumps( @@ -381,7 +408,7 @@ def completionOpenAI(tenant_id, agent_id, question, session_id=None, stream=True else: # Non-streaming mode try: all_answer_content = "" - for answer in canvas.run(stream=False): + for answer in canvas.run(stream=False, bypass_begin=True): if answer.get("running_status"): continue diff --git a/api/db/services/common_service.py b/api/db/services/common_service.py index 95f5d759f0f..7645b43d4e4 100644 --- a/api/db/services/common_service.py +++ b/api/db/services/common_service.py @@ -254,7 +254,7 @@ def delete_by_id(cls, pid): # Returns: # Number of records deleted return cls.model.delete().where(cls.model.id == pid).execute() - + @classmethod @DB.connection_context() def delete_by_ids(cls, pids): diff --git a/api/db/services/conversation_service.py b/api/db/services/conversation_service.py index 575eea6955a..5e247c21cc3 100644 --- a/api/db/services/conversation_service.py +++ b/api/db/services/conversation_service.py @@ -90,17 +90,18 @@ def completion(tenant_id, chat_id, question, name="New session", session_id=None "user_id": kwargs.get("user_id", "") } ConversationService.save(**conv) - yield "data:" + json.dumps({"code": 0, "message": "", - "data": { - "answer": conv["message"][0]["content"], - "reference": {}, - "audio_binary": None, - "id": None, - "session_id": session_id - }}, - ensure_ascii=False) + "\n\n" - yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n" - return + if stream: + yield "data:" + json.dumps({"code": 0, "message": "", + "data": { + "answer": conv["message"][0]["content"], + "reference": {}, + "audio_binary": None, + "id": None, + "session_id": session_id + }}, + ensure_ascii=False) + "\n\n" + yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n" + return conv = ConversationService.query(id=session_id, dialog_id=chat_id) if not conv: @@ -123,6 +124,8 @@ def completion(tenant_id, chat_id, question, name="New session", session_id=None message_id = msg[-1].get("id") e, dia = DialogService.get_by_id(conv.dialog_id) + kb_ids = kwargs.get("kb_ids",[]) + dia.kb_ids = list(set(dia.kb_ids + kb_ids)) if not conv.reference: conv.reference = [] conv.message.append({"role": "assistant", "content": "", "id": message_id}) diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py index c498b924a14..211178a51b7 100644 --- a/api/db/services/dialog_service.py +++ b/api/db/services/dialog_service.py @@ -127,9 +127,71 @@ def chat_solo(dialog, messages, stream=True): yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, answer), "prompt": "", "created_at": time.time()} +def get_models(dialog): + embd_mdl, chat_mdl, rerank_mdl, tts_mdl = None, None, None, None + kbs = KnowledgebaseService.get_by_ids(dialog.kb_ids) + embedding_list = list(set([kb.embd_id for kb in kbs])) + if len(embedding_list) > 1: + raise Exception("**ERROR**: Knowledge bases use different embedding models.") + + if embedding_list: + embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING, embedding_list[0]) + if not embd_mdl: + raise LookupError("Embedding model(%s) not found" % embedding_list[0]) + + if llm_id2llm_type(dialog.llm_id) == "image2text": + chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id) + else: + chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id) + + if dialog.rerank_id: + rerank_mdl = LLMBundle(dialog.tenant_id, LLMType.RERANK, dialog.rerank_id) + + if dialog.prompt_config.get("tts"): + tts_mdl = LLMBundle(dialog.tenant_id, LLMType.TTS) + return kbs, embd_mdl, rerank_mdl, chat_mdl, tts_mdl + + +BAD_CITATION_PATTERNS = [ + re.compile(r"\(\s*ID\s*[: ]*\s*(\d+)\s*\)"), # (ID: 12) + re.compile(r"\[\s*ID\s*[: ]*\s*(\d+)\s*\]"), # [ID: 12] + re.compile(r"【\s*ID\s*[: ]*\s*(\d+)\s*】"), # 【ID: 12】 + re.compile(r"ref\s*(\d+)", flags=re.IGNORECASE), # ref12、REF 12 +] + + +def repair_bad_citation_formats(answer: str, kbinfos: dict, idx: set): + max_index = len(kbinfos["chunks"]) + + def safe_add(i): + if 0 <= i < max_index: + idx.add(i) + return True + return False + + def find_and_replace(pattern, group_index=1, repl=lambda i: f"ID:{i}", flags=0): + nonlocal answer + + def replacement(match): + try: + i = int(match.group(group_index)) + if safe_add(i): + return f"[{repl(i)}]" + except Exception: + pass + return match.group(0) + + answer = re.sub(pattern, replacement, answer, flags=flags) + + for pattern in BAD_CITATION_PATTERNS: + find_and_replace(pattern) + + return answer, idx + + def chat(dialog, messages, stream=True, **kwargs): assert messages[-1]["role"] == "user", "The last content of this conversation is not from user." - if not dialog.kb_ids: + if not dialog.kb_ids and not dialog.prompt_config.get("tavily_api_key"): for ans in chat_solo(dialog, messages, stream): yield ans return @@ -154,45 +216,19 @@ def chat(dialog, messages, stream=True, **kwargs): langfuse.trace = langfuse_tracer.trace(name=f"{dialog.name}-{llm_model_config['llm_name']}") check_langfuse_tracer_ts = timer() - - kbs = KnowledgebaseService.get_by_ids(dialog.kb_ids) - embedding_list = list(set([kb.embd_id for kb in kbs])) - if len(embedding_list) != 1: - yield {"answer": "**ERROR**: Knowledge bases use different embedding models.", "reference": []} - return {"answer": "**ERROR**: Knowledge bases use different embedding models.", "reference": []} - - embedding_model_name = embedding_list[0] + kbs, embd_mdl, rerank_mdl, chat_mdl, tts_mdl = get_models(dialog) + toolcall_session, tools = kwargs.get("toolcall_session"), kwargs.get("tools") + if toolcall_session and tools: + chat_mdl.bind_tools(toolcall_session, tools) + bind_models_ts = timer() retriever = settings.retrievaler - questions = [m["content"] for m in messages if m["role"] == "user"][-3:] attachments = kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None if "doc_ids" in messages[-1]: attachments = messages[-1]["doc_ids"] - - create_retriever_ts = timer() - - embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING, embedding_model_name) - if not embd_mdl: - raise LookupError("Embedding model(%s) not found" % embedding_model_name) - - bind_embedding_ts = timer() - - if llm_id2llm_type(dialog.llm_id) == "image2text": - chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id) - else: - chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id) - toolcall_session, tools = kwargs.get("toolcall_session"), kwargs.get("tools") - if toolcall_session and tools: - chat_mdl.bind_tools(toolcall_session, tools) - - bind_llm_ts = timer() - prompt_config = dialog.prompt_config field_map = KnowledgebaseService.get_field_map(dialog.kb_ids) - tts_mdl = None - if prompt_config.get("tts"): - tts_mdl = LLMBundle(dialog.tenant_id, LLMType.TTS) # try to use sql if field mapping is good to go if field_map: logging.debug("Use SQL to retrieval:{}".format(questions[-1])) @@ -217,26 +253,18 @@ def chat(dialog, messages, stream=True, **kwargs): if prompt_config.get("cross_languages"): questions = [cross_languages(dialog.tenant_id, dialog.llm_id, questions[0], prompt_config["cross_languages"])] - refine_question_ts = timer() + if prompt_config.get("keyword", False): + questions[-1] += keyword_extraction(chat_mdl, questions[-1]) - rerank_mdl = None - if dialog.rerank_id: - rerank_mdl = LLMBundle(dialog.tenant_id, LLMType.RERANK, dialog.rerank_id) + refine_question_ts = timer() - bind_reranker_ts = timer() - generate_keyword_ts = bind_reranker_ts thought = "" kbinfos = {"total": 0, "chunks": [], "doc_aggs": []} if "knowledge" not in [p["key"] for p in prompt_config["parameters"]]: knowledges = [] else: - if prompt_config.get("keyword", False): - questions[-1] += keyword_extraction(chat_mdl, questions[-1]) - generate_keyword_ts = timer() - tenant_ids = list(set([kb.tenant_id for kb in kbs])) - knowledges = [] if prompt_config.get("reasoning", False): reasoner = DeepResearcher( @@ -252,21 +280,22 @@ def chat(dialog, messages, stream=True, **kwargs): elif stream: yield think else: - kbinfos = retriever.retrieval( - " ".join(questions), - embd_mdl, - tenant_ids, - dialog.kb_ids, - 1, - dialog.top_n, - dialog.similarity_threshold, - dialog.vector_similarity_weight, - doc_ids=attachments, - top=dialog.top_k, - aggs=False, - rerank_mdl=rerank_mdl, - rank_feature=label_question(" ".join(questions), kbs), - ) + if embd_mdl: + kbinfos = retriever.retrieval( + " ".join(questions), + embd_mdl, + tenant_ids, + dialog.kb_ids, + 1, + dialog.top_n, + dialog.similarity_threshold, + dialog.vector_similarity_weight, + doc_ids=attachments, + top=dialog.top_k, + aggs=False, + rerank_mdl=rerank_mdl, + rank_feature=label_question(" ".join(questions), kbs), + ) if prompt_config.get("tavily_api_key"): tav = Tavily(prompt_config["tavily_api_key"]) tav_res = tav.retrieve_chunks(" ".join(questions)) @@ -302,41 +331,8 @@ def chat(dialog, messages, stream=True, **kwargs): if "max_tokens" in gen_conf: gen_conf["max_tokens"] = min(gen_conf["max_tokens"], max_tokens - used_token_count) - def repair_bad_citation_formats(answer: str, kbinfos: dict, idx: set): - max_index = len(kbinfos["chunks"]) - - def safe_add(i): - if 0 <= i < max_index: - idx.add(i) - return True - return False - - def find_and_replace(pattern, group_index=1, repl=lambda i: f"##{i}$$", flags=0): - nonlocal answer - for match in re.finditer(pattern, answer, flags=flags): - try: - i = int(match.group(group_index)) - if safe_add(i): - answer = answer.replace(match.group(0), repl(i)) - except Exception: - continue - - find_and_replace(r"\(\s*ID:\s*(\d+)\s*\)") # (ID: 12) - find_and_replace(r"ID[: ]+(\d+)") # ID: 12, ID 12 - find_and_replace(r"\$\$(\d+)\$\$") # $$12$$ - find_and_replace(r"\$\[(\d+)\]\$") # $[12]$ - find_and_replace(r"\$\$(\d+)\${2,}") # $$12$$$$ - find_and_replace(r"\$(\d+)\$") # $12$ - find_and_replace(r"(#{2,})(\d+)(\${2,})", group_index=2) # 2+ # and 2+ $ - find_and_replace(r"(#{2,})(\d+)(#{1,})", group_index=2) # 2+ # and 1+ # - find_and_replace(r"##(\d+)#{2,}") # ##12### - find_and_replace(r"【(\d+)】") # 【12】 - find_and_replace(r"ref\s*(\d+)", flags=re.IGNORECASE) # ref12, ref 12, REF 12 - - return answer, idx - def decorate_answer(answer): - nonlocal prompt_config, knowledges, kwargs, kbinfos, prompt, retrieval_ts, questions, langfuse_tracer + nonlocal embd_mdl, prompt_config, knowledges, kwargs, kbinfos, prompt, retrieval_ts, questions, langfuse_tracer refs = [] ans = answer.split("") @@ -346,9 +342,8 @@ def decorate_answer(answer): answer = ans[1] if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)): - answer = re.sub(r"##[ij]\$\$", "", answer, flags=re.DOTALL) idx = set([]) - if not re.search(r"##[0-9]+\$\$", answer): + if embd_mdl and not re.search(r"\[ID:([0-9]+)\]", answer): answer, idx = retriever.insert_citations( answer, [ck["content_ltks"] for ck in kbinfos["chunks"]], @@ -358,7 +353,7 @@ def decorate_answer(answer): vtweight=dialog.vector_similarity_weight, ) else: - for match in re.finditer(r"##([0-9]+)\$\$", answer): + for match in re.finditer(r"\[ID:([0-9]+)\]", answer): i = int(match.group(1)) if i < len(kbinfos["chunks"]): idx.add(i) @@ -383,13 +378,9 @@ def decorate_answer(answer): total_time_cost = (finish_chat_ts - chat_start_ts) * 1000 check_llm_time_cost = (check_llm_ts - chat_start_ts) * 1000 check_langfuse_tracer_cost = (check_langfuse_tracer_ts - check_llm_ts) * 1000 - create_retriever_time_cost = (create_retriever_ts - check_langfuse_tracer_ts) * 1000 - bind_embedding_time_cost = (bind_embedding_ts - create_retriever_ts) * 1000 - bind_llm_time_cost = (bind_llm_ts - bind_embedding_ts) * 1000 - refine_question_time_cost = (refine_question_ts - bind_llm_ts) * 1000 - bind_reranker_time_cost = (bind_reranker_ts - refine_question_ts) * 1000 - generate_keyword_time_cost = (generate_keyword_ts - bind_reranker_ts) * 1000 - retrieval_time_cost = (retrieval_ts - generate_keyword_ts) * 1000 + bind_embedding_time_cost = (bind_models_ts - check_langfuse_tracer_ts) * 1000 + refine_question_time_cost = (refine_question_ts - bind_models_ts) * 1000 + retrieval_time_cost = (retrieval_ts - refine_question_ts) * 1000 generate_result_time_cost = (finish_chat_ts - retrieval_ts) * 1000 tk_num = num_tokens_from_string(think + answer) @@ -400,12 +391,8 @@ def decorate_answer(answer): f" - Total: {total_time_cost:.1f}ms\n" f" - Check LLM: {check_llm_time_cost:.1f}ms\n" f" - Check Langfuse tracer: {check_langfuse_tracer_cost:.1f}ms\n" - f" - Create retriever: {create_retriever_time_cost:.1f}ms\n" - f" - Bind embedding: {bind_embedding_time_cost:.1f}ms\n" - f" - Bind LLM: {bind_llm_time_cost:.1f}ms\n" - f" - Multi-turn optimization: {refine_question_time_cost:.1f}ms\n" - f" - Bind reranker: {bind_reranker_time_cost:.1f}ms\n" - f" - Generate keyword: {generate_keyword_time_cost:.1f}ms\n" + f" - Bind models: {bind_embedding_time_cost:.1f}ms\n" + f" - Query refinement(LLM): {refine_question_time_cost:.1f}ms\n" f" - Retrieval: {retrieval_time_cost:.1f}ms\n" f" - Generate answer: {generate_result_time_cost:.1f}ms\n\n" "## Token usage:\n" @@ -569,7 +556,7 @@ def tts(tts_mdl, text): return binascii.hexlify(bin).decode("utf-8") -def ask(question, kb_ids, tenant_id): +def ask(question, kb_ids, tenant_id, chat_llm_name=None): kbs = KnowledgebaseService.get_by_ids(kb_ids) embedding_list = list(set([kb.embd_id for kb in kbs])) @@ -577,7 +564,7 @@ def ask(question, kb_ids, tenant_id): retriever = settings.retrievaler if not is_knowledge_graph else settings.kg_retrievaler embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embedding_list[0]) - chat_mdl = LLMBundle(tenant_id, LLMType.CHAT) + chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, chat_llm_name) max_tokens = chat_mdl.max_length tenant_ids = list(set([kb.tenant_id for kb in kbs])) kbinfos = retriever.retrieval(question, embd_mdl, tenant_ids, kb_ids, 1, 12, 0.1, 0.3, aggs=False, rank_feature=label_question(question, kbs)) diff --git a/api/db/services/document_service.py b/api/db/services/document_service.py index 284df853a73..8b7bc666000 100644 --- a/api/db/services/document_service.py +++ b/api/db/services/document_service.py @@ -27,6 +27,7 @@ from peewee import fn from api import settings +from api.constants import IMG_BASE64_PREFIX from api.db import FileType, LLMType, ParserType, StatusEnum, TaskStatus, UserTenantRole from api.db.db_models import DB, Document, Knowledgebase, Task, Tenant, UserTenant from api.db.db_utils import bulk_insert_into_db @@ -34,7 +35,7 @@ from api.db.services.knowledgebase_service import KnowledgebaseService from api.utils import current_timestamp, get_format_time, get_uuid from rag.nlp import rag_tokenizer, search -from rag.settings import get_svr_queue_name +from rag.settings import get_svr_queue_name, SVR_CONSUMER_GROUP_NAME from rag.utils.redis_conn import REDIS_CONN from rag.utils.storage_factory import STORAGE_IMPL from rag.utils.doc_store_conn import OrderByExpr @@ -147,7 +148,26 @@ def insert(cls, doc): def remove_document(cls, doc, tenant_id): cls.clear_chunk_num(doc.id) try: + page = 0 + page_size = 1000 + all_chunk_ids = [] + while True: + chunks = settings.docStoreConn.search(["img_id"], [], {"doc_id": doc.id}, [], OrderByExpr(), + page * page_size, page_size, search.index_name(tenant_id), + [doc.kb_id]) + chunk_ids = settings.docStoreConn.getChunkIds(chunks) + if not chunk_ids: + break + all_chunk_ids.extend(chunk_ids) + page += 1 + for cid in all_chunk_ids: + if STORAGE_IMPL.obj_exist(doc.kb_id, cid): + STORAGE_IMPL.rm(doc.kb_id, cid) + if doc.thumbnail and not doc.thumbnail.startswith(IMG_BASE64_PREFIX): + if STORAGE_IMPL.obj_exist(doc.kb_id, doc.thumbnail): + STORAGE_IMPL.rm(doc.kb_id, doc.thumbnail) settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id) + graph_source = settings.docStoreConn.getFields( settings.docStoreConn.search(["source_id"], [], {"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]}, [], OrderByExpr(), 0, 1, search.index_name(tenant_id), [doc.kb_id]), ["source_id"] ) @@ -464,7 +484,8 @@ def update_progress(cls): if t.progress == -1: bad += 1 prg += t.progress if t.progress >= 0 else 0 - msg.append(t.progress_msg) + if t.progress_msg.strip(): + msg.append(t.progress_msg) if t.task_type == "raptor": has_raptor = True elif t.task_type == "graphrag": @@ -494,6 +515,8 @@ def update_progress(cls): info["progress"] = prg if msg: info["progress_msg"] = msg + else: + info["progress_msg"] = "%d tasks are ahead in the queue..."%get_queue_length(priority) cls.update_by_id(d["id"], info) except Exception as e: if str(e).find("'0'") < 0: @@ -542,6 +565,11 @@ def new_task(): assert REDIS_CONN.queue_product(get_svr_queue_name(priority), message=task), "Can't access Redis. Please check the Redis' status." +def get_queue_length(priority): + group_info = REDIS_CONN.queue_info(get_svr_queue_name(priority), SVR_CONSUMER_GROUP_NAME) + return int(group_info.get("lag", 0)) + + def doc_upload_and_parse(conversation_id, file_objs, user_id): from api.db.services.api_service import API4ConversationService from api.db.services.conversation_service import ConversationService diff --git a/api/db/services/file_service.py b/api/db/services/file_service.py index 803164f9eef..25c856531b5 100644 --- a/api/db/services/file_service.py +++ b/api/db/services/file_service.py @@ -21,6 +21,7 @@ from flask_login import current_user from peewee import fn +from api.constants import FILE_NAME_LEN_LIMIT from api.db import KNOWLEDGEBASE_FOLDER_NAME, FileSource, FileType, ParserType from api.db.db_models import DB, Document, File, File2Document, Knowledgebase from api.db.services import duplicate_name @@ -412,8 +413,8 @@ def upload_document(self, kb, file_objs, user_id): MAX_FILE_NUM_PER_USER = int(os.environ.get("MAX_FILE_NUM_PER_USER", 0)) if MAX_FILE_NUM_PER_USER > 0 and DocumentService.get_doc_count(kb.tenant_id) >= MAX_FILE_NUM_PER_USER: raise RuntimeError("Exceed the maximum file number of a free user!") - if len(file.filename.encode("utf-8")) >= 128: - raise RuntimeError("Exceed the maximum length of file name!") + if len(file.filename.encode("utf-8")) > FILE_NAME_LEN_LIMIT: + raise RuntimeError(f"File name must be {FILE_NAME_LEN_LIMIT} bytes or less.") filename = duplicate_name(DocumentService.query, name=file.filename, kb_id=kb.id) filetype = filename_type(filename) @@ -492,4 +493,3 @@ def get_parser(doc_type, filename, default): if re.search(r"\.(eml)$", filename): return ParserType.EMAIL.value return default - diff --git a/api/db/services/llm_service.py b/api/db/services/llm_service.py index 02e66944ec0..e124b5b16ac 100644 --- a/api/db/services/llm_service.py +++ b/api/db/services/llm_service.py @@ -169,7 +169,7 @@ def increase_usage(cls, tenant_id, llm_type, used_tokens, llm_name=None): return 0 llm_map = { - LLMType.EMBEDDING.value: tenant.embd_id, + LLMType.EMBEDDING.value: tenant.embd_id if not llm_name else llm_name, LLMType.SPEECH2TEXT.value: tenant.asr_id, LLMType.IMAGE2TEXT.value: tenant.img2txt_id, LLMType.CHAT.value: tenant.llm_id if not llm_name else llm_name, @@ -235,7 +235,8 @@ def encode(self, texts: list): generation = self.trace.generation(name="encode", model=self.llm_name, input={"texts": texts}) embeddings, used_tokens = self.mdl.encode(texts) - if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens): + llm_name = getattr(self, "llm_name", None) + if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, llm_name): logging.error("LLMBundle.encode can't update token usage for {}/EMBEDDING used_tokens: {}".format(self.tenant_id, used_tokens)) if self.langfuse: @@ -248,7 +249,8 @@ def encode_queries(self, query: str): generation = self.trace.generation(name="encode_queries", model=self.llm_name, input={"query": query}) emd, used_tokens = self.mdl.encode_queries(query) - if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens): + llm_name = getattr(self, "llm_name", None) + if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, llm_name): logging.error("LLMBundle.encode_queries can't update token usage for {}/EMBEDDING used_tokens: {}".format(self.tenant_id, used_tokens)) if self.langfuse: diff --git a/api/db/services/search_service.py b/api/db/services/search_service.py new file mode 100644 index 00000000000..c5c812cc99f --- /dev/null +++ b/api/db/services/search_service.py @@ -0,0 +1,110 @@ +# +# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from datetime import datetime + +from peewee import fn + +from api.db import StatusEnum +from api.db.db_models import DB, Search, User +from api.db.services.common_service import CommonService +from api.utils import current_timestamp, datetime_format + + +class SearchService(CommonService): + model = Search + + @classmethod + def save(cls, **kwargs): + kwargs["create_time"] = current_timestamp() + kwargs["create_date"] = datetime_format(datetime.now()) + kwargs["update_time"] = current_timestamp() + kwargs["update_date"] = datetime_format(datetime.now()) + obj = cls.model.create(**kwargs) + return obj + + @classmethod + @DB.connection_context() + def accessible4deletion(cls, search_id, user_id) -> bool: + search = ( + cls.model.select(cls.model.id) + .where( + cls.model.id == search_id, + cls.model.created_by == user_id, + cls.model.status == StatusEnum.VALID.value, + ) + .first() + ) + return search is not None + + @classmethod + @DB.connection_context() + def get_detail(cls, search_id): + fields = [ + cls.model.id, + cls.model.avatar, + cls.model.tenant_id, + cls.model.name, + cls.model.description, + cls.model.created_by, + cls.model.search_config, + cls.model.update_time, + User.nickname, + User.avatar.alias("tenant_avatar"), + ] + search = ( + cls.model.select(*fields) + .join(User, on=((User.id == cls.model.tenant_id) & (User.status == StatusEnum.VALID.value))) + .where((cls.model.id == search_id) & (cls.model.status == StatusEnum.VALID.value)) + .first() + .to_dict() + ) + return search + + @classmethod + @DB.connection_context() + def get_by_tenant_ids(cls, joined_tenant_ids, user_id, page_number, items_per_page, orderby, desc, keywords): + fields = [ + cls.model.id, + cls.model.avatar, + cls.model.tenant_id, + cls.model.name, + cls.model.description, + cls.model.created_by, + cls.model.status, + cls.model.update_time, + cls.model.create_time, + User.nickname, + User.avatar.alias("tenant_avatar"), + ] + query = ( + cls.model.select(*fields) + .join(User, on=(cls.model.tenant_id == User.id)) + .where(((cls.model.tenant_id.in_(joined_tenant_ids)) | (cls.model.tenant_id == user_id)) & (cls.model.status == StatusEnum.VALID.value)) + ) + + if keywords: + query = query.where(fn.LOWER(cls.model.name).contains(keywords.lower())) + if desc: + query = query.order_by(cls.model.getter_by(orderby).desc()) + else: + query = query.order_by(cls.model.getter_by(orderby).asc()) + + count = query.count() + + if page_number and items_per_page: + query = query.paginate(page_number, items_per_page) + + return list(query.dicts()), count diff --git a/api/db/services/task_service.py b/api/db/services/task_service.py index 4693a6684bd..5fd0eefc3ca 100644 --- a/api/db/services/task_service.py +++ b/api/db/services/task_service.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import logging import os import random import xxhash @@ -256,36 +257,55 @@ def do_cancel(cls, id): @DB.connection_context() def update_progress(cls, id, info): """Update the progress information for a task. - + This method updates both the progress message and completion percentage of a task. It handles platform-specific behavior (macOS vs others) and uses database locking when necessary to ensure thread safety. - + + Update Rules: + - progress_msg: Always appends the new message to the existing one, and trims the result to max 3000 lines. + - progress: Only updates if the current progress is not -1 AND + (the new progress is -1 OR greater than the existing progress), + to avoid overwriting valid progress with invalid or regressive values. + Args: id (str): The unique identifier of the task to update. info (dict): Dictionary containing progress information with keys: - progress_msg (str, optional): Progress message to append - progress (float, optional): Progress percentage (0.0 to 1.0) """ + task = cls.model.get_by_id(id) + if not task: + logging.warning("Update_progress error: task not found") + return + if os.environ.get("MACOS"): if info["progress_msg"]: - task = cls.model.get_by_id(id) progress_msg = trim_header_by_lines(task.progress_msg + "\n" + info["progress_msg"], 3000) cls.model.update(progress_msg=progress_msg).where(cls.model.id == id).execute() if "progress" in info: - cls.model.update(progress=info["progress"]).where( - cls.model.id == id + prog = info["progress"] + cls.model.update(progress=prog).where( + (cls.model.id == id) & + ( + (cls.model.progress != -1) & + ((prog == -1) | (prog > cls.model.progress)) + ) ).execute() return with DB.lock("update_progress", -1): if info["progress_msg"]: - task = cls.model.get_by_id(id) progress_msg = trim_header_by_lines(task.progress_msg + "\n" + info["progress_msg"], 3000) cls.model.update(progress_msg=progress_msg).where(cls.model.id == id).execute() if "progress" in info: - cls.model.update(progress=info["progress"]).where( - cls.model.id == id + prog = info["progress"] + cls.model.update(progress=prog).where( + (cls.model.id == id) & + ( + (cls.model.progress != -1) & + ((prog == -1) | (prog > cls.model.progress)) + ) ).execute() diff --git a/api/db/services/user_service.py b/api/db/services/user_service.py index 1edd46c107f..e8344cb43ba 100644 --- a/api/db/services/user_service.py +++ b/api/db/services/user_service.py @@ -15,6 +15,7 @@ # import hashlib from datetime import datetime +import logging import peewee from werkzeug.security import generate_password_hash, check_password_hash @@ -39,6 +40,30 @@ class UserService(CommonService): """ model = User + @classmethod + @DB.connection_context() + def query(cls, cols=None, reverse=None, order_by=None, **kwargs): + if 'access_token' in kwargs: + access_token = kwargs['access_token'] + + # Reject empty, None, or whitespace-only access tokens + if not access_token or not str(access_token).strip(): + logging.warning("UserService.query: Rejecting empty access_token query") + return cls.model.select().where(cls.model.id == "INVALID_EMPTY_TOKEN") # Returns empty result + + # Reject tokens that are too short (should be UUID, 32+ chars) + if len(str(access_token).strip()) < 32: + logging.warning(f"UserService.query: Rejecting short access_token query: {len(str(access_token))} chars") + return cls.model.select().where(cls.model.id == "INVALID_SHORT_TOKEN") # Returns empty result + + # Reject tokens that start with "INVALID_" (from logout) + if str(access_token).startswith("INVALID_"): + logging.warning("UserService.query: Rejecting invalidated access_token") + return cls.model.select().where(cls.model.id == "INVALID_LOGOUT_TOKEN") # Returns empty result + + # Call parent query method for valid requests + return super().query(cols=cols, reverse=reverse, order_by=order_by, **kwargs) + @classmethod @DB.connection_context() def filter_by_id(cls, user_id): diff --git a/api/ragflow_server.py b/api/ragflow_server.py index 024492cecb1..75bc8916c7d 100644 --- a/api/ragflow_server.py +++ b/api/ragflow_server.py @@ -18,9 +18,9 @@ # from beartype.claw import beartype_all # <-- you didn't sign up for this # beartype_all(conf=BeartypeConf(violation_type=UserWarning)) # <-- emit warnings from all code -from api.utils.log_utils import initRootLogger +from api.utils.log_utils import init_root_logger from plugin import GlobalPluginManager -initRootLogger("ragflow_server") +init_root_logger("ragflow_server") import logging import os @@ -28,7 +28,6 @@ import sys import time import traceback -from concurrent.futures import ThreadPoolExecutor import threading import uuid @@ -125,8 +124,16 @@ def signal_handler(sig, frame): signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGTERM, signal_handler) - thread = ThreadPoolExecutor(max_workers=1) - thread.submit(update_progress) + def delayed_start_update_progress(): + logging.info("Starting update_progress thread (delayed)") + t = threading.Thread(target=update_progress, daemon=True) + t.start() + + if RuntimeConfig.DEBUG: + if os.environ.get("WERKZEUG_RUN_MAIN") == "true": + threading.Timer(1.0, delayed_start_update_progress).start() + else: + threading.Timer(1.0, delayed_start_update_progress).start() # start http server try: diff --git a/api/settings.py b/api/settings.py index 2d743f90474..22e9d03f461 100644 --- a/api/settings.py +++ b/api/settings.py @@ -15,6 +15,7 @@ # import json import os +import secrets from datetime import date from enum import Enum, IntEnum @@ -73,6 +74,25 @@ BUILTIN_EMBEDDING_MODELS = ["BAAI/bge-large-zh-v1.5@BAAI", "maidalun1020/bce-embedding-base_v1@Youdao"] +def get_or_create_secret_key(): + secret_key = os.environ.get("RAGFLOW_SECRET_KEY") + if secret_key and len(secret_key) >= 32: + return secret_key + + # Check if there's a configured secret key + configured_key = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("secret_key") + if configured_key and configured_key != str(date.today()) and len(configured_key) >= 32: + return configured_key + + # Generate a new secure key and warn about it + import logging + new_key = secrets.token_hex(32) + logging.warning( + "SECURITY WARNING: Using auto-generated SECRET_KEY. " + f"Generated key: {new_key}" + ) + return new_key + def init_settings(): global LLM, LLM_FACTORY, LLM_BASE_URL, LIGHTEN, DATABASE_TYPE, DATABASE, FACTORY_LLM_INFOS, REGISTER_ENABLED @@ -121,7 +141,7 @@ def init_settings(): HOST_IP = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("host", "127.0.0.1") HOST_PORT = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("http_port") - SECRET_KEY = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("secret_key", str(date.today())) + SECRET_KEY = get_or_create_secret_key() global AUTHENTICATION_CONF, CLIENT_AUTHENTICATION, HTTP_APP_KEY, GITHUB_OAUTH, FEISHU_OAUTH, OAUTH_CONFIG # authentication diff --git a/api/utils/api_utils.py b/api/utils/api_utils.py index c0f2c195703..8368d9ad421 100644 --- a/api/utils/api_utils.py +++ b/api/utils/api_utils.py @@ -428,11 +428,11 @@ def verify_embedding_availability(embd_id: str, tenant_id: str) -> tuple[bool, R """ Verifies availability of an embedding model for a specific tenant. - Implements a four-stage validation process: - 1. Model identifier parsing and validation - 2. System support verification - 3. Tenant authorization check - 4. Database operation error handling + Performs comprehensive verification through: + 1. Identifier Parsing: Decomposes embd_id into name and factory components + 2. System Verification: Checks model registration in LLMService + 3. Tenant Authorization: Validates tenant-specific model assignments + 4. Built-in Model Check: Confirms inclusion in predefined system models Args: embd_id (str): Unique identifier for the embedding model in format "model_name@factory" @@ -460,14 +460,15 @@ def verify_embedding_availability(embd_id: str, tenant_id: str) -> tuple[bool, R """ try: llm_name, llm_factory = TenantLLMService.split_model_name_and_factory(embd_id) - if not LLMService.query(llm_name=llm_name, fid=llm_factory, model_type="embedding"): - return False, get_error_argument_result(f"Unsupported model: <{embd_id}>") + in_llm_service = bool(LLMService.query(llm_name=llm_name, fid=llm_factory, model_type="embedding")) - # Tongyi-Qianwen is added to TenantLLM by default, but remains unusable with empty api_key tenant_llms = TenantLLMService.get_my_llms(tenant_id=tenant_id) is_tenant_model = any(llm["llm_name"] == llm_name and llm["llm_factory"] == llm_factory and llm["model_type"] == "embedding" for llm in tenant_llms) is_builtin_model = embd_id in settings.BUILTIN_EMBEDDING_MODELS + if not (is_builtin_model or is_tenant_model or in_llm_service): + return False, get_error_argument_result(f"Unsupported model: <{embd_id}>") + if not (is_builtin_model or is_tenant_model): return False, get_error_argument_result(f"Unauthorized model: <{embd_id}>") except OperationalError as e: diff --git a/api/utils/file_utils.py b/api/utils/file_utils.py index b90527c704f..7fefc54a651 100644 --- a/api/utils/file_utils.py +++ b/api/utils/file_utils.py @@ -158,7 +158,7 @@ def filename_type(filename): if re.match(r".*\.(eml|doc|docx|ppt|pptx|yml|xml|htm|json|csv|txt|ini|xls|xlsx|wps|rtf|hlp|pages|numbers|key|md|py|js|java|c|cpp|h|php|go|ts|sh|cs|kt|html|sql)$", filename): return FileType.DOC.value - if re.match(r".*\.(wav|flac|ape|alac|wavpack|wv|mp3|aac|ogg|vorbis|opus|mp3)$", filename): + if re.match(r".*\.(wav|flac|ape|alac|wavpack|wv|mp3|aac|ogg|vorbis|opus)$", filename): return FileType.AURAL.value if re.match(r".*\.(jpg|jpeg|png|tif|gif|pcx|tga|exif|fpx|svg|psd|cdr|pcd|dxf|ufo|eps|ai|raw|WMF|webp|avif|apng|icon|ico|mpg|mpeg|avi|rm|rmvb|mov|wmv|asf|dat|asx|wvx|mpe|mpa|mp4)$", filename): diff --git a/api/utils/log_utils.py b/api/utils/log_utils.py index 7ec1cafef25..3ebedd14821 100644 --- a/api/utils/log_utils.py +++ b/api/utils/log_utils.py @@ -30,7 +30,7 @@ def get_project_base_directory(): ) return PROJECT_BASE -def initRootLogger(logfile_basename: str, log_format: str = "%(asctime)-15s %(levelname)-8s %(process)d %(message)s"): +def init_root_logger(logfile_basename: str, log_format: str = "%(asctime)-15s %(levelname)-8s %(process)d %(message)s"): global initialized_root_logger if initialized_root_logger: return @@ -77,4 +77,11 @@ def initRootLogger(logfile_basename: str, log_format: str = "%(asctime)-15s %(le pkg_logger.setLevel(pkg_level) msg = f"{logfile_basename} log path: {log_path}, log levels: {pkg_levels}" - logger.info(msg) \ No newline at end of file + logger.info(msg) + + +def log_exception(e, *args): + logging.exception(e) + for a in args: + logging.error(str(a)) + raise e \ No newline at end of file diff --git a/api/utils/t_crypt.py b/api/utils/t_crypt.py index cd9d1edcc9f..d0763c19f45 100644 --- a/api/utils/t_crypt.py +++ b/api/utils/t_crypt.py @@ -35,6 +35,6 @@ def crypt(line): if __name__ == "__main__": - pswd = crypt(sys.argv[1]) - print(pswd) - print(decrypt(pswd)) + passwd = crypt(sys.argv[1]) + print(passwd) + print(decrypt(passwd)) diff --git a/api/utils/validation_utils.py b/api/utils/validation_utils.py index 206a91f12d4..d60dc556102 100644 --- a/api/utils/validation_utils.py +++ b/api/utils/validation_utils.py @@ -312,7 +312,7 @@ class PermissionEnum(StrEnum): team = auto() -class ChunkMethodnEnum(StrEnum): +class ChunkMethodEnum(StrEnum): naive = auto() book = auto() email = auto() @@ -382,8 +382,7 @@ class CreateDatasetReq(Base): description: str | None = Field(default=None, max_length=65535) embedding_model: Annotated[str, StringConstraints(strip_whitespace=True, max_length=255), Field(default="", serialization_alias="embd_id")] permission: PermissionEnum = Field(default=PermissionEnum.me, min_length=1, max_length=16) - chunk_method: ChunkMethodnEnum = Field(default=ChunkMethodnEnum.naive, min_length=1, max_length=32, serialization_alias="parser_id") - pagerank: int = Field(default=0, ge=0, le=100) + chunk_method: ChunkMethodEnum = Field(default=ChunkMethodEnum.naive, min_length=1, max_length=32, serialization_alias="parser_id") parser_config: ParserConfig | None = Field(default=None) @field_validator("avatar") @@ -539,6 +538,7 @@ def validate_parser_config_json_length(cls, v: ParserConfig | None) -> ParserCon class UpdateDatasetReq(CreateDatasetReq): dataset_id: str = Field(...) name: Annotated[str, StringConstraints(strip_whitespace=True, min_length=1, max_length=DATASET_NAME_LIMIT), Field(default="")] + pagerank: int = Field(default=0, ge=0, le=100) @field_validator("dataset_id", mode="before") @classmethod diff --git a/conf/llm_factories.json b/conf/llm_factories.json index 08e5268aa2b..5e02696e669 100644 --- a/conf/llm_factories.json +++ b/conf/llm_factories.json @@ -360,6 +360,12 @@ "max_tokens": 8192, "model_type": "embedding" }, + { + "llm_name": "text-embedding-v4", + "tags": "TEXT EMBEDDING,8K", + "max_tokens": 8192, + "model_type": "embedding" + }, { "llm_name": "qwen-vl-max", "tags": "LLM,CHAT,IMAGE2TEXT", @@ -3195,6 +3201,12 @@ "tags": "TEXT EMBEDDING, TEXT RE-RANK", "status": "1", "llm": [ + { + "llm_name": "voyage-multimodal-3", + "tags": "TEXT EMBEDDING,Chat,IMAGE2TEXT,32000", + "max_tokens": 32000, + "model_type": "embedding" + }, { "llm_name": "voyage-large-2-instruct", "tags": "TEXT EMBEDDING,16000", diff --git a/conf/service_conf.yaml b/conf/service_conf.yaml index b57f582067d..4c0635770c9 100644 --- a/conf/service_conf.yaml +++ b/conf/service_conf.yaml @@ -9,6 +9,7 @@ mysql: port: 5455 max_connections: 900 stale_timeout: 300 + max_allowed_packet: 1073741824 minio: user: 'rag_flow' password: 'infini_rag_flow' @@ -28,7 +29,6 @@ redis: db: 1 password: 'infini_rag_flow' host: 'localhost:6379' - # postgres: # name: 'rag_flow' # user: 'rag_flow' @@ -58,6 +58,11 @@ redis: # secret: 'secret' # tenant_id: 'tenant_id' # container_name: 'container_name' +# The OSS object storage uses the MySQL configuration above by default. If you need to switch to another object storage service, please uncomment and configure the following parameters. +# opendal: +# scheme: 'mysql' # Storage type, such as s3, oss, azure, etc. +# config: +# oss_table: 'your_table_name' # user_default_llm: # factory: 'Tongyi-Qianwen' # api_key: 'sk-xxxxxxxxxxxxx' diff --git a/deepdoc/parser/docx_parser.py b/deepdoc/parser/docx_parser.py index dfe3f37fd16..f3711961564 100644 --- a/deepdoc/parser/docx_parser.py +++ b/deepdoc/parser/docx_parser.py @@ -69,7 +69,7 @@ def blockType(b): max_type = max(max_type.items(), key=lambda x: x[1])[0] colnm = len(df.iloc[0, :]) - hdrows = [0] # header is not nessesarily appear in the first line + hdrows = [0] # header is not necessarily appear in the first line if max_type == "Nu": for r in range(1, len(df)): tys = Counter([blockType(str(df.iloc[r, j])) diff --git a/deepdoc/parser/figure_parser.py b/deepdoc/parser/figure_parser.py index 49263630d20..b29a4a8a527 100644 --- a/deepdoc/parser/figure_parser.py +++ b/deepdoc/parser/figure_parser.py @@ -21,7 +21,7 @@ from rag.prompts import vision_llm_figure_describe_prompt -def vision_figure_parser_figure_data_wraper(figures_data_without_positions): +def vision_figure_parser_figure_data_wrapper(figures_data_without_positions): return [ ( (figure_data[1], [figure_data[0]]), diff --git a/deepdoc/parser/pdf_parser.py b/deepdoc/parser/pdf_parser.py index 492c4dc5455..68c8946978b 100644 --- a/deepdoc/parser/pdf_parser.py +++ b/deepdoc/parser/pdf_parser.py @@ -61,7 +61,7 @@ def __init__(self, **kwargs): self.ocr = OCR() self.parallel_limiter = None - if PARALLEL_DEVICES is not None and PARALLEL_DEVICES > 1: + if PARALLEL_DEVICES > 1: self.parallel_limiter = [trio.CapacityLimiter(1) for _ in range(PARALLEL_DEVICES)] if hasattr(self, "model_speciess"): @@ -180,13 +180,13 @@ def _updown_concat_features(self, up, down): return fea @staticmethod - def sort_X_by_page(arr, threashold): + def sort_X_by_page(arr, threshold): # sort using y1 first and then x1 arr = sorted(arr, key=lambda r: (r["page_number"], r["x0"], r["top"])) for i in range(len(arr) - 1): for j in range(i, -1, -1): # restore the order using th - if abs(arr[j + 1]["x0"] - arr[j]["x0"]) < threashold \ + if abs(arr[j + 1]["x0"] - arr[j]["x0"]) < threshold \ and arr[j + 1]["top"] < arr[j]["top"] \ and arr[j + 1]["page_number"] == arr[j]["page_number"]: tmp = arr[j] @@ -264,13 +264,13 @@ def gather(kwd, fzy=10, ption=0.6): for b in self.boxes: if b.get("layout_type", "") != "table": continue - ii = Recognizer.find_overlapped_with_threashold(b, rows, thr=0.3) + ii = Recognizer.find_overlapped_with_threshold(b, rows, thr=0.3) if ii is not None: b["R"] = ii b["R_top"] = rows[ii]["top"] b["R_bott"] = rows[ii]["bottom"] - ii = Recognizer.find_overlapped_with_threashold( + ii = Recognizer.find_overlapped_with_threshold( b, headers, thr=0.3) if ii is not None: b["H_top"] = headers[ii]["top"] @@ -285,7 +285,7 @@ def gather(kwd, fzy=10, ption=0.6): b["C_left"] = clmns[ii]["x0"] b["C_right"] = clmns[ii]["x1"] - ii = Recognizer.find_overlapped_with_threashold(b, spans, thr=0.3) + ii = Recognizer.find_overlapped_with_threshold(b, spans, thr=0.3) if ii is not None: b["H_top"] = spans[ii]["top"] b["H_bott"] = spans[ii]["bottom"] diff --git a/deepdoc/parser/ppt_parser.py b/deepdoc/parser/ppt_parser.py index 83c27530908..58a983266a8 100644 --- a/deepdoc/parser/ppt_parser.py +++ b/deepdoc/parser/ppt_parser.py @@ -63,7 +63,7 @@ def __extract(self, shape): if shape_type == 6: texts = [] for p in sorted(shape.shapes, key=lambda x: (x.top // 10, x.left)): - t = self.__extract_texts(p) + t = self.__extract(p) if t: texts.append(t) return "\n".join(texts) diff --git a/deepdoc/parser/resume/entities/corporations.py b/deepdoc/parser/resume/entities/corporations.py index 43793668d8b..0396281deed 100644 --- a/deepdoc/parser/resume/entities/corporations.py +++ b/deepdoc/parser/resume/entities/corporations.py @@ -53,14 +53,14 @@ def corpNorm(nm, add_region=True): nm = re.sub(r"&", "&", nm) nm = re.sub(r"[\(\)()\+'\"\t \*\\【】-]+", " ", nm) nm = re.sub( - r"([—-]+.*| +co\..*|corp\..*| +inc\..*| +ltd.*)", "", nm, 10000, re.IGNORECASE + r"([—-]+.*| +co\..*|corp\..*| +inc\..*| +ltd.*)", "", nm, count=10000, flags=re.IGNORECASE ) nm = re.sub( r"(计算机|技术|(技术|科技|网络)*有限公司|公司|有限|研发中心|中国|总部)$", "", nm, - 10000, - re.IGNORECASE, + count=10000, + flags=re.IGNORECASE, ) if not nm or (len(nm) < 5 and not regions.isName(nm[0:2])): return nm diff --git a/deepdoc/parser/resume/step_two.py b/deepdoc/parser/resume/step_two.py index 6097a0132dc..0aa3ad38359 100644 --- a/deepdoc/parser/resume/step_two.py +++ b/deepdoc/parser/resume/step_two.py @@ -51,7 +51,7 @@ def signal_handler(signum, frame): def rmHtmlTag(line): - return re.sub(r"<[a-z0-9.\"=';,:\+_/ -]+>", " ", line, 100000, re.IGNORECASE) + return re.sub(r"<[a-z0-9.\"=';,:\+_/ -]+>", " ", line, count=100000, flags=re.IGNORECASE) def highest_degree(dg): @@ -507,7 +507,7 @@ def hasValues(flds): (r".*国有.*", "国企"), (r"[ ()\(\)人/·0-9-]+", ""), (r".*(元|规模|于|=|北京|上海|至今|中国|工资|州|shanghai|强|餐饮|融资|职).*", "")]: - cv["corporation_type"] = re.sub(p, r, cv["corporation_type"], 1000, re.IGNORECASE) + cv["corporation_type"] = re.sub(p, r, cv["corporation_type"], count=1000, flags=re.IGNORECASE) if len(cv["corporation_type"]) < 2: del cv["corporation_type"] diff --git a/deepdoc/vision/layout_recognizer.py b/deepdoc/vision/layout_recognizer.py index 5126171900a..46be451c634 100644 --- a/deepdoc/vision/layout_recognizer.py +++ b/deepdoc/vision/layout_recognizer.py @@ -106,7 +106,7 @@ def findLayout(ty): bxs.pop(i) continue - ii = self.find_overlapped_with_threashold(bxs[i], lts_, + ii = self.find_overlapped_with_threshold(bxs[i], lts_, thr=0.4) if ii is None: # belong to nothing bxs[i]["layout_type"] = "" diff --git a/deepdoc/vision/ocr.py b/deepdoc/vision/ocr.py index 4dedb7c67a3..e9e594274c7 100644 --- a/deepdoc/vision/ocr.py +++ b/deepdoc/vision/ocr.py @@ -529,31 +529,30 @@ def __init__(self, model_dir=None): "rag/res/deepdoc") # Append muti-gpus task to the list - if PARALLEL_DEVICES is not None and PARALLEL_DEVICES > 0: + if PARALLEL_DEVICES > 0: self.text_detector = [] self.text_recognizer = [] for device_id in range(PARALLEL_DEVICES): self.text_detector.append(TextDetector(model_dir, device_id)) self.text_recognizer.append(TextRecognizer(model_dir, device_id)) else: - self.text_detector = [TextDetector(model_dir, 0)] - self.text_recognizer = [TextRecognizer(model_dir, 0)] + self.text_detector = [TextDetector(model_dir)] + self.text_recognizer = [TextRecognizer(model_dir)] except Exception: model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc", local_dir=os.path.join(get_project_base_directory(), "rag/res/deepdoc"), local_dir_use_symlinks=False) - if PARALLEL_DEVICES is not None: - assert PARALLEL_DEVICES > 0, "Number of devices must be >= 1" + if PARALLEL_DEVICES > 0: self.text_detector = [] self.text_recognizer = [] for device_id in range(PARALLEL_DEVICES): self.text_detector.append(TextDetector(model_dir, device_id)) self.text_recognizer.append(TextRecognizer(model_dir, device_id)) else: - self.text_detector = [TextDetector(model_dir, 0)] - self.text_recognizer = [TextRecognizer(model_dir, 0)] + self.text_detector = [TextDetector(model_dir)] + self.text_recognizer = [TextRecognizer(model_dir)] self.drop_score = 0.5 self.crop_image_res_index = 0 @@ -589,7 +588,29 @@ def get_rotate_crop_image(self, img, points): flags=cv2.INTER_CUBIC) dst_img_height, dst_img_width = dst_img.shape[0:2] if dst_img_height * 1.0 / dst_img_width >= 1.5: - dst_img = np.rot90(dst_img) + # Try original orientation + rec_result = self.text_recognizer[0]([dst_img]) + text, score = rec_result[0][0] + best_score = score + best_img = dst_img + + # Try clockwise 90° rotation + rotated_cw = np.rot90(dst_img, k=3) + rec_result = self.text_recognizer[0]([rotated_cw]) + rotated_cw_text, rotated_cw_score = rec_result[0][0] + if rotated_cw_score > best_score: + best_score = rotated_cw_score + best_img = rotated_cw + + # Try counter-clockwise 90° rotation + rotated_ccw = np.rot90(dst_img, k=1) + rec_result = self.text_recognizer[0]([rotated_ccw]) + rotated_ccw_text, rotated_ccw_score = rec_result[0][0] + if rotated_ccw_score > best_score: + best_img = rotated_ccw + + # Use the best image + dst_img = best_img return dst_img def sorted_boxes(self, dt_boxes): diff --git a/deepdoc/vision/recognizer.py b/deepdoc/vision/recognizer.py index 6911d86980b..9fa82d7f5d0 100644 --- a/deepdoc/vision/recognizer.py +++ b/deepdoc/vision/recognizer.py @@ -52,20 +52,20 @@ def __init__(self, label_list, task_name, model_dir=None): self.label_list = label_list @staticmethod - def sort_Y_firstly(arr, threashold): + def sort_Y_firstly(arr, threshold): def cmp(c1, c2): diff = c1["top"] - c2["top"] - if abs(diff) < threashold: + if abs(diff) < threshold: diff = c1["x0"] - c2["x0"] return diff arr = sorted(arr, key=cmp_to_key(cmp)) return arr @staticmethod - def sort_X_firstly(arr, threashold): + def sort_X_firstly(arr, threshold): def cmp(c1, c2): diff = c1["x0"] - c2["x0"] - if abs(diff) < threashold: + if abs(diff) < threshold: diff = c1["top"] - c2["top"] return diff arr = sorted(arr, key=cmp_to_key(cmp)) @@ -133,7 +133,7 @@ def overlapped_area(a, b, ratio=True): @staticmethod def layouts_cleanup(boxes, layouts, far=2, thr=0.7): - def notOverlapped(a, b): + def not_overlapped(a, b): return any([a["x1"] < b["x0"], a["x0"] > b["x1"], a["bottom"] < b["top"], @@ -144,7 +144,7 @@ def notOverlapped(a, b): j = i + 1 while j < min(i + far, len(layouts)) \ and (layouts[i].get("type", "") != layouts[j].get("type", "") - or notOverlapped(layouts[i], layouts[j])): + or not_overlapped(layouts[i], layouts[j])): j += 1 if j >= min(i + far, len(layouts)): i += 1 @@ -163,9 +163,9 @@ def notOverlapped(a, b): area_i, area_i_1 = 0, 0 for b in boxes: - if not notOverlapped(b, layouts[i]): + if not not_overlapped(b, layouts[i]): area_i += Recognizer.overlapped_area(b, layouts[i], False) - if not notOverlapped(b, layouts[j]): + if not not_overlapped(b, layouts[j]): area_i_1 += Recognizer.overlapped_area(b, layouts[j], False) if area_i > area_i_1: @@ -239,15 +239,15 @@ def find_overlapped(box, boxes_sorted_by_y, naive=False): e -= 1 break - max_overlaped_i, max_overlaped = None, 0 + max_overlapped_i, max_overlapped = None, 0 for i in range(s, e): ov = Recognizer.overlapped_area(bxs[i], box) - if ov <= max_overlaped: + if ov <= max_overlapped: continue - max_overlaped_i = i - max_overlaped = ov + max_overlapped_i = i + max_overlapped = ov - return max_overlaped_i + return max_overlapped_i @staticmethod def find_horizontally_tightest_fit(box, boxes): @@ -264,7 +264,7 @@ def find_horizontally_tightest_fit(box, boxes): return min_i @staticmethod - def find_overlapped_with_threashold(box, boxes, thr=0.3): + def find_overlapped_with_threshold(box, boxes, thr=0.3): if not boxes: return max_overlapped_i, max_overlapped, _max_overlapped = None, thr, 0 @@ -408,18 +408,18 @@ def iou_filter(boxes, scores, iou_threshold): def __call__(self, image_list, thr=0.7, batch_size=16): res = [] - imgs = [] + images = [] for i in range(len(image_list)): if not isinstance(image_list[i], np.ndarray): - imgs.append(np.array(image_list[i])) + images.append(np.array(image_list[i])) else: - imgs.append(image_list[i]) + images.append(image_list[i]) - batch_loop_cnt = math.ceil(float(len(imgs)) / batch_size) + batch_loop_cnt = math.ceil(float(len(images)) / batch_size) for i in range(batch_loop_cnt): start_index = i * batch_size - end_index = min((i + 1) * batch_size, len(imgs)) - batch_image_list = imgs[start_index:end_index] + end_index = min((i + 1) * batch_size, len(images)) + batch_image_list = images[start_index:end_index] inputs = self.preprocess(batch_image_list) logging.debug("preprocess") for ins in inputs: diff --git a/deepdoc/vision/t_recognizer.py b/deepdoc/vision/t_recognizer.py index 1db3356a9d6..264014c8602 100644 --- a/deepdoc/vision/t_recognizer.py +++ b/deepdoc/vision/t_recognizer.py @@ -84,13 +84,13 @@ def gather(kwd, fzy=10, ption=0.6): clmns = LayoutRecognizer.layouts_cleanup(boxes, clmns, 5, 0.5) for b in boxes: - ii = LayoutRecognizer.find_overlapped_with_threashold(b, rows, thr=0.3) + ii = LayoutRecognizer.find_overlapped_with_threshold(b, rows, thr=0.3) if ii is not None: b["R"] = ii b["R_top"] = rows[ii]["top"] b["R_bott"] = rows[ii]["bottom"] - ii = LayoutRecognizer.find_overlapped_with_threashold(b, headers, thr=0.3) + ii = LayoutRecognizer.find_overlapped_with_threshold(b, headers, thr=0.3) if ii is not None: b["H_top"] = headers[ii]["top"] b["H_bott"] = headers[ii]["bottom"] @@ -104,7 +104,7 @@ def gather(kwd, fzy=10, ption=0.6): b["C_left"] = clmns[ii]["x0"] b["C_right"] = clmns[ii]["x1"] - ii = LayoutRecognizer.find_overlapped_with_threashold(b, spans, thr=0.3) + ii = LayoutRecognizer.find_overlapped_with_threshold(b, spans, thr=0.3) if ii is not None: b["H_top"] = spans[ii]["top"] b["H_bott"] = spans[ii]["bottom"] diff --git a/docker/.env b/docker/.env index 675e1704dc0..75bda17c004 100644 --- a/docker/.env +++ b/docker/.env @@ -129,6 +129,14 @@ TIMEZONE='Asia/Shanghai' # Note that neither `MAX_CONTENT_LENGTH` nor `client_max_body_size` sets the maximum size for files uploaded to an agent. # See https://ragflow.io/docs/dev/begin_component for details. +# Controls how many documents are processed in a single batch. +# Defaults to 4 if DOC_BULK_SIZE is not explicitly set. +DOC_BULK_SIZE=${DOC_BULK_SIZE:-4} + +# Defines the number of items to process per batch when generating embeddings. +# Defaults to 16 if EMBEDDING_BATCH_SIZE is not set in the environment. +EMBEDDING_BATCH_SIZE=${EMBEDDING_BATCH_SIZE:-16} + # Log level for the RAGFlow's own and imported packages. # Available levels: # - `DEBUG` diff --git a/docker/README.md b/docker/README.md index fd49164d3a8..c63d7385102 100644 --- a/docker/README.md +++ b/docker/README.md @@ -78,8 +78,8 @@ The [.env](./.env) file contains important environment variables for Docker. - `RAGFLOW-IMAGE` The Docker image edition. Available editions: - - `infiniflow/ragflow:v0.19.0-slim` (default): The RAGFlow Docker image without embedding models. - - `infiniflow/ragflow:v0.19.0`: The RAGFlow Docker image with embedding models including: + - `infiniflow/ragflow:v0.19.1-slim` (default): The RAGFlow Docker image without embedding models. + - `infiniflow/ragflow:v0.19.1`: The RAGFlow Docker image with embedding models including: - Built-in embedding models: - `BAAI/bge-large-zh-v1.5` - `maidalun1020/bce-embedding-base_v1` @@ -115,6 +115,16 @@ The [.env](./.env) file contains important environment variables for Docker. - `MAX_CONTENT_LENGTH` The maximum file size for each uploaded file, in bytes. You can uncomment this line if you wish to change the 128M file size limit. After making the change, ensure you update `client_max_body_size` in nginx/nginx.conf correspondingly. +### Doc bulk size + +- `DOC_BULK_SIZE` + The number of document chunks processed in a single batch during document parsing. Defaults to `4`. + +### Embedding batch size + +- `EMBEDDING_BATCH_SIZE` + The number of text chunks processed in a single batch during embedding vectorization. Defaults to `16`. + ## 🐋 Service configuration [service_conf.yaml](./service_conf.yaml) specifies the system-level configuration for RAGFlow and is used by its API server and task executor. In a dockerized setup, this file is automatically created based on the [service_conf.yaml.template](./service_conf.yaml.template) file (replacing all environment variables by their values). diff --git a/docker/docker-compose.yml b/docker/docker-compose.yml index 7378ab4be0b..bbf2111ed16 100644 --- a/docker/docker-compose.yml +++ b/docker/docker-compose.yml @@ -1,3 +1,4 @@ + include: - ./docker-compose-base.yml diff --git a/docker/service_conf.yaml.template b/docker/service_conf.yaml.template index 06789e0ac91..616ccb524b0 100644 --- a/docker/service_conf.yaml.template +++ b/docker/service_conf.yaml.template @@ -20,7 +20,7 @@ es: os: hosts: 'http://${OS_HOST:-opensearch01}:9201' username: '${OS_USER:-admin}' - password: '${OPENSEARCHH_PASSWORD:-infini_rag_flow_OS_01}' + password: '${OPENSEARCH_PASSWORD:-infini_rag_flow_OS_01}' infinity: uri: '${INFINITY_HOST:-infinity}:23817' db_name: 'default_db' diff --git a/docs/configurations.md b/docs/configurations.md index 9c106ff9d46..4f1535d32dd 100644 --- a/docs/configurations.md +++ b/docs/configurations.md @@ -99,8 +99,8 @@ RAGFlow utilizes MinIO as its object storage solution, leveraging its scalabilit - `RAGFLOW-IMAGE` The Docker image edition. Available editions: - - `infiniflow/ragflow:v0.19.0-slim` (default): The RAGFlow Docker image without embedding models. - - `infiniflow/ragflow:v0.19.0`: The RAGFlow Docker image with embedding models including: + - `infiniflow/ragflow:v0.19.1-slim` (default): The RAGFlow Docker image without embedding models. + - `infiniflow/ragflow:v0.19.1`: The RAGFlow Docker image with embedding models including: - Built-in embedding models: - `BAAI/bge-large-zh-v1.5` - `maidalun1020/bce-embedding-base_v1` diff --git a/docs/develop/build_docker_image.mdx b/docs/develop/build_docker_image.mdx index 7849611cc18..dd4b82a128c 100644 --- a/docs/develop/build_docker_image.mdx +++ b/docs/develop/build_docker_image.mdx @@ -77,7 +77,7 @@ After building the infiniflow/ragflow:nightly-slim image, you are ready to launc 1. Edit Docker Compose Configuration -Open the `docker/.env` file. Find the `RAGFLOW_IMAGE` setting and change the image reference from `infiniflow/ragflow:v0.19.0-slim` to `infiniflow/ragflow:nightly-slim` to use the pre-built image. +Open the `docker/.env` file. Find the `RAGFLOW_IMAGE` setting and change the image reference from `infiniflow/ragflow:v0.19.1-slim` to `infiniflow/ragflow:nightly-slim` to use the pre-built image. 2. Launch the Service diff --git a/docs/develop/mcp/launch_mcp_server.md b/docs/develop/mcp/launch_mcp_server.md index ea7bc928b84..a98939efb4e 100644 --- a/docs/develop/mcp/launch_mcp_server.md +++ b/docs/develop/mcp/launch_mcp_server.md @@ -23,7 +23,7 @@ Once a connection is established, an MCP server communicates with its client in ## Prerequisites 1. Ensure RAGFlow is upgraded to v0.18.0 or later. -2. Have your RAGFlow API key ready. See [Acquire a RAGFlow API key](./acquire_ragflow_api_key.md). +2. Have your RAGFlow API key ready. See [Acquire a RAGFlow API key](../acquire_ragflow_api_key.md). :::tip INFO If you wish to try out our MCP server without upgrading RAGFlow, community contributor [yiminghub2024](https://github.com/yiminghub2024) 👏 shares their recommended steps [here](#launch-an-mcp-server-without-upgrading-ragflow). @@ -42,10 +42,10 @@ You can start an MCP server either from source code or via Docker. ```bash # Launch the MCP server to work in self-host mode, run either of the following uv run mcp/server/server.py --host=127.0.0.1 --port=9382 --base_url=http://127.0.0.1:9380 --api_key=ragflow-xxxxx -# uv run mcp/server/server.py --host=127.0.0.1 --port=9382 --base_url=http://127.0.0.1:9380 mode=self-host --api_key=ragflow-xxxxx +# uv run mcp/server/server.py --host=127.0.0.1 --port=9382 --base_url=http://127.0.0.1:9380 --mode=self-host --api_key=ragflow-xxxxx # To launch the MCP server to work in host mode, run the following instead: -# uv run mcp/server/server.py --host=127.0.0.1 --port=9382 --base_url=http://127.0.0.1:9380 mode=host +# uv run mcp/server/server.py --host=127.0.0.1 --port=9382 --base_url=http://127.0.0.1:9380 --mode=host ``` Where: diff --git a/docs/develop/mcp/mcp_client_example.md b/docs/develop/mcp/mcp_client_example.md index 40b38b03957..ee9c5c2cfde 100644 --- a/docs/develop/mcp/mcp_client_example.md +++ b/docs/develop/mcp/mcp_client_example.md @@ -1,16 +1,240 @@ --- sidebar_position: 3 slug: /mcp_client + --- -# RAGFlow MCP client example +# RAGFlow MCP client examples + +Python and curl MCP client examples. + +------ + +## Example MCP Python client We provide a *prototype* MCP client example for testing [here](https://github.com/infiniflow/ragflow/blob/main/mcp/client/client.py). :::danger IMPORTANT -If your MCP server is running in host mode, include your acquired API key in your client's `headers` as shown below: +If your MCP server is running in host mode, include your acquired API key in your client's `headers` when connecting asynchronously to it: + ```python async with sse_client("http://localhost:9382/sse", headers={"api_key": "YOUR_KEY_HERE"}) as streams: # Rest of your code... ``` -::: \ No newline at end of file + +Alternatively, to comply with [OAuth 2.1 Section 5](https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-12#section-5), you can run the following code *instead* to connect to your MCP server: + +```python +async with sse_client("http://localhost:9382/sse", headers={"Authorization": "YOUR_KEY_HERE"}) as streams: + # Rest of your code... +``` +::: + +## Use curl to interact with the RAGFlow MCP server + +When interacting with the MCP server via HTTP requests, follow this initialization sequence: + +1. **The client sends an `initialize` request** with protocol version and capabilities. +2. **The server replies with an `initialize` response**, including the supported protocol and capabilities. +3. **The client confirms readiness with an `initialized` notification**. + _The connection is established between the client and the server, and further operations (such as tool listing) may proceed._ + +:::tip NOTE +For more information about this initialization process, see [here](https://modelcontextprotocol.io/docs/concepts/architecture#1-initialization). +::: + +In the following sections, we will walk you through a complete tool calling process. + +### 1. Obtain a session ID + +Each curl request with the MCP server must include a session ID: + +```bash +$ curl -N -H "api_key: YOUR_API_KEY" http://127.0.0.1:9382/sse +``` + +:::tip NOTE +See [here](../acquire_ragflow_api_key.md) for information about acquiring an API key. +::: + +#### Transport + +The transport will stream messages such as tool results, server responses, and keep-alive pings. + +_The server returns the session ID:_ + +```bash +event: endpoint +data: /messages/?session_id=5c6600ef61b845a788ddf30dceb25c54 +``` + +### 2. Send an `Initialize` request + +The client sends an `initialize` request with protocol version and capabilities: + +```bash +session_id="5c6600ef61b845a788ddf30dceb25c54" && \ + +curl -X POST "http://127.0.0.1:9382/messages/?session_id=$session_id" \ + -H "api_key: YOUR_API_KEY" \ + -H "Content-Type: application/json" \ + -d '{ + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + "params": { + "protocolVersion": "1.0", + "capabilities": {}, + "clientInfo": { + "name": "ragflow-mcp-client", + "version": "0.1" + } + } + }' && \ +``` + +#### Transport + +_The server replies with an `initialize` response, including the supported protocol and capabilities:_ + +```bash +event: message +data: {"jsonrpc":"2.0","id":1,"result":{"protocolVersion":"2025-03-26","capabilities":{"experimental":{"headers":{"host":"127.0.0.1:9382","user-agent":"curl/8.7.1","accept":"*/*","api_key":"ragflow-xxxxxxxxxxxx","accept-encoding":"gzip"}},"tools":{"listChanged":false}},"serverInfo":{"name":"ragflow-server","version":"1.9.4"}}} +``` + +### 3. Acknowledge readiness + +The client confirms readiness with an `initialized` notification: + +```bash +curl -X POST "http://127.0.0.1:9382/messages/?session_id=$session_id" \ + -H "api_key: YOUR_API_KEY" \ + -H "Content-Type: application/json" \ + -d '{ + "jsonrpc": "2.0", + "method": "notifications/initialized", + "params": {} + }' && \ +``` + + _The connection is established between the client and the server, and further operations (such as tool listing) may proceed._ + +### 4. Tool listing + +```bash +curl -X POST "http://127.0.0.1:9382/messages/?session_id=$session_id" \ + -H "api_key: YOUR_API_KEY" \ + -H "Content-Type: application/json" \ + -d '{ + "jsonrpc": "2.0", + "id": 3, + "method": "tools/list", + "params": {} + }' && \ +``` + +#### Transport + +```bash +event: message +data: {"jsonrpc":"2.0","id":3,"result":{"tools":[{"name":"ragflow_retrieval","description":"Retrieve relevant chunks from the RAGFlow retrieve interface based on the question, using the specified dataset_ids and optionally document_ids. Below is the list of all available datasets, including their descriptions and IDs. If you're unsure which datasets are relevant to the question, simply pass all dataset IDs to the function.","inputSchema":{"type":"object","properties":{"dataset_ids":{"type":"array","items":{"type":"string"}},"document_ids":{"type":"array","items":{"type":"string"}},"question":{"type":"string"}},"required":["dataset_ids","question"]}}]}} + +``` + +### 5. Tool calling + +```bash +curl -X POST "http://127.0.0.1:9382/messages/?session_id=$session_id" \ + -H "api_key: YOUR_API_KEY" \ + -H "Content-Type: application/json" \ + -d '{ + "jsonrpc": "2.0", + "id": 4, + "method": "tools/call", + "params": { + "name": "ragflow_retrieval", + "arguments": { + "question": "How to install neovim?", + "dataset_ids": ["DATASET_ID_HERE"], + "document_ids": [] + } + +``` + +#### Transport + +```bash +event: message +data: {"jsonrpc":"2.0","id":4,"result":{...}} + +``` + +### A complete curl example + +```bash +session_id="YOUR_SESSION_ID" && \ + +# Step 1: Initialize request +curl -X POST "http://127.0.0.1:9382/messages/?session_id=$session_id" \ + -H "api_key: YOUR_API_KEY" \ + -H "Content-Type: application/json" \ + -d '{ + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + "params": { + "protocolVersion": "1.0", + "capabilities": {}, + "clientInfo": { + "name": "ragflow-mcp-client", + "version": "0.1" + } + } + }' && \ + +sleep 2 && \ + +# Step 2: Initialized notification +curl -X POST "http://127.0.0.1:9382/messages/?session_id=$session_id" \ + -H "api_key: YOUR_API_KEY" \ + -H "Content-Type: application/json" \ + -d '{ + "jsonrpc": "2.0", + "method": "notifications/initialized", + "params": {} + }' && \ + +sleep 2 && \ + +# Step 3: Tool listing +curl -X POST "http://127.0.0.1:9382/messages/?session_id=$session_id" \ + -H "api_key: YOUR_API_KEY" \ + -H "Content-Type: application/json" \ + -d '{ + "jsonrpc": "2.0", + "id": 3, + "method": "tools/list", + "params": {} + }' && \ + +sleep 2 && \ + +# Step 4: Tool call +curl -X POST "http://127.0.0.1:9382/messages/?session_id=$session_id" \ + -H "api_key: YOUR_API_KEY" \ + -H "Content-Type: application/json" \ + -d '{ + "jsonrpc": "2.0", + "id": 4, + "method": "tools/call", + "params": { + "name": "ragflow_retrieval", + "arguments": { + "question": "How to install neovim?", + "dataset_ids": ["DATASET_ID_HERE"], + "document_ids": [] + } + } + }' + +``` diff --git a/docs/develop/switch_doc_engine.md b/docs/develop/switch_doc_engine.md index 07a3c2e4c1f..ebac20bd686 100644 --- a/docs/develop/switch_doc_engine.md +++ b/docs/develop/switch_doc_engine.md @@ -11,7 +11,7 @@ Switch your doc engine from Elasticsearch to Infinity. RAGFlow uses Elasticsearch by default for storing full text and vectors. To switch to [Infinity](https://github.com/infiniflow/infinity/), follow these steps: -:::danger WARNING +:::caution WARNING Switching to Infinity on a Linux/arm64 machine is not yet officially supported. ::: @@ -21,7 +21,7 @@ Switching to Infinity on a Linux/arm64 machine is not yet officially supported. $ docker compose -f docker/docker-compose.yml down -v ``` -:::cautiion WARNING +:::caution WARNING `-v` will delete the docker container volumes, and the existing data will be cleared. ::: diff --git a/docs/faq.mdx b/docs/faq.mdx index 297b478af0a..8c939889e67 100644 --- a/docs/faq.mdx +++ b/docs/faq.mdx @@ -19,7 +19,7 @@ import TOCInline from '@theme/TOCInline'; ### What sets RAGFlow apart from other RAG products? -The "garbage in garbage out" status quo remains unchanged despite the fact that LLMs have advanced Natural Language Processing (NLP) significantly. In response, RAGFlow introduces two unique features compared to other Retrieval-Augmented Generation (RAG) products. +The "garbage in garbage out" status quo remains unchanged despite the fact that LLMs have advanced Natural Language Processing (NLP) significantly. In its response, RAGFlow introduces two unique features compared to other Retrieval-Augmented Generation (RAG) products. - Fine-grained document parsing: Document parsing involves images and tables, with the flexibility for you to intervene as needed. - Traceable answers with reduced hallucinations: You can trust RAGFlow's responses as you can view the citations and references supporting them. @@ -30,17 +30,17 @@ The "garbage in garbage out" status quo remains unchanged despite the fact that Each RAGFlow release is available in two editions: -- **Slim edition**: excludes built-in embedding models and is identified by a **-slim** suffix added to the version name. Example: `infiniflow/ragflow:v0.19.0-slim` -- **Full edition**: includes built-in embedding models and has no suffix added to the version name. Example: `infiniflow/ragflow:v0.19.0` +- **Slim edition**: excludes built-in embedding models and is identified by a **-slim** suffix added to the version name. Example: `infiniflow/ragflow:v0.19.1-slim` +- **Full edition**: includes built-in embedding models and has no suffix added to the version name. Example: `infiniflow/ragflow:v0.19.1` --- ### Which embedding models can be deployed locally? -RAGFlow offers two Docker image editions, `v0.19.0-slim` and `v0.19.0`: +RAGFlow offers two Docker image editions, `v0.19.1-slim` and `v0.19.1`: -- `infiniflow/ragflow:v0.19.0-slim` (default): The RAGFlow Docker image without embedding models. -- `infiniflow/ragflow:v0.19.0`: The RAGFlow Docker image with embedding models including: +- `infiniflow/ragflow:v0.19.1-slim` (default): The RAGFlow Docker image without embedding models. +- `infiniflow/ragflow:v0.19.1`: The RAGFlow Docker image with embedding models including: - Built-in embedding models: - `BAAI/bge-large-zh-v1.5` - `maidalun1020/bce-embedding-base_v1` @@ -127,7 +127,19 @@ The corresponding APIs are now available. See the [RAGFlow HTTP API Reference](. ### Do you support stream output? -Yes, we do. +Yes, we do. Stream output is enabled by default in the chat assistant and agent. Note that you cannot disable stream output via RAGFlow's UI. To disable stream output in responses, use RAGFlow's Python or RESTful APIs: + +Python: + +- [Create chat completion](./references/python_api_reference.md#create-chat-completion) +- [Converse with chat assistant](./references/python_api_reference.md#converse-with-chat-assistant) +- [Converse with agent](./references/python_api_reference.md#converse-with-agent) + +RESTful: + +- [Create chat completion](./references/http_api_reference.md#create-chat-completion) +- [Converse with chat assistant](./references/http_api_reference.md#converse-with-chat-assistant) +- [Converse with agent](./references/http_api_reference.md#converse-with-agent) --- @@ -488,4 +500,10 @@ To switch your document engine from Elasticsearch to [Infinity](https://github.c All uploaded files are stored in Minio, RAGFlow's object storage solution. For instance, if you upload your file directly to a knowledge base, it is located at `/filename`. ---- \ No newline at end of file +--- + +### How to tune batch size for document parsing and embedding? + +You can control the batch size for document parsing and embedding by setting the environment variables `DOC_BULK_SIZE` and `EMBEDDING_BATCH_SIZE`. Increasing these values may improve throughput for large-scale data processing, but will also increase memory usage. Adjust them according to your hardware resources. + +--- diff --git a/docs/guides/agent/agent_component_reference/begin.mdx b/docs/guides/agent/agent_component_reference/begin.mdx index 74c7234241b..548a782a1ec 100644 --- a/docs/guides/agent/agent_component_reference/begin.mdx +++ b/docs/guides/agent/agent_component_reference/begin.mdx @@ -21,7 +21,7 @@ Click the component to display its **Configuration** window. Here, you can set a ### ID -The ID is the unique identifier for the component within the workflow. Unlike the IDs of other components, the ID of the **Begin** component *cannot* be changed. +The ID is the unique identifier for the component within the workflow. Unlike the IDs of other components, the ID of the **Begin** component _cannot_ be changed. ### Opening greeting @@ -31,30 +31,36 @@ An opening greeting is the agent's first message to the user. It can be a welcom You can set global variables within the **Begin** component, which can be either required or optional. Once established, users will need to provide values for these variables when interacting or chatting with the agent. Click **+ Add variable** to add a global variable, each with the following attributes: -- **Key**: *Required* +- **Key**: _Required_ The unique variable name. -- **Name**: *Required* +- **Name**: _Required_ A descriptive name providing additional details about the variable. For example, if **Key** is set to `lang`, you can set its **Name** to `Target language`. -- **Type**: *Required* - The type of the variable: +- **Type**: _Required_ + The type of the variable: - **line**: Accepts a single line of text without line breaks. - **paragraph**: Accepts multiple lines of text, including line breaks. - - **options**: Requires the user to select a value for this variable from a dropdown menu. And you are required to set *at least* one option for the dropdown menu. + - **options**: Requires the user to select a value for this variable from a dropdown menu. And you are required to set _at least_ one option for the dropdown menu. - **file**: Requires the user to upload one or multiple files. - **integer**: Accepts an integer as input. - **boolean**: Requires the user to toggle between on and off. -- **Optional**: A toggle indicating whether the variable is optional. +- **Optional**: A toggle indicating whether the variable is optional. :::tip NOTE To pass in parameters from a client, call: + - HTTP method [Converse with agent](../../../references/http_api_reference.md#converse-with-agent), or - Python method [Converse with agent](../../../references/python_api_reference.md#converse-with-agent). -::: + ::: :::danger IMPORTANT + - If you set the key type as **file**, ensure the token count of the uploaded file does not exceed your model provider's maximum token limit; otherwise, the plain text in your file will be truncated and incomplete. -- If your agent's **Begin** component takes a variable, you *cannot* embed it into a webpage. +- If your agent's **Begin** component takes a variable, you _cannot_ embed it into a webpage. + ::: + +:::note +You can tune document parsing and embedding efficiency by setting the environment variables `DOC_BULK_SIZE` and `EMBEDDING_BATCH_SIZE`. ::: ## Examples @@ -71,7 +77,7 @@ As mentioned earlier, the **Begin** component is indispensable for an agent. Sti ### Is the uploaded file in a knowledge base? -No. Files uploaded to an agent as input are not stored in a knowledge base and hence will not be processed using RAGFlow's built-in OCR, DLR or TSR models, or chunked using RAGFlow's built-in chunking methods. +No. Files uploaded to an agent as input are not stored in a knowledge base and hence will not be processed using RAGFlow's built-in OCR, DLR or TSR models, or chunked using RAGFlow's built-in chunking methods. ### How to upload a webpage or file from a URL? @@ -81,8 +87,8 @@ If you set the type of a variable as **file**, your users will be able to upload ### File size limit for an uploaded file -There is no *specific* file size limit for a file uploaded to an agent. However, note that model providers typically have a default or explicit maximum token setting, which can range from 8196 to 128k: The plain text part of the uploaded file will be passed in as the key value, but if the file's token count exceeds this limit, the string will be truncated and incomplete. +There is no _specific_ file size limit for a file uploaded to an agent. However, note that model providers typically have a default or explicit maximum token setting, which can range from 8196 to 128k: The plain text part of the uploaded file will be passed in as the key value, but if the file's token count exceeds this limit, the string will be truncated and incomplete. :::tip NOTE The variables `MAX_CONTENT_LENGTH` in `/docker/.env` and `client_max_body_size` in `/docker/nginx/nginx.conf` set the file size limit for each upload to a knowledge base or **File Management**. These settings DO NOT apply in this scenario. -::: \ No newline at end of file +::: diff --git a/docs/guides/agent/agent_component_reference/code.mdx b/docs/guides/agent/agent_component_reference/code.mdx index cb572f6d1d4..8ae9c374447 100644 --- a/docs/guides/agent/agent_component_reference/code.mdx +++ b/docs/guides/agent/agent_component_reference/code.mdx @@ -23,6 +23,8 @@ After defining an input variable, you are required to select from the dropdown m ## Coding field +This field allows you to enter and edit your source code. + ### A Python code example ```Python diff --git a/docs/guides/agent/sandbox_quickstart.md b/docs/guides/agent/sandbox_quickstart.md new file mode 100644 index 00000000000..b0e0ddce138 --- /dev/null +++ b/docs/guides/agent/sandbox_quickstart.md @@ -0,0 +1,116 @@ +--- +sidebar_position: 20 +slug: /sandbox_quickstart +--- + +# Sandbox quickstart + +A secure, pluggable code execution backend designed for RAGFlow and other applications requiring isolated code execution environments. + +## Features: + +- Seamless RAGFlow Integration — Works out-of-the-box with the code component of RAGFlow. +- High Security — Uses gVisor for syscall-level sandboxing to isolate execution. +- Customisable Sandboxing — Modify seccomp profiles easily to tailor syscall restrictions. +- Pluggable Runtime Support — Extendable to support any programming language runtime. +- Developer Friendly — Quick setup with a convenient Makefile. + +## Architecture + +The architecture consists of isolated Docker base images for each supported language runtime, managed by the executor manager service. The executor manager orchestrates sandboxed code execution using gVisor for syscall interception and optional seccomp profiles for enhanced syscall filtering. + +## Prerequisites + +- Linux distribution compatible with gVisor. +- gVisor installed and configured. +- Docker version 24.0.0 or higher. +- Docker Compose version 2.26.1 or higher (similar to RAGFlow requirements). +- uv package and project manager installed. +- (Optional) GNU Make for simplified command-line management. + +## Build Docker base images + +The sandbox uses isolated base images for secure containerised execution environments. + +Build the base images manually: + +```bash +docker build -t sandbox-base-python:latest ./sandbox_base_image/python +docker build -t sandbox-base-nodejs:latest ./sandbox_base_image/nodejs +``` + +Alternatively, build all base images at once using the Makefile: + +```bash +make build +``` + +Next, build the executor manager image: + +```bash +docker build -t sandbox-executor-manager:latest ./executor_manager +``` + +## Running with RAGFlow + +1. Verify that gVisor is properly installed and operational. + +2. Configure the .env file located at docker/.env: + +- Uncomment sandbox-related environment variables. +- Enable the sandbox profile at the bottom of the file. + +3. Add the following entry to your /etc/hosts file to resolve the executor manager service: + +```bash +127.0.0.1 sandbox-executor-manager +``` + +4. Start the RAGFlow service as usual. + +## Running standalone + +### Manual setup + +1. Initialize the environment variables: + +```bash +cp .env.example .env +``` + +2. Launch the sandbox services with Docker Compose: + +```bash +docker compose -f docker-compose.yml up +``` + +3. Test the sandbox setup: + +```bash +source .venv/bin/activate +export PYTHONPATH=$(pwd) +uv pip install -r executor_manager/requirements.txt +uv run tests/sandbox_security_tests_full.py +``` + +### Using Makefile + +Run all setup, build, launch, and tests with a single command: + +```bash +make +``` + +### Monitoring + +To follow logs of the executor manager container: + +```bash +docker logs -f sandbox-executor-manager +``` + +Or use the Makefile shortcut: + +```bash +make logs +``` \ No newline at end of file diff --git a/docs/guides/ai_search.md b/docs/guides/ai_search.md index f2db3e41b3d..abf0e47a038 100644 --- a/docs/guides/ai_search.md +++ b/docs/guides/ai_search.md @@ -9,7 +9,7 @@ Conduct an AI search. --- -An AI search is a single-turn AI conversation using a predefined retrieval strategy (a hybrid search of weighted keyword similarity and weighted vector similarity) and the system's default chat model. It does not involve advanced RAG strategies like knowledge graph, auto-keyword, or auto-question. Retrieved chunks will be listed below the chat model's response. +An AI search is a single-turn AI conversation using a predefined retrieval strategy (a hybrid search of weighted keyword similarity and weighted vector similarity) and the system's default chat model. It does not involve advanced RAG strategies like knowledge graph, auto-keyword, or auto-question. The related chunks are listed below the chat model's response in descending order based on their similarity scores. ![](https://raw.githubusercontent.com/infiniflow/ragflow-docs/main/images/ai_search.jpg) @@ -25,7 +25,7 @@ When debugging your chat assistant, you can use AI search as a reference to veri ## Frequently asked questions -### key difference between an AI search and an AI chat? +### Key difference between an AI search and an AI chat? A chat is a multi-turn AI conversation where you can define your retrieval strategy (a weighted reranking score can be used to replace the weighted vector similarity in a hybrid search) and choose your chat model. In an AI chat, you can configure advanced RAG strategies, such as knowledge graphs, auto-keyword, and auto-question, for your specific case. Retrieved chunks are not displayed along with the answer. diff --git a/docs/guides/chat/set_chat_variables.md b/docs/guides/chat/set_chat_variables.md index d850748ef81..a4364b2045c 100644 --- a/docs/guides/chat/set_chat_variables.md +++ b/docs/guides/chat/set_chat_variables.md @@ -30,7 +30,7 @@ In the **Variable** section, you add, remove, or update variables. `{knowledge}` is the system's reserved variable, representing the chunks retrieved from the knowledge base(s) specified by **Knowledge bases** under the **Assistant settings** tab. If your chat assistant is associated with certain knowledge bases, you can keep it as is. :::info NOTE -It does not currently make a difference whether you set `{knowledge}` to optional or mandatory, but note that this design will be updated at a later point. +It currently makes no difference whether `{knowledge}` is set as optional or mandatory, but please note this design will be updated in due course. ::: From v0.17.0 onward, you can start an AI chat without specifying knowledge bases. In this case, we recommend removing the `{knowledge}` variable to prevent unnecessary reference and keeping the **Empty response** field empty to avoid errors. diff --git a/docs/guides/chat/start_chat.md b/docs/guides/chat/start_chat.md index a6135674958..f86289e4b13 100644 --- a/docs/guides/chat/start_chat.md +++ b/docs/guides/chat/start_chat.md @@ -42,13 +42,13 @@ You start an AI conversation by creating an assistant. - **Rerank model** sets the reranker model to use. It is left empty by default. - If **Rerank model** is left empty, the hybrid score system uses keyword similarity and vector similarity, and the default weight assigned to the vector similarity component is 1-0.7=0.3. - If **Rerank model** is selected, the hybrid score system uses keyword similarity and reranker score, and the default weight assigned to the reranker score is 1-0.7=0.3. - - **Cross-language search**: Optional + - [Cross-language search](../../references/glossary.mdx#cross-language-search): Optional Select one or more target languages from the dropdown menu. The system’s default chat model will then translate your query into the selected target language(s). This translation ensures accurate semantic matching across languages, allowing you to retrieve relevant results regardless of language differences. - When selecting target languages, please ensure that these languages are present in the knowledge base to guarantee an effective search. - If no target language is selected, the system will search only in the language of your query, which may cause relevant information in other languages to be missed. - **Variable** refers to the variables (keys) to be used in the system prompt. `{knowledge}` is a reserved variable. Click **Add** to add more variables for the system prompt. - If you are uncertain about the logic behind **Variable**, leave it *as-is*. - - As of v0.19.0, if you add custom variables here, the only way you can pass in their values is to call: + - As of v0.19.1, if you add custom variables here, the only way you can pass in their values is to call: - HTTP method [Converse with chat assistant](../../references/http_api_reference.md#converse-with-chat-assistant), or - Python method [Converse with chat assistant](../../references/python_api_reference.md#converse-with-chat-assistant). diff --git a/docs/guides/dataset/autokeyword_autoquestion.mdx b/docs/guides/dataset/autokeyword_autoquestion.mdx new file mode 100644 index 00000000000..c7a1293af8e --- /dev/null +++ b/docs/guides/dataset/autokeyword_autoquestion.mdx @@ -0,0 +1,72 @@ +--- +sidebar_position: 3 +slug: /autokeyword_autoquestion +--- + +# Auto-keyword Auto-question +import APITable from '@site/src/components/APITable'; + +Use a chat model to generate keywords or questions from each chunk in the knowledge base. + +--- + +When selecting a chunking method, you can also enable auto-keyword or auto-question generation to increase retrieval rates. This feature uses a chat model to produce a specified number of keywords and questions from each created chunk, generating an "additional layer of information" from the original content. + +:::caution WARNING +Enabling this feature increases document indexing time and uses extra tokens, as all created chunks will be sent to the chat model for keyword or question generation. +::: + +## What is Auto-keyword? + +Auto-keyword refers to the auto-keyword generation feature of RAGFlow. It uses a chat model to generate a set of keywords or synonyms from each chunk to correct errors and enhance retrieval accuracy. This feature is implemented as a slider under **Page rank** on the **Configuration** page of your knowledge base. + +**Values**: + +- 0: (Default) Disabled. +- Between 3 and 5 (inclusive): Recommended if you have chunks of approximately 1,000 characters. +- 30 (maximum) + +:::tip NOTE +- If your chunk size increases, you can increase the value accordingly. Please note, as the value increases, the marginal benefit decreases. +- An Auto-keyword value must be an integer. If you set it to a non-integer, say 1.7, it will be rounded down to the nearest integer, which in this case is 1. +::: + +## What is Auto-question? + +Auto-question is a feature of RAGFlow that automatically generates questions from chunks of data using a chat model. These questions (e.g. who, what, and why) also help correct errors and improve the matching of user queries. The feature usually works with FAQ retrieval scenarios involving product manuals or policy documents. And you can find this feature as a slider under **Page rank** on the **Configuration** page of your knowledge base. + +**Values**: + +- 0: (Default) Disabled. +- 1 or 2: Recommended if you have chunks of approximately 1,000 characters. +- 10 (maximum) + +:::tip NOTE +- If your chunk size increases, you can increase the value accordingly. Please note, as the value increases, the marginal benefit decreases. +- An Auto-question value must be an integer. If you set it to a non-integer, say 1.7, it will be rounded down to the nearest integer, which in this case is 1. +::: + +## Tips from the community + +The Auto-keyword or Auto-question values relate closely to the chunking size in your knowledge base. However, if you are new to this feature and unsure which value(s) to start with, the following are some value settings we gathered from our community. While they may not be accurate, they provide a starting point at the very least. + +```mdx-code-block + +``` + +| Use cases or typical scenarios | Document volume/length | Auto_keyword (0–30) | Auto_question (0–10) | +|---------------------------------------------------------------------|---------------------------------|----------------------------|----------------------------| +| Internal process guidance for employee handbook | Small, under 10 pages | 0 | 0 | +| Customer service FAQs | Medium, 10–100 pages | 3–7 | 1–3 | +| Technical whitepapers: Development standards, protocol details | Large, over 100 pages | 2–4 | 1–2 | +| Contracts / Regulations / Legal clause retrieval | Large, over 50 pages | 2–5 | 0–1 | +| Multi-repository layered new documents + old archive | Many | Adjust as appropriate |Adjust as appropriate | +| Social media comment pool: multilingual & mixed spelling | Very large volume of short text | 8–12 | 0 | +| Operational logs for troubleshooting | Very large volume of short text | 3–6 | 0 | +| Marketing asset library: multilingual product descriptions | Medium | 6–10 | 1–2 | +| Training courses / eBooks | Large | 2–5 | 1–2 | +| Maintenance manual: equipment diagrams + steps | Medium | 3–7 | 1–2 | + +```mdx-code-block + +``` diff --git a/docs/guides/dataset/best_practices/accelerate_doc_indexing.mdx b/docs/guides/dataset/best_practices/accelerate_doc_indexing.mdx index 7765d8b4afb..bc0dde11b4c 100644 --- a/docs/guides/dataset/best_practices/accelerate_doc_indexing.mdx +++ b/docs/guides/dataset/best_practices/accelerate_doc_indexing.mdx @@ -16,4 +16,4 @@ Please note that some of your settings may consume a significant amount of time. - On the configuration page of your knowledge base, switch off **Use RAPTOR to enhance retrieval**. - Extracting knowledge graph (GraphRAG) is time-consuming. - Disable **Auto-keyword** and **Auto-question** on the configuration page of your knowledge base, as both depend on the LLM. -- **v0.17.0+:** If your document is plain text PDF and does not require GPU-intensive processes like OCR (Optical Character Recognition), TSR (Table Structure Recognition), or DLA (Document Layout Analysis), you can choose **Naive** over **DeepDoc** or other time-consuming large model options in the **Document parser** dropdown. This will substantially reduce document parsing time. +- **v0.17.0+:** If all PDFs in your knowledge base are plain text and do not require GPU-intensive processes like OCR (Optical Character Recognition), TSR (Table Structure Recognition), or DLA (Document Layout Analysis), you can choose **Naive** over **DeepDoc** or other time-consuming large model options in the **Document parser** dropdown. This will substantially reduce document parsing time. diff --git a/docs/guides/dataset/configure_knowledge_base.md b/docs/guides/dataset/configure_knowledge_base.md index 6bc4656866a..ff628e2d557 100644 --- a/docs/guides/dataset/configure_knowledge_base.md +++ b/docs/guides/dataset/configure_knowledge_base.md @@ -1,5 +1,5 @@ --- -sidebar_position: 0 +sidebar_position: -1 slug: /configure_knowledge_base --- @@ -41,7 +41,7 @@ RAGFlow offers multiple chunking template to facilitate chunking files of differ | **Template** | Description | File format | |--------------|-----------------------------------------------------------------------|-----------------------------------------------------------------------------------------------| -| General | Files are consecutively chunked based on a preset chunk token number. | DOCX, XLSX, XLS (Excel 97-2003), PPT, PDF, TXT, JPEG, JPG, PNG, TIF, GIF, CSV, JSON, EML, HTML | +| General | Files are consecutively chunked based on a preset chunk token number. | MD, MDX, DOCX, XLSX, XLS (Excel 97-2003), PPT, PDF, TXT, JPEG, JPG, PNG, TIF, GIF, CSV, JSON, EML, HTML | | Q&A | | XLSX, XLS (Excel 97-2003), CSV/TXT | | Resume | Enterprise edition only. You can also try it out on demo.ragflow.io. | DOCX, PDF, TXT | | Manual | | PDF | @@ -68,7 +68,7 @@ The following embedding models can be deployed locally: - maidalun1020/bce-embedding-base_v1 :::danger IMPORTANT -Please note these two embedding models support both English and Chinese. If your knowledge base contains other languages, the performance may be COMPROMISED. +These two embedding models are optimized specifically for English and Chinese, so performance may be compromised if you use them to embed documents in other languages. ::: ### Upload file @@ -128,7 +128,7 @@ See [Run retrieval test](./run_retrieval_test.md) for details. ## Search for knowledge base -As of RAGFlow v0.19.0, the search feature is still in a rudimentary form, supporting only knowledge base search by name. +As of RAGFlow v0.19.1, the search feature is still in a rudimentary form, supporting only knowledge base search by name. ![search knowledge base](https://github.com/infiniflow/ragflow/assets/93570324/836ae94c-2438-42be-879e-c7ad2a59693e) diff --git a/docs/guides/dataset/enable_excel2html.md b/docs/guides/dataset/enable_excel2html.md index ae01f1998f4..531a673cce4 100644 --- a/docs/guides/dataset/enable_excel2html.md +++ b/docs/guides/dataset/enable_excel2html.md @@ -9,7 +9,7 @@ Convert complex Excel spreadsheets into HTML tables. --- -When using the General chunking method, you can enable the **Excel to HTML** toggle to convert spreadsheet files into HTML tables. If it is disabled, spreadsheet tables will be represented as key-value pairs. For complex tables that cannot be simply represented this way, you must enable this feature. +When using the **General** chunking method, you can enable the **Excel to HTML** toggle to convert spreadsheet files into HTML tables. If it is disabled, spreadsheet tables will be represented as key-value pairs. For complex tables that cannot be simply represented this way, you must enable this feature. :::caution WARNING The feature is disabled by default. If your knowledge base contains spreadsheets with complex tables and you do not enable this feature, RAGFlow will not throw an error but your tables are likely to be garbled. @@ -22,7 +22,7 @@ Works with complex tables that cannot be represented as key-value pairs. Example ## Considerations - The Excel2HTML feature applies only to spreadsheet files (XLSX or XLS (Excel 97-2003)). -- This feature is associated with the General chunking method. In other words, it is available *only when* you select the General chunking method. +- This feature is associated with the **General** chunking method. In other words, it is available *only when* you select the **General** chunking method. - When this feature is enabled, spreadsheet tables with more than 12 rows will be split into chunks of 12 rows each. ## Procedure diff --git a/docs/guides/dataset/enable_raptor.md b/docs/guides/dataset/enable_raptor.md index 660233a38ab..4beab3dea1e 100644 --- a/docs/guides/dataset/enable_raptor.md +++ b/docs/guides/dataset/enable_raptor.md @@ -47,7 +47,7 @@ The RAPTOR feature is disabled by default. To enable it, manually switch on the ### Prompt -The following prompt will be applied recursively for cluster summarization, with `{cluster_content}` serving as an internal parameter. We recommend that you keep it as-is for now. The design will be updated at a later point. +The following prompt will be applied *recursively* for cluster summarization, with `{cluster_content}` serving as an internal parameter. We recommend that you keep it as-is for now. The design will be updated in due course. ``` Please summarize the following paragraphs... Paragraphs as following: diff --git a/docs/guides/dataset/run_retrieval_test.md b/docs/guides/dataset/run_retrieval_test.md index a8bbe2181e4..a9ca9f192cb 100644 --- a/docs/guides/dataset/run_retrieval_test.md +++ b/docs/guides/dataset/run_retrieval_test.md @@ -62,7 +62,7 @@ Using a knowledge graph in a retrieval test will significantly increase the time ### Cross-language search -To perform a cross-language search, select one or more target languages from the dropdown menu. The system’s default chat model will then translate your query entered in the Test text field into the selected target language(s). This translation ensures accurate semantic matching across languages, allowing you to retrieve relevant results regardless of language differences. +To perform a [cross-language search](../../references/glossary.mdx#cross-language-search), select one or more target languages from the dropdown menu. The system’s default chat model will then translate your query entered in the Test text field into the selected target language(s). This translation ensures accurate semantic matching across languages, allowing you to retrieve relevant results regardless of language differences. :::tip NOTE - When selecting target languages, please ensure that these languages are present in the knowledge base to guarantee an effective search. diff --git a/docs/guides/dataset/select_pdf_parser.md b/docs/guides/dataset/select_pdf_parser.md new file mode 100644 index 00000000000..1bdda5d1d5d --- /dev/null +++ b/docs/guides/dataset/select_pdf_parser.md @@ -0,0 +1,53 @@ +--- +sidebar_position: 1 +slug: /select_pdf_parser +--- + +# Select PDF parser + +Select a visual model for parsing your PDFs. + +--- + +RAGFlow isn't one-size-fits-all. It is built for flexibility and supports deeper customization to accommodate more complex use cases. From v0.17.0 onwards, RAGFlow decouples DeepDoc-specific data extraction tasks from chunking methods **for PDF files**. This separation enables you to autonomously select a visual model for OCR (Optical Character Recognition), TSR (Table Structure Recognition), and DLR (Document Layout Recognition) tasks that balances speed and performance to suit your specific use cases. If your PDFs contain only plain text, you can opt to skip these tasks by selecting the **Naive** option, to reduce the overall parsing time. + +![data extraction](https://raw.githubusercontent.com/infiniflow/ragflow-docs/main/images/data_extraction.jpg) + +## Prerequisites + +- The PDF parser dropdown menu appears only when you select a chunking method compatible with PDFs, including: + - **General** + - **Manual** + - **Paper** + - **Book** + - **Laws** + - **Presentation** + - **One** +- To use a third-party visual model for parsing PDFs, ensure you have set a default img2txt model under **Set default models** on the **Model providers** page. + +## Procedure + +1. On your knowledge base's **Configuration** page, select a chunking method, say **General**. + + _The **PDF parser** dropdown menu appears._ + +2. Select the option that works best with your scenario: + + - DeepDoc: (Default) The default visual model performing OCR, TSR, and DLR tasks on PDFs, which can be time-consuming. + - Naive: Skip OCR, TSR, and DLR tasks if *all* your PDFs are plain text. + - A third-party visual model provided by a specific model provider. + +:::caution WARNING +Third-party visual models are marked **Experimental**, because we have not fully tested these models for the aforementioned data extraction tasks. +::: + +## Frequently asked questions + +### When should I select DeepDoc or a third-party visual model as the PDF parser? + +Use a visual model to extract data if your PDFs contain formatted or image-based text rather than plain text. DeepDoc is the default visual model but can be time-consuming. You can also choose a lightweight or high-performance img2txt model depending on your needs and hardware capabilities. + +### Can I select a visual model to parse my DOCX files? + +No, you cannot. This dropdown menu is for PDFs only. To use this feature, convert your DOCX files to PDF first. + diff --git a/docs/guides/dataset/set_metadata.md b/docs/guides/dataset/set_metadata.md index b6d9f5f93ea..b36761c9486 100644 --- a/docs/guides/dataset/set_metadata.md +++ b/docs/guides/dataset/set_metadata.md @@ -1,5 +1,5 @@ --- -sidebar_position: 1 +sidebar_position: 0 slug: /set_metada --- @@ -19,4 +19,10 @@ For example, if you have a dataset of HTML files and want the LLM to cite the so Ensure that your metadata is in JSON format; otherwise, your updates will not be applied. ::: -![Image](https://github.com/user-attachments/assets/379cf2c5-4e37-4b79-8aeb-53bf8e01d326) \ No newline at end of file +![Image](https://github.com/user-attachments/assets/379cf2c5-4e37-4b79-8aeb-53bf8e01d326) + +## Frequently asked questions + +### Can I set metadata for multiple documents at once? + +No, you must set metadata *individually* for each document, as RAGFlow does not support batch setting of metadata. If you still consider this feature essential, please [raise an issue](https://github.com/infiniflow/ragflow/issues) explaining your use case and its importance. \ No newline at end of file diff --git a/docs/guides/dataset/set_page_rank.md b/docs/guides/dataset/set_page_rank.md index 46d984ecf17..c0af823080a 100644 --- a/docs/guides/dataset/set_page_rank.md +++ b/docs/guides/dataset/set_page_rank.md @@ -1,5 +1,5 @@ --- -sidebar_position: 3 +sidebar_position: 2 slug: /set_page_rank --- diff --git a/docs/guides/dataset/use_tag_sets.md b/docs/guides/dataset/use_tag_sets.md index 4d713dcc882..3205ba9c271 100644 --- a/docs/guides/dataset/use_tag_sets.md +++ b/docs/guides/dataset/use_tag_sets.md @@ -5,7 +5,7 @@ slug: /use_tag_sets # Use tag set -Use a tag set to tag chunks in your datasets. +Use a tag set to auto-tag chunks in your datasets. --- @@ -21,7 +21,7 @@ The auto-tagging feature is *unavailable* on the [Infinity](https://github.com/i Auto-tagging applies in situations where chunks are so similar to each other that the intended chunks cannot be distinguished from the rest. For example, when you have a few chunks about iPhone and a majority about iPhone case or iPhone accessaries, it becomes difficult to retrieve those chunks about iPhone without additional information. -## Create tag set +## 1. Create tag set You can consider a tag set as a closed set, and the tags to attach to the chunks in your dataset (knowledge base) are *exclusively* from the specified tag set. You use a tag set to "inform" RAGFlow which chunks to tag and which tags to apply. @@ -41,6 +41,10 @@ As a rule of thumb, consider including the following entries in your tag table: ### Create a tag set +:::danger IMPORTANT +A tag set is *not* involved in document indexing or retrieval. Do not specify a tag set when configuring your chat assistant or agent. +::: + 1. Click **+ Create knowledge base** to create a knowledge base. 2. Navigate to the **Configuration** page of the created knowledge base and choose **Tag** as the default chunking method. 3. Navigate to the **Dataset** page and upload and parse your table file in XLSX, CSV, or TXT formats. @@ -49,11 +53,7 @@ As a rule of thumb, consider including the following entries in your tag table: 4. Click the **Table** tab to view the tag frequency table: ![Image](https://github.com/user-attachments/assets/af91d10c-5ea5-491f-ab21-3803d5ebf59f) -:::danger IMPORTANT -A tag set is *not* involved in document indexing or retrieval. Do not specify a tag set when configuring your chat assistant or agent. -::: - -## Tag chunks +## 2. Tag chunks Once a tag set is created, you can apply it to your dataset: @@ -67,7 +67,7 @@ If the tag set is missing from the dropdown, check that it has been created or c 3. Re-parse your documents to start the auto-tagging process. _In an AI chat scenario using auto-tagged datasets, each query will be tagged using the corresponding tag set(s) and chunks with these tags will have a higher chance to be retrieved._ -## Update tag set +## 3. Update tag set Creating a tag set is *not* for once and for all. Oftentimes, you may find it necessary to update or delete existing tags or add new entries. diff --git a/docs/guides/manage_files.md b/docs/guides/manage_files.md index 78c24622b13..241a0b88d93 100644 --- a/docs/guides/manage_files.md +++ b/docs/guides/manage_files.md @@ -87,4 +87,4 @@ RAGFlow's file management allows you to download an uploaded file: ![download_file](https://github.com/infiniflow/ragflow/assets/93570324/cf3b297f-7d9b-4522-bf5f-4f45743e4ed5) -> As of RAGFlow v0.19.0, bulk download is not supported, nor can you download an entire folder. +> As of RAGFlow v0.19.1, bulk download is not supported, nor can you download an entire folder. diff --git a/docs/guides/models/llm_api_key_setup.md b/docs/guides/models/llm_api_key_setup.md index d42d4de358b..46fd6c8a072 100644 --- a/docs/guides/models/llm_api_key_setup.md +++ b/docs/guides/models/llm_api_key_setup.md @@ -49,6 +49,6 @@ After logging into RAGFlow, you can *only* configure API Key on the **Model prov 5. Click **OK** to confirm your changes. :::note -To update an existing model API key at a later point: +To update an existing model API key: ![update api key](https://github.com/infiniflow/ragflow/assets/93570324/0bfba679-33f7-4f6b-9ed6-f0e6e4b228ad) ::: \ No newline at end of file diff --git a/docs/guides/tracing.mdx b/docs/guides/tracing.mdx index 27178f7265d..07fd8ddfdf9 100644 --- a/docs/guides/tracing.mdx +++ b/docs/guides/tracing.mdx @@ -18,7 +18,7 @@ RAGFlow ships with a built-in [Langfuse](https://langfuse.com) integration so th Langfuse stores traces, spans and prompt payloads in a purpose-built observability backend and offers filtering and visualisations on top. :::info NOTE -• RAGFlow **≥ 0.19.0** (contains the Langfuse connector) +• RAGFlow **≥ 0.19.1** (contains the Langfuse connector) • A Langfuse workspace (cloud or self-hosted) with a _Project Public Key_ and _Secret Key_ ::: diff --git a/docs/guides/upgrade_ragflow.mdx b/docs/guides/upgrade_ragflow.mdx index 18d4fd479c4..7eb8b88f7e6 100644 --- a/docs/guides/upgrade_ragflow.mdx +++ b/docs/guides/upgrade_ragflow.mdx @@ -66,16 +66,16 @@ To upgrade RAGFlow, you must upgrade **both** your code **and** your Docker imag git clone https://github.com/infiniflow/ragflow.git ``` -2. Switch to the latest, officially published release, e.g., `v0.19.0`: +2. Switch to the latest, officially published release, e.g., `v0.19.1`: ```bash - git checkout -f v0.19.0 + git checkout -f v0.19.1 ``` 3. Update **ragflow/docker/.env** as follows: ```bash - RAGFLOW_IMAGE=infiniflow/ragflow:v0.19.0 + RAGFLOW_IMAGE=infiniflow/ragflow:v0.19.1 ``` 4. Update the RAGFlow image and restart RAGFlow: @@ -92,10 +92,10 @@ To upgrade RAGFlow, you must upgrade **both** your code **and** your Docker imag 1. From an environment with Internet access, pull the required Docker image. 2. Save the Docker image to a **.tar** file. ```bash - docker save -o ragflow.v0.19.0.tar infiniflow/ragflow:v0.19.0 + docker save -o ragflow.v0.19.1.tar infiniflow/ragflow:v0.19.1 ``` 3. Copy the **.tar** file to the target server. 4. Load the **.tar** file into Docker: ```bash - docker load -i ragflow.v0.19.0.tar + docker load -i ragflow.v0.19.1.tar ``` diff --git a/docs/quickstart.mdx b/docs/quickstart.mdx index 88321f81edd..8a0f262bfba 100644 --- a/docs/quickstart.mdx +++ b/docs/quickstart.mdx @@ -29,7 +29,7 @@ If you are on an ARM platform, follow [this guide](./develop/build_docker_image. - RAM ≥ 16 GB; - Disk ≥ 50 GB; - Docker ≥ 24.0.0 & Docker Compose ≥ v2.26.1. -- [gVisor](https://gvisor.dev/docs/user_guide/install/): Required only if you intend to use the code executor (sandbox) feature of RAGFlow. +- [gVisor](https://gvisor.dev/docs/user_guide/install/): Required only if you intend to use the code executor ([sandbox](https://github.com/infiniflow/ragflow/tree/main/sandbox)) feature of RAGFlow. :::tip NOTE If you have not installed Docker on your local machine (Windows, Mac, or Linux), see [Install Docker Engine](https://docs.docker.com/engine/install/). @@ -44,7 +44,7 @@ This section provides instructions on setting up the RAGFlow server on Linux. If `vm.max_map_count`. This value sets the maximum number of memory map areas a process may have. Its default value is 65530. While most applications require fewer than a thousand maps, reducing this value can result in abnormal behaviors, and the system will throw out-of-memory errors when a process reaches the limitation. - RAGFlow v0.19.0 uses Elasticsearch or [Infinity](https://github.com/infiniflow/infinity) for multiple recall. Setting the value of `vm.max_map_count` correctly is crucial to the proper functioning of the Elasticsearch component. + RAGFlow v0.19.1 uses Elasticsearch or [Infinity](https://github.com/infiniflow/infinity) for multiple recall. Setting the value of `vm.max_map_count` correctly is crucial to the proper functioning of the Elasticsearch component. Each RAGFlow account is able to use **text-embedding-v2** for free, an embedding model of Tongyi-Qianwen. This is why you can see Tongyi-Qianwen in the **Added models** list. And you may need to update your Tongyi-Qianwen API key at a later point. - 2. Click on the desired LLM and update the API key accordingly (DeepSeek-V2 in this case): ![update api key](https://github.com/infiniflow/ragflow/assets/93570324/4e5e13ef-a98d-42e6-bcb1-0c6045fc1666) @@ -289,7 +287,7 @@ To add and configure an LLM: ## Create your first knowledge base -You are allowed to upload files to a knowledge base in RAGFlow and parse them into datasets. A knowledge base is virtually a collection of datasets. Question answering in RAGFlow can be based on a particular knowledge base or multiple knowledge bases. File formats that RAGFlow supports include documents (PDF, DOC, DOCX, TXT, MD), tables (CSV, XLSX, XLS), pictures (JPEG, JPG, PNG, TIF, GIF), and slides (PPT, PPTX). +You are allowed to upload files to a knowledge base in RAGFlow and parse them into datasets. A knowledge base is virtually a collection of datasets. Question answering in RAGFlow can be based on a particular knowledge base or multiple knowledge bases. File formats that RAGFlow supports include documents (PDF, DOC, DOCX, TXT, MD, MDX), tables (CSV, XLSX, XLS), pictures (JPEG, JPG, PNG, TIF, GIF), and slides (PPT, PPTX). To create your first knowledge base: diff --git a/docs/references/glossary.mdx b/docs/references/glossary.mdx index ceec555dd24..d0e8712186b 100644 --- a/docs/references/glossary.mdx +++ b/docs/references/glossary.mdx @@ -19,7 +19,7 @@ import TOCInline from '@theme/TOCInline'; ### Cross-language search -Cross-language search (also known as cross-lingual retrieval) is a feature introduced in version 0.19.0. It enables users to submit queries in one language (for example, English) and retrieve relevant documents written in other languages such as Chinese or Spanish. This feature is enabled by the system’s default chat model, which translates queries to ensure accurate matching of semantic meaning across languages. +Cross-language search (also known as cross-lingual retrieval) is a feature introduced in version 0.19.1. It enables users to submit queries in one language (for example, English) and retrieve relevant documents written in other languages such as Chinese or Spanish. This feature is enabled by the system’s default chat model, which translates queries to ensure accurate matching of semantic meaning across languages. By enabling cross-language search, users can effortlessly access a broader range of information regardless of language barriers, significantly enhancing the system’s usability and inclusiveness. diff --git a/docs/references/http_api_reference.md b/docs/references/http_api_reference.md index 86ea4c0e46d..aef4ce2a2c2 100644 --- a/docs/references/http_api_reference.md +++ b/docs/references/http_api_reference.md @@ -343,7 +343,6 @@ Creates a dataset. - `"embedding_model"`: `string` - `"permission"`: `string` - `"chunk_method"`: `string` - - `"pagerank"`: `int` - `"parser_config"`: `object` ##### Request example @@ -384,12 +383,6 @@ curl --request POST \ - `"me"`: (Default) Only you can manage the dataset. - `"team"`: All team members can manage the dataset. -- `"pagerank"`: (*Body parameter*), `int` - refer to [Set page rank](https://ragflow.io/docs/dev/set_page_rank) - - Default: `0` - - Minimum: `0` - - Maximum: `100` - - `"chunk_method"`: (*Body parameter*), `enum` The chunking method of the dataset to create. Available options: - `"naive"`: General (default) @@ -900,7 +893,7 @@ curl --request PUT \ - `document_id`: (*Path parameter*) The ID of the document to update. - `"name"`: (*Body parameter*), `string` -- `"meta_fields"`: (*Body parameter*), `dict[str, Any]` The meta fields of the document. +- `"meta_fields"`: (*Body parameter*), `dict[str, Any]` The meta fields of the document. - `"chunk_method"`: (*Body parameter*), `string` The parsing method to apply to the document: - `"naive"`: General @@ -2142,7 +2135,7 @@ Success: "id": "4606b4ec87ad11efbc4f0242ac120006", "messages": [ { - "content": "Hi! I am your assistant,can I help you?", + "content": "Hi! I am your assistant, can I help you?", "role": "assistant" } ], @@ -2283,7 +2276,7 @@ Success: "id": "578d541e87ad11ef96b90242ac120006", "messages": [ { - "content": "Hi! I am your assistant,can I help you?", + "content": "Hi! I am your assistant, can I help you?", "role": "assistant" } ], @@ -3227,22 +3220,22 @@ Failure: --- -### Related Questions +### Generate related questions -**POST** `/v1/conversation/related_questions` +**POST** `/v1/sessions/related_questions` Generates five to ten alternative question strings from the user's original query to retrieve more relevant search results. -This operation requires a `Bearer Login Token`, typically expires with in 24 hours. You can find the it in the browser request easily. +This operation requires a `Bearer Login Token`, which typically expires with in 24 hours. You can find the it in the Request Headers in your browser easily. :::tip NOTE -The chat model dynamically determines the number of questions to generate based on the instruction, typically between five and ten. +The chat model autonomously determines the number of questions to generate based on the instruction, typically between five and ten. ::: #### Request - Method: POST -- URL: `/v1/conversation/related_questions` +- URL: `/v1/sessions/related_questions` - Headers: - `'content-Type: application/json'` - `'Authorization: Bearer '` @@ -3253,7 +3246,7 @@ The chat model dynamically determines the number of questions to generate based ```bash curl --request POST \ - --url http://{address}/v1/conversation/related_questions \ + --url http://{address}/v1/sessions/related_questions \ --header 'Content-Type: application/json' \ --header 'Authorization: Bearer ' \ --data ' diff --git a/docs/references/python_api_reference.md b/docs/references/python_api_reference.md index edfa45ce72d..af66c0e4983 100644 --- a/docs/references/python_api_reference.md +++ b/docs/references/python_api_reference.md @@ -100,7 +100,6 @@ RAGFlow.create_dataset( embedding_model: Optional[str] = "BAAI/bge-large-zh-v1.5@BAAI", permission: str = "me", chunk_method: str = "naive", - pagerank: int = 0, parser_config: DataSet.ParserConfig = None ) -> DataSet ``` @@ -148,10 +147,6 @@ The chunking method of the dataset to create. Available options: - `"one"`: One - `"email"`: Email -##### pagerank, `int` - -The pagerank of the dataset to create. Defaults to `0`. - ##### parser_config The parser configuration of the dataset. A `ParserConfig` object's attributes vary based on the selected `chunk_method`: @@ -1238,7 +1233,7 @@ The name of the chat session to create. - Success: A `Session` object containing the following attributes: - `id`: `str` The auto-generated unique identifier of the created session. - `name`: `str` The name of the created session. - - `message`: `list[Message]` The opening message of the created session. Default: `[{"role": "assistant", "content": "Hi! I am your assistant,can I help you?"}]` + - `message`: `list[Message]` The opening message of the created session. Default: `[{"role": "assistant", "content": "Hi! I am your assistant, can I help you?"}]` - `chat_id`: `str` The ID of the associated chat assistant. - Failure: `Exception` @@ -1497,7 +1492,7 @@ The parameters in `begin` component. - Success: A `Session` object containing the following attributes: - `id`: `str` The auto-generated unique identifier of the created session. - - `message`: `list[Message]` The messages of the created session assistant. Default: `[{"role": "assistant", "content": "Hi! I am your assistant,can I help you?"}]` + - `message`: `list[Message]` The messages of the created session assistant. Default: `[{"role": "assistant", "content": "Hi! I am your assistant, can I help you?"}]` - `agent_id`: `str` The ID of the associated agent. - Failure: `Exception` diff --git a/docs/release_notes.md b/docs/release_notes.md index fa6cd507c97..20f66ed9a33 100644 --- a/docs/release_notes.md +++ b/docs/release_notes.md @@ -14,12 +14,12 @@ Each RAGFlow release is available in two editions: ::: :::danger IMPORTANT -:collision: The embedding models included in a full edition are: +The embedding models included in a full edition are: - BAAI/bge-large-zh-v1.5 - maidalun1020/bce-embedding-base_v1 -Please note these two embedding models support both English and Chinese. If your knowledge base contains other languages, the performance may be COMPROMISED. +These two embedding models are optimized specifically for English and Chinese, so performance may be compromised if you use them to embed documents in other languages. ::: ## v0.19.0 @@ -28,10 +28,10 @@ Released on May 26, 2025. ### New features -- Cross-language search is supported in the Knowledge and Chat modules, enhancing search accuracy and user experience in multilingual environments, such as in Chinese-English knowledge bases. +- [Cross-language search](./references/glossary.mdx#cross-language-search) is supported in the Knowledge and Chat modules, enhancing search accuracy and user experience in multilingual environments, such as in Chinese-English knowledge bases. - Agent component: A new Code component supports Python and JavaScript scripts, enabling developers to handle more complex tasks like dynamic data processing. - Enhanced image display: Images in Chat and Search now render directly within responses, rather than as external references. Knowledge retrieval testing can retrieve images directly, instead of texts extracted from images. -- Claude 4 and ChatGPT o3: Developers can now use the newly released, most advanced Claude model alongside OpenAI’s latest ChatGPT o3 inference model. +- Claude 4 and ChatGPT o3: Developers can now use the newly released, most advanced Claude model and OpenAI’s latest ChatGPT o3 inference model. > The following features are contributed by our community contributors: @@ -39,6 +39,14 @@ Released on May 26, 2025. - Markdown rendering: Image references in a markdown file can be displayed after chunking. Thanks to [Woody-Hu](https://github.com/Woody-Hu). - Document engine support: OpenSearch can now be used as RAGFlow's document engine. Thanks to [pyyuhao](https://github.com/pyyuhao). +### Documentation + +#### Added documents + +- [Select PDF parser](./guides/dataset/select_pdf_parser.md) +- [Enable Excel2HTML](./guides/dataset/enable_excel2html.md) +- [Code component](./guides/agent/agent_component_reference/code.mdx) + ## v0.18.0 Released on April 23, 2025. @@ -143,7 +151,7 @@ Released on March 3, 2025. - AI chat: Leverages Tavily-based web search to enhance contexts in agentic reasoning. To activate this, enter the correct Tavily API key under the **Assistant settings** tab of your chat assistant dialogue. - AI chat: Supports starting a chat without specifying knowledge bases. - AI chat: HTML files can also be previewed and referenced, in addition to PDF files. -- Dataset: Adds a **PDF parser**, aka **Document parser**, dropdown menu to dataset configurations. This includes a DeepDoc model option, which is time-consuming, a much faster **naive** option (plain text), which skips DLA (Document Layout Analysis), OCR (Optical Character Recognition), and TSR (Table Structure Recognition) tasks, and several currently *experimental* large model options. +- Dataset: Adds a **PDF parser**, aka **Document parser**, dropdown menu to dataset configurations. This includes a DeepDoc model option, which is time-consuming, a much faster **naive** option (plain text), which skips DLA (Document Layout Analysis), OCR (Optical Character Recognition), and TSR (Table Structure Recognition) tasks, and several currently *experimental* large model options. See [here](./guides/dataset/select_pdf_parser.md). - Agent component: **(x)** or a forward slash `/` can be used to insert available keys (variables) in the system prompt field of the **Generate** or **Template** component. - Object storage: Supports using Aliyun OSS (Object Storage Service) as a file storage option. - Models: Updates the supported model list for Tongyi-Qianwen (Qwen), adding DeepSeek-specific models; adds ModelScope as a model provider. diff --git a/download_deps.py b/download_deps.py index 3ada5be2c7e..058cf258b03 100644 --- a/download_deps.py +++ b/download_deps.py @@ -11,12 +11,13 @@ # /// from huggingface_hub import snapshot_download +from typing import Union import nltk import os import urllib.request import argparse -def get_urls(use_china_mirrors=False): +def get_urls(use_china_mirrors=False) -> Union[str, list[str]]: if use_china_mirrors: return [ "http://mirrors.tuna.tsinghua.edu.cn/ubuntu/pool/main/o/openssl/libssl1.1_1.1.1f-1ubuntu2_amd64.deb", @@ -24,8 +25,8 @@ def get_urls(use_china_mirrors=False): "https://repo.huaweicloud.com/repository/maven/org/apache/tika/tika-server-standard/3.0.0/tika-server-standard-3.0.0.jar", "https://repo.huaweicloud.com/repository/maven/org/apache/tika/tika-server-standard/3.0.0/tika-server-standard-3.0.0.jar.md5", "https://openaipublic.blob.core.windows.net/encodings/cl100k_base.tiktoken", - "https://storage.googleapis.com/chrome-for-testing-public/121.0.6167.85/linux64/chrome-linux64.zip", - "https://storage.googleapis.com/chrome-for-testing-public/121.0.6167.85/linux64/chromedriver-linux64.zip", + ["https://storage.googleapis.com/chrome-for-testing-public/121.0.6167.85/linux64/chrome-linux64.zip", "chrome-linux64-121-0-6167-85"], + ["https://storage.googleapis.com/chrome-for-testing-public/121.0.6167.85/linux64/chromedriver-linux64.zip", "chromedriver-linux64-121-0-6167-85"], ] else: return [ @@ -49,7 +50,7 @@ def get_urls(use_china_mirrors=False): def download_model(repo_id): local_dir = os.path.abspath(os.path.join("huggingface.co", repo_id)) os.makedirs(local_dir, exist_ok=True) - snapshot_download(repo_id=repo_id, local_dir=local_dir, local_dir_use_symlinks=False) + snapshot_download(repo_id=repo_id, local_dir=local_dir) if __name__ == "__main__": @@ -60,10 +61,11 @@ def download_model(repo_id): urls = get_urls(args.china_mirrors) for url in urls: - filename = url.split("/")[-1] - print(f"Downloading {url}...") + download_url = url[0] if isinstance(url, list) else url + filename = url[1] if isinstance(url, list) else url.split("/")[-1] + print(f"Downloading {filename} from {download_url}...") if not os.path.exists(filename): - urllib.request.urlretrieve(url, filename) + urllib.request.urlretrieve(download_url, filename) local_dir = os.path.abspath('nltk_data') for data in ['wordnet', 'punkt', 'punkt_tab']: @@ -72,4 +74,4 @@ def download_model(repo_id): for repo_id in repos: print(f"Downloading huggingface repo {repo_id}...") - download_model(repo_id) \ No newline at end of file + download_model(repo_id) diff --git a/graphrag/entity_resolution.py b/graphrag/entity_resolution.py index 60792d381a6..97b13577595 100644 --- a/graphrag/entity_resolution.py +++ b/graphrag/entity_resolution.py @@ -94,25 +94,52 @@ async def __call__(self, graph: nx.Graph, candidate_resolution[k] = [(a, b) for a, b in itertools.combinations(v, 2) if (a in subgraph_nodes or b in subgraph_nodes) and self.is_similarity(a, b)] num_candidates = sum([len(candidates) for _, candidates in candidate_resolution.items()]) callback(msg=f"Identified {num_candidates} candidate pairs") + remain_candidates_to_resolve = num_candidates resolution_result = set() + resolution_result_lock = trio.Lock() resolution_batch_size = 100 + max_concurrent_tasks = 5 + semaphore = trio.Semaphore(max_concurrent_tasks) + + async def limited_resolve_candidate(candidate_batch, result_set, result_lock): + nonlocal remain_candidates_to_resolve, callback + async with semaphore: + try: + with trio.move_on_after(180) as cancel_scope: + await self._resolve_candidate(candidate_batch, result_set, result_lock) + remain_candidates_to_resolve = remain_candidates_to_resolve - len(candidate_batch[1]) + callback(msg=f"Resolved {len(candidate_batch[1])} pairs, {remain_candidates_to_resolve} are remained to resolve. ") + if cancel_scope.cancelled_caught: + logging.warning(f"Timeout resolving {candidate_batch}, skipping...") + remain_candidates_to_resolve = remain_candidates_to_resolve - len(candidate_batch[1]) + callback(msg=f"Fail to resolved {len(candidate_batch[1])} pairs due to timeout reason, skipped. {remain_candidates_to_resolve} are remained to resolve. ") + except Exception as e: + logging.error(f"Error resolving candidate batch: {e}") + + async with trio.open_nursery() as nursery: for candidate_resolution_i in candidate_resolution.items(): if not candidate_resolution_i[1]: continue for i in range(0, len(candidate_resolution_i[1]), resolution_batch_size): candidate_batch = candidate_resolution_i[0], candidate_resolution_i[1][i:i + resolution_batch_size] - nursery.start_soon(self._resolve_candidate, candidate_batch, resolution_result) + nursery.start_soon(limited_resolve_candidate, candidate_batch, resolution_result, resolution_result_lock) + callback(msg=f"Resolved {num_candidates} candidate pairs, {len(resolution_result)} of them are selected to merge.") change = GraphChange() connect_graph = nx.Graph() connect_graph.add_edges_from(resolution_result) + + async def limited_merge_nodes(graph, nodes, change): + async with semaphore: + await self._merge_graph_nodes(graph, nodes, change) + async with trio.open_nursery() as nursery: for sub_connect_graph in nx.connected_components(connect_graph): merging_nodes = list(sub_connect_graph) - nursery.start_soon(self._merge_graph_nodes, graph, merging_nodes, change) + nursery.start_soon(limited_merge_nodes, graph, merging_nodes, change) # Update pagerank pr = nx.pagerank(graph) @@ -124,7 +151,7 @@ async def __call__(self, graph: nx.Graph, change=change, ) - async def _resolve_candidate(self, candidate_resolution_i: tuple[str, list[tuple[str, str]]], resolution_result: set[str]): + async def _resolve_candidate(self, candidate_resolution_i: tuple[str, list[tuple[str, str]]], resolution_result: set[str], resolution_result_lock: trio.Lock): gen_conf = {"temperature": 0.5} pair_txt = [ f'When determining whether two {candidate_resolution_i[0]}s are the same, you should only focus on critical properties and overlook noisy factors.\n'] @@ -142,7 +169,16 @@ async def _resolve_candidate(self, candidate_resolution_i: tuple[str, list[tuple text = perform_variable_replacements(self._resolution_prompt, variables=variables) logging.info(f"Created resolution prompt {len(text)} bytes for {len(candidate_resolution_i[1])} entity pairs of type {candidate_resolution_i[0]}") async with chat_limiter: - response = await trio.to_thread.run_sync(lambda: self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf)) + try: + with trio.move_on_after(120) as cancel_scope: + response = await trio.to_thread.run_sync(self._chat, text, [{"role": "user", "content": "Output:"}], gen_conf) + if cancel_scope.cancelled_caught: + logging.warning("_resolve_candidate._chat timeout, skipping...") + return + except Exception as e: + logging.error(f"_resolve_candidate._chat failed: {e}") + return + logging.debug(f"_resolve_candidate chat prompt: {text}\nchat response: {response}") result = self._process_results(len(candidate_resolution_i[1]), response, self.prompt_variables.get(self._record_delimiter_key, @@ -151,8 +187,9 @@ async def _resolve_candidate(self, candidate_resolution_i: tuple[str, list[tuple DEFAULT_ENTITY_INDEX_DELIMITER), self.prompt_variables.get(self._resolution_result_delimiter_key, DEFAULT_RESOLUTION_RESULT_DELIMITER)) - for result_i in result: - resolution_result.add(candidate_resolution_i[1][result_i[0] - 1]) + async with resolution_result_lock: + for result_i in result: + resolution_result.add(candidate_resolution_i[1][result_i[0] - 1]) def _process_results( self, @@ -185,6 +222,7 @@ def is_similarity(self, a, b): if is_english(a) and is_english(b): if editdistance.eval(a, b) <= min(len(a), len(b)) // 2: return True + return False if len(set(a) & set(b)) > 1: return True diff --git a/graphrag/general/community_reports_extractor.py b/graphrag/general/community_reports_extractor.py index 14966af02cb..4d8b33bfdfa 100644 --- a/graphrag/general/community_reports_extractor.py +++ b/graphrag/general/community_reports_extractor.py @@ -89,7 +89,15 @@ async def extract_community_report(community): text = perform_variable_replacements(self._extraction_prompt, variables=prompt_variables) gen_conf = {"temperature": 0.3} async with chat_limiter: - response = await trio.to_thread.run_sync(lambda: self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf)) + try: + with trio.move_on_after(120) as cancel_scope: + response = await trio.to_thread.run_sync( self._chat, text, [{"role": "user", "content": "Output:"}], gen_conf) + if cancel_scope.cancelled_caught: + logging.warning("extract_community_report._chat timeout, skipping...") + return + except Exception as e: + logging.error(f"extract_community_report._chat failed: {e}") + return token_count += num_tokens_from_string(text + response) response = re.sub(r"^[^\{]*", "", response) response = re.sub(r"[^\}]*$", "", response) diff --git a/graphrag/general/index.py b/graphrag/general/index.py index 4aa73ac40bc..8c98636791f 100644 --- a/graphrag/general/index.py +++ b/graphrag/general/index.py @@ -57,7 +57,7 @@ async def run_graphrag( subgraph = await generate_subgraph( LightKGExt - if row["kb_parser_config"]["graphrag"]["method"] != "general" + if "method" not in row["kb_parser_config"]["graphrag"] or row["kb_parser_config"]["graphrag"]["method"] != "general" else GeneralKGExt, tenant_id, kb_id, @@ -166,7 +166,7 @@ async def generate_subgraph( ) if ignored_rels: callback(msg=f"ignored {ignored_rels} relations due to missing entities.") - tidy_graph(subgraph, callback) + tidy_graph(subgraph, callback, check_attribute=False) subgraph.graph["source_id"] = [doc_id] chunk = { diff --git a/graphrag/utils.py b/graphrag/utils.py index 151b5aab74a..81df2a24b4d 100644 --- a/graphrag/utils.py +++ b/graphrag/utils.py @@ -157,30 +157,32 @@ def set_tags_to_cache(kb_ids, tags): k = hasher.hexdigest() REDIS_CONN.set(k, json.dumps(tags).encode("utf-8"), 600) -def tidy_graph(graph: nx.Graph, callback): +def tidy_graph(graph: nx.Graph, callback, check_attribute: bool = True): """ Ensure all nodes and edges in the graph have some essential attribute. """ - def is_valid_node(node_attrs: dict) -> bool: + def is_valid_item(node_attrs: dict) -> bool: valid_node = True for attr in ["description", "source_id"]: if attr not in node_attrs: valid_node = False break return valid_node - purged_nodes = [] - for node, node_attrs in graph.nodes(data=True): - if not is_valid_node(node_attrs): - purged_nodes.append(node) - for node in purged_nodes: - graph.remove_node(node) - if purged_nodes and callback: - callback(msg=f"Purged {len(purged_nodes)} nodes from graph due to missing essential attributes.") + if check_attribute: + purged_nodes = [] + for node, node_attrs in graph.nodes(data=True): + if not is_valid_item(node_attrs): + purged_nodes.append(node) + for node in purged_nodes: + graph.remove_node(node) + if purged_nodes and callback: + callback(msg=f"Purged {len(purged_nodes)} nodes from graph due to missing essential attributes.") purged_edges = [] for source, target, attr in graph.edges(data=True): - if not is_valid_node(attr): - purged_edges.append((source, target)) + if check_attribute: + if not is_valid_item(attr): + purged_edges.append((source, target)) if "keywords" not in attr: attr["keywords"] = [] for source, target in purged_edges: diff --git a/helm/values.yaml b/helm/values.yaml index 9b8477d009c..2d0cbee168a 100644 --- a/helm/values.yaml +++ b/helm/values.yaml @@ -27,13 +27,13 @@ env: REDIS_PASSWORD: infini_rag_flow_helm # The RAGFlow Docker image to download. - # Defaults to the v0.19.0-slim edition, which is the RAGFlow Docker image without embedding models. - RAGFLOW_IMAGE: infiniflow/ragflow:v0.19.0-slim + # Defaults to the v0.19.1-slim edition, which is the RAGFlow Docker image without embedding models. + RAGFLOW_IMAGE: infiniflow/ragflow:v0.19.1-slim # # To download the RAGFlow Docker image with embedding models, uncomment the following line instead: - # RAGFLOW_IMAGE: infiniflow/ragflow:v0.19.0 + # RAGFLOW_IMAGE: infiniflow/ragflow:v0.19.1 # - # The Docker image of the v0.19.0 edition includes: + # The Docker image of the v0.19.1 edition includes: # - Built-in embedding models: # - BAAI/bge-large-zh-v1.5 # - BAAI/bge-reranker-v2-m3 @@ -62,6 +62,12 @@ env: # MAX_CONTENT_LENGTH: "134217728" # After making the change, ensure you update `client_max_body_size` in nginx/nginx.conf correspondingly. + # The number of document chunks processed in a single batch during document parsing. + DOC_BULK_SIZE: 4 + + # The number of text chunks processed in a single batch during embedding vectorization. + EMBEDDING_BATCH_SIZE: 16 + ragflow: deployment: strategy: diff --git a/mcp/client/client.py b/mcp/client/client.py index 2f3ad81ce08..3c54ea03095 100644 --- a/mcp/client/client.py +++ b/mcp/client/client.py @@ -23,6 +23,9 @@ async def main(): try: # To access RAGFlow server in `host` mode, you need to attach `api_key` for each request to indicate identification. # async with sse_client("http://localhost:9382/sse", headers={"api_key": "ragflow-IyMGI1ZDhjMTA2ZTExZjBiYTMyMGQ4Zm"}) as streams: + # Or follow the requirements of OAuth 2.1 Section 5 with Authorization header + # async with sse_client("http://localhost:9382/sse", headers={"Authorization": "Bearer ragflow-IyMGI1ZDhjMTA2ZTExZjBiYTMyMGQ4Zm"}) as streams: + async with sse_client("http://localhost:9382/sse") as streams: async with ClientSession( streams[0], diff --git a/mcp/server/server.py b/mcp/server/server.py index de87c221d98..743cd16f94a 100644 --- a/mcp/server/server.py +++ b/mcp/server/server.py @@ -17,6 +17,7 @@ import json from collections.abc import AsyncIterator from contextlib import asynccontextmanager +from functools import wraps import requests from starlette.applications import Starlette @@ -127,22 +128,45 @@ async def server_lifespan(server: Server) -> AsyncIterator[dict]: sse = SseServerTransport("/messages/") -@app.list_tools() -async def list_tools() -> list[types.Tool]: - ctx = app.request_context - ragflow_ctx = ctx.lifespan_context["ragflow_ctx"] - if not ragflow_ctx: - raise ValueError("Get RAGFlow Context failed") - connector = ragflow_ctx.conn +def with_api_key(required=True): + def decorator(func): + @wraps(func) + async def wrapper(*args, **kwargs): + ctx = app.request_context + ragflow_ctx = ctx.lifespan_context.get("ragflow_ctx") + if not ragflow_ctx: + raise ValueError("Get RAGFlow Context failed") + + connector = ragflow_ctx.conn + + if MODE == LaunchMode.HOST: + headers = ctx.session._init_options.capabilities.experimental.get("headers", {}) + token = None + + # lower case here, because of Starlette conversion + auth = headers.get("authorization", "") + if auth.startswith("Bearer "): + token = auth.removeprefix("Bearer ").strip() + elif "api_key" in headers: + token = headers["api_key"] + + if required and not token: + raise ValueError("RAGFlow API key or Bearer token is required.") + + connector.bind_api_key(token) + else: + connector.bind_api_key(HOST_API_KEY) + + return await func(*args, connector=connector, **kwargs) + + return wrapper + + return decorator - if MODE == LaunchMode.HOST: - api_key = ctx.session._init_options.capabilities.experimental["headers"]["api_key"] - if not api_key: - raise ValueError("RAGFlow API_KEY is required.") - else: - api_key = HOST_API_KEY - connector.bind_api_key(api_key) +@app.list_tools() +@with_api_key(required=True) +async def list_tools(*, connector) -> list[types.Tool]: dataset_description = connector.list_datasets() return [ @@ -152,7 +176,17 @@ async def list_tools() -> list[types.Tool]: + dataset_description, inputSchema={ "type": "object", - "properties": {"dataset_ids": {"type": "array", "items": {"type": "string"}}, "document_ids": {"type": "array", "items": {"type": "string"}}, "question": {"type": "string"}}, + "properties": { + "dataset_ids": { + "type": "array", + "items": {"type": "string"}, + }, + "document_ids": { + "type": "array", + "items": {"type": "string"}, + }, + "question": {"type": "string"}, + }, "required": ["dataset_ids", "question"], }, ), @@ -160,24 +194,15 @@ async def list_tools() -> list[types.Tool]: @app.call_tool() -async def call_tool(name: str, arguments: dict) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]: - ctx = app.request_context - ragflow_ctx = ctx.lifespan_context["ragflow_ctx"] - if not ragflow_ctx: - raise ValueError("Get RAGFlow Context failed") - connector = ragflow_ctx.conn - - if MODE == LaunchMode.HOST: - api_key = ctx.session._init_options.capabilities.experimental["headers"]["api_key"] - if not api_key: - raise ValueError("RAGFlow API_KEY is required.") - else: - api_key = HOST_API_KEY - connector.bind_api_key(api_key) - +@with_api_key(required=True) +async def call_tool(name: str, arguments: dict, *, connector) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]: if name == "ragflow_retrieval": document_ids = arguments.get("document_ids", []) - return connector.retrieval(dataset_ids=arguments["dataset_ids"], document_ids=document_ids, question=arguments["question"]) + return connector.retrieval( + dataset_ids=arguments["dataset_ids"], + document_ids=document_ids, + question=arguments["question"], + ) raise ValueError(f"Tool not found: {name}") @@ -188,25 +213,34 @@ async def handle_sse(request): class AuthMiddleware(BaseHTTPMiddleware): async def dispatch(self, request, call_next): + # Authentication is deferred, will be handled by RAGFlow core service. if request.url.path.startswith("/sse") or request.url.path.startswith("/messages"): - api_key = request.headers.get("api_key") - if not api_key: - return JSONResponse({"error": "Missing unauthorization header"}, status_code=401) - return await call_next(request) + token = None + auth_header = request.headers.get("Authorization") + if auth_header and auth_header.startswith("Bearer "): + token = auth_header.removeprefix("Bearer ").strip() + elif request.headers.get("api_key"): + token = request.headers["api_key"] + + if not token: + return JSONResponse({"error": "Missing or invalid authorization header"}, status_code=401) + return await call_next(request) -middleware = None -if MODE == LaunchMode.HOST: - middleware = [Middleware(AuthMiddleware)] -starlette_app = Starlette( - debug=True, - routes=[ - Route("/sse", endpoint=handle_sse), - Mount("/messages/", app=sse.handle_post_message), - ], - middleware=middleware, -) +def create_starlette_app(): + middleware = None + if MODE == LaunchMode.HOST: + middleware = [Middleware(AuthMiddleware)] + + return Starlette( + debug=True, + routes=[ + Route("/sse", endpoint=handle_sse), + Mount("/messages/", app=sse.handle_post_message), + ], + middleware=middleware, + ) if __name__ == "__main__": @@ -236,7 +270,7 @@ async def dispatch(self, request, call_next): default="self-host", help="Launch mode options:\n" " * self-host: Launches an MCP server to access a specific tenant space. The 'api_key' argument is required.\n" - " * host: Launches an MCP server that allows users to access their own spaces. Each request must include a header " + " * host: Launches an MCP server that allows users to access their own spaces. Each request must include a Authorization header " "indicating the user's identification.", ) parser.add_argument("--api_key", type=str, default="", help="RAGFlow MCP SERVER HOST API KEY") @@ -268,7 +302,7 @@ async def dispatch(self, request, call_next): print(f"MCP base_url: {BASE_URL}", flush=True) uvicorn.run( - starlette_app, + create_starlette_app(), host=HOST, port=int(PORT), ) diff --git a/pyproject.toml b/pyproject.toml index f9bf62697d3..ded458b98b4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ragflow" -version = "0.19.0" +version = "0.19.1" description = "[RAGFlow](https://ragflow.io/) is an open-source RAG (Retrieval-Augmented Generation) engine based on deep document understanding. It offers a streamlined RAG workflow for businesses of any scale, combining LLM (Large Language Models) to provide truthful question-answering capabilities, backed by well-founded citations from various complex formatted data." authors = [{ name = "Zhichang Yu", email = "yuzhichang@gmail.com" }] license-files = ["LICENSE"] @@ -62,6 +62,7 @@ dependencies = [ "opencv-python==4.10.0.84", "opencv-python-headless==4.10.0.84", "openpyxl>=3.1.0,<4.0.0", + "opendal>=0.45.0,<0.46.0", "ormsgpack==1.5.0", "pandas>=2.2.0,<3.0.0", "pdfplumber==0.10.4", @@ -71,8 +72,8 @@ dependencies = [ "psycopg2-binary==2.9.9", "pyclipper==1.3.0.post5", "pycryptodomex==3.20.0", + "pymysql>=1.1.1,<2.0.0", "pypdf>=5.0.0,<6.0.0", - "pytest>=8.3.0,<9.0.0", "python-dotenv==1.0.1", "python-dateutil==2.8.2", "python-pptx>=1.0.2,<2.0.0", @@ -85,6 +86,7 @@ dependencies = [ "replicate==0.31.0", "roman-numbers==1.0.2", "ruamel-base==1.0.0", + "ruamel-yaml>=0.18.6,<0.19.0", "scholarly==1.7.11", "scikit-learn==1.5.0", "selenium==4.22.0", @@ -106,7 +108,7 @@ dependencies = [ "werkzeug==3.0.6", "wikipedia==1.4.0", "word2number==1.1", - "xgboost==1.5.0", + "xgboost==1.6.0", "xpinyin==0.7.6", "yfinance==0.1.96", "zhipuai==2.0.1", @@ -136,22 +138,49 @@ full = [ "fastembed-gpu>=0.3.6,<0.4.0; sys_platform != 'darwin' and platform_machine == 'x86_64'", "flagembedding==1.2.10", "torch>=2.5.0,<3.0.0", - "transformers>=4.35.0,<5.0.0" + "transformers>=4.35.0,<5.0.0", +] + +[dependency-groups] +test = [ + "hypothesis>=6.132.0", + "openpyxl>=3.1.5", + "pillow>=10.4.0", + "pytest>=8.3.5", + "python-docx>=1.1.2", + "python-pptx>=1.0.2", + "reportlab>=4.4.1", + "requests>=2.32.2", + "requests-toolbelt>=1.0.0", ] [tool.setuptools] -packages = ['agent', 'agentic_reasoning', 'api', 'deepdoc', 'graphrag', 'intergrations.chatgpt-on-wechat.plugins', 'mcp.server', 'rag', 'sdk.python.ragflow_sdk'] +packages = [ + 'agent', + 'agentic_reasoning', + 'api', + 'deepdoc', + 'graphrag', + 'intergrations.chatgpt-on-wechat.plugins', + 'mcp.server', + 'rag', + 'sdk.python.ragflow_sdk', +] [[tool.uv.index]] url = "https://mirrors.aliyun.com/pypi/simple" [tool.ruff] line-length = 200 -exclude = [ - ".venv", - "rag/svr/discord_svr.py", -] +exclude = [".venv", "rag/svr/discord_svr.py"] [tool.ruff.lint] extend-select = ["ASYNC", "ASYNC1"] ignore = ["E402"] + +[tool.pytest.ini_options] +markers = [ + "p1: high priority test cases", + "p2: medium priority test cases", + "p3: low priority test cases", +] diff --git a/rag/app/naive.py b/rag/app/naive.py index 28e3bbbcc91..809da121d30 100644 --- a/rag/app/naive.py +++ b/rag/app/naive.py @@ -29,10 +29,9 @@ from api.db import LLMType from api.db.services.llm_service import LLMBundle from deepdoc.parser import DocxParser, ExcelParser, HtmlParser, JsonParser, MarkdownParser, PdfParser, TxtParser -from deepdoc.parser.figure_parser import VisionFigureParser, vision_figure_parser_figure_data_wraper +from deepdoc.parser.figure_parser import VisionFigureParser, vision_figure_parser_figure_data_wrapper from deepdoc.parser.pdf_parser import PlainParser, VisionParser from rag.nlp import concat_img, find_codec, naive_merge, naive_merge_with_images, naive_merge_docx, rag_tokenizer, tokenize_chunks, tokenize_chunks_with_images, tokenize_table -from rag.utils import num_tokens_from_string class Docx(DocxParser): @@ -335,17 +334,13 @@ def __call__(self, filename, binary=None): sections = [] tbls = [] for sec in remainder.split("\n"): - if num_tokens_from_string(sec) > 3 * self.chunk_token_num: - sections.append((sec[:int(len(sec) / 2)], "")) - sections.append((sec[int(len(sec) / 2):], "")) + if sec.strip().find("#") == 0: + sections.append((sec, "")) + elif sections and sections[-1][0].strip().find("#") == 0: + sec_, _ = sections.pop(-1) + sections.append((sec_ + "\n" + sec, "")) else: - if sec.strip().find("#") == 0: - sections.append((sec, "")) - elif sections and sections[-1][0].strip().find("#") == 0: - sec_, _ = sections.pop(-1) - sections.append((sec_ + "\n" + sec, "")) - else: - sections.append((sec, "")) + sections.append((sec, "")) for table in tables: tbls.append(((None, markdown(table, extensions=['markdown.extensions.tables'])), "")) return sections, tbls @@ -384,7 +379,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, sections, tables = Docx()(filename, binary) if vision_model: - figures_data = vision_figure_parser_figure_data_wraper(sections) + figures_data = vision_figure_parser_figure_data_wrapper(sections) try: docx_vision_parser = VisionFigureParser(vision_model=vision_model, figures_data=figures_data, **kwargs) boosted_figures = docx_vision_parser(callback=callback) diff --git a/rag/app/presentation.py b/rag/app/presentation.py index d3e4f021f1b..fd32c261bc4 100644 --- a/rag/app/presentation.py +++ b/rag/app/presentation.py @@ -20,6 +20,9 @@ from PIL import Image +from api.db import LLMType +from api.db.services.llm_service import LLMBundle +from deepdoc.parser.pdf_parser import VisionParser from rag.nlp import tokenize, is_english from rag.nlp import rag_tokenizer from deepdoc.parser import PdfParser, PptParser, PlainParser @@ -123,11 +126,21 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, res.append(d) return res elif re.search(r"\.pdf$", filename, re.IGNORECASE): - pdf_parser = Pdf() - if kwargs.get("layout_recognize", "DeepDOC") == "Plain Text": + layout_recognizer = kwargs.get("layout_recognize", "DeepDOC") + if layout_recognizer == "DeepDOC": + pdf_parser = Pdf() + sections = pdf_parser(filename, binary, from_page=from_page, to_page=to_page, callback=callback) + elif layout_recognizer == "Plain Text": pdf_parser = PlainParser() - for pn, (txt, img) in enumerate(pdf_parser(filename, binary, - from_page=from_page, to_page=to_page, callback=callback)): + sections, _ = pdf_parser(filename, binary, from_page=from_page, to_page=to_page, callback=callback) + else: + vision_model = LLMBundle(kwargs["tenant_id"], LLMType.IMAGE2TEXT, llm_name=layout_recognizer, lang=lang) + pdf_parser = VisionParser(vision_model=vision_model, **kwargs) + sections, _ = pdf_parser(filename if not binary else binary, from_page=from_page, to_page=to_page, + callback=callback) + + callback(0.8, "Finish parsing.") + for pn, (txt, img) in enumerate(sections): d = copy.deepcopy(doc) pn += from_page if img: diff --git a/rag/app/table.py b/rag/app/table.py index 14facd16909..450cd628063 100644 --- a/rag/app/table.py +++ b/rag/app/table.py @@ -20,6 +20,8 @@ from xpinyin import Pinyin import numpy as np import pandas as pd +from collections import Counter + # from openpyxl import load_workbook, Workbook from dateutil.parser import parse as datetime_parse @@ -30,8 +32,7 @@ class Excel(ExcelParser): - def __call__(self, fnm, binary=None, from_page=0, - to_page=10000000000, callback=None): + def __call__(self, fnm, binary=None, from_page=0, to_page=10000000000, callback=None): if not binary: wb = Excel._load_excel_to_workbook(fnm) else: @@ -49,10 +50,7 @@ def __call__(self, fnm, binary=None, from_page=0, continue headers = [cell.value for cell in rows[0]] missed = set([i for i, h in enumerate(headers) if h is None]) - headers = [ - cell.value for i, - cell in enumerate( - rows[0]) if i not in missed] + headers = [cell.value for i, cell in enumerate(rows[0]) if i not in missed] if not headers: continue data = [] @@ -62,9 +60,7 @@ def __call__(self, fnm, binary=None, from_page=0, continue if rn - 1 >= to_page: break - row = [ - cell.value for ii, - cell in enumerate(r) if ii not in missed] + row = [cell.value for ii, cell in enumerate(r) if ii not in missed] if len(row) != len(headers): fails.append(str(i)) continue @@ -74,8 +70,7 @@ def __call__(self, fnm, binary=None, from_page=0, continue res.append(pd.DataFrame(np.array(data), columns=headers)) - callback(0.3, ("Extract records: {}~{}".format(from_page + 1, min(to_page, from_page + rn)) + ( - f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else ""))) + callback(0.3, ("Extract records: {}~{}".format(from_page + 1, min(to_page, from_page + rn)) + (f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else ""))) return res @@ -87,8 +82,7 @@ def trans_datatime(s): def trans_bool(s): - if re.match(r"(true|yes|是|\*|✓|✔|☑|✅|√)$", - str(s).strip(), flags=re.IGNORECASE): + if re.match(r"(true|yes|是|\*|✓|✔|☑|✅|√)$", str(s).strip(), flags=re.IGNORECASE): return "yes" if re.match(r"(false|no|否|⍻|×)$", str(s).strip(), flags=re.IGNORECASE): return "no" @@ -97,8 +91,7 @@ def trans_bool(s): def column_data_type(arr): arr = list(arr) counts = {"int": 0, "float": 0, "text": 0, "datetime": 0, "bool": 0} - trans = {t: f for f, t in - [(int, "int"), (float, "float"), (trans_datatime, "datetime"), (trans_bool, "bool"), (str, "text")]} + trans = {t: f for f, t in [(int, "int"), (float, "float"), (trans_datatime, "datetime"), (trans_bool, "bool"), (str, "text")]} for a in arr: if a is None: continue @@ -127,31 +120,25 @@ def column_data_type(arr): return arr, ty -def chunk(filename, binary=None, from_page=0, to_page=10000000000, - lang="Chinese", callback=None, **kwargs): +def chunk(filename, binary=None, from_page=0, to_page=10000000000, lang="Chinese", callback=None, **kwargs): """ - Excel and csv(txt) format files are supported. - For csv or txt file, the delimiter between columns is TAB. - The first line must be column headers. - Column headers must be meaningful terms inorder to make our NLP model understanding. - It's good to enumerate some synonyms using slash '/' to separate, and even better to - enumerate values using brackets like 'gender/sex(male, female)'. - Here are some examples for headers: - 1. supplier/vendor\tcolor(yellow, red, brown)\tgender/sex(male, female)\tsize(M,L,XL,XXL) - 2. 姓名/名字\t电话/手机/微信\t最高学历(高中,职高,硕士,本科,博士,初中,中技,中专,专科,专升本,MPA,MBA,EMBA) - - Every row in table will be treated as a chunk. + Excel and csv(txt) format files are supported. + For csv or txt file, the delimiter between columns is TAB. + The first line must be column headers. + Column headers must be meaningful terms inorder to make our NLP model understanding. + It's good to enumerate some synonyms using slash '/' to separate, and even better to + enumerate values using brackets like 'gender/sex(male, female)'. + Here are some examples for headers: + 1. supplier/vendor\tcolor(yellow, red, brown)\tgender/sex(male, female)\tsize(M,L,XL,XXL) + 2. 姓名/名字\t电话/手机/微信\t最高学历(高中,职高,硕士,本科,博士,初中,中技,中专,专科,专升本,MPA,MBA,EMBA) + + Every row in table will be treated as a chunk. """ if re.search(r"\.xlsx?$", filename, re.IGNORECASE): callback(0.1, "Start to parse.") excel_parser = Excel() - dfs = excel_parser( - filename, - binary, - from_page=from_page, - to_page=to_page, - callback=callback) + dfs = excel_parser(filename, binary, from_page=from_page, to_page=to_page, callback=callback) elif re.search(r"\.(txt|csv)$", filename, re.IGNORECASE): callback(0.1, "Start to parse.") txt = get_text(filename, binary) @@ -170,40 +157,29 @@ def chunk(filename, binary=None, from_page=0, to_page=10000000000, continue rows.append(row) - callback(0.3, ("Extract records: {}~{}".format(from_page, min(len(lines), to_page)) + ( - f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else ""))) + callback(0.3, ("Extract records: {}~{}".format(from_page, min(len(lines), to_page)) + (f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else ""))) dfs = [pd.DataFrame(np.array(rows), columns=headers)] else: - raise NotImplementedError( - "file type not supported yet(excel, text, csv supported)") + raise NotImplementedError("file type not supported yet(excel, text, csv supported)") res = [] PY = Pinyin() - fieds_map = { - "text": "_tks", - "int": "_long", - "keyword": "_kwd", - "float": "_flt", - "datetime": "_dt", - "bool": "_kwd"} + fieds_map = {"text": "_tks", "int": "_long", "keyword": "_kwd", "float": "_flt", "datetime": "_dt", "bool": "_kwd"} for df in dfs: for n in ["id", "_id", "index", "idx"]: if n in df.columns: del df[n] clmns = df.columns.values if len(clmns) != len(set(clmns)): - duplicates = [col for col in clmns if list(clmns).count(col) > 1] - raise ValueError(f"Duplicate column names detected: {set(duplicates)}") + col_counts = Counter(clmns) + duplicates = [col for col, count in col_counts.items() if count > 1] + if duplicates: + raise ValueError(f"Duplicate column names detected: {duplicates}\nFrom: {clmns}") + txts = list(copy.deepcopy(clmns)) - py_clmns = [ - PY.get_pinyins( - re.sub( - r"(/.*|([^()]+?)|\([^()]+?\))", - "", - str(n)), - '_')[0] for n in clmns] + py_clmns = [PY.get_pinyins(re.sub(r"(/.*|([^()]+?)|\([^()]+?\))", "", str(n)), "_")[0] for n in clmns] clmn_tys = [] for j in range(len(clmns)): cln, ty = column_data_type(df[clmns[j]]) @@ -211,15 +187,11 @@ def chunk(filename, binary=None, from_page=0, to_page=10000000000, df[clmns[j]] = cln if ty == "text": txts.extend([str(c) for c in cln if c]) - clmns_map = [(py_clmns[i].lower() + fieds_map[clmn_tys[i]], str(clmns[i]).replace("_", " ")) - for i in range(len(clmns))] + clmns_map = [(py_clmns[i].lower() + fieds_map[clmn_tys[i]], str(clmns[i]).replace("_", " ")) for i in range(len(clmns))] eng = lang.lower() == "english" # is_english(txts) for ii, row in df.iterrows(): - d = { - "docnm_kwd": filename, - "title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", filename)) - } + d = {"docnm_kwd": filename, "title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", filename))} row_txt = [] for j in range(len(clmns)): if row[clmns[j]] is None: @@ -229,16 +201,14 @@ def chunk(filename, binary=None, from_page=0, to_page=10000000000, if not isinstance(row[clmns[j]], pd.Series) and pd.isna(row[clmns[j]]): continue fld = clmns_map[j][0] - d[fld] = row[clmns[j]] if clmn_tys[j] != "text" else rag_tokenizer.tokenize( - row[clmns[j]]) + d[fld] = row[clmns[j]] if clmn_tys[j] != "text" else rag_tokenizer.tokenize(row[clmns[j]]) row_txt.append("{}:{}".format(clmns[j], row[clmns[j]])) if not row_txt: continue tokenize(d, "; ".join(row_txt), eng) res.append(d) - KnowledgebaseService.update_parser_config( - kwargs["kb_id"], {"field_map": {k: v for k, v in clmns_map}}) + KnowledgebaseService.update_parser_config(kwargs["kb_id"], {"field_map": {k: v for k, v in clmns_map}}) callback(0.35, "") return res diff --git a/rag/llm/__init__.py b/rag/llm/__init__.py index ce78e2914fb..323ac9502f8 100644 --- a/rag/llm/__init__.py +++ b/rag/llm/__init__.py @@ -45,6 +45,7 @@ HuggingFaceEmbed, VolcEngineEmbed, GPUStackEmbed, + NovitaEmbed ) from .chat_model import ( GptTurbo, @@ -128,6 +129,7 @@ QWenRerank, GPUStackRerank, HuggingfaceRerank, + NovitaRerank ) from .sequence2txt_model import ( @@ -180,6 +182,7 @@ "HuggingFace": HuggingFaceEmbed, "VolcEngine": VolcEngineEmbed, "GPUStack": GPUStackEmbed, + "NovitaAI": NovitaEmbed } CvModel = { @@ -267,6 +270,7 @@ "Tongyi-Qianwen": QWenRerank, "GPUStack": GPUStackRerank, "HuggingFace": HuggingfaceRerank, + "NovitaAI": NovitaRerank } Seq2txtModel = { diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py index c9c8f88848c..fbef3478180 100644 --- a/rag/llm/chat_model.py +++ b/rag/llm/chat_model.py @@ -21,8 +21,12 @@ import re import time from abc import ABC +from copy import deepcopy +from http import HTTPStatus from typing import Any, Protocol +from urllib.parse import urljoin +import json_repair import openai import requests from dashscope import Generation @@ -57,18 +61,19 @@ def tool_call(self, name: str, arguments: dict[str, Any]) -> str: ... class Base(ABC): - def __init__(self, key, model_name, base_url): + def __init__(self, key, model_name, base_url, **kwargs): timeout = int(os.environ.get("LM_TIMEOUT_SECONDS", 600)) self.client = OpenAI(api_key=key, base_url=base_url, timeout=timeout) self.model_name = model_name # Configure retry parameters - self.max_retries = int(os.environ.get("LLM_MAX_RETRIES", 5)) - self.base_delay = float(os.environ.get("LLM_BASE_DELAY", 2.0)) + self.max_retries = kwargs.get("max_retries", int(os.environ.get("LLM_MAX_RETRIES", 5))) + self.base_delay = kwargs.get("retry_interval", float(os.environ.get("LLM_BASE_DELAY", 2.0))) + self.max_rounds = kwargs.get("max_rounds", 5) self.is_tools = False - def _get_delay(self, attempt): + def _get_delay(self): """Calculate retry delay time""" - return self.base_delay * (2**attempt) + random.uniform(0, 0.5) + return self.base_delay + random.uniform(0, 0.5) def _classify_error(self, error): """Classify error based on error message content""" @@ -95,6 +100,47 @@ def _classify_error(self, error): else: return ERROR_GENERIC + def _clean_conf(self, gen_conf): + if "max_tokens" in gen_conf: + del gen_conf["max_tokens"] + return gen_conf + + def _chat(self, history, gen_conf): + response = self.client.chat.completions.create(model=self.model_name, messages=history, **gen_conf) + + if any([not response.choices, not response.choices[0].message, not response.choices[0].message.content]): + return "", 0 + ans = response.choices[0].message.content.strip() + if response.choices[0].finish_reason == "length": + if is_chinese(ans): + ans += LENGTH_NOTIFICATION_CN + else: + ans += LENGTH_NOTIFICATION_EN + return ans, self.total_token_count(response) + + def _length_stop(self, ans): + if is_chinese([ans]): + return ans + LENGTH_NOTIFICATION_CN + return ans + LENGTH_NOTIFICATION_EN + + def _exceptions(self, e, attempt): + logging.exception("OpenAI cat_with_tools") + # Classify the error + error_code = self._classify_error(e) + + # Check if it's a rate limit error or server error and not the last attempt + should_retry = (error_code == ERROR_RATE_LIMIT or error_code == ERROR_SERVER) and attempt < self.max_retries + + if should_retry: + delay = self._get_delay() + logging.warning(f"Error: {error_code}. Retrying in {delay:.2f} seconds... (Attempt {attempt + 1}/{self.max_retries})") + time.sleep(delay) + else: + # For non-rate limit errors or the last attempt, return an error message + if attempt == self.max_retries: + error_code = ERROR_MAX_RETRIES + return f"{ERROR_PREFIX}: {error_code} - {str(e)}" + def bind_tools(self, toolcall_session, tools): if not (toolcall_session and tools): return @@ -103,120 +149,63 @@ def bind_tools(self, toolcall_session, tools): self.tools = tools def chat_with_tools(self, system: str, history: list, gen_conf: dict): - if "max_tokens" in gen_conf: - del gen_conf["max_tokens"] - - tools = self.tools - + gen_conf = self._clean_conf() if system: history.insert(0, {"role": "system", "content": system}) ans = "" tk_count = 0 + hist = deepcopy(history) # Implement exponential backoff retry strategy - for attempt in range(self.max_retries): - try: - response = self.client.chat.completions.create(model=self.model_name, messages=history, tools=tools, **gen_conf) - - assistant_output = response.choices[0].message - if not ans and "tool_calls" not in assistant_output and "reasoning_content" in assistant_output: - ans += "" + ans + "" - ans += response.choices[0].message.content - - if not response.choices[0].message.tool_calls: + for attempt in range(self.max_retries+1): + history = hist + for _ in range(self.max_rounds*2): + try: + response = self.client.chat.completions.create(model=self.model_name, messages=history, tools=self.tools, **gen_conf) tk_count += self.total_token_count(response) - if response.choices[0].finish_reason == "length": - if is_chinese([ans]): - ans += LENGTH_NOTIFICATION_CN - else: - ans += LENGTH_NOTIFICATION_EN - return ans, tk_count + if any([not response.choices, not response.choices[0].message, not response.choices[0].message.content]): + raise Exception("500 response structure error.") - tk_count += self.total_token_count(response) - history.append(assistant_output) + if not hasattr(response.choices[0].message, "tool_calls") or not response.choices[0].message.tool_calls: + if hasattr(response.choices[0].message, "reasoning_content") and response.choices[0].message.reasoning_content: + ans += "" + response.choices[0].message.reasoning_content + "" - for tool_call in response.choices[0].message.tool_calls: - name = tool_call.function.name - args = json.loads(tool_call.function.arguments) - - tool_response = self.toolcall_session.tool_call(name, args) - # if tool_response.choices[0].finish_reason == "length": - # if is_chinese(ans): - # ans += LENGTH_NOTIFICATION_CN - # else: - # ans += LENGTH_NOTIFICATION_EN - # return ans, tk_count + self.total_token_count(tool_response) - history.append({"role": "tool", "tool_call_id": tool_call.id, "content": str(tool_response)}) - - final_response = self.client.chat.completions.create(model=self.model_name, messages=history, tools=tools, **gen_conf) - assistant_output = final_response.choices[0].message - if "tool_calls" not in assistant_output and "reasoning_content" in assistant_output: - ans += "" + ans + "" - ans += final_response.choices[0].message.content - if final_response.choices[0].finish_reason == "length": - tk_count += self.total_token_count(response) - if is_chinese([ans]): - ans += LENGTH_NOTIFICATION_CN - else: - ans += LENGTH_NOTIFICATION_EN - return ans, tk_count - return ans, tk_count + ans += response.choices[0].message.content + if response.choices[0].finish_reason == "length": + ans = self._length_stop(ans) - except Exception as e: - logging.exception("OpenAI cat_with_tools") - # Classify the error - error_code = self._classify_error(e) + return ans, tk_count - # Check if it's a rate limit error or server error and not the last attempt - should_retry = (error_code == ERROR_RATE_LIMIT or error_code == ERROR_SERVER) and attempt < self.max_retries - 1 + for tool_call in response.choices[0].message.tool_calls: + name = tool_call.function.name + try: + args = json_repair.loads(tool_call.function.arguments) + tool_response = self.toolcall_session.tool_call(name, args) + history.append({"role": "tool", "tool_call_id": tool_call.id, "content": str(tool_response)}) + except Exception as e: + history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)}) - if should_retry: - delay = self._get_delay(attempt) - logging.warning(f"Error: {error_code}. Retrying in {delay:.2f} seconds... (Attempt {attempt + 1}/{self.max_retries})") - time.sleep(delay) - else: - # For non-rate limit errors or the last attempt, return an error message - if attempt == self.max_retries - 1: - error_code = ERROR_MAX_RETRIES - return f"{ERROR_PREFIX}: {error_code} - {str(e)}", 0 + + except Exception as e: + e = self._exceptions(e, attempt) + if e: + return e, tk_count + assert False, "Shouldn't be here." def chat(self, system, history, gen_conf): if system: history.insert(0, {"role": "system", "content": system}) - if "max_tokens" in gen_conf: - del gen_conf["max_tokens"] + gen_conf = self._clean_conf(gen_conf) # Implement exponential backoff retry strategy - for attempt in range(self.max_retries): + for attempt in range(self.max_retries+1): try: - response = self.client.chat.completions.create(model=self.model_name, messages=history, **gen_conf) - - if any([not response.choices, not response.choices[0].message, not response.choices[0].message.content]): - return "", 0 - ans = response.choices[0].message.content.strip() - if response.choices[0].finish_reason == "length": - if is_chinese(ans): - ans += LENGTH_NOTIFICATION_CN - else: - ans += LENGTH_NOTIFICATION_EN - return ans, self.total_token_count(response) + return self._chat(history, gen_conf) except Exception as e: - logging.exception("chat_model.Base.chat got exception") - # Classify the error - error_code = self._classify_error(e) - - # Check if it's a rate limit error or server error and not the last attempt - should_retry = (error_code == ERROR_RATE_LIMIT or error_code == ERROR_SERVER) and attempt < self.max_retries - 1 - - if should_retry: - delay = self._get_delay(attempt) - logging.warning(f"Error: {error_code}. Retrying in {delay:.2f} seconds... (Attempt {attempt + 1}/{self.max_retries})") - time.sleep(delay) - else: - # For non-rate limit errors or the last attempt, return an error message - if attempt == self.max_retries - 1: - error_code = ERROR_MAX_RETRIES - return f"{ERROR_PREFIX}: {error_code} - {str(e)}", 0 + e = self._exceptions(e, attempt) + if e: + return e, 0 + assert False, "Shouldn't be here." def _wrap_toolcall_message(self, stream): final_tool_calls = {} @@ -237,41 +226,48 @@ def chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict): del gen_conf["max_tokens"] tools = self.tools - if system: history.insert(0, {"role": "system", "content": system}) - ans = "" total_tokens = 0 - reasoning_start = False - finish_completion = False - final_tool_calls = {} - try: - response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, **gen_conf) - while not finish_completion: - for resp in response: - if resp.choices[0].delta.tool_calls: - for tool_call in resp.choices[0].delta.tool_calls or []: - index = tool_call.index - - if index not in final_tool_calls: - final_tool_calls[index] = tool_call - else: - final_tool_calls[index].function.arguments += tool_call.function.arguments - else: - if not resp.choices: + hist = deepcopy(history) + # Implement exponential backoff retry strategy + for attempt in range(self.max_retries+1): + history = hist + for _ in range(self.max_rounds*2): + reasoning_start = False + try: + response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, **gen_conf) + final_tool_calls = {} + answer = "" + for resp in response: + if resp.choices[0].delta.tool_calls: + for tool_call in resp.choices[0].delta.tool_calls or []: + index = tool_call.index + + if index not in final_tool_calls: + final_tool_calls[index] = tool_call + else: + final_tool_calls[index].function.arguments += tool_call.function.arguments continue + + if any([not resp.choices, not resp.choices[0].delta, not hasattr(resp.choices[0].delta, "content")]): + raise Exception("500 response structure error.") + if not resp.choices[0].delta.content: resp.choices[0].delta.content = "" + if hasattr(resp.choices[0].delta, "reasoning_content") and resp.choices[0].delta.reasoning_content: ans = "" if not reasoning_start: reasoning_start = True ans = "" ans += resp.choices[0].delta.reasoning_content + "" + yield ans else: reasoning_start = False - ans = resp.choices[0].delta.content + answer += resp.choices[0].delta.content + yield resp.choices[0].delta.content tol = self.total_token_count(resp) if not tol: @@ -279,18 +275,18 @@ def chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict): else: total_tokens += tol - finish_reason = resp.choices[0].finish_reason - if finish_reason == "tool_calls" and final_tool_calls: - for tool_call in final_tool_calls.values(): - name = tool_call.function.name - try: - args = json.loads(tool_call.function.arguments) - except Exception as e: - logging.exception(msg=f"Wrong JSON argument format in LLM tool call response: {tool_call}") - yield ans + "\n**ERROR**: " + str(e) - finish_completion = True - break + finish_reason = resp.choices[0].finish_reason if hasattr(resp.choices[0], "finish_reason") else "" + if finish_reason == "length": + yield self._length_stop("") + if answer: + yield total_tokens + return + + for tool_call in final_tool_calls.values(): + name = tool_call.function.name + try: + args = json_repair.loads(tool_call.function.arguments) tool_response = self.toolcall_session.tool_call(name, args) history.append( { @@ -308,33 +304,17 @@ def chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict): ], } ) - # if tool_response.choices[0].finish_reason == "length": - # if is_chinese(ans): - # ans += LENGTH_NOTIFICATION_CN - # else: - # ans += LENGTH_NOTIFICATION_EN - # return ans, total_tokens + self.total_token_count(tool_response) history.append({"role": "tool", "tool_call_id": tool_call.id, "content": str(tool_response)}) - final_tool_calls = {} - response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, **gen_conf) - continue - if finish_reason == "length": - if is_chinese(ans): - ans += LENGTH_NOTIFICATION_CN - else: - ans += LENGTH_NOTIFICATION_EN - return ans, total_tokens - if finish_reason == "stop": - finish_completion = True - yield ans - break - yield ans - continue - - except openai.APIError as e: - yield ans + "\n**ERROR**: " + str(e) + except Exception as e: + logging.exception(msg=f"Wrong JSON argument format in LLM tool call response: {tool_call}") + history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)}) + except Exception as e: + e = self._exceptions(e, attempt) + if e: + yield total_tokens + return - yield total_tokens + assert False, "Shouldn't be here." def chat_streamly(self, system, history, gen_conf): if system: @@ -428,68 +408,64 @@ def count_tokens(text): class GptTurbo(Base): - def __init__(self, key, model_name="gpt-3.5-turbo", base_url="https://api.openai.com/v1"): + def __init__(self, key, model_name="gpt-3.5-turbo", base_url="https://api.openai.com/v1", **kwargs): if not base_url: base_url = "https://api.openai.com/v1" - super().__init__(key, model_name, base_url) + super().__init__(key, model_name, base_url, **kwargs) class MoonshotChat(Base): - def __init__(self, key, model_name="moonshot-v1-8k", base_url="https://api.moonshot.cn/v1"): + def __init__(self, key, model_name="moonshot-v1-8k", base_url="https://api.moonshot.cn/v1", **kwargs): if not base_url: base_url = "https://api.moonshot.cn/v1" super().__init__(key, model_name, base_url) class XinferenceChat(Base): - def __init__(self, key=None, model_name="", base_url=""): + def __init__(self, key=None, model_name="", base_url="", **kwargs): if not base_url: raise ValueError("Local llm url cannot be None") - if base_url.split("/")[-1] != "v1": - base_url = os.path.join(base_url, "v1") - super().__init__(key, model_name, base_url) + base_url = urljoin(base_url, "v1") + super().__init__(key, model_name, base_url, **kwargs) class HuggingFaceChat(Base): - def __init__(self, key=None, model_name="", base_url=""): + def __init__(self, key=None, model_name="", base_url="", **kwargs): if not base_url: raise ValueError("Local llm url cannot be None") - if base_url.split("/")[-1] != "v1": - base_url = os.path.join(base_url, "v1") - super().__init__(key, model_name.split("___")[0], base_url) + base_url = urljoin(base_url, "v1") + super().__init__(key, model_name.split("___")[0], base_url, **kwargs) class ModelScopeChat(Base): - def __init__(self, key=None, model_name="", base_url=""): + def __init__(self, key=None, model_name="", base_url="", **kwargs): if not base_url: raise ValueError("Local llm url cannot be None") - base_url = base_url.rstrip("/") - if base_url.split("/")[-1] != "v1": - base_url = os.path.join(base_url, "v1") - super().__init__(key, model_name.split("___")[0], base_url) + base_url = urljoin(base_url, "v1") + super().__init__(key, model_name.split("___")[0], base_url, **kwargs) class DeepSeekChat(Base): - def __init__(self, key, model_name="deepseek-chat", base_url="https://api.deepseek.com/v1"): + def __init__(self, key, model_name="deepseek-chat", base_url="https://api.deepseek.com/v1", **kwargs): if not base_url: base_url = "https://api.deepseek.com/v1" - super().__init__(key, model_name, base_url) + super().__init__(key, model_name, base_url, **kwargs) class AzureChat(Base): def __init__(self, key, model_name, **kwargs): api_key = json.loads(key).get("api_key", "") api_version = json.loads(key).get("api_version", "2024-02-01") - super().__init__(key, model_name, kwargs["base_url"]) + super().__init__(key, model_name, kwargs["base_url"], **kwargs) self.client = AzureOpenAI(api_key=api_key, azure_endpoint=kwargs["base_url"], api_version=api_version) self.model_name = model_name class BaiChuanChat(Base): - def __init__(self, key, model_name="Baichuan3-Turbo", base_url="https://api.baichuan-ai.com/v1"): + def __init__(self, key, model_name="Baichuan3-Turbo", base_url="https://api.baichuan-ai.com/v1", **kwargs): if not base_url: base_url = "https://api.baichuan-ai.com/v1" - super().__init__(key, model_name, base_url) + super().__init__(key, model_name, base_url, **kwargs) @staticmethod def _format_params(params): @@ -498,27 +474,26 @@ def _format_params(params): "top_p": params.get("top_p", 0.85), } - def chat(self, system, history, gen_conf): - if system: - history.insert(0, {"role": "system", "content": system}) - if "max_tokens" in gen_conf: - del gen_conf["max_tokens"] - try: - response = self.client.chat.completions.create( - model=self.model_name, - messages=history, - extra_body={"tools": [{"type": "web_search", "web_search": {"enable": True, "search_mode": "performance_first"}}]}, - **self._format_params(gen_conf), - ) - ans = response.choices[0].message.content.strip() - if response.choices[0].finish_reason == "length": - if is_chinese([ans]): - ans += LENGTH_NOTIFICATION_CN - else: - ans += LENGTH_NOTIFICATION_EN - return ans, self.total_token_count(response) - except openai.APIError as e: - return "**ERROR**: " + str(e), 0 + def _clean_conf(self, gen_conf): + return { + "temperature": gen_conf.get("temperature", 0.3), + "top_p": gen_conf.get("top_p", 0.85), + } + + def _chat(self, history, gen_conf): + response = self.client.chat.completions.create( + model=self.model_name, + messages=history, + extra_body={"tools": [{"type": "web_search", "web_search": {"enable": True, "search_mode": "performance_first"}}]}, + **gen_conf, + ) + ans = response.choices[0].message.content.strip() + if response.choices[0].finish_reason == "length": + if is_chinese([ans]): + ans += LENGTH_NOTIFICATION_CN + else: + ans += LENGTH_NOTIFICATION_EN + return ans, self.total_token_count(response) def chat_streamly(self, system, history, gen_conf): if system: @@ -560,15 +535,15 @@ def chat_streamly(self, system, history, gen_conf): class QWenChat(Base): - def __init__(self, key, model_name=Generation.Models.qwen_turbo, **kwargs): - super().__init__(key, model_name, base_url=None) + def __init__(self, key, model_name=Generation.Models.qwen_turbo, base_url=None, **kwargs): + super().__init__(key, model_name, base_url=base_url, **kwargs) import dashscope dashscope.api_key = key self.model_name = model_name if self.is_reasoning_model(self.model_name) or self.model_name in ["qwen-vl-plus", "qwen-vl-plus-latest", "qwen-vl-max", "qwen-vl-max-latest"]: - super().__init__(key, model_name, "https://dashscope.aliyuncs.com/compatible-mode/v1") + super().__init__(key, model_name, "https://dashscope.aliyuncs.com/compatible-mode/v1", **kwargs) def chat_with_tools(self, system: str, history: list, gen_conf: dict) -> tuple[str, int]: if "max_tokens" in gen_conf: @@ -642,41 +617,22 @@ def chat_with_tools(self, system: str, history: list, gen_conf: dict) -> tuple[s else: return "".join(result_list[:-1]), result_list[-1] - def chat(self, system, history, gen_conf): - if "max_tokens" in gen_conf: - del gen_conf["max_tokens"] + def _chat(self, history, gen_conf): if self.is_reasoning_model(self.model_name) or self.model_name in ["qwen-vl-plus", "qwen-vl-plus-latest", "qwen-vl-max", "qwen-vl-max-latest"]: - return super().chat(system, history, gen_conf) - - stream_flag = str(os.environ.get("QWEN_CHAT_BY_STREAM", "true")).lower() == "true" - if not stream_flag: - from http import HTTPStatus - - if system: - history.insert(0, {"role": "system", "content": system}) - - response = Generation.call(self.model_name, messages=history, result_format="message", **gen_conf) - ans = "" - tk_count = 0 - if response.status_code == HTTPStatus.OK: - ans += response.output.choices[0]["message"]["content"] - tk_count += self.total_token_count(response) - if response.output.choices[0].get("finish_reason", "") == "length": - if is_chinese([ans]): - ans += LENGTH_NOTIFICATION_CN - else: - ans += LENGTH_NOTIFICATION_EN - return ans, tk_count - - return "**ERROR**: " + response.message, tk_count - else: - g = self._chat_streamly(system, history, gen_conf, incremental_output=True) - result_list = list(g) - error_msg_list = [item for item in result_list if str(item).find("**ERROR**") >= 0] - if len(error_msg_list) > 0: - return "**ERROR**: " + "".join(error_msg_list), 0 - else: - return "".join(result_list[:-1]), result_list[-1] + return super()._chat(history, gen_conf) + response = Generation.call(self.model_name, messages=history, result_format="message", **gen_conf) + ans = "" + tk_count = 0 + if response.status_code == HTTPStatus.OK: + ans += response.output.choices[0]["message"]["content"] + tk_count += self.total_token_count(response) + if response.output.choices[0].get("finish_reason", "") == "length": + if is_chinese([ans]): + ans += LENGTH_NOTIFICATION_CN + else: + ans += LENGTH_NOTIFICATION_EN + return ans, tk_count + return "**ERROR**: " + response.message, tk_count def _wrap_toolcall_message(self, old_message, message): if not old_message: @@ -829,32 +785,20 @@ def is_reasoning_model(model_name: str) -> bool: class ZhipuChat(Base): - def __init__(self, key, model_name="glm-3-turbo", **kwargs): - super().__init__(key, model_name, base_url=None) + def __init__(self, key, model_name="glm-3-turbo", base_url=None, **kwargs): + super().__init__(key, model_name, base_url=base_url, **kwargs) self.client = ZhipuAI(api_key=key) self.model_name = model_name - def chat(self, system, history, gen_conf): - if system: - history.insert(0, {"role": "system", "content": system}) + def _clean_conf(self, gen_conf): if "max_tokens" in gen_conf: del gen_conf["max_tokens"] - try: - if "presence_penalty" in gen_conf: - del gen_conf["presence_penalty"] - if "frequency_penalty" in gen_conf: - del gen_conf["frequency_penalty"] - response = self.client.chat.completions.create(model=self.model_name, messages=history, **gen_conf) - ans = response.choices[0].message.content.strip() - if response.choices[0].finish_reason == "length": - if is_chinese(ans): - ans += LENGTH_NOTIFICATION_CN - else: - ans += LENGTH_NOTIFICATION_EN - return ans, self.total_token_count(response) - except Exception as e: - return "**ERROR**: " + str(e), 0 + if "presence_penalty" in gen_conf: + del gen_conf["presence_penalty"] + if "frequency_penalty" in gen_conf: + del gen_conf["frequency_penalty"] + return gen_conf def chat_with_tools(self, system: str, history: list, gen_conf: dict): if "presence_penalty" in gen_conf: @@ -906,39 +850,31 @@ def chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict): class OllamaChat(Base): - def __init__(self, key, model_name, **kwargs): - super().__init__(key, model_name, base_url=None) + def __init__(self, key, model_name, base_url=None, **kwargs): + super().__init__(key, model_name, base_url=base_url, **kwargs) - self.client = Client(host=kwargs["base_url"]) if not key or key == "x" else Client(host=kwargs["base_url"], headers={"Authorization": f"Bearer {key}"}) + self.client = Client(host=base_url) if not key or key == "x" else Client(host=base_url, headers={"Authorization": f"Bearer {key}"}) self.model_name = model_name - def chat(self, system, history, gen_conf): - if system: - history.insert(0, {"role": "system", "content": system}) + def _clean_conf(self, gen_conf): + options = {} if "max_tokens" in gen_conf: - del gen_conf["max_tokens"] - try: - # Calculate context size - ctx_size = self._calculate_dynamic_ctx(history) - - options = {"num_ctx": ctx_size} - if "temperature" in gen_conf: - options["temperature"] = gen_conf["temperature"] - if "max_tokens" in gen_conf: - options["num_predict"] = gen_conf["max_tokens"] - if "top_p" in gen_conf: - options["top_p"] = gen_conf["top_p"] - if "presence_penalty" in gen_conf: - options["presence_penalty"] = gen_conf["presence_penalty"] - if "frequency_penalty" in gen_conf: - options["frequency_penalty"] = gen_conf["frequency_penalty"] - - response = self.client.chat(model=self.model_name, messages=history, options=options) - ans = response["message"]["content"].strip() - token_count = response.get("eval_count", 0) + response.get("prompt_eval_count", 0) - return ans, token_count - except Exception as e: - return "**ERROR**: " + str(e), 0 + options["num_predict"] = gen_conf["max_tokens"] + for k in ["temperature", "top_p", "presence_penalty", "frequency_penalty"]: + if k not in gen_conf: + continue + options[k] = gen_conf[k] + return options + + def _chat(self, history, gen_conf): + # Calculate context size + ctx_size = self._calculate_dynamic_ctx(history) + + gen_conf["num_ctx"] = ctx_size + response = self.client.chat(model=self.model_name, messages=history, options=gen_conf, keep_alive=-1) + ans = response["message"]["content"].strip() + token_count = response.get("eval_count", 0) + response.get("prompt_eval_count", 0) + return ans, token_count def chat_streamly(self, system, history, gen_conf): if system: @@ -962,7 +898,7 @@ def chat_streamly(self, system, history, gen_conf): ans = "" try: - response = self.client.chat(model=self.model_name, messages=history, stream=True, options=options) + response = self.client.chat(model=self.model_name, messages=history, stream=True, options=options, keep_alive=-1) for resp in response: if resp["done"]: token_count = resp.get("prompt_eval_count", 0) + resp.get("eval_count", 0) @@ -978,48 +914,21 @@ def chat_streamly(self, system, history, gen_conf): class LocalAIChat(Base): - def __init__(self, key, model_name, base_url): - super().__init__(key, model_name, base_url=None) + def __init__(self, key, model_name, base_url=None, **kwargs): + super().__init__(key, model_name, base_url=base_url, **kwargs) if not base_url: raise ValueError("Local llm url cannot be None") - if base_url.split("/")[-1] != "v1": - base_url = os.path.join(base_url, "v1") + base_url = urljoin(base_url, "v1") self.client = OpenAI(api_key="empty", base_url=base_url) self.model_name = model_name.split("___")[0] class LocalLLM(Base): - class RPCProxy: - def __init__(self, host, port): - self.host = host - self.port = int(port) - self.__conn() - - def __conn(self): - from multiprocessing.connection import Client - - self._connection = Client((self.host, self.port), authkey=b"infiniflow-token4kevinhu") - - def __getattr__(self, name): - import pickle - - def do_rpc(*args, **kwargs): - for _ in range(3): - try: - self._connection.send(pickle.dumps((name, args, kwargs))) - return pickle.loads(self._connection.recv()) - except Exception: - self.__conn() - raise Exception("RPC connection lost!") - - return do_rpc - - def __init__(self, key, model_name): - super().__init__(key, model_name, base_url=None) + def __init__(self, key, model_name, base_url=None, **kwargs): + super().__init__(key, model_name, base_url=base_url, **kwargs) from jina import Client - self.client = Client(port=12345, protocol="grpc", asyncio=True) def _prepare_prompt(self, system, history, gen_conf): @@ -1063,9 +972,7 @@ def chat_streamly(self, system, history, gen_conf): class VolcEngineChat(Base): - def __init__(self, key, model_name, base_url="https://ark.cn-beijing.volces.com/api/v3"): - super().__init__(key, model_name, base_url=None) - + def __init__(self, key, model_name, base_url="https://ark.cn-beijing.volces.com/api/v3", **kwargs): """ Since do not want to modify the original database fields, and the VolcEngine authentication method is quite special, Assemble ark_api_key, ep_id into api_key, store it as a dictionary type, and parse it for use @@ -1074,7 +981,7 @@ def __init__(self, key, model_name, base_url="https://ark.cn-beijing.volces.com/ base_url = base_url if base_url else "https://ark.cn-beijing.volces.com/api/v3" ark_api_key = json.loads(key).get("ark_api_key", "") model_name = json.loads(key).get("ep_id", "") + json.loads(key).get("endpoint_id", "") - super().__init__(ark_api_key, model_name, base_url) + super().__init__(ark_api_key, model_name, base_url, **kwargs) class MiniMaxChat(Base): @@ -1083,8 +990,9 @@ def __init__( key, model_name, base_url="https://api.minimax.chat/v1/text/chatcompletion_v2", + **kwargs ): - super().__init__(key, model_name, base_url=None) + super().__init__(key, model_name, base_url=base_url, **kwargs) if not base_url: base_url = "https://api.minimax.chat/v1/text/chatcompletion_v2" @@ -1092,29 +1000,27 @@ def __init__( self.model_name = model_name self.api_key = key - def chat(self, system, history, gen_conf): - if system: - history.insert(0, {"role": "system", "content": system}) + def _clean_conf(self, gen_conf): for k in list(gen_conf.keys()): if k not in ["temperature", "top_p", "max_tokens"]: del gen_conf[k] + return gen_conf + + def _chat(self, history, gen_conf): headers = { "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json", } payload = json.dumps({"model": self.model_name, "messages": history, **gen_conf}) - try: - response = requests.request("POST", url=self.base_url, headers=headers, data=payload) - response = response.json() - ans = response["choices"][0]["message"]["content"].strip() - if response["choices"][0]["finish_reason"] == "length": - if is_chinese(ans): - ans += LENGTH_NOTIFICATION_CN - else: - ans += LENGTH_NOTIFICATION_EN - return ans, self.total_token_count(response) - except Exception as e: - return "**ERROR**: " + str(e), 0 + response = requests.request("POST", url=self.base_url, headers=headers, data=payload) + response = response.json() + ans = response["choices"][0]["message"]["content"].strip() + if response["choices"][0]["finish_reason"] == "length": + if is_chinese(ans): + ans += LENGTH_NOTIFICATION_CN + else: + ans += LENGTH_NOTIFICATION_EN + return ans, self.total_token_count(response) def chat_streamly(self, system, history, gen_conf): if system: @@ -1163,31 +1069,29 @@ def chat_streamly(self, system, history, gen_conf): class MistralChat(Base): - def __init__(self, key, model_name, base_url=None): - super().__init__(key, model_name, base_url=None) + def __init__(self, key, model_name, base_url=None, **kwargs): + super().__init__(key, model_name, base_url=base_url, **kwargs) from mistralai.client import MistralClient self.client = MistralClient(api_key=key) self.model_name = model_name - def chat(self, system, history, gen_conf): - if system: - history.insert(0, {"role": "system", "content": system}) + def _clean_conf(self, gen_conf): for k in list(gen_conf.keys()): if k not in ["temperature", "top_p", "max_tokens"]: del gen_conf[k] - try: - response = self.client.chat(model=self.model_name, messages=history, **gen_conf) - ans = response.choices[0].message.content - if response.choices[0].finish_reason == "length": - if is_chinese(ans): - ans += LENGTH_NOTIFICATION_CN - else: - ans += LENGTH_NOTIFICATION_EN - return ans, self.total_token_count(response) - except openai.APIError as e: - return "**ERROR**: " + str(e), 0 + return gen_conf + + def _chat(self, history, gen_conf): + response = self.client.chat(model=self.model_name, messages=history, **gen_conf) + ans = response.choices[0].message.content + if response.choices[0].finish_reason == "length": + if is_chinese(ans): + ans += LENGTH_NOTIFICATION_CN + else: + ans += LENGTH_NOTIFICATION_EN + return ans, self.total_token_count(response) def chat_streamly(self, system, history, gen_conf): if system: @@ -1218,8 +1122,8 @@ def chat_streamly(self, system, history, gen_conf): class BedrockChat(Base): - def __init__(self, key, model_name, **kwargs): - super().__init__(key, model_name, base_url=None) + def __init__(self, key, model_name, base_url=None, **kwargs): + super().__init__(key, model_name, base_url=base_url, **kwargs) import boto3 @@ -1234,31 +1138,32 @@ def __init__(self, key, model_name, **kwargs): else: self.client = boto3.client(service_name="bedrock-runtime", region_name=self.bedrock_region, aws_access_key_id=self.bedrock_ak, aws_secret_access_key=self.bedrock_sk) - def chat(self, system, history, gen_conf): - from botocore.exceptions import ClientError - + def _clean_conf(self, gen_conf): for k in list(gen_conf.keys()): if k not in ["temperature"]: del gen_conf[k] - for item in history: - if not isinstance(item["content"], list) and not isinstance(item["content"], tuple): - item["content"] = [{"text": item["content"]}] - - try: - # Send the message to the model, using a basic inference configuration. - response = self.client.converse( - modelId=self.model_name, - messages=history, - inferenceConfig=gen_conf, - system=[{"text": (system if system else "Answer the user's message.")}], - ) + return gen_conf - # Extract and print the response text. - ans = response["output"]["message"]["content"][0]["text"] - return ans, num_tokens_from_string(ans) + def _chat(self, history, gen_conf): + system = history[0]["content"] if history and history[0]["role"] == "system" else "" + hist = [] + for item in history: + if item["role"] == "system": + continue + hist.append(deepcopy(item)) + if not isinstance(hist[-1]["content"], list) and not isinstance(hist[-1]["content"], tuple): + hist[-1]["content"] = [{"text": hist[-1]["content"]}] + # Send the message to the model, using a basic inference configuration. + response = self.client.converse( + modelId=self.model_name, + messages=hist, + inferenceConfig=gen_conf, + system=[{"text": (system if system else "Answer the user's message.")}], + ) - except (ClientError, Exception) as e: - return f"ERROR: Can't invoke '{self.model_name}'. Reason: {e}", 0 + # Extract and print the response text. + ans = response["output"]["message"]["content"][0]["text"] + return ans, num_tokens_from_string(ans) def chat_streamly(self, system, history, gen_conf): from botocore.exceptions import ClientError @@ -1299,8 +1204,8 @@ def chat_streamly(self, system, history, gen_conf): class GeminiChat(Base): - def __init__(self, key, model_name, base_url=None): - super().__init__(key, model_name, base_url=None) + def __init__(self, key, model_name, base_url=None, **kwargs): + super().__init__(key, model_name, base_url=base_url, **kwargs) from google.generativeai import GenerativeModel, client @@ -1310,15 +1215,21 @@ def __init__(self, key, model_name, base_url=None): self.model = GenerativeModel(model_name=self.model_name) self.model._client = _client - def chat(self, system, history, gen_conf): - from google.generativeai.types import content_types - - if system: - self.model._system_instruction = content_types.to_content(system) + def _clean_conf(self, gen_conf): for k in list(gen_conf.keys()): - if k not in ["temperature", "top_p", "max_tokens"]: + if k not in ["temperature", "top_p"]: del gen_conf[k] + return gen_conf + + def _chat(self, history, gen_conf): + from google.generativeai.types import content_types + system = history[0]["content"] if history and history[0]["role"] == "system" else "" + hist = [] for item in history: + if item["role"] == "system": + continue + hist.append(deepcopy(item)) + item = hist[-1] if "role" in item and item["role"] == "assistant": item["role"] = "model" if "role" in item and item["role"] == "system": @@ -1326,12 +1237,11 @@ def chat(self, system, history, gen_conf): if "content" in item: item["parts"] = item.pop("content") - try: - response = self.model.generate_content(history, generation_config=gen_conf) - ans = response.text - return ans, response.usage_metadata.total_token_count - except Exception as e: - return "**ERROR**: " + str(e), 0 + if system: + self.model._system_instruction = content_types.to_content(system) + response = self.model.generate_content(hist, generation_config=gen_conf) + ans = response.text + return ans, response.usage_metadata.total_token_count def chat_streamly(self, system, history, gen_conf): from google.generativeai.types import content_types @@ -1361,32 +1271,19 @@ def chat_streamly(self, system, history, gen_conf): class GroqChat(Base): - def __init__(self, key, model_name, base_url=""): - super().__init__(key, model_name, base_url=None) + def __init__(self, key, model_name, base_url=None, **kwargs): + super().__init__(key, model_name, base_url=base_url, **kwargs) from groq import Groq self.client = Groq(api_key=key) self.model_name = model_name - def chat(self, system, history, gen_conf): - if system: - history.insert(0, {"role": "system", "content": system}) + def _clean_conf(self, gen_conf): for k in list(gen_conf.keys()): if k not in ["temperature", "top_p", "max_tokens"]: del gen_conf[k] - ans = "" - try: - response = self.client.chat.completions.create(model=self.model_name, messages=history, **gen_conf) - ans = response.choices[0].message.content - if response.choices[0].finish_reason == "length": - if is_chinese(ans): - ans += LENGTH_NOTIFICATION_CN - else: - ans += LENGTH_NOTIFICATION_EN - return ans, self.total_token_count(response) - except Exception as e: - return ans + "\n**ERROR**: " + str(e), 0 + return gen_conf def chat_streamly(self, system, history, gen_conf): if system: @@ -1418,33 +1315,32 @@ def chat_streamly(self, system, history, gen_conf): ## openrouter class OpenRouterChat(Base): - def __init__(self, key, model_name, base_url="https://openrouter.ai/api/v1"): + def __init__(self, key, model_name, base_url="https://openrouter.ai/api/v1", **kwargs): if not base_url: base_url = "https://openrouter.ai/api/v1" - super().__init__(key, model_name, base_url) + super().__init__(key, model_name, base_url, **kwargs) class StepFunChat(Base): - def __init__(self, key, model_name, base_url="https://api.stepfun.com/v1"): + def __init__(self, key, model_name, base_url="https://api.stepfun.com/v1", **kwargs): if not base_url: base_url = "https://api.stepfun.com/v1" - super().__init__(key, model_name, base_url) + super().__init__(key, model_name, base_url, **kwargs) class NvidiaChat(Base): - def __init__(self, key, model_name, base_url="https://integrate.api.nvidia.com/v1"): + def __init__(self, key, model_name, base_url="https://integrate.api.nvidia.com/v1", **kwargs): if not base_url: base_url = "https://integrate.api.nvidia.com/v1" - super().__init__(key, model_name, base_url) + super().__init__(key, model_name, base_url, **kwargs) class LmStudioChat(Base): - def __init__(self, key, model_name, base_url): + def __init__(self, key, model_name, base_url, **kwargs): if not base_url: raise ValueError("Local llm url cannot be None") - if base_url.split("/")[-1] != "v1": - base_url = os.path.join(base_url, "v1") - super().__init__(key, model_name, base_url) + base_url = urljoin(base_url, "v1") + super().__init__(key, model_name, base_url, **kwargs) self.client = OpenAI(api_key="lm-studio", base_url=base_url) self.model_name = model_name @@ -1458,50 +1354,50 @@ def __init__(self, key, model_name, base_url): class PPIOChat(Base): - def __init__(self, key, model_name, base_url="https://api.ppinfra.com/v3/openai"): + def __init__(self, key, model_name, base_url="https://api.ppinfra.com/v3/openai", **kwargs): if not base_url: base_url = "https://api.ppinfra.com/v3/openai" - super().__init__(key, model_name, base_url) + super().__init__(key, model_name, base_url, **kwargs) class CoHereChat(Base): - def __init__(self, key, model_name, base_url=""): - super().__init__(key, model_name, base_url=None) + def __init__(self, key, model_name, base_url=None, **kwargs): + super().__init__(key, model_name, base_url=base_url, **kwargs) from cohere import Client self.client = Client(api_key=key) self.model_name = model_name - def chat(self, system, history, gen_conf): - if system: - history.insert(0, {"role": "system", "content": system}) + def _clean_conf(self, gen_conf): if "max_tokens" in gen_conf: del gen_conf["max_tokens"] if "top_p" in gen_conf: gen_conf["p"] = gen_conf.pop("top_p") if "frequency_penalty" in gen_conf and "presence_penalty" in gen_conf: gen_conf.pop("presence_penalty") + return gen_conf + + def _chat(self, history, gen_conf): + hist = [] for item in history: + hist.append(deepcopy(item)) + item = hist[-1] if "role" in item and item["role"] == "user": item["role"] = "USER" if "role" in item and item["role"] == "assistant": item["role"] = "CHATBOT" if "content" in item: item["message"] = item.pop("content") - mes = history.pop()["message"] - ans = "" - try: - response = self.client.chat(model=self.model_name, chat_history=history, message=mes, **gen_conf) - ans = response.text - if response.finish_reason == "MAX_TOKENS": - ans += "...\nFor the content length reason, it stopped, continue?" if is_english([ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" - return ( - ans, - response.meta.tokens.input_tokens + response.meta.tokens.output_tokens, - ) - except Exception as e: - return ans + "\n**ERROR**: " + str(e), 0 + mes = hist.pop()["message"] + response = self.client.chat(model=self.model_name, chat_history=hist, message=mes, **gen_conf) + ans = response.text + if response.finish_reason == "MAX_TOKENS": + ans += "...\nFor the content length reason, it stopped, continue?" if is_english([ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" + return ( + ans, + response.meta.tokens.input_tokens + response.meta.tokens.output_tokens, + ) def chat_streamly(self, system, history, gen_conf): if system: @@ -1540,92 +1436,82 @@ def chat_streamly(self, system, history, gen_conf): class LeptonAIChat(Base): - def __init__(self, key, model_name, base_url=None): + def __init__(self, key, model_name, base_url=None, **kwargs): if not base_url: - base_url = os.path.join("https://" + model_name + ".lepton.run", "api", "v1") - super().__init__(key, model_name, base_url) + base_url = urljoin("https://" + model_name + ".lepton.run", "api/v1") + super().__init__(key, model_name, base_url, **kwargs) class TogetherAIChat(Base): - def __init__(self, key, model_name, base_url="https://api.together.xyz/v1"): + def __init__(self, key, model_name, base_url="https://api.together.xyz/v1", **kwargs): if not base_url: base_url = "https://api.together.xyz/v1" - super().__init__(key, model_name, base_url) + super().__init__(key, model_name, base_url, **kwargs) class PerfXCloudChat(Base): - def __init__(self, key, model_name, base_url="https://cloud.perfxlab.cn/v1"): + def __init__(self, key, model_name, base_url="https://cloud.perfxlab.cn/v1", **kwargs): if not base_url: base_url = "https://cloud.perfxlab.cn/v1" - super().__init__(key, model_name, base_url) + super().__init__(key, model_name, base_url, **kwargs) class UpstageChat(Base): - def __init__(self, key, model_name, base_url="https://api.upstage.ai/v1/solar"): + def __init__(self, key, model_name, base_url="https://api.upstage.ai/v1/solar", **kwargs): if not base_url: base_url = "https://api.upstage.ai/v1/solar" - super().__init__(key, model_name, base_url) + super().__init__(key, model_name, base_url, **kwargs) class NovitaAIChat(Base): - def __init__(self, key, model_name, base_url="https://api.novita.ai/v3/openai"): + def __init__(self, key, model_name, base_url="https://api.novita.ai/v3/openai", **kwargs): if not base_url: base_url = "https://api.novita.ai/v3/openai" - super().__init__(key, model_name, base_url) + super().__init__(key, model_name, base_url, **kwargs) class SILICONFLOWChat(Base): - def __init__(self, key, model_name, base_url="https://api.siliconflow.cn/v1"): + def __init__(self, key, model_name, base_url="https://api.siliconflow.cn/v1", **kwargs): if not base_url: base_url = "https://api.siliconflow.cn/v1" - super().__init__(key, model_name, base_url) + super().__init__(key, model_name, base_url, **kwargs) class YiChat(Base): - def __init__(self, key, model_name, base_url="https://api.lingyiwanwu.com/v1"): + def __init__(self, key, model_name, base_url="https://api.lingyiwanwu.com/v1", **kwargs): if not base_url: base_url = "https://api.lingyiwanwu.com/v1" - super().__init__(key, model_name, base_url) + super().__init__(key, model_name, base_url, **kwargs) class ReplicateChat(Base): - def __init__(self, key, model_name, base_url=None): - super().__init__(key, model_name, base_url=None) + def __init__(self, key, model_name, base_url=None, **kwargs): + super().__init__(key, model_name, base_url=base_url, **kwargs) from replicate.client import Client self.model_name = model_name self.client = Client(api_token=key) - self.system = "" - def chat(self, system, history, gen_conf): - if "max_tokens" in gen_conf: - del gen_conf["max_tokens"] - if system: - self.system = system - prompt = "\n".join([item["role"] + ":" + item["content"] for item in history[-5:]]) - ans = "" - try: - response = self.client.run( - self.model_name, - input={"system_prompt": self.system, "prompt": prompt, **gen_conf}, - ) - ans = "".join(response) - return ans, num_tokens_from_string(ans) - except Exception as e: - return ans + "\n**ERROR**: " + str(e), 0 + def _chat(self, history, gen_conf): + system = history[0]["content"] if history and history[0]["role"] == "system" else "" + prompt = "\n".join([item["role"] + ":" + item["content"] for item in history[-5:] if item["role"] != "system"]) + response = self.client.run( + self.model_name, + input={"system_prompt": system, "prompt": prompt, **gen_conf}, + ) + ans = "".join(response) + return ans, num_tokens_from_string(ans) def chat_streamly(self, system, history, gen_conf): if "max_tokens" in gen_conf: del gen_conf["max_tokens"] - if system: - self.system = system prompt = "\n".join([item["role"] + ":" + item["content"] for item in history[-5:]]) ans = "" try: response = self.client.run( self.model_name, - input={"system_prompt": self.system, "prompt": prompt, **gen_conf}, + input={"system_prompt": system, "prompt": prompt, **gen_conf}, ) for resp in response: ans = resp @@ -1638,8 +1524,8 @@ def chat_streamly(self, system, history, gen_conf): class HunyuanChat(Base): - def __init__(self, key, model_name, base_url=None): - super().__init__(key, model_name, base_url=None) + def __init__(self, key, model_name, base_url=None, **kwargs): + super().__init__(key, model_name, base_url=base_url, **kwargs) from tencentcloud.common import credential from tencentcloud.hunyuan.v20230901 import hunyuan_client @@ -1651,33 +1537,24 @@ def __init__(self, key, model_name, base_url=None): self.model_name = model_name self.client = hunyuan_client.HunyuanClient(cred, "") - def chat(self, system, history, gen_conf): - from tencentcloud.common.exception.tencent_cloud_sdk_exception import ( - TencentCloudSDKException, - ) - from tencentcloud.hunyuan.v20230901 import models - + def _clean_conf(self, gen_conf): _gen_conf = {} - _history = [{k.capitalize(): v for k, v in item.items()} for item in history] - if system: - _history.insert(0, {"Role": "system", "Content": system}) - if "max_tokens" in gen_conf: - del gen_conf["max_tokens"] if "temperature" in gen_conf: _gen_conf["Temperature"] = gen_conf["temperature"] if "top_p" in gen_conf: _gen_conf["TopP"] = gen_conf["top_p"] + return _gen_conf + + def _chat(self, history, gen_conf): + from tencentcloud.hunyuan.v20230901 import models + hist = [{k.capitalize(): v for k, v in item.items()} for item in history] req = models.ChatCompletionsRequest() - params = {"Model": self.model_name, "Messages": _history, **_gen_conf} + params = {"Model": self.model_name, "Messages": hist, **gen_conf} req.from_json_string(json.dumps(params)) - ans = "" - try: - response = self.client.ChatCompletions(req) - ans = response.Choices[0].Message.Content - return ans, response.Usage.TotalTokens - except TencentCloudSDKException as e: - return ans + "\n**ERROR**: " + str(e), 0 + response = self.client.ChatCompletions(req) + ans = response.Choices[0].Message.Content + return ans, response.Usage.TotalTokens def chat_streamly(self, system, history, gen_conf): from tencentcloud.common.exception.tencent_cloud_sdk_exception import ( @@ -1723,7 +1600,7 @@ def chat_streamly(self, system, history, gen_conf): class SparkChat(Base): - def __init__(self, key, model_name, base_url="https://spark-api-open.xf-yun.com/v1"): + def __init__(self, key, model_name, base_url="https://spark-api-open.xf-yun.com/v1", **kwargs): if not base_url: base_url = "https://spark-api-open.xf-yun.com/v1" model2version = { @@ -1739,12 +1616,12 @@ def __init__(self, key, model_name, base_url="https://spark-api-open.xf-yun.com/ model_version = model2version[model_name] else: model_version = model_name - super().__init__(key, model_version, base_url) + super().__init__(key, model_version, base_url, **kwargs) class BaiduYiyanChat(Base): - def __init__(self, key, model_name, base_url=None): - super().__init__(key, model_name, base_url=None) + def __init__(self, key, model_name, base_url=None, **kwargs): + super().__init__(key, model_name, base_url=base_url, **kwargs) import qianfan @@ -1753,27 +1630,20 @@ def __init__(self, key, model_name, base_url=None): sk = key.get("yiyan_sk", "") self.client = qianfan.ChatCompletion(ak=ak, sk=sk) self.model_name = model_name.lower() - self.system = "" - def chat(self, system, history, gen_conf): - if system: - self.system = system + def _clean_conf(self, gen_conf): gen_conf["penalty_score"] = ((gen_conf.get("presence_penalty", 0) + gen_conf.get("frequency_penalty", 0)) / 2) + 1 if "max_tokens" in gen_conf: del gen_conf["max_tokens"] - ans = "" - - try: - response = self.client.do(model=self.model_name, messages=history, system=self.system, **gen_conf).body - ans = response["result"] - return ans, self.total_token_count(response) + return gen_conf - except Exception as e: - return ans + "\n**ERROR**: " + str(e), 0 + def _chat(self, history, gen_conf): + system = history[0]["content"] if history and history[0]["role"] == "system" else "" + response = self.client.do(model=self.model_name, messages=[h for h in history if h["role"] != "system"], system=system, **gen_conf).body + ans = response["result"] + return ans, self.total_token_count(response) def chat_streamly(self, system, history, gen_conf): - if system: - self.system = system gen_conf["penalty_score"] = ((gen_conf.get("presence_penalty", 0) + gen_conf.get("frequency_penalty", 0)) / 2) + 1 if "max_tokens" in gen_conf: del gen_conf["max_tokens"] @@ -1781,7 +1651,7 @@ def chat_streamly(self, system, history, gen_conf): total_tokens = 0 try: - response = self.client.do(model=self.model_name, messages=history, system=self.system, stream=True, **gen_conf) + response = self.client.do(model=self.model_name, messages=history, system=system, stream=True, **gen_conf) for resp in response: resp = resp.body ans = resp["result"] @@ -1796,18 +1666,15 @@ def chat_streamly(self, system, history, gen_conf): class AnthropicChat(Base): - def __init__(self, key, model_name, base_url=None): - super().__init__(key, model_name, base_url=None) + def __init__(self, key, model_name, base_url=None, **kwargs): + super().__init__(key, model_name, base_url=base_url, **kwargs) import anthropic self.client = anthropic.Anthropic(api_key=key) self.model_name = model_name - self.system = "" - def chat(self, system, history, gen_conf): - if system: - self.system = system + def _clean_conf(self, gen_conf): if "presence_penalty" in gen_conf: del gen_conf["presence_penalty"] if "frequency_penalty" in gen_conf: @@ -1815,29 +1682,26 @@ def chat(self, system, history, gen_conf): gen_conf["max_tokens"] = 8192 if "haiku" in self.model_name or "opus" in self.model_name: gen_conf["max_tokens"] = 4096 - - ans = "" - try: - response = self.client.messages.create( - model=self.model_name, - messages=history, - system=self.system, - stream=False, - **gen_conf, - ).to_dict() - ans = response["content"][0]["text"] - if response["stop_reason"] == "max_tokens": - ans += "...\nFor the content length reason, it stopped, continue?" if is_english([ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" - return ( - ans, - response["usage"]["input_tokens"] + response["usage"]["output_tokens"], - ) - except Exception as e: - return ans + "\n**ERROR**: " + str(e), 0 + return gen_conf + + def _chat(self, history, gen_conf): + system = history[0]["content"] if history and history[0]["role"] == "system" else "" + response = self.client.messages.create( + model=self.model_name, + messages=[h for h in history if h["role"] != "system"], + system=system, + stream=False, + **gen_conf, + ).to_dict() + ans = response["content"][0]["text"] + if response["stop_reason"] == "max_tokens": + ans += "...\nFor the content length reason, it stopped, continue?" if is_english([ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" + return ( + ans, + response["usage"]["input_tokens"] + response["usage"]["output_tokens"], + ) def chat_streamly(self, system, history, gen_conf): - if system: - self.system = system if "presence_penalty" in gen_conf: del gen_conf["presence_penalty"] if "frequency_penalty" in gen_conf: @@ -1878,8 +1742,8 @@ def chat_streamly(self, system, history, gen_conf): class GoogleChat(Base): - def __init__(self, key, model_name, base_url=None): - super().__init__(key, model_name, base_url=None) + def __init__(self, key, model_name, base_url=None, **kwargs): + super().__init__(key, model_name, base_url=base_url, **kwargs) import base64 @@ -1892,7 +1756,6 @@ def __init__(self, key, model_name, base_url=None): scopes = ["https://www.googleapis.com/auth/cloud-platform"] self.model_name = model_name - self.system = "" if "claude" in self.model_name: from anthropic import AnthropicVertex @@ -1917,53 +1780,53 @@ def __init__(self, key, model_name, base_url=None): aiplatform.init(project=project_id, location=region) self.client = glm.GenerativeModel(model_name=self.model_name) - def chat(self, system, history, gen_conf): - if system: - self.system = system - + def _clean_conf(self, gen_conf): if "claude" in self.model_name: if "max_tokens" in gen_conf: del gen_conf["max_tokens"] - try: - response = self.client.messages.create( - model=self.model_name, - messages=history, - system=self.system, - stream=False, - **gen_conf, - ).json() - ans = response["content"][0]["text"] - if response["stop_reason"] == "max_tokens": - ans += "...\nFor the content length reason, it stopped, continue?" if is_english([ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" - return ( - ans, - response["usage"]["input_tokens"] + response["usage"]["output_tokens"], - ) - except Exception as e: - return "\n**ERROR**: " + str(e), 0 else: - self.client._system_instruction = self.system if "max_tokens" in gen_conf: gen_conf["max_output_tokens"] = gen_conf["max_tokens"] for k in list(gen_conf.keys()): if k not in ["temperature", "top_p", "max_output_tokens"]: del gen_conf[k] - for item in history: - if "role" in item and item["role"] == "assistant": - item["role"] = "model" - if "content" in item: - item["parts"] = item.pop("content") - try: - response = self.client.generate_content(history, generation_config=gen_conf) - ans = response.text - return ans, response.usage_metadata.total_token_count - except Exception as e: - return "**ERROR**: " + str(e), 0 + return gen_conf - def chat_streamly(self, system, history, gen_conf): - if system: - self.system = system + def _chat(self, history, gen_conf): + system = history[0]["content"] if history and history[0]["role"] == "system" else "" + if "claude" in self.model_name: + response = self.client.messages.create( + model=self.model_name, + messages=[h for h in history if h["role"] != "system"], + system=system, + stream=False, + **gen_conf, + ).json() + ans = response["content"][0]["text"] + if response["stop_reason"] == "max_tokens": + ans += "...\nFor the content length reason, it stopped, continue?" if is_english([ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" + return ( + ans, + response["usage"]["input_tokens"] + response["usage"]["output_tokens"], + ) + self.client._system_instruction = system + hist = [] + for item in history: + if item["role"] == "system": + continue + hist.append(deepcopy(item)) + item = hist[-1] + if "role" in item and item["role"] == "assistant": + item["role"] = "model" + if "content" in item: + item["parts"] = item.pop("content") + + response = self.client.generate_content(hist, generation_config=gen_conf) + ans = response.text + return ans, response.usage_metadata.total_token_count + + def chat_streamly(self, system, history, gen_conf): if "claude" in self.model_name: if "max_tokens" in gen_conf: del gen_conf["max_tokens"] @@ -1973,7 +1836,7 @@ def chat_streamly(self, system, history, gen_conf): response = self.client.messages.create( model=self.model_name, messages=history, - system=self.system, + system=system, stream=True, **gen_conf, ) @@ -1988,7 +1851,7 @@ def chat_streamly(self, system, history, gen_conf): yield total_tokens else: - self.client._system_instruction = self.system + self.client._system_instruction = system if "max_tokens" in gen_conf: gen_conf["max_output_tokens"] = gen_conf["max_tokens"] for k in list(gen_conf.keys()): @@ -2013,9 +1876,8 @@ def chat_streamly(self, system, history, gen_conf): class GPUStackChat(Base): - def __init__(self, key=None, model_name="", base_url=""): + def __init__(self, key=None, model_name="", base_url="", **kwargs): if not base_url: raise ValueError("Local llm url cannot be None") - if base_url.split("/")[-1] != "v1": - base_url = os.path.join(base_url, "v1") - super().__init__(key, model_name, base_url) + base_url = urljoin(base_url, "v1") + super().__init__(key, model_name, base_url, **kwargs) \ No newline at end of file diff --git a/rag/llm/cv_model.py b/rag/llm/cv_model.py index f9d4e67c185..82640b56f53 100644 --- a/rag/llm/cv_model.py +++ b/rag/llm/cv_model.py @@ -19,6 +19,7 @@ import os from abc import ABC from io import BytesIO +from urllib.parse import urljoin import requests from ollama import Client @@ -322,7 +323,9 @@ def chat(self, system, history, gen_conf, image=""): ans = "" tk_count = 0 if response.status_code == HTTPStatus.OK: - ans += response.output.choices[0]['message']['content'] + ans = response.output.choices[0]['message']['content'] + if isinstance(ans, list): + ans = ans[0]["text"] if ans else "" tk_count += response.usage.total_tokens if response.output.choices[0].get("finish_reason", "") == "length": ans += "...\nFor the content length reason, it stopped, continue?" if is_english( @@ -351,7 +354,10 @@ def chat_streamly(self, system, history, gen_conf, image=""): stream=True) for resp in response: if resp.status_code == HTTPStatus.OK: - ans = resp.output.choices[0]['message']['content'] + cnt = resp.output.choices[0]['message']['content'] + if isinstance(cnt, list): + cnt = cnt[0]["text"] if ans else "" + ans += cnt tk_count = resp.usage.total_tokens if resp.output.choices[0].get("finish_reason", "") == "length": ans += "...\nFor the content length reason, it stopped, continue?" if is_english( @@ -500,7 +506,8 @@ def chat(self, system, history, gen_conf, image=""): response = self.client.chat( model=self.model_name, messages=history, - options=options + options=options, + keep_alive=-1 ) ans = response["message"]["content"].strip() @@ -530,7 +537,8 @@ def chat_streamly(self, system, history, gen_conf, image=""): model=self.model_name, messages=history, stream=True, - options=options + options=options, + keep_alive=-1 ) for resp in response: if resp["done"]: @@ -546,8 +554,7 @@ class LocalAICV(GptV4): def __init__(self, key, model_name, base_url, lang="Chinese"): if not base_url: raise ValueError("Local cv model url cannot be None") - if base_url.split("/")[-1] != "v1": - base_url = os.path.join(base_url, "v1") + base_url = urljoin(base_url, "v1") self.client = OpenAI(api_key="empty", base_url=base_url) self.model_name = model_name.split("___")[0] self.lang = lang @@ -555,8 +562,7 @@ def __init__(self, key, model_name, base_url, lang="Chinese"): class XinferenceCV(Base): def __init__(self, key, model_name="", lang="Chinese", base_url=""): - if base_url.split("/")[-1] != "v1": - base_url = os.path.join(base_url, "v1") + base_url = urljoin(base_url, "v1") self.client = OpenAI(api_key=key, base_url=base_url) self.model_name = model_name self.lang = lang @@ -706,11 +712,9 @@ def __init__( self.lang = lang factory, llm_name = model_name.split("/") if factory != "liuhaotian": - self.base_url = os.path.join(base_url, factory, llm_name) + self.base_url = urljoin(base_url, f"{factory}/{llm_name}") else: - self.base_url = os.path.join( - base_url, "community", llm_name.replace("-v1.6", "16") - ) + self.base_url = urljoin(f"{base_url}/community", llm_name.replace("-v1.6", "16")) self.key = key def describe(self, image): @@ -799,8 +803,7 @@ class LmStudioCV(GptV4): def __init__(self, key, model_name, lang="Chinese", base_url=""): if not base_url: raise ValueError("Local llm url cannot be None") - if base_url.split("/")[-1] != "v1": - base_url = os.path.join(base_url, "v1") + base_url = urljoin(base_url, "v1") self.client = OpenAI(api_key="lm-studio", base_url=base_url) self.model_name = model_name self.lang = lang @@ -810,8 +813,7 @@ class OpenAI_APICV(GptV4): def __init__(self, key, model_name, lang="Chinese", base_url=""): if not base_url: raise ValueError("url cannot be None") - if base_url.split("/")[-1] != "v1": - base_url = os.path.join(base_url, "v1") + base_url = urljoin(base_url, "v1") self.client = OpenAI(api_key=key, base_url=base_url) self.model_name = model_name.split("___")[0] self.lang = lang @@ -1032,8 +1034,7 @@ class GPUStackCV(GptV4): def __init__(self, key, model_name, lang="Chinese", base_url=""): if not base_url: raise ValueError("Local llm url cannot be None") - if base_url.split("/")[-1] != "v1": - base_url = os.path.join(base_url, "v1") + base_url = urljoin(base_url, "v1") self.client = OpenAI(api_key=key, base_url=base_url) self.model_name = model_name self.lang = lang \ No newline at end of file diff --git a/rag/llm/embedding_model.py b/rag/llm/embedding_model.py index 5c978040b68..25cbc5250e0 100644 --- a/rag/llm/embedding_model.py +++ b/rag/llm/embedding_model.py @@ -16,6 +16,8 @@ import logging import re import threading +from urllib.parse import urljoin + import requests from huggingface_hub import snapshot_download from zhipuai import ZhipuAI @@ -29,6 +31,7 @@ from api import settings from api.utils.file_utils import get_home_cache_dir +from api.utils.log_utils import log_exception from rag.utils import num_tokens_from_string, truncate import google.generativeai as genai import json @@ -127,8 +130,11 @@ def encode(self, texts: list): for i in range(0, len(texts), batch_size): res = self.client.embeddings.create(input=texts[i:i + batch_size], model=self.model_name) - ress.extend([d.embedding for d in res.data]) - total_tokens += self.total_token_count(res) + try: + ress.extend([d.embedding for d in res.data]) + total_tokens += self.total_token_count(res) + except Exception as _e: + log_exception(_e, res) return np.array(ress), total_tokens def encode_queries(self, text): @@ -141,8 +147,7 @@ class LocalAIEmbed(Base): def __init__(self, key, model_name, base_url): if not base_url: raise ValueError("Local embedding model url cannot be None") - if base_url.split("/")[-1] != "v1": - base_url = os.path.join(base_url, "v1") + base_url = urljoin(base_url, "v1") self.client = OpenAI(api_key="empty", base_url=base_url) self.model_name = model_name.split("___")[0] @@ -151,7 +156,10 @@ def encode(self, texts: list): ress = [] for i in range(0, len(texts), batch_size): res = self.client.embeddings.create(input=texts[i:i + batch_size], model=self.model_name) - ress.extend([d.embedding for d in res.data]) + try: + ress.extend([d.embedding for d in res.data]) + except Exception as _e: + log_exception(_e, res) # local embedding for LmStudio donot count tokens return np.array(ress), 1024 @@ -186,40 +194,39 @@ def __init__(self, key, model_name="text_embedding_v2", **kwargs): def encode(self, texts: list): import dashscope batch_size = 4 - try: - res = [] - token_count = 0 - texts = [truncate(t, 2048) for t in texts] - for i in range(0, len(texts), batch_size): - resp = dashscope.TextEmbedding.call( - model=self.model_name, - input=texts[i:i + batch_size], - api_key=self.key, - text_type="document" - ) + res = [] + token_count = 0 + texts = [truncate(t, 2048) for t in texts] + for i in range(0, len(texts), batch_size): + resp = dashscope.TextEmbedding.call( + model=self.model_name, + input=texts[i:i + batch_size], + api_key=self.key, + text_type="document" + ) + try: embds = [[] for _ in range(len(resp["output"]["embeddings"]))] for e in resp["output"]["embeddings"]: embds[e["text_index"]] = e["embedding"] res.extend(embds) token_count += self.total_token_count(resp) - return np.array(res), token_count - except Exception as e: - raise Exception("Account abnormal. Please ensure it's on good standing to use QWen's "+self.model_name) - return np.array([]), 0 + except Exception as _e: + log_exception(_e, resp) + raise + return np.array(res), token_count def encode_queries(self, text): + resp = dashscope.TextEmbedding.call( + model=self.model_name, + input=text[:2048], + api_key=self.key, + text_type="query" + ) try: - resp = dashscope.TextEmbedding.call( - model=self.model_name, - input=text[:2048], - api_key=self.key, - text_type="query" - ) return np.array(resp["output"]["embeddings"][0] ["embedding"]), self.total_token_count(resp) - except Exception: - raise Exception("Account abnormal. Please ensure it's on good standing to use QWen's "+self.model_name) - return np.array([]), 0 + except Exception as _e: + log_exception(_e, resp) class ZhipuEmbed(Base): @@ -241,14 +248,20 @@ def encode(self, texts: list): for txt in texts: res = self.client.embeddings.create(input=txt, model=self.model_name) - arr.append(res.data[0].embedding) - tks_num += self.total_token_count(res) + try: + arr.append(res.data[0].embedding) + tks_num += self.total_token_count(res) + except Exception as _e: + log_exception(_e, res) return np.array(arr), tks_num def encode_queries(self, text): res = self.client.embeddings.create(input=text, model=self.model_name) - return np.array(res.data[0].embedding), self.total_token_count(res) + try: + return np.array(res.data[0].embedding), self.total_token_count(res) + except Exception as _e: + log_exception(_e, res) class OllamaEmbed(Base): @@ -264,7 +277,10 @@ def encode(self, texts: list): res = self.client.embeddings(prompt=txt, model=self.model_name, options={"use_mmap": True}) - arr.append(res["embedding"]) + try: + arr.append(res["embedding"]) + except Exception as _e: + log_exception(_e, res) tks_num += 128 return np.array(arr), tks_num @@ -272,7 +288,10 @@ def encode_queries(self, text): res = self.client.embeddings(prompt=text, model=self.model_name, options={"use_mmap": True}) - return np.array(res["embedding"]), 128 + try: + return np.array(res["embedding"]), 128 + except Exception as _e: + log_exception(_e, res) class FastEmbed(DefaultEmbedding): @@ -322,8 +341,7 @@ def encode_queries(self, text: str): class XinferenceEmbed(Base): def __init__(self, key, model_name="", base_url=""): - if base_url.split("/")[-1] != "v1": - base_url = os.path.join(base_url, "v1") + base_url = urljoin(base_url, "v1") self.client = OpenAI(api_key=key, base_url=base_url) self.model_name = model_name @@ -333,14 +351,20 @@ def encode(self, texts: list): total_tokens = 0 for i in range(0, len(texts), batch_size): res = self.client.embeddings.create(input=texts[i:i + batch_size], model=self.model_name) - ress.extend([d.embedding for d in res.data]) - total_tokens += self.total_token_count(res) + try: + ress.extend([d.embedding for d in res.data]) + total_tokens += self.total_token_count(res) + except Exception as _e: + log_exception(_e, res) return np.array(ress), total_tokens def encode_queries(self, text): res = self.client.embeddings.create(input=[text], model=self.model_name) - return np.array(res.data[0].embedding), self.total_token_count(res) + try: + return np.array(res.data[0].embedding), self.total_token_count(res) + except Exception as _e: + log_exception(_e, res) class YoudaoEmbed(Base): @@ -397,9 +421,13 @@ def encode(self, texts: list): "input": texts[i:i + batch_size], 'encoding_type': 'float' } - res = requests.post(self.base_url, headers=self.headers, json=data).json() - ress.extend([d["embedding"] for d in res["data"]]) - token_count += self.total_token_count(res) + response = requests.post(self.base_url, headers=self.headers, json=data) + try: + res = response.json() + ress.extend([d["embedding"] for d in res["data"]]) + token_count += self.total_token_count(res) + except Exception as _e: + log_exception(_e, response) return np.array(ress), token_count def encode_queries(self, text): @@ -462,14 +490,20 @@ def encode(self, texts: list): for i in range(0, len(texts), batch_size): res = self.client.embeddings(input=texts[i:i + batch_size], model=self.model_name) - ress.extend([d.embedding for d in res.data]) - token_count += self.total_token_count(res) + try: + ress.extend([d.embedding for d in res.data]) + token_count += self.total_token_count(res) + except Exception as _e: + log_exception(_e, res) return np.array(ress), token_count def encode_queries(self, text): res = self.client.embeddings(input=[truncate(text, 8196)], model=self.model_name) - return np.array(res.data[0].embedding), self.total_token_count(res) + try: + return np.array(res.data[0].embedding), self.total_token_count(res) + except Exception as _e: + log_exception(_e, res) class BedrockEmbed(Base): @@ -499,9 +533,12 @@ def encode(self, texts: list): body = {"texts": [text], "input_type": 'search_document'} response = self.client.invoke_model(modelId=self.model_name, body=json.dumps(body)) - model_response = json.loads(response["body"].read()) - embeddings.extend([model_response["embedding"]]) - token_count += num_tokens_from_string(text) + try: + model_response = json.loads(response["body"].read()) + embeddings.extend([model_response["embedding"]]) + token_count += num_tokens_from_string(text) + except Exception as _e: + log_exception(_e, response) return np.array(embeddings), token_count @@ -514,8 +551,11 @@ def encode_queries(self, text): body = {"texts": [truncate(text, 8196)], "input_type": 'search_query'} response = self.client.invoke_model(modelId=self.model_name, body=json.dumps(body)) - model_response = json.loads(response["body"].read()) - embeddings.extend(model_response["embedding"]) + try: + model_response = json.loads(response["body"].read()) + embeddings.extend(model_response["embedding"]) + except Exception as _e: + log_exception(_e, response) return np.array(embeddings), token_count @@ -538,7 +578,10 @@ def encode(self, texts: list): content=texts[i: i + batch_size], task_type="retrieval_document", title="Embedding of single string") - ress.extend(result['embedding']) + try: + ress.extend(result['embedding']) + except Exception as _e: + log_exception(_e, result) return np.array(ress),token_count def encode_queries(self, text): @@ -549,7 +592,10 @@ def encode_queries(self, text): task_type="retrieval_document", title="Embedding of single string") token_count = num_tokens_from_string(text) - return np.array(result['embedding']), token_count + try: + return np.array(result['embedding']), token_count + except Exception as _e: + log_exception(_e, result) class NvidiaEmbed(Base): @@ -584,7 +630,11 @@ def encode(self, texts: list): "encoding_format": "float", "truncate": "END", } - res = requests.post(self.base_url, headers=self.headers, json=payload).json() + response = requests.post(self.base_url, headers=self.headers, json=payload) + try: + res = response.json() + except Exception as _e: + log_exception(_e, response) ress.extend([d["embedding"] for d in res["data"]]) token_count += self.total_token_count(res) return np.array(ress), token_count @@ -598,8 +648,7 @@ class LmStudioEmbed(LocalAIEmbed): def __init__(self, key, model_name, base_url): if not base_url: raise ValueError("Local llm url cannot be None") - if base_url.split("/")[-1] != "v1": - base_url = os.path.join(base_url, "v1") + base_url = urljoin(base_url, "v1") self.client = OpenAI(api_key="lm-studio", base_url=base_url) self.model_name = model_name @@ -608,8 +657,7 @@ class OpenAI_APIEmbed(OpenAIEmbed): def __init__(self, key, model_name, base_url): if not base_url: raise ValueError("url cannot be None") - if base_url.split("/")[-1] != "v1": - base_url = os.path.join(base_url, "v1") + base_url = urljoin(base_url, "v1") self.client = OpenAI(api_key=key, base_url=base_url) self.model_name = model_name.split("___")[0] @@ -632,8 +680,11 @@ def encode(self, texts: list): input_type="search_document", embedding_types=["float"], ) - ress.extend([d for d in res.embeddings.float]) - token_count += res.meta.billed_units.input_tokens + try: + ress.extend([d for d in res.embeddings.float]) + token_count += res.meta.billed_units.input_tokens + except Exception as _e: + log_exception(_e, res) return np.array(ress), token_count def encode_queries(self, text): @@ -643,9 +694,10 @@ def encode_queries(self, text): input_type="search_query", embedding_types=["float"], ) - return np.array(res.embeddings.float[0]), int( - res.meta.billed_units.input_tokens - ) + try: + return np.array(res.embeddings.float[0]), int(res.meta.billed_units.input_tokens) + except Exception as _e: + log_exception(_e, res) class TogetherAIEmbed(OpenAIEmbed): @@ -694,11 +746,14 @@ def encode(self, texts: list): "input": texts_batch, "encoding_format": "float", } - res = requests.post(self.base_url, json=payload, headers=self.headers).json() - if "data" not in res or not isinstance(res["data"], list) or len(res["data"]) != len(texts_batch): - raise ValueError(f"SILICONFLOWEmbed.encode got invalid response from {self.base_url}") - ress.extend([d["embedding"] for d in res["data"]]) - token_count += self.total_token_count(res) + response = requests.post(self.base_url, json=payload, headers=self.headers) + try: + res = response.json() + ress.extend([d["embedding"] for d in res["data"]]) + token_count += self.total_token_count(res) + except Exception as _e: + log_exception(_e, response) + return np.array(ress), token_count def encode_queries(self, text): @@ -707,10 +762,12 @@ def encode_queries(self, text): "input": text, "encoding_format": "float", } - res = requests.post(self.base_url, json=payload, headers=self.headers).json() - if "data" not in res or not isinstance(res["data"], list) or len(res["data"])!= 1: - raise ValueError(f"SILICONFLOWEmbed.encode_queries got invalid response from {self.base_url}") - return np.array(res["data"][0]["embedding"]), self.total_token_count(res) + response = requests.post(self.base_url, json=payload, headers=self.headers) + try: + res = response.json() + return np.array(res["data"][0]["embedding"]), self.total_token_count(res) + except Exception as _e: + log_exception(_e, response) class ReplicateEmbed(Base): @@ -746,17 +803,23 @@ def __init__(self, key, model_name, base_url=None): def encode(self, texts: list, batch_size=16): res = self.client.do(model=self.model_name, texts=texts).body - return ( - np.array([r["embedding"] for r in res["data"]]), - self.total_token_count(res), - ) + try: + return ( + np.array([r["embedding"] for r in res["data"]]), + self.total_token_count(res), + ) + except Exception as _e: + log_exception(_e, res) def encode_queries(self, text): res = self.client.do(model=self.model_name, texts=[text]).body - return ( - np.array([r["embedding"] for r in res["data"]]), - self.total_token_count(res), - ) + try: + return ( + np.array([r["embedding"] for r in res["data"]]), + self.total_token_count(res), + ) + except Exception as _e: + log_exception(_e, res) class VoyageEmbed(Base): @@ -774,15 +837,21 @@ def encode(self, texts: list): res = self.client.embed( texts=texts[i : i + batch_size], model=self.model_name, input_type="document" ) - ress.extend(res.embeddings) - token_count += res.total_tokens + try: + ress.extend(res.embeddings) + token_count += res.total_tokens + except Exception as _e: + log_exception(_e, res) return np.array(ress), token_count def encode_queries(self, text): res = self.client.embed( texts=text, model=self.model_name, input_type="query" ) - return np.array(res.embeddings)[0], res.total_tokens + try: + return np.array(res.embeddings)[0], res.total_tokens + except Exception as _e: + log_exception(_e, res) class HuggingFaceEmbed(Base): @@ -802,11 +871,14 @@ def encode(self, texts: list): headers={'Content-Type': 'application/json'} ) if response.status_code == 200: - embedding = response.json() - embeddings.append(embedding[0]) + try: + embedding = response.json() + embeddings.append(embedding[0]) + return np.array(embeddings), sum([num_tokens_from_string(text) for text in texts]) + except Exception as _e: + log_exception(_e, response) else: raise Exception(f"Error: {response.status_code} - {response.text}") - return np.array(embeddings), sum([num_tokens_from_string(text) for text in texts]) def encode_queries(self, text): response = requests.post( @@ -815,8 +887,11 @@ def encode_queries(self, text): headers={'Content-Type': 'application/json'} ) if response.status_code == 200: - embedding = response.json() - return np.array(embedding[0]), num_tokens_from_string(text) + try: + embedding = response.json() + return np.array(embedding[0]), num_tokens_from_string(text) + except Exception as _e: + log_exception(_e, response) else: raise Exception(f"Error: {response.status_code} - {response.text}") @@ -829,12 +904,17 @@ def __init__(self, key, model_name, base_url="https://ark.cn-beijing.volces.com/ model_name = json.loads(key).get('ep_id', '') + json.loads(key).get('endpoint_id', '') super().__init__(ark_api_key,model_name,base_url) + class GPUStackEmbed(OpenAIEmbed): def __init__(self, key, model_name, base_url): if not base_url: raise ValueError("url cannot be None") - if base_url.split("/")[-1] != "v1": - base_url = os.path.join(base_url, "v1") + base_url = urljoin(base_url, "v1") self.client = OpenAI(api_key=key, base_url=base_url) self.model_name = model_name + + +class NovitaEmbed(SILICONFLOWEmbed): + def __init__(self, key, model_name, base_url="https://api.novita.ai/v3/openai/embeddings"): + super().__init__(key, model_name, base_url) \ No newline at end of file diff --git a/rag/llm/rerank_model.py b/rag/llm/rerank_model.py index 09631e9933e..5310a5f3baa 100644 --- a/rag/llm/rerank_model.py +++ b/rag/llm/rerank_model.py @@ -28,6 +28,7 @@ from api import settings from api.utils.file_utils import get_home_cache_dir +from api.utils.log_utils import log_exception from rag.utils import num_tokens_from_string, truncate import json @@ -170,8 +171,11 @@ def similarity(self, query: str, texts: list): } res = requests.post(self.base_url, headers=self.headers, json=data).json() rank = np.zeros(len(texts), dtype=float) - for d in res["results"]: - rank[d["index"]] = d["relevance_score"] + try: + for d in res["results"]: + rank[d["index"]] = d["relevance_score"] + except Exception as _e: + log_exception(_e, res) return rank, self.total_token_count(res) @@ -238,8 +242,11 @@ def similarity(self, query: str, texts: list): } res = requests.post(self.base_url, headers=self.headers, json=data).json() rank = np.zeros(len(texts), dtype=float) - for d in res["results"]: - rank[d["index"]] = d["relevance_score"] + try: + for d in res["results"]: + rank[d["index"]] = d["relevance_score"] + except Exception as _e: + log_exception(_e, res) return rank, token_count @@ -269,10 +276,11 @@ def similarity(self, query: str, texts: list): token_count += num_tokens_from_string(t) res = requests.post(self.base_url, headers=self.headers, json=data).json() rank = np.zeros(len(texts), dtype=float) - if 'results' not in res: - raise ValueError("response not contains results\n" + str(res)) - for d in res["results"]: - rank[d["index"]] = d["relevance_score"] + try: + for d in res["results"]: + rank[d["index"]] = d["relevance_score"] + except Exception as _e: + log_exception(_e, res) # Normalize the rank values to the range 0 to 1 min_rank = np.min(rank) @@ -296,12 +304,11 @@ def __init__( self.model_name = model_name if self.model_name == "nvidia/nv-rerankqa-mistral-4b-v3": - self.base_url = os.path.join( - base_url, "nv-rerankqa-mistral-4b-v3", "reranking" + self.base_url = urljoin(base_url, "nv-rerankqa-mistral-4b-v3/reranking" ) if self.model_name == "nvidia/rerank-qa-mistral-4b": - self.base_url = os.path.join(base_url, "reranking") + self.base_url = urljoin(base_url, "reranking") self.model_name = "nv-rerank-qa-mistral-4b:1" self.headers = { @@ -323,8 +330,11 @@ def similarity(self, query: str, texts: list): } res = requests.post(self.base_url, headers=self.headers, json=data).json() rank = np.zeros(len(texts), dtype=float) - for d in res["rankings"]: - rank[d["index"]] = d["logit"] + try: + for d in res["rankings"]: + rank[d["index"]] = d["logit"] + except Exception as _e: + log_exception(_e, res) return rank, token_count @@ -362,10 +372,11 @@ def similarity(self, query: str, texts: list): token_count += num_tokens_from_string(t) res = requests.post(self.base_url, headers=self.headers, json=data).json() rank = np.zeros(len(texts), dtype=float) - if 'results' not in res: - raise ValueError("response not contains results\n" + str(res)) - for d in res["results"]: - rank[d["index"]] = d["relevance_score"] + try: + for d in res["results"]: + rank[d["index"]] = d["relevance_score"] + except Exception as _e: + log_exception(_e, res) # Normalize the rank values to the range 0 to 1 min_rank = np.min(rank) @@ -399,8 +410,11 @@ def similarity(self, query: str, texts: list): return_documents=False, ) rank = np.zeros(len(texts), dtype=float) - for d in res.results: - rank[d.index] = d.relevance_score + try: + for d in res.results: + rank[d.index] = d.relevance_score + except Exception as _e: + log_exception(_e, res) return rank, token_count @@ -440,11 +454,11 @@ def similarity(self, query: str, texts: list): self.base_url, json=payload, headers=self.headers ).json() rank = np.zeros(len(texts), dtype=float) - if "results" not in response: - return rank, 0 - - for d in response["results"]: - rank[d["index"]] = d["relevance_score"] + try: + for d in response["results"]: + rank[d["index"]] = d["relevance_score"] + except Exception as _e: + log_exception(_e, response) return ( rank, response["meta"]["tokens"]["input_tokens"] + response["meta"]["tokens"]["output_tokens"], @@ -469,8 +483,11 @@ def similarity(self, query: str, texts: list): top_n=len(texts), ).body rank = np.zeros(len(texts), dtype=float) - for d in res["results"]: - rank[d["index"]] = d["relevance_score"] + try: + for d in res["results"]: + rank[d["index"]] = d["relevance_score"] + except Exception as _e: + log_exception(_e, res) return rank, self.total_token_count(res) @@ -488,8 +505,11 @@ def similarity(self, query: str, texts: list): res = self.client.rerank( query=query, documents=texts, model=self.model_name, top_k=len(texts) ) - for r in res.results: - rank[r.index] = r.relevance_score + try: + for r in res.results: + rank[r.index] = r.relevance_score + except Exception as _e: + log_exception(_e, res) return rank, res.total_tokens @@ -512,8 +532,11 @@ def similarity(self, query: str, texts: list): ) rank = np.zeros(len(texts), dtype=float) if resp.status_code == HTTPStatus.OK: - for r in resp.output.results: - rank[r.index] = r.relevance_score + try: + for r in resp.output.results: + rank[r.index] = r.relevance_score + except Exception as _e: + log_exception(_e, resp) return rank, resp.usage.total_tokens else: raise ValueError(f"Error calling QWenRerank model {self.model_name}: {resp.status_code} - {resp.text}") @@ -530,6 +553,7 @@ def post(query: str, texts: list, url="127.0.0.1"): res = requests.post(f"http://{url}/rerank", headers={"Content-Type": "application/json"}, json={"query": query, "texts": texts[i: i + batch_size], "raw_scores": False, "truncate": True}) + for o in res.json(): scores[o["index"] + i] = o["score"] except Exception as e: @@ -583,15 +607,15 @@ def similarity(self, query: str, texts: list): response_json = response.json() rank = np.zeros(len(texts), dtype=float) - if "results" not in response_json: - return rank, 0 token_count = 0 for t in texts: token_count += num_tokens_from_string(t) - - for result in response_json["results"]: - rank[result["index"]] = result["relevance_score"] + try: + for result in response_json["results"]: + rank[result["index"]] = result["relevance_score"] + except Exception as _e: + log_exception(_e, response) return ( rank, @@ -602,3 +626,7 @@ def similarity(self, query: str, texts: list): raise ValueError( f"Error calling GPUStackRerank model {self.model_name}: {e.response.status_code} - {e.response.text}") + +class NovitaRerank(JinaRerank): + def __init__(self, key, model_name, base_url="https://api.novita.ai/v3/openai/rerank"): + super().__init__(key, model_name, base_url) \ No newline at end of file diff --git a/rag/nlp/__init__.py b/rag/nlp/__init__.py index 5b0d4ff36c5..f88c059a5f4 100644 --- a/rag/nlp/__init__.py +++ b/rag/nlp/__init__.py @@ -279,12 +279,13 @@ def tokenize_chunks(chunks, doc, eng, pdf_parser=None): def tokenize_chunks_with_images(chunks, doc, eng, images): res = [] # wrap up as es documents - for ck, image in zip(chunks, images): + for ii, (ck, image) in enumerate(zip(chunks, images)): if len(ck.strip()) == 0: continue logging.debug("-- {}".format(ck)) d = copy.deepcopy(doc) d["image"] = image + add_positions(d, [[ii]*5]) tokenize(d, ck, eng) res.append(d) return res @@ -343,7 +344,7 @@ def get(i): type("")) else sections[i][0]).strip() if not re.match(r"(contents|目录|目次|table of contents|致谢|acknowledge)$", - re.sub(r"( | |\u3000)+", "", get(i).split("@@")[0], re.IGNORECASE)): + re.sub(r"( | |\u3000)+", "", get(i).split("@@")[0], flags=re.IGNORECASE)): i += 1 continue sections.pop(i) @@ -524,7 +525,7 @@ def add_chunk(t, pos): if tnum < 8: pos = "" # Ensure that the length of the merged chunk does not exceed chunk_token_num - if tk_nums[-1] > chunk_token_num: + if cks[-1] == "" or tk_nums[-1] > chunk_token_num: if t.find(pos) < 0: t += pos @@ -536,11 +537,16 @@ def add_chunk(t, pos): cks[-1] += t tk_nums[-1] += tnum + dels = get_delimiters(delimiter) for sec, pos in sections: - add_chunk(sec, pos) + splited_sec = re.split(r"(%s)" % dels, sec) + for sub_sec in splited_sec: + if re.match(f"^{dels}$", sub_sec): + continue + add_chunk(sub_sec, pos) return cks - + def naive_merge_with_images(texts, images, chunk_token_num=128, delimiter="\n。;!?"): if not texts or len(texts) != len(images): @@ -560,7 +566,7 @@ def add_chunk(t, image, pos=""): if tnum < 8: pos = "" # Ensure that the length of the merged chunk does not exceed chunk_token_num - if tk_nums[-1] > chunk_token_num: + if cks[-1] == "" or tk_nums[-1] > chunk_token_num: if t.find(pos) < 0: t += pos cks.append(t) @@ -576,8 +582,13 @@ def add_chunk(t, image, pos=""): result_images[-1] = concat_img(result_images[-1], image) tk_nums[-1] += tnum + dels = get_delimiters(delimiter) for text, image in zip(texts, images): - add_chunk(text, image) + splited_sec = re.split(r"(%s)" % dels, text) + for sub_sec in splited_sec: + if re.match(f"^{dels}$", sub_sec): + continue + add_chunk(text, image) return cks, result_images @@ -627,7 +638,7 @@ def add_chunk(t, image, pos=""): tnum = num_tokens_from_string(t) if tnum < 8: pos = "" - if tk_nums[-1] > chunk_token_num: + if cks[-1] == "" or tk_nums[-1] > chunk_token_num: if t.find(pos) < 0: t += pos cks.append(t) @@ -640,8 +651,13 @@ def add_chunk(t, image, pos=""): images[-1] = concat_img(images[-1], image) tk_nums[-1] += tnum + dels = get_delimiters(delimiter) for sec, image in sections: - add_chunk(sec, image, '') + splited_sec = re.split(r"(%s)" % dels, sec) + for sub_sec in splited_sec: + if re.match(f"^{dels}$", sub_sec): + continue + add_chunk(sub_sec, image,"") return cks, images @@ -649,3 +665,22 @@ def add_chunk(t, image, pos=""): def extract_between(text: str, start_tag: str, end_tag: str) -> list[str]: pattern = re.escape(start_tag) + r"(.*?)" + re.escape(end_tag) return re.findall(pattern, text, flags=re.DOTALL) + + +def get_delimiters(delimiters: str): + dels = [] + s = 0 + for m in re.finditer(r"`([^`]+)`", delimiters, re.I): + f, t = m.span() + dels.append(m.group(1)) + dels.extend(list(delimiters[s: f])) + s = t + if s < len(delimiters): + dels.extend(list(delimiters[s:])) + + dels.sort(key=lambda x: -len(x)) + dels = [re.escape(d) for d in dels if d] + dels = [d for d in dels if d] + dels_pattern = "|".join(dels) + + return dels_pattern diff --git a/rag/nlp/query.py b/rag/nlp/query.py index 34333a3505e..55b4e9d3260 100644 --- a/rag/nlp/query.py +++ b/rag/nlp/query.py @@ -71,7 +71,19 @@ def rmWWW(txt): txt = otxt return txt + @staticmethod + def add_space_between_eng_zh(txt): + # (ENG/ENG+NUM) + ZH + txt = re.sub(r'([A-Za-z]+[0-9]+)([\u4e00-\u9fa5]+)', r'\1 \2', txt) + # ENG + ZH + txt = re.sub(r'([A-Za-z])([\u4e00-\u9fa5]+)', r'\1 \2', txt) + # ZH + (ENG/ENG+NUM) + txt = re.sub(r'([\u4e00-\u9fa5]+)([A-Za-z]+[0-9]+)', r'\1 \2', txt) + txt = re.sub(r'([\u4e00-\u9fa5]+)([A-Za-z])', r'\1 \2', txt) + return txt + def question(self, txt, tbl="qa", min_match: float = 0.6): + txt = FulltextQueryer.add_space_between_eng_zh(txt) txt = re.sub( r"[ :|\r\n\t,,。??/`!!&^%%()\[\]{}<>]+", " ", diff --git a/rag/nlp/search.py b/rag/nlp/search.py index cf024a381ac..855468c9e0d 100644 --- a/rag/nlp/search.py +++ b/rag/nlp/search.py @@ -245,7 +245,7 @@ def insert_citations(self, answer, chunks, chunk_v, for c in cites[i]: if c in seted: continue - res += f" ##{c}$$" + res += f" [ID:{c}]" seted.add(c) return res, seted @@ -380,15 +380,12 @@ def retrieval(self, question, embd_mdl, tenant_ids, kb_ids, page, page_size, sim rank_feature=rank_feature) # Already paginated in search function idx = np.argsort(sim * -1)[(page - 1) * page_size:page * page_size] - - dim = len(sres.query_vector) vector_column = f"q_{dim}_vec" zero_vector = [0.0] * dim - if doc_ids: - similarity_threshold = 0 - page_size = 30 sim_np = np.array(sim) + if doc_ids: + similarity_threshold = 0 filtered_count = (sim_np >= similarity_threshold).sum() ranks["total"] = int(filtered_count) # Convert from np.int64 to Python int otherwise JSON serializable error for i in idx: diff --git a/rag/prompts.py b/rag/prompts.py index cb1e1108b01..389d6a66d84 100644 --- a/rag/prompts.py +++ b/rag/prompts.py @@ -119,7 +119,7 @@ def kb_prompt(kbinfos, max_tokens): doc2chunks = defaultdict(lambda: {"chunks": [], "meta": []}) for i, ck in enumerate(kbinfos["chunks"][:chunks_num]): cnt = f"---\nID: {i}\n" + (f"URL: {ck['url']}\n" if "url" in ck else "") - cnt += ck["content_with_weight"] + cnt += re.sub(r"( style=\"[^\"]+\"||)", " ", ck["content_with_weight"], flags=re.DOTALL|re.IGNORECASE) doc2chunks[ck["docnm_kwd"]]["chunks"].append(cnt) doc2chunks[ck["docnm_kwd"]]["meta"] = docs.get(ck["doc_id"], {}) @@ -136,16 +136,18 @@ def kb_prompt(kbinfos, max_tokens): def citation_prompt(): + print("USE PROMPT", flush=True) return """ # Citation requirements: -- Inserts CITATIONS in format '##i$$ ##j$$' where i,j are the ID of the content you are citing and encapsulated with '##' and '$$'. -- Inserts the CITATION symbols at the end of a sentence, AND NO MORE than 4 citations. + +- Use a uniform citation format such as [ID:i] [ID:j], where "i" and "j" are document IDs enclosed in square brackets. Separate multiple IDs with spaces (e.g., [ID:0] [ID:1]). +- Citation markers must be placed at the end of a sentence, separated by a space from the final punctuation (e.g., period, question mark). A maximum of 4 citations are allowed per sentence. - DO NOT insert CITATION in the answer if the content is not from retrieved chunks. - DO NOT use standalone Document IDs (e.g., '#ID#'). -- Under NO circumstances any other citation styles or formats (e.g., '~~i==', '[i]', '(i)', etc.) be used. -- Citations ALWAYS the '##i$$' format. -- Any failure to adhere to the above rules, including but not limited to incorrect formatting, use of prohibited styles, or unsupported citations, will be considered a error, should skip adding Citation for this sentence. +- Citations ALWAYS in the "[ID:i]" format. +- STRICTLY prohibit the use of strikethrough symbols (e.g., ~~) or any other non-standard formatting syntax. +- Any failure to adhere to the above rules, including but not limited to incorrect formatting, use of prohibited styles, or unsupported citations, will be considered an error, and no citation will be added for that sentence. --- Example START --- : Here is the knowledge base: @@ -171,8 +173,8 @@ def citation_prompt(): : What's the Elon's view on dogecoin? -: Musk has consistently expressed his fondness for Dogecoin, often citing its humor and the inclusion of dogs in its branding. He has referred to it as his favorite cryptocurrency ##0$$ ##1$$. -Recently, Musk has hinted at potential future roles for Dogecoin. His tweets have sparked speculation about Dogecoin's potential integration into public services ##3$$. +: Musk has consistently expressed his fondness for Dogecoin, often citing its humor and the inclusion of dogs in its branding. He has referred to it as his favorite cryptocurrency [ID:0] [ID:1]. +Recently, Musk has hinted at potential future roles for Dogecoin. His tweets have sparked speculation about Dogecoin's potential integration into public services [ID:3]. Overall, while Musk enjoys Dogecoin and often promotes it, he also warns against over-investing in it, reflecting both his personal amusement and caution regarding its speculative nature. --- Example END --- @@ -182,13 +184,13 @@ def citation_prompt(): def keyword_extraction(chat_mdl, content, topn=3): prompt = f""" -Role: You're a text analyzer. -Task: extract the most important keywords/phrases of a given piece of text content. +Role: You are a text analyzer. +Task: Extract the most important keywords/phrases of a given piece of text content. Requirements: - - Summarize the text content, and give top {topn} important keywords/phrases. - - The keywords MUST be in language of the given piece of text content. + - Summarize the text content, and give the top {topn} important keywords/phrases. + - The keywords MUST be in the same language as the given piece of text content. - The keywords are delimited by ENGLISH COMMA. - - Keywords ONLY in output. + - Output keywords ONLY. ### Text Content {content} @@ -207,15 +209,15 @@ def keyword_extraction(chat_mdl, content, topn=3): def question_proposal(chat_mdl, content, topn=3): prompt = f""" -Role: You're a text analyzer. -Task: propose {topn} questions about a given piece of text content. +Role: You are a text analyzer. +Task: Propose {topn} questions about a given piece of text content. Requirements: - - Understand and summarize the text content, and propose top {topn} important questions. + - Understand and summarize the text content, and propose the top {topn} important questions. - The questions SHOULD NOT have overlapping meanings. - The questions SHOULD cover the main content of the text as much as possible. - - The questions MUST be in language of the given piece of text content. + - The questions MUST be in the same language as the given piece of text content. - One question per line. - - Question ONLY in output. + - Output questions ONLY. ### Text Content {content} @@ -256,14 +258,14 @@ def full_question(tenant_id, llm_id, messages, language=None): 2. If the user's question involves relative date, you need to convert it into absolute date based on the current date, which is {today}. For example: 'yesterday' would be converted to {yesterday}. Requirements & Restrictions: - - If the user's latest question is completely, don't do anything, just return the original question. + - If the user's latest question is already complete, don't do anything, just return the original question. - DON'T generate anything except a refined question.""" if language: prompt += f""" - Text generated MUST be in {language}.""" else: prompt += """ - - Text generated MUST be in the same language of the original user's question. + - Text generated MUST be in the same language as the original user's question. """ prompt += f""" @@ -309,6 +311,7 @@ def full_question(tenant_id, llm_id, messages, language=None): ans = re.sub(r"^.*", "", ans, flags=re.DOTALL) return ans if ans.find("**ERROR**") < 0 else messages[-1]["content"] + def cross_languages(tenant_id, llm_id, query, languages=[]): from api.db.services.llm_service import LLMBundle @@ -339,7 +342,7 @@ def cross_languages(tenant_id, llm_id, query, languages=[]): Input: Hello World! Let's discuss AI safety. === -Chinese, French, Jappanese +Chinese, French, Japanese Output: 你好世界!让我们讨论人工智能安全问题。 @@ -348,11 +351,11 @@ def cross_languages(tenant_id, llm_id, query, languages=[]): ### こんにちは世界!AIの安全性について話し合いましょう。 """ - user_prompt=f""" + user_prompt = f""" Input: {query} === -{', '.join(languages)} +{", ".join(languages)} Output: """ @@ -366,20 +369,20 @@ def cross_languages(tenant_id, llm_id, query, languages=[]): def content_tagging(chat_mdl, content, all_tags, examples, topn=3): prompt = f""" -Role: You're a text analyzer. +Role: You are a text analyzer. -Task: Tag (put on some labels) to a given piece of text content based on the examples and the entire tag set. +Task: Add tags (labels) to a given piece of text content based on the examples and the entire tag set. -Steps:: - - Comprehend the tag/label set. - - Comprehend examples which all consist of both text content and assigned tags with relevance score in format of JSON. - - Summarize the text content, and tag it with top {topn} most relevant tags from the set of tag/label and the corresponding relevance score. +Steps: + - Review the tag/label set. + - Review examples which all consist of both text content and assigned tags with relevance score in JSON format. + - Summarize the text content, and tag it with the top {topn} most relevant tags from the set of tags/labels and the corresponding relevance score. -Requirements +Requirements: - The tags MUST be from the tag set. - The output MUST be in JSON format only, the key is tag and the value is its relevance score. - - The relevance score must be range from 1 to 10. - - Keywords ONLY in output. + - The relevance score must range from 1 to 10. + - Output keywords ONLY. # TAG SET {", ".join(all_tags)} @@ -479,6 +482,6 @@ def vision_llm_figure_describe_prompt() -> str: - Trends / Insights: [Analysis and interpretation] - Captions / Annotations: [Text and relevance, if available] -Ensure high accuracy, clarity, and completeness in your analysis, and includes only the information present in the image. Avoid unnecessary statements about missing elements. +Ensure high accuracy, clarity, and completeness in your analysis, and include only the information present in the image. Avoid unnecessary statements about missing elements. """ return prompt diff --git a/rag/raptor.py b/rag/raptor.py index 007f2529a13..db2c82d2eec 100644 --- a/rag/raptor.py +++ b/rag/raptor.py @@ -151,8 +151,7 @@ async def summarize(ck_idx: list[int]): for c in range(n_clusters): ck_idx = [i + start for i in range(len(lbls)) if lbls[i] == c] assert len(ck_idx) > 0 - async with chat_limiter: - nursery.start_soon(summarize, ck_idx) + nursery.start_soon(summarize, ck_idx) assert len(chunks) - end == n_clusters, "{} vs. {}".format( len(chunks) - end, n_clusters diff --git a/rag/res/synonym.json b/rag/res/synonym.json index 369ba1759ba..d5f1336caeb 100644 --- a/rag/res/synonym.json +++ b/rag/res/synonym.json @@ -10527,10 +10527,10 @@ "833454": "同心传动", "博纳影业": "001330", "001330": "博纳影业", -"去年": "2022", -"前年": "2021", -"今年": "2023", -"上季度": ["三季度", "q3"], +"去年": "2024", +"前年": "2023", +"今年": "2025", +"上季度": ["一季度", "q1"], "q1": "一季度", "q2": "二季度", "q3": "三季度", diff --git a/rag/settings.py b/rag/settings.py index 2dfaea62760..70d1b6234cc 100644 --- a/rag/settings.py +++ b/rag/settings.py @@ -56,13 +56,14 @@ REDIS = {} pass DOC_MAXIMUM_SIZE = int(os.environ.get("MAX_CONTENT_LENGTH", 128 * 1024 * 1024)) - +DOC_BULK_SIZE = int(os.environ.get("DOC_BULK_SIZE", 4)) +EMBEDDING_BATCH_SIZE = int(os.environ.get("EMBEDDING_BATCH_SIZE", 16)) SVR_QUEUE_NAME = "rag_flow_svr_queue" SVR_CONSUMER_GROUP_NAME = "rag_flow_svr_task_broker" PAGERANK_FLD = "pagerank_fea" TAG_FLD = "tag_feas" -PARALLEL_DEVICES = None +PARALLEL_DEVICES = 0 try: import torch.cuda PARALLEL_DEVICES = torch.cuda.device_count() diff --git a/rag/svr/task_executor.py b/rag/svr/task_executor.py index 6448c4e81e9..fbe6f134f61 100644 --- a/rag/svr/task_executor.py +++ b/rag/svr/task_executor.py @@ -21,7 +21,7 @@ import threading import time -from api.utils.log_utils import initRootLogger, get_project_base_directory +from api.utils.log_utils import init_root_logger, get_project_base_directory from graphrag.general.index import run_graphrag from graphrag.utils import get_llm_cache, set_llm_cache, get_tags_from_cache, set_tags_to_cache from rag.prompts import keyword_extraction, question_proposal, content_tagging @@ -58,7 +58,7 @@ email, tag from rag.nlp import search, rag_tokenizer from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor -from rag.settings import DOC_MAXIMUM_SIZE, SVR_CONSUMER_GROUP_NAME, get_svr_queue_name, get_svr_queue_names, print_rag_settings, TAG_FLD, PAGERANK_FLD +from rag.settings import DOC_MAXIMUM_SIZE, DOC_BULK_SIZE, EMBEDDING_BATCH_SIZE, SVR_CONSUMER_GROUP_NAME, get_svr_queue_name, get_svr_queue_names, print_rag_settings, TAG_FLD, PAGERANK_FLD from rag.utils import num_tokens_from_string, truncate from rag.utils.redis_conn import REDIS_CONN, RedisDistributedLock from rag.utils.storage_factory import STORAGE_IMPL @@ -100,9 +100,10 @@ MAX_CONCURRENT_TASKS = int(os.environ.get('MAX_CONCURRENT_TASKS', "5")) MAX_CONCURRENT_CHUNK_BUILDERS = int(os.environ.get('MAX_CONCURRENT_CHUNK_BUILDERS', "1")) MAX_CONCURRENT_MINIO = int(os.environ.get('MAX_CONCURRENT_MINIO', '10')) -task_limiter = trio.CapacityLimiter(MAX_CONCURRENT_TASKS) +task_limiter = trio.Semaphore(MAX_CONCURRENT_TASKS) chunk_limiter = trio.CapacityLimiter(MAX_CONCURRENT_CHUNK_BUILDERS) minio_limiter = trio.CapacityLimiter(MAX_CONCURRENT_MINIO) +kg_limiter = trio.CapacityLimiter(2) WORKER_HEARTBEAT_TIMEOUT = int(os.environ.get('WORKER_HEARTBEAT_TIMEOUT', '120')) stop_event = threading.Event() @@ -185,6 +186,7 @@ def set_progress(task_id, from_page=0, to_page=-1, prog=None, msg="Processing... async def collect(): global CONSUMER_NAME, DONE_TASKS, FAILED_TASKS global UNACKED_ITERATOR + svr_queue_names = get_svr_queue_names() try: if not UNACKED_ITERATOR: @@ -276,28 +278,27 @@ async def build_chunks(task, progress_callback): async def upload_to_minio(document, chunk): try: + d = copy.deepcopy(document) + d.update(chunk) + d["id"] = xxhash.xxh64((chunk["content_with_weight"] + str(d["doc_id"])).encode("utf-8")).hexdigest() + d["create_time"] = str(datetime.now()).replace("T", " ")[:19] + d["create_timestamp_flt"] = datetime.now().timestamp() + if not d.get("image"): + _ = d.pop("image", None) + d["img_id"] = "" + docs.append(d) + return + + output_buffer = BytesIO() + if isinstance(d["image"], bytes): + output_buffer = BytesIO(d["image"]) + else: + d["image"].save(output_buffer, format='JPEG') async with minio_limiter: - d = copy.deepcopy(document) - d.update(chunk) - d["id"] = xxhash.xxh64((chunk["content_with_weight"] + str(d["doc_id"])).encode("utf-8")).hexdigest() - d["create_time"] = str(datetime.now()).replace("T", " ")[:19] - d["create_timestamp_flt"] = datetime.now().timestamp() - if not d.get("image"): - _ = d.pop("image", None) - d["img_id"] = "" - docs.append(d) - return - - output_buffer = BytesIO() - if isinstance(d["image"], bytes): - output_buffer = BytesIO(d["image"]) - else: - d["image"].save(output_buffer, format='JPEG') await trio.to_thread.run_sync(lambda: STORAGE_IMPL.put(task["kb_id"], d["id"], output_buffer.getvalue())) - - d["img_id"] = "{}-{}".format(task["kb_id"], d["id"]) - del d["image"] - docs.append(d) + d["img_id"] = "{}-{}".format(task["kb_id"], d["id"]) + del d["image"] + docs.append(d) except Exception: logging.exception( "Saving image of chunk {}/{}/{} got exception".format(task["location"], task["name"], d["id"])) @@ -368,6 +369,10 @@ async def doc_question_proposal(chat_mdl, d, topn): docs_to_tag = [] for d in docs: + task_canceled = TaskService.do_cancel(task["id"]) + if task_canceled: + progress_callback(-1, msg="Task has been canceled.") + return if settings.retrievaler.tag_content(tenant_id, kb_ids, d, all_tags, topn_tags=topn_tags, S=S) and len(d[TAG_FLD]) > 0: examples.append({"content": d["content_with_weight"], TAG_FLD: d[TAG_FLD]}) else: @@ -402,7 +407,6 @@ def init_kb(row, vector_size: int): async def embedding(docs, mdl, parser_config=None, callback=None): if parser_config is None: parser_config = {} - batch_size = 16 tts, cnts = [], [] for d in docs: tts.append(d.get("docnm_kwd", "Title")) @@ -421,8 +425,8 @@ async def embedding(docs, mdl, parser_config=None, callback=None): tk_count += c cnts_ = np.array([]) - for i in range(0, len(cnts), batch_size): - vts, c = await trio.to_thread.run_sync(lambda: mdl.encode([truncate(c, mdl.max_length-10) for c in cnts[i: i + batch_size]])) + for i in range(0, len(cnts), EMBEDDING_BATCH_SIZE): + vts, c = await trio.to_thread.run_sync(lambda: mdl.encode([truncate(c, mdl.max_length-10) for c in cnts[i: i + EMBEDDING_BATCH_SIZE]])) if len(cnts_) == 0: cnts_ = vts else: @@ -532,11 +536,10 @@ async def do_handle_task(task): # bind LLM for raptor chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language) # run RAPTOR - chunks, token_count = await run_raptor(task, chat_model, embedding_model, vector_size, progress_callback) + async with kg_limiter: + chunks, token_count = await run_raptor(task, chat_model, embedding_model, vector_size, progress_callback) # Either using graphrag or Standard chunking methods elif task.get("task_type", "") == "graphrag": - global task_limiter - task_limiter = trio.CapacityLimiter(2) if not task_parser_config.get("graphrag", {}).get("use_graphrag", False): return graphrag_conf = task["kb_parser_config"].get("graphrag", {}) @@ -544,7 +547,8 @@ async def do_handle_task(task): chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language) with_resolution = graphrag_conf.get("resolution", False) with_community = graphrag_conf.get("community", False) - await run_graphrag(task, task_language, with_resolution, with_community, chat_model, embedding_model, progress_callback) + async with kg_limiter: + await run_graphrag(task, task_language, with_resolution, with_community, chat_model, embedding_model, progress_callback) progress_callback(prog=1.0, msg="Knowledge Graph done ({:.2f}s)".format(timer() - start_ts)) return else: @@ -576,23 +580,40 @@ async def do_handle_task(task): chunk_count = len(set([chunk["id"] for chunk in chunks])) start_ts = timer() doc_store_result = "" - es_bulk_size = 4 - for b in range(0, len(chunks), es_bulk_size): - doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b:b + es_bulk_size], search.index_name(task_tenant_id), task_dataset_id)) + + async def delete_image(kb_id, chunk_id): + try: + async with minio_limiter: + STORAGE_IMPL.delete(kb_id, chunk_id) + except Exception: + logging.exception( + "Deleting image of chunk {}/{}/{} got exception".format(task["location"], task["name"], chunk_id)) + raise + + for b in range(0, len(chunks), DOC_BULK_SIZE): + doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b:b + DOC_BULK_SIZE], search.index_name(task_tenant_id), task_dataset_id)) + task_canceled = TaskService.do_cancel(task_id) + if task_canceled: + progress_callback(-1, msg="Task has been canceled.") + return if b % 128 == 0: progress_callback(prog=0.8 + 0.1 * (b + 1) / len(chunks), msg="") if doc_store_result: error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!" progress_callback(-1, msg=error_message) raise Exception(error_message) - chunk_ids = [chunk["id"] for chunk in chunks[:b + es_bulk_size]] + chunk_ids = [chunk["id"] for chunk in chunks[:b + DOC_BULK_SIZE]] chunk_ids_str = " ".join(chunk_ids) try: TaskService.update_chunk_ids(task["id"], chunk_ids_str) except DoesNotExist: logging.warning(f"do_handle_task update_chunk_ids failed since task {task['id']} is unknown.") doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({"id": chunk_ids}, search.index_name(task_tenant_id), task_dataset_id)) + async with trio.open_nursery() as nursery: + for chunk_id in chunk_ids: + nursery.start_soon(delete_image, task_dataset_id, chunk_id) return + logging.info("Indexing doc({}), page({}-{}), chunks({}), elapsed: {:.2f}".format(task_document_name, task_from_page, task_to_page, len(chunks), timer() - start_ts)) @@ -684,39 +705,13 @@ async def report_status(): finally: redis_lock.release() await trio.sleep(30) - - -def recover_pending_tasks(): - redis_lock = RedisDistributedLock("recover_pending_tasks", lock_value=CONSUMER_NAME, timeout=60) - svr_queue_names = get_svr_queue_names() - while not stop_event.is_set(): - try: - if redis_lock.acquire(): - for queue_name in svr_queue_names: - msgs = REDIS_CONN.get_pending_msg(queue=queue_name, group_name=SVR_CONSUMER_GROUP_NAME) - msgs = [msg for msg in msgs if msg['consumer'] != CONSUMER_NAME] - if len(msgs) == 0: - continue - - task_executors = REDIS_CONN.smembers("TASKEXE") - task_executor_set = {t for t in task_executors} - msgs = [msg for msg in msgs if msg['consumer'] not in task_executor_set] - for msg in msgs: - logging.info( - f"Recover pending task: {msg['message_id']}, consumer: {msg['consumer']}, " - f"time since delivered: {msg['time_since_delivered'] / 1000} s" - ) - REDIS_CONN.requeue_msg(queue_name, SVR_CONSUMER_GROUP_NAME, msg['message_id']) - except Exception: - logging.warning("recover_pending_tasks got exception") - finally: - redis_lock.release() - stop_event.wait(60) + async def task_manager(): - global task_limiter - async with task_limiter: + try: await handle_task() + finally: + task_limiter.release() async def main(): @@ -740,16 +735,14 @@ async def main(): signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGTERM, signal_handler) - threading.Thread(name="RecoverPendingTask", target=recover_pending_tasks).start() - async with trio.open_nursery() as nursery: nursery.start_soon(report_status) while not stop_event.is_set(): + await task_limiter.acquire() nursery.start_soon(task_manager) - await trio.sleep(0.1) logging.error("BUG!!! You should not reach here!!!") if __name__ == "__main__": faulthandler.enable() - initRootLogger(CONSUMER_NAME) + init_root_logger(CONSUMER_NAME) trio.run(main) diff --git a/rag/utils/es_conn.py b/rag/utils/es_conn.py index f761a54e723..5e1f6ceb05b 100644 --- a/rag/utils/es_conn.py +++ b/rag/utils/es_conn.py @@ -483,6 +483,9 @@ def getFields(self, res, fields: list[str]) -> dict[str, dict]: if isinstance(v, list): m[n] = v continue + if n == "available_int" and isinstance(v, (int, float)): + m[n] = v + continue if not isinstance(v, str): m[n] = str(m[n]) # if n.find("tks") > 0: diff --git a/rag/utils/minio_conn.py b/rag/utils/minio_conn.py index 03b90a5cfe3..80a723a5c89 100644 --- a/rag/utils/minio_conn.py +++ b/rag/utils/minio_conn.py @@ -118,3 +118,13 @@ def get_presigned_url(self, bucket, fnm, expires): time.sleep(1) return + def remove_bucket(self, bucket): + try: + if self.conn.bucket_exists(bucket): + objects_to_delete = self.conn.list_objects(bucket, recursive=True) + for obj in objects_to_delete: + self.conn.remove_object(bucket, obj.object_name) + self.conn.remove_bucket(bucket) + except Exception: + logging.exception(f"Fail to remove bucket {bucket}") + diff --git a/rag/utils/opendal_conn.py b/rag/utils/opendal_conn.py new file mode 100644 index 00000000000..715c18d190a --- /dev/null +++ b/rag/utils/opendal_conn.py @@ -0,0 +1,120 @@ +import opendal +import logging +import pymysql +import yaml + +from rag.utils import singleton + +SERVICE_CONF_PATH = "conf/service_conf.yaml" + +CREATE_TABLE_SQL = """ +CREATE TABLE IF NOT EXISTS `{}` ( + `key` VARCHAR(255) PRIMARY KEY, + `value` LONGBLOB, + `created_at` TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + `updated_at` TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP +); +""" +SET_MAX_ALLOWED_PACKET_SQL = """ +SET GLOBAL max_allowed_packet={} +""" + + +def get_opendal_config_from_yaml(yaml_path=SERVICE_CONF_PATH): + try: + with open(yaml_path, 'r') as f: + config = yaml.safe_load(f) + + opendal_config = config.get('opendal', {}) + kwargs = {} + if opendal_config.get("scheme") == 'mysql': + mysql_config = config.get('mysql', {}) + kwargs = { + "scheme": "mysql", + "host": mysql_config.get("host", "127.0.0.1"), + "port": str(mysql_config.get("port", 3306)), + "user": mysql_config.get("user", "root"), + "password": mysql_config.get("password", ""), + "database": mysql_config.get("name", "test_open_dal"), + "table": opendal_config.get("config").get("table", "opendal_storage") + } + kwargs["connection_string"] = f"mysql://{kwargs['user']}:{kwargs['password']}@{kwargs['host']}:{kwargs['port']}/{kwargs['database']}" + else: + scheme = opendal_config.get("scheme") + config_data = opendal_config.get("config", {}) + kwargs = {"scheme": scheme, **config_data} + logging.info("Loaded OpenDAL configuration from yaml: %s", kwargs) + return kwargs + except Exception as e: + logging.error("Failed to load OpenDAL configuration from yaml: %s", str(e)) + raise + + +@singleton +class OpenDALStorage: + def __init__(self): + self._kwargs = get_opendal_config_from_yaml() + self._scheme = self._kwargs.get('scheme', 'mysql') + if self._scheme == 'mysql': + self.init_db_config() + self.init_opendal_mysql_table() + self._operator = opendal.Operator(**self._kwargs) + + logging.info("OpenDALStorage initialized successfully") + + def health(self): + bucket, fnm, binary = "txtxtxtxt1", "txtxtxtxt1", b"_t@@@1" + r = self._operator.write(f"{bucket}/{fnm}", binary) + return r + + def put(self, bucket, fnm, binary): + self._operator.write(f"{bucket}/{fnm}", binary) + + def get(self, bucket, fnm): + return self._operator.read(f"{bucket}/{fnm}") + + def rm(self, bucket, fnm): + self._operator.delete(f"{bucket}/{fnm}") + self._operator.__init__() + + def scan(self, bucket, fnm): + return self._operator.scan(f"{bucket}/{fnm}") + + def obj_exist(self, bucket, fnm): + return self._operator.exists(f"{bucket}/{fnm}") + + + def init_db_config(self): + try: + conn = pymysql.connect( + host=self._kwargs['host'], + port=int(self._kwargs['port']), + user=self._kwargs['user'], + password=self._kwargs['password'], + database=self._kwargs['database'] + ) + cursor = conn.cursor() + max_packet = self._kwargs.get('max_allowed_packet', 4194304) # Default to 4MB if not specified + cursor.execute(SET_MAX_ALLOWED_PACKET_SQL.format(max_packet)) + conn.commit() + cursor.close() + conn.close() + logging.info(f"Database configuration initialized with max_allowed_packet={max_packet}") + except Exception as e: + logging.error(f"Failed to initialize database configuration: {str(e)}") + raise + + def init_opendal_mysql_table(self): + conn = pymysql.connect( + host=self._kwargs['host'], + port=int(self._kwargs['port']), + user=self._kwargs['user'], + password=self._kwargs['password'], + database=self._kwargs['database'] + ) + cursor = conn.cursor() + cursor.execute(CREATE_TABLE_SQL.format(self._kwargs['table'])) + conn.commit() + cursor.close() + conn.close() + logging.info(f"Table `{self._kwargs['table']}` initialized.") diff --git a/rag/utils/opensearch_coon.py b/rag/utils/opensearch_coon.py index 8bbde7e076e..4a8fd0889b2 100644 --- a/rag/utils/opensearch_coon.py +++ b/rag/utils/opensearch_coon.py @@ -217,7 +217,7 @@ def search( if bqry: s = s.query(bqry) for field in highlightFields: - s = s.highlight(field) + s = s.highlight(field,force_source=True,no_match_size=30,require_field_match=False) if orderBy: orders = list() @@ -269,7 +269,7 @@ def get(self, chunkId: str, indexName: str, knowledgebaseIds: list[str]) -> dict for i in range(ATTEMPT_TIME): try: res = self.os.get(index=(indexName), - id=chunkId, source=True, ) + id=chunkId, _source=True, ) if str(res.get("timed_out", "")).lower() == "true": raise Exception("Es Timeout.") chunk = res["_source"] @@ -329,7 +329,7 @@ def update(self, condition: dict, newValue: dict, indexName: str, knowledgebaseI chunkId = condition["id"] for i in range(ATTEMPT_TIME): try: - self.os.update(index=indexName, id=chunkId, doc=doc) + self.os.update(index=indexName, id=chunkId, body=doc) return True except Exception as e: logger.exception( @@ -411,7 +411,10 @@ def delete(self, condition: dict, indexName: str, knowledgebaseId: str) -> int: chunk_ids = condition["id"] if not isinstance(chunk_ids, list): chunk_ids = [chunk_ids] - qry = Q("ids", values=chunk_ids) + if not chunk_ids: # when chunk_ids is empty, delete all + qry = Q("match_all") + else: + qry = Q("ids", values=chunk_ids) else: qry = Q("bool") for k, v in condition.items(): diff --git a/rag/utils/s3_conn.py b/rag/utils/s3_conn.py index 05c68880a75..bccfd91fc10 100644 --- a/rag/utils/s3_conn.py +++ b/rag/utils/s3_conn.py @@ -65,10 +65,14 @@ def __open__(self): pass try: - s3_params = { - 'aws_access_key_id': self.access_key, - 'aws_secret_access_key': self.secret_key, - } + s3_params = {} + # if not set ak/sk, boto3 s3 client would try several ways to do the authentication + # see doc: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html#configuring-credentials + if self.access_key and self.secret_key: + s3_params = { + 'aws_access_key_id': self.access_key, + 'aws_secret_access_key': self.secret_key, + } if self.region in self.s3_config: s3_params['region_name'] = self.region if 'endpoint_url' in self.s3_config: diff --git a/rag/utils/storage_factory.py b/rag/utils/storage_factory.py index 63587b3b01e..4ac091f85f5 100644 --- a/rag/utils/storage_factory.py +++ b/rag/utils/storage_factory.py @@ -20,6 +20,7 @@ from rag.utils.azure_sas_conn import RAGFlowAzureSasBlob from rag.utils.azure_spn_conn import RAGFlowAzureSpnBlob from rag.utils.minio_conn import RAGFlowMinio +from rag.utils.opendal_conn import OpenDALStorage from rag.utils.s3_conn import RAGFlowS3 from rag.utils.oss_conn import RAGFlowOSS @@ -30,6 +31,7 @@ class Storage(Enum): AZURE_SAS = 3 AWS_S3 = 4 OSS = 5 + OPENDAL = 6 class StorageFactory: @@ -39,6 +41,7 @@ class StorageFactory: Storage.AZURE_SAS: RAGFlowAzureSasBlob, Storage.AWS_S3: RAGFlowS3, Storage.OSS: RAGFlowOSS, + Storage.OPENDAL: OpenDALStorage } @classmethod diff --git a/sandbox/.env.example b/sandbox/.env.example new file mode 100644 index 00000000000..e88ed561cea --- /dev/null +++ b/sandbox/.env.example @@ -0,0 +1,9 @@ +# Copy this file to `.env` and modify as needed + +SANDBOX_EXECUTOR_MANAGER_POOL_SIZE=5 +SANDBOX_BASE_PYTHON_IMAGE=sandbox-base-python:latest +SANDBOX_BASE_NODEJS_IMAGE=sandbox-base-nodejs:latest +SANDBOX_EXECUTOR_MANAGER_PORT=9385 +SANDBOX_ENABLE_SECCOMP=false +SANDBOX_MAX_MEMORY=256m # b, k, m, g +SANDBOX_TIMEOUT=10s # s, m, 1m30s diff --git a/sandbox/Makefile b/sandbox/Makefile new file mode 100644 index 00000000000..da07da2b7ab --- /dev/null +++ b/sandbox/Makefile @@ -0,0 +1,115 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Force using Bash to ensure the source command is available +SHELL := /bin/bash + +# Environment variable definitions +VENV := .venv +PYTHON := $(VENV)/bin/python +UV := uv +ACTIVATE_SCRIPT := $(VENV)/bin/activate +SYS_PYTHON := python3 +PYTHONPATH := $(shell pwd) + +.PHONY: all setup ensure_env ensure_uv start stop restart build clean test logs + +all: setup start + +# 🌱 Initialize environment + install dependencies +setup: ensure_env ensure_uv + @echo "📦 Installing dependencies with uv..." + @$(UV) sync --python 3.11 + source $(ACTIVATE_SCRIPT) && \ + export PYTHONPATH=$(PYTHONPATH) + @$(UV) pip install -r executor_manager/requirements.txt + @echo "✅ Setup complete." + +# 🔑 Ensure .env exists (copy from .env.example on first run) +ensure_env: + @if [ ! -f ".env" ]; then \ + if [ -f ".env.example" ]; then \ + echo "📝 Creating .env from .env.example..."; \ + cp .env.example .env; \ + else \ + echo "⚠️ Warning: .env.example not found, creating empty .env"; \ + touch .env; \ + fi; \ + else \ + echo "✅ .env already exists."; \ + fi + +# 🔧 Ensure uv is executable (install using system Python) +ensure_uv: + @if ! command -v $(UV) >/dev/null 2>&1; then \ + echo "🛠️ Installing uv using system Python..."; \ + $(SYS_PYTHON) -m pip install -q --upgrade pip; \ + $(SYS_PYTHON) -m pip install -q uv || (echo "⚠️ uv install failed, check manually" && exit 1); \ + fi + +# 🐳 Service control (using safer variable loading) +start: + @echo "🚀 Starting services..." + source $(ACTIVATE_SCRIPT) && \ + export PYTHONPATH=$(PYTHONPATH) && \ + [ -f .env ] && source .env || true && \ + bash scripts/start.sh + +stop: + @echo "🛑 Stopping services..." + source $(ACTIVATE_SCRIPT) && \ + bash scripts/stop.sh + +restart: stop start + @echo "🔁 Restarting services..." + +build: + @echo "🔧 Building base sandbox images..." + @if [ -f .env ]; then \ + source .env && \ + echo "🐍 Building base sandbox image for Python ($$SANDBOX_BASE_PYTHON_IMAGE)..." && \ + docker build -t "$$SANDBOX_BASE_PYTHON_IMAGE" ./sandbox_base_image/python && \ + echo "⬢ Building base sandbox image for Nodejs ($$SANDBOX_BASE_NODEJS_IMAGE)..." && \ + docker build -t "$$SANDBOX_BASE_NODEJS_IMAGE" ./sandbox_base_image/nodejs; \ + else \ + echo "⚠️ .env file not found, skipping build."; \ + fi + +test: + @echo "🧪 Running sandbox security tests..." + source $(ACTIVATE_SCRIPT) && \ + export PYTHONPATH=$(PYTHONPATH) && \ + $(PYTHON) tests/sandbox_security_tests_full.py + +logs: + @echo "📋 Showing logs from api-server and executor-manager..." + docker compose logs -f + +# 🧹 Clean all containers and volumes +clean: + @echo "🧹 Cleaning all containers and volumes..." + @docker compose down -v || true + @if [ -f .env ]; then \ + source .env && \ + for i in $$(seq 0 $$((SANDBOX_EXECUTOR_MANAGER_POOL_SIZE - 1))); do \ + echo "🧹 Deleting sandbox_python_$$i..." && \ + docker rm -f sandbox_python_$$i 2>/dev/null || true && \ + echo "🧹 Deleting sandbox_nodejs_$$i..." && \ + docker rm -f sandbox_nodejs_$$i 2>/dev/null || true; \ + done; \ + else \ + echo "⚠️ .env not found, skipping container cleanup"; \ + fi diff --git a/sandbox/README.md b/sandbox/README.md new file mode 100644 index 00000000000..2e4e209ff47 --- /dev/null +++ b/sandbox/README.md @@ -0,0 +1,291 @@ +# RAGFlow Sandbox + +A secure, pluggable code execution backend for RAGFlow and beyond. + +## 🔧 Features + +- ✅ **Seamless RAGFlow Integration** — Out-of-the-box compatibility with the `code` component. +- 🔐 **High Security** — Leverages [gVisor](https://gvisor.dev/) for syscall-level sandboxing. +- 🔧 **Customizable Sandboxing** — Easily modify `seccomp` settings as needed. +- 🧩 **Pluggable Runtime Support** — Easily extend to support any programming language. +- ⚙️ **Developer Friendly** — Get started with a single command using `Makefile`. + +## 🏗 Architecture + +

+ Architecture Diagram +

+ +## 🚀 Quick Start + +### 📋 Prerequisites + +#### Required + +- Linux distro compatible with gVisor +- [gVisor](https://gvisor.dev/docs/user_guide/install/) +- Docker >= `24.0.0` +- Docker Compose >= `v2.26.1` like [RAGFlow](https://github.com/infiniflow/ragflow) +- [uv](https://docs.astral.sh/uv/) as package and project manager + +#### Optional (Recommended) + +- [GNU Make](https://www.gnu.org/software/make/) for simplified CLI management + +--- + +### 🐳 Build Docker Base Images + +We use isolated base images for secure containerized execution: + +```bash +# Build base images manually +docker build -t sandbox-base-python:latest ./sandbox_base_image/python +docker build -t sandbox-base-nodejs:latest ./sandbox_base_image/nodejs + +# OR use Makefile +make build +``` + +Then, build the executor manager image: + +```bash +docker build -t sandbox-executor-manager:latest ./executor_manager +``` + +--- + +### 📦 Running with RAGFlow + +1. Ensure gVisor is correctly installed. +2. Configure your `.env` in `docker/.env`: + + - Uncomment sandbox-related variables. + - Enable sandbox profile at the bottom. +3. Add the following line to `/etc/hosts` as recommended: + + ```text + 127.0.0.1 sandbox-executor-manager + ``` + +4. Start RAGFlow service. + +--- + +### 🧭 Running Standalone + +#### Manual Setup + +1. Initialize environment: + + ```bash + cp .env.example .env + ``` + +2. Launch: + + ```bash + docker compose -f docker-compose.yml up + ``` + +3. Test: + + ```bash + source .venv/bin/activate + export PYTHONPATH=$(pwd) + uv pip install -r executor_manager/requirements.txt + uv run tests/sandbox_security_tests_full.py + ``` + +#### With Make + +```bash +make # setup + build + launch + test +``` + +--- + +### 📈 Monitoring + +```bash +docker logs -f sandbox-executor-manager # Manual +make logs # With Make +``` + +--- + +### 🧰 Makefile Toolbox + +| Command | Description | +| ----------------- | ------------------------------------------------ | +| `make` | Setup, build, launch and test all at once | +| `make setup` | Initialize environment and install uv | +| `make ensure_env` | Auto-create `.env` if missing | +| `make ensure_uv` | Install `uv` package manager if missing | +| `make build` | Build all Docker base images | +| `make start` | Start services with safe env loading and testing | +| `make stop` | Gracefully stop all services | +| `make restart` | Shortcut for `stop` + `start` | +| `make test` | Run full test suite | +| `make logs` | Stream container logs | +| `make clean` | Stop and remove orphan containers and volumes | + +--- + +## 🔐 Security + +The RAGFlow sandbox is designed to balance security and usability, offering solid protection without compromising developer experience. + +### ✅ gVisor Isolation + +At its core, we use [gVisor](https://gvisor.dev/docs/architecture_guide/security/), a user-space kernel, to isolate code execution from the host system. gVisor intercepts and restricts syscalls, offering robust protection against container escapes and privilege escalations. + +### 🔒 Optional seccomp Support (Advanced) + +For users who need **zero-trust-level syscall control**, we support an additional `seccomp` profile. This feature restricts containers to only a predefined set of system calls, as specified in `executor_manager/seccomp-profile-default.json`. + +> ⚠️ This feature is **disabled by default** to maintain compatibility and usability. Enabling it may cause compatibility issues with some dependencies. + +#### To enable seccomp + +1. Edit your `.env` file: + + ```dotenv + SANDBOX_ENABLE_SECCOMP=true + ``` + +2. Customize allowed syscalls in: + + ``` + executor_manager/seccomp-profile-default.json + ``` + + This profile is passed to the container with: + + ```bash + --security-opt seccomp=/app/seccomp-profile-default.json + ``` + +### 🧠 Python Code AST Inspection + +In addition to sandboxing, Python code is **statically analyzed via AST (Abstract Syntax Tree)** before execution. Potentially malicious code (e.g. file operations, subprocess calls, etc.) is rejected early, providing an extra layer of protection. + +--- + +This security model strikes a balance between **robust isolation** and **developer usability**. While `seccomp` can be highly restrictive, our default setup aims to keep things usable for most developers — no obscure crashes or cryptic setup required. + +## 📦 Add Extra Dependencies for Supported Languages + +Currently, the following languages are officially supported: + +| Language | Priority | +| -------- | -------- | +| Python | High | +| Node.js | Medium | + +### 🐍 Python + +To add Python dependencies, simply edit the following file: + +```bash +sandbox_base_image/python/requirements.txt +``` + +Add any additional packages you need, one per line (just like a normal pip requirements file). + +### 🟨 Node.js + +To add Node.js dependencies: + +1. Navigate to the Node.js base image directory: + + ```bash + cd sandbox_base_image/nodejs + ``` + +2. Use `npm` to install the desired packages. For example: + + ```bash + npm install lodash + ``` + +3. The dependencies will be saved to `package.json` and `package-lock.json`, and included in the Docker image when rebuilt. + +--- + +## 📋 FAQ + +### ❓Sandbox Not Working? + +Follow this checklist to troubleshoot: + +- [ ] **Is your machine compatible with gVisor?** + + Ensure that your system supports gVisor. Refer to the [gVisor installation guide](https://gvisor.dev/docs/user_guide/install/). + +- [ ] **Is gVisor properly installed?** + + **Common error:** + + `HTTPConnectionPool(host='sandbox-executor-manager', port=9385): Read timed out.` + + Cause: `runsc` is an unknown or invalid Docker runtime. + **Fix:** + + - Install gVisor + + - Restart Docker + + - Test with: + + ```bash + docker run --rm --runtime=runsc hello-world + ``` + +- [ ] **Is `sandbox-executor-manager` mapped in `/etc/hosts`?** + + **Common error:** + + `HTTPConnectionPool(host='none', port=9385): Max retries exceeded.` + + **Fix:** + + Add the following entry to `/etc/hosts`: + + ```text + 127.0.0.1 es01 infinity mysql minio redis sandbox-executor-manager + ``` + +- [ ] **Have you enabled sandbox-related configurations in RAGFlow?** + + Double-check that all sandbox settings are correctly enabled in your RAGFlow configuration. + +- [ ] **Have you pulled the required base images for the runners?** + + **Common error:** + + `HTTPConnectionPool(host='sandbox-executor-manager', port=9385): Read timed out.` + + Cause: no runner was started. + + **Fix:** + + Pull the necessary base images: + + ```bash + docker pull infiniflow/sandbox-base-nodejs:latest + docker pull infiniflow/sandbox-base-python:latest + ``` + +- [ ] **Did you restart the service after making changes?** + + Any changes to configuration or environment require a full service restart to take effect. + + +### ❓Container pool is busy? + +All available runners are currently in use, executing tasks/running code. Please try again shortly, or consider increasing the pool size in the configuration to improve availability and reduce wait times. + +## 🤝 Contribution + +Contributions are welcome! diff --git a/sandbox/asserts/code_executor_manager.svg b/sandbox/asserts/code_executor_manager.svg new file mode 100644 index 00000000000..710beabf214 --- /dev/null +++ b/sandbox/asserts/code_executor_manager.svg @@ -0,0 +1,4 @@ + + + +
RAGFlow
RAGFlow
executor_manager
executor_manager
Code executor pool
Code executor pool
code run reqest
code run reqest
response
response
executor_manager lifespan
executor_manager lifespan
patch run task
patch run task
code result
code result
Before: creating gVisor guarded code executor pool 
Before: creating gVisor guarded code executor poo...
After: resource clean up 
After: resource clean up 

Python
in
runsc
Python...


Python
in
runsc
Python...


Node.js
in
runsc
Node.js...

...
...
gVisor
gVisor
gVisor
gVisor
gVisor
gVisor
gVisor
gVisor
x_x
x_x
x_x
x_x
x_x
x_x

...
...
Clean up
Clean up
Task orchestration and pool management...
Task orchestration and pool management...
Text is not SVG - cannot display
\ No newline at end of file diff --git a/sandbox/docker-compose.yml b/sandbox/docker-compose.yml new file mode 100644 index 00000000000..71067f3b633 --- /dev/null +++ b/sandbox/docker-compose.yml @@ -0,0 +1,33 @@ +services: + sandbox-executor-manager: + container_name: sandbox-executor-manager + build: + context: ./executor_manager + dockerfile: Dockerfile + image: sandbox-executor-manager:latest + runtime: runc + privileged: true + ports: + - "${EXECUTOR_PORT:-9385}:9385" + volumes: + - /var/run/docker.sock:/var/run/docker.sock + networks: + - sandbox-network + restart: always + security_opt: + - no-new-privileges:true + environment: + - SANDBOX_EXECUTOR_MANAGER_POOL_SIZE=${SANDBOX_EXECUTOR_MANAGER_POOL_SIZE:-5} + - SANDBOX_BASE_PYTHON_IMAGE=${SANDBOX_BASE_PYTHON_IMAGE-sandbox-base-python:latest} + - SANDBOX_BASE_NODEJS_IMAGE=${SANDBOX_BASE_NODEJS_IMAGE-sandbox-base-nodejs:latest} + - SANDBOX_ENABLE_SECCOMP=${SANDBOX_ENABLE_SECCOMP:-false} + - SANDBOX_MAX_MEMORY=${SANDBOX_MAX_MEMORY:-256m} # b, k, m, g + - SANDBOX_TIMEOUT=${SANDBOX_TIMEOUT:-10s} # s, m, 1m30s + healthcheck: + test: ["CMD-SHELL", "curl --fail http://localhost:9385/healthz || exit 1"] + interval: 10s + timeout: 5s + retries: 5 +networks: + sandbox-network: + driver: bridge diff --git a/sandbox/executor_manager/Dockerfile b/sandbox/executor_manager/Dockerfile new file mode 100644 index 00000000000..85f4f36c722 --- /dev/null +++ b/sandbox/executor_manager/Dockerfile @@ -0,0 +1,23 @@ +FROM python:3.11-slim-bookworm + +RUN grep -rl 'deb.debian.org' /etc/apt/ | xargs sed -i 's|http[s]*://deb.debian.org|https://mirrors.tuna.tsinghua.edu.cn|g' && \ + apt-get update && \ + apt-get install -y curl gcc && \ + rm -rf /var/lib/apt/lists/* + +RUN curl -fsSL https://mirrors.aliyun.com/docker-ce/linux/static/stable/x86_64/docker-24.0.7.tgz -o docker.tgz && \ + tar -xzf docker.tgz && \ + mv docker/docker /usr/bin/docker && \ + rm -rf docker docker.tgz + +COPY --from=ghcr.io/astral-sh/uv:0.7.5 /uv /uvx /bin/ +ENV UV_INDEX_URL=https://pypi.tuna.tsinghua.edu.cn/simple + + +WORKDIR /app +COPY . . + +RUN uv pip install --system -r requirements.txt + +CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "9385"] + diff --git a/sandbox/executor_manager/api/__init__.py b/sandbox/executor_manager/api/__init__.py new file mode 100644 index 00000000000..177b91dd051 --- /dev/null +++ b/sandbox/executor_manager/api/__init__.py @@ -0,0 +1,15 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/sandbox/executor_manager/api/handlers.py b/sandbox/executor_manager/api/handlers.py new file mode 100644 index 00000000000..c6c673df468 --- /dev/null +++ b/sandbox/executor_manager/api/handlers.py @@ -0,0 +1,44 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import base64 + +from core.logger import logger +from fastapi import Request +from models.enums import ResultStatus +from models.schemas import CodeExecutionRequest, CodeExecutionResult +from services.execution import execute_code +from services.limiter import limiter +from services.security import analyze_code_security + + +async def healthz_handler(): + return {"status": "ok"} + + +@limiter.limit("5/second") +async def run_code_handler(req: CodeExecutionRequest, request: Request): + logger.info("🟢 Received /run request") + + code = base64.b64decode(req.code_b64).decode("utf-8") + is_safe, issues = analyze_code_security(code, language=req.language) + if not is_safe: + issue_details = "\n".join([f"Line {lineno}: {issue}" for issue, lineno in issues]) + return CodeExecutionResult(status=ResultStatus.PROGRAM_RUNNER_ERROR, stdout="", stderr=issue_details, exit_code=-999, detail="Code is unsafe") + + try: + return await execute_code(req) + except Exception as e: + return CodeExecutionResult(status=ResultStatus.PROGRAM_RUNNER_ERROR, stdout="", stderr=str(e), exit_code=-999, detail="unhandled_exception") diff --git a/sandbox/executor_manager/api/routes.py b/sandbox/executor_manager/api/routes.py new file mode 100644 index 00000000000..69317b6720c --- /dev/null +++ b/sandbox/executor_manager/api/routes.py @@ -0,0 +1,23 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from fastapi import APIRouter + +from api.handlers import healthz_handler, run_code_handler + +router = APIRouter() + +router.get("/healthz")(healthz_handler) +router.post("/run")(run_code_handler) diff --git a/sandbox/executor_manager/core/__init__.py b/sandbox/executor_manager/core/__init__.py new file mode 100644 index 00000000000..177b91dd051 --- /dev/null +++ b/sandbox/executor_manager/core/__init__.py @@ -0,0 +1,15 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/sandbox/executor_manager/core/config.py b/sandbox/executor_manager/core/config.py new file mode 100644 index 00000000000..962b3f0bfb2 --- /dev/null +++ b/sandbox/executor_manager/core/config.py @@ -0,0 +1,44 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os +from contextlib import asynccontextmanager + +from fastapi import FastAPI +from util import format_timeout_duration, parse_timeout_duration + +from core.container import init_containers, teardown_containers +from core.logger import logger + +TIMEOUT = 10 + + +@asynccontextmanager +async def _lifespan(app: FastAPI): + """Asynchronous lifecycle management""" + size = int(os.getenv("SANDBOX_EXECUTOR_MANAGER_POOL_SIZE", 1)) + + success_count, total_task_count = await init_containers(size) + logger.info(f"\n📊 Container pool initialization complete: {success_count}/{total_task_count} available") + + yield + + await teardown_containers() + + +def init(): + TIMEOUT = parse_timeout_duration(os.getenv("SANDBOX_TIMEOUT")) + logger.info(f"Global timeout: {format_timeout_duration(TIMEOUT)}") + return _lifespan diff --git a/sandbox/executor_manager/core/container.py b/sandbox/executor_manager/core/container.py new file mode 100644 index 00000000000..a026de112fb --- /dev/null +++ b/sandbox/executor_manager/core/container.py @@ -0,0 +1,188 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import asyncio +import contextlib +import os +from queue import Empty, Queue +from threading import Lock + +from models.enums import SupportLanguage +from util import env_setting_enabled, is_valid_memory_limit +from utils.common import async_run_command + +from core.logger import logger + +_CONTAINER_QUEUES: dict[SupportLanguage, Queue] = {} +_CONTAINER_LOCK: Lock = Lock() + + +async def init_containers(size: int) -> tuple[int, int]: + global _CONTAINER_QUEUES + _CONTAINER_QUEUES = {SupportLanguage.PYTHON: Queue(), SupportLanguage.NODEJS: Queue()} + + with _CONTAINER_LOCK: + while not _CONTAINER_QUEUES[SupportLanguage.PYTHON].empty(): + _CONTAINER_QUEUES[SupportLanguage.PYTHON].get_nowait() + while not _CONTAINER_QUEUES[SupportLanguage.NODEJS].empty(): + _CONTAINER_QUEUES[SupportLanguage.NODEJS].get_nowait() + + create_tasks = [] + for i in range(size): + name = f"sandbox_python_{i}" + logger.info(f"🛠️ Creating Python container {i + 1}/{size}") + create_tasks.append(_prepare_container(name, SupportLanguage.PYTHON)) + + name = f"sandbox_nodejs_{i}" + logger.info(f"🛠️ Creating Node.js container {i + 1}/{size}") + create_tasks.append(_prepare_container(name, SupportLanguage.NODEJS)) + + results = await asyncio.gather(*create_tasks, return_exceptions=True) + success_count = sum(1 for r in results if r is True) + total_task_count = len(create_tasks) + return success_count, total_task_count + + +async def teardown_containers(): + with _CONTAINER_LOCK: + while not _CONTAINER_QUEUES[SupportLanguage.PYTHON].empty(): + name = _CONTAINER_QUEUES[SupportLanguage.PYTHON].get_nowait() + await async_run_command("docker", "rm", "-f", name, timeout=5) + while not _CONTAINER_QUEUES[SupportLanguage.NODEJS].empty(): + name = _CONTAINER_QUEUES[SupportLanguage.NODEJS].get_nowait() + await async_run_command("docker", "rm", "-f", name, timeout=5) + + +async def _prepare_container(name: str, language: SupportLanguage) -> bool: + """Prepare a single container""" + with contextlib.suppress(Exception): + await async_run_command("docker", "rm", "-f", name, timeout=5) + + if await create_container(name, language): + _CONTAINER_QUEUES[language].put(name) + return True + return False + + +async def create_container(name: str, language: SupportLanguage) -> bool: + """Asynchronously create a container""" + create_args = [ + "docker", + "run", + "-d", + "--runtime=runsc", + "--name", + name, + "--read-only", + "--tmpfs", + "/workspace:rw,exec,size=100M,uid=65534,gid=65534", + "--tmpfs", + "/tmp:rw,exec,size=50M", + "--user", + "nobody", + "--workdir", + "/workspace", + ] + if os.getenv("SANDBOX_MAX_MEMORY"): + memory_limit = os.getenv("SANDBOX_MAX_MEMORY") or "256m" + if is_valid_memory_limit(memory_limit): + logger.info(f"SANDBOX_MAX_MEMORY: {os.getenv('SANDBOX_MAX_MEMORY')}") + else: + logger.info("Invalid SANDBOX_MAX_MEMORY, using default value: 256m") + memory_limit = "256m" + create_args.extend(["--memory", memory_limit]) + else: + logger.info("Set default SANDBOX_MAX_MEMORY: 256m") + create_args.extend(["--memory", "256m"]) + + if env_setting_enabled("SANDBOX_ENABLE_SECCOMP", "false"): + logger.info(f"SANDBOX_ENABLE_SECCOMP: {os.getenv('SANDBOX_ENABLE_SECCOMP')}") + create_args.extend(["--security-opt", "seccomp=/app/seccomp-profile-default.json"]) + + if language == SupportLanguage.PYTHON: + create_args.append(os.getenv("SANDBOX_BASE_PYTHON_IMAGE", "sandbox-base-python:latest")) + elif language == SupportLanguage.NODEJS: + create_args.append(os.getenv("SANDBOX_BASE_NODEJS_IMAGE", "sandbox-base-nodejs:latest")) + + logger.info(f"Sandbox config:\n\t {create_args}") + + try: + returncode, _, stderr = await async_run_command(*create_args, timeout=10) + if returncode != 0: + logger.error(f"❌ Container creation failed {name}: {stderr}") + return False + + if language == SupportLanguage.NODEJS: + copy_cmd = ["docker", "exec", name, "bash", "-c", "cp -a /app/node_modules /workspace/"] + returncode, _, stderr = await async_run_command(*copy_cmd, timeout=10) + if returncode != 0: + logger.error(f"❌ Failed to prepare dependencies for {name}: {stderr}") + return False + + return await container_is_running(name) + except Exception as e: + logger.error(f"❌ Container creation exception {name}: {str(e)}") + return False + + +async def recreate_container(name: str, language: SupportLanguage) -> bool: + """Asynchronously recreate a container""" + logger.info(f"🛠️ Recreating container: {name}") + try: + await async_run_command("docker", "rm", "-f", name, timeout=5) + + return await create_container(name, language) + except Exception as e: + logger.error(f"❌ Container {name} recreation failed: {str(e)}") + return False + + +async def release_container(name: str, language: SupportLanguage): + """Asynchronously release a container""" + with _CONTAINER_LOCK: + if await container_is_running(name): + _CONTAINER_QUEUES[language].put(name) + logger.info(f"🟢 Released container: {name} (remaining available: {_CONTAINER_QUEUES[language].qsize()})") + else: + logger.warning(f"⚠️ Container {name} has crashed, attempting to recreate...") + if await recreate_container(name, language): + _CONTAINER_QUEUES[language].put(name) + logger.info(f"✅ Container {name} successfully recreated and returned to queue") + + +async def allocate_container_blocking(language: SupportLanguage, timeout=10) -> str: + """Asynchronously allocate an available container""" + start_time = asyncio.get_running_loop().time() + while asyncio.get_running_loop().time() - start_time < timeout: + try: + name = _CONTAINER_QUEUES[language].get_nowait() + with _CONTAINER_LOCK: + if not await container_is_running(name) and not await recreate_container(name, language): + continue + + return name + except Empty: + await asyncio.sleep(0.1) + + return "" + + +async def container_is_running(name: str) -> bool: + """Asynchronously check the container status""" + try: + returncode, stdout, _ = await async_run_command("docker", "inspect", "-f", "{{.State.Running}}", name, timeout=2) + return returncode == 0 and stdout.strip() == "true" + except Exception: + return False diff --git a/sandbox/executor_manager/core/logger.py b/sandbox/executor_manager/core/logger.py new file mode 100644 index 00000000000..c393129db24 --- /dev/null +++ b/sandbox/executor_manager/core/logger.py @@ -0,0 +1,19 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import logging + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger("sandbox") diff --git a/sandbox/executor_manager/main.py b/sandbox/executor_manager/main.py new file mode 100644 index 00000000000..ccad79b48ab --- /dev/null +++ b/sandbox/executor_manager/main.py @@ -0,0 +1,25 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from api.routes import router as api_router +from core.config import init +from fastapi import FastAPI +from services.limiter import limiter, rate_limit_exceeded_handler +from slowapi.errors import RateLimitExceeded + +app = FastAPI(lifespan=init()) +app.include_router(api_router) +app.state.limiter = limiter +app.add_exception_handler(RateLimitExceeded, rate_limit_exceeded_handler) diff --git a/sandbox/executor_manager/models/__init__.py b/sandbox/executor_manager/models/__init__.py new file mode 100644 index 00000000000..177b91dd051 --- /dev/null +++ b/sandbox/executor_manager/models/__init__.py @@ -0,0 +1,15 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/sandbox/executor_manager/models/enums.py b/sandbox/executor_manager/models/enums.py new file mode 100644 index 00000000000..b575e54c26e --- /dev/null +++ b/sandbox/executor_manager/models/enums.py @@ -0,0 +1,47 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from enum import Enum + + +class SupportLanguage(str, Enum): + PYTHON = "python" + NODEJS = "nodejs" + + +class ResultStatus(str, Enum): + SUCCESS = "success" + PROGRAM_ERROR = "program_error" + RESOURCE_LIMIT_EXCEEDED = "resource_limit_exceeded" + UNAUTHORIZED_ACCESS = "unauthorized_access" + RUNTIME_ERROR = "runtime_error" + PROGRAM_RUNNER_ERROR = "program_runner_error" + + +class ResourceLimitType(str, Enum): + TIME = "time" + MEMORY = "memory" + OUTPUT = "output" + + +class UnauthorizedAccessType(str, Enum): + DISALLOWED_SYSCALL = "disallowed_syscall" + FILE_ACCESS = "file_access" + NETWORK_ACCESS = "network_access" + + +class RuntimeErrorType(str, Enum): + SIGNALLED = "signalled" + NONZERO_EXIT = "nonzero_exit" diff --git a/sandbox/executor_manager/models/schemas.py b/sandbox/executor_manager/models/schemas.py new file mode 100644 index 00000000000..750db5bc8cf --- /dev/null +++ b/sandbox/executor_manager/models/schemas.py @@ -0,0 +1,53 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import base64 +from typing import Optional + +from pydantic import BaseModel, Field, field_validator + +from models.enums import ResourceLimitType, ResultStatus, RuntimeErrorType, SupportLanguage, UnauthorizedAccessType + + +class CodeExecutionResult(BaseModel): + status: ResultStatus + stdout: str + stderr: str + exit_code: int + detail: Optional[str] = None + + # Resource usage + time_used_ms: Optional[float] = None + memory_used_kb: Optional[float] = None + + # Error details + resource_limit_type: Optional[ResourceLimitType] = None + unauthorized_access_type: Optional[UnauthorizedAccessType] = None + runtime_error_type: Optional[RuntimeErrorType] = None + + +class CodeExecutionRequest(BaseModel): + code_b64: str = Field(..., description="Base64 encoded code string") + language: SupportLanguage = Field(default=SupportLanguage.PYTHON, description="Programming language") + arguments: Optional[dict] = Field(default={}, description="Arguments") + + @field_validator("code_b64") + @classmethod + def validate_base64(cls, v: str) -> str: + try: + base64.b64decode(v, validate=True) + return v + except Exception as e: + raise ValueError(f"Invalid base64 encoding: {str(e)}") diff --git a/sandbox/executor_manager/requirements.txt b/sandbox/executor_manager/requirements.txt new file mode 100644 index 00000000000..4ee4c706eb5 --- /dev/null +++ b/sandbox/executor_manager/requirements.txt @@ -0,0 +1,3 @@ +fastapi +uvicorn +slowapi diff --git a/sandbox/executor_manager/seccomp-profile-default.json b/sandbox/executor_manager/seccomp-profile-default.json new file mode 100644 index 00000000000..e384ac35824 --- /dev/null +++ b/sandbox/executor_manager/seccomp-profile-default.json @@ -0,0 +1,55 @@ +{ + "defaultAction": "SCMP_ACT_ERRNO", + "archMap": [ + { + "architecture": "SCMP_ARCH_X86_64", + "subArchitectures": [ + "SCMP_ARCH_X86", + "SCMP_ARCH_X32" + ] + } + ], + "syscalls": [ + { + "names": [ + "read", + "write", + "exit", + "sigreturn", + "brk", + "mmap", + "munmap", + "rt_sigaction", + "rt_sigprocmask", + "futex", + "clone", + "execve", + "arch_prctl", + "access", + "openat", + "close", + "stat", + "fstat", + "lstat", + "getpid", + "gettid", + "getuid", + "getgid", + "geteuid", + "getegid", + "clock_gettime", + "nanosleep", + "uname", + "writev", + "readlink", + "getrandom", + "statx", + "faccessat2", + "pread64", + "pwrite64", + "rt_sigreturn" + ], + "action": "SCMP_ACT_ALLOW" + } + ] +} diff --git a/sandbox/executor_manager/services/__init__.py b/sandbox/executor_manager/services/__init__.py new file mode 100644 index 00000000000..177b91dd051 --- /dev/null +++ b/sandbox/executor_manager/services/__init__.py @@ -0,0 +1,15 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/sandbox/executor_manager/services/execution.py b/sandbox/executor_manager/services/execution.py new file mode 100644 index 00000000000..c196ef6225a --- /dev/null +++ b/sandbox/executor_manager/services/execution.py @@ -0,0 +1,245 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import asyncio +import base64 +import json +import os +import time +import uuid + +from core.config import TIMEOUT +from core.container import allocate_container_blocking, release_container +from core.logger import logger +from models.enums import ResourceLimitType, ResultStatus, RuntimeErrorType, SupportLanguage, UnauthorizedAccessType +from models.schemas import CodeExecutionRequest, CodeExecutionResult +from utils.common import async_run_command + + +async def execute_code(req: CodeExecutionRequest): + """Fully asynchronous execution logic""" + language = req.language + container = await allocate_container_blocking(language) + if not container: + return CodeExecutionResult( + status=ResultStatus.PROGRAM_RUNNER_ERROR, + stdout="", + stderr="Container pool is busy", + exit_code=-10, + detail="no_available_container", + ) + + task_id = str(uuid.uuid4()) + workdir = f"/tmp/sandbox_{task_id}" + os.makedirs(workdir, mode=0o700, exist_ok=True) + + try: + if language == SupportLanguage.PYTHON: + code_name = "main.py" + # code + code_path = os.path.join(workdir, code_name) + with open(code_path, "wb") as f: + f.write(base64.b64decode(req.code_b64)) + # runner + runner_name = "runner.py" + runner_path = os.path.join(workdir, runner_name) + with open(runner_path, "w") as f: + f.write("""import json +import os +import sys +sys.path.insert(0, os.path.dirname(__file__)) +from main import main +if __name__ == "__main__": + args = json.loads(sys.argv[1]) + result = main(**args) + if result is not None: + print(result) +""") + + elif language == SupportLanguage.NODEJS: + code_name = "main.js" + code_path = os.path.join(workdir, "main.js") + with open(code_path, "wb") as f: + f.write(base64.b64decode(req.code_b64)) + + runner_name = "runner.js" + runner_path = os.path.join(workdir, "runner.js") + with open(runner_path, "w") as f: + f.write(""" +const fs = require('fs'); +const path = require('path'); + +const args = JSON.parse(process.argv[2]); + +const mainPath = path.join(__dirname, 'main.js'); + +if (fs.existsSync(mainPath)) { + const { main } = require(mainPath); + + if (typeof args === 'object' && args !== null) { + main(args).then(result => { + if (result !== null) { + console.log(result); + } + }).catch(err => { + console.error('Error in main function:', err); + }); + } else { + console.error('Error: args is not a valid object:', args); + } +} else { + console.error('main.js not found in the current directory'); +} +""") + # dirs + returncode, _, stderr = await async_run_command("docker", "exec", container, "mkdir", "-p", f"/workspace/{task_id}", timeout=5) + if returncode != 0: + raise RuntimeError(f"Directory creation failed: {stderr}") + + # archive + tar_proc = await asyncio.create_subprocess_exec("tar", "czf", "-", "-C", workdir, code_name, runner_name, stdout=asyncio.subprocess.PIPE) + tar_stdout, _ = await tar_proc.communicate() + + # unarchive + docker_proc = await asyncio.create_subprocess_exec( + "docker", "exec", "-i", container, "tar", "xzf", "-", "-C", f"/workspace/{task_id}", stdin=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE + ) + stdout, stderr = await docker_proc.communicate(input=tar_stdout) + + if docker_proc.returncode != 0: + raise RuntimeError(stderr.decode()) + + # exec + start_time = time.time() + try: + logger.info(f"Passed in args: {req.arguments}") + args_json = json.dumps(req.arguments or {}) + run_args = [ + "docker", + "exec", + "--workdir", + f"/workspace/{task_id}", + container, + "timeout", + str(TIMEOUT), + language, + ] + # flags + if language == SupportLanguage.PYTHON: + run_args.extend(["-I", "-B"]) + elif language == SupportLanguage.NODEJS: + run_args.extend([]) + else: + assert True, "Will never reach here" + run_args.extend([runner_name, args_json]) + + returncode, stdout, stderr = await async_run_command( + *run_args, + timeout=TIMEOUT + 5, + ) + + time_used_ms = (time.time() - start_time) * 1000 + + logger.info("----------------------------------------------") + logger.info(f"Code: {str(base64.b64decode(req.code_b64))}") + logger.info(f"{returncode=}") + logger.info(f"{stdout=}") + logger.info(f"{stderr=}") + logger.info(f"{args_json=}") + + if returncode == 0: + return CodeExecutionResult( + status=ResultStatus.SUCCESS, + stdout=str(stdout), + stderr=stderr, + exit_code=0, + time_used_ms=time_used_ms, + ) + elif returncode == 124: + return CodeExecutionResult( + status=ResultStatus.RESOURCE_LIMIT_EXCEEDED, + stdout="", + stderr="Execution timeout", + exit_code=-124, + resource_limit_type=ResourceLimitType.TIME, + time_used_ms=time_used_ms, + ) + elif returncode == 137: + return CodeExecutionResult( + status=ResultStatus.RESOURCE_LIMIT_EXCEEDED, + stdout="", + stderr="Memory limit exceeded (killed by OOM)", + exit_code=-137, + resource_limit_type=ResourceLimitType.MEMORY, + time_used_ms=time_used_ms, + ) + return analyze_error_result(stderr, returncode) + + except asyncio.TimeoutError: + await async_run_command("docker", "exec", container, "pkill", "-9", language) + return CodeExecutionResult( + status=ResultStatus.RESOURCE_LIMIT_EXCEEDED, + stdout="", + stderr="Execution timeout", + exit_code=-1, + resource_limit_type=ResourceLimitType.TIME, + time_used_ms=(time.time() - start_time) * 1000, + ) + + except Exception as e: + logger.error(f"Execution exception: {str(e)}") + return CodeExecutionResult(status=ResultStatus.PROGRAM_RUNNER_ERROR, stdout="", stderr=str(e), exit_code=-3, detail="internal_error") + + finally: + # cleanup + cleanup_tasks = [async_run_command("docker", "exec", container, "rm", "-rf", f"/workspace/{task_id}"), async_run_command("rm", "-rf", workdir)] + await asyncio.gather(*cleanup_tasks, return_exceptions=True) + await release_container(container, language) + + +def analyze_error_result(stderr: str, exit_code: int) -> CodeExecutionResult: + """Analyze the error result and classify it""" + if "Permission denied" in stderr: + return CodeExecutionResult( + status=ResultStatus.UNAUTHORIZED_ACCESS, + stdout="", + stderr=stderr, + exit_code=exit_code, + unauthorized_access_type=UnauthorizedAccessType.FILE_ACCESS, + ) + elif "Operation not permitted" in stderr: + return CodeExecutionResult( + status=ResultStatus.UNAUTHORIZED_ACCESS, + stdout="", + stderr=stderr, + exit_code=exit_code, + unauthorized_access_type=UnauthorizedAccessType.DISALLOWED_SYSCALL, + ) + elif "MemoryError" in stderr: + return CodeExecutionResult( + status=ResultStatus.RESOURCE_LIMIT_EXCEEDED, + stdout="", + stderr=stderr, + exit_code=exit_code, + resource_limit_type=ResourceLimitType.MEMORY, + ) + else: + return CodeExecutionResult( + status=ResultStatus.PROGRAM_ERROR, + stdout="", + stderr=stderr, + exit_code=exit_code, + runtime_error_type=RuntimeErrorType.NONZERO_EXIT, + ) diff --git a/sandbox/executor_manager/services/limiter.py b/sandbox/executor_manager/services/limiter.py new file mode 100644 index 00000000000..cdaffbd43f8 --- /dev/null +++ b/sandbox/executor_manager/services/limiter.py @@ -0,0 +1,38 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from fastapi import Request +from fastapi.responses import JSONResponse +from models.enums import ResultStatus +from models.schemas import CodeExecutionResult +from slowapi import Limiter +from slowapi.errors import RateLimitExceeded +from slowapi.util import get_remote_address + +limiter = Limiter(key_func=get_remote_address) + + +async def rate_limit_exceeded_handler(request: Request, exc: Exception) -> JSONResponse: + if isinstance(exc, RateLimitExceeded): + return JSONResponse( + content=CodeExecutionResult( + status=ResultStatus.PROGRAM_RUNNER_ERROR, + stdout="", + stderr="Too many requests, please try again later", + exit_code=-429, + detail="Too many requests, please try again later", + ).model_dump(), + ) + raise exc diff --git a/sandbox/executor_manager/services/security.py b/sandbox/executor_manager/services/security.py new file mode 100644 index 00000000000..cbe1ca27e1a --- /dev/null +++ b/sandbox/executor_manager/services/security.py @@ -0,0 +1,173 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import ast +from typing import List, Tuple + +from core.logger import logger +from models.enums import SupportLanguage + + +class SecurePythonAnalyzer(ast.NodeVisitor): + """ + An AST-based analyzer for detecting unsafe Python code patterns. + """ + + DANGEROUS_IMPORTS = {"os", "subprocess", "sys", "shutil", "socket", "ctypes", "pickle", "threading", "multiprocessing", "asyncio", "http.client", "ftplib", "telnetlib"} + + DANGEROUS_CALLS = { + "eval", + "exec", + "open", + "__import__", + "compile", + "input", + "system", + "popen", + "remove", + "rename", + "rmdir", + "chdir", + "chmod", + "chown", + "getattr", + "setattr", + "globals", + "locals", + "shutil.rmtree", + "subprocess.call", + "subprocess.Popen", + "ctypes", + "pickle.load", + "pickle.loads", + "pickle.dump", + "pickle.dumps", + } + + def __init__(self): + self.unsafe_items: List[Tuple[str, int]] = [] + + def visit_Import(self, node: ast.Import): + """Check for dangerous imports.""" + for alias in node.names: + if alias.name.split(".")[0] in self.DANGEROUS_IMPORTS: + self.unsafe_items.append((f"Import: {alias.name}", node.lineno)) + self.generic_visit(node) + + def visit_ImportFrom(self, node: ast.ImportFrom): + """Check for dangerous imports from specific modules.""" + if node.module and node.module.split(".")[0] in self.DANGEROUS_IMPORTS: + self.unsafe_items.append((f"From Import: {node.module}", node.lineno)) + self.generic_visit(node) + + def visit_Call(self, node: ast.Call): + """Check for dangerous function calls.""" + if isinstance(node.func, ast.Name) and node.func.id in self.DANGEROUS_CALLS: + self.unsafe_items.append((f"Call: {node.func.id}", node.lineno)) + self.generic_visit(node) + + def visit_Attribute(self, node: ast.Attribute): + """Check for dangerous attribute access.""" + if isinstance(node.value, ast.Name) and node.value.id in self.DANGEROUS_IMPORTS: + self.unsafe_items.append((f"Attribute Access: {node.value.id}.{node.attr}", node.lineno)) + self.generic_visit(node) + + def visit_BinOp(self, node: ast.BinOp): + """Check for possible unsafe operations like concatenating strings with commands.""" + # This could be useful to detect `eval("os." + "system")` + if isinstance(node.left, ast.Constant) and isinstance(node.right, ast.Constant): + self.unsafe_items.append(("Possible unsafe string concatenation", node.lineno)) + self.generic_visit(node) + + def visit_FunctionDef(self, node: ast.FunctionDef): + """Check for dangerous function definitions (e.g., user-defined eval).""" + if node.name in self.DANGEROUS_CALLS: + self.unsafe_items.append((f"Function Definition: {node.name}", node.lineno)) + self.generic_visit(node) + + def visit_Assign(self, node: ast.Assign): + """Check for assignments to variables that might lead to dangerous operations.""" + for target in node.targets: + if isinstance(target, ast.Name) and target.id in self.DANGEROUS_CALLS: + self.unsafe_items.append((f"Assignment to dangerous variable: {target.id}", node.lineno)) + self.generic_visit(node) + + def visit_Lambda(self, node: ast.Lambda): + """Check for lambda functions with dangerous operations.""" + if isinstance(node.body, ast.Call) and isinstance(node.body.func, ast.Name) and node.body.func.id in self.DANGEROUS_CALLS: + self.unsafe_items.append(("Lambda with dangerous function call", node.lineno)) + self.generic_visit(node) + + def visit_ListComp(self, node: ast.ListComp): + """Check for list comprehensions with dangerous operations.""" + # First, visit the generators to check for any issues there + for elem in node.generators: + if isinstance(elem, ast.comprehension): + self.generic_visit(elem) + + if isinstance(node.elt, ast.Call) and isinstance(node.elt.func, ast.Name) and node.elt.func.id in self.DANGEROUS_CALLS: + self.unsafe_items.append(("List comprehension with dangerous function call", node.lineno)) + self.generic_visit(node) + + def visit_DictComp(self, node: ast.DictComp): + """Check for dictionary comprehensions with dangerous operations.""" + # Check for dangerous calls in both the key and value expressions of the dictionary comprehension + if isinstance(node.key, ast.Call) and isinstance(node.key.func, ast.Name) and node.key.func.id in self.DANGEROUS_CALLS: + self.unsafe_items.append(("Dict comprehension with dangerous function call in key", node.lineno)) + + if isinstance(node.value, ast.Call) and isinstance(node.value.func, ast.Name) and node.value.func.id in self.DANGEROUS_CALLS: + self.unsafe_items.append(("Dict comprehension with dangerous function call in value", node.lineno)) + + # Visit other sub-nodes (e.g., the generators in the comprehension) + self.generic_visit(node) + + def visit_SetComp(self, node: ast.SetComp): + """Check for set comprehensions with dangerous operations.""" + for elt in node.generators: + if isinstance(elt, ast.comprehension): + self.generic_visit(elt) + + if isinstance(node.elt, ast.Call) and isinstance(node.elt.func, ast.Name) and node.elt.func.id in self.DANGEROUS_CALLS: + self.unsafe_items.append(("Set comprehension with dangerous function call", node.lineno)) + + self.generic_visit(node) + + def visit_Yield(self, node: ast.Yield): + """Check for yield statements that could be used to produce unsafe values.""" + if isinstance(node.value, ast.Call) and isinstance(node.value.func, ast.Name) and node.value.func.id in self.DANGEROUS_CALLS: + self.unsafe_items.append(("Yield with dangerous function call", node.lineno)) + self.generic_visit(node) + + +def analyze_code_security(code: str, language: SupportLanguage) -> Tuple[bool, List[Tuple[str, int]]]: + """ + Analyze the provided code string and return whether it's safe and why. + + :param code: The source code to analyze. + :param language: The programming language of the code. + :return: (is_safe: bool, issues: List of (description, line number)) + """ + if language == SupportLanguage.PYTHON: + try: + tree = ast.parse(code) + analyzer = SecurePythonAnalyzer() + analyzer.visit(tree) + return len(analyzer.unsafe_items) == 0, analyzer.unsafe_items + except Exception as e: + logger.error(f"[SafeCheck] Python parsing failed: {str(e)}") + return False, [(f"Parsing Error: {str(e)}", -1)] + else: + logger.warning(f"[SafeCheck] Unsupported language for security analysis: {language} — defaulting to SAFE (manual review recommended)") + return True, [(f"Unsupported language for security analysis: {language} — defaulted to SAFE, manual review recommended", -1)] diff --git a/sandbox/executor_manager/util.py b/sandbox/executor_manager/util.py new file mode 100644 index 00000000000..a84fe570610 --- /dev/null +++ b/sandbox/executor_manager/util.py @@ -0,0 +1,76 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os +import re + + +def is_enabled(value: str) -> bool: + return str(value).strip().lower() in {"1", "true", "yes", "on"} + + +def env_setting_enabled(env_key: str, default: str = "false") -> bool: + value = os.getenv(env_key, default) + return is_enabled(value) + + +def is_valid_memory_limit(mem: str | None) -> bool: + """ + Return True if the input string is a valid Docker memory limit (e.g. '256m', '1g'). + Units allowed: b, k, m, g (case-insensitive). + Disallows zero or negative values. + """ + if not mem or not isinstance(mem, str): + return False + + mem = mem.strip().lower() + + return re.fullmatch(r"[1-9]\d*(b|k|m|g)", mem) is not None + + +def parse_timeout_duration(timeout: str | None, default_seconds: int = 10) -> int: + """ + Parses a string like '90s', '2m', '1m30s' into total seconds (int). + Supports 's', 'm' (lower or upper case). Returns default if invalid. + '1m30s' -> 90 + """ + if not timeout or not isinstance(timeout, str): + return default_seconds + + timeout = timeout.strip().lower() + + pattern = r"^(?:(\d+)m)?(?:(\d+)s)?$" + match = re.fullmatch(pattern, timeout) + if not match: + return default_seconds + + minutes = int(match.group(1)) if match.group(1) else 0 + seconds = int(match.group(2)) if match.group(2) else 0 + total = minutes * 60 + seconds + + return total if total > 0 else default_seconds + + +def format_timeout_duration(seconds: int) -> str: + """ + Formats an integer number of seconds into a string like '1m30s'. + 90 -> '1m30s' + """ + if seconds < 60: + return f"{seconds}s" + minutes, sec = divmod(seconds, 60) + if sec == 0: + return f"{minutes}m" + return f"{minutes}m{sec}s" diff --git a/sandbox/executor_manager/utils/__init__.py b/sandbox/executor_manager/utils/__init__.py new file mode 100644 index 00000000000..177b91dd051 --- /dev/null +++ b/sandbox/executor_manager/utils/__init__.py @@ -0,0 +1,15 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/sandbox/executor_manager/utils/common.py b/sandbox/executor_manager/utils/common.py new file mode 100644 index 00000000000..9a85566ce1a --- /dev/null +++ b/sandbox/executor_manager/utils/common.py @@ -0,0 +1,36 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import asyncio +from typing import Tuple + + +async def async_run_command(*args, timeout: float = 5) -> Tuple[int, str, str]: + """Safe asynchronous command execution tool""" + proc = await asyncio.create_subprocess_exec(*args, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE) + + try: + stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=timeout) + if proc.returncode is None: + raise RuntimeError("Process finished but returncode is None") + return proc.returncode, stdout.decode(), stderr.decode() + except asyncio.TimeoutError: + proc.kill() + await proc.wait() + raise RuntimeError("Command timed out") + except Exception as e: + proc.kill() + await proc.wait() + raise e diff --git a/sandbox/pyproject.toml b/sandbox/pyproject.toml new file mode 100644 index 00000000000..c1380d3409e --- /dev/null +++ b/sandbox/pyproject.toml @@ -0,0 +1,28 @@ +[project] +name = "gvisor-sandbox" +version = "0.1.0" +description = "Add your description here" +readme = "README.md" +requires-python = ">=3.10" +dependencies = [ + "fastapi>=0.115.12", + "httpx>=0.28.1", + "pydantic>=2.11.4", + "requests>=2.32.3", + "slowapi>=0.1.9", + "uvicorn>=0.34.2", +] + +[[tool.uv.index]] +url = "https://pypi.tuna.tsinghua.edu.cn/simple" + +[dependency-groups] +dev = [ + "basedpyright>=1.29.1", +] + +[tool.ruff] +line-length = 200 + +[tool.ruff.lint] +extend-select = ["C4", "SIM", "TCH"] diff --git a/sandbox/sandbox_base_image/nodejs/Dockerfile b/sandbox/sandbox_base_image/nodejs/Dockerfile new file mode 100644 index 00000000000..ada730faf1c --- /dev/null +++ b/sandbox/sandbox_base_image/nodejs/Dockerfile @@ -0,0 +1,17 @@ +FROM node:24-bookworm-slim + +RUN npm config set registry https://registry.npmmirror.com + +# RUN grep -rl 'deb.debian.org' /etc/apt/ | xargs sed -i 's|http[s]*://deb.debian.org|https://mirrors.ustc.edu.cn|g' && \ +# apt-get update && \ +# apt-get install -y curl gcc make + + +WORKDIR /app + +COPY package.json package-lock.json . + +RUN npm install + +CMD ["sleep", "infinity"] + diff --git a/sandbox/sandbox_base_image/nodejs/package-lock.json b/sandbox/sandbox_base_image/nodejs/package-lock.json new file mode 100644 index 00000000000..6aa834100c0 --- /dev/null +++ b/sandbox/sandbox_base_image/nodejs/package-lock.json @@ -0,0 +1,294 @@ +{ + "name": "nodejs", + "version": "1.0.0", + "lockfileVersion": 3, + "requires": true, + "packages": { + "": { + "name": "nodejs", + "version": "1.0.0", + "license": "ISC", + "dependencies": { + "axios": "^1.9.0" + } + }, + "node_modules/asynckit": { + "version": "0.4.0", + "resolved": "https://registry.npmmirror.com/asynckit/-/asynckit-0.4.0.tgz", + "integrity": "sha512-Oei9OH4tRh0YqU3GxhX79dM/mwVgvbZJaSNaRk+bshkj0S5cfHcgYakreBjrHwatXKbz+IoIdYLxrKim2MjW0Q==", + "license": "MIT" + }, + "node_modules/axios": { + "version": "1.9.0", + "resolved": "https://registry.npmmirror.com/axios/-/axios-1.9.0.tgz", + "integrity": "sha512-re4CqKTJaURpzbLHtIi6XpDv20/CnpXOtjRY5/CU32L8gU8ek9UIivcfvSWvmKEngmVbrUtPpdDwWDWL7DNHvg==", + "license": "MIT", + "dependencies": { + "follow-redirects": "^1.15.6", + "form-data": "^4.0.0", + "proxy-from-env": "^1.1.0" + } + }, + "node_modules/call-bind-apply-helpers": { + "version": "1.0.2", + "resolved": "https://registry.npmmirror.com/call-bind-apply-helpers/-/call-bind-apply-helpers-1.0.2.tgz", + "integrity": "sha512-Sp1ablJ0ivDkSzjcaJdxEunN5/XvksFJ2sMBFfq6x0ryhQV/2b/KwFe21cMpmHtPOSij8K99/wSfoEuTObmuMQ==", + "license": "MIT", + "dependencies": { + "es-errors": "^1.3.0", + "function-bind": "^1.1.2" + }, + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/combined-stream": { + "version": "1.0.8", + "resolved": "https://registry.npmmirror.com/combined-stream/-/combined-stream-1.0.8.tgz", + "integrity": "sha512-FQN4MRfuJeHf7cBbBMJFXhKSDq+2kAArBlmRBvcvFE5BB1HZKXtSFASDhdlz9zOYwxh8lDdnvmMOe/+5cdoEdg==", + "license": "MIT", + "dependencies": { + "delayed-stream": "~1.0.0" + }, + "engines": { + "node": ">= 0.8" + } + }, + "node_modules/delayed-stream": { + "version": "1.0.0", + "resolved": "https://registry.npmmirror.com/delayed-stream/-/delayed-stream-1.0.0.tgz", + "integrity": "sha512-ZySD7Nf91aLB0RxL4KGrKHBXl7Eds1DAmEdcoVawXnLD7SDhpNgtuII2aAkg7a7QS41jxPSZ17p4VdGnMHk3MQ==", + "license": "MIT", + "engines": { + "node": ">=0.4.0" + } + }, + "node_modules/dunder-proto": { + "version": "1.0.1", + "resolved": "https://registry.npmmirror.com/dunder-proto/-/dunder-proto-1.0.1.tgz", + "integrity": "sha512-KIN/nDJBQRcXw0MLVhZE9iQHmG68qAVIBg9CqmUYjmQIhgij9U5MFvrqkUL5FbtyyzZuOeOt0zdeRe4UY7ct+A==", + "license": "MIT", + "dependencies": { + "call-bind-apply-helpers": "^1.0.1", + "es-errors": "^1.3.0", + "gopd": "^1.2.0" + }, + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/es-define-property": { + "version": "1.0.1", + "resolved": "https://registry.npmmirror.com/es-define-property/-/es-define-property-1.0.1.tgz", + "integrity": "sha512-e3nRfgfUZ4rNGL232gUgX06QNyyez04KdjFrF+LTRoOXmrOgFKDg4BCdsjW8EnT69eqdYGmRpJwiPVYNrCaW3g==", + "license": "MIT", + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/es-errors": { + "version": "1.3.0", + "resolved": "https://registry.npmmirror.com/es-errors/-/es-errors-1.3.0.tgz", + "integrity": "sha512-Zf5H2Kxt2xjTvbJvP2ZWLEICxA6j+hAmMzIlypy4xcBg1vKVnx89Wy0GbS+kf5cwCVFFzdCFh2XSCFNULS6csw==", + "license": "MIT", + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/es-object-atoms": { + "version": "1.1.1", + "resolved": "https://registry.npmmirror.com/es-object-atoms/-/es-object-atoms-1.1.1.tgz", + "integrity": "sha512-FGgH2h8zKNim9ljj7dankFPcICIK9Cp5bm+c2gQSYePhpaG5+esrLODihIorn+Pe6FGJzWhXQotPv73jTaldXA==", + "license": "MIT", + "dependencies": { + "es-errors": "^1.3.0" + }, + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/es-set-tostringtag": { + "version": "2.1.0", + "resolved": "https://registry.npmmirror.com/es-set-tostringtag/-/es-set-tostringtag-2.1.0.tgz", + "integrity": "sha512-j6vWzfrGVfyXxge+O0x5sh6cvxAog0a/4Rdd2K36zCMV5eJ+/+tOAngRO8cODMNWbVRdVlmGZQL2YS3yR8bIUA==", + "license": "MIT", + "dependencies": { + "es-errors": "^1.3.0", + "get-intrinsic": "^1.2.6", + "has-tostringtag": "^1.0.2", + "hasown": "^2.0.2" + }, + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/follow-redirects": { + "version": "1.15.9", + "resolved": "https://registry.npmmirror.com/follow-redirects/-/follow-redirects-1.15.9.tgz", + "integrity": "sha512-gew4GsXizNgdoRyqmyfMHyAmXsZDk6mHkSxZFCzW9gwlbtOW44CDtYavM+y+72qD/Vq2l550kMF52DT8fOLJqQ==", + "funding": [ + { + "type": "individual", + "url": "https://github.com/sponsors/RubenVerborgh" + } + ], + "license": "MIT", + "engines": { + "node": ">=4.0" + }, + "peerDependenciesMeta": { + "debug": { + "optional": true + } + } + }, + "node_modules/form-data": { + "version": "4.0.2", + "resolved": "https://registry.npmmirror.com/form-data/-/form-data-4.0.2.tgz", + "integrity": "sha512-hGfm/slu0ZabnNt4oaRZ6uREyfCj6P4fT/n6A1rGV+Z0VdGXjfOhVUpkn6qVQONHGIFwmveGXyDs75+nr6FM8w==", + "license": "MIT", + "dependencies": { + "asynckit": "^0.4.0", + "combined-stream": "^1.0.8", + "es-set-tostringtag": "^2.1.0", + "mime-types": "^2.1.12" + }, + "engines": { + "node": ">= 6" + } + }, + "node_modules/function-bind": { + "version": "1.1.2", + "resolved": "https://registry.npmmirror.com/function-bind/-/function-bind-1.1.2.tgz", + "integrity": "sha512-7XHNxH7qX9xG5mIwxkhumTox/MIRNcOgDrxWsMt2pAr23WHp6MrRlN7FBSFpCpr+oVO0F744iUgR82nJMfG2SA==", + "license": "MIT", + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/get-intrinsic": { + "version": "1.3.0", + "resolved": "https://registry.npmmirror.com/get-intrinsic/-/get-intrinsic-1.3.0.tgz", + "integrity": "sha512-9fSjSaos/fRIVIp+xSJlE6lfwhES7LNtKaCBIamHsjr2na1BiABJPo0mOjjz8GJDURarmCPGqaiVg5mfjb98CQ==", + "license": "MIT", + "dependencies": { + "call-bind-apply-helpers": "^1.0.2", + "es-define-property": "^1.0.1", + "es-errors": "^1.3.0", + "es-object-atoms": "^1.1.1", + "function-bind": "^1.1.2", + "get-proto": "^1.0.1", + "gopd": "^1.2.0", + "has-symbols": "^1.1.0", + "hasown": "^2.0.2", + "math-intrinsics": "^1.1.0" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/get-proto": { + "version": "1.0.1", + "resolved": "https://registry.npmmirror.com/get-proto/-/get-proto-1.0.1.tgz", + "integrity": "sha512-sTSfBjoXBp89JvIKIefqw7U2CCebsc74kiY6awiGogKtoSGbgjYE/G/+l9sF3MWFPNc9IcoOC4ODfKHfxFmp0g==", + "license": "MIT", + "dependencies": { + "dunder-proto": "^1.0.1", + "es-object-atoms": "^1.0.0" + }, + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/gopd": { + "version": "1.2.0", + "resolved": "https://registry.npmmirror.com/gopd/-/gopd-1.2.0.tgz", + "integrity": "sha512-ZUKRh6/kUFoAiTAtTYPZJ3hw9wNxx+BIBOijnlG9PnrJsCcSjs1wyyD6vJpaYtgnzDrKYRSqf3OO6Rfa93xsRg==", + "license": "MIT", + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/has-symbols": { + "version": "1.1.0", + "resolved": "https://registry.npmmirror.com/has-symbols/-/has-symbols-1.1.0.tgz", + "integrity": "sha512-1cDNdwJ2Jaohmb3sg4OmKaMBwuC48sYni5HUw2DvsC8LjGTLK9h+eb1X6RyuOHe4hT0ULCW68iomhjUoKUqlPQ==", + "license": "MIT", + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/has-tostringtag": { + "version": "1.0.2", + "resolved": "https://registry.npmmirror.com/has-tostringtag/-/has-tostringtag-1.0.2.tgz", + "integrity": "sha512-NqADB8VjPFLM2V0VvHUewwwsw0ZWBaIdgo+ieHtK3hasLz4qeCRjYcqfB6AQrBggRKppKF8L52/VqdVsO47Dlw==", + "license": "MIT", + "dependencies": { + "has-symbols": "^1.0.3" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/hasown": { + "version": "2.0.2", + "resolved": "https://registry.npmmirror.com/hasown/-/hasown-2.0.2.tgz", + "integrity": "sha512-0hJU9SCPvmMzIBdZFqNPXWa6dqh7WdH0cII9y+CyS8rG3nL48Bclra9HmKhVVUHyPWNH5Y7xDwAB7bfgSjkUMQ==", + "license": "MIT", + "dependencies": { + "function-bind": "^1.1.2" + }, + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/math-intrinsics": { + "version": "1.1.0", + "resolved": "https://registry.npmmirror.com/math-intrinsics/-/math-intrinsics-1.1.0.tgz", + "integrity": "sha512-/IXtbwEk5HTPyEwyKX6hGkYXxM9nbj64B+ilVJnC/R6B0pH5G4V3b0pVbL7DBj4tkhBAppbQUlf6F6Xl9LHu1g==", + "license": "MIT", + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/mime-db": { + "version": "1.52.0", + "resolved": "https://registry.npmmirror.com/mime-db/-/mime-db-1.52.0.tgz", + "integrity": "sha512-sPU4uV7dYlvtWJxwwxHD0PuihVNiE7TyAbQ5SWxDCB9mUYvOgroQOwYQQOKPJ8CIbE+1ETVlOoK1UC2nU3gYvg==", + "license": "MIT", + "engines": { + "node": ">= 0.6" + } + }, + "node_modules/mime-types": { + "version": "2.1.35", + "resolved": "https://registry.npmmirror.com/mime-types/-/mime-types-2.1.35.tgz", + "integrity": "sha512-ZDY+bPm5zTTF+YpCrAU9nK0UgICYPT0QtT1NZWFv4s++TNkcgVaT0g6+4R2uI4MjQjzysHB1zxuWL50hzaeXiw==", + "license": "MIT", + "dependencies": { + "mime-db": "1.52.0" + }, + "engines": { + "node": ">= 0.6" + } + }, + "node_modules/proxy-from-env": { + "version": "1.1.0", + "resolved": "https://registry.npmmirror.com/proxy-from-env/-/proxy-from-env-1.1.0.tgz", + "integrity": "sha512-D+zkORCbA9f1tdWRK0RaCR3GPv50cMxcrz4X8k5LTSUD1Dkw47mKJEZQNunItRTkWwgtaUSo1RVFRIG9ZXiFYg==", + "license": "MIT" + } + } +} diff --git a/sandbox/sandbox_base_image/nodejs/package.json b/sandbox/sandbox_base_image/nodejs/package.json new file mode 100644 index 00000000000..3bdae4a936c --- /dev/null +++ b/sandbox/sandbox_base_image/nodejs/package.json @@ -0,0 +1,15 @@ +{ + "name": "nodejs", + "version": "1.0.0", + "main": "index.js", + "scripts": { + "test": "echo \"Error: no test specified\" && exit 1" + }, + "keywords": [], + "author": "", + "license": "ISC", + "description": "", + "dependencies": { + "axios": "^1.9.0" + } +} diff --git a/sandbox/sandbox_base_image/python/Dockerfile b/sandbox/sandbox_base_image/python/Dockerfile new file mode 100644 index 00000000000..7b985764f60 --- /dev/null +++ b/sandbox/sandbox_base_image/python/Dockerfile @@ -0,0 +1,15 @@ +FROM python:3.11-slim-bookworm + +COPY --from=ghcr.io/astral-sh/uv:0.7.5 /uv /uvx /bin/ +ENV UV_INDEX_URL=https://pypi.tuna.tsinghua.edu.cn/simple + +COPY requirements.txt . + +RUN grep -rl 'deb.debian.org' /etc/apt/ | xargs sed -i 's|http[s]*://deb.debian.org|https://mirrors.tuna.tsinghua.edu.cn|g' && \ + apt-get update && \ + apt-get install -y curl gcc && \ + uv pip install --system -r requirements.txt + +WORKDIR /workspace + +CMD ["sleep", "infinity"] diff --git a/sandbox/sandbox_base_image/python/requirements.txt b/sandbox/sandbox_base_image/python/requirements.txt new file mode 100644 index 00000000000..4ad1501633d --- /dev/null +++ b/sandbox/sandbox_base_image/python/requirements.txt @@ -0,0 +1,3 @@ +numpy +pandas +requests diff --git a/sandbox/scripts/restart.sh b/sandbox/scripts/restart.sh new file mode 100755 index 00000000000..525465903e4 --- /dev/null +++ b/sandbox/scripts/restart.sh @@ -0,0 +1,21 @@ +#!/bin/bash +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +set -e + +bash "$(dirname "$0")/stop.sh" +bash "$(dirname "$0")/start.sh" diff --git a/sandbox/scripts/start.sh b/sandbox/scripts/start.sh new file mode 100755 index 00000000000..68c7227d531 --- /dev/null +++ b/sandbox/scripts/start.sh @@ -0,0 +1,72 @@ +#!/bin/bash +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +set -e + +BASE_DIR="$(cd "$(dirname "$0")/.." && pwd)" +cd "$BASE_DIR" + +if [ -f .env ]; then + source .env + SANDBOX_EXECUTOR_MANAGER_PORT="${SANDBOX_EXECUTOR_MANAGER_PORT:-9385}" # Default to 9385 if not set in .env + SANDBOX_EXECUTOR_MANAGER_POOL_SIZE="${SANDBOX_EXECUTOR_MANAGER_POOL_SIZE:-5}" # Default to 5 if not set in .env + SANDBOX_BASE_PYTHON_IMAGE=${SANDBOX_BASE_PYTHON_IMAGE-"sandbox-base-python:latest"} + SANDBOX_BASE_NODEJS_IMAGE=${SANDBOX_BASE_NODEJS_IMAGE-"sandbox-base-nodejs:latest"} +else + echo "⚠️ .env not found, using default ports and pool size" + SANDBOX_EXECUTOR_MANAGER_PORT=9385 + SANDBOX_EXECUTOR_MANAGER_POOL_SIZE=5 + SANDBOX_BASE_PYTHON_IMAGE=sandbox-base-python:latest + SANDBOX_BASE_NODEJS_IMAGE=sandbox-base-nodejs:latest +fi + +echo "📦 STEP 1: Build sandbox-base image ..." +if [ -f .env ]; then + source .env && + echo "🐍 Building base sandbox image for Python ($SANDBOX_BASE_PYTHON_IMAGE)..." && + docker build -t "$SANDBOX_BASE_PYTHON_IMAGE" ./sandbox_base_image/python && + echo "⬢ Building base sandbox image for Nodejs ($SANDBOX_BASE_NODEJS_IMAGE)..." && + docker build -t "$SANDBOX_BASE_NODEJS_IMAGE" ./sandbox_base_image/nodejs +else + echo "⚠️ .env file not found, skipping build." +fi + +echo "🧹 STEP 2: Clean up old sandbox containers (sandbox_nodejs_0~$((SANDBOX_EXECUTOR_MANAGER_POOL_SIZE - 1)) and sandbox_python_0~$((SANDBOX_EXECUTOR_MANAGER_POOL_SIZE - 1))) ..." +for i in $(seq 0 $((SANDBOX_EXECUTOR_MANAGER_POOL_SIZE - 1))); do + echo "🧹 Deleting sandbox_python_$i..." + docker rm -f "sandbox_python_$i" >/dev/null 2>&1 || true + + echo "🧹 Deleting sandbox_nodejs_$i..." + docker rm -f "sandbox_nodejs_$i" >/dev/null 2>&1 || true +done + +echo "🔧 STEP 3: Build executor services ..." +docker compose build + +echo "🚀 STEP 4: Start services ..." +docker compose up -d + +echo "⏳ STEP 5a: Check if ports are open (basic connectivity) ..." +bash ./scripts/wait-for-it.sh "localhost" "$SANDBOX_EXECUTOR_MANAGER_PORT" -t 30 + +echo "⏳ STEP 5b: Check if the interfaces are healthy (/healthz) ..." +bash ./scripts/wait-for-it-http.sh "http://localhost:$SANDBOX_EXECUTOR_MANAGER_PORT/healthz" 30 + +echo "✅ STEP 6: Run security tests ..." +python3 ./tests/sandbox_security_tests_full.py + +echo "🎉 Service is ready: http://localhost:$SANDBOX_EXECUTOR_MANAGER_PORT/docs" diff --git a/sandbox/scripts/stop.sh b/sandbox/scripts/stop.sh new file mode 100755 index 00000000000..51bd2b6e93b --- /dev/null +++ b/sandbox/scripts/stop.sh @@ -0,0 +1,40 @@ +#!/bin/bash +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +set -e + +BASE_DIR="$(cd "$(dirname "$0")/.." && pwd)" +cd "$BASE_DIR" + +echo "🛑 Stopping all services..." +docker compose down + +echo "🧹 Deleting sandbox containers..." +if [ -f .env ]; then + source .env + for i in $(seq 0 $((SANDBOX_EXECUTOR_MANAGER_POOL_SIZE - 1))); do + echo "🧹 Deleting sandbox_python_$i..." + docker rm -f "sandbox_python_$i" >/dev/null 2>&1 || true + + echo "🧹 Deleting sandbox_nodejs_$i..." + docker rm -f "sandbox_nodejs_$i" >/dev/null 2>&1 || true + done +else + echo "⚠️ .env not found, skipping container cleanup" +fi + +echo "✅ Stopping and cleanup complete" diff --git a/sandbox/scripts/wait-for-it-http.sh b/sandbox/scripts/wait-for-it-http.sh new file mode 100755 index 00000000000..c99c4970d3c --- /dev/null +++ b/sandbox/scripts/wait-for-it-http.sh @@ -0,0 +1,31 @@ +#!/bin/bash +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +url=$1 +timeout=${2:-15} +quiet=${3:-0} + +for i in $(seq "$timeout"); do + if curl -fs "$url" >/dev/null; then + [[ "$quiet" -ne 1 ]] && echo "✔ $url is healthy after $i seconds" + exit 0 + fi + sleep 1 +done + +echo "✖ Timeout after $timeout seconds waiting for $url" +exit 1 diff --git a/sandbox/scripts/wait-for-it.sh b/sandbox/scripts/wait-for-it.sh new file mode 100755 index 00000000000..718f25488f8 --- /dev/null +++ b/sandbox/scripts/wait-for-it.sh @@ -0,0 +1,50 @@ +#!/bin/bash +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +host=$1 +port=$2 +shift 2 + +timeout=15 +quiet=0 + +while [[ $# -gt 0 ]]; do + case "$1" in + -t | --timeout) + timeout="$2" + shift 2 + ;; + -q | --quiet) + quiet=1 + shift + ;; + *) + break + ;; + esac +done + +for i in $(seq "$timeout"); do + if nc -z "$host" "$port" >/dev/null 2>&1; then + [[ "$quiet" -ne 1 ]] && echo "✔ $host:$port is available after $i seconds" + exit 0 + fi + sleep 1 +done + +echo "✖ Timeout after $timeout seconds waiting for $host:$port" +exit 1 diff --git a/sandbox/tests/sandbox_security_tests_full.py b/sandbox/tests/sandbox_security_tests_full.py new file mode 100644 index 00000000000..758120758e6 --- /dev/null +++ b/sandbox/tests/sandbox_security_tests_full.py @@ -0,0 +1,436 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import base64 +import os +import textwrap +import time +from concurrent.futures import ThreadPoolExecutor, as_completed +from enum import Enum +from typing import Dict, Optional + +import requests +from pydantic import BaseModel + +API_URL = os.getenv("SANDBOX_API_URL", "http://localhost:9385/run") +TIMEOUT = 15 +MAX_WORKERS = 5 + + +class ResultStatus(str, Enum): + SUCCESS = "success" + PROGRAM_ERROR = "program_error" + RESOURCE_LIMIT_EXCEEDED = "resource_limit_exceeded" + UNAUTHORIZED_ACCESS = "unauthorized_access" + RUNTIME_ERROR = "runtime_error" + PROGRAM_RUNNER_ERROR = "program_runner_error" + + +class ResourceLimitType(str, Enum): + TIME = "time" + MEMORY = "memory" + OUTPUT = "output" + + +class UnauthorizedAccessType(str, Enum): + DISALLOWED_SYSCALL = "disallowed_syscall" + FILE_ACCESS = "file_access" + NETWORK_ACCESS = "network_access" + + +class RuntimeErrorType(str, Enum): + SIGNALLED = "signalled" + NONZERO_EXIT = "nonzero_exit" + + +class ExecutionResult(BaseModel): + status: ResultStatus + stdout: str + stderr: str + exit_code: int + detail: Optional[str] = None + resource_limit_type: Optional[ResourceLimitType] = None + unauthorized_access_type: Optional[UnauthorizedAccessType] = None + runtime_error_type: Optional[RuntimeErrorType] = None + + +class TestResult(BaseModel): + name: str + passed: bool + duration: float + expected_failure: bool = False + result: Optional[ExecutionResult] = None + error: Optional[str] = None + validation_error: Optional[str] = None + + +def encode_code(code: str) -> str: + return base64.b64encode(code.encode("utf-8")).decode("utf-8") + + +def execute_single_test(name: str, code: str, language: str, arguments: dict, expect_fail: bool = False) -> TestResult: + """Execute a single test case""" + payload = { + "code_b64": encode_code(textwrap.dedent(code)), + "language": language, + "arguments": arguments, + } + + test_result = TestResult(name=name, passed=False, duration=0, expected_failure=expect_fail) + + really_processed = False + try: + while not really_processed: + start_time = time.perf_counter() + + resp = requests.post(API_URL, json=payload, timeout=TIMEOUT) + resp.raise_for_status() + response_data = resp.json() + if response_data["exit_code"] == -429: # too many request + print(f"[{name}] Reached request limit, retring...") + time.sleep(0.5) + continue + really_processed = True + + print("-------------------") + print(f"{name}:\n{response_data}") + print("-------------------") + + test_result.duration = time.perf_counter() - start_time + test_result.result = ExecutionResult(**response_data) + + # Validate test result expectations + validate_test_result(name, expect_fail, test_result) + + except requests.exceptions.RequestException as e: + test_result.duration = time.perf_counter() - start_time + test_result.error = f"Request failed: {str(e)}" + test_result.result = ExecutionResult( + status=ResultStatus.PROGRAM_RUNNER_ERROR, + stdout="", + stderr=str(e), + exit_code=-999, + detail="request_failed", + ) + + return test_result + + +def validate_test_result(name: str, expect_fail: bool, test_result: TestResult): + """Validate if the test result meets expectations""" + if not test_result.result: + test_result.passed = False + test_result.validation_error = "No result returned" + return + + test_result.passed = test_result.result.status == ResultStatus.SUCCESS + # General validation logic + if expect_fail: + # Tests expected to fail should return a non-success status + if test_result.passed: + test_result.validation_error = "Expected failure but actually succeeded" + else: + # Tests expected to succeed should return a success status + if not test_result.passed: + test_result.validation_error = f"Unexpected failure (status={test_result.result.status})" + + +def get_test_cases() -> Dict[str, dict]: + """Return test cases (code, whether expected to fail)""" + return { + "1 Infinite loop: Should be forcibly terminated": { + "code": """ +def main(): + while True: + pass + """, + "should_fail": True, + "arguments": {}, + "language": "python", + }, + "2 Infinite loop: Should be forcibly terminated": { + "code": """ +def main(): + while True: + pass + """, + "should_fail": True, + "arguments": {}, + "language": "python", + }, + "3 Infinite loop: Should be forcibly terminated": { + "code": """ +def main(): + while True: + pass + """, + "should_fail": True, + "arguments": {}, + "language": "python", + }, + "4 Infinite loop: Should be forcibly terminated": { + "code": """ +def main(): + while True: + pass + """, + "should_fail": True, + "arguments": {}, + "language": "python", + }, + "5 Infinite loop: Should be forcibly terminated": { + "code": """ +def main(): + while True: + pass + """, + "should_fail": True, + "arguments": {}, + "language": "python", + }, + "6 Infinite loop: Should be forcibly terminated": { + "code": """ +def main(): + while True: + pass + """, + "should_fail": True, + "arguments": {}, + "language": "python", + }, + "7 Normal test: Python without dependencies": { + "code": """ +def main(): + return {"data": "hello, world"} + """, + "should_fail": False, + "arguments": {}, + "language": "python", + }, + "8 Normal test: Python with pandas, should pass without any error": { + "code": """ +import pandas as pd + +def main(): + data = {'Name': ['Alice', 'Bob', 'Charlie'], + 'Age': [25, 30, 35]} + df = pd.DataFrame(data) + """, + "should_fail": False, + "arguments": {}, + "language": "python", + }, + "9 Normal test: Nodejs without dependencies, should pass without any error": { + "code": """ +const https = require('https'); + +async function main(args) { + return new Promise((resolve, reject) => { + const req = https.get('https://example.com/', (res) => { + let data = ''; + + res.on('data', (chunk) => { + data += chunk; + }); + + res.on('end', () => { + clearTimeout(timeout); + console.log('Body:', data); + resolve(data); + }); + }); + + const timeout = setTimeout(() => { + req.destroy(new Error('Request timeout after 10s')); + }, 10000); + + req.on('error', (err) => { + clearTimeout(timeout); + console.error('Error:', err.message); + reject(err); + }); + }); +} + +module.exports = { main }; + """, + "should_fail": False, + "arguments": {}, + "language": "nodejs", + }, + "10 Normal test: Nodejs with axios, should pass without any error": { + "code": """ +const axios = require('axios'); + +async function main(args) { + try { + const response = await axios.get('https://example.com/', { + timeout: 10000 + }); + console.log('Body:', response.data); + } catch (error) { + console.error('Error:', error.message); + } +} + +module.exports = { main }; + """, + "should_fail": False, + "arguments": {}, + "language": "nodejs", + }, + "11 Dangerous import: Should fail due to os module import": { + "code": """ +import os + +def main(): + pass + """, + "should_fail": True, + "arguments": {}, + "language": "python", + }, + "12 Dangerous import from subprocess: Should fail due to subprocess import": { + "code": """ +from subprocess import Popen + +def main(): + pass + """, + "should_fail": True, + "arguments": {}, + "language": "python", + }, + "13 Dangerous call: Should fail due to eval function call": { + "code": """ +def main(): + eval('os.system("echo hello")') + """, + "should_fail": True, + "arguments": {}, + "language": "python", + }, + "14 Dangerous attribute access: Should fail due to shutil.rmtree": { + "code": """ +import shutil + +def main(): + shutil.rmtree('/some/path') + """, + "should_fail": True, + "arguments": {}, + "language": "python", + }, + "15 Dangerous binary operation: Should fail due to unsafe concatenation leading to eval": { + "code": """ +def main(): + dangerous_string = "os." + "system" + eval(dangerous_string + '("echo hello")') + """, + "should_fail": True, + "arguments": {}, + "language": "python", + }, + "16 Dangerous function definition: Should fail due to user-defined eval function": { + "code": """ +def eval_function(): + eval('os.system("echo hello")') + +def main(): + eval_function() + """, + "should_fail": True, + "arguments": {}, + "language": "python", + }, + "17 Memory exhaustion(256m): Should fail due to exceeding memory limit(try to allocate 300m)": { + "code": """ +def main(): + x = ['a' * 1024 * 1024] * 300 # 300MB +""", + "should_fail": True, + "arguments": {}, + "language": "python", + }, + } + + +def print_test_report(results: Dict[str, TestResult]): + print("\n=== 🔍 Test Report ===") + + max_name_len = max(len(name) for name in results) + + for name, result in results.items(): + status = "✅" if result.passed else "❌" + if result.expected_failure: + status = "⚠️" if result.passed else "✓" # Expected failure case + + print(f"{status} {name.ljust(max_name_len)} {result.duration:.2f}s") + + if result.error: + print(f" REQUEST ERROR: {result.error}") + if result.validation_error: + print(f" VALIDATION ERROR: {result.validation_error}") + + if result.result and not result.passed: + print(f" STATUS: {result.result.status}") + if result.result.stderr: + print(f" STDERR: {result.result.stderr[:200]}...") + if result.result.detail: + print(f" DETAIL: {result.result.detail}") + + passed = sum(1 for r in results.values() if ((not r.expected_failure and r.passed) or (r.expected_failure and not r.passed))) + failed = len(results) - passed + + print("\n=== 📊 Statistics ===") + print(f"✅ Passed: {passed}") + print(f"❌ Failed: {failed}") + print(f"📌 Total: {len(results)}") + + +def main(): + print(f"🔐 Starting sandbox security tests (API: {API_URL})") + print(f"🚀 Concurrent threads: {MAX_WORKERS}") + + test_cases = get_test_cases() + results = {} + + with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor: + futures = {} + for name, detail in test_cases.items(): + # ✅ Log when a task is submitted + print(f"✅ Task submitted: {name}") + time.sleep(0.4) + future = executor.submit(execute_single_test, name, detail["code"], detail["language"], detail["arguments"], detail["should_fail"]) + futures[future] = name + + print("\n=== 🚦 Test Progress ===") + for i, future in enumerate(as_completed(futures)): + name = futures[future] + print(f" {i + 1}/{len(test_cases)} completed: {name}") + try: + results[name] = future.result() + except Exception as e: + print(f"⚠️ Test {name} execution exception: {str(e)}") + results[name] = TestResult(name=name, passed=False, duration=0, error=f"Execution exception: {str(e)}") + + print_test_report(results) + + if any(not r.passed and not r.expected_failure for r in results.values()): + exit(1) + + +if __name__ == "__main__": + main() diff --git a/sandbox/uv.lock b/sandbox/uv.lock new file mode 100644 index 00000000000..1f27216bf06 --- /dev/null +++ b/sandbox/uv.lock @@ -0,0 +1,539 @@ +version = 1 +revision = 1 +requires-python = ">=3.10" + +[[package]] +name = "annotated-types" +version = "0.7.0" +source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" } +sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/ee/67/531ea369ba64dcff5ec9c3402f9f51bf748cec26dde048a2f973a4eea7f5/annotated_types-0.7.0.tar.gz", hash = "sha256:aff07c09a53a08bc8cfccb9c85b05f1aa9a2a6f23728d790723543408344ce89", size = 16081 } +wheels = [ + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl", hash = "sha256:1f02e8b43a8fbbc3f3e0d4f0f4bfc8131bcb4eebe8849b8e5c773f3a1c582a53", size = 13643 }, +] + +[[package]] +name = "anyio" +version = "4.9.0" +source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" } +dependencies = [ + { name = "exceptiongroup", marker = "python_full_version < '3.11'" }, + { name = "idna" }, + { name = "sniffio" }, + { name = "typing-extensions", marker = "python_full_version < '3.13'" }, +] +sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/95/7d/4c1bd541d4dffa1b52bd83fb8527089e097a106fc90b467a7313b105f840/anyio-4.9.0.tar.gz", hash = "sha256:673c0c244e15788651a4ff38710fea9675823028a6f08a5eda409e0c9840a028", size = 190949 } +wheels = [ + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/a1/ee/48ca1a7c89ffec8b6a0c5d02b89c305671d5ffd8d3c94acf8b8c408575bb/anyio-4.9.0-py3-none-any.whl", hash = "sha256:9f76d541cad6e36af7beb62e978876f3b41e3e04f2c1fbf0884604c0a9c4d93c", size = 100916 }, +] + +[[package]] +name = "basedpyright" +version = "1.29.1" +source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" } +dependencies = [ + { name = "nodejs-wheel-binaries" }, +] +sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/b9/18/f5e488eac4960ad9a2e71b95f0d91cf93a982c7f68aa90e4e0554f0bc37e/basedpyright-1.29.1.tar.gz", hash = "sha256:06bbe6c3b50ab4af20f80e154049477a50d8b81d2522eadbc9f472f2f92cd44b", size = 21773469 } +wheels = [ + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/95/1b/1bb837bbb7e259928f33d3c105dfef4f5349ef08b3ef45576801256e3234/basedpyright-1.29.1-py3-none-any.whl", hash = "sha256:b7eb65b9d4aaeeea29a349ac494252032a75a364942d0ac466d7f07ddeacc786", size = 11397959 }, +] + +[[package]] +name = "certifi" +version = "2025.4.26" +source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" } +sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/e8/9e/c05b3920a3b7d20d3d3310465f50348e5b3694f4f88c6daf736eef3024c4/certifi-2025.4.26.tar.gz", hash = "sha256:0a816057ea3cdefcef70270d2c515e4506bbc954f417fa5ade2021213bb8f0c6", size = 160705 } +wheels = [ + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/4a/7e/3db2bd1b1f9e95f7cddca6d6e75e2f2bd9f51b1246e546d88addca0106bd/certifi-2025.4.26-py3-none-any.whl", hash = "sha256:30350364dfe371162649852c63336a15c70c6510c2ad5015b21c2345311805f3", size = 159618 }, +] + +[[package]] +name = "charset-normalizer" +version = "3.4.2" +source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" } +sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/e4/33/89c2ced2b67d1c2a61c19c6751aa8902d46ce3dacb23600a283619f5a12d/charset_normalizer-3.4.2.tar.gz", hash = "sha256:5baececa9ecba31eff645232d59845c07aa030f0c81ee70184a90d35099a0e63", size = 126367 } +wheels = [ + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/95/28/9901804da60055b406e1a1c5ba7aac1276fb77f1dde635aabfc7fd84b8ab/charset_normalizer-3.4.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:7c48ed483eb946e6c04ccbe02c6b4d1d48e51944b6db70f697e089c193404941", size = 201818 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/d9/9b/892a8c8af9110935e5adcbb06d9c6fe741b6bb02608c6513983048ba1a18/charset_normalizer-3.4.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b2d318c11350e10662026ad0eb71bb51c7812fc8590825304ae0bdd4ac283acd", size = 144649 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/7b/a5/4179abd063ff6414223575e008593861d62abfc22455b5d1a44995b7c101/charset_normalizer-3.4.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9cbfacf36cb0ec2897ce0ebc5d08ca44213af24265bd56eca54bee7923c48fd6", size = 155045 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/3b/95/bc08c7dfeddd26b4be8c8287b9bb055716f31077c8b0ea1cd09553794665/charset_normalizer-3.4.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:18dd2e350387c87dabe711b86f83c9c78af772c748904d372ade190b5c7c9d4d", size = 147356 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/a8/2d/7a5b635aa65284bf3eab7653e8b4151ab420ecbae918d3e359d1947b4d61/charset_normalizer-3.4.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8075c35cd58273fee266c58c0c9b670947c19df5fb98e7b66710e04ad4e9ff86", size = 149471 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/ae/38/51fc6ac74251fd331a8cfdb7ec57beba8c23fd5493f1050f71c87ef77ed0/charset_normalizer-3.4.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5bf4545e3b962767e5c06fe1738f951f77d27967cb2caa64c28be7c4563e162c", size = 151317 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/b7/17/edee1e32215ee6e9e46c3e482645b46575a44a2d72c7dfd49e49f60ce6bf/charset_normalizer-3.4.2-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:7a6ab32f7210554a96cd9e33abe3ddd86732beeafc7a28e9955cdf22ffadbab0", size = 146368 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/26/2c/ea3e66f2b5f21fd00b2825c94cafb8c326ea6240cd80a91eb09e4a285830/charset_normalizer-3.4.2-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:b33de11b92e9f75a2b545d6e9b6f37e398d86c3e9e9653c4864eb7e89c5773ef", size = 154491 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/52/47/7be7fa972422ad062e909fd62460d45c3ef4c141805b7078dbab15904ff7/charset_normalizer-3.4.2-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:8755483f3c00d6c9a77f490c17e6ab0c8729e39e6390328e42521ef175380ae6", size = 157695 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/2f/42/9f02c194da282b2b340f28e5fb60762de1151387a36842a92b533685c61e/charset_normalizer-3.4.2-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:68a328e5f55ec37c57f19ebb1fdc56a248db2e3e9ad769919a58672958e8f366", size = 154849 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/67/44/89cacd6628f31fb0b63201a618049be4be2a7435a31b55b5eb1c3674547a/charset_normalizer-3.4.2-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:21b2899062867b0e1fde9b724f8aecb1af14f2778d69aacd1a5a1853a597a5db", size = 150091 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/1f/79/4b8da9f712bc079c0f16b6d67b099b0b8d808c2292c937f267d816ec5ecc/charset_normalizer-3.4.2-cp310-cp310-win32.whl", hash = "sha256:e8082b26888e2f8b36a042a58307d5b917ef2b1cacab921ad3323ef91901c71a", size = 98445 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/7d/d7/96970afb4fb66497a40761cdf7bd4f6fca0fc7bafde3a84f836c1f57a926/charset_normalizer-3.4.2-cp310-cp310-win_amd64.whl", hash = "sha256:f69a27e45c43520f5487f27627059b64aaf160415589230992cec34c5e18a509", size = 105782 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/05/85/4c40d00dcc6284a1c1ad5de5e0996b06f39d8232f1031cd23c2f5c07ee86/charset_normalizer-3.4.2-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:be1e352acbe3c78727a16a455126d9ff83ea2dfdcbc83148d2982305a04714c2", size = 198794 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/41/d9/7a6c0b9db952598e97e93cbdfcb91bacd89b9b88c7c983250a77c008703c/charset_normalizer-3.4.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:aa88ca0b1932e93f2d961bf3addbb2db902198dca337d88c89e1559e066e7645", size = 142846 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/66/82/a37989cda2ace7e37f36c1a8ed16c58cf48965a79c2142713244bf945c89/charset_normalizer-3.4.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d524ba3f1581b35c03cb42beebab4a13e6cdad7b36246bd22541fa585a56cccd", size = 153350 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/df/68/a576b31b694d07b53807269d05ec3f6f1093e9545e8607121995ba7a8313/charset_normalizer-3.4.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28a1005facc94196e1fb3e82a3d442a9d9110b8434fc1ded7a24a2983c9888d8", size = 145657 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/92/9b/ad67f03d74554bed3aefd56fe836e1623a50780f7c998d00ca128924a499/charset_normalizer-3.4.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fdb20a30fe1175ecabed17cbf7812f7b804b8a315a25f24678bcdf120a90077f", size = 147260 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/a6/e6/8aebae25e328160b20e31a7e9929b1578bbdc7f42e66f46595a432f8539e/charset_normalizer-3.4.2-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0f5d9ed7f254402c9e7d35d2f5972c9bbea9040e99cd2861bd77dc68263277c7", size = 149164 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/8b/f2/b3c2f07dbcc248805f10e67a0262c93308cfa149a4cd3d1fe01f593e5fd2/charset_normalizer-3.4.2-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:efd387a49825780ff861998cd959767800d54f8308936b21025326de4b5a42b9", size = 144571 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/60/5b/c3f3a94bc345bc211622ea59b4bed9ae63c00920e2e8f11824aa5708e8b7/charset_normalizer-3.4.2-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:f0aa37f3c979cf2546b73e8222bbfa3dc07a641585340179d768068e3455e544", size = 151952 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/e2/4d/ff460c8b474122334c2fa394a3f99a04cf11c646da895f81402ae54f5c42/charset_normalizer-3.4.2-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:e70e990b2137b29dc5564715de1e12701815dacc1d056308e2b17e9095372a82", size = 155959 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/a2/2b/b964c6a2fda88611a1fe3d4c400d39c66a42d6c169c924818c848f922415/charset_normalizer-3.4.2-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:0c8c57f84ccfc871a48a47321cfa49ae1df56cd1d965a09abe84066f6853b9c0", size = 153030 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/59/2e/d3b9811db26a5ebf444bc0fa4f4be5aa6d76fc6e1c0fd537b16c14e849b6/charset_normalizer-3.4.2-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:6b66f92b17849b85cad91259efc341dce9c1af48e2173bf38a85c6329f1033e5", size = 148015 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/90/07/c5fd7c11eafd561bb51220d600a788f1c8d77c5eef37ee49454cc5c35575/charset_normalizer-3.4.2-cp311-cp311-win32.whl", hash = "sha256:daac4765328a919a805fa5e2720f3e94767abd632ae410a9062dff5412bae65a", size = 98106 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/a8/05/5e33dbef7e2f773d672b6d79f10ec633d4a71cd96db6673625838a4fd532/charset_normalizer-3.4.2-cp311-cp311-win_amd64.whl", hash = "sha256:e53efc7c7cee4c1e70661e2e112ca46a575f90ed9ae3fef200f2a25e954f4b28", size = 105402 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/d7/a4/37f4d6035c89cac7930395a35cc0f1b872e652eaafb76a6075943754f095/charset_normalizer-3.4.2-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:0c29de6a1a95f24b9a1aa7aefd27d2487263f00dfd55a77719b530788f75cff7", size = 199936 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/ee/8a/1a5e33b73e0d9287274f899d967907cd0bf9c343e651755d9307e0dbf2b3/charset_normalizer-3.4.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cddf7bd982eaa998934a91f69d182aec997c6c468898efe6679af88283b498d3", size = 143790 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/66/52/59521f1d8e6ab1482164fa21409c5ef44da3e9f653c13ba71becdd98dec3/charset_normalizer-3.4.2-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:fcbe676a55d7445b22c10967bceaaf0ee69407fbe0ece4d032b6eb8d4565982a", size = 153924 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/86/2d/fb55fdf41964ec782febbf33cb64be480a6b8f16ded2dbe8db27a405c09f/charset_normalizer-3.4.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d41c4d287cfc69060fa91cae9683eacffad989f1a10811995fa309df656ec214", size = 146626 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/8c/73/6ede2ec59bce19b3edf4209d70004253ec5f4e319f9a2e3f2f15601ed5f7/charset_normalizer-3.4.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4e594135de17ab3866138f496755f302b72157d115086d100c3f19370839dd3a", size = 148567 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/09/14/957d03c6dc343c04904530b6bef4e5efae5ec7d7990a7cbb868e4595ee30/charset_normalizer-3.4.2-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:cf713fe9a71ef6fd5adf7a79670135081cd4431c2943864757f0fa3a65b1fafd", size = 150957 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/0d/c8/8174d0e5c10ccebdcb1b53cc959591c4c722a3ad92461a273e86b9f5a302/charset_normalizer-3.4.2-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:a370b3e078e418187da8c3674eddb9d983ec09445c99a3a263c2011993522981", size = 145408 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/58/aa/8904b84bc8084ac19dc52feb4f5952c6df03ffb460a887b42615ee1382e8/charset_normalizer-3.4.2-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:a955b438e62efdf7e0b7b52a64dc5c3396e2634baa62471768a64bc2adb73d5c", size = 153399 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/c2/26/89ee1f0e264d201cb65cf054aca6038c03b1a0c6b4ae998070392a3ce605/charset_normalizer-3.4.2-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:7222ffd5e4de8e57e03ce2cef95a4c43c98fcb72ad86909abdfc2c17d227fc1b", size = 156815 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/fd/07/68e95b4b345bad3dbbd3a8681737b4338ff2c9df29856a6d6d23ac4c73cb/charset_normalizer-3.4.2-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:bee093bf902e1d8fc0ac143c88902c3dfc8941f7ea1d6a8dd2bcb786d33db03d", size = 154537 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/77/1a/5eefc0ce04affb98af07bc05f3bac9094513c0e23b0562d64af46a06aae4/charset_normalizer-3.4.2-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:dedb8adb91d11846ee08bec4c8236c8549ac721c245678282dcb06b221aab59f", size = 149565 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/37/a0/2410e5e6032a174c95e0806b1a6585eb21e12f445ebe239fac441995226a/charset_normalizer-3.4.2-cp312-cp312-win32.whl", hash = "sha256:db4c7bf0e07fc3b7d89ac2a5880a6a8062056801b83ff56d8464b70f65482b6c", size = 98357 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/6c/4f/c02d5c493967af3eda9c771ad4d2bbc8df6f99ddbeb37ceea6e8716a32bc/charset_normalizer-3.4.2-cp312-cp312-win_amd64.whl", hash = "sha256:5a9979887252a82fefd3d3ed2a8e3b937a7a809f65dcb1e068b090e165bbe99e", size = 105776 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/ea/12/a93df3366ed32db1d907d7593a94f1fe6293903e3e92967bebd6950ed12c/charset_normalizer-3.4.2-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:926ca93accd5d36ccdabd803392ddc3e03e6d4cd1cf17deff3b989ab8e9dbcf0", size = 199622 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/04/93/bf204e6f344c39d9937d3c13c8cd5bbfc266472e51fc8c07cb7f64fcd2de/charset_normalizer-3.4.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:eba9904b0f38a143592d9fc0e19e2df0fa2e41c3c3745554761c5f6447eedabf", size = 143435 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/22/2a/ea8a2095b0bafa6c5b5a55ffdc2f924455233ee7b91c69b7edfcc9e02284/charset_normalizer-3.4.2-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3fddb7e2c84ac87ac3a947cb4e66d143ca5863ef48e4a5ecb83bd48619e4634e", size = 153653 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/b6/57/1b090ff183d13cef485dfbe272e2fe57622a76694061353c59da52c9a659/charset_normalizer-3.4.2-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:98f862da73774290f251b9df8d11161b6cf25b599a66baf087c1ffe340e9bfd1", size = 146231 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/e2/28/ffc026b26f441fc67bd21ab7f03b313ab3fe46714a14b516f931abe1a2d8/charset_normalizer-3.4.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c9379d65defcab82d07b2a9dfbfc2e95bc8fe0ebb1b176a3190230a3ef0e07c", size = 148243 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/c0/0f/9abe9bd191629c33e69e47c6ef45ef99773320e9ad8e9cb08b8ab4a8d4cb/charset_normalizer-3.4.2-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e635b87f01ebc977342e2697d05b56632f5f879a4f15955dfe8cef2448b51691", size = 150442 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/67/7c/a123bbcedca91d5916c056407f89a7f5e8fdfce12ba825d7d6b9954a1a3c/charset_normalizer-3.4.2-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:1c95a1e2902a8b722868587c0e1184ad5c55631de5afc0eb96bc4b0d738092c0", size = 145147 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/ec/fe/1ac556fa4899d967b83e9893788e86b6af4d83e4726511eaaad035e36595/charset_normalizer-3.4.2-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:ef8de666d6179b009dce7bcb2ad4c4a779f113f12caf8dc77f0162c29d20490b", size = 153057 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/2b/ff/acfc0b0a70b19e3e54febdd5301a98b72fa07635e56f24f60502e954c461/charset_normalizer-3.4.2-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:32fc0341d72e0f73f80acb0a2c94216bd704f4f0bce10aedea38f30502b271ff", size = 156454 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/92/08/95b458ce9c740d0645feb0e96cea1f5ec946ea9c580a94adfe0b617f3573/charset_normalizer-3.4.2-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:289200a18fa698949d2b39c671c2cc7a24d44096784e76614899a7ccf2574b7b", size = 154174 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/78/be/8392efc43487ac051eee6c36d5fbd63032d78f7728cb37aebcc98191f1ff/charset_normalizer-3.4.2-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:4a476b06fbcf359ad25d34a057b7219281286ae2477cc5ff5e3f70a246971148", size = 149166 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/44/96/392abd49b094d30b91d9fbda6a69519e95802250b777841cf3bda8fe136c/charset_normalizer-3.4.2-cp313-cp313-win32.whl", hash = "sha256:aaeeb6a479c7667fbe1099af9617c83aaca22182d6cf8c53966491a0f1b7ffb7", size = 98064 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/e9/b0/0200da600134e001d91851ddc797809e2fe0ea72de90e09bec5a2fbdaccb/charset_normalizer-3.4.2-cp313-cp313-win_amd64.whl", hash = "sha256:aa6af9e7d59f9c12b33ae4e9450619cf2488e2bbe9b44030905877f0b2324980", size = 105641 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/20/94/c5790835a017658cbfabd07f3bfb549140c3ac458cfc196323996b10095a/charset_normalizer-3.4.2-py3-none-any.whl", hash = "sha256:7f56930ab0abd1c45cd15be65cc741c28b1c9a34876ce8c17a2fa107810c0af0", size = 52626 }, +] + +[[package]] +name = "click" +version = "8.1.8" +source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" } +dependencies = [ + { name = "colorama", marker = "sys_platform == 'win32'" }, +] +sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/b9/2e/0090cbf739cee7d23781ad4b89a9894a41538e4fcf4c31dcdd705b78eb8b/click-8.1.8.tar.gz", hash = "sha256:ed53c9d8990d83c2a27deae68e4ee337473f6330c040a31d4225c9574d16096a", size = 226593 } +wheels = [ + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/7e/d4/7ebdbd03970677812aac39c869717059dbb71a4cfc033ca6e5221787892c/click-8.1.8-py3-none-any.whl", hash = "sha256:63c132bbbed01578a06712a2d1f497bb62d9c1c0d329b7903a866228027263b2", size = 98188 }, +] + +[[package]] +name = "colorama" +version = "0.4.6" +source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" } +sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/d8/53/6f443c9a4a8358a93a6792e2acffb9d9d5cb0a5cfd8802644b7b1c9a02e4/colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44", size = 27697 } +wheels = [ + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/d1/d6/3965ed04c63042e047cb6a3e6ed1a63a35087b6a609aa3a15ed8ac56c221/colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6", size = 25335 }, +] + +[[package]] +name = "deprecated" +version = "1.2.18" +source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" } +dependencies = [ + { name = "wrapt" }, +] +sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/98/97/06afe62762c9a8a86af0cfb7bfdab22a43ad17138b07af5b1a58442690a2/deprecated-1.2.18.tar.gz", hash = "sha256:422b6f6d859da6f2ef57857761bfb392480502a64c3028ca9bbe86085d72115d", size = 2928744 } +wheels = [ + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/6e/c6/ac0b6c1e2d138f1002bcf799d330bd6d85084fece321e662a14223794041/Deprecated-1.2.18-py2.py3-none-any.whl", hash = "sha256:bd5011788200372a32418f888e326a09ff80d0214bd961147cfed01b5c018eec", size = 9998 }, +] + +[[package]] +name = "exceptiongroup" +version = "1.2.2" +source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" } +sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/09/35/2495c4ac46b980e4ca1f6ad6db102322ef3ad2410b79fdde159a4b0f3b92/exceptiongroup-1.2.2.tar.gz", hash = "sha256:47c2edf7c6738fafb49fd34290706d1a1a2f4d1c6df275526b62cbb4aa5393cc", size = 28883 } +wheels = [ + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/02/cc/b7e31358aac6ed1ef2bb790a9746ac2c69bcb3c8588b41616914eb106eaf/exceptiongroup-1.2.2-py3-none-any.whl", hash = "sha256:3111b9d131c238bec2f8f516e123e14ba243563fb135d3fe885990585aa7795b", size = 16453 }, +] + +[[package]] +name = "fastapi" +version = "0.115.12" +source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" } +dependencies = [ + { name = "pydantic" }, + { name = "starlette" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/f4/55/ae499352d82338331ca1e28c7f4a63bfd09479b16395dce38cf50a39e2c2/fastapi-0.115.12.tar.gz", hash = "sha256:1e2c2a2646905f9e83d32f04a3f86aff4a286669c6c950ca95b5fd68c2602681", size = 295236 } +wheels = [ + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/50/b3/b51f09c2ba432a576fe63758bddc81f78f0c6309d9e5c10d194313bf021e/fastapi-0.115.12-py3-none-any.whl", hash = "sha256:e94613d6c05e27be7ffebdd6ea5f388112e5e430c8f7d6494a9d1d88d43e814d", size = 95164 }, +] + +[[package]] +name = "gvisor-sandbox" +version = "0.1.0" +source = { virtual = "." } +dependencies = [ + { name = "fastapi" }, + { name = "httpx" }, + { name = "pydantic" }, + { name = "requests" }, + { name = "slowapi" }, + { name = "uvicorn" }, +] + +[package.dev-dependencies] +dev = [ + { name = "basedpyright" }, +] + +[package.metadata] +requires-dist = [ + { name = "fastapi", specifier = ">=0.115.12" }, + { name = "httpx", specifier = ">=0.28.1" }, + { name = "pydantic", specifier = ">=2.11.4" }, + { name = "requests", specifier = ">=2.32.3" }, + { name = "slowapi", specifier = ">=0.1.9" }, + { name = "uvicorn", specifier = ">=0.34.2" }, +] + +[package.metadata.requires-dev] +dev = [{ name = "basedpyright", specifier = ">=1.29.1" }] + +[[package]] +name = "h11" +version = "0.16.0" +source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" } +sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/01/ee/02a2c011bdab74c6fb3c75474d40b3052059d95df7e73351460c8588d963/h11-0.16.0.tar.gz", hash = "sha256:4e35b956cf45792e4caa5885e69fba00bdbc6ffafbfa020300e549b208ee5ff1", size = 101250 } +wheels = [ + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/04/4b/29cac41a4d98d144bf5f6d33995617b185d14b22401f75ca86f384e87ff1/h11-0.16.0-py3-none-any.whl", hash = "sha256:63cf8bbe7522de3bf65932fda1d9c2772064ffb3dae62d55932da54b31cb6c86", size = 37515 }, +] + +[[package]] +name = "httpcore" +version = "1.0.9" +source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" } +dependencies = [ + { name = "certifi" }, + { name = "h11" }, +] +sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/06/94/82699a10bca87a5556c9c59b5963f2d039dbd239f25bc2a63907a05a14cb/httpcore-1.0.9.tar.gz", hash = "sha256:6e34463af53fd2ab5d807f399a9b45ea31c3dfa2276f15a2c3f00afff6e176e8", size = 85484 } +wheels = [ + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/7e/f5/f66802a942d491edb555dd61e3a9961140fd64c90bce1eafd741609d334d/httpcore-1.0.9-py3-none-any.whl", hash = "sha256:2d400746a40668fc9dec9810239072b40b4484b640a8c38fd654a024c7a1bf55", size = 78784 }, +] + +[[package]] +name = "httpx" +version = "0.28.1" +source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" } +dependencies = [ + { name = "anyio" }, + { name = "certifi" }, + { name = "httpcore" }, + { name = "idna" }, +] +sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/b1/df/48c586a5fe32a0f01324ee087459e112ebb7224f646c0b5023f5e79e9956/httpx-0.28.1.tar.gz", hash = "sha256:75e98c5f16b0f35b567856f597f06ff2270a374470a5c2392242528e3e3e42fc", size = 141406 } +wheels = [ + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/2a/39/e50c7c3a983047577ee07d2a9e53faf5a69493943ec3f6a384bdc792deb2/httpx-0.28.1-py3-none-any.whl", hash = "sha256:d909fcccc110f8c7faf814ca82a9a4d816bc5a6dbfea25d6591d6985b8ba59ad", size = 73517 }, +] + +[[package]] +name = "idna" +version = "3.10" +source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" } +sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/f1/70/7703c29685631f5a7590aa73f1f1d3fa9a380e654b86af429e0934a32f7d/idna-3.10.tar.gz", hash = "sha256:12f65c9b470abda6dc35cf8e63cc574b1c52b11df2c86030af0ac09b01b13ea9", size = 190490 } +wheels = [ + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/76/c6/c88e154df9c4e1a2a66ccf0005a88dfb2650c1dffb6f5ce603dfbd452ce3/idna-3.10-py3-none-any.whl", hash = "sha256:946d195a0d259cbba61165e88e65941f16e9b36ea6ddb97f00452bae8b1287d3", size = 70442 }, +] + +[[package]] +name = "limits" +version = "5.1.0" +source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" } +dependencies = [ + { name = "deprecated" }, + { name = "packaging" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/c6/94/a04e64f487a56f97aff67c53df609cc19d5c3f3e7e5697ec8a1ff8413829/limits-5.1.0.tar.gz", hash = "sha256:b298e4af0b47997da03cbeee9df027ddc2328f8630546125e81083bb56311827", size = 94655 } +wheels = [ + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/de/00/876a5ec60addda62ee13ac4b588a5afc0d1a86a431645a91711ceae834cf/limits-5.1.0-py3-none-any.whl", hash = "sha256:f368d4572ac3ef8190cb8b9911ed481175a0b4189894a63cac95cae39ebeb147", size = 60472 }, +] + +[[package]] +name = "nodejs-wheel-binaries" +version = "22.15.0" +source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" } +sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/45/5b/6c5f973765b96793d4e4d03684bcbd273b17e471ecc7e9bec4c32b595ebd/nodejs_wheel_binaries-22.15.0.tar.gz", hash = "sha256:ff81aa2a79db279c2266686ebcb829b6634d049a5a49fc7dc6921e4f18af9703", size = 8054 } +wheels = [ + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/d3/a8/a32e5bb99e95c536e7dac781cffab1e7e9f8661b8ee296b93df77e4df7f9/nodejs_wheel_binaries-22.15.0-py2.py3-none-macosx_11_0_arm64.whl", hash = "sha256:aa16366d48487fff89446fb237693e777aa2ecd987208db7d4e35acc40c3e1b1", size = 50514526 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/05/e8/eb024dbb3a7d3b98c8922d1c306be989befad4d2132292954cb902f43b07/nodejs_wheel_binaries-22.15.0-py2.py3-none-macosx_11_0_x86_64.whl", hash = "sha256:a54bb3fee9170003fa8abc69572d819b2b1540344eff78505fcc2129a9175596", size = 51409179 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/3f/0f/baa968456c3577e45c7d0e3715258bd175dcecc67b683a41a5044d5dae40/nodejs_wheel_binaries-22.15.0-py2.py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:867121ccf99d10523f6878a26db86e162c4939690e24cfb5bea56d01ea696c93", size = 57364460 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/2f/a2/977f63cd07ed8fc27bc0d0cd72e801fc3691ffc8cd40a51496ff18a6d0a2/nodejs_wheel_binaries-22.15.0-py2.py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0ab0fbcda2ddc8aab7db1505d72cb958f99324b3834c4543541a305e02bfe860", size = 57889101 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/67/7f/57b9c24a4f0d25490527b043146aa0fdff2d8fdc82f90667cdaf6f00cfc9/nodejs_wheel_binaries-22.15.0-py2.py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:2bde1d8e00cd955b9ce9ee9ac08309923e2778a790ee791b715e93e487e74bfd", size = 59190817 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/fd/7f/970acbe33b81c22b3c7928f52e32347030aa46d23d779cf781cf9a9cf557/nodejs_wheel_binaries-22.15.0-py2.py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:acdd4ef73b6701aab9fbe02ac5e104f208a5e3c300402fa41ad7bc7f49499fbf", size = 60220316 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/07/4c/030243c04bb60f0de66c2d7ee3be289c6d28ef09113c06ffa417bdfedf8f/nodejs_wheel_binaries-22.15.0-py2.py3-none-win_amd64.whl", hash = "sha256:51deaf13ee474e39684ce8c066dfe86240edb94e7241950ca789befbbbcbd23d", size = 40718853 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/1f/49/011d472814af4fabeaab7d7ce3d5a1a635a3dadc23ae404d1f546839ecb3/nodejs_wheel_binaries-22.15.0-py2.py3-none-win_arm64.whl", hash = "sha256:01a3fe4d60477f93bf21a44219db33548c75d7fed6dc6e6f4c05cf0adf015609", size = 36436645 }, +] + +[[package]] +name = "packaging" +version = "25.0" +source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" } +sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/a1/d4/1fc4078c65507b51b96ca8f8c3ba19e6a61c8253c72794544580a7b6c24d/packaging-25.0.tar.gz", hash = "sha256:d443872c98d677bf60f6a1f2f8c1cb748e8fe762d2bf9d3148b5599295b0fc4f", size = 165727 } +wheels = [ + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/20/12/38679034af332785aac8774540895e234f4d07f7545804097de4b666afd8/packaging-25.0-py3-none-any.whl", hash = "sha256:29572ef2b1f17581046b3a2227d5c611fb25ec70ca1ba8554b24b0e69331a484", size = 66469 }, +] + +[[package]] +name = "pydantic" +version = "2.11.4" +source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" } +dependencies = [ + { name = "annotated-types" }, + { name = "pydantic-core" }, + { name = "typing-extensions" }, + { name = "typing-inspection" }, +] +sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/77/ab/5250d56ad03884ab5efd07f734203943c8a8ab40d551e208af81d0257bf2/pydantic-2.11.4.tar.gz", hash = "sha256:32738d19d63a226a52eed76645a98ee07c1f410ee41d93b4afbfa85ed8111c2d", size = 786540 } +wheels = [ + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/e7/12/46b65f3534d099349e38ef6ec98b1a5a81f42536d17e0ba382c28c67ba67/pydantic-2.11.4-py3-none-any.whl", hash = "sha256:d9615eaa9ac5a063471da949c8fc16376a84afb5024688b3ff885693506764eb", size = 443900 }, +] + +[[package]] +name = "pydantic-core" +version = "2.33.2" +source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" } +dependencies = [ + { name = "typing-extensions" }, +] +sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/ad/88/5f2260bdfae97aabf98f1778d43f69574390ad787afb646292a638c923d4/pydantic_core-2.33.2.tar.gz", hash = "sha256:7cb8bc3605c29176e1b105350d2e6474142d7c1bd1d9327c4a9bdb46bf827acc", size = 435195 } +wheels = [ + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/e5/92/b31726561b5dae176c2d2c2dc43a9c5bfba5d32f96f8b4c0a600dd492447/pydantic_core-2.33.2-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:2b3d326aaef0c0399d9afffeb6367d5e26ddc24d351dbc9c636840ac355dc5d8", size = 2028817 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/a3/44/3f0b95fafdaca04a483c4e685fe437c6891001bf3ce8b2fded82b9ea3aa1/pydantic_core-2.33.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:0e5b2671f05ba48b94cb90ce55d8bdcaaedb8ba00cc5359f6810fc918713983d", size = 1861357 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/30/97/e8f13b55766234caae05372826e8e4b3b96e7b248be3157f53237682e43c/pydantic_core-2.33.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0069c9acc3f3981b9ff4cdfaf088e98d83440a4c7ea1bc07460af3d4dc22e72d", size = 1898011 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/9b/a3/99c48cf7bafc991cc3ee66fd544c0aae8dc907b752f1dad2d79b1b5a471f/pydantic_core-2.33.2-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:d53b22f2032c42eaaf025f7c40c2e3b94568ae077a606f006d206a463bc69572", size = 1982730 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/de/8e/a5b882ec4307010a840fb8b58bd9bf65d1840c92eae7534c7441709bf54b/pydantic_core-2.33.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:0405262705a123b7ce9f0b92f123334d67b70fd1f20a9372b907ce1080c7ba02", size = 2136178 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/e4/bb/71e35fc3ed05af6834e890edb75968e2802fe98778971ab5cba20a162315/pydantic_core-2.33.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4b25d91e288e2c4e0662b8038a28c6a07eaac3e196cfc4ff69de4ea3db992a1b", size = 2736462 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/31/0d/c8f7593e6bc7066289bbc366f2235701dcbebcd1ff0ef8e64f6f239fb47d/pydantic_core-2.33.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6bdfe4b3789761f3bcb4b1ddf33355a71079858958e3a552f16d5af19768fef2", size = 2005652 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/d2/7a/996d8bd75f3eda405e3dd219ff5ff0a283cd8e34add39d8ef9157e722867/pydantic_core-2.33.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:efec8db3266b76ef9607c2c4c419bdb06bf335ae433b80816089ea7585816f6a", size = 2113306 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/ff/84/daf2a6fb2db40ffda6578a7e8c5a6e9c8affb251a05c233ae37098118788/pydantic_core-2.33.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:031c57d67ca86902726e0fae2214ce6770bbe2f710dc33063187a68744a5ecac", size = 2073720 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/77/fb/2258da019f4825128445ae79456a5499c032b55849dbd5bed78c95ccf163/pydantic_core-2.33.2-cp310-cp310-musllinux_1_1_armv7l.whl", hash = "sha256:f8de619080e944347f5f20de29a975c2d815d9ddd8be9b9b7268e2e3ef68605a", size = 2244915 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/d8/7a/925ff73756031289468326e355b6fa8316960d0d65f8b5d6b3a3e7866de7/pydantic_core-2.33.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:73662edf539e72a9440129f231ed3757faab89630d291b784ca99237fb94db2b", size = 2241884 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/0b/b0/249ee6d2646f1cdadcb813805fe76265745c4010cf20a8eba7b0e639d9b2/pydantic_core-2.33.2-cp310-cp310-win32.whl", hash = "sha256:0a39979dcbb70998b0e505fb1556a1d550a0781463ce84ebf915ba293ccb7e22", size = 1910496 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/66/ff/172ba8f12a42d4b552917aa65d1f2328990d3ccfc01d5b7c943ec084299f/pydantic_core-2.33.2-cp310-cp310-win_amd64.whl", hash = "sha256:b0379a2b24882fef529ec3b4987cb5d003b9cda32256024e6fe1586ac45fc640", size = 1955019 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/3f/8d/71db63483d518cbbf290261a1fc2839d17ff89fce7089e08cad07ccfce67/pydantic_core-2.33.2-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:4c5b0a576fb381edd6d27f0a85915c6daf2f8138dc5c267a57c08a62900758c7", size = 2028584 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/24/2f/3cfa7244ae292dd850989f328722d2aef313f74ffc471184dc509e1e4e5a/pydantic_core-2.33.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:e799c050df38a639db758c617ec771fd8fb7a5f8eaaa4b27b101f266b216a246", size = 1855071 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/b3/d3/4ae42d33f5e3f50dd467761304be2fa0a9417fbf09735bc2cce003480f2a/pydantic_core-2.33.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dc46a01bf8d62f227d5ecee74178ffc448ff4e5197c756331f71efcc66dc980f", size = 1897823 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/f4/f3/aa5976e8352b7695ff808599794b1fba2a9ae2ee954a3426855935799488/pydantic_core-2.33.2-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:a144d4f717285c6d9234a66778059f33a89096dfb9b39117663fd8413d582dcc", size = 1983792 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/d5/7a/cda9b5a23c552037717f2b2a5257e9b2bfe45e687386df9591eff7b46d28/pydantic_core-2.33.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:73cf6373c21bc80b2e0dc88444f41ae60b2f070ed02095754eb5a01df12256de", size = 2136338 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/2b/9f/b8f9ec8dd1417eb9da784e91e1667d58a2a4a7b7b34cf4af765ef663a7e5/pydantic_core-2.33.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3dc625f4aa79713512d1976fe9f0bc99f706a9dee21dfd1810b4bbbf228d0e8a", size = 2730998 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/47/bc/cd720e078576bdb8255d5032c5d63ee5c0bf4b7173dd955185a1d658c456/pydantic_core-2.33.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:881b21b5549499972441da4758d662aeea93f1923f953e9cbaff14b8b9565aef", size = 2003200 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/ca/22/3602b895ee2cd29d11a2b349372446ae9727c32e78a94b3d588a40fdf187/pydantic_core-2.33.2-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:bdc25f3681f7b78572699569514036afe3c243bc3059d3942624e936ec93450e", size = 2113890 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/ff/e6/e3c5908c03cf00d629eb38393a98fccc38ee0ce8ecce32f69fc7d7b558a7/pydantic_core-2.33.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:fe5b32187cbc0c862ee201ad66c30cf218e5ed468ec8dc1cf49dec66e160cc4d", size = 2073359 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/12/e7/6a36a07c59ebefc8777d1ffdaf5ae71b06b21952582e4b07eba88a421c79/pydantic_core-2.33.2-cp311-cp311-musllinux_1_1_armv7l.whl", hash = "sha256:bc7aee6f634a6f4a95676fcb5d6559a2c2a390330098dba5e5a5f28a2e4ada30", size = 2245883 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/16/3f/59b3187aaa6cc0c1e6616e8045b284de2b6a87b027cce2ffcea073adf1d2/pydantic_core-2.33.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:235f45e5dbcccf6bd99f9f472858849f73d11120d76ea8707115415f8e5ebebf", size = 2241074 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/e0/ed/55532bb88f674d5d8f67ab121a2a13c385df382de2a1677f30ad385f7438/pydantic_core-2.33.2-cp311-cp311-win32.whl", hash = "sha256:6368900c2d3ef09b69cb0b913f9f8263b03786e5b2a387706c5afb66800efd51", size = 1910538 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/fe/1b/25b7cccd4519c0b23c2dd636ad39d381abf113085ce4f7bec2b0dc755eb1/pydantic_core-2.33.2-cp311-cp311-win_amd64.whl", hash = "sha256:1e063337ef9e9820c77acc768546325ebe04ee38b08703244c1309cccc4f1bab", size = 1952909 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/49/a9/d809358e49126438055884c4366a1f6227f0f84f635a9014e2deb9b9de54/pydantic_core-2.33.2-cp311-cp311-win_arm64.whl", hash = "sha256:6b99022f1d19bc32a4c2a0d544fc9a76e3be90f0b3f4af413f87d38749300e65", size = 1897786 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/18/8a/2b41c97f554ec8c71f2a8a5f85cb56a8b0956addfe8b0efb5b3d77e8bdc3/pydantic_core-2.33.2-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:a7ec89dc587667f22b6a0b6579c249fca9026ce7c333fc142ba42411fa243cdc", size = 2009000 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/a1/02/6224312aacb3c8ecbaa959897af57181fb6cf3a3d7917fd44d0f2917e6f2/pydantic_core-2.33.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:3c6db6e52c6d70aa0d00d45cdb9b40f0433b96380071ea80b09277dba021ddf7", size = 1847996 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/d6/46/6dcdf084a523dbe0a0be59d054734b86a981726f221f4562aed313dbcb49/pydantic_core-2.33.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4e61206137cbc65e6d5256e1166f88331d3b6238e082d9f74613b9b765fb9025", size = 1880957 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/ec/6b/1ec2c03837ac00886ba8160ce041ce4e325b41d06a034adbef11339ae422/pydantic_core-2.33.2-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:eb8c529b2819c37140eb51b914153063d27ed88e3bdc31b71198a198e921e011", size = 1964199 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/2d/1d/6bf34d6adb9debd9136bd197ca72642203ce9aaaa85cfcbfcf20f9696e83/pydantic_core-2.33.2-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c52b02ad8b4e2cf14ca7b3d918f3eb0ee91e63b3167c32591e57c4317e134f8f", size = 2120296 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/e0/94/2bd0aaf5a591e974b32a9f7123f16637776c304471a0ab33cf263cf5591a/pydantic_core-2.33.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:96081f1605125ba0855dfda83f6f3df5ec90c61195421ba72223de35ccfb2f88", size = 2676109 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/f9/41/4b043778cf9c4285d59742281a769eac371b9e47e35f98ad321349cc5d61/pydantic_core-2.33.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8f57a69461af2a5fa6e6bbd7a5f60d3b7e6cebb687f55106933188e79ad155c1", size = 2002028 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/cb/d5/7bb781bf2748ce3d03af04d5c969fa1308880e1dca35a9bd94e1a96a922e/pydantic_core-2.33.2-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:572c7e6c8bb4774d2ac88929e3d1f12bc45714ae5ee6d9a788a9fb35e60bb04b", size = 2100044 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/fe/36/def5e53e1eb0ad896785702a5bbfd25eed546cdcf4087ad285021a90ed53/pydantic_core-2.33.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:db4b41f9bd95fbe5acd76d89920336ba96f03e149097365afe1cb092fceb89a1", size = 2058881 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/01/6c/57f8d70b2ee57fc3dc8b9610315949837fa8c11d86927b9bb044f8705419/pydantic_core-2.33.2-cp312-cp312-musllinux_1_1_armv7l.whl", hash = "sha256:fa854f5cf7e33842a892e5c73f45327760bc7bc516339fda888c75ae60edaeb6", size = 2227034 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/27/b9/9c17f0396a82b3d5cbea4c24d742083422639e7bb1d5bf600e12cb176a13/pydantic_core-2.33.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:5f483cfb75ff703095c59e365360cb73e00185e01aaea067cd19acffd2ab20ea", size = 2234187 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/b0/6a/adf5734ffd52bf86d865093ad70b2ce543415e0e356f6cacabbc0d9ad910/pydantic_core-2.33.2-cp312-cp312-win32.whl", hash = "sha256:9cb1da0f5a471435a7bc7e439b8a728e8b61e59784b2af70d7c169f8dd8ae290", size = 1892628 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/43/e4/5479fecb3606c1368d496a825d8411e126133c41224c1e7238be58b87d7e/pydantic_core-2.33.2-cp312-cp312-win_amd64.whl", hash = "sha256:f941635f2a3d96b2973e867144fde513665c87f13fe0e193c158ac51bfaaa7b2", size = 1955866 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/0d/24/8b11e8b3e2be9dd82df4b11408a67c61bb4dc4f8e11b5b0fc888b38118b5/pydantic_core-2.33.2-cp312-cp312-win_arm64.whl", hash = "sha256:cca3868ddfaccfbc4bfb1d608e2ccaaebe0ae628e1416aeb9c4d88c001bb45ab", size = 1888894 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/46/8c/99040727b41f56616573a28771b1bfa08a3d3fe74d3d513f01251f79f172/pydantic_core-2.33.2-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:1082dd3e2d7109ad8b7da48e1d4710c8d06c253cbc4a27c1cff4fbcaa97a9e3f", size = 2015688 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/3a/cc/5999d1eb705a6cefc31f0b4a90e9f7fc400539b1a1030529700cc1b51838/pydantic_core-2.33.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:f517ca031dfc037a9c07e748cefd8d96235088b83b4f4ba8939105d20fa1dcd6", size = 1844808 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/6f/5e/a0a7b8885c98889a18b6e376f344da1ef323d270b44edf8174d6bce4d622/pydantic_core-2.33.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0a9f2c9dd19656823cb8250b0724ee9c60a82f3cdf68a080979d13092a3b0fef", size = 1885580 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/3b/2a/953581f343c7d11a304581156618c3f592435523dd9d79865903272c256a/pydantic_core-2.33.2-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:2b0a451c263b01acebe51895bfb0e1cc842a5c666efe06cdf13846c7418caa9a", size = 1973859 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/e6/55/f1a813904771c03a3f97f676c62cca0c0a4138654107c1b61f19c644868b/pydantic_core-2.33.2-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1ea40a64d23faa25e62a70ad163571c0b342b8bf66d5fa612ac0dec4f069d916", size = 2120810 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/aa/c3/053389835a996e18853ba107a63caae0b9deb4a276c6b472931ea9ae6e48/pydantic_core-2.33.2-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0fb2d542b4d66f9470e8065c5469ec676978d625a8b7a363f07d9a501a9cb36a", size = 2676498 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/eb/3c/f4abd740877a35abade05e437245b192f9d0ffb48bbbbd708df33d3cda37/pydantic_core-2.33.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9fdac5d6ffa1b5a83bca06ffe7583f5576555e6c8b3a91fbd25ea7780f825f7d", size = 2000611 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/59/a7/63ef2fed1837d1121a894d0ce88439fe3e3b3e48c7543b2a4479eb99c2bd/pydantic_core-2.33.2-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:04a1a413977ab517154eebb2d326da71638271477d6ad87a769102f7c2488c56", size = 2107924 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/04/8f/2551964ef045669801675f1cfc3b0d74147f4901c3ffa42be2ddb1f0efc4/pydantic_core-2.33.2-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:c8e7af2f4e0194c22b5b37205bfb293d166a7344a5b0d0eaccebc376546d77d5", size = 2063196 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/26/bd/d9602777e77fc6dbb0c7db9ad356e9a985825547dce5ad1d30ee04903918/pydantic_core-2.33.2-cp313-cp313-musllinux_1_1_armv7l.whl", hash = "sha256:5c92edd15cd58b3c2d34873597a1e20f13094f59cf88068adb18947df5455b4e", size = 2236389 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/42/db/0e950daa7e2230423ab342ae918a794964b053bec24ba8af013fc7c94846/pydantic_core-2.33.2-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:65132b7b4a1c0beded5e057324b7e16e10910c106d43675d9bd87d4f38dde162", size = 2239223 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/58/4d/4f937099c545a8a17eb52cb67fe0447fd9a373b348ccfa9a87f141eeb00f/pydantic_core-2.33.2-cp313-cp313-win32.whl", hash = "sha256:52fb90784e0a242bb96ec53f42196a17278855b0f31ac7c3cc6f5c1ec4811849", size = 1900473 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/a0/75/4a0a9bac998d78d889def5e4ef2b065acba8cae8c93696906c3a91f310ca/pydantic_core-2.33.2-cp313-cp313-win_amd64.whl", hash = "sha256:c083a3bdd5a93dfe480f1125926afcdbf2917ae714bdb80b36d34318b2bec5d9", size = 1955269 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/f9/86/1beda0576969592f1497b4ce8e7bc8cbdf614c352426271b1b10d5f0aa64/pydantic_core-2.33.2-cp313-cp313-win_arm64.whl", hash = "sha256:e80b087132752f6b3d714f041ccf74403799d3b23a72722ea2e6ba2e892555b9", size = 1893921 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/a4/7d/e09391c2eebeab681df2b74bfe6c43422fffede8dc74187b2b0bf6fd7571/pydantic_core-2.33.2-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:61c18fba8e5e9db3ab908620af374db0ac1baa69f0f32df4f61ae23f15e586ac", size = 1806162 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/f1/3d/847b6b1fed9f8ed3bb95a9ad04fbd0b212e832d4f0f50ff4d9ee5a9f15cf/pydantic_core-2.33.2-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:95237e53bb015f67b63c91af7518a62a8660376a6a0db19b89acc77a4d6199f5", size = 1981560 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/6f/9a/e73262f6c6656262b5fdd723ad90f518f579b7bc8622e43a942eec53c938/pydantic_core-2.33.2-cp313-cp313t-win_amd64.whl", hash = "sha256:c2fc0a768ef76c15ab9238afa6da7f69895bb5d1ee83aeea2e3509af4472d0b9", size = 1935777 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/30/68/373d55e58b7e83ce371691f6eaa7175e3a24b956c44628eb25d7da007917/pydantic_core-2.33.2-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:5c4aa4e82353f65e548c476b37e64189783aa5384903bfea4f41580f255fddfa", size = 2023982 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/a4/16/145f54ac08c96a63d8ed6442f9dec17b2773d19920b627b18d4f10a061ea/pydantic_core-2.33.2-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:d946c8bf0d5c24bf4fe333af284c59a19358aa3ec18cb3dc4370080da1e8ad29", size = 1858412 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/41/b1/c6dc6c3e2de4516c0bb2c46f6a373b91b5660312342a0cf5826e38ad82fa/pydantic_core-2.33.2-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:87b31b6846e361ef83fedb187bb5b4372d0da3f7e28d85415efa92d6125d6e6d", size = 1892749 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/12/73/8cd57e20afba760b21b742106f9dbdfa6697f1570b189c7457a1af4cd8a0/pydantic_core-2.33.2-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:aa9d91b338f2df0508606f7009fde642391425189bba6d8c653afd80fd6bb64e", size = 2067527 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/e3/d5/0bb5d988cc019b3cba4a78f2d4b3854427fc47ee8ec8e9eaabf787da239c/pydantic_core-2.33.2-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:2058a32994f1fde4ca0480ab9d1e75a0e8c87c22b53a3ae66554f9af78f2fe8c", size = 2108225 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/f1/c5/00c02d1571913d496aabf146106ad8239dc132485ee22efe08085084ff7c/pydantic_core-2.33.2-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:0e03262ab796d986f978f79c943fc5f620381be7287148b8010b4097f79a39ec", size = 2069490 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/22/a8/dccc38768274d3ed3a59b5d06f59ccb845778687652daa71df0cab4040d7/pydantic_core-2.33.2-pp310-pypy310_pp73-musllinux_1_1_armv7l.whl", hash = "sha256:1a8695a8d00c73e50bff9dfda4d540b7dee29ff9b8053e38380426a85ef10052", size = 2237525 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/d4/e7/4f98c0b125dda7cf7ccd14ba936218397b44f50a56dd8c16a3091df116c3/pydantic_core-2.33.2-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:fa754d1850735a0b0e03bcffd9d4b4343eb417e47196e4485d9cca326073a42c", size = 2238446 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/ce/91/2ec36480fdb0b783cd9ef6795753c1dea13882f2e68e73bce76ae8c21e6a/pydantic_core-2.33.2-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:a11c8d26a50bfab49002947d3d237abe4d9e4b5bdc8846a63537b6488e197808", size = 2066678 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/7b/27/d4ae6487d73948d6f20dddcd94be4ea43e74349b56eba82e9bdee2d7494c/pydantic_core-2.33.2-pp311-pypy311_pp73-macosx_10_12_x86_64.whl", hash = "sha256:dd14041875d09cc0f9308e37a6f8b65f5585cf2598a53aa0123df8b129d481f8", size = 2025200 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/f1/b8/b3cb95375f05d33801024079b9392a5ab45267a63400bf1866e7ce0f0de4/pydantic_core-2.33.2-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:d87c561733f66531dced0da6e864f44ebf89a8fba55f31407b00c2f7f9449593", size = 1859123 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/05/bc/0d0b5adeda59a261cd30a1235a445bf55c7e46ae44aea28f7bd6ed46e091/pydantic_core-2.33.2-pp311-pypy311_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2f82865531efd18d6e07a04a17331af02cb7a651583c418df8266f17a63c6612", size = 1892852 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/3e/11/d37bdebbda2e449cb3f519f6ce950927b56d62f0b84fd9cb9e372a26a3d5/pydantic_core-2.33.2-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2bfb5112df54209d820d7bf9317c7a6c9025ea52e49f46b6a2060104bba37de7", size = 2067484 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/8c/55/1f95f0a05ce72ecb02a8a8a1c3be0579bbc29b1d5ab68f1378b7bebc5057/pydantic_core-2.33.2-pp311-pypy311_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:64632ff9d614e5eecfb495796ad51b0ed98c453e447a76bcbeeb69615079fc7e", size = 2108896 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/53/89/2b2de6c81fa131f423246a9109d7b2a375e83968ad0800d6e57d0574629b/pydantic_core-2.33.2-pp311-pypy311_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:f889f7a40498cc077332c7ab6b4608d296d852182211787d4f3ee377aaae66e8", size = 2069475 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/b8/e9/1f7efbe20d0b2b10f6718944b5d8ece9152390904f29a78e68d4e7961159/pydantic_core-2.33.2-pp311-pypy311_pp73-musllinux_1_1_armv7l.whl", hash = "sha256:de4b83bb311557e439b9e186f733f6c645b9417c84e2eb8203f3f820a4b988bf", size = 2239013 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/3c/b2/5309c905a93811524a49b4e031e9851a6b00ff0fb668794472ea7746b448/pydantic_core-2.33.2-pp311-pypy311_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:82f68293f055f51b51ea42fafc74b6aad03e70e191799430b90c13d643059ebb", size = 2238715 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/32/56/8a7ca5d2cd2cda1d245d34b1c9a942920a718082ae8e54e5f3e5a58b7add/pydantic_core-2.33.2-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:329467cecfb529c925cf2bbd4d60d2c509bc2fb52a20c1045bf09bb70971a9c1", size = 2066757 }, +] + +[[package]] +name = "requests" +version = "2.32.3" +source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" } +dependencies = [ + { name = "certifi" }, + { name = "charset-normalizer" }, + { name = "idna" }, + { name = "urllib3" }, +] +sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/63/70/2bf7780ad2d390a8d301ad0b550f1581eadbd9a20f896afe06353c2a2913/requests-2.32.3.tar.gz", hash = "sha256:55365417734eb18255590a9ff9eb97e9e1da868d4ccd6402399eaf68af20a760", size = 131218 } +wheels = [ + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/f9/9b/335f9764261e915ed497fcdeb11df5dfd6f7bf257d4a6a2a686d80da4d54/requests-2.32.3-py3-none-any.whl", hash = "sha256:70761cfe03c773ceb22aa2f671b4757976145175cdfca038c02654d061d6dcc6", size = 64928 }, +] + +[[package]] +name = "slowapi" +version = "0.1.9" +source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" } +dependencies = [ + { name = "limits" }, +] +sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/a0/99/adfc7f94ca024736f061257d39118e1542bade7a52e86415a4c4ae92d8ff/slowapi-0.1.9.tar.gz", hash = "sha256:639192d0f1ca01b1c6d95bf6c71d794c3a9ee189855337b4821f7f457dddad77", size = 14028 } +wheels = [ + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/2b/bb/f71c4b7d7e7eb3fc1e8c0458a8979b912f40b58002b9fbf37729b8cb464b/slowapi-0.1.9-py3-none-any.whl", hash = "sha256:cfad116cfb84ad9d763ee155c1e5c5cbf00b0d47399a769b227865f5df576e36", size = 14670 }, +] + +[[package]] +name = "sniffio" +version = "1.3.1" +source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" } +sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/a2/87/a6771e1546d97e7e041b6ae58d80074f81b7d5121207425c964ddf5cfdbd/sniffio-1.3.1.tar.gz", hash = "sha256:f4324edc670a0f49750a81b895f35c3adb843cca46f0530f79fc1babb23789dc", size = 20372 } +wheels = [ + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/e9/44/75a9c9421471a6c4805dbf2356f7c181a29c1879239abab1ea2cc8f38b40/sniffio-1.3.1-py3-none-any.whl", hash = "sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2", size = 10235 }, +] + +[[package]] +name = "starlette" +version = "0.46.2" +source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" } +dependencies = [ + { name = "anyio" }, +] +sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/ce/20/08dfcd9c983f6a6f4a1000d934b9e6d626cff8d2eeb77a89a68eef20a2b7/starlette-0.46.2.tar.gz", hash = "sha256:7f7361f34eed179294600af672f565727419830b54b7b084efe44bb82d2fccd5", size = 2580846 } +wheels = [ + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/8b/0c/9d30a4ebeb6db2b25a841afbb80f6ef9a854fc3b41be131d249a977b4959/starlette-0.46.2-py3-none-any.whl", hash = "sha256:595633ce89f8ffa71a015caed34a5b2dc1c0cdb3f0f1fbd1e69339cf2abeec35", size = 72037 }, +] + +[[package]] +name = "typing-extensions" +version = "4.13.2" +source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" } +sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/f6/37/23083fcd6e35492953e8d2aaaa68b860eb422b34627b13f2ce3eb6106061/typing_extensions-4.13.2.tar.gz", hash = "sha256:e6c81219bd689f51865d9e372991c540bda33a0379d5573cddb9a3a23f7caaef", size = 106967 } +wheels = [ + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/8b/54/b1ae86c0973cc6f0210b53d508ca3641fb6d0c56823f288d108bc7ab3cc8/typing_extensions-4.13.2-py3-none-any.whl", hash = "sha256:a439e7c04b49fec3e5d3e2beaa21755cadbbdc391694e28ccdd36ca4a1408f8c", size = 45806 }, +] + +[[package]] +name = "typing-inspection" +version = "0.4.0" +source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" } +dependencies = [ + { name = "typing-extensions" }, +] +sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/82/5c/e6082df02e215b846b4b8c0b887a64d7d08ffaba30605502639d44c06b82/typing_inspection-0.4.0.tar.gz", hash = "sha256:9765c87de36671694a67904bf2c96e395be9c6439bb6c87b5142569dcdd65122", size = 76222 } +wheels = [ + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/31/08/aa4fdfb71f7de5176385bd9e90852eaf6b5d622735020ad600f2bab54385/typing_inspection-0.4.0-py3-none-any.whl", hash = "sha256:50e72559fcd2a6367a19f7a7e610e6afcb9fac940c650290eed893d61386832f", size = 14125 }, +] + +[[package]] +name = "urllib3" +version = "2.4.0" +source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" } +sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/8a/78/16493d9c386d8e60e442a35feac5e00f0913c0f4b7c217c11e8ec2ff53e0/urllib3-2.4.0.tar.gz", hash = "sha256:414bc6535b787febd7567804cc015fee39daab8ad86268f1310a9250697de466", size = 390672 } +wheels = [ + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/6b/11/cc635220681e93a0183390e26485430ca2c7b5f9d33b15c74c2861cb8091/urllib3-2.4.0-py3-none-any.whl", hash = "sha256:4e16665048960a0900c702d4a66415956a584919c03361cac9f1df5c5dd7e813", size = 128680 }, +] + +[[package]] +name = "uvicorn" +version = "0.34.2" +source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" } +dependencies = [ + { name = "click" }, + { name = "h11" }, + { name = "typing-extensions", marker = "python_full_version < '3.11'" }, +] +sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/a6/ae/9bbb19b9e1c450cf9ecaef06463e40234d98d95bf572fab11b4f19ae5ded/uvicorn-0.34.2.tar.gz", hash = "sha256:0e929828f6186353a80b58ea719861d2629d766293b6d19baf086ba31d4f3328", size = 76815 } +wheels = [ + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/b1/4b/4cef6ce21a2aaca9d852a6e84ef4f135d99fcd74fa75105e2fc0c8308acd/uvicorn-0.34.2-py3-none-any.whl", hash = "sha256:deb49af569084536d269fe0a6d67e3754f104cf03aba7c11c40f01aadf33c403", size = 62483 }, +] + +[[package]] +name = "wrapt" +version = "1.17.2" +source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" } +sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/c3/fc/e91cc220803d7bc4db93fb02facd8461c37364151b8494762cc88b0fbcef/wrapt-1.17.2.tar.gz", hash = "sha256:41388e9d4d1522446fe79d3213196bd9e3b301a336965b9e27ca2788ebd122f3", size = 55531 } +wheels = [ + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/5a/d1/1daec934997e8b160040c78d7b31789f19b122110a75eca3d4e8da0049e1/wrapt-1.17.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:3d57c572081fed831ad2d26fd430d565b76aa277ed1d30ff4d40670b1c0dd984", size = 53307 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/1b/7b/13369d42651b809389c1a7153baa01d9700430576c81a2f5c5e460df0ed9/wrapt-1.17.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:b5e251054542ae57ac7f3fba5d10bfff615b6c2fb09abeb37d2f1463f841ae22", size = 38486 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/62/bf/e0105016f907c30b4bd9e377867c48c34dc9c6c0c104556c9c9126bd89ed/wrapt-1.17.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:80dd7db6a7cb57ffbc279c4394246414ec99537ae81ffd702443335a61dbf3a7", size = 38777 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/27/70/0f6e0679845cbf8b165e027d43402a55494779295c4b08414097b258ac87/wrapt-1.17.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0a6e821770cf99cc586d33833b2ff32faebdbe886bd6322395606cf55153246c", size = 83314 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/0f/77/0576d841bf84af8579124a93d216f55d6f74374e4445264cb378a6ed33eb/wrapt-1.17.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b60fb58b90c6d63779cb0c0c54eeb38941bae3ecf7a73c764c52c88c2dcb9d72", size = 74947 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/90/ec/00759565518f268ed707dcc40f7eeec38637d46b098a1f5143bff488fe97/wrapt-1.17.2-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b870b5df5b71d8c3359d21be8f0d6c485fa0ebdb6477dda51a1ea54a9b558061", size = 82778 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/f8/5a/7cffd26b1c607b0b0c8a9ca9d75757ad7620c9c0a9b4a25d3f8a1480fafc/wrapt-1.17.2-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:4011d137b9955791f9084749cba9a367c68d50ab8d11d64c50ba1688c9b457f2", size = 81716 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/7e/09/dccf68fa98e862df7e6a60a61d43d644b7d095a5fc36dbb591bbd4a1c7b2/wrapt-1.17.2-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:1473400e5b2733e58b396a04eb7f35f541e1fb976d0c0724d0223dd607e0f74c", size = 74548 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/b7/8e/067021fa3c8814952c5e228d916963c1115b983e21393289de15128e867e/wrapt-1.17.2-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:3cedbfa9c940fdad3e6e941db7138e26ce8aad38ab5fe9dcfadfed9db7a54e62", size = 81334 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/4b/0d/9d4b5219ae4393f718699ca1c05f5ebc0c40d076f7e65fd48f5f693294fb/wrapt-1.17.2-cp310-cp310-win32.whl", hash = "sha256:582530701bff1dec6779efa00c516496968edd851fba224fbd86e46cc6b73563", size = 36427 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/72/6a/c5a83e8f61aec1e1aeef939807602fb880e5872371e95df2137142f5c58e/wrapt-1.17.2-cp310-cp310-win_amd64.whl", hash = "sha256:58705da316756681ad3c9c73fd15499aa4d8c69f9fd38dc8a35e06c12468582f", size = 38774 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/cd/f7/a2aab2cbc7a665efab072344a8949a71081eed1d2f451f7f7d2b966594a2/wrapt-1.17.2-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:ff04ef6eec3eee8a5efef2401495967a916feaa353643defcc03fc74fe213b58", size = 53308 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/50/ff/149aba8365fdacef52b31a258c4dc1c57c79759c335eff0b3316a2664a64/wrapt-1.17.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4db983e7bca53819efdbd64590ee96c9213894272c776966ca6306b73e4affda", size = 38488 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/65/46/5a917ce85b5c3b490d35c02bf71aedaa9f2f63f2d15d9949cc4ba56e8ba9/wrapt-1.17.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:9abc77a4ce4c6f2a3168ff34b1da9b0f311a8f1cfd694ec96b0603dff1c79438", size = 38776 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/ca/74/336c918d2915a4943501c77566db41d1bd6e9f4dbc317f356b9a244dfe83/wrapt-1.17.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0b929ac182f5ace000d459c59c2c9c33047e20e935f8e39371fa6e3b85d56f4a", size = 83776 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/09/99/c0c844a5ccde0fe5761d4305485297f91d67cf2a1a824c5f282e661ec7ff/wrapt-1.17.2-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f09b286faeff3c750a879d336fb6d8713206fc97af3adc14def0cdd349df6000", size = 75420 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/b4/b0/9fc566b0fe08b282c850063591a756057c3247b2362b9286429ec5bf1721/wrapt-1.17.2-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1a7ed2d9d039bd41e889f6fb9364554052ca21ce823580f6a07c4ec245c1f5d6", size = 83199 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/9d/4b/71996e62d543b0a0bd95dda485219856def3347e3e9380cc0d6cf10cfb2f/wrapt-1.17.2-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:129a150f5c445165ff941fc02ee27df65940fcb8a22a61828b1853c98763a64b", size = 82307 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/39/35/0282c0d8789c0dc9bcc738911776c762a701f95cfe113fb8f0b40e45c2b9/wrapt-1.17.2-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:1fb5699e4464afe5c7e65fa51d4f99e0b2eadcc176e4aa33600a3df7801d6662", size = 75025 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/4f/6d/90c9fd2c3c6fee181feecb620d95105370198b6b98a0770cba090441a828/wrapt-1.17.2-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:9a2bce789a5ea90e51a02dfcc39e31b7f1e662bc3317979aa7e5538e3a034f72", size = 81879 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/8f/fa/9fb6e594f2ce03ef03eddbdb5f4f90acb1452221a5351116c7c4708ac865/wrapt-1.17.2-cp311-cp311-win32.whl", hash = "sha256:4afd5814270fdf6380616b321fd31435a462019d834f83c8611a0ce7484c7317", size = 36419 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/47/f8/fb1773491a253cbc123c5d5dc15c86041f746ed30416535f2a8df1f4a392/wrapt-1.17.2-cp311-cp311-win_amd64.whl", hash = "sha256:acc130bc0375999da18e3d19e5a86403667ac0c4042a094fefb7eec8ebac7cf3", size = 38773 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/a1/bd/ab55f849fd1f9a58ed7ea47f5559ff09741b25f00c191231f9f059c83949/wrapt-1.17.2-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:d5e2439eecc762cd85e7bd37161d4714aa03a33c5ba884e26c81559817ca0925", size = 53799 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/53/18/75ddc64c3f63988f5a1d7e10fb204ffe5762bc663f8023f18ecaf31a332e/wrapt-1.17.2-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:3fc7cb4c1c744f8c05cd5f9438a3caa6ab94ce8344e952d7c45a8ed59dd88392", size = 38821 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/48/2a/97928387d6ed1c1ebbfd4efc4133a0633546bec8481a2dd5ec961313a1c7/wrapt-1.17.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:8fdbdb757d5390f7c675e558fd3186d590973244fab0c5fe63d373ade3e99d40", size = 38919 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/73/54/3bfe5a1febbbccb7a2f77de47b989c0b85ed3a6a41614b104204a788c20e/wrapt-1.17.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5bb1d0dbf99411f3d871deb6faa9aabb9d4e744d67dcaaa05399af89d847a91d", size = 88721 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/25/cb/7262bc1b0300b4b64af50c2720ef958c2c1917525238d661c3e9a2b71b7b/wrapt-1.17.2-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d18a4865f46b8579d44e4fe1e2bcbc6472ad83d98e22a26c963d46e4c125ef0b", size = 80899 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/2a/5a/04cde32b07a7431d4ed0553a76fdb7a61270e78c5fd5a603e190ac389f14/wrapt-1.17.2-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc570b5f14a79734437cb7b0500376b6b791153314986074486e0b0fa8d71d98", size = 89222 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/09/28/2e45a4f4771fcfb109e244d5dbe54259e970362a311b67a965555ba65026/wrapt-1.17.2-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:6d9187b01bebc3875bac9b087948a2bccefe464a7d8f627cf6e48b1bbae30f82", size = 86707 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/c6/d2/dcb56bf5f32fcd4bd9aacc77b50a539abdd5b6536872413fd3f428b21bed/wrapt-1.17.2-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:9e8659775f1adf02eb1e6f109751268e493c73716ca5761f8acb695e52a756ae", size = 79685 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/80/4e/eb8b353e36711347893f502ce91c770b0b0929f8f0bed2670a6856e667a9/wrapt-1.17.2-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:e8b2816ebef96d83657b56306152a93909a83f23994f4b30ad4573b00bd11bb9", size = 87567 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/17/27/4fe749a54e7fae6e7146f1c7d914d28ef599dacd4416566c055564080fe2/wrapt-1.17.2-cp312-cp312-win32.whl", hash = "sha256:468090021f391fe0056ad3e807e3d9034e0fd01adcd3bdfba977b6fdf4213ea9", size = 36672 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/15/06/1dbf478ea45c03e78a6a8c4be4fdc3c3bddea5c8de8a93bc971415e47f0f/wrapt-1.17.2-cp312-cp312-win_amd64.whl", hash = "sha256:ec89ed91f2fa8e3f52ae53cd3cf640d6feff92ba90d62236a81e4e563ac0e991", size = 38865 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/ce/b9/0ffd557a92f3b11d4c5d5e0c5e4ad057bd9eb8586615cdaf901409920b14/wrapt-1.17.2-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:6ed6ffac43aecfe6d86ec5b74b06a5be33d5bb9243d055141e8cabb12aa08125", size = 53800 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/c0/ef/8be90a0b7e73c32e550c73cfb2fa09db62234227ece47b0e80a05073b375/wrapt-1.17.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:35621ae4c00e056adb0009f8e86e28eb4a41a4bfa8f9bfa9fca7d343fe94f998", size = 38824 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/36/89/0aae34c10fe524cce30fe5fc433210376bce94cf74d05b0d68344c8ba46e/wrapt-1.17.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:a604bf7a053f8362d27eb9fefd2097f82600b856d5abe996d623babd067b1ab5", size = 38920 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/3b/24/11c4510de906d77e0cfb5197f1b1445d4fec42c9a39ea853d482698ac681/wrapt-1.17.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5cbabee4f083b6b4cd282f5b817a867cf0b1028c54d445b7ec7cfe6505057cf8", size = 88690 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/71/d7/cfcf842291267bf455b3e266c0c29dcb675b5540ee8b50ba1699abf3af45/wrapt-1.17.2-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:49703ce2ddc220df165bd2962f8e03b84c89fee2d65e1c24a7defff6f988f4d6", size = 80861 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/d5/66/5d973e9f3e7370fd686fb47a9af3319418ed925c27d72ce16b791231576d/wrapt-1.17.2-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8112e52c5822fc4253f3901b676c55ddf288614dc7011634e2719718eaa187dc", size = 89174 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/a7/d3/8e17bb70f6ae25dabc1aaf990f86824e4fd98ee9cadf197054e068500d27/wrapt-1.17.2-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:9fee687dce376205d9a494e9c121e27183b2a3df18037f89d69bd7b35bcf59e2", size = 86721 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/6f/54/f170dfb278fe1c30d0ff864513cff526d624ab8de3254b20abb9cffedc24/wrapt-1.17.2-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:18983c537e04d11cf027fbb60a1e8dfd5190e2b60cc27bc0808e653e7b218d1b", size = 79763 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/4a/98/de07243751f1c4a9b15c76019250210dd3486ce098c3d80d5f729cba029c/wrapt-1.17.2-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:703919b1633412ab54bcf920ab388735832fdcb9f9a00ae49387f0fe67dad504", size = 87585 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/f9/f0/13925f4bd6548013038cdeb11ee2cbd4e37c30f8bfd5db9e5a2a370d6e20/wrapt-1.17.2-cp313-cp313-win32.whl", hash = "sha256:abbb9e76177c35d4e8568e58650aa6926040d6a9f6f03435b7a522bf1c487f9a", size = 36676 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/bf/ae/743f16ef8c2e3628df3ddfd652b7d4c555d12c84b53f3d8218498f4ade9b/wrapt-1.17.2-cp313-cp313-win_amd64.whl", hash = "sha256:69606d7bb691b50a4240ce6b22ebb319c1cfb164e5f6569835058196e0f3a845", size = 38871 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/3d/bc/30f903f891a82d402ffb5fda27ec1d621cc97cb74c16fea0b6141f1d4e87/wrapt-1.17.2-cp313-cp313t-macosx_10_13_universal2.whl", hash = "sha256:4a721d3c943dae44f8e243b380cb645a709ba5bd35d3ad27bc2ed947e9c68192", size = 56312 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/8a/04/c97273eb491b5f1c918857cd26f314b74fc9b29224521f5b83f872253725/wrapt-1.17.2-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:766d8bbefcb9e00c3ac3b000d9acc51f1b399513f44d77dfe0eb026ad7c9a19b", size = 40062 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/4e/ca/3b7afa1eae3a9e7fefe499db9b96813f41828b9fdb016ee836c4c379dadb/wrapt-1.17.2-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:e496a8ce2c256da1eb98bd15803a79bee00fc351f5dfb9ea82594a3f058309e0", size = 40155 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/89/be/7c1baed43290775cb9030c774bc53c860db140397047cc49aedaf0a15477/wrapt-1.17.2-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:40d615e4fe22f4ad3528448c193b218e077656ca9ccb22ce2cb20db730f8d306", size = 113471 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/32/98/4ed894cf012b6d6aae5f5cc974006bdeb92f0241775addad3f8cd6ab71c8/wrapt-1.17.2-cp313-cp313t-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a5aaeff38654462bc4b09023918b7f21790efb807f54c000a39d41d69cf552cb", size = 101208 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/ea/fd/0c30f2301ca94e655e5e057012e83284ce8c545df7661a78d8bfca2fac7a/wrapt-1.17.2-cp313-cp313t-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9a7d15bbd2bc99e92e39f49a04653062ee6085c0e18b3b7512a4f2fe91f2d681", size = 109339 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/75/56/05d000de894c4cfcb84bcd6b1df6214297b8089a7bd324c21a4765e49b14/wrapt-1.17.2-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:e3890b508a23299083e065f435a492b5435eba6e304a7114d2f919d400888cc6", size = 110232 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/53/f8/c3f6b2cf9b9277fb0813418e1503e68414cd036b3b099c823379c9575e6d/wrapt-1.17.2-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:8c8b293cd65ad716d13d8dd3624e42e5a19cc2a2f1acc74b30c2c13f15cb61a6", size = 100476 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/a7/b1/0bb11e29aa5139d90b770ebbfa167267b1fc548d2302c30c8f7572851738/wrapt-1.17.2-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:4c82b8785d98cdd9fed4cac84d765d234ed3251bd6afe34cb7ac523cb93e8b4f", size = 106377 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/6a/e1/0122853035b40b3f333bbb25f1939fc1045e21dd518f7f0922b60c156f7c/wrapt-1.17.2-cp313-cp313t-win32.whl", hash = "sha256:13e6afb7fe71fe7485a4550a8844cc9ffbe263c0f1a1eea569bc7091d4898555", size = 37986 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/09/5e/1655cf481e079c1f22d0cabdd4e51733679932718dc23bf2db175f329b76/wrapt-1.17.2-cp313-cp313t-win_amd64.whl", hash = "sha256:eaf675418ed6b3b31c7a989fd007fa7c3be66ce14e5c3b27336383604c9da85c", size = 40750 }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/2d/82/f56956041adef78f849db6b289b282e72b55ab8045a75abad81898c28d19/wrapt-1.17.2-py3-none-any.whl", hash = "sha256:b18f2d1533a71f069c7f82d524a52599053d4c7166e9dd374ae2136b7f40f7c8", size = 23594 }, +] diff --git a/sdk/python/pyproject.toml b/sdk/python/pyproject.toml index 0faa307f46b..e15c14e8735 100644 --- a/sdk/python/pyproject.toml +++ b/sdk/python/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ragflow-sdk" -version = "0.19.0" +version = "0.19.1" description = "Python client sdk of [RAGFlow](https://github.com/infiniflow/ragflow). RAGFlow is an open-source RAG (Retrieval-Augmented Generation) engine based on deep document understanding." authors = [{ name = "Zhichang Yu", email = "yuzhichang@gmail.com" }] license = { text = "Apache License, Version 2.0" } diff --git a/sdk/python/ragflow_sdk/modules/base.py b/sdk/python/ragflow_sdk/modules/base.py index 9014dd02d84..6b958fb8d1a 100644 --- a/sdk/python/ragflow_sdk/modules/base.py +++ b/sdk/python/ragflow_sdk/modules/base.py @@ -14,9 +14,13 @@ # limitations under the License. # + class Base: def __init__(self, rag, res_dict): self.rag = rag + self._update_from_dict(rag, res_dict) + + def _update_from_dict(self, rag, res_dict): for k, v in res_dict.items(): if isinstance(v, dict): self.__dict__[k] = Base(rag, v) @@ -27,7 +31,7 @@ def to_json(self): pr = {} for name in dir(self): value = getattr(self, name) - if not name.startswith('__') and not callable(value) and name != "rag": + if not name.startswith("__") and not callable(value) and name != "rag": if isinstance(value, Base): pr[name] = value.to_json() else: @@ -35,7 +39,7 @@ def to_json(self): return pr def post(self, path, json=None, stream=False, files=None): - res = self.rag.post(path, json, stream=stream,files=files) + res = self.rag.post(path, json, stream=stream, files=files) return res def get(self, path, params=None): @@ -46,8 +50,8 @@ def rm(self, path, json): res = self.rag.delete(path, json) return res - def put(self,path, json): - res = self.rag.put(path,json) + def put(self, path, json): + res = self.rag.put(path, json) return res def __str__(self): diff --git a/sdk/python/ragflow_sdk/modules/chat.py b/sdk/python/ragflow_sdk/modules/chat.py index e85011504d3..01083b37d10 100644 --- a/sdk/python/ragflow_sdk/modules/chat.py +++ b/sdk/python/ragflow_sdk/modules/chat.py @@ -31,7 +31,7 @@ def __init__(self, rag, res_dict): class LLM(Base): def __init__(self, rag, res_dict): - self.model_name = "deepseek-chat" + self.model_name = None self.temperature = 0.1 self.top_p = 0.3 self.presence_penalty = 0.4 @@ -46,7 +46,7 @@ def __init__(self, rag, res_dict): self.top_n = 8 self.top_k = 1024 self.variables = [{"key": "knowledge", "optional": True}] - self.rerank_model = None + self.rerank_model = "" self.empty_response = None self.opener = "Hi! I'm your assistant, what can I do for you?" self.show_quote = True @@ -59,8 +59,7 @@ def __init__(self, rag, res_dict): super().__init__(rag, res_dict) def update(self, update_message: dict): - res = self.put(f'/chats/{self.id}', - update_message) + res = self.put(f"/chats/{self.id}", update_message) res = res.json() if res.get("code") != 0: raise Exception(res["message"]) @@ -69,13 +68,11 @@ def create_session(self, name: str = "New session") -> Session: res = self.post(f"/chats/{self.id}/sessions", {"name": name}) res = res.json() if res.get("code") == 0: - return Session(self.rag, res['data']) + return Session(self.rag, res["data"]) raise Exception(res["message"]) - def list_sessions(self, page: int = 1, page_size: int = 30, orderby: str = "create_time", desc: bool = True, - id: str = None, name: str = None) -> list[Session]: - res = self.get(f'/chats/{self.id}/sessions', - {"page": page, "page_size": page_size, "orderby": orderby, "desc": desc, "id": id, "name": name}) + def list_sessions(self, page: int = 1, page_size: int = 30, orderby: str = "create_time", desc: bool = True, id: str = None, name: str = None) -> list[Session]: + res = self.get(f"/chats/{self.id}/sessions", {"page": page, "page_size": page_size, "orderby": orderby, "desc": desc, "id": id, "name": name}) res = res.json() if res.get("code") == 0: result_list = [] diff --git a/sdk/python/ragflow_sdk/modules/dataset.py b/sdk/python/ragflow_sdk/modules/dataset.py index fdecde4a5e2..fc0bc8f5ba4 100644 --- a/sdk/python/ragflow_sdk/modules/dataset.py +++ b/sdk/python/ragflow_sdk/modules/dataset.py @@ -14,9 +14,8 @@ # limitations under the License. # -from .document import Document - from .base import Base +from .document import Document class DataSet(Base): @@ -43,12 +42,14 @@ def __init__(self, rag, res_dict): super().__init__(rag, res_dict) def update(self, update_message: dict): - res = self.put(f'/datasets/{self.id}', - update_message) + res = self.put(f"/datasets/{self.id}", update_message) res = res.json() if res.get("code") != 0: raise Exception(res["message"]) + self._update_from_dict(self.rag, res.get("data", {})) + return self + def upload_documents(self, document_list: list[dict]): url = f"/datasets/{self.id}/documents" files = [("file", (ele["display_name"], ele["blob"])) for ele in document_list] @@ -62,11 +63,8 @@ def upload_documents(self, document_list: list[dict]): return doc_list raise Exception(res.get("message")) - def list_documents(self, id: str | None = None, keywords: str | None = None, page: int = 1, page_size: int = 30, - orderby: str = "create_time", desc: bool = True): - res = self.get(f"/datasets/{self.id}/documents", - params={"id": id, "keywords": keywords, "page": page, "page_size": page_size, "orderby": orderby, - "desc": desc}) + def list_documents(self, id: str | None = None, name: str | None = None, keywords: str | None = None, page: int = 1, page_size: int = 30, orderby: str = "create_time", desc: bool = True): + res = self.get(f"/datasets/{self.id}/documents", params={"id": id, "name": name, "keywords": keywords, "page": page, "page_size": page_size, "orderby": orderby, "desc": desc}) res = res.json() documents = [] if res.get("code") == 0: diff --git a/sdk/python/ragflow_sdk/modules/document.py b/sdk/python/ragflow_sdk/modules/document.py index 187d089012f..cec8de161f2 100644 --- a/sdk/python/ragflow_sdk/modules/document.py +++ b/sdk/python/ragflow_sdk/modules/document.py @@ -15,6 +15,7 @@ # import json + from .base import Base from .chunk import Chunk @@ -40,7 +41,7 @@ def __init__(self, rag, res_dict): self.progress = 0.0 self.progress_msg = "" self.process_begin_at = None - self.process_duration = 0.0 + self.process_duation = 0.0 self.run = "0" self.status = "1" for k in list(res_dict.keys()): @@ -52,23 +53,30 @@ def update(self, update_message: dict): if "meta_fields" in update_message: if not isinstance(update_message["meta_fields"], dict): raise Exception("meta_fields must be a dictionary") - res = self.put(f'/datasets/{self.dataset_id}/documents/{self.id}', - update_message) + res = self.put(f"/datasets/{self.dataset_id}/documents/{self.id}", update_message) res = res.json() if res.get("code") != 0: raise Exception(res["message"]) + self._update_from_dict(self.rag, res.get("data", {})) + return self + def download(self): res = self.get(f"/datasets/{self.dataset_id}/documents/{self.id}") + error_keys = set(["code", "message"]) try: - res = res.json() - raise Exception(res.get("message")) + response = res.json() + actual_keys = set(response.keys()) + if actual_keys == error_keys: + raise Exception(res.get("message")) + else: + return res.content except json.JSONDecodeError: return res.content - def list_chunks(self, page=1, page_size=30, keywords="", id = ""): + def list_chunks(self, page=1, page_size=30, keywords="", id=""): data = {"keywords": keywords, "page": page, "page_size": page_size, "id": id} - res = self.get(f'/datasets/{self.dataset_id}/documents/{self.id}/chunks', data) + res = self.get(f"/datasets/{self.dataset_id}/documents/{self.id}/chunks", data) res = res.json() if res.get("code") == 0: chunks = [] @@ -79,8 +87,7 @@ def list_chunks(self, page=1, page_size=30, keywords="", id = ""): raise Exception(res.get("message")) def add_chunk(self, content: str, important_keywords: list[str] = [], questions: list[str] = []): - res = self.post(f'/datasets/{self.dataset_id}/documents/{self.id}/chunks', - {"content": content, "important_keywords": important_keywords, "questions": questions}) + res = self.post(f"/datasets/{self.dataset_id}/documents/{self.id}/chunks", {"content": content, "important_keywords": important_keywords, "questions": questions}) res = res.json() if res.get("code") == 0: return Chunk(self.rag, res["data"].get("chunk")) diff --git a/sdk/python/ragflow_sdk/modules/session.py b/sdk/python/ragflow_sdk/modules/session.py index f91405cab86..d534c782b0d 100644 --- a/sdk/python/ragflow_sdk/modules/session.py +++ b/sdk/python/ragflow_sdk/modules/session.py @@ -23,7 +23,7 @@ class Session(Base): def __init__(self, rag, res_dict): self.id = None self.name = "New session" - self.messages = [{"role": "assistant", "content": "Hi! I am your assistant,can I help you?"}] + self.messages = [{"role": "assistant", "content": "Hi! I am your assistant, can I help you?"}] for key, value in res_dict.items(): if key == "chat_id" and value is not None: self.chat_id = None @@ -50,33 +50,28 @@ def ask(self, question="", stream=True, **kwargs): json_data = json.loads(line[5:]) if json_data["data"] is True or json_data["data"].get("running_status"): continue - answer = json_data["data"]["answer"] - reference = json_data["data"].get("reference", {}) - temp_dict = { - "content": answer, - "role": "assistant" - } - if reference and "chunks" in reference: - chunks = reference["chunks"] - temp_dict["reference"] = chunks - message = Message(self.rag, temp_dict) + message = self._structure_answer(json_data) yield message else: try: json_data = json.loads(res.text) except ValueError: raise Exception(f"Invalid response {res}") - answer = json_data["data"]["answer"] - reference = json_data["data"].get("reference", {}) - temp_dict = { - "content": answer, - "role": "assistant" - } - if reference and "chunks" in reference: - chunks = reference["chunks"] - temp_dict["reference"] = chunks - message = Message(self.rag, temp_dict) - return message + return self._structure_answer(json_data) + + + def _structure_answer(self, json_data): + answer = json_data["data"]["answer"] + reference = json_data["data"].get("reference", {}) + temp_dict = { + "content": answer, + "role": "assistant" + } + if reference and "chunks" in reference: + chunks = reference["chunks"] + temp_dict["reference"] = chunks + message = Message(self.rag, temp_dict) + return message def _ask_chat(self, question: str, stream: bool, **kwargs): json_data = {"question": question, "stream": stream, "session_id": self.id} @@ -100,7 +95,7 @@ def update(self, update_message): class Message(Base): def __init__(self, rag, res_dict): - self.content = "Hi! I am your assistant,can I help you?" + self.content = "Hi! I am your assistant, can I help you?" self.reference = None self.role = "assistant" self.prompt = None diff --git a/sdk/python/ragflow_sdk/ragflow.py b/sdk/python/ragflow_sdk/ragflow.py index a2569353ca1..5b65d6201d5 100644 --- a/sdk/python/ragflow_sdk/ragflow.py +++ b/sdk/python/ragflow_sdk/ragflow.py @@ -56,8 +56,7 @@ def create_dataset( embedding_model: Optional[str] = "BAAI/bge-large-zh-v1.5@BAAI", permission: str = "me", chunk_method: str = "naive", - pagerank: int = 0, - parser_config: DataSet.ParserConfig = None, + parser_config: Optional[DataSet.ParserConfig] = None, ) -> DataSet: payload = { "name": name, @@ -66,7 +65,6 @@ def create_dataset( "embedding_model": embedding_model, "permission": permission, "chunk_method": chunk_method, - "pagerank": pagerank, } if parser_config is not None: payload["parser_config"] = parser_config.to_json() @@ -246,10 +244,7 @@ def list_agents(self, page: int = 1, page_size: int = 30, orderby: str = "update raise Exception(res["message"]) def create_agent(self, title: str, dsl: dict, description: str | None = None) -> None: - req = { - "title": title, - "dsl": dsl - } + req = {"title": title, "dsl": dsl} if description is not None: req["description"] = description @@ -260,13 +255,7 @@ def create_agent(self, title: str, dsl: dict, description: str | None = None) -> if res.get("code") != 0: raise Exception(res["message"]) - def update_agent( - self, - agent_id: str, - title: str | None = None, - description: str | None = None, - dsl: dict | None = None - ) -> None: + def update_agent(self, agent_id: str, title: str | None = None, description: str | None = None, dsl: dict | None = None) -> None: req = {} if title is not None: diff --git a/sdk/python/test/conftest.py b/sdk/python/test/conftest.py index 6eae2c2de32..6e018491bc5 100644 --- a/sdk/python/test/conftest.py +++ b/sdk/python/test/conftest.py @@ -20,7 +20,7 @@ import requests HOST_ADDRESS = os.getenv("HOST_ADDRESS", "http://127.0.0.1:9380") -ZHIPU_AI_API_KEY = os.getenv("ZHIPU_AI_API_KEY", "ca148e43209c40109e2bc2f56281dd11.BltyA2N1B043B7Ra") +ZHIPU_AI_API_KEY = os.getenv("ZHIPU_AI_API_KEY") if ZHIPU_AI_API_KEY is None: pytest.exit("Error: Environment variable ZHIPU_AI_API_KEY must be set") diff --git a/sdk/python/test/test_frontend_api/test_dataset.py b/sdk/python/test/test_frontend_api/test_dataset.py index 3baaf4e8ca1..c60dc4d0c00 100644 --- a/sdk/python/test/test_frontend_api/test_dataset.py +++ b/sdk/python/test/test_frontend_api/test_dataset.py @@ -108,7 +108,7 @@ def test_invalid_name_dataset(get_auth): long_string = "" - while len(long_string) <= DATASET_NAME_LIMIT: + while len(long_string.encode("utf-8")) <= DATASET_NAME_LIMIT: long_string += random.choice(string.ascii_letters + string.digits) res = create_dataset(get_auth, long_string) diff --git a/sdk/python/test/test_http_api/test_chunk_management_within_dataset/test_list_chunks.py b/sdk/python/test/test_http_api/test_chunk_management_within_dataset/test_list_chunks.py index 35b6e416c4a..f866d3f09db 100644 --- a/sdk/python/test/test_http_api/test_chunk_management_within_dataset/test_list_chunks.py +++ b/sdk/python/test/test_http_api/test_chunk_management_within_dataset/test_list_chunks.py @@ -69,7 +69,7 @@ def test_page(self, get_http_api_auth, add_chunks, params, expected_code, expect [ ({"page_size": None}, 0, 5, ""), pytest.param({"page_size": 0}, 0, 5, "", marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="Infinity does not support page_size=0")), - pytest.param({"page_size": 0}, 100, 0, "3013", marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") in [None, "elasticsearch"], reason="Infinity does not support page_size=0")), + pytest.param({"page_size": 0}, 100, 0, "3013", marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") in [None, "opensearch", "elasticsearch"], reason="Infinity does not support page_size=0")), ({"page_size": 1}, 0, 1, ""), ({"page_size": 6}, 0, 5, ""), ({"page_size": "1"}, 0, 1, ""), diff --git a/sdk/python/test/test_http_api/test_chunk_management_within_dataset/test_retrieval_chunks.py b/sdk/python/test/test_http_api/test_chunk_management_within_dataset/test_retrieval_chunks.py index df6731b1668..c4fd4b62688 100644 --- a/sdk/python/test/test_http_api/test_chunk_management_within_dataset/test_retrieval_chunks.py +++ b/sdk/python/test/test_http_api/test_chunk_management_within_dataset/test_retrieval_chunks.py @@ -185,28 +185,28 @@ def test_vector_similarity_weight(self, get_http_api_auth, add_chunks, payload, 0, 4, "", - marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="Infinity"), + marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") in ["infinity", "opensearch"], reason="Infinity"), ), pytest.param( {"top_k": 1}, 0, 1, "", - marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") in [None, "elasticsearch"], reason="elasticsearch"), + marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") in [None, "opensearch", "elasticsearch"], reason="elasticsearch"), ), pytest.param( {"top_k": -1}, 100, 4, "must be greater than 0", - marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="Infinity"), + marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") in ["infinity", "opensearch"], reason="Infinity"), ), pytest.param( {"top_k": -1}, 100, 4, "3014", - marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") in [None, "elasticsearch"], reason="elasticsearch"), + marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") in [None, "opensearch", "elasticsearch"], reason="elasticsearch"), ), pytest.param( {"top_k": "a"}, diff --git a/sdk/python/test/test_http_api/test_chunk_management_within_dataset/test_update_chunk.py b/sdk/python/test/test_http_api/test_chunk_management_within_dataset/test_update_chunk.py index 710829ac15a..b364f81bd91 100644 --- a/sdk/python/test/test_http_api/test_chunk_management_within_dataset/test_update_chunk.py +++ b/sdk/python/test/test_http_api/test_chunk_management_within_dataset/test_update_chunk.py @@ -146,7 +146,7 @@ def test_available( [ ("", 100, ""), pytest.param("invalid_dataset_id", 102, "You don't own the dataset invalid_dataset_id.", marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="infinity")), - pytest.param("invalid_dataset_id", 102, "Can't find this chunk", marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") in [None, "elasticsearch"], reason="elasticsearch")), + pytest.param("invalid_dataset_id", 102, "Can't find this chunk", marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") in [None, "opensearch","elasticsearch"], reason="elasticsearch")), ], ) def test_invalid_dataset_id(self, get_http_api_auth, add_chunks, dataset_id, expected_code, expected_message): diff --git a/sdk/python/test/test_http_api/test_file_management_within_dataset/test_stop_parse_documents.py b/sdk/python/test/test_http_api/test_file_management_within_dataset/test_stop_parse_documents.py index e2dc4f7e04f..7f05302b56a 100644 --- a/sdk/python/test/test_http_api/test_file_management_within_dataset/test_stop_parse_documents.py +++ b/sdk/python/test/test_http_api/test_file_management_within_dataset/test_stop_parse_documents.py @@ -92,13 +92,13 @@ def condition(_auth, _dataset_id, _document_ids): res = stop_parse_documnets(get_http_api_auth, dataset_id, payload) assert res["code"] == expected_code - if expected_code != 0: - assert res["message"] == expected_message - else: + if expected_code == 0: completed_document_ids = list(set(document_ids) - set(payload["document_ids"])) condition(get_http_api_auth, dataset_id, completed_document_ids) validate_document_parse_cancel(get_http_api_auth, dataset_id, payload["document_ids"]) validate_document_parse_done(get_http_api_auth, dataset_id, completed_document_ids) + else: + assert res["message"] == expected_message @pytest.mark.p3 @pytest.mark.parametrize( diff --git a/sdk/python/test/test_http_api/test_file_management_within_dataset/test_update_document.py b/sdk/python/test/test_http_api/test_file_management_within_dataset/test_update_document.py index 04735b5502d..5bf6f0410a5 100644 --- a/sdk/python/test/test_http_api/test_file_management_within_dataset/test_update_document.py +++ b/sdk/python/test/test_http_api/test_file_management_within_dataset/test_update_document.py @@ -173,10 +173,10 @@ def test_chunk_method(self, get_http_api_auth, add_documents, chunk_method, expe assert res["code"] == expected_code if expected_code == 0: res = list_documnets(get_http_api_auth, dataset_id, {"id": document_ids[0]}) - if chunk_method != "": - assert res["data"]["docs"][0]["chunk_method"] == chunk_method - else: + if chunk_method == "": assert res["data"]["docs"][0]["chunk_method"] == "naive" + else: + assert res["data"]["docs"][0]["chunk_method"] == chunk_method else: assert res["message"] == expected_message @@ -532,10 +532,7 @@ def test_parser_config( assert res["code"] == expected_code if expected_code == 0: res = list_documnets(get_http_api_auth, dataset_id, {"id": document_ids[0]}) - if parser_config != {}: - for k, v in parser_config.items(): - assert res["data"]["docs"][0]["parser_config"][k] == v - else: + if parser_config == {}: assert res["data"]["docs"][0]["parser_config"] == { "chunk_token_num": 128, "delimiter": r"\n", @@ -543,5 +540,8 @@ def test_parser_config( "layout_recognize": "DeepDOC", "raptor": {"use_raptor": False}, } + else: + for k, v in parser_config.items(): + assert res["data"]["docs"][0]["parser_config"][k] == v if expected_code != 0 or expected_message: assert res["message"] == expected_message diff --git a/sdk/python/test/test_http_api/test_session_management/test_list_sessions_with_chat_assistant.py b/sdk/python/test/test_http_api/test_session_management/test_list_sessions_with_chat_assistant.py index 19379d340f6..e84bd6f1f01 100644 --- a/sdk/python/test/test_http_api/test_session_management/test_list_sessions_with_chat_assistant.py +++ b/sdk/python/test/test_http_api/test_session_management/test_list_sessions_with_chat_assistant.py @@ -162,10 +162,10 @@ def test_name(self, get_http_api_auth, add_sessions_with_chat_assistant, params, res = list_session_with_chat_assistants(get_http_api_auth, chat_assistant_id, params=params) assert res["code"] == expected_code if expected_code == 0: - if params["name"] != "session_with_chat_assistant_1": - assert len(res["data"]) == expected_num - else: + if params["name"] == "session_with_chat_assistant_1": assert res["data"][0]["name"] == params["name"] + else: + assert len(res["data"]) == expected_num else: assert res["message"] == expected_message @@ -189,10 +189,10 @@ def test_id(self, get_http_api_auth, add_sessions_with_chat_assistant, session_i res = list_session_with_chat_assistants(get_http_api_auth, chat_assistant_id, params=params) assert res["code"] == expected_code if expected_code == 0: - if params["id"] != session_ids[0]: - assert len(res["data"]) == expected_num - else: + if params["id"] == session_ids[0]: assert res["data"][0]["id"] == params["id"] + else: + assert len(res["data"]) == expected_num else: assert res["message"] == expected_message diff --git a/sdk/python/uv.lock b/sdk/python/uv.lock index 19fb76ebfb0..48eda3cd43f 100644 --- a/sdk/python/uv.lock +++ b/sdk/python/uv.lock @@ -342,7 +342,7 @@ wheels = [ [[package]] name = "ragflow-sdk" -version = "0.19.0" +version = "0.19.1" source = { virtual = "." } dependencies = [ { name = "beartype" }, diff --git a/test/configs.py b/test/configs.py new file mode 100644 index 00000000000..7381567b54b --- /dev/null +++ b/test/configs.py @@ -0,0 +1,36 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os + +import pytest + +HOST_ADDRESS = os.getenv("HOST_ADDRESS", "http://127.0.0.1:9380") +VERSION = "v1" +ZHIPU_AI_API_KEY = os.getenv("ZHIPU_AI_API_KEY") +if ZHIPU_AI_API_KEY is None: + pytest.exit("Error: Environment variable ZHIPU_AI_API_KEY must be set") + +EMAIL = "qa@infiniflow.org" +# password is "123" +PASSWORD = """ctAseGvejiaSWWZ88T/m4FQVOpQyUvP+x7sXtdv3feqZACiQleuewkUi35E16wSd5C5QcnkkcV9cYc8TKPTRZlxappDuirxghxoOvFcJxFU4ixLsD +fN33jCHRoDUW81IH9zjij/vaw8IbVyb6vuwg6MX6inOEBRRzVbRYxXOu1wkWY6SsI8X70oF9aeLFp/PzQpjoe/YbSqpTq8qqrmHzn9vO+yvyYyvmDsphXe +X8f7fp9c7vUsfOCkM+gHY3PadG+QHa7KI7mzTKgUTZImK6BZtfRBATDTthEUbbaTewY4H0MnWiCeeDhcbeQao6cFy1To8pE3RpmxnGnS8BsBn8w==""" + +INVALID_API_TOKEN = "invalid_key_123" +DATASET_NAME_LIMIT = 128 +DOCUMENT_NAME_LIMIT = 255 +CHAT_ASSISTANT_NAME_LIMIT = 255 +SESSION_WITH_CHAT_NAME_LIMIT = 255 diff --git a/test/conftest.py b/test/conftest.py new file mode 100644 index 00000000000..94952640029 --- /dev/null +++ b/test/conftest.py @@ -0,0 +1,152 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import pytest +import requests +from configs import EMAIL, HOST_ADDRESS, PASSWORD, VERSION, ZHIPU_AI_API_KEY + +MARKER_EXPRESSIONS = { + "p1": "p1", + "p2": "p1 or p2", + "p3": "p1 or p2 or p3", +} + + +def pytest_addoption(parser: pytest.Parser) -> None: + parser.addoption( + "--level", + action="store", + default="p2", + choices=list(MARKER_EXPRESSIONS.keys()), + help=f"Test level ({'/'.join(MARKER_EXPRESSIONS)}): p1=smoke, p2=core, p3=full", + ) + + parser.addoption( + "--client-type", + action="store", + default="http", + choices=["python_sdk", "http", "web"], + help="Test client type: 'python_sdk', 'http', 'web'", + ) + + +def pytest_configure(config: pytest.Config) -> None: + level = config.getoption("--level") + config.option.markexpr = MARKER_EXPRESSIONS[level] + if config.option.verbose > 0: + print(f"\n[CONFIG] Active test level: {level}") + + +def register(): + url = HOST_ADDRESS + f"/{VERSION}/user/register" + name = "qa" + register_data = {"email": EMAIL, "nickname": name, "password": PASSWORD} + res = requests.post(url=url, json=register_data) + res = res.json() + if res.get("code") != 0 and "has already registered" not in res.get("message"): + raise Exception(res.get("message")) + + +def login(): + url = HOST_ADDRESS + f"/{VERSION}/user/login" + login_data = {"email": EMAIL, "password": PASSWORD} + response = requests.post(url=url, json=login_data) + res = response.json() + if res.get("code") != 0: + raise Exception(res.get("message")) + auth = response.headers["Authorization"] + return auth + + +@pytest.fixture(scope="session") +def auth(): + try: + register() + except Exception as e: + print(e) + auth = login() + return auth + + +@pytest.fixture(scope="session") +def token(auth): + url = HOST_ADDRESS + f"/{VERSION}/system/new_token" + auth = {"Authorization": auth} + response = requests.post(url=url, headers=auth) + res = response.json() + if res.get("code") != 0: + raise Exception(res.get("message")) + return res["data"].get("token") + + +def get_my_llms(auth, name): + url = HOST_ADDRESS + f"/{VERSION}/llm/my_llms" + authorization = {"Authorization": auth} + response = requests.get(url=url, headers=authorization) + res = response.json() + if res.get("code") != 0: + raise Exception(res.get("message")) + if name in res.get("data"): + return True + return False + + +def add_models(auth): + url = HOST_ADDRESS + f"/{VERSION}/llm/set_api_key" + authorization = {"Authorization": auth} + models_info = { + "ZHIPU-AI": {"llm_factory": "ZHIPU-AI", "api_key": ZHIPU_AI_API_KEY}, + } + + for name, model_info in models_info.items(): + if not get_my_llms(auth, name): + response = requests.post(url=url, headers=authorization, json=model_info) + res = response.json() + if res.get("code") != 0: + pytest.exit(f"Critical error in add_models: {res.get('message')}") + + +def get_tenant_info(auth): + url = HOST_ADDRESS + f"/{VERSION}/user/tenant_info" + authorization = {"Authorization": auth} + response = requests.get(url=url, headers=authorization) + res = response.json() + if res.get("code") != 0: + raise Exception(res.get("message")) + return res["data"].get("tenant_id") + + +@pytest.fixture(scope="session", autouse=True) +def set_tenant_info(auth): + try: + add_models(auth) + tenant_id = get_tenant_info(auth) + except Exception as e: + pytest.exit(f"Error in set_tenant_info: {str(e)}") + url = HOST_ADDRESS + f"/{VERSION}/user/set_tenant_info" + authorization = {"Authorization": auth} + tenant_info = { + "tenant_id": tenant_id, + "llm_id": "glm-4-flash@ZHIPU-AI", + "embd_id": "BAAI/bge-large-zh-v1.5@BAAI", + "img2txt_id": "", + "asr_id": "", + "tts_id": None, + } + response = requests.post(url=url, headers=authorization, json=tenant_info) + res = response.json() + if res.get("code") != 0: + raise Exception(res.get("message")) diff --git a/test/libs/__init__.py b/test/libs/__init__.py new file mode 100644 index 00000000000..177b91dd051 --- /dev/null +++ b/test/libs/__init__.py @@ -0,0 +1,15 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/test/libs/auth.py b/test/libs/auth.py new file mode 100644 index 00000000000..cdc31c94b14 --- /dev/null +++ b/test/libs/auth.py @@ -0,0 +1,34 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from requests.auth import AuthBase + + +class RAGFlowHttpApiAuth(AuthBase): + def __init__(self, token): + self._token = token + + def __call__(self, r): + r.headers["Authorization"] = f"Bearer {self._token}" + return r + + +class RAGFlowWebApiAuth(AuthBase): + def __init__(self, token): + self._token = token + + def __call__(self, r): + r.headers["Authorization"] = self._token + return r diff --git a/test/testcases/test_http_api/common.py b/test/testcases/test_http_api/common.py new file mode 100644 index 00000000000..123fb766736 --- /dev/null +++ b/test/testcases/test_http_api/common.py @@ -0,0 +1,250 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pathlib import Path + +import requests +from configs import HOST_ADDRESS +from requests_toolbelt import MultipartEncoder +from utils.file_utils import create_txt_file + +HEADERS = {"Content-Type": "application/json"} +DATASETS_API_URL = "/api/v1/datasets" +FILE_API_URL = "/api/v1/datasets/{dataset_id}/documents" +FILE_CHUNK_API_URL = "/api/v1/datasets/{dataset_id}/chunks" +CHUNK_API_URL = "/api/v1/datasets/{dataset_id}/documents/{document_id}/chunks" +CHAT_ASSISTANT_API_URL = "/api/v1/chats" +SESSION_WITH_CHAT_ASSISTANT_API_URL = "/api/v1/chats/{chat_id}/sessions" +SESSION_WITH_AGENT_API_URL = "/api/v1/agents/{agent_id}/sessions" + + +# DATASET MANAGEMENT +def create_dataset(auth, payload=None, *, headers=HEADERS, data=None): + res = requests.post(url=f"{HOST_ADDRESS}{DATASETS_API_URL}", headers=headers, auth=auth, json=payload, data=data) + return res.json() + + +def list_datasets(auth, params=None, *, headers=HEADERS): + res = requests.get(url=f"{HOST_ADDRESS}{DATASETS_API_URL}", headers=headers, auth=auth, params=params) + return res.json() + + +def update_dataset(auth, dataset_id, payload=None, *, headers=HEADERS, data=None): + res = requests.put(url=f"{HOST_ADDRESS}{DATASETS_API_URL}/{dataset_id}", headers=headers, auth=auth, json=payload, data=data) + return res.json() + + +def delete_datasets(auth, payload=None, *, headers=HEADERS, data=None): + res = requests.delete(url=f"{HOST_ADDRESS}{DATASETS_API_URL}", headers=headers, auth=auth, json=payload, data=data) + return res.json() + + +def batch_create_datasets(auth, num): + ids = [] + for i in range(num): + res = create_dataset(auth, {"name": f"dataset_{i}"}) + ids.append(res["data"]["id"]) + return ids + + +# FILE MANAGEMENT WITHIN DATASET +def upload_documents(auth, dataset_id, files_path=None): + url = f"{HOST_ADDRESS}{FILE_API_URL}".format(dataset_id=dataset_id) + + if files_path is None: + files_path = [] + + fields = [] + file_objects = [] + try: + for fp in files_path: + p = Path(fp) + f = p.open("rb") + fields.append(("file", (p.name, f))) + file_objects.append(f) + m = MultipartEncoder(fields=fields) + + res = requests.post( + url=url, + headers={"Content-Type": m.content_type}, + auth=auth, + data=m, + ) + return res.json() + finally: + for f in file_objects: + f.close() + + +def download_document(auth, dataset_id, document_id, save_path): + url = f"{HOST_ADDRESS}{FILE_API_URL}/{document_id}".format(dataset_id=dataset_id) + res = requests.get(url=url, auth=auth, stream=True) + try: + if res.status_code == 200: + with open(save_path, "wb") as f: + for chunk in res.iter_content(chunk_size=8192): + f.write(chunk) + finally: + res.close() + + return res + + +def list_documents(auth, dataset_id, params=None): + url = f"{HOST_ADDRESS}{FILE_API_URL}".format(dataset_id=dataset_id) + res = requests.get(url=url, headers=HEADERS, auth=auth, params=params) + return res.json() + + +def update_document(auth, dataset_id, document_id, payload=None): + url = f"{HOST_ADDRESS}{FILE_API_URL}/{document_id}".format(dataset_id=dataset_id) + res = requests.put(url=url, headers=HEADERS, auth=auth, json=payload) + return res.json() + + +def delete_documents(auth, dataset_id, payload=None): + url = f"{HOST_ADDRESS}{FILE_API_URL}".format(dataset_id=dataset_id) + res = requests.delete(url=url, headers=HEADERS, auth=auth, json=payload) + return res.json() + + +def parse_documents(auth, dataset_id, payload=None): + url = f"{HOST_ADDRESS}{FILE_CHUNK_API_URL}".format(dataset_id=dataset_id) + res = requests.post(url=url, headers=HEADERS, auth=auth, json=payload) + return res.json() + + +def stop_parse_documents(auth, dataset_id, payload=None): + url = f"{HOST_ADDRESS}{FILE_CHUNK_API_URL}".format(dataset_id=dataset_id) + res = requests.delete(url=url, headers=HEADERS, auth=auth, json=payload) + return res.json() + + +def bulk_upload_documents(auth, dataset_id, num, tmp_path): + fps = [] + for i in range(num): + fp = create_txt_file(tmp_path / f"ragflow_test_upload_{i}.txt") + fps.append(fp) + res = upload_documents(auth, dataset_id, fps) + document_ids = [] + for document in res["data"]: + document_ids.append(document["id"]) + return document_ids + + +# CHUNK MANAGEMENT WITHIN DATASET +def add_chunk(auth, dataset_id, document_id, payload=None): + url = f"{HOST_ADDRESS}{CHUNK_API_URL}".format(dataset_id=dataset_id, document_id=document_id) + res = requests.post(url=url, headers=HEADERS, auth=auth, json=payload) + return res.json() + + +def list_chunks(auth, dataset_id, document_id, params=None): + url = f"{HOST_ADDRESS}{CHUNK_API_URL}".format(dataset_id=dataset_id, document_id=document_id) + res = requests.get(url=url, headers=HEADERS, auth=auth, params=params) + return res.json() + + +def update_chunk(auth, dataset_id, document_id, chunk_id, payload=None): + url = f"{HOST_ADDRESS}{CHUNK_API_URL}/{chunk_id}".format(dataset_id=dataset_id, document_id=document_id) + res = requests.put(url=url, headers=HEADERS, auth=auth, json=payload) + return res.json() + + +def delete_chunks(auth, dataset_id, document_id, payload=None): + url = f"{HOST_ADDRESS}{CHUNK_API_URL}".format(dataset_id=dataset_id, document_id=document_id) + res = requests.delete(url=url, headers=HEADERS, auth=auth, json=payload) + return res.json() + + +def retrieval_chunks(auth, payload=None): + url = f"{HOST_ADDRESS}/api/v1/retrieval" + res = requests.post(url=url, headers=HEADERS, auth=auth, json=payload) + return res.json() + + +def batch_add_chunks(auth, dataset_id, document_id, num): + chunk_ids = [] + for i in range(num): + res = add_chunk(auth, dataset_id, document_id, {"content": f"chunk test {i}"}) + chunk_ids.append(res["data"]["chunk"]["id"]) + return chunk_ids + + +# CHAT ASSISTANT MANAGEMENT +def create_chat_assistant(auth, payload=None): + url = f"{HOST_ADDRESS}{CHAT_ASSISTANT_API_URL}" + res = requests.post(url=url, headers=HEADERS, auth=auth, json=payload) + return res.json() + + +def list_chat_assistants(auth, params=None): + url = f"{HOST_ADDRESS}{CHAT_ASSISTANT_API_URL}" + res = requests.get(url=url, headers=HEADERS, auth=auth, params=params) + return res.json() + + +def update_chat_assistant(auth, chat_assistant_id, payload=None): + url = f"{HOST_ADDRESS}{CHAT_ASSISTANT_API_URL}/{chat_assistant_id}" + res = requests.put(url=url, headers=HEADERS, auth=auth, json=payload) + return res.json() + + +def delete_chat_assistants(auth, payload=None): + url = f"{HOST_ADDRESS}{CHAT_ASSISTANT_API_URL}" + res = requests.delete(url=url, headers=HEADERS, auth=auth, json=payload) + return res.json() + + +def batch_create_chat_assistants(auth, num): + chat_assistant_ids = [] + for i in range(num): + res = create_chat_assistant(auth, {"name": f"test_chat_assistant_{i}", "dataset_ids": []}) + chat_assistant_ids.append(res["data"]["id"]) + return chat_assistant_ids + + +# SESSION MANAGEMENT +def create_session_with_chat_assistant(auth, chat_assistant_id, payload=None): + url = f"{HOST_ADDRESS}{SESSION_WITH_CHAT_ASSISTANT_API_URL}".format(chat_id=chat_assistant_id) + res = requests.post(url=url, headers=HEADERS, auth=auth, json=payload) + return res.json() + + +def list_session_with_chat_assistants(auth, chat_assistant_id, params=None): + url = f"{HOST_ADDRESS}{SESSION_WITH_CHAT_ASSISTANT_API_URL}".format(chat_id=chat_assistant_id) + res = requests.get(url=url, headers=HEADERS, auth=auth, params=params) + return res.json() + + +def update_session_with_chat_assistant(auth, chat_assistant_id, session_id, payload=None): + url = f"{HOST_ADDRESS}{SESSION_WITH_CHAT_ASSISTANT_API_URL}/{session_id}".format(chat_id=chat_assistant_id) + res = requests.put(url=url, headers=HEADERS, auth=auth, json=payload) + return res.json() + + +def delete_session_with_chat_assistants(auth, chat_assistant_id, payload=None): + url = f"{HOST_ADDRESS}{SESSION_WITH_CHAT_ASSISTANT_API_URL}".format(chat_id=chat_assistant_id) + res = requests.delete(url=url, headers=HEADERS, auth=auth, json=payload) + return res.json() + + +def batch_add_sessions_with_chat_assistant(auth, chat_assistant_id, num): + session_ids = [] + for i in range(num): + res = create_session_with_chat_assistant(auth, chat_assistant_id, {"name": f"session_with_chat_assistant_{i}"}) + session_ids.append(res["data"]["id"]) + return session_ids diff --git a/test/testcases/test_http_api/conftest.py b/test/testcases/test_http_api/conftest.py new file mode 100644 index 00000000000..983ef8aee2c --- /dev/null +++ b/test/testcases/test_http_api/conftest.py @@ -0,0 +1,165 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from time import sleep + +import pytest +from common import ( + batch_add_chunks, + batch_create_chat_assistants, + batch_create_datasets, + bulk_upload_documents, + delete_chat_assistants, + delete_datasets, + delete_session_with_chat_assistants, + list_documents, + parse_documents, +) +from libs.auth import RAGFlowHttpApiAuth +from utils import wait_for +from utils.file_utils import ( + create_docx_file, + create_eml_file, + create_excel_file, + create_html_file, + create_image_file, + create_json_file, + create_md_file, + create_pdf_file, + create_ppt_file, + create_txt_file, +) + + +@wait_for(30, 1, "Document parsing timeout") +def condition(_auth, _dataset_id): + res = list_documents(_auth, _dataset_id) + for doc in res["data"]["docs"]: + if doc["run"] != "DONE": + return False + return True + + +@pytest.fixture +def generate_test_files(request, tmp_path): + file_creators = { + "docx": (tmp_path / "ragflow_test.docx", create_docx_file), + "excel": (tmp_path / "ragflow_test.xlsx", create_excel_file), + "ppt": (tmp_path / "ragflow_test.pptx", create_ppt_file), + "image": (tmp_path / "ragflow_test.png", create_image_file), + "pdf": (tmp_path / "ragflow_test.pdf", create_pdf_file), + "txt": (tmp_path / "ragflow_test.txt", create_txt_file), + "md": (tmp_path / "ragflow_test.md", create_md_file), + "json": (tmp_path / "ragflow_test.json", create_json_file), + "eml": (tmp_path / "ragflow_test.eml", create_eml_file), + "html": (tmp_path / "ragflow_test.html", create_html_file), + } + + files = {} + for file_type, (file_path, creator_func) in file_creators.items(): + if request.param in ["", file_type]: + creator_func(file_path) + files[file_type] = file_path + return files + + +@pytest.fixture(scope="class") +def ragflow_tmp_dir(request, tmp_path_factory): + class_name = request.cls.__name__ + return tmp_path_factory.mktemp(class_name) + + +@pytest.fixture(scope="session") +def HttpApiAuth(token): + return RAGFlowHttpApiAuth(token) + + +@pytest.fixture(scope="function") +def clear_datasets(request, HttpApiAuth): + def cleanup(): + delete_datasets(HttpApiAuth, {"ids": None}) + + request.addfinalizer(cleanup) + + +@pytest.fixture(scope="function") +def clear_chat_assistants(request, HttpApiAuth): + def cleanup(): + delete_chat_assistants(HttpApiAuth) + + request.addfinalizer(cleanup) + + +@pytest.fixture(scope="function") +def clear_session_with_chat_assistants(request, HttpApiAuth, add_chat_assistants): + def cleanup(): + for chat_assistant_id in chat_assistant_ids: + delete_session_with_chat_assistants(HttpApiAuth, chat_assistant_id) + + request.addfinalizer(cleanup) + + _, _, chat_assistant_ids = add_chat_assistants + + +@pytest.fixture(scope="class") +def add_dataset(request, HttpApiAuth): + def cleanup(): + delete_datasets(HttpApiAuth, {"ids": None}) + + request.addfinalizer(cleanup) + + dataset_ids = batch_create_datasets(HttpApiAuth, 1) + return dataset_ids[0] + + +@pytest.fixture(scope="function") +def add_dataset_func(request, HttpApiAuth): + def cleanup(): + delete_datasets(HttpApiAuth, {"ids": None}) + + request.addfinalizer(cleanup) + + return batch_create_datasets(HttpApiAuth, 1)[0] + + +@pytest.fixture(scope="class") +def add_document(HttpApiAuth, add_dataset, ragflow_tmp_dir): + dataset_id = add_dataset + document_ids = bulk_upload_documents(HttpApiAuth, dataset_id, 1, ragflow_tmp_dir) + return dataset_id, document_ids[0] + + +@pytest.fixture(scope="class") +def add_chunks(HttpApiAuth, add_document): + dataset_id, document_id = add_document + parse_documents(HttpApiAuth, dataset_id, {"document_ids": [document_id]}) + condition(HttpApiAuth, dataset_id) + chunk_ids = batch_add_chunks(HttpApiAuth, dataset_id, document_id, 4) + sleep(1) # issues/6487 + return dataset_id, document_id, chunk_ids + + +@pytest.fixture(scope="class") +def add_chat_assistants(request, HttpApiAuth, add_document): + def cleanup(): + delete_chat_assistants(HttpApiAuth) + + request.addfinalizer(cleanup) + + dataset_id, document_id = add_document + parse_documents(HttpApiAuth, dataset_id, {"document_ids": [document_id]}) + condition(HttpApiAuth, dataset_id) + return dataset_id, document_id, batch_create_chat_assistants(HttpApiAuth, 5) diff --git a/test/testcases/test_http_api/test_chat_assistant_management/conftest.py b/test/testcases/test_http_api/test_chat_assistant_management/conftest.py new file mode 100644 index 00000000000..3087d5929c3 --- /dev/null +++ b/test/testcases/test_http_api/test_chat_assistant_management/conftest.py @@ -0,0 +1,40 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import pytest +from common import batch_create_chat_assistants, delete_chat_assistants, list_documents, parse_documents +from utils import wait_for + + +@wait_for(30, 1, "Document parsing timeout") +def condition(_auth, _dataset_id): + res = list_documents(_auth, _dataset_id) + for doc in res["data"]["docs"]: + if doc["run"] != "DONE": + return False + return True + + +@pytest.fixture(scope="function") +def add_chat_assistants_func(request, HttpApiAuth, add_document): + def cleanup(): + delete_chat_assistants(HttpApiAuth) + + request.addfinalizer(cleanup) + + dataset_id, document_id = add_document + parse_documents(HttpApiAuth, dataset_id, {"document_ids": [document_id]}) + condition(HttpApiAuth, dataset_id) + return dataset_id, document_id, batch_create_chat_assistants(HttpApiAuth, 5) diff --git a/test/testcases/test_http_api/test_chat_assistant_management/test_create_chat_assistant.py b/test/testcases/test_http_api/test_chat_assistant_management/test_create_chat_assistant.py new file mode 100644 index 00000000000..7e7d9c11617 --- /dev/null +++ b/test/testcases/test_http_api/test_chat_assistant_management/test_create_chat_assistant.py @@ -0,0 +1,242 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import pytest +from common import create_chat_assistant +from configs import CHAT_ASSISTANT_NAME_LIMIT, INVALID_API_TOKEN +from libs.auth import RAGFlowHttpApiAuth +from utils import encode_avatar +from utils.file_utils import create_image_file + + +@pytest.mark.p1 +class TestAuthorization: + @pytest.mark.parametrize( + "invalid_auth, expected_code, expected_message", + [ + (None, 0, "`Authorization` can't be empty"), + ( + RAGFlowHttpApiAuth(INVALID_API_TOKEN), + 109, + "Authentication error: API key is invalid!", + ), + ], + ) + def test_invalid_auth(self, invalid_auth, expected_code, expected_message): + res = create_chat_assistant(invalid_auth) + assert res["code"] == expected_code + assert res["message"] == expected_message + + +@pytest.mark.usefixtures("clear_chat_assistants") +class TestChatAssistantCreate: + @pytest.mark.p1 + @pytest.mark.parametrize( + "payload, expected_code, expected_message", + [ + ({"name": "valid_name"}, 0, ""), + pytest.param({"name": "a" * (CHAT_ASSISTANT_NAME_LIMIT + 1)}, 102, "", marks=pytest.mark.skip(reason="issues/")), + pytest.param({"name": 1}, 100, "", marks=pytest.mark.skip(reason="issues/")), + ({"name": ""}, 102, "`name` is required."), + ({"name": "duplicated_name"}, 102, "Duplicated chat name in creating chat."), + ({"name": "case insensitive"}, 102, "Duplicated chat name in creating chat."), + ], + ) + def test_name(self, HttpApiAuth, add_chunks, payload, expected_code, expected_message): + payload["dataset_ids"] = [] # issues/ + if payload["name"] == "duplicated_name": + create_chat_assistant(HttpApiAuth, payload) + elif payload["name"] == "case insensitive": + create_chat_assistant(HttpApiAuth, {"name": payload["name"].upper()}) + + res = create_chat_assistant(HttpApiAuth, payload) + assert res["code"] == expected_code, res + if expected_code == 0: + assert res["data"]["name"] == payload["name"] + else: + assert res["message"] == expected_message + + @pytest.mark.p1 + @pytest.mark.parametrize( + "dataset_ids, expected_code, expected_message", + [ + ([], 0, ""), + (lambda r: [r], 0, ""), + (["invalid_dataset_id"], 102, "You don't own the dataset invalid_dataset_id"), + ("invalid_dataset_id", 102, "You don't own the dataset i"), + ], + ) + def test_dataset_ids(self, HttpApiAuth, add_chunks, dataset_ids, expected_code, expected_message): + dataset_id, _, _ = add_chunks + payload = {"name": "ragflow test"} + if callable(dataset_ids): + payload["dataset_ids"] = dataset_ids(dataset_id) + else: + payload["dataset_ids"] = dataset_ids + + res = create_chat_assistant(HttpApiAuth, payload) + assert res["code"] == expected_code, res + if expected_code == 0: + assert res["data"]["name"] == payload["name"] + else: + assert res["message"] == expected_message + + @pytest.mark.p3 + def test_avatar(self, HttpApiAuth, tmp_path): + fn = create_image_file(tmp_path / "ragflow_test.png") + payload = {"name": "avatar_test", "avatar": encode_avatar(fn), "dataset_ids": []} + res = create_chat_assistant(HttpApiAuth, payload) + assert res["code"] == 0 + + @pytest.mark.p2 + @pytest.mark.parametrize( + "llm, expected_code, expected_message", + [ + ({}, 0, ""), + ({"model_name": "glm-4"}, 0, ""), + ({"model_name": "unknown"}, 102, "`model_name` unknown doesn't exist"), + ({"temperature": 0}, 0, ""), + ({"temperature": 1}, 0, ""), + pytest.param({"temperature": -1}, 0, "", marks=pytest.mark.skip), + pytest.param({"temperature": 10}, 0, "", marks=pytest.mark.skip), + pytest.param({"temperature": "a"}, 0, "", marks=pytest.mark.skip), + ({"top_p": 0}, 0, ""), + ({"top_p": 1}, 0, ""), + pytest.param({"top_p": -1}, 0, "", marks=pytest.mark.skip), + pytest.param({"top_p": 10}, 0, "", marks=pytest.mark.skip), + pytest.param({"top_p": "a"}, 0, "", marks=pytest.mark.skip), + ({"presence_penalty": 0}, 0, ""), + ({"presence_penalty": 1}, 0, ""), + pytest.param({"presence_penalty": -1}, 0, "", marks=pytest.mark.skip), + pytest.param({"presence_penalty": 10}, 0, "", marks=pytest.mark.skip), + pytest.param({"presence_penalty": "a"}, 0, "", marks=pytest.mark.skip), + ({"frequency_penalty": 0}, 0, ""), + ({"frequency_penalty": 1}, 0, ""), + pytest.param({"frequency_penalty": -1}, 0, "", marks=pytest.mark.skip), + pytest.param({"frequency_penalty": 10}, 0, "", marks=pytest.mark.skip), + pytest.param({"frequency_penalty": "a"}, 0, "", marks=pytest.mark.skip), + ({"max_token": 0}, 0, ""), + ({"max_token": 1024}, 0, ""), + pytest.param({"max_token": -1}, 0, "", marks=pytest.mark.skip), + pytest.param({"max_token": 10}, 0, "", marks=pytest.mark.skip), + pytest.param({"max_token": "a"}, 0, "", marks=pytest.mark.skip), + pytest.param({"unknown": "unknown"}, 0, "", marks=pytest.mark.skip), + ], + ) + def test_llm(self, HttpApiAuth, add_chunks, llm, expected_code, expected_message): + dataset_id, _, _ = add_chunks + payload = {"name": "llm_test", "dataset_ids": [dataset_id], "llm": llm} + res = create_chat_assistant(HttpApiAuth, payload) + assert res["code"] == expected_code + if expected_code == 0: + if llm: + for k, v in llm.items(): + assert res["data"]["llm"][k] == v + else: + assert res["data"]["llm"]["model_name"] == "glm-4-flash@ZHIPU-AI" + assert res["data"]["llm"]["temperature"] == 0.1 + assert res["data"]["llm"]["top_p"] == 0.3 + assert res["data"]["llm"]["presence_penalty"] == 0.4 + assert res["data"]["llm"]["frequency_penalty"] == 0.7 + assert res["data"]["llm"]["max_tokens"] == 512 + else: + assert res["message"] == expected_message + + @pytest.mark.p2 + @pytest.mark.parametrize( + "prompt, expected_code, expected_message", + [ + ({}, 0, ""), + ({"similarity_threshold": 0}, 0, ""), + ({"similarity_threshold": 1}, 0, ""), + pytest.param({"similarity_threshold": -1}, 0, "", marks=pytest.mark.skip), + pytest.param({"similarity_threshold": 10}, 0, "", marks=pytest.mark.skip), + pytest.param({"similarity_threshold": "a"}, 0, "", marks=pytest.mark.skip), + ({"keywords_similarity_weight": 0}, 0, ""), + ({"keywords_similarity_weight": 1}, 0, ""), + pytest.param({"keywords_similarity_weight": -1}, 0, "", marks=pytest.mark.skip), + pytest.param({"keywords_similarity_weight": 10}, 0, "", marks=pytest.mark.skip), + pytest.param({"keywords_similarity_weight": "a"}, 0, "", marks=pytest.mark.skip), + ({"variables": []}, 0, ""), + ({"top_n": 0}, 0, ""), + ({"top_n": 1}, 0, ""), + pytest.param({"top_n": -1}, 0, "", marks=pytest.mark.skip), + pytest.param({"top_n": 10}, 0, "", marks=pytest.mark.skip), + pytest.param({"top_n": "a"}, 0, "", marks=pytest.mark.skip), + ({"empty_response": "Hello World"}, 0, ""), + ({"empty_response": ""}, 0, ""), + ({"empty_response": "!@#$%^&*()"}, 0, ""), + ({"empty_response": "中文测试"}, 0, ""), + pytest.param({"empty_response": 123}, 0, "", marks=pytest.mark.skip), + pytest.param({"empty_response": True}, 0, "", marks=pytest.mark.skip), + pytest.param({"empty_response": " "}, 0, "", marks=pytest.mark.skip), + ({"opener": "Hello World"}, 0, ""), + ({"opener": ""}, 0, ""), + ({"opener": "!@#$%^&*()"}, 0, ""), + ({"opener": "中文测试"}, 0, ""), + pytest.param({"opener": 123}, 0, "", marks=pytest.mark.skip), + pytest.param({"opener": True}, 0, "", marks=pytest.mark.skip), + pytest.param({"opener": " "}, 0, "", marks=pytest.mark.skip), + ({"show_quote": True}, 0, ""), + ({"show_quote": False}, 0, ""), + ({"prompt": "Hello World {knowledge}"}, 0, ""), + ({"prompt": "{knowledge}"}, 0, ""), + ({"prompt": "!@#$%^&*() {knowledge}"}, 0, ""), + ({"prompt": "中文测试 {knowledge}"}, 0, ""), + ({"prompt": "Hello World"}, 102, "Parameter 'knowledge' is not used"), + ({"prompt": "Hello World", "variables": []}, 0, ""), + pytest.param({"prompt": 123}, 100, """AttributeError("\'int\' object has no attribute \'find\'")""", marks=pytest.mark.skip), + pytest.param({"prompt": True}, 100, """AttributeError("\'int\' object has no attribute \'find\'")""", marks=pytest.mark.skip), + pytest.param({"unknown": "unknown"}, 0, "", marks=pytest.mark.skip), + ], + ) + def test_prompt(self, HttpApiAuth, add_chunks, prompt, expected_code, expected_message): + dataset_id, _, _ = add_chunks + payload = {"name": "prompt_test", "dataset_ids": [dataset_id], "prompt": prompt} + res = create_chat_assistant(HttpApiAuth, payload) + assert res["code"] == expected_code + if expected_code == 0: + if prompt: + for k, v in prompt.items(): + if k == "keywords_similarity_weight": + assert res["data"]["prompt"][k] == 1 - v + else: + assert res["data"]["prompt"][k] == v + else: + assert res["data"]["prompt"]["similarity_threshold"] == 0.2 + assert res["data"]["prompt"]["keywords_similarity_weight"] == 0.7 + assert res["data"]["prompt"]["top_n"] == 6 + assert res["data"]["prompt"]["variables"] == [{"key": "knowledge", "optional": False}] + assert res["data"]["prompt"]["rerank_model"] == "" + assert res["data"]["prompt"]["empty_response"] == "Sorry! No relevant content was found in the knowledge base!" + assert res["data"]["prompt"]["opener"] == "Hi! I'm your assistant, what can I do for you?" + assert res["data"]["prompt"]["show_quote"] is True + assert ( + res["data"]["prompt"]["prompt"] + == 'You are an intelligent assistant. Please summarize the content of the knowledge base to answer the question. Please list the data in the knowledge base and answer in detail. When all knowledge base content is irrelevant to the question, your answer must include the sentence "The answer you are looking for is not found in the knowledge base!" Answers need to consider chat history.\n Here is the knowledge base:\n {knowledge}\n The above is the knowledge base.' + ) + else: + assert res["message"] == expected_message + + +class TestChatAssistantCreate2: + @pytest.mark.p2 + def test_unparsed_document(self, HttpApiAuth, add_document): + dataset_id, _ = add_document + payload = {"name": "prompt_test", "dataset_ids": [dataset_id]} + res = create_chat_assistant(HttpApiAuth, payload) + assert res["code"] == 102 + assert "doesn't own parsed file" in res["message"] diff --git a/test/testcases/test_http_api/test_chat_assistant_management/test_delete_chat_assistants.py b/test/testcases/test_http_api/test_chat_assistant_management/test_delete_chat_assistants.py new file mode 100644 index 00000000000..2a2fdc9a6a5 --- /dev/null +++ b/test/testcases/test_http_api/test_chat_assistant_management/test_delete_chat_assistants.py @@ -0,0 +1,127 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from concurrent.futures import ThreadPoolExecutor, as_completed + +import pytest +from common import batch_create_chat_assistants, delete_chat_assistants, list_chat_assistants +from configs import INVALID_API_TOKEN +from libs.auth import RAGFlowHttpApiAuth + + +@pytest.mark.p1 +class TestAuthorization: + @pytest.mark.parametrize( + "invalid_auth, expected_code, expected_message", + [ + (None, 0, "`Authorization` can't be empty"), + ( + RAGFlowHttpApiAuth(INVALID_API_TOKEN), + 109, + "Authentication error: API key is invalid!", + ), + ], + ) + def test_invalid_auth(self, invalid_auth, expected_code, expected_message): + res = delete_chat_assistants(invalid_auth) + assert res["code"] == expected_code + assert res["message"] == expected_message + + +class TestChatAssistantsDelete: + @pytest.mark.parametrize( + "payload, expected_code, expected_message, remaining", + [ + pytest.param(None, 0, "", 0, marks=pytest.mark.p3), + pytest.param({"ids": []}, 0, "", 0, marks=pytest.mark.p3), + pytest.param({"ids": ["invalid_id"]}, 102, "Assistant(invalid_id) not found.", 5, marks=pytest.mark.p3), + pytest.param({"ids": ["\n!?。;!?\"'"]}, 102, """Assistant(\n!?。;!?"\') not found.""", 5, marks=pytest.mark.p3), + pytest.param("not json", 100, "AttributeError(\"'str' object has no attribute 'get'\")", 5, marks=pytest.mark.p3), + pytest.param(lambda r: {"ids": r[:1]}, 0, "", 4, marks=pytest.mark.p3), + pytest.param(lambda r: {"ids": r}, 0, "", 0, marks=pytest.mark.p1), + ], + ) + def test_basic_scenarios(self, HttpApiAuth, add_chat_assistants_func, payload, expected_code, expected_message, remaining): + _, _, chat_assistant_ids = add_chat_assistants_func + if callable(payload): + payload = payload(chat_assistant_ids) + res = delete_chat_assistants(HttpApiAuth, payload) + assert res["code"] == expected_code + if res["code"] != 0: + assert res["message"] == expected_message + + res = list_chat_assistants(HttpApiAuth) + assert len(res["data"]) == remaining + + @pytest.mark.parametrize( + "payload", + [ + pytest.param(lambda r: {"ids": ["invalid_id"] + r}, marks=pytest.mark.p3), + pytest.param(lambda r: {"ids": r[:1] + ["invalid_id"] + r[1:5]}, marks=pytest.mark.p1), + pytest.param(lambda r: {"ids": r + ["invalid_id"]}, marks=pytest.mark.p3), + ], + ) + def test_delete_partial_invalid_id(self, HttpApiAuth, add_chat_assistants_func, payload): + _, _, chat_assistant_ids = add_chat_assistants_func + if callable(payload): + payload = payload(chat_assistant_ids) + res = delete_chat_assistants(HttpApiAuth, payload) + assert res["code"] == 0 + assert res["data"]["errors"][0] == "Assistant(invalid_id) not found." + assert res["data"]["success_count"] == 5 + + res = list_chat_assistants(HttpApiAuth) + assert len(res["data"]) == 0 + + @pytest.mark.p3 + def test_repeated_deletion(self, HttpApiAuth, add_chat_assistants_func): + _, _, chat_assistant_ids = add_chat_assistants_func + res = delete_chat_assistants(HttpApiAuth, {"ids": chat_assistant_ids}) + assert res["code"] == 0 + + res = delete_chat_assistants(HttpApiAuth, {"ids": chat_assistant_ids}) + assert res["code"] == 102 + assert "not found" in res["message"] + + @pytest.mark.p3 + def test_duplicate_deletion(self, HttpApiAuth, add_chat_assistants_func): + _, _, chat_assistant_ids = add_chat_assistants_func + res = delete_chat_assistants(HttpApiAuth, {"ids": chat_assistant_ids + chat_assistant_ids}) + assert res["code"] == 0 + assert "Duplicate assistant ids" in res["data"]["errors"][0] + assert res["data"]["success_count"] == 5 + + res = list_chat_assistants(HttpApiAuth) + assert res["code"] == 0 + + @pytest.mark.p3 + def test_concurrent_deletion(self, HttpApiAuth): + count = 100 + ids = batch_create_chat_assistants(HttpApiAuth, count) + + with ThreadPoolExecutor(max_workers=5) as executor: + futures = [executor.submit(delete_chat_assistants, HttpApiAuth, {"ids": ids[i : i + 1]}) for i in range(count)] + responses = list(as_completed(futures)) + assert len(responses) == count, responses + assert all(future.result()["code"] == 0 for future in futures) + + @pytest.mark.p3 + def test_delete_10k(self, HttpApiAuth): + ids = batch_create_chat_assistants(HttpApiAuth, 1_000) + res = delete_chat_assistants(HttpApiAuth, {"ids": ids}) + assert res["code"] == 0 + + res = list_chat_assistants(HttpApiAuth) + assert len(res["data"]) == 0 diff --git a/test/testcases/test_http_api/test_chat_assistant_management/test_list_chat_assistants.py b/test/testcases/test_http_api/test_chat_assistant_management/test_list_chat_assistants.py new file mode 100644 index 00000000000..20bce689eea --- /dev/null +++ b/test/testcases/test_http_api/test_chat_assistant_management/test_list_chat_assistants.py @@ -0,0 +1,314 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from concurrent.futures import ThreadPoolExecutor, as_completed + +import pytest +from common import delete_datasets, list_chat_assistants +from configs import INVALID_API_TOKEN +from libs.auth import RAGFlowHttpApiAuth +from utils import is_sorted + + +@pytest.mark.p1 +class TestAuthorization: + @pytest.mark.parametrize( + "invalid_auth, expected_code, expected_message", + [ + (None, 0, "`Authorization` can't be empty"), + ( + RAGFlowHttpApiAuth(INVALID_API_TOKEN), + 109, + "Authentication error: API key is invalid!", + ), + ], + ) + def test_invalid_auth(self, invalid_auth, expected_code, expected_message): + res = list_chat_assistants(invalid_auth) + assert res["code"] == expected_code + assert res["message"] == expected_message + + +@pytest.mark.usefixtures("add_chat_assistants") +class TestChatAssistantsList: + @pytest.mark.p1 + def test_default(self, HttpApiAuth): + res = list_chat_assistants(HttpApiAuth) + assert res["code"] == 0 + assert len(res["data"]) == 5 + + @pytest.mark.p1 + @pytest.mark.parametrize( + "params, expected_code, expected_page_size, expected_message", + [ + ({"page": None, "page_size": 2}, 0, 2, ""), + ({"page": 0, "page_size": 2}, 0, 2, ""), + ({"page": 2, "page_size": 2}, 0, 2, ""), + ({"page": 3, "page_size": 2}, 0, 1, ""), + ({"page": "3", "page_size": 2}, 0, 1, ""), + pytest.param( + {"page": -1, "page_size": 2}, + 100, + 0, + "1064", + marks=pytest.mark.skip(reason="issues/5851"), + ), + pytest.param( + {"page": "a", "page_size": 2}, + 100, + 0, + """ValueError("invalid literal for int() with base 10: \'a\'")""", + marks=pytest.mark.skip(reason="issues/5851"), + ), + ], + ) + def test_page(self, HttpApiAuth, params, expected_code, expected_page_size, expected_message): + res = list_chat_assistants(HttpApiAuth, params=params) + assert res["code"] == expected_code + if expected_code == 0: + assert len(res["data"]) == expected_page_size + else: + assert res["message"] == expected_message + + @pytest.mark.p1 + @pytest.mark.parametrize( + "params, expected_code, expected_page_size, expected_message", + [ + ({"page_size": None}, 0, 5, ""), + ({"page_size": 0}, 0, 0, ""), + ({"page_size": 1}, 0, 1, ""), + ({"page_size": 6}, 0, 5, ""), + ({"page_size": "1"}, 0, 1, ""), + pytest.param( + {"page_size": -1}, + 100, + 0, + "1064", + marks=pytest.mark.skip(reason="issues/5851"), + ), + pytest.param( + {"page_size": "a"}, + 100, + 0, + """ValueError("invalid literal for int() with base 10: \'a\'")""", + marks=pytest.mark.skip(reason="issues/5851"), + ), + ], + ) + def test_page_size( + self, + HttpApiAuth, + params, + expected_code, + expected_page_size, + expected_message, + ): + res = list_chat_assistants(HttpApiAuth, params=params) + assert res["code"] == expected_code + if expected_code == 0: + assert len(res["data"]) == expected_page_size + else: + assert res["message"] == expected_message + + @pytest.mark.p3 + @pytest.mark.parametrize( + "params, expected_code, assertions, expected_message", + [ + ({"orderby": None}, 0, lambda r: (is_sorted(r["data"], "create_time", True)), ""), + ({"orderby": "create_time"}, 0, lambda r: (is_sorted(r["data"], "create_time", True)), ""), + ({"orderby": "update_time"}, 0, lambda r: (is_sorted(r["data"], "update_time", True)), ""), + pytest.param( + {"orderby": "name", "desc": "False"}, + 0, + lambda r: (is_sorted(r["data"], "name", False)), + "", + marks=pytest.mark.skip(reason="issues/5851"), + ), + pytest.param( + {"orderby": "unknown"}, + 102, + 0, + "orderby should be create_time or update_time", + marks=pytest.mark.skip(reason="issues/5851"), + ), + ], + ) + def test_orderby( + self, + HttpApiAuth, + params, + expected_code, + assertions, + expected_message, + ): + res = list_chat_assistants(HttpApiAuth, params=params) + assert res["code"] == expected_code + if expected_code == 0: + if callable(assertions): + assert assertions(res) + else: + assert res["message"] == expected_message + + @pytest.mark.p3 + @pytest.mark.parametrize( + "params, expected_code, assertions, expected_message", + [ + ({"desc": None}, 0, lambda r: (is_sorted(r["data"], "create_time", True)), ""), + ({"desc": "true"}, 0, lambda r: (is_sorted(r["data"], "create_time", True)), ""), + ({"desc": "True"}, 0, lambda r: (is_sorted(r["data"], "create_time", True)), ""), + ({"desc": True}, 0, lambda r: (is_sorted(r["data"], "create_time", True)), ""), + ({"desc": "false"}, 0, lambda r: (is_sorted(r["data"], "create_time", False)), ""), + ({"desc": "False"}, 0, lambda r: (is_sorted(r["data"], "create_time", False)), ""), + ({"desc": False}, 0, lambda r: (is_sorted(r["data"], "create_time", False)), ""), + ({"desc": "False", "orderby": "update_time"}, 0, lambda r: (is_sorted(r["data"], "update_time", False)), ""), + pytest.param( + {"desc": "unknown"}, + 102, + 0, + "desc should be true or false", + marks=pytest.mark.skip(reason="issues/5851"), + ), + ], + ) + def test_desc( + self, + HttpApiAuth, + params, + expected_code, + assertions, + expected_message, + ): + res = list_chat_assistants(HttpApiAuth, params=params) + assert res["code"] == expected_code + if expected_code == 0: + if callable(assertions): + assert assertions(res) + else: + assert res["message"] == expected_message + + @pytest.mark.p1 + @pytest.mark.parametrize( + "params, expected_code, expected_num, expected_message", + [ + ({"name": None}, 0, 5, ""), + ({"name": ""}, 0, 5, ""), + ({"name": "test_chat_assistant_1"}, 0, 1, ""), + ({"name": "unknown"}, 102, 0, "The chat doesn't exist"), + ], + ) + def test_name(self, HttpApiAuth, params, expected_code, expected_num, expected_message): + res = list_chat_assistants(HttpApiAuth, params=params) + assert res["code"] == expected_code + if expected_code == 0: + if params["name"] in [None, ""]: + assert len(res["data"]) == expected_num + else: + assert res["data"][0]["name"] == params["name"] + else: + assert res["message"] == expected_message + + @pytest.mark.p1 + @pytest.mark.parametrize( + "chat_assistant_id, expected_code, expected_num, expected_message", + [ + (None, 0, 5, ""), + ("", 0, 5, ""), + (lambda r: r[0], 0, 1, ""), + ("unknown", 102, 0, "The chat doesn't exist"), + ], + ) + def test_id( + self, + HttpApiAuth, + add_chat_assistants, + chat_assistant_id, + expected_code, + expected_num, + expected_message, + ): + _, _, chat_assistant_ids = add_chat_assistants + if callable(chat_assistant_id): + params = {"id": chat_assistant_id(chat_assistant_ids)} + else: + params = {"id": chat_assistant_id} + + res = list_chat_assistants(HttpApiAuth, params=params) + assert res["code"] == expected_code + if expected_code == 0: + if params["id"] in [None, ""]: + assert len(res["data"]) == expected_num + else: + assert res["data"][0]["id"] == params["id"] + else: + assert res["message"] == expected_message + + @pytest.mark.p3 + @pytest.mark.parametrize( + "chat_assistant_id, name, expected_code, expected_num, expected_message", + [ + (lambda r: r[0], "test_chat_assistant_0", 0, 1, ""), + (lambda r: r[0], "test_chat_assistant_1", 102, 0, "The chat doesn't exist"), + (lambda r: r[0], "unknown", 102, 0, "The chat doesn't exist"), + ("id", "chat_assistant_0", 102, 0, "The chat doesn't exist"), + ], + ) + def test_name_and_id( + self, + HttpApiAuth, + add_chat_assistants, + chat_assistant_id, + name, + expected_code, + expected_num, + expected_message, + ): + _, _, chat_assistant_ids = add_chat_assistants + if callable(chat_assistant_id): + params = {"id": chat_assistant_id(chat_assistant_ids), "name": name} + else: + params = {"id": chat_assistant_id, "name": name} + + res = list_chat_assistants(HttpApiAuth, params=params) + assert res["code"] == expected_code + if expected_code == 0: + assert len(res["data"]) == expected_num + else: + assert res["message"] == expected_message + + @pytest.mark.p3 + def test_concurrent_list(self, HttpApiAuth): + count = 100 + with ThreadPoolExecutor(max_workers=5) as executor: + futures = [executor.submit(list_chat_assistants, HttpApiAuth) for i in range(count)] + responses = list(as_completed(futures)) + assert len(responses) == count, responses + assert all(future.result()["code"] == 0 for future in futures) + + @pytest.mark.p3 + def test_invalid_params(self, HttpApiAuth): + params = {"a": "b"} + res = list_chat_assistants(HttpApiAuth, params=params) + assert res["code"] == 0 + assert len(res["data"]) == 5 + + @pytest.mark.p2 + def test_list_chats_after_deleting_associated_dataset(self, HttpApiAuth, add_chat_assistants): + dataset_id, _, _ = add_chat_assistants + res = delete_datasets(HttpApiAuth, {"ids": [dataset_id]}) + assert res["code"] == 0 + + res = list_chat_assistants(HttpApiAuth) + assert res["code"] == 0 + assert len(res["data"]) == 5 diff --git a/test/testcases/test_http_api/test_chat_assistant_management/test_update_chat_assistant.py b/test/testcases/test_http_api/test_chat_assistant_management/test_update_chat_assistant.py new file mode 100644 index 00000000000..54a16131933 --- /dev/null +++ b/test/testcases/test_http_api/test_chat_assistant_management/test_update_chat_assistant.py @@ -0,0 +1,229 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import pytest +from common import list_chat_assistants, update_chat_assistant +from configs import CHAT_ASSISTANT_NAME_LIMIT, INVALID_API_TOKEN +from libs.auth import RAGFlowHttpApiAuth +from utils import encode_avatar +from utils.file_utils import create_image_file + + +@pytest.mark.p1 +class TestAuthorization: + @pytest.mark.parametrize( + "invalid_auth, expected_code, expected_message", + [ + (None, 0, "`Authorization` can't be empty"), + ( + RAGFlowHttpApiAuth(INVALID_API_TOKEN), + 109, + "Authentication error: API key is invalid!", + ), + ], + ) + def test_invalid_auth(self, invalid_auth, expected_code, expected_message): + res = update_chat_assistant(invalid_auth, "chat_assistant_id") + assert res["code"] == expected_code + assert res["message"] == expected_message + + +class TestChatAssistantUpdate: + @pytest.mark.parametrize( + "payload, expected_code, expected_message", + [ + pytest.param({"name": "valid_name"}, 0, "", marks=pytest.mark.p1), + pytest.param({"name": "a" * (CHAT_ASSISTANT_NAME_LIMIT + 1)}, 102, "", marks=pytest.mark.skip(reason="issues/")), + pytest.param({"name": 1}, 100, "", marks=pytest.mark.skip(reason="issues/")), + pytest.param({"name": ""}, 102, "`name` cannot be empty.", marks=pytest.mark.p3), + pytest.param({"name": "test_chat_assistant_1"}, 102, "Duplicated chat name in updating chat.", marks=pytest.mark.p3), + pytest.param({"name": "TEST_CHAT_ASSISTANT_1"}, 102, "Duplicated chat name in updating chat.", marks=pytest.mark.p3), + ], + ) + def test_name(self, HttpApiAuth, add_chat_assistants_func, payload, expected_code, expected_message): + _, _, chat_assistant_ids = add_chat_assistants_func + + res = update_chat_assistant(HttpApiAuth, chat_assistant_ids[0], payload) + assert res["code"] == expected_code, res + if expected_code == 0: + res = list_chat_assistants(HttpApiAuth, {"id": chat_assistant_ids[0]}) + assert res["data"][0]["name"] == payload.get("name") + else: + assert res["message"] == expected_message + + @pytest.mark.parametrize( + "dataset_ids, expected_code, expected_message", + [ + pytest.param([], 0, "", marks=pytest.mark.skip(reason="issues/")), + pytest.param(lambda r: [r], 0, "", marks=pytest.mark.p1), + pytest.param(["invalid_dataset_id"], 102, "You don't own the dataset invalid_dataset_id", marks=pytest.mark.p3), + pytest.param("invalid_dataset_id", 102, "You don't own the dataset i", marks=pytest.mark.p3), + ], + ) + def test_dataset_ids(self, HttpApiAuth, add_chat_assistants_func, dataset_ids, expected_code, expected_message): + dataset_id, _, chat_assistant_ids = add_chat_assistants_func + payload = {"name": "ragflow test"} + if callable(dataset_ids): + payload["dataset_ids"] = dataset_ids(dataset_id) + else: + payload["dataset_ids"] = dataset_ids + + res = update_chat_assistant(HttpApiAuth, chat_assistant_ids[0], payload) + assert res["code"] == expected_code, res + if expected_code == 0: + res = list_chat_assistants(HttpApiAuth, {"id": chat_assistant_ids[0]}) + assert res["data"][0]["name"] == payload.get("name") + else: + assert res["message"] == expected_message + + @pytest.mark.p3 + def test_avatar(self, HttpApiAuth, add_chat_assistants_func, tmp_path): + dataset_id, _, chat_assistant_ids = add_chat_assistants_func + fn = create_image_file(tmp_path / "ragflow_test.png") + payload = {"name": "avatar_test", "avatar": encode_avatar(fn), "dataset_ids": [dataset_id]} + res = update_chat_assistant(HttpApiAuth, chat_assistant_ids[0], payload) + assert res["code"] == 0 + + @pytest.mark.p3 + @pytest.mark.parametrize( + "llm, expected_code, expected_message", + [ + ({}, 100, "ValueError"), + ({"model_name": "glm-4"}, 0, ""), + ({"model_name": "unknown"}, 102, "`model_name` unknown doesn't exist"), + ({"temperature": 0}, 0, ""), + ({"temperature": 1}, 0, ""), + pytest.param({"temperature": -1}, 0, "", marks=pytest.mark.skip), + pytest.param({"temperature": 10}, 0, "", marks=pytest.mark.skip), + pytest.param({"temperature": "a"}, 0, "", marks=pytest.mark.skip), + ({"top_p": 0}, 0, ""), + ({"top_p": 1}, 0, ""), + pytest.param({"top_p": -1}, 0, "", marks=pytest.mark.skip), + pytest.param({"top_p": 10}, 0, "", marks=pytest.mark.skip), + pytest.param({"top_p": "a"}, 0, "", marks=pytest.mark.skip), + ({"presence_penalty": 0}, 0, ""), + ({"presence_penalty": 1}, 0, ""), + pytest.param({"presence_penalty": -1}, 0, "", marks=pytest.mark.skip), + pytest.param({"presence_penalty": 10}, 0, "", marks=pytest.mark.skip), + pytest.param({"presence_penalty": "a"}, 0, "", marks=pytest.mark.skip), + ({"frequency_penalty": 0}, 0, ""), + ({"frequency_penalty": 1}, 0, ""), + pytest.param({"frequency_penalty": -1}, 0, "", marks=pytest.mark.skip), + pytest.param({"frequency_penalty": 10}, 0, "", marks=pytest.mark.skip), + pytest.param({"frequency_penalty": "a"}, 0, "", marks=pytest.mark.skip), + ({"max_token": 0}, 0, ""), + ({"max_token": 1024}, 0, ""), + pytest.param({"max_token": -1}, 0, "", marks=pytest.mark.skip), + pytest.param({"max_token": 10}, 0, "", marks=pytest.mark.skip), + pytest.param({"max_token": "a"}, 0, "", marks=pytest.mark.skip), + pytest.param({"unknown": "unknown"}, 0, "", marks=pytest.mark.skip), + ], + ) + def test_llm(self, HttpApiAuth, add_chat_assistants_func, llm, expected_code, expected_message): + dataset_id, _, chat_assistant_ids = add_chat_assistants_func + payload = {"name": "llm_test", "dataset_ids": [dataset_id], "llm": llm} + res = update_chat_assistant(HttpApiAuth, chat_assistant_ids[0], payload) + assert res["code"] == expected_code + if expected_code == 0: + res = list_chat_assistants(HttpApiAuth, {"id": chat_assistant_ids[0]}) + if llm: + for k, v in llm.items(): + assert res["data"][0]["llm"][k] == v + else: + assert res["data"][0]["llm"]["model_name"] == "glm-4-flash@ZHIPU-AI" + assert res["data"][0]["llm"]["temperature"] == 0.1 + assert res["data"][0]["llm"]["top_p"] == 0.3 + assert res["data"][0]["llm"]["presence_penalty"] == 0.4 + assert res["data"][0]["llm"]["frequency_penalty"] == 0.7 + assert res["data"][0]["llm"]["max_tokens"] == 512 + else: + assert expected_message in res["message"] + + @pytest.mark.p3 + @pytest.mark.parametrize( + "prompt, expected_code, expected_message", + [ + ({}, 100, "ValueError"), + ({"similarity_threshold": 0}, 0, ""), + ({"similarity_threshold": 1}, 0, ""), + pytest.param({"similarity_threshold": -1}, 0, "", marks=pytest.mark.skip), + pytest.param({"similarity_threshold": 10}, 0, "", marks=pytest.mark.skip), + pytest.param({"similarity_threshold": "a"}, 0, "", marks=pytest.mark.skip), + ({"keywords_similarity_weight": 0}, 0, ""), + ({"keywords_similarity_weight": 1}, 0, ""), + pytest.param({"keywords_similarity_weight": -1}, 0, "", marks=pytest.mark.skip), + pytest.param({"keywords_similarity_weight": 10}, 0, "", marks=pytest.mark.skip), + pytest.param({"keywords_similarity_weight": "a"}, 0, "", marks=pytest.mark.skip), + ({"variables": []}, 0, ""), + ({"top_n": 0}, 0, ""), + ({"top_n": 1}, 0, ""), + pytest.param({"top_n": -1}, 0, "", marks=pytest.mark.skip), + pytest.param({"top_n": 10}, 0, "", marks=pytest.mark.skip), + pytest.param({"top_n": "a"}, 0, "", marks=pytest.mark.skip), + ({"empty_response": "Hello World"}, 0, ""), + ({"empty_response": ""}, 0, ""), + ({"empty_response": "!@#$%^&*()"}, 0, ""), + ({"empty_response": "中文测试"}, 0, ""), + pytest.param({"empty_response": 123}, 0, "", marks=pytest.mark.skip), + pytest.param({"empty_response": True}, 0, "", marks=pytest.mark.skip), + pytest.param({"empty_response": " "}, 0, "", marks=pytest.mark.skip), + ({"opener": "Hello World"}, 0, ""), + ({"opener": ""}, 0, ""), + ({"opener": "!@#$%^&*()"}, 0, ""), + ({"opener": "中文测试"}, 0, ""), + pytest.param({"opener": 123}, 0, "", marks=pytest.mark.skip), + pytest.param({"opener": True}, 0, "", marks=pytest.mark.skip), + pytest.param({"opener": " "}, 0, "", marks=pytest.mark.skip), + ({"show_quote": True}, 0, ""), + ({"show_quote": False}, 0, ""), + ({"prompt": "Hello World {knowledge}"}, 0, ""), + ({"prompt": "{knowledge}"}, 0, ""), + ({"prompt": "!@#$%^&*() {knowledge}"}, 0, ""), + ({"prompt": "中文测试 {knowledge}"}, 0, ""), + ({"prompt": "Hello World"}, 102, "Parameter 'knowledge' is not used"), + ({"prompt": "Hello World", "variables": []}, 0, ""), + pytest.param({"prompt": 123}, 100, """AttributeError("\'int\' object has no attribute \'find\'")""", marks=pytest.mark.skip), + pytest.param({"prompt": True}, 100, """AttributeError("\'int\' object has no attribute \'find\'")""", marks=pytest.mark.skip), + pytest.param({"unknown": "unknown"}, 0, "", marks=pytest.mark.skip), + ], + ) + def test_prompt(self, HttpApiAuth, add_chat_assistants_func, prompt, expected_code, expected_message): + dataset_id, _, chat_assistant_ids = add_chat_assistants_func + payload = {"name": "prompt_test", "dataset_ids": [dataset_id], "prompt": prompt} + res = update_chat_assistant(HttpApiAuth, chat_assistant_ids[0], payload) + assert res["code"] == expected_code + if expected_code == 0: + res = list_chat_assistants(HttpApiAuth, {"id": chat_assistant_ids[0]}) + if prompt: + for k, v in prompt.items(): + if k == "keywords_similarity_weight": + assert res["data"][0]["prompt"][k] == 1 - v + else: + assert res["data"][0]["prompt"][k] == v + else: + assert res["data"]["prompt"][0]["similarity_threshold"] == 0.2 + assert res["data"]["prompt"][0]["keywords_similarity_weight"] == 0.7 + assert res["data"]["prompt"][0]["top_n"] == 6 + assert res["data"]["prompt"][0]["variables"] == [{"key": "knowledge", "optional": False}] + assert res["data"]["prompt"][0]["rerank_model"] == "" + assert res["data"]["prompt"][0]["empty_response"] == "Sorry! No relevant content was found in the knowledge base!" + assert res["data"]["prompt"][0]["opener"] == "Hi! I'm your assistant, what can I do for you?" + assert res["data"]["prompt"][0]["show_quote"] is True + assert ( + res["data"]["prompt"][0]["prompt"] + == 'You are an intelligent assistant. Please summarize the content of the knowledge base to answer the question. Please list the data in the knowledge base and answer in detail. When all knowledge base content is irrelevant to the question, your answer must include the sentence "The answer you are looking for is not found in the knowledge base!" Answers need to consider chat history.\n Here is the knowledge base:\n {knowledge}\n The above is the knowledge base.' + ) + else: + assert expected_message in res["message"] diff --git a/test/testcases/test_http_api/test_chunk_management_within_dataset/conftest.py b/test/testcases/test_http_api/test_chunk_management_within_dataset/conftest.py new file mode 100644 index 00000000000..7a06a23eb57 --- /dev/null +++ b/test/testcases/test_http_api/test_chunk_management_within_dataset/conftest.py @@ -0,0 +1,47 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +from time import sleep + +import pytest +from common import batch_add_chunks, delete_chunks, list_documents, parse_documents +from utils import wait_for + + +@wait_for(30, 1, "Document parsing timeout") +def condition(_auth, _dataset_id): + res = list_documents(_auth, _dataset_id) + for doc in res["data"]["docs"]: + if doc["run"] != "DONE": + return False + return True + + +@pytest.fixture(scope="function") +def add_chunks_func(request, HttpApiAuth, add_document): + def cleanup(): + delete_chunks(HttpApiAuth, dataset_id, document_id, {"chunk_ids": []}) + + request.addfinalizer(cleanup) + + dataset_id, document_id = add_document + parse_documents(HttpApiAuth, dataset_id, {"document_ids": [document_id]}) + condition(HttpApiAuth, dataset_id) + chunk_ids = batch_add_chunks(HttpApiAuth, dataset_id, document_id, 4) + # issues/6487 + sleep(1) + return dataset_id, document_id, chunk_ids diff --git a/test/testcases/test_http_api/test_chunk_management_within_dataset/test_add_chunk.py b/test/testcases/test_http_api/test_chunk_management_within_dataset/test_add_chunk.py new file mode 100644 index 00000000000..d46469d91cb --- /dev/null +++ b/test/testcases/test_http_api/test_chunk_management_within_dataset/test_add_chunk.py @@ -0,0 +1,252 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from concurrent.futures import ThreadPoolExecutor, as_completed + +import pytest +from common import add_chunk, delete_documents, list_chunks +from configs import INVALID_API_TOKEN +from libs.auth import RAGFlowHttpApiAuth + + +def validate_chunk_details(dataset_id, document_id, payload, res): + chunk = res["data"]["chunk"] + assert chunk["dataset_id"] == dataset_id + assert chunk["document_id"] == document_id + assert chunk["content"] == payload["content"] + if "important_keywords" in payload: + assert chunk["important_keywords"] == payload["important_keywords"] + if "questions" in payload: + assert chunk["questions"] == [str(q).strip() for q in payload.get("questions", []) if str(q).strip()] + + +@pytest.mark.p1 +class TestAuthorization: + @pytest.mark.parametrize( + "invalid_auth, expected_code, expected_message", + [ + (None, 0, "`Authorization` can't be empty"), + ( + RAGFlowHttpApiAuth(INVALID_API_TOKEN), + 109, + "Authentication error: API key is invalid!", + ), + ], + ) + def test_invalid_auth(self, invalid_auth, expected_code, expected_message): + res = add_chunk(invalid_auth, "dataset_id", "document_id") + assert res["code"] == expected_code + assert res["message"] == expected_message + + +class TestAddChunk: + @pytest.mark.p1 + @pytest.mark.parametrize( + "payload, expected_code, expected_message", + [ + ({"content": None}, 100, """TypeError("unsupported operand type(s) for +: \'NoneType\' and \'str\'")"""), + ({"content": ""}, 102, "`content` is required"), + pytest.param( + {"content": 1}, + 100, + """TypeError("unsupported operand type(s) for +: \'int\' and \'str\'")""", + marks=pytest.mark.skip, + ), + ({"content": "a"}, 0, ""), + ({"content": " "}, 102, "`content` is required"), + ({"content": "\n!?。;!?\"'"}, 0, ""), + ], + ) + def test_content(self, HttpApiAuth, add_document, payload, expected_code, expected_message): + dataset_id, document_id = add_document + res = list_chunks(HttpApiAuth, dataset_id, document_id) + if res["code"] != 0: + assert False, res + chunks_count = res["data"]["doc"]["chunk_count"] + res = add_chunk(HttpApiAuth, dataset_id, document_id, payload) + assert res["code"] == expected_code + if expected_code == 0: + validate_chunk_details(dataset_id, document_id, payload, res) + res = list_chunks(HttpApiAuth, dataset_id, document_id) + if res["code"] != 0: + assert False, res + assert res["data"]["doc"]["chunk_count"] == chunks_count + 1 + else: + assert res["message"] == expected_message + + @pytest.mark.p2 + @pytest.mark.parametrize( + "payload, expected_code, expected_message", + [ + ({"content": "chunk test", "important_keywords": ["a", "b", "c"]}, 0, ""), + ({"content": "chunk test", "important_keywords": [""]}, 0, ""), + ( + {"content": "chunk test", "important_keywords": [1]}, + 100, + "TypeError('sequence item 0: expected str instance, int found')", + ), + ({"content": "chunk test", "important_keywords": ["a", "a"]}, 0, ""), + ({"content": "chunk test", "important_keywords": "abc"}, 102, "`important_keywords` is required to be a list"), + ({"content": "chunk test", "important_keywords": 123}, 102, "`important_keywords` is required to be a list"), + ], + ) + def test_important_keywords(self, HttpApiAuth, add_document, payload, expected_code, expected_message): + dataset_id, document_id = add_document + res = list_chunks(HttpApiAuth, dataset_id, document_id) + if res["code"] != 0: + assert False, res + chunks_count = res["data"]["doc"]["chunk_count"] + res = add_chunk(HttpApiAuth, dataset_id, document_id, payload) + assert res["code"] == expected_code + if expected_code == 0: + validate_chunk_details(dataset_id, document_id, payload, res) + res = list_chunks(HttpApiAuth, dataset_id, document_id) + if res["code"] != 0: + assert False, res + assert res["data"]["doc"]["chunk_count"] == chunks_count + 1 + else: + assert res["message"] == expected_message + + @pytest.mark.p2 + @pytest.mark.parametrize( + "payload, expected_code, expected_message", + [ + ({"content": "chunk test", "questions": ["a", "b", "c"]}, 0, ""), + ({"content": "chunk test", "questions": [""]}, 0, ""), + ({"content": "chunk test", "questions": [1]}, 100, "TypeError('sequence item 0: expected str instance, int found')"), + ({"content": "chunk test", "questions": ["a", "a"]}, 0, ""), + ({"content": "chunk test", "questions": "abc"}, 102, "`questions` is required to be a list"), + ({"content": "chunk test", "questions": 123}, 102, "`questions` is required to be a list"), + ], + ) + def test_questions(self, HttpApiAuth, add_document, payload, expected_code, expected_message): + dataset_id, document_id = add_document + res = list_chunks(HttpApiAuth, dataset_id, document_id) + if res["code"] != 0: + assert False, res + chunks_count = res["data"]["doc"]["chunk_count"] + res = add_chunk(HttpApiAuth, dataset_id, document_id, payload) + assert res["code"] == expected_code + if expected_code == 0: + validate_chunk_details(dataset_id, document_id, payload, res) + if res["code"] != 0: + assert False, res + res = list_chunks(HttpApiAuth, dataset_id, document_id) + assert res["data"]["doc"]["chunk_count"] == chunks_count + 1 + else: + assert res["message"] == expected_message + + @pytest.mark.p3 + @pytest.mark.parametrize( + "dataset_id, expected_code, expected_message", + [ + ("", 100, ""), + ( + "invalid_dataset_id", + 102, + "You don't own the dataset invalid_dataset_id.", + ), + ], + ) + def test_invalid_dataset_id( + self, + HttpApiAuth, + add_document, + dataset_id, + expected_code, + expected_message, + ): + _, document_id = add_document + res = add_chunk(HttpApiAuth, dataset_id, document_id, {"content": "a"}) + assert res["code"] == expected_code + assert res["message"] == expected_message + + @pytest.mark.p3 + @pytest.mark.parametrize( + "document_id, expected_code, expected_message", + [ + ("", 100, ""), + ( + "invalid_document_id", + 102, + "You don't own the document invalid_document_id.", + ), + ], + ) + def test_invalid_document_id(self, HttpApiAuth, add_document, document_id, expected_code, expected_message): + dataset_id, _ = add_document + res = add_chunk(HttpApiAuth, dataset_id, document_id, {"content": "chunk test"}) + assert res["code"] == expected_code + assert res["message"] == expected_message + + @pytest.mark.p3 + def test_repeated_add_chunk(self, HttpApiAuth, add_document): + payload = {"content": "chunk test"} + dataset_id, document_id = add_document + res = list_chunks(HttpApiAuth, dataset_id, document_id) + if res["code"] != 0: + assert False, res + chunks_count = res["data"]["doc"]["chunk_count"] + res = add_chunk(HttpApiAuth, dataset_id, document_id, payload) + assert res["code"] == 0 + validate_chunk_details(dataset_id, document_id, payload, res) + res = list_chunks(HttpApiAuth, dataset_id, document_id) + if res["code"] != 0: + assert False, res + assert res["data"]["doc"]["chunk_count"] == chunks_count + 1 + + res = add_chunk(HttpApiAuth, dataset_id, document_id, payload) + assert res["code"] == 0 + validate_chunk_details(dataset_id, document_id, payload, res) + res = list_chunks(HttpApiAuth, dataset_id, document_id) + if res["code"] != 0: + assert False, res + assert res["data"]["doc"]["chunk_count"] == chunks_count + 2 + + @pytest.mark.p2 + def test_add_chunk_to_deleted_document(self, HttpApiAuth, add_document): + dataset_id, document_id = add_document + delete_documents(HttpApiAuth, dataset_id, {"ids": [document_id]}) + res = add_chunk(HttpApiAuth, dataset_id, document_id, {"content": "chunk test"}) + assert res["code"] == 102 + assert res["message"] == f"You don't own the document {document_id}." + + @pytest.mark.skip(reason="issues/6411") + def test_concurrent_add_chunk(self, HttpApiAuth, add_document): + count = 50 + dataset_id, document_id = add_document + res = list_chunks(HttpApiAuth, dataset_id, document_id) + if res["code"] != 0: + assert False, res + chunks_count = res["data"]["doc"]["chunk_count"] + + with ThreadPoolExecutor(max_workers=5) as executor: + futures = [ + executor.submit( + add_chunk, + HttpApiAuth, + dataset_id, + document_id, + {"content": f"chunk test {i}"}, + ) + for i in range(count) + ] + responses = list(as_completed(futures)) + assert len(responses) == count, responses + assert all(future.result()["code"] == 0 for future in futures) + res = list_chunks(HttpApiAuth, dataset_id, document_id) + if res["code"] != 0: + assert False, res + assert res["data"]["doc"]["chunk_count"] == chunks_count + count diff --git a/test/testcases/test_http_api/test_chunk_management_within_dataset/test_delete_chunks.py b/test/testcases/test_http_api/test_chunk_management_within_dataset/test_delete_chunks.py new file mode 100644 index 00000000000..69f1744e288 --- /dev/null +++ b/test/testcases/test_http_api/test_chunk_management_within_dataset/test_delete_chunks.py @@ -0,0 +1,196 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from concurrent.futures import ThreadPoolExecutor, as_completed + +import pytest +from common import batch_add_chunks, delete_chunks, list_chunks +from configs import INVALID_API_TOKEN +from libs.auth import RAGFlowHttpApiAuth + + +@pytest.mark.p1 +class TestAuthorization: + @pytest.mark.parametrize( + "invalid_auth, expected_code, expected_message", + [ + (None, 0, "`Authorization` can't be empty"), + ( + RAGFlowHttpApiAuth(INVALID_API_TOKEN), + 109, + "Authentication error: API key is invalid!", + ), + ], + ) + def test_invalid_auth(self, invalid_auth, expected_code, expected_message): + res = delete_chunks(invalid_auth, "dataset_id", "document_id") + assert res["code"] == expected_code + assert res["message"] == expected_message + + +class TestChunksDeletion: + @pytest.mark.p3 + @pytest.mark.parametrize( + "dataset_id, expected_code, expected_message", + [ + ("", 100, ""), + ( + "invalid_dataset_id", + 102, + "You don't own the dataset invalid_dataset_id.", + ), + ], + ) + def test_invalid_dataset_id(self, HttpApiAuth, add_chunks_func, dataset_id, expected_code, expected_message): + _, document_id, chunk_ids = add_chunks_func + res = delete_chunks(HttpApiAuth, dataset_id, document_id, {"chunk_ids": chunk_ids}) + assert res["code"] == expected_code + assert res["message"] == expected_message + + @pytest.mark.p3 + @pytest.mark.parametrize( + "document_id, expected_code, expected_message", + [ + ("", 100, ""), + ("invalid_document_id", 100, """LookupError("Can't find the document with ID invalid_document_id!")"""), + ], + ) + def test_invalid_document_id(self, HttpApiAuth, add_chunks_func, document_id, expected_code, expected_message): + dataset_id, _, chunk_ids = add_chunks_func + res = delete_chunks(HttpApiAuth, dataset_id, document_id, {"chunk_ids": chunk_ids}) + assert res["code"] == expected_code + assert res["message"] == expected_message + + @pytest.mark.parametrize( + "payload", + [ + pytest.param(lambda r: {"chunk_ids": ["invalid_id"] + r}, marks=pytest.mark.p3), + pytest.param(lambda r: {"chunk_ids": r[:1] + ["invalid_id"] + r[1:4]}, marks=pytest.mark.p1), + pytest.param(lambda r: {"chunk_ids": r + ["invalid_id"]}, marks=pytest.mark.p3), + ], + ) + def test_delete_partial_invalid_id(self, HttpApiAuth, add_chunks_func, payload): + dataset_id, document_id, chunk_ids = add_chunks_func + if callable(payload): + payload = payload(chunk_ids) + res = delete_chunks(HttpApiAuth, dataset_id, document_id, payload) + assert res["code"] == 102 + assert res["message"] == "rm_chunk deleted chunks 4, expect 5" + + res = list_chunks(HttpApiAuth, dataset_id, document_id) + if res["code"] != 0: + assert False, res + assert len(res["data"]["chunks"]) == 1 + assert res["data"]["total"] == 1 + + @pytest.mark.p3 + def test_repeated_deletion(self, HttpApiAuth, add_chunks_func): + dataset_id, document_id, chunk_ids = add_chunks_func + payload = {"chunk_ids": chunk_ids} + res = delete_chunks(HttpApiAuth, dataset_id, document_id, payload) + assert res["code"] == 0 + + res = delete_chunks(HttpApiAuth, dataset_id, document_id, payload) + assert res["code"] == 102 + assert res["message"] == "rm_chunk deleted chunks 0, expect 4" + + @pytest.mark.p3 + def test_duplicate_deletion(self, HttpApiAuth, add_chunks_func): + dataset_id, document_id, chunk_ids = add_chunks_func + res = delete_chunks(HttpApiAuth, dataset_id, document_id, {"chunk_ids": chunk_ids * 2}) + assert res["code"] == 0 + assert "Duplicate chunk ids" in res["data"]["errors"][0] + assert res["data"]["success_count"] == 4 + + res = list_chunks(HttpApiAuth, dataset_id, document_id) + if res["code"] != 0: + assert False, res + assert len(res["data"]["chunks"]) == 1 + assert res["data"]["total"] == 1 + + @pytest.mark.p3 + def test_concurrent_deletion(self, HttpApiAuth, add_document): + count = 100 + dataset_id, document_id = add_document + chunk_ids = batch_add_chunks(HttpApiAuth, dataset_id, document_id, count) + + with ThreadPoolExecutor(max_workers=5) as executor: + futures = [ + executor.submit( + delete_chunks, + HttpApiAuth, + dataset_id, + document_id, + {"chunk_ids": chunk_ids[i : i + 1]}, + ) + for i in range(count) + ] + responses = list(as_completed(futures)) + assert len(responses) == count, responses + assert all(future.result()["code"] == 0 for future in futures) + + @pytest.mark.p3 + def test_delete_1k(self, HttpApiAuth, add_document): + chunks_num = 1_000 + dataset_id, document_id = add_document + chunk_ids = batch_add_chunks(HttpApiAuth, dataset_id, document_id, chunks_num) + + # issues/6487 + from time import sleep + + sleep(1) + + res = delete_chunks(HttpApiAuth, dataset_id, document_id, {"chunk_ids": chunk_ids}) + assert res["code"] == 0 + + res = list_chunks(HttpApiAuth, dataset_id, document_id) + if res["code"] != 0: + assert False, res + assert len(res["data"]["chunks"]) == 0 + assert res["data"]["total"] == 0 + + @pytest.mark.parametrize( + "payload, expected_code, expected_message, remaining", + [ + pytest.param(None, 100, """TypeError("argument of type \'NoneType\' is not iterable")""", 5, marks=pytest.mark.skip), + pytest.param({"chunk_ids": ["invalid_id"]}, 102, "rm_chunk deleted chunks 0, expect 1", 5, marks=pytest.mark.p3), + pytest.param("not json", 100, """UnboundLocalError("local variable \'duplicate_messages\' referenced before assignment")""", 5, marks=pytest.mark.skip(reason="pull/6376")), + pytest.param(lambda r: {"chunk_ids": r[:1]}, 0, "", 4, marks=pytest.mark.p3), + pytest.param(lambda r: {"chunk_ids": r}, 0, "", 1, marks=pytest.mark.p1), + pytest.param({"chunk_ids": []}, 0, "", 0, marks=pytest.mark.p3), + ], + ) + def test_basic_scenarios( + self, + HttpApiAuth, + add_chunks_func, + payload, + expected_code, + expected_message, + remaining, + ): + dataset_id, document_id, chunk_ids = add_chunks_func + if callable(payload): + payload = payload(chunk_ids) + res = delete_chunks(HttpApiAuth, dataset_id, document_id, payload) + assert res["code"] == expected_code + if res["code"] != 0: + assert res["message"] == expected_message + + res = list_chunks(HttpApiAuth, dataset_id, document_id) + if res["code"] != 0: + assert False, res + assert len(res["data"]["chunks"]) == remaining + assert res["data"]["total"] == remaining diff --git a/test/testcases/test_http_api/test_chunk_management_within_dataset/test_list_chunks.py b/test/testcases/test_http_api/test_chunk_management_within_dataset/test_list_chunks.py new file mode 100644 index 00000000000..c8134214c70 --- /dev/null +++ b/test/testcases/test_http_api/test_chunk_management_within_dataset/test_list_chunks.py @@ -0,0 +1,210 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os +from concurrent.futures import ThreadPoolExecutor, as_completed + +import pytest +from common import batch_add_chunks, list_chunks +from configs import INVALID_API_TOKEN +from libs.auth import RAGFlowHttpApiAuth + + +@pytest.mark.p1 +class TestAuthorization: + @pytest.mark.parametrize( + "invalid_auth, expected_code, expected_message", + [ + (None, 0, "`Authorization` can't be empty"), + ( + RAGFlowHttpApiAuth(INVALID_API_TOKEN), + 109, + "Authentication error: API key is invalid!", + ), + ], + ) + def test_invalid_auth(self, invalid_auth, expected_code, expected_message): + res = list_chunks(invalid_auth, "dataset_id", "document_id") + assert res["code"] == expected_code + assert res["message"] == expected_message + + +class TestChunksList: + @pytest.mark.p1 + @pytest.mark.parametrize( + "params, expected_code, expected_page_size, expected_message", + [ + ({"page": None, "page_size": 2}, 0, 2, ""), + pytest.param({"page": 0, "page_size": 2}, 100, 0, "ValueError('Search does not support negative slicing.')", marks=pytest.mark.skip), + ({"page": 2, "page_size": 2}, 0, 2, ""), + ({"page": 3, "page_size": 2}, 0, 1, ""), + ({"page": "3", "page_size": 2}, 0, 1, ""), + pytest.param({"page": -1, "page_size": 2}, 100, 0, "ValueError('Search does not support negative slicing.')", marks=pytest.mark.skip), + pytest.param({"page": "a", "page_size": 2}, 100, 0, """ValueError("invalid literal for int() with base 10: \'a\'")""", marks=pytest.mark.skip), + ], + ) + def test_page(self, HttpApiAuth, add_chunks, params, expected_code, expected_page_size, expected_message): + dataset_id, document_id, _ = add_chunks + res = list_chunks(HttpApiAuth, dataset_id, document_id, params=params) + assert res["code"] == expected_code + if expected_code == 0: + assert len(res["data"]["chunks"]) == expected_page_size + else: + assert res["message"] == expected_message + + @pytest.mark.p1 + @pytest.mark.parametrize( + "params, expected_code, expected_page_size, expected_message", + [ + ({"page_size": None}, 0, 5, ""), + pytest.param({"page_size": 0}, 0, 5, "", marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="Infinity does not support page_size=0")), + pytest.param({"page_size": 0}, 100, 0, "3013", marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") in [None, "opensearch", "elasticsearch"], reason="Infinity does not support page_size=0")), + ({"page_size": 1}, 0, 1, ""), + ({"page_size": 6}, 0, 5, ""), + ({"page_size": "1"}, 0, 1, ""), + pytest.param({"page_size": -1}, 0, 5, "", marks=pytest.mark.skip), + pytest.param({"page_size": "a"}, 100, 0, """ValueError("invalid literal for int() with base 10: \'a\'")""", marks=pytest.mark.skip), + ], + ) + def test_page_size(self, HttpApiAuth, add_chunks, params, expected_code, expected_page_size, expected_message): + dataset_id, document_id, _ = add_chunks + res = list_chunks(HttpApiAuth, dataset_id, document_id, params=params) + assert res["code"] == expected_code + if expected_code == 0: + assert len(res["data"]["chunks"]) == expected_page_size + else: + assert res["message"] == expected_message + + @pytest.mark.p2 + @pytest.mark.parametrize( + "params, expected_page_size", + [ + ({"keywords": None}, 5), + ({"keywords": ""}, 5), + ({"keywords": "1"}, 1), + pytest.param({"keywords": "chunk"}, 4, marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="issues/6509")), + ({"keywords": "ragflow"}, 1), + ({"keywords": "unknown"}, 0), + ], + ) + def test_keywords(self, HttpApiAuth, add_chunks, params, expected_page_size): + dataset_id, document_id, _ = add_chunks + res = list_chunks(HttpApiAuth, dataset_id, document_id, params=params) + assert res["code"] == 0 + assert len(res["data"]["chunks"]) == expected_page_size + + @pytest.mark.p1 + @pytest.mark.parametrize( + "chunk_id, expected_code, expected_page_size, expected_message", + [ + (None, 0, 5, ""), + ("", 0, 5, ""), + pytest.param(lambda r: r[0], 0, 1, "", marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="issues/6499")), + pytest.param("unknown", 100, 0, """AttributeError("\'NoneType\' object has no attribute \'keys\'")""", marks=pytest.mark.skip), + ], + ) + def test_id( + self, + HttpApiAuth, + add_chunks, + chunk_id, + expected_code, + expected_page_size, + expected_message, + ): + dataset_id, document_id, chunk_ids = add_chunks + if callable(chunk_id): + params = {"id": chunk_id(chunk_ids)} + else: + params = {"id": chunk_id} + res = list_chunks(HttpApiAuth, dataset_id, document_id, params=params) + assert res["code"] == expected_code + if expected_code == 0: + if params["id"] in [None, ""]: + assert len(res["data"]["chunks"]) == expected_page_size + else: + assert res["data"]["chunks"][0]["id"] == params["id"] + else: + assert res["message"] == expected_message + + @pytest.mark.p3 + def test_invalid_params(self, HttpApiAuth, add_chunks): + dataset_id, document_id, _ = add_chunks + params = {"a": "b"} + res = list_chunks(HttpApiAuth, dataset_id, document_id, params=params) + assert res["code"] == 0 + assert len(res["data"]["chunks"]) == 5 + + @pytest.mark.p3 + def test_concurrent_list(self, HttpApiAuth, add_chunks): + dataset_id, document_id, _ = add_chunks + count = 100 + with ThreadPoolExecutor(max_workers=5) as executor: + futures = [executor.submit(list_chunks, HttpApiAuth, dataset_id, document_id) for i in range(count)] + responses = list(as_completed(futures)) + assert len(responses) == count, responses + assert all(len(future.result()["data"]["chunks"]) == 5 for future in futures) + + @pytest.mark.p1 + def test_default(self, HttpApiAuth, add_document): + dataset_id, document_id = add_document + + res = list_chunks(HttpApiAuth, dataset_id, document_id) + chunks_count = res["data"]["doc"]["chunk_count"] + batch_add_chunks(HttpApiAuth, dataset_id, document_id, 31) + # issues/6487 + from time import sleep + + sleep(3) + res = list_chunks(HttpApiAuth, dataset_id, document_id) + assert res["code"] == 0 + assert len(res["data"]["chunks"]) == 30 + assert res["data"]["doc"]["chunk_count"] == chunks_count + 31 + + @pytest.mark.p3 + @pytest.mark.parametrize( + "dataset_id, expected_code, expected_message", + [ + ("", 100, ""), + ( + "invalid_dataset_id", + 102, + "You don't own the dataset invalid_dataset_id.", + ), + ], + ) + def test_invalid_dataset_id(self, HttpApiAuth, add_chunks, dataset_id, expected_code, expected_message): + _, document_id, _ = add_chunks + res = list_chunks(HttpApiAuth, dataset_id, document_id) + assert res["code"] == expected_code + assert res["message"] == expected_message + + @pytest.mark.p3 + @pytest.mark.parametrize( + "document_id, expected_code, expected_message", + [ + ("", 102, "The dataset not own the document chunks."), + ( + "invalid_document_id", + 102, + "You don't own the document invalid_document_id.", + ), + ], + ) + def test_invalid_document_id(self, HttpApiAuth, add_chunks, document_id, expected_code, expected_message): + dataset_id, _, _ = add_chunks + res = list_chunks(HttpApiAuth, dataset_id, document_id) + assert res["code"] == expected_code + assert res["message"] == expected_message diff --git a/test/testcases/test_http_api/test_chunk_management_within_dataset/test_retrieval_chunks.py b/test/testcases/test_http_api/test_chunk_management_within_dataset/test_retrieval_chunks.py new file mode 100644 index 00000000000..52421d5b16b --- /dev/null +++ b/test/testcases/test_http_api/test_chunk_management_within_dataset/test_retrieval_chunks.py @@ -0,0 +1,312 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os +from concurrent.futures import ThreadPoolExecutor, as_completed + +import pytest +from common import retrieval_chunks +from configs import INVALID_API_TOKEN +from libs.auth import RAGFlowHttpApiAuth + + +@pytest.mark.p1 +class TestAuthorization: + @pytest.mark.parametrize( + "invalid_auth, expected_code, expected_message", + [ + (None, 0, "`Authorization` can't be empty"), + ( + RAGFlowHttpApiAuth(INVALID_API_TOKEN), + 109, + "Authentication error: API key is invalid!", + ), + ], + ) + def test_invalid_auth(self, invalid_auth, expected_code, expected_message): + res = retrieval_chunks(invalid_auth) + assert res["code"] == expected_code + assert res["message"] == expected_message + + +class TestChunksRetrieval: + @pytest.mark.p1 + @pytest.mark.parametrize( + "payload, expected_code, expected_page_size, expected_message", + [ + ({"question": "chunk", "dataset_ids": None}, 0, 4, ""), + ({"question": "chunk", "document_ids": None}, 102, 0, "`dataset_ids` is required."), + ({"question": "chunk", "dataset_ids": None, "document_ids": None}, 0, 4, ""), + ({"question": "chunk"}, 102, 0, "`dataset_ids` is required."), + ], + ) + def test_basic_scenarios(self, HttpApiAuth, add_chunks, payload, expected_code, expected_page_size, expected_message): + dataset_id, document_id, _ = add_chunks + if "dataset_ids" in payload: + payload["dataset_ids"] = [dataset_id] + if "document_ids" in payload: + payload["document_ids"] = [document_id] + res = retrieval_chunks(HttpApiAuth, payload) + assert res["code"] == expected_code + if expected_code == 0: + assert len(res["data"]["chunks"]) == expected_page_size + else: + assert res["message"] == expected_message + + @pytest.mark.p2 + @pytest.mark.parametrize( + "payload, expected_code, expected_page_size, expected_message", + [ + pytest.param( + {"page": None, "page_size": 2}, + 100, + 2, + """TypeError("int() argument must be a string, a bytes-like object or a real number, not \'NoneType\'")""", + marks=pytest.mark.skip, + ), + pytest.param( + {"page": 0, "page_size": 2}, + 100, + 0, + "ValueError('Search does not support negative slicing.')", + marks=pytest.mark.skip, + ), + pytest.param({"page": 2, "page_size": 2}, 0, 2, "", marks=pytest.mark.skip(reason="issues/6646")), + ({"page": 3, "page_size": 2}, 0, 0, ""), + ({"page": "3", "page_size": 2}, 0, 0, ""), + pytest.param( + {"page": -1, "page_size": 2}, + 100, + 0, + "ValueError('Search does not support negative slicing.')", + marks=pytest.mark.skip, + ), + pytest.param( + {"page": "a", "page_size": 2}, + 100, + 0, + """ValueError("invalid literal for int() with base 10: \'a\'")""", + marks=pytest.mark.skip, + ), + ], + ) + def test_page(self, HttpApiAuth, add_chunks, payload, expected_code, expected_page_size, expected_message): + dataset_id, _, _ = add_chunks + payload.update({"question": "chunk", "dataset_ids": [dataset_id]}) + res = retrieval_chunks(HttpApiAuth, payload) + assert res["code"] == expected_code + if expected_code == 0: + assert len(res["data"]["chunks"]) == expected_page_size + else: + assert res["message"] == expected_message + + @pytest.mark.p3 + @pytest.mark.parametrize( + "payload, expected_code, expected_page_size, expected_message", + [ + pytest.param( + {"page_size": None}, + 100, + 0, + """TypeError("int() argument must be a string, a bytes-like object or a real number, not \'NoneType\'")""", + marks=pytest.mark.skip, + ), + # ({"page_size": 0}, 0, 0, ""), + ({"page_size": 1}, 0, 1, ""), + ({"page_size": 5}, 0, 4, ""), + ({"page_size": "1"}, 0, 1, ""), + # ({"page_size": -1}, 0, 0, ""), + pytest.param( + {"page_size": "a"}, + 100, + 0, + """ValueError("invalid literal for int() with base 10: \'a\'")""", + marks=pytest.mark.skip, + ), + ], + ) + def test_page_size(self, HttpApiAuth, add_chunks, payload, expected_code, expected_page_size, expected_message): + dataset_id, _, _ = add_chunks + payload.update({"question": "chunk", "dataset_ids": [dataset_id]}) + + res = retrieval_chunks(HttpApiAuth, payload) + assert res["code"] == expected_code + if expected_code == 0: + assert len(res["data"]["chunks"]) == expected_page_size + else: + assert res["message"] == expected_message + + @pytest.mark.p3 + @pytest.mark.parametrize( + "payload, expected_code, expected_page_size, expected_message", + [ + ({"vector_similarity_weight": 0}, 0, 4, ""), + ({"vector_similarity_weight": 0.5}, 0, 4, ""), + ({"vector_similarity_weight": 10}, 0, 4, ""), + pytest.param( + {"vector_similarity_weight": "a"}, + 100, + 0, + """ValueError("could not convert string to float: \'a\'")""", + marks=pytest.mark.skip, + ), + ], + ) + def test_vector_similarity_weight(self, HttpApiAuth, add_chunks, payload, expected_code, expected_page_size, expected_message): + dataset_id, _, _ = add_chunks + payload.update({"question": "chunk", "dataset_ids": [dataset_id]}) + res = retrieval_chunks(HttpApiAuth, payload) + assert res["code"] == expected_code + if expected_code == 0: + assert len(res["data"]["chunks"]) == expected_page_size + else: + assert res["message"] == expected_message + + @pytest.mark.p2 + @pytest.mark.parametrize( + "payload, expected_code, expected_page_size, expected_message", + [ + ({"top_k": 10}, 0, 4, ""), + pytest.param( + {"top_k": 1}, + 0, + 4, + "", + marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") in ["infinity", "opensearch"], reason="Infinity"), + ), + pytest.param( + {"top_k": 1}, + 0, + 1, + "", + marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") in [None, "opensearch", "elasticsearch"], reason="elasticsearch"), + ), + pytest.param( + {"top_k": -1}, + 100, + 4, + "must be greater than 0", + marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") in ["infinity", "opensearch"], reason="Infinity"), + ), + pytest.param( + {"top_k": -1}, + 100, + 4, + "3014", + marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") in [None, "opensearch", "elasticsearch"], reason="elasticsearch"), + ), + pytest.param( + {"top_k": "a"}, + 100, + 0, + """ValueError("invalid literal for int() with base 10: \'a\'")""", + marks=pytest.mark.skip, + ), + ], + ) + def test_top_k(self, HttpApiAuth, add_chunks, payload, expected_code, expected_page_size, expected_message): + dataset_id, _, _ = add_chunks + payload.update({"question": "chunk", "dataset_ids": [dataset_id]}) + res = retrieval_chunks(HttpApiAuth, payload) + assert res["code"] == expected_code + if expected_code == 0: + assert len(res["data"]["chunks"]) == expected_page_size + else: + assert expected_message in res["message"] + + @pytest.mark.skip + @pytest.mark.parametrize( + "payload, expected_code, expected_message", + [ + ({"rerank_id": "BAAI/bge-reranker-v2-m3"}, 0, ""), + pytest.param({"rerank_id": "unknown"}, 100, "LookupError('Model(unknown) not authorized')", marks=pytest.mark.skip), + ], + ) + def test_rerank_id(self, HttpApiAuth, add_chunks, payload, expected_code, expected_message): + dataset_id, _, _ = add_chunks + payload.update({"question": "chunk", "dataset_ids": [dataset_id]}) + res = retrieval_chunks(HttpApiAuth, payload) + assert res["code"] == expected_code + if expected_code == 0: + assert len(res["data"]["chunks"]) > 0 + else: + assert expected_message in res["message"] + + @pytest.mark.skip + @pytest.mark.parametrize( + "payload, expected_code, expected_page_size, expected_message", + [ + ({"keyword": True}, 0, 5, ""), + ({"keyword": "True"}, 0, 5, ""), + ({"keyword": False}, 0, 5, ""), + ({"keyword": "False"}, 0, 5, ""), + ({"keyword": None}, 0, 5, ""), + ], + ) + def test_keyword(self, HttpApiAuth, add_chunks, payload, expected_code, expected_page_size, expected_message): + dataset_id, _, _ = add_chunks + payload.update({"question": "chunk test", "dataset_ids": [dataset_id]}) + res = retrieval_chunks(HttpApiAuth, payload) + assert res["code"] == expected_code + if expected_code == 0: + assert len(res["data"]["chunks"]) == expected_page_size + else: + assert res["message"] == expected_message + + @pytest.mark.p3 + @pytest.mark.parametrize( + "payload, expected_code, expected_highlight, expected_message", + [ + ({"highlight": True}, 0, True, ""), + ({"highlight": "True"}, 0, True, ""), + pytest.param({"highlight": False}, 0, False, "", marks=pytest.mark.skip(reason="issues/6648")), + ({"highlight": "False"}, 0, False, ""), + pytest.param({"highlight": None}, 0, False, "", marks=pytest.mark.skip(reason="issues/6648")), + ], + ) + def test_highlight(self, HttpApiAuth, add_chunks, payload, expected_code, expected_highlight, expected_message): + dataset_id, _, _ = add_chunks + payload.update({"question": "chunk", "dataset_ids": [dataset_id]}) + res = retrieval_chunks(HttpApiAuth, payload) + assert res["code"] == expected_code + if expected_highlight: + for chunk in res["data"]["chunks"]: + assert "highlight" in chunk + else: + for chunk in res["data"]["chunks"]: + assert "highlight" not in chunk + + if expected_code != 0: + assert res["message"] == expected_message + + @pytest.mark.p3 + def test_invalid_params(self, HttpApiAuth, add_chunks): + dataset_id, _, _ = add_chunks + payload = {"question": "chunk", "dataset_ids": [dataset_id], "a": "b"} + res = retrieval_chunks(HttpApiAuth, payload) + assert res["code"] == 0 + assert len(res["data"]["chunks"]) == 4 + + @pytest.mark.p3 + def test_concurrent_retrieval(self, HttpApiAuth, add_chunks): + dataset_id, _, _ = add_chunks + count = 100 + payload = {"question": "chunk", "dataset_ids": [dataset_id]} + + with ThreadPoolExecutor(max_workers=5) as executor: + futures = [executor.submit(retrieval_chunks, HttpApiAuth, payload) for i in range(count)] + responses = list(as_completed(futures)) + assert len(responses) == count, responses + assert all(future.result()["code"] == 0 for future in futures) diff --git a/test/testcases/test_http_api/test_chunk_management_within_dataset/test_update_chunk.py b/test/testcases/test_http_api/test_chunk_management_within_dataset/test_update_chunk.py new file mode 100644 index 00000000000..d6d0278fee4 --- /dev/null +++ b/test/testcases/test_http_api/test_chunk_management_within_dataset/test_update_chunk.py @@ -0,0 +1,248 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os +from concurrent.futures import ThreadPoolExecutor, as_completed +from random import randint + +import pytest +from common import delete_documents, update_chunk +from configs import INVALID_API_TOKEN +from libs.auth import RAGFlowHttpApiAuth + + +@pytest.mark.p1 +class TestAuthorization: + @pytest.mark.parametrize( + "invalid_auth, expected_code, expected_message", + [ + (None, 0, "`Authorization` can't be empty"), + ( + RAGFlowHttpApiAuth(INVALID_API_TOKEN), + 109, + "Authentication error: API key is invalid!", + ), + ], + ) + def test_invalid_auth(self, invalid_auth, expected_code, expected_message): + res = update_chunk(invalid_auth, "dataset_id", "document_id", "chunk_id") + assert res["code"] == expected_code + assert res["message"] == expected_message + + +class TestUpdatedChunk: + @pytest.mark.p1 + @pytest.mark.parametrize( + "payload, expected_code, expected_message", + [ + ({"content": None}, 100, "TypeError('expected string or bytes-like object')"), + pytest.param( + {"content": ""}, + 100, + """APIRequestFailedError(\'Error code: 400, with error text {"error":{"code":"1213","message":"未正常接收到prompt参数。"}}\')""", + marks=pytest.mark.skip(reason="issues/6541"), + ), + pytest.param( + {"content": 1}, + 100, + "TypeError('expected string or bytes-like object')", + marks=pytest.mark.skip, + ), + ({"content": "update chunk"}, 0, ""), + pytest.param( + {"content": " "}, + 100, + """APIRequestFailedError(\'Error code: 400, with error text {"error":{"code":"1213","message":"未正常接收到prompt参数。"}}\')""", + marks=pytest.mark.skip(reason="issues/6541"), + ), + ({"content": "\n!?。;!?\"'"}, 0, ""), + ], + ) + def test_content(self, HttpApiAuth, add_chunks, payload, expected_code, expected_message): + dataset_id, document_id, chunk_ids = add_chunks + res = update_chunk(HttpApiAuth, dataset_id, document_id, chunk_ids[0], payload) + assert res["code"] == expected_code + if expected_code != 0: + assert res["message"] == expected_message + + @pytest.mark.p2 + @pytest.mark.parametrize( + "payload, expected_code, expected_message", + [ + ({"important_keywords": ["a", "b", "c"]}, 0, ""), + ({"important_keywords": [""]}, 0, ""), + ({"important_keywords": [1]}, 100, "TypeError('sequence item 0: expected str instance, int found')"), + ({"important_keywords": ["a", "a"]}, 0, ""), + ({"important_keywords": "abc"}, 102, "`important_keywords` should be a list"), + ({"important_keywords": 123}, 102, "`important_keywords` should be a list"), + ], + ) + def test_important_keywords(self, HttpApiAuth, add_chunks, payload, expected_code, expected_message): + dataset_id, document_id, chunk_ids = add_chunks + res = update_chunk(HttpApiAuth, dataset_id, document_id, chunk_ids[0], payload) + assert res["code"] == expected_code + if expected_code != 0: + assert res["message"] == expected_message + + @pytest.mark.p2 + @pytest.mark.parametrize( + "payload, expected_code, expected_message", + [ + ({"questions": ["a", "b", "c"]}, 0, ""), + ({"questions": [""]}, 0, ""), + ({"questions": [1]}, 100, "TypeError('sequence item 0: expected str instance, int found')"), + ({"questions": ["a", "a"]}, 0, ""), + ({"questions": "abc"}, 102, "`questions` should be a list"), + ({"questions": 123}, 102, "`questions` should be a list"), + ], + ) + def test_questions(self, HttpApiAuth, add_chunks, payload, expected_code, expected_message): + dataset_id, document_id, chunk_ids = add_chunks + res = update_chunk(HttpApiAuth, dataset_id, document_id, chunk_ids[0], payload) + assert res["code"] == expected_code + if expected_code != 0: + assert res["message"] == expected_message + + @pytest.mark.p2 + @pytest.mark.parametrize( + "payload, expected_code, expected_message", + [ + ({"available": True}, 0, ""), + pytest.param({"available": "True"}, 100, """ValueError("invalid literal for int() with base 10: \'True\'")""", marks=pytest.mark.skip), + ({"available": 1}, 0, ""), + ({"available": False}, 0, ""), + pytest.param({"available": "False"}, 100, """ValueError("invalid literal for int() with base 10: \'False\'")""", marks=pytest.mark.skip), + ({"available": 0}, 0, ""), + ], + ) + def test_available( + self, + HttpApiAuth, + add_chunks, + payload, + expected_code, + expected_message, + ): + dataset_id, document_id, chunk_ids = add_chunks + res = update_chunk(HttpApiAuth, dataset_id, document_id, chunk_ids[0], payload) + assert res["code"] == expected_code + if expected_code != 0: + assert res["message"] == expected_message + + @pytest.mark.p3 + @pytest.mark.parametrize( + "dataset_id, expected_code, expected_message", + [ + ("", 100, ""), + pytest.param("invalid_dataset_id", 102, "You don't own the dataset invalid_dataset_id.", marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="infinity")), + pytest.param("invalid_dataset_id", 102, "Can't find this chunk", marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") in [None, "opensearch", "elasticsearch"], reason="elasticsearch")), + ], + ) + def test_invalid_dataset_id(self, HttpApiAuth, add_chunks, dataset_id, expected_code, expected_message): + _, document_id, chunk_ids = add_chunks + res = update_chunk(HttpApiAuth, dataset_id, document_id, chunk_ids[0]) + assert res["code"] == expected_code + assert expected_message in res["message"] + + @pytest.mark.p3 + @pytest.mark.parametrize( + "document_id, expected_code, expected_message", + [ + ("", 100, ""), + ( + "invalid_document_id", + 102, + "You don't own the document invalid_document_id.", + ), + ], + ) + def test_invalid_document_id(self, HttpApiAuth, add_chunks, document_id, expected_code, expected_message): + dataset_id, _, chunk_ids = add_chunks + res = update_chunk(HttpApiAuth, dataset_id, document_id, chunk_ids[0]) + assert res["code"] == expected_code + assert res["message"] == expected_message + + @pytest.mark.p3 + @pytest.mark.parametrize( + "chunk_id, expected_code, expected_message", + [ + ("", 100, ""), + ( + "invalid_document_id", + 102, + "Can't find this chunk invalid_document_id", + ), + ], + ) + def test_invalid_chunk_id(self, HttpApiAuth, add_chunks, chunk_id, expected_code, expected_message): + dataset_id, document_id, _ = add_chunks + res = update_chunk(HttpApiAuth, dataset_id, document_id, chunk_id) + assert res["code"] == expected_code + assert res["message"] == expected_message + + @pytest.mark.p3 + def test_repeated_update_chunk(self, HttpApiAuth, add_chunks): + dataset_id, document_id, chunk_ids = add_chunks + res = update_chunk(HttpApiAuth, dataset_id, document_id, chunk_ids[0], {"content": "chunk test 1"}) + assert res["code"] == 0 + + res = update_chunk(HttpApiAuth, dataset_id, document_id, chunk_ids[0], {"content": "chunk test 2"}) + assert res["code"] == 0 + + @pytest.mark.p3 + @pytest.mark.parametrize( + "payload, expected_code, expected_message", + [ + ({"unknown_key": "unknown_value"}, 0, ""), + ({}, 0, ""), + pytest.param(None, 100, """TypeError("argument of type \'NoneType\' is not iterable")""", marks=pytest.mark.skip), + ], + ) + def test_invalid_params(self, HttpApiAuth, add_chunks, payload, expected_code, expected_message): + dataset_id, document_id, chunk_ids = add_chunks + res = update_chunk(HttpApiAuth, dataset_id, document_id, chunk_ids[0], payload) + assert res["code"] == expected_code + if expected_code != 0: + assert res["message"] == expected_message + + @pytest.mark.p3 + @pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="issues/6554") + def test_concurrent_update_chunk(self, HttpApiAuth, add_chunks): + count = 50 + dataset_id, document_id, chunk_ids = add_chunks + + with ThreadPoolExecutor(max_workers=5) as executor: + futures = [ + executor.submit( + update_chunk, + HttpApiAuth, + dataset_id, + document_id, + chunk_ids[randint(0, 3)], + {"content": f"update chunk test {i}"}, + ) + for i in range(count) + ] + responses = list(as_completed(futures)) + assert len(responses) == count, responses + assert all(future.result()["code"] == 0 for future in futures) + + @pytest.mark.p3 + def test_update_chunk_to_deleted_document(self, HttpApiAuth, add_chunks): + dataset_id, document_id, chunk_ids = add_chunks + delete_documents(HttpApiAuth, dataset_id, {"ids": [document_id]}) + res = update_chunk(HttpApiAuth, dataset_id, document_id, chunk_ids[0]) + assert res["code"] == 102 + assert res["message"] == f"Can't find this chunk {chunk_ids[0]}" diff --git a/test/testcases/test_http_api/test_dataset_mangement/conftest.py b/test/testcases/test_http_api/test_dataset_mangement/conftest.py new file mode 100644 index 00000000000..d4ef989ff7a --- /dev/null +++ b/test/testcases/test_http_api/test_dataset_mangement/conftest.py @@ -0,0 +1,39 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +import pytest +from common import batch_create_datasets, delete_datasets + + +@pytest.fixture(scope="class") +def add_datasets(HttpApiAuth, request): + def cleanup(): + delete_datasets(HttpApiAuth, {"ids": None}) + + request.addfinalizer(cleanup) + + return batch_create_datasets(HttpApiAuth, 5) + + +@pytest.fixture(scope="function") +def add_datasets_func(HttpApiAuth, request): + def cleanup(): + delete_datasets(HttpApiAuth, {"ids": None}) + + request.addfinalizer(cleanup) + + return batch_create_datasets(HttpApiAuth, 3) diff --git a/test/testcases/test_http_api/test_dataset_mangement/test_create_dataset.py b/test/testcases/test_http_api/test_dataset_mangement/test_create_dataset.py new file mode 100644 index 00000000000..b3b3f9b8abc --- /dev/null +++ b/test/testcases/test_http_api/test_dataset_mangement/test_create_dataset.py @@ -0,0 +1,695 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from concurrent.futures import ThreadPoolExecutor, as_completed + +import pytest +from common import create_dataset +from configs import DATASET_NAME_LIMIT, INVALID_API_TOKEN +from hypothesis import example, given, settings +from libs.auth import RAGFlowHttpApiAuth +from utils import encode_avatar +from utils.file_utils import create_image_file +from utils.hypothesis_utils import valid_names + + +@pytest.mark.usefixtures("clear_datasets") +class TestAuthorization: + @pytest.mark.p1 + @pytest.mark.parametrize( + "invalid_auth, expected_code, expected_message", + [ + (None, 0, "`Authorization` can't be empty"), + ( + RAGFlowHttpApiAuth(INVALID_API_TOKEN), + 109, + "Authentication error: API key is invalid!", + ), + ], + ids=["empty_auth", "invalid_api_token"], + ) + def test_auth_invalid(self, invalid_auth, expected_code, expected_message): + res = create_dataset(invalid_auth, {"name": "auth_test"}) + assert res["code"] == expected_code, res + assert res["message"] == expected_message, res + + +class TestRquest: + @pytest.mark.p3 + def test_content_type_bad(self, HttpApiAuth): + BAD_CONTENT_TYPE = "text/xml" + res = create_dataset(HttpApiAuth, {"name": "bad_content_type"}, headers={"Content-Type": BAD_CONTENT_TYPE}) + assert res["code"] == 101, res + assert res["message"] == f"Unsupported content type: Expected application/json, got {BAD_CONTENT_TYPE}", res + + @pytest.mark.p3 + @pytest.mark.parametrize( + "payload, expected_message", + [ + ("a", "Malformed JSON syntax: Missing commas/brackets or invalid encoding"), + ('"a"', "Invalid request payload: expected object, got str"), + ], + ids=["malformed_json_syntax", "invalid_request_payload_type"], + ) + def test_payload_bad(self, HttpApiAuth, payload, expected_message): + res = create_dataset(HttpApiAuth, data=payload) + assert res["code"] == 101, res + assert res["message"] == expected_message, res + + +@pytest.mark.usefixtures("clear_datasets") +class TestCapability: + @pytest.mark.p3 + def test_create_dataset_1k(self, HttpApiAuth): + for i in range(1_000): + payload = {"name": f"dataset_{i}"} + res = create_dataset(HttpApiAuth, payload) + assert res["code"] == 0, f"Failed to create dataset {i}" + + @pytest.mark.p3 + def test_create_dataset_concurrent(self, HttpApiAuth): + count = 100 + with ThreadPoolExecutor(max_workers=5) as executor: + futures = [executor.submit(create_dataset, HttpApiAuth, {"name": f"dataset_{i}"}) for i in range(count)] + responses = list(as_completed(futures)) + assert len(responses) == count, responses + assert all(future.result()["code"] == 0 for future in futures) + + +@pytest.mark.usefixtures("clear_datasets") +class TestDatasetCreate: + @pytest.mark.p1 + @given(name=valid_names()) + @example("a" * 128) + @settings(max_examples=20) + def test_name(self, HttpApiAuth, name): + res = create_dataset(HttpApiAuth, {"name": name}) + assert res["code"] == 0, res + assert res["data"]["name"] == name, res + + @pytest.mark.p2 + @pytest.mark.parametrize( + "name, expected_message", + [ + ("", "String should have at least 1 character"), + (" ", "String should have at least 1 character"), + ("a" * (DATASET_NAME_LIMIT + 1), "String should have at most 128 characters"), + (0, "Input should be a valid string"), + (None, "Input should be a valid string"), + ], + ids=["empty_name", "space_name", "too_long_name", "invalid_name", "None_name"], + ) + def test_name_invalid(self, HttpApiAuth, name, expected_message): + payload = {"name": name} + res = create_dataset(HttpApiAuth, payload) + assert res["code"] == 101, res + assert expected_message in res["message"], res + + @pytest.mark.p3 + def test_name_duplicated(self, HttpApiAuth): + name = "duplicated_name" + payload = {"name": name} + res = create_dataset(HttpApiAuth, payload) + assert res["code"] == 0, res + + res = create_dataset(HttpApiAuth, payload) + assert res["code"] == 103, res + assert res["message"] == f"Dataset name '{name}' already exists", res + + @pytest.mark.p3 + def test_name_case_insensitive(self, HttpApiAuth): + name = "CaseInsensitive" + payload = {"name": name.upper()} + res = create_dataset(HttpApiAuth, payload) + assert res["code"] == 0, res + + payload = {"name": name.lower()} + res = create_dataset(HttpApiAuth, payload) + assert res["code"] == 103, res + assert res["message"] == f"Dataset name '{name.lower()}' already exists", res + + @pytest.mark.p2 + def test_avatar(self, HttpApiAuth, tmp_path): + fn = create_image_file(tmp_path / "ragflow_test.png") + payload = { + "name": "avatar", + "avatar": f"data:image/png;base64,{encode_avatar(fn)}", + } + res = create_dataset(HttpApiAuth, payload) + assert res["code"] == 0, res + + @pytest.mark.p2 + def test_avatar_exceeds_limit_length(self, HttpApiAuth): + payload = {"name": "avatar_exceeds_limit_length", "avatar": "a" * 65536} + res = create_dataset(HttpApiAuth, payload) + assert res["code"] == 101, res + assert "String should have at most 65535 characters" in res["message"], res + + @pytest.mark.p3 + @pytest.mark.parametrize( + "name, prefix, expected_message", + [ + ("empty_prefix", "", "Missing MIME prefix. Expected format: data:;base64,"), + ("missing_comma", "data:image/png;base64", "Missing MIME prefix. Expected format: data:;base64,"), + ("unsupported_mine_type", "invalid_mine_prefix:image/png;base64,", "Invalid MIME prefix format. Must start with 'data:'"), + ("invalid_mine_type", "data:unsupported_mine_type;base64,", "Unsupported MIME type. Allowed: ['image/jpeg', 'image/png']"), + ], + ids=["empty_prefix", "missing_comma", "unsupported_mine_type", "invalid_mine_type"], + ) + def test_avatar_invalid_prefix(self, HttpApiAuth, tmp_path, name, prefix, expected_message): + fn = create_image_file(tmp_path / "ragflow_test.png") + payload = { + "name": name, + "avatar": f"{prefix}{encode_avatar(fn)}", + } + res = create_dataset(HttpApiAuth, payload) + assert res["code"] == 101, res + assert expected_message in res["message"], res + + @pytest.mark.p3 + def test_avatar_unset(self, HttpApiAuth): + payload = {"name": "avatar_unset"} + res = create_dataset(HttpApiAuth, payload) + assert res["code"] == 0, res + assert res["data"]["avatar"] is None, res + + @pytest.mark.p3 + def test_avatar_none(self, HttpApiAuth): + payload = {"name": "avatar_none", "avatar": None} + res = create_dataset(HttpApiAuth, payload) + assert res["code"] == 0, res + assert res["data"]["avatar"] is None, res + + @pytest.mark.p2 + def test_description(self, HttpApiAuth): + payload = {"name": "description", "description": "description"} + res = create_dataset(HttpApiAuth, payload) + assert res["code"] == 0, res + assert res["data"]["description"] == "description", res + + @pytest.mark.p2 + def test_description_exceeds_limit_length(self, HttpApiAuth): + payload = {"name": "description_exceeds_limit_length", "description": "a" * 65536} + res = create_dataset(HttpApiAuth, payload) + assert res["code"] == 101, res + assert "String should have at most 65535 characters" in res["message"], res + + @pytest.mark.p3 + def test_description_unset(self, HttpApiAuth): + payload = {"name": "description_unset"} + res = create_dataset(HttpApiAuth, payload) + assert res["code"] == 0, res + assert res["data"]["description"] is None, res + + @pytest.mark.p3 + def test_description_none(self, HttpApiAuth): + payload = {"name": "description_none", "description": None} + res = create_dataset(HttpApiAuth, payload) + assert res["code"] == 0, res + assert res["data"]["description"] is None, res + + @pytest.mark.p1 + @pytest.mark.parametrize( + "name, embedding_model", + [ + ("BAAI/bge-large-zh-v1.5@BAAI", "BAAI/bge-large-zh-v1.5@BAAI"), + ("maidalun1020/bce-embedding-base_v1@Youdao", "maidalun1020/bce-embedding-base_v1@Youdao"), + ("embedding-3@ZHIPU-AI", "embedding-3@ZHIPU-AI"), + ], + ids=["builtin_baai", "builtin_youdao", "tenant_zhipu"], + ) + def test_embedding_model(self, HttpApiAuth, name, embedding_model): + payload = {"name": name, "embedding_model": embedding_model} + res = create_dataset(HttpApiAuth, payload) + assert res["code"] == 0, res + assert res["data"]["embedding_model"] == embedding_model, res + + @pytest.mark.p2 + @pytest.mark.parametrize( + "name, embedding_model", + [ + ("unknown_llm_name", "unknown@ZHIPU-AI"), + ("unknown_llm_factory", "embedding-3@unknown"), + ("tenant_no_auth_default_tenant_llm", "text-embedding-v3@Tongyi-Qianwen"), + ("tenant_no_auth", "text-embedding-3-small@OpenAI"), + ], + ids=["unknown_llm_name", "unknown_llm_factory", "tenant_no_auth_default_tenant_llm", "tenant_no_auth"], + ) + def test_embedding_model_invalid(self, HttpApiAuth, name, embedding_model): + payload = {"name": name, "embedding_model": embedding_model} + res = create_dataset(HttpApiAuth, payload) + assert res["code"] == 101, res + if "tenant_no_auth" in name: + assert res["message"] == f"Unauthorized model: <{embedding_model}>", res + else: + assert res["message"] == f"Unsupported model: <{embedding_model}>", res + + @pytest.mark.p2 + @pytest.mark.parametrize( + "name, embedding_model", + [ + ("missing_at", "BAAI/bge-large-zh-v1.5BAAI"), + ("missing_model_name", "@BAAI"), + ("missing_provider", "BAAI/bge-large-zh-v1.5@"), + ("whitespace_only_model_name", " @BAAI"), + ("whitespace_only_provider", "BAAI/bge-large-zh-v1.5@ "), + ], + ids=["missing_at", "empty_model_name", "empty_provider", "whitespace_only_model_name", "whitespace_only_provider"], + ) + def test_embedding_model_format(self, HttpApiAuth, name, embedding_model): + payload = {"name": name, "embedding_model": embedding_model} + res = create_dataset(HttpApiAuth, payload) + assert res["code"] == 101, res + if name == "missing_at": + assert "Embedding model identifier must follow @ format" in res["message"], res + else: + assert "Both model_name and provider must be non-empty strings" in res["message"], res + + @pytest.mark.p2 + def test_embedding_model_unset(self, HttpApiAuth): + payload = {"name": "embedding_model_unset"} + res = create_dataset(HttpApiAuth, payload) + assert res["code"] == 0, res + assert res["data"]["embedding_model"] == "BAAI/bge-large-zh-v1.5@BAAI", res + + @pytest.mark.p2 + def test_embedding_model_none(self, HttpApiAuth): + payload = {"name": "embedding_model_none", "embedding_model": None} + res = create_dataset(HttpApiAuth, payload) + assert res["code"] == 101, res + assert "Input should be a valid string" in res["message"], res + + @pytest.mark.p1 + @pytest.mark.parametrize( + "name, permission", + [ + ("me", "me"), + ("team", "team"), + ("me_upercase", "ME"), + ("team_upercase", "TEAM"), + ("whitespace", " ME "), + ], + ids=["me", "team", "me_upercase", "team_upercase", "whitespace"], + ) + def test_permission(self, HttpApiAuth, name, permission): + payload = {"name": name, "permission": permission} + res = create_dataset(HttpApiAuth, payload) + assert res["code"] == 0, res + assert res["data"]["permission"] == permission.lower().strip(), res + + @pytest.mark.p2 + @pytest.mark.parametrize( + "name, permission", + [ + ("empty", ""), + ("unknown", "unknown"), + ("type_error", list()), + ], + ids=["empty", "unknown", "type_error"], + ) + def test_permission_invalid(self, HttpApiAuth, name, permission): + payload = {"name": name, "permission": permission} + res = create_dataset(HttpApiAuth, payload) + assert res["code"] == 101 + assert "Input should be 'me' or 'team'" in res["message"] + + @pytest.mark.p2 + def test_permission_unset(self, HttpApiAuth): + payload = {"name": "permission_unset"} + res = create_dataset(HttpApiAuth, payload) + assert res["code"] == 0, res + assert res["data"]["permission"] == "me", res + + @pytest.mark.p3 + def test_permission_none(self, HttpApiAuth): + payload = {"name": "permission_none", "permission": None} + res = create_dataset(HttpApiAuth, payload) + assert res["code"] == 101, res + assert "Input should be 'me' or 'team'" in res["message"], res + + @pytest.mark.p1 + @pytest.mark.parametrize( + "name, chunk_method", + [ + ("naive", "naive"), + ("book", "book"), + ("email", "email"), + ("laws", "laws"), + ("manual", "manual"), + ("one", "one"), + ("paper", "paper"), + ("picture", "picture"), + ("presentation", "presentation"), + ("qa", "qa"), + ("table", "table"), + ("tag", "tag"), + ], + ids=["naive", "book", "email", "laws", "manual", "one", "paper", "picture", "presentation", "qa", "table", "tag"], + ) + def test_chunk_method(self, HttpApiAuth, name, chunk_method): + payload = {"name": name, "chunk_method": chunk_method} + res = create_dataset(HttpApiAuth, payload) + assert res["code"] == 0, res + assert res["data"]["chunk_method"] == chunk_method, res + + @pytest.mark.p2 + @pytest.mark.parametrize( + "name, chunk_method", + [ + ("empty", ""), + ("unknown", "unknown"), + ("type_error", list()), + ], + ids=["empty", "unknown", "type_error"], + ) + def test_chunk_method_invalid(self, HttpApiAuth, name, chunk_method): + payload = {"name": name, "chunk_method": chunk_method} + res = create_dataset(HttpApiAuth, payload) + assert res["code"] == 101, res + assert "Input should be 'naive', 'book', 'email', 'laws', 'manual', 'one', 'paper', 'picture', 'presentation', 'qa', 'table' or 'tag'" in res["message"], res + + @pytest.mark.p2 + def test_chunk_method_unset(self, HttpApiAuth): + payload = {"name": "chunk_method_unset"} + res = create_dataset(HttpApiAuth, payload) + assert res["code"] == 0, res + assert res["data"]["chunk_method"] == "naive", res + + @pytest.mark.p3 + def test_chunk_method_none(self, HttpApiAuth): + payload = {"name": "chunk_method_none", "chunk_method": None} + res = create_dataset(HttpApiAuth, payload) + assert res["code"] == 101, res + assert "Input should be 'naive', 'book', 'email', 'laws', 'manual', 'one', 'paper', 'picture', 'presentation', 'qa', 'table' or 'tag'" in res["message"], res + + @pytest.mark.p1 + @pytest.mark.parametrize( + "name, parser_config", + [ + ("auto_keywords_min", {"auto_keywords": 0}), + ("auto_keywords_mid", {"auto_keywords": 16}), + ("auto_keywords_max", {"auto_keywords": 32}), + ("auto_questions_min", {"auto_questions": 0}), + ("auto_questions_mid", {"auto_questions": 5}), + ("auto_questions_max", {"auto_questions": 10}), + ("chunk_token_num_min", {"chunk_token_num": 1}), + ("chunk_token_num_mid", {"chunk_token_num": 1024}), + ("chunk_token_num_max", {"chunk_token_num": 2048}), + ("delimiter", {"delimiter": "\n"}), + ("delimiter_space", {"delimiter": " "}), + ("html4excel_true", {"html4excel": True}), + ("html4excel_false", {"html4excel": False}), + ("layout_recognize_DeepDOC", {"layout_recognize": "DeepDOC"}), + ("layout_recognize_navie", {"layout_recognize": "Plain Text"}), + ("tag_kb_ids", {"tag_kb_ids": ["1", "2"]}), + ("topn_tags_min", {"topn_tags": 1}), + ("topn_tags_mid", {"topn_tags": 5}), + ("topn_tags_max", {"topn_tags": 10}), + ("filename_embd_weight_min", {"filename_embd_weight": 0.1}), + ("filename_embd_weight_mid", {"filename_embd_weight": 0.5}), + ("filename_embd_weight_max", {"filename_embd_weight": 1.0}), + ("task_page_size_min", {"task_page_size": 1}), + ("task_page_size_None", {"task_page_size": None}), + ("pages", {"pages": [[1, 100]]}), + ("pages_none", {"pages": None}), + ("graphrag_true", {"graphrag": {"use_graphrag": True}}), + ("graphrag_false", {"graphrag": {"use_graphrag": False}}), + ("graphrag_entity_types", {"graphrag": {"entity_types": ["age", "sex", "height", "weight"]}}), + ("graphrag_method_general", {"graphrag": {"method": "general"}}), + ("graphrag_method_light", {"graphrag": {"method": "light"}}), + ("graphrag_community_true", {"graphrag": {"community": True}}), + ("graphrag_community_false", {"graphrag": {"community": False}}), + ("graphrag_resolution_true", {"graphrag": {"resolution": True}}), + ("graphrag_resolution_false", {"graphrag": {"resolution": False}}), + ("raptor_true", {"raptor": {"use_raptor": True}}), + ("raptor_false", {"raptor": {"use_raptor": False}}), + ("raptor_prompt", {"raptor": {"prompt": "Who are you?"}}), + ("raptor_max_token_min", {"raptor": {"max_token": 1}}), + ("raptor_max_token_mid", {"raptor": {"max_token": 1024}}), + ("raptor_max_token_max", {"raptor": {"max_token": 2048}}), + ("raptor_threshold_min", {"raptor": {"threshold": 0.0}}), + ("raptor_threshold_mid", {"raptor": {"threshold": 0.5}}), + ("raptor_threshold_max", {"raptor": {"threshold": 1.0}}), + ("raptor_max_cluster_min", {"raptor": {"max_cluster": 1}}), + ("raptor_max_cluster_mid", {"raptor": {"max_cluster": 512}}), + ("raptor_max_cluster_max", {"raptor": {"max_cluster": 1024}}), + ("raptor_random_seed_min", {"raptor": {"random_seed": 0}}), + ], + ids=[ + "auto_keywords_min", + "auto_keywords_mid", + "auto_keywords_max", + "auto_questions_min", + "auto_questions_mid", + "auto_questions_max", + "chunk_token_num_min", + "chunk_token_num_mid", + "chunk_token_num_max", + "delimiter", + "delimiter_space", + "html4excel_true", + "html4excel_false", + "layout_recognize_DeepDOC", + "layout_recognize_navie", + "tag_kb_ids", + "topn_tags_min", + "topn_tags_mid", + "topn_tags_max", + "filename_embd_weight_min", + "filename_embd_weight_mid", + "filename_embd_weight_max", + "task_page_size_min", + "task_page_size_None", + "pages", + "pages_none", + "graphrag_true", + "graphrag_false", + "graphrag_entity_types", + "graphrag_method_general", + "graphrag_method_light", + "graphrag_community_true", + "graphrag_community_false", + "graphrag_resolution_true", + "graphrag_resolution_false", + "raptor_true", + "raptor_false", + "raptor_prompt", + "raptor_max_token_min", + "raptor_max_token_mid", + "raptor_max_token_max", + "raptor_threshold_min", + "raptor_threshold_mid", + "raptor_threshold_max", + "raptor_max_cluster_min", + "raptor_max_cluster_mid", + "raptor_max_cluster_max", + "raptor_random_seed_min", + ], + ) + def test_parser_config(self, HttpApiAuth, name, parser_config): + payload = {"name": name, "parser_config": parser_config} + res = create_dataset(HttpApiAuth, payload) + assert res["code"] == 0, res + for k, v in parser_config.items(): + if isinstance(v, dict): + for kk, vv in v.items(): + assert res["data"]["parser_config"][k][kk] == vv, res + else: + assert res["data"]["parser_config"][k] == v, res + + @pytest.mark.p2 + @pytest.mark.parametrize( + "name, parser_config, expected_message", + [ + ("auto_keywords_min_limit", {"auto_keywords": -1}, "Input should be greater than or equal to 0"), + ("auto_keywords_max_limit", {"auto_keywords": 33}, "Input should be less than or equal to 32"), + ("auto_keywords_float_not_allowed", {"auto_keywords": 3.14}, "Input should be a valid integer, got a number with a fractional part"), + ("auto_keywords_type_invalid", {"auto_keywords": "string"}, "Input should be a valid integer, unable to parse string as an integer"), + ("auto_questions_min_limit", {"auto_questions": -1}, "Input should be greater than or equal to 0"), + ("auto_questions_max_limit", {"auto_questions": 11}, "Input should be less than or equal to 10"), + ("auto_questions_float_not_allowed", {"auto_questions": 3.14}, "Input should be a valid integer, got a number with a fractional part"), + ("auto_questions_type_invalid", {"auto_questions": "string"}, "Input should be a valid integer, unable to parse string as an integer"), + ("chunk_token_num_min_limit", {"chunk_token_num": 0}, "Input should be greater than or equal to 1"), + ("chunk_token_num_max_limit", {"chunk_token_num": 2049}, "Input should be less than or equal to 2048"), + ("chunk_token_num_float_not_allowed", {"chunk_token_num": 3.14}, "Input should be a valid integer, got a number with a fractional part"), + ("chunk_token_num_type_invalid", {"chunk_token_num": "string"}, "Input should be a valid integer, unable to parse string as an integer"), + ("delimiter_empty", {"delimiter": ""}, "String should have at least 1 character"), + ("html4excel_type_invalid", {"html4excel": "string"}, "Input should be a valid boolean, unable to interpret input"), + ("tag_kb_ids_not_list", {"tag_kb_ids": "1,2"}, "Input should be a valid list"), + ("tag_kb_ids_int_in_list", {"tag_kb_ids": [1, 2]}, "Input should be a valid string"), + ("topn_tags_min_limit", {"topn_tags": 0}, "Input should be greater than or equal to 1"), + ("topn_tags_max_limit", {"topn_tags": 11}, "Input should be less than or equal to 10"), + ("topn_tags_float_not_allowed", {"topn_tags": 3.14}, "Input should be a valid integer, got a number with a fractional part"), + ("topn_tags_type_invalid", {"topn_tags": "string"}, "Input should be a valid integer, unable to parse string as an integer"), + ("filename_embd_weight_min_limit", {"filename_embd_weight": -1}, "Input should be greater than or equal to 0"), + ("filename_embd_weight_max_limit", {"filename_embd_weight": 1.1}, "Input should be less than or equal to 1"), + ("filename_embd_weight_type_invalid", {"filename_embd_weight": "string"}, "Input should be a valid number, unable to parse string as a number"), + ("task_page_size_min_limit", {"task_page_size": 0}, "Input should be greater than or equal to 1"), + ("task_page_size_float_not_allowed", {"task_page_size": 3.14}, "Input should be a valid integer, got a number with a fractional part"), + ("task_page_size_type_invalid", {"task_page_size": "string"}, "Input should be a valid integer, unable to parse string as an integer"), + ("pages_not_list", {"pages": "1,2"}, "Input should be a valid list"), + ("pages_not_list_in_list", {"pages": ["1,2"]}, "Input should be a valid list"), + ("pages_not_int_list", {"pages": [["string1", "string2"]]}, "Input should be a valid integer, unable to parse string as an integer"), + ("graphrag_type_invalid", {"graphrag": {"use_graphrag": "string"}}, "Input should be a valid boolean, unable to interpret input"), + ("graphrag_entity_types_not_list", {"graphrag": {"entity_types": "1,2"}}, "Input should be a valid list"), + ("graphrag_entity_types_not_str_in_list", {"graphrag": {"entity_types": [1, 2]}}, "nput should be a valid string"), + ("graphrag_method_unknown", {"graphrag": {"method": "unknown"}}, "Input should be 'light' or 'general'"), + ("graphrag_method_none", {"graphrag": {"method": None}}, "Input should be 'light' or 'general'"), + ("graphrag_community_type_invalid", {"graphrag": {"community": "string"}}, "Input should be a valid boolean, unable to interpret input"), + ("graphrag_resolution_type_invalid", {"graphrag": {"resolution": "string"}}, "Input should be a valid boolean, unable to interpret input"), + ("raptor_type_invalid", {"raptor": {"use_raptor": "string"}}, "Input should be a valid boolean, unable to interpret input"), + ("raptor_prompt_empty", {"raptor": {"prompt": ""}}, "String should have at least 1 character"), + ("raptor_prompt_space", {"raptor": {"prompt": " "}}, "String should have at least 1 character"), + ("raptor_max_token_min_limit", {"raptor": {"max_token": 0}}, "Input should be greater than or equal to 1"), + ("raptor_max_token_max_limit", {"raptor": {"max_token": 2049}}, "Input should be less than or equal to 2048"), + ("raptor_max_token_float_not_allowed", {"raptor": {"max_token": 3.14}}, "Input should be a valid integer, got a number with a fractional part"), + ("raptor_max_token_type_invalid", {"raptor": {"max_token": "string"}}, "Input should be a valid integer, unable to parse string as an integer"), + ("raptor_threshold_min_limit", {"raptor": {"threshold": -0.1}}, "Input should be greater than or equal to 0"), + ("raptor_threshold_max_limit", {"raptor": {"threshold": 1.1}}, "Input should be less than or equal to 1"), + ("raptor_threshold_type_invalid", {"raptor": {"threshold": "string"}}, "Input should be a valid number, unable to parse string as a number"), + ("raptor_max_cluster_min_limit", {"raptor": {"max_cluster": 0}}, "Input should be greater than or equal to 1"), + ("raptor_max_cluster_max_limit", {"raptor": {"max_cluster": 1025}}, "Input should be less than or equal to 1024"), + ("raptor_max_cluster_float_not_allowed", {"raptor": {"max_cluster": 3.14}}, "Input should be a valid integer, got a number with a fractional par"), + ("raptor_max_cluster_type_invalid", {"raptor": {"max_cluster": "string"}}, "Input should be a valid integer, unable to parse string as an integer"), + ("raptor_random_seed_min_limit", {"raptor": {"random_seed": -1}}, "Input should be greater than or equal to 0"), + ("raptor_random_seed_float_not_allowed", {"raptor": {"random_seed": 3.14}}, "Input should be a valid integer, got a number with a fractional part"), + ("raptor_random_seed_type_invalid", {"raptor": {"random_seed": "string"}}, "Input should be a valid integer, unable to parse string as an integer"), + ("parser_config_type_invalid", {"delimiter": "a" * 65536}, "Parser config exceeds size limit (max 65,535 characters)"), + ], + ids=[ + "auto_keywords_min_limit", + "auto_keywords_max_limit", + "auto_keywords_float_not_allowed", + "auto_keywords_type_invalid", + "auto_questions_min_limit", + "auto_questions_max_limit", + "auto_questions_float_not_allowed", + "auto_questions_type_invalid", + "chunk_token_num_min_limit", + "chunk_token_num_max_limit", + "chunk_token_num_float_not_allowed", + "chunk_token_num_type_invalid", + "delimiter_empty", + "html4excel_type_invalid", + "tag_kb_ids_not_list", + "tag_kb_ids_int_in_list", + "topn_tags_min_limit", + "topn_tags_max_limit", + "topn_tags_float_not_allowed", + "topn_tags_type_invalid", + "filename_embd_weight_min_limit", + "filename_embd_weight_max_limit", + "filename_embd_weight_type_invalid", + "task_page_size_min_limit", + "task_page_size_float_not_allowed", + "task_page_size_type_invalid", + "pages_not_list", + "pages_not_list_in_list", + "pages_not_int_list", + "graphrag_type_invalid", + "graphrag_entity_types_not_list", + "graphrag_entity_types_not_str_in_list", + "graphrag_method_unknown", + "graphrag_method_none", + "graphrag_community_type_invalid", + "graphrag_resolution_type_invalid", + "raptor_type_invalid", + "raptor_prompt_empty", + "raptor_prompt_space", + "raptor_max_token_min_limit", + "raptor_max_token_max_limit", + "raptor_max_token_float_not_allowed", + "raptor_max_token_type_invalid", + "raptor_threshold_min_limit", + "raptor_threshold_max_limit", + "raptor_threshold_type_invalid", + "raptor_max_cluster_min_limit", + "raptor_max_cluster_max_limit", + "raptor_max_cluster_float_not_allowed", + "raptor_max_cluster_type_invalid", + "raptor_random_seed_min_limit", + "raptor_random_seed_float_not_allowed", + "raptor_random_seed_type_invalid", + "parser_config_type_invalid", + ], + ) + def test_parser_config_invalid(self, HttpApiAuth, name, parser_config, expected_message): + payload = {"name": name, "parser_config": parser_config} + res = create_dataset(HttpApiAuth, payload) + assert res["code"] == 101, res + assert expected_message in res["message"], res + + @pytest.mark.p2 + def test_parser_config_empty(self, HttpApiAuth): + payload = {"name": "parser_config_empty", "parser_config": {}} + res = create_dataset(HttpApiAuth, payload) + assert res["code"] == 0, res + assert res["data"]["parser_config"] == { + "chunk_token_num": 128, + "delimiter": r"\n", + "html4excel": False, + "layout_recognize": "DeepDOC", + "raptor": {"use_raptor": False}, + }, res + + @pytest.mark.p2 + def test_parser_config_unset(self, HttpApiAuth): + payload = {"name": "parser_config_unset"} + res = create_dataset(HttpApiAuth, payload) + assert res["code"] == 0, res + assert res["data"]["parser_config"] == { + "chunk_token_num": 128, + "delimiter": r"\n", + "html4excel": False, + "layout_recognize": "DeepDOC", + "raptor": {"use_raptor": False}, + }, res + + @pytest.mark.p3 + def test_parser_config_none(self, HttpApiAuth): + payload = {"name": "parser_config_none", "parser_config": None} + res = create_dataset(HttpApiAuth, payload) + assert res["code"] == 0, res + assert res["data"]["parser_config"] == { + "chunk_token_num": 128, + "delimiter": "\\n", + "html4excel": False, + "layout_recognize": "DeepDOC", + "raptor": {"use_raptor": False}, + }, res + + @pytest.mark.p2 + @pytest.mark.parametrize( + "payload", + [ + {"name": "id", "id": "id"}, + {"name": "tenant_id", "tenant_id": "e57c1966f99211efb41e9e45646e0111"}, + {"name": "created_by", "created_by": "created_by"}, + {"name": "create_date", "create_date": "Tue, 11 Mar 2025 13:37:23 GMT"}, + {"name": "create_time", "create_time": 1741671443322}, + {"name": "update_date", "update_date": "Tue, 11 Mar 2025 13:37:23 GMT"}, + {"name": "update_time", "update_time": 1741671443339}, + {"name": "document_count", "document_count": 1}, + {"name": "chunk_count", "chunk_count": 1}, + {"name": "token_num", "token_num": 1}, + {"name": "status", "status": "1"}, + {"name": "pagerank", "pagerank": 50}, + {"name": "unknown_field", "unknown_field": "unknown_field"}, + ], + ) + def test_unsupported_field(self, HttpApiAuth, payload): + res = create_dataset(HttpApiAuth, payload) + assert res["code"] == 101, res + assert "Extra inputs are not permitted" in res["message"], res diff --git a/test/testcases/test_http_api/test_dataset_mangement/test_delete_datasets.py b/test/testcases/test_http_api/test_dataset_mangement/test_delete_datasets.py new file mode 100644 index 00000000000..1bba3fac9eb --- /dev/null +++ b/test/testcases/test_http_api/test_dataset_mangement/test_delete_datasets.py @@ -0,0 +1,220 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import uuid +from concurrent.futures import ThreadPoolExecutor, as_completed + +import pytest +from common import ( + batch_create_datasets, + delete_datasets, + list_datasets, +) +from configs import INVALID_API_TOKEN +from libs.auth import RAGFlowHttpApiAuth + + +class TestAuthorization: + @pytest.mark.p1 + @pytest.mark.parametrize( + "invalid_auth, expected_code, expected_message", + [ + (None, 0, "`Authorization` can't be empty"), + ( + RAGFlowHttpApiAuth(INVALID_API_TOKEN), + 109, + "Authentication error: API key is invalid!", + ), + ], + ) + def test_auth_invalid(self, invalid_auth, expected_code, expected_message): + res = delete_datasets(invalid_auth) + assert res["code"] == expected_code, res + assert res["message"] == expected_message, res + + +class TestRquest: + @pytest.mark.p3 + def test_content_type_bad(self, HttpApiAuth): + BAD_CONTENT_TYPE = "text/xml" + res = delete_datasets(HttpApiAuth, headers={"Content-Type": BAD_CONTENT_TYPE}) + assert res["code"] == 101, res + assert res["message"] == f"Unsupported content type: Expected application/json, got {BAD_CONTENT_TYPE}", res + + @pytest.mark.p3 + @pytest.mark.parametrize( + "payload, expected_message", + [ + ("a", "Malformed JSON syntax: Missing commas/brackets or invalid encoding"), + ('"a"', "Invalid request payload: expected object, got str"), + ], + ids=["malformed_json_syntax", "invalid_request_payload_type"], + ) + def test_payload_bad(self, HttpApiAuth, payload, expected_message): + res = delete_datasets(HttpApiAuth, data=payload) + assert res["code"] == 101, res + assert res["message"] == expected_message, res + + @pytest.mark.p3 + def test_payload_unset(self, HttpApiAuth): + res = delete_datasets(HttpApiAuth, None) + assert res["code"] == 101, res + assert res["message"] == "Malformed JSON syntax: Missing commas/brackets or invalid encoding", res + + +class TestCapability: + @pytest.mark.p3 + def test_delete_dataset_1k(self, HttpApiAuth): + ids = batch_create_datasets(HttpApiAuth, 1_000) + res = delete_datasets(HttpApiAuth, {"ids": ids}) + assert res["code"] == 0, res + + res = list_datasets(HttpApiAuth) + assert len(res["data"]) == 0, res + + @pytest.mark.p3 + def test_concurrent_deletion(self, HttpApiAuth): + count = 1_000 + ids = batch_create_datasets(HttpApiAuth, count) + + with ThreadPoolExecutor(max_workers=5) as executor: + futures = [executor.submit(delete_datasets, HttpApiAuth, {"ids": ids[i : i + 1]}) for i in range(count)] + responses = list(as_completed(futures)) + assert len(responses) == count, responses + assert all(future.result()["code"] == 0 for future in futures) + + +class TestDatasetsDelete: + @pytest.mark.p1 + @pytest.mark.parametrize( + "func, expected_code, remaining", + [ + (lambda r: {"ids": r[:1]}, 0, 2), + (lambda r: {"ids": r}, 0, 0), + ], + ids=["single_dataset", "multiple_datasets"], + ) + def test_ids(self, HttpApiAuth, add_datasets_func, func, expected_code, remaining): + dataset_ids = add_datasets_func + if callable(func): + payload = func(dataset_ids) + res = delete_datasets(HttpApiAuth, payload) + assert res["code"] == expected_code, res + + res = list_datasets(HttpApiAuth) + assert len(res["data"]) == remaining, res + + @pytest.mark.p1 + @pytest.mark.usefixtures("add_dataset_func") + def test_ids_empty(self, HttpApiAuth): + payload = {"ids": []} + res = delete_datasets(HttpApiAuth, payload) + assert res["code"] == 0, res + + res = list_datasets(HttpApiAuth) + assert len(res["data"]) == 1, res + + @pytest.mark.p1 + @pytest.mark.usefixtures("add_datasets_func") + def test_ids_none(self, HttpApiAuth): + payload = {"ids": None} + res = delete_datasets(HttpApiAuth, payload) + assert res["code"] == 0, res + + res = list_datasets(HttpApiAuth) + assert len(res["data"]) == 0, res + + @pytest.mark.p2 + @pytest.mark.usefixtures("add_dataset_func") + def test_id_not_uuid(self, HttpApiAuth): + payload = {"ids": ["not_uuid"]} + res = delete_datasets(HttpApiAuth, payload) + assert res["code"] == 101, res + assert "Invalid UUID1 format" in res["message"], res + + res = list_datasets(HttpApiAuth) + assert len(res["data"]) == 1, res + + @pytest.mark.p3 + @pytest.mark.usefixtures("add_dataset_func") + def test_id_not_uuid1(self, HttpApiAuth): + payload = {"ids": [uuid.uuid4().hex]} + res = delete_datasets(HttpApiAuth, payload) + assert res["code"] == 101, res + assert "Invalid UUID1 format" in res["message"], res + + @pytest.mark.p2 + @pytest.mark.usefixtures("add_dataset_func") + def test_id_wrong_uuid(self, HttpApiAuth): + payload = {"ids": ["d94a8dc02c9711f0930f7fbc369eab6d"]} + res = delete_datasets(HttpApiAuth, payload) + assert res["code"] == 108, res + assert "lacks permission for dataset" in res["message"], res + + res = list_datasets(HttpApiAuth) + assert len(res["data"]) == 1, res + + @pytest.mark.p2 + @pytest.mark.parametrize( + "func", + [ + lambda r: {"ids": ["d94a8dc02c9711f0930f7fbc369eab6d"] + r}, + lambda r: {"ids": r[:1] + ["d94a8dc02c9711f0930f7fbc369eab6d"] + r[1:3]}, + lambda r: {"ids": r + ["d94a8dc02c9711f0930f7fbc369eab6d"]}, + ], + ) + def test_ids_partial_invalid(self, HttpApiAuth, add_datasets_func, func): + dataset_ids = add_datasets_func + if callable(func): + payload = func(dataset_ids) + res = delete_datasets(HttpApiAuth, payload) + assert res["code"] == 108, res + assert "lacks permission for dataset" in res["message"], res + + res = list_datasets(HttpApiAuth) + assert len(res["data"]) == 3, res + + @pytest.mark.p2 + def test_ids_duplicate(self, HttpApiAuth, add_datasets_func): + dataset_ids = add_datasets_func + payload = {"ids": dataset_ids + dataset_ids} + res = delete_datasets(HttpApiAuth, payload) + assert res["code"] == 101, res + assert "Duplicate ids:" in res["message"], res + + res = list_datasets(HttpApiAuth) + assert len(res["data"]) == 3, res + + @pytest.mark.p2 + def test_repeated_delete(self, HttpApiAuth, add_datasets_func): + dataset_ids = add_datasets_func + payload = {"ids": dataset_ids} + res = delete_datasets(HttpApiAuth, payload) + assert res["code"] == 0, res + + res = delete_datasets(HttpApiAuth, payload) + assert res["code"] == 108, res + assert "lacks permission for dataset" in res["message"], res + + @pytest.mark.p2 + @pytest.mark.usefixtures("add_dataset_func") + def test_field_unsupported(self, HttpApiAuth): + payload = {"unknown_field": "unknown_field"} + res = delete_datasets(HttpApiAuth, payload) + assert res["code"] == 101, res + assert "Extra inputs are not permitted" in res["message"], res + + res = list_datasets(HttpApiAuth) + assert len(res["data"]) == 1, res diff --git a/test/testcases/test_http_api/test_dataset_mangement/test_list_datasets.py b/test/testcases/test_http_api/test_dataset_mangement/test_list_datasets.py new file mode 100644 index 00000000000..9d81491b4b2 --- /dev/null +++ b/test/testcases/test_http_api/test_dataset_mangement/test_list_datasets.py @@ -0,0 +1,342 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import uuid +from concurrent.futures import ThreadPoolExecutor, as_completed + +import pytest +from common import list_datasets +from configs import INVALID_API_TOKEN +from libs.auth import RAGFlowHttpApiAuth +from utils import is_sorted + + +class TestAuthorization: + @pytest.mark.p1 + @pytest.mark.parametrize( + "invalid_auth, expected_code, expected_message", + [ + (None, 0, "`Authorization` can't be empty"), + ( + RAGFlowHttpApiAuth(INVALID_API_TOKEN), + 109, + "Authentication error: API key is invalid!", + ), + ], + ) + def test_auth_invalid(self, invalid_auth, expected_code, expected_message): + res = list_datasets(invalid_auth) + assert res["code"] == expected_code, res + assert res["message"] == expected_message, res + + +class TestCapability: + @pytest.mark.p3 + def test_concurrent_list(self, HttpApiAuth): + count = 100 + with ThreadPoolExecutor(max_workers=5) as executor: + futures = [executor.submit(list_datasets, HttpApiAuth) for i in range(count)] + responses = list(as_completed(futures)) + assert len(responses) == count, responses + assert all(future.result()["code"] == 0 for future in futures) + + +@pytest.mark.usefixtures("add_datasets") +class TestDatasetsList: + @pytest.mark.p1 + def test_params_unset(self, HttpApiAuth): + res = list_datasets(HttpApiAuth, None) + assert res["code"] == 0, res + assert len(res["data"]) == 5, res + + @pytest.mark.p2 + def test_params_empty(self, HttpApiAuth): + res = list_datasets(HttpApiAuth, {}) + assert res["code"] == 0, res + assert len(res["data"]) == 5, res + + @pytest.mark.p1 + @pytest.mark.parametrize( + "params, expected_page_size", + [ + ({"page": 2, "page_size": 2}, 2), + ({"page": 3, "page_size": 2}, 1), + ({"page": 4, "page_size": 2}, 0), + ({"page": "2", "page_size": 2}, 2), + ({"page": 1, "page_size": 10}, 5), + ], + ids=["normal_middle_page", "normal_last_partial_page", "beyond_max_page", "string_page_number", "full_data_single_page"], + ) + def test_page(self, HttpApiAuth, params, expected_page_size): + res = list_datasets(HttpApiAuth, params) + assert res["code"] == 0, res + assert len(res["data"]) == expected_page_size, res + + @pytest.mark.p2 + @pytest.mark.parametrize( + "params, expected_code, expected_message", + [ + ({"page": 0}, 101, "Input should be greater than or equal to 1"), + ({"page": "a"}, 101, "Input should be a valid integer, unable to parse string as an integer"), + ], + ids=["page_0", "page_a"], + ) + def test_page_invalid(self, HttpApiAuth, params, expected_code, expected_message): + res = list_datasets(HttpApiAuth, params=params) + assert res["code"] == expected_code, res + assert expected_message in res["message"], res + + @pytest.mark.p2 + def test_page_none(self, HttpApiAuth): + params = {"page": None} + res = list_datasets(HttpApiAuth, params) + assert res["code"] == 0, res + assert len(res["data"]) == 5, res + + @pytest.mark.p1 + @pytest.mark.parametrize( + "params, expected_page_size", + [ + ({"page_size": 1}, 1), + ({"page_size": 3}, 3), + ({"page_size": 5}, 5), + ({"page_size": 6}, 5), + ({"page_size": "1"}, 1), + ], + ids=["min_valid_page_size", "medium_page_size", "page_size_equals_total", "page_size_exceeds_total", "string_type_page_size"], + ) + def test_page_size(self, HttpApiAuth, params, expected_page_size): + res = list_datasets(HttpApiAuth, params) + assert res["code"] == 0, res + assert len(res["data"]) == expected_page_size, res + + @pytest.mark.p2 + @pytest.mark.parametrize( + "params, expected_code, expected_message", + [ + ({"page_size": 0}, 101, "Input should be greater than or equal to 1"), + ({"page_size": "a"}, 101, "Input should be a valid integer, unable to parse string as an integer"), + ], + ) + def test_page_size_invalid(self, HttpApiAuth, params, expected_code, expected_message): + res = list_datasets(HttpApiAuth, params) + assert res["code"] == expected_code, res + assert expected_message in res["message"], res + + @pytest.mark.p2 + def test_page_size_none(self, HttpApiAuth): + params = {"page_size": None} + res = list_datasets(HttpApiAuth, params) + assert res["code"] == 0, res + assert len(res["data"]) == 5, res + + @pytest.mark.p2 + @pytest.mark.parametrize( + "params, assertions", + [ + ({"orderby": "create_time"}, lambda r: (is_sorted(r["data"], "create_time", True))), + ({"orderby": "update_time"}, lambda r: (is_sorted(r["data"], "update_time", True))), + ({"orderby": "CREATE_TIME"}, lambda r: (is_sorted(r["data"], "create_time", True))), + ({"orderby": "UPDATE_TIME"}, lambda r: (is_sorted(r["data"], "update_time", True))), + ({"orderby": " create_time "}, lambda r: (is_sorted(r["data"], "update_time", True))), + ], + ids=["orderby_create_time", "orderby_update_time", "orderby_create_time_upper", "orderby_update_time_upper", "whitespace"], + ) + def test_orderby(self, HttpApiAuth, params, assertions): + res = list_datasets(HttpApiAuth, params) + assert res["code"] == 0, res + if callable(assertions): + assert assertions(res), res + + @pytest.mark.p3 + @pytest.mark.parametrize( + "params", + [ + {"orderby": ""}, + {"orderby": "unknown"}, + ], + ids=["empty", "unknown"], + ) + def test_orderby_invalid(self, HttpApiAuth, params): + res = list_datasets(HttpApiAuth, params) + assert res["code"] == 101, res + assert "Input should be 'create_time' or 'update_time'" in res["message"], res + + @pytest.mark.p3 + def test_orderby_none(self, HttpApiAuth): + params = {"orderby": None} + res = list_datasets(HttpApiAuth, params) + assert res["code"] == 0, res + assert is_sorted(res["data"], "create_time", True), res + + @pytest.mark.p2 + @pytest.mark.parametrize( + "params, assertions", + [ + ({"desc": True}, lambda r: (is_sorted(r["data"], "create_time", True))), + ({"desc": False}, lambda r: (is_sorted(r["data"], "create_time", False))), + ({"desc": "true"}, lambda r: (is_sorted(r["data"], "create_time", True))), + ({"desc": "false"}, lambda r: (is_sorted(r["data"], "create_time", False))), + ({"desc": 1}, lambda r: (is_sorted(r["data"], "create_time", True))), + ({"desc": 0}, lambda r: (is_sorted(r["data"], "create_time", False))), + ({"desc": "yes"}, lambda r: (is_sorted(r["data"], "create_time", True))), + ({"desc": "no"}, lambda r: (is_sorted(r["data"], "create_time", False))), + ({"desc": "y"}, lambda r: (is_sorted(r["data"], "create_time", True))), + ({"desc": "n"}, lambda r: (is_sorted(r["data"], "create_time", False))), + ], + ids=["desc=True", "desc=False", "desc=true", "desc=false", "desc=1", "desc=0", "desc=yes", "desc=no", "desc=y", "desc=n"], + ) + def test_desc(self, HttpApiAuth, params, assertions): + res = list_datasets(HttpApiAuth, params) + assert res["code"] == 0, res + if callable(assertions): + assert assertions(res), res + + @pytest.mark.p3 + @pytest.mark.parametrize( + "params", + [ + {"desc": 3.14}, + {"desc": "unknown"}, + ], + ids=["empty", "unknown"], + ) + def test_desc_invalid(self, HttpApiAuth, params): + res = list_datasets(HttpApiAuth, params) + assert res["code"] == 101, res + assert "Input should be a valid boolean, unable to interpret input" in res["message"], res + + @pytest.mark.p3 + def test_desc_none(self, HttpApiAuth): + params = {"desc": None} + res = list_datasets(HttpApiAuth, params) + assert res["code"] == 0, res + assert is_sorted(res["data"], "create_time", True), res + + @pytest.mark.p1 + def test_name(self, HttpApiAuth): + params = {"name": "dataset_1"} + res = list_datasets(HttpApiAuth, params) + assert res["code"] == 0, res + assert len(res["data"]) == 1, res + assert res["data"][0]["name"] == "dataset_1", res + + @pytest.mark.p2 + def test_name_wrong(self, HttpApiAuth): + params = {"name": "wrong name"} + res = list_datasets(HttpApiAuth, params) + assert res["code"] == 108, res + assert "lacks permission for dataset" in res["message"], res + + @pytest.mark.p2 + def test_name_empty(self, HttpApiAuth): + params = {"name": ""} + res = list_datasets(HttpApiAuth, params) + assert res["code"] == 0, res + assert len(res["data"]) == 5, res + + @pytest.mark.p2 + def test_name_none(self, HttpApiAuth): + params = {"name": None} + res = list_datasets(HttpApiAuth, params) + assert res["code"] == 0, res + assert len(res["data"]) == 5, res + + @pytest.mark.p1 + def test_id(self, HttpApiAuth, add_datasets): + dataset_ids = add_datasets + params = {"id": dataset_ids[0]} + res = list_datasets(HttpApiAuth, params) + assert res["code"] == 0 + assert len(res["data"]) == 1 + assert res["data"][0]["id"] == dataset_ids[0] + + @pytest.mark.p2 + def test_id_not_uuid(self, HttpApiAuth): + params = {"id": "not_uuid"} + res = list_datasets(HttpApiAuth, params) + assert res["code"] == 101, res + assert "Invalid UUID1 format" in res["message"], res + + @pytest.mark.p2 + def test_id_not_uuid1(self, HttpApiAuth): + params = {"id": uuid.uuid4().hex} + res = list_datasets(HttpApiAuth, params) + assert res["code"] == 101, res + assert "Invalid UUID1 format" in res["message"], res + + @pytest.mark.p2 + def test_id_wrong_uuid(self, HttpApiAuth): + params = {"id": "d94a8dc02c9711f0930f7fbc369eab6d"} + res = list_datasets(HttpApiAuth, params) + assert res["code"] == 108, res + assert "lacks permission for dataset" in res["message"], res + + @pytest.mark.p2 + def test_id_empty(self, HttpApiAuth): + params = {"id": ""} + res = list_datasets(HttpApiAuth, params) + assert res["code"] == 101, res + assert "Invalid UUID1 format" in res["message"], res + + @pytest.mark.p2 + def test_id_none(self, HttpApiAuth): + params = {"id": None} + res = list_datasets(HttpApiAuth, params) + assert res["code"] == 0, res + assert len(res["data"]) == 5, res + + @pytest.mark.p2 + @pytest.mark.parametrize( + "func, name, expected_num", + [ + (lambda r: r[0], "dataset_0", 1), + (lambda r: r[0], "dataset_1", 0), + ], + ids=["name_and_id_match", "name_and_id_mismatch"], + ) + def test_name_and_id(self, HttpApiAuth, add_datasets, func, name, expected_num): + dataset_ids = add_datasets + if callable(func): + params = {"id": func(dataset_ids), "name": name} + res = list_datasets(HttpApiAuth, params) + assert res["code"] == 0, res + assert len(res["data"]) == expected_num, res + + @pytest.mark.p3 + @pytest.mark.parametrize( + "dataset_id, name", + [ + (lambda r: r[0], "wrong_name"), + (uuid.uuid1().hex, "dataset_0"), + ], + ids=["name", "id"], + ) + def test_name_and_id_wrong(self, HttpApiAuth, add_datasets, dataset_id, name): + dataset_ids = add_datasets + if callable(dataset_id): + params = {"id": dataset_id(dataset_ids), "name": name} + else: + params = {"id": dataset_id, "name": name} + res = list_datasets(HttpApiAuth, params) + assert res["code"] == 108, res + assert "lacks permission for dataset" in res["message"], res + + @pytest.mark.p2 + def test_field_unsupported(self, HttpApiAuth): + params = {"unknown_field": "unknown_field"} + res = list_datasets(HttpApiAuth, params) + assert res["code"] == 101, res + assert "Extra inputs are not permitted" in res["message"], res diff --git a/test/testcases/test_http_api/test_dataset_mangement/test_update_dataset.py b/test/testcases/test_http_api/test_dataset_mangement/test_update_dataset.py new file mode 100644 index 00000000000..15278800086 --- /dev/null +++ b/test/testcases/test_http_api/test_dataset_mangement/test_update_dataset.py @@ -0,0 +1,848 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os +import uuid +from concurrent.futures import ThreadPoolExecutor, as_completed + +import pytest +from common import list_datasets, update_dataset +from configs import DATASET_NAME_LIMIT, INVALID_API_TOKEN +from hypothesis import HealthCheck, example, given, settings +from libs.auth import RAGFlowHttpApiAuth +from utils import encode_avatar +from utils.file_utils import create_image_file +from utils.hypothesis_utils import valid_names + +# TODO: Missing scenario for updating embedding_model with chunk_count != 0 + + +class TestAuthorization: + @pytest.mark.p1 + @pytest.mark.parametrize( + "invalid_auth, expected_code, expected_message", + [ + (None, 0, "`Authorization` can't be empty"), + ( + RAGFlowHttpApiAuth(INVALID_API_TOKEN), + 109, + "Authentication error: API key is invalid!", + ), + ], + ids=["empty_auth", "invalid_api_token"], + ) + def test_auth_invalid(self, invalid_auth, expected_code, expected_message): + res = update_dataset(invalid_auth, "dataset_id") + assert res["code"] == expected_code, res + assert res["message"] == expected_message, res + + +class TestRquest: + @pytest.mark.p3 + def test_bad_content_type(self, HttpApiAuth, add_dataset_func): + dataset_id = add_dataset_func + BAD_CONTENT_TYPE = "text/xml" + res = update_dataset(HttpApiAuth, dataset_id, {"name": "bad_content_type"}, headers={"Content-Type": BAD_CONTENT_TYPE}) + assert res["code"] == 101, res + assert res["message"] == f"Unsupported content type: Expected application/json, got {BAD_CONTENT_TYPE}", res + + @pytest.mark.p3 + @pytest.mark.parametrize( + "payload, expected_message", + [ + ("a", "Malformed JSON syntax: Missing commas/brackets or invalid encoding"), + ('"a"', "Invalid request payload: expected object, got str"), + ], + ids=["malformed_json_syntax", "invalid_request_payload_type"], + ) + def test_payload_bad(self, HttpApiAuth, add_dataset_func, payload, expected_message): + dataset_id = add_dataset_func + res = update_dataset(HttpApiAuth, dataset_id, data=payload) + assert res["code"] == 101, res + assert res["message"] == expected_message, res + + @pytest.mark.p2 + def test_payload_empty(self, HttpApiAuth, add_dataset_func): + dataset_id = add_dataset_func + res = update_dataset(HttpApiAuth, dataset_id, {}) + assert res["code"] == 101, res + assert res["message"] == "No properties were modified", res + + @pytest.mark.p3 + def test_payload_unset(self, HttpApiAuth, add_dataset_func): + dataset_id = add_dataset_func + res = update_dataset(HttpApiAuth, dataset_id, None) + assert res["code"] == 101, res + assert res["message"] == "Malformed JSON syntax: Missing commas/brackets or invalid encoding", res + + +class TestCapability: + @pytest.mark.p3 + def test_update_dateset_concurrent(self, HttpApiAuth, add_dataset_func): + dataset_id = add_dataset_func + count = 100 + with ThreadPoolExecutor(max_workers=5) as executor: + futures = [executor.submit(update_dataset, HttpApiAuth, dataset_id, {"name": f"dataset_{i}"}) for i in range(count)] + responses = list(as_completed(futures)) + assert len(responses) == count, responses + assert all(future.result()["code"] == 0 for future in futures) + + +class TestDatasetUpdate: + @pytest.mark.p3 + def test_dataset_id_not_uuid(self, HttpApiAuth): + payload = {"name": "not uuid"} + res = update_dataset(HttpApiAuth, "not_uuid", payload) + assert res["code"] == 101, res + assert "Invalid UUID1 format" in res["message"], res + + @pytest.mark.p3 + def test_dataset_id_not_uuid1(self, HttpApiAuth): + payload = {"name": "not uuid1"} + res = update_dataset(HttpApiAuth, uuid.uuid4().hex, payload) + assert res["code"] == 101, res + assert "Invalid UUID1 format" in res["message"], res + + @pytest.mark.p3 + def test_dataset_id_wrong_uuid(self, HttpApiAuth): + payload = {"name": "wrong uuid"} + res = update_dataset(HttpApiAuth, "d94a8dc02c9711f0930f7fbc369eab6d", payload) + assert res["code"] == 108, res + assert "lacks permission for dataset" in res["message"], res + + @pytest.mark.p1 + @given(name=valid_names()) + @example("a" * 128) + @settings(max_examples=20, suppress_health_check=[HealthCheck.function_scoped_fixture]) + def test_name(self, HttpApiAuth, add_dataset_func, name): + dataset_id = add_dataset_func + payload = {"name": name} + res = update_dataset(HttpApiAuth, dataset_id, payload) + assert res["code"] == 0, res + + res = list_datasets(HttpApiAuth) + assert res["code"] == 0, res + assert res["data"][0]["name"] == name, res + + @pytest.mark.p2 + @pytest.mark.parametrize( + "name, expected_message", + [ + ("", "String should have at least 1 character"), + (" ", "String should have at least 1 character"), + ("a" * (DATASET_NAME_LIMIT + 1), "String should have at most 128 characters"), + (0, "Input should be a valid string"), + (None, "Input should be a valid string"), + ], + ids=["empty_name", "space_name", "too_long_name", "invalid_name", "None_name"], + ) + def test_name_invalid(self, HttpApiAuth, add_dataset_func, name, expected_message): + dataset_id = add_dataset_func + payload = {"name": name} + res = update_dataset(HttpApiAuth, dataset_id, payload) + assert res["code"] == 101, res + assert expected_message in res["message"], res + + @pytest.mark.p3 + def test_name_duplicated(self, HttpApiAuth, add_datasets_func): + dataset_id = add_datasets_func[0] + name = "dataset_1" + payload = {"name": name} + res = update_dataset(HttpApiAuth, dataset_id, payload) + assert res["code"] == 102, res + assert res["message"] == f"Dataset name '{name}' already exists", res + + @pytest.mark.p3 + def test_name_case_insensitive(self, HttpApiAuth, add_datasets_func): + dataset_id = add_datasets_func[0] + name = "DATASET_1" + payload = {"name": name} + res = update_dataset(HttpApiAuth, dataset_id, payload) + assert res["code"] == 102, res + assert res["message"] == f"Dataset name '{name}' already exists", res + + @pytest.mark.p2 + def test_avatar(self, HttpApiAuth, add_dataset_func, tmp_path): + dataset_id = add_dataset_func + fn = create_image_file(tmp_path / "ragflow_test.png") + payload = { + "avatar": f"data:image/png;base64,{encode_avatar(fn)}", + } + res = update_dataset(HttpApiAuth, dataset_id, payload) + assert res["code"] == 0, res + + res = list_datasets(HttpApiAuth) + assert res["code"] == 0, res + assert res["data"][0]["avatar"] == f"data:image/png;base64,{encode_avatar(fn)}", res + + @pytest.mark.p2 + def test_avatar_exceeds_limit_length(self, HttpApiAuth, add_dataset_func): + dataset_id = add_dataset_func + payload = {"avatar": "a" * 65536} + res = update_dataset(HttpApiAuth, dataset_id, payload) + assert res["code"] == 101, res + assert "String should have at most 65535 characters" in res["message"], res + + @pytest.mark.p3 + @pytest.mark.parametrize( + "avatar_prefix, expected_message", + [ + ("", "Missing MIME prefix. Expected format: data:;base64,"), + ("data:image/png;base64", "Missing MIME prefix. Expected format: data:;base64,"), + ("invalid_mine_prefix:image/png;base64,", "Invalid MIME prefix format. Must start with 'data:'"), + ("data:unsupported_mine_type;base64,", "Unsupported MIME type. Allowed: ['image/jpeg', 'image/png']"), + ], + ids=["empty_prefix", "missing_comma", "unsupported_mine_type", "invalid_mine_type"], + ) + def test_avatar_invalid_prefix(self, HttpApiAuth, add_dataset_func, tmp_path, avatar_prefix, expected_message): + dataset_id = add_dataset_func + fn = create_image_file(tmp_path / "ragflow_test.png") + payload = {"avatar": f"{avatar_prefix}{encode_avatar(fn)}"} + res = update_dataset(HttpApiAuth, dataset_id, payload) + assert res["code"] == 101, res + assert expected_message in res["message"], res + + @pytest.mark.p3 + def test_avatar_none(self, HttpApiAuth, add_dataset_func): + dataset_id = add_dataset_func + payload = {"avatar": None} + res = update_dataset(HttpApiAuth, dataset_id, payload) + assert res["code"] == 0, res + + res = list_datasets(HttpApiAuth) + assert res["code"] == 0, res + assert res["data"][0]["avatar"] is None, res + + @pytest.mark.p2 + def test_description(self, HttpApiAuth, add_dataset_func): + dataset_id = add_dataset_func + payload = {"description": "description"} + res = update_dataset(HttpApiAuth, dataset_id, payload) + assert res["code"] == 0 + + res = list_datasets(HttpApiAuth, {"id": dataset_id}) + assert res["code"] == 0, res + assert res["data"][0]["description"] == "description" + + @pytest.mark.p2 + def test_description_exceeds_limit_length(self, HttpApiAuth, add_dataset_func): + dataset_id = add_dataset_func + payload = {"description": "a" * 65536} + res = update_dataset(HttpApiAuth, dataset_id, payload) + assert res["code"] == 101, res + assert "String should have at most 65535 characters" in res["message"], res + + @pytest.mark.p3 + def test_description_none(self, HttpApiAuth, add_dataset_func): + dataset_id = add_dataset_func + payload = {"description": None} + res = update_dataset(HttpApiAuth, dataset_id, payload) + assert res["code"] == 0, res + + res = list_datasets(HttpApiAuth, {"id": dataset_id}) + assert res["code"] == 0, res + assert res["data"][0]["description"] is None + + @pytest.mark.p1 + @pytest.mark.parametrize( + "embedding_model", + [ + "BAAI/bge-large-zh-v1.5@BAAI", + "maidalun1020/bce-embedding-base_v1@Youdao", + "embedding-3@ZHIPU-AI", + ], + ids=["builtin_baai", "builtin_youdao", "tenant_zhipu"], + ) + def test_embedding_model(self, HttpApiAuth, add_dataset_func, embedding_model): + dataset_id = add_dataset_func + payload = {"embedding_model": embedding_model} + res = update_dataset(HttpApiAuth, dataset_id, payload) + assert res["code"] == 0, res + + res = list_datasets(HttpApiAuth) + assert res["code"] == 0, res + assert res["data"][0]["embedding_model"] == embedding_model, res + + @pytest.mark.p2 + @pytest.mark.parametrize( + "name, embedding_model", + [ + ("unknown_llm_name", "unknown@ZHIPU-AI"), + ("unknown_llm_factory", "embedding-3@unknown"), + ("tenant_no_auth_default_tenant_llm", "text-embedding-v3@Tongyi-Qianwen"), + ("tenant_no_auth", "text-embedding-3-small@OpenAI"), + ], + ids=["unknown_llm_name", "unknown_llm_factory", "tenant_no_auth_default_tenant_llm", "tenant_no_auth"], + ) + def test_embedding_model_invalid(self, HttpApiAuth, add_dataset_func, name, embedding_model): + dataset_id = add_dataset_func + payload = {"name": name, "embedding_model": embedding_model} + res = update_dataset(HttpApiAuth, dataset_id, payload) + assert res["code"] == 101, res + if "tenant_no_auth" in name: + assert res["message"] == f"Unauthorized model: <{embedding_model}>", res + else: + assert res["message"] == f"Unsupported model: <{embedding_model}>", res + + @pytest.mark.p2 + @pytest.mark.parametrize( + "name, embedding_model", + [ + ("missing_at", "BAAI/bge-large-zh-v1.5BAAI"), + ("missing_model_name", "@BAAI"), + ("missing_provider", "BAAI/bge-large-zh-v1.5@"), + ("whitespace_only_model_name", " @BAAI"), + ("whitespace_only_provider", "BAAI/bge-large-zh-v1.5@ "), + ], + ids=["missing_at", "empty_model_name", "empty_provider", "whitespace_only_model_name", "whitespace_only_provider"], + ) + def test_embedding_model_format(self, HttpApiAuth, add_dataset_func, name, embedding_model): + dataset_id = add_dataset_func + payload = {"name": name, "embedding_model": embedding_model} + res = update_dataset(HttpApiAuth, dataset_id, payload) + assert res["code"] == 101, res + if name == "missing_at": + assert "Embedding model identifier must follow @ format" in res["message"], res + else: + assert "Both model_name and provider must be non-empty strings" in res["message"], res + + @pytest.mark.p2 + def test_embedding_model_none(self, HttpApiAuth, add_dataset_func): + dataset_id = add_dataset_func + payload = {"embedding_model": None} + res = update_dataset(HttpApiAuth, dataset_id, payload) + assert res["code"] == 101, res + assert "Input should be a valid string" in res["message"], res + + @pytest.mark.p1 + @pytest.mark.parametrize( + "permission", + [ + "me", + "team", + "ME", + "TEAM", + " ME ", + ], + ids=["me", "team", "me_upercase", "team_upercase", "whitespace"], + ) + def test_permission(self, HttpApiAuth, add_dataset_func, permission): + dataset_id = add_dataset_func + payload = {"permission": permission} + res = update_dataset(HttpApiAuth, dataset_id, payload) + assert res["code"] == 0, res + + res = list_datasets(HttpApiAuth) + assert res["code"] == 0, res + assert res["data"][0]["permission"] == permission.lower().strip(), res + + @pytest.mark.p2 + @pytest.mark.parametrize( + "permission", + [ + "", + "unknown", + list(), + ], + ids=["empty", "unknown", "type_error"], + ) + def test_permission_invalid(self, HttpApiAuth, add_dataset_func, permission): + dataset_id = add_dataset_func + payload = {"permission": permission} + res = update_dataset(HttpApiAuth, dataset_id, payload) + assert res["code"] == 101 + assert "Input should be 'me' or 'team'" in res["message"] + + @pytest.mark.p3 + def test_permission_none(self, HttpApiAuth, add_dataset_func): + dataset_id = add_dataset_func + payload = {"permission": None} + res = update_dataset(HttpApiAuth, dataset_id, payload) + assert res["code"] == 101, res + assert "Input should be 'me' or 'team'" in res["message"], res + + @pytest.mark.p1 + @pytest.mark.parametrize( + "chunk_method", + [ + "naive", + "book", + "email", + "laws", + "manual", + "one", + "paper", + "picture", + "presentation", + "qa", + "table", + "tag", + ], + ids=["naive", "book", "email", "laws", "manual", "one", "paper", "picture", "presentation", "qa", "table", "tag"], + ) + def test_chunk_method(self, HttpApiAuth, add_dataset_func, chunk_method): + dataset_id = add_dataset_func + payload = {"chunk_method": chunk_method} + res = update_dataset(HttpApiAuth, dataset_id, payload) + assert res["code"] == 0, res + + res = list_datasets(HttpApiAuth) + assert res["code"] == 0, res + assert res["data"][0]["chunk_method"] == chunk_method, res + + @pytest.mark.p2 + @pytest.mark.parametrize( + "chunk_method", + [ + "", + "unknown", + list(), + ], + ids=["empty", "unknown", "type_error"], + ) + def test_chunk_method_invalid(self, HttpApiAuth, add_dataset_func, chunk_method): + dataset_id = add_dataset_func + payload = {"chunk_method": chunk_method} + res = update_dataset(HttpApiAuth, dataset_id, payload) + assert res["code"] == 101, res + assert "Input should be 'naive', 'book', 'email', 'laws', 'manual', 'one', 'paper', 'picture', 'presentation', 'qa', 'table' or 'tag'" in res["message"], res + + @pytest.mark.p3 + def test_chunk_method_none(self, HttpApiAuth, add_dataset_func): + dataset_id = add_dataset_func + payload = {"chunk_method": None} + res = update_dataset(HttpApiAuth, dataset_id, payload) + assert res["code"] == 101, res + assert "Input should be 'naive', 'book', 'email', 'laws', 'manual', 'one', 'paper', 'picture', 'presentation', 'qa', 'table' or 'tag'" in res["message"], res + + @pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="#8208") + @pytest.mark.p2 + @pytest.mark.parametrize("pagerank", [0, 50, 100], ids=["min", "mid", "max"]) + def test_pagerank(self, HttpApiAuth, add_dataset_func, pagerank): + dataset_id = add_dataset_func + payload = {"pagerank": pagerank} + res = update_dataset(HttpApiAuth, dataset_id, payload) + assert res["code"] == 0 + + res = list_datasets(HttpApiAuth, {"id": dataset_id}) + assert res["code"] == 0, res + assert res["data"][0]["pagerank"] == pagerank + + @pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="#8208") + @pytest.mark.p2 + def test_pagerank_set_to_0(self, HttpApiAuth, add_dataset_func): + dataset_id = add_dataset_func + payload = {"pagerank": 50} + res = update_dataset(HttpApiAuth, dataset_id, payload) + assert res["code"] == 0, res + + res = list_datasets(HttpApiAuth, {"id": dataset_id}) + assert res["code"] == 0, res + assert res["data"][0]["pagerank"] == 50, res + + payload = {"pagerank": 0} + res = update_dataset(HttpApiAuth, dataset_id, payload) + assert res["code"] == 0 + + res = list_datasets(HttpApiAuth, {"id": dataset_id}) + assert res["code"] == 0, res + assert res["data"][0]["pagerank"] == 0, res + + @pytest.mark.skipif(os.getenv("DOC_ENGINE") != "infinity", reason="#8208") + @pytest.mark.p2 + def test_pagerank_infinity(self, HttpApiAuth, add_dataset_func): + dataset_id = add_dataset_func + payload = {"pagerank": 50} + res = update_dataset(HttpApiAuth, dataset_id, payload) + assert res["code"] == 101, res + assert res["message"] == "'pagerank' can only be set when doc_engine is elasticsearch", res + + @pytest.mark.p2 + @pytest.mark.parametrize( + "pagerank, expected_message", + [ + (-1, "Input should be greater than or equal to 0"), + (101, "Input should be less than or equal to 100"), + ], + ids=["min_limit", "max_limit"], + ) + def test_pagerank_invalid(self, HttpApiAuth, add_dataset_func, pagerank, expected_message): + dataset_id = add_dataset_func + payload = {"pagerank": pagerank} + res = update_dataset(HttpApiAuth, dataset_id, payload) + assert res["code"] == 101, res + assert expected_message in res["message"], res + + @pytest.mark.p3 + def test_pagerank_none(self, HttpApiAuth, add_dataset_func): + dataset_id = add_dataset_func + payload = {"pagerank": None} + res = update_dataset(HttpApiAuth, dataset_id, payload) + assert res["code"] == 101, res + assert "Input should be a valid integer" in res["message"], res + + @pytest.mark.p1 + @pytest.mark.parametrize( + "parser_config", + [ + {"auto_keywords": 0}, + {"auto_keywords": 16}, + {"auto_keywords": 32}, + {"auto_questions": 0}, + {"auto_questions": 5}, + {"auto_questions": 10}, + {"chunk_token_num": 1}, + {"chunk_token_num": 1024}, + {"chunk_token_num": 2048}, + {"delimiter": "\n"}, + {"delimiter": " "}, + {"html4excel": True}, + {"html4excel": False}, + {"layout_recognize": "DeepDOC"}, + {"layout_recognize": "Plain Text"}, + {"tag_kb_ids": ["1", "2"]}, + {"topn_tags": 1}, + {"topn_tags": 5}, + {"topn_tags": 10}, + {"filename_embd_weight": 0.1}, + {"filename_embd_weight": 0.5}, + {"filename_embd_weight": 1.0}, + {"task_page_size": 1}, + {"task_page_size": None}, + {"pages": [[1, 100]]}, + {"pages": None}, + {"graphrag": {"use_graphrag": True}}, + {"graphrag": {"use_graphrag": False}}, + {"graphrag": {"entity_types": ["age", "sex", "height", "weight"]}}, + {"graphrag": {"method": "general"}}, + {"graphrag": {"method": "light"}}, + {"graphrag": {"community": True}}, + {"graphrag": {"community": False}}, + {"graphrag": {"resolution": True}}, + {"graphrag": {"resolution": False}}, + {"raptor": {"use_raptor": True}}, + {"raptor": {"use_raptor": False}}, + {"raptor": {"prompt": "Who are you?"}}, + {"raptor": {"max_token": 1}}, + {"raptor": {"max_token": 1024}}, + {"raptor": {"max_token": 2048}}, + {"raptor": {"threshold": 0.0}}, + {"raptor": {"threshold": 0.5}}, + {"raptor": {"threshold": 1.0}}, + {"raptor": {"max_cluster": 1}}, + {"raptor": {"max_cluster": 512}}, + {"raptor": {"max_cluster": 1024}}, + {"raptor": {"random_seed": 0}}, + ], + ids=[ + "auto_keywords_min", + "auto_keywords_mid", + "auto_keywords_max", + "auto_questions_min", + "auto_questions_mid", + "auto_questions_max", + "chunk_token_num_min", + "chunk_token_num_mid", + "chunk_token_num_max", + "delimiter", + "delimiter_space", + "html4excel_true", + "html4excel_false", + "layout_recognize_DeepDOC", + "layout_recognize_navie", + "tag_kb_ids", + "topn_tags_min", + "topn_tags_mid", + "topn_tags_max", + "filename_embd_weight_min", + "filename_embd_weight_mid", + "filename_embd_weight_max", + "task_page_size_min", + "task_page_size_None", + "pages", + "pages_none", + "graphrag_true", + "graphrag_false", + "graphrag_entity_types", + "graphrag_method_general", + "graphrag_method_light", + "graphrag_community_true", + "graphrag_community_false", + "graphrag_resolution_true", + "graphrag_resolution_false", + "raptor_true", + "raptor_false", + "raptor_prompt", + "raptor_max_token_min", + "raptor_max_token_mid", + "raptor_max_token_max", + "raptor_threshold_min", + "raptor_threshold_mid", + "raptor_threshold_max", + "raptor_max_cluster_min", + "raptor_max_cluster_mid", + "raptor_max_cluster_max", + "raptor_random_seed_min", + ], + ) + def test_parser_config(self, HttpApiAuth, add_dataset_func, parser_config): + dataset_id = add_dataset_func + payload = {"parser_config": parser_config} + res = update_dataset(HttpApiAuth, dataset_id, payload) + assert res["code"] == 0, res + + res = list_datasets(HttpApiAuth) + assert res["code"] == 0, res + for k, v in parser_config.items(): + if isinstance(v, dict): + for kk, vv in v.items(): + assert res["data"][0]["parser_config"][k][kk] == vv, res + else: + assert res["data"][0]["parser_config"][k] == v, res + + @pytest.mark.p2 + @pytest.mark.parametrize( + "parser_config, expected_message", + [ + ({"auto_keywords": -1}, "Input should be greater than or equal to 0"), + ({"auto_keywords": 33}, "Input should be less than or equal to 32"), + ({"auto_keywords": 3.14}, "Input should be a valid integer, got a number with a fractional part"), + ({"auto_keywords": "string"}, "Input should be a valid integer, unable to parse string as an integer"), + ({"auto_questions": -1}, "Input should be greater than or equal to 0"), + ({"auto_questions": 11}, "Input should be less than or equal to 10"), + ({"auto_questions": 3.14}, "Input should be a valid integer, got a number with a fractional part"), + ({"auto_questions": "string"}, "Input should be a valid integer, unable to parse string as an integer"), + ({"chunk_token_num": 0}, "Input should be greater than or equal to 1"), + ({"chunk_token_num": 2049}, "Input should be less than or equal to 2048"), + ({"chunk_token_num": 3.14}, "Input should be a valid integer, got a number with a fractional part"), + ({"chunk_token_num": "string"}, "Input should be a valid integer, unable to parse string as an integer"), + ({"delimiter": ""}, "String should have at least 1 character"), + ({"html4excel": "string"}, "Input should be a valid boolean, unable to interpret input"), + ({"tag_kb_ids": "1,2"}, "Input should be a valid list"), + ({"tag_kb_ids": [1, 2]}, "Input should be a valid string"), + ({"topn_tags": 0}, "Input should be greater than or equal to 1"), + ({"topn_tags": 11}, "Input should be less than or equal to 10"), + ({"topn_tags": 3.14}, "Input should be a valid integer, got a number with a fractional part"), + ({"topn_tags": "string"}, "Input should be a valid integer, unable to parse string as an integer"), + ({"filename_embd_weight": -1}, "Input should be greater than or equal to 0"), + ({"filename_embd_weight": 1.1}, "Input should be less than or equal to 1"), + ({"filename_embd_weight": "string"}, "Input should be a valid number, unable to parse string as a number"), + ({"task_page_size": 0}, "Input should be greater than or equal to 1"), + ({"task_page_size": 3.14}, "Input should be a valid integer, got a number with a fractional part"), + ({"task_page_size": "string"}, "Input should be a valid integer, unable to parse string as an integer"), + ({"pages": "1,2"}, "Input should be a valid list"), + ({"pages": ["1,2"]}, "Input should be a valid list"), + ({"pages": [["string1", "string2"]]}, "Input should be a valid integer, unable to parse string as an integer"), + ({"graphrag": {"use_graphrag": "string"}}, "Input should be a valid boolean, unable to interpret input"), + ({"graphrag": {"entity_types": "1,2"}}, "Input should be a valid list"), + ({"graphrag": {"entity_types": [1, 2]}}, "nput should be a valid string"), + ({"graphrag": {"method": "unknown"}}, "Input should be 'light' or 'general'"), + ({"graphrag": {"method": None}}, "Input should be 'light' or 'general'"), + ({"graphrag": {"community": "string"}}, "Input should be a valid boolean, unable to interpret input"), + ({"graphrag": {"resolution": "string"}}, "Input should be a valid boolean, unable to interpret input"), + ({"raptor": {"use_raptor": "string"}}, "Input should be a valid boolean, unable to interpret input"), + ({"raptor": {"prompt": ""}}, "String should have at least 1 character"), + ({"raptor": {"prompt": " "}}, "String should have at least 1 character"), + ({"raptor": {"max_token": 0}}, "Input should be greater than or equal to 1"), + ({"raptor": {"max_token": 2049}}, "Input should be less than or equal to 2048"), + ({"raptor": {"max_token": 3.14}}, "Input should be a valid integer, got a number with a fractional part"), + ({"raptor": {"max_token": "string"}}, "Input should be a valid integer, unable to parse string as an integer"), + ({"raptor": {"threshold": -0.1}}, "Input should be greater than or equal to 0"), + ({"raptor": {"threshold": 1.1}}, "Input should be less than or equal to 1"), + ({"raptor": {"threshold": "string"}}, "Input should be a valid number, unable to parse string as a number"), + ({"raptor": {"max_cluster": 0}}, "Input should be greater than or equal to 1"), + ({"raptor": {"max_cluster": 1025}}, "Input should be less than or equal to 1024"), + ({"raptor": {"max_cluster": 3.14}}, "Input should be a valid integer, got a number with a fractional par"), + ({"raptor": {"max_cluster": "string"}}, "Input should be a valid integer, unable to parse string as an integer"), + ({"raptor": {"random_seed": -1}}, "Input should be greater than or equal to 0"), + ({"raptor": {"random_seed": 3.14}}, "Input should be a valid integer, got a number with a fractional part"), + ({"raptor": {"random_seed": "string"}}, "Input should be a valid integer, unable to parse string as an integer"), + ({"delimiter": "a" * 65536}, "Parser config exceeds size limit (max 65,535 characters)"), + ], + ids=[ + "auto_keywords_min_limit", + "auto_keywords_max_limit", + "auto_keywords_float_not_allowed", + "auto_keywords_type_invalid", + "auto_questions_min_limit", + "auto_questions_max_limit", + "auto_questions_float_not_allowed", + "auto_questions_type_invalid", + "chunk_token_num_min_limit", + "chunk_token_num_max_limit", + "chunk_token_num_float_not_allowed", + "chunk_token_num_type_invalid", + "delimiter_empty", + "html4excel_type_invalid", + "tag_kb_ids_not_list", + "tag_kb_ids_int_in_list", + "topn_tags_min_limit", + "topn_tags_max_limit", + "topn_tags_float_not_allowed", + "topn_tags_type_invalid", + "filename_embd_weight_min_limit", + "filename_embd_weight_max_limit", + "filename_embd_weight_type_invalid", + "task_page_size_min_limit", + "task_page_size_float_not_allowed", + "task_page_size_type_invalid", + "pages_not_list", + "pages_not_list_in_list", + "pages_not_int_list", + "graphrag_type_invalid", + "graphrag_entity_types_not_list", + "graphrag_entity_types_not_str_in_list", + "graphrag_method_unknown", + "graphrag_method_none", + "graphrag_community_type_invalid", + "graphrag_resolution_type_invalid", + "raptor_type_invalid", + "raptor_prompt_empty", + "raptor_prompt_space", + "raptor_max_token_min_limit", + "raptor_max_token_max_limit", + "raptor_max_token_float_not_allowed", + "raptor_max_token_type_invalid", + "raptor_threshold_min_limit", + "raptor_threshold_max_limit", + "raptor_threshold_type_invalid", + "raptor_max_cluster_min_limit", + "raptor_max_cluster_max_limit", + "raptor_max_cluster_float_not_allowed", + "raptor_max_cluster_type_invalid", + "raptor_random_seed_min_limit", + "raptor_random_seed_float_not_allowed", + "raptor_random_seed_type_invalid", + "parser_config_type_invalid", + ], + ) + def test_parser_config_invalid(self, HttpApiAuth, add_dataset_func, parser_config, expected_message): + dataset_id = add_dataset_func + payload = {"parser_config": parser_config} + res = update_dataset(HttpApiAuth, dataset_id, payload) + assert res["code"] == 101, res + assert expected_message in res["message"], res + + @pytest.mark.p2 + def test_parser_config_empty(self, HttpApiAuth, add_dataset_func): + dataset_id = add_dataset_func + payload = {"parser_config": {}} + res = update_dataset(HttpApiAuth, dataset_id, payload) + assert res["code"] == 0, res + + res = list_datasets(HttpApiAuth) + assert res["code"] == 0, res + assert res["data"][0]["parser_config"] == { + "chunk_token_num": 128, + "delimiter": r"\n", + "html4excel": False, + "layout_recognize": "DeepDOC", + "raptor": {"use_raptor": False}, + }, res + + @pytest.mark.p3 + def test_parser_config_none(self, HttpApiAuth, add_dataset_func): + dataset_id = add_dataset_func + payload = {"parser_config": None} + res = update_dataset(HttpApiAuth, dataset_id, payload) + assert res["code"] == 0, res + + res = list_datasets(HttpApiAuth, {"id": dataset_id}) + assert res["code"] == 0, res + assert res["data"][0]["parser_config"] == { + "chunk_token_num": 128, + "delimiter": r"\n", + "html4excel": False, + "layout_recognize": "DeepDOC", + "raptor": {"use_raptor": False}, + }, res + + @pytest.mark.p3 + def test_parser_config_empty_with_chunk_method_change(self, HttpApiAuth, add_dataset_func): + dataset_id = add_dataset_func + payload = {"chunk_method": "qa", "parser_config": {}} + res = update_dataset(HttpApiAuth, dataset_id, payload) + assert res["code"] == 0, res + + res = list_datasets(HttpApiAuth) + assert res["code"] == 0, res + assert res["data"][0]["parser_config"] == {"raptor": {"use_raptor": False}}, res + + @pytest.mark.p3 + def test_parser_config_unset_with_chunk_method_change(self, HttpApiAuth, add_dataset_func): + dataset_id = add_dataset_func + payload = {"chunk_method": "qa"} + res = update_dataset(HttpApiAuth, dataset_id, payload) + assert res["code"] == 0, res + + res = list_datasets(HttpApiAuth) + assert res["code"] == 0, res + assert res["data"][0]["parser_config"] == {"raptor": {"use_raptor": False}}, res + + @pytest.mark.p3 + def test_parser_config_none_with_chunk_method_change(self, HttpApiAuth, add_dataset_func): + dataset_id = add_dataset_func + payload = {"chunk_method": "qa", "parser_config": None} + res = update_dataset(HttpApiAuth, dataset_id, payload) + assert res["code"] == 0, res + + res = list_datasets(HttpApiAuth, {"id": dataset_id}) + assert res["code"] == 0, res + assert res["data"][0]["parser_config"] == {"raptor": {"use_raptor": False}}, res + + @pytest.mark.p2 + @pytest.mark.parametrize( + "payload", + [ + {"id": "id"}, + {"tenant_id": "e57c1966f99211efb41e9e45646e0111"}, + {"created_by": "created_by"}, + {"create_date": "Tue, 11 Mar 2025 13:37:23 GMT"}, + {"create_time": 1741671443322}, + {"update_date": "Tue, 11 Mar 2025 13:37:23 GMT"}, + {"update_time": 1741671443339}, + {"document_count": 1}, + {"chunk_count": 1}, + {"token_num": 1}, + {"status": "1"}, + {"unknown_field": "unknown_field"}, + ], + ) + def test_field_unsupported(self, HttpApiAuth, add_dataset_func, payload): + dataset_id = add_dataset_func + res = update_dataset(HttpApiAuth, dataset_id, payload) + assert res["code"] == 101, res + assert "Extra inputs are not permitted" in res["message"], res + + @pytest.mark.p2 + def test_field_unset(self, HttpApiAuth, add_dataset_func): + dataset_id = add_dataset_func + res = list_datasets(HttpApiAuth) + assert res["code"] == 0, res + original_data = res["data"][0] + + payload = {"name": "default_unset"} + res = update_dataset(HttpApiAuth, dataset_id, payload) + assert res["code"] == 0, res + + res = list_datasets(HttpApiAuth) + assert res["code"] == 0, res + assert res["data"][0]["avatar"] == original_data["avatar"], res + assert res["data"][0]["description"] == original_data["description"], res + assert res["data"][0]["embedding_model"] == original_data["embedding_model"], res + assert res["data"][0]["permission"] == original_data["permission"], res + assert res["data"][0]["chunk_method"] == original_data["chunk_method"], res + assert res["data"][0]["pagerank"] == original_data["pagerank"], res + assert res["data"][0]["parser_config"] == original_data["parser_config"], res diff --git a/test/testcases/test_http_api/test_file_management_within_dataset/conftest.py b/test/testcases/test_http_api/test_file_management_within_dataset/conftest.py new file mode 100644 index 00000000000..cd1014382e8 --- /dev/null +++ b/test/testcases/test_http_api/test_file_management_within_dataset/conftest.py @@ -0,0 +1,52 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +import pytest +from common import bulk_upload_documents, delete_documents + + +@pytest.fixture(scope="function") +def add_document_func(request, HttpApiAuth, add_dataset, ragflow_tmp_dir): + def cleanup(): + delete_documents(HttpApiAuth, dataset_id, {"ids": None}) + + request.addfinalizer(cleanup) + + dataset_id = add_dataset + return dataset_id, bulk_upload_documents(HttpApiAuth, dataset_id, 1, ragflow_tmp_dir)[0] + + +@pytest.fixture(scope="class") +def add_documents(request, HttpApiAuth, add_dataset, ragflow_tmp_dir): + def cleanup(): + delete_documents(HttpApiAuth, dataset_id, {"ids": None}) + + request.addfinalizer(cleanup) + + dataset_id = add_dataset + return dataset_id, bulk_upload_documents(HttpApiAuth, dataset_id, 5, ragflow_tmp_dir) + + +@pytest.fixture(scope="function") +def add_documents_func(request, HttpApiAuth, add_dataset_func, ragflow_tmp_dir): + def cleanup(): + delete_documents(HttpApiAuth, dataset_id, {"ids": None}) + + request.addfinalizer(cleanup) + + dataset_id = add_dataset_func + return dataset_id, bulk_upload_documents(HttpApiAuth, dataset_id, 3, ragflow_tmp_dir) diff --git a/test/testcases/test_http_api/test_file_management_within_dataset/test_delete_documents.py b/test/testcases/test_http_api/test_file_management_within_dataset/test_delete_documents.py new file mode 100644 index 00000000000..74f5c060639 --- /dev/null +++ b/test/testcases/test_http_api/test_file_management_within_dataset/test_delete_documents.py @@ -0,0 +1,183 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from concurrent.futures import ThreadPoolExecutor, as_completed + +import pytest +from common import bulk_upload_documents, delete_documents, list_documents +from configs import INVALID_API_TOKEN +from libs.auth import RAGFlowHttpApiAuth + + +@pytest.mark.p1 +class TestAuthorization: + @pytest.mark.parametrize( + "invalid_auth, expected_code, expected_message", + [ + (None, 0, "`Authorization` can't be empty"), + ( + RAGFlowHttpApiAuth(INVALID_API_TOKEN), + 109, + "Authentication error: API key is invalid!", + ), + ], + ) + def test_invalid_auth(self, invalid_auth, expected_code, expected_message): + res = delete_documents(invalid_auth, "dataset_id") + assert res["code"] == expected_code + assert res["message"] == expected_message + + +class TestDocumentsDeletion: + @pytest.mark.p1 + @pytest.mark.parametrize( + "payload, expected_code, expected_message, remaining", + [ + (None, 0, "", 0), + ({"ids": []}, 0, "", 0), + ({"ids": ["invalid_id"]}, 102, "Documents not found: ['invalid_id']", 3), + ( + {"ids": ["\n!?。;!?\"'"]}, + 102, + """Documents not found: [\'\\n!?。;!?"\\\'\']""", + 3, + ), + ( + "not json", + 100, + "AttributeError(\"'str' object has no attribute 'get'\")", + 3, + ), + (lambda r: {"ids": r[:1]}, 0, "", 2), + (lambda r: {"ids": r}, 0, "", 0), + ], + ) + def test_basic_scenarios( + self, + HttpApiAuth, + add_documents_func, + payload, + expected_code, + expected_message, + remaining, + ): + dataset_id, document_ids = add_documents_func + if callable(payload): + payload = payload(document_ids) + res = delete_documents(HttpApiAuth, dataset_id, payload) + assert res["code"] == expected_code + if res["code"] != 0: + assert res["message"] == expected_message + + res = list_documents(HttpApiAuth, dataset_id) + assert len(res["data"]["docs"]) == remaining + assert res["data"]["total"] == remaining + + @pytest.mark.p3 + @pytest.mark.parametrize( + "dataset_id, expected_code, expected_message", + [ + ("", 100, ""), + ( + "invalid_dataset_id", + 102, + "You don't own the dataset invalid_dataset_id. ", + ), + ], + ) + def test_invalid_dataset_id(self, HttpApiAuth, add_documents_func, dataset_id, expected_code, expected_message): + _, document_ids = add_documents_func + res = delete_documents(HttpApiAuth, dataset_id, {"ids": document_ids[:1]}) + assert res["code"] == expected_code + assert res["message"] == expected_message + + @pytest.mark.p2 + @pytest.mark.parametrize( + "payload", + [ + lambda r: {"ids": ["invalid_id"] + r}, + lambda r: {"ids": r[:1] + ["invalid_id"] + r[1:3]}, + lambda r: {"ids": r + ["invalid_id"]}, + ], + ) + def test_delete_partial_invalid_id(self, HttpApiAuth, add_documents_func, payload): + dataset_id, document_ids = add_documents_func + if callable(payload): + payload = payload(document_ids) + res = delete_documents(HttpApiAuth, dataset_id, payload) + assert res["code"] == 102 + assert res["message"] == "Documents not found: ['invalid_id']" + + res = list_documents(HttpApiAuth, dataset_id) + assert len(res["data"]["docs"]) == 0 + assert res["data"]["total"] == 0 + + @pytest.mark.p2 + def test_repeated_deletion(self, HttpApiAuth, add_documents_func): + dataset_id, document_ids = add_documents_func + res = delete_documents(HttpApiAuth, dataset_id, {"ids": document_ids}) + assert res["code"] == 0 + + res = delete_documents(HttpApiAuth, dataset_id, {"ids": document_ids}) + assert res["code"] == 102 + assert "Documents not found" in res["message"] + + @pytest.mark.p2 + def test_duplicate_deletion(self, HttpApiAuth, add_documents_func): + dataset_id, document_ids = add_documents_func + res = delete_documents(HttpApiAuth, dataset_id, {"ids": document_ids + document_ids}) + assert res["code"] == 0 + assert "Duplicate document ids" in res["data"]["errors"][0] + assert res["data"]["success_count"] == 3 + + res = list_documents(HttpApiAuth, dataset_id) + assert len(res["data"]["docs"]) == 0 + assert res["data"]["total"] == 0 + + +@pytest.mark.p3 +def test_concurrent_deletion(HttpApiAuth, add_dataset, tmp_path): + count = 100 + dataset_id = add_dataset + document_ids = bulk_upload_documents(HttpApiAuth, dataset_id, count, tmp_path) + + with ThreadPoolExecutor(max_workers=5) as executor: + futures = [ + executor.submit( + delete_documents, + HttpApiAuth, + dataset_id, + {"ids": document_ids[i : i + 1]}, + ) + for i in range(count) + ] + responses = list(as_completed(futures)) + assert len(responses) == count, responses + assert all(future.result()["code"] == 0 for future in futures) + + +@pytest.mark.p3 +def test_delete_1k(HttpApiAuth, add_dataset, tmp_path): + documents_num = 1_000 + dataset_id = add_dataset + document_ids = bulk_upload_documents(HttpApiAuth, dataset_id, documents_num, tmp_path) + res = list_documents(HttpApiAuth, dataset_id) + assert res["data"]["total"] == documents_num + + res = delete_documents(HttpApiAuth, dataset_id, {"ids": document_ids}) + assert res["code"] == 0 + + res = list_documents(HttpApiAuth, dataset_id) + assert res["data"]["total"] == 0 diff --git a/test/testcases/test_http_api/test_file_management_within_dataset/test_download_document.py b/test/testcases/test_http_api/test_file_management_within_dataset/test_download_document.py new file mode 100644 index 00000000000..2d04ae53192 --- /dev/null +++ b/test/testcases/test_http_api/test_file_management_within_dataset/test_download_document.py @@ -0,0 +1,179 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import json +from concurrent.futures import ThreadPoolExecutor, as_completed + +import pytest +from common import bulk_upload_documents, download_document, upload_documents +from configs import INVALID_API_TOKEN +from libs.auth import RAGFlowHttpApiAuth +from requests import codes +from utils import compare_by_hash + + +@pytest.mark.p1 +class TestAuthorization: + @pytest.mark.parametrize( + "invalid_auth, expected_code, expected_message", + [ + (None, 0, "`Authorization` can't be empty"), + ( + RAGFlowHttpApiAuth(INVALID_API_TOKEN), + 109, + "Authentication error: API key is invalid!", + ), + ], + ) + def test_invalid_auth(self, invalid_auth, tmp_path, expected_code, expected_message): + res = download_document(invalid_auth, "dataset_id", "document_id", tmp_path / "ragflow_tes.txt") + assert res.status_code == codes.ok + with (tmp_path / "ragflow_tes.txt").open("r") as f: + response_json = json.load(f) + assert response_json["code"] == expected_code + assert response_json["message"] == expected_message + + +@pytest.mark.p1 +@pytest.mark.parametrize( + "generate_test_files", + [ + "docx", + "excel", + "ppt", + "image", + "pdf", + "txt", + "md", + "json", + "eml", + "html", + ], + indirect=True, +) +def test_file_type_validation(HttpApiAuth, add_dataset, generate_test_files, request): + dataset_id = add_dataset + fp = generate_test_files[request.node.callspec.params["generate_test_files"]] + res = upload_documents(HttpApiAuth, dataset_id, [fp]) + document_id = res["data"][0]["id"] + + res = download_document( + HttpApiAuth, + dataset_id, + document_id, + fp.with_stem("ragflow_test_download"), + ) + assert res.status_code == codes.ok + assert compare_by_hash( + fp, + fp.with_stem("ragflow_test_download"), + ) + + +class TestDocumentDownload: + @pytest.mark.p3 + @pytest.mark.parametrize( + "document_id, expected_code, expected_message", + [ + ( + "invalid_document_id", + 102, + "The dataset not own the document invalid_document_id.", + ), + ], + ) + def test_invalid_document_id(self, HttpApiAuth, add_documents, tmp_path, document_id, expected_code, expected_message): + dataset_id, _ = add_documents + res = download_document( + HttpApiAuth, + dataset_id, + document_id, + tmp_path / "ragflow_test_download_1.txt", + ) + assert res.status_code == codes.ok + with (tmp_path / "ragflow_test_download_1.txt").open("r") as f: + response_json = json.load(f) + assert response_json["code"] == expected_code + assert response_json["message"] == expected_message + + @pytest.mark.p3 + @pytest.mark.parametrize( + "dataset_id, expected_code, expected_message", + [ + ("", 100, ""), + ( + "invalid_dataset_id", + 102, + "You do not own the dataset invalid_dataset_id.", + ), + ], + ) + def test_invalid_dataset_id(self, HttpApiAuth, add_documents, tmp_path, dataset_id, expected_code, expected_message): + _, document_ids = add_documents + res = download_document( + HttpApiAuth, + dataset_id, + document_ids[0], + tmp_path / "ragflow_test_download_1.txt", + ) + assert res.status_code == codes.ok + with (tmp_path / "ragflow_test_download_1.txt").open("r") as f: + response_json = json.load(f) + assert response_json["code"] == expected_code + assert response_json["message"] == expected_message + + @pytest.mark.p3 + def test_same_file_repeat(self, HttpApiAuth, add_documents, tmp_path, ragflow_tmp_dir): + num = 5 + dataset_id, document_ids = add_documents + for i in range(num): + res = download_document( + HttpApiAuth, + dataset_id, + document_ids[0], + tmp_path / f"ragflow_test_download_{i}.txt", + ) + assert res.status_code == codes.ok + assert compare_by_hash( + ragflow_tmp_dir / "ragflow_test_upload_0.txt", + tmp_path / f"ragflow_test_download_{i}.txt", + ) + + +@pytest.mark.p3 +def test_concurrent_download(HttpApiAuth, add_dataset, tmp_path): + count = 20 + dataset_id = add_dataset + document_ids = bulk_upload_documents(HttpApiAuth, dataset_id, count, tmp_path) + + with ThreadPoolExecutor(max_workers=5) as executor: + futures = [ + executor.submit( + download_document, + HttpApiAuth, + dataset_id, + document_ids[i], + tmp_path / f"ragflow_test_download_{i}.txt", + ) + for i in range(count) + ] + responses = list(as_completed(futures)) + assert len(responses) == count, responses + for i in range(count): + assert compare_by_hash( + tmp_path / f"ragflow_test_upload_{i}.txt", + tmp_path / f"ragflow_test_download_{i}.txt", + ) diff --git a/test/testcases/test_http_api/test_file_management_within_dataset/test_list_documents.py b/test/testcases/test_http_api/test_file_management_within_dataset/test_list_documents.py new file mode 100644 index 00000000000..4fbe59b0bdb --- /dev/null +++ b/test/testcases/test_http_api/test_file_management_within_dataset/test_list_documents.py @@ -0,0 +1,360 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from concurrent.futures import ThreadPoolExecutor, as_completed + +import pytest +from common import list_documents +from configs import INVALID_API_TOKEN +from libs.auth import RAGFlowHttpApiAuth +from utils import is_sorted + + +@pytest.mark.p1 +class TestAuthorization: + @pytest.mark.parametrize( + "invalid_auth, expected_code, expected_message", + [ + (None, 0, "`Authorization` can't be empty"), + ( + RAGFlowHttpApiAuth(INVALID_API_TOKEN), + 109, + "Authentication error: API key is invalid!", + ), + ], + ) + def test_invalid_auth(self, invalid_auth, expected_code, expected_message): + res = list_documents(invalid_auth, "dataset_id") + assert res["code"] == expected_code + assert res["message"] == expected_message + + +class TestDocumentsList: + @pytest.mark.p1 + def test_default(self, HttpApiAuth, add_documents): + dataset_id, _ = add_documents + res = list_documents(HttpApiAuth, dataset_id) + assert res["code"] == 0 + assert len(res["data"]["docs"]) == 5 + assert res["data"]["total"] == 5 + + @pytest.mark.p3 + @pytest.mark.parametrize( + "dataset_id, expected_code, expected_message", + [ + ("", 100, ""), + ( + "invalid_dataset_id", + 102, + "You don't own the dataset invalid_dataset_id. ", + ), + ], + ) + def test_invalid_dataset_id(self, HttpApiAuth, dataset_id, expected_code, expected_message): + res = list_documents(HttpApiAuth, dataset_id) + assert res["code"] == expected_code + assert res["message"] == expected_message + + @pytest.mark.p1 + @pytest.mark.parametrize( + "params, expected_code, expected_page_size, expected_message", + [ + ({"page": None, "page_size": 2}, 0, 2, ""), + ({"page": 0, "page_size": 2}, 0, 2, ""), + ({"page": 2, "page_size": 2}, 0, 2, ""), + ({"page": 3, "page_size": 2}, 0, 1, ""), + ({"page": "3", "page_size": 2}, 0, 1, ""), + pytest.param( + {"page": -1, "page_size": 2}, + 100, + 0, + "1064", + marks=pytest.mark.skip(reason="issues/5851"), + ), + pytest.param( + {"page": "a", "page_size": 2}, + 100, + 0, + """ValueError("invalid literal for int() with base 10: \'a\'")""", + marks=pytest.mark.skip(reason="issues/5851"), + ), + ], + ) + def test_page( + self, + HttpApiAuth, + add_documents, + params, + expected_code, + expected_page_size, + expected_message, + ): + dataset_id, _ = add_documents + res = list_documents(HttpApiAuth, dataset_id, params=params) + assert res["code"] == expected_code + if expected_code == 0: + assert len(res["data"]["docs"]) == expected_page_size + assert res["data"]["total"] == 5 + else: + assert res["message"] == expected_message + + @pytest.mark.p1 + @pytest.mark.parametrize( + "params, expected_code, expected_page_size, expected_message", + [ + ({"page_size": None}, 0, 5, ""), + ({"page_size": 0}, 0, 0, ""), + ({"page_size": 1}, 0, 1, ""), + ({"page_size": 6}, 0, 5, ""), + ({"page_size": "1"}, 0, 1, ""), + pytest.param( + {"page_size": -1}, + 100, + 0, + "1064", + marks=pytest.mark.skip(reason="issues/5851"), + ), + pytest.param( + {"page_size": "a"}, + 100, + 0, + """ValueError("invalid literal for int() with base 10: \'a\'")""", + marks=pytest.mark.skip(reason="issues/5851"), + ), + ], + ) + def test_page_size( + self, + HttpApiAuth, + add_documents, + params, + expected_code, + expected_page_size, + expected_message, + ): + dataset_id, _ = add_documents + res = list_documents(HttpApiAuth, dataset_id, params=params) + assert res["code"] == expected_code + if expected_code == 0: + assert len(res["data"]["docs"]) == expected_page_size + else: + assert res["message"] == expected_message + + @pytest.mark.p3 + @pytest.mark.parametrize( + "params, expected_code, assertions, expected_message", + [ + ({"orderby": None}, 0, lambda r: (is_sorted(r["data"]["docs"], "create_time", True)), ""), + ({"orderby": "create_time"}, 0, lambda r: (is_sorted(r["data"]["docs"], "create_time", True)), ""), + ({"orderby": "update_time"}, 0, lambda r: (is_sorted(r["data"]["docs"], "update_time", True)), ""), + pytest.param({"orderby": "name", "desc": "False"}, 0, lambda r: (is_sorted(r["data"]["docs"], "name", False)), "", marks=pytest.mark.skip(reason="issues/5851")), + pytest.param({"orderby": "unknown"}, 102, 0, "orderby should be create_time or update_time", marks=pytest.mark.skip(reason="issues/5851")), + ], + ) + def test_orderby( + self, + HttpApiAuth, + add_documents, + params, + expected_code, + assertions, + expected_message, + ): + dataset_id, _ = add_documents + res = list_documents(HttpApiAuth, dataset_id, params=params) + assert res["code"] == expected_code + if expected_code == 0: + if callable(assertions): + assert assertions(res) + else: + assert res["message"] == expected_message + + @pytest.mark.p3 + @pytest.mark.parametrize( + "params, expected_code, assertions, expected_message", + [ + ({"desc": None}, 0, lambda r: (is_sorted(r["data"]["docs"], "create_time", True)), ""), + ({"desc": "true"}, 0, lambda r: (is_sorted(r["data"]["docs"], "create_time", True)), ""), + ({"desc": "True"}, 0, lambda r: (is_sorted(r["data"]["docs"], "create_time", True)), ""), + ({"desc": True}, 0, lambda r: (is_sorted(r["data"]["docs"], "create_time", True)), ""), + pytest.param({"desc": "false"}, 0, lambda r: (is_sorted(r["data"]["docs"], "create_time", False)), "", marks=pytest.mark.skip(reason="issues/5851")), + ({"desc": "False"}, 0, lambda r: (is_sorted(r["data"]["docs"], "create_time", False)), ""), + ({"desc": False}, 0, lambda r: (is_sorted(r["data"]["docs"], "create_time", False)), ""), + ({"desc": "False", "orderby": "update_time"}, 0, lambda r: (is_sorted(r["data"]["docs"], "update_time", False)), ""), + pytest.param({"desc": "unknown"}, 102, 0, "desc should be true or false", marks=pytest.mark.skip(reason="issues/5851")), + ], + ) + def test_desc( + self, + HttpApiAuth, + add_documents, + params, + expected_code, + assertions, + expected_message, + ): + dataset_id, _ = add_documents + res = list_documents(HttpApiAuth, dataset_id, params=params) + assert res["code"] == expected_code + if expected_code == 0: + if callable(assertions): + assert assertions(res) + else: + assert res["message"] == expected_message + + @pytest.mark.p2 + @pytest.mark.parametrize( + "params, expected_num", + [ + ({"keywords": None}, 5), + ({"keywords": ""}, 5), + ({"keywords": "0"}, 1), + ({"keywords": "ragflow_test_upload"}, 5), + ({"keywords": "unknown"}, 0), + ], + ) + def test_keywords(self, HttpApiAuth, add_documents, params, expected_num): + dataset_id, _ = add_documents + res = list_documents(HttpApiAuth, dataset_id, params=params) + assert res["code"] == 0 + assert len(res["data"]["docs"]) == expected_num + assert res["data"]["total"] == expected_num + + @pytest.mark.p1 + @pytest.mark.parametrize( + "params, expected_code, expected_num, expected_message", + [ + ({"name": None}, 0, 5, ""), + ({"name": ""}, 0, 5, ""), + ({"name": "ragflow_test_upload_0.txt"}, 0, 1, ""), + ( + {"name": "unknown.txt"}, + 102, + 0, + "You don't own the document unknown.txt.", + ), + ], + ) + def test_name( + self, + HttpApiAuth, + add_documents, + params, + expected_code, + expected_num, + expected_message, + ): + dataset_id, _ = add_documents + res = list_documents(HttpApiAuth, dataset_id, params=params) + assert res["code"] == expected_code + if expected_code == 0: + if params["name"] in [None, ""]: + assert len(res["data"]["docs"]) == expected_num + else: + assert res["data"]["docs"][0]["name"] == params["name"] + else: + assert res["message"] == expected_message + + @pytest.mark.p1 + @pytest.mark.parametrize( + "document_id, expected_code, expected_num, expected_message", + [ + (None, 0, 5, ""), + ("", 0, 5, ""), + (lambda r: r[0], 0, 1, ""), + ("unknown.txt", 102, 0, "You don't own the document unknown.txt."), + ], + ) + def test_id( + self, + HttpApiAuth, + add_documents, + document_id, + expected_code, + expected_num, + expected_message, + ): + dataset_id, document_ids = add_documents + if callable(document_id): + params = {"id": document_id(document_ids)} + else: + params = {"id": document_id} + res = list_documents(HttpApiAuth, dataset_id, params=params) + + assert res["code"] == expected_code + if expected_code == 0: + if params["id"] in [None, ""]: + assert len(res["data"]["docs"]) == expected_num + else: + assert res["data"]["docs"][0]["id"] == params["id"] + else: + assert res["message"] == expected_message + + @pytest.mark.p3 + @pytest.mark.parametrize( + "document_id, name, expected_code, expected_num, expected_message", + [ + (lambda r: r[0], "ragflow_test_upload_0.txt", 0, 1, ""), + (lambda r: r[0], "ragflow_test_upload_1.txt", 0, 0, ""), + (lambda r: r[0], "unknown", 102, 0, "You don't own the document unknown."), + ( + "id", + "ragflow_test_upload_0.txt", + 102, + 0, + "You don't own the document id.", + ), + ], + ) + def test_name_and_id( + self, + HttpApiAuth, + add_documents, + document_id, + name, + expected_code, + expected_num, + expected_message, + ): + dataset_id, document_ids = add_documents + if callable(document_id): + params = {"id": document_id(document_ids), "name": name} + else: + params = {"id": document_id, "name": name} + + res = list_documents(HttpApiAuth, dataset_id, params=params) + if expected_code == 0: + assert len(res["data"]["docs"]) == expected_num + else: + assert res["message"] == expected_message + + @pytest.mark.p3 + def test_concurrent_list(self, HttpApiAuth, add_documents): + dataset_id, _ = add_documents + count = 100 + + with ThreadPoolExecutor(max_workers=5) as executor: + futures = [executor.submit(list_documents, HttpApiAuth, dataset_id) for i in range(count)] + responses = list(as_completed(futures)) + assert len(responses) == count, responses + assert all(future.result()["code"] == 0 for future in futures) + + @pytest.mark.p3 + def test_invalid_params(self, HttpApiAuth, add_documents): + dataset_id, _ = add_documents + params = {"a": "b"} + res = list_documents(HttpApiAuth, dataset_id, params=params) + assert res["code"] == 0 + assert len(res["data"]["docs"]) == 5 diff --git a/test/testcases/test_http_api/test_file_management_within_dataset/test_parse_documents.py b/test/testcases/test_http_api/test_file_management_within_dataset/test_parse_documents.py new file mode 100644 index 00000000000..e8ffa914ed3 --- /dev/null +++ b/test/testcases/test_http_api/test_file_management_within_dataset/test_parse_documents.py @@ -0,0 +1,219 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from concurrent.futures import ThreadPoolExecutor, as_completed + +import pytest +from common import bulk_upload_documents, list_documents, parse_documents +from configs import INVALID_API_TOKEN +from libs.auth import RAGFlowHttpApiAuth +from utils import wait_for + + +@wait_for(30, 1, "Document parsing timeout") +def condition(_auth, _dataset_id, _document_ids=None): + res = list_documents(_auth, _dataset_id) + target_docs = res["data"]["docs"] + + if _document_ids is None: + for doc in target_docs: + if doc["run"] != "DONE": + return False + return True + + target_ids = set(_document_ids) + for doc in target_docs: + if doc["id"] in target_ids: + if doc.get("run") != "DONE": + return False + return True + + +def validate_document_details(auth, dataset_id, document_ids): + for document_id in document_ids: + res = list_documents(auth, dataset_id, params={"id": document_id}) + doc = res["data"]["docs"][0] + assert doc["run"] == "DONE" + assert len(doc["process_begin_at"]) > 0 + assert doc["process_duation"] > 0 + assert doc["progress"] > 0 + assert "Task done" in doc["progress_msg"] + + +@pytest.mark.p1 +class TestAuthorization: + @pytest.mark.parametrize( + "invalid_auth, expected_code, expected_message", + [ + (None, 0, "`Authorization` can't be empty"), + ( + RAGFlowHttpApiAuth(INVALID_API_TOKEN), + 109, + "Authentication error: API key is invalid!", + ), + ], + ) + def test_invalid_auth(self, invalid_auth, expected_code, expected_message): + res = parse_documents(invalid_auth, "dataset_id") + assert res["code"] == expected_code + assert res["message"] == expected_message + + +class TestDocumentsParse: + @pytest.mark.parametrize( + "payload, expected_code, expected_message", + [ + pytest.param(None, 102, """AttributeError("\'NoneType\' object has no attribute \'get\'")""", marks=pytest.mark.skip), + pytest.param({"document_ids": []}, 102, "`document_ids` is required", marks=pytest.mark.p1), + pytest.param({"document_ids": ["invalid_id"]}, 102, "Documents not found: ['invalid_id']", marks=pytest.mark.p3), + pytest.param({"document_ids": ["\n!?。;!?\"'"]}, 102, """Documents not found: [\'\\n!?。;!?"\\\'\']""", marks=pytest.mark.p3), + pytest.param("not json", 102, "AttributeError(\"'str' object has no attribute 'get'\")", marks=pytest.mark.skip), + pytest.param(lambda r: {"document_ids": r[:1]}, 0, "", marks=pytest.mark.p1), + pytest.param(lambda r: {"document_ids": r}, 0, "", marks=pytest.mark.p1), + ], + ) + def test_basic_scenarios(self, HttpApiAuth, add_documents_func, payload, expected_code, expected_message): + dataset_id, document_ids = add_documents_func + if callable(payload): + payload = payload(document_ids) + res = parse_documents(HttpApiAuth, dataset_id, payload) + assert res["code"] == expected_code + if expected_code != 0: + assert res["message"] == expected_message + if expected_code == 0: + condition(HttpApiAuth, dataset_id, payload["document_ids"]) + validate_document_details(HttpApiAuth, dataset_id, payload["document_ids"]) + + @pytest.mark.p3 + @pytest.mark.parametrize( + "dataset_id, expected_code, expected_message", + [ + ("", 100, ""), + ( + "invalid_dataset_id", + 102, + "You don't own the dataset invalid_dataset_id.", + ), + ], + ) + def test_invalid_dataset_id( + self, + HttpApiAuth, + add_documents_func, + dataset_id, + expected_code, + expected_message, + ): + _, document_ids = add_documents_func + res = parse_documents(HttpApiAuth, dataset_id, {"document_ids": document_ids}) + assert res["code"] == expected_code + assert res["message"] == expected_message + + @pytest.mark.parametrize( + "payload", + [ + pytest.param(lambda r: {"document_ids": ["invalid_id"] + r}, marks=pytest.mark.p3), + pytest.param(lambda r: {"document_ids": r[:1] + ["invalid_id"] + r[1:3]}, marks=pytest.mark.p1), + pytest.param(lambda r: {"document_ids": r + ["invalid_id"]}, marks=pytest.mark.p3), + ], + ) + def test_parse_partial_invalid_document_id(self, HttpApiAuth, add_documents_func, payload): + dataset_id, document_ids = add_documents_func + if callable(payload): + payload = payload(document_ids) + res = parse_documents(HttpApiAuth, dataset_id, payload) + assert res["code"] == 102 + assert res["message"] == "Documents not found: ['invalid_id']" + + condition(HttpApiAuth, dataset_id) + + validate_document_details(HttpApiAuth, dataset_id, document_ids) + + @pytest.mark.p3 + def test_repeated_parse(self, HttpApiAuth, add_documents_func): + dataset_id, document_ids = add_documents_func + res = parse_documents(HttpApiAuth, dataset_id, {"document_ids": document_ids}) + assert res["code"] == 0 + + condition(HttpApiAuth, dataset_id) + + res = parse_documents(HttpApiAuth, dataset_id, {"document_ids": document_ids}) + assert res["code"] == 0 + + @pytest.mark.p3 + def test_duplicate_parse(self, HttpApiAuth, add_documents_func): + dataset_id, document_ids = add_documents_func + res = parse_documents(HttpApiAuth, dataset_id, {"document_ids": document_ids + document_ids}) + assert res["code"] == 0 + assert "Duplicate document ids" in res["data"]["errors"][0] + assert res["data"]["success_count"] == 3 + + condition(HttpApiAuth, dataset_id) + + validate_document_details(HttpApiAuth, dataset_id, document_ids) + + +@pytest.mark.p3 +def test_parse_100_files(HttpApiAuth, add_dataset_func, tmp_path): + @wait_for(100, 1, "Document parsing timeout") + def condition(_auth, _dataset_id, _document_num): + res = list_documents(_auth, _dataset_id, {"page_size": _document_num}) + for doc in res["data"]["docs"]: + if doc["run"] != "DONE": + return False + return True + + document_num = 100 + dataset_id = add_dataset_func + document_ids = bulk_upload_documents(HttpApiAuth, dataset_id, document_num, tmp_path) + res = parse_documents(HttpApiAuth, dataset_id, {"document_ids": document_ids}) + assert res["code"] == 0 + + condition(HttpApiAuth, dataset_id, document_num) + + validate_document_details(HttpApiAuth, dataset_id, document_ids) + + +@pytest.mark.p3 +def test_concurrent_parse(HttpApiAuth, add_dataset_func, tmp_path): + @wait_for(120, 1, "Document parsing timeout") + def condition(_auth, _dataset_id, _document_num): + res = list_documents(_auth, _dataset_id, {"page_size": _document_num}) + for doc in res["data"]["docs"]: + if doc["run"] != "DONE": + return False + return True + + count = 100 + dataset_id = add_dataset_func + document_ids = bulk_upload_documents(HttpApiAuth, dataset_id, count, tmp_path) + + with ThreadPoolExecutor(max_workers=5) as executor: + futures = [ + executor.submit( + parse_documents, + HttpApiAuth, + dataset_id, + {"document_ids": document_ids[i : i + 1]}, + ) + for i in range(count) + ] + responses = list(as_completed(futures)) + assert len(responses) == count, responses + assert all(future.result()["code"] == 0 for future in futures) + + condition(HttpApiAuth, dataset_id, count) + + validate_document_details(HttpApiAuth, dataset_id, document_ids) diff --git a/test/testcases/test_http_api/test_file_management_within_dataset/test_stop_parse_documents.py b/test/testcases/test_http_api/test_file_management_within_dataset/test_stop_parse_documents.py new file mode 100644 index 00000000000..4c324487893 --- /dev/null +++ b/test/testcases/test_http_api/test_file_management_within_dataset/test_stop_parse_documents.py @@ -0,0 +1,203 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from concurrent.futures import ThreadPoolExecutor +from time import sleep + +import pytest +from common import bulk_upload_documents, list_documents, parse_documents, stop_parse_documents +from configs import INVALID_API_TOKEN +from libs.auth import RAGFlowHttpApiAuth +from utils import wait_for + + +def validate_document_parse_done(auth, dataset_id, document_ids): + for document_id in document_ids: + res = list_documents(auth, dataset_id, params={"id": document_id}) + doc = res["data"]["docs"][0] + assert doc["run"] == "DONE" + assert len(doc["process_begin_at"]) > 0 + assert doc["process_duation"] > 0 + assert doc["progress"] > 0 + assert "Task done" in doc["progress_msg"] + + +def validate_document_parse_cancel(auth, dataset_id, document_ids): + for document_id in document_ids: + res = list_documents(auth, dataset_id, params={"id": document_id}) + doc = res["data"]["docs"][0] + assert doc["run"] == "CANCEL" + assert len(doc["process_begin_at"]) > 0 + assert doc["progress"] == 0.0 + + +@pytest.mark.p1 +class TestAuthorization: + @pytest.mark.parametrize( + "invalid_auth, expected_code, expected_message", + [ + (None, 0, "`Authorization` can't be empty"), + ( + RAGFlowHttpApiAuth(INVALID_API_TOKEN), + 109, + "Authentication error: API key is invalid!", + ), + ], + ) + def test_invalid_auth(self, invalid_auth, expected_code, expected_message): + res = stop_parse_documents(invalid_auth, "dataset_id") + assert res["code"] == expected_code + assert res["message"] == expected_message + + +@pytest.mark.skip +class TestDocumentsParseStop: + @pytest.mark.parametrize( + "payload, expected_code, expected_message", + [ + pytest.param(None, 102, """AttributeError("\'NoneType\' object has no attribute \'get\'")""", marks=pytest.mark.skip), + pytest.param({"document_ids": []}, 102, "`document_ids` is required", marks=pytest.mark.p1), + pytest.param({"document_ids": ["invalid_id"]}, 102, "You don't own the document invalid_id.", marks=pytest.mark.p3), + pytest.param({"document_ids": ["\n!?。;!?\"'"]}, 102, """You don\'t own the document \n!?。;!?"\'.""", marks=pytest.mark.p3), + pytest.param("not json", 102, "AttributeError(\"'str' object has no attribute 'get'\")", marks=pytest.mark.skip), + pytest.param(lambda r: {"document_ids": r[:1]}, 0, "", marks=pytest.mark.p1), + pytest.param(lambda r: {"document_ids": r}, 0, "", marks=pytest.mark.p1), + ], + ) + def test_basic_scenarios(self, HttpApiAuth, add_documents_func, payload, expected_code, expected_message): + @wait_for(10, 1, "Document parsing timeout") + def condition(_auth, _dataset_id, _document_ids): + for _document_id in _document_ids: + res = list_documents(_auth, _dataset_id, {"id": _document_id}) + if res["data"]["docs"][0]["run"] != "DONE": + return False + return True + + dataset_id, document_ids = add_documents_func + parse_documents(HttpApiAuth, dataset_id, {"document_ids": document_ids}) + + if callable(payload): + payload = payload(document_ids) + + res = stop_parse_documents(HttpApiAuth, dataset_id, payload) + assert res["code"] == expected_code + if expected_code != 0: + assert res["message"] == expected_message + else: + completed_document_ids = list(set(document_ids) - set(payload["document_ids"])) + condition(HttpApiAuth, dataset_id, completed_document_ids) + validate_document_parse_cancel(HttpApiAuth, dataset_id, payload["document_ids"]) + validate_document_parse_done(HttpApiAuth, dataset_id, completed_document_ids) + + @pytest.mark.p3 + @pytest.mark.parametrize( + "invalid_dataset_id, expected_code, expected_message", + [ + ("", 100, ""), + ( + "invalid_dataset_id", + 102, + "You don't own the dataset invalid_dataset_id.", + ), + ], + ) + def test_invalid_dataset_id( + self, + HttpApiAuth, + add_documents_func, + invalid_dataset_id, + expected_code, + expected_message, + ): + dataset_id, document_ids = add_documents_func + parse_documents(HttpApiAuth, dataset_id, {"document_ids": document_ids}) + res = stop_parse_documents(HttpApiAuth, invalid_dataset_id, {"document_ids": document_ids}) + assert res["code"] == expected_code + assert res["message"] == expected_message + + @pytest.mark.skip + @pytest.mark.parametrize( + "payload", + [ + lambda r: {"document_ids": ["invalid_id"] + r}, + lambda r: {"document_ids": r[:1] + ["invalid_id"] + r[1:3]}, + lambda r: {"document_ids": r + ["invalid_id"]}, + ], + ) + def test_stop_parse_partial_invalid_document_id(self, HttpApiAuth, add_documents_func, payload): + dataset_id, document_ids = add_documents_func + parse_documents(HttpApiAuth, dataset_id, {"document_ids": document_ids}) + + if callable(payload): + payload = payload(document_ids) + res = stop_parse_documents(HttpApiAuth, dataset_id, payload) + assert res["code"] == 102 + assert res["message"] == "You don't own the document invalid_id." + + validate_document_parse_cancel(HttpApiAuth, dataset_id, document_ids) + + @pytest.mark.p3 + def test_repeated_stop_parse(self, HttpApiAuth, add_documents_func): + dataset_id, document_ids = add_documents_func + parse_documents(HttpApiAuth, dataset_id, {"document_ids": document_ids}) + res = stop_parse_documents(HttpApiAuth, dataset_id, {"document_ids": document_ids}) + assert res["code"] == 0 + + res = stop_parse_documents(HttpApiAuth, dataset_id, {"document_ids": document_ids}) + assert res["code"] == 102 + assert res["message"] == "Can't stop parsing document with progress at 0 or 1" + + @pytest.mark.p3 + def test_duplicate_stop_parse(self, HttpApiAuth, add_documents_func): + dataset_id, document_ids = add_documents_func + parse_documents(HttpApiAuth, dataset_id, {"document_ids": document_ids}) + res = stop_parse_documents(HttpApiAuth, dataset_id, {"document_ids": document_ids + document_ids}) + assert res["code"] == 0 + assert res["data"]["success_count"] == 3 + assert f"Duplicate document ids: {document_ids[0]}" in res["data"]["errors"] + + +@pytest.mark.skip(reason="unstable") +def test_stop_parse_100_files(HttpApiAuth, add_dataset_func, tmp_path): + document_num = 100 + dataset_id = add_dataset_func + document_ids = bulk_upload_documents(HttpApiAuth, dataset_id, document_num, tmp_path) + parse_documents(HttpApiAuth, dataset_id, {"document_ids": document_ids}) + sleep(1) + res = stop_parse_documents(HttpApiAuth, dataset_id, {"document_ids": document_ids}) + assert res["code"] == 0 + validate_document_parse_cancel(HttpApiAuth, dataset_id, document_ids) + + +@pytest.mark.skip(reason="unstable") +def test_concurrent_parse(HttpApiAuth, add_dataset_func, tmp_path): + document_num = 50 + dataset_id = add_dataset_func + document_ids = bulk_upload_documents(HttpApiAuth, dataset_id, document_num, tmp_path) + parse_documents(HttpApiAuth, dataset_id, {"document_ids": document_ids}) + + with ThreadPoolExecutor(max_workers=5) as executor: + futures = [ + executor.submit( + stop_parse_documents, + HttpApiAuth, + dataset_id, + {"document_ids": document_ids[i : i + 1]}, + ) + for i in range(document_num) + ] + responses = [f.result() for f in futures] + assert all(r["code"] == 0 for r in responses) + validate_document_parse_cancel(HttpApiAuth, dataset_id, document_ids) diff --git a/test/testcases/test_http_api/test_file_management_within_dataset/test_update_document.py b/test/testcases/test_http_api/test_file_management_within_dataset/test_update_document.py new file mode 100644 index 00000000000..ca7bbe5c748 --- /dev/null +++ b/test/testcases/test_http_api/test_file_management_within_dataset/test_update_document.py @@ -0,0 +1,548 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +import pytest +from common import list_documents, update_document +from configs import DOCUMENT_NAME_LIMIT, INVALID_API_TOKEN +from libs.auth import RAGFlowHttpApiAuth + + +@pytest.mark.p1 +class TestAuthorization: + @pytest.mark.parametrize( + "invalid_auth, expected_code, expected_message", + [ + (None, 0, "`Authorization` can't be empty"), + ( + RAGFlowHttpApiAuth(INVALID_API_TOKEN), + 109, + "Authentication error: API key is invalid!", + ), + ], + ) + def test_invalid_auth(self, invalid_auth, expected_code, expected_message): + res = update_document(invalid_auth, "dataset_id", "document_id") + assert res["code"] == expected_code + assert res["message"] == expected_message + + +class TestDocumentsUpdated: + @pytest.mark.p1 + @pytest.mark.parametrize( + "name, expected_code, expected_message", + [ + ("new_name.txt", 0, ""), + ( + f"{'a' * (DOCUMENT_NAME_LIMIT - 4)}.txt", + 0, + "", + ), + ( + 0, + 100, + """AttributeError("\'int\' object has no attribute \'encode\'")""", + ), + ( + None, + 100, + """AttributeError("\'NoneType\' object has no attribute \'encode\'")""", + ), + ( + "", + 101, + "The extension of file can't be changed", + ), + ( + "ragflow_test_upload_0", + 101, + "The extension of file can't be changed", + ), + ( + "ragflow_test_upload_1.txt", + 102, + "Duplicated document name in the same dataset.", + ), + ( + "RAGFLOW_TEST_UPLOAD_1.TXT", + 0, + "", + ), + ], + ) + def test_name(self, HttpApiAuth, add_documents, name, expected_code, expected_message): + dataset_id, document_ids = add_documents + res = update_document(HttpApiAuth, dataset_id, document_ids[0], {"name": name}) + assert res["code"] == expected_code + if expected_code == 0: + res = list_documents(HttpApiAuth, dataset_id, {"id": document_ids[0]}) + assert res["data"]["docs"][0]["name"] == name + else: + assert res["message"] == expected_message + + @pytest.mark.p3 + @pytest.mark.parametrize( + "document_id, expected_code, expected_message", + [ + ("", 100, ""), + ( + "invalid_document_id", + 102, + "The dataset doesn't own the document.", + ), + ], + ) + def test_invalid_document_id(self, HttpApiAuth, add_documents, document_id, expected_code, expected_message): + dataset_id, _ = add_documents + res = update_document(HttpApiAuth, dataset_id, document_id, {"name": "new_name.txt"}) + assert res["code"] == expected_code + assert res["message"] == expected_message + + @pytest.mark.p3 + @pytest.mark.parametrize( + "dataset_id, expected_code, expected_message", + [ + ("", 100, ""), + ( + "invalid_dataset_id", + 102, + "You don't own the dataset.", + ), + ], + ) + def test_invalid_dataset_id(self, HttpApiAuth, add_documents, dataset_id, expected_code, expected_message): + _, document_ids = add_documents + res = update_document(HttpApiAuth, dataset_id, document_ids[0], {"name": "new_name.txt"}) + assert res["code"] == expected_code + assert res["message"] == expected_message + + @pytest.mark.p3 + @pytest.mark.parametrize( + "meta_fields, expected_code, expected_message", + [({"test": "test"}, 0, ""), ("test", 102, "meta_fields must be a dictionary")], + ) + def test_meta_fields(self, HttpApiAuth, add_documents, meta_fields, expected_code, expected_message): + dataset_id, document_ids = add_documents + res = update_document(HttpApiAuth, dataset_id, document_ids[0], {"meta_fields": meta_fields}) + if expected_code == 0: + res = list_documents(HttpApiAuth, dataset_id, {"id": document_ids[0]}) + assert res["data"]["docs"][0]["meta_fields"] == meta_fields + else: + assert res["message"] == expected_message + + @pytest.mark.p2 + @pytest.mark.parametrize( + "chunk_method, expected_code, expected_message", + [ + ("naive", 0, ""), + ("manual", 0, ""), + ("qa", 0, ""), + ("table", 0, ""), + ("paper", 0, ""), + ("book", 0, ""), + ("laws", 0, ""), + ("presentation", 0, ""), + ("picture", 0, ""), + ("one", 0, ""), + ("knowledge_graph", 0, ""), + ("email", 0, ""), + ("tag", 0, ""), + ("", 102, "`chunk_method` doesn't exist"), + ( + "other_chunk_method", + 102, + "`chunk_method` other_chunk_method doesn't exist", + ), + ], + ) + def test_chunk_method(self, HttpApiAuth, add_documents, chunk_method, expected_code, expected_message): + dataset_id, document_ids = add_documents + res = update_document(HttpApiAuth, dataset_id, document_ids[0], {"chunk_method": chunk_method}) + assert res["code"] == expected_code + if expected_code == 0: + res = list_documents(HttpApiAuth, dataset_id, {"id": document_ids[0]}) + if chunk_method == "": + assert res["data"]["docs"][0]["chunk_method"] == "naive" + else: + assert res["data"]["docs"][0]["chunk_method"] == chunk_method + else: + assert res["message"] == expected_message + + @pytest.mark.p3 + @pytest.mark.parametrize( + "payload, expected_code, expected_message", + [ + ({"chunk_count": 1}, 102, "Can't change `chunk_count`."), + pytest.param( + {"create_date": "Fri, 14 Mar 2025 16:53:42 GMT"}, + 102, + "The input parameters are invalid.", + marks=pytest.mark.skip(reason="issues/6104"), + ), + pytest.param( + {"create_time": 1}, + 102, + "The input parameters are invalid.", + marks=pytest.mark.skip(reason="issues/6104"), + ), + pytest.param( + {"created_by": "ragflow_test"}, + 102, + "The input parameters are invalid.", + marks=pytest.mark.skip(reason="issues/6104"), + ), + pytest.param( + {"dataset_id": "ragflow_test"}, + 102, + "The input parameters are invalid.", + marks=pytest.mark.skip(reason="issues/6104"), + ), + pytest.param( + {"id": "ragflow_test"}, + 102, + "The input parameters are invalid.", + marks=pytest.mark.skip(reason="issues/6104"), + ), + pytest.param( + {"location": "ragflow_test.txt"}, + 102, + "The input parameters are invalid.", + marks=pytest.mark.skip(reason="issues/6104"), + ), + pytest.param( + {"process_begin_at": 1}, + 102, + "The input parameters are invalid.", + marks=pytest.mark.skip(reason="issues/6104"), + ), + pytest.param( + {"process_duation": 1.0}, + 102, + "The input parameters are invalid.", + marks=pytest.mark.skip(reason="issues/6104"), + ), + pytest.param({"progress": 1.0}, 102, "Can't change `progress`."), + pytest.param( + {"progress_msg": "ragflow_test"}, + 102, + "The input parameters are invalid.", + marks=pytest.mark.skip(reason="issues/6104"), + ), + pytest.param( + {"run": "ragflow_test"}, + 102, + "The input parameters are invalid.", + marks=pytest.mark.skip(reason="issues/6104"), + ), + pytest.param( + {"size": 1}, + 102, + "The input parameters are invalid.", + marks=pytest.mark.skip(reason="issues/6104"), + ), + pytest.param( + {"source_type": "ragflow_test"}, + 102, + "The input parameters are invalid.", + marks=pytest.mark.skip(reason="issues/6104"), + ), + pytest.param( + {"thumbnail": "ragflow_test"}, + 102, + "The input parameters are invalid.", + marks=pytest.mark.skip(reason="issues/6104"), + ), + ({"token_count": 1}, 102, "Can't change `token_count`."), + pytest.param( + {"type": "ragflow_test"}, + 102, + "The input parameters are invalid.", + marks=pytest.mark.skip(reason="issues/6104"), + ), + pytest.param( + {"update_date": "Fri, 14 Mar 2025 16:33:17 GMT"}, + 102, + "The input parameters are invalid.", + marks=pytest.mark.skip(reason="issues/6104"), + ), + pytest.param( + {"update_time": 1}, + 102, + "The input parameters are invalid.", + marks=pytest.mark.skip(reason="issues/6104"), + ), + ], + ) + def test_invalid_field( + self, + HttpApiAuth, + add_documents, + payload, + expected_code, + expected_message, + ): + dataset_id, document_ids = add_documents + res = update_document(HttpApiAuth, dataset_id, document_ids[0], payload) + assert res["code"] == expected_code + assert res["message"] == expected_message + + +class TestUpdateDocumentParserConfig: + @pytest.mark.p2 + @pytest.mark.parametrize( + "chunk_method, parser_config, expected_code, expected_message", + [ + ("naive", {}, 0, ""), + ( + "naive", + { + "chunk_token_num": 128, + "layout_recognize": "DeepDOC", + "html4excel": False, + "delimiter": r"\n", + "task_page_size": 12, + "raptor": {"use_raptor": False}, + }, + 0, + "", + ), + pytest.param( + "naive", + {"chunk_token_num": -1}, + 100, + "AssertionError('chunk_token_num should be in range from 1 to 100000000')", + marks=pytest.mark.skip(reason="issues/6098"), + ), + pytest.param( + "naive", + {"chunk_token_num": 0}, + 100, + "AssertionError('chunk_token_num should be in range from 1 to 100000000')", + marks=pytest.mark.skip(reason="issues/6098"), + ), + pytest.param( + "naive", + {"chunk_token_num": 100000000}, + 100, + "AssertionError('chunk_token_num should be in range from 1 to 100000000')", + marks=pytest.mark.skip(reason="issues/6098"), + ), + pytest.param( + "naive", + {"chunk_token_num": 3.14}, + 102, + "", + marks=pytest.mark.skip(reason="issues/6098"), + ), + pytest.param( + "naive", + {"chunk_token_num": "1024"}, + 100, + "", + marks=pytest.mark.skip(reason="issues/6098"), + ), + ( + "naive", + {"layout_recognize": "DeepDOC"}, + 0, + "", + ), + ( + "naive", + {"layout_recognize": "Naive"}, + 0, + "", + ), + ("naive", {"html4excel": True}, 0, ""), + ("naive", {"html4excel": False}, 0, ""), + pytest.param( + "naive", + {"html4excel": 1}, + 100, + "AssertionError('html4excel should be True or False')", + marks=pytest.mark.skip(reason="issues/6098"), + ), + ("naive", {"delimiter": ""}, 0, ""), + ("naive", {"delimiter": "`##`"}, 0, ""), + pytest.param( + "naive", + {"delimiter": 1}, + 100, + "", + marks=pytest.mark.skip(reason="issues/6098"), + ), + pytest.param( + "naive", + {"task_page_size": -1}, + 100, + "AssertionError('task_page_size should be in range from 1 to 100000000')", + marks=pytest.mark.skip(reason="issues/6098"), + ), + pytest.param( + "naive", + {"task_page_size": 0}, + 100, + "AssertionError('task_page_size should be in range from 1 to 100000000')", + marks=pytest.mark.skip(reason="issues/6098"), + ), + pytest.param( + "naive", + {"task_page_size": 100000000}, + 100, + "AssertionError('task_page_size should be in range from 1 to 100000000')", + marks=pytest.mark.skip(reason="issues/6098"), + ), + pytest.param( + "naive", + {"task_page_size": 3.14}, + 100, + "", + marks=pytest.mark.skip(reason="issues/6098"), + ), + pytest.param( + "naive", + {"task_page_size": "1024"}, + 100, + "", + marks=pytest.mark.skip(reason="issues/6098"), + ), + ("naive", {"raptor": {"use_raptor": True}}, 0, ""), + ("naive", {"raptor": {"use_raptor": False}}, 0, ""), + pytest.param( + "naive", + {"invalid_key": "invalid_value"}, + 100, + """AssertionError("Abnormal \'parser_config\'. Invalid key: invalid_key")""", + marks=pytest.mark.skip(reason="issues/6098"), + ), + pytest.param( + "naive", + {"auto_keywords": -1}, + 100, + "AssertionError('auto_keywords should be in range from 0 to 32')", + marks=pytest.mark.skip(reason="issues/6098"), + ), + pytest.param( + "naive", + {"auto_keywords": 32}, + 100, + "AssertionError('auto_keywords should be in range from 0 to 32')", + marks=pytest.mark.skip(reason="issues/6098"), + ), + pytest.param( + "naive", + {"auto_questions": 3.14}, + 100, + "", + marks=pytest.mark.skip(reason="issues/6098"), + ), + pytest.param( + "naive", + {"auto_keywords": "1024"}, + 100, + "", + marks=pytest.mark.skip(reason="issues/6098"), + ), + pytest.param( + "naive", + {"auto_questions": -1}, + 100, + "AssertionError('auto_questions should be in range from 0 to 10')", + marks=pytest.mark.skip(reason="issues/6098"), + ), + pytest.param( + "naive", + {"auto_questions": 10}, + 100, + "AssertionError('auto_questions should be in range from 0 to 10')", + marks=pytest.mark.skip(reason="issues/6098"), + ), + pytest.param( + "naive", + {"auto_questions": 3.14}, + 100, + "", + marks=pytest.mark.skip(reason="issues/6098"), + ), + pytest.param( + "naive", + {"auto_questions": "1024"}, + 100, + "", + marks=pytest.mark.skip(reason="issues/6098"), + ), + pytest.param( + "naive", + {"topn_tags": -1}, + 100, + "AssertionError('topn_tags should be in range from 0 to 10')", + marks=pytest.mark.skip(reason="issues/6098"), + ), + pytest.param( + "naive", + {"topn_tags": 10}, + 100, + "AssertionError('topn_tags should be in range from 0 to 10')", + marks=pytest.mark.skip(reason="issues/6098"), + ), + pytest.param( + "naive", + {"topn_tags": 3.14}, + 100, + "", + marks=pytest.mark.skip(reason="issues/6098"), + ), + pytest.param( + "naive", + {"topn_tags": "1024"}, + 100, + "", + marks=pytest.mark.skip(reason="issues/6098"), + ), + ], + ) + def test_parser_config( + self, + HttpApiAuth, + add_documents, + chunk_method, + parser_config, + expected_code, + expected_message, + ): + dataset_id, document_ids = add_documents + res = update_document( + HttpApiAuth, + dataset_id, + document_ids[0], + {"chunk_method": chunk_method, "parser_config": parser_config}, + ) + assert res["code"] == expected_code + if expected_code == 0: + res = list_documents(HttpApiAuth, dataset_id, {"id": document_ids[0]}) + if parser_config == {}: + assert res["data"]["docs"][0]["parser_config"] == { + "chunk_token_num": 128, + "delimiter": r"\n", + "html4excel": False, + "layout_recognize": "DeepDOC", + "raptor": {"use_raptor": False}, + } + else: + for k, v in parser_config.items(): + assert res["data"]["docs"][0]["parser_config"][k] == v + if expected_code != 0 or expected_message: + assert res["message"] == expected_message diff --git a/test/testcases/test_http_api/test_file_management_within_dataset/test_upload_documents.py b/test/testcases/test_http_api/test_file_management_within_dataset/test_upload_documents.py new file mode 100644 index 00000000000..f8f23864154 --- /dev/null +++ b/test/testcases/test_http_api/test_file_management_within_dataset/test_upload_documents.py @@ -0,0 +1,218 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import string +from concurrent.futures import ThreadPoolExecutor, as_completed + +import pytest +import requests +from common import FILE_API_URL, list_datasets, upload_documents +from configs import DOCUMENT_NAME_LIMIT, HOST_ADDRESS, INVALID_API_TOKEN +from libs.auth import RAGFlowHttpApiAuth +from requests_toolbelt import MultipartEncoder +from utils.file_utils import create_txt_file + + +@pytest.mark.p1 +@pytest.mark.usefixtures("clear_datasets") +class TestAuthorization: + @pytest.mark.parametrize( + "invalid_auth, expected_code, expected_message", + [ + (None, 0, "`Authorization` can't be empty"), + ( + RAGFlowHttpApiAuth(INVALID_API_TOKEN), + 109, + "Authentication error: API key is invalid!", + ), + ], + ) + def test_invalid_auth(self, invalid_auth, expected_code, expected_message): + res = upload_documents(invalid_auth, "dataset_id") + assert res["code"] == expected_code + assert res["message"] == expected_message + + +class TestDocumentsUpload: + @pytest.mark.p1 + def test_valid_single_upload(self, HttpApiAuth, add_dataset_func, tmp_path): + dataset_id = add_dataset_func + fp = create_txt_file(tmp_path / "ragflow_test.txt") + res = upload_documents(HttpApiAuth, dataset_id, [fp]) + assert res["code"] == 0 + assert res["data"][0]["dataset_id"] == dataset_id + assert res["data"][0]["name"] == fp.name + + @pytest.mark.p1 + @pytest.mark.parametrize( + "generate_test_files", + [ + "docx", + "excel", + "ppt", + "image", + "pdf", + "txt", + "md", + "json", + "eml", + "html", + ], + indirect=True, + ) + def test_file_type_validation(self, HttpApiAuth, add_dataset_func, generate_test_files, request): + dataset_id = add_dataset_func + fp = generate_test_files[request.node.callspec.params["generate_test_files"]] + res = upload_documents(HttpApiAuth, dataset_id, [fp]) + assert res["code"] == 0 + assert res["data"][0]["dataset_id"] == dataset_id + assert res["data"][0]["name"] == fp.name + + @pytest.mark.p2 + @pytest.mark.parametrize( + "file_type", + ["exe", "unknown"], + ) + def test_unsupported_file_type(self, HttpApiAuth, add_dataset_func, tmp_path, file_type): + dataset_id = add_dataset_func + fp = tmp_path / f"ragflow_test.{file_type}" + fp.touch() + res = upload_documents(HttpApiAuth, dataset_id, [fp]) + assert res["code"] == 500 + assert res["message"] == f"ragflow_test.{file_type}: This type of file has not been supported yet!" + + @pytest.mark.p2 + def test_missing_file(self, HttpApiAuth, add_dataset_func): + dataset_id = add_dataset_func + res = upload_documents(HttpApiAuth, dataset_id) + assert res["code"] == 101 + assert res["message"] == "No file part!" + + @pytest.mark.p3 + def test_empty_file(self, HttpApiAuth, add_dataset_func, tmp_path): + dataset_id = add_dataset_func + fp = tmp_path / "empty.txt" + fp.touch() + + res = upload_documents(HttpApiAuth, dataset_id, [fp]) + assert res["code"] == 0 + assert res["data"][0]["size"] == 0 + + @pytest.mark.p3 + def test_filename_empty(self, HttpApiAuth, add_dataset_func, tmp_path): + dataset_id = add_dataset_func + fp = create_txt_file(tmp_path / "ragflow_test.txt") + url = f"{HOST_ADDRESS}{FILE_API_URL}".format(dataset_id=dataset_id) + fields = (("file", ("", fp.open("rb"))),) + m = MultipartEncoder(fields=fields) + res = requests.post( + url=url, + headers={"Content-Type": m.content_type}, + auth=HttpApiAuth, + data=m, + ) + assert res.json()["code"] == 101 + assert res.json()["message"] == "No file selected!" + + @pytest.mark.p2 + def test_filename_max_length(self, HttpApiAuth, add_dataset_func, tmp_path): + dataset_id = add_dataset_func + fp = create_txt_file(tmp_path / f"{'a' * (DOCUMENT_NAME_LIMIT - 4)}.txt") + res = upload_documents(HttpApiAuth, dataset_id, [fp]) + assert res["code"] == 0 + assert res["data"][0]["name"] == fp.name + + @pytest.mark.p2 + def test_invalid_dataset_id(self, HttpApiAuth, tmp_path): + fp = create_txt_file(tmp_path / "ragflow_test.txt") + res = upload_documents(HttpApiAuth, "invalid_dataset_id", [fp]) + assert res["code"] == 100 + assert res["message"] == """LookupError("Can\'t find the dataset with ID invalid_dataset_id!")""" + + @pytest.mark.p2 + def test_duplicate_files(self, HttpApiAuth, add_dataset_func, tmp_path): + dataset_id = add_dataset_func + fp = create_txt_file(tmp_path / "ragflow_test.txt") + res = upload_documents(HttpApiAuth, dataset_id, [fp, fp]) + assert res["code"] == 0 + assert len(res["data"]) == 2 + for i in range(len(res["data"])): + assert res["data"][i]["dataset_id"] == dataset_id + expected_name = fp.name + if i != 0: + expected_name = f"{fp.stem}({i}){fp.suffix}" + assert res["data"][i]["name"] == expected_name + + @pytest.mark.p2 + def test_same_file_repeat(self, HttpApiAuth, add_dataset_func, tmp_path): + dataset_id = add_dataset_func + fp = create_txt_file(tmp_path / "ragflow_test.txt") + for i in range(3): + res = upload_documents(HttpApiAuth, dataset_id, [fp]) + assert res["code"] == 0 + assert len(res["data"]) == 1 + assert res["data"][0]["dataset_id"] == dataset_id + expected_name = fp.name + if i != 0: + expected_name = f"{fp.stem}({i}){fp.suffix}" + assert res["data"][0]["name"] == expected_name + + @pytest.mark.p3 + def test_filename_special_characters(self, HttpApiAuth, add_dataset_func, tmp_path): + dataset_id = add_dataset_func + illegal_chars = '<>:"/\\|?*' + translation_table = str.maketrans({char: "_" for char in illegal_chars}) + safe_filename = string.punctuation.translate(translation_table) + fp = tmp_path / f"{safe_filename}.txt" + fp.write_text("Sample text content") + + res = upload_documents(HttpApiAuth, dataset_id, [fp]) + assert res["code"] == 0 + assert len(res["data"]) == 1 + assert res["data"][0]["dataset_id"] == dataset_id + assert res["data"][0]["name"] == fp.name + + @pytest.mark.p1 + def test_multiple_files(self, HttpApiAuth, add_dataset_func, tmp_path): + dataset_id = add_dataset_func + expected_document_count = 20 + fps = [] + for i in range(expected_document_count): + fp = create_txt_file(tmp_path / f"ragflow_test_{i}.txt") + fps.append(fp) + res = upload_documents(HttpApiAuth, dataset_id, fps) + assert res["code"] == 0 + + res = list_datasets(HttpApiAuth, {"id": dataset_id}) + assert res["data"][0]["document_count"] == expected_document_count + + @pytest.mark.p3 + def test_concurrent_upload(self, HttpApiAuth, add_dataset_func, tmp_path): + dataset_id = add_dataset_func + + count = 20 + fps = [] + for i in range(count): + fp = create_txt_file(tmp_path / f"ragflow_test_{i}.txt") + fps.append(fp) + + with ThreadPoolExecutor(max_workers=5) as executor: + futures = [executor.submit(upload_documents, HttpApiAuth, dataset_id, fps[i : i + 1]) for i in range(count)] + responses = list(as_completed(futures)) + assert len(responses) == count, responses + assert all(future.result()["code"] == 0 for future in futures) + + res = list_datasets(HttpApiAuth, {"id": dataset_id}) + assert res["data"][0]["document_count"] == count diff --git a/test/testcases/test_http_api/test_session_management/conftest.py b/test/testcases/test_http_api/test_session_management/conftest.py new file mode 100644 index 00000000000..56eafab0aab --- /dev/null +++ b/test/testcases/test_http_api/test_session_management/conftest.py @@ -0,0 +1,41 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import pytest +from common import batch_add_sessions_with_chat_assistant, delete_session_with_chat_assistants + + +@pytest.fixture(scope="class") +def add_sessions_with_chat_assistant(request, HttpApiAuth, add_chat_assistants): + def cleanup(): + for chat_assistant_id in chat_assistant_ids: + delete_session_with_chat_assistants(HttpApiAuth, chat_assistant_id) + + request.addfinalizer(cleanup) + + _, _, chat_assistant_ids = add_chat_assistants + return chat_assistant_ids[0], batch_add_sessions_with_chat_assistant(HttpApiAuth, chat_assistant_ids[0], 5) + + +@pytest.fixture(scope="function") +def add_sessions_with_chat_assistant_func(request, HttpApiAuth, add_chat_assistants): + def cleanup(): + for chat_assistant_id in chat_assistant_ids: + delete_session_with_chat_assistants(HttpApiAuth, chat_assistant_id) + + request.addfinalizer(cleanup) + + _, _, chat_assistant_ids = add_chat_assistants + return chat_assistant_ids[0], batch_add_sessions_with_chat_assistant(HttpApiAuth, chat_assistant_ids[0], 5) diff --git a/test/testcases/test_http_api/test_session_management/test_create_session_with_chat_assistant.py b/test/testcases/test_http_api/test_session_management/test_create_session_with_chat_assistant.py new file mode 100644 index 00000000000..322fd1b7a71 --- /dev/null +++ b/test/testcases/test_http_api/test_session_management/test_create_session_with_chat_assistant.py @@ -0,0 +1,119 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from concurrent.futures import ThreadPoolExecutor, as_completed + +import pytest +from common import create_session_with_chat_assistant, delete_chat_assistants, list_session_with_chat_assistants +from configs import INVALID_API_TOKEN, SESSION_WITH_CHAT_NAME_LIMIT +from libs.auth import RAGFlowHttpApiAuth + + +@pytest.mark.p1 +class TestAuthorization: + @pytest.mark.parametrize( + "invalid_auth, expected_code, expected_message", + [ + (None, 0, "`Authorization` can't be empty"), + ( + RAGFlowHttpApiAuth(INVALID_API_TOKEN), + 109, + "Authentication error: API key is invalid!", + ), + ], + ) + def test_invalid_auth(self, invalid_auth, expected_code, expected_message): + res = create_session_with_chat_assistant(invalid_auth, "chat_assistant_id") + assert res["code"] == expected_code + assert res["message"] == expected_message + + +@pytest.mark.usefixtures("clear_session_with_chat_assistants") +class TestSessionWithChatAssistantCreate: + @pytest.mark.p1 + @pytest.mark.parametrize( + "payload, expected_code, expected_message", + [ + ({"name": "valid_name"}, 0, ""), + pytest.param({"name": "a" * (SESSION_WITH_CHAT_NAME_LIMIT + 1)}, 102, "", marks=pytest.mark.skip(reason="issues/")), + pytest.param({"name": 1}, 100, "", marks=pytest.mark.skip(reason="issues/")), + ({"name": ""}, 102, "`name` can not be empty."), + ({"name": "duplicated_name"}, 0, ""), + ({"name": "case insensitive"}, 0, ""), + ], + ) + def test_name(self, HttpApiAuth, add_chat_assistants, payload, expected_code, expected_message): + _, _, chat_assistant_ids = add_chat_assistants + if payload["name"] == "duplicated_name": + create_session_with_chat_assistant(HttpApiAuth, chat_assistant_ids[0], payload) + elif payload["name"] == "case insensitive": + create_session_with_chat_assistant(HttpApiAuth, chat_assistant_ids[0], {"name": payload["name"].upper()}) + + res = create_session_with_chat_assistant(HttpApiAuth, chat_assistant_ids[0], payload) + assert res["code"] == expected_code, res + if expected_code == 0: + assert res["data"]["name"] == payload["name"] + assert res["data"]["chat_id"] == chat_assistant_ids[0] + else: + assert res["message"] == expected_message + + @pytest.mark.p3 + @pytest.mark.parametrize( + "chat_assistant_id, expected_code, expected_message", + [ + ("", 100, ""), + ("invalid_chat_assistant_id", 102, "You do not own the assistant."), + ], + ) + def test_invalid_chat_assistant_id(self, HttpApiAuth, chat_assistant_id, expected_code, expected_message): + res = create_session_with_chat_assistant(HttpApiAuth, chat_assistant_id, {"name": "valid_name"}) + assert res["code"] == expected_code + assert res["message"] == expected_message + + @pytest.mark.p3 + def test_concurrent_create_session(self, HttpApiAuth, add_chat_assistants): + count = 1000 + _, _, chat_assistant_ids = add_chat_assistants + res = list_session_with_chat_assistants(HttpApiAuth, chat_assistant_ids[0]) + if res["code"] != 0: + assert False, res + sessions_count = len(res["data"]) + + with ThreadPoolExecutor(max_workers=5) as executor: + futures = [ + executor.submit( + create_session_with_chat_assistant, + HttpApiAuth, + chat_assistant_ids[0], + {"name": f"session with chat assistant test {i}"}, + ) + for i in range(count) + ] + responses = list(as_completed(futures)) + assert len(responses) == count, responses + assert all(future.result()["code"] == 0 for future in futures) + res = list_session_with_chat_assistants(HttpApiAuth, chat_assistant_ids[0], {"page_size": count * 2}) + if res["code"] != 0: + assert False, res + assert len(res["data"]) == sessions_count + count + + @pytest.mark.p3 + def test_add_session_to_deleted_chat_assistant(self, HttpApiAuth, add_chat_assistants): + _, _, chat_assistant_ids = add_chat_assistants + res = delete_chat_assistants(HttpApiAuth, {"ids": [chat_assistant_ids[0]]}) + assert res["code"] == 0 + res = create_session_with_chat_assistant(HttpApiAuth, chat_assistant_ids[0], {"name": "valid_name"}) + assert res["code"] == 102 + assert res["message"] == "You do not own the assistant." diff --git a/test/testcases/test_http_api/test_session_management/test_delete_sessions_with_chat_assistant.py b/test/testcases/test_http_api/test_session_management/test_delete_sessions_with_chat_assistant.py new file mode 100644 index 00000000000..818050819b2 --- /dev/null +++ b/test/testcases/test_http_api/test_session_management/test_delete_sessions_with_chat_assistant.py @@ -0,0 +1,172 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from concurrent.futures import ThreadPoolExecutor, as_completed + +import pytest +from common import batch_add_sessions_with_chat_assistant, delete_session_with_chat_assistants, list_session_with_chat_assistants +from configs import INVALID_API_TOKEN +from libs.auth import RAGFlowHttpApiAuth + + +@pytest.mark.p1 +class TestAuthorization: + @pytest.mark.parametrize( + "invalid_auth, expected_code, expected_message", + [ + (None, 0, "`Authorization` can't be empty"), + ( + RAGFlowHttpApiAuth(INVALID_API_TOKEN), + 109, + "Authentication error: API key is invalid!", + ), + ], + ) + def test_invalid_auth(self, invalid_auth, expected_code, expected_message): + res = delete_session_with_chat_assistants(invalid_auth, "chat_assistant_id") + assert res["code"] == expected_code + assert res["message"] == expected_message + + +class TestSessionWithChatAssistantDelete: + @pytest.mark.p3 + @pytest.mark.parametrize( + "chat_assistant_id, expected_code, expected_message", + [ + ("", 100, ""), + ( + "invalid_chat_assistant_id", + 102, + "You don't own the chat", + ), + ], + ) + def test_invalid_chat_assistant_id(self, HttpApiAuth, add_sessions_with_chat_assistant_func, chat_assistant_id, expected_code, expected_message): + _, session_ids = add_sessions_with_chat_assistant_func + res = delete_session_with_chat_assistants(HttpApiAuth, chat_assistant_id, {"ids": session_ids}) + assert res["code"] == expected_code + assert res["message"] == expected_message + + @pytest.mark.parametrize( + "payload", + [ + pytest.param(lambda r: {"ids": ["invalid_id"] + r}, marks=pytest.mark.p3), + pytest.param(lambda r: {"ids": r[:1] + ["invalid_id"] + r[1:5]}, marks=pytest.mark.p1), + pytest.param(lambda r: {"ids": r + ["invalid_id"]}, marks=pytest.mark.p3), + ], + ) + def test_delete_partial_invalid_id(self, HttpApiAuth, add_sessions_with_chat_assistant_func, payload): + chat_assistant_id, session_ids = add_sessions_with_chat_assistant_func + if callable(payload): + payload = payload(session_ids) + res = delete_session_with_chat_assistants(HttpApiAuth, chat_assistant_id, payload) + assert res["code"] == 0 + assert res["data"]["errors"][0] == "The chat doesn't own the session invalid_id" + + res = list_session_with_chat_assistants(HttpApiAuth, chat_assistant_id) + if res["code"] != 0: + assert False, res + assert len(res["data"]) == 0 + + @pytest.mark.p3 + def test_repeated_deletion(self, HttpApiAuth, add_sessions_with_chat_assistant_func): + chat_assistant_id, session_ids = add_sessions_with_chat_assistant_func + payload = {"ids": session_ids} + res = delete_session_with_chat_assistants(HttpApiAuth, chat_assistant_id, payload) + assert res["code"] == 0 + + res = delete_session_with_chat_assistants(HttpApiAuth, chat_assistant_id, payload) + assert res["code"] == 102 + assert "The chat doesn't own the session" in res["message"] + + @pytest.mark.p3 + def test_duplicate_deletion(self, HttpApiAuth, add_sessions_with_chat_assistant_func): + chat_assistant_id, session_ids = add_sessions_with_chat_assistant_func + res = delete_session_with_chat_assistants(HttpApiAuth, chat_assistant_id, {"ids": session_ids * 2}) + assert res["code"] == 0 + assert "Duplicate session ids" in res["data"]["errors"][0] + assert res["data"]["success_count"] == 5 + + res = list_session_with_chat_assistants(HttpApiAuth, chat_assistant_id) + if res["code"] != 0: + assert False, res + assert len(res["data"]) == 0 + + @pytest.mark.p3 + def test_concurrent_deletion(self, HttpApiAuth, add_chat_assistants): + count = 100 + _, _, chat_assistant_ids = add_chat_assistants + session_ids = batch_add_sessions_with_chat_assistant(HttpApiAuth, chat_assistant_ids[0], count) + + with ThreadPoolExecutor(max_workers=5) as executor: + futures = [ + executor.submit( + delete_session_with_chat_assistants, + HttpApiAuth, + chat_assistant_ids[0], + {"ids": session_ids[i : i + 1]}, + ) + for i in range(count) + ] + responses = list(as_completed(futures)) + assert len(responses) == count, responses + assert all(future.result()["code"] == 0 for future in futures) + + @pytest.mark.p3 + def test_delete_1k(self, HttpApiAuth, add_chat_assistants): + sessions_num = 1_000 + _, _, chat_assistant_ids = add_chat_assistants + session_ids = batch_add_sessions_with_chat_assistant(HttpApiAuth, chat_assistant_ids[0], sessions_num) + + res = delete_session_with_chat_assistants(HttpApiAuth, chat_assistant_ids[0], {"ids": session_ids}) + assert res["code"] == 0 + + res = list_session_with_chat_assistants(HttpApiAuth, chat_assistant_ids[0]) + if res["code"] != 0: + assert False, res + assert len(res["data"]) == 0 + + @pytest.mark.parametrize( + "payload, expected_code, expected_message, remaining", + [ + pytest.param(None, 0, """TypeError("argument of type \'NoneType\' is not iterable")""", 0, marks=pytest.mark.skip), + pytest.param({"ids": ["invalid_id"]}, 102, "The chat doesn't own the session invalid_id", 5, marks=pytest.mark.p3), + pytest.param("not json", 100, """AttributeError("\'str\' object has no attribute \'get\'")""", 5, marks=pytest.mark.skip), + pytest.param(lambda r: {"ids": r[:1]}, 0, "", 4, marks=pytest.mark.p3), + pytest.param(lambda r: {"ids": r}, 0, "", 0, marks=pytest.mark.p1), + pytest.param({"ids": []}, 0, "", 0, marks=pytest.mark.p3), + ], + ) + def test_basic_scenarios( + self, + HttpApiAuth, + add_sessions_with_chat_assistant_func, + payload, + expected_code, + expected_message, + remaining, + ): + chat_assistant_id, session_ids = add_sessions_with_chat_assistant_func + if callable(payload): + payload = payload(session_ids) + res = delete_session_with_chat_assistants(HttpApiAuth, chat_assistant_id, payload) + assert res["code"] == expected_code + if res["code"] != 0: + assert res["message"] == expected_message + + res = list_session_with_chat_assistants(HttpApiAuth, chat_assistant_id) + if res["code"] != 0: + assert False, res + assert len(res["data"]) == remaining diff --git a/test/testcases/test_http_api/test_session_management/test_list_sessions_with_chat_assistant.py b/test/testcases/test_http_api/test_session_management/test_list_sessions_with_chat_assistant.py new file mode 100644 index 00000000000..fb1f1737a32 --- /dev/null +++ b/test/testcases/test_http_api/test_session_management/test_list_sessions_with_chat_assistant.py @@ -0,0 +1,250 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from concurrent.futures import ThreadPoolExecutor, as_completed + +import pytest +from common import delete_chat_assistants, list_session_with_chat_assistants +from configs import INVALID_API_TOKEN +from libs.auth import RAGFlowHttpApiAuth +from utils import is_sorted + + +@pytest.mark.p1 +class TestAuthorization: + @pytest.mark.parametrize( + "invalid_auth, expected_code, expected_message", + [ + (None, 0, "`Authorization` can't be empty"), + ( + RAGFlowHttpApiAuth(INVALID_API_TOKEN), + 109, + "Authentication error: API key is invalid!", + ), + ], + ) + def test_invalid_auth(self, invalid_auth, expected_code, expected_message): + res = list_session_with_chat_assistants(invalid_auth, "chat_assistant_id") + assert res["code"] == expected_code + assert res["message"] == expected_message + + +class TestSessionsWithChatAssistantList: + @pytest.mark.p1 + @pytest.mark.parametrize( + "params, expected_code, expected_page_size, expected_message", + [ + ({"page": None, "page_size": 2}, 0, 2, ""), + pytest.param({"page": 0, "page_size": 2}, 100, 0, "ValueError('Search does not support negative slicing.')", marks=pytest.mark.skip), + ({"page": 2, "page_size": 2}, 0, 2, ""), + ({"page": 3, "page_size": 2}, 0, 1, ""), + ({"page": "3", "page_size": 2}, 0, 1, ""), + pytest.param({"page": -1, "page_size": 2}, 100, 0, "ValueError('Search does not support negative slicing.')", marks=pytest.mark.skip), + pytest.param({"page": "a", "page_size": 2}, 100, 0, """ValueError("invalid literal for int() with base 10: \'a\'")""", marks=pytest.mark.skip), + ], + ) + def test_page(self, HttpApiAuth, add_sessions_with_chat_assistant, params, expected_code, expected_page_size, expected_message): + chat_assistant_id, _ = add_sessions_with_chat_assistant + res = list_session_with_chat_assistants(HttpApiAuth, chat_assistant_id, params=params) + assert res["code"] == expected_code + if expected_code == 0: + assert len(res["data"]) == expected_page_size + else: + assert res["message"] == expected_message + + @pytest.mark.p1 + @pytest.mark.parametrize( + "params, expected_code, expected_page_size, expected_message", + [ + ({"page_size": None}, 0, 5, ""), + ({"page_size": 0}, 0, 0, ""), + ({"page_size": 1}, 0, 1, ""), + ({"page_size": 6}, 0, 5, ""), + ({"page_size": "1"}, 0, 1, ""), + pytest.param({"page_size": -1}, 0, 5, "", marks=pytest.mark.skip), + pytest.param({"page_size": "a"}, 100, 0, """ValueError("invalid literal for int() with base 10: \'a\'")""", marks=pytest.mark.skip), + ], + ) + def test_page_size(self, HttpApiAuth, add_sessions_with_chat_assistant, params, expected_code, expected_page_size, expected_message): + chat_assistant_id, _ = add_sessions_with_chat_assistant + res = list_session_with_chat_assistants(HttpApiAuth, chat_assistant_id, params=params) + assert res["code"] == expected_code + if expected_code == 0: + assert len(res["data"]) == expected_page_size + else: + assert res["message"] == expected_message + + @pytest.mark.p3 + @pytest.mark.parametrize( + "params, expected_code, assertions, expected_message", + [ + ({"orderby": None}, 0, lambda r: (is_sorted(r["data"], "create_time", True)), ""), + ({"orderby": "create_time"}, 0, lambda r: (is_sorted(r["data"], "create_time", True)), ""), + ({"orderby": "update_time"}, 0, lambda r: (is_sorted(r["data"], "update_time", True)), ""), + ({"orderby": "name", "desc": "False"}, 0, lambda r: (is_sorted(r["data"], "name", False)), ""), + pytest.param({"orderby": "unknown"}, 102, 0, "orderby should be create_time or update_time", marks=pytest.mark.skip(reason="issues/")), + ], + ) + def test_orderby( + self, + HttpApiAuth, + add_sessions_with_chat_assistant, + params, + expected_code, + assertions, + expected_message, + ): + chat_assistant_id, _ = add_sessions_with_chat_assistant + res = list_session_with_chat_assistants(HttpApiAuth, chat_assistant_id, params=params) + assert res["code"] == expected_code + if expected_code == 0: + if callable(assertions): + assert assertions(res) + else: + assert res["message"] == expected_message + + @pytest.mark.p3 + @pytest.mark.parametrize( + "params, expected_code, assertions, expected_message", + [ + ({"desc": None}, 0, lambda r: (is_sorted(r["data"], "create_time", True)), ""), + ({"desc": "true"}, 0, lambda r: (is_sorted(r["data"], "create_time", True)), ""), + ({"desc": "True"}, 0, lambda r: (is_sorted(r["data"], "create_time", True)), ""), + ({"desc": True}, 0, lambda r: (is_sorted(r["data"], "create_time", True)), ""), + ({"desc": "false"}, 0, lambda r: (is_sorted(r["data"], "create_time", False)), ""), + ({"desc": "False"}, 0, lambda r: (is_sorted(r["data"], "create_time", False)), ""), + ({"desc": False}, 0, lambda r: (is_sorted(r["data"], "create_time", False)), ""), + ({"desc": "False", "orderby": "update_time"}, 0, lambda r: (is_sorted(r["data"], "update_time", False)), ""), + pytest.param({"desc": "unknown"}, 102, 0, "desc should be true or false", marks=pytest.mark.skip(reason="issues/")), + ], + ) + def test_desc( + self, + HttpApiAuth, + add_sessions_with_chat_assistant, + params, + expected_code, + assertions, + expected_message, + ): + chat_assistant_id, _ = add_sessions_with_chat_assistant + res = list_session_with_chat_assistants(HttpApiAuth, chat_assistant_id, params=params) + assert res["code"] == expected_code + if expected_code == 0: + if callable(assertions): + assert assertions(res) + else: + assert res["message"] == expected_message + + @pytest.mark.p1 + @pytest.mark.parametrize( + "params, expected_code, expected_num, expected_message", + [ + ({"name": None}, 0, 5, ""), + ({"name": ""}, 0, 5, ""), + ({"name": "session_with_chat_assistant_1"}, 0, 1, ""), + ({"name": "unknown"}, 0, 0, ""), + ], + ) + def test_name(self, HttpApiAuth, add_sessions_with_chat_assistant, params, expected_code, expected_num, expected_message): + chat_assistant_id, _ = add_sessions_with_chat_assistant + res = list_session_with_chat_assistants(HttpApiAuth, chat_assistant_id, params=params) + assert res["code"] == expected_code + if expected_code == 0: + if params["name"] == "session_with_chat_assistant_1": + assert res["data"][0]["name"] == params["name"] + else: + assert len(res["data"]) == expected_num + else: + assert res["message"] == expected_message + + @pytest.mark.p1 + @pytest.mark.parametrize( + "session_id, expected_code, expected_num, expected_message", + [ + (None, 0, 5, ""), + ("", 0, 5, ""), + (lambda r: r[0], 0, 1, ""), + ("unknown", 0, 0, "The chat doesn't exist"), + ], + ) + def test_id(self, HttpApiAuth, add_sessions_with_chat_assistant, session_id, expected_code, expected_num, expected_message): + chat_assistant_id, session_ids = add_sessions_with_chat_assistant + if callable(session_id): + params = {"id": session_id(session_ids)} + else: + params = {"id": session_id} + + res = list_session_with_chat_assistants(HttpApiAuth, chat_assistant_id, params=params) + assert res["code"] == expected_code + if expected_code == 0: + if params["id"] == session_ids[0]: + assert res["data"][0]["id"] == params["id"] + else: + assert len(res["data"]) == expected_num + else: + assert res["message"] == expected_message + + @pytest.mark.p3 + @pytest.mark.parametrize( + "session_id, name, expected_code, expected_num, expected_message", + [ + (lambda r: r[0], "session_with_chat_assistant_0", 0, 1, ""), + (lambda r: r[0], "session_with_chat_assistant_100", 0, 0, ""), + (lambda r: r[0], "unknown", 0, 0, ""), + ("id", "session_with_chat_assistant_0", 0, 0, ""), + ], + ) + def test_name_and_id(self, HttpApiAuth, add_sessions_with_chat_assistant, session_id, name, expected_code, expected_num, expected_message): + chat_assistant_id, session_ids = add_sessions_with_chat_assistant + if callable(session_id): + params = {"id": session_id(session_ids), "name": name} + else: + params = {"id": session_id, "name": name} + + res = list_session_with_chat_assistants(HttpApiAuth, chat_assistant_id, params=params) + assert res["code"] == expected_code + if expected_code == 0: + assert len(res["data"]) == expected_num + else: + assert res["message"] == expected_message + + @pytest.mark.p3 + def test_concurrent_list(self, HttpApiAuth, add_sessions_with_chat_assistant): + count = 100 + chat_assistant_id, _ = add_sessions_with_chat_assistant + with ThreadPoolExecutor(max_workers=5) as executor: + futures = [executor.submit(list_session_with_chat_assistants, HttpApiAuth, chat_assistant_id) for i in range(count)] + responses = list(as_completed(futures)) + assert len(responses) == count, responses + assert all(future.result()["code"] == 0 for future in futures) + + @pytest.mark.p3 + def test_invalid_params(self, HttpApiAuth, add_sessions_with_chat_assistant): + chat_assistant_id, _ = add_sessions_with_chat_assistant + params = {"a": "b"} + res = list_session_with_chat_assistants(HttpApiAuth, chat_assistant_id, params=params) + assert res["code"] == 0 + assert len(res["data"]) == 5 + + @pytest.mark.p3 + def test_list_chats_after_deleting_associated_chat_assistant(self, HttpApiAuth, add_sessions_with_chat_assistant): + chat_assistant_id, _ = add_sessions_with_chat_assistant + res = delete_chat_assistants(HttpApiAuth, {"ids": [chat_assistant_id]}) + assert res["code"] == 0 + + res = list_session_with_chat_assistants(HttpApiAuth, chat_assistant_id) + assert res["code"] == 102 + assert "You don't own the assistant" in res["message"] diff --git a/test/testcases/test_http_api/test_session_management/test_update_session_with_chat_assistant.py b/test/testcases/test_http_api/test_session_management/test_update_session_with_chat_assistant.py new file mode 100644 index 00000000000..e035e876b54 --- /dev/null +++ b/test/testcases/test_http_api/test_session_management/test_update_session_with_chat_assistant.py @@ -0,0 +1,150 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from concurrent.futures import ThreadPoolExecutor, as_completed +from random import randint + +import pytest +from common import delete_chat_assistants, list_session_with_chat_assistants, update_session_with_chat_assistant +from configs import INVALID_API_TOKEN, SESSION_WITH_CHAT_NAME_LIMIT +from libs.auth import RAGFlowHttpApiAuth + + +@pytest.mark.p1 +class TestAuthorization: + @pytest.mark.parametrize( + "invalid_auth, expected_code, expected_message", + [ + (None, 0, "`Authorization` can't be empty"), + ( + RAGFlowHttpApiAuth(INVALID_API_TOKEN), + 109, + "Authentication error: API key is invalid!", + ), + ], + ) + def test_invalid_auth(self, invalid_auth, expected_code, expected_message): + res = update_session_with_chat_assistant(invalid_auth, "chat_assistant_id", "session_id") + assert res["code"] == expected_code + assert res["message"] == expected_message + + +class TestSessionWithChatAssistantUpdate: + @pytest.mark.parametrize( + "payload, expected_code, expected_message", + [ + pytest.param({"name": "valid_name"}, 0, "", marks=pytest.mark.p1), + pytest.param({"name": "a" * (SESSION_WITH_CHAT_NAME_LIMIT + 1)}, 102, "", marks=pytest.mark.skip(reason="issues/")), + pytest.param({"name": 1}, 100, "", marks=pytest.mark.skip(reason="issues/")), + pytest.param({"name": ""}, 102, "`name` can not be empty.", marks=pytest.mark.p3), + pytest.param({"name": "duplicated_name"}, 0, "", marks=pytest.mark.p3), + pytest.param({"name": "case insensitive"}, 0, "", marks=pytest.mark.p3), + ], + ) + def test_name(self, HttpApiAuth, add_sessions_with_chat_assistant_func, payload, expected_code, expected_message): + chat_assistant_id, session_ids = add_sessions_with_chat_assistant_func + if payload["name"] == "duplicated_name": + update_session_with_chat_assistant(HttpApiAuth, chat_assistant_id, session_ids[0], payload) + elif payload["name"] == "case insensitive": + update_session_with_chat_assistant(HttpApiAuth, chat_assistant_id, session_ids[0], {"name": payload["name"].upper()}) + + res = update_session_with_chat_assistant(HttpApiAuth, chat_assistant_id, session_ids[0], payload) + assert res["code"] == expected_code, res + if expected_code == 0: + res = list_session_with_chat_assistants(HttpApiAuth, chat_assistant_id, {"id": session_ids[0]}) + assert res["data"][0]["name"] == payload["name"] + else: + assert res["message"] == expected_message + + @pytest.mark.p3 + @pytest.mark.parametrize( + "chat_assistant_id, expected_code, expected_message", + [ + ("", 100, ""), + pytest.param("invalid_chat_assistant_id", 102, "Session does not exist", marks=pytest.mark.skip(reason="issues/")), + ], + ) + def test_invalid_chat_assistant_id(self, HttpApiAuth, add_sessions_with_chat_assistant_func, chat_assistant_id, expected_code, expected_message): + _, session_ids = add_sessions_with_chat_assistant_func + res = update_session_with_chat_assistant(HttpApiAuth, chat_assistant_id, session_ids[0], {"name": "valid_name"}) + assert res["code"] == expected_code + assert res["message"] == expected_message + + @pytest.mark.p3 + @pytest.mark.parametrize( + "session_id, expected_code, expected_message", + [ + ("", 100, ""), + ("invalid_session_id", 102, "Session does not exist"), + ], + ) + def test_invalid_session_id(self, HttpApiAuth, add_sessions_with_chat_assistant_func, session_id, expected_code, expected_message): + chat_assistant_id, _ = add_sessions_with_chat_assistant_func + res = update_session_with_chat_assistant(HttpApiAuth, chat_assistant_id, session_id, {"name": "valid_name"}) + assert res["code"] == expected_code + assert res["message"] == expected_message + + @pytest.mark.p3 + def test_repeated_update_session(self, HttpApiAuth, add_sessions_with_chat_assistant_func): + chat_assistant_id, session_ids = add_sessions_with_chat_assistant_func + res = update_session_with_chat_assistant(HttpApiAuth, chat_assistant_id, session_ids[0], {"name": "valid_name_1"}) + assert res["code"] == 0 + + res = update_session_with_chat_assistant(HttpApiAuth, chat_assistant_id, session_ids[0], {"name": "valid_name_2"}) + assert res["code"] == 0 + + @pytest.mark.p3 + @pytest.mark.parametrize( + "payload, expected_code, expected_message", + [ + pytest.param({"unknown_key": "unknown_value"}, 100, "ValueError", marks=pytest.mark.skip), + ({}, 0, ""), + pytest.param(None, 100, "TypeError", marks=pytest.mark.skip), + ], + ) + def test_invalid_params(self, HttpApiAuth, add_sessions_with_chat_assistant_func, payload, expected_code, expected_message): + chat_assistant_id, session_ids = add_sessions_with_chat_assistant_func + res = update_session_with_chat_assistant(HttpApiAuth, chat_assistant_id, session_ids[0], payload) + assert res["code"] == expected_code + if expected_code != 0: + assert expected_message in res["message"] + + @pytest.mark.p3 + def test_concurrent_update_session(self, HttpApiAuth, add_sessions_with_chat_assistant_func): + count = 50 + chat_assistant_id, session_ids = add_sessions_with_chat_assistant_func + + with ThreadPoolExecutor(max_workers=5) as executor: + futures = [ + executor.submit( + update_session_with_chat_assistant, + HttpApiAuth, + chat_assistant_id, + session_ids[randint(0, 4)], + {"name": f"update session test {i}"}, + ) + for i in range(count) + ] + responses = list(as_completed(futures)) + assert len(responses) == count, responses + assert all(future.result()["code"] == 0 for future in futures) + + @pytest.mark.p3 + def test_update_session_to_deleted_chat_assistant(self, HttpApiAuth, add_sessions_with_chat_assistant_func): + chat_assistant_id, session_ids = add_sessions_with_chat_assistant_func + delete_chat_assistants(HttpApiAuth, {"ids": [chat_assistant_id]}) + res = update_session_with_chat_assistant(HttpApiAuth, chat_assistant_id, session_ids[0], {"name": "valid_name"}) + assert res["code"] == 102 + assert res["message"] == "You do not own the session" diff --git a/test/testcases/test_sdk_api/common.py b/test/testcases/test_sdk_api/common.py new file mode 100644 index 00000000000..3035383a472 --- /dev/null +++ b/test/testcases/test_sdk_api/common.py @@ -0,0 +1,52 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pathlib import Path + +from ragflow_sdk import Chat, Chunk, DataSet, Document, RAGFlow, Session +from utils.file_utils import create_txt_file + + +# DATASET MANAGEMENT +def batch_create_datasets(client: RAGFlow, num: int) -> list[DataSet]: + return [client.create_dataset(name=f"dataset_{i}") for i in range(num)] + + +# FILE MANAGEMENT WITHIN DATASET +def bulk_upload_documents(dataset: DataSet, num: int, tmp_path: Path) -> list[Document]: + document_infos = [] + for i in range(num): + fp = create_txt_file(tmp_path / f"ragflow_test_upload_{i}.txt") + with fp.open("rb") as f: + blob = f.read() + document_infos.append({"display_name": fp.name, "blob": blob}) + + return dataset.upload_documents(document_infos) + + +# CHUNK MANAGEMENT WITHIN DATASET +def batch_add_chunks(document: Document, num: int) -> list[Chunk]: + return [document.add_chunk(content=f"chunk test {i}") for i in range(num)] + + +# CHAT ASSISTANT MANAGEMENT +def batch_create_chat_assistants(client: RAGFlow, num: int) -> list[Chat]: + return [client.create_chat(name=f"test_chat_assistant_{i}") for i in range(num)] + + +# SESSION MANAGEMENT +def batch_add_sessions_with_chat_assistant(chat_assistant: Chat, num) -> list[Session]: + return [chat_assistant.create_session(name=f"session_with_chat_assistant_{i}") for i in range(num)] diff --git a/test/testcases/test_sdk_api/conftest.py b/test/testcases/test_sdk_api/conftest.py new file mode 100644 index 00000000000..11a258a5ad1 --- /dev/null +++ b/test/testcases/test_sdk_api/conftest.py @@ -0,0 +1,173 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pathlib import Path +from time import sleep + +import pytest +from common import ( + batch_add_chunks, + batch_create_chat_assistants, + batch_create_datasets, + bulk_upload_documents, +) +from configs import HOST_ADDRESS, VERSION +from pytest import FixtureRequest +from ragflow_sdk import Chat, Chunk, DataSet, Document, RAGFlow +from utils import wait_for +from utils.file_utils import ( + create_docx_file, + create_eml_file, + create_excel_file, + create_html_file, + create_image_file, + create_json_file, + create_md_file, + create_pdf_file, + create_ppt_file, + create_txt_file, +) + + +@wait_for(30, 1, "Document parsing timeout") +def condition(_dataset: DataSet): + documents = _dataset.list_documents(page_size=1000) + for document in documents: + if document.run != "DONE": + return False + return True + + +@pytest.fixture +def generate_test_files(request: FixtureRequest, tmp_path: Path): + file_creators = { + "docx": (tmp_path / "ragflow_test.docx", create_docx_file), + "excel": (tmp_path / "ragflow_test.xlsx", create_excel_file), + "ppt": (tmp_path / "ragflow_test.pptx", create_ppt_file), + "image": (tmp_path / "ragflow_test.png", create_image_file), + "pdf": (tmp_path / "ragflow_test.pdf", create_pdf_file), + "txt": (tmp_path / "ragflow_test.txt", create_txt_file), + "md": (tmp_path / "ragflow_test.md", create_md_file), + "json": (tmp_path / "ragflow_test.json", create_json_file), + "eml": (tmp_path / "ragflow_test.eml", create_eml_file), + "html": (tmp_path / "ragflow_test.html", create_html_file), + } + + files = {} + for file_type, (file_path, creator_func) in file_creators.items(): + if request.param in ["", file_type]: + creator_func(file_path) + files[file_type] = file_path + return files + + +@pytest.fixture(scope="class") +def ragflow_tmp_dir(request: FixtureRequest, tmp_path_factory: Path) -> Path: + class_name = request.cls.__name__ + return tmp_path_factory.mktemp(class_name) + + +@pytest.fixture(scope="session") +def client(token: str) -> RAGFlow: + return RAGFlow(api_key=token, base_url=HOST_ADDRESS, version=VERSION) + + +@pytest.fixture(scope="function") +def clear_datasets(request: FixtureRequest, client: RAGFlow): + def cleanup(): + client.delete_datasets(ids=None) + + request.addfinalizer(cleanup) + + +@pytest.fixture(scope="function") +def clear_chat_assistants(request: FixtureRequest, client: RAGFlow): + def cleanup(): + client.delete_chats(ids=None) + + request.addfinalizer(cleanup) + + +@pytest.fixture(scope="function") +def clear_session_with_chat_assistants(request, add_chat_assistants): + def cleanup(): + for chat_assistant in chat_assistants: + try: + chat_assistant.delete_sessions(ids=None) + except Exception: + pass + + request.addfinalizer(cleanup) + + _, _, chat_assistants = add_chat_assistants + + +@pytest.fixture(scope="class") +def add_dataset(request: FixtureRequest, client: RAGFlow) -> DataSet: + def cleanup(): + client.delete_datasets(ids=None) + + request.addfinalizer(cleanup) + return batch_create_datasets(client, 1)[0] + + +@pytest.fixture(scope="function") +def add_dataset_func(request: FixtureRequest, client: RAGFlow) -> DataSet: + def cleanup(): + client.delete_datasets(ids=None) + + request.addfinalizer(cleanup) + return batch_create_datasets(client, 1)[0] + + +@pytest.fixture(scope="class") +def add_document(add_dataset: DataSet, ragflow_tmp_dir: Path) -> tuple[DataSet, Document]: + return add_dataset, bulk_upload_documents(add_dataset, 1, ragflow_tmp_dir)[0] + + +@pytest.fixture(scope="class") +def add_chunks(request: FixtureRequest, add_document: tuple[DataSet, Document]) -> tuple[DataSet, Document, list[Chunk]]: + def cleanup(): + try: + document.delete_chunks(ids=[]) + except Exception: + pass + + request.addfinalizer(cleanup) + + dataset, document = add_document + dataset.async_parse_documents([document.id]) + condition(dataset) + chunks = batch_add_chunks(document, 4) + # issues/6487 + sleep(1) + return dataset, document, chunks + + +@pytest.fixture(scope="class") +def add_chat_assistants(request, client, add_document) -> tuple[DataSet, Document, list[Chat]]: + def cleanup(): + try: + client.delete_chats(ids=None) + except Exception: + pass + + request.addfinalizer(cleanup) + + dataset, document = add_document + dataset.async_parse_documents([document.id]) + condition(dataset) + return dataset, document, batch_create_chat_assistants(client, 5) diff --git a/test/testcases/test_sdk_api/test_chat_assistant_management/conftest.py b/test/testcases/test_sdk_api/test_chat_assistant_management/conftest.py new file mode 100644 index 00000000000..79347d67a99 --- /dev/null +++ b/test/testcases/test_sdk_api/test_chat_assistant_management/conftest.py @@ -0,0 +1,42 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import pytest +from common import batch_create_chat_assistants +from pytest import FixtureRequest +from ragflow_sdk import Chat, DataSet, Document, RAGFlow +from utils import wait_for + + +@wait_for(30, 1, "Document parsing timeout") +def condition(_dataset: DataSet): + documents = _dataset.list_documents(page_size=1000) + for document in documents: + if document.run != "DONE": + return False + return True + + +@pytest.fixture(scope="function") +def add_chat_assistants_func(request: FixtureRequest, client: RAGFlow, add_document: tuple[DataSet, Document]) -> tuple[DataSet, Document, list[Chat]]: + def cleanup(): + client.delete_chats(ids=None) + + request.addfinalizer(cleanup) + + dataset, document = add_document + dataset.async_parse_documents([document.id]) + condition(dataset) + return dataset, document, batch_create_chat_assistants(client, 5) diff --git a/test/testcases/test_sdk_api/test_chat_assistant_management/test_create_chat_assistant.py b/test/testcases/test_sdk_api/test_chat_assistant_management/test_create_chat_assistant.py new file mode 100644 index 00000000000..7ba87a2a927 --- /dev/null +++ b/test/testcases/test_sdk_api/test_chat_assistant_management/test_create_chat_assistant.py @@ -0,0 +1,224 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from operator import attrgetter + +import pytest +from configs import CHAT_ASSISTANT_NAME_LIMIT +from ragflow_sdk import Chat +from utils import encode_avatar +from utils.file_utils import create_image_file + + +@pytest.mark.usefixtures("clear_chat_assistants") +class TestChatAssistantCreate: + @pytest.mark.p1 + @pytest.mark.usefixtures("add_chunks") + @pytest.mark.parametrize( + "name, expected_message", + [ + ("valid_name", ""), + pytest.param("a" * (CHAT_ASSISTANT_NAME_LIMIT + 1), "", marks=pytest.mark.skip(reason="issues/")), + pytest.param(1, "", marks=pytest.mark.skip(reason="issues/")), + ("", "`name` is required."), + ("duplicated_name", "Duplicated chat name in creating chat."), + ("case insensitive", "Duplicated chat name in creating chat."), + ], + ) + def test_name(self, client, name, expected_message): + if name == "duplicated_name": + client.create_chat(name=name) + elif name == "case insensitive": + client.create_chat(name=name.upper()) + + if expected_message: + with pytest.raises(Exception) as excinfo: + client.create_chat(name=name) + assert expected_message in str(excinfo.value) + else: + chat_assistant = client.create_chat(name=name) + assert chat_assistant.name == name + + @pytest.mark.p1 + @pytest.mark.parametrize( + "dataset_ids, expected_message", + [ + ([], ""), + (lambda r: [r], ""), + (["invalid_dataset_id"], "You don't own the dataset invalid_dataset_id"), + ("invalid_dataset_id", "You don't own the dataset i"), + ], + ) + def test_dataset_ids(self, client, add_chunks, dataset_ids, expected_message): + dataset, _, _ = add_chunks + if callable(dataset_ids): + dataset_ids = dataset_ids(dataset.id) + + if expected_message: + with pytest.raises(Exception) as excinfo: + client.create_chat(name="ragflow test", dataset_ids=dataset_ids) + assert expected_message in str(excinfo.value) + else: + chat_assistant = client.create_chat(name="ragflow test", dataset_ids=dataset_ids) + assert chat_assistant.name == "ragflow test" + + @pytest.mark.p3 + def test_avatar(self, client, tmp_path): + fn = create_image_file(tmp_path / "ragflow_test.png") + chat_assistant = client.create_chat(name="avatar_test", avatar=encode_avatar(fn), dataset_ids=[]) + assert chat_assistant.name == "avatar_test" + + @pytest.mark.p2 + @pytest.mark.parametrize( + "llm, expected_message", + [ + ({}, ""), + ({"model_name": "glm-4"}, ""), + ({"model_name": "unknown"}, "`model_name` unknown doesn't exist"), + ({"temperature": 0}, ""), + ({"temperature": 1}, ""), + pytest.param({"temperature": -1}, "", marks=pytest.mark.skip), + pytest.param({"temperature": 10}, "", marks=pytest.mark.skip), + pytest.param({"temperature": "a"}, "", marks=pytest.mark.skip), + ({"top_p": 0}, ""), + ({"top_p": 1}, ""), + pytest.param({"top_p": -1}, "", marks=pytest.mark.skip), + pytest.param({"top_p": 10}, "", marks=pytest.mark.skip), + pytest.param({"top_p": "a"}, "", marks=pytest.mark.skip), + ({"presence_penalty": 0}, ""), + ({"presence_penalty": 1}, ""), + pytest.param({"presence_penalty": -1}, "", marks=pytest.mark.skip), + pytest.param({"presence_penalty": 10}, "", marks=pytest.mark.skip), + pytest.param({"presence_penalty": "a"}, "", marks=pytest.mark.skip), + ({"frequency_penalty": 0}, ""), + ({"frequency_penalty": 1}, ""), + pytest.param({"frequency_penalty": -1}, "", marks=pytest.mark.skip), + pytest.param({"frequency_penalty": 10}, "", marks=pytest.mark.skip), + pytest.param({"frequency_penalty": "a"}, "", marks=pytest.mark.skip), + ({"max_token": 0}, ""), + ({"max_token": 1024}, ""), + pytest.param({"max_token": -1}, "", marks=pytest.mark.skip), + pytest.param({"max_token": 10}, "", marks=pytest.mark.skip), + pytest.param({"max_token": "a"}, "", marks=pytest.mark.skip), + pytest.param({"unknown": "unknown"}, "", marks=pytest.mark.skip), + ], + ) + def test_llm(self, client, add_chunks, llm, expected_message): + dataset, _, _ = add_chunks + llm_o = Chat.LLM(client, llm) + + if expected_message: + with pytest.raises(Exception) as excinfo: + client.create_chat(name="llm_test", dataset_ids=[dataset.id], llm=llm_o) + assert expected_message in str(excinfo.value) + else: + chat_assistant = client.create_chat(name="llm_test", dataset_ids=[dataset.id], llm=llm_o) + if llm: + for k, v in llm.items(): + assert attrgetter(k)(chat_assistant.llm) == v + else: + assert attrgetter("model_name")(chat_assistant.llm) == "glm-4-flash@ZHIPU-AI" + assert attrgetter("temperature")(chat_assistant.llm) == 0.1 + assert attrgetter("top_p")(chat_assistant.llm) == 0.3 + assert attrgetter("presence_penalty")(chat_assistant.llm) == 0.4 + assert attrgetter("frequency_penalty")(chat_assistant.llm) == 0.7 + assert attrgetter("max_tokens")(chat_assistant.llm) == 512 + + @pytest.mark.p2 + @pytest.mark.parametrize( + "prompt, expected_message", + [ + ({"similarity_threshold": 0}, ""), + ({"similarity_threshold": 1}, ""), + pytest.param({"similarity_threshold": -1}, "", marks=pytest.mark.skip), + pytest.param({"similarity_threshold": 10}, "", marks=pytest.mark.skip), + pytest.param({"similarity_threshold": "a"}, "", marks=pytest.mark.skip), + ({"keywords_similarity_weight": 0}, ""), + ({"keywords_similarity_weight": 1}, ""), + pytest.param({"keywords_similarity_weight": -1}, "", marks=pytest.mark.skip), + pytest.param({"keywords_similarity_weight": 10}, "", marks=pytest.mark.skip), + pytest.param({"keywords_similarity_weight": "a"}, "", marks=pytest.mark.skip), + ({"variables": []}, ""), + ({"top_n": 0}, ""), + ({"top_n": 1}, ""), + pytest.param({"top_n": -1}, "", marks=pytest.mark.skip), + pytest.param({"top_n": 10}, "", marks=pytest.mark.skip), + pytest.param({"top_n": "a"}, "", marks=pytest.mark.skip), + ({"empty_response": "Hello World"}, ""), + ({"empty_response": ""}, ""), + ({"empty_response": "!@#$%^&*()"}, ""), + ({"empty_response": "中文测试"}, ""), + pytest.param({"empty_response": 123}, "", marks=pytest.mark.skip), + pytest.param({"empty_response": True}, "", marks=pytest.mark.skip), + pytest.param({"empty_response": " "}, "", marks=pytest.mark.skip), + ({"opener": "Hello World"}, ""), + ({"opener": ""}, ""), + ({"opener": "!@#$%^&*()"}, ""), + ({"opener": "中文测试"}, ""), + pytest.param({"opener": 123}, "", marks=pytest.mark.skip), + pytest.param({"opener": True}, "", marks=pytest.mark.skip), + pytest.param({"opener": " "}, "", marks=pytest.mark.skip), + ({"show_quote": True}, ""), + ({"show_quote": False}, ""), + ({"prompt": "Hello World {knowledge}"}, ""), + ({"prompt": "{knowledge}"}, ""), + ({"prompt": "!@#$%^&*() {knowledge}"}, ""), + ({"prompt": "中文测试 {knowledge}"}, ""), + ({"prompt": "Hello World"}, ""), + ({"prompt": "Hello World", "variables": []}, ""), + pytest.param({"prompt": 123}, """AttributeError("\'int\' object has no attribute \'find\'")""", marks=pytest.mark.skip), + pytest.param({"prompt": True}, """AttributeError("\'int\' object has no attribute \'find\'")""", marks=pytest.mark.skip), + pytest.param({"unknown": "unknown"}, "", marks=pytest.mark.skip), + ], + ) + def test_prompt(self, client, add_chunks, prompt, expected_message): + dataset, _, _ = add_chunks + prompt_o = Chat.Prompt(client, prompt) + + if expected_message: + with pytest.raises(Exception) as excinfo: + client.create_chat(name="prompt_test", dataset_ids=[dataset.id], prompt=prompt_o) + assert expected_message in str(excinfo.value) + else: + chat_assistant = client.create_chat(name="prompt_test", dataset_ids=[dataset.id], prompt=prompt_o) + if prompt: + for k, v in prompt.items(): + if k == "keywords_similarity_weight": + assert attrgetter(k)(chat_assistant.prompt) == 1 - v + else: + assert attrgetter(k)(chat_assistant.prompt) == v + else: + assert attrgetter("similarity_threshold")(chat_assistant.prompt) == 0.2 + assert attrgetter("keywords_similarity_weight")(chat_assistant.prompt) == 0.7 + assert attrgetter("top_n")(chat_assistant.prompt) == 6 + assert attrgetter("variables")(chat_assistant.prompt) == [{"key": "knowledge", "optional": False}] + assert attrgetter("rerank_model")(chat_assistant.prompt) == "" + assert attrgetter("empty_response")(chat_assistant.prompt) == "Sorry! No relevant content was found in the knowledge base!" + assert attrgetter("opener")(chat_assistant.prompt) == "Hi! I'm your assistant, what can I do for you?" + assert attrgetter("show_quote")(chat_assistant.prompt) is True + assert ( + attrgetter("prompt")(chat_assistant.prompt) + == 'You are an intelligent assistant. Please summarize the content of the knowledge base to answer the question. Please list the data in the knowledge base and answer in detail. When all knowledge base content is irrelevant to the question, your answer must include the sentence "The answer you are looking for is not found in the knowledge base!" Answers need to consider chat history.\n Here is the knowledge base:\n {knowledge}\n The above is the knowledge base.' + ) + + +class TestChatAssistantCreate2: + @pytest.mark.p2 + def test_unparsed_document(self, client, add_document): + dataset, _ = add_document + with pytest.raises(Exception) as excinfo: + client.create_chat(name="prompt_test", dataset_ids=[dataset.id]) + assert "doesn't own parsed file" in str(excinfo.value) diff --git a/test/testcases/test_sdk_api/test_chat_assistant_management/test_delete_chat_assistants.py b/test/testcases/test_sdk_api/test_chat_assistant_management/test_delete_chat_assistants.py new file mode 100644 index 00000000000..db0c39ff82d --- /dev/null +++ b/test/testcases/test_sdk_api/test_chat_assistant_management/test_delete_chat_assistants.py @@ -0,0 +1,105 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from concurrent.futures import ThreadPoolExecutor, as_completed + +import pytest +from common import batch_create_chat_assistants + + +class TestChatAssistantsDelete: + @pytest.mark.parametrize( + "payload, expected_message, remaining", + [ + pytest.param(None, "", 0, marks=pytest.mark.p3), + pytest.param({"ids": []}, "", 0, marks=pytest.mark.p3), + pytest.param({"ids": ["invalid_id"]}, "Assistant(invalid_id) not found.", 5, marks=pytest.mark.p3), + pytest.param({"ids": ["\n!?。;!?\"'"]}, """Assistant(\n!?。;!?"\') not found.""", 5, marks=pytest.mark.p3), + pytest.param(lambda r: {"ids": r[:1]}, "", 4, marks=pytest.mark.p3), + pytest.param(lambda r: {"ids": r}, "", 0, marks=pytest.mark.p1), + ], + ) + def test_basic_scenarios(self, client, add_chat_assistants_func, payload, expected_message, remaining): + _, _, chat_assistants = add_chat_assistants_func + if callable(payload): + payload = payload([chat_assistant.id for chat_assistant in chat_assistants]) + + if expected_message: + with pytest.raises(Exception) as excinfo: + client.delete_chats(**payload) + assert expected_message in str(excinfo.value) + else: + if payload is None: + client.delete_chats(payload) + else: + client.delete_chats(**payload) + + assistants = client.list_chats() + assert len(assistants) == remaining + + @pytest.mark.parametrize( + "payload", + [ + pytest.param(lambda r: {"ids": ["invalid_id"] + r}, marks=pytest.mark.p3), + pytest.param(lambda r: {"ids": r[:1] + ["invalid_id"] + r[1:5]}, marks=pytest.mark.p1), + pytest.param(lambda r: {"ids": r + ["invalid_id"]}, marks=pytest.mark.p3), + ], + ) + def test_delete_partial_invalid_id(self, client, add_chat_assistants_func, payload): + _, _, chat_assistants = add_chat_assistants_func + payload = payload([chat_assistant.id for chat_assistant in chat_assistants]) + client.delete_chats(**payload) + + assistants = client.list_chats() + assert len(assistants) == 0 + + @pytest.mark.p3 + def test_repeated_deletion(self, client, add_chat_assistants_func): + _, _, chat_assistants = add_chat_assistants_func + chat_ids = [chat.id for chat in chat_assistants] + client.delete_chats(ids=chat_ids) + + with pytest.raises(Exception) as excinfo: + client.delete_chats(ids=chat_ids) + assert "not found" in str(excinfo.value) + + @pytest.mark.p3 + def test_duplicate_deletion(self, client, add_chat_assistants_func): + _, _, chat_assistants = add_chat_assistants_func + chat_ids = [chat.id for chat in chat_assistants] + client.delete_chats(ids=chat_ids + chat_ids) + + assistants = client.list_chats() + assert len(assistants) == 0 + + @pytest.mark.p3 + def test_concurrent_deletion(self, client): + count = 100 + chat_ids = [client.create_chat(name=f"test_{i}").id for i in range(count)] + + with ThreadPoolExecutor(max_workers=5) as executor: + futures = [executor.submit(client.delete_chats, ids=[chat_ids[i]]) for i in range(count)] + responses = list(as_completed(futures)) + + assert len(responses) == count + assert all(future.exception() is None for future in futures) + + @pytest.mark.p3 + def test_delete_1k(self, client): + chat_assistants = batch_create_chat_assistants(client, 1_000) + client.delete_chats(ids=[chat_assistants.id for chat_assistants in chat_assistants]) + + assistants = client.list_chats() + assert len(assistants) == 0 diff --git a/test/testcases/test_sdk_api/test_chat_assistant_management/test_list_chat_assistants.py b/test/testcases/test_sdk_api/test_chat_assistant_management/test_list_chat_assistants.py new file mode 100644 index 00000000000..d79a4f55ddf --- /dev/null +++ b/test/testcases/test_sdk_api/test_chat_assistant_management/test_list_chat_assistants.py @@ -0,0 +1,224 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from concurrent.futures import ThreadPoolExecutor, as_completed + +import pytest + + +@pytest.mark.usefixtures("add_chat_assistants") +class TestChatAssistantsList: + @pytest.mark.p1 + def test_default(self, client): + assistants = client.list_chats() + assert len(assistants) == 5 + + @pytest.mark.p1 + @pytest.mark.parametrize( + "params, expected_page_size, expected_message", + [ + ({"page": 0, "page_size": 2}, 2, ""), + ({"page": 2, "page_size": 2}, 2, ""), + ({"page": 3, "page_size": 2}, 1, ""), + ({"page": "3", "page_size": 2}, 0, "not instance of"), + pytest.param( + {"page": -1, "page_size": 2}, + 0, + "1064", + marks=pytest.mark.skip(reason="issues/5851"), + ), + pytest.param( + {"page": "a", "page_size": 2}, + 0, + """ValueError("invalid literal for int() with base 10: \'a\'")""", + marks=pytest.mark.skip(reason="issues/5851"), + ), + ], + ) + def test_page(self, client, params, expected_page_size, expected_message): + if expected_message: + with pytest.raises(Exception) as excinfo: + client.list_chats(**params) + assert expected_message in str(excinfo.value) + else: + assistants = client.list_chats(**params) + assert len(assistants) == expected_page_size + + @pytest.mark.p1 + @pytest.mark.parametrize( + "params, expected_page_size, expected_message", + [ + ({"page_size": 0}, 0, ""), + ({"page_size": 1}, 1, ""), + ({"page_size": 6}, 5, ""), + ({"page_size": "1"}, 0, "not instance of"), + pytest.param( + {"page_size": -1}, + 0, + "1064", + marks=pytest.mark.skip(reason="issues/5851"), + ), + pytest.param( + {"page_size": "a"}, + 0, + """ValueError("invalid literal for int() with base 10: \'a\'")""", + marks=pytest.mark.skip(reason="issues/5851"), + ), + ], + ) + def test_page_size(self, client, params, expected_page_size, expected_message): + if expected_message: + with pytest.raises(Exception) as excinfo: + client.list_chats(**params) + assert expected_message in str(excinfo.value) + else: + assistants = client.list_chats(**params) + assert len(assistants) == expected_page_size + + @pytest.mark.p3 + @pytest.mark.parametrize( + "params, expected_message", + [ + ({"orderby": "create_time"}, ""), + ({"orderby": "update_time"}, ""), + pytest.param({"orderby": "name", "desc": "False"}, "", marks=pytest.mark.skip(reason="issues/5851")), + pytest.param({"orderby": "unknown"}, "orderby should be create_time or update_time", marks=pytest.mark.skip(reason="issues/5851")), + ], + ) + def test_orderby(self, client, params, expected_message): + if expected_message: + with pytest.raises(Exception) as excinfo: + client.list_chats(**params) + assert expected_message in str(excinfo.value) + else: + client.list_chats(**params) + + @pytest.mark.p3 + @pytest.mark.parametrize( + "params, expected_message", + [ + ({"desc": None}, "not instance of"), + ({"desc": "true"}, "not instance of"), + ({"desc": "True"}, "not instance of"), + ({"desc": True}, ""), + ({"desc": "false"}, "not instance of"), + ({"desc": "False"}, "not instance of"), + ({"desc": False}, ""), + ({"desc": "False", "orderby": "update_time"}, "not instance of"), + pytest.param( + {"desc": "unknown"}, + "desc should be true or false", + marks=pytest.mark.skip(reason="issues/5851"), + ), + ], + ) + def test_desc(self, client, params, expected_message): + if expected_message: + with pytest.raises(Exception) as excinfo: + client.list_chats(**params) + assert expected_message in str(excinfo.value) + else: + client.list_chats(**params) + + @pytest.mark.p1 + @pytest.mark.parametrize( + "params, expected_num, expected_message", + [ + ({"name": None}, 5, ""), + ({"name": ""}, 5, ""), + ({"name": "test_chat_assistant_1"}, 1, ""), + ({"name": "unknown"}, 0, "The chat doesn't exist"), + ], + ) + def test_name(self, client, params, expected_num, expected_message): + if expected_message: + with pytest.raises(Exception) as excinfo: + client.list_chats(**params) + assert expected_message in str(excinfo.value) + else: + assistants = client.list_chats(**params) + if params["name"] in [None, ""]: + assert len(assistants) == expected_num + else: + assert assistants[0].name == params["name"] + + @pytest.mark.p1 + @pytest.mark.parametrize( + "chat_assistant_id, expected_num, expected_message", + [ + (None, 5, ""), + ("", 5, ""), + (lambda r: r[0], 1, ""), + ("unknown", 0, "The chat doesn't exist"), + ], + ) + def test_id(self, client, add_chat_assistants, chat_assistant_id, expected_num, expected_message): + _, _, chat_assistants = add_chat_assistants + if callable(chat_assistant_id): + params = {"id": chat_assistant_id([chat.id for chat in chat_assistants])} + else: + params = {"id": chat_assistant_id} + + if expected_message: + with pytest.raises(Exception) as excinfo: + client.list_chats(**params) + assert expected_message in str(excinfo.value) + else: + assistants = client.list_chats(**params) + if params["id"] in [None, ""]: + assert len(assistants) == expected_num + else: + assert assistants[0].id == params["id"] + + @pytest.mark.p3 + @pytest.mark.parametrize( + "chat_assistant_id, name, expected_num, expected_message", + [ + (lambda r: r[0], "test_chat_assistant_0", 1, ""), + (lambda r: r[0], "test_chat_assistant_1", 0, "The chat doesn't exist"), + (lambda r: r[0], "unknown", 0, "The chat doesn't exist"), + ("id", "chat_assistant_0", 0, "The chat doesn't exist"), + ], + ) + def test_name_and_id(self, client, add_chat_assistants, chat_assistant_id, name, expected_num, expected_message): + _, _, chat_assistants = add_chat_assistants + if callable(chat_assistant_id): + params = {"id": chat_assistant_id([chat.id for chat in chat_assistants]), "name": name} + else: + params = {"id": chat_assistant_id, "name": name} + + if expected_message: + with pytest.raises(Exception) as excinfo: + client.list_chats(**params) + assert expected_message in str(excinfo.value) + else: + assistants = client.list_chats(**params) + assert len(assistants) == expected_num + + @pytest.mark.p3 + def test_concurrent_list(self, client): + count = 100 + with ThreadPoolExecutor(max_workers=5) as executor: + futures = [executor.submit(client.list_chats) for _ in range(count)] + responses = list(as_completed(futures)) + assert len(responses) == count, responses + + @pytest.mark.p2 + def test_list_chats_after_deleting_associated_dataset(self, client, add_chat_assistants): + dataset, _, _ = add_chat_assistants + client.delete_datasets(ids=[dataset.id]) + + assistants = client.list_chats() + assert len(assistants) == 5 diff --git a/test/testcases/test_sdk_api/test_chat_assistant_management/test_update_chat_assistant.py b/test/testcases/test_sdk_api/test_chat_assistant_management/test_update_chat_assistant.py new file mode 100644 index 00000000000..805460d5fc9 --- /dev/null +++ b/test/testcases/test_sdk_api/test_chat_assistant_management/test_update_chat_assistant.py @@ -0,0 +1,208 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from operator import attrgetter + +import pytest +from configs import CHAT_ASSISTANT_NAME_LIMIT +from ragflow_sdk import Chat +from utils import encode_avatar +from utils.file_utils import create_image_file + + +class TestChatAssistantUpdate: + @pytest.mark.parametrize( + "payload, expected_message", + [ + pytest.param({"name": "valid_name"}, "", marks=pytest.mark.p1), + pytest.param({"name": "a" * (CHAT_ASSISTANT_NAME_LIMIT + 1)}, "", marks=pytest.mark.skip(reason="issues/")), + pytest.param({"name": 1}, "", marks=pytest.mark.skip(reason="issues/")), + pytest.param({"name": ""}, "`name` cannot be empty.", marks=pytest.mark.p3), + pytest.param({"name": "test_chat_assistant_1"}, "Duplicated chat name in updating chat.", marks=pytest.mark.p3), + pytest.param({"name": "TEST_CHAT_ASSISTANT_1"}, "Duplicated chat name in updating chat.", marks=pytest.mark.p3), + ], + ) + def test_name(self, client, add_chat_assistants_func, payload, expected_message): + _, _, chat_assistants = add_chat_assistants_func + chat_assistant = chat_assistants[0] + + if expected_message: + with pytest.raises(Exception) as excinfo: + chat_assistant.update(payload) + assert expected_message in str(excinfo.value) + else: + chat_assistant.update(payload) + updated_chat = client.list_chats(id=chat_assistant.id)[0] + assert updated_chat.name == payload["name"], str(updated_chat) + + @pytest.mark.p3 + def test_avatar(self, client, add_chat_assistants_func, tmp_path): + dataset, _, chat_assistants = add_chat_assistants_func + chat_assistant = chat_assistants[0] + + fn = create_image_file(tmp_path / "ragflow_test.png") + payload = {"name": "avatar_test", "avatar": encode_avatar(fn), "dataset_ids": [dataset.id]} + + chat_assistant.update(payload) + updated_chat = client.list_chats(id=chat_assistant.id)[0] + assert updated_chat.name == payload["name"], str(updated_chat) + assert updated_chat.avatar is not None, str(updated_chat) + + @pytest.mark.p3 + @pytest.mark.parametrize( + "llm, expected_message", + [ + ({}, "ValueError"), + ({"model_name": "glm-4"}, ""), + ({"model_name": "unknown"}, "`model_name` unknown doesn't exist"), + ({"temperature": 0}, ""), + ({"temperature": 1}, ""), + pytest.param({"temperature": -1}, "", marks=pytest.mark.skip), + pytest.param({"temperature": 10}, "", marks=pytest.mark.skip), + pytest.param({"temperature": "a"}, "", marks=pytest.mark.skip), + ({"top_p": 0}, ""), + ({"top_p": 1}, ""), + pytest.param({"top_p": -1}, "", marks=pytest.mark.skip), + pytest.param({"top_p": 10}, "", marks=pytest.mark.skip), + pytest.param({"top_p": "a"}, "", marks=pytest.mark.skip), + ({"presence_penalty": 0}, ""), + ({"presence_penalty": 1}, ""), + pytest.param({"presence_penalty": -1}, "", marks=pytest.mark.skip), + pytest.param({"presence_penalty": 10}, "", marks=pytest.mark.skip), + pytest.param({"presence_penalty": "a"}, "", marks=pytest.mark.skip), + ({"frequency_penalty": 0}, ""), + ({"frequency_penalty": 1}, ""), + pytest.param({"frequency_penalty": -1}, "", marks=pytest.mark.skip), + pytest.param({"frequency_penalty": 10}, "", marks=pytest.mark.skip), + pytest.param({"frequency_penalty": "a"}, "", marks=pytest.mark.skip), + ({"max_token": 0}, ""), + ({"max_token": 1024}, ""), + pytest.param({"max_token": -1}, "", marks=pytest.mark.skip), + pytest.param({"max_token": 10}, "", marks=pytest.mark.skip), + pytest.param({"max_token": "a"}, "", marks=pytest.mark.skip), + pytest.param({"unknown": "unknown"}, "", marks=pytest.mark.skip), + ], + ) + def test_llm(self, client, add_chat_assistants_func, llm, expected_message): + dataset, _, chat_assistants = add_chat_assistants_func + chat_assistant = chat_assistants[0] + payload = {"name": "llm_test", "llm": llm, "dataset_ids": [dataset.id]} + + if expected_message: + with pytest.raises(Exception) as excinfo: + chat_assistant.update(payload) + assert expected_message in str(excinfo.value) + else: + chat_assistant.update(payload) + updated_chat = client.list_chats(id=chat_assistant.id)[0] + if llm: + for k, v in llm.items(): + assert attrgetter(k)(updated_chat.llm) == v, str(updated_chat) + else: + excepted_value = Chat.LLM( + client, + { + "model_name": "glm-4-flash@ZHIPU-AI", + "temperature": 0.1, + "top_p": 0.3, + "presence_penalty": 0.4, + "frequency_penalty": 0.7, + "max_tokens": 512, + }, + ) + assert str(updated_chat.llm) == str(excepted_value), str(updated_chat) + + @pytest.mark.p3 + @pytest.mark.parametrize( + "prompt, expected_message", + [ + ({}, "ValueError"), + ({"similarity_threshold": 0}, ""), + ({"similarity_threshold": 1}, ""), + pytest.param({"similarity_threshold": -1}, "", marks=pytest.mark.skip), + pytest.param({"similarity_threshold": 10}, "", marks=pytest.mark.skip), + pytest.param({"similarity_threshold": "a"}, "", marks=pytest.mark.skip), + ({"keywords_similarity_weight": 0}, ""), + ({"keywords_similarity_weight": 1}, ""), + pytest.param({"keywords_similarity_weight": -1}, "", marks=pytest.mark.skip), + pytest.param({"keywords_similarity_weight": 10}, "", marks=pytest.mark.skip), + pytest.param({"keywords_similarity_weight": "a"}, "", marks=pytest.mark.skip), + ({"variables": []}, ""), + ({"top_n": 0}, ""), + ({"top_n": 1}, ""), + pytest.param({"top_n": -1}, "", marks=pytest.mark.skip), + pytest.param({"top_n": 10}, "", marks=pytest.mark.skip), + pytest.param({"top_n": "a"}, "", marks=pytest.mark.skip), + ({"empty_response": "Hello World"}, ""), + ({"empty_response": ""}, ""), + ({"empty_response": "!@#$%^&*()"}, ""), + ({"empty_response": "中文测试"}, ""), + pytest.param({"empty_response": 123}, "", marks=pytest.mark.skip), + pytest.param({"empty_response": True}, "", marks=pytest.mark.skip), + pytest.param({"empty_response": " "}, "", marks=pytest.mark.skip), + ({"opener": "Hello World"}, ""), + ({"opener": ""}, ""), + ({"opener": "!@#$%^&*()"}, ""), + ({"opener": "中文测试"}, ""), + pytest.param({"opener": 123}, "", marks=pytest.mark.skip), + pytest.param({"opener": True}, "", marks=pytest.mark.skip), + pytest.param({"opener": " "}, "", marks=pytest.mark.skip), + ({"show_quote": True}, ""), + ({"show_quote": False}, ""), + ({"prompt": "Hello World {knowledge}"}, ""), + ({"prompt": "{knowledge}"}, ""), + ({"prompt": "!@#$%^&*() {knowledge}"}, ""), + ({"prompt": "中文测试 {knowledge}"}, ""), + ({"prompt": "Hello World"}, ""), + ({"prompt": "Hello World", "variables": []}, ""), + pytest.param({"prompt": 123}, """AttributeError("\'int\' object has no attribute \'find\'")""", marks=pytest.mark.skip), + pytest.param({"prompt": True}, """AttributeError("\'int\' object has no attribute \'find\'")""", marks=pytest.mark.skip), + pytest.param({"unknown": "unknown"}, "", marks=pytest.mark.skip), + ], + ) + def test_prompt(self, client, add_chat_assistants_func, prompt, expected_message): + dataset, _, chat_assistants = add_chat_assistants_func + chat_assistant = chat_assistants[0] + payload = {"name": "prompt_test", "prompt": prompt, "dataset_ids": [dataset.id]} + + if expected_message: + with pytest.raises(Exception) as excinfo: + chat_assistant.update(payload) + assert expected_message in str(excinfo.value) + else: + chat_assistant.update(payload) + updated_chat = client.list_chats(id=chat_assistant.id)[0] + if prompt: + for k, v in prompt.items(): + if k == "keywords_similarity_weight": + assert attrgetter(k)(updated_chat.prompt) == 1 - v, str(updated_chat) + else: + assert attrgetter(k)(updated_chat.prompt) == v, str(updated_chat) + else: + excepted_value = Chat.LLM( + client, + { + "similarity_threshold": 0.2, + "keywords_similarity_weight": 0.7, + "top_n": 6, + "variables": [{"key": "knowledge", "optional": False}], + "rerank_model": "", + "empty_response": "Sorry! No relevant content was found in the knowledge base!", + "opener": "Hi! I'm your assistant, what can I do for you?", + "show_quote": True, + "prompt": 'You are an intelligent assistant. Please summarize the content of the knowledge base to answer the question. Please list the data in the knowledge base and answer in detail. When all knowledge base content is irrelevant to the question, your answer must include the sentence "The answer you are looking for is not found in the knowledge base!" Answers need to consider chat history.\n Here is the knowledge base:\n {knowledge}\n The above is the knowledge base.', + }, + ) + assert str(updated_chat.prompt) == str(excepted_value), str(updated_chat) diff --git a/test/testcases/test_sdk_api/test_chunk_management_within_dataset/conftest.py b/test/testcases/test_sdk_api/test_chunk_management_within_dataset/conftest.py new file mode 100644 index 00000000000..d9ed678387f --- /dev/null +++ b/test/testcases/test_sdk_api/test_chunk_management_within_dataset/conftest.py @@ -0,0 +1,52 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +from time import sleep + +import pytest +from common import batch_add_chunks +from pytest import FixtureRequest +from ragflow_sdk import Chunk, DataSet, Document +from utils import wait_for + + +@wait_for(30, 1, "Document parsing timeout") +def condition(_dataset: DataSet): + documents = _dataset.list_documents(page_size=1000) + for document in documents: + if document.run != "DONE": + return False + return True + + +@pytest.fixture(scope="function") +def add_chunks_func(request: FixtureRequest, add_document: tuple[DataSet, Document]) -> tuple[DataSet, Document, list[Chunk]]: + def cleanup(): + try: + document.delete_chunks(ids=[]) + except Exception: + pass + + request.addfinalizer(cleanup) + + dataset, document = add_document + dataset.async_parse_documents([document.id]) + condition(dataset) + chunks = batch_add_chunks(document, 4) + # issues/6487 + sleep(1) + return dataset, document, chunks diff --git a/test/testcases/test_sdk_api/test_chunk_management_within_dataset/test_add_chunk.py b/test/testcases/test_sdk_api/test_chunk_management_within_dataset/test_add_chunk.py new file mode 100644 index 00000000000..5d1638db215 --- /dev/null +++ b/test/testcases/test_sdk_api/test_chunk_management_within_dataset/test_add_chunk.py @@ -0,0 +1,160 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from concurrent.futures import ThreadPoolExecutor, as_completed +from time import sleep + +import pytest +from ragflow_sdk import Chunk + + +def validate_chunk_details(dataset_id: str, document_id: str, payload: dict, chunk: Chunk): + assert chunk.dataset_id == dataset_id + assert chunk.document_id == document_id + assert chunk.content == payload["content"] + if "important_keywords" in payload: + assert chunk.important_keywords == payload["important_keywords"] + if "questions" in payload: + assert chunk.questions == [str(q).strip() for q in payload.get("questions", []) if str(q).strip()] + + +class TestAddChunk: + @pytest.mark.p1 + @pytest.mark.parametrize( + "payload, expected_message", + [ + ({"content": None}, "not instance of"), + ({"content": ""}, "`content` is required"), + ({"content": 1}, "not instance of"), + ({"content": "a"}, ""), + ({"content": " "}, "`content` is required"), + ({"content": "\n!?。;!?\"'"}, ""), + ], + ) + def test_content(self, add_document, payload, expected_message): + dataset, document = add_document + chunks_count = len(document.list_chunks()) + + if expected_message: + with pytest.raises(Exception) as excinfo: + document.add_chunk(**payload) + assert expected_message in str(excinfo.value), str(excinfo.value) + else: + chunk = document.add_chunk(**payload) + validate_chunk_details(dataset.id, document.id, payload, chunk) + + sleep(1) + chunks = document.list_chunks() + assert len(chunks) == chunks_count + 1, str(chunks) + + @pytest.mark.p2 + @pytest.mark.parametrize( + "payload, expected_message", + [ + ({"content": "chunk test important_keywords 1", "important_keywords": ["a", "b", "c"]}, ""), + ({"content": "chunk test important_keywords 2", "important_keywords": [""]}, ""), + ({"content": "chunk test important_keywords 3", "important_keywords": [1]}, "not instance of"), + ({"content": "chunk test important_keywords 4", "important_keywords": ["a", "a"]}, ""), + ({"content": "chunk test important_keywords 5", "important_keywords": "abc"}, "not instance of"), + ({"content": "chunk test important_keywords 6", "important_keywords": 123}, "not instance of"), + ], + ) + def test_important_keywords(self, add_document, payload, expected_message): + dataset, document = add_document + chunks_count = len(document.list_chunks()) + + if expected_message: + with pytest.raises(Exception) as excinfo: + document.add_chunk(**payload) + assert expected_message in str(excinfo.value), str(excinfo.value) + else: + chunk = document.add_chunk(**payload) + validate_chunk_details(dataset.id, document.id, payload, chunk) + + sleep(1) + chunks = document.list_chunks() + assert len(chunks) == chunks_count + 1, str(chunks) + + @pytest.mark.p2 + @pytest.mark.parametrize( + "payload, expected_message", + [ + ({"content": "chunk test test_questions 1", "questions": ["a", "b", "c"]}, ""), + ({"content": "chunk test test_questions 2", "questions": [""]}, ""), + ({"content": "chunk test test_questions 3", "questions": [1]}, "not instance of"), + ({"content": "chunk test test_questions 4", "questions": ["a", "a"]}, ""), + ({"content": "chunk test test_questions 5", "questions": "abc"}, "not instance of"), + ({"content": "chunk test test_questions 6", "questions": 123}, "not instance of"), + ], + ) + def test_questions(self, add_document, payload, expected_message): + dataset, document = add_document + chunks_count = len(document.list_chunks()) + + if expected_message: + with pytest.raises(Exception) as excinfo: + document.add_chunk(**payload) + assert expected_message in str(excinfo.value), str(excinfo.value) + else: + chunk = document.add_chunk(**payload) + validate_chunk_details(dataset.id, document.id, payload, chunk) + + sleep(1) + chunks = document.list_chunks() + assert len(chunks) == chunks_count + 1, str(chunks) + + @pytest.mark.p3 + def test_repeated_add_chunk(self, add_document): + payload = {"content": "chunk test repeated_add_chunk"} + dataset, document = add_document + chunks_count = len(document.list_chunks()) + + chunk1 = document.add_chunk(**payload) + validate_chunk_details(dataset.id, document.id, payload, chunk1) + sleep(1) + chunks = document.list_chunks() + assert len(chunks) == chunks_count + 1, str(chunks) + + chunk2 = document.add_chunk(**payload) + validate_chunk_details(dataset.id, document.id, payload, chunk2) + sleep(1) + chunks = document.list_chunks() + assert len(chunks) == chunks_count + 1, str(chunks) + + @pytest.mark.p2 + def test_add_chunk_to_deleted_document(self, add_document): + dataset, document = add_document + dataset.delete_documents(ids=[document.id]) + + with pytest.raises(Exception) as excinfo: + document.add_chunk(content="chunk test") + assert f"You don't own the document {document.id}" in str(excinfo.value), str(excinfo.value) + + @pytest.mark.skip(reason="issues/6411") + @pytest.mark.p3 + def test_concurrent_add_chunk(self, add_document): + count = 50 + _, document = add_document + initial_chunk_count = len(document.list_chunks()) + + def add_chunk_task(i): + return document.add_chunk(content=f"chunk test concurrent {i}") + + with ThreadPoolExecutor(max_workers=5) as executor: + futures = [executor.submit(add_chunk_task, i) for i in range(count)] + responses = list(as_completed(futures)) + assert len(responses) == count, responses + sleep(5) + assert len(document.list_chunks(page_size=100)) == initial_chunk_count + count diff --git a/test/testcases/test_sdk_api/test_chunk_management_within_dataset/test_delete_chunks.py b/test/testcases/test_sdk_api/test_chunk_management_within_dataset/test_delete_chunks.py new file mode 100644 index 00000000000..25aac7b88b8 --- /dev/null +++ b/test/testcases/test_sdk_api/test_chunk_management_within_dataset/test_delete_chunks.py @@ -0,0 +1,113 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from concurrent.futures import ThreadPoolExecutor, as_completed + +import pytest +from common import batch_add_chunks + + +class TestChunksDeletion: + @pytest.mark.parametrize( + "payload", + [ + pytest.param(lambda r: {"ids": ["invalid_id"] + r}, marks=pytest.mark.p3), + pytest.param(lambda r: {"ids": r[:1] + ["invalid_id"] + r[1:4]}, marks=pytest.mark.p1), + pytest.param(lambda r: {"ids": r + ["invalid_id"]}, marks=pytest.mark.p3), + ], + ) + def test_delete_partial_invalid_id(self, add_chunks_func, payload): + _, document, chunks = add_chunks_func + chunk_ids = [chunk.id for chunk in chunks] + payload = payload(chunk_ids) + + with pytest.raises(Exception) as excinfo: + document.delete_chunks(**payload) + assert "rm_chunk deleted chunks" in str(excinfo.value), str(excinfo.value) + + remaining_chunks = document.list_chunks() + assert len(remaining_chunks) == 1, str(remaining_chunks) + + @pytest.mark.p3 + def test_repeated_deletion(self, add_chunks_func): + _, document, chunks = add_chunks_func + chunk_ids = [chunk.id for chunk in chunks] + document.delete_chunks(ids=chunk_ids) + + with pytest.raises(Exception) as excinfo: + document.delete_chunks(ids=chunk_ids) + assert "rm_chunk deleted chunks 0, expect" in str(excinfo.value), str(excinfo.value) + + @pytest.mark.p3 + def test_duplicate_deletion(self, add_chunks_func): + _, document, chunks = add_chunks_func + chunk_ids = [chunk.id for chunk in chunks] + document.delete_chunks(ids=chunk_ids * 2) + remaining_chunks = document.list_chunks() + assert len(remaining_chunks) == 1, str(remaining_chunks) + + @pytest.mark.p3 + def test_concurrent_deletion(self, add_document): + count = 100 + _, document = add_document + chunks = batch_add_chunks(document, count) + chunk_ids = [chunk.id for chunk in chunks] + + with ThreadPoolExecutor(max_workers=5) as executor: + futures = [executor.submit(document.delete_chunks, ids=[chunk_id]) for chunk_id in chunk_ids] + responses = list(as_completed(futures)) + assert len(responses) == count, responses + + @pytest.mark.p3 + def test_delete_1k(self, add_document): + count = 1_000 + _, document = add_document + chunks = batch_add_chunks(document, count) + chunk_ids = [chunk.id for chunk in chunks] + + from time import sleep + + sleep(1) + + document.delete_chunks(ids=chunk_ids) + remaining_chunks = document.list_chunks() + assert len(remaining_chunks) == 0, str(remaining_chunks) + + @pytest.mark.parametrize( + "payload, expected_message, remaining", + [ + pytest.param(None, "TypeError", 5, marks=pytest.mark.skip), + pytest.param({"ids": ["invalid_id"]}, "rm_chunk deleted chunks 0, expect 1", 5, marks=pytest.mark.p3), + pytest.param("not json", "UnboundLocalError", 5, marks=pytest.mark.skip(reason="pull/6376")), + pytest.param(lambda r: {"ids": r[:1]}, "", 4, marks=pytest.mark.p3), + pytest.param(lambda r: {"ids": r}, "", 1, marks=pytest.mark.p1), + pytest.param({"ids": []}, "", 0, marks=pytest.mark.p3), + ], + ) + def test_basic_scenarios(self, add_chunks_func, payload, expected_message, remaining): + _, document, chunks = add_chunks_func + chunk_ids = [chunk.id for chunk in chunks] + if callable(payload): + payload = payload(chunk_ids) + + if expected_message: + with pytest.raises(Exception) as excinfo: + document.delete_chunks(**payload) + assert expected_message in str(excinfo.value), str(excinfo.value) + else: + document.delete_chunks(**payload) + + remaining_chunks = document.list_chunks() + assert len(remaining_chunks) == remaining, str(remaining_chunks) diff --git a/test/testcases/test_sdk_api/test_chunk_management_within_dataset/test_list_chunks.py b/test/testcases/test_sdk_api/test_chunk_management_within_dataset/test_list_chunks.py new file mode 100644 index 00000000000..76f9da5e052 --- /dev/null +++ b/test/testcases/test_sdk_api/test_chunk_management_within_dataset/test_list_chunks.py @@ -0,0 +1,140 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os +from concurrent.futures import ThreadPoolExecutor, as_completed + +import pytest +from common import batch_add_chunks + + +class TestChunksList: + @pytest.mark.p1 + @pytest.mark.parametrize( + "params, expected_page_size, expected_message", + [ + ({"page": None, "page_size": 2}, 2, ""), + pytest.param({"page": 0, "page_size": 2}, 0, "ValueError('Search does not support negative slicing.')", marks=pytest.mark.skip), + ({"page": 2, "page_size": 2}, 2, ""), + ({"page": 3, "page_size": 2}, 1, ""), + ({"page": "3", "page_size": 2}, 1, ""), + pytest.param({"page": -1, "page_size": 2}, 0, "ValueError('Search does not support negative slicing.')", marks=pytest.mark.skip), + pytest.param({"page": "a", "page_size": 2}, 0, """ValueError("invalid literal for int() with base 10: \'a\'")""", marks=pytest.mark.skip), + ], + ) + def test_page(self, add_chunks, params, expected_page_size, expected_message): + _, document, _ = add_chunks + + if expected_message: + with pytest.raises(Exception) as excinfo: + document.list_chunks(**params) + assert expected_message in str(excinfo.value), str(excinfo.value) + else: + chunks = document.list_chunks(**params) + assert len(chunks) == expected_page_size, str(chunks) + + @pytest.mark.p1 + @pytest.mark.parametrize( + "params, expected_page_size, expected_message", + [ + ({"page_size": None}, 5, ""), + pytest.param({"page_size": 0}, 5, "", marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="Infinity does not support page_size=0")), + pytest.param({"page_size": 0}, 0, "3013", marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") in [None, "opensearch", "elasticsearch"], reason="Infinity does not support page_size=0")), + ({"page_size": 1}, 1, ""), + ({"page_size": 6}, 5, ""), + ({"page_size": "1"}, 1, ""), + pytest.param({"page_size": -1}, 5, "", marks=pytest.mark.skip), + pytest.param({"page_size": "a"}, 0, """ValueError("invalid literal for int() with base 10: \'a\'")""", marks=pytest.mark.skip), + ], + ) + def test_page_size(self, add_chunks, params, expected_page_size, expected_message): + _, document, _ = add_chunks + + if expected_message: + with pytest.raises(Exception) as excinfo: + document.list_chunks(**params) + assert expected_message in str(excinfo.value), str(excinfo.value) + else: + chunks = document.list_chunks(**params) + assert len(chunks) == expected_page_size, str(chunks) + + @pytest.mark.p2 + @pytest.mark.parametrize( + "params, expected_page_size", + [ + ({"keywords": None}, 5), + ({"keywords": ""}, 5), + ({"keywords": "1"}, 1), + pytest.param({"keywords": "chunk"}, 4, marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="issues/6509")), + ({"keywords": "ragflow"}, 1), + ({"keywords": "unknown"}, 0), + ], + ) + def test_keywords(self, add_chunks, params, expected_page_size): + _, document, _ = add_chunks + chunks = document.list_chunks(**params) + assert len(chunks) == expected_page_size, str(chunks) + + @pytest.mark.p1 + @pytest.mark.parametrize( + "chunk_id, expected_page_size, expected_message", + [ + (None, 5, ""), + ("", 5, ""), + pytest.param(lambda r: r[0], 1, "", marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="issues/6499")), + pytest.param("unknown", 0, """AttributeError("\'NoneType\' object has no attribute \'keys\'")""", marks=pytest.mark.skip), + ], + ) + def test_id(self, add_chunks, chunk_id, expected_page_size, expected_message): + _, document, chunks = add_chunks + chunk_ids = [chunk.id for chunk in chunks] + if callable(chunk_id): + params = {"id": chunk_id(chunk_ids)} + else: + params = {"id": chunk_id} + + if expected_message: + with pytest.raises(Exception) as excinfo: + document.list_chunks(**params) + assert expected_message in str(excinfo.value), str(excinfo.value) + else: + chunks = document.list_chunks(**params) + if params["id"] in [None, ""]: + assert len(chunks) == expected_page_size, str(chunks) + else: + assert chunks[0].id == params["id"], str(chunks) + + @pytest.mark.p3 + def test_concurrent_list(self, add_chunks): + _, document, _ = add_chunks + count = 100 + with ThreadPoolExecutor(max_workers=5) as executor: + futures = [executor.submit(document.list_chunks) for _ in range(count)] + + responses = list(as_completed(futures)) + assert len(responses) == count, responses + assert all(len(future.result()) == 5 for future in futures) + + @pytest.mark.p1 + def test_default(self, add_document): + _, document = add_document + batch_add_chunks(document, 31) + + from time import sleep + + sleep(3) + + chunks = document.list_chunks() + assert len(chunks) == 30, str(chunks) diff --git a/test/testcases/test_sdk_api/test_chunk_management_within_dataset/test_retrieval_chunks.py b/test/testcases/test_sdk_api/test_chunk_management_within_dataset/test_retrieval_chunks.py new file mode 100644 index 00000000000..e1b3fa7f04c --- /dev/null +++ b/test/testcases/test_sdk_api/test_chunk_management_within_dataset/test_retrieval_chunks.py @@ -0,0 +1,254 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os +from concurrent.futures import ThreadPoolExecutor, as_completed + +import pytest + + +class TestChunksRetrieval: + @pytest.mark.p1 + @pytest.mark.parametrize( + "payload, expected_page_size, expected_message", + [ + ({"question": "chunk", "dataset_ids": None}, 4, ""), + ({"question": "chunk", "document_ids": None}, 0, "missing 1 required positional argument"), + ({"question": "chunk", "dataset_ids": None, "document_ids": None}, 4, ""), + ({"question": "chunk"}, 0, "missing 1 required positional argument"), + ], + ) + def test_basic_scenarios(self, client, add_chunks, payload, expected_page_size, expected_message): + dataset, document, _ = add_chunks + if "dataset_ids" in payload: + payload["dataset_ids"] = [dataset.id] + if "document_ids" in payload: + payload["document_ids"] = [document.id] + + if expected_message: + with pytest.raises(Exception) as excinfo: + client.retrieve(**payload) + assert expected_message in str(excinfo.value), str(excinfo.value) + else: + chunks = client.retrieve(**payload) + assert len(chunks) == expected_page_size, str(chunks) + + @pytest.mark.p2 + @pytest.mark.parametrize( + "payload, expected_page_size, expected_message", + [ + pytest.param( + {"page": None, "page_size": 2}, + 2, + """TypeError("int() argument must be a string, a bytes-like object or a real number, not \'NoneType\'")""", + marks=pytest.mark.skip, + ), + pytest.param( + {"page": 0, "page_size": 2}, + 0, + "ValueError('Search does not support negative slicing.')", + marks=pytest.mark.skip, + ), + pytest.param({"page": 2, "page_size": 2}, 2, "", marks=pytest.mark.skip(reason="issues/6646")), + ({"page": 3, "page_size": 2}, 0, ""), + ({"page": "3", "page_size": 2}, 0, ""), + pytest.param( + {"page": -1, "page_size": 2}, + 0, + "ValueError('Search does not support negative slicing.')", + marks=pytest.mark.skip, + ), + pytest.param( + {"page": "a", "page_size": 2}, + 0, + """ValueError("invalid literal for int() with base 10: \'a\'")""", + marks=pytest.mark.skip, + ), + ], + ) + def test_page(self, client, add_chunks, payload, expected_page_size, expected_message): + dataset, _, _ = add_chunks + payload.update({"question": "chunk", "dataset_ids": [dataset.id]}) + + if expected_message: + with pytest.raises(Exception) as excinfo: + client.retrieve(**payload) + assert expected_message in str(excinfo.value), str(excinfo.value) + else: + chunks = client.retrieve(**payload) + assert len(chunks) == expected_page_size, str(chunks) + + @pytest.mark.p3 + @pytest.mark.parametrize( + "payload, expected_page_size, expected_message", + [ + pytest.param( + {"page_size": None}, + 0, + """TypeError("int() argument must be a string, a bytes-like object or a real number, not \'NoneType\'")""", + marks=pytest.mark.skip, + ), + ({"page_size": 1}, 1, ""), + ({"page_size": 5}, 4, ""), + ({"page_size": "1"}, 1, ""), + pytest.param( + {"page_size": "a"}, + 0, + """ValueError("invalid literal for int() with base 10: \'a\'")""", + marks=pytest.mark.skip, + ), + ], + ) + def test_page_size(self, client, add_chunks, payload, expected_page_size, expected_message): + dataset, _, _ = add_chunks + payload.update({"question": "chunk", "dataset_ids": [dataset.id]}) + + if expected_message: + with pytest.raises(Exception) as excinfo: + client.retrieve(**payload) + assert expected_message in str(excinfo.value), str(excinfo.value) + else: + chunks = client.retrieve(**payload) + assert len(chunks) == expected_page_size, str(chunks) + + @pytest.mark.p3 + @pytest.mark.parametrize( + "payload, expected_page_size, expected_message", + [ + ({"vector_similarity_weight": 0}, 4, ""), + ({"vector_similarity_weight": 0.5}, 4, ""), + ({"vector_similarity_weight": 10}, 4, ""), + pytest.param( + {"vector_similarity_weight": "a"}, + 0, + """ValueError("could not convert string to float: 'a'")""", + marks=pytest.mark.skip, + ), + ], + ) + def test_vector_similarity_weight(self, client, add_chunks, payload, expected_page_size, expected_message): + dataset, _, _ = add_chunks + payload.update({"question": "chunk", "dataset_ids": [dataset.id]}) + + if expected_message: + with pytest.raises(Exception) as excinfo: + client.retrieve(**payload) + assert expected_message in str(excinfo.value), str(excinfo.value) + else: + chunks = client.retrieve(**payload) + assert len(chunks) == expected_page_size, str(chunks) + + @pytest.mark.p2 + @pytest.mark.parametrize( + "payload, expected_page_size, expected_message", + [ + ({"top_k": 10}, 4, ""), + pytest.param( + {"top_k": 1}, + 4, + "", + marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") in ["infinity", "opensearch"], reason="Infinity"), + ), + pytest.param( + {"top_k": 1}, + 1, + "", + marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") in [None, "opensearch", "elasticsearch"], reason="elasticsearch"), + ), + pytest.param( + {"top_k": -1}, + 4, + "must be greater than 0", + marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") in ["infinity", "opensearch"], reason="Infinity"), + ), + pytest.param( + {"top_k": -1}, + 4, + "3014", + marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") in [None, "opensearch", "elasticsearch"], reason="elasticsearch"), + ), + pytest.param( + {"top_k": "a"}, + 0, + """ValueError("invalid literal for int() with base 10: \'a\'")""", + marks=pytest.mark.skip, + ), + ], + ) + def test_top_k(self, client, add_chunks, payload, expected_page_size, expected_message): + dataset, _, _ = add_chunks + payload.update({"question": "chunk", "dataset_ids": [dataset.id]}) + + if expected_message: + with pytest.raises(Exception) as excinfo: + client.retrieve(**payload) + assert expected_message in str(excinfo.value), str(excinfo.value) + else: + chunks = client.retrieve(**payload) + assert len(chunks) == expected_page_size, str(chunks) + + @pytest.mark.skip + @pytest.mark.parametrize( + "payload, expected_message", + [ + ({"rerank_id": "BAAI/bge-reranker-v2-m3"}, ""), + pytest.param({"rerank_id": "unknown"}, "LookupError('Model(unknown) not authorized')", marks=pytest.mark.skip), + ], + ) + def test_rerank_id(self, client, add_chunks, payload, expected_message): + dataset, _, _ = add_chunks + payload.update({"question": "chunk", "dataset_ids": [dataset.id]}) + + if expected_message: + with pytest.raises(Exception) as excinfo: + client.retrieve(**payload) + assert expected_message in str(excinfo.value), str(excinfo.value) + else: + chunks = client.retrieve(**payload) + assert len(chunks) > 0, str(chunks) + + @pytest.mark.skip + @pytest.mark.parametrize( + "payload, expected_page_size, expected_message", + [ + ({"keyword": True}, 5, ""), + ({"keyword": "True"}, 5, ""), + ({"keyword": False}, 5, ""), + ({"keyword": "False"}, 5, ""), + ({"keyword": None}, 5, ""), + ], + ) + def test_keyword(self, client, add_chunks, payload, expected_page_size, expected_message): + dataset, _, _ = add_chunks + payload.update({"question": "chunk test", "dataset_ids": [dataset.id]}) + + if expected_message: + with pytest.raises(Exception) as excinfo: + client.retrieve(**payload) + assert expected_message in str(excinfo.value), str(excinfo.value) + else: + chunks = client.retrieve(**payload) + assert len(chunks) == expected_page_size, str(chunks) + + @pytest.mark.p3 + def test_concurrent_retrieval(self, client, add_chunks): + dataset, _, _ = add_chunks + count = 100 + payload = {"question": "chunk", "dataset_ids": [dataset.id]} + + with ThreadPoolExecutor(max_workers=5) as executor: + futures = [executor.submit(client.retrieve, **payload) for _ in range(count)] + responses = list(as_completed(futures)) + assert len(responses) == count, responses diff --git a/test/testcases/test_sdk_api/test_chunk_management_within_dataset/test_update_chunk.py b/test/testcases/test_sdk_api/test_chunk_management_within_dataset/test_update_chunk.py new file mode 100644 index 00000000000..dc85d6385a0 --- /dev/null +++ b/test/testcases/test_sdk_api/test_chunk_management_within_dataset/test_update_chunk.py @@ -0,0 +1,154 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os +from concurrent.futures import ThreadPoolExecutor, as_completed +from random import randint + +import pytest + + +class TestUpdatedChunk: + @pytest.mark.p1 + @pytest.mark.parametrize( + "payload, expected_message", + [ + ({"content": None}, "TypeError('expected string or bytes-like object')"), + pytest.param( + {"content": ""}, + """APIRequestFailedError(\'Error code: 400, with error text {"error":{"code":"1213","message":"未正常接收到prompt参数。"}}\')""", + marks=pytest.mark.skip(reason="issues/6541"), + ), + pytest.param( + {"content": 1}, + "TypeError('expected string or bytes-like object')", + marks=pytest.mark.skip, + ), + ({"content": "update chunk"}, ""), + pytest.param( + {"content": " "}, + """APIRequestFailedError(\'Error code: 400, with error text {"error":{"code":"1213","message":"未正常接收到prompt参数。"}}\')""", + marks=pytest.mark.skip(reason="issues/6541"), + ), + ({"content": "\n!?。;!?\"'"}, ""), + ], + ) + def test_content(self, add_chunks, payload, expected_message): + _, _, chunks = add_chunks + chunk = chunks[0] + + if expected_message: + with pytest.raises(Exception) as excinfo: + chunk.update(payload) + assert expected_message in str(excinfo.value), str(excinfo.value) + else: + chunk.update(payload) + + @pytest.mark.p2 + @pytest.mark.parametrize( + "payload, expected_message", + [ + ({"important_keywords": ["a", "b", "c"]}, ""), + ({"important_keywords": [""]}, ""), + ({"important_keywords": [1]}, "TypeError('sequence item 0: expected str instance, int found')"), + ({"important_keywords": ["a", "a"]}, ""), + ({"important_keywords": "abc"}, "`important_keywords` should be a list"), + ({"important_keywords": 123}, "`important_keywords` should be a list"), + ], + ) + def test_important_keywords(self, add_chunks, payload, expected_message): + _, _, chunks = add_chunks + chunk = chunks[0] + + if expected_message: + with pytest.raises(Exception) as excinfo: + chunk.update(payload) + assert expected_message in str(excinfo.value), str(excinfo.value) + else: + chunk.update(payload) + + @pytest.mark.p2 + @pytest.mark.parametrize( + "payload, expected_message", + [ + ({"questions": ["a", "b", "c"]}, ""), + ({"questions": [""]}, ""), + ({"questions": [1]}, "TypeError('sequence item 0: expected str instance, int found')"), + ({"questions": ["a", "a"]}, ""), + ({"questions": "abc"}, "`questions` should be a list"), + ({"questions": 123}, "`questions` should be a list"), + ], + ) + def test_questions(self, add_chunks, payload, expected_message): + _, _, chunks = add_chunks + chunk = chunks[0] + + if expected_message: + with pytest.raises(Exception) as excinfo: + chunk.update(payload) + assert expected_message in str(excinfo.value), str(excinfo.value) + else: + chunk.update(payload) + + @pytest.mark.p2 + @pytest.mark.parametrize( + "payload, expected_message", + [ + ({"available": True}, ""), + pytest.param({"available": "True"}, """ValueError("invalid literal for int() with base 10: \'True\'")""", marks=pytest.mark.skip), + ({"available": 1}, ""), + ({"available": False}, ""), + pytest.param({"available": "False"}, """ValueError("invalid literal for int() with base 10: \'False\'")""", marks=pytest.mark.skip), + ({"available": 0}, ""), + ], + ) + def test_available(self, add_chunks, payload, expected_message): + _, _, chunks = add_chunks + chunk = chunks[0] + + if expected_message: + with pytest.raises(Exception) as excinfo: + chunk.update(payload) + assert expected_message in str(excinfo.value), str(excinfo.value) + else: + chunk.update(payload) + + @pytest.mark.p3 + def test_repeated_update_chunk(self, add_chunks): + _, _, chunks = add_chunks + chunk = chunks[0] + + chunk.update({"content": "chunk test 1"}) + chunk.update({"content": "chunk test 2"}) + + @pytest.mark.p3 + @pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="issues/6554") + def test_concurrent_update_chunk(self, add_chunks): + count = 50 + _, _, chunks = add_chunks + + with ThreadPoolExecutor(max_workers=5) as executor: + futures = [executor.submit(chunks[randint(0, 3)].update, {"content": f"update chunk test {i}"}) for i in range(count)] + responses = list(as_completed(futures)) + assert len(responses) == count, responses + + @pytest.mark.p3 + def test_update_chunk_to_deleted_document(self, add_chunks): + dataset, document, chunks = add_chunks + dataset.delete_documents(ids=[document.id]) + + with pytest.raises(Exception) as excinfo: + chunks[0].update({}) + assert f"Can't find this chunk {chunks[0].id}" in str(excinfo.value), str(excinfo.value) diff --git a/test/testcases/test_sdk_api/test_dataset_mangement/conftest.py b/test/testcases/test_sdk_api/test_dataset_mangement/conftest.py new file mode 100644 index 00000000000..8d53eac2ee8 --- /dev/null +++ b/test/testcases/test_sdk_api/test_dataset_mangement/conftest.py @@ -0,0 +1,39 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +import pytest +from common import batch_create_datasets + + +@pytest.fixture(scope="class") +def add_datasets(client, request): + def cleanup(): + client.delete_datasets(**{"ids": None}) + + request.addfinalizer(cleanup) + + return batch_create_datasets(client, 5) + + +@pytest.fixture(scope="function") +def add_datasets_func(client, request): + def cleanup(): + client.delete_datasets(**{"ids": None}) + + request.addfinalizer(cleanup) + + return batch_create_datasets(client, 3) diff --git a/test/testcases/test_sdk_api/test_dataset_mangement/test_create_dataset.py b/test/testcases/test_sdk_api/test_dataset_mangement/test_create_dataset.py new file mode 100644 index 00000000000..4ba2696481e --- /dev/null +++ b/test/testcases/test_sdk_api/test_dataset_mangement/test_create_dataset.py @@ -0,0 +1,656 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from concurrent.futures import ThreadPoolExecutor, as_completed +from operator import attrgetter + +import pytest +from configs import DATASET_NAME_LIMIT, HOST_ADDRESS, INVALID_API_TOKEN +from hypothesis import example, given, settings +from ragflow_sdk import DataSet, RAGFlow +from utils import encode_avatar +from utils.file_utils import create_image_file +from utils.hypothesis_utils import valid_names + + +@pytest.mark.usefixtures("clear_datasets") +class TestAuthorization: + @pytest.mark.p1 + @pytest.mark.parametrize( + "invalid_auth, expected_message", + [ + (None, "Authentication error: API key is invalid!"), + (INVALID_API_TOKEN, "Authentication error: API key is invalid!"), + ], + ids=["empty_auth", "invalid_api_token"], + ) + def test_auth_invalid(self, invalid_auth, expected_message): + client = RAGFlow(invalid_auth, HOST_ADDRESS) + with pytest.raises(Exception) as excinfo: + client.create_dataset(**{"name": "auth_test"}) + assert str(excinfo.value) == expected_message + + +@pytest.mark.usefixtures("clear_datasets") +class TestCapability: + @pytest.mark.p3 + def test_create_dataset_1k(self, client): + count = 1_000 + for i in range(count): + payload = {"name": f"dataset_{i}"} + client.create_dataset(**payload) + assert len(client.list_datasets(page_size=2000)) == count + + @pytest.mark.p3 + def test_create_dataset_concurrent(self, client): + count = 100 + with ThreadPoolExecutor(max_workers=5) as executor: + futures = [executor.submit(client.create_dataset, **{"name": f"dataset_{i}"}) for i in range(100)] + responses = list(as_completed(futures)) + assert len(responses) == count, responses + + +@pytest.mark.usefixtures("clear_datasets") +class TestDatasetCreate: + @pytest.mark.p1 + @given(name=valid_names()) + @example("a" * 128) + @settings(max_examples=20) + def test_name(self, client, name): + dataset = client.create_dataset(**{"name": name}) + assert dataset.name == name, str(dataset) + + @pytest.mark.p2 + @pytest.mark.parametrize( + "name, expected_message", + [ + ("", "String should have at least 1 character"), + (" ", "String should have at least 1 character"), + ("a" * (DATASET_NAME_LIMIT + 1), "String should have at most 128 characters"), + (0, "not instance of"), + (None, "not instance of"), + ], + ids=["empty_name", "space_name", "too_long_name", "invalid_name", "None_name"], + ) + def test_name_invalid(self, client, name, expected_message): + with pytest.raises(Exception) as excinfo: + client.create_dataset(**{"name": name}) + assert expected_message in str(excinfo.value), str(excinfo.value) + + @pytest.mark.p3 + def test_name_duplicated(self, client): + name = "duplicated_name" + payload = {"name": name} + client.create_dataset(**payload) + + with pytest.raises(Exception) as excinfo: + client.create_dataset(**payload) + assert str(excinfo.value) == f"Dataset name '{name}' already exists", str(excinfo.value) + + @pytest.mark.p3 + def test_name_case_insensitive(self, client): + name = "CaseInsensitive" + payload = {"name": name.upper()} + client.create_dataset(**payload) + + payload = {"name": name.lower()} + with pytest.raises(Exception) as excinfo: + client.create_dataset(**payload) + assert str(excinfo.value) == f"Dataset name '{name.lower()}' already exists", str(excinfo.value) + + @pytest.mark.p2 + def test_avatar(self, client, tmp_path): + fn = create_image_file(tmp_path / "ragflow_test.png") + payload = { + "name": "avatar", + "avatar": f"data:image/png;base64,{encode_avatar(fn)}", + } + client.create_dataset(**payload) + + @pytest.mark.p2 + def test_avatar_exceeds_limit_length(self, client): + payload = {"name": "avatar_exceeds_limit_length", "avatar": "a" * 65536} + with pytest.raises(Exception) as excinfo: + client.create_dataset(**payload) + assert "String should have at most 65535 characters" in str(excinfo.value), str(excinfo.value) + + @pytest.mark.p3 + @pytest.mark.parametrize( + "name, prefix, expected_message", + [ + ("empty_prefix", "", "Missing MIME prefix. Expected format: data:;base64,"), + ("missing_comma", "data:image/png;base64", "Missing MIME prefix. Expected format: data:;base64,"), + ("unsupported_mine_type", "invalid_mine_prefix:image/png;base64,", "Invalid MIME prefix format. Must start with 'data:'"), + ("invalid_mine_type", "data:unsupported_mine_type;base64,", "Unsupported MIME type. Allowed: ['image/jpeg', 'image/png']"), + ], + ids=["empty_prefix", "missing_comma", "unsupported_mine_type", "invalid_mine_type"], + ) + def test_avatar_invalid_prefix(self, client, tmp_path, name, prefix, expected_message): + fn = create_image_file(tmp_path / "ragflow_test.png") + payload = { + "name": name, + "avatar": f"{prefix}{encode_avatar(fn)}", + } + with pytest.raises(Exception) as excinfo: + client.create_dataset(**payload) + assert expected_message in str(excinfo.value), str(excinfo.value) + + @pytest.mark.p3 + def test_avatar_unset(self, client): + payload = {"name": "avatar_unset"} + dataset = client.create_dataset(**payload) + assert dataset.avatar is None, str(dataset) + + @pytest.mark.p2 + def test_description(self, client): + payload = {"name": "description", "description": "description"} + dataset = client.create_dataset(**payload) + assert dataset.description == "description", str(dataset) + + @pytest.mark.p2 + def test_description_exceeds_limit_length(self, client): + payload = {"name": "description_exceeds_limit_length", "description": "a" * 65536} + with pytest.raises(Exception) as excinfo: + client.create_dataset(**payload) + assert "String should have at most 65535 characters" in str(excinfo.value), str(excinfo.value) + + @pytest.mark.p3 + def test_description_unset(self, client): + payload = {"name": "description_unset"} + dataset = client.create_dataset(**payload) + assert dataset.description is None, str(dataset) + + @pytest.mark.p3 + def test_description_none(self, client): + payload = {"name": "description_none", "description": None} + dataset = client.create_dataset(**payload) + assert dataset.description is None, str(dataset) + + @pytest.mark.p1 + @pytest.mark.parametrize( + "name, embedding_model", + [ + ("BAAI/bge-large-zh-v1.5@BAAI", "BAAI/bge-large-zh-v1.5@BAAI"), + ("maidalun1020/bce-embedding-base_v1@Youdao", "maidalun1020/bce-embedding-base_v1@Youdao"), + ("embedding-3@ZHIPU-AI", "embedding-3@ZHIPU-AI"), + ], + ids=["builtin_baai", "builtin_youdao", "tenant_zhipu"], + ) + def test_embedding_model(self, client, name, embedding_model): + payload = {"name": name, "embedding_model": embedding_model} + dataset = client.create_dataset(**payload) + assert dataset.embedding_model == embedding_model, str(dataset) + + @pytest.mark.p2 + @pytest.mark.parametrize( + "name, embedding_model", + [ + ("unknown_llm_name", "unknown@ZHIPU-AI"), + ("unknown_llm_factory", "embedding-3@unknown"), + ("tenant_no_auth_default_tenant_llm", "text-embedding-v3@Tongyi-Qianwen"), + ("tenant_no_auth", "text-embedding-3-small@OpenAI"), + ], + ids=["unknown_llm_name", "unknown_llm_factory", "tenant_no_auth_default_tenant_llm", "tenant_no_auth"], + ) + def test_embedding_model_invalid(self, client, name, embedding_model): + payload = {"name": name, "embedding_model": embedding_model} + with pytest.raises(Exception) as excinfo: + client.create_dataset(**payload) + if "tenant_no_auth" in name: + assert str(excinfo.value) == f"Unauthorized model: <{embedding_model}>", str(excinfo.value) + else: + assert str(excinfo.value) == f"Unsupported model: <{embedding_model}>", str(excinfo.value) + + @pytest.mark.p2 + @pytest.mark.parametrize( + "name, embedding_model", + [ + ("missing_at", "BAAI/bge-large-zh-v1.5BAAI"), + ("missing_model_name", "@BAAI"), + ("missing_provider", "BAAI/bge-large-zh-v1.5@"), + ("whitespace_only_model_name", " @BAAI"), + ("whitespace_only_provider", "BAAI/bge-large-zh-v1.5@ "), + ], + ids=["missing_at", "empty_model_name", "empty_provider", "whitespace_only_model_name", "whitespace_only_provider"], + ) + def test_embedding_model_format(self, client, name, embedding_model): + payload = {"name": name, "embedding_model": embedding_model} + with pytest.raises(Exception) as excinfo: + client.create_dataset(**payload) + if name == "missing_at": + assert "Embedding model identifier must follow @ format" in str(excinfo.value), str(excinfo.value) + else: + assert "Both model_name and provider must be non-empty strings" in str(excinfo.value), str(excinfo.value) + + @pytest.mark.p2 + def test_embedding_model_unset(self, client): + payload = {"name": "embedding_model_unset"} + dataset = client.create_dataset(**payload) + assert dataset.embedding_model == "BAAI/bge-large-zh-v1.5@BAAI", str(dataset) + + @pytest.mark.p2 + def test_embedding_model_none(self, client): + payload = {"name": "embedding_model_none", "embedding_model": None} + with pytest.raises(Exception) as excinfo: + client.create_dataset(**payload) + assert "Input should be a valid string" in str(excinfo.value), str(excinfo.value) + + @pytest.mark.p1 + @pytest.mark.parametrize( + "name, permission", + [ + ("me", "me"), + ("team", "team"), + ("me_upercase", "ME"), + ("team_upercase", "TEAM"), + ("whitespace", " ME "), + ], + ids=["me", "team", "me_upercase", "team_upercase", "whitespace"], + ) + def test_permission(self, client, name, permission): + payload = {"name": name, "permission": permission} + dataset = client.create_dataset(**payload) + assert dataset.permission == permission.lower().strip(), str(dataset) + + @pytest.mark.p2 + @pytest.mark.parametrize( + "name, permission", + [ + ("empty", ""), + ("unknown", "unknown"), + ], + ids=["empty", "unknown"], + ) + def test_permission_invalid(self, client, name, permission): + payload = {"name": name, "permission": permission} + with pytest.raises(Exception) as excinfo: + client.create_dataset(**payload) + assert "Input should be 'me' or 'team'" in str(excinfo.value) + + @pytest.mark.p2 + def test_permission_unset(self, client): + payload = {"name": "permission_unset"} + dataset = client.create_dataset(**payload) + assert dataset.permission == "me", str(dataset) + + @pytest.mark.p3 + def test_permission_none(self, client): + payload = {"name": "permission_none", "permission": None} + with pytest.raises(Exception) as excinfo: + client.create_dataset(**payload) + assert "not instance of" in str(excinfo.value), str(excinfo.value) + + @pytest.mark.p1 + @pytest.mark.parametrize( + "name, chunk_method", + [ + ("naive", "naive"), + ("book", "book"), + ("email", "email"), + ("laws", "laws"), + ("manual", "manual"), + ("one", "one"), + ("paper", "paper"), + ("picture", "picture"), + ("presentation", "presentation"), + ("qa", "qa"), + ("table", "table"), + ("tag", "tag"), + ], + ids=["naive", "book", "email", "laws", "manual", "one", "paper", "picture", "presentation", "qa", "table", "tag"], + ) + def test_chunk_method(self, client, name, chunk_method): + payload = {"name": name, "chunk_method": chunk_method} + dataset = client.create_dataset(**payload) + assert dataset.chunk_method == chunk_method, str(dataset) + + @pytest.mark.p2 + @pytest.mark.parametrize( + "name, chunk_method", + [ + ("empty", ""), + ("unknown", "unknown"), + ], + ids=["empty", "unknown"], + ) + def test_chunk_method_invalid(self, client, name, chunk_method): + payload = {"name": name, "chunk_method": chunk_method} + with pytest.raises(Exception) as excinfo: + client.create_dataset(**payload) + assert "Input should be 'naive', 'book', 'email', 'laws', 'manual', 'one', 'paper', 'picture', 'presentation', 'qa', 'table' or 'tag'" in str(excinfo.value), str(excinfo.value) + + @pytest.mark.p2 + def test_chunk_method_unset(self, client): + payload = {"name": "chunk_method_unset"} + dataset = client.create_dataset(**payload) + assert dataset.chunk_method == "naive", str(dataset) + + @pytest.mark.p3 + def test_chunk_method_none(self, client): + payload = {"name": "chunk_method_none", "chunk_method": None} + with pytest.raises(Exception) as excinfo: + client.create_dataset(**payload) + assert "not instance of" in str(excinfo.value), str(excinfo.value) + + @pytest.mark.p1 + @pytest.mark.parametrize( + "name, parser_config", + [ + ("auto_keywords_min", {"auto_keywords": 0}), + ("auto_keywords_mid", {"auto_keywords": 16}), + ("auto_keywords_max", {"auto_keywords": 32}), + ("auto_questions_min", {"auto_questions": 0}), + ("auto_questions_mid", {"auto_questions": 5}), + ("auto_questions_max", {"auto_questions": 10}), + ("chunk_token_num_min", {"chunk_token_num": 1}), + ("chunk_token_num_mid", {"chunk_token_num": 1024}), + ("chunk_token_num_max", {"chunk_token_num": 2048}), + ("delimiter", {"delimiter": "\n"}), + ("delimiter_space", {"delimiter": " "}), + ("html4excel_true", {"html4excel": True}), + ("html4excel_false", {"html4excel": False}), + ("layout_recognize_DeepDOC", {"layout_recognize": "DeepDOC"}), + ("layout_recognize_navie", {"layout_recognize": "Plain Text"}), + ("tag_kb_ids", {"tag_kb_ids": ["1", "2"]}), + ("topn_tags_min", {"topn_tags": 1}), + ("topn_tags_mid", {"topn_tags": 5}), + ("topn_tags_max", {"topn_tags": 10}), + ("filename_embd_weight_min", {"filename_embd_weight": 0.1}), + ("filename_embd_weight_mid", {"filename_embd_weight": 0.5}), + ("filename_embd_weight_max", {"filename_embd_weight": 1.0}), + ("task_page_size_min", {"task_page_size": 1}), + ("task_page_size_None", {"task_page_size": None}), + ("pages", {"pages": [[1, 100]]}), + ("pages_none", {"pages": None}), + ("graphrag_true", {"graphrag": {"use_graphrag": True}}), + ("graphrag_false", {"graphrag": {"use_graphrag": False}}), + ("graphrag_entity_types", {"graphrag": {"entity_types": ["age", "sex", "height", "weight"]}}), + ("graphrag_method_general", {"graphrag": {"method": "general"}}), + ("graphrag_method_light", {"graphrag": {"method": "light"}}), + ("graphrag_community_true", {"graphrag": {"community": True}}), + ("graphrag_community_false", {"graphrag": {"community": False}}), + ("graphrag_resolution_true", {"graphrag": {"resolution": True}}), + ("graphrag_resolution_false", {"graphrag": {"resolution": False}}), + ("raptor_true", {"raptor": {"use_raptor": True}}), + ("raptor_false", {"raptor": {"use_raptor": False}}), + ("raptor_prompt", {"raptor": {"prompt": "Who are you?"}}), + ("raptor_max_token_min", {"raptor": {"max_token": 1}}), + ("raptor_max_token_mid", {"raptor": {"max_token": 1024}}), + ("raptor_max_token_max", {"raptor": {"max_token": 2048}}), + ("raptor_threshold_min", {"raptor": {"threshold": 0.0}}), + ("raptor_threshold_mid", {"raptor": {"threshold": 0.5}}), + ("raptor_threshold_max", {"raptor": {"threshold": 1.0}}), + ("raptor_max_cluster_min", {"raptor": {"max_cluster": 1}}), + ("raptor_max_cluster_mid", {"raptor": {"max_cluster": 512}}), + ("raptor_max_cluster_max", {"raptor": {"max_cluster": 1024}}), + ("raptor_random_seed_min", {"raptor": {"random_seed": 0}}), + ], + ids=[ + "auto_keywords_min", + "auto_keywords_mid", + "auto_keywords_max", + "auto_questions_min", + "auto_questions_mid", + "auto_questions_max", + "chunk_token_num_min", + "chunk_token_num_mid", + "chunk_token_num_max", + "delimiter", + "delimiter_space", + "html4excel_true", + "html4excel_false", + "layout_recognize_DeepDOC", + "layout_recognize_navie", + "tag_kb_ids", + "topn_tags_min", + "topn_tags_mid", + "topn_tags_max", + "filename_embd_weight_min", + "filename_embd_weight_mid", + "filename_embd_weight_max", + "task_page_size_min", + "task_page_size_None", + "pages", + "pages_none", + "graphrag_true", + "graphrag_false", + "graphrag_entity_types", + "graphrag_method_general", + "graphrag_method_light", + "graphrag_community_true", + "graphrag_community_false", + "graphrag_resolution_true", + "graphrag_resolution_false", + "raptor_true", + "raptor_false", + "raptor_prompt", + "raptor_max_token_min", + "raptor_max_token_mid", + "raptor_max_token_max", + "raptor_threshold_min", + "raptor_threshold_mid", + "raptor_threshold_max", + "raptor_max_cluster_min", + "raptor_max_cluster_mid", + "raptor_max_cluster_max", + "raptor_random_seed_min", + ], + ) + def test_parser_config(self, client, name, parser_config): + parser_config_o = DataSet.ParserConfig(client, parser_config) + payload = {"name": name, "parser_config": parser_config_o} + dataset = client.create_dataset(**payload) + for k, v in parser_config.items(): + if isinstance(v, dict): + for kk, vv in v.items(): + assert attrgetter(f"{k}.{kk}")(dataset.parser_config) == vv, str(dataset) + else: + assert attrgetter(k)(dataset.parser_config) == v, str(dataset) + + @pytest.mark.p2 + @pytest.mark.parametrize( + "name, parser_config, expected_message", + [ + ("auto_keywords_min_limit", {"auto_keywords": -1}, "Input should be greater than or equal to 0"), + ("auto_keywords_max_limit", {"auto_keywords": 33}, "Input should be less than or equal to 32"), + ("auto_keywords_float_not_allowed", {"auto_keywords": 3.14}, "Input should be a valid integer, got a number with a fractional part"), + ("auto_keywords_type_invalid", {"auto_keywords": "string"}, "Input should be a valid integer, unable to parse string as an integer"), + ("auto_questions_min_limit", {"auto_questions": -1}, "Input should be greater than or equal to 0"), + ("auto_questions_max_limit", {"auto_questions": 11}, "Input should be less than or equal to 10"), + ("auto_questions_float_not_allowed", {"auto_questions": 3.14}, "Input should be a valid integer, got a number with a fractional part"), + ("auto_questions_type_invalid", {"auto_questions": "string"}, "Input should be a valid integer, unable to parse string as an integer"), + ("chunk_token_num_min_limit", {"chunk_token_num": 0}, "Input should be greater than or equal to 1"), + ("chunk_token_num_max_limit", {"chunk_token_num": 2049}, "Input should be less than or equal to 2048"), + ("chunk_token_num_float_not_allowed", {"chunk_token_num": 3.14}, "Input should be a valid integer, got a number with a fractional part"), + ("chunk_token_num_type_invalid", {"chunk_token_num": "string"}, "Input should be a valid integer, unable to parse string as an integer"), + ("delimiter_empty", {"delimiter": ""}, "String should have at least 1 character"), + ("html4excel_type_invalid", {"html4excel": "string"}, "Input should be a valid boolean, unable to interpret input"), + ("tag_kb_ids_not_list", {"tag_kb_ids": "1,2"}, "Input should be a valid list"), + ("tag_kb_ids_int_in_list", {"tag_kb_ids": [1, 2]}, "Input should be a valid string"), + ("topn_tags_min_limit", {"topn_tags": 0}, "Input should be greater than or equal to 1"), + ("topn_tags_max_limit", {"topn_tags": 11}, "Input should be less than or equal to 10"), + ("topn_tags_float_not_allowed", {"topn_tags": 3.14}, "Input should be a valid integer, got a number with a fractional part"), + ("topn_tags_type_invalid", {"topn_tags": "string"}, "Input should be a valid integer, unable to parse string as an integer"), + ("filename_embd_weight_min_limit", {"filename_embd_weight": -1}, "Input should be greater than or equal to 0"), + ("filename_embd_weight_max_limit", {"filename_embd_weight": 1.1}, "Input should be less than or equal to 1"), + ("filename_embd_weight_type_invalid", {"filename_embd_weight": "string"}, "Input should be a valid number, unable to parse string as a number"), + ("task_page_size_min_limit", {"task_page_size": 0}, "Input should be greater than or equal to 1"), + ("task_page_size_float_not_allowed", {"task_page_size": 3.14}, "Input should be a valid integer, got a number with a fractional part"), + ("task_page_size_type_invalid", {"task_page_size": "string"}, "Input should be a valid integer, unable to parse string as an integer"), + ("pages_not_list", {"pages": "1,2"}, "Input should be a valid list"), + ("pages_not_list_in_list", {"pages": ["1,2"]}, "Input should be a valid list"), + ("pages_not_int_list", {"pages": [["string1", "string2"]]}, "Input should be a valid integer, unable to parse string as an integer"), + ("graphrag_type_invalid", {"graphrag": {"use_graphrag": "string"}}, "Input should be a valid boolean, unable to interpret input"), + ("graphrag_entity_types_not_list", {"graphrag": {"entity_types": "1,2"}}, "Input should be a valid list"), + ("graphrag_entity_types_not_str_in_list", {"graphrag": {"entity_types": [1, 2]}}, "nput should be a valid string"), + ("graphrag_method_unknown", {"graphrag": {"method": "unknown"}}, "Input should be 'light' or 'general'"), + ("graphrag_method_none", {"graphrag": {"method": None}}, "Input should be 'light' or 'general'"), + ("graphrag_community_type_invalid", {"graphrag": {"community": "string"}}, "Input should be a valid boolean, unable to interpret input"), + ("graphrag_resolution_type_invalid", {"graphrag": {"resolution": "string"}}, "Input should be a valid boolean, unable to interpret input"), + ("raptor_type_invalid", {"raptor": {"use_raptor": "string"}}, "Input should be a valid boolean, unable to interpret input"), + ("raptor_prompt_empty", {"raptor": {"prompt": ""}}, "String should have at least 1 character"), + ("raptor_prompt_space", {"raptor": {"prompt": " "}}, "String should have at least 1 character"), + ("raptor_max_token_min_limit", {"raptor": {"max_token": 0}}, "Input should be greater than or equal to 1"), + ("raptor_max_token_max_limit", {"raptor": {"max_token": 2049}}, "Input should be less than or equal to 2048"), + ("raptor_max_token_float_not_allowed", {"raptor": {"max_token": 3.14}}, "Input should be a valid integer, got a number with a fractional part"), + ("raptor_max_token_type_invalid", {"raptor": {"max_token": "string"}}, "Input should be a valid integer, unable to parse string as an integer"), + ("raptor_threshold_min_limit", {"raptor": {"threshold": -0.1}}, "Input should be greater than or equal to 0"), + ("raptor_threshold_max_limit", {"raptor": {"threshold": 1.1}}, "Input should be less than or equal to 1"), + ("raptor_threshold_type_invalid", {"raptor": {"threshold": "string"}}, "Input should be a valid number, unable to parse string as a number"), + ("raptor_max_cluster_min_limit", {"raptor": {"max_cluster": 0}}, "Input should be greater than or equal to 1"), + ("raptor_max_cluster_max_limit", {"raptor": {"max_cluster": 1025}}, "Input should be less than or equal to 1024"), + ("raptor_max_cluster_float_not_allowed", {"raptor": {"max_cluster": 3.14}}, "Input should be a valid integer, got a number with a fractional par"), + ("raptor_max_cluster_type_invalid", {"raptor": {"max_cluster": "string"}}, "Input should be a valid integer, unable to parse string as an integer"), + ("raptor_random_seed_min_limit", {"raptor": {"random_seed": -1}}, "Input should be greater than or equal to 0"), + ("raptor_random_seed_float_not_allowed", {"raptor": {"random_seed": 3.14}}, "Input should be a valid integer, got a number with a fractional part"), + ("raptor_random_seed_type_invalid", {"raptor": {"random_seed": "string"}}, "Input should be a valid integer, unable to parse string as an integer"), + ("parser_config_type_invalid", {"delimiter": "a" * 65536}, "Parser config exceeds size limit (max 65,535 characters)"), + ], + ids=[ + "auto_keywords_min_limit", + "auto_keywords_max_limit", + "auto_keywords_float_not_allowed", + "auto_keywords_type_invalid", + "auto_questions_min_limit", + "auto_questions_max_limit", + "auto_questions_float_not_allowed", + "auto_questions_type_invalid", + "chunk_token_num_min_limit", + "chunk_token_num_max_limit", + "chunk_token_num_float_not_allowed", + "chunk_token_num_type_invalid", + "delimiter_empty", + "html4excel_type_invalid", + "tag_kb_ids_not_list", + "tag_kb_ids_int_in_list", + "topn_tags_min_limit", + "topn_tags_max_limit", + "topn_tags_float_not_allowed", + "topn_tags_type_invalid", + "filename_embd_weight_min_limit", + "filename_embd_weight_max_limit", + "filename_embd_weight_type_invalid", + "task_page_size_min_limit", + "task_page_size_float_not_allowed", + "task_page_size_type_invalid", + "pages_not_list", + "pages_not_list_in_list", + "pages_not_int_list", + "graphrag_type_invalid", + "graphrag_entity_types_not_list", + "graphrag_entity_types_not_str_in_list", + "graphrag_method_unknown", + "graphrag_method_none", + "graphrag_community_type_invalid", + "graphrag_resolution_type_invalid", + "raptor_type_invalid", + "raptor_prompt_empty", + "raptor_prompt_space", + "raptor_max_token_min_limit", + "raptor_max_token_max_limit", + "raptor_max_token_float_not_allowed", + "raptor_max_token_type_invalid", + "raptor_threshold_min_limit", + "raptor_threshold_max_limit", + "raptor_threshold_type_invalid", + "raptor_max_cluster_min_limit", + "raptor_max_cluster_max_limit", + "raptor_max_cluster_float_not_allowed", + "raptor_max_cluster_type_invalid", + "raptor_random_seed_min_limit", + "raptor_random_seed_float_not_allowed", + "raptor_random_seed_type_invalid", + "parser_config_type_invalid", + ], + ) + def test_parser_config_invalid(self, client, name, parser_config, expected_message): + parser_config_o = DataSet.ParserConfig(client, parser_config) + payload = {"name": name, "parser_config": parser_config_o} + with pytest.raises(Exception) as excinfo: + client.create_dataset(**payload) + assert expected_message in str(excinfo.value), str(excinfo.value) + + @pytest.mark.p2 + def test_parser_config_empty(self, client): + excepted_value = DataSet.ParserConfig( + client, + { + "chunk_token_num": 128, + "delimiter": r"\n", + "html4excel": False, + "layout_recognize": "DeepDOC", + "raptor": {"use_raptor": False}, + }, + ) + parser_config_o = DataSet.ParserConfig(client, {}) + payload = {"name": "parser_config_empty", "parser_config": parser_config_o} + dataset = client.create_dataset(**payload) + assert str(dataset.parser_config) == str(excepted_value), str(dataset) + + @pytest.mark.p2 + def test_parser_config_unset(self, client): + excepted_value = DataSet.ParserConfig( + client, + { + "chunk_token_num": 128, + "delimiter": r"\n", + "html4excel": False, + "layout_recognize": "DeepDOC", + "raptor": {"use_raptor": False}, + }, + ) + payload = {"name": "parser_config_unset"} + dataset = client.create_dataset(**payload) + assert str(dataset.parser_config) == str(excepted_value), str(dataset) + + @pytest.mark.p3 + def test_parser_config_none(self, client): + excepted_value = DataSet.ParserConfig( + client, + { + "chunk_token_num": 128, + "delimiter": r"\n", + "html4excel": False, + "layout_recognize": "DeepDOC", + "raptor": {"use_raptor": False}, + }, + ) + payload = {"name": "parser_config_empty", "parser_config": None} + dataset = client.create_dataset(**payload) + assert str(dataset.parser_config) == str(excepted_value), str(dataset) + + @pytest.mark.p2 + @pytest.mark.parametrize( + "payload", + [ + {"name": "id", "id": "id"}, + {"name": "tenant_id", "tenant_id": "e57c1966f99211efb41e9e45646e0111"}, + {"name": "created_by", "created_by": "created_by"}, + {"name": "create_date", "create_date": "Tue, 11 Mar 2025 13:37:23 GMT"}, + {"name": "create_time", "create_time": 1741671443322}, + {"name": "update_date", "update_date": "Tue, 11 Mar 2025 13:37:23 GMT"}, + {"name": "update_time", "update_time": 1741671443339}, + {"name": "document_count", "document_count": 1}, + {"name": "chunk_count", "chunk_count": 1}, + {"name": "token_num", "token_num": 1}, + {"name": "status", "status": "1"}, + {"name": "pagerank", "pagerank": 50}, + {"name": "unknown_field", "unknown_field": "unknown_field"}, + ], + ) + def test_unsupported_field(self, client, payload): + with pytest.raises(Exception) as excinfo: + client.create_dataset(**payload) + assert "got an unexpected keyword argument" in str(excinfo.value), str(excinfo.value) diff --git a/test/testcases/test_sdk_api/test_dataset_mangement/test_delete_datasets.py b/test/testcases/test_sdk_api/test_dataset_mangement/test_delete_datasets.py new file mode 100644 index 00000000000..5a27d89bc69 --- /dev/null +++ b/test/testcases/test_sdk_api/test_dataset_mangement/test_delete_datasets.py @@ -0,0 +1,178 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import uuid +from concurrent.futures import ThreadPoolExecutor, as_completed + +import pytest +from common import batch_create_datasets +from configs import HOST_ADDRESS, INVALID_API_TOKEN +from ragflow_sdk import RAGFlow + + +class TestAuthorization: + @pytest.mark.p1 + @pytest.mark.parametrize( + "invalid_auth, expected_message", + [ + (None, "Authentication error: API key is invalid!"), + (INVALID_API_TOKEN, "Authentication error: API key is invalid!"), + ], + ) + def test_auth_invalid(self, invalid_auth, expected_message): + client = RAGFlow(invalid_auth, HOST_ADDRESS) + with pytest.raises(Exception) as excinfo: + client.delete_datasets() + assert str(excinfo.value) == expected_message + + +class TestCapability: + @pytest.mark.p3 + def test_delete_dataset_1k(self, client): + datasets = batch_create_datasets(client, 1_000) + client.delete_datasets(**{"ids": [dataset.id for dataset in datasets]}) + + datasets = client.list_datasets() + assert len(datasets) == 0, datasets + + @pytest.mark.p3 + def test_concurrent_deletion(self, client): + count = 1_000 + datasets = batch_create_datasets(client, count) + with ThreadPoolExecutor(max_workers=5) as executor: + futures = [executor.submit(client.delete_datasets, **{"ids": [dataset.id for dataset in datasets][i : i + 1]}) for i in range(count)] + responses = list(as_completed(futures)) + assert len(responses) == count, responses + + datasets = client.list_datasets() + assert len(datasets) == 0, datasets + + +class TestDatasetsDelete: + @pytest.mark.p1 + @pytest.mark.parametrize( + "func, remaining", + [ + (lambda r: {"ids": r[:1]}, 2), + (lambda r: {"ids": r}, 0), + ], + ids=["single_dataset", "multiple_datasets"], + ) + def test_ids(self, client, add_datasets_func, func, remaining): + if callable(func): + payload = func([dataset.id for dataset in add_datasets_func]) + client.delete_datasets(**payload) + + datasets = client.list_datasets() + assert len(datasets) == remaining, str(datasets) + + @pytest.mark.p1 + @pytest.mark.usefixtures("add_dataset_func") + def test_ids_empty(self, client): + payload = {"ids": []} + client.delete_datasets(**payload) + + datasets = client.list_datasets() + assert len(datasets) == 1, str(datasets) + + @pytest.mark.p1 + @pytest.mark.usefixtures("add_datasets_func") + def test_ids_none(self, client): + payload = {"ids": None} + client.delete_datasets(**payload) + + datasets = client.list_datasets() + assert len(datasets) == 0, str(datasets) + + @pytest.mark.p2 + @pytest.mark.usefixtures("add_dataset_func") + def test_id_not_uuid(self, client): + payload = {"ids": ["not_uuid"]} + with pytest.raises(Exception) as excinfo: + client.delete_datasets(**payload) + assert "Invalid UUID1 format" in str(excinfo.value), str(excinfo.value) + + datasets = client.list_datasets() + assert len(datasets) == 1, str(datasets) + + @pytest.mark.p3 + @pytest.mark.usefixtures("add_dataset_func") + def test_id_not_uuid1(self, client): + payload = {"ids": [uuid.uuid4().hex]} + with pytest.raises(Exception) as excinfo: + client.delete_datasets(**payload) + assert "Invalid UUID1 format" in str(excinfo.value), str(excinfo.value) + + @pytest.mark.p2 + @pytest.mark.usefixtures("add_dataset_func") + def test_id_wrong_uuid(self, client): + payload = {"ids": ["d94a8dc02c9711f0930f7fbc369eab6d"]} + with pytest.raises(Exception) as excinfo: + client.delete_datasets(**payload) + assert "lacks permission for dataset" in str(excinfo.value), str(excinfo.value) + + datasets = client.list_datasets() + assert len(datasets) == 1, str(datasets) + + @pytest.mark.p2 + @pytest.mark.parametrize( + "func", + [ + lambda r: {"ids": ["d94a8dc02c9711f0930f7fbc369eab6d"] + r}, + lambda r: {"ids": r[:1] + ["d94a8dc02c9711f0930f7fbc369eab6d"] + r[1:3]}, + lambda r: {"ids": r + ["d94a8dc02c9711f0930f7fbc369eab6d"]}, + ], + ) + def test_ids_partial_invalid(self, client, add_datasets_func, func): + if callable(func): + payload = func([dataset.id for dataset in add_datasets_func]) + with pytest.raises(Exception) as excinfo: + client.delete_datasets(**payload) + assert "lacks permission for dataset" in str(excinfo.value), str(excinfo.value) + + datasets = client.list_datasets() + assert len(datasets) == 3, str(datasets) + + @pytest.mark.p2 + def test_ids_duplicate(self, client, add_datasets_func): + dataset_ids = [dataset.id for dataset in add_datasets_func] + payload = {"ids": dataset_ids + dataset_ids} + with pytest.raises(Exception) as excinfo: + client.delete_datasets(**payload) + assert "Duplicate ids:" in str(excinfo.value), str(excinfo.value) + + datasets = client.list_datasets() + assert len(datasets) == 3, str(datasets) + + @pytest.mark.p2 + def test_repeated_delete(self, client, add_datasets_func): + dataset_ids = [dataset.id for dataset in add_datasets_func] + payload = {"ids": dataset_ids} + client.delete_datasets(**payload) + + with pytest.raises(Exception) as excinfo: + client.delete_datasets(**payload) + assert "lacks permission for dataset" in str(excinfo.value), str(excinfo.value) + + @pytest.mark.p2 + @pytest.mark.usefixtures("add_dataset_func") + def test_field_unsupported(self, client): + payload = {"unknown_field": "unknown_field"} + with pytest.raises(Exception) as excinfo: + client.delete_datasets(**payload) + assert "got an unexpected keyword argument 'unknown_field'" in str(excinfo.value), str(excinfo.value) + + datasets = client.list_datasets() + assert len(datasets) == 1, str(datasets) diff --git a/test/testcases/test_sdk_api/test_dataset_mangement/test_list_datasets.py b/test/testcases/test_sdk_api/test_dataset_mangement/test_list_datasets.py new file mode 100644 index 00000000000..f067acda5e4 --- /dev/null +++ b/test/testcases/test_sdk_api/test_dataset_mangement/test_list_datasets.py @@ -0,0 +1,313 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import uuid +from concurrent.futures import ThreadPoolExecutor, as_completed + +import pytest +from configs import HOST_ADDRESS, INVALID_API_TOKEN +from ragflow_sdk import RAGFlow + + +class TestAuthorization: + @pytest.mark.p1 + @pytest.mark.parametrize( + "invalid_auth, expected_message", + [ + (None, "Authentication error: API key is invalid!"), + (INVALID_API_TOKEN, "Authentication error: API key is invalid!"), + ], + ) + def test_auth_invalid(self, invalid_auth, expected_message): + client = RAGFlow(invalid_auth, HOST_ADDRESS) + with pytest.raises(Exception) as excinfo: + client.list_datasets() + assert expected_message in str(excinfo.value) + + +class TestCapability: + @pytest.mark.p3 + def test_concurrent_list(self, client): + count = 100 + with ThreadPoolExecutor(max_workers=5) as executor: + futures = [ + executor.submit( + client.list_datasets, + ) + for i in range(count) + ] + responses = list(as_completed(futures)) + assert len(responses) == count, responses + + +@pytest.mark.usefixtures("add_datasets") +class TestDatasetsList: + @pytest.mark.p1 + def test_params_unset(self, client): + datasets = client.list_datasets() + assert len(datasets) == 5, str(datasets) + + @pytest.mark.p2 + def test_params_empty(self, client): + datasets = client.list_datasets(**{}) + assert len(datasets) == 5, str(datasets) + + @pytest.mark.p1 + @pytest.mark.parametrize( + "params, expected_page_size", + [ + ({"page": 2, "page_size": 2}, 2), + ({"page": 3, "page_size": 2}, 1), + ({"page": 4, "page_size": 2}, 0), + ({"page": 1, "page_size": 10}, 5), + ], + ids=["normal_middle_page", "normal_last_partial_page", "beyond_max_page", "full_data_single_page"], + ) + def test_page(self, client, params, expected_page_size): + datasets = client.list_datasets(**params) + assert len(datasets) == expected_page_size, str(datasets) + + @pytest.mark.p2 + @pytest.mark.parametrize( + "params, expected_message", + [ + ({"page": 0}, "Input should be greater than or equal to 1"), + ({"page": "a"}, "not instance of"), + ], + ids=["page_0", "page_a"], + ) + def test_page_invalid(self, client, params, expected_message): + with pytest.raises(Exception) as excinfo: + client.list_datasets(**params) + assert expected_message in str(excinfo.value), str(excinfo.value) + + @pytest.mark.p2 + def test_page_none(self, client): + params = {"page": None} + with pytest.raises(Exception) as excinfo: + client.list_datasets(**params) + assert "not instance of" in str(excinfo.value), str(excinfo.value) + + @pytest.mark.p1 + @pytest.mark.parametrize( + "params, expected_page_size", + [ + ({"page_size": 1}, 1), + ({"page_size": 3}, 3), + ({"page_size": 5}, 5), + ({"page_size": 6}, 5), + ], + ids=["min_valid_page_size", "medium_page_size", "page_size_equals_total", "page_size_exceeds_total"], + ) + def test_page_size(self, client, params, expected_page_size): + datasets = client.list_datasets(**params) + assert len(datasets) == expected_page_size, str(datasets) + + @pytest.mark.p2 + @pytest.mark.parametrize( + "params, expected_message", + [ + ({"page_size": 0}, "Input should be greater than or equal to 1"), + ({"page_size": "a"}, "not instance of"), + ], + ) + def test_page_size_invalid(self, client, params, expected_message): + with pytest.raises(Exception) as excinfo: + client.list_datasets(**params) + assert expected_message in str(excinfo.value), str(excinfo.value) + + @pytest.mark.p2 + def test_page_size_none(self, client): + params = {"page_size": None} + with pytest.raises(Exception) as excinfo: + client.list_datasets(**params) + assert "not instance of" in str(excinfo.value), str(excinfo.value) + + @pytest.mark.p2 + @pytest.mark.parametrize( + "params", + [ + {"orderby": "create_time"}, + {"orderby": "update_time"}, + {"orderby": "CREATE_TIME"}, + {"orderby": "UPDATE_TIME"}, + {"orderby": " create_time "}, + ], + ids=["orderby_create_time", "orderby_update_time", "orderby_create_time_upper", "orderby_update_time_upper", "whitespace"], + ) + def test_orderby(self, client, params): + client.list_datasets(**params) + + @pytest.mark.p3 + @pytest.mark.parametrize( + "params", + [ + {"orderby": ""}, + {"orderby": "unknown"}, + ], + ids=["empty", "unknown"], + ) + def test_orderby_invalid(self, client, params): + with pytest.raises(Exception) as excinfo: + client.list_datasets(**params) + assert "Input should be 'create_time' or 'update_time'" in str(excinfo.value), str(excinfo.value) + + @pytest.mark.p3 + def test_orderby_none(self, client): + params = {"orderby": None} + with pytest.raises(Exception) as excinfo: + client.list_datasets(**params) + assert "not instance of" in str(excinfo.value), str(excinfo.value) + + @pytest.mark.p2 + @pytest.mark.parametrize( + "params", + [ + {"desc": True}, + {"desc": False}, + ], + ids=["desc=True", "desc=False"], + ) + def test_desc(self, client, params): + client.list_datasets(**params) + + @pytest.mark.p3 + @pytest.mark.parametrize( + "params", + [ + {"desc": 3.14}, + {"desc": "unknown"}, + ], + ids=["float_value", "invalid_string"], + ) + def test_desc_invalid(self, client, params): + with pytest.raises(Exception) as excinfo: + client.list_datasets(**params) + assert "not instance of" in str(excinfo.value), str(excinfo.value) + + @pytest.mark.p3 + def test_desc_none(self, client): + params = {"desc": None} + with pytest.raises(Exception) as excinfo: + client.list_datasets(**params) + assert "not instance of" in str(excinfo.value), str(excinfo.value) + + @pytest.mark.p1 + def test_name(self, client): + params = {"name": "dataset_1"} + datasets = client.list_datasets(**params) + assert len(datasets) == 1, str(datasets) + assert datasets[0].name == "dataset_1", str(datasets) + + @pytest.mark.p2 + def test_name_wrong(self, client): + params = {"name": "wrong name"} + with pytest.raises(Exception) as excinfo: + client.list_datasets(**params) + assert "lacks permission for dataset" in str(excinfo.value), str(excinfo.value) + + @pytest.mark.p2 + def test_name_empty(self, client): + params = {"name": ""} + datasets = client.list_datasets(**params) + assert len(datasets) == 5, str(datasets) + + @pytest.mark.p2 + def test_name_none(self, client): + params = {"name": None} + datasets = client.list_datasets(**params) + assert len(datasets) == 5, str(datasets) + + @pytest.mark.p1 + def test_id(self, client, add_datasets): + dataset_ids = [dataset.id for dataset in add_datasets] + params = {"id": dataset_ids[0]} + datasets = client.list_datasets(**params) + assert len(datasets) == 1, str(datasets) + assert datasets[0].id == dataset_ids[0], str(datasets) + + @pytest.mark.p2 + def test_id_not_uuid(self, client): + params = {"id": "not_uuid"} + with pytest.raises(Exception) as excinfo: + client.list_datasets(**params) + assert "Invalid UUID1 format" in str(excinfo.value), str(excinfo.value) + + @pytest.mark.p2 + def test_id_not_uuid1(self, client): + params = {"id": uuid.uuid4().hex} + with pytest.raises(Exception) as excinfo: + client.list_datasets(**params) + assert "Invalid UUID1 format" in str(excinfo.value), str(excinfo.value) + + @pytest.mark.p2 + def test_id_wrong_uuid(self, client): + params = {"id": "d94a8dc02c9711f0930f7fbc369eab6d"} + with pytest.raises(Exception) as excinfo: + client.list_datasets(**params) + assert "lacks permission for dataset" in str(excinfo.value), str(excinfo.value) + + @pytest.mark.p2 + def test_id_empty(self, client): + params = {"id": ""} + with pytest.raises(Exception) as excinfo: + client.list_datasets(**params) + assert "Invalid UUID1 format" in str(excinfo.value), str(excinfo.value) + + @pytest.mark.p2 + def test_id_none(self, client): + params = {"id": None} + datasets = client.list_datasets(**params) + assert len(datasets) == 5, str(datasets) + + @pytest.mark.p2 + @pytest.mark.parametrize( + "func, name, expected_num", + [ + (lambda r: r[0].id, "dataset_0", 1), + (lambda r: r[0].id, "dataset_1", 0), + ], + ids=["name_and_id_match", "name_and_id_mismatch"], + ) + def test_name_and_id(self, client, add_datasets, func, name, expected_num): + if callable(func): + params = {"id": func(add_datasets), "name": name} + datasets = client.list_datasets(**params) + assert len(datasets) == expected_num, str(datasets) + + @pytest.mark.p3 + @pytest.mark.parametrize( + "dataset_id, name", + [ + (lambda r: r[0].id, "wrong_name"), + (uuid.uuid1().hex, "dataset_0"), + ], + ids=["name", "id"], + ) + def test_name_and_id_wrong(self, client, add_datasets, dataset_id, name): + if callable(dataset_id): + params = {"id": dataset_id(add_datasets), "name": name} + else: + params = {"id": dataset_id, "name": name} + with pytest.raises(Exception) as excinfo: + client.list_datasets(**params) + assert "lacks permission for dataset" in str(excinfo.value), str(excinfo.value) + + @pytest.mark.p2 + def test_field_unsupported(self, client): + params = {"unknown_field": "unknown_field"} + with pytest.raises(Exception) as excinfo: + client.list_datasets(**params) + assert "got an unexpected keyword argument" in str(excinfo.value), str(excinfo.value) diff --git a/test/testcases/test_sdk_api/test_dataset_mangement/test_update_dataset.py b/test/testcases/test_sdk_api/test_dataset_mangement/test_update_dataset.py new file mode 100644 index 00000000000..f4a0a916324 --- /dev/null +++ b/test/testcases/test_sdk_api/test_dataset_mangement/test_update_dataset.py @@ -0,0 +1,750 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os +from concurrent.futures import ThreadPoolExecutor, as_completed +from operator import attrgetter + +import pytest +from configs import DATASET_NAME_LIMIT +from hypothesis import HealthCheck, example, given, settings +from ragflow_sdk import DataSet +from utils import encode_avatar +from utils.file_utils import create_image_file +from utils.hypothesis_utils import valid_names + + +class TestRquest: + @pytest.mark.p2 + def test_payload_empty(self, add_dataset_func): + dataset = add_dataset_func + with pytest.raises(Exception) as excinfo: + dataset.update({}) + assert "No properties were modified" in str(excinfo.value), str(excinfo.value) + + +class TestCapability: + @pytest.mark.p3 + def test_update_dateset_concurrent(self, add_dataset_func): + dataset = add_dataset_func + count = 100 + with ThreadPoolExecutor(max_workers=5) as executor: + futures = [executor.submit(dataset.update, {"name": f"dataset_{i}"}) for i in range(count)] + responses = list(as_completed(futures)) + assert len(responses) == count, responses + + +class TestDatasetUpdate: + @pytest.mark.p1 + @given(name=valid_names()) + @example("a" * 128) + @settings(max_examples=20, suppress_health_check=[HealthCheck.function_scoped_fixture]) + def test_name(self, client, add_dataset_func, name): + dataset = add_dataset_func + payload = {"name": name} + dataset.update(payload) + assert dataset.name == name, str(dataset) + + retrieved_dataset = client.get_dataset(name=dataset.name) + assert retrieved_dataset.name == name, str(retrieved_dataset) + + @pytest.mark.p2 + @pytest.mark.parametrize( + "name, expected_message", + [ + ("", "String should have at least 1 character"), + (" ", "String should have at least 1 character"), + ("a" * (DATASET_NAME_LIMIT + 1), "String should have at most 128 characters"), + (0, "Input should be a valid string"), + (None, "Input should be a valid string"), + ], + ids=["empty_name", "space_name", "too_long_name", "invalid_name", "None_name"], + ) + def test_name_invalid(self, add_dataset_func, name, expected_message): + dataset = add_dataset_func + with pytest.raises(Exception) as excinfo: + dataset.update({"name": name}) + assert expected_message in str(excinfo.value), str(excinfo.value) + + @pytest.mark.p3 + def test_name_duplicated(self, add_datasets_func): + datasets = add_datasets_func + name = "dataset_1" + with pytest.raises(Exception) as excinfo: + datasets[0].update({"name": name}) + assert f"Dataset name '{name}' already exists" in str(excinfo.value), str(excinfo.value) + + @pytest.mark.p3 + def test_name_case_insensitive(self, add_datasets_func): + dataset = add_datasets_func[0] + name = "DATASET_1" + with pytest.raises(Exception) as excinfo: + dataset.update({"name": name}) + assert f"Dataset name '{name}' already exists" in str(excinfo.value), str(excinfo.value) + + @pytest.mark.p2 + def test_avatar(self, client, add_dataset_func, tmp_path): + dataset = add_dataset_func + fn = create_image_file(tmp_path / "ragflow_test.png") + avatar_data = f"data:image/png;base64,{encode_avatar(fn)}" + dataset.update({"avatar": avatar_data}) + assert dataset.avatar == avatar_data, str(dataset) + + retrieved_dataset = client.get_dataset(name=dataset.name) + assert retrieved_dataset.avatar == avatar_data, str(retrieved_dataset) + + @pytest.mark.p2 + def test_avatar_exceeds_limit_length(self, add_dataset_func): + dataset = add_dataset_func + with pytest.raises(Exception) as excinfo: + dataset.update({"avatar": "a" * 65536}) + assert "String should have at most 65535 characters" in str(excinfo.value), str(excinfo.value) + + @pytest.mark.p3 + @pytest.mark.parametrize( + "avatar_prefix, expected_message", + [ + ("", "Missing MIME prefix. Expected format: data:;base64,"), + ("data:image/png;base64", "Missing MIME prefix. Expected format: data:;base64,"), + ("invalid_mine_prefix:image/png;base64,", "Invalid MIME prefix format. Must start with 'data:'"), + ("data:unsupported_mine_type;base64,", "Unsupported MIME type. Allowed: ['image/jpeg', 'image/png']"), + ], + ids=["empty_prefix", "missing_comma", "unsupported_mine_type", "invalid_mine_type"], + ) + def test_avatar_invalid_prefix(self, add_dataset_func, tmp_path, avatar_prefix, expected_message): + dataset = add_dataset_func + fn = create_image_file(tmp_path / "ragflow_test.png") + with pytest.raises(Exception) as excinfo: + dataset.update({"avatar": f"{avatar_prefix}{encode_avatar(fn)}"}) + assert expected_message in str(excinfo.value), str(excinfo.value) + + @pytest.mark.p3 + def test_avatar_none(self, client, add_dataset_func): + dataset = add_dataset_func + dataset.update({"avatar": None}) + assert dataset.avatar is None, str(dataset) + + retrieved_dataset = client.get_dataset(name=dataset.name) + assert retrieved_dataset.avatar is None, str(retrieved_dataset) + + @pytest.mark.p2 + def test_description(self, client, add_dataset_func): + dataset = add_dataset_func + dataset.update({"description": "description"}) + assert dataset.description == "description", str(dataset) + + retrieved_dataset = client.get_dataset(name=dataset.name) + assert retrieved_dataset.description == "description", str(retrieved_dataset) + + @pytest.mark.p2 + def test_description_exceeds_limit_length(self, add_dataset_func): + dataset = add_dataset_func + with pytest.raises(Exception) as excinfo: + dataset.update({"description": "a" * 65536}) + assert "String should have at most 65535 characters" in str(excinfo.value), str(excinfo.value) + + @pytest.mark.p3 + def test_description_none(self, client, add_dataset_func): + dataset = add_dataset_func + dataset.update({"description": None}) + assert dataset.description is None, str(dataset) + + retrieved_dataset = client.get_dataset(name=dataset.name) + assert retrieved_dataset.description is None, str(retrieved_dataset) + + @pytest.mark.p1 + @pytest.mark.parametrize( + "embedding_model", + [ + "BAAI/bge-large-zh-v1.5@BAAI", + "maidalun1020/bce-embedding-base_v1@Youdao", + "embedding-3@ZHIPU-AI", + ], + ids=["builtin_baai", "builtin_youdao", "tenant_zhipu"], + ) + def test_embedding_model(self, client, add_dataset_func, embedding_model): + dataset = add_dataset_func + dataset.update({"embedding_model": embedding_model}) + assert dataset.embedding_model == embedding_model, str(dataset) + + retrieved_dataset = client.get_dataset(name=dataset.name) + assert retrieved_dataset.embedding_model == embedding_model, str(retrieved_dataset) + + @pytest.mark.p2 + @pytest.mark.parametrize( + "name, embedding_model", + [ + ("unknown_llm_name", "unknown@ZHIPU-AI"), + ("unknown_llm_factory", "embedding-3@unknown"), + ("tenant_no_auth_default_tenant_llm", "text-embedding-v3@Tongyi-Qianwen"), + ("tenant_no_auth", "text-embedding-3-small@OpenAI"), + ], + ids=["unknown_llm_name", "unknown_llm_factory", "tenant_no_auth_default_tenant_llm", "tenant_no_auth"], + ) + def test_embedding_model_invalid(self, add_dataset_func, name, embedding_model): + dataset = add_dataset_func + with pytest.raises(Exception) as excinfo: + dataset.update({"name": name, "embedding_model": embedding_model}) + error_msg = str(excinfo.value) + if "tenant_no_auth" in name: + assert error_msg == f"Unauthorized model: <{embedding_model}>", error_msg + else: + assert error_msg == f"Unsupported model: <{embedding_model}>", error_msg + + @pytest.mark.p2 + @pytest.mark.parametrize( + "name, embedding_model", + [ + ("missing_at", "BAAI/bge-large-zh-v1.5BAAI"), + ("missing_model_name", "@BAAI"), + ("missing_provider", "BAAI/bge-large-zh-v1.5@"), + ("whitespace_only_model_name", " @BAAI"), + ("whitespace_only_provider", "BAAI/bge-large-zh-v1.5@ "), + ], + ids=["missing_at", "empty_model_name", "empty_provider", "whitespace_only_model_name", "whitespace_only_provider"], + ) + def test_embedding_model_format(self, add_dataset_func, name, embedding_model): + dataset = add_dataset_func + with pytest.raises(Exception) as excinfo: + dataset.update({"name": name, "embedding_model": embedding_model}) + error_msg = str(excinfo.value) + if name == "missing_at": + assert "Embedding model identifier must follow @ format" in error_msg, error_msg + else: + assert "Both model_name and provider must be non-empty strings" in error_msg, error_msg + + @pytest.mark.p2 + def test_embedding_model_none(self, add_dataset_func): + dataset = add_dataset_func + with pytest.raises(Exception) as excinfo: + dataset.update({"embedding_model": None}) + assert "Input should be a valid string" in str(excinfo.value), str(excinfo.value) + + @pytest.mark.p1 + @pytest.mark.parametrize( + "permission", + [ + "me", + "team", + "ME", + "TEAM", + " ME ", + ], + ids=["me", "team", "me_upercase", "team_upercase", "whitespace"], + ) + def test_permission(self, client, add_dataset_func, permission): + dataset = add_dataset_func + dataset.update({"permission": permission}) + assert dataset.permission == permission.lower().strip(), str(dataset) + + retrieved_dataset = client.get_dataset(name=dataset.name) + assert retrieved_dataset.permission == permission.lower().strip(), str(retrieved_dataset) + + @pytest.mark.p2 + @pytest.mark.parametrize( + "permission", + [ + "", + "unknown", + list(), + ], + ids=["empty", "unknown", "type_error"], + ) + def test_permission_invalid(self, add_dataset_func, permission): + dataset = add_dataset_func + with pytest.raises(Exception) as excinfo: + dataset.update({"permission": permission}) + assert "Input should be 'me' or 'team'" in str(excinfo.value), str(excinfo.value) + + @pytest.mark.p3 + def test_permission_none(self, add_dataset_func): + dataset = add_dataset_func + with pytest.raises(Exception) as excinfo: + dataset.update({"permission": None}) + assert "Input should be 'me' or 'team'" in str(excinfo.value), str(excinfo.value) + + @pytest.mark.p1 + @pytest.mark.parametrize( + "chunk_method", + [ + "naive", + "book", + "email", + "laws", + "manual", + "one", + "paper", + "picture", + "presentation", + "qa", + "table", + "tag", + ], + ids=["naive", "book", "email", "laws", "manual", "one", "paper", "picture", "presentation", "qa", "table", "tag"], + ) + def test_chunk_method(self, client, add_dataset_func, chunk_method): + dataset = add_dataset_func + dataset.update({"chunk_method": chunk_method}) + assert dataset.chunk_method == chunk_method, str(dataset) + + retrieved_dataset = client.get_dataset(name=dataset.name) + assert retrieved_dataset.chunk_method == chunk_method, str(retrieved_dataset) + + @pytest.mark.p2 + @pytest.mark.parametrize( + "chunk_method", + [ + "", + "unknown", + list(), + ], + ids=["empty", "unknown", "type_error"], + ) + def test_chunk_method_invalid(self, add_dataset_func, chunk_method): + dataset = add_dataset_func + with pytest.raises(Exception) as excinfo: + dataset.update({"chunk_method": chunk_method}) + assert "Input should be 'naive', 'book', 'email', 'laws', 'manual', 'one', 'paper', 'picture', 'presentation', 'qa', 'table' or 'tag'" in str(excinfo.value), str(excinfo.value) + + @pytest.mark.p3 + def test_chunk_method_none(self, add_dataset_func): + dataset = add_dataset_func + with pytest.raises(Exception) as excinfo: + dataset.update({"chunk_method": None}) + assert "Input should be 'naive', 'book', 'email', 'laws', 'manual', 'one', 'paper', 'picture', 'presentation', 'qa', 'table' or 'tag'" in str(excinfo.value), str(excinfo.value) + + @pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="#8208") + @pytest.mark.p2 + @pytest.mark.parametrize("pagerank", [0, 50, 100], ids=["min", "mid", "max"]) + def test_pagerank(self, client, add_dataset_func, pagerank): + dataset = add_dataset_func + dataset.update({"pagerank": pagerank}) + assert dataset.pagerank == pagerank, str(dataset) + + retrieved_dataset = client.get_dataset(name=dataset.name) + assert retrieved_dataset.pagerank == pagerank, str(retrieved_dataset) + + @pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="#8208") + @pytest.mark.p2 + def test_pagerank_set_to_0(self, client, add_dataset_func): + dataset = add_dataset_func + dataset.update({"pagerank": 50}) + assert dataset.pagerank == 50, str(dataset) + + retrieved_dataset = client.get_dataset(name=dataset.name) + assert retrieved_dataset.pagerank == 50, str(retrieved_dataset) + + dataset.update({"pagerank": 0}) + assert dataset.pagerank == 0, str(dataset) + + retrieved_dataset = client.get_dataset(name=dataset.name) + assert retrieved_dataset.pagerank == 0, str(retrieved_dataset) + + @pytest.mark.skipif(os.getenv("DOC_ENGINE") != "infinity", reason="#8208") + @pytest.mark.p2 + def test_pagerank_infinity(self, client, add_dataset_func): + dataset = add_dataset_func + with pytest.raises(Exception) as excinfo: + dataset.update({"pagerank": 50}) + assert "'pagerank' can only be set when doc_engine is elasticsearch" in str(excinfo.value), str(excinfo.value) + + @pytest.mark.p2 + @pytest.mark.parametrize( + "pagerank, expected_message", + [ + (-1, "Input should be greater than or equal to 0"), + (101, "Input should be less than or equal to 100"), + ], + ids=["min_limit", "max_limit"], + ) + def test_pagerank_invalid(self, add_dataset_func, pagerank, expected_message): + dataset = add_dataset_func + with pytest.raises(Exception) as excinfo: + dataset.update({"pagerank": pagerank}) + assert expected_message in str(excinfo.value), str(excinfo.value) + + @pytest.mark.p3 + def test_pagerank_none(self, add_dataset_func): + dataset = add_dataset_func + with pytest.raises(Exception) as excinfo: + dataset.update({"pagerank": None}) + assert "Input should be a valid integer" in str(excinfo.value), str(excinfo.value) + + @pytest.mark.p1 + @pytest.mark.parametrize( + "parser_config", + [ + {"auto_keywords": 0}, + {"auto_keywords": 16}, + {"auto_keywords": 32}, + {"auto_questions": 0}, + {"auto_questions": 5}, + {"auto_questions": 10}, + {"chunk_token_num": 1}, + {"chunk_token_num": 1024}, + {"chunk_token_num": 2048}, + {"delimiter": "\n"}, + {"delimiter": " "}, + {"html4excel": True}, + {"html4excel": False}, + {"layout_recognize": "DeepDOC"}, + {"layout_recognize": "Plain Text"}, + {"tag_kb_ids": ["1", "2"]}, + {"topn_tags": 1}, + {"topn_tags": 5}, + {"topn_tags": 10}, + {"filename_embd_weight": 0.1}, + {"filename_embd_weight": 0.5}, + {"filename_embd_weight": 1.0}, + {"task_page_size": 1}, + {"task_page_size": None}, + {"pages": [[1, 100]]}, + {"pages": None}, + {"graphrag": {"use_graphrag": True}}, + {"graphrag": {"use_graphrag": False}}, + {"graphrag": {"entity_types": ["age", "sex", "height", "weight"]}}, + {"graphrag": {"method": "general"}}, + {"graphrag": {"method": "light"}}, + {"graphrag": {"community": True}}, + {"graphrag": {"community": False}}, + {"graphrag": {"resolution": True}}, + {"graphrag": {"resolution": False}}, + {"raptor": {"use_raptor": True}}, + {"raptor": {"use_raptor": False}}, + {"raptor": {"prompt": "Who are you?"}}, + {"raptor": {"max_token": 1}}, + {"raptor": {"max_token": 1024}}, + {"raptor": {"max_token": 2048}}, + {"raptor": {"threshold": 0.0}}, + {"raptor": {"threshold": 0.5}}, + {"raptor": {"threshold": 1.0}}, + {"raptor": {"max_cluster": 1}}, + {"raptor": {"max_cluster": 512}}, + {"raptor": {"max_cluster": 1024}}, + {"raptor": {"random_seed": 0}}, + ], + ids=[ + "auto_keywords_min", + "auto_keywords_mid", + "auto_keywords_max", + "auto_questions_min", + "auto_questions_mid", + "auto_questions_max", + "chunk_token_num_min", + "chunk_token_num_mid", + "chunk_token_num_max", + "delimiter", + "delimiter_space", + "html4excel_true", + "html4excel_false", + "layout_recognize_DeepDOC", + "layout_recognize_navie", + "tag_kb_ids", + "topn_tags_min", + "topn_tags_mid", + "topn_tags_max", + "filename_embd_weight_min", + "filename_embd_weight_mid", + "filename_embd_weight_max", + "task_page_size_min", + "task_page_size_None", + "pages", + "pages_none", + "graphrag_true", + "graphrag_false", + "graphrag_entity_types", + "graphrag_method_general", + "graphrag_method_light", + "graphrag_community_true", + "graphrag_community_false", + "graphrag_resolution_true", + "graphrag_resolution_false", + "raptor_true", + "raptor_false", + "raptor_prompt", + "raptor_max_token_min", + "raptor_max_token_mid", + "raptor_max_token_max", + "raptor_threshold_min", + "raptor_threshold_mid", + "raptor_threshold_max", + "raptor_max_cluster_min", + "raptor_max_cluster_mid", + "raptor_max_cluster_max", + "raptor_random_seed_min", + ], + ) + def test_parser_config(self, client, add_dataset_func, parser_config): + dataset = add_dataset_func + dataset.update({"parser_config": parser_config}) + for k, v in parser_config.items(): + if isinstance(v, dict): + for kk, vv in v.items(): + assert attrgetter(f"{k}.{kk}")(dataset.parser_config) == vv, str(dataset) + else: + assert attrgetter(k)(dataset.parser_config) == v, str(dataset) + + retrieved_dataset = client.get_dataset(name=dataset.name) + for k, v in parser_config.items(): + if isinstance(v, dict): + for kk, vv in v.items(): + assert attrgetter(f"{k}.{kk}")(retrieved_dataset.parser_config) == vv, str(retrieved_dataset) + else: + assert attrgetter(k)(retrieved_dataset.parser_config) == v, str(retrieved_dataset) + + @pytest.mark.p2 + @pytest.mark.parametrize( + "parser_config, expected_message", + [ + ({"auto_keywords": -1}, "Input should be greater than or equal to 0"), + ({"auto_keywords": 33}, "Input should be less than or equal to 32"), + ({"auto_keywords": 3.14}, "Input should be a valid integer, got a number with a fractional part"), + ({"auto_keywords": "string"}, "Input should be a valid integer, unable to parse string as an integer"), + ({"auto_questions": -1}, "Input should be greater than or equal to 0"), + ({"auto_questions": 11}, "Input should be less than or equal to 10"), + ({"auto_questions": 3.14}, "Input should be a valid integer, got a number with a fractional part"), + ({"auto_questions": "string"}, "Input should be a valid integer, unable to parse string as an integer"), + ({"chunk_token_num": 0}, "Input should be greater than or equal to 1"), + ({"chunk_token_num": 2049}, "Input should be less than or equal to 2048"), + ({"chunk_token_num": 3.14}, "Input should be a valid integer, got a number with a fractional part"), + ({"chunk_token_num": "string"}, "Input should be a valid integer, unable to parse string as an integer"), + ({"delimiter": ""}, "String should have at least 1 character"), + ({"html4excel": "string"}, "Input should be a valid boolean, unable to interpret input"), + ({"tag_kb_ids": "1,2"}, "Input should be a valid list"), + ({"tag_kb_ids": [1, 2]}, "Input should be a valid string"), + ({"topn_tags": 0}, "Input should be greater than or equal to 1"), + ({"topn_tags": 11}, "Input should be less than or equal to 10"), + ({"topn_tags": 3.14}, "Input should be a valid integer, got a number with a fractional part"), + ({"topn_tags": "string"}, "Input should be a valid integer, unable to parse string as an integer"), + ({"filename_embd_weight": -1}, "Input should be greater than or equal to 0"), + ({"filename_embd_weight": 1.1}, "Input should be less than or equal to 1"), + ({"filename_embd_weight": "string"}, "Input should be a valid number, unable to parse string as a number"), + ({"task_page_size": 0}, "Input should be greater than or equal to 1"), + ({"task_page_size": 3.14}, "Input should be a valid integer, got a number with a fractional part"), + ({"task_page_size": "string"}, "Input should be a valid integer, unable to parse string as an integer"), + ({"pages": "1,2"}, "Input should be a valid list"), + ({"pages": ["1,2"]}, "Input should be a valid list"), + ({"pages": [["string1", "string2"]]}, "Input should be a valid integer, unable to parse string as an integer"), + ({"graphrag": {"use_graphrag": "string"}}, "Input should be a valid boolean, unable to interpret input"), + ({"graphrag": {"entity_types": "1,2"}}, "Input should be a valid list"), + ({"graphrag": {"entity_types": [1, 2]}}, "nput should be a valid string"), + ({"graphrag": {"method": "unknown"}}, "Input should be 'light' or 'general'"), + ({"graphrag": {"method": None}}, "Input should be 'light' or 'general'"), + ({"graphrag": {"community": "string"}}, "Input should be a valid boolean, unable to interpret input"), + ({"graphrag": {"resolution": "string"}}, "Input should be a valid boolean, unable to interpret input"), + ({"raptor": {"use_raptor": "string"}}, "Input should be a valid boolean, unable to interpret input"), + ({"raptor": {"prompt": ""}}, "String should have at least 1 character"), + ({"raptor": {"prompt": " "}}, "String should have at least 1 character"), + ({"raptor": {"max_token": 0}}, "Input should be greater than or equal to 1"), + ({"raptor": {"max_token": 2049}}, "Input should be less than or equal to 2048"), + ({"raptor": {"max_token": 3.14}}, "Input should be a valid integer, got a number with a fractional part"), + ({"raptor": {"max_token": "string"}}, "Input should be a valid integer, unable to parse string as an integer"), + ({"raptor": {"threshold": -0.1}}, "Input should be greater than or equal to 0"), + ({"raptor": {"threshold": 1.1}}, "Input should be less than or equal to 1"), + ({"raptor": {"threshold": "string"}}, "Input should be a valid number, unable to parse string as a number"), + ({"raptor": {"max_cluster": 0}}, "Input should be greater than or equal to 1"), + ({"raptor": {"max_cluster": 1025}}, "Input should be less than or equal to 1024"), + ({"raptor": {"max_cluster": 3.14}}, "Input should be a valid integer, got a number with a fractional par"), + ({"raptor": {"max_cluster": "string"}}, "Input should be a valid integer, unable to parse string as an integer"), + ({"raptor": {"random_seed": -1}}, "Input should be greater than or equal to 0"), + ({"raptor": {"random_seed": 3.14}}, "Input should be a valid integer, got a number with a fractional part"), + ({"raptor": {"random_seed": "string"}}, "Input should be a valid integer, unable to parse string as an integer"), + ({"delimiter": "a" * 65536}, "Parser config exceeds size limit (max 65,535 characters)"), + ], + ids=[ + "auto_keywords_min_limit", + "auto_keywords_max_limit", + "auto_keywords_float_not_allowed", + "auto_keywords_type_invalid", + "auto_questions_min_limit", + "auto_questions_max_limit", + "auto_questions_float_not_allowed", + "auto_questions_type_invalid", + "chunk_token_num_min_limit", + "chunk_token_num_max_limit", + "chunk_token_num_float_not_allowed", + "chunk_token_num_type_invalid", + "delimiter_empty", + "html4excel_type_invalid", + "tag_kb_ids_not_list", + "tag_kb_ids_int_in_list", + "topn_tags_min_limit", + "topn_tags_max_limit", + "topn_tags_float_not_allowed", + "topn_tags_type_invalid", + "filename_embd_weight_min_limit", + "filename_embd_weight_max_limit", + "filename_embd_weight_type_invalid", + "task_page_size_min_limit", + "task_page_size_float_not_allowed", + "task_page_size_type_invalid", + "pages_not_list", + "pages_not_list_in_list", + "pages_not_int_list", + "graphrag_type_invalid", + "graphrag_entity_types_not_list", + "graphrag_entity_types_not_str_in_list", + "graphrag_method_unknown", + "graphrag_method_none", + "graphrag_community_type_invalid", + "graphrag_resolution_type_invalid", + "raptor_type_invalid", + "raptor_prompt_empty", + "raptor_prompt_space", + "raptor_max_token_min_limit", + "raptor_max_token_max_limit", + "raptor_max_token_float_not_allowed", + "raptor_max_token_type_invalid", + "raptor_threshold_min_limit", + "raptor_threshold_max_limit", + "raptor_threshold_type_invalid", + "raptor_max_cluster_min_limit", + "raptor_max_cluster_max_limit", + "raptor_max_cluster_float_not_allowed", + "raptor_max_cluster_type_invalid", + "raptor_random_seed_min_limit", + "raptor_random_seed_float_not_allowed", + "raptor_random_seed_type_invalid", + "parser_config_type_invalid", + ], + ) + def test_parser_config_invalid(self, add_dataset_func, parser_config, expected_message): + dataset = add_dataset_func + with pytest.raises(Exception) as excinfo: + dataset.update({"parser_config": parser_config}) + assert expected_message in str(excinfo.value), str(excinfo.value) + + @pytest.mark.p2 + def test_parser_config_empty(self, client, add_dataset_func): + dataset = add_dataset_func + expected_config = DataSet.ParserConfig( + client, + { + "chunk_token_num": 128, + "delimiter": r"\n", + "html4excel": False, + "layout_recognize": "DeepDOC", + "raptor": {"use_raptor": False}, + }, + ) + dataset.update({"parser_config": {}}) + assert str(dataset.parser_config) == str(expected_config), str(dataset) + + retrieved_dataset = client.get_dataset(name=dataset.name) + assert str(retrieved_dataset.parser_config) == str(expected_config), str(retrieved_dataset) + + @pytest.mark.p3 + def test_parser_config_none(self, client, add_dataset_func): + dataset = add_dataset_func + expected_config = DataSet.ParserConfig( + client, + { + "chunk_token_num": 128, + "delimiter": r"\n", + "html4excel": False, + "layout_recognize": "DeepDOC", + "raptor": {"use_raptor": False}, + }, + ) + dataset.update({"parser_config": None}) + assert str(dataset.parser_config) == str(expected_config), str(dataset) + + retrieved_dataset = client.get_dataset(name=dataset.name) + assert str(retrieved_dataset.parser_config) == str(expected_config), str(retrieved_dataset) + + @pytest.mark.p3 + def test_parser_config_empty_with_chunk_method_change(self, client, add_dataset_func): + dataset = add_dataset_func + expected_config = DataSet.ParserConfig( + client, + { + "raptor": {"use_raptor": False}, + }, + ) + dataset.update({"chunk_method": "qa", "parser_config": {}}) + assert str(dataset.parser_config) == str(expected_config), str(dataset) + + retrieved_dataset = client.get_dataset(name=dataset.name) + assert str(retrieved_dataset.parser_config) == str(expected_config), str(retrieved_dataset) + + @pytest.mark.p3 + def test_parser_config_unset_with_chunk_method_change(self, client, add_dataset_func): + dataset = add_dataset_func + expected_config = DataSet.ParserConfig( + client, + { + "raptor": {"use_raptor": False}, + }, + ) + dataset.update({"chunk_method": "qa"}) + assert str(dataset.parser_config) == str(expected_config), str(dataset) + + retrieved_dataset = client.get_dataset(name=dataset.name) + assert str(retrieved_dataset.parser_config) == str(expected_config), str(retrieved_dataset) + + @pytest.mark.p3 + def test_parser_config_none_with_chunk_method_change(self, client, add_dataset_func): + dataset = add_dataset_func + expected_config = DataSet.ParserConfig( + client, + { + "raptor": {"use_raptor": False}, + }, + ) + dataset.update({"chunk_method": "qa", "parser_config": None}) + assert str(dataset.parser_config) == str(expected_config), str(dataset) + + retrieved_dataset = client.get_dataset(name=dataset.name) + assert str(retrieved_dataset.parser_config) == str(expected_config), str(retrieved_dataset) + + @pytest.mark.p2 + @pytest.mark.parametrize( + "payload", + [ + {"id": "id"}, + {"tenant_id": "e57c1966f99211efb41e9e45646e0111"}, + {"created_by": "created_by"}, + {"create_date": "Tue, 11 Mar 2025 13:37:23 GMT"}, + {"create_time": 1741671443322}, + {"update_date": "Tue, 11 Mar 2025 13:37:23 GMT"}, + {"update_time": 1741671443339}, + {"document_count": 1}, + {"chunk_count": 1}, + {"token_num": 1}, + {"status": "1"}, + {"unknown_field": "unknown_field"}, + ], + ) + def test_field_unsupported(self, add_dataset_func, payload): + dataset = add_dataset_func + with pytest.raises(Exception) as excinfo: + dataset.update(payload) + assert "Extra inputs are not permitted" in str(excinfo.value), str(excinfo.value) + + @pytest.mark.p2 + def test_field_unset(self, client, add_dataset_func): + dataset = add_dataset_func + original_dataset = client.get_dataset(name=dataset.name) + + dataset.update({"name": "default_unset"}) + + updated_dataset = client.get_dataset(name="default_unset") + assert updated_dataset.avatar == original_dataset.avatar, str(updated_dataset) + assert updated_dataset.description == original_dataset.description, str(updated_dataset) + assert updated_dataset.embedding_model == original_dataset.embedding_model, str(updated_dataset) + assert updated_dataset.permission == original_dataset.permission, str(updated_dataset) + assert updated_dataset.chunk_method == original_dataset.chunk_method, str(updated_dataset) + assert updated_dataset.pagerank == original_dataset.pagerank, str(updated_dataset) + assert str(updated_dataset.parser_config) == str(original_dataset.parser_config), str(updated_dataset) diff --git a/test/testcases/test_sdk_api/test_file_management_within_dataset/conftest.py b/test/testcases/test_sdk_api/test_file_management_within_dataset/conftest.py new file mode 100644 index 00000000000..32be9683a5b --- /dev/null +++ b/test/testcases/test_sdk_api/test_file_management_within_dataset/conftest.py @@ -0,0 +1,57 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +import pytest +from common import bulk_upload_documents +from pytest import FixtureRequest +from ragflow_sdk import DataSet, Document + + +@pytest.fixture(scope="function") +def add_document_func(request: FixtureRequest, add_dataset: DataSet, ragflow_tmp_dir) -> tuple[DataSet, Document]: + dataset = add_dataset + documents = bulk_upload_documents(dataset, 1, ragflow_tmp_dir) + + def cleanup(): + dataset.delete_documents(ids=None) + + request.addfinalizer(cleanup) + return dataset, documents[0] + + +@pytest.fixture(scope="class") +def add_documents(request: FixtureRequest, add_dataset: DataSet, ragflow_tmp_dir) -> tuple[DataSet, list[Document]]: + dataset = add_dataset + documents = bulk_upload_documents(dataset, 5, ragflow_tmp_dir) + + def cleanup(): + dataset.delete_documents(ids=None) + + request.addfinalizer(cleanup) + return dataset, documents + + +@pytest.fixture(scope="function") +def add_documents_func(request: FixtureRequest, add_dataset_func: DataSet, ragflow_tmp_dir) -> tuple[DataSet, list[Document]]: + dataset = add_dataset_func + documents = bulk_upload_documents(dataset, 3, ragflow_tmp_dir) + + def cleanup(): + dataset.delete_documents(ids=None) + + request.addfinalizer(cleanup) + return dataset, documents diff --git a/test/testcases/test_sdk_api/test_file_management_within_dataset/test_delete_documents.py b/test/testcases/test_sdk_api/test_file_management_within_dataset/test_delete_documents.py new file mode 100644 index 00000000000..c90d4294e13 --- /dev/null +++ b/test/testcases/test_sdk_api/test_file_management_within_dataset/test_delete_documents.py @@ -0,0 +1,118 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from concurrent.futures import ThreadPoolExecutor, as_completed + +import pytest +from common import bulk_upload_documents + + +class TestDocumentsDeletion: + @pytest.mark.p1 + @pytest.mark.parametrize( + "payload, expected_message, remaining", + [ + ({"ids": None}, "", 0), + ({"ids": []}, "", 0), + ({"ids": ["invalid_id"]}, "Documents not found: ['invalid_id']", 3), + ({"ids": ["\n!?。;!?\"'"]}, "Documents not found: ['\\n!?。;!?\"\\'']", 3), + ("not json", "must be a mapping", 3), + (lambda r: {"ids": r[:1]}, "", 2), + (lambda r: {"ids": r}, "", 0), + ], + ) + def test_basic_scenarios( + self, + add_documents_func, + payload, + expected_message, + remaining, + ): + dataset, documents = add_documents_func + if callable(payload): + payload = payload([document.id for document in documents]) + + if expected_message: + with pytest.raises(Exception) as excinfo: + dataset.delete_documents(**payload) + assert expected_message in str(excinfo.value), str(excinfo.value) + else: + dataset.delete_documents(**payload) + + documents = dataset.list_documents() + assert len(documents) == remaining, str(documents) + + @pytest.mark.p2 + @pytest.mark.parametrize( + "payload", + [ + lambda r: {"ids": ["invalid_id"] + r}, + lambda r: {"ids": r[:1] + ["invalid_id"] + r[1:3]}, + lambda r: {"ids": r + ["invalid_id"]}, + ], + ) + def test_delete_partial_invalid_id(self, add_documents_func, payload): + dataset, documents = add_documents_func + payload = payload([document.id for document in documents]) + + with pytest.raises(Exception) as excinfo: + dataset.delete_documents(**payload) + assert "Documents not found: ['invalid_id']" in str(excinfo.value), str(excinfo.value) + + documents = dataset.list_documents() + assert len(documents) == 0, str(documents) + + @pytest.mark.p2 + def test_repeated_deletion(self, add_documents_func): + dataset, documents = add_documents_func + document_ids = [document.id for document in documents] + dataset.delete_documents(ids=document_ids) + with pytest.raises(Exception) as excinfo: + dataset.delete_documents(ids=document_ids) + assert "Documents not found" in str(excinfo.value), str(excinfo.value) + + @pytest.mark.p2 + def test_duplicate_deletion(self, add_documents_func): + dataset, documents = add_documents_func + document_ids = [document.id for document in documents] + dataset.delete_documents(ids=document_ids + document_ids) + assert len(dataset.list_documents()) == 0, str(dataset.list_documents()) + + +@pytest.mark.p3 +def test_concurrent_deletion(add_dataset, tmp_path): + count = 100 + dataset = add_dataset + documents = bulk_upload_documents(dataset, count, tmp_path) + + def delete_doc(doc_id): + dataset.delete_documents(ids=[doc_id]) + + with ThreadPoolExecutor(max_workers=5) as executor: + futures = [executor.submit(delete_doc, doc.id) for doc in documents] + + responses = list(as_completed(futures)) + assert len(responses) == count, responses + + +@pytest.mark.p3 +def test_delete_1k(add_dataset, tmp_path): + count = 1_000 + dataset = add_dataset + documents = bulk_upload_documents(dataset, count, tmp_path) + assert len(dataset.list_documents(page_size=count * 2)) == count + + dataset.delete_documents(ids=[doc.id for doc in documents]) + assert len(dataset.list_documents()) == 0 diff --git a/test/testcases/test_sdk_api/test_file_management_within_dataset/test_download_document.py b/test/testcases/test_sdk_api/test_file_management_within_dataset/test_download_document.py new file mode 100644 index 00000000000..3c9169fbb70 --- /dev/null +++ b/test/testcases/test_sdk_api/test_file_management_within_dataset/test_download_document.py @@ -0,0 +1,89 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from concurrent.futures import ThreadPoolExecutor, as_completed + +import pytest +from common import bulk_upload_documents +from utils import compare_by_hash + + +@pytest.mark.p1 +@pytest.mark.parametrize( + "generate_test_files", + [ + "docx", + "excel", + "ppt", + "image", + "pdf", + "txt", + "md", + "json", + "eml", + "html", + ], + indirect=True, +) +def test_file_type_validation(add_dataset, generate_test_files, request): + dataset = add_dataset + fp = generate_test_files[request.node.callspec.params["generate_test_files"]] + with fp.open("rb") as f: + blob = f.read() + + documents = dataset.upload_documents([{"display_name": fp.name, "blob": blob}]) + + for document in documents: + with fp.with_stem("ragflow_test_download").open("wb") as f: + f.write(document.download()) + + assert compare_by_hash(fp, fp.with_stem("ragflow_test_download")) + + +class TestDocumentDownload: + @pytest.mark.p3 + def test_same_file_repeat(self, add_documents, tmp_path, ragflow_tmp_dir): + num = 5 + _, documents = add_documents + + for i in range(num): + download_path = tmp_path / f"ragflow_test_download_{i}.txt" + with download_path.open("wb") as f: + f.write(documents[0].download()) + assert compare_by_hash(ragflow_tmp_dir / "ragflow_test_upload_0.txt", download_path), f"Downloaded file {i} does not match original" + + +@pytest.mark.p3 +def test_concurrent_download(add_dataset, tmp_path): + count = 20 + dataset = add_dataset + documents = bulk_upload_documents(dataset, count, tmp_path) + + def download_doc(document, i): + download_path = tmp_path / f"ragflow_test_download_{i}.txt" + with download_path.open("wb") as f: + f.write(document.download()) + # assert compare_by_hash(tmp_path / f"ragflow_test_upload_{i}.txt", download_path), str(download_path) + + with ThreadPoolExecutor(max_workers=5) as executor: + futures = [executor.submit(download_doc, documents[i], i) for i in range(count)] + responses = list(as_completed(futures)) + assert len(responses) == count, responses + + for i in range(count): + assert compare_by_hash( + tmp_path / f"ragflow_test_upload_{i}.txt", + tmp_path / f"ragflow_test_download_{i}.txt", + ) diff --git a/test/testcases/test_sdk_api/test_file_management_within_dataset/test_list_documents.py b/test/testcases/test_sdk_api/test_file_management_within_dataset/test_list_documents.py new file mode 100644 index 00000000000..189093da267 --- /dev/null +++ b/test/testcases/test_sdk_api/test_file_management_within_dataset/test_list_documents.py @@ -0,0 +1,247 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from concurrent.futures import ThreadPoolExecutor, as_completed + +import pytest + + +class TestDocumentsList: + @pytest.mark.p1 + def test_default(self, add_documents): + dataset, _ = add_documents + documents = dataset.list_documents() + assert len(documents) == 5, str(documents) + + @pytest.mark.p1 + @pytest.mark.parametrize( + "params, expected_page_size, expected_message", + [ + ({"page": None, "page_size": 2}, 2, "not instance of"), + ({"page": 0, "page_size": 2}, 2, ""), + ({"page": 2, "page_size": 2}, 2, ""), + ({"page": 3, "page_size": 2}, 1, ""), + ({"page": "3", "page_size": 2}, 1, "not instance of"), + pytest.param( + {"page": -1, "page_size": 2}, + 0, + "Invalid page number", + marks=pytest.mark.skip(reason="issues/5851"), + ), + pytest.param( + {"page": "a", "page_size": 2}, + 0, + "Invalid page value", + marks=pytest.mark.skip(reason="issues/5851"), + ), + ], + ) + def test_page(self, add_documents, params, expected_page_size, expected_message): + dataset, _ = add_documents + if expected_message: + with pytest.raises(Exception) as excinfo: + dataset.list_documents(**params) + assert expected_message in str(excinfo.value), str(excinfo.value) + else: + documents = dataset.list_documents(**params) + assert len(documents) == expected_page_size, str(documents) + + @pytest.mark.p1 + @pytest.mark.parametrize( + "params, expected_page_size, expected_message", + [ + ({"page_size": None}, 5, "not instance of"), + ({"page_size": 0}, 0, ""), + ({"page_size": 1}, 1, ""), + ({"page_size": 6}, 5, ""), + ({"page_size": "1"}, 1, "not instance of"), + pytest.param( + {"page_size": -1}, + 0, + "Invalid page size", + marks=pytest.mark.skip(reason="issues/5851"), + ), + pytest.param( + {"page_size": "a"}, + 0, + "Invalid page size value", + marks=pytest.mark.skip(reason="issues/5851"), + ), + ], + ) + def test_page_size(self, add_documents, params, expected_page_size, expected_message): + dataset, _ = add_documents + if expected_message: + with pytest.raises(Exception) as excinfo: + dataset.list_documents(**params) + assert expected_message in str(excinfo.value), str(excinfo.value) + else: + documents = dataset.list_documents(**params) + assert len(documents) == expected_page_size, str(documents) + + @pytest.mark.p3 + @pytest.mark.parametrize( + "params, expected_message", + [ + ({"orderby": None}, "not instance of"), + ({"orderby": "create_time"}, ""), + ({"orderby": "update_time"}, ""), + pytest.param({"orderby": "name", "desc": "False"}, "", marks=pytest.mark.skip(reason="issues/5851")), + pytest.param({"orderby": "unknown"}, "orderby should be create_time or update_time", marks=pytest.mark.skip(reason="issues/5851")), + ], + ) + def test_orderby(self, add_documents, params, expected_message): + dataset, _ = add_documents + if expected_message: + with pytest.raises(Exception) as excinfo: + dataset.list_documents(**params) + assert expected_message in str(excinfo.value), str(excinfo.value) + else: + dataset.list_documents(**params) + + @pytest.mark.p3 + @pytest.mark.parametrize( + "params, expected_message", + [ + ({"desc": None}, "not instance of"), + ({"desc": "true"}, "not instance of"), + ({"desc": "True"}, "not instance of"), + ({"desc": True}, ""), + pytest.param({"desc": "false"}, "", marks=pytest.mark.skip(reason="issues/5851")), + ({"desc": "False"}, "not instance of"), + ({"desc": False}, ""), + ({"desc": "False", "orderby": "update_time"}, "not instance of"), + pytest.param({"desc": "unknown"}, "desc should be true or false", marks=pytest.mark.skip(reason="issues/5851")), + ], + ) + def test_desc(self, add_documents, params, expected_message): + dataset, _ = add_documents + if expected_message: + with pytest.raises(Exception) as excinfo: + dataset.list_documents(**params) + assert expected_message in str(excinfo.value), str(excinfo.value) + else: + dataset.list_documents(**params) + + @pytest.mark.p2 + @pytest.mark.parametrize( + "params, expected_num", + [ + ({"keywords": None}, 5), + ({"keywords": ""}, 5), + ({"keywords": "0"}, 1), + ({"keywords": "ragflow_test_upload"}, 5), + ({"keywords": "unknown"}, 0), + ], + ) + def test_keywords(self, add_documents, params, expected_num): + dataset, _ = add_documents + documents = dataset.list_documents(**params) + assert len(documents) == expected_num, str(documents) + + @pytest.mark.p1 + @pytest.mark.parametrize( + "params, expected_num, expected_message", + [ + ({"name": None}, 5, ""), + ({"name": ""}, 5, ""), + ({"name": "ragflow_test_upload_0.txt"}, 1, ""), + ({"name": "unknown.txt"}, 0, "You don't own the document unknown.txt"), + ], + ) + def test_name(self, add_documents, params, expected_num, expected_message): + dataset, _ = add_documents + if expected_message: + with pytest.raises(Exception) as excinfo: + dataset.list_documents(**params) + assert expected_message in str(excinfo.value), str(excinfo.value) + else: + documents = dataset.list_documents(**params) + assert len(documents) == expected_num, str(documents) + if params["name"] not in [None, ""]: + assert documents[0].name == params["name"], str(documents) + + @pytest.mark.p1 + @pytest.mark.parametrize( + "document_id, expected_num, expected_message", + [ + (None, 5, ""), + ("", 5, ""), + (lambda docs: docs[0].id, 1, ""), + ("unknown.txt", 0, "You don't own the document unknown.txt"), + ], + ) + def test_id(self, add_documents, document_id, expected_num, expected_message): + dataset, documents = add_documents + if callable(document_id): + params = {"id": document_id(documents)} + else: + params = {"id": document_id} + + if expected_message: + with pytest.raises(Exception) as excinfo: + dataset.list_documents(**params) + assert expected_message in str(excinfo.value), str(excinfo.value) + else: + documents = dataset.list_documents(**params) + assert len(documents) == expected_num, str(documents) + if params["id"] not in [None, ""]: + assert documents[0].id == params["id"], str(documents) + + @pytest.mark.p3 + @pytest.mark.parametrize( + "document_id, name, expected_num, expected_message", + [ + (lambda docs: docs[0].id, "ragflow_test_upload_0.txt", 1, ""), + (lambda docs: docs[0].id, "ragflow_test_upload_1.txt", 0, ""), + (lambda docs: docs[0].id, "unknown", 0, "You don't own the document unknown"), + ("invalid_id", "ragflow_test_upload_0.txt", 0, "You don't own the document invalid_id"), + ], + ) + def test_name_and_id(self, add_documents, document_id, name, expected_num, expected_message): + dataset, documents = add_documents + params = {"id": document_id(documents) if callable(document_id) else document_id, "name": name} + + if expected_message: + with pytest.raises(Exception) as excinfo: + dataset.list_documents(**params) + assert expected_message in str(excinfo.value), str(excinfo.value) + else: + documents = dataset.list_documents(**params) + assert len(documents) == expected_num, str(documents) + + @pytest.mark.p3 + def test_concurrent_list(self, add_documents): + dataset, _ = add_documents + count = 100 + + def list_docs(): + return dataset.list_documents() + + with ThreadPoolExecutor(max_workers=5) as executor: + futures = [executor.submit(list_docs) for _ in range(count)] + responses = list(as_completed(futures)) + assert len(responses) == count, responses + for future in futures: + docs = future.result() + assert len(docs) == 5, str(docs) + + @pytest.mark.p3 + def test_invalid_params(self, add_documents): + dataset, _ = add_documents + params = {"a": "b"} + with pytest.raises(TypeError) as excinfo: + dataset.list_documents(**params) + assert "got an unexpected keyword argument" in str(excinfo.value), str(excinfo.value) diff --git a/test/testcases/test_sdk_api/test_file_management_within_dataset/test_parse_documents.py b/test/testcases/test_sdk_api/test_file_management_within_dataset/test_parse_documents.py new file mode 100644 index 00000000000..9b76bf6d1d4 --- /dev/null +++ b/test/testcases/test_sdk_api/test_file_management_within_dataset/test_parse_documents.py @@ -0,0 +1,162 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from concurrent.futures import ThreadPoolExecutor, as_completed + +import pytest +from common import bulk_upload_documents +from ragflow_sdk import DataSet +from utils import wait_for + + +@wait_for(30, 1, "Document parsing timeout") +def condition(_dataset: DataSet, _document_ids: list[str] = None): + documents = _dataset.list_documents(page_size=1000) + + if _document_ids is None: + for document in documents: + if document.run != "DONE": + return False + return True + + target_ids = set(_document_ids) + for document in documents: + if document.id in target_ids: + if document.run != "DONE": + return False + return True + + +def validate_document_details(dataset, document_ids): + documents = dataset.list_documents(page_size=1000) + for document in documents: + if document.id in document_ids: + assert document.run == "DONE" + assert len(document.process_begin_at) > 0 + assert document.process_duation > 0 + assert document.progress > 0 + assert "Task done" in document.progress_msg + + +class TestDocumentsParse: + @pytest.mark.parametrize( + "payload, expected_message", + [ + pytest.param(None, "AttributeError", marks=pytest.mark.skip), + pytest.param({"document_ids": []}, "`document_ids` is required", marks=pytest.mark.p1), + pytest.param({"document_ids": ["invalid_id"]}, "Documents not found: ['invalid_id']", marks=pytest.mark.p3), + pytest.param({"document_ids": ["\n!?。;!?\"'"]}, "Documents not found: ['\\n!?。;!?\"\\'']", marks=pytest.mark.p3), + pytest.param("not json", "AttributeError", marks=pytest.mark.skip), + pytest.param(lambda r: {"document_ids": r[:1]}, "", marks=pytest.mark.p1), + pytest.param(lambda r: {"document_ids": r}, "", marks=pytest.mark.p1), + ], + ) + def test_basic_scenarios(self, add_documents_func, payload, expected_message): + dataset, documents = add_documents_func + if callable(payload): + payload = payload([doc.id for doc in documents]) + + if expected_message: + with pytest.raises(Exception) as excinfo: + dataset.async_parse_documents(**payload) + assert expected_message in str(excinfo.value), str(excinfo.value) + else: + dataset.async_parse_documents(**payload) + condition(dataset, payload["document_ids"]) + validate_document_details(dataset, payload["document_ids"]) + + @pytest.mark.parametrize( + "payload", + [ + pytest.param(lambda r: {"document_ids": ["invalid_id"] + r}, marks=pytest.mark.p3), + pytest.param(lambda r: {"document_ids": r[:1] + ["invalid_id"] + r[1:3]}, marks=pytest.mark.p1), + pytest.param(lambda r: {"document_ids": r + ["invalid_id"]}, marks=pytest.mark.p3), + ], + ) + def test_parse_partial_invalid_document_id(self, add_documents_func, payload): + dataset, documents = add_documents_func + document_ids = [doc.id for doc in documents] + payload = payload(document_ids) + + with pytest.raises(Exception) as excinfo: + dataset.async_parse_documents(**payload) + assert "Documents not found: ['invalid_id']" in str(excinfo.value), str(excinfo.value) + + condition(dataset, document_ids) + validate_document_details(dataset, document_ids) + + @pytest.mark.p3 + def test_repeated_parse(self, add_documents_func): + dataset, documents = add_documents_func + document_ids = [doc.id for doc in documents] + dataset.async_parse_documents(document_ids=document_ids) + condition(dataset, document_ids) + dataset.async_parse_documents(document_ids=document_ids) + + @pytest.mark.p3 + def test_duplicate_parse(self, add_documents_func): + dataset, documents = add_documents_func + document_ids = [doc.id for doc in documents] + dataset.async_parse_documents(document_ids=document_ids + document_ids) + condition(dataset, document_ids) + validate_document_details(dataset, document_ids) + + +@pytest.mark.p3 +def test_parse_100_files(add_dataset_func, tmp_path): + @wait_for(100, 1, "Document parsing timeout") + def condition(_dataset: DataSet, _count: int): + documents = _dataset.list_documents(page_size=_count * 2) + for document in documents: + if document.run != "DONE": + return False + return True + + count = 100 + dataset = add_dataset_func + documents = bulk_upload_documents(dataset, count, tmp_path) + document_ids = [doc.id for doc in documents] + + dataset.async_parse_documents(document_ids=document_ids) + condition(dataset, count) + validate_document_details(dataset, document_ids) + + +@pytest.mark.p3 +def test_concurrent_parse(add_dataset_func, tmp_path): + @wait_for(120, 1, "Document parsing timeout") + def condition(_dataset: DataSet, _count: int): + documents = _dataset.list_documents(page_size=_count * 2) + for document in documents: + if document.run != "DONE": + return False + return True + + count = 100 + dataset = add_dataset_func + documents = bulk_upload_documents(dataset, count, tmp_path) + document_ids = [doc.id for doc in documents] + + def parse_doc(doc_id): + dataset.async_parse_documents(document_ids=[doc_id]) + + with ThreadPoolExecutor(max_workers=5) as executor: + futures = [executor.submit(parse_doc, doc.id) for doc in documents] + + responses = list(as_completed(futures)) + assert len(responses) == count, responses + + condition(dataset, count) + validate_document_details(dataset, document_ids) diff --git a/test/testcases/test_sdk_api/test_file_management_within_dataset/test_stop_parse_documents.py b/test/testcases/test_sdk_api/test_file_management_within_dataset/test_stop_parse_documents.py new file mode 100644 index 00000000000..9b561881ef3 --- /dev/null +++ b/test/testcases/test_sdk_api/test_file_management_within_dataset/test_stop_parse_documents.py @@ -0,0 +1,41 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import pytest + + +def validate_document_parse_done(dataset, document_ids): + documents = dataset.list_documents(page_size=1000) + for document in documents: + if document.id in document_ids: + assert document.run == "DONE" + assert len(document.process_begin_at) > 0 + assert document.process_duation > 0 + assert document.progress > 0 + assert "Task done" in document.progress_msg + + +def validate_document_parse_cancel(dataset, document_ids): + documents = dataset.list_documents(page_size=1000) + for document in documents: + assert document.run == "CANCEL" + assert len(document.process_begin_at) > 0 + assert document.progress == 0.0 + + +@pytest.mark.skip +class TestDocumentsParseStop: + pass diff --git a/test/testcases/test_sdk_api/test_file_management_within_dataset/test_update_document.py b/test/testcases/test_sdk_api/test_file_management_within_dataset/test_update_document.py new file mode 100644 index 00000000000..f1d3dadb8d2 --- /dev/null +++ b/test/testcases/test_sdk_api/test_file_management_within_dataset/test_update_document.py @@ -0,0 +1,411 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import pytest +from configs import DOCUMENT_NAME_LIMIT +from ragflow_sdk import DataSet + + +class TestDocumentsUpdated: + @pytest.mark.p1 + @pytest.mark.parametrize( + "name, expected_message", + [ + ("new_name.txt", ""), + (f"{'a' * (DOCUMENT_NAME_LIMIT - 4)}.txt", ""), + (0, "AttributeError"), + (None, "AttributeError"), + ("", "The extension of file can't be changed"), + ("ragflow_test_upload_0", "The extension of file can't be changed"), + ("ragflow_test_upload_1.txt", "Duplicated document name in the same dataset"), + ("RAGFLOW_TEST_UPLOAD_1.TXT", ""), + ], + ) + def test_name(self, add_documents, name, expected_message): + dataset, documents = add_documents + document = documents[0] + + if expected_message: + with pytest.raises(Exception) as excinfo: + document.update({"name": name}) + assert expected_message in str(excinfo.value), str(excinfo.value) + else: + document.update({"name": name}) + updated_doc = dataset.list_documents(id=document.id)[0] + assert updated_doc.name == name, str(updated_doc) + + @pytest.mark.p3 + @pytest.mark.parametrize( + "meta_fields, expected_message", + [ + ({"test": "test"}, ""), + ("test", "meta_fields must be a dictionary"), + ], + ) + def test_meta_fields(self, add_documents, meta_fields, expected_message): + _, documents = add_documents + document = documents[0] + + if expected_message: + with pytest.raises(Exception) as excinfo: + document.update({"meta_fields": meta_fields}) + assert expected_message in str(excinfo.value), str(excinfo.value) + else: + document.update({"meta_fields": meta_fields}) + + @pytest.mark.p2 + @pytest.mark.parametrize( + "chunk_method, expected_message", + [ + ("naive", ""), + ("manual", ""), + ("qa", ""), + ("table", ""), + ("paper", ""), + ("book", ""), + ("laws", ""), + ("presentation", ""), + ("picture", ""), + ("one", ""), + ("knowledge_graph", ""), + ("email", ""), + ("tag", ""), + ("", "`chunk_method` doesn't exist"), + ("other_chunk_method", "`chunk_method` other_chunk_method doesn't exist"), + ], + ) + def test_chunk_method(self, add_documents, chunk_method, expected_message): + dataset, documents = add_documents + document = documents[0] + + if expected_message: + with pytest.raises(Exception) as excinfo: + document.update({"chunk_method": chunk_method}) + assert expected_message in str(excinfo.value), str(excinfo.value) + else: + document.update({"chunk_method": chunk_method}) + updated_doc = dataset.list_documents(id=document.id)[0] + assert updated_doc.chunk_method == chunk_method, str(updated_doc) + + @pytest.mark.p3 + @pytest.mark.parametrize( + "payload, expected_message", + [ + ({"chunk_count": 1}, "Can't change `chunk_count`"), + pytest.param( + {"create_date": "Fri, 14 Mar 2025 16:53:42 GMT"}, + "The input parameters are invalid", + marks=pytest.mark.skip(reason="issues/6104"), + ), + pytest.param( + {"create_time": 1}, + "The input parameters are invalid", + marks=pytest.mark.skip(reason="issues/6104"), + ), + pytest.param( + {"created_by": "ragflow_test"}, + "The input parameters are invalid", + marks=pytest.mark.skip(reason="issues/6104"), + ), + pytest.param( + {"dataset_id": "ragflow_test"}, + "The input parameters are invalid", + marks=pytest.mark.skip(reason="issues/6104"), + ), + pytest.param( + {"id": "ragflow_test"}, + "The input parameters are invalid", + marks=pytest.mark.skip(reason="issues/6104"), + ), + pytest.param( + {"location": "ragflow_test.txt"}, + "The input parameters are invalid", + marks=pytest.mark.skip(reason="issues/6104"), + ), + pytest.param( + {"process_begin_at": 1}, + "The input parameters are invalid", + marks=pytest.mark.skip(reason="issues/6104"), + ), + pytest.param( + {"process_duation": 1.0}, + "The input parameters are invalid", + marks=pytest.mark.skip(reason="issues/6104"), + ), + ({"progress": 1.0}, "Can't change `progress`"), + pytest.param( + {"progress_msg": "ragflow_test"}, + "The input parameters are invalid", + marks=pytest.mark.skip(reason="issues/6104"), + ), + pytest.param( + {"run": "ragflow_test"}, + "The input parameters are invalid", + marks=pytest.mark.skip(reason="issues/6104"), + ), + pytest.param( + {"size": 1}, + "The input parameters are invalid", + marks=pytest.mark.skip(reason="issues/6104"), + ), + pytest.param( + {"source_type": "ragflow_test"}, + "The input parameters are invalid", + marks=pytest.mark.skip(reason="issues/6104"), + ), + pytest.param( + {"thumbnail": "ragflow_test"}, + "The input parameters are invalid", + marks=pytest.mark.skip(reason="issues/6104"), + ), + ({"token_count": 1}, "Can't change `token_count`"), + pytest.param( + {"type": "ragflow_test"}, + "The input parameters are invalid", + marks=pytest.mark.skip(reason="issues/6104"), + ), + pytest.param( + {"update_date": "Fri, 14 Mar 2025 16:33:17 GMT"}, + "The input parameters are invalid", + marks=pytest.mark.skip(reason="issues/6104"), + ), + pytest.param( + {"update_time": 1}, + "The input parameters are invalid", + marks=pytest.mark.skip(reason="issues/6104"), + ), + ], + ) + def test_invalid_field(self, add_documents, payload, expected_message): + _, documents = add_documents + document = documents[0] + + with pytest.raises(Exception) as excinfo: + document.update(payload) + assert expected_message in str(excinfo.value), str(excinfo.value) + + +class TestUpdateDocumentParserConfig: + @pytest.mark.p2 + @pytest.mark.parametrize( + "chunk_method, parser_config, expected_message", + [ + ("naive", {}, ""), + ( + "naive", + { + "chunk_token_num": 128, + "layout_recognize": "DeepDOC", + "html4excel": False, + "delimiter": r"\n", + "task_page_size": 12, + "raptor": {"use_raptor": False}, + }, + "", + ), + pytest.param( + "naive", + {"chunk_token_num": -1}, + "chunk_token_num should be in range from 1 to 100000000", + marks=pytest.mark.skip(reason="issues/6098"), + ), + pytest.param( + "naive", + {"chunk_token_num": 0}, + "chunk_token_num should be in range from 1 to 100000000", + marks=pytest.mark.skip(reason="issues/6098"), + ), + pytest.param( + "naive", + {"chunk_token_num": 100000000}, + "chunk_token_num should be in range from 1 to 100000000", + marks=pytest.mark.skip(reason="issues/6098"), + ), + pytest.param( + "naive", + {"chunk_token_num": 3.14}, + "", + marks=pytest.mark.skip(reason="issues/6098"), + ), + pytest.param( + "naive", + {"chunk_token_num": "1024"}, + "", + marks=pytest.mark.skip(reason="issues/6098"), + ), + ("naive", {"layout_recognize": "DeepDOC"}, ""), + ("naive", {"layout_recognize": "Naive"}, ""), + ("naive", {"html4excel": True}, ""), + ("naive", {"html4excel": False}, ""), + pytest.param( + "naive", + {"html4excel": 1}, + "html4excel should be True or False", + marks=pytest.mark.skip(reason="issues/6098"), + ), + ("naive", {"delimiter": ""}, ""), + ("naive", {"delimiter": "`##`"}, ""), + pytest.param( + "naive", + {"delimiter": 1}, + "", + marks=pytest.mark.skip(reason="issues/6098"), + ), + pytest.param( + "naive", + {"task_page_size": -1}, + "task_page_size should be in range from 1 to 100000000", + marks=pytest.mark.skip(reason="issues/6098"), + ), + pytest.param( + "naive", + {"task_page_size": 0}, + "task_page_size should be in range from 1 to 100000000", + marks=pytest.mark.skip(reason="issues/6098"), + ), + pytest.param( + "naive", + {"task_page_size": 100000000}, + "task_page_size should be in range from 1 to 100000000", + marks=pytest.mark.skip(reason="issues/6098"), + ), + pytest.param( + "naive", + {"task_page_size": 3.14}, + "", + marks=pytest.mark.skip(reason="issues/6098"), + ), + pytest.param( + "naive", + {"task_page_size": "1024"}, + "", + marks=pytest.mark.skip(reason="issues/6098"), + ), + ("naive", {"raptor": {"use_raptor": True}}, ""), + ("naive", {"raptor": {"use_raptor": False}}, ""), + pytest.param( + "naive", + {"invalid_key": "invalid_value"}, + "Abnormal 'parser_config'. Invalid key: invalid_key", + marks=pytest.mark.skip(reason="issues/6098"), + ), + pytest.param( + "naive", + {"auto_keywords": -1}, + "auto_keywords should be in range from 0 to 32", + marks=pytest.mark.skip(reason="issues/6098"), + ), + pytest.param( + "naive", + {"auto_keywords": 32}, + "auto_keywords should be in range from 0 to 32", + marks=pytest.mark.skip(reason="issues/6098"), + ), + pytest.param( + "naive", + {"auto_keywords": 3.14}, + "", + marks=pytest.mark.skip(reason="issues/6098"), + ), + pytest.param( + "naive", + {"auto_keywords": "1024"}, + "", + marks=pytest.mark.skip(reason="issues/6098"), + ), + pytest.param( + "naive", + {"auto_questions": -1}, + "auto_questions should be in range from 0 to 10", + marks=pytest.mark.skip(reason="issues/6098"), + ), + pytest.param( + "naive", + {"auto_questions": 10}, + "auto_questions should be in range from 0 to 10", + marks=pytest.mark.skip(reason="issues/6098"), + ), + pytest.param( + "naive", + {"auto_questions": 3.14}, + "", + marks=pytest.mark.skip(reason="issues/6098"), + ), + pytest.param( + "naive", + {"auto_questions": "1024"}, + "", + marks=pytest.mark.skip(reason="issues/6098"), + ), + pytest.param( + "naive", + {"topn_tags": -1}, + "topn_tags should be in range from 0 to 10", + marks=pytest.mark.skip(reason="issues/6098"), + ), + pytest.param( + "naive", + {"topn_tags": 10}, + "topn_tags should be in range from 0 to 10", + marks=pytest.mark.skip(reason="issues/6098"), + ), + pytest.param( + "naive", + {"topn_tags": 3.14}, + "", + marks=pytest.mark.skip(reason="issues/6098"), + ), + pytest.param( + "naive", + {"topn_tags": "1024"}, + "", + marks=pytest.mark.skip(reason="issues/6098"), + ), + ], + ) + def test_parser_config(self, client, add_documents, chunk_method, parser_config, expected_message): + dataset, documents = add_documents + document = documents[0] + from operator import attrgetter + + update_data = {"chunk_method": chunk_method, "parser_config": parser_config} + + if expected_message: + with pytest.raises(Exception) as excinfo: + document.update(update_data) + assert expected_message in str(excinfo.value), str(excinfo.value) + else: + document.update(update_data) + updated_doc = dataset.list_documents(id=document.id)[0] + if parser_config: + for k, v in parser_config.items(): + if isinstance(v, dict): + for kk, vv in v.items(): + assert attrgetter(f"{k}.{kk}")(updated_doc.parser_config) == vv, str(updated_doc) + else: + assert attrgetter(k)(updated_doc.parser_config) == v, str(updated_doc) + else: + expected_config = DataSet.ParserConfig( + client, + { + "chunk_token_num": 128, + "delimiter": r"\n", + "html4excel": False, + "layout_recognize": "DeepDOC", + "raptor": {"use_raptor": False}, + }, + ) + assert str(updated_doc.parser_config) == str(expected_config), str(updated_doc) diff --git a/test/testcases/test_sdk_api/test_file_management_within_dataset/test_upload_documents.py b/test/testcases/test_sdk_api/test_file_management_within_dataset/test_upload_documents.py new file mode 100644 index 00000000000..72034e27df0 --- /dev/null +++ b/test/testcases/test_sdk_api/test_file_management_within_dataset/test_upload_documents.py @@ -0,0 +1,211 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import string +from concurrent.futures import ThreadPoolExecutor, as_completed + +import pytest +from configs import DOCUMENT_NAME_LIMIT +from utils.file_utils import create_txt_file + + +class TestDocumentsUpload: + @pytest.mark.p1 + def test_valid_single_upload(self, add_dataset_func, tmp_path): + dataset = add_dataset_func + fp = create_txt_file(tmp_path / "ragflow_test.txt") + with fp.open("rb") as f: + blob = f.read() + + documents = dataset.upload_documents([{"display_name": fp.name, "blob": blob}]) + for document in documents: + assert document.dataset_id == dataset.id, str(document) + assert document.name == fp.name, str(document) + + @pytest.mark.p1 + @pytest.mark.parametrize( + "generate_test_files", + [ + "docx", + "excel", + "ppt", + "image", + "pdf", + "txt", + "md", + "json", + "eml", + "html", + ], + indirect=True, + ) + def test_file_type_validation(self, add_dataset_func, generate_test_files, request): + dataset = add_dataset_func + fp = generate_test_files[request.node.callspec.params["generate_test_files"]] + + with fp.open("rb") as f: + blob = f.read() + + documents = dataset.upload_documents([{"display_name": fp.name, "blob": blob}]) + for document in documents: + assert document.dataset_id == dataset.id, str(document) + assert document.name == fp.name, str(document) + + @pytest.mark.p2 + @pytest.mark.parametrize( + "file_type", + ["exe", "unknown"], + ) + def test_unsupported_file_type(self, add_dataset_func, tmp_path, file_type): + dataset = add_dataset_func + fp = tmp_path / f"ragflow_test.{file_type}" + fp.touch() + + with fp.open("rb") as f: + blob = f.read() + + with pytest.raises(Exception) as excinfo: + dataset.upload_documents([{"display_name": fp.name, "blob": blob}]) + assert str(excinfo.value) == f"ragflow_test.{file_type}: This type of file has not been supported yet!", str(excinfo.value) + + @pytest.mark.p2 + def test_missing_file(self, add_dataset_func): + dataset = add_dataset_func + with pytest.raises(Exception) as excinfo: + dataset.upload_documents([]) + assert str(excinfo.value) == "No file part!", str(excinfo.value) + + @pytest.mark.p3 + def test_empty_file(self, add_dataset_func, tmp_path): + dataset = add_dataset_func + fp = tmp_path / "empty.txt" + fp.touch() + + with fp.open("rb") as f: + blob = f.read() + + documents = dataset.upload_documents([{"display_name": fp.name, "blob": blob}]) + for document in documents: + assert document.size == 0, str(document) + + @pytest.mark.p3 + def test_filename_empty(self, add_dataset_func, tmp_path): + dataset = add_dataset_func + fp = create_txt_file(tmp_path / "ragflow_test.txt") + + with fp.open("rb") as f: + blob = f.read() + + with pytest.raises(Exception) as excinfo: + dataset.upload_documents([{"display_name": "", "blob": blob}]) + assert str(excinfo.value) == "No file selected!", str(excinfo.value) + + @pytest.mark.p2 + def test_filename_max_length(self, add_dataset_func, tmp_path): + dataset = add_dataset_func + fp = create_txt_file(tmp_path / f"{'a' * (DOCUMENT_NAME_LIMIT - 4)}.txt") + + with fp.open("rb") as f: + blob = f.read() + + documents = dataset.upload_documents([{"display_name": fp.name, "blob": blob}]) + for document in documents: + assert document.dataset_id == dataset.id, str(document) + assert document.name == fp.name, str(document) + + @pytest.mark.p2 + def test_duplicate_files(self, add_dataset_func, tmp_path): + dataset = add_dataset_func + fp = create_txt_file(tmp_path / "ragflow_test.txt") + + with fp.open("rb") as f: + blob = f.read() + + documents = dataset.upload_documents([{"display_name": fp.name, "blob": blob}, {"display_name": fp.name, "blob": blob}]) + + assert len(documents) == 2, str(documents) + for i, document in enumerate(documents): + assert document.dataset_id == dataset.id, str(document) + expected_name = fp.name if i == 0 else f"{fp.stem}({i}){fp.suffix}" + assert document.name == expected_name, str(document) + + @pytest.mark.p2 + def test_same_file_repeat(self, add_dataset_func, tmp_path): + dataset = add_dataset_func + fp = create_txt_file(tmp_path / "ragflow_test.txt") + + with fp.open("rb") as f: + blob = f.read() + + for i in range(3): + documents = dataset.upload_documents([{"display_name": fp.name, "blob": blob}]) + assert len(documents) == 1, str(documents) + document = documents[0] + assert document.dataset_id == dataset.id, str(document) + expected_name = fp.name if i == 0 else f"{fp.stem}({i}){fp.suffix}" + assert document.name == expected_name, str(document) + + @pytest.mark.p3 + def test_filename_special_characters(self, add_dataset_func, tmp_path): + dataset = add_dataset_func + illegal_chars = '<>:"/\\|?*' + translation_table = str.maketrans({char: "_" for char in illegal_chars}) + safe_filename = string.punctuation.translate(translation_table) + fp = tmp_path / f"{safe_filename}.txt" + fp.write_text("Sample text content") + + with fp.open("rb") as f: + blob = f.read() + + documents = dataset.upload_documents([{"display_name": fp.name, "blob": blob}]) + assert len(documents) == 1, str(documents) + document = documents[0] + assert document.dataset_id == dataset.id, str(document) + assert document.name == fp.name, str(document) + + @pytest.mark.p1 + def test_multiple_files(self, client, add_dataset_func, tmp_path): + dataset = add_dataset_func + expected_document_count = 20 + document_infos = [] + for i in range(expected_document_count): + fp = create_txt_file(tmp_path / f"ragflow_test_upload_{i}.txt") + with fp.open("rb") as f: + blob = f.read() + document_infos.append({"display_name": fp.name, "blob": blob}) + documents = dataset.upload_documents(document_infos) + assert len(documents) == expected_document_count, str(documents) + + retrieved_dataset = client.get_dataset(name=dataset.name) + assert retrieved_dataset.document_count == expected_document_count, str(retrieved_dataset) + + @pytest.mark.p3 + def test_concurrent_upload(self, client, add_dataset_func, tmp_path): + dataset = add_dataset_func + count = 20 + fps = [create_txt_file(tmp_path / f"ragflow_test_{i}.txt") for i in range(count)] + + def upload_file(fp): + with fp.open("rb") as f: + blob = f.read() + return dataset.upload_documents([{"display_name": fp.name, "blob": blob}]) + + with ThreadPoolExecutor(max_workers=5) as executor: + futures = [executor.submit(upload_file, fp) for fp in fps] + responses = list(as_completed(futures)) + assert len(responses) == count, responses + + retrieved_dataset = client.get_dataset(name=dataset.name) + assert retrieved_dataset.document_count == count, str(retrieved_dataset) diff --git a/test/testcases/test_sdk_api/test_session_management/conftest.py b/test/testcases/test_sdk_api/test_session_management/conftest.py new file mode 100644 index 00000000000..08dfd08d251 --- /dev/null +++ b/test/testcases/test_sdk_api/test_session_management/conftest.py @@ -0,0 +1,49 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import pytest +from common import batch_add_sessions_with_chat_assistant +from pytest import FixtureRequest +from ragflow_sdk import Chat, DataSet, Document, Session + + +@pytest.fixture(scope="class") +def add_sessions_with_chat_assistant(request: FixtureRequest, add_chat_assistants: tuple[DataSet, Document, list[Chat]]) -> tuple[Chat, list[Session]]: + def cleanup(): + for chat_assistant in chat_assistants: + try: + chat_assistant.delete_sessions(ids=None) + except Exception: + pass + + request.addfinalizer(cleanup) + + _, _, chat_assistants = add_chat_assistants + return chat_assistants[0], batch_add_sessions_with_chat_assistant(chat_assistants[0], 5) + + +@pytest.fixture(scope="function") +def add_sessions_with_chat_assistant_func(request: FixtureRequest, add_chat_assistants: tuple[DataSet, Document, list[Chat]]) -> tuple[Chat, list[Session]]: + def cleanup(): + for chat_assistant in chat_assistants: + try: + chat_assistant.delete_sessions(ids=None) + except Exception: + pass + + request.addfinalizer(cleanup) + + _, _, chat_assistants = add_chat_assistants + return chat_assistants[0], batch_add_sessions_with_chat_assistant(chat_assistants[0], 5) diff --git a/test/testcases/test_sdk_api/test_session_management/test_create_session_with_chat_assistant.py b/test/testcases/test_sdk_api/test_session_management/test_create_session_with_chat_assistant.py new file mode 100644 index 00000000000..084c7122c20 --- /dev/null +++ b/test/testcases/test_sdk_api/test_session_management/test_create_session_with_chat_assistant.py @@ -0,0 +1,76 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from concurrent.futures import ThreadPoolExecutor, as_completed + +import pytest +from configs import SESSION_WITH_CHAT_NAME_LIMIT + + +@pytest.mark.usefixtures("clear_session_with_chat_assistants") +class TestSessionWithChatAssistantCreate: + @pytest.mark.p1 + @pytest.mark.parametrize( + "name, expected_message", + [ + ("valid_name", ""), + pytest.param("a" * (SESSION_WITH_CHAT_NAME_LIMIT + 1), "", marks=pytest.mark.skip(reason="issues/")), + pytest.param(1, "", marks=pytest.mark.skip(reason="issues/")), + ("", "`name` can not be empty."), + ("duplicated_name", ""), + ("case insensitive", ""), + ], + ) + def test_name(self, add_chat_assistants, name, expected_message): + _, _, chat_assistants = add_chat_assistants + chat_assistant = chat_assistants[0] + + if name == "duplicated_name": + chat_assistant.create_session(name=name) + elif name == "case insensitive": + chat_assistant.create_session(name=name.upper()) + + if expected_message: + with pytest.raises(Exception) as excinfo: + chat_assistant.create_session(name=name) + assert expected_message in str(excinfo.value) + else: + session = chat_assistant.create_session(name=name) + assert session.name == name, str(session) + assert session.chat_id == chat_assistant.id, str(session) + + @pytest.mark.p3 + def test_concurrent_create_session(self, add_chat_assistants): + count = 1000 + _, _, chat_assistants = add_chat_assistants + chat_assistant = chat_assistants[0] + + with ThreadPoolExecutor(max_workers=5) as executor: + futures = [executor.submit(chat_assistant.create_session, name=f"session with chat assistant test {i}") for i in range(count)] + responses = list(as_completed(futures)) + assert len(responses) == count, responses + + updated_sessions = chat_assistant.list_sessions(page_size=count * 2) + assert len(updated_sessions) == count + + @pytest.mark.p3 + def test_add_session_to_deleted_chat_assistant(self, client, add_chat_assistants): + _, _, chat_assistants = add_chat_assistants + chat_assistant = chat_assistants[0] + + client.delete_chats(ids=[chat_assistant.id]) + with pytest.raises(Exception) as excinfo: + chat_assistant.create_session(name="valid_name") + assert "You do not own the assistant" in str(excinfo.value) diff --git a/test/testcases/test_sdk_api/test_session_management/test_delete_sessions_with_chat_assistant.py b/test/testcases/test_sdk_api/test_session_management/test_delete_sessions_with_chat_assistant.py new file mode 100644 index 00000000000..0c3c0ebe42a --- /dev/null +++ b/test/testcases/test_sdk_api/test_session_management/test_delete_sessions_with_chat_assistant.py @@ -0,0 +1,108 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from concurrent.futures import ThreadPoolExecutor, as_completed + +import pytest +from common import batch_add_sessions_with_chat_assistant + + +class TestSessionWithChatAssistantDelete: + @pytest.mark.parametrize( + "payload", + [ + pytest.param(lambda r: {"ids": ["invalid_id"] + r}, marks=pytest.mark.p3), + pytest.param(lambda r: {"ids": r[:1] + ["invalid_id"] + r[1:5]}, marks=pytest.mark.p1), + pytest.param(lambda r: {"ids": r + ["invalid_id"]}, marks=pytest.mark.p3), + ], + ) + def test_delete_partial_invalid_id(self, add_sessions_with_chat_assistant_func, payload): + chat_assistant, sessions = add_sessions_with_chat_assistant_func + if callable(payload): + payload = payload([session.id for session in sessions]) + + chat_assistant.delete_sessions(**payload) + + sessions = chat_assistant.list_sessions() + assert len(sessions) == 0 + + @pytest.mark.p3 + def test_repeated_deletion(self, add_sessions_with_chat_assistant_func): + chat_assistant, sessions = add_sessions_with_chat_assistant_func + session_ids = {"ids": [session.id for session in sessions]} + + chat_assistant.delete_sessions(**session_ids) + + with pytest.raises(Exception) as excinfo: + chat_assistant.delete_sessions(**session_ids) + assert "The chat doesn't own the session" in str(excinfo.value) + + @pytest.mark.p3 + def test_duplicate_deletion(self, add_sessions_with_chat_assistant_func): + chat_assistant, sessions = add_sessions_with_chat_assistant_func + session_ids = {"ids": [session.id for session in sessions] * 2} + chat_assistant.delete_sessions(**session_ids) + sessions = chat_assistant.list_sessions() + assert len(sessions) == 0 + + @pytest.mark.p3 + def test_concurrent_deletion(self, add_chat_assistants): + count = 100 + _, _, chat_assistants = add_chat_assistants + chat_assistant = chat_assistants[0] + sessions = batch_add_sessions_with_chat_assistant(chat_assistant, count) + + with ThreadPoolExecutor(max_workers=5) as executor: + futures = [executor.submit(chat_assistant.delete_sessions, ids=[sessions[i].id]) for i in range(count)] + responses = list(as_completed(futures)) + assert len(responses) == count, responses + + @pytest.mark.p3 + def test_delete_1k(self, add_chat_assistants): + count = 1_000 + _, _, chat_assistants = add_chat_assistants + chat_assistant = chat_assistants[0] + ssessions = batch_add_sessions_with_chat_assistant(chat_assistant, count) + chat_assistant.delete_sessions(ids=[ssession.id for ssession in ssessions]) + + sessions = chat_assistant.list_sessions() + assert len(sessions) == 0 + + @pytest.mark.parametrize( + "payload, expected_message, remaining", + [ + pytest.param(None, """TypeError("argument of type \'NoneType\' is not iterable")""", 0, marks=pytest.mark.skip), + pytest.param({"ids": ["invalid_id"]}, "The chat doesn't own the session invalid_id", 5, marks=pytest.mark.p3), + pytest.param("not json", """AttributeError("\'str\' object has no attribute \'get\'")""", 5, marks=pytest.mark.skip), + pytest.param(lambda r: {"ids": r[:1]}, "", 4, marks=pytest.mark.p3), + pytest.param(lambda r: {"ids": r}, "", 0, marks=pytest.mark.p1), + pytest.param({"ids": []}, "", 0, marks=pytest.mark.p3), + ], + ) + def test_basic_scenarios(self, add_sessions_with_chat_assistant_func, payload, expected_message, remaining): + chat_assistant, sessions = add_sessions_with_chat_assistant_func + if callable(payload): + payload = payload([session.id for session in sessions]) + + if expected_message: + with pytest.raises(Exception) as excinfo: + chat_assistant.delete_sessions(**payload) + assert expected_message in str(excinfo.value) + else: + chat_assistant.delete_sessions(**payload) + + sessions = chat_assistant.list_sessions() + assert len(sessions) == remaining diff --git a/test/testcases/test_sdk_api/test_session_management/test_list_sessions_with_chat_assistant.py b/test/testcases/test_sdk_api/test_session_management/test_list_sessions_with_chat_assistant.py new file mode 100644 index 00000000000..c2918848c7f --- /dev/null +++ b/test/testcases/test_sdk_api/test_session_management/test_list_sessions_with_chat_assistant.py @@ -0,0 +1,203 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import pytest +from concurrent.futures import ThreadPoolExecutor, as_completed + + +class TestSessionsWithChatAssistantList: + @pytest.mark.p1 + @pytest.mark.parametrize( + "params, expected_page_size, expected_message", + [ + ({"page": None, "page_size": 2}, 0, "not instance of"), + pytest.param({"page": 0, "page_size": 2}, 0, "ValueError('Search does not support negative slicing.')", marks=pytest.mark.skip), + ({"page": 2, "page_size": 2}, 2, ""), + ({"page": 3, "page_size": 2}, 1, ""), + ({"page": "3", "page_size": 2}, 0, "not instance of"), + pytest.param({"page": -1, "page_size": 2}, 0, "ValueError('Search does not support negative slicing.')", marks=pytest.mark.skip), + pytest.param({"page": "a", "page_size": 2}, 0, """ValueError("invalid literal for int() with base 10: \'a\'")""", marks=pytest.mark.skip), + ], + ) + def test_page(self, add_sessions_with_chat_assistant, params, expected_page_size, expected_message): + chat_assistant, _ = add_sessions_with_chat_assistant + if expected_message: + with pytest.raises(Exception) as excinfo: + chat_assistant.list_sessions(**params) + assert expected_message in str(excinfo.value) + else: + sessions = chat_assistant.list_sessions(**params) + assert len(sessions) == expected_page_size + + @pytest.mark.p1 + @pytest.mark.parametrize( + "params, expected_page_size, expected_message", + [ + ({"page_size": None}, 0, "not instance of"), + ({"page_size": 0}, 0, ""), + ({"page_size": 1}, 1, ""), + ({"page_size": 6}, 5, ""), + ({"page_size": "1"}, 0, "not instance of"), + pytest.param({"page_size": -1}, 5, "", marks=pytest.mark.skip), + pytest.param({"page_size": "a"}, 0, """ValueError("invalid literal for int() with base 10: \'a\'")""", marks=pytest.mark.skip), + ], + ) + def test_page_size(self, add_sessions_with_chat_assistant, params, expected_page_size, expected_message): + chat_assistant, _ = add_sessions_with_chat_assistant + if expected_message: + with pytest.raises(Exception) as excinfo: + chat_assistant.list_sessions(**params) + assert expected_message in str(excinfo.value) + else: + sessions = chat_assistant.list_sessions(**params) + assert len(sessions) == expected_page_size + + @pytest.mark.p3 + @pytest.mark.parametrize( + "params, expected_message", + [ + ({"orderby": None}, "not instance of"), + ({"orderby": "create_time"}, ""), + ({"orderby": "update_time"}, ""), + ({"orderby": "name", "desc": "False"}, "not instance of"), + pytest.param({"orderby": "unknown"}, "orderby should be create_time or update_time", marks=pytest.mark.skip(reason="issues/")), + ], + ) + def test_orderby(self, add_sessions_with_chat_assistant, params, expected_message): + chat_assistant, _ = add_sessions_with_chat_assistant + if expected_message: + with pytest.raises(Exception) as excinfo: + chat_assistant.list_sessions(**params) + assert expected_message in str(excinfo.value) + else: + chat_assistant.list_sessions(**params) + + @pytest.mark.p3 + @pytest.mark.parametrize( + "params, expected_message", + [ + ({"desc": None}, "not instance of"), + ({"desc": "true"}, "not instance of"), + ({"desc": "True"}, "not instance of"), + ({"desc": True}, ""), + ({"desc": "false"}, "not instance of"), + ({"desc": "False"}, "not instance of"), + ({"desc": False}, ""), + ({"desc": "False", "orderby": "update_time"}, "not instance of"), + pytest.param({"desc": "unknown"}, "desc should be true or false", marks=pytest.mark.skip(reason="issues/")), + ], + ) + def test_desc(self, add_sessions_with_chat_assistant, params, expected_message): + chat_assistant, _ = add_sessions_with_chat_assistant + if expected_message: + with pytest.raises(Exception) as excinfo: + chat_assistant.list_sessions(**params) + assert expected_message in str(excinfo.value) + else: + chat_assistant.list_sessions(**params) + + @pytest.mark.p1 + @pytest.mark.parametrize( + "params, expected_num, expected_message", + [ + ({"name": None}, 0, "not instance of"), + ({"name": ""}, 5, ""), + ({"name": "session_with_chat_assistant_1"}, 1, ""), + ({"name": "unknown"}, 0, ""), + ], + ) + def test_name(self, add_sessions_with_chat_assistant, params, expected_num, expected_message): + chat_assistant, _ = add_sessions_with_chat_assistant + if expected_message: + with pytest.raises(Exception) as excinfo: + chat_assistant.list_sessions(**params) + assert expected_message in str(excinfo.value) + else: + sessions = chat_assistant.list_sessions(**params) + if params["name"] == "session_with_chat_assistant_1": + assert sessions[0].name == params["name"] + else: + assert len(sessions) == expected_num + + @pytest.mark.p1 + @pytest.mark.parametrize( + "session_id, expected_num, expected_message", + [ + (None, 0, "not instance of"), + ("", 5, ""), + (lambda r: r[0], 1, ""), + ("unknown", 0, ""), + ], + ) + def test_id(self, add_sessions_with_chat_assistant, session_id, expected_num, expected_message): + chat_assistant, sessions = add_sessions_with_chat_assistant + if callable(session_id): + params = {"id": session_id([s.id for s in sessions])} + else: + params = {"id": session_id} + + if expected_message: + with pytest.raises(Exception) as excinfo: + chat_assistant.list_sessions(**params) + assert expected_message in str(excinfo.value) + else: + list_sessions = chat_assistant.list_sessions(**params) + if "id" in params and params["id"] == sessions[0].id: + assert list_sessions[0].id == params["id"] + else: + assert len(list_sessions) == expected_num + + @pytest.mark.p3 + @pytest.mark.parametrize( + "session_id, name, expected_num, expected_message", + [ + (lambda r: r[0], "session_with_chat_assistant_0", 1, ""), + (lambda r: r[0], "session_with_chat_assistant_100", 0, ""), + (lambda r: r[0], "unknown", 0, ""), + ("id", "session_with_chat_assistant_0", 0, ""), + ], + ) + def test_name_and_id(self, add_sessions_with_chat_assistant, session_id, name, expected_num, expected_message): + chat_assistant, sessions = add_sessions_with_chat_assistant + if callable(session_id): + params = {"id": session_id([s.id for s in sessions]), "name": name} + else: + params = {"id": session_id, "name": name} + + if expected_message: + with pytest.raises(Exception) as excinfo: + chat_assistant.list_sessions(**params) + assert expected_message in str(excinfo.value) + else: + list_sessions = chat_assistant.list_sessions(**params) + assert len(list_sessions) == expected_num + + @pytest.mark.p3 + def test_concurrent_list(self, add_sessions_with_chat_assistant): + count = 100 + chat_assistant, _ = add_sessions_with_chat_assistant + with ThreadPoolExecutor(max_workers=5) as executor: + futures = [executor.submit(chat_assistant.list_sessions) for i in range(count)] + responses = list(as_completed(futures)) + assert len(responses) == count, responses + + @pytest.mark.p3 + def test_list_chats_after_deleting_associated_chat_assistant(self, client, add_sessions_with_chat_assistant): + chat_assistant, _ = add_sessions_with_chat_assistant + client.delete_chats(ids=[chat_assistant.id]) + + with pytest.raises(Exception) as excinfo: + chat_assistant.list_sessions() + assert "You don't own the assistant" in str(excinfo.value) diff --git a/test/testcases/test_sdk_api/test_session_management/test_update_session_with_chat_assistant.py b/test/testcases/test_sdk_api/test_session_management/test_update_session_with_chat_assistant.py new file mode 100644 index 00000000000..d0fa3e187cb --- /dev/null +++ b/test/testcases/test_sdk_api/test_session_management/test_update_session_with_chat_assistant.py @@ -0,0 +1,98 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from concurrent.futures import ThreadPoolExecutor, as_completed +from random import randint + +import pytest +from configs import SESSION_WITH_CHAT_NAME_LIMIT + + +class TestSessionWithChatAssistantUpdate: + @pytest.mark.parametrize( + "payload, expected_message", + [ + pytest.param({"name": "valid_name"}, "", marks=pytest.mark.p1), + pytest.param({"name": "a" * (SESSION_WITH_CHAT_NAME_LIMIT + 1)}, "", marks=pytest.mark.skip(reason="issues/")), + pytest.param({"name": 1}, "", marks=pytest.mark.skip(reason="issues/")), + pytest.param({"name": ""}, "`name` can not be empty.", marks=pytest.mark.p3), + pytest.param({"name": "duplicated_name"}, "", marks=pytest.mark.p3), + pytest.param({"name": "case insensitive"}, "", marks=pytest.mark.p3), + ], + ) + def test_name(self, add_sessions_with_chat_assistant_func, payload, expected_message): + chat_assistant, sessions = add_sessions_with_chat_assistant_func + session = sessions[0] + + if payload["name"] == "duplicated_name": + session.update(payload) + elif payload["name"] == "case insensitive": + session.update({"name": payload["name"].upper()}) + + if expected_message: + with pytest.raises(Exception) as excinfo: + session.update(payload) + assert expected_message in str(excinfo.value) + else: + session.update(payload) + updated_session = chat_assistant.list_sessions(id=session.id)[0] + assert updated_session.name == payload["name"] + + @pytest.mark.p3 + def test_repeated_update_session(self, add_sessions_with_chat_assistant_func): + _, sessions = add_sessions_with_chat_assistant_func + session = sessions[0] + + session.update({"name": "valid_name_1"}) + session.update({"name": "valid_name_2"}) + + @pytest.mark.p3 + @pytest.mark.parametrize( + "payload, expected_message", + [ + pytest.param({"unknown_key": "unknown_value"}, "ValueError", marks=pytest.mark.skip), + ({}, ""), + pytest.param(None, "TypeError", marks=pytest.mark.skip), + ], + ) + def test_invalid_params(self, add_sessions_with_chat_assistant_func, payload, expected_message): + _, sessions = add_sessions_with_chat_assistant_func + session = sessions[0] + + if expected_message: + with pytest.raises(Exception) as excinfo: + session.update(payload) + assert expected_message in str(excinfo.value) + else: + session.update(payload) + + @pytest.mark.p3 + def test_concurrent_update_session(self, add_sessions_with_chat_assistant_func): + count = 50 + _, sessions = add_sessions_with_chat_assistant_func + + with ThreadPoolExecutor(max_workers=5) as executor: + futures = [executor.submit(sessions[randint(0, 4)].update, {"name": f"update session test {i}"}) for i in range(count)] + responses = list(as_completed(futures)) + assert len(responses) == count, responses + + @pytest.mark.p3 + def test_update_session_to_deleted_chat_assistant(self, client, add_sessions_with_chat_assistant_func): + chat_assistant, sessions = add_sessions_with_chat_assistant_func + client.delete_chats(ids=[chat_assistant.id]) + + with pytest.raises(Exception) as excinfo: + sessions[0].update({"name": "valid_name"}) + assert "You do not own the session" in str(excinfo.value) diff --git a/test/testcases/test_web_api/common.py b/test/testcases/test_web_api/common.py new file mode 100644 index 00000000000..69eba070d61 --- /dev/null +++ b/test/testcases/test_web_api/common.py @@ -0,0 +1,93 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import requests +from configs import HOST_ADDRESS + +HEADERS = {"Content-Type": "application/json"} + +KB_APP_URL = "/v1/kb" +# FILE_API_URL = "/api/v1/datasets/{dataset_id}/documents" +# FILE_CHUNK_API_URL = "/api/v1/datasets/{dataset_id}/chunks" +# CHUNK_API_URL = "/api/v1/datasets/{dataset_id}/documents/{document_id}/chunks" +# CHAT_ASSISTANT_API_URL = "/api/v1/chats" +# SESSION_WITH_CHAT_ASSISTANT_API_URL = "/api/v1/chats/{chat_id}/sessions" +# SESSION_WITH_AGENT_API_URL = "/api/v1/agents/{agent_id}/sessions" + + +# DATASET MANAGEMENT +def create_kb(auth, payload=None, *, headers=HEADERS, data=None): + res = requests.post(url=f"{HOST_ADDRESS}{KB_APP_URL}/create", headers=headers, auth=auth, json=payload, data=data) + return res.json() + + +def list_kbs(auth, params=None, payload=None, *, headers=HEADERS, data=None): + if payload is None: + payload = {} + res = requests.post(url=f"{HOST_ADDRESS}{KB_APP_URL}/list", headers=headers, auth=auth, params=params, json=payload, data=data) + return res.json() + + +def update_kb(auth, payload=None, *, headers=HEADERS, data=None): + res = requests.post(url=f"{HOST_ADDRESS}{KB_APP_URL}/update", headers=headers, auth=auth, json=payload, data=data) + return res.json() + + +def rm_kb(auth, payload=None, *, headers=HEADERS, data=None): + res = requests.post(url=f"{HOST_ADDRESS}{KB_APP_URL}/rm", headers=headers, auth=auth, json=payload, data=data) + return res.json() + + +def detail_kb(auth, params=None, *, headers=HEADERS): + res = requests.get(url=f"{HOST_ADDRESS}{KB_APP_URL}/detail", headers=headers, auth=auth, params=params) + return res.json() + + +def list_tags_from_kbs(auth, params=None, *, headers=HEADERS): + res = requests.get(url=f"{HOST_ADDRESS}{KB_APP_URL}/tags", headers=headers, auth=auth, params=params) + return res.json() + + +def list_tags(auth, dataset_id, params=None, *, headers=HEADERS): + res = requests.get(url=f"{HOST_ADDRESS}{KB_APP_URL}/{dataset_id}/tags", headers=headers, auth=auth, params=params) + return res.json() + + +def rm_tags(auth, dataset_id, payload=None, *, headers=HEADERS, data=None): + res = requests.post(url=f"{HOST_ADDRESS}{KB_APP_URL}/{dataset_id}/rm_tags", headers=headers, auth=auth, json=payload, data=data) + return res.json() + + +def rename_tags(auth, dataset_id, payload=None, *, headers=HEADERS, data=None): + res = requests.post(url=f"{HOST_ADDRESS}{KB_APP_URL}/{dataset_id}/rename_tags", headers=headers, auth=auth, json=payload, data=data) + return res.json() + + +def knowledge_graph(auth, dataset_id, params=None, *, headers=HEADERS): + res = requests.get(url=f"{HOST_ADDRESS}{KB_APP_URL}/{dataset_id}/knowledge_graph", headers=headers, auth=auth, params=params) + return res.json() + + +def delete_knowledge_graph(auth, dataset_id, payload=None, *, headers=HEADERS, data=None): + res = requests.delete(url=f"{HOST_ADDRESS}{KB_APP_URL}/{dataset_id}/delete_knowledge_graph", headers=headers, auth=auth, json=payload, data=data) + return res.json() + + +def batch_create_datasets(auth, num): + ids = [] + for i in range(num): + res = create_kb(auth, {"name": f"kb_{i}"}) + ids.append(res["data"]["kb_id"]) + return ids diff --git a/test/testcases/test_web_api/conftest.py b/test/testcases/test_web_api/conftest.py new file mode 100644 index 00000000000..44c80d9aff8 --- /dev/null +++ b/test/testcases/test_web_api/conftest.py @@ -0,0 +1,100 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import pytest +from common import ( + batch_create_datasets, +) +from configs import HOST_ADDRESS, VERSION +from libs.auth import RAGFlowWebApiAuth +from pytest import FixtureRequest +from ragflow_sdk import RAGFlow +from utils.file_utils import ( + create_docx_file, + create_eml_file, + create_excel_file, + create_html_file, + create_image_file, + create_json_file, + create_md_file, + create_pdf_file, + create_ppt_file, + create_txt_file, +) + + +@pytest.fixture +def generate_test_files(request: FixtureRequest, tmp_path): + file_creators = { + "docx": (tmp_path / "ragflow_test.docx", create_docx_file), + "excel": (tmp_path / "ragflow_test.xlsx", create_excel_file), + "ppt": (tmp_path / "ragflow_test.pptx", create_ppt_file), + "image": (tmp_path / "ragflow_test.png", create_image_file), + "pdf": (tmp_path / "ragflow_test.pdf", create_pdf_file), + "txt": (tmp_path / "ragflow_test.txt", create_txt_file), + "md": (tmp_path / "ragflow_test.md", create_md_file), + "json": (tmp_path / "ragflow_test.json", create_json_file), + "eml": (tmp_path / "ragflow_test.eml", create_eml_file), + "html": (tmp_path / "ragflow_test.html", create_html_file), + } + + files = {} + for file_type, (file_path, creator_func) in file_creators.items(): + if request.param in ["", file_type]: + creator_func(file_path) + files[file_type] = file_path + return files + + +@pytest.fixture(scope="class") +def ragflow_tmp_dir(request, tmp_path_factory): + class_name = request.cls.__name__ + return tmp_path_factory.mktemp(class_name) + + +@pytest.fixture(scope="session") +def WebApiAuth(auth): + return RAGFlowWebApiAuth(auth) + + +@pytest.fixture(scope="session") +def client(token: str) -> RAGFlow: + return RAGFlow(api_key=token, base_url=HOST_ADDRESS, version=VERSION) + + +@pytest.fixture(scope="function") +def clear_datasets(request: FixtureRequest, client: RAGFlow): + def cleanup(): + client.delete_datasets(ids=None) + + request.addfinalizer(cleanup) + + +@pytest.fixture(scope="class") +def add_dataset(request: FixtureRequest, client: RAGFlow, WebApiAuth: RAGFlowWebApiAuth) -> str: + def cleanup(): + client.delete_datasets(ids=None) + + request.addfinalizer(cleanup) + return batch_create_datasets(WebApiAuth, 1)[0] + + +@pytest.fixture(scope="function") +def add_dataset_func(request: FixtureRequest, client: RAGFlow, WebApiAuth: RAGFlowWebApiAuth) -> str: + def cleanup(): + client.delete_datasets(ids=None) + + request.addfinalizer(cleanup) + return batch_create_datasets(WebApiAuth, 1)[0] diff --git a/test/testcases/test_web_api/test_kb_app/conftest.py b/test/testcases/test_web_api/test_kb_app/conftest.py new file mode 100644 index 00000000000..0a435483ce8 --- /dev/null +++ b/test/testcases/test_web_api/test_kb_app/conftest.py @@ -0,0 +1,38 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import pytest +from common import batch_create_datasets +from libs.auth import RAGFlowWebApiAuth +from pytest import FixtureRequest +from ragflow_sdk import RAGFlow + + +@pytest.fixture(scope="class") +def add_datasets(request: FixtureRequest, client: RAGFlow, WebApiAuth: RAGFlowWebApiAuth) -> list[str]: + def cleanup(): + client.delete_datasets(ids=None) + + request.addfinalizer(cleanup) + return batch_create_datasets(WebApiAuth, 5) + + +@pytest.fixture(scope="function") +def add_datasets_func(request: FixtureRequest, client: RAGFlow, WebApiAuth: RAGFlowWebApiAuth) -> list[str]: + def cleanup(): + client.delete_datasets(ids=None) + + request.addfinalizer(cleanup) + return batch_create_datasets(WebApiAuth, 3) diff --git a/test/testcases/test_web_api/test_kb_app/test_create_kb.py b/test/testcases/test_web_api/test_kb_app/test_create_kb.py new file mode 100644 index 00000000000..82f596491fc --- /dev/null +++ b/test/testcases/test_web_api/test_kb_app/test_create_kb.py @@ -0,0 +1,109 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from concurrent.futures import ThreadPoolExecutor, as_completed + +import pytest +from common import create_kb +from configs import DATASET_NAME_LIMIT, INVALID_API_TOKEN +from hypothesis import example, given, settings +from libs.auth import RAGFlowWebApiAuth +from utils.hypothesis_utils import valid_names + + +@pytest.mark.usefixtures("clear_datasets") +class TestAuthorization: + @pytest.mark.p1 + @pytest.mark.parametrize( + "invalid_auth, expected_code, expected_message", + [ + (None, 401, ""), + (RAGFlowWebApiAuth(INVALID_API_TOKEN), 401, ""), + ], + ids=["empty_auth", "invalid_api_token"], + ) + def test_auth_invalid(self, invalid_auth, expected_code, expected_message): + res = create_kb(invalid_auth, {"name": "auth_test"}) + assert res["code"] == expected_code, res + assert res["message"] == expected_message, res + + +@pytest.mark.usefixtures("clear_datasets") +class TestCapability: + @pytest.mark.p3 + def test_create_kb_1k(self, WebApiAuth): + for i in range(1_000): + payload = {"name": f"dataset_{i}"} + res = create_kb(WebApiAuth, payload) + assert res["code"] == 0, f"Failed to create dataset {i}" + + @pytest.mark.p3 + def test_create_kb_concurrent(self, WebApiAuth): + count = 100 + with ThreadPoolExecutor(max_workers=5) as executor: + futures = [executor.submit(create_kb, WebApiAuth, {"name": f"dataset_{i}"}) for i in range(count)] + responses = list(as_completed(futures)) + assert len(responses) == count, responses + assert all(future.result()["code"] == 0 for future in futures) + + +@pytest.mark.usefixtures("clear_datasets") +class TestDatasetCreate: + @pytest.mark.p1 + @given(name=valid_names()) + @example("a" * 128) + @settings(max_examples=20) + def test_name(self, WebApiAuth, name): + res = create_kb(WebApiAuth, {"name": name}) + assert res["code"] == 0, res + + @pytest.mark.p2 + @pytest.mark.parametrize( + "name, expected_message", + [ + ("", "Dataset name can't be empty."), + (" ", "Dataset name can't be empty."), + ("a" * (DATASET_NAME_LIMIT + 1), "Dataset name length is 129 which is large than 128"), + (0, "Dataset name must be string."), + (None, "Dataset name must be string."), + ], + ids=["empty_name", "space_name", "too_long_name", "invalid_name", "None_name"], + ) + def test_name_invalid(self, WebApiAuth, name, expected_message): + payload = {"name": name} + res = create_kb(WebApiAuth, payload) + assert res["code"] == 102, res + assert expected_message in res["message"], res + + @pytest.mark.p3 + def test_name_duplicated(self, WebApiAuth): + name = "duplicated_name" + payload = {"name": name} + res = create_kb(WebApiAuth, payload) + assert res["code"] == 0, res + + res = create_kb(WebApiAuth, payload) + assert res["code"] == 0, res + + @pytest.mark.p3 + def test_name_case_insensitive(self, WebApiAuth): + name = "CaseInsensitive" + payload = {"name": name.upper()} + res = create_kb(WebApiAuth, payload) + assert res["code"] == 0, res + + payload = {"name": name.lower()} + res = create_kb(WebApiAuth, payload) + assert res["code"] == 0, res diff --git a/test/testcases/test_web_api/test_kb_app/test_detail_kb.py b/test/testcases/test_web_api/test_kb_app/test_detail_kb.py new file mode 100644 index 00000000000..a3c0f82b1a8 --- /dev/null +++ b/test/testcases/test_web_api/test_kb_app/test_detail_kb.py @@ -0,0 +1,53 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import pytest +from common import ( + detail_kb, +) +from configs import INVALID_API_TOKEN +from libs.auth import RAGFlowWebApiAuth + + +class TestAuthorization: + @pytest.mark.p1 + @pytest.mark.parametrize( + "invalid_auth, expected_code, expected_message", + [ + (None, 401, ""), + (RAGFlowWebApiAuth(INVALID_API_TOKEN), 401, ""), + ], + ) + def test_auth_invalid(self, invalid_auth, expected_code, expected_message): + res = detail_kb(invalid_auth) + assert res["code"] == expected_code, res + assert res["message"] == expected_message, res + + +class TestDatasetsDetail: + @pytest.mark.p1 + def test_kb_id(self, WebApiAuth, add_dataset): + kb_id = add_dataset + payload = {"kb_id": kb_id} + res = detail_kb(WebApiAuth, payload) + assert res["code"] == 0, res + assert res["data"]["name"] == "kb_0" + + @pytest.mark.p2 + def test_id_wrong_uuid(self, WebApiAuth): + payload = {"kb_id": "d94a8dc02c9711f0930f7fbc369eab6d"} + res = detail_kb(WebApiAuth, payload) + assert res["code"] == 103, res + assert "Only owner of knowledgebase authorized for this operation." in res["message"], res diff --git a/test/testcases/test_web_api/test_kb_app/test_list_kbs.py b/test/testcases/test_web_api/test_kb_app/test_list_kbs.py new file mode 100644 index 00000000000..5d29968d975 --- /dev/null +++ b/test/testcases/test_web_api/test_kb_app/test_list_kbs.py @@ -0,0 +1,184 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from concurrent.futures import ThreadPoolExecutor, as_completed + +import pytest +from common import list_kbs +from configs import INVALID_API_TOKEN +from libs.auth import RAGFlowWebApiAuth +from utils import is_sorted + + +class TestAuthorization: + @pytest.mark.p1 + @pytest.mark.parametrize( + "invalid_auth, expected_code, expected_message", + [ + (None, 401, ""), + (RAGFlowWebApiAuth(INVALID_API_TOKEN), 401, ""), + ], + ) + def test_auth_invalid(self, invalid_auth, expected_code, expected_message): + res = list_kbs(invalid_auth) + assert res["code"] == expected_code, res + assert res["message"] == expected_message, res + + +class TestCapability: + @pytest.mark.p3 + def test_concurrent_list(self, WebApiAuth): + count = 100 + with ThreadPoolExecutor(max_workers=5) as executor: + futures = [executor.submit(list_kbs, WebApiAuth) for i in range(count)] + responses = list(as_completed(futures)) + assert len(responses) == count, responses + assert all(future.result()["code"] == 0 for future in futures) + + +@pytest.mark.usefixtures("add_datasets") +class TestDatasetsList: + @pytest.mark.p1 + def test_params_unset(self, WebApiAuth): + res = list_kbs(WebApiAuth, None) + assert res["code"] == 0, res + assert len(res["data"]["kbs"]) == 5, res + + @pytest.mark.p2 + def test_params_empty(self, WebApiAuth): + res = list_kbs(WebApiAuth, {}) + assert res["code"] == 0, res + assert len(res["data"]["kbs"]) == 5, res + + @pytest.mark.p1 + @pytest.mark.parametrize( + "params, expected_page_size", + [ + ({"page": 2, "page_size": 2}, 2), + ({"page": 3, "page_size": 2}, 1), + ({"page": 4, "page_size": 2}, 0), + ({"page": "2", "page_size": 2}, 2), + ({"page": 1, "page_size": 10}, 5), + ], + ids=["normal_middle_page", "normal_last_partial_page", "beyond_max_page", "string_page_number", "full_data_single_page"], + ) + def test_page(self, WebApiAuth, params, expected_page_size): + res = list_kbs(WebApiAuth, params) + assert res["code"] == 0, res + assert len(res["data"]["kbs"]) == expected_page_size, res + + @pytest.mark.skip + @pytest.mark.p2 + @pytest.mark.parametrize( + "params, expected_code, expected_message", + [ + ({"page": 0}, 101, "Input should be greater than or equal to 1"), + ({"page": "a"}, 101, "Input should be a valid integer, unable to parse string as an integer"), + ], + ids=["page_0", "page_a"], + ) + def test_page_invalid(self, WebApiAuth, params, expected_code, expected_message): + res = list_kbs(WebApiAuth, params=params) + assert res["code"] == expected_code, res + assert expected_message in res["message"], res + + @pytest.mark.p2 + def test_page_none(self, WebApiAuth): + params = {"page": None} + res = list_kbs(WebApiAuth, params) + assert res["code"] == 0, res + assert len(res["data"]["kbs"]) == 5, res + + @pytest.mark.p1 + @pytest.mark.parametrize( + "params, expected_page_size", + [ + ({"page": 1, "page_size": 1}, 1), + ({"page": 1, "page_size": 3}, 3), + ({"page": 1, "page_size": 5}, 5), + ({"page": 1, "page_size": 6}, 5), + ({"page": 1, "page_size": "1"}, 1), + ], + ids=["min_valid_page_size", "medium_page_size", "page_size_equals_total", "page_size_exceeds_total", "string_type_page_size"], + ) + def test_page_size(self, WebApiAuth, params, expected_page_size): + res = list_kbs(WebApiAuth, params) + assert res["code"] == 0, res + assert len(res["data"]["kbs"]) == expected_page_size, res + + @pytest.mark.skip + @pytest.mark.p2 + @pytest.mark.parametrize( + "params, expected_code, expected_message", + [ + ({"page_size": 0}, 101, "Input should be greater than or equal to 1"), + ({"page_size": "a"}, 101, "Input should be a valid integer, unable to parse string as an integer"), + ], + ) + def test_page_size_invalid(self, WebApiAuth, params, expected_code, expected_message): + res = list_kbs(WebApiAuth, params) + assert res["code"] == expected_code, res + assert expected_message in res["message"], res + + @pytest.mark.p2 + def test_page_size_none(self, WebApiAuth): + params = {"page_size": None} + res = list_kbs(WebApiAuth, params) + assert res["code"] == 0, res + assert len(res["data"]["kbs"]) == 5, res + + @pytest.mark.p2 + @pytest.mark.parametrize( + "params, assertions", + [ + ({"orderby": "update_time"}, lambda r: (is_sorted(r["data"]["kbs"], "update_time", True))), + ], + ids=["orderby_update_time"], + ) + def test_orderby(self, WebApiAuth, params, assertions): + res = list_kbs(WebApiAuth, params) + assert res["code"] == 0, res + if callable(assertions): + assert assertions(res), res + + @pytest.mark.p2 + @pytest.mark.parametrize( + "params, assertions", + [ + ({"desc": "True"}, lambda r: (is_sorted(r["data"]["kbs"], "update_time", True))), + ({"desc": "False"}, lambda r: (is_sorted(r["data"]["kbs"], "update_time", False))), + ], + ids=["desc=True", "desc=False"], + ) + def test_desc(self, WebApiAuth, params, assertions): + res = list_kbs(WebApiAuth, params) + + assert res["code"] == 0, res + if callable(assertions): + assert assertions(res), res + + @pytest.mark.p2 + @pytest.mark.parametrize( + "params, expected_page_size", + [ + ({"parser_id": "naive"}, 5), + ({"parser_id": "qa"}, 0), + ], + ids=["naive", "dqa"], + ) + def test_parser_id(self, WebApiAuth, params, expected_page_size): + res = list_kbs(WebApiAuth, params) + assert res["code"] == 0, res + assert len(res["data"]["kbs"]) == expected_page_size, res diff --git a/test/testcases/test_web_api/test_kb_app/test_rm_kb.py b/test/testcases/test_web_api/test_kb_app/test_rm_kb.py new file mode 100644 index 00000000000..ff20ea8c36b --- /dev/null +++ b/test/testcases/test_web_api/test_kb_app/test_rm_kb.py @@ -0,0 +1,61 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import pytest +from common import ( + list_kbs, + rm_kb, +) +from configs import INVALID_API_TOKEN +from libs.auth import RAGFlowWebApiAuth + + +class TestAuthorization: + @pytest.mark.p1 + @pytest.mark.parametrize( + "invalid_auth, expected_code, expected_message", + [ + (None, 401, ""), + (RAGFlowWebApiAuth(INVALID_API_TOKEN), 401, ""), + ], + ) + def test_auth_invalid(self, invalid_auth, expected_code, expected_message): + res = rm_kb(invalid_auth) + assert res["code"] == expected_code, res + assert res["message"] == expected_message, res + + +class TestDatasetsDelete: + @pytest.mark.p1 + def test_kb_id(self, WebApiAuth, add_datasets_func): + kb_ids = add_datasets_func + payload = {"kb_id": kb_ids[0]} + res = rm_kb(WebApiAuth, payload) + assert res["code"] == 0, res + + res = list_kbs(WebApiAuth) + assert len(res["data"]["kbs"]) == 2, res + + @pytest.mark.p2 + @pytest.mark.usefixtures("add_dataset_func") + def test_id_wrong_uuid(self, WebApiAuth): + payload = {"kb_id": "d94a8dc02c9711f0930f7fbc369eab6d"} + res = rm_kb(WebApiAuth, payload) + assert res["code"] == 109, res + assert "No authorization." in res["message"], res + + res = list_kbs(WebApiAuth) + assert len(res["data"]["kbs"]) == 1, res diff --git a/test/testcases/test_web_api/test_kb_app/test_update_kb.py b/test/testcases/test_web_api/test_kb_app/test_update_kb.py new file mode 100644 index 00000000000..6505dd1b9e1 --- /dev/null +++ b/test/testcases/test_web_api/test_kb_app/test_update_kb.py @@ -0,0 +1,378 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os +from concurrent.futures import ThreadPoolExecutor, as_completed + +import pytest +from common import update_kb +from configs import DATASET_NAME_LIMIT, INVALID_API_TOKEN +from hypothesis import HealthCheck, example, given, settings +from libs.auth import RAGFlowWebApiAuth +from utils import encode_avatar +from utils.file_utils import create_image_file +from utils.hypothesis_utils import valid_names + + +class TestAuthorization: + @pytest.mark.p1 + @pytest.mark.parametrize( + "invalid_auth, expected_code, expected_message", + [ + (None, 401, ""), + (RAGFlowWebApiAuth(INVALID_API_TOKEN), 401, ""), + ], + ids=["empty_auth", "invalid_api_token"], + ) + def test_auth_invalid(self, invalid_auth, expected_code, expected_message): + res = update_kb(invalid_auth, "dataset_id") + assert res["code"] == expected_code, res + assert res["message"] == expected_message, res + + +class TestCapability: + @pytest.mark.p3 + def test_update_dateset_concurrent(self, WebApiAuth, add_dataset_func): + dataset_id = add_dataset_func + count = 100 + with ThreadPoolExecutor(max_workers=5) as executor: + futures = [ + executor.submit( + update_kb, + WebApiAuth, + { + "kb_id": dataset_id, + "name": f"dataset_{i}", + "description": "", + "parser_id": "naive", + }, + ) + for i in range(count) + ] + responses = list(as_completed(futures)) + assert len(responses) == count, responses + assert all(future.result()["code"] == 0 for future in futures) + + +class TestDatasetUpdate: + @pytest.mark.p3 + def test_dataset_id_not_uuid(self, WebApiAuth): + payload = {"name": "not uuid", "description": "", "parser_id": "naive", "kb_id": "not_uuid"} + res = update_kb(WebApiAuth, payload) + assert res["code"] == 109, res + assert "No authorization." in res["message"], res + + @pytest.mark.p1 + @given(name=valid_names()) + @example("a" * 128) + @settings(max_examples=20, suppress_health_check=[HealthCheck.function_scoped_fixture]) + def test_name(self, WebApiAuth, add_dataset_func, name): + dataset_id = add_dataset_func + payload = {"name": name, "description": "", "parser_id": "naive", "kb_id": dataset_id} + res = update_kb(WebApiAuth, payload) + assert res["code"] == 0, res + assert res["data"]["name"] == name, res + + @pytest.mark.p2 + @pytest.mark.parametrize( + "name, expected_message", + [ + ("", "Dataset name can't be empty."), + (" ", "Dataset name can't be empty."), + ("a" * (DATASET_NAME_LIMIT + 1), "Dataset name length is 129 which is large than 128"), + (0, "Dataset name must be string."), + (None, "Dataset name must be string."), + ], + ids=["empty_name", "space_name", "too_long_name", "invalid_name", "None_name"], + ) + def test_name_invalid(self, WebApiAuth, add_dataset_func, name, expected_message): + kb_id = add_dataset_func + payload = {"name": name, "description": "", "parser_id": "naive", "kb_id": kb_id} + res = update_kb(WebApiAuth, payload) + assert res["code"] == 102, res + assert expected_message in res["message"], res + + @pytest.mark.p3 + def test_name_duplicated(self, WebApiAuth, add_datasets_func): + kb_id = add_datasets_func[0] + name = "kb_1" + payload = {"name": name, "description": "", "parser_id": "naive", "kb_id": kb_id} + res = update_kb(WebApiAuth, payload) + assert res["code"] == 102, res + assert res["message"] == "Duplicated knowledgebase name.", res + + @pytest.mark.p3 + def test_name_case_insensitive(self, WebApiAuth, add_datasets_func): + kb_id = add_datasets_func[0] + name = "KB_1" + payload = {"name": name, "description": "", "parser_id": "naive", "kb_id": kb_id} + res = update_kb(WebApiAuth, payload) + assert res["code"] == 102, res + assert res["message"] == "Duplicated knowledgebase name.", res + + @pytest.mark.p2 + def test_avatar(self, WebApiAuth, add_dataset_func, tmp_path): + kb_id = add_dataset_func + fn = create_image_file(tmp_path / "ragflow_test.png") + payload = { + "name": "avatar", + "description": "", + "parser_id": "naive", + "kb_id": kb_id, + "avatar": f"data:image/png;base64,{encode_avatar(fn)}", + } + res = update_kb(WebApiAuth, payload) + assert res["code"] == 0, res + assert res["data"]["avatar"] == f"data:image/png;base64,{encode_avatar(fn)}", res + + @pytest.mark.p2 + def test_description(self, WebApiAuth, add_dataset_func): + kb_id = add_dataset_func + payload = {"name": "description", "description": "description", "parser_id": "naive", "kb_id": kb_id} + res = update_kb(WebApiAuth, payload) + assert res["code"] == 0, res + assert res["data"]["description"] == "description", res + + @pytest.mark.p1 + @pytest.mark.parametrize( + "embedding_model", + [ + "BAAI/bge-large-zh-v1.5@BAAI", + "maidalun1020/bce-embedding-base_v1@Youdao", + "embedding-3@ZHIPU-AI", + ], + ids=["builtin_baai", "builtin_youdao", "tenant_zhipu"], + ) + def test_embedding_model(self, WebApiAuth, add_dataset_func, embedding_model): + kb_id = add_dataset_func + payload = {"name": "embedding_model", "description": "", "parser_id": "naive", "kb_id": kb_id, "embd_id": embedding_model} + res = update_kb(WebApiAuth, payload) + assert res["code"] == 0, res + assert res["data"]["embd_id"] == embedding_model, res + + @pytest.mark.p1 + @pytest.mark.parametrize( + "permission", + [ + "me", + "team", + ], + ids=["me", "team"], + ) + def test_permission(self, WebApiAuth, add_dataset_func, permission): + kb_id = add_dataset_func + payload = {"name": "permission", "description": "", "parser_id": "naive", "kb_id": kb_id, "permission": permission} + res = update_kb(WebApiAuth, payload) + assert res["code"] == 0, res + assert res["data"]["permission"] == permission.lower().strip(), res + + @pytest.mark.p1 + @pytest.mark.parametrize( + "chunk_method", + [ + "naive", + "book", + "email", + "laws", + "manual", + "one", + "paper", + "picture", + "presentation", + "qa", + "table", + pytest.param("tag", marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="Infinity does not support parser_id=tag")), + ], + ids=["naive", "book", "email", "laws", "manual", "one", "paper", "picture", "presentation", "qa", "table", "tag"], + ) + def test_chunk_method(self, WebApiAuth, add_dataset_func, chunk_method): + kb_id = add_dataset_func + payload = {"name": "chunk_method", "description": "", "parser_id": chunk_method, "kb_id": kb_id} + res = update_kb(WebApiAuth, payload) + assert res["code"] == 0, res + assert res["data"]["parser_id"] == chunk_method, res + + @pytest.mark.p1 + @pytest.mark.skipif(os.getenv("DOC_ENGINE") != "infinity", reason="Infinity does not support parser_id=tag") + def test_chunk_method_tag_with_infinity(self, WebApiAuth, add_dataset_func): + kb_id = add_dataset_func + payload = {"name": "chunk_method", "description": "", "parser_id": "tag", "kb_id": kb_id} + res = update_kb(WebApiAuth, payload) + assert res["code"] == 103, res + assert res["message"] == "The chunking method Tag has not been supported by Infinity yet.", res + + @pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="#8208") + @pytest.mark.p2 + @pytest.mark.parametrize("pagerank", [0, 50, 100], ids=["min", "mid", "max"]) + def test_pagerank(self, WebApiAuth, add_dataset_func, pagerank): + kb_id = add_dataset_func + payload = {"name": "pagerank", "description": "", "parser_id": "naive", "kb_id": kb_id, "pagerank": pagerank} + res = update_kb(WebApiAuth, payload) + assert res["code"] == 0, res + assert res["data"]["pagerank"] == pagerank, res + + @pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="#8208") + @pytest.mark.p2 + def test_pagerank_set_to_0(self, WebApiAuth, add_dataset_func): + kb_id = add_dataset_func + payload = {"name": "pagerank", "description": "", "parser_id": "naive", "kb_id": kb_id, "pagerank": 50} + res = update_kb(WebApiAuth, payload) + assert res["code"] == 0, res + assert res["data"]["pagerank"] == 50, res + + payload = {"name": "pagerank", "description": "", "parser_id": "naive", "kb_id": kb_id, "pagerank": 0} + res = update_kb(WebApiAuth, payload) + assert res["code"] == 0, res + assert res["data"]["pagerank"] == 0, res + + @pytest.mark.skipif(os.getenv("DOC_ENGINE") != "infinity", reason="#8208") + @pytest.mark.p2 + def test_pagerank_infinity(self, WebApiAuth, add_dataset_func): + kb_id = add_dataset_func + payload = {"name": "pagerank", "description": "", "parser_id": "naive", "kb_id": kb_id, "pagerank": 50} + res = update_kb(WebApiAuth, payload) + assert res["code"] == 102, res + assert res["message"] == "'pagerank' can only be set when doc_engine is elasticsearch", res + + @pytest.mark.p1 + @pytest.mark.parametrize( + "parser_config", + [ + {"auto_keywords": 0}, + {"auto_keywords": 16}, + {"auto_keywords": 32}, + {"auto_questions": 0}, + {"auto_questions": 5}, + {"auto_questions": 10}, + {"chunk_token_num": 1}, + {"chunk_token_num": 1024}, + {"chunk_token_num": 2048}, + {"delimiter": "\n"}, + {"delimiter": " "}, + {"html4excel": True}, + {"html4excel": False}, + {"layout_recognize": "DeepDOC"}, + {"layout_recognize": "Plain Text"}, + {"tag_kb_ids": ["1", "2"]}, + {"topn_tags": 1}, + {"topn_tags": 5}, + {"topn_tags": 10}, + {"filename_embd_weight": 0.1}, + {"filename_embd_weight": 0.5}, + {"filename_embd_weight": 1.0}, + {"task_page_size": 1}, + {"task_page_size": None}, + {"pages": [[1, 100]]}, + {"pages": None}, + {"graphrag": {"use_graphrag": True}}, + {"graphrag": {"use_graphrag": False}}, + {"graphrag": {"entity_types": ["age", "sex", "height", "weight"]}}, + {"graphrag": {"method": "general"}}, + {"graphrag": {"method": "light"}}, + {"graphrag": {"community": True}}, + {"graphrag": {"community": False}}, + {"graphrag": {"resolution": True}}, + {"graphrag": {"resolution": False}}, + {"raptor": {"use_raptor": True}}, + {"raptor": {"use_raptor": False}}, + {"raptor": {"prompt": "Who are you?"}}, + {"raptor": {"max_token": 1}}, + {"raptor": {"max_token": 1024}}, + {"raptor": {"max_token": 2048}}, + {"raptor": {"threshold": 0.0}}, + {"raptor": {"threshold": 0.5}}, + {"raptor": {"threshold": 1.0}}, + {"raptor": {"max_cluster": 1}}, + {"raptor": {"max_cluster": 512}}, + {"raptor": {"max_cluster": 1024}}, + {"raptor": {"random_seed": 0}}, + ], + ids=[ + "auto_keywords_min", + "auto_keywords_mid", + "auto_keywords_max", + "auto_questions_min", + "auto_questions_mid", + "auto_questions_max", + "chunk_token_num_min", + "chunk_token_num_mid", + "chunk_token_num_max", + "delimiter", + "delimiter_space", + "html4excel_true", + "html4excel_false", + "layout_recognize_DeepDOC", + "layout_recognize_navie", + "tag_kb_ids", + "topn_tags_min", + "topn_tags_mid", + "topn_tags_max", + "filename_embd_weight_min", + "filename_embd_weight_mid", + "filename_embd_weight_max", + "task_page_size_min", + "task_page_size_None", + "pages", + "pages_none", + "graphrag_true", + "graphrag_false", + "graphrag_entity_types", + "graphrag_method_general", + "graphrag_method_light", + "graphrag_community_true", + "graphrag_community_false", + "graphrag_resolution_true", + "graphrag_resolution_false", + "raptor_true", + "raptor_false", + "raptor_prompt", + "raptor_max_token_min", + "raptor_max_token_mid", + "raptor_max_token_max", + "raptor_threshold_min", + "raptor_threshold_mid", + "raptor_threshold_max", + "raptor_max_cluster_min", + "raptor_max_cluster_mid", + "raptor_max_cluster_max", + "raptor_random_seed_min", + ], + ) + def test_parser_config(self, WebApiAuth, add_dataset_func, parser_config): + kb_id = add_dataset_func + payload = {"name": "parser_config", "description": "", "parser_id": "naive", "kb_id": kb_id, "parser_config": parser_config} + res = update_kb(WebApiAuth, payload) + assert res["code"] == 0, res + assert res["data"]["parser_config"] == parser_config, res + + @pytest.mark.p2 + @pytest.mark.parametrize( + "payload", + [ + {"id": "id"}, + {"tenant_id": "e57c1966f99211efb41e9e45646e0111"}, + {"created_by": "created_by"}, + {"create_date": "Tue, 11 Mar 2025 13:37:23 GMT"}, + {"create_time": 1741671443322}, + {"update_date": "Tue, 11 Mar 2025 13:37:23 GMT"}, + {"update_time": 1741671443339}, + ], + ) + def test_field_unsupported(self, WebApiAuth, add_dataset_func, payload): + kb_id = add_dataset_func + full_payload = {"name": "field_unsupported", "description": "", "parser_id": "naive", "kb_id": kb_id, **payload} + res = update_kb(WebApiAuth, full_payload) + assert res["code"] == 101, res + assert "isn't allowed" in res["message"], res diff --git a/test/utils/__init__.py b/test/utils/__init__.py new file mode 100644 index 00000000000..7620fdac266 --- /dev/null +++ b/test/utils/__init__.py @@ -0,0 +1,63 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import base64 +import functools +import hashlib +import time +from pathlib import Path + + +def encode_avatar(image_path): + with Path.open(image_path, "rb") as file: + binary_data = file.read() + base64_encoded = base64.b64encode(binary_data).decode("utf-8") + return base64_encoded + + +def compare_by_hash(file1, file2, algorithm="sha256"): + def _calc_hash(file_path): + hash_func = hashlib.new(algorithm) + with open(file_path, "rb") as f: + while chunk := f.read(8192): + hash_func.update(chunk) + return hash_func.hexdigest() + + return _calc_hash(file1) == _calc_hash(file2) + + +def wait_for(timeout=10, interval=1, error_msg="Timeout"): + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + start_time = time.time() + while True: + result = func(*args, **kwargs) + if result is True: + return result + elapsed = time.time() - start_time + if elapsed > timeout: + assert False, error_msg + time.sleep(interval) + + return wrapper + + return decorator + + +def is_sorted(data, field, descending=True): + timestamps = [ds[field] for ds in data] + return all(a >= b for a, b in zip(timestamps, timestamps[1:])) if descending else all(a <= b for a, b in zip(timestamps, timestamps[1:])) diff --git a/test/utils/file_utils.py b/test/utils/file_utils.py new file mode 100644 index 00000000000..6ccfb02dc2b --- /dev/null +++ b/test/utils/file_utils.py @@ -0,0 +1,107 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import json + +from docx import Document # pip install python-docx +from openpyxl import Workbook # pip install openpyxl +from PIL import Image, ImageDraw # pip install Pillow +from pptx import Presentation # pip install python-pptx +from reportlab.pdfgen import canvas # pip install reportlab + + +def create_docx_file(path): + doc = Document() + doc.add_paragraph("This is a test DOCX file.") + doc.save(path) + return path + + +def create_excel_file(path): + wb = Workbook() + ws = wb.active + ws["A1"] = "Test Excel File" + wb.save(path) + return path + + +def create_ppt_file(path): + prs = Presentation() + slide = prs.slides.add_slide(prs.slide_layouts[0]) + slide.shapes.title.text = "Test PPT File" + prs.save(path) + return path + + +def create_image_file(path): + img = Image.new("RGB", (100, 100), color="blue") + draw = ImageDraw.Draw(img) + draw.text((10, 40), "Test", fill="white") + img.save(path) + return path + + +def create_pdf_file(path): + if not isinstance(path, str): + path = str(path) + c = canvas.Canvas(path) + c.drawString(100, 750, "Test PDF File") + c.save() + return path + + +def create_txt_file(path): + with open(path, "w", encoding="utf-8") as f: + f.write("This is the content of a test TXT file.") + return path + + +def create_md_file(path): + md_content = "# Test MD File\n\nThis is a test Markdown file." + with open(path, "w", encoding="utf-8") as f: + f.write(md_content) + return path + + +def create_json_file(path): + data = {"message": "This is a test JSON file", "value": 123} + with open(path, "w", encoding="utf-8") as f: + json.dump(data, f, indent=2) + return path + + +def create_eml_file(path): + eml_content = ( + "From: sender@example.com\n" + "To: receiver@example.com\n" + "Subject: Test EML File\n\n" + "This is a test email content.\n" + ) + with open(path, "w", encoding="utf-8") as f: + f.write(eml_content) + return path + + +def create_html_file(path): + html_content = ( + "\n" + "Test HTML File\n" + "

This is a test HTML file

\n" + "" + ) + with open(path, "w", encoding="utf-8") as f: + f.write(html_content) + return path diff --git a/test/utils/hypothesis_utils.py b/test/utils/hypothesis_utils.py new file mode 100644 index 00000000000..736e6cbdf55 --- /dev/null +++ b/test/utils/hypothesis_utils.py @@ -0,0 +1,28 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +import hypothesis.strategies as st + + +@st.composite +def valid_names(draw): + base_chars = "abcdefghijklmnopqrstuvwxyz_" + first_char = draw(st.sampled_from([c for c in base_chars if c.isalpha() or c == "_"])) + remaining = draw(st.text(alphabet=st.sampled_from(base_chars), min_size=0, max_size=128 - 2)) + + name = (first_char + remaining)[:128] + return name.encode("utf-8").decode("utf-8") diff --git a/uv.lock b/uv.lock index 6f0a57ac7fd..5cf632156c6 100644 --- a/uv.lock +++ b/uv.lock @@ -2329,6 +2329,20 @@ wheels = [ { url = "https://mirrors.aliyun.com/pypi/packages/48/30/47d0bf6072f7252e6521f3447ccfa40b421b6824517f82854703d0f5a98b/hyperframe-6.1.0-py3-none-any.whl", hash = "sha256:b03380493a519fce58ea5af42e4a42317bf9bd425596f7a0835ffce80f1a42e5" }, ] +[[package]] +name = "hypothesis" +version = "6.132.0" +source = { registry = "https://mirrors.aliyun.com/pypi/simple" } +dependencies = [ + { name = "attrs" }, + { name = "exceptiongroup", marker = "python_full_version < '3.11'" }, + { name = "sortedcontainers" }, +] +sdist = { url = "https://mirrors.aliyun.com/pypi/packages/35/ff/8a67f7217f86d0bc597a2d8c958b273729592d5b2cb40430506b8fb4acbd/hypothesis-6.132.0.tar.gz", hash = "sha256:55868060add41baa6176ed9c3456655678d140c74e3514bdf03381dae6391403" } +wheels = [ + { url = "https://mirrors.aliyun.com/pypi/packages/36/16/92696e0da87d9799e22742afdd44978440c7a5738191929def39fc8115ac/hypothesis-6.132.0-py3-none-any.whl", hash = "sha256:9d11f81664c0688d27d37c871cee8baf4349383cf9ef9938ef6b3ae836962595" }, +] + [[package]] name = "hyppo" version = "0.4.0" @@ -3648,6 +3662,30 @@ wheels = [ { url = "https://mirrors.aliyun.com/pypi/packages/26/d0/22f68eb23eea053a31655960f133c0be9726c6a881547e6e9e7e2a946c4f/opencv_python_headless-4.10.0.84-cp37-abi3-win_amd64.whl", hash = "sha256:afcf28bd1209dd58810d33defb622b325d3cbe49dcd7a43a902982c33e5fad05" }, ] +[[package]] +name = "opendal" +version = "0.45.20" +source = { registry = "https://mirrors.aliyun.com/pypi/simple" } +sdist = { url = "https://mirrors.aliyun.com/pypi/packages/2f/3f/927dfe1349ae58b9238b8eafba747af648d660a9425f486dda01a10f0b78/opendal-0.45.20.tar.gz", hash = "sha256:9f6f90d9e9f9d6e9e5a34aa7729169ef34d2f1869ad1e01ddc39b1c0ce0c9405" } +wheels = [ + { url = "https://mirrors.aliyun.com/pypi/packages/f7/d9/b74575762bd9178b0498125f270268e0fb122ee991188e053048da7f002c/opendal-0.45.20-cp310-cp310-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:d6069cef67f501eda221da63320bd1291aee967f5f8678ccee9e6e566ab37c78" }, + { url = "https://mirrors.aliyun.com/pypi/packages/56/f6/0af7d8a4afe5bae6222c4715f0563fa8c257f0525802da47120e28314353/opendal-0.45.20-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:c52c4bf9433a3fa17d1f7b18f386a8f601c4b41e3fae9a839d0a861867d6086a" }, + { url = "https://mirrors.aliyun.com/pypi/packages/96/16/cf0cfc0838c7837f5642824738ad57f84cee658b4cfdd2b25fdfb52ca8a7/opendal-0.45.20-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:088bc9b20c5f07bbb19a9ff45c32dd3d42cf2d0b4ef40a2319ca27cdc635bf0f" }, + { url = "https://mirrors.aliyun.com/pypi/packages/b0/76/e903436877895fcf948e36aa728b4b56a3a600c4fd3297d8e4bc38a843be/opendal-0.45.20-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:55efb4388fa03f309de497bf9b9854377fc4045da069c72c9d2df21d24c686cb" }, + { url = "https://mirrors.aliyun.com/pypi/packages/34/10/7863a90a592ed6bfb2ddde104db23a00586004e2197f86a255ad9f8a9401/opendal-0.45.20-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:49c966cda40dc6b7b100ea6150d2f29e01ed7db694c5a5168c5fc451872ec77c" }, + { url = "https://mirrors.aliyun.com/pypi/packages/b4/a3/b77497101e320bcaebb7e99c43d61ca1886795c6a83001d4426cdbc3683d/opendal-0.45.20-cp310-cp310-musllinux_1_1_armv7l.whl", hash = "sha256:e81af55e1d8c145119dfa4c9cacd1fd60c1c1fba2207ec5064cb6baae8c3c86b" }, + { url = "https://mirrors.aliyun.com/pypi/packages/fc/36/21495e4a405d47ece52df98c323ba9467f43e0641e04819ab5732bf0f370/opendal-0.45.20-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:3bbdfcb6840ab8bbd29c36a2a329c1f691023b3cd6a26f8a285dc89f39526017" }, + { url = "https://mirrors.aliyun.com/pypi/packages/50/28/bb822cad3f3ef15836227751dad46554c499bbefcf0eb34b4cc7e9975e9b/opendal-0.45.20-cp310-cp310-win_amd64.whl", hash = "sha256:e3987c4766a3611ea8cb3a216f21d083ac3e7fa91eb2ff7c0ebe5dc6e6958cce" }, + { url = "https://mirrors.aliyun.com/pypi/packages/84/77/6427e16b8630f0cc71f4a1b01648ed3264f1e04f1f6d9b5d09e5c6a4dd2f/opendal-0.45.20-cp311-abi3-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:35acdd8001e4a741532834fdbff3020ffb10b40028bb49fbe93c4f8197d66d8c" }, + { url = "https://mirrors.aliyun.com/pypi/packages/12/1f/83e415334739f1ab4dba55cdd349abf0b66612249055afb422a354b96ac8/opendal-0.45.20-cp311-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:629bfe8d384364bced6cbeb01f49b99779fa5151c68048a1869ff645ddcfcb25" }, + { url = "https://mirrors.aliyun.com/pypi/packages/49/94/c5de6ed54a02d7413636c2ccefa71d8dd09c2ada1cd6ecab202feb1fdeda/opendal-0.45.20-cp311-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d12cc5ac7e441fb93d86d1673112d9fb08580fc3226f864434f4a56a72efec53" }, + { url = "https://mirrors.aliyun.com/pypi/packages/c6/83/713a1e1de8cbbd69af50e26644bbdeef3c1068b89f442417376fa3c0f591/opendal-0.45.20-cp311-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:45a3adae1f473052234fc4054a6f210df3ded9aff10db8d545d0a37eff3b13cc" }, + { url = "https://mirrors.aliyun.com/pypi/packages/c7/78/c9651e753aaf6eb61887ca372a3f9c2ae57dae03c3159d24deaf018c26dc/opendal-0.45.20-cp311-abi3-musllinux_1_1_aarch64.whl", hash = "sha256:d8947857052c85a4b0e251d50e23f5f68f0cdd9e509e32e614a5e4b2fc7424c4" }, + { url = "https://mirrors.aliyun.com/pypi/packages/3c/9d/5d8c20c0fc93df5e349e5694167de30afdc54c5755704cc64764a6cbb309/opendal-0.45.20-cp311-abi3-musllinux_1_1_armv7l.whl", hash = "sha256:891d2f9114efeef648973049ed15e56477e8feb9e48b540bd8d6105ea22a253c" }, + { url = "https://mirrors.aliyun.com/pypi/packages/21/39/05262f748a2085522e0c85f03eab945589313dc9caedc002872c39162776/opendal-0.45.20-cp311-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:539de9b825f6783d6289d88c0c9ac5415daa4d892d761e3540c565bda51e8997" }, + { url = "https://mirrors.aliyun.com/pypi/packages/74/83/cc7c6de29b0a7585cd445258d174ca204d37729c3874ad08e515b0bf331c/opendal-0.45.20-cp311-abi3-win_amd64.whl", hash = "sha256:145efd56aa33b493d5b652c3e4f5ae5097ab69d38c132d80f108e9f5c1e4d863" }, +] + [[package]] name = "openpyxl" version = "3.1.5" @@ -4813,7 +4851,7 @@ wheels = [ [[package]] name = "ragflow" -version = "0.19.0" +version = "0.19.1" source = { virtual = "." } dependencies = [ { name = "akshare" }, @@ -4877,6 +4915,7 @@ dependencies = [ { name = "openai" }, { name = "opencv-python" }, { name = "opencv-python-headless" }, + { name = "opendal" }, { name = "openpyxl" }, { name = "opensearch-py" }, { name = "ormsgpack" }, @@ -4894,7 +4933,6 @@ dependencies = [ { name = "pyodbc" }, { name = "pypdf" }, { name = "pypdf2" }, - { name = "pytest" }, { name = "python-dateutil" }, { name = "python-docx" }, { name = "python-dotenv" }, @@ -4948,6 +4986,19 @@ full = [ { name = "transformers" }, ] +[package.dev-dependencies] +test = [ + { name = "hypothesis" }, + { name = "openpyxl" }, + { name = "pillow" }, + { name = "pytest" }, + { name = "python-docx" }, + { name = "python-pptx" }, + { name = "reportlab" }, + { name = "requests" }, + { name = "requests-toolbelt" }, +] + [package.metadata] requires-dist = [ { name = "akshare", specifier = ">=1.15.78,<2.0.0" }, @@ -5015,6 +5066,7 @@ requires-dist = [ { name = "openai", specifier = "==1.45.0" }, { name = "opencv-python", specifier = "==4.10.0.84" }, { name = "opencv-python-headless", specifier = "==4.10.0.84" }, + { name = "opendal", specifier = ">=0.45.0,<0.46.0" }, { name = "openpyxl", specifier = ">=3.1.0,<4.0.0" }, { name = "opensearch-py", specifier = "==2.7.1" }, { name = "ormsgpack", specifier = "==1.5.0" }, @@ -5032,7 +5084,6 @@ requires-dist = [ { name = "pyodbc", specifier = ">=5.2.0,<6.0.0" }, { name = "pypdf", specifier = ">=5.0.0,<6.0.0" }, { name = "pypdf2", specifier = ">=3.0.1,<4.0.0" }, - { name = "pytest", specifier = ">=8.3.0,<9.0.0" }, { name = "python-dateutil", specifier = "==2.8.2" }, { name = "python-docx", specifier = ">=1.1.2,<2.0.0" }, { name = "python-dotenv", specifier = "==1.0.1" }, @@ -5071,7 +5122,7 @@ requires-dist = [ { name = "werkzeug", specifier = "==3.0.6" }, { name = "wikipedia", specifier = "==1.4.0" }, { name = "word2number", specifier = "==1.1" }, - { name = "xgboost", specifier = "==1.5.0" }, + { name = "xgboost", specifier = "==1.6.0" }, { name = "xpinyin", specifier = "==0.7.6" }, { name = "xxhash", specifier = ">=3.5.0,<4.0.0" }, { name = "yfinance", specifier = "==0.1.96" }, @@ -5079,6 +5130,19 @@ requires-dist = [ ] provides-extras = ["full"] +[package.metadata.requires-dev] +test = [ + { name = "hypothesis", specifier = ">=6.132.0" }, + { name = "openpyxl", specifier = ">=3.1.5" }, + { name = "pillow", specifier = ">=10.4.0" }, + { name = "pytest", specifier = ">=8.3.5" }, + { name = "python-docx", specifier = ">=1.1.2" }, + { name = "python-pptx", specifier = ">=1.0.2" }, + { name = "reportlab", specifier = ">=4.4.1" }, + { name = "requests", specifier = ">=2.32.2" }, + { name = "requests-toolbelt", specifier = ">=1.0.0" }, +] + [[package]] name = "ranx" version = "0.3.20" @@ -5212,6 +5276,19 @@ wheels = [ { url = "https://mirrors.aliyun.com/pypi/packages/77/0f/f6067b7076faee22aef6190f703524e8ba8eac490191352c5cb0253c4823/replicate-0.31.0-py3-none-any.whl", hash = "sha256:27ee067ccb4c37d8c2fc5ab87bb312da36447dfcd12527002bbd0b78f6ef195a" }, ] +[[package]] +name = "reportlab" +version = "4.4.1" +source = { registry = "https://mirrors.aliyun.com/pypi/simple" } +dependencies = [ + { name = "chardet" }, + { name = "pillow" }, +] +sdist = { url = "https://mirrors.aliyun.com/pypi/packages/7b/d8/c3366bf10a5a5fcc3467eefa9504f6aa24fcda5817b5b147eabd37a385e1/reportlab-4.4.1.tar.gz", hash = "sha256:5f9b9fc0b7a48e8912c25ccf69d26b82980ab0da718e4f583fa720e8f8f5073f" } +wheels = [ + { url = "https://mirrors.aliyun.com/pypi/packages/a1/2e/7994a139150abf11c8dd258feb091ad654823a83cfd9720bfdded27185c3/reportlab-4.4.1-py3-none-any.whl", hash = "sha256:9217a1c8c1917217f819718b24972a96ad0c485a1c494749562d097b58d974b7" }, +] + [[package]] name = "requests" version = "2.32.2" @@ -5232,6 +5309,18 @@ socks = [ { name = "pysocks" }, ] +[[package]] +name = "requests-toolbelt" +version = "1.0.0" +source = { registry = "https://mirrors.aliyun.com/pypi/simple" } +dependencies = [ + { name = "requests" }, +] +sdist = { url = "https://mirrors.aliyun.com/pypi/packages/f3/61/d7545dafb7ac2230c70d38d31cbfe4cc64f7144dc41f6e4e4b78ecd9f5bb/requests-toolbelt-1.0.0.tar.gz", hash = "sha256:7681a0a3d047012b5bdc0ee37d7f8f07ebe76ab08caeccfc3921ce23c88d5bc6" } +wheels = [ + { url = "https://mirrors.aliyun.com/pypi/packages/3f/51/d4db610ef29373b879047326cbf6fa98b6c1969d6f6dc423279de2b1be2c/requests_toolbelt-1.0.0-py2.py3-none-any.whl", hash = "sha256:cccfdd665f0a24fcf4726e690f65639d272bb0637b9b92dfd91a5568ccf6bd06" }, +] + [[package]] name = "retry" version = "0.9.2" @@ -6583,18 +6672,19 @@ wheels = [ [[package]] name = "xgboost" -version = "1.5.0" +version = "1.6.0" source = { registry = "https://mirrors.aliyun.com/pypi/simple" } dependencies = [ { name = "numpy" }, { name = "scipy" }, ] -sdist = { url = "https://mirrors.aliyun.com/pypi/packages/15/70/7308768fdd4a35477efb01098ffd426e455b600837ed0dd70c3293cd3e03/xgboost-1.5.0.tar.gz", hash = "sha256:6b7c34a18474c1b73b5c6dcf1231f5fe102d6f26490a65465e3879048bf1f3d4" } +sdist = { url = "https://mirrors.aliyun.com/pypi/packages/77/89/92b399140a7688443fc182b54240822c903e906121d63446eb2f84350e99/xgboost-1.6.0.tar.gz", hash = "sha256:9c944c2495cb426b8a365021565755c39ee0b53156cf5e53a4346bdad2e3b734" } wheels = [ - { url = "https://mirrors.aliyun.com/pypi/packages/9e/0d/376d5491ba5f06c6224304d572e1131521115d423a6c6dc347e85129c55c/xgboost-1.5.0-py3-none-macosx_10_14_x86_64.macosx_10_15_x86_64.macosx_11_0_x86_64.whl", hash = "sha256:ae122406c2c1d2a407c18a49c2a05f9bbbaa3249fd39e91c11cb223d538dba4e" }, - { url = "https://mirrors.aliyun.com/pypi/packages/dd/e2/43d1d96344d8733bbebde084bbf95ef8f2a3aab9f8d3779c3ee381754514/xgboost-1.5.0-py3-none-manylinux2014_aarch64.whl", hash = "sha256:ebe36ee21516a37f645bcd1f3ca1247485fe77d96f1c3d605f970c469b6a9015" }, - { url = "https://mirrors.aliyun.com/pypi/packages/a7/c9/4968ca0434c313aed71fc4cc2339aa8844482d5eefdcc8989c985a19ea2e/xgboost-1.5.0-py3-none-manylinux2014_x86_64.whl", hash = "sha256:edaad84317a53671069b4477152915c25b4121d997bfa531711f0a18ba1402fe" }, - { url = "https://mirrors.aliyun.com/pypi/packages/a2/8a/dbfcab37ea93951fa85e5746de0fc7378b2a4c99a32ddad36a7fb504ed57/xgboost-1.5.0-py3-none-win_amd64.whl", hash = "sha256:9502a9ed5c52669c83207380eeaaa3862cbf38b271040c6d5226914e07cf196c" }, + { url = "https://mirrors.aliyun.com/pypi/packages/f1/71/abca2240b5d19aa3e90c8228cf307962fc9f598acc3c623fb49db83b4092/xgboost-1.6.0-py3-none-macosx_10_15_x86_64.macosx_11_0_x86_64.macosx_12_0_x86_64.whl", hash = "sha256:5f7fd61024c41d0c424a8272dfd27797a0393a616b717c05c0f981a49a47b4fd" }, + { url = "https://mirrors.aliyun.com/pypi/packages/49/d0/85c9c40e7ca1a4bc05278c1e57a89c43ab846be4cb5227871ca7605921a6/xgboost-1.6.0-py3-none-macosx_12_0_arm64.whl", hash = "sha256:ad27c6a72f6abef6d20e67f957fb25553bb09a6d1c4eaf08cb8ee7efca288255" }, + { url = "https://mirrors.aliyun.com/pypi/packages/c3/be/18970943eb7e9d9ded5e37e87c1dc02c8a961416f725f2734629f26d69d5/xgboost-1.6.0-py3-none-manylinux2014_aarch64.whl", hash = "sha256:b1d532b8d548dd3acb4bd5f56632339e48167d9e2ec0eda2d8d6b4cd772e03b4" }, + { url = "https://mirrors.aliyun.com/pypi/packages/bf/64/c467a20848adc3d1c3f45d60df9c7cd0c40a548fd534a9f842a35114039d/xgboost-1.6.0-py3-none-manylinux2014_x86_64.whl", hash = "sha256:640b9649104f22f0dc43c7202d22cde5531cc303801a9c75cad3f2b6e413dcf7" }, + { url = "https://mirrors.aliyun.com/pypi/packages/64/51/3e33a4df0ca66474e7f4e357328a5c7b35fb52cbc48b312c64d276d37da8/xgboost-1.6.0-py3-none-win_amd64.whl", hash = "sha256:e2f9baca0b7cbc208ad4fbafa4cd70b50b292717ee8ba817a3ba7a0fe49de958" }, ] [[package]] diff --git a/web/.umirc.ts b/web/.umirc.ts index cae37519ddd..dab80fcb222 100644 --- a/web/.umirc.ts +++ b/web/.umirc.ts @@ -16,6 +16,7 @@ export default defineConfig({ icons: {}, hash: true, favicons: ['/logo.svg'], + headScripts: [{ src: '/iconfont.js', defer: true }], clickToComponent: {}, history: { type: 'browser', diff --git a/web/package-lock.json b/web/package-lock.json index b55ee249ae8..ad3faec8fcd 100644 --- a/web/package-lock.json +++ b/web/package-lock.json @@ -38,6 +38,8 @@ "@radix-ui/react-switch": "^1.1.1", "@radix-ui/react-tabs": "^1.1.1", "@radix-ui/react-toast": "^1.2.6", + "@radix-ui/react-toggle": "^1.1.9", + "@radix-ui/react-toggle-group": "^1.1.10", "@radix-ui/react-tooltip": "^1.1.4", "@tailwindcss/line-clamp": "^0.4.4", "@tanstack/react-query": "^5.40.0", @@ -72,7 +74,7 @@ "react-copy-to-clipboard": "^5.1.0", "react-dropzone": "^14.3.5", "react-error-boundary": "^4.0.13", - "react-hook-form": "^7.53.1", + "react-hook-form": "^7.56.4", "react-i18next": "^14.0.0", "react-infinite-scroll-component": "^6.1.0", "react-markdown": "^9.0.1", @@ -112,6 +114,7 @@ "@types/webpack-env": "^1.18.4", "@umijs/lint": "^4.1.1", "@umijs/plugins": "^4.1.0", + "@welldone-software/why-did-you-render": "^8.0.3", "cross-env": "^7.0.3", "html-loader": "^5.1.0", "husky": "^9.0.11", @@ -7590,6 +7593,372 @@ } } }, + "node_modules/@radix-ui/react-toggle": { + "version": "1.1.9", + "resolved": "https://registry.npmmirror.com/@radix-ui/react-toggle/-/react-toggle-1.1.9.tgz", + "integrity": "sha512-ZoFkBBz9zv9GWer7wIjvdRxmh2wyc2oKWw6C6CseWd6/yq1DK/l5lJ+wnsmFwJZbBYqr02mrf8A2q/CVCuM3ZA==", + "license": "MIT", + "dependencies": { + "@radix-ui/primitive": "1.1.2", + "@radix-ui/react-primitive": "2.1.3", + "@radix-ui/react-use-controllable-state": "1.2.2" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", + "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-toggle-group": { + "version": "1.1.10", + "resolved": "https://registry.npmmirror.com/@radix-ui/react-toggle-group/-/react-toggle-group-1.1.10.tgz", + "integrity": "sha512-kiU694Km3WFLTC75DdqgM/3Jauf3rD9wxeS9XtyWFKsBUeZA337lC+6uUazT7I1DhanZ5gyD5Stf8uf2dbQxOQ==", + "license": "MIT", + "dependencies": { + "@radix-ui/primitive": "1.1.2", + "@radix-ui/react-context": "1.1.2", + "@radix-ui/react-direction": "1.1.1", + "@radix-ui/react-primitive": "2.1.3", + "@radix-ui/react-roving-focus": "1.1.10", + "@radix-ui/react-toggle": "1.1.9", + "@radix-ui/react-use-controllable-state": "1.2.2" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", + "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-toggle-group/node_modules/@radix-ui/primitive": { + "version": "1.1.2", + "resolved": "https://registry.npmmirror.com/@radix-ui/primitive/-/primitive-1.1.2.tgz", + "integrity": "sha512-XnbHrrprsNqZKQhStrSwgRUQzoCI1glLzdw79xiZPoofhGICeZRSQ3dIxAKH1gb3OHfNf4d6f+vAv3kil2eggA==", + "license": "MIT" + }, + "node_modules/@radix-ui/react-toggle-group/node_modules/@radix-ui/react-collection": { + "version": "1.1.7", + "resolved": "https://registry.npmmirror.com/@radix-ui/react-collection/-/react-collection-1.1.7.tgz", + "integrity": "sha512-Fh9rGN0MoI4ZFUNyfFVNU4y9LUz93u9/0K+yLgA2bwRojxM8JU1DyvvMBabnZPBgMWREAJvU2jjVzq+LrFUglw==", + "license": "MIT", + "dependencies": { + "@radix-ui/react-compose-refs": "1.1.2", + "@radix-ui/react-context": "1.1.2", + "@radix-ui/react-primitive": "2.1.3", + "@radix-ui/react-slot": "1.2.3" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", + "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-toggle-group/node_modules/@radix-ui/react-compose-refs": { + "version": "1.1.2", + "resolved": "https://registry.npmmirror.com/@radix-ui/react-compose-refs/-/react-compose-refs-1.1.2.tgz", + "integrity": "sha512-z4eqJvfiNnFMHIIvXP3CY57y2WJs5g2v3X0zm9mEJkrkNv4rDxu+sg9Jh8EkXyeqBkB7SOcboo9dMVqhyrACIg==", + "license": "MIT", + "peerDependencies": { + "@types/react": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-toggle-group/node_modules/@radix-ui/react-context": { + "version": "1.1.2", + "resolved": "https://registry.npmmirror.com/@radix-ui/react-context/-/react-context-1.1.2.tgz", + "integrity": "sha512-jCi/QKUM2r1Ju5a3J64TH2A5SpKAgh0LpknyqdQ4m6DCV0xJ2HG1xARRwNGPQfi1SLdLWZ1OJz6F4OMBBNiGJA==", + "license": "MIT", + "peerDependencies": { + "@types/react": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-toggle-group/node_modules/@radix-ui/react-direction": { + "version": "1.1.1", + "resolved": "https://registry.npmmirror.com/@radix-ui/react-direction/-/react-direction-1.1.1.tgz", + "integrity": "sha512-1UEWRX6jnOA2y4H5WczZ44gOOjTEmlqv1uNW4GAJEO5+bauCBhv8snY65Iw5/VOS/ghKN9gr2KjnLKxrsvoMVw==", + "license": "MIT", + "peerDependencies": { + "@types/react": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-toggle-group/node_modules/@radix-ui/react-id": { + "version": "1.1.1", + "resolved": "https://registry.npmmirror.com/@radix-ui/react-id/-/react-id-1.1.1.tgz", + "integrity": "sha512-kGkGegYIdQsOb4XjsfM97rXsiHaBwco+hFI66oO4s9LU+PLAC5oJ7khdOVFxkhsmlbpUqDAvXw11CluXP+jkHg==", + "license": "MIT", + "dependencies": { + "@radix-ui/react-use-layout-effect": "1.1.1" + }, + "peerDependencies": { + "@types/react": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-toggle-group/node_modules/@radix-ui/react-primitive": { + "version": "2.1.3", + "resolved": "https://registry.npmmirror.com/@radix-ui/react-primitive/-/react-primitive-2.1.3.tgz", + "integrity": "sha512-m9gTwRkhy2lvCPe6QJp4d3G1TYEUHn/FzJUtq9MjH46an1wJU+GdoGC5VLof8RX8Ft/DlpshApkhswDLZzHIcQ==", + "license": "MIT", + "dependencies": { + "@radix-ui/react-slot": "1.2.3" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", + "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-toggle-group/node_modules/@radix-ui/react-roving-focus": { + "version": "1.1.10", + "resolved": "https://registry.npmmirror.com/@radix-ui/react-roving-focus/-/react-roving-focus-1.1.10.tgz", + "integrity": "sha512-dT9aOXUen9JSsxnMPv/0VqySQf5eDQ6LCk5Sw28kamz8wSOW2bJdlX2Bg5VUIIcV+6XlHpWTIuTPCf/UNIyq8Q==", + "license": "MIT", + "dependencies": { + "@radix-ui/primitive": "1.1.2", + "@radix-ui/react-collection": "1.1.7", + "@radix-ui/react-compose-refs": "1.1.2", + "@radix-ui/react-context": "1.1.2", + "@radix-ui/react-direction": "1.1.1", + "@radix-ui/react-id": "1.1.1", + "@radix-ui/react-primitive": "2.1.3", + "@radix-ui/react-use-callback-ref": "1.1.1", + "@radix-ui/react-use-controllable-state": "1.2.2" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", + "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-toggle-group/node_modules/@radix-ui/react-slot": { + "version": "1.2.3", + "resolved": "https://registry.npmmirror.com/@radix-ui/react-slot/-/react-slot-1.2.3.tgz", + "integrity": "sha512-aeNmHnBxbi2St0au6VBVC7JXFlhLlOnvIIlePNniyUNAClzmtAUEY8/pBiK3iHjufOlwA+c20/8jngo7xcrg8A==", + "license": "MIT", + "dependencies": { + "@radix-ui/react-compose-refs": "1.1.2" + }, + "peerDependencies": { + "@types/react": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-toggle-group/node_modules/@radix-ui/react-use-callback-ref": { + "version": "1.1.1", + "resolved": "https://registry.npmmirror.com/@radix-ui/react-use-callback-ref/-/react-use-callback-ref-1.1.1.tgz", + "integrity": "sha512-FkBMwD+qbGQeMu1cOHnuGB6x4yzPjho8ap5WtbEJ26umhgqVXbhekKUQO+hZEL1vU92a3wHwdp0HAcqAUF5iDg==", + "license": "MIT", + "peerDependencies": { + "@types/react": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-toggle-group/node_modules/@radix-ui/react-use-controllable-state": { + "version": "1.2.2", + "resolved": "https://registry.npmmirror.com/@radix-ui/react-use-controllable-state/-/react-use-controllable-state-1.2.2.tgz", + "integrity": "sha512-BjasUjixPFdS+NKkypcyyN5Pmg83Olst0+c6vGov0diwTEo6mgdqVR6hxcEgFuh4QrAs7Rc+9KuGJ9TVCj0Zzg==", + "license": "MIT", + "dependencies": { + "@radix-ui/react-use-effect-event": "0.0.2", + "@radix-ui/react-use-layout-effect": "1.1.1" + }, + "peerDependencies": { + "@types/react": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-toggle-group/node_modules/@radix-ui/react-use-layout-effect": { + "version": "1.1.1", + "resolved": "https://registry.npmmirror.com/@radix-ui/react-use-layout-effect/-/react-use-layout-effect-1.1.1.tgz", + "integrity": "sha512-RbJRS4UWQFkzHTTwVymMTUv8EqYhOp8dOOviLj2ugtTiXRaRQS7GLGxZTLL1jWhMeoSCf5zmcZkqTl9IiYfXcQ==", + "license": "MIT", + "peerDependencies": { + "@types/react": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-toggle/node_modules/@radix-ui/primitive": { + "version": "1.1.2", + "resolved": "https://registry.npmmirror.com/@radix-ui/primitive/-/primitive-1.1.2.tgz", + "integrity": "sha512-XnbHrrprsNqZKQhStrSwgRUQzoCI1glLzdw79xiZPoofhGICeZRSQ3dIxAKH1gb3OHfNf4d6f+vAv3kil2eggA==", + "license": "MIT" + }, + "node_modules/@radix-ui/react-toggle/node_modules/@radix-ui/react-compose-refs": { + "version": "1.1.2", + "resolved": "https://registry.npmmirror.com/@radix-ui/react-compose-refs/-/react-compose-refs-1.1.2.tgz", + "integrity": "sha512-z4eqJvfiNnFMHIIvXP3CY57y2WJs5g2v3X0zm9mEJkrkNv4rDxu+sg9Jh8EkXyeqBkB7SOcboo9dMVqhyrACIg==", + "license": "MIT", + "peerDependencies": { + "@types/react": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-toggle/node_modules/@radix-ui/react-primitive": { + "version": "2.1.3", + "resolved": "https://registry.npmmirror.com/@radix-ui/react-primitive/-/react-primitive-2.1.3.tgz", + "integrity": "sha512-m9gTwRkhy2lvCPe6QJp4d3G1TYEUHn/FzJUtq9MjH46an1wJU+GdoGC5VLof8RX8Ft/DlpshApkhswDLZzHIcQ==", + "license": "MIT", + "dependencies": { + "@radix-ui/react-slot": "1.2.3" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", + "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-toggle/node_modules/@radix-ui/react-slot": { + "version": "1.2.3", + "resolved": "https://registry.npmmirror.com/@radix-ui/react-slot/-/react-slot-1.2.3.tgz", + "integrity": "sha512-aeNmHnBxbi2St0au6VBVC7JXFlhLlOnvIIlePNniyUNAClzmtAUEY8/pBiK3iHjufOlwA+c20/8jngo7xcrg8A==", + "license": "MIT", + "dependencies": { + "@radix-ui/react-compose-refs": "1.1.2" + }, + "peerDependencies": { + "@types/react": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-toggle/node_modules/@radix-ui/react-use-controllable-state": { + "version": "1.2.2", + "resolved": "https://registry.npmmirror.com/@radix-ui/react-use-controllable-state/-/react-use-controllable-state-1.2.2.tgz", + "integrity": "sha512-BjasUjixPFdS+NKkypcyyN5Pmg83Olst0+c6vGov0diwTEo6mgdqVR6hxcEgFuh4QrAs7Rc+9KuGJ9TVCj0Zzg==", + "license": "MIT", + "dependencies": { + "@radix-ui/react-use-effect-event": "0.0.2", + "@radix-ui/react-use-layout-effect": "1.1.1" + }, + "peerDependencies": { + "@types/react": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-toggle/node_modules/@radix-ui/react-use-layout-effect": { + "version": "1.1.1", + "resolved": "https://registry.npmmirror.com/@radix-ui/react-use-layout-effect/-/react-use-layout-effect-1.1.1.tgz", + "integrity": "sha512-RbJRS4UWQFkzHTTwVymMTUv8EqYhOp8dOOviLj2ugtTiXRaRQS7GLGxZTLL1jWhMeoSCf5zmcZkqTl9IiYfXcQ==", + "license": "MIT", + "peerDependencies": { + "@types/react": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + } + } + }, "node_modules/@radix-ui/react-tooltip": { "version": "1.1.4", "resolved": "https://registry.npmmirror.com/@radix-ui/react-tooltip/-/react-tooltip-1.1.4.tgz", @@ -11187,6 +11556,19 @@ "@xtuc/long": "4.2.2" } }, + "node_modules/@welldone-software/why-did-you-render": { + "version": "8.0.3", + "resolved": "https://registry.npmmirror.com/@welldone-software/why-did-you-render/-/why-did-you-render-8.0.3.tgz", + "integrity": "sha512-bb5bKPMStYnocyTBVBu4UTegZdBqzV1mPhxc0UIV/S43KFUSRflux9gvzJfu2aM4EWLJ3egTvdjOi+viK+LKGA==", + "dev": true, + "license": "MIT", + "dependencies": { + "lodash": "^4" + }, + "peerDependencies": { + "react": "^18" + } + }, "node_modules/@xmldom/xmldom": { "version": "0.8.10", "resolved": "https://registry.npmmirror.com/@xmldom/xmldom/-/xmldom-0.8.10.tgz", @@ -27245,9 +27627,10 @@ } }, "node_modules/react-hook-form": { - "version": "7.53.1", - "resolved": "https://registry.npmmirror.com/react-hook-form/-/react-hook-form-7.53.1.tgz", - "integrity": "sha512-6aiQeBda4zjcuaugWvim9WsGqisoUk+etmFEsSUMm451/Ic8L/UAb7sRtMj3V+Hdzm6mMjU1VhiSzYUZeBm0Vg==", + "version": "7.56.4", + "resolved": "https://registry.npmmirror.com/react-hook-form/-/react-hook-form-7.56.4.tgz", + "integrity": "sha512-Rob7Ftz2vyZ/ZGsQZPaRdIefkgOSrQSPXfqBdvOPwJfoGnjwRJUs7EM7Kc1mcoDv3NOtqBzPGbcMB8CGn9CKgw==", + "license": "MIT", "engines": { "node": ">=18.0.0" }, diff --git a/web/package.json b/web/package.json index 6a2faf96a8a..6184959ef86 100644 --- a/web/package.json +++ b/web/package.json @@ -49,6 +49,8 @@ "@radix-ui/react-switch": "^1.1.1", "@radix-ui/react-tabs": "^1.1.1", "@radix-ui/react-toast": "^1.2.6", + "@radix-ui/react-toggle": "^1.1.9", + "@radix-ui/react-toggle-group": "^1.1.10", "@radix-ui/react-tooltip": "^1.1.4", "@tailwindcss/line-clamp": "^0.4.4", "@tanstack/react-query": "^5.40.0", @@ -83,7 +85,7 @@ "react-copy-to-clipboard": "^5.1.0", "react-dropzone": "^14.3.5", "react-error-boundary": "^4.0.13", - "react-hook-form": "^7.53.1", + "react-hook-form": "^7.56.4", "react-i18next": "^14.0.0", "react-infinite-scroll-component": "^6.1.0", "react-markdown": "^9.0.1", @@ -123,6 +125,7 @@ "@types/webpack-env": "^1.18.4", "@umijs/lint": "^4.1.1", "@umijs/plugins": "^4.1.0", + "@welldone-software/why-did-you-render": "^8.0.3", "cross-env": "^7.0.3", "html-loader": "^5.1.0", "husky": "^9.0.11", diff --git a/web/public/iconfont.js b/web/public/iconfont.js index 66daadffbaa..b405874429f 100644 --- a/web/public/iconfont.js +++ b/web/public/iconfont.js @@ -1,5 +1,5 @@ (window._iconfont_svg_string_4909832 = - ''), + ''), ((h) => { var a = (l = (l = document.getElementsByTagName('script'))[ l.length - 1 diff --git a/web/src/app.tsx b/web/src/app.tsx index ddafd4a161b..1ea4025b6fd 100644 --- a/web/src/app.tsx +++ b/web/src/app.tsx @@ -39,6 +39,15 @@ const AntLanguageMap = { de: deDE, }; +if (process.env.NODE_ENV === 'development') { + const whyDidYouRender = require('@welldone-software/why-did-you-render'); + whyDidYouRender(React, { + trackAllPureComponents: true, + trackExtraHooks: [], + logOnDifferentValues: true, + }); +} + const queryClient = new QueryClient(); type Locale = ConfigProviderProps['locale']; diff --git a/web/src/components/chunk-method-dialog/hooks.ts b/web/src/components/chunk-method-dialog/hooks.ts index 3e00161d697..b53088577df 100644 --- a/web/src/components/chunk-method-dialog/hooks.ts +++ b/web/src/components/chunk-method-dialog/hooks.ts @@ -1,5 +1,5 @@ import { useSelectParserList } from '@/hooks/user-setting-hooks'; -import { useCallback, useEffect, useMemo, useState } from 'react'; +import { useCallback, useMemo } from 'react'; const ParserListMap = new Map([ [ @@ -80,15 +80,8 @@ const getParserList = ( return parserList.filter((x) => values?.some((y) => y === x.value)); }; -export const useFetchParserListOnMount = ( - documentId: string, - parserId: string, - documentExtension: string, - // form: FormInstance, -) => { - const [selectedTag, setSelectedTag] = useState(''); +export const useFetchParserListOnMount = (documentExtension: string) => { const parserList = useSelectParserList(); - // const handleChunkMethodSelectChange = useHandleChunkMethodSelectChange(form); // TODO const nextParserList = useMemo(() => { const key = [...ParserListMap.keys()].find((x) => @@ -105,16 +98,7 @@ export const useFetchParserListOnMount = ( ); }, [parserList, documentExtension]); - useEffect(() => { - setSelectedTag(parserId); - }, [parserId, documentId]); - - const handleChange = (tag: string) => { - // handleChunkMethodSelectChange(tag); - setSelectedTag(tag); - }; - - return { parserList: nextParserList, handleChange, selectedTag }; + return { parserList: nextParserList }; }; const hideAutoKeywords = ['qa', 'table', 'resume', 'knowledge_graph', 'tag']; diff --git a/web/src/components/chunk-method-dialog/index.tsx b/web/src/components/chunk-method-dialog/index.tsx index 4822914432f..404ffb223d2 100644 --- a/web/src/components/chunk-method-dialog/index.tsx +++ b/web/src/components/chunk-method-dialog/index.tsx @@ -88,12 +88,7 @@ export function ChunkMethodDialog({ }: IProps) { const { t } = useTranslation(); - const { parserList } = useFetchParserListOnMount( - documentId, - parserId, - documentExtension, - // form, - ); + const { parserList } = useFetchParserListOnMount(documentExtension); const { data: knowledgeDetails } = useFetchKnowledgeBaseConfiguration(); diff --git a/web/src/components/collapse.tsx b/web/src/components/collapse.tsx new file mode 100644 index 00000000000..35e85e257c0 --- /dev/null +++ b/web/src/components/collapse.tsx @@ -0,0 +1,28 @@ +import { + Collapsible, + CollapsibleContent, + CollapsibleTrigger, +} from '@/components/ui/collapsible'; +import { ListCollapse } from 'lucide-react'; +import { PropsWithChildren, ReactNode } from 'react'; + +type CollapseProps = { + title?: ReactNode; + rightContent?: ReactNode; +} & PropsWithChildren; + +export function Collapse({ title, children, rightContent }: CollapseProps) { + return ( + + +
+
+ {title} +
+
{rightContent}
+
+
+ {children} +
+ ); +} diff --git a/web/src/components/cross-language-item-ui.tsx b/web/src/components/cross-language-item-ui.tsx new file mode 100644 index 00000000000..3fbf5430d84 --- /dev/null +++ b/web/src/components/cross-language-item-ui.tsx @@ -0,0 +1,48 @@ +import { FormLabel } from '@/components/ui/form'; +import { MultiSelect } from '@/components/ui/multi-select'; +import { useTranslation } from 'react-i18next'; + +const Languages = [ + 'English', + 'Chinese', + 'Spanish', + 'French', + 'German', + 'Japanese', + 'Korean', +]; + +const options = Languages.map((x) => ({ label: x, value: x })); + +type CrossLanguageItemProps = { + name?: string | Array; + onChange: (arg: string[]) => void; +}; + +export const CrossLanguageItem = ({ + name = ['prompt_config', 'cross_languages'], + onChange = () => {}, +}: CrossLanguageItemProps) => { + const { t } = useTranslation(); + + return ( +
+
+ + {t('chat.crossLanguage')} + +
+ { + onChange(val); + }} + // defaultValue={field.value} + placeholder={t('fileManager.pleaseSelect')} + maxCount={100} + // {...field} + modalPopover + /> +
+ ); +}; diff --git a/web/src/components/delimiter-form-field.tsx b/web/src/components/delimiter-form-field.tsx index 479aa917c60..cc1affd56bf 100644 --- a/web/src/components/delimiter-form-field.tsx +++ b/web/src/components/delimiter-form-field.tsx @@ -43,17 +43,33 @@ export function DelimiterFormField() { ( - - - {t('knowledgeDetails.delimiter')} - - - - - - - )} + render={({ field }) => { + if (typeof field.value === 'undefined') { + // default value set + form.setValue('parser_config.delimiter', '\n'); + } + return ( + +
+ + {t('knowledgeDetails.delimiter')} + +
+ + + +
+
+
+
+ +
+
+ ); + }} /> ); } diff --git a/web/src/components/edit-tag/index.less b/web/src/components/edit-tag/index.less index ef67d102383..1c4314cfbad 100644 --- a/web/src/components/edit-tag/index.less +++ b/web/src/components/edit-tag/index.less @@ -2,14 +2,15 @@ display: flex; gap: 8px; flex-wrap: wrap; - width: 100%; + // width: 100%; margin-bottom: 8px; } .tag { max-width: 100%; margin: 0; - padding: 2px 20px 2px 4px; + padding: 2px 20px 0px 4px; + height: 26px; font-size: 14px; .textEllipsis(); position: relative; diff --git a/web/src/components/entity-types-form-field.tsx b/web/src/components/entity-types-form-field.tsx index 6cf3e364bf2..3a5de8ce586 100644 --- a/web/src/components/entity-types-form-field.tsx +++ b/web/src/components/entity-types-form-field.tsx @@ -24,12 +24,21 @@ export function EntityTypesFormField({ control={form.control} name={name} render={({ field }) => ( - - {t('entityTypes')} - - - - + +
+ + * {t('entityTypes')} + +
+ + + +
+
+
+
+ +
)} /> diff --git a/web/src/components/excel-to-html-form-field.tsx b/web/src/components/excel-to-html-form-field.tsx index 2e4ba113937..90fa5de8196 100644 --- a/web/src/components/excel-to-html-form-field.tsx +++ b/web/src/components/excel-to-html-form-field.tsx @@ -17,18 +17,37 @@ export function ExcelToHtmlFormField() { ( - - {t('html4excel')} - - - - - - )} + render={({ field }) => { + if (typeof field.value === 'undefined') { + // default value set + form.setValue('parser_config.html4excel', false); + } + + return ( + +
+ + {t('html4excel')} + +
+ + + +
+
+
+
+ +
+
+ ); + }} /> ); } diff --git a/web/src/components/large-model-form-field.tsx b/web/src/components/large-model-form-field.tsx index f6b2a4314c7..58d58f385b4 100644 --- a/web/src/components/large-model-form-field.tsx +++ b/web/src/components/large-model-form-field.tsx @@ -7,7 +7,7 @@ import { } from '@/components/ui/form'; import { useFormContext } from 'react-hook-form'; import { useTranslation } from 'react-i18next'; -import { NextLLMSelect } from './llm-select'; +import { NextLLMSelect } from './llm-select/next'; export function LargeModelFormField() { const form = useFormContext(); diff --git a/web/src/components/layout-recognize-form-field.tsx b/web/src/components/layout-recognize-form-field.tsx index 0c95ef66056..50eccbffb41 100644 --- a/web/src/components/layout-recognize-form-field.tsx +++ b/web/src/components/layout-recognize-form-field.tsx @@ -54,17 +54,37 @@ export function LayoutRecognizeFormField() { ( - - - {t('layoutRecognize')} - - - - - - - )} + render={({ field }) => { + if (typeof field.value === 'undefined') { + // default value set + form.setValue( + 'parser_config.layout_recognize', + form.formState.defaultValues?.parser_config?.layout_recognize ?? + 'DeepDOC', + ); + } + return ( + +
+ + {t('layoutRecognize')} + +
+ + + +
+
+
+
+ +
+
+ ); + }} /> ); } diff --git a/web/src/components/llm-select/index.tsx b/web/src/components/llm-select/index.tsx index fc31f3a6cae..fa65b95c34d 100644 --- a/web/src/components/llm-select/index.tsx +++ b/web/src/components/llm-select/index.tsx @@ -1,12 +1,7 @@ import { LlmModelType } from '@/constants/knowledge'; import { useComposeLlmOptionsByModelTypes } from '@/hooks/llm-hooks'; -import * as SelectPrimitive from '@radix-ui/react-select'; import { Popover as AntPopover, Select as AntSelect } from 'antd'; -import { forwardRef, useState } from 'react'; import LlmSettingItems from '../llm-setting-items'; -import { LlmSettingFieldItems } from '../llm-setting-items/next'; -import { Popover, PopoverContent, PopoverTrigger } from '../ui/popover'; -import { Select, SelectTrigger, SelectValue } from '../ui/select'; interface IProps { id?: string; @@ -16,7 +11,13 @@ interface IProps { disabled?: boolean; } -const LLMSelect = ({ id, value, onInitialValue, onChange, disabled }: IProps) => { +const LLMSelect = ({ + id, + value, + onInitialValue, + onChange, + disabled, +}: IProps) => { const modelOptions = useComposeLlmOptionsByModelTypes([ LlmModelType.Chat, LlmModelType.Image2text, @@ -31,11 +32,12 @@ const LLMSelect = ({ id, value, onInitialValue, onChange, disabled }: IProps) => } } } - } + } const content = (
-
@@ -63,43 +65,3 @@ const LLMSelect = ({ id, value, onInitialValue, onChange, disabled }: IProps) => }; export default LLMSelect; - -export const NextLLMSelect = forwardRef< - React.ElementRef, - IProps ->(({ value, disabled }, ref) => { - const [isPopoverOpen, setIsPopoverOpen] = useState(false); - const modelOptions = useComposeLlmOptionsByModelTypes([ - LlmModelType.Chat, - LlmModelType.Image2text, - ]); - - return ( - - ); -}); - -NextLLMSelect.displayName = 'LLMSelect'; diff --git a/web/src/components/llm-select/llm-label.tsx b/web/src/components/llm-select/llm-label.tsx index 29c75c91f93..1840ae95797 100644 --- a/web/src/components/llm-select/llm-label.tsx +++ b/web/src/components/llm-select/llm-label.tsx @@ -1,4 +1,5 @@ import { getLLMIconName, getLlmNameAndFIdByLlmId } from '@/utils/llm-util'; +import { memo } from 'react'; import { LlmIcon } from '../svg-icon'; interface IProps { @@ -24,4 +25,4 @@ const LLMLabel = ({ value }: IProps) => { ); }; -export default LLMLabel; +export default memo(LLMLabel); diff --git a/web/src/components/llm-select/next.tsx b/web/src/components/llm-select/next.tsx new file mode 100644 index 00000000000..4b0c0b07a41 --- /dev/null +++ b/web/src/components/llm-select/next.tsx @@ -0,0 +1,57 @@ +import { LlmModelType } from '@/constants/knowledge'; +import { useComposeLlmOptionsByModelTypes } from '@/hooks/llm-hooks'; +import * as SelectPrimitive from '@radix-ui/react-select'; +import { forwardRef, memo, useState } from 'react'; +import { LlmSettingFieldItems } from '../llm-setting-items/next'; +import { Popover, PopoverContent, PopoverTrigger } from '../ui/popover'; +import { Select, SelectTrigger, SelectValue } from '../ui/select'; + +interface IProps { + id?: string; + value?: string; + onInitialValue?: (value: string, option: any) => void; + onChange?: (value: string, option: any) => void; + disabled?: boolean; +} + +const NextInnerLLMSelect = forwardRef< + React.ElementRef, + IProps +>(({ value, disabled }, ref) => { + const [isPopoverOpen, setIsPopoverOpen] = useState(false); + const modelOptions = useComposeLlmOptionsByModelTypes([ + LlmModelType.Chat, + LlmModelType.Image2text, + ]); + + return ( + + ); +}); + +NextInnerLLMSelect.displayName = 'LLMSelect'; + +export const NextLLMSelect = memo(NextInnerLLMSelect); diff --git a/web/src/components/llm-setting-items/index.tsx b/web/src/components/llm-setting-items/index.tsx index 81c4adf08eb..caa471da5b6 100644 --- a/web/src/components/llm-setting-items/index.tsx +++ b/web/src/components/llm-setting-items/index.tsx @@ -8,6 +8,7 @@ import camelCase from 'lodash/camelCase'; import { useTranslate } from '@/hooks/common-hooks'; import { useComposeLlmOptionsByModelTypes } from '@/hooks/llm-hooks'; +import { setChatVariableEnabledFieldValuePage } from '@/utils/chat'; import { QuestionCircleOutlined } from '@ant-design/icons'; import { useCallback, useMemo } from 'react'; import styles from './index.less'; @@ -34,7 +35,8 @@ const LlmSettingItems = ({ prefix, formItemLayout = {}, onChange }: IProps) => { if (prefix) { nextVariable = { [prefix]: variable }; } - form.setFieldsValue(nextVariable); + const variableCheckBoxFieldMap = setChatVariableEnabledFieldValuePage(); + form.setFieldsValue({ ...nextVariable, ...variableCheckBoxFieldMap }); }, [form, prefix], ); @@ -102,7 +104,11 @@ const LlmSettingItems = ({ prefix, formItemLayout = {}, onChange }: IProps) => { > - + {({ getFieldValue }) => { const disabled = !getFieldValue('temperatureEnabled'); return ( @@ -147,7 +153,7 @@ const LlmSettingItems = ({ prefix, formItemLayout = {}, onChange }: IProps) => { - + {({ getFieldValue }) => { const disabled = !getFieldValue('topPEnabled'); return ( @@ -190,7 +196,11 @@ const LlmSettingItems = ({ prefix, formItemLayout = {}, onChange }: IProps) => { > - + {({ getFieldValue }) => { const disabled = !getFieldValue('presencePenaltyEnabled'); return ( @@ -239,7 +249,11 @@ const LlmSettingItems = ({ prefix, formItemLayout = {}, onChange }: IProps) => { > - + {({ getFieldValue }) => { const disabled = !getFieldValue('frequencyPenaltyEnabled'); return ( @@ -275,6 +289,58 @@ const LlmSettingItems = ({ prefix, formItemLayout = {}, onChange }: IProps) => { + + + + + + + {({ getFieldValue }) => { + const disabled = !getFieldValue('maxTokensEnabled'); + + return ( + <> + + + + + + + + + + ); + }} + + + diff --git a/web/src/components/llm-setting-items/next.tsx b/web/src/components/llm-setting-items/next.tsx index 31e63068bc9..c47cc296901 100644 --- a/web/src/components/llm-setting-items/next.tsx +++ b/web/src/components/llm-setting-items/next.tsx @@ -4,6 +4,7 @@ import { useComposeLlmOptionsByModelTypes } from '@/hooks/llm-hooks'; import { camelCase } from 'lodash'; import { useCallback } from 'react'; import { useFormContext } from 'react-hook-form'; +import { z } from 'zod'; import { FormControl, FormField, @@ -11,7 +12,6 @@ import { FormLabel, FormMessage, } from '../ui/form'; -import { Input } from '../ui/input'; import { Select, SelectContent, @@ -21,86 +21,26 @@ import { SelectTrigger, SelectValue, } from '../ui/select'; -import { FormSlider } from '../ui/slider'; -import { Switch } from '../ui/switch'; - -interface SliderWithInputNumberFormFieldProps { - name: string; - label: string; - checkName: string; - max: number; - min?: number; - step?: number; -} - -function SliderWithInputNumberFormField({ - name, - label, - checkName, - max, - min = 0, - step = 1, -}: SliderWithInputNumberFormFieldProps) { - const { control, watch } = useFormContext(); - const { t } = useTranslate('chat'); - const disabled = !watch(checkName); - - return ( - ( - -
- {t(label)} - ( - - - - - - - )} - /> -
- -
- - -
-
- -
- )} - /> - ); -} +import { SliderInputSwitchFormField } from './slider'; +import { useHandleFreedomChange } from './use-watch-change'; interface LlmSettingFieldItemsProps { prefix?: string; } +export const LlmSettingSchema = { + llm_id: z.string(), + temperature: z.coerce.number(), + top_p: z.string(), + presence_penalty: z.coerce.number(), + frequency_penalty: z.coerce.number(), + temperatureEnabled: z.boolean(), + topPEnabled: z.boolean(), + presencePenaltyEnabled: z.boolean(), + frequencyPenaltyEnabled: z.boolean(), + maxTokensEnabled: z.boolean(), +}; + export function LlmSettingFieldItems({ prefix }: LlmSettingFieldItemsProps) { const form = useFormContext(); const { t } = useTranslate('chat'); @@ -109,6 +49,10 @@ export function LlmSettingFieldItems({ prefix }: LlmSettingFieldItemsProps) { LlmModelType.Image2text, ]); + // useWatchFreedomChange(); + + const handleChange = useHandleFreedomChange(); + const parameterOptions = Object.values(ModelVariableType).map((x) => ({ label: t(camelCase(x)), value: x, @@ -116,13 +60,13 @@ export function LlmSettingFieldItems({ prefix }: LlmSettingFieldItemsProps) { const getFieldWithPrefix = useCallback( (name: string) => { - return `${prefix}.${name}`; + return prefix ? `${prefix}.${name}` : name; }, [prefix], ); return ( -
+
{t('freedom')} - { + handleChange(val); + field.onChange(val); + }} + > @@ -180,40 +130,40 @@ export function LlmSettingFieldItems({ prefix }: LlmSettingFieldItemsProps) { )} /> - - + - + - + - + + >
); } diff --git a/web/src/components/llm-setting-items/slider.tsx b/web/src/components/llm-setting-items/slider.tsx new file mode 100644 index 00000000000..a9137b3538c --- /dev/null +++ b/web/src/components/llm-setting-items/slider.tsx @@ -0,0 +1,92 @@ +import { useTranslate } from '@/hooks/common-hooks'; +import { cn } from '@/lib/utils'; +import { useFormContext } from 'react-hook-form'; +import { SingleFormSlider } from '../ui/dual-range-slider'; +import { + FormControl, + FormField, + FormItem, + FormLabel, + FormMessage, +} from '../ui/form'; +import { Input } from '../ui/input'; +import { Switch } from '../ui/switch'; + +type SliderInputSwitchFormFieldProps = { + max?: number; + min?: number; + step?: number; + name: string; + label: string; + defaultValue?: number; + className?: string; + checkName: string; +}; + +export function SliderInputSwitchFormField({ + max, + min, + step, + label, + name, + defaultValue, + className, + checkName, +}: SliderInputSwitchFormFieldProps) { + const form = useFormContext(); + const disabled = !form.watch(checkName); + const { t } = useTranslate('chat'); + + return ( + ( + + {t(label)} +
+ ( + + + + + + + )} + /> + + + + + + +
+ +
+ )} + /> + ); +} diff --git a/web/src/components/llm-setting-items/use-watch-change.ts b/web/src/components/llm-setting-items/use-watch-change.ts new file mode 100644 index 00000000000..bf3fa595c1c --- /dev/null +++ b/web/src/components/llm-setting-items/use-watch-change.ts @@ -0,0 +1,38 @@ +import { settledModelVariableMap } from '@/constants/knowledge'; +import { AgentFormContext } from '@/pages/agent/context'; +import useGraphStore from '@/pages/agent/store'; +import { useCallback, useContext } from 'react'; +import { useFormContext } from 'react-hook-form'; + +export function useHandleFreedomChange() { + const form = useFormContext(); + const node = useContext(AgentFormContext); + const updateNodeForm = useGraphStore((state) => state.updateNodeForm); + + const handleChange = useCallback( + (parameter: string) => { + const currentValues = { ...form.getValues() }; + const values = + settledModelVariableMap[ + parameter as keyof typeof settledModelVariableMap + ]; + + const nextValues = { ...currentValues, ...values }; + + if (node?.id) { + updateNodeForm(node?.id, nextValues); + } + + for (const key in values) { + if (Object.prototype.hasOwnProperty.call(values, key)) { + const element = values[key]; + + form.setValue(key, element); + } + } + }, + [form, node, updateNodeForm], + ); + + return handleChange; +} diff --git a/web/src/components/message-history-window-size-item.tsx b/web/src/components/message-history-window-size-item.tsx index bab9f4ccb6e..2f2043f3c0f 100644 --- a/web/src/components/message-history-window-size-item.tsx +++ b/web/src/components/message-history-window-size-item.tsx @@ -1,4 +1,5 @@ import { Form, InputNumber } from 'antd'; +import { useMemo } from 'react'; import { useFormContext } from 'react-hook-form'; import { useTranslation } from 'react-i18next'; import { @@ -8,7 +9,7 @@ import { FormLabel, FormMessage, } from './ui/form'; -import { Input } from './ui/input'; +import { BlurInput, Input } from './ui/input'; const MessageHistoryWindowSizeItem = ({ initialValue, @@ -31,10 +32,20 @@ const MessageHistoryWindowSizeItem = ({ export default MessageHistoryWindowSizeItem; -export function MessageHistoryWindowSizeFormField() { +type MessageHistoryWindowSizeFormFieldProps = { + useBlurInput?: boolean; +}; + +export function MessageHistoryWindowSizeFormField({ + useBlurInput = false, +}: MessageHistoryWindowSizeFormFieldProps) { const form = useFormContext(); const { t } = useTranslation(); + const NextInput = useMemo(() => { + return useBlurInput ? BlurInput : Input; + }, [useBlurInput]); + return ( - + diff --git a/web/src/components/next-message-item/feedback-modal.tsx b/web/src/components/next-message-item/feedback-modal.tsx new file mode 100644 index 00000000000..c2a471c7b4e --- /dev/null +++ b/web/src/components/next-message-item/feedback-modal.tsx @@ -0,0 +1,51 @@ +import { Form, Input, Modal } from 'antd'; + +import { IModalProps } from '@/interfaces/common'; +import { IFeedbackRequestBody } from '@/interfaces/request/chat'; +import { useCallback } from 'react'; + +type FieldType = { + feedback?: string; +}; + +const FeedbackModal = ({ + visible, + hideModal, + onOk, + loading, +}: IModalProps) => { + const [form] = Form.useForm(); + + const handleOk = useCallback(async () => { + const ret = await form.validateFields(); + return onOk?.({ thumbup: false, feedback: ret.feedback }); + }, [onOk, form]); + + return ( + +
+ + name="feedback" + rules={[{ required: true, message: 'Please input your feedback!' }]} + > + + + +
+ ); +}; + +export default FeedbackModal; diff --git a/web/src/components/next-message-item/group-button.tsx b/web/src/components/next-message-item/group-button.tsx new file mode 100644 index 00000000000..3e1b21e888f --- /dev/null +++ b/web/src/components/next-message-item/group-button.tsx @@ -0,0 +1,220 @@ +import { PromptIcon } from '@/assets/icon/Icon'; +import CopyToClipboard from '@/components/copy-to-clipboard'; +import { useSetModalState } from '@/hooks/common-hooks'; +import { IRemoveMessageById } from '@/hooks/logic-hooks'; +import { AgentChatContext } from '@/pages/agent/context'; +import { + DeleteOutlined, + DislikeOutlined, + LikeOutlined, + PauseCircleOutlined, + SoundOutlined, + SyncOutlined, +} from '@ant-design/icons'; +import { Radio, Tooltip } from 'antd'; +import { NotebookText } from 'lucide-react'; +import { useCallback, useContext } from 'react'; +import { useTranslation } from 'react-i18next'; +import { ToggleGroup, ToggleGroupItem } from '../ui/toggle-group'; +import FeedbackModal from './feedback-modal'; +import { useRemoveMessage, useSendFeedback, useSpeech } from './hooks'; +import PromptModal from './prompt-modal'; + +interface IProps { + messageId: string; + content: string; + prompt?: string; + showLikeButton: boolean; + audioBinary?: string; + showLoudspeaker?: boolean; +} + +export const AssistantGroupButton = ({ + messageId, + content, + prompt, + audioBinary, + showLikeButton, + showLoudspeaker = true, +}: IProps) => { + const { visible, hideModal, showModal, onFeedbackOk, loading } = + useSendFeedback(messageId); + const { + visible: promptVisible, + hideModal: hidePromptModal, + showModal: showPromptModal, + } = useSetModalState(); + const { t } = useTranslation(); + const { handleRead, ref, isPlaying } = useSpeech(content, audioBinary); + + const handleLike = useCallback(() => { + onFeedbackOk({ thumbup: true }); + }, [onFeedbackOk]); + + const { showLogSheet } = useContext(AgentChatContext); + + const handleShowLogSheet = useCallback(() => { + showLogSheet(messageId); + }, [messageId, showLogSheet]); + + return ( + <> + + + + + {showLoudspeaker && ( + + + {isPlaying ? : } + + + + )} + {showLikeButton && ( + <> + + + + + + + + )} + {prompt && ( + + + + )} + + + + + {visible && ( + + )} + {promptVisible && ( + + )} + + ); + + return ( + <> + + + + + {showLoudspeaker && ( + + + {isPlaying ? : } + + + + )} + {showLikeButton && ( + <> + + + + + + + + )} + {prompt && ( + + + + )} + { + e.preventDefault(); + e.stopPropagation(); + handleShowLogSheet(); + }} + > + + + + {visible && ( + + )} + {promptVisible && ( + + )} + + ); +}; + +interface UserGroupButtonProps extends Partial { + messageId: string; + content: string; + regenerateMessage?: () => void; + sendLoading: boolean; +} + +export const UserGroupButton = ({ + content, + messageId, + sendLoading, + removeMessageById, + regenerateMessage, +}: UserGroupButtonProps) => { + const { onRemoveMessage, loading } = useRemoveMessage( + messageId, + removeMessageById, + ); + const { t } = useTranslation(); + + return ( + + + + + {regenerateMessage && ( + + + + + + )} + {removeMessageById && ( + + + + + + )} + + ); +}; diff --git a/web/src/components/next-message-item/hooks.ts b/web/src/components/next-message-item/hooks.ts new file mode 100644 index 00000000000..f77da050114 --- /dev/null +++ b/web/src/components/next-message-item/hooks.ts @@ -0,0 +1,116 @@ +import { useDeleteMessage, useFeedback } from '@/hooks/chat-hooks'; +import { useSetModalState } from '@/hooks/common-hooks'; +import { IRemoveMessageById, useSpeechWithSse } from '@/hooks/logic-hooks'; +import { IFeedbackRequestBody } from '@/interfaces/request/chat'; +import { hexStringToUint8Array } from '@/utils/common-util'; +import { SpeechPlayer } from 'openai-speech-stream-player'; +import { useCallback, useEffect, useRef, useState } from 'react'; + +export const useSendFeedback = (messageId: string) => { + const { visible, hideModal, showModal } = useSetModalState(); + const { feedback, loading } = useFeedback(); + + const onFeedbackOk = useCallback( + async (params: IFeedbackRequestBody) => { + const ret = await feedback({ + ...params, + messageId: messageId, + }); + + if (ret === 0) { + hideModal(); + } + }, + [feedback, hideModal, messageId], + ); + + return { + loading, + onFeedbackOk, + visible, + hideModal, + showModal, + }; +}; + +export const useRemoveMessage = ( + messageId: string, + removeMessageById?: IRemoveMessageById['removeMessageById'], +) => { + const { deleteMessage, loading } = useDeleteMessage(); + + const onRemoveMessage = useCallback(async () => { + if (messageId) { + const code = await deleteMessage(messageId); + if (code === 0) { + removeMessageById?.(messageId); + } + } + }, [deleteMessage, messageId, removeMessageById]); + + return { onRemoveMessage, loading }; +}; + +export const useSpeech = (content: string, audioBinary?: string) => { + const ref = useRef(null); + const { read } = useSpeechWithSse(); + const player = useRef(); + const [isPlaying, setIsPlaying] = useState(false); + + const initialize = useCallback(async () => { + player.current = new SpeechPlayer({ + audio: ref.current!, + onPlaying: () => { + setIsPlaying(true); + }, + onPause: () => { + setIsPlaying(false); + }, + onChunkEnd: () => {}, + mimeType: MediaSource.isTypeSupported('audio/mpeg') + ? 'audio/mpeg' + : 'audio/mp4; codecs="mp4a.40.2"', // https://stackoverflow.com/questions/64079424/cannot-replay-mp3-in-firefox-using-mediasource-even-though-it-works-in-chrome + }); + await player.current.init(); + }, []); + + const pause = useCallback(() => { + player.current?.pause(); + }, []); + + const speech = useCallback(async () => { + const response = await read({ text: content }); + if (response) { + player?.current?.feedWithResponse(response); + } + }, [read, content]); + + const handleRead = useCallback(async () => { + if (isPlaying) { + setIsPlaying(false); + pause(); + } else { + setIsPlaying(true); + speech(); + } + }, [setIsPlaying, speech, isPlaying, pause]); + + useEffect(() => { + if (audioBinary) { + const units = hexStringToUint8Array(audioBinary); + if (units) { + try { + player.current?.feed(units); + } catch (error) { + console.warn(error); + } + } + } + }, [audioBinary]); + + useEffect(() => { + initialize(); + }, [initialize]); + + return { ref, handleRead, isPlaying }; +}; diff --git a/web/src/components/next-message-item/index.less b/web/src/components/next-message-item/index.less new file mode 100644 index 00000000000..a4812bd61d1 --- /dev/null +++ b/web/src/components/next-message-item/index.less @@ -0,0 +1,63 @@ +.messageItem { + padding: 24px 0; + .messageItemSection { + display: inline-block; + } + .messageItemSectionLeft { + width: 80%; + } + .messageItemContent { + display: inline-flex; + gap: 20px; + } + .messageItemContentReverse { + flex-direction: row-reverse; + } + + .messageTextBase() { + padding: 6px 10px; + border-radius: 8px; + & > p { + margin: 0; + } + } + .messageText { + .chunkText(); + .messageTextBase(); + background-color: #e6f4ff; + word-break: break-word; + } + .messageTextDark { + .chunkText(); + .messageTextBase(); + background-color: #1668dc; + word-break: break-word; + :global(section.think) { + color: rgb(166, 166, 166); + border-left-color: rgb(78, 78, 86); + } + } + + .messageUserText { + .chunkText(); + .messageTextBase(); + background-color: rgba(255, 255, 255, 0.3); + word-break: break-word; + text-align: justify; + } + .messageEmpty { + width: 300px; + } + + .thumbnailImg { + max-width: 20px; + } +} + +.messageItemLeft { + text-align: left; +} + +.messageItemRight { + text-align: right; +} diff --git a/web/src/components/next-message-item/index.tsx b/web/src/components/next-message-item/index.tsx new file mode 100644 index 00000000000..664f46afb0a --- /dev/null +++ b/web/src/components/next-message-item/index.tsx @@ -0,0 +1,244 @@ +import { ReactComponent as AssistantIcon } from '@/assets/svg/assistant.svg'; +import { MessageType } from '@/constants/chat'; +import { useSetModalState } from '@/hooks/common-hooks'; +import { IReference, IReferenceChunk } from '@/interfaces/database/chat'; +import classNames from 'classnames'; +import { memo, useCallback, useEffect, useMemo, useState } from 'react'; + +import { + useFetchDocumentInfosByIds, + useFetchDocumentThumbnailsByIds, +} from '@/hooks/document-hooks'; +import { IRegenerateMessage, IRemoveMessageById } from '@/hooks/logic-hooks'; +import { IMessage } from '@/pages/chat/interface'; +import MarkdownContent from '@/pages/chat/markdown-content'; +import { getExtension, isImage } from '@/utils/document-util'; +import { Avatar, Button, Flex, List, Space, Typography } from 'antd'; +import FileIcon from '../file-icon'; +import IndentedTreeModal from '../indented-tree/modal'; +import NewDocumentLink from '../new-document-link'; +import { useTheme } from '../theme-provider'; +import { AssistantGroupButton, UserGroupButton } from './group-button'; +import styles from './index.less'; + +const { Text } = Typography; + +interface IProps extends Partial, IRegenerateMessage { + item: IMessage; + reference: IReference; + loading?: boolean; + sendLoading?: boolean; + visibleAvatar?: boolean; + nickname?: string; + avatar?: string; + avatarDialog?: string | null; + clickDocumentButton?: (documentId: string, chunk: IReferenceChunk) => void; + index: number; + showLikeButton?: boolean; + showLoudspeaker?: boolean; +} + +const MessageItem = ({ + item, + reference, + loading = false, + avatar, + avatarDialog, + sendLoading = false, + clickDocumentButton, + index, + removeMessageById, + regenerateMessage, + showLikeButton = true, + showLoudspeaker = true, + visibleAvatar = true, +}: IProps) => { + const { theme } = useTheme(); + const isAssistant = item.role === MessageType.Assistant; + const isUser = item.role === MessageType.User; + const { data: documentList, setDocumentIds } = useFetchDocumentInfosByIds(); + const { data: documentThumbnails, setDocumentIds: setIds } = + useFetchDocumentThumbnailsByIds(); + const { visible, hideModal, showModal } = useSetModalState(); + const [clickedDocumentId, setClickedDocumentId] = useState(''); + + const referenceDocumentList = useMemo(() => { + return reference?.doc_aggs ?? []; + }, [reference?.doc_aggs]); + + const handleUserDocumentClick = useCallback( + (id: string) => () => { + setClickedDocumentId(id); + showModal(); + }, + [showModal], + ); + + const handleRegenerateMessage = useCallback(() => { + regenerateMessage?.(item); + }, [regenerateMessage, item]); + + useEffect(() => { + const ids = item?.doc_ids ?? []; + if (ids.length) { + setDocumentIds(ids); + const documentIds = ids.filter((x) => !(x in documentThumbnails)); + if (documentIds.length) { + setIds(documentIds); + } + } + }, [item.doc_ids, setDocumentIds, setIds, documentThumbnails]); + + return ( +
+
+
+ {visibleAvatar && + (item.role === MessageType.User ? ( + + ) : avatarDialog ? ( + + ) : ( + + ))} + + + + {isAssistant ? ( + index !== 0 && ( + + ) + ) : ( + + )} + + {/* {isAssistant ? '' : nickname} */} + +
+ +
+ {isAssistant && referenceDocumentList.length > 0 && ( + { + return ( + + + + + + {item.doc_name} + + + + ); + }} + /> + )} + {isUser && documentList.length > 0 && ( + { + // TODO: + // const fileThumbnail = + // documentThumbnails[item.id] || documentThumbnails[item.id]; + const fileExtension = getExtension(item.name); + return ( + + + + + {isImage(fileExtension) ? ( + + {item.name} + + ) : ( + + )} + + + ); + }} + /> + )} +
+
+
+ {visible && ( + + )} +
+ ); +}; + +export default memo(MessageItem); diff --git a/web/src/components/next-message-item/prompt-modal.tsx b/web/src/components/next-message-item/prompt-modal.tsx new file mode 100644 index 00000000000..f5222e59bb1 --- /dev/null +++ b/web/src/components/next-message-item/prompt-modal.tsx @@ -0,0 +1,30 @@ +import { IModalProps } from '@/interfaces/common'; +import { IFeedbackRequestBody } from '@/interfaces/request/chat'; +import { Modal, Space } from 'antd'; +import HightLightMarkdown from '../highlight-markdown'; +import SvgIcon from '../svg-icon'; + +const PromptModal = ({ + visible, + hideModal, + prompt, +}: IModalProps & { prompt?: string }) => { + return ( + + + Prompt + + } + width={'80%'} + open={visible} + onCancel={hideModal} + footer={null} + > + {prompt} + + ); +}; + +export default PromptModal; diff --git a/web/src/components/originui/input.tsx b/web/src/components/originui/input.tsx new file mode 100644 index 00000000000..d8c125f9115 --- /dev/null +++ b/web/src/components/originui/input.tsx @@ -0,0 +1,25 @@ +import * as React from 'react'; + +import { cn } from '@/lib/utils'; + +function Input({ className, type, ...props }: React.ComponentProps<'input'>) { + return ( + + ); +} + +export { Input }; diff --git a/web/src/components/originui/select-with-search.tsx b/web/src/components/originui/select-with-search.tsx new file mode 100644 index 00000000000..fbf1584e094 --- /dev/null +++ b/web/src/components/originui/select-with-search.tsx @@ -0,0 +1,170 @@ +'use client'; + +import { CheckIcon, ChevronDownIcon } from 'lucide-react'; +import { + Fragment, + forwardRef, + useCallback, + useEffect, + useId, + useState, +} from 'react'; + +import { Button } from '@/components/ui/button'; +import { + Command, + CommandEmpty, + CommandGroup, + CommandInput, + CommandItem, + CommandList, +} from '@/components/ui/command'; +import { + Popover, + PopoverContent, + PopoverTrigger, +} from '@/components/ui/popover'; +import { RAGFlowSelectOptionType } from '../ui/select'; + +const countries = [ + { + label: 'America', + options: [ + { value: 'United States', label: '🇺🇸' }, + { value: 'Canada', label: '🇨🇦' }, + { value: 'Mexico', label: '🇲🇽' }, + ], + }, + { + label: 'Africa', + options: [ + { value: 'South Africa', label: '🇿🇦' }, + { value: 'Nigeria', label: '🇳🇬' }, + { value: 'Morocco', label: '🇲🇦' }, + ], + }, + { + label: 'Asia', + options: [ + { value: 'China', label: '🇨🇳' }, + { value: 'Japan', label: '🇯🇵' }, + { value: 'India', label: '🇮🇳' }, + ], + }, + { + label: 'Europe', + options: [ + { value: 'United Kingdom', label: '🇬🇧' }, + { value: 'France', label: '🇫🇷' }, + { value: 'Germany', label: '🇩🇪' }, + ], + }, + { + label: 'Oceania', + options: [ + { value: 'Australia', label: '🇦🇺' }, + { value: 'New Zealand', label: '🇳🇿' }, + ], + }, +]; + +export type SelectWithSearchFlagOptionType = { + label: string; + options: RAGFlowSelectOptionType[]; +}; + +export type SelectWithSearchFlagProps = { + options?: SelectWithSearchFlagOptionType[]; + value?: string; + onChange?(value: string): void; +}; + +export const SelectWithSearch = forwardRef< + React.ElementRef, + SelectWithSearchFlagProps +>(({ value: val = '', onChange, options = countries }, ref) => { + const id = useId(); + const [open, setOpen] = useState(false); + const [value, setValue] = useState(''); + + const handleSelect = useCallback( + (val: string) => { + setValue(val); + setOpen(false); + onChange?.(val); + }, + [onChange], + ); + + useEffect(() => { + setValue(val); + }, [val]); + + return ( + + + + + + + + + No data found. + {options.map((group) => ( + + + {group.options.map((option) => ( + + + {option.label} + + + {value === option.value && ( + + )} + + ))} + + + ))} + + + + + ); +}); diff --git a/web/src/components/originui/timeline.tsx b/web/src/components/originui/timeline.tsx new file mode 100644 index 00000000000..8b4bbea4273 --- /dev/null +++ b/web/src/components/originui/timeline.tsx @@ -0,0 +1,209 @@ +'use client'; + +import { cn } from '@/lib/utils'; +import { Slot } from '@radix-ui/react-slot'; +import * as React from 'react'; + +// Types +type TimelineContextValue = { + activeStep: number; + setActiveStep: (step: number) => void; +}; + +// Context +const TimelineContext = React.createContext( + undefined, +); + +const useTimeline = () => { + const context = React.useContext(TimelineContext); + if (!context) { + throw new Error('useTimeline must be used within a Timeline'); + } + return context; +}; + +// Components +interface TimelineProps extends React.HTMLAttributes { + defaultValue?: number; + value?: number; + onValueChange?: (value: number) => void; + orientation?: 'horizontal' | 'vertical'; +} + +function Timeline({ + defaultValue = 1, + value, + onValueChange, + orientation = 'vertical', + className, + ...props +}: TimelineProps) { + const [activeStep, setInternalStep] = React.useState(defaultValue); + + const setActiveStep = React.useCallback( + (step: number) => { + if (value === undefined) { + setInternalStep(step); + } + onValueChange?.(step); + }, + [value, onValueChange], + ); + + const currentStep = value ?? activeStep; + + return ( + +
+ + ); +} + +// TimelineContent +function TimelineContent({ + className, + ...props +}: React.HTMLAttributes) { + return ( +
+ ); +} + +// TimelineDate +interface TimelineDateProps extends React.HTMLAttributes { + asChild?: boolean; +} + +function TimelineDate({ + asChild = false, + className, + ...props +}: TimelineDateProps) { + const Comp = asChild ? Slot : 'time'; + + return ( + + ); +} + +// TimelineHeader +function TimelineHeader({ + className, + ...props +}: React.HTMLAttributes) { + return ( +
+ ); +} + +// TimelineIndicator +interface TimelineIndicatorProps extends React.HTMLAttributes { + asChild?: boolean; +} + +function TimelineIndicator({ + asChild = false, + className, + children, + ...props +}: TimelineIndicatorProps) { + return ( + + ); +} + +// TimelineItem +interface TimelineItemProps extends React.HTMLAttributes { + step: number; +} + +function TimelineItem({ step, className, ...props }: TimelineItemProps) { + const { activeStep } = useTimeline(); + + return ( +
+ ); +} + +// TimelineSeparator +function TimelineSeparator({ + className, + ...props +}: React.HTMLAttributes) { + return ( + , - )} - > - {t('graphRagMethod')} - - - - - + +
+
, + )} + > + {t('graphRagMethod')} + +
+ + + +
+
+
+
+ +
)} /> @@ -139,17 +157,27 @@ const GraphRagItems = ({ control={form.control} name="parser_config.graphrag.resolution" render={({ field }) => ( - - - {t('resolution')} - - - - - + +
+ + {t('resolution')} + +
+ + + +
+
+
+
+ +
)} /> @@ -158,17 +186,27 @@ const GraphRagItems = ({ control={form.control} name="parser_config.graphrag.community" render={({ field }) => ( - - - {t('community')} - - - - - + +
+ + {t('community')} + +
+ + + +
+
+
+
+ +
)} /> diff --git a/web/src/components/parse-configuration/raptor-form-fields-old.tsx b/web/src/components/parse-configuration/raptor-form-fields-old.tsx new file mode 100644 index 00000000000..28cd79cedbc --- /dev/null +++ b/web/src/components/parse-configuration/raptor-form-fields-old.tsx @@ -0,0 +1,146 @@ +import { DocumentParserType } from '@/constants/knowledge'; +import { useTranslate } from '@/hooks/common-hooks'; +import random from 'lodash/random'; +import { Plus } from 'lucide-react'; +import { useCallback } from 'react'; +import { useFormContext, useWatch } from 'react-hook-form'; +import { SliderInputFormField } from '../slider-input-form-field'; +import { Button } from '../ui/button'; +import { + FormControl, + FormField, + FormItem, + FormLabel, + FormMessage, +} from '../ui/form'; +import { Input } from '../ui/input'; +import { Switch } from '../ui/switch'; +import { Textarea } from '../ui/textarea'; + +export const excludedParseMethods = [ + DocumentParserType.Table, + DocumentParserType.Resume, + DocumentParserType.One, + DocumentParserType.Picture, + DocumentParserType.KnowledgeGraph, + DocumentParserType.Qa, + DocumentParserType.Tag, +]; + +export const showRaptorParseConfiguration = ( + parserId: DocumentParserType | undefined, +) => { + return !excludedParseMethods.some((x) => x === parserId); +}; + +export const excludedTagParseMethods = [ + DocumentParserType.Table, + DocumentParserType.KnowledgeGraph, + DocumentParserType.Tag, +]; + +export const showTagItems = (parserId: DocumentParserType) => { + return !excludedTagParseMethods.includes(parserId); +}; + +const UseRaptorField = 'parser_config.raptor.use_raptor'; +const RandomSeedField = 'parser_config.raptor.random_seed'; + +// The three types "table", "resume" and "one" do not display this configuration. + +const RaptorFormFields = () => { + const form = useFormContext(); + const { t } = useTranslate('knowledgeConfiguration'); + const useRaptor = useWatch({ name: UseRaptorField }); + + const handleGenerate = useCallback(() => { + form.setValue(RandomSeedField, random(10000)); + }, [form]); + + return ( + <> + ( + + {t('useRaptor')} + + + + + + )} + /> + {useRaptor && ( +
+ ( + + {t('prompt')} + + + ); +}); diff --git a/web/src/components/ui/toggle-group.tsx b/web/src/components/ui/toggle-group.tsx new file mode 100644 index 00000000000..a9977938501 --- /dev/null +++ b/web/src/components/ui/toggle-group.tsx @@ -0,0 +1,73 @@ +'use client'; + +import * as ToggleGroupPrimitive from '@radix-ui/react-toggle-group'; +import { type VariantProps } from 'class-variance-authority'; +import * as React from 'react'; + +import { toggleVariants } from '@/components/ui/toggle'; +import { cn } from '@/lib/utils'; + +const ToggleGroupContext = React.createContext< + VariantProps +>({ + size: 'default', + variant: 'default', +}); + +function ToggleGroup({ + className, + variant, + size, + children, + ...props +}: React.ComponentProps & + VariantProps) { + return ( + + + {children} + + + ); +} + +function ToggleGroupItem({ + className, + children, + variant, + size, + ...props +}: React.ComponentProps & + VariantProps) { + const context = React.useContext(ToggleGroupContext); + + return ( + + {children} + + ); +} + +export { ToggleGroup, ToggleGroupItem }; diff --git a/web/src/components/ui/toggle.tsx b/web/src/components/ui/toggle.tsx new file mode 100644 index 00000000000..72f15cf6060 --- /dev/null +++ b/web/src/components/ui/toggle.tsx @@ -0,0 +1,47 @@ +'use client'; + +import * as TogglePrimitive from '@radix-ui/react-toggle'; +import { cva, type VariantProps } from 'class-variance-authority'; +import * as React from 'react'; + +import { cn } from '@/lib/utils'; + +const toggleVariants = cva( + "inline-flex items-center justify-center gap-2 rounded-md text-sm font-medium hover:bg-muted hover:text-muted-foreground disabled:pointer-events-none disabled:opacity-50 data-[state=on]:bg-accent data-[state=on]:text-accent-foreground [&_svg]:pointer-events-none [&_svg:not([class*='size-'])]:size-4 [&_svg]:shrink-0 focus-visible:border-ring focus-visible:ring-ring/50 focus-visible:ring-[3px] outline-none transition-[color,box-shadow] aria-invalid:ring-destructive/20 dark:aria-invalid:ring-destructive/40 aria-invalid:border-destructive whitespace-nowrap", + { + variants: { + variant: { + default: 'bg-transparent', + outline: + 'border border-input bg-transparent shadow-xs hover:bg-accent hover:text-accent-foreground', + }, + size: { + default: 'h-9 px-2 min-w-9', + sm: 'h-8 px-1.5 min-w-8', + lg: 'h-10 px-2.5 min-w-10', + }, + }, + defaultVariants: { + variant: 'default', + size: 'default', + }, + }, +); + +function Toggle({ + className, + variant, + size, + ...props +}: React.ComponentProps & + VariantProps) { + return ( + + ); +} + +export { Toggle, toggleVariants }; diff --git a/web/src/components/ui/tooltip.tsx b/web/src/components/ui/tooltip.tsx index 7569004539f..7668b4baf6e 100644 --- a/web/src/components/ui/tooltip.tsx +++ b/web/src/components/ui/tooltip.tsx @@ -4,6 +4,7 @@ import * as TooltipPrimitive from '@radix-ui/react-tooltip'; import * as React from 'react'; import { cn } from '@/lib/utils'; +import { Info } from 'lucide-react'; const TooltipProvider = TooltipPrimitive.Provider; @@ -28,3 +29,16 @@ const TooltipContent = React.forwardRef< TooltipContent.displayName = TooltipPrimitive.Content.displayName; export { Tooltip, TooltipContent, TooltipProvider, TooltipTrigger }; + +export const FormTooltip = ({ tooltip }: { tooltip: React.ReactNode }) => { + return ( + + + + + +

{tooltip}

+
+
+ ); +}; diff --git a/web/src/components/xyflow/base-node.tsx b/web/src/components/xyflow/base-node.tsx new file mode 100644 index 00000000000..222afdaa10d --- /dev/null +++ b/web/src/components/xyflow/base-node.tsx @@ -0,0 +1,22 @@ +import { forwardRef, HTMLAttributes } from 'react'; + +import { cn } from '@/lib/utils'; + +export const BaseNode = forwardRef< + HTMLDivElement, + HTMLAttributes & { selected?: boolean } +>(({ className, selected, ...props }, ref) => ( +
+)); + +BaseNode.displayName = 'BaseNode'; diff --git a/web/src/components/xyflow/tooltip-node.tsx b/web/src/components/xyflow/tooltip-node.tsx new file mode 100644 index 00000000000..b53c67617e7 --- /dev/null +++ b/web/src/components/xyflow/tooltip-node.tsx @@ -0,0 +1,101 @@ +import { NodeProps, NodeToolbar, NodeToolbarProps } from '@xyflow/react'; +import { + HTMLAttributes, + ReactNode, + createContext, + forwardRef, + useCallback, + useContext, + useState, +} from 'react'; +import { BaseNode } from './base-node'; + +/* TOOLTIP CONTEXT ---------------------------------------------------------- */ + +const TooltipContext = createContext(false); + +/* TOOLTIP NODE ------------------------------------------------------------- */ + +export type TooltipNodeProps = Partial & { + children?: ReactNode; +}; + +/** + * A component that wraps a node and provides tooltip visibility context. + */ +export const TooltipNode = forwardRef( + ({ selected, children }, ref) => { + const [isTooltipVisible, setTooltipVisible] = useState(false); + + const showTooltip = useCallback(() => setTooltipVisible(true), []); + const hideTooltip = useCallback(() => setTooltipVisible(false), []); + + return ( + + + {children} + + + ); + }, +); + +TooltipNode.displayName = 'TooltipNode'; + +/* TOOLTIP CONTENT ---------------------------------------------------------- */ + +export type TooltipContentProps = NodeToolbarProps; + +/** + * A component that displays the tooltip content based on visibility context. + */ +export const TooltipContent = forwardRef( + ({ position, children }, ref) => { + const isTooltipVisible = useContext(TooltipContext); + + return ( +
+ + {children} + +
+ ); + }, +); + +TooltipContent.displayName = 'TooltipContent'; + +/* TOOLTIP TRIGGER ---------------------------------------------------------- */ + +export type TooltipTriggerProps = HTMLAttributes; + +/** + * A component that triggers the tooltip visibility. + */ +export const TooltipTrigger = forwardRef< + HTMLParagraphElement, + TooltipTriggerProps +>(({ children, ...props }, ref) => { + return ( +
+ {children} +
+ ); +}); + +TooltipTrigger.displayName = 'TooltipTrigger'; diff --git a/web/src/constants/agent.ts b/web/src/constants/agent.ts index 91c12880470..28a34111b82 100644 --- a/web/src/constants/agent.ts +++ b/web/src/constants/agent.ts @@ -20,3 +20,10 @@ async function main(args) { module.exports = { main }; `, }; + +export enum AgentGlobals { + SysQuery = 'sys.query', + SysUserId = 'sys.user_id', + SysConversationTurns = 'sys.conversation_turns', + SysFiles = 'sys.files', +} diff --git a/web/src/constants/knowledge.ts b/web/src/constants/knowledge.ts index ed2f6b31987..6839f383fba 100644 --- a/web/src/constants/knowledge.ts +++ b/web/src/constants/knowledge.ts @@ -23,25 +23,25 @@ export enum ModelVariableType { export const settledModelVariableMap = { [ModelVariableType.Improvise]: { - temperature: 0.9, + temperature: 0.8, top_p: 0.9, - frequency_penalty: 0.2, - presence_penalty: 0.4, - max_tokens: 512, + frequency_penalty: 0.1, + presence_penalty: 0.1, + max_tokens: 4096, }, [ModelVariableType.Precise]: { - temperature: 0.1, - top_p: 0.3, - frequency_penalty: 0.7, - presence_penalty: 0.4, - max_tokens: 512, + temperature: 0.2, + top_p: 0.75, + frequency_penalty: 0.5, + presence_penalty: 0.5, + max_tokens: 4096, }, [ModelVariableType.Balance]: { temperature: 0.5, - top_p: 0.5, - frequency_penalty: 0.7, - presence_penalty: 0.4, - max_tokens: 512, + top_p: 0.85, + frequency_penalty: 0.3, + presence_penalty: 0.2, + max_tokens: 4096, }, }; diff --git a/web/src/custom.d.ts b/web/src/custom.d.ts new file mode 100644 index 00000000000..f73d61b396c --- /dev/null +++ b/web/src/custom.d.ts @@ -0,0 +1,4 @@ +declare module '*.md' { + const content: string; + export default content; +} diff --git a/web/src/hooks/knowledge-hooks.ts b/web/src/hooks/knowledge-hooks.ts index 92df31685b5..f4cfd125f91 100644 --- a/web/src/hooks/knowledge-hooks.ts +++ b/web/src/hooks/knowledge-hooks.ts @@ -261,6 +261,51 @@ export const useTestChunkRetrieval = (): ResponsePostType & { }; }; +export const useTestChunkAllRetrieval = (): ResponsePostType & { + testChunkAll: (...params: any[]) => void; +} => { + const knowledgeBaseId = useKnowledgeBaseId(); + const { page, size: pageSize } = useSetPaginationParams(); + + const { + data, + isPending: loading, + mutateAsync, + } = useMutation({ + mutationKey: ['testChunkAll'], // This method is invalid + gcTime: 0, + mutationFn: async (values: any) => { + const { data } = await kbService.retrieval_test({ + ...values, + kb_id: values.kb_id ?? knowledgeBaseId, + doc_ids: [], + page, + size: pageSize, + }); + if (data.code === 0) { + const res = data.data; + return { + ...res, + documents: res.doc_aggs, + }; + } + return ( + data?.data ?? { + chunks: [], + documents: [], + total: 0, + } + ); + }, + }); + + return { + data: data ?? { chunks: [], documents: [], total: 0 }, + loading, + testChunkAll: mutateAsync, + }; +}; + export const useChunkIsTesting = () => { return useIsMutating({ mutationKey: ['testChunk'] }) > 0; }; @@ -288,6 +333,30 @@ export const useSelectIsTestingSuccess = () => { }); return status.at(-1) === 'success'; }; + +export const useAllTestingSuccess = () => { + const status = useMutationState({ + filters: { mutationKey: ['testChunkAll'] }, + select: (mutation) => { + return mutation.state.status; + }, + }); + return status.at(-1) === 'success'; +}; + +export const useAllTestingResult = (): ITestingResult => { + const data = useMutationState({ + filters: { mutationKey: ['testChunkAll'] }, + select: (mutation) => { + return mutation.state.data; + }, + }); + return (data.at(-1) ?? { + chunks: [], + documents: [], + total: 0, + }) as ITestingResult; +}; //#endregion //#region tags diff --git a/web/src/hooks/logic-hooks/navigate-hooks.ts b/web/src/hooks/logic-hooks/navigate-hooks.ts index 5cdff81b997..c5131f2cf00 100644 --- a/web/src/hooks/logic-hooks/navigate-hooks.ts +++ b/web/src/hooks/logic-hooks/navigate-hooks.ts @@ -64,7 +64,8 @@ export const useNavigatePage = () => { const navigateToChunkParsedResult = useCallback( (id: string, knowledgeId?: string) => () => { navigate( - `${Routes.ParsedResult}/${id}?${QueryStringMap.KnowledgeId}=${knowledgeId}`, + // `${Routes.ParsedResult}/${id}?${QueryStringMap.KnowledgeId}=${knowledgeId}`, + `${Routes.ParsedResult}/chunks?id=${knowledgeId}&doc_id=${id}`, ); }, [navigate], diff --git a/web/src/hooks/use-agent-request.ts b/web/src/hooks/use-agent-request.ts index ee8b2144c35..016718af86f 100644 --- a/web/src/hooks/use-agent-request.ts +++ b/web/src/hooks/use-agent-request.ts @@ -1,22 +1,270 @@ -import { IFlow } from '@/interfaces/database/flow'; +import { AgentGlobals } from '@/constants/agent'; +import { DSL, IFlow, IFlowTemplate } from '@/interfaces/database/flow'; +import i18n from '@/locales/config'; +import { BeginId } from '@/pages/agent/constant'; +import { useGetSharedChatSearchParams } from '@/pages/chat/shared-hooks'; import flowService from '@/services/flow-service'; -import { useQuery } from '@tanstack/react-query'; +import { buildMessageListWithUuid } from '@/utils/chat'; +import { useMutation, useQuery, useQueryClient } from '@tanstack/react-query'; +import { useDebounce } from 'ahooks'; +import { message } from 'antd'; +import { get, set } from 'lodash'; +import { useCallback } from 'react'; +import { useTranslation } from 'react-i18next'; +import { useParams } from 'umi'; +import { v4 as uuid } from 'uuid'; +import { + useGetPaginationWithRouter, + useHandleSearchChange, +} from './logic-hooks'; export const enum AgentApiAction { FetchAgentList = 'fetchAgentList', + UpdateAgentSetting = 'updateAgentSetting', + DeleteAgent = 'deleteAgent', + FetchAgentDetail = 'fetchAgentDetail', + ResetAgent = 'resetAgent', + SetAgent = 'setAgent', + FetchAgentTemplates = 'fetchAgentTemplates', } -export const useFetchAgentList = () => { - const { data, isFetching: loading } = useQuery({ - queryKey: [AgentApiAction.FetchAgentList], +export const EmptyDsl = { + graph: { + nodes: [ + { + id: BeginId, + type: 'beginNode', + position: { + x: 50, + y: 200, + }, + data: { + label: 'Begin', + name: 'begin', + }, + sourcePosition: 'left', + targetPosition: 'right', + }, + ], + edges: [], + }, + components: { + begin: { + obj: { + component_name: 'Begin', + params: {}, + }, + downstream: ['Answer:China'], // other edge target is downstream, edge source is current node id + upstream: [], // edge source is upstream, edge target is current node id + }, + }, + retrieval: [], // reference + history: [], + path: [], + globals: { + [AgentGlobals.SysQuery]: '', + [AgentGlobals.SysUserId]: '', + [AgentGlobals.SysConversationTurns]: 0, + [AgentGlobals.SysFiles]: [], + }, +}; + +export const useFetchAgentTemplates = () => { + const { t } = useTranslation(); + + const { data } = useQuery({ + queryKey: [AgentApiAction.FetchAgentTemplates], initialData: [], + queryFn: async () => { + const { data } = await flowService.listTemplates(); + if (Array.isArray(data?.data)) { + data.data.unshift({ + id: uuid(), + title: t('flow.blank'), + description: t('flow.createFromNothing'), + dsl: EmptyDsl, + }); + } + + return data.data; + }, + }); + + return data; +}; + +export const useFetchAgentListByPage = () => { + const { searchString, handleInputChange } = useHandleSearchChange(); + const { pagination, setPagination } = useGetPaginationWithRouter(); + const debouncedSearchString = useDebounce(searchString, { wait: 500 }); + + const { data, isFetching: loading } = useQuery<{ + kbs: IFlow[]; + total: number; + }>({ + queryKey: [ + AgentApiAction.FetchAgentList, + { + debouncedSearchString, + ...pagination, + }, + ], + initialData: { kbs: [], total: 0 }, gcTime: 0, queryFn: async () => { - const { data } = await flowService.listCanvas(); + const { data } = await flowService.listCanvasTeam({ + keywords: debouncedSearchString, + page_size: pagination.pageSize, + page: pagination.current, + }); + + return data?.data ?? []; + }, + }); + + const onInputChange: React.ChangeEventHandler = useCallback( + (e) => { + // setPagination({ page: 1 }); // TODO: 这里导致重复请求 + handleInputChange(e); + }, + [handleInputChange], + ); + + return { + data: data.kbs, + loading, + searchString, + handleInputChange: onInputChange, + pagination: { ...pagination, total: data?.total }, + setPagination, + }; +}; + +export const useUpdateAgentSetting = () => { + const queryClient = useQueryClient(); + + const { + data, + isPending: loading, + mutateAsync, + } = useMutation({ + mutationKey: [AgentApiAction.UpdateAgentSetting], + mutationFn: async (params: any) => { + const ret = await flowService.settingCanvas(params); + if (ret?.data?.code === 0) { + message.success('success'); + queryClient.invalidateQueries({ + queryKey: [AgentApiAction.FetchAgentList], + }); + } else { + message.error(ret?.data?.data); + } + return ret?.data?.code; + }, + }); + + return { data, loading, updateAgentSetting: mutateAsync }; +}; +export const useDeleteAgent = () => { + const queryClient = useQueryClient(); + const { + data, + isPending: loading, + mutateAsync, + } = useMutation({ + mutationKey: [AgentApiAction.DeleteAgent], + mutationFn: async (canvasIds: string[]) => { + const { data } = await flowService.removeCanvas({ canvasIds }); + if (data.code === 0) { + queryClient.invalidateQueries({ + queryKey: [AgentApiAction.FetchAgentList], + }); + } return data?.data ?? []; }, }); - return { data, loading }; + return { data, loading, deleteAgent: mutateAsync }; +}; + +export const useFetchAgent = (): { + data: IFlow; + loading: boolean; + refetch: () => void; +} => { + const { id } = useParams(); + const { sharedId } = useGetSharedChatSearchParams(); + + const { + data, + isFetching: loading, + refetch, + } = useQuery({ + queryKey: [AgentApiAction.FetchAgentDetail], + initialData: {} as IFlow, + refetchOnReconnect: false, + refetchOnMount: false, + refetchOnWindowFocus: false, + gcTime: 0, + queryFn: async () => { + const { data } = await flowService.getCanvas({}, sharedId || id); + + const messageList = buildMessageListWithUuid( + get(data, 'data.dsl.messages', []), + ); + set(data, 'data.dsl.messages', messageList); + + return data?.data ?? {}; + }, + }); + + return { data, loading, refetch }; +}; + +export const useResetAgent = () => { + const { id } = useParams(); + const { + data, + isPending: loading, + mutateAsync, + } = useMutation({ + mutationKey: [AgentApiAction.ResetAgent], + mutationFn: async () => { + const { data } = await flowService.resetCanvas({ id }); + return data; + }, + }); + + return { data, loading, resetAgent: mutateAsync }; +}; + +export const useSetAgent = () => { + const queryClient = useQueryClient(); + const { + data, + isPending: loading, + mutateAsync, + } = useMutation({ + mutationKey: [AgentApiAction.SetAgent], + mutationFn: async (params: { + id?: string; + title?: string; + dsl?: DSL; + avatar?: string; + }) => { + const { data = {} } = await flowService.setCanvas(params); + if (data.code === 0) { + message.success( + i18n.t(`message.${params?.id ? 'modified' : 'created'}`), + ); + queryClient.invalidateQueries({ + queryKey: [AgentApiAction.FetchAgentList], + }); + } + return data; + }, + }); + + return { data, loading, setAgent: mutateAsync }; }; diff --git a/web/src/hooks/use-chunk-request.ts b/web/src/hooks/use-chunk-request.ts new file mode 100644 index 00000000000..1cb80242386 --- /dev/null +++ b/web/src/hooks/use-chunk-request.ts @@ -0,0 +1,91 @@ +import { ResponseGetType } from '@/interfaces/database/base'; +import { IChunk, IKnowledgeFile } from '@/interfaces/database/knowledge'; +import kbService from '@/services/knowledge-service'; +import { useQuery } from '@tanstack/react-query'; +import { useDebounce } from 'ahooks'; +import { useCallback, useState } from 'react'; +import { IChunkListResult } from './chunk-hooks'; +import { + useGetPaginationWithRouter, + useHandleSearchChange, +} from './logic-hooks'; +import { useGetKnowledgeSearchParams } from './route-hook'; + +export const useFetchNextChunkList = (): ResponseGetType<{ + data: IChunk[]; + total: number; + documentInfo: IKnowledgeFile; +}> & + IChunkListResult => { + const { pagination, setPagination } = useGetPaginationWithRouter(); + const { documentId } = useGetKnowledgeSearchParams(); + const { searchString, handleInputChange } = useHandleSearchChange(); + const [available, setAvailable] = useState(); + const debouncedSearchString = useDebounce(searchString, { wait: 500 }); + + const { data, isFetching: loading } = useQuery({ + queryKey: [ + 'fetchChunkList', + documentId, + pagination.current, + pagination.pageSize, + debouncedSearchString, + available, + ], + placeholderData: (previousData: any) => + previousData ?? { data: [], total: 0, documentInfo: {} }, // https://github.com/TanStack/query/issues/8183 + gcTime: 0, + queryFn: async () => { + const { data } = await kbService.chunk_list({ + doc_id: documentId, + page: pagination.current, + size: pagination.pageSize, + available_int: available, + keywords: searchString, + }); + if (data.code === 0) { + const res = data.data; + return { + data: res.chunks, + total: res.total, + documentInfo: res.doc, + }; + } + + return ( + data?.data ?? { + data: [], + total: 0, + documentInfo: {}, + } + ); + }, + }); + + const onInputChange: React.ChangeEventHandler = useCallback( + (e) => { + setPagination({ page: 1 }); + handleInputChange(e); + }, + [handleInputChange, setPagination], + ); + + const handleSetAvailable = useCallback( + (a: number | undefined) => { + setPagination({ page: 1 }); + setAvailable(a); + }, + [setAvailable, setPagination], + ); + + return { + data, + loading, + pagination, + setPagination, + searchString, + handleInputChange: onInputChange, + available, + handleSetAvailable, + }; +}; diff --git a/web/src/hooks/use-knowledge-request.ts b/web/src/hooks/use-knowledge-request.ts index f57c04a2549..101e8150989 100644 --- a/web/src/hooks/use-knowledge-request.ts +++ b/web/src/hooks/use-knowledge-request.ts @@ -228,11 +228,18 @@ export const useUpdateKnowledge = (shouldFetchList = false) => { return { data, loading, saveKnowledgeConfiguration: mutateAsync }; }; -export const useFetchKnowledgeBaseConfiguration = () => { +export const useFetchKnowledgeBaseConfiguration = (refreshCount?: number) => { const { id } = useParams(); + let queryKey: (KnowledgeApiAction | number)[] = [ + KnowledgeApiAction.FetchKnowledgeDetail, + ]; + if (typeof refreshCount === 'number') { + queryKey = [KnowledgeApiAction.FetchKnowledgeDetail, refreshCount]; + } + const { data, isFetching: loading } = useQuery({ - queryKey: [KnowledgeApiAction.FetchKnowledgeDetail], + queryKey, initialData: {} as IKnowledge, gcTime: 0, queryFn: async () => { diff --git a/web/src/hooks/use-send-message.ts b/web/src/hooks/use-send-message.ts new file mode 100644 index 00000000000..0253f29bfc2 --- /dev/null +++ b/web/src/hooks/use-send-message.ts @@ -0,0 +1,140 @@ +import { Authorization } from '@/constants/authorization'; +import api from '@/utils/api'; +import { getAuthorization } from '@/utils/authorization-util'; +import { EventSourceParserStream } from 'eventsource-parser/stream'; +import { useCallback, useRef, useState } from 'react'; + +export enum MessageEventType { + WorkflowStarted = 'workflow_started', + NodeStarted = 'node_started', + NodeFinished = 'node_finished', + Message = 'message', + MessageEnd = 'message_end', + WorkflowFinished = 'workflow_finished', +} + +export interface IAnswerEvent { + event: MessageEventType; + message_id: string; + created_at: number; + task_id: string; + data: T; +} + +export interface INodeData { + inputs: Record; + outputs: Record; + component_id: string; + error: null | string; + elapsed_time: number; + created_at: number; +} + +export interface IMessageData { + content: string; +} + +export type INodeEvent = IAnswerEvent; + +export type IMessageEvent = IAnswerEvent; + +export type IChatEvent = INodeEvent | IMessageEvent; + +export type IEventList = Array; + +export const useSendMessageBySSE = (url: string = api.completeConversation) => { + const [answerList, setAnswerList] = useState([]); + const [done, setDone] = useState(true); + const timer = useRef(); + const sseRef = useRef(); + + const initializeSseRef = useCallback(() => { + sseRef.current = new AbortController(); + }, []); + + const resetAnswerList = useCallback(() => { + if (timer.current) { + clearTimeout(timer.current); + } + timer.current = setTimeout(() => { + setAnswerList([]); + clearTimeout(timer.current); + }, 1000); + }, []); + + const send = useCallback( + async ( + body: any, + controller?: AbortController, + ): Promise<{ response: Response; data: ResponseType } | undefined> => { + initializeSseRef(); + try { + setDone(false); + const response = await fetch(url, { + method: 'POST', + headers: { + [Authorization]: getAuthorization(), + 'Content-Type': 'application/json', + }, + body: JSON.stringify(body), + signal: controller?.signal || sseRef.current?.signal, + }); + + const res = response.clone().json(); + + const reader = response?.body + ?.pipeThrough(new TextDecoderStream()) + .pipeThrough(new EventSourceParserStream()) + .getReader(); + + while (true) { + const x = await reader?.read(); + if (x) { + const { done, value } = x; + if (done) { + console.info('done'); + resetAnswerList(); + break; + } + try { + const val = JSON.parse(value?.data || ''); + + console.info('data:', val); + + setAnswerList((list) => { + const nextList = [...list]; + nextList.push(val); + return nextList; + }); + } catch (e) { + console.warn(e); + } + } + } + console.info('done?'); + setDone(true); + resetAnswerList(); + return { data: await res, response }; + } catch (e) { + setDone(true); + resetAnswerList(); + + console.warn(e); + } + }, + [initializeSseRef, url, resetAnswerList], + ); + + const stopOutputMessage = useCallback(() => { + sseRef.current?.abort(); + }, []); + + return { + send, + answerList, + done, + setDone, + resetAnswerList, + stopOutputMessage, + }; +}; diff --git a/web/src/interfaces/database/agent.ts b/web/src/interfaces/database/agent.ts new file mode 100644 index 00000000000..8fd95d83d7c --- /dev/null +++ b/web/src/interfaces/database/agent.ts @@ -0,0 +1,217 @@ +export interface ICategorizeItem { + name: string; + description?: string; + examples?: { value: string }[]; + index: number; +} + +export type ICategorizeItemResult = Record< + string, + Omit & { examples: string[] } +>; + +export interface ISwitchCondition { + items: ISwitchItem[]; + logical_operator: string; + to: string[]; +} + +export interface ISwitchItem { + cpn_id: string; + operator: string; + value: string; +} + +export interface ISwitchForm { + conditions: ISwitchCondition[]; + end_cpn_ids: string[]; + no: string; +} + +import { Edge, Node } from '@xyflow/react'; +import { IReference, Message } from './chat'; + +export type DSLComponents = Record; + +export interface DSL { + components: DSLComponents; + history: any[]; + path?: string[][]; + answer?: any[]; + graph?: IGraph; + messages: Message[]; + reference: IReference[]; + globals: Record; + retrieval: IReference[]; +} + +export interface IOperator { + obj: IOperatorNode; + downstream: string[]; + upstream: string[]; + parent_id?: string; +} + +export interface IOperatorNode { + component_name: string; + params: Record; +} + +export declare interface IFlow { + avatar?: string; + canvas_type: null; + create_date: string; + create_time: number; + description: null; + dsl: DSL; + id: string; + title: string; + update_date: string; + update_time: number; + user_id: string; + permission: string; + nickname: string; +} + +export interface IFlowTemplate { + avatar: string; + canvas_type: string; + create_date: string; + create_time: number; + description: string; + dsl: DSL; + id: string; + title: string; + update_date: string; + update_time: number; +} + +export interface IGenerateForm { + max_tokens?: number; + temperature?: number; + top_p?: number; + presence_penalty?: number; + frequency_penalty?: number; + cite?: boolean; + prompt: number; + llm_id: string; + parameters: { key: string; component_id: string }; +} + +export interface ICategorizeForm extends IGenerateForm { + category_description: ICategorizeItemResult; +} + +export interface IRelevantForm extends IGenerateForm { + yes: string; + no: string; +} + +export interface ISwitchItem { + cpn_id: string; + operator: string; + value: string; +} + +export interface ISwitchForm { + conditions: ISwitchCondition[]; + end_cpn_id: string; + no: string; +} + +export interface IBeginForm { + prologue?: string; +} + +export interface IRetrievalForm { + similarity_threshold?: number; + keywords_similarity_weight?: number; + top_n?: number; + top_k?: number; + rerank_id?: string; + empty_response?: string; + kb_ids: string[]; +} + +export interface ICodeForm { + inputs?: Array<{ name?: string; component_id?: string }>; + lang: string; + script?: string; +} + +export interface IAgentForm { + sys_prompt: string; + prompts: Array<{ + role: string; + content: string; + }>; + max_retries: number; + delay_after_error: number; + visual_files_var: string; + max_rounds: number; + exception_method: Nullable<'comment' | 'go'>; + exception_comment: any; + exception_goto: any; + tools: Array<{ + component_name: string; + params: Record; + }>; + outputs: { + structured_output: Record>; + content: Record; + }; +} + +export type BaseNodeData = { + label: string; // operator type + name: string; // operator name + color?: string; + form?: TForm; +}; + +export type BaseNode = Node>; + +export type IBeginNode = BaseNode; +export type IRetrievalNode = BaseNode; +export type IGenerateNode = BaseNode; +export type ICategorizeNode = BaseNode; +export type ISwitchNode = BaseNode; +export type IRagNode = BaseNode; +export type IRelevantNode = BaseNode; +export type ILogicNode = BaseNode; +export type INoteNode = BaseNode; +export type IMessageNode = BaseNode; +export type IRewriteNode = BaseNode; +export type IInvokeNode = BaseNode; +export type ITemplateNode = BaseNode; +export type IEmailNode = BaseNode; +export type IIterationNode = BaseNode; +export type IIterationStartNode = BaseNode; +export type IKeywordNode = BaseNode; +export type ICodeNode = BaseNode; +export type IAgentNode = BaseNode; +export type IToolNode = BaseNode; + +export type RAGFlowNodeType = + | IBeginNode + | IRetrievalNode + | IGenerateNode + | ICategorizeNode + | ISwitchNode + | IRagNode + | IRelevantNode + | ILogicNode + | INoteNode + | IMessageNode + | IRewriteNode + | IInvokeNode + | ITemplateNode + | IEmailNode + | IIterationNode + | IIterationStartNode + | IKeywordNode; + +export interface IGraph { + nodes: RAGFlowNodeType[]; + edges: Edge[]; +} diff --git a/web/src/interfaces/database/flow.ts b/web/src/interfaces/database/flow.ts index 8d324c373a2..2d4aa4cbda2 100644 --- a/web/src/interfaces/database/flow.ts +++ b/web/src/interfaces/database/flow.ts @@ -11,6 +11,8 @@ export interface DSL { graph?: IGraph; messages: Message[]; reference: IReference[]; + globals: Record; + retrieval: IReference[]; } export interface IOperator { @@ -90,7 +92,7 @@ export interface IRelevantForm extends IGenerateForm { export interface ISwitchCondition { items: ISwitchItem[]; logical_operator: string; - to: string; + to: string[] | string; } export interface ISwitchItem { @@ -152,6 +154,7 @@ export type IIterationNode = BaseNode; export type IIterationStartNode = BaseNode; export type IKeywordNode = BaseNode; export type ICodeNode = BaseNode; +export type IAgentNode = BaseNode; export type RAGFlowNodeType = | IBeginNode diff --git a/web/src/interfaces/database/user-setting.ts b/web/src/interfaces/database/user-setting.ts index 24c6257cafe..ff4094d4b4a 100644 --- a/web/src/interfaces/database/user-setting.ts +++ b/web/src/interfaces/database/user-setting.ts @@ -16,6 +16,7 @@ export interface IUserInfo { nickname: string; password: string; status: string; + timezone: string; update_date: string; update_time: number; } diff --git a/web/src/interfaces/request/base.ts b/web/src/interfaces/request/base.ts index b780abe8dec..789be810462 100644 --- a/web/src/interfaces/request/base.ts +++ b/web/src/interfaces/request/base.ts @@ -1,7 +1,7 @@ export interface IPaginationRequestBody { keywords?: string; page?: number; - page_size?: number; // name|create|doc_num|create_time|update_time,default:create_time + page_size?: number; // name|create|doc_num|create_time|update_time, default:create_time orderby?: string; desc?: string; } diff --git a/web/src/locales/de.ts b/web/src/locales/de.ts index 858bb4b0c77..572c6ff340e 100644 --- a/web/src/locales/de.ts +++ b/web/src/locales/de.ts @@ -144,7 +144,7 @@ export default { toMessage: 'Endseitennummer fehlt (ausgeschlossen)', layoutRecognize: 'Dokumentenparser', layoutRecognizeTip: - 'Verwendet ein visuelles Modell für die PDF-Layout-Analyse, um Dokumententitel, Textblöcke, Bilder und Tabellen effektiv zu lokalisieren. Wenn die einfache Option gewählt wird, wird nur der reine Text im PDF abgerufen. Bitte beachten Sie, dass diese Option derzeit NUR für PDF-Dokumente funktioniert.', + 'Verwendet ein visuelles Modell für die PDF-Layout-Analyse, um Dokumententitel, Textblöcke, Bilder und Tabellen effektiv zu lokalisieren. Wenn die einfache Option gewählt wird, wird nur der reine Text im PDF abgerufen. Bitte beachten Sie, dass diese Option derzeit NUR für PDF-Dokumente funktioniert. Weitere Informationen finden Sie unter https://ragflow.io/docs/dev/select_pdf_parser.', taskPageSize: 'Aufgabenseitengröße', taskPageSizeMessage: 'Bitte geben Sie die Größe der Aufgabenseite ein!', taskPageSizeTip: @@ -176,10 +176,10 @@ export default { 'Verwenden Sie dies zusammen mit der General-Schnittmethode. Wenn deaktiviert, werden Tabellenkalkulationsdateien (XLSX, XLS (Excel 97-2003)) zeilenweise in Schlüssel-Wert-Paare analysiert. Wenn aktiviert, werden Tabellenkalkulationsdateien in HTML-Tabellen umgewandelt. Wenn die ursprüngliche Tabelle mehr als 12 Zeilen enthält, teilt das System sie automatisch alle 12 Zeilen in mehrere HTML-Tabellen auf. Für weitere Informationen siehe https://ragflow.io/docs/dev/enable_excel2html.', autoKeywords: 'Auto-Schlüsselwort', autoKeywordsTip: - 'Extrahieren Sie automatisch N Schlüsselwörter für jeden Abschnitt, um deren Ranking in Abfragen mit diesen Schlüsselwörtern zu verbessern. Beachten Sie, dass zusätzliche Tokens vom in den "Systemmodelleinstellungen" angegebenen Chat-Modell verbraucht werden. Sie können die hinzugefügten Schlüsselwörter eines Abschnitts in der Abschnittsliste überprüfen oder aktualisieren.', + 'Extrahieren Sie automatisch N Schlüsselwörter für jeden Abschnitt, um deren Ranking in Abfragen mit diesen Schlüsselwörtern zu verbessern. Beachten Sie, dass zusätzliche Tokens vom in den "Systemmodelleinstellungen" angegebenen Chat-Modell verbraucht werden. Sie können die hinzugefügten Schlüsselwörter eines Abschnitts in der Abschnittsliste überprüfen oder aktualisieren. Für weitere Informationen siehe https://ragflow.io/docs/dev/autokeyword_autoquestion.', autoQuestions: 'Auto-Frage', autoQuestionsTip: - 'Um die Ranking-Ergebnisse zu verbessern, extrahieren Sie N Fragen für jeden Wissensdatenbank-Chunk mithilfe des im "Systemmodell-Setup" definierten Chatmodells. Beachten Sie, dass dies zusätzliche Token verbraucht. Die Ergebnisse können in der Chunk-Liste eingesehen und bearbeitet werden. Fehler bei der Fragenextraktion blockieren den Chunking-Prozess nicht; leere Ergebnisse werden dem ursprünglichen Chunk hinzugefügt.', + 'Um die Ranking-Ergebnisse zu verbessern, extrahieren Sie N Fragen für jeden Wissensdatenbank-Chunk mithilfe des im "Systemmodell-Setup" definierten Chatmodells. Beachten Sie, dass dies zusätzliche Token verbraucht. Die Ergebnisse können in der Chunk-Liste eingesehen und bearbeitet werden. Fehler bei der Fragenextraktion blockieren den Chunking-Prozess nicht; leere Ergebnisse werden dem ursprünglichen Chunk hinzugefügt. Für weitere Informationen siehe https://ragflow.io/docs/dev/autokeyword_autoquestion.', redo: 'Möchten Sie die vorhandenen {{chunkNum}} Chunks löschen?', setMetaData: 'Metadaten festlegen', pleaseInputJson: 'Bitte JSON eingeben', @@ -241,7 +241,7 @@ export default { methodTitle: 'Beschreibung der Chunk-Methode', methodExamples: 'Beispiele', methodExamplesDescription: - 'Die folgenden Screenshots dienen zur Verdeutlichung.', + 'Um Ihnen das Verständnis zu erleichtern, haben wir relevante Screenshots als Referenz bereitgestellt.', dialogueExamplesTitle: 'Dialogbeispiele', methodEmpty: 'Hier wird eine visuelle Erklärung der Wissensdatenbank-Kategorien angezeigt', @@ -255,7 +255,7 @@ export default { manual: `

Nur PDF wird unterstützt.

Wir gehen davon aus, dass das Handbuch eine hierarchische Abschnittsstruktur aufweist und verwenden die Titel der untersten Abschnitte als Grundeinheit für die Aufteilung der Dokumente. Daher werden Abbildungen und Tabellen im selben Abschnitt nicht getrennt, was zu größeren Chunk-Größen führen kann.

`, - naive: `

Unterstützte Dateiformate sind DOCX, XLSX, XLS (Excel 97-2003), PPT, PDF, TXT, JPEG, JPG, PNG, TIF, GIF, CSV, JSON, EML, HTML.

+ naive: `

Unterstützte Dateiformate sind MD, MDX, DOCX, XLSX, XLS (Excel 97-2003), PPT, PDF, TXT, JPEG, JPG, PNG, TIF, GIF, CSV, JSON, EML, HTML.

Diese Methode teilt Dateien mit einer 'naiven' Methode auf:

  • Verwenden eines Erkennungsmodells, um die Texte in kleinere Segmente aufzuteilen.
  • @@ -565,6 +565,7 @@ export default { }, setting: { profile: 'Profil', + avatar: 'Avatar', profileDescription: 'Aktualisieren Sie hier Ihr Foto und Ihre persönlichen Daten.', maxTokens: 'Maximale Tokens', diff --git a/web/src/locales/en.ts b/web/src/locales/en.ts index 08b086e0f22..63fa4516369 100644 --- a/web/src/locales/en.ts +++ b/web/src/locales/en.ts @@ -162,7 +162,7 @@ export default { cancel: 'Cancel', rerankModel: 'Rerank model', rerankPlaceholder: 'Please select', - rerankTip: `If left empty, RAGFlow will use a combination of weighted keyword similarity and weighted vector cosine similarity; if a rerank model is selected, a weighted reranking score will replace the weighted vector cosine similarity. Please be aware that using a rerank model will significantly increase the system's response time.`, + rerankTip: `Optional. If left empty, RAGFlow will use a combination of weighted keyword similarity and weighted vector cosine similarity; if a rerank model is selected, a weighted reranking score will replace the weighted vector cosine similarity. Please be aware that using a rerank model will significantly increase the system's response time. If you wish to use a rerank model, ensure you use a SaaS reranker; if you prefer a locally deployed rerank model, ensure you start RAGFlow with docker-compose-gpu.yml.`, topK: 'Top-K', topKTip: `Used together with the Rerank model, this setting defines the number of text chunks to be sent to the specified reranking model.`, delimiter: `Delimiter for text`, @@ -171,9 +171,9 @@ export default { html4excel: 'Excel to HTML', html4excelTip: `Use with the General chunking method. When disabled, spreadsheets (XLSX or XLS(Excel 97-2003)) in the knowledge base will be parsed into key-value pairs. When enabled, they will be parsed into HTML tables, splitting every 12 rows if the original table has more than 12 rows.`, autoKeywords: 'Auto-keyword', - autoKeywordsTip: `Automatically extract N keywords for each chunk to increase their ranking for queries containing those keywords. Be aware that extra tokens will be consumed by the chat model specified in 'System model settings'. You can check or update the added keywords for a chunk from the chunk list. `, + autoKeywordsTip: `Automatically extract N keywords for each chunk to increase their ranking for queries containing those keywords. Be aware that extra tokens will be consumed by the chat model specified in 'System model settings'. You can check or update the added keywords for a chunk from the chunk list. For details, see https://ragflow.io/docs/dev/autokeyword_autoquestion.`, autoQuestions: 'Auto-question', - autoQuestionsTip: `Automatically extract N questions for each chunk to increase their ranking for queries containing those questions. You can check or update the added questions for a chunk from the chunk list. This feature will not disrupt the chunking process if an error occurs, except that it may add an empty result to the original chunk. Be aware that extra tokens will be consumed by the LLM specified in 'System model settings'.`, + autoQuestionsTip: `Automatically extract N questions for each chunk to increase their ranking for queries containing those questions. You can check or update the added questions for a chunk from the chunk list. This feature will not disrupt the chunking process if an error occurs, except that it may add an empty result to the original chunk. Be aware that extra tokens will be consumed by the LLM specified in 'System model settings'. For details, see https://ragflow.io/docs/dev/autokeyword_autoquestion.`, redo: 'Do you want to clear the existing {{chunkNum}} chunks?', setMetaData: 'Set Meta Data', pleaseInputJson: 'Please enter JSON', @@ -236,7 +236,7 @@ export default { methodTitle: 'Chunking method description', methodExamples: 'Examples', methodExamplesDescription: - 'The following screenshots are provided for clarity.', + 'The following screenshots are provided for clarification.', dialogueExamplesTitle: 'view', methodEmpty: 'This will display a visual explanation of the knowledge base categories', @@ -250,7 +250,7 @@ export default { manual: `

    Only PDF is supported.

    We assume that the manual has a hierarchical section structure, using the lowest section titles as basic unit for chunking documents. Therefore, figures and tables in the same section will not be separated, which may result in larger chunk sizes.

    `, - naive: `

    Supported file formats are DOCX, XLSX, XLS (Excel 97-2003), PPT, PDF, TXT, JPEG, JPG, PNG, TIF, GIF, CSV, JSON, EML, HTML.

    + naive: `

    Supported file formats are MD, MDX, DOCX, XLSX, XLS (Excel 97-2003), PPT, PDF, TXT, JPEG, JPG, PNG, TIF, GIF, CSV, JSON, EML, HTML.

    This method chunks files using a 'naive' method:

  • Use vision detection model to split the texts into smaller segments.
  • @@ -455,7 +455,8 @@ This auto-tagging feature enhances retrieval by adding another layer of domain-s modelTip: 'Large language chat model', modelMessage: 'Please select!', modelEnabledTools: 'Enabled tools', - modelEnabledToolsTip: 'Please select one or more tools for the chat model to use. It takes no effect for models not supporting tool call.', + modelEnabledToolsTip: + 'Please select one or more tools for the chat model to use. It takes no effect for models not supporting tool call.', freedom: 'Freedom', improvise: 'Improvise', precise: 'Precise', @@ -550,6 +551,7 @@ This auto-tagging feature enhances retrieval by adding another layer of domain-s }, setting: { profile: 'Profile', + avatar: 'Avatar', profileDescription: 'Update your photo and personal details here.', maxTokens: 'Max Tokens', maxTokensMessage: 'Max Tokens is required', @@ -788,7 +790,9 @@ This auto-tagging feature enhances retrieval by adding another layer of domain-s examples: 'Examples', to: 'To', msg: 'Messages', - messagePlaceholder: 'message', + msgTip: + 'Output the variable content of the upstream component or the text entered by yourself.', + messagePlaceholder: `Please enter your message content, use '/' to quickly insert variables.`, messageMsg: 'Please input message or delete this field.', addField: 'Add option', addMessage: 'Add message', @@ -812,7 +816,7 @@ This auto-tagging feature enhances retrieval by adding another layer of domain-s relevantDescription: `A component that uses the LLM to assess whether the upstream output is relevant to the user's latest query. Ensure you specify the next component for each judge result.`, rewriteQuestionDescription: `A component that rewrites a user query from the Interact component, based on the context of previous dialogues.`, messageDescription: - "A component that sends out a static message. If multiple messages are supplied, it randomly selects one to send. Ensure its downstream is 'Interact', the interface component.", + 'This component returns the final data output of the workflow along with predefined message content. ', keywordDescription: `A component that retrieves top N search results from user's input. Ensure the TopN value is set properly before use.`, switchDescription: `A component that evaluates conditions based on the output of previous components and directs the flow of execution accordingly. It allows for complex branching logic by defining cases and specifying actions for each case or default action if no conditions are met.`, wikipediaDescription: `A component that searches from wikipedia.org, using TopN to specify the number of search results. It supplements the existing knowledge bases.`, @@ -1268,14 +1272,26 @@ This delimiter is used to split the input text into several text pieces echo of codeDescription: 'It allows developers to write custom Python logic.', inputVariables: 'Input variables', runningHintText: 'is running...🕞', + openingSwitch: 'Opening switch', + openingCopy: 'Opening copy', + openingSwitchTip: + 'Your users will see this welcome message at the beginning.', + modeTip: 'The mode defines how the workflow is initiated.', + beginInputTip: + 'By defining input parameters, this content can be accessed by other components in subsequent processes.', + query: 'Query variables', + agent: 'Agent', + agentDescription: + 'Builds agent components equipped with reasoning, tool usage, and multi-agent collaboration. ', }, llmTools: { bad_calculator: { - name: "Calculator", - description: "A tool to calculate the sum of two numbers (will give wrong answer)", + name: 'Calculator', + description: + 'A tool to calculate the sum of two numbers (will give wrong answer)', params: { - a: "The first number", - b: "The second number", + a: 'The first number', + b: 'The second number', }, }, }, diff --git a/web/src/locales/es.ts b/web/src/locales/es.ts index 57a8a1f3ac3..274dbbbb11e 100644 --- a/web/src/locales/es.ts +++ b/web/src/locales/es.ts @@ -133,7 +133,7 @@ export default { toMessage: 'Falta el número de página final (excluido)', layoutRecognize: 'Reconocimiento de disposición', layoutRecognizeTip: - 'Usa modelos visuales para el análisis de disposición y así identificar mejor la estructura del documento, encontrar dónde están los títulos, bloques de texto, imágenes y tablas. Sin esta función, solo se obtendrá el texto plano del PDF.', + 'Usa modelos visuales para el análisis de disposición y así identificar mejor la estructura del documento, encontrar dónde están los títulos, bloques de texto, imágenes y tablas. Sin esta función, solo se obtendrá el texto plano del PDF. Para más información, consulte https://ragflow.io/docs/dev/select_pdf_parser.', taskPageSize: 'Tamaño de la tarea por página', taskPageSizeMessage: '¡Por favor ingresa el tamaño de la tarea por página!', @@ -151,7 +151,7 @@ export default { cancel: 'Cancelar', rerankModel: 'Modelo de reordenamiento', rerankPlaceholder: 'Por favor selecciona', - rerankTip: `Si está vacío, se utilizan los embeddings de la consulta y los fragmentos para calcular la similitud coseno del vector. De lo contrario, se usa la puntuación de reordenamiento en lugar de la similitud coseno del vector.`, + rerankTip: `Opcional. Si se deja vacío, RAGFlow utilizará una combinación de similitud ponderada de palabras clave y similitud ponderada del coseno vectorial; si se selecciona un modelo de reordenamiento, una puntuación ponderada de reordenamiento reemplazará la similitud ponderada del coseno vectorial. Tenga en cuenta que usar un modelo de reordenamiento aumentará significativamente el tiempo de respuesta del sistema. Si desea usar un modelo de reordenamiento, asegúrese de usar un reranker SaaS; si prefiere un modelo de reordenamiento desplegado localmente, asegúrese de iniciar RAGFlow con docker-compose-gpu.yml.`, topK: 'Top-K', topKTip: `Utilizado junto con el Rerank model, esta configuración define el número de fragmentos de texto que se enviarán al modelo reranking especificado.`, delimiter: `Delimitadores para segmentación de texto`, @@ -282,6 +282,7 @@ export default { }, setting: { profile: 'Perfil', + avatar: 'Avatar', profileDescription: 'Actualiza tu foto y tus datos personales aquí.', maxTokens: 'Máximo de tokens', maxTokensMessage: 'El máximo de tokens es obligatorio', diff --git a/web/src/locales/id.ts b/web/src/locales/id.ts index dbd89561bea..60cff62bd2e 100644 --- a/web/src/locales/id.ts +++ b/web/src/locales/id.ts @@ -138,7 +138,7 @@ export default { toMessage: 'Nomor halaman akhir hilang (tidak termasuk)', layoutRecognize: 'Pengenalan tata letak', layoutRecognizeTip: - 'Gunakan model visual untuk analisis tata letak untuk lebih mengidentifikasi struktur dokumen, menemukan di mana judul, blok teks, gambar, dan tabel berada. Tanpa fitur ini, hanya teks biasa dari PDF yang dapat diperoleh.', + 'Gunakan model visual untuk analisis tata letak untuk lebih mengidentifikasi struktur dokumen, menemukan di mana judul, blok teks, gambar, dan tabel berada. Tanpa fitur ini, hanya teks biasa dari PDF yang dapat diperoleh. Untuk informasi lebih lanjut, lihat https://ragflow.io/docs/dev/select_pdf_parser.', taskPageSize: 'Ukuran halaman tugas', taskPageSizeMessage: 'Silakan masukkan ukuran halaman tugas Anda!', taskPageSizeTip: `Jika menggunakan pengenalan tata letak, file PDF akan dibagi menjadi kelompok berturut-turut. Analisis tata letak akan dilakukan secara paralel antar kelompok untuk meningkatkan kecepatan pemrosesan. 'Ukuran halaman tugas' menentukan ukuran kelompok. Semakin besar ukuran halaman, semakin kecil kemungkinan teks berkelanjutan antara halaman dibagi menjadi potongan yang berbeda.`, @@ -155,7 +155,7 @@ export default { cancel: 'Batal', rerankModel: 'Model Rerank', rerankPlaceholder: 'Silakan pilih', - rerankTip: `Jika kosong. Ini menggunakan embedding dari kueri dan potongan untuk menghitung kesamaan kosinus vektor. Jika tidak, ini menggunakan skor rerank sebagai pengganti kesamaan kosinus vektor.`, + rerankTip: `Opsional. Jika dikosongkan, RAGFlow akan menggunakan kombinasi kesamaan kata kunci berbobot dan kesamaan kosinus vektor berbobot; jika model rerank dipilih, skor reranking berbobot akan menggantikan kesamaan kosinus vektor berbobot. Harap diperhatikan bahwa menggunakan model rerank akan secara signifikan meningkatkan waktu respons sistem. Jika Anda ingin menggunakan model rerank, pastikan menggunakan SaaS reranker; jika Anda lebih memilih model rerank yang dijalankan secara lokal, pastikan memulai RAGFlow dengan docker-compose-gpu.yml.`, topK: 'Top-K', topKTip: `Digunakan bersama dengan Rerank model, pengaturan ini menentukan jumlah potongan teks yang akan dikirim ke model reranking yang ditentukan.`, delimiter: `Pemisah untuk segmentasi teks`, @@ -195,7 +195,7 @@ export default { methodTitle: 'Deskripsi Metode Pemotongan', methodExamples: 'Contoh', methodExamplesDescription: - 'Cuplikan berikut disajikan untuk memudahkan pemahaman.', + 'Untuk membantu Anda memahami lebih baik, kami menyediakan tangkapan layar terkait sebagai referensi.', dialogueExamplesTitle: 'Contoh Dialog', methodEmpty: 'Ini akan menampilkan penjelasan visual dari kategori basis pengetahuan', @@ -211,7 +211,7 @@ export default { Kami mengasumsikan manual memiliki struktur bagian hierarkis. Kami menggunakan judul bagian terendah sebagai poros untuk memotong dokumen. Jadi, gambar dan tabel dalam bagian yang sama tidak akan dipisahkan, dan ukuran potongan mungkin besar.

    `, - naive: `

    Format file yang didukung adalah DOCX, XLSX, XLS (Excel 97-2003), PPT, PDF, TXT, JPEG, JPG, PNG, TIF, GIF, CSV, JSON, EML, HTML.

    + naive: `

    Format file yang didukung adalah MD, MDX, DOCX, XLSX, XLS (Excel 97-2003), PPT, PDF, TXT, JPEG, JPG, PNG, TIF, GIF, CSV, JSON, EML, HTML.

    Metode ini menerapkan cara naif untuk memotong file:

  • Teks berturut-turut akan dipotong menjadi potongan menggunakan model deteksi visual.
  • @@ -456,6 +456,7 @@ export default { }, setting: { profile: 'Profil', + avatar: 'Avatar', profileDescription: 'Perbarui foto dan detail pribadi Anda di sini.', maxTokens: 'Token Maksimum', maxTokensMessage: 'Token Maksimum diperlukan', diff --git a/web/src/locales/ja.ts b/web/src/locales/ja.ts index aa57438ecbf..0a044b74217 100644 --- a/web/src/locales/ja.ts +++ b/web/src/locales/ja.ts @@ -138,7 +138,7 @@ export default { toMessage: '終了ページ番号が不足しています(除外)', layoutRecognize: 'レイアウト認識', layoutRecognizeTip: - 'レイアウト分析のためにビジュアルモデルを使用し、文書の構造を理解しやすくします。', + 'レイアウト分析のためにビジュアルモデルを使用し、文書の構造を理解しやすくします。詳細については、https://ragflow.io/docs/dev/select_pdf_parser をご覧ください。', taskPageSize: 'タスクページサイズ', taskPageSizeMessage: 'タスクページサイズを入力してください', taskPageSizeTip: `レイアウト認識中、PDFファイルはチャンクに分割され、処理速度を向上させるために並列処理されます。`, @@ -156,7 +156,7 @@ export default { cancel: 'キャンセル', rerankModel: 'リランキングモデル', rerankPlaceholder: '選択してください', - rerankTip: `オプション:Rerankモデルを選択しない場合、システムはデフォルトでキーワードの類似度とベクトルのコサイン類似度を組み合わせたハイブリッド検索方式を採用します。Rerankモデルを設定した場合、ハイブリッド検索のベクトル類似度部分はrerankのスコアに置き換えられます。`, + rerankTip: `任意です。空欄の場合、RAGFlowは加重キーワード類似度と加重ベクトルコサイン類似度の組み合わせを使用します。リランキングモデルが選択された場合は、加重リランキングスコアが加重ベクトルコサイン類似度に代わります。リランキングモデルを使用すると、システムの応答時間が大幅に増加することにご注意ください。リランキングモデルを使用する場合は、SaaSリランカーを使用してください。ローカルにデプロイされたリランキングモデルを使用する場合は、docker-compose-gpu.ymlでRAGFlowを起動してください。`, topK: 'トップK', topKTip: `Rerank modelと一緒に使用する場合、この設定は指定されたreranking modelに送信するテキストのチャンク数を定義します。`, delimiter: `テキストセグメンテーションの区切り文字`, @@ -165,9 +165,9 @@ export default { html4excel: 'ExcelをHTMLに変換', html4excelTip: `General切片方法と併用してください。無効の場合、表計算ファイル(XLSX、XLS(Excel 97-2003))は行ごとにキーと値のペアとして解析されます。有効の場合、表計算ファイルはHTML表として解析されます。元の表が12行を超える場合、システムは自動的に12行ごとに複数のHTML表に分割します。詳細については、https://ragflow.io/docs/dev/enable_excel2html をご覧ください。`, autoKeywords: '自動キーワード', - autoKeywordsTip: `各チャンクに含まれるキーワードのランキングを向上させるために、自動的にN個のキーワードを抽出します。「システムモデル設定」で指定されたチャットモデルによって追加のトークンが消費されることに注意してください。チャンクリストから追加されたキーワードを確認または更新することができます。`, + autoKeywordsTip: `各チャンクに含まれるキーワードのランキングを向上させるために、自動的にN個のキーワードを抽出します。「システムモデル設定」で指定されたチャットモデルによって追加のトークンが消費されることに注意してください。チャンクリストから追加されたキーワードを確認または更新することができます。詳細は https://ragflow.io/docs/dev/autokeyword_autoquestion をご覧ください。`, autoQuestions: '自動質問', - autoQuestionsTip: `ランキングスコアを向上させるために、「システムモデル設定」で定義されたチャットモデルを使用して、ナレッジベースのチャンクごとにN個の質問を抽出します。 これにより、追加のトークンが消費されることに注意してください。 結果はチャンクリストで表示および編集できます。 質問抽出エラーはチャンク処理をブロックしません。空の結果が元のチャンクに追加されます。`, + autoQuestionsTip: `ランキングスコアを向上させるために、「システムモデル設定」で定義されたチャットモデルを使用して、ナレッジベースのチャンクごとにN個の質問を抽出します。 これにより、追加のトークンが消費されることに注意してください。 結果はチャンクリストで表示および編集できます。 質問抽出エラーはチャンク処理をブロックしません。空の結果が元のチャンクに追加されます。詳細は https://ragflow.io/docs/dev/autokeyword_autoquestion をご覧ください。`, }, knowledgeConfiguration: { titleDescription: @@ -202,7 +202,7 @@ export default { methodTitle: 'チャンク方法の説明', methodExamples: '例', methodExamplesDescription: - '以下のスクリーンショットは明確な説明のために提供されています。', + '理解を深めるために、関連するスクリーンショットを参考として提供しております。', dialogueExamplesTitle: '会話の例', methodEmpty: 'ナレッジベースカテゴリの視覚的説明がここに表示されます', book: `

    対応ファイル形式はDOCX, PDF, TXTです。

    @@ -215,7 +215,7 @@ export default { manual: `

    対応するのはPDFのみです。

    マニュアルは階層的なセクション構造を持つと仮定され、最下位のセクションタイトルを基にチャンク分割を行います。そのため、同じセクション内の図表は分割されませんが、大きなチャンクサイズになる可能性があります。

    `, - naive: `

    対応ファイル形式はDOCX, XLSX, XLS (Excel 97-2003), PPT, PDF, TXT, JPEG, JPG, PNG, TIF, GIF, CSV, JSON, EML, HTMLです。

    + naive: `

    対応ファイル形式はMD, MDX, DOCX, XLSX, XLS (Excel 97-2003), PPT, PDF, TXT, JPEG, JPG, PNG, TIF, GIF, CSV, JSON, EML, HTMLです。

    この方法では、'ナイーブ'な方法でファイルを分割します:

  • 視覚認識モデルを使用してテキストを小さなセグメントに分割します。
  • @@ -453,6 +453,7 @@ export default { }, setting: { profile: 'プロファイル', + avatar: 'アバター‌', profileDescription: 'ここで写真と個人情報を更新してください。', maxTokens: '最大トークン数', maxTokensMessage: '最大トークン数は必須です', diff --git a/web/src/locales/pt-br.ts b/web/src/locales/pt-br.ts index de20891e881..c0144d2ab2e 100644 --- a/web/src/locales/pt-br.ts +++ b/web/src/locales/pt-br.ts @@ -141,7 +141,7 @@ export default { toMessage: 'Página final ausente (excluída)', layoutRecognize: 'Reconhecimento de layout', layoutRecognizeTip: - 'Use modelos visuais para análise de layout para entender melhor a estrutura do documento e localizar efetivamente títulos, blocos de texto, imagens e tabelas. Se desativado, apenas o texto simples no PDF será recuperado.', + 'Use modelos visuais para análise de layout para entender melhor a estrutura do documento e localizar efetivamente títulos, blocos de texto, imagens e tabelas. Se desativado, apenas o texto simples no PDF será recuperado. Para mais informações, acesse https://ragflow.io/docs/dev/select_pdf_parser.', taskPageSize: 'Tamanho da página da tarefa', taskPageSizeMessage: 'Por favor, insira o tamanho da página da tarefa!', taskPageSizeTip: @@ -161,7 +161,7 @@ export default { rerankModel: 'Modelo de reranking', rerankPlaceholder: 'Por favor, selecione', rerankTip: - 'Se deixado vazio, o RAGFlow usará uma combinação de similaridade de palavras-chave ponderada e similaridade de cosseno vetorial ponderada; se um modelo de reranking for selecionado, uma pontuação de reranking ponderada substituirá a similaridade de cosseno vetorial ponderada. Esteja ciente de que usar um modelo de reranking aumentará significativamente o tempo de resposta do sistema.', + 'Opcional. Se deixar em branco, o RAGFlow usará uma combinação de similaridade ponderada por palavra-chave e similaridade ponderada do cosseno vetorial; se um modelo de rerank for selecionado, uma pontuação ponderada de reranking substituirá a similaridade ponderada do cosseno vetorial. Esteja ciente de que usar um modelo de rerank aumentará significativamente o tempo de resposta do sistema. Se desejar usar um modelo de rerank, certifique-se de usar um reranker SaaS; se preferir um modelo de rerank implantado localmente, certifique-se de iniciar o RAGFlow com docker-compose-gpu.yml.', topK: 'Top-K', topKTip: 'Usado em conjunto com o Rerank model, essa configuração define o número de trechos de texto a serem enviados ao modelo reranking especificado.', @@ -173,9 +173,9 @@ export default { 'Use em conjunto com o método de fragmentação General. Quando desativado, arquivos de planilhas (XLSX, XLS (Excel 97-2003)) serão analisados linha por linha como pares chave-valor. Quando ativado, os arquivos de planilhas serão convertidos em tabelas HTML. Se a tabela original tiver mais de 12 linhas, o sistema dividirá automaticamente em várias tabelas HTML a cada 12 linhas. Para mais informações, consulte https://ragflow.io/docs/dev/enable_excel2html.', autoKeywords: 'Palavras-chave automáticas', autoKeywordsTip: - 'Extraia automaticamente N palavras-chave de cada bloco para aumentar sua classificação em consultas que contenham essas palavras-chave. Esteja ciente de que o modelo de chat especificado nas "Configurações do modelo do sistema" consumirá tokens adicionais. Você pode verificar ou atualizar as palavras-chave adicionadas a um bloco na lista de blocos.', + 'Extraia automaticamente N palavras-chave de cada bloco para aumentar sua classificação em consultas que contenham essas palavras-chave. Esteja ciente de que o modelo de chat especificado nas "Configurações do modelo do sistema" consumirá tokens adicionais. Você pode verificar ou atualizar as palavras-chave adicionadas a um bloco na lista de blocos. Para mais detalhes, consulte https://ragflow.io/docs/dev/autokeyword_autoquestion.', autoQuestions: 'Perguntas automáticas', - autoQuestionsTip: `Para aumentar as pontuações de classificação, extraia N perguntas para cada bloco da base de conhecimento usando o modelo de bate-papo definido em "Configurações do Modelo do Sistema". Observe que isso consome tokens extras. Os resultados podem ser visualizados e editados na lista de blocos. Erros na extração de perguntas não bloquearão o processo de fragmentação; resultados vazios serão adicionados ao bloco original.`, + autoQuestionsTip: `Para aumentar as pontuações de classificação, extraia N perguntas para cada bloco da base de conhecimento usando o modelo de bate-papo definido em "Configurações do Modelo do Sistema". Observe que isso consome tokens extras. Os resultados podem ser visualizados e editados na lista de blocos. Erros na extração de perguntas não bloquearão o processo de fragmentação; resultados vazios serão adicionados ao bloco original. Para mais detalhes, consulte https://ragflow.io/docs/dev/autokeyword_autoquestion.`, redo: 'Deseja limpar os {{chunkNum}} fragmentos existentes?', setMetaData: 'Definir Metadados', pleaseInputJson: 'Por favor, insira um JSON', @@ -235,7 +235,7 @@ export default { methodTitle: 'Descrição do método de fragmentação', methodExamples: 'Exemplos', methodExamplesDescription: - 'As capturas de tela a seguir são fornecidas para maior clareza.', + 'Para ajudá-lo(a) a entender melhor, disponibilizamos capturas de tela relevantes para referência.', dialogueExamplesTitle: 'Exemplos de diálogos', methodEmpty: 'Aqui será exibida uma explicação visual das categorias da base de conhecimento', @@ -246,7 +246,7 @@ export default { Os fragmentos terão granularidade compatível com 'ARTIGO', garantindo que todo o texto de nível superior seja incluído no fragmento.

    `, manual: `

    Apenas PDF é suportado.

    Assumimos que o manual tem uma estrutura hierárquica de seções, usando os títulos das seções inferiores como unidade básica para fragmentação. Assim, figuras e tabelas na mesma seção não serão separadas, o que pode resultar em fragmentos maiores.

    `, - naive: `

    Os formatos de arquivo suportados são DOCX, XLSX, XLS (Excel 97-2003), PPT, PDF, TXT, JPEG, JPG, PNG, TIF, GIF, CSV, JSON, EML, HTML.

    + naive: `

    Os formatos de arquivo suportados são MD, MDX, DOCX, XLSX, XLS (Excel 97-2003), PPT, PDF, TXT, JPEG, JPG, PNG, TIF, GIF, CSV, JSON, EML, HTML.

    Este método fragmenta arquivos de maneira 'simples':

  • Usa um modelo de detecção visual para dividir os textos em segmentos menores.
  • @@ -451,6 +451,7 @@ export default { }, setting: { profile: 'Perfil', + avatar: 'Avatar', profileDescription: 'Atualize sua foto e detalhes pessoais aqui.', maxTokens: 'Máximo de Tokens', maxTokensMessage: 'Máximo de Tokens é obrigatório', diff --git a/web/src/locales/vi.ts b/web/src/locales/vi.ts index 29f9aabc91b..1b643bab346 100644 --- a/web/src/locales/vi.ts +++ b/web/src/locales/vi.ts @@ -144,7 +144,7 @@ export default { toMessage: 'Thiếu số trang kết thúc (được loại trừ)', layoutRecognize: 'Nhận dạng bố cục', layoutRecognizeTip: - 'Sử dụng các mô hình trực quan để phân tích bố cục nhằm xác định tốt hơn cấu trúc tài liệu, tìm vị trí của tiêu đề, khối văn bản, hình ảnh và bảng. Nếu không có tính năng này, chỉ có thể lấy được văn bản thuần của PDF.', + 'Sử dụng các mô hình trực quan để phân tích bố cục nhằm xác định tốt hơn cấu trúc tài liệu, tìm vị trí của tiêu đề, khối văn bản, hình ảnh và bảng. Nếu không có tính năng này, chỉ có thể lấy được văn bản thuần của PDF. Để biết thêm thông tin, hãy xem https://ragflow.io/docs/dev/select_pdf_parser.', taskPageSize: 'Kích thước trang tác vụ', taskPageSizeMessage: 'Vui lòng nhập kích thước trang tác vụ của bạn!', taskPageSizeTip: `Nếu sử dụng nhận dạng bố cục, tệp PDF sẽ được chia thành các nhóm trang liên tiếp. Phân tích bố cục sẽ được thực hiện song song giữa các nhóm để tăng tốc độ xử lý. 'Kích thước trang tác vụ' xác định kích thước của các nhóm. Kích thước trang càng lớn, khả năng chia tách văn bản liên tục giữa các trang thành các khối khác nhau càng thấp.`, @@ -161,16 +161,16 @@ export default { cancel: 'Hủy bỏ', rerankModel: 'Mô hình xếp hạng lại', rerankPlaceholder: 'Vui lòng chọn', - rerankTip: `Nếu để trống, RAGFlow sẽ sử dụng kết hợp giữa độ tương đồng từ khóa được trọng số và độ tương đồng vectơ cosin được trọng số; nếu chọn mô hình xếp hạng lại, điểm xếp hạng được tính lại sẽ thay thế độ tương đồng vectơ cosin được trọng số.`, + rerankTip: `Tùy chọn. Nếu để trống, RAGFlow sẽ sử dụng kết hợp giữa độ tương đồng từ khóa có trọng số và độ tương đồng cosine vector có trọng số; nếu chọn mô hình rerank, điểm rerank có trọng số sẽ thay thế độ tương đồng cosine vector có trọng số. Xin lưu ý rằng việc sử dụng mô hình rerank sẽ làm tăng đáng kể thời gian phản hồi của hệ thống. Nếu bạn muốn sử dụng mô hình rerank, hãy đảm bảo sử dụng SaaS reranker; nếu bạn muốn sử dụng mô hình rerank triển khai cục bộ, hãy khởi động RAGFlow bằng docker-compose-gpu.yml.`, topK: 'Top-K', topKTip: `Sử dụng cùng với Rerank model, thiết lập này xác định số lượng đoạn văn cần gửi đến mô hình reranking được chỉ định.`, delimiter: 'Dấu phân cách cho phân đoạn văn bản', html4excel: 'Excel sang HTML', html4excelTip: `Sử dụng cùng với phương pháp cắt khúc General. Khi chưa được bật, tệp bảng tính (XLSX, XLS (Excel 97-2003)) sẽ được phân tích theo dòng thành các cặp khóa-giá trị. Khi bật, tệp bảng tính sẽ được phân tích thành bảng HTML. Nếu bảng gốc vượt quá 12 dòng, hệ thống sẽ tự động chia thành nhiều bảng HTML mỗi 12 dòng. Để biết thêm thông tin, vui lòng xem https://ragflow.io/docs/dev/enable_excel2html.`, autoKeywords: 'Từ khóa tự động', - autoKeywordsTip: `Tự động trích xuất N từ khóa cho mỗi khối để tăng thứ hạng của chúng trong các truy vấn chứa các từ khóa đó. Lưu ý rằng các token bổ sung sẽ được tiêu thụ bởi mô hình trò chuyện được chỉ định trong "Cài đặt mô hình hệ thống". Bạn có thể kiểm tra hoặc cập nhật các từ khóa đã thêm cho một khối từ danh sách khối.`, + autoKeywordsTip: `Tự động trích xuất N từ khóa cho mỗi khối để tăng thứ hạng của chúng trong các truy vấn chứa các từ khóa đó. Lưu ý rằng các token bổ sung sẽ được tiêu thụ bởi mô hình trò chuyện được chỉ định trong "Cài đặt mô hình hệ thống". Bạn có thể kiểm tra hoặc cập nhật các từ khóa đã thêm cho một khối từ danh sách khối. Để biết chi tiết, vui lòng xem https://ragflow.io/docs/dev/autokeyword_autoquestion.`, autoQuestions: 'Câu hỏi tự động', - autoQuestionsTip: `Để tăng điểm xếp hạng, hãy trích xuất N câu hỏi cho mỗi đoạn kiến thức bằng mô hình trò chuyện được xác định trong "Cài đặt mô hình hệ thống". Lưu ý rằng việc này sẽ tiêu tốn thêm token. Kết quả có thể được xem và chỉnh sửa trong danh sách các đoạn. Lỗi trích xuất câu hỏi sẽ không chặn quá trình phân đoạn; kết quả trống sẽ được thêm vào đoạn gốc.`, + autoQuestionsTip: `Để tăng điểm xếp hạng, hãy trích xuất N câu hỏi cho mỗi đoạn kiến thức bằng mô hình trò chuyện được xác định trong "Cài đặt mô hình hệ thống". Lưu ý rằng việc này sẽ tiêu tốn thêm token. Kết quả có thể được xem và chỉnh sửa trong danh sách các đoạn. Lỗi trích xuất câu hỏi sẽ không chặn quá trình phân đoạn; kết quả trống sẽ được thêm vào đoạn gốc. Để biết chi tiết, vui lòng xem https://ragflow.io/docs/dev/autokeyword_autoquestion.`, delimiterTip: `Hỗ trợ nhiều ký tự phân cách, và các ký tự phân cách nhiều ký tự được bao bọc bởi dấu . Ví dụ: nếu được cấu hình như thế này: "##"; thì văn bản sẽ được phân tách bởi dấu xuống dòng, hai dấu # và dấu chấm phẩy, sau đó được lắp ráp theo kích thước của "số token". Thiết lập các dấu phân cách chỉ sau khi hiểu cơ chế phân đoạn và phân khối văn bản.`, redo: `Bạn có muốn xóa các đoạn {{chunkNum}} hiện có không?`, knowledgeGraph: 'Đồ thị tri thức', @@ -214,7 +214,7 @@ export default { methodTitle: 'Mô tả phương thức phân khối', methodExamples: 'Ví dụ', methodExamplesDescription: - 'Các ảnh chụp màn hình sau được cung cấp để minh họa.', + 'Để giúp bạn hiểu rõ hơn, chúng tôi đã cung cấp ảnh chụp màn hình liên quan để tham khảo.', dialogueExamplesTitle: 'Ví dụ hội thoại', methodEmpty: 'Mô tả bằng hình ảnh các danh mục cơ sở kiến thức', book: `

    Các định dạng tệp được hỗ trợ là DOCX, PDF, TXT.

    @@ -231,7 +231,7 @@ export default {

  • Sử dụng mô hình nhận dạng thị giác để chia các văn bản thành các phân đoạn nhỏ hơn.
  • Sau đó, kết hợp các phân đoạn liền kề cho đến khi số lượng token vượt quá ngưỡng được chỉ định bởi 'Số token khối', tại thời điểm đó, một khối được tạo.
  • -

    Các định dạng tệp được hỗ trợ là DOCX, XLSX, XLS (Excel 97-2003), PPT, PDF, TXT, JPEG, JPG, PNG, TIF, GIF, CSV, JSON, EML, HTML.

    `, +

    Các định dạng tệp được hỗ trợ là MD, MDX, DOCX, XLSX, XLS (Excel 97-2003), PPT, PDF, TXT, JPEG, JPG, PNG, TIF, GIF, CSV, JSON, EML, HTML.

    `, paper: `

    Chỉ hỗ trợ tệp PDF.

    Bài báo sẽ được chia theo các phần, chẳng hạn như tóm tắt, 1.1, 1.2.

    Cách tiếp cận này cho phép LLM tóm tắt bài báo hiệu quả hơn và cung cấp các phản hồi toàn diện, dễ hiểu hơn. @@ -505,6 +505,7 @@ export default { }, setting: { profile: 'Hồ sơ', + avatar: 'Avatar', profileDescription: 'Cập nhật ảnh và thông tin cá nhân của bạn tại đây.', maxTokens: 'Token tối đa', maxTokensMessage: 'Token tối đa là bắt buộc', diff --git a/web/src/locales/zh-traditional.ts b/web/src/locales/zh-traditional.ts index 219386d2983..923dc1c0e25 100644 --- a/web/src/locales/zh-traditional.ts +++ b/web/src/locales/zh-traditional.ts @@ -143,7 +143,7 @@ export default { toMessage: '缺少結束頁碼(不包含)', layoutRecognize: 'PDF解析器', layoutRecognizeTip: - '使用視覺模型進行 PDF 布局分析,以更好地識別文檔結構,找到標題、文字塊、圖像和表格的位置。若選擇 Naive 選項,則只能取得 PDF 的純文字。請注意此功能僅適用於 PDF 文檔,對其他文檔不生效。', + '使用視覺模型進行 PDF 布局分析,以更好地識別文檔結構,找到標題、文字塊、圖像和表格的位置。若選擇 Naive 選項,則只能取得 PDF 的純文字。請注意此功能僅適用於 PDF 文檔,對其他文檔不生效。如需更多資訊,請參閱 https://ragflow.io/docs/dev/select_pdf_parser。', taskPageSize: '任務頁面大小', taskPageSizeMessage: '請輸入您的任務頁面大小!', taskPageSizeTip: `如果使用佈局識別,PDF 文件將被分成連續的組。佈局分析將在組之間並行執行,以提高處理速度。“任務頁面大小”決定組的大小。頁面大小越大,將頁面之間的連續文本分割成不同塊的機會就越低。`, @@ -160,7 +160,7 @@ export default { cancel: '取消', rerankModel: 'rerank模型', rerankPlaceholder: '請選擇', - rerankTip: `如果是空的。它使用查詢和塊的嵌入來構成矢量餘弦相似性。否則,它使用rerank評分代替矢量餘弦相似性。`, + rerankTip: `非必選項:若不選擇 rerank 模型,系統將默認採用關鍵詞相似度與向量餘弦相似度相結合的混合查詢方式;如果設定了 rerank 模型,則混合查詢中的向量相似度部分將被 rerank 打分替代。請注意:採用 rerank 模型會非常耗時。如需選用 rerank 模型,建議使用 SaaS 的 rerank 模型服務;如果你傾向使用本地部署的 rerank 模型,請務必確保你使用 docker-compose-gpu.yml 啟動 RAGFlow。`, topK: 'Top-K', topKTip: `與 Rerank 模型配合使用,用於設定傳給 Rerank 模型的文本塊數量。`, delimiter: `文字分段標識符`, @@ -169,9 +169,9 @@ export default { html4excel: '表格轉HTML', html4excelTip: `與 General 切片方法配合使用。未開啟狀態下,表格檔案(XLSX、XLS(Excel 97-2003)會按行解析為鍵值對。開啟後,表格檔案會被解析為 HTML 表格。若原始表格超過 12 行,系統會自動按每 12 行拆分為多個 HTML 表格。欲了解更多資訊,請參閱 https://ragflow.io/docs/dev/enable_excel2html。`, autoKeywords: '自動關鍵字', - autoKeywordsTip: `自動為每個文字區塊中提取 N 個關鍵詞,以提升查詢精度。請注意:此功能採用「系統模型設定」中設定的預設聊天模型提取關鍵詞,因此也會產生更多 Token 消耗。此外,你也可以手動更新生成的關鍵詞。`, + autoKeywordsTip: `自動為每個文字區塊中提取 N 個關鍵詞,以提升查詢精度。請注意:此功能採用「系統模型設定」中設定的預設聊天模型提取關鍵詞,因此也會產生更多 Token 消耗。此外,你也可以手動更新生成的關鍵詞。詳情請參見 https://ragflow.io/docs/dev/autokeyword_autoquestion。`, autoQuestions: '自動問題', - autoQuestionsTip: `為了提高排名分數,請使用「系統模型設定」中定義的聊天模型,為每個知識庫區塊提取 N 個問題。 請注意:這會消耗額外的 token。 結果可在區塊列表中查看和編輯。 問題提取錯誤不會阻止分塊過程; 空結果將被添加到原始區塊。 `, + autoQuestionsTip: `為了提高排名分數,請使用「系統模型設定」中定義的聊天模型,為每個知識庫區塊提取 N 個問題。 請注意:這會消耗額外的 token。 結果可在區塊列表中查看和編輯。 問題提取錯誤不會阻止分塊過程; 空結果將被添加到原始區塊。詳情請參見 https://ragflow.io/docs/dev/autokeyword_autoquestion。 `, redo: '是否清空已有 {{chunkNum}}個 chunk?', setMetaData: '設定元數據', pleaseInputJson: '請輸入JSON', @@ -231,7 +231,7 @@ export default { cancel: '取消', methodTitle: '分塊方法說明', methodExamples: '示例', - methodExamplesDescription: '提出以下屏幕截圖以促進理解。', + methodExamplesDescription: '為方便您理解,我們附上相關截圖供您參考。', dialogueExamplesTitle: '對話示例', methodEmpty: '這將顯示知識庫類別的可視化解釋', book: `

    支持的文件格式為DOCXPDFTXT

    @@ -246,7 +246,7 @@ export default { 我們假設手冊具有分層部分結構。我們使用最低的部分標題作為對文檔進行切片的樞軸。 因此,同一部分中的圖和表不會被分割,並且塊大小可能會很大。

    `, - naive: `

    支持的文件格式為DOCX、XLSX、XLS (Excel 97-2003)、PPT、PDF、TXT、JPEG、JPG、PNG、TIF、GIF、CSV、JSON、EML、HTML

    + naive: `

    支持的文件格式為MD、MDX、DOCX、XLSX、XLS (Excel 97-2003)、PPT、PDF、TXT、JPEG、JPG、PNG、TIF、GIF、CSV、JSON、EML、HTML

    此方法將簡單的方法應用於塊文件:

  • 系統將使用視覺檢測模型將連續文本分割成多個片段。
  • @@ -534,6 +534,7 @@ export default { }, setting: { profile: '概述', + avatar: '头像', profileDescription: '在此更新您的照片和個人詳細信息。', maxTokens: '最大token數', maxTokensMessage: '最大token數是必填項', @@ -761,7 +762,8 @@ export default { examples: '範例', to: '下一步', msg: '訊息', - messagePlaceholder: '訊息', + msgTip: '輸出上游組件的變數內容或自行輸入的文字。', + messagePlaceholder: '請輸入您的訊息內容,使用‘/’快速插入變數。', messageMsg: '請輸入訊息或刪除此欄位。', addField: '新增字段', addMessage: '新增訊息', @@ -786,7 +788,7 @@ export default { relevantDescription: `此元件用來判斷upstream的輸出是否與使用者最新的問題相關,『是』代表相關,『否』代表不相關。`, rewriteQuestionDescription: `此元件用於細化使用者的提問。通常,當使用者的原始提問無法從知識庫中檢索相關資訊時,此元件可協助您將問題變更為更符合知識庫表達方式的適當問題。`, messageDescription: - '此元件用於向使用者發送靜態訊息。您可以準備幾條訊息,這些訊息將隨機選擇。', + '此元件用來傳回工作流程最後產生的資料內容和原先設定的文字內容。', keywordDescription: `該組件用於從用戶的問題中提取關鍵字。 Top N指定需要提取的關鍵字數量。`, switchDescription: `該組件用於根據前面組件的輸出評估條件,並相應地引導執行流程。通過定義各種情況並指定操作,或在不滿足條件時採取默認操作,實現複雜的分支邏輯。`, wikipediaDescription: `此元件用於從 https://www.wikipedia.org/ 取得搜尋結果。通常,它作為知識庫的補充。 Top N 指定您需要調整的搜尋結果數。`, @@ -1162,6 +1164,9 @@ export default { codeDescription: '它允許開發人員編寫自訂 Python 邏輯。', inputVariables: '輸入變數', runningHintText: '正在運行...🕞', + openingSwitchTip: '您的用戶將在開始時看到此歡迎訊息。', + modeTip: '模式定義工作流程如何啟動。 ', + beginInputTip: `透過定義輸入參數,這些內容可以在後續流程中被其他元件存取。`, }, footer: { profile: '“保留所有權利 @ react”', diff --git a/web/src/locales/zh.ts b/web/src/locales/zh.ts index 67c7472138f..2bbb025c154 100644 --- a/web/src/locales/zh.ts +++ b/web/src/locales/zh.ts @@ -143,7 +143,7 @@ export default { toMessage: '缺少结束页码(不包含)', layoutRecognize: 'PDF解析器', layoutRecognizeTip: - '使用视觉模型进行 PDF 布局分析,以更好地识别文档结构,找到标题、文本块、图像和表格的位置。 如果选择 Naive 选项,则只能获取 PDF 的纯文本。请注意该功能只适用于 PDF 文档,对其他文档不生效。', + '使用视觉模型进行 PDF 布局分析,以更好地识别文档结构,找到标题、文本块、图像和表格的位置。 如果选择 Naive 选项,则只能获取 PDF 的纯文本。请注意该功能只适用于 PDF 文档,对其他文档不生效。欲了解更多信息,请参阅 https://ragflow.io/docs/dev/select_pdf_parser。', taskPageSize: '任务页面大小', taskPageSizeMessage: '请输入您的任务页面大小!', taskPageSizeTip: `如果使用布局识别,PDF 文件将被分成连续的组。 布局分析将在组之间并行执行,以提高处理速度。 “任务页面大小”决定组的大小。 页面大小越大,将页面之间的连续文本分割成不同块的机会就越低。`, @@ -160,7 +160,7 @@ export default { cancel: '取消', rerankModel: 'Rerank模型', rerankPlaceholder: '请选择', - rerankTip: `非必选项:若不选择 rerank 模型,系统将默认采用关键词相似度与向量余弦相似度相结合的混合查询方式;如果设置了 rerank 模型,则混合查询中的向量相似度部分将被 rerank 打分替代。请注意:采用 rerank 模型会非常耗时。`, + rerankTip: `非必选项:若不选择 rerank 模型,系统将默认采用关键词相似度与向量余弦相似度相结合的混合查询方式;如果设置了 rerank 模型,则混合查询中的向量相似度部分将被 rerank 打分替代。请注意:采用 rerank 模型会非常耗时。如需选用 rerank 模型,建议使用 SaaS 的 rerank 模型服务;如果你倾向使用本地部署的 rerank 模型,请务必确保你使用 docker-compose-gpu.yml 启动 RAGFlow。`, topK: 'Top-K', topKTip: `与 Rerank 模型配合使用,用于设置传给 Rerank 模型的文本块数量。`, delimiter: `文本分段标识符`, @@ -169,9 +169,9 @@ export default { html4excel: '表格转HTML', html4excelTip: `与 General 切片方法配合使用。未开启状态下,表格文件(XLSX、XLS(Excel 97-2003))会按行解析为键值对。开启后,表格文件会被解析为 HTML 表格。若原始表格超过 12 行,系统会自动按每 12 行拆分为多个 HTML 表格。欲了解更多详情,请参阅 https://ragflow.io/docs/dev/enable_excel2html。`, autoKeywords: '自动关键词提取', - autoKeywordsTip: `自动为每个文本块中提取 N 个关键词,用以提升查询精度。请注意:该功能采用“系统模型设置”中设置的默认聊天模型提取关键词,因此也会产生更多 Token 消耗。另外,你也可以手动更新生成的关键词。`, + autoKeywordsTip: `自动为每个文本块中提取 N 个关键词,用以提升查询精度。请注意:该功能采用“系统模型设置”中设置的默认聊天模型提取关键词,因此也会产生更多 Token 消耗。另外,你也可以手动更新生成的关键词。详情请见 https://ragflow.io/docs/dev/autokeyword_autoquestion。`, autoQuestions: '自动问题提取', - autoQuestionsTip: `利用“系统模型设置”中设置的 chat model 对知识库的每个文本块提取 N 个问题以提高其排名得分。请注意,开启后将消耗额外的 token。您可以在块列表中查看、编辑结果。如果自动问题提取发生错误,不会妨碍整个分块过程,只会将空结果添加到原始文本块。`, + autoQuestionsTip: `利用“系统模型设置”中设置的 chat model 对知识库的每个文本块提取 N 个问题以提高其排名得分。请注意,开启后将消耗额外的 token。您可以在块列表中查看、编辑结果。如果自动问题提取发生错误,不会妨碍整个分块过程,只会将空结果添加到原始文本块。详情请见 https://ragflow.io/docs/dev/autokeyword_autoquestion。`, redo: '是否清空已有 {{chunkNum}}个 chunk?', setMetaData: '设置元数据', pleaseInputJson: '请输入JSON', @@ -232,7 +232,8 @@ export default { cancel: '取消', methodTitle: '分块方法说明', methodExamples: '示例', - methodExamplesDescription: '提出以下屏幕截图以促进理解。', + methodExamplesDescription: + '为帮助您更好地理解,我们提供了相关截图供您参考。', dialogueExamplesTitle: '对话示例', methodEmpty: '这将显示知识库类别的可视化解释', book: `

    支持的文件格式为DOCXPDFTXT

    @@ -247,7 +248,7 @@ export default { 我们假设手册具有分层部分结构。 我们使用最低的部分标题作为对文档进行切片的枢轴。 因此,同一部分中的图和表不会被分割,并且块大小可能会很大。

    `, - naive: `

    支持的文件格式为DOCX、XLSX、XLS (Excel 97-2003)、PPT、PDF、TXT、JPEG、JPG、PNG、TIF、GIF、CSV、JSON、EML、HTML

    + naive: `

    支持的文件格式为MD、MDX、DOCX、XLSX、XLS (Excel 97-2003)、PPT、PDF、TXT、JPEG、JPG、PNG、TIF、GIF、CSV、JSON、EML、HTML

    此方法将简单的方法应用于块文件:

  • 系统将使用视觉检测模型将连续文本分割成多个片段。
  • @@ -462,7 +463,8 @@ General:实体和关系提取提示来自 GitHub - microsoft/graphrag:基于 modelTip: '大语言聊天模型', modelMessage: '请选择', modelEnabledTools: '可用的工具', - modelEnabledToolsTip: '请选择一个或多个可供该模型所使用的工具。仅对支持工具调用的模型生效。', + modelEnabledToolsTip: + '请选择一个或多个可供该模型所使用的工具。仅对支持工具调用的模型生效。', freedom: '自由度', improvise: '即兴创作', precise: '精确', @@ -553,6 +555,7 @@ General:实体和关系提取提示来自 GitHub - microsoft/graphrag:基于 }, setting: { profile: '概要', + avatar: '头像', profileDescription: '在此更新您的照片和个人详细信息。', maxTokens: '最大token数', maxTokensMessage: '最大token数是必填项', @@ -788,7 +791,8 @@ General:实体和关系提取提示来自 GitHub - microsoft/graphrag:基于 examples: '示例', to: '下一步', msg: '消息', - messagePlaceholder: '消息', + msgTip: '输出上游组件的变量内容或者自己输入的文本。', + messagePlaceholder: '请输入您的消息内容,使用‘/’快速插入变量。', messageMsg: '请输入消息或删除此字段。', addField: '新增字段', addMessage: '新增消息', @@ -812,7 +816,7 @@ General:实体和关系提取提示来自 GitHub - microsoft/graphrag:基于 relevantDescription: `该组件用来判断upstream的输出是否与用户最新的问题相关,‘是’代表相关,‘否’代表不相关。`, rewriteQuestionDescription: `此组件用于细化用户的提问。通常,当用户的原始提问无法从知识库中检索到相关信息时,此组件可帮助您将问题更改为更符合知识库表达方式的适当问题。`, messageDescription: - '此组件用于向用户发送静态信息。您可以准备几条消息,这些消息将被随机选择。', + '该组件用来返回工作流最后产生的数据内容和原先设置的文本内容。', keywordDescription: `该组件用于从用户的问题中提取关键词。Top N指定需要提取的关键词数量。`, switchDescription: `该组件用于根据前面组件的输出评估条件,并相应地引导执行流程。通过定义各种情况并指定操作,或在不满足条件时采取默认操作,实现复杂的分支逻辑。`, wikipediaDescription: `此组件用于从 https://www.wikipedia.org/ 获取搜索结果。通常,它作为知识库的补充。Top N 指定您需要调整的搜索结果数量。`, @@ -1224,6 +1228,14 @@ General:实体和关系提取提示来自 GitHub - microsoft/graphrag:基于 inputVariables: '输入变量', addVariable: '新增变量', runningHintText: '正在运行中...🕞', + openingSwitch: '开场白开关', + openingCopy: '开场白文案', + openingSwitchTip: '您的用户将在开始时看到此欢迎消息。', + modeTip: '模式定义了工作流的启动方式。', + beginInputTip: '通过定义输入参数,此内容可以被后续流程中的其他组件访问。', + query: '查询变量', + agent: 'Agent', + agentDescription: '构建具备推理、工具调用和多智能体协同的智能体组件。', }, footer: { profile: 'All rights reserved @ React', @@ -1235,11 +1247,11 @@ General:实体和关系提取提示来自 GitHub - microsoft/graphrag:基于 }, llmTools: { bad_calculator: { - name: "计算器", - description: "用于计算两个数的和的工具(会给出错误答案)", + name: '计算器', + description: '用于计算两个数的和的工具(会给出错误答案)', params: { - a: "第一个数", - b: "第二个数", + a: '第一个数', + b: '第二个数', }, }, }, diff --git a/web/src/pages/add-knowledge/components/knowledge-testing/index.tsx b/web/src/pages/add-knowledge/components/knowledge-testing/index.tsx index e5a22f13597..d7140c92618 100644 --- a/web/src/pages/add-knowledge/components/knowledge-testing/index.tsx +++ b/web/src/pages/add-knowledge/components/knowledge-testing/index.tsx @@ -1,13 +1,19 @@ -import { useTestChunkRetrieval } from '@/hooks/knowledge-hooks'; +import { + useTestChunkAllRetrieval, + useTestChunkRetrieval, +} from '@/hooks/knowledge-hooks'; import { Flex, Form } from 'antd'; import TestingControl from './testing-control'; import TestingResult from './testing-result'; +import { useState } from 'react'; import styles from './index.less'; const KnowledgeTesting = () => { const [form] = Form.useForm(); const { testChunk } = useTestChunkRetrieval(); + const { testChunkAll } = useTestChunkAllRetrieval(); + const [selectedDocumentIds, setSelectedDocumentIds] = useState([]); const handleTesting = async (documentIds: string[] = []) => { const values = await form.validateFields(); @@ -16,6 +22,12 @@ const KnowledgeTesting = () => { doc_ids: Array.isArray(documentIds) ? documentIds : [], vector_similarity_weight: 1 - values.vector_similarity_weight, }); + + testChunkAll({ + ...values, + doc_ids: [], + vector_similarity_weight: 1 - values.vector_similarity_weight, + }); }; return ( @@ -23,8 +35,13 @@ const KnowledgeTesting = () => { - + ); }; diff --git a/web/src/pages/add-knowledge/components/knowledge-testing/testing-control/index.tsx b/web/src/pages/add-knowledge/components/knowledge-testing/testing-control/index.tsx index 18f347051df..8efb4c195ef 100644 --- a/web/src/pages/add-knowledge/components/knowledge-testing/testing-control/index.tsx +++ b/web/src/pages/add-knowledge/components/knowledge-testing/testing-control/index.tsx @@ -18,10 +18,15 @@ type FieldType = { interface IProps { form: FormInstance; - handleTesting: () => Promise; + handleTesting: (documentIds?: string[]) => Promise; + selectedDocumentIds: string[]; } -const TestingControl = ({ form, handleTesting }: IProps) => { +const TestingControl = ({ + form, + handleTesting, + selectedDocumentIds, +}: IProps) => { const question = Form.useWatch('question', { form, preserve: true }); const loading = useChunkIsTesting(); const { t } = useTranslate('knowledgeDetails'); @@ -29,6 +34,10 @@ const TestingControl = ({ form, handleTesting }: IProps) => { const buttonDisabled = !question || (typeof question === 'string' && question.trim() === ''); + const onClick = () => { + handleTesting(selectedDocumentIds); + }; + return (
    @@ -53,7 +62,7 @@ const TestingControl = ({ form, handleTesting }: IProps) => {
    ); } + +export const EmailNode = memo(InnerEmailNode); diff --git a/web/src/pages/agent/canvas/node/generate-node.tsx b/web/src/pages/agent/canvas/node/generate-node.tsx index 255eccd993a..8ffbbd79c77 100644 --- a/web/src/pages/agent/canvas/node/generate-node.tsx +++ b/web/src/pages/agent/canvas/node/generate-node.tsx @@ -4,11 +4,12 @@ import { IGenerateNode } from '@/interfaces/database/flow'; import { Handle, NodeProps, Position } from '@xyflow/react'; import classNames from 'classnames'; import { get } from 'lodash'; +import { memo } from 'react'; import { LeftHandleStyle, RightHandleStyle } from './handle-icon'; import styles from './index.less'; import NodeHeader from './node-header'; -export function GenerateNode({ +export function InnerGenerateNode({ id, data, isConnectable = true, @@ -55,3 +56,5 @@ export function GenerateNode({ ); } + +export const GenerateNode = memo(InnerGenerateNode); diff --git a/web/src/pages/agent/canvas/node/handle.tsx b/web/src/pages/agent/canvas/node/handle.tsx new file mode 100644 index 00000000000..ad76a87eb63 --- /dev/null +++ b/web/src/pages/agent/canvas/node/handle.tsx @@ -0,0 +1,41 @@ +import { cn } from '@/lib/utils'; +import { Handle, HandleProps } from '@xyflow/react'; +import { Plus } from 'lucide-react'; +import { useMemo } from 'react'; +import { HandleContext } from '../../context'; +import { NextStepDropdown } from './dropdown/next-step-dropdown'; + +export function CommonHandle({ + className, + nodeId, + ...props +}: HandleProps & { nodeId: string }) { + const value = useMemo( + () => ({ + nodeId, + id: props.id, + type: props.type, + position: props.position, + }), + [nodeId, props.id, props.position, props.type], + ); + + return ( + + + { + e.stopPropagation(); + }} + > + + + + + ); +} diff --git a/web/src/pages/agent/canvas/node/index.tsx b/web/src/pages/agent/canvas/node/index.tsx index 32191f5ccc7..b68edf931aa 100644 --- a/web/src/pages/agent/canvas/node/index.tsx +++ b/web/src/pages/agent/canvas/node/index.tsx @@ -1,45 +1,43 @@ -import { useTheme } from '@/components/theme-provider'; import { IRagNode } from '@/interfaces/database/flow'; -import { Handle, NodeProps, Position } from '@xyflow/react'; -import classNames from 'classnames'; +import { NodeProps, Position } from '@xyflow/react'; +import { memo } from 'react'; +import { NodeHandleId } from '../../constant'; +import { CommonHandle } from './handle'; import { LeftHandleStyle, RightHandleStyle } from './handle-icon'; -import styles from './index.less'; import NodeHeader from './node-header'; +import { NodeWrapper } from './node-wrapper'; +import { ToolBar } from './toolbar'; -export function RagNode({ +function InnerRagNode({ id, data, isConnectable = true, selected, }: NodeProps) { - const { theme } = useTheme(); return ( -
    - - - -
    + + + + + + + ); } + +export const RagNode = memo(InnerRagNode); diff --git a/web/src/pages/agent/canvas/node/invoke-node.tsx b/web/src/pages/agent/canvas/node/invoke-node.tsx index 42d109f3d06..cf1e28d0264 100644 --- a/web/src/pages/agent/canvas/node/invoke-node.tsx +++ b/web/src/pages/agent/canvas/node/invoke-node.tsx @@ -4,12 +4,13 @@ import { Handle, NodeProps, Position } from '@xyflow/react'; import { Flex } from 'antd'; import classNames from 'classnames'; import { get } from 'lodash'; +import { memo } from 'react'; import { useTranslation } from 'react-i18next'; import { LeftHandleStyle, RightHandleStyle } from './handle-icon'; import styles from './index.less'; import NodeHeader from './node-header'; -export function InvokeNode({ +function InnerInvokeNode({ id, data, isConnectable = true, @@ -57,3 +58,5 @@ export function InvokeNode({ ); } + +export const InvokeNode = memo(InnerInvokeNode); diff --git a/web/src/pages/agent/canvas/node/iteration-node.tsx b/web/src/pages/agent/canvas/node/iteration-node.tsx index c15b4fc6c6a..53a84835ded 100644 --- a/web/src/pages/agent/canvas/node/iteration-node.tsx +++ b/web/src/pages/agent/canvas/node/iteration-node.tsx @@ -6,6 +6,7 @@ import { import { cn } from '@/lib/utils'; import { Handle, NodeProps, NodeResizeControl, Position } from '@xyflow/react'; import { ListRestart } from 'lucide-react'; +import { memo } from 'react'; import { LeftHandleStyle, RightHandleStyle } from './handle-icon'; import styles from './index.less'; import NodeHeader from './node-header'; @@ -43,7 +44,7 @@ const controlStyle = { cursor: 'nwse-resize', }; -export function IterationNode({ +export function InnerIterationNode({ id, data, isConnectable = true, @@ -98,7 +99,7 @@ export function IterationNode({ ); } -export function IterationStartNode({ +function InnerIterationStartNode({ isConnectable = true, selected, }: NodeProps) { @@ -125,3 +126,7 @@ export function IterationStartNode({ ); } + +export const IterationStartNode = memo(InnerIterationStartNode); + +export const IterationNode = memo(InnerIterationNode); diff --git a/web/src/pages/agent/canvas/node/keyword-node.tsx b/web/src/pages/agent/canvas/node/keyword-node.tsx index f607d431780..012dcf26cd7 100644 --- a/web/src/pages/agent/canvas/node/keyword-node.tsx +++ b/web/src/pages/agent/canvas/node/keyword-node.tsx @@ -4,11 +4,12 @@ import { IKeywordNode } from '@/interfaces/database/flow'; import { Handle, NodeProps, Position } from '@xyflow/react'; import classNames from 'classnames'; import { get } from 'lodash'; +import { memo } from 'react'; import { LeftHandleStyle, RightHandleStyle } from './handle-icon'; import styles from './index.less'; import NodeHeader from './node-header'; -export function KeywordNode({ +export function InnerKeywordNode({ id, data, isConnectable = true, @@ -55,3 +56,5 @@ export function KeywordNode({ ); } + +export const KeywordNode = memo(InnerKeywordNode); diff --git a/web/src/pages/agent/canvas/node/logic-node.tsx b/web/src/pages/agent/canvas/node/logic-node.tsx index 28215617b4f..a98efc3719d 100644 --- a/web/src/pages/agent/canvas/node/logic-node.tsx +++ b/web/src/pages/agent/canvas/node/logic-node.tsx @@ -1,45 +1,41 @@ -import { useTheme } from '@/components/theme-provider'; import { ILogicNode } from '@/interfaces/database/flow'; -import { Handle, NodeProps, Position } from '@xyflow/react'; -import classNames from 'classnames'; +import { NodeProps, Position } from '@xyflow/react'; +import { memo } from 'react'; +import { CommonHandle } from './handle'; import { LeftHandleStyle, RightHandleStyle } from './handle-icon'; -import styles from './index.less'; import NodeHeader from './node-header'; +import { NodeWrapper } from './node-wrapper'; +import { ToolBar } from './toolbar'; -export function LogicNode({ +export function InnerLogicNode({ id, data, isConnectable = true, selected, }: NodeProps) { - const { theme } = useTheme(); return ( -
    - - - -
    + + + + + + + ); } + +export const LogicNode = memo(InnerLogicNode); diff --git a/web/src/pages/agent/canvas/node/message-node.tsx b/web/src/pages/agent/canvas/node/message-node.tsx index 5b3a1736eef..8d4c3199d71 100644 --- a/web/src/pages/agent/canvas/node/message-node.tsx +++ b/web/src/pages/agent/canvas/node/message-node.tsx @@ -1,65 +1,65 @@ -import { useTheme } from '@/components/theme-provider'; import { IMessageNode } from '@/interfaces/database/flow'; -import { Handle, NodeProps, Position } from '@xyflow/react'; +import { NodeProps, Position } from '@xyflow/react'; import { Flex } from 'antd'; import classNames from 'classnames'; import { get } from 'lodash'; +import { memo } from 'react'; +import { NodeHandleId } from '../../constant'; +import { CommonHandle } from './handle'; import { LeftHandleStyle, RightHandleStyle } from './handle-icon'; import styles from './index.less'; import NodeHeader from './node-header'; +import { NodeWrapper } from './node-wrapper'; +import { ToolBar } from './toolbar'; -export function MessageNode({ +function InnerMessageNode({ id, data, isConnectable = true, selected, }: NodeProps) { const messages: string[] = get(data, 'form.messages', []); - const { theme } = useTheme(); return ( -
    - - - 0, - })} - > + + + + + 0, + })} + > - - {messages.map((message, idx) => { - return ( -
    - {message} -
    - ); - })} -
    -
    + + {messages.map((message, idx) => { + return ( +
    + {message} +
    + ); + })} +
    + + ); } + +export const MessageNode = memo(InnerMessageNode); diff --git a/web/src/pages/agent/canvas/node/node-header.tsx b/web/src/pages/agent/canvas/node/node-header.tsx index 99a37dc1e6f..9647af1ed56 100644 --- a/web/src/pages/agent/canvas/node/node-header.tsx +++ b/web/src/pages/agent/canvas/node/node-header.tsx @@ -1,13 +1,7 @@ -import { useTranslate } from '@/hooks/common-hooks'; -import { Flex } from 'antd'; -import { Play } from 'lucide-react'; -import { Operator, operatorMap } from '../../constant'; +import { cn } from '@/lib/utils'; +import { memo } from 'react'; +import { Operator } from '../../constant'; import OperatorIcon from '../../operator-icon'; -import { needsSingleStepDebugging } from '../../utils'; -import NodeDropdown from './dropdown'; -import { NextNodePopover } from './popover'; - -import { RunTooltip } from '../../flow-tooltip'; interface IProps { id: string; label: string; @@ -17,57 +11,24 @@ interface IProps { wrapperClassName?: string; } -const ExcludedRunStateOperators = [Operator.Answer]; - -export function RunStatus({ id, name, label }: IProps) { - const { t } = useTranslate('flow'); - return ( -
    - {needsSingleStepDebugging(label) && ( - - - // data-play is used to trigger single step debugging - )} - - - {t('operationResults')} - - -
    - ); -} - -const NodeHeader = ({ +const InnerNodeHeader = ({ label, - id, name, - gap = 4, className, wrapperClassName, }: IProps) => { return ( -
    - {!ExcludedRunStateOperators.includes(label as Operator) && ( - - )} - - +
    +
    + {name} - - +
    ); }; +const NodeHeader = memo(InnerNodeHeader); + export default NodeHeader; diff --git a/web/src/pages/agent/canvas/node/node-wrapper.tsx b/web/src/pages/agent/canvas/node/node-wrapper.tsx new file mode 100644 index 00000000000..e0010cad28e --- /dev/null +++ b/web/src/pages/agent/canvas/node/node-wrapper.tsx @@ -0,0 +1,18 @@ +import { cn } from '@/lib/utils'; +import { HTMLAttributes, PropsWithChildren } from 'react'; + +export function NodeWrapper({ + children, + className, +}: PropsWithChildren & HTMLAttributes) { + return ( +
    + {children} +
    + ); +} diff --git a/web/src/pages/agent/canvas/node/note-node.tsx b/web/src/pages/agent/canvas/node/note-node.tsx index 1917a81509e..237942133e0 100644 --- a/web/src/pages/agent/canvas/node/note-node.tsx +++ b/web/src/pages/agent/canvas/node/note-node.tsx @@ -8,10 +8,8 @@ import { useTheme } from '@/components/theme-provider'; import { INoteNode } from '@/interfaces/database/flow'; import { memo, useEffect } from 'react'; import { useTranslation } from 'react-i18next'; -import { - useHandleFormValuesChange, - useHandleNodeNameChange, -} from '../../hooks'; +import { useHandleNodeNameChange } from '../../hooks'; +import { useHandleFormValuesChange } from '../../hooks/use-watch-form-change'; import styles from './index.less'; const { TextArea } = Input; diff --git a/web/src/pages/agent/canvas/node/popover.tsx b/web/src/pages/agent/canvas/node/popover.tsx index 342ce40ebab..d445386568e 100644 --- a/web/src/pages/agent/canvas/node/popover.tsx +++ b/web/src/pages/agent/canvas/node/popover.tsx @@ -1,4 +1,3 @@ -import { useFetchFlow } from '@/hooks/flow-hooks'; import get from 'lodash/get'; import React, { MouseEventHandler, useCallback, useMemo } from 'react'; import JsonView from 'react18-json-view'; @@ -20,6 +19,7 @@ import { TableRow, } from '@/components/ui/table'; import { useTranslate } from '@/hooks/common-hooks'; +import { useFetchAgent } from '@/hooks/use-agent-request'; import { useGetComponentLabelByValue } from '../../hooks/use-get-begin-query'; interface IProps extends React.PropsWithChildren { @@ -30,7 +30,7 @@ interface IProps extends React.PropsWithChildren { export function NextNodePopover({ children, nodeId, name }: IProps) { const { t } = useTranslate('flow'); - const { data } = useFetchFlow(); + const { data } = useFetchAgent(); const { theme } = useTheme(); const component = useMemo(() => { return get(data, ['dsl', 'components', nodeId], {}); diff --git a/web/src/pages/agent/canvas/node/relevant-node.tsx b/web/src/pages/agent/canvas/node/relevant-node.tsx index acc098d69b1..410a7accdd1 100644 --- a/web/src/pages/agent/canvas/node/relevant-node.tsx +++ b/web/src/pages/agent/canvas/node/relevant-node.tsx @@ -6,11 +6,12 @@ import { RightHandleStyle } from './handle-icon'; import { useTheme } from '@/components/theme-provider'; import { IRelevantNode } from '@/interfaces/database/flow'; import { get } from 'lodash'; +import { memo } from 'react'; import { useReplaceIdWithName } from '../../hooks'; import styles from './index.less'; import NodeHeader from './node-header'; -export function RelevantNode({ id, data, selected }: NodeProps) { +function InnerRelevantNode({ id, data, selected }: NodeProps) { const yes = get(data, 'form.yes'); const no = get(data, 'form.no'); const replaceIdWithName = useReplaceIdWithName(); @@ -68,3 +69,5 @@ export function RelevantNode({ id, data, selected }: NodeProps) {
    ); } + +export const RelevantNode = memo(InnerRelevantNode); diff --git a/web/src/pages/agent/canvas/node/retrieval-node.tsx b/web/src/pages/agent/canvas/node/retrieval-node.tsx index 0fd2760ede5..4187b2c44b4 100644 --- a/web/src/pages/agent/canvas/node/retrieval-node.tsx +++ b/web/src/pages/agent/canvas/node/retrieval-node.tsx @@ -1,24 +1,26 @@ -import { useTheme } from '@/components/theme-provider'; import { useFetchKnowledgeList } from '@/hooks/knowledge-hooks'; import { IRetrievalNode } from '@/interfaces/database/flow'; import { UserOutlined } from '@ant-design/icons'; -import { Handle, NodeProps, Position } from '@xyflow/react'; +import { NodeProps, Position } from '@xyflow/react'; import { Avatar, Flex } from 'antd'; import classNames from 'classnames'; import { get } from 'lodash'; -import { useMemo } from 'react'; +import { memo, useMemo } from 'react'; +import { NodeHandleId } from '../../constant'; +import { CommonHandle } from './handle'; import { LeftHandleStyle, RightHandleStyle } from './handle-icon'; import styles from './index.less'; import NodeHeader from './node-header'; +import { NodeWrapper } from './node-wrapper'; +import { ToolBar } from './toolbar'; -export function RetrievalNode({ +function InnerRetrievalNode({ id, data, isConnectable = true, selected, }: NodeProps) { const knowledgeBaseIds: string[] = get(data, 'form.kb_ids', []); - const { theme } = useTheme(); const { list: knowledgeList } = useFetchKnowledgeList(true); const knowledgeBases = useMemo(() => { return knowledgeBaseIds.map((x) => { @@ -32,57 +34,56 @@ export function RetrievalNode({ }, [knowledgeList, knowledgeBaseIds]); return ( -
    - - - 0, - })} - > - - {knowledgeBases.map((knowledge) => { - return ( -
    - - } - src={knowledge.avatar} - /> - - {knowledge.name} + + + + + 0, + })} + > + + {knowledgeBases.map((knowledge) => { + return ( +
    + + } + src={knowledge.avatar} + /> + + {knowledge.name} + - -
    - ); - })} -
    -
    +
    + ); + })} + + + ); } + +export const RetrievalNode = memo(InnerRetrievalNode); diff --git a/web/src/pages/agent/canvas/node/rewrite-node.tsx b/web/src/pages/agent/canvas/node/rewrite-node.tsx index 093b2c80ea3..134899c8bea 100644 --- a/web/src/pages/agent/canvas/node/rewrite-node.tsx +++ b/web/src/pages/agent/canvas/node/rewrite-node.tsx @@ -4,11 +4,12 @@ import { IRewriteNode } from '@/interfaces/database/flow'; import { Handle, NodeProps, Position } from '@xyflow/react'; import classNames from 'classnames'; import { get } from 'lodash'; +import { memo } from 'react'; import { LeftHandleStyle, RightHandleStyle } from './handle-icon'; import styles from './index.less'; import NodeHeader from './node-header'; -export function RewriteNode({ +function InnerRewriteNode({ id, data, isConnectable = true, @@ -55,3 +56,5 @@ export function RewriteNode({ ); } + +export const RewriteNode = memo(InnerRewriteNode); diff --git a/web/src/pages/agent/canvas/node/switch-node.tsx b/web/src/pages/agent/canvas/node/switch-node.tsx index 860a0ba9618..c386731a0e0 100644 --- a/web/src/pages/agent/canvas/node/switch-node.tsx +++ b/web/src/pages/agent/canvas/node/switch-node.tsx @@ -1,13 +1,16 @@ -import { useTheme } from '@/components/theme-provider'; +import { IconFont } from '@/components/icon-font'; +import { Card, CardContent } from '@/components/ui/card'; import { ISwitchCondition, ISwitchNode } from '@/interfaces/database/flow'; -import { Handle, NodeProps, Position } from '@xyflow/react'; -import { Divider, Flex } from 'antd'; -import classNames from 'classnames'; +import { NodeProps, Position } from '@xyflow/react'; +import { memo, useCallback } from 'react'; +import { NodeHandleId, SwitchOperatorOptions } from '../../constant'; import { useGetComponentLabelByValue } from '../../hooks/use-get-begin-query'; +import { CommonHandle } from './handle'; import { RightHandleStyle } from './handle-icon'; -import { useBuildSwitchHandlePositions } from './hooks'; -import styles from './index.less'; import NodeHeader from './node-header'; +import { NodeWrapper } from './node-wrapper'; +import { ToolBar } from './toolbar'; +import { useBuildSwitchHandlePositions } from './use-build-switch-handle-positions'; const getConditionKey = (idx: number, length: number) => { if (idx === 0 && length !== 1) { @@ -28,87 +31,80 @@ const ConditionBlock = ({ }) => { const items = condition?.items ?? []; const getLabel = useGetComponentLabelByValue(nodeId); + + const renderOperatorIcon = useCallback((operator?: string) => { + const name = SwitchOperatorOptions.find((x) => x.value === operator)?.icon; + return ; + }, []); + return ( - - {items.map((x, idx) => ( -
    - -
    - {getLabel(x?.cpn_id)} -
    - {x?.operator} - - {x?.value} - -
    - {idx + 1 < items.length && ( - - {condition?.logical_operator} - - )} -
    - ))} -
    + + + {items.map((x, idx) => ( +
    +
    +
    + {getLabel(x?.cpn_id)} +
    + {renderOperatorIcon(x?.operator)} +
    {x?.value}
    +
    +
    + ))} +
    +
    ); }; -export function SwitchNode({ id, data, selected }: NodeProps) { +function InnerSwitchNode({ id, data, selected }: NodeProps) { const { positions } = useBuildSwitchHandlePositions({ data, id }); - const { theme } = useTheme(); return ( -
    - - - - {positions.map((position, idx) => { - return ( -
    - - - {idx < positions.length - 1 && position.text} - {getConditionKey(idx, positions.length)} - - {position.condition && ( - - )} - - -
    - ); - })} -
    -
    + + + + +
    + {positions.map((position, idx) => { + return ( +
    +
    +
    + + {idx < positions.length - 1 && + position.condition?.logical_operator?.toUpperCase()} + + {getConditionKey(idx, positions.length)} +
    + {position.condition && ( + + )} +
    + +
    + ); + })} +
    +
    +
    ); } + +export const SwitchNode = memo(InnerSwitchNode); diff --git a/web/src/pages/agent/canvas/node/template-node.tsx b/web/src/pages/agent/canvas/node/template-node.tsx index 971fbab3842..b204717ab2a 100644 --- a/web/src/pages/agent/canvas/node/template-node.tsx +++ b/web/src/pages/agent/canvas/node/template-node.tsx @@ -9,9 +9,10 @@ import { LeftHandleStyle, RightHandleStyle } from './handle-icon'; import NodeHeader from './node-header'; import { ITemplateNode } from '@/interfaces/database/flow'; +import { memo } from 'react'; import styles from './index.less'; -export function TemplateNode({ +function InnerTemplateNode({ id, data, isConnectable = true, @@ -73,3 +74,5 @@ export function TemplateNode({ ); } + +export const TemplateNode = memo(InnerTemplateNode); diff --git a/web/src/pages/agent/canvas/node/tool-node.tsx b/web/src/pages/agent/canvas/node/tool-node.tsx new file mode 100644 index 00000000000..ba3f621be70 --- /dev/null +++ b/web/src/pages/agent/canvas/node/tool-node.tsx @@ -0,0 +1,52 @@ +import { IAgentForm, IToolNode } from '@/interfaces/database/agent'; +import { Handle, NodeProps, Position } from '@xyflow/react'; +import { get } from 'lodash'; +import { memo, useCallback } from 'react'; +import { NodeHandleId } from '../../constant'; +import { ToolCard } from '../../form/agent-form/agent-tools'; +import useGraphStore from '../../store'; +import { NodeWrapper } from './node-wrapper'; + +function InnerToolNode({ + id, + data, + isConnectable = true, + selected, +}: NodeProps) { + const { edges, getNode } = useGraphStore((state) => state); + const upstreamAgentNodeId = edges.find((x) => x.target === id)?.source; + const upstreamAgentNode = getNode(upstreamAgentNodeId); + + const handleClick = useCallback(() => {}, []); + + const tools: IAgentForm['tools'] = get( + upstreamAgentNode, + 'data.form.tools', + [], + ); + + return ( + + +
      + {tools.map((x) => ( + + {x.component_name} + + ))} +
    +
    + ); +} + +export const ToolNode = memo(InnerToolNode); diff --git a/web/src/pages/agent/canvas/node/toolbar.tsx b/web/src/pages/agent/canvas/node/toolbar.tsx new file mode 100644 index 00000000000..53a5f8a115c --- /dev/null +++ b/web/src/pages/agent/canvas/node/toolbar.tsx @@ -0,0 +1,74 @@ +import { + TooltipContent, + TooltipNode, + TooltipTrigger, +} from '@/components/xyflow/tooltip-node'; +import { Position } from '@xyflow/react'; +import { Copy, Play, Trash2 } from 'lucide-react'; +import { MouseEventHandler, PropsWithChildren, useCallback } from 'react'; +import { Operator } from '../../constant'; +import { useDuplicateNode } from '../../hooks'; +import useGraphStore from '../../store'; + +function IconWrapper({ children }: PropsWithChildren) { + return ( +
    + {children} +
    + ); +} + +type ToolBarProps = { + selected?: boolean | undefined; + label: string; + id: string; +} & PropsWithChildren; + +export function ToolBar({ selected, children, label, id }: ToolBarProps) { + const deleteNodeById = useGraphStore((store) => store.deleteNodeById); + const deleteIterationNodeById = useGraphStore( + (store) => store.deleteIterationNodeById, + ); + + const deleteNode: MouseEventHandler = useCallback( + (e) => { + e.stopPropagation(); + if (label === Operator.Iteration) { + deleteIterationNodeById(id); + } else { + deleteNodeById(id); + } + }, + [deleteIterationNodeById, deleteNodeById, id, label], + ); + + const duplicateNode = useDuplicateNode(); + + const handleDuplicate: MouseEventHandler = useCallback( + (e) => { + e.stopPropagation(); + duplicateNode(id, label); + }, + [duplicateNode, id, label], + ); + + return ( + + {children} + + +
    + + + + + + + + + +
    +
    +
    + ); +} diff --git a/web/src/pages/agent/canvas/node/use-build-categorize-handle-positions.ts b/web/src/pages/agent/canvas/node/use-build-categorize-handle-positions.ts new file mode 100644 index 00000000000..e5249ae56eb --- /dev/null +++ b/web/src/pages/agent/canvas/node/use-build-categorize-handle-positions.ts @@ -0,0 +1,45 @@ +import { ICategorizeItemResult } from '@/interfaces/database/agent'; +import { RAGFlowNodeType } from '@/interfaces/database/flow'; +import { useUpdateNodeInternals } from '@xyflow/react'; +import { get } from 'lodash'; +import { useEffect, useMemo } from 'react'; + +export const useBuildCategorizeHandlePositions = ({ + data, + id, +}: { + id: string; + data: RAGFlowNodeType['data']; +}) => { + const updateNodeInternals = useUpdateNodeInternals(); + + const categoryData: ICategorizeItemResult = useMemo(() => { + return get(data, `form.category_description`, {}); + }, [data]); + + const positions = useMemo(() => { + const list: Array<{ + text: string; + top: number; + idx: number; + }> = []; + + Object.keys(categoryData) + .sort((a, b) => categoryData[a].index - categoryData[b].index) + .forEach((x, idx) => { + list.push({ + text: x, + idx, + top: idx === 0 ? 86 : list[idx - 1].top + 8 + 24, + }); + }); + + return list; + }, [categoryData]); + + useEffect(() => { + updateNodeInternals(id); + }, [id, updateNodeInternals, categoryData]); + + return { positions }; +}; diff --git a/web/src/pages/agent/canvas/node/hooks.ts b/web/src/pages/agent/canvas/node/use-build-switch-handle-positions.ts similarity index 51% rename from web/src/pages/agent/canvas/node/hooks.ts rename to web/src/pages/agent/canvas/node/use-build-switch-handle-positions.ts index fbea8f1668b..eb0d0c5f108 100644 --- a/web/src/pages/agent/canvas/node/hooks.ts +++ b/web/src/pages/agent/canvas/node/use-build-switch-handle-positions.ts @@ -1,55 +1,10 @@ +import { ISwitchCondition, RAGFlowNodeType } from '@/interfaces/database/flow'; import { useUpdateNodeInternals } from '@xyflow/react'; import get from 'lodash/get'; import { useEffect, useMemo } from 'react'; import { SwitchElseTo } from '../../constant'; - -import { - ICategorizeItemResult, - ISwitchCondition, - RAGFlowNodeType, -} from '@/interfaces/database/flow'; import { generateSwitchHandleText } from '../../utils'; -export const useBuildCategorizeHandlePositions = ({ - data, - id, -}: { - id: string; - data: RAGFlowNodeType['data']; -}) => { - const updateNodeInternals = useUpdateNodeInternals(); - - const categoryData: ICategorizeItemResult = useMemo(() => { - return get(data, `form.category_description`, {}); - }, [data]); - - const positions = useMemo(() => { - const list: Array<{ - text: string; - top: number; - idx: number; - }> = []; - - Object.keys(categoryData) - .sort((a, b) => categoryData[a].index - categoryData[b].index) - .forEach((x, idx) => { - list.push({ - text: x, - idx, - top: idx === 0 ? 98 + 20 : list[idx - 1].top + 8 + 26, - }); - }); - - return list; - }, [categoryData]); - - useEffect(() => { - updateNodeInternals(id); - }, [id, updateNodeInternals, categoryData]); - - return { positions }; -}; - export const useBuildSwitchHandlePositions = ({ data, id, @@ -63,6 +18,10 @@ export const useBuildSwitchHandlePositions = ({ return get(data, 'form.conditions', []); }, [data]); + useEffect(() => { + console.info('xxx0000'); + }, [conditions]); + const positions = useMemo(() => { const list: Array<{ text: string; @@ -72,13 +31,13 @@ export const useBuildSwitchHandlePositions = ({ }> = []; [...conditions, ''].forEach((x, idx) => { - let top = idx === 0 ? 58 + 20 : list[idx - 1].top + 32; // case number (Case 1) height + flex gap - if (idx - 1 >= 0) { + let top = idx === 0 ? 53 : list[idx - 1].top + 10 + 14; // case number (Case 1) height + flex gap + if (idx >= 1) { const previousItems = conditions[idx - 1]?.items ?? []; if (previousItems.length > 0) { - top += 12; // ConditionBlock padding - top += previousItems.length * 22; // condition variable height - top += (previousItems.length - 1) * 25; // operator height + // top += 12; // ConditionBlock padding + top += previousItems.length * 26; // condition variable height + // top += (previousItems.length - 1) * 25; // operator height } } diff --git a/web/src/pages/agent/chat/box.tsx b/web/src/pages/agent/chat/box.tsx new file mode 100644 index 00000000000..6daabd17e88 --- /dev/null +++ b/web/src/pages/agent/chat/box.tsx @@ -0,0 +1,91 @@ +import { MessageType } from '@/constants/chat'; +import { useGetFileIcon } from '@/pages/chat/hooks'; +import { buildMessageItemReference } from '@/pages/chat/utils'; +import { Spin } from 'antd'; + +import { useSendNextMessage } from './hooks'; + +import MessageInput from '@/components/message-input'; +import MessageItem from '@/components/next-message-item'; +import PdfDrawer from '@/components/pdf-drawer'; +import { useClickDrawer } from '@/components/pdf-drawer/hooks'; +import { useFetchAgent } from '@/hooks/use-agent-request'; +import { useFetchUserInfo } from '@/hooks/user-setting-hooks'; +import { buildMessageUuidWithRole } from '@/utils/chat'; + +const AgentChatBox = () => { + const { + sendLoading, + handleInputChange, + handlePressEnter, + value, + loading, + ref, + derivedMessages, + reference, + stopOutputMessage, + } = useSendNextMessage(); + + const { visible, hideModal, documentId, selectedChunk, clickDocumentButton } = + useClickDrawer(); + useGetFileIcon(); + const { data: userInfo } = useFetchUserInfo(); + const { data: canvasInfo } = useFetchAgent(); + + return ( + <> +
    +
    +
    + + {derivedMessages?.map((message, i) => { + return ( + + ); + })} + +
    +
    +
    + +
    + + + ); +}; + +export default AgentChatBox; diff --git a/web/src/pages/agent/chat/chat-sheet.tsx b/web/src/pages/agent/chat/chat-sheet.tsx new file mode 100644 index 00000000000..1050c460ff1 --- /dev/null +++ b/web/src/pages/agent/chat/chat-sheet.tsx @@ -0,0 +1,26 @@ +import { + Sheet, + SheetContent, + SheetHeader, + SheetTitle, +} from '@/components/ui/sheet'; +import { IModalProps } from '@/interfaces/common'; +import { cn } from '@/lib/utils'; +import AgentChatBox from './box'; + +export function ChatSheet({ hideModal }: IModalProps) { + return ( + + + e.preventDefault()} + > + + Are you absolutely sure? + + + + + ); +} diff --git a/web/src/pages/agent/chat/hooks.ts b/web/src/pages/agent/chat/hooks.ts new file mode 100644 index 00000000000..e813c90c479 --- /dev/null +++ b/web/src/pages/agent/chat/hooks.ts @@ -0,0 +1,181 @@ +import { MessageType } from '@/constants/chat'; +import { + useHandleMessageInputChange, + useSelectDerivedMessages, +} from '@/hooks/logic-hooks'; +import { useFetchAgent } from '@/hooks/use-agent-request'; +import { + IEventList, + IMessageEvent, + MessageEventType, + useSendMessageBySSE, +} from '@/hooks/use-send-message'; +import { Message } from '@/interfaces/database/chat'; +import i18n from '@/locales/config'; +import api from '@/utils/api'; +import { message } from 'antd'; +import { get } from 'lodash'; +import trim from 'lodash/trim'; +import { useCallback, useContext, useEffect, useMemo } from 'react'; +import { useParams } from 'umi'; +import { v4 as uuid } from 'uuid'; +import { BeginId } from '../constant'; +import { AgentChatLogContext } from '../context'; +import useGraphStore from '../store'; +import { receiveMessageError } from '../utils'; + +const antMessage = message; + +export const useSelectNextMessages = () => { + const { data: flowDetail, loading } = useFetchAgent(); + const reference = flowDetail.dsl.retrieval; + const { + derivedMessages, + ref, + addNewestQuestion, + addNewestAnswer, + removeLatestMessage, + removeMessageById, + removeMessagesAfterCurrentMessage, + } = useSelectDerivedMessages(); + + return { + reference, + loading, + derivedMessages, + ref, + addNewestQuestion, + addNewestAnswer, + removeLatestMessage, + removeMessageById, + removeMessagesAfterCurrentMessage, + }; +}; + +function findMessageFromList(eventList: IEventList) { + const messageEventList = eventList.filter( + (x) => x.event === MessageEventType.Message, + ) as IMessageEvent[]; + return { + id: messageEventList[0]?.message_id, + content: messageEventList.map((x) => x.data.content).join(''), + }; +} + +const useGetBeginNodePrologue = () => { + const getNode = useGraphStore((state) => state.getNode); + + return useMemo(() => { + const formData = get(getNode(BeginId), 'data.form', {}); + if (formData?.enablePrologue) { + return formData?.prologue; + } + }, [getNode]); +}; + +export const useSendNextMessage = () => { + const { + reference, + loading, + derivedMessages, + ref, + addNewestQuestion, + addNewestAnswer, + removeLatestMessage, + removeMessageById, + } = useSelectNextMessages(); + const { id: agentId } = useParams(); + const { handleInputChange, value, setValue } = useHandleMessageInputChange(); + const { refetch } = useFetchAgent(); + const { addEventList } = useContext(AgentChatLogContext); + + const { send, answerList, done, stopOutputMessage } = useSendMessageBySSE( + api.runCanvas, + ); + + const prologue = useGetBeginNodePrologue(); + + const sendMessage = useCallback( + async ({ message }: { message: Message; messages?: Message[] }) => { + const params: Record = { + id: agentId, + }; + params.running_hint_text = i18n.t('flow.runningHintText', { + defaultValue: 'is running...🕞', + }); + if (message.content) { + params.query = message.content; + // params.message_id = message.id; + params.inputs = {}; // begin operator inputs + } + const res = await send(params); + + if (receiveMessageError(res)) { + antMessage.error(res?.data?.message); + + // cancel loading + setValue(message.content); + removeLatestMessage(); + } else { + refetch(); // pull the message list after sending the message successfully + } + }, + [agentId, send, setValue, removeLatestMessage, refetch], + ); + + const handleSendMessage = useCallback( + async (message: Message) => { + sendMessage({ message }); + }, + [sendMessage], + ); + + useEffect(() => { + const { content, id } = findMessageFromList(answerList); + if (content) { + addNewestAnswer({ + answer: content, + id: id, + }); + } + }, [answerList, addNewestAnswer]); + + const handlePressEnter = useCallback(() => { + if (trim(value) === '') return; + const id = uuid(); + if (done) { + setValue(''); + handleSendMessage({ id, content: value.trim(), role: MessageType.User }); + } + addNewestQuestion({ + content: value, + id, + role: MessageType.User, + }); + }, [addNewestQuestion, handleSendMessage, done, setValue, value]); + + useEffect(() => { + if (prologue) { + addNewestAnswer({ + answer: prologue, + }); + } + }, [addNewestAnswer, prologue]); + + useEffect(() => { + addEventList(answerList); + }, [addEventList, answerList]); + + return { + handlePressEnter, + handleInputChange, + value, + sendLoading: !done, + reference, + loading, + derivedMessages, + ref, + removeMessageById, + stopOutputMessage, + }; +}; diff --git a/web/src/pages/agent/constant.tsx b/web/src/pages/agent/constant.tsx index f308e07278f..3f1d192e1f2 100644 --- a/web/src/pages/agent/constant.tsx +++ b/web/src/pages/agent/constant.tsx @@ -1,39 +1,23 @@ import { - GitHubIcon, - KeywordIcon, - QWeatherIcon, - WikipediaIcon, -} from '@/assets/icon/Icon'; -import { ReactComponent as AkShareIcon } from '@/assets/svg/akshare.svg'; -import { ReactComponent as ArXivIcon } from '@/assets/svg/arxiv.svg'; -import { ReactComponent as baiduFanyiIcon } from '@/assets/svg/baidu-fanyi.svg'; -import { ReactComponent as BaiduIcon } from '@/assets/svg/baidu.svg'; -import { ReactComponent as BeginIcon } from '@/assets/svg/begin.svg'; -import { ReactComponent as BingIcon } from '@/assets/svg/bing.svg'; -import { ReactComponent as ConcentratorIcon } from '@/assets/svg/concentrator.svg'; -import { ReactComponent as CrawlerIcon } from '@/assets/svg/crawler.svg'; -import { ReactComponent as DeepLIcon } from '@/assets/svg/deepl.svg'; -import { ReactComponent as DuckIcon } from '@/assets/svg/duck.svg'; -import { ReactComponent as EmailIcon } from '@/assets/svg/email.svg'; -import { ReactComponent as ExeSqlIcon } from '@/assets/svg/exesql.svg'; -import { ReactComponent as GoogleScholarIcon } from '@/assets/svg/google-scholar.svg'; -import { ReactComponent as GoogleIcon } from '@/assets/svg/google.svg'; -import { ReactComponent as InvokeIcon } from '@/assets/svg/invoke-ai.svg'; -import { ReactComponent as Jin10Icon } from '@/assets/svg/jin10.svg'; -import { ReactComponent as NoteIcon } from '@/assets/svg/note.svg'; -import { ReactComponent as PubMedIcon } from '@/assets/svg/pubmed.svg'; -import { ReactComponent as SwitchIcon } from '@/assets/svg/switch.svg'; -import { ReactComponent as TemplateIcon } from '@/assets/svg/template.svg'; -import { ReactComponent as TuShareIcon } from '@/assets/svg/tushare.svg'; -import { ReactComponent as WenCaiIcon } from '@/assets/svg/wencai.svg'; -import { ReactComponent as YahooFinanceIcon } from '@/assets/svg/yahoo-finance.svg'; - -// 邮件功能 + initialKeywordsSimilarityWeightValue, + initialSimilarityThresholdValue, +} from '@/components/similarity-slider'; +import { + AgentGlobals, + CodeTemplateStrMap, + ProgrammingLanguage, +} from '@/constants/agent'; + +export enum AgentDialogueMode { + Conversational = 'conversational', + Task = 'task', +} import { ChatVariableEnabledField, variableEnabledFieldMap, } from '@/constants/chat'; +import { ModelVariableType } from '@/constants/knowledge'; import i18n from '@/locales/config'; import { setInitialChatVariableEnabledFieldValue } from '@/utils/chat'; @@ -43,20 +27,14 @@ export enum Channel { News = 'news', } -import { - BranchesOutlined, - DatabaseOutlined, - FormOutlined, - MergeCellsOutlined, - MessageOutlined, - RocketOutlined, - SendOutlined, -} from '@ant-design/icons'; +export enum PromptRole { + User = 'user', + Assistant = 'assistant', +} + import upperFirst from 'lodash/upperFirst'; import { - CirclePower, CloudUpload, - IterationCcw, ListOrdered, OptionIcon, TextCursorInput, @@ -103,8 +81,15 @@ export enum Operator { Email = 'Email', Iteration = 'Iteration', IterationStart = 'IterationItem', + Code = 'Code', + WaitingDialogue = 'WaitingDialogue', + Agent = 'Agent', + Tool = 'Tool', + Tavily = 'Tavily', } +export const SwitchLogicOperatorOptions = ['and', 'or']; + export const CommonOperatorList = Object.values(Operator).filter( (x) => x !== Operator.Note, ); @@ -121,48 +106,11 @@ export const AgentOperatorList = [ Operator.Concentrator, Operator.Template, Operator.Iteration, + Operator.WaitingDialogue, Operator.Note, + Operator.Agent, ]; -export const operatorIconMap = { - [Operator.Retrieval]: RocketOutlined, - [Operator.Generate]: MergeCellsOutlined, - [Operator.Answer]: SendOutlined, - [Operator.Begin]: BeginIcon, - [Operator.Categorize]: DatabaseOutlined, - [Operator.Message]: MessageOutlined, - [Operator.Relevant]: BranchesOutlined, - [Operator.RewriteQuestion]: FormOutlined, - [Operator.KeywordExtract]: KeywordIcon, - [Operator.DuckDuckGo]: DuckIcon, - [Operator.Baidu]: BaiduIcon, - [Operator.Wikipedia]: WikipediaIcon, - [Operator.PubMed]: PubMedIcon, - [Operator.ArXiv]: ArXivIcon, - [Operator.Google]: GoogleIcon, - [Operator.Bing]: BingIcon, - [Operator.GoogleScholar]: GoogleScholarIcon, - [Operator.DeepL]: DeepLIcon, - [Operator.GitHub]: GitHubIcon, - [Operator.BaiduFanyi]: baiduFanyiIcon, - [Operator.QWeather]: QWeatherIcon, - [Operator.ExeSQL]: ExeSqlIcon, - [Operator.Switch]: SwitchIcon, - [Operator.WenCai]: WenCaiIcon, - [Operator.AkShare]: AkShareIcon, - [Operator.YahooFinance]: YahooFinanceIcon, - [Operator.Jin10]: Jin10Icon, - [Operator.Concentrator]: ConcentratorIcon, - [Operator.TuShare]: TuShareIcon, - [Operator.Note]: NoteIcon, - [Operator.Crawler]: CrawlerIcon, - [Operator.Invoke]: InvokeIcon, - [Operator.Template]: TemplateIcon, - [Operator.Email]: EmailIcon, - [Operator.Iteration]: IterationCcw, - [Operator.IterationStart]: CirclePower, -}; - export const operatorMap: Record< Operator, { @@ -299,6 +247,10 @@ export const operatorMap: Record< [Operator.Email]: { backgroundColor: '#e6f7ff' }, [Operator.Iteration]: { backgroundColor: '#e6f7ff' }, [Operator.IterationStart]: { backgroundColor: '#e6f7ff' }, + [Operator.Code]: { backgroundColor: '#4c5458' }, + [Operator.WaitingDialogue]: { backgroundColor: '#a5d65c' }, + [Operator.Agent]: { backgroundColor: '#a5d65c' }, + [Operator.Tavily]: { backgroundColor: '#a5d65c' }, }; export const componentMenuList = [ @@ -336,6 +288,15 @@ export const componentMenuList = [ { name: Operator.Iteration, }, + { + name: Operator.Code, + }, + { + name: Operator.WaitingDialogue, + }, + { + name: Operator.Agent, + }, { name: Operator.Note, }, @@ -404,18 +365,46 @@ export const componentMenuList = [ }, ]; +export const SwitchOperatorOptions = [ + { value: '=', label: 'equal', icon: 'equal' }, + { value: '≠', label: 'notEqual', icon: 'not-equals' }, + { value: '>', label: 'gt', icon: 'Less' }, + { value: '≥', label: 'ge', icon: 'Greater-or-equal' }, + { value: '<', label: 'lt', icon: 'Less' }, + { value: '≤', label: 'le', icon: 'less-or-equal' }, + { value: 'contains', label: 'contains', icon: 'Contains' }, + { value: 'not contains', label: 'notContains', icon: 'not-contains' }, + { value: 'start with', label: 'startWith', icon: 'list-start' }, + { value: 'end with', label: 'endWith', icon: 'list-end' }, + { value: 'empty', label: 'empty', icon: 'circle' }, + { value: 'not empty', label: 'notEmpty', icon: 'circle-slash-2' }, +]; + +export const SwitchElseTo = 'end_cpn_ids'; + const initialQueryBaseValues = { query: [], }; export const initialRetrievalValues = { - similarity_threshold: 0.2, - keywords_similarity_weight: 0.3, + query: '', top_n: 8, - ...initialQueryBaseValues, + top_k: 1024, + kb_ids: [], + rerank_id: '', + empty_response: '', + ...initialSimilarityThresholdValue, + ...initialKeywordsSimilarityWeightValue, + outputs: { + formalized_content: { + type: 'string', + value: '', + }, + }, }; export const initialBeginValues = { + mode: AgentDialogueMode.Conversational, prologue: `Hi! I'm your assistant, what can I do for you?`, }; @@ -457,6 +446,7 @@ export const initialRelevantValues = { export const initialCategorizeValues = { ...initialLlmBaseValues, + parameter: ModelVariableType.Precise, message_history_window_size: 1, category_description: {}, ...initialQueryBaseValues, @@ -563,7 +553,20 @@ export const initialExeSqlValues = { ...initialQueryBaseValues, }; -export const initialSwitchValues = { conditions: [] }; +export const initialSwitchValues = { + conditions: [ + { + logical_operator: SwitchLogicOperatorOptions[0], + items: [ + { + operator: SwitchOperatorOptions[0].value, + }, + ], + to: [], + }, + ], + [SwitchElseTo]: [], +}; export const initialWenCaiValues = { top_n: 20, @@ -645,6 +648,44 @@ export const initialIterationValues = { }; export const initialIterationStartValues = {}; +export const initialCodeValues = { + lang: 'python', + script: CodeTemplateStrMap[ProgrammingLanguage.Python], + arguments: [ + { + name: 'arg1', + }, + { + name: 'arg2', + }, + ], +}; + +export const initialWaitingDialogueValues = {}; + +export const initialAgentValues = { + ...initialLlmBaseValues, + sys_prompt: ``, + prompts: [{ role: PromptRole.User, content: `{${AgentGlobals.SysQuery}}` }], + message_history_window_size: 12, + tools: [], + outputs: { + structured_output: { + // topic: { + // type: 'string', + // description: + // 'default:general. The category of the search.news is useful for retrieving real-time updates, particularly about politics, sports, and major current events covered by mainstream media sources. general is for broader, more general-purpose searches that may include a wide range of sources.', + // enum: ['general', 'news'], + // default: 'general', + // }, + }, + content: { + type: 'string', + value: '', + }, + }, +}; + export const CategorizeAnchorPointPositions = [ { top: 1, right: 34 }, { top: 8, right: 18 }, @@ -726,6 +767,9 @@ export const RestrictedUpstreamMap = { [Operator.Email]: [Operator.Begin], [Operator.Iteration]: [Operator.Begin], [Operator.IterationStart]: [Operator.Begin], + [Operator.Code]: [Operator.Begin], + [Operator.WaitingDialogue]: [Operator.Begin], + [Operator.Agent]: [Operator.Begin], }; export const NodeMap = { @@ -765,6 +809,10 @@ export const NodeMap = { [Operator.Email]: 'emailNode', [Operator.Iteration]: 'group', [Operator.IterationStart]: 'iterationStartNode', + [Operator.Code]: 'ragNode', + [Operator.WaitingDialogue]: 'ragNode', + [Operator.Agent]: 'agentNode', + [Operator.Tool]: 'toolNode', }; export const LanguageOptions = [ @@ -2903,25 +2951,6 @@ export const ExeSQLOptions = ['mysql', 'postgresql', 'mariadb', 'mssql'].map( }), ); -export const SwitchElseTo = 'end_cpn_id'; - -export const SwitchOperatorOptions = [ - { value: '=', label: 'equal' }, - { value: '≠', label: 'notEqual' }, - { value: '>', label: 'gt' }, - { value: '≥', label: 'ge' }, - { value: '<', label: 'lt' }, - { value: '≤', label: 'le' }, - { value: 'contains', label: 'contains' }, - { value: 'not contains', label: 'notContains' }, - { value: 'start with', label: 'startWith' }, - { value: 'end with', label: 'endWith' }, - { value: 'empty', label: 'empty' }, - { value: 'not empty', label: 'notEmpty' }, -]; - -export const SwitchLogicOperatorOptions = ['and', 'or']; - export const WenCaiQueryTypeOptions = [ 'stock', 'zhishu', @@ -2983,3 +3012,9 @@ export const NoDebugOperatorsList = [ Operator.Switch, Operator.Iteration, ]; + +export enum NodeHandleId { + Start = 'start', + End = 'end', + Tool = 'tool', +} diff --git a/web/src/pages/agent/context.ts b/web/src/pages/agent/context.ts index fe51d8d6dac..eda8c50ee03 100644 --- a/web/src/pages/agent/context.ts +++ b/web/src/pages/agent/context.ts @@ -1,6 +1,48 @@ import { RAGFlowNodeType } from '@/interfaces/database/flow'; +import { HandleType, Position } from '@xyflow/react'; import { createContext } from 'react'; +import { useAddNode } from './hooks/use-add-node'; +import { useCacheChatLog } from './hooks/use-cache-chat-log'; +import { useShowLogSheet } from './hooks/use-show-drawer'; -export const FlowFormContext = createContext( +export const AgentFormContext = createContext( undefined, ); + +type AgentInstanceContextType = Pick< + ReturnType, + 'addCanvasNode' +>; + +export const AgentInstanceContext = createContext( + {} as AgentInstanceContextType, +); + +type AgentChatContextType = Pick< + ReturnType, + 'showLogSheet' +>; + +export const AgentChatContext = createContext( + {} as AgentChatContextType, +); + +type AgentChatLogContextType = Pick< + ReturnType, + 'addEventList' | 'setCurrentMessageId' +>; + +export const AgentChatLogContext = createContext( + {} as AgentChatLogContextType, +); + +export type HandleContextType = { + nodeId?: string; + id?: string; + type: HandleType; + position: Position; +}; + +export const HandleContext = createContext( + {} as HandleContextType, +); diff --git a/web/src/pages/agent/debug-content/index.less b/web/src/pages/agent/debug-content/index.less deleted file mode 100644 index fda707810ba..00000000000 --- a/web/src/pages/agent/debug-content/index.less +++ /dev/null @@ -1,5 +0,0 @@ -.formWrapper { - :global(.ant-form-item-label) { - font-weight: 600 !important; - } -} diff --git a/web/src/pages/agent/debug-content/index.tsx b/web/src/pages/agent/debug-content/index.tsx index f5493c76483..377c53e04e3 100644 --- a/web/src/pages/agent/debug-content/index.tsx +++ b/web/src/pages/agent/debug-content/index.tsx @@ -1,30 +1,42 @@ -import { Authorization } from '@/constants/authorization'; -import { useSetModalState } from '@/hooks/common-hooks'; -import { useSetSelectedRecord } from '@/hooks/logic-hooks'; -import { useHandleSubmittable } from '@/hooks/login-hooks'; -import api from '@/utils/api'; -import { getAuthorization } from '@/utils/authorization-util'; -import { UploadOutlined } from '@ant-design/icons'; +import { FileUploader } from '@/components/file-uploader'; +import { ButtonLoading } from '@/components/ui/button'; import { - Button, Form, - FormItemProps, - Input, - InputNumber, - Select, - Switch, - Upload, -} from 'antd'; + FormControl, + FormField, + FormItem, + FormLabel, + FormMessage, +} from '@/components/ui/form'; +import { Input } from '@/components/ui/input'; +import { RAGFlowSelect } from '@/components/ui/select'; +import { Switch } from '@/components/ui/switch'; +import { Textarea } from '@/components/ui/textarea'; +import { useSetModalState } from '@/hooks/common-hooks'; +import { useSetSelectedRecord } from '@/hooks/logic-hooks'; +import { zodResolver } from '@hookform/resolvers/zod'; import { UploadChangeParam, UploadFile } from 'antd/es/upload'; -import { pick } from 'lodash'; -import { Link } from 'lucide-react'; -import React, { useCallback, useState } from 'react'; +import React, { useCallback, useMemo, useState } from 'react'; +import { useForm } from 'react-hook-form'; import { useTranslation } from 'react-i18next'; +import { z } from 'zod'; import { BeginQueryType } from '../constant'; import { BeginQuery } from '../interface'; -import { PopoverForm } from './popover-form'; -import styles from './index.less'; +export const BeginQueryComponentMap = { + [BeginQueryType.Line]: 'string', + [BeginQueryType.Paragraph]: 'string', + [BeginQueryType.Options]: 'string', + [BeginQueryType.File]: 'file', + [BeginQueryType.Integer]: 'number', + [BeginQueryType.Boolean]: 'boolean', +}; + +const StringFields = [ + BeginQueryType.Line, + BeginQueryType.Paragraph, + BeginQueryType.Options, +]; interface IProps { parameters: BeginQuery[]; @@ -34,6 +46,8 @@ interface IProps { submitButtonDisabled?: boolean; } +const values = {}; + const DebugContent = ({ parameters, ok, @@ -42,15 +56,48 @@ const DebugContent = ({ submitButtonDisabled = false, }: IProps) => { const { t } = useTranslation(); - const [form] = Form.useForm(); + + const FormSchema = useMemo(() => { + const obj = parameters.reduce((pre, cur, idx) => { + const type = cur.type; + let fieldSchema; + if (StringFields.some((x) => x === type)) { + fieldSchema = z.string(); + } else if (type === BeginQueryType.Boolean) { + fieldSchema = z.boolean(); + } else if (type === BeginQueryType.Integer) { + fieldSchema = z.coerce.number(); + } else { + fieldSchema = z.instanceof(File); + } + + if (cur.optional) { + fieldSchema.optional(); + } + + pre[idx.toString()] = fieldSchema; + + return pre; + }, {}); + + return z.object(obj); + }, [parameters]); + + const form = useForm({ + defaultValues: values, + resolver: zodResolver(FormSchema), + }); + const { visible, hideModal: hidePopover, switchVisible, showModal: showPopover, } = useSetModalState(); + const { setRecord, currentRecord } = useSetSelectedRecord(); - const { submittable } = useHandleSubmittable(form); + // const { submittable } = useHandleSubmittable(form); + const submittable = true; const [isUploading, setIsUploading] = useState(false); const handleShowPopover = useCallback( @@ -79,8 +126,8 @@ const DebugContent = ({ ); const renderWidget = useCallback( - (q: BeginQuery, idx: number) => { - const props: FormItemProps & { key: number } = { + (q: BeginQuery, idx: string) => { + const props = { key: idx, label: q.name ?? q.key, name: idx, @@ -89,80 +136,119 @@ const DebugContent = ({ props.rules = [{ required: true }]; } - const urlList: { url: string; result: string }[] = - form.getFieldValue(idx) || []; + // const urlList: { url: string; result: string }[] = + // form.getFieldValue(idx) || []; + + const urlList: { url: string; result: string }[] = []; const BeginQueryTypeMap = { [BeginQueryType.Line]: ( - - - + ( + + {props.label} + + + + + + )} + /> ), [BeginQueryType.Paragraph]: ( - - - + ( + + {props.label} + + + + + + )} + /> ), [BeginQueryType.Options]: ( - - - + ( + + {props.label} + + ({ label: x, value: x })) ?? [] + } + {...field} + > + + + + )} + /> ), [BeginQueryType.File]: ( - -
    - - - - - - 0 ? 'mb-1' : ''} - noStyle - > - - - - -
    -
    - + ( +
    + + {t('assistantAvatar')} + + + + + +
    + )} + />
    ), [BeginQueryType.Integer]: ( - - - + ( + + {props.label} + + + + + + )} + /> ), [BeginQueryType.Boolean]: ( - - - + ( + + {props.label} + + + + + + )} + /> ), }; @@ -171,66 +257,53 @@ const DebugContent = ({ BeginQueryTypeMap[BeginQueryType.Paragraph] ); }, - [form, handleShowPopover, onChange, switchVisible, t, visible], + [form, t], ); - const onOk = useCallback(async () => { - const values = await form.validateFields(); - const nextValues = Object.entries(values).map(([key, value]) => { - const item = parameters[Number(key)]; - let nextValue = value; - if (Array.isArray(value)) { - nextValue = ``; - - value.forEach((x) => { - nextValue += - x?.originFileObj instanceof File - ? `${x.name}\n${x.response?.data}\n----\n` - : `${x.url}\n${x.result}\n----\n`; - }); - } - return { ...item, value: nextValue }; - }); + const onSubmit = useCallback( + (values: z.infer) => { + console.log('🚀 ~ values:', values); + return values; + const nextValues = Object.entries(values).map(([key, value]) => { + const item = parameters[Number(key)]; + let nextValue = value; + if (Array.isArray(value)) { + nextValue = ``; + + value.forEach((x) => { + nextValue += + x?.originFileObj instanceof File + ? `${x.name}\n${x.response?.data}\n----\n` + : `${x.url}\n${x.result}\n----\n`; + }); + } + return { ...item, value: nextValue }; + }); - ok(nextValues); - }, [form, ok, parameters]); + ok(nextValues); + }, + [ok, parameters], + ); return ( <> -
    - { - if (name === 'urlForm') { - const { basicForm } = forms; - const urlInfo = basicForm.getFieldValue(currentRecord) || []; - basicForm.setFieldsValue({ - [currentRecord]: [...urlInfo, { ...values, name: values.url }], - }); - hidePopover(); - } - }} - > -
    +
    + + {parameters.map((x, idx) => { - return renderWidget(x, idx); + return
    {renderWidget(x, idx.toString())}
    ; })} - - + + {t(isNext ? 'common.next' : 'flow.run')} + + +
    - ); }; diff --git a/web/src/pages/agent/debug-content/popover-form.tsx b/web/src/pages/agent/debug-content/popover-form.tsx index 557e3185bc3..9465d903b46 100644 --- a/web/src/pages/agent/debug-content/popover-form.tsx +++ b/web/src/pages/agent/debug-content/popover-form.tsx @@ -1,74 +1,103 @@ +import { + Form, + FormControl, + FormField, + FormItem, + FormMessage, +} from '@/components/ui/form'; +import { Input } from '@/components/ui/input'; +import { Popover, PopoverContent } from '@/components/ui/popover'; import { useParseDocument } from '@/hooks/document-hooks'; -import { useResetFormOnCloseModal } from '@/hooks/logic-hooks'; import { IModalProps } from '@/interfaces/common'; -import { Button, Form, Input, Popover } from 'antd'; +import { zodResolver } from '@hookform/resolvers/zod'; import { PropsWithChildren } from 'react'; +import { useForm } from 'react-hook-form'; import { useTranslation } from 'react-i18next'; +import { z } from 'zod'; const reg = /^(((ht|f)tps?):\/\/)?([^!@#$%^&*?.\s-]([^!@#$%^&*?.\s]{0,63}[^!@#$%^&*?.\s])?\.)+[a-z]{2,6}\/?/; +const FormSchema = z.object({ + url: z.string(), + result: z.any(), +}); + +const values = { + url: '', + result: null, +}; + export const PopoverForm = ({ children, visible, switchVisible, }: PropsWithChildren>) => { - const [form] = Form.useForm(); + const form = useForm({ + defaultValues: values, + resolver: zodResolver(FormSchema), + }); const { parseDocument, loading } = useParseDocument(); const { t } = useTranslation(); - useResetFormOnCloseModal({ - form, - visible, - }); + // useResetFormOnCloseModal({ + // form, + // visible, + // }); - const onOk = async () => { - const values = await form.validateFields(); + async function onSubmit(values: z.infer) { const val = values.url; if (reg.test(val)) { const ret = await parseDocument(val); if (ret?.data?.code === 0) { - form.setFieldValue('result', ret?.data?.data); - form.submit(); + form.setValue('result', ret?.data?.data); } } - }; + } const content = ( -
    - - e.preventDefault()} - placeholder={t('flow.pasteFileLink')} - suffix={ - - } + + + ( + + + e.preventDefault()} + placeholder={t('flow.pasteFileLink')} + // suffix={ + // + // } + /> + + + + )} + /> + <>} /> - - + ); return ( - + {children} + {content} ); }; diff --git a/web/src/pages/agent/form-sheet/next.tsx b/web/src/pages/agent/form-sheet/next.tsx index 7335844c2d6..fb500458efd 100644 --- a/web/src/pages/agent/form-sheet/next.tsx +++ b/web/src/pages/agent/form-sheet/next.tsx @@ -9,20 +9,14 @@ import { useTranslate } from '@/hooks/common-hooks'; import { IModalProps } from '@/interfaces/common'; import { RAGFlowNodeType } from '@/interfaces/database/flow'; import { cn } from '@/lib/utils'; -import { zodResolver } from '@hookform/resolvers/zod'; -import { get, isPlainObject, lowerFirst } from 'lodash'; +import { lowerFirst } from 'lodash'; import { Play, X } from 'lucide-react'; -import { useEffect, useRef } from 'react'; -import { useForm } from 'react-hook-form'; -import { BeginId, Operator, operatorMap } from '../constant'; -import { FlowFormContext } from '../context'; +import { BeginId, Operator } from '../constant'; +import { AgentFormContext } from '../context'; import { RunTooltip } from '../flow-tooltip'; -import { useHandleFormValuesChange, useHandleNodeNameChange } from '../hooks'; +import { useHandleNodeNameChange } from '../hooks'; import OperatorIcon from '../operator-icon'; -import { - buildCategorizeListFromObject, - needsSingleStepDebugging, -} from '../utils'; +import { needsSingleStepDebugging } from '../utils'; import SingleDebugDrawer from './single-debug-drawer'; import { useFormConfigMap } from './use-form-config-map'; @@ -51,61 +45,21 @@ const FormSheet = ({ const OperatorForm = currentFormMap.component ?? EmptyContent; - const form = useForm({ - defaultValues: currentFormMap.defaultValues, - resolver: zodResolver(currentFormMap.schema), - }); - const { name, handleNameBlur, handleNameChange } = useHandleNodeNameChange({ id: node?.id, data: node?.data, }); - const previousId = useRef(node?.id); - const { t } = useTranslate('flow'); - const { handleValuesChange } = useHandleFormValuesChange( - operatorName, - node?.id, - form, - ); - - useEffect(() => { - if (visible && !form.formState.isDirty) { - if (node?.id !== previousId.current) { - form.reset(); - form.clearErrors(); - } - - if (operatorName === Operator.Categorize) { - const items = buildCategorizeListFromObject( - get(node, 'data.form.category_description', {}), - ); - const formData = node?.data?.form; - if (isPlainObject(formData)) { - // form.setFieldsValue({ ...formData, items }); - form.reset({ ...formData, items }); - } - } else { - // form.setFieldsValue(node?.data?.form); - form.reset(node?.data?.form); - } - previousId.current = node?.id; - } - }, [visible, form, node?.data?.form, node?.id, node, operatorName]); - return ( - - + +
    - +
    {node?.id === BeginId ? ( @@ -132,15 +86,11 @@ const FormSheet = ({ {t(`${lowerFirst(operatorName)}Description`)}
    -
    +
    {visible && ( - - - + + + )}
    diff --git a/web/src/pages/agent/form-sheet/use-form-config-map.tsx b/web/src/pages/agent/form-sheet/use-form-config-map.tsx index 4371b747171..d887bf58c35 100644 --- a/web/src/pages/agent/form-sheet/use-form-config-map.tsx +++ b/web/src/pages/agent/form-sheet/use-form-config-map.tsx @@ -1,6 +1,9 @@ +import { LlmSettingSchema } from '@/components/llm-setting-items/next'; +import { CodeTemplateStrMap, ProgrammingLanguage } from '@/constants/agent'; import { useTranslation } from 'react-i18next'; import { z } from 'zod'; import { Operator } from '../constant'; +import AgentForm from '../form/agent-form'; import AkShareForm from '../form/akshare-form'; import AnswerForm from '../form/answer-form'; import ArXivForm from '../form/arxiv-form'; @@ -9,6 +12,7 @@ import BaiduForm from '../form/baidu-form'; import BeginForm from '../form/begin-form'; import BingForm from '../form/bing-form'; import CategorizeForm from '../form/categorize-form'; +import CodeForm from '../form/code-form'; import CrawlerForm from '../form/crawler-form'; import DeepLForm from '../form/deepl-form'; import DuckDuckGoForm from '../form/duckduckgo-form'; @@ -30,6 +34,7 @@ import RetrievalForm from '../form/retrieval-form/next'; import RewriteQuestionForm from '../form/rewrite-question-form'; import SwitchForm from '../form/switch-form'; import TemplateForm from '../form/template-form'; +import ToolForm from '../form/tool-form'; import TuShareForm from '../form/tushare-form'; import WenCaiForm from '../form/wencai-form'; import WikipediaForm from '../form/wikipedia-form'; @@ -43,18 +48,27 @@ export function useFormConfigMap() { component: BeginForm, defaultValues: {}, schema: z.object({ - name: z + enablePrologue: z.boolean().optional(), + prologue: z .string() .min(1, { message: t('common.namePlaceholder'), }) - .trim(), - age: z - .string() - .min(1, { - message: t('common.namePlaceholder'), - }) - .trim(), + .trim() + .optional(), + mode: z.string(), + query: z + .array( + z.object({ + key: z.string(), + type: z.string(), + value: z.string(), + optional: z.boolean(), + name: z.string(), + options: z.array(z.union([z.number(), z.string(), z.boolean()])), + }), + ) + .optional(), }), }, [Operator.Retrieval]: { @@ -99,20 +113,40 @@ export function useFormConfigMap() { }, [Operator.Categorize]: { component: CategorizeForm, - defaultValues: { message_history_window_size: 1 }, + defaultValues: {}, schema: z.object({ - message_history_window_size: z.number(), + parameter: z.string().optional(), + ...LlmSettingSchema, + message_history_window_size: z.coerce.number(), items: z.array( - z.object({ - name: z.string().min(1, t('flow.nameMessage')).trim(), - }), + z + .object({ + name: z.string().min(1, t('flow.nameMessage')).trim(), + description: z.string().optional(), + // examples: z + // .array( + // z.object({ + // value: z.string(), + // }), + // ) + // .optional(), + }) + .optional(), ), }), }, [Operator.Message]: { component: MessageForm, defaultValues: {}, - schema: z.object({}), + schema: z.object({ + content: z + .array( + z.object({ + value: z.string(), + }), + ) + .optional(), + }), }, [Operator.Relevant]: { component: RelevantForm, @@ -130,6 +164,41 @@ export function useFormConfigMap() { language: z.string(), }), }, + [Operator.Code]: { + component: CodeForm, + defaultValues: { + lang: ProgrammingLanguage.Python, + script: CodeTemplateStrMap[ProgrammingLanguage.Python], + arguments: [], + }, + schema: z.object({ + lang: z.string(), + script: z.string(), + arguments: z.array( + z.object({ name: z.string(), component_id: z.string() }), + ), + return: z.union([ + z + .array(z.object({ name: z.string(), component_id: z.string() })) + .optional(), + z.object({ name: z.string(), component_id: z.string() }), + ]), + }), + }, + [Operator.WaitingDialogue]: { + component: CodeForm, + defaultValues: {}, + schema: z.object({ + arguments: z.array( + z.object({ name: z.string(), component_id: z.string() }), + ), + }), + }, + [Operator.Agent]: { + component: AgentForm, + defaultValues: {}, + schema: z.object({}), + }, [Operator.Baidu]: { component: BaiduForm, defaultValues: { top_n: 10 }, @@ -301,6 +370,11 @@ export function useFormConfigMap() { defaultValues: {}, schema: z.object({}), }, + [Operator.Tool]: { + component: ToolForm, + defaultValues: {}, + schema: z.object({}), + }, }; return FormConfigMap; diff --git a/web/src/pages/agent/form-sheet/use-values.ts b/web/src/pages/agent/form-sheet/use-values.ts new file mode 100644 index 00000000000..eccee9c7737 --- /dev/null +++ b/web/src/pages/agent/form-sheet/use-values.ts @@ -0,0 +1,42 @@ +import { RAGFlowNodeType } from '@/interfaces/database/flow'; +import { get, isEmpty, isPlainObject, omit } from 'lodash'; +import { useMemo, useRef } from 'react'; +import { Operator } from '../constant'; +import { buildCategorizeListFromObject, convertToObjectArray } from '../utils'; +import { useFormConfigMap } from './use-form-config-map'; + +export function useValues(node?: RAGFlowNodeType, isDirty?: boolean) { + const operatorName: Operator = node?.data.label as Operator; + const previousId = useRef(node?.id); + + const FormConfigMap = useFormConfigMap(); + + const currentFormMap = FormConfigMap[operatorName]; + + const values = useMemo(() => { + const formData = node?.data?.form; + if (operatorName === Operator.Categorize) { + const items = buildCategorizeListFromObject( + get(node, 'data.form.category_description', {}), + ); + if (isPlainObject(formData)) { + console.info('xxx'); + const nextValues = { + ...omit(formData, 'category_description'), + items, + }; + + return nextValues; + } + } else if (operatorName === Operator.Message) { + return { + ...formData, + content: convertToObjectArray(formData.content), + }; + } else { + return isEmpty(formData) ? currentFormMap : formData; + } + }, [currentFormMap, node, operatorName]); + + return values; +} diff --git a/web/src/pages/agent/form/agent-form/agent-tools.tsx b/web/src/pages/agent/form/agent-form/agent-tools.tsx new file mode 100644 index 00000000000..25bbead6d71 --- /dev/null +++ b/web/src/pages/agent/form/agent-form/agent-tools.tsx @@ -0,0 +1,53 @@ +import { BlockButton } from '@/components/ui/button'; +import { cn } from '@/lib/utils'; +import { PencilLine, X } from 'lucide-react'; +import { PropsWithChildren } from 'react'; +import { ToolPopover } from './tool-popover'; +import { useDeleteAgentNodeTools } from './tool-popover/use-update-tools'; +import { useGetAgentToolNames } from './use-get-tools'; + +export function ToolCard({ + children, + className, + ...props +}: PropsWithChildren & React.HTMLAttributes) { + return ( +
  • + {children} +
  • + ); +} + +export function AgentTools() { + const { toolNames } = useGetAgentToolNames(); + const { deleteNodeTool } = useDeleteAgentNodeTools(); + + return ( +
    + Tools +
      + {toolNames.map((x) => ( + + {x} +
      + + +
      +
      + ))} +
    + + Add Tool + +
    + ); +} diff --git a/web/src/pages/agent/form/agent-form/dynamic-prompt.tsx b/web/src/pages/agent/form/agent-form/dynamic-prompt.tsx new file mode 100644 index 00000000000..1cda9fbd508 --- /dev/null +++ b/web/src/pages/agent/form/agent-form/dynamic-prompt.tsx @@ -0,0 +1,93 @@ +import { BlockButton, Button } from '@/components/ui/button'; +import { + FormControl, + FormField, + FormItem, + FormLabel, + FormMessage, +} from '@/components/ui/form'; +import { RAGFlowSelect } from '@/components/ui/select'; +import { X } from 'lucide-react'; +import { memo } from 'react'; +import { useFieldArray, useFormContext } from 'react-hook-form'; +import { useTranslation } from 'react-i18next'; +import { PromptRole } from '../../constant'; +import { PromptEditor } from '../components/prompt-editor'; + +const options = [ + { label: 'User', value: PromptRole.User }, + { label: 'Assistant', value: PromptRole.Assistant }, +]; + +const DynamicPrompt = () => { + const { t } = useTranslation(); + const form = useFormContext(); + const name = 'prompts'; + + const { fields, append, remove } = useFieldArray({ + name: name, + control: form.control, + }); + + return ( + + {t('flow.msg')} +
    + {fields.map((field, index) => ( +
    +
    + ( + + + + + + + + )} + /> + + ( + + +
    + +
    +
    +
    + )} + /> +
    + +
    + ))} +
    + + append({ content: '', role: PromptRole.User })} + > + Add + +
    + ); +}; + +export default memo(DynamicPrompt); diff --git a/web/src/pages/agent/form/agent-form/dynamic-tool.tsx b/web/src/pages/agent/form/agent-form/dynamic-tool.tsx new file mode 100644 index 00000000000..afda465b652 --- /dev/null +++ b/web/src/pages/agent/form/agent-form/dynamic-tool.tsx @@ -0,0 +1,63 @@ +import { BlockButton, Button } from '@/components/ui/button'; +import { + FormControl, + FormField, + FormItem, + FormMessage, +} from '@/components/ui/form'; +import { X } from 'lucide-react'; +import { memo } from 'react'; +import { useFieldArray, useFormContext } from 'react-hook-form'; +import { PromptEditor } from '../components/prompt-editor'; + +const DynamicTool = () => { + const form = useFormContext(); + const name = 'tools'; + + const { fields, append, remove } = useFieldArray({ + name: name, + control: form.control, + }); + + return ( + +
    + {fields.map((field, index) => ( +
    +
    + ( + + +
    + +
    +
    +
    + )} + /> +
    + +
    + ))} +
    + + append({ component_name: '' })}> + Add + +
    + ); +}; + +export default memo(DynamicTool); diff --git a/web/src/pages/agent/form/agent-form/index.tsx b/web/src/pages/agent/form/agent-form/index.tsx new file mode 100644 index 00000000000..4825f20c908 --- /dev/null +++ b/web/src/pages/agent/form/agent-form/index.tsx @@ -0,0 +1,131 @@ +import { FormContainer } from '@/components/form-container'; +import { LargeModelFormField } from '@/components/large-model-form-field'; +import { LlmSettingSchema } from '@/components/llm-setting-items/next'; +import { MessageHistoryWindowSizeFormField } from '@/components/message-history-window-size-item'; +import { BlockButton } from '@/components/ui/button'; +import { + Form, + FormControl, + FormField, + FormItem, + FormLabel, +} from '@/components/ui/form'; +import { zodResolver } from '@hookform/resolvers/zod'; +import { Position } from '@xyflow/react'; +import { useContext, useMemo } from 'react'; +import { useForm } from 'react-hook-form'; +import { useTranslation } from 'react-i18next'; +import { z } from 'zod'; +import { Operator, initialAgentValues } from '../../constant'; +import { AgentInstanceContext } from '../../context'; +import { INextOperatorForm } from '../../interface'; +import { Output } from '../components/output'; +import { PromptEditor } from '../components/prompt-editor'; +import { AgentTools } from './agent-tools'; +import { useValues } from './use-values'; +import { useWatchFormChange } from './use-watch-change'; + +const FormSchema = z.object({ + sys_prompt: z.string(), + prompts: z.string().optional(), + // prompts: z + // .array( + // z.object({ + // role: z.string(), + // content: z.string(), + // }), + // ) + // .optional(), + message_history_window_size: z.coerce.number(), + tools: z + .array( + z.object({ + component_name: z.string(), + }), + ) + .optional(), + ...LlmSettingSchema, +}); + +const AgentForm = ({ node }: INextOperatorForm) => { + const { t } = useTranslation(); + + const defaultValues = useValues(node); + + const outputList = useMemo(() => { + return [ + { title: 'content', type: initialAgentValues.outputs.content.type }, + ]; + }, []); + + const form = useForm({ + defaultValues: defaultValues, + resolver: zodResolver(FormSchema), + }); + + useWatchFormChange(node?.id, form); + + const { addCanvasNode } = useContext(AgentInstanceContext); + + return ( +
    + { + e.preventDefault(); + }} + > + + + ( + + Prompt + + + + + )} + /> + + + + {/* */} + ( + + +
    + +
    +
    +
    + )} + /> +
    + + + + Add Agent + + + +
    + + ); +}; + +export default AgentForm; diff --git a/web/src/pages/agent/form/agent-form/tool-popover/index.tsx b/web/src/pages/agent/form/agent-form/tool-popover/index.tsx new file mode 100644 index 00000000000..7c6bbda79f7 --- /dev/null +++ b/web/src/pages/agent/form/agent-form/tool-popover/index.tsx @@ -0,0 +1,47 @@ +import { + Popover, + PopoverContent, + PopoverTrigger, +} from '@/components/ui/popover'; +import { Operator } from '@/pages/agent/constant'; +import { AgentFormContext, AgentInstanceContext } from '@/pages/agent/context'; +import { Position } from '@xyflow/react'; +import { PropsWithChildren, useCallback, useContext } from 'react'; +import { useDeleteToolNode } from '../use-delete-tool-node'; +import { useGetAgentToolNames } from '../use-get-tools'; +import { ToolCommand } from './tool-command'; +import { useUpdateAgentNodeTools } from './use-update-tools'; + +export function ToolPopover({ children }: PropsWithChildren) { + const { addCanvasNode } = useContext(AgentInstanceContext); + const node = useContext(AgentFormContext); + const { updateNodeTools } = useUpdateAgentNodeTools(); + const { toolNames } = useGetAgentToolNames(); + const { deleteToolNode } = useDeleteToolNode(); + + const handleChange = useCallback( + (value: string[]) => { + if (Array.isArray(value) && node?.id) { + updateNodeTools(value); + if (value.length > 0) { + addCanvasNode(Operator.Tool, { + position: Position.Bottom, + nodeId: node?.id, + })(); + } else { + deleteToolNode(node.id); // TODO: The tool node should be derived from the agent tools data + } + } + }, + [addCanvasNode, deleteToolNode, node?.id, updateNodeTools], + ); + + return ( + + {children} + + + + + ); +} diff --git a/web/src/pages/agent/form/agent-form/tool-popover/tool-command.tsx b/web/src/pages/agent/form/agent-form/tool-popover/tool-command.tsx new file mode 100644 index 00000000000..5d948fe4715 --- /dev/null +++ b/web/src/pages/agent/form/agent-form/tool-popover/tool-command.tsx @@ -0,0 +1,113 @@ +import { Calendar, CheckIcon } from 'lucide-react'; + +import { + Command, + CommandEmpty, + CommandGroup, + CommandInput, + CommandItem, + CommandList, +} from '@/components/ui/command'; +import { cn } from '@/lib/utils'; +import { Operator } from '@/pages/agent/constant'; +import { useCallback, useEffect, useState } from 'react'; + +const Menus = [ + { + label: 'Search', + list: [ + Operator.Tavily, + Operator.Google, + Operator.Bing, + Operator.DuckDuckGo, + Operator.Wikipedia, + Operator.YahooFinance, + Operator.PubMed, + Operator.GoogleScholar, + ], + }, + { + label: 'Communication', + list: [Operator.Email], + }, + { + label: 'Productivity', + list: [], + }, + { + label: 'Developer', + list: [ + Operator.GitHub, + Operator.ExeSQL, + Operator.Invoke, + Operator.Crawler, + Operator.Code, + ], + }, +]; + +type ToolCommandProps = { + value?: string[]; + onChange?(values: string[]): void; +}; + +export function ToolCommand({ value, onChange }: ToolCommandProps) { + const [currentValue, setCurrentValue] = useState([]); + + const toggleOption = useCallback( + (option: string) => { + const newSelectedValues = currentValue.includes(option) + ? currentValue.filter((value) => value !== option) + : [...currentValue, option]; + setCurrentValue(newSelectedValues); + onChange?.(newSelectedValues); + }, + [currentValue, onChange], + ); + + useEffect(() => { + if (Array.isArray(value)) { + setCurrentValue(value); + } + }, [value]); + + return ( + + + + No results found. + {Menus.map((x) => ( + + {x.list.map((y) => { + const isSelected = currentValue.includes(y); + return ( + toggleOption(y)} + > +
    + +
    + {/* {option.icon && ( + + )} */} + {/* {option.label} */} + + {y} +
    + ); + })} +
    + ))} +
    +
    + ); +} diff --git a/web/src/pages/agent/form/agent-form/tool-popover/use-update-tools.ts b/web/src/pages/agent/form/agent-form/tool-popover/use-update-tools.ts new file mode 100644 index 00000000000..3bcf844847b --- /dev/null +++ b/web/src/pages/agent/form/agent-form/tool-popover/use-update-tools.ts @@ -0,0 +1,67 @@ +import { IAgentForm } from '@/interfaces/database/agent'; +import { AgentFormContext } from '@/pages/agent/context'; +import useGraphStore from '@/pages/agent/store'; +import { get } from 'lodash'; +import { useCallback, useContext, useMemo } from 'react'; +import { useDeleteToolNode } from '../use-delete-tool-node'; + +export function useGetNodeTools() { + const node = useContext(AgentFormContext); + + return useMemo(() => { + const tools: IAgentForm['tools'] = get(node, 'data.form.tools'); + return tools; + }, [node]); +} + +export function useUpdateAgentNodeTools() { + const { updateNodeForm } = useGraphStore((state) => state); + const node = useContext(AgentFormContext); + const tools = useGetNodeTools(); + + const updateNodeTools = useCallback( + (value: string[]) => { + if (node?.id) { + const nextValue = value.reduce((pre, cur) => { + const tool = tools.find((x) => x.component_name === cur); + pre.push(tool ? tool : { component_name: cur, params: {} }); + return pre; + }, []); + + updateNodeForm(node?.id, nextValue, ['tools']); + } + }, + [node?.id, tools, updateNodeForm], + ); + + const deleteNodeTool = useCallback( + (value: string) => { + updateNodeTools([value]); + }, + [updateNodeTools], + ); + + return { updateNodeTools, deleteNodeTool }; +} + +export function useDeleteAgentNodeTools() { + const { updateNodeForm } = useGraphStore((state) => state); + const tools = useGetNodeTools(); + const node = useContext(AgentFormContext); + const { deleteToolNode } = useDeleteToolNode(); + + const deleteNodeTool = useCallback( + (value: string) => () => { + const nextTools = tools.filter((x) => x.component_name !== value); + if (node?.id) { + updateNodeForm(node?.id, nextTools, ['tools']); + if (nextTools.length === 0) { + deleteToolNode(node?.id); + } + } + }, + [deleteToolNode, node?.id, tools, updateNodeForm], + ); + + return { deleteNodeTool }; +} diff --git a/web/src/pages/agent/form/agent-form/use-delete-tool-node.ts b/web/src/pages/agent/form/agent-form/use-delete-tool-node.ts new file mode 100644 index 00000000000..b3227459581 --- /dev/null +++ b/web/src/pages/agent/form/agent-form/use-delete-tool-node.ts @@ -0,0 +1,24 @@ +import { useCallback } from 'react'; +import { NodeHandleId } from '../../constant'; +import useGraphStore from '../../store'; + +export function useDeleteToolNode() { + const { edges, deleteEdgeById, deleteNodeById } = useGraphStore( + (state) => state, + ); + const deleteToolNode = useCallback( + (agentNodeId: string) => { + const edge = edges.find( + (x) => x.source === agentNodeId && x.sourceHandle === NodeHandleId.Tool, + ); + + if (edge) { + deleteEdgeById(edge.id); + deleteNodeById(edge.target); + } + }, + [deleteEdgeById, deleteNodeById, edges], + ); + + return { deleteToolNode }; +} diff --git a/web/src/pages/agent/form/agent-form/use-get-tools.ts b/web/src/pages/agent/form/agent-form/use-get-tools.ts new file mode 100644 index 00000000000..c9f37113de5 --- /dev/null +++ b/web/src/pages/agent/form/agent-form/use-get-tools.ts @@ -0,0 +1,15 @@ +import { IAgentForm } from '@/interfaces/database/agent'; +import { get } from 'lodash'; +import { useContext, useMemo } from 'react'; +import { AgentFormContext } from '../../context'; + +export function useGetAgentToolNames() { + const node = useContext(AgentFormContext); + + const toolNames = useMemo(() => { + const tools: IAgentForm['tools'] = get(node, 'data.form.tools', []); + return tools.map((x) => x.component_name); + }, [node]); + + return { toolNames }; +} diff --git a/web/src/pages/agent/form/agent-form/use-values.ts b/web/src/pages/agent/form/agent-form/use-values.ts new file mode 100644 index 00000000000..3e6b057a6a6 --- /dev/null +++ b/web/src/pages/agent/form/agent-form/use-values.ts @@ -0,0 +1,30 @@ +import { useFetchModelId } from '@/hooks/logic-hooks'; +import { RAGFlowNodeType } from '@/interfaces/database/flow'; +import { get, isEmpty } from 'lodash'; +import { useMemo } from 'react'; +import { initialAgentValues } from '../../constant'; + +export function useValues(node?: RAGFlowNodeType) { + const llmId = useFetchModelId(); + + const defaultValues = useMemo( + () => ({ + ...initialAgentValues, + llm_id: llmId, + prompts: '', + }), + [llmId], + ); + + const values = useMemo(() => { + const formData = node?.data?.form; + + if (isEmpty(formData)) { + return defaultValues; + } + + return { ...formData, prompts: get(formData, 'prompts.0.content', '') }; + }, [defaultValues, node?.data?.form]); + + return values; +} diff --git a/web/src/pages/agent/form/agent-form/use-watch-change.ts b/web/src/pages/agent/form/agent-form/use-watch-change.ts new file mode 100644 index 00000000000..8640a45184a --- /dev/null +++ b/web/src/pages/agent/form/agent-form/use-watch-change.ts @@ -0,0 +1,22 @@ +import { useEffect } from 'react'; +import { UseFormReturn, useWatch } from 'react-hook-form'; +import { PromptRole } from '../../constant'; +import useGraphStore from '../../store'; + +export function useWatchFormChange(id?: string, form?: UseFormReturn) { + let values = useWatch({ control: form?.control }); + const updateNodeForm = useGraphStore((state) => state.updateNodeForm); + + useEffect(() => { + // Manually triggered form updates are synchronized to the canvas + if (id && form?.formState.isDirty) { + values = form?.getValues(); + let nextValues: any = { + ...values, + prompts: [{ role: PromptRole.User, content: values.prompts }], + }; + + updateNodeForm(id, nextValues); + } + }, [form?.formState.isDirty, id, updateNodeForm, values]); +} diff --git a/web/src/pages/agent/form/begin-form/begin-dynamic-options.tsx b/web/src/pages/agent/form/begin-form/begin-dynamic-options.tsx index bcc9c578843..d71da8d2bbf 100644 --- a/web/src/pages/agent/form/begin-form/begin-dynamic-options.tsx +++ b/web/src/pages/agent/form/begin-form/begin-dynamic-options.tsx @@ -1,68 +1,57 @@ -import { MinusCircleOutlined, PlusOutlined } from '@ant-design/icons'; -import { Button, Form, Input } from 'antd'; +'use client'; + +import { BlockButton, Button } from '@/components/ui/button'; +import { + FormControl, + FormField, + FormItem, + FormMessage, +} from '@/components/ui/form'; +import { Input } from '@/components/ui/input'; +import { X } from 'lucide-react'; +import { useFieldArray, useFormContext } from 'react-hook-form'; +import { useTranslation } from 'react-i18next'; + +export function BeginDynamicOptions() { + const { t } = useTranslation(); + const form = useFormContext(); + const name = 'options'; + + const { fields, remove, append } = useFieldArray({ + name: name, + control: form.control, + }); -const BeginDynamicOptions = () => { return ( - { - if (!names || names.length < 1) { - return Promise.reject(new Error('At least 1 option')); - } - }, - }, - ]} - > - {(fields, { add, remove }, { errors }) => ( - <> - {fields.map((field, index) => ( - - - - - {fields.length > 1 ? ( - remove(field.name)} - /> - ) : null} - - ))} - - - - - - )} - +
    + ); + })} + append({ value: '' })} type="button"> + {t('flow.addVariable')} + +
    ); -}; - -export default BeginDynamicOptions; +} diff --git a/web/src/pages/agent/form/begin-form/index.less b/web/src/pages/agent/form/begin-form/index.less deleted file mode 100644 index 0a03d47433c..00000000000 --- a/web/src/pages/agent/form/begin-form/index.less +++ /dev/null @@ -1,24 +0,0 @@ -.dynamicInputVariable { - background-color: #ebe9e950; - :global(.ant-collapse-content) { - background-color: #f6f6f657; - } - :global(.ant-collapse-content-box) { - padding: 0 !important; - } - margin-bottom: 20px; - .title { - font-weight: 600; - font-size: 16px; - } - - .addButton { - color: rgb(22, 119, 255); - font-weight: 600; - } -} - -.addButton { - color: rgb(22, 119, 255); - font-weight: 600; -} diff --git a/web/src/pages/agent/form/begin-form/index.tsx b/web/src/pages/agent/form/begin-form/index.tsx index 8df1181f1fa..95c519abbab 100644 --- a/web/src/pages/agent/form/begin-form/index.tsx +++ b/web/src/pages/agent/form/begin-form/index.tsx @@ -1,20 +1,74 @@ -import { PlusOutlined } from '@ant-design/icons'; -import { Button, Form, Input } from 'antd'; -import { useCallback } from 'react'; +import { Collapse } from '@/components/collapse'; +import { Button } from '@/components/ui/button'; +import { + Form, + FormControl, + FormField, + FormItem, + FormLabel, + FormMessage, +} from '@/components/ui/form'; +import { RAGFlowSelect } from '@/components/ui/select'; +import { Switch } from '@/components/ui/switch'; +import { Textarea } from '@/components/ui/textarea'; +import { FormTooltip } from '@/components/ui/tooltip'; +import { buildSelectOptions } from '@/utils/component-util'; +import { zodResolver } from '@hookform/resolvers/zod'; +import { Plus } from 'lucide-react'; +import { useForm, useWatch } from 'react-hook-form'; import { useTranslation } from 'react-i18next'; -import { BeginQuery, IOperatorForm } from '../../interface'; -import { useEditQueryRecord } from './hooks'; -import { ModalForm } from './paramater-modal'; -import QueryTable from './query-table'; +import { z } from 'zod'; +import { AgentDialogueMode } from '../../constant'; +import { INextOperatorForm } from '../../interface'; +import { ParameterDialog } from './parameter-dialog'; +import { QueryTable } from './query-table'; +import { useEditQueryRecord } from './use-edit-query'; +import { useValues } from './use-values'; +import { useWatchFormChange } from './use-watch-change'; -import styles from './index.less'; +const ModeOptions = buildSelectOptions([ + AgentDialogueMode.Conversational, + AgentDialogueMode.Task, +]); -type FieldType = { - prologue?: string; -}; - -const BeginForm = ({ onValuesChange, form }: IOperatorForm) => { +const BeginForm = ({ node }: INextOperatorForm) => { const { t } = useTranslation(); + + const values = useValues(node); + + const FormSchema = z.object({ + enablePrologue: z.boolean().optional(), + prologue: z.string().trim().optional(), + mode: z.string(), + inputs: z + .array( + z.object({ + key: z.string(), + type: z.string(), + value: z.string(), + optional: z.boolean(), + name: z.string(), + options: z.array(z.union([z.number(), z.string(), z.boolean()])), + }), + ) + .optional(), + }); + + const form = useForm({ + defaultValues: values, + resolver: zodResolver(FormSchema), + }); + + useWatchFormChange(node?.id, form); + + const inputs = useWatch({ control: form.control, name: 'inputs' }); + const mode = useWatch({ control: form.control, name: 'mode' }); + + const enablePrologue = useWatch({ + control: form.control, + name: 'enablePrologue', + }); + const { ok, currentRecord, @@ -22,89 +76,115 @@ const BeginForm = ({ onValuesChange, form }: IOperatorForm) => { hideModal, showModal, otherThanCurrentQuery, + handleDeleteRecord, } = useEditQueryRecord({ form, - onValuesChange, + node, }); - const handleDeleteRecord = useCallback( - (idx: number) => { - const query = form?.getFieldValue('query') || []; - const nextQuery = query.filter( - (item: BeginQuery, index: number) => index !== idx, - ); - onValuesChange?.( - { query: nextQuery }, - { query: nextQuery, prologue: form?.getFieldValue('prologue') }, - ); - }, - [form, onValuesChange], - ); - return ( - { - if (name === 'queryForm') { - ok(values as BeginQuery); - } - }} - > -
    - - name={'prologue'} - label={t('chat.setAnOpener')} - tooltip={t('chat.setAnOpenerTip')} - initialValue={t('chat.setAnOpenerInitial')} - > - - +
    + + ( + + Mode + + + + + + )} + /> + {mode === AgentDialogueMode.Conversational && ( + ( + + + {t('flow.openingSwitch')} + + + + + + + )} + /> + )} + {enablePrologue && ( + ( + + + {t('flow.openingCopy')} + + + + + + + )} + /> + )} {/* Create a hidden field to make Form instance record this */} - - - - prevValues.query !== curValues.query +
    } + /> + + {t('flow.input')} + +
    + } + rightContent={ + } > - {({ getFieldValue }) => { - const query: BeginQuery[] = getFieldValue('query') || []; - return ( - - ); - }} - + + - {visible && ( - + submit={ok} + > )} - + ); }; diff --git a/web/src/pages/agent/form/begin-form/paramater-modal.tsx b/web/src/pages/agent/form/begin-form/paramater-modal.tsx deleted file mode 100644 index 7d689601b14..00000000000 --- a/web/src/pages/agent/form/begin-form/paramater-modal.tsx +++ /dev/null @@ -1,124 +0,0 @@ -import { useResetFormOnCloseModal } from '@/hooks/logic-hooks'; -import { IModalProps } from '@/interfaces/common'; -import { Form, Input, Modal, Select, Switch } from 'antd'; -import { DefaultOptionType } from 'antd/es/select'; -import { useEffect, useMemo } from 'react'; -import { useTranslation } from 'react-i18next'; -import { BeginQueryType, BeginQueryTypeIconMap } from '../../constant'; -import { BeginQuery } from '../../interface'; -import BeginDynamicOptions from './begin-dynamic-options'; - -export const ModalForm = ({ - visible, - initialValue, - hideModal, - otherThanCurrentQuery, -}: IModalProps & { - initialValue: BeginQuery; - otherThanCurrentQuery: BeginQuery[]; -}) => { - const { t } = useTranslation(); - const [form] = Form.useForm(); - const options = useMemo(() => { - return Object.values(BeginQueryType).reduce( - (pre, cur) => { - const Icon = BeginQueryTypeIconMap[cur]; - - return [ - ...pre, - { - label: ( -
    - - {cur} -
    - ), - value: cur, - }, - ]; - }, - [], - ); - }, []); - - useResetFormOnCloseModal({ - form, - visible: visible, - }); - - useEffect(() => { - form.setFieldsValue(initialValue); - }, [form, initialValue]); - - const onOk = () => { - form.submit(); - }; - - return ( - -
    - - - - - - - - - - - prevValues.type !== curValues.type - } - > - {({ getFieldValue }) => { - const type: BeginQueryType = getFieldValue('type'); - return ( - type === BeginQueryType.Options && ( - - ) - ); - }} - -
    -
    - ); -}; diff --git a/web/src/pages/agent/form/begin-form/parameter-dialog.tsx b/web/src/pages/agent/form/begin-form/parameter-dialog.tsx new file mode 100644 index 00000000000..1ed08a132f3 --- /dev/null +++ b/web/src/pages/agent/form/begin-form/parameter-dialog.tsx @@ -0,0 +1,217 @@ +import { Button } from '@/components/ui/button'; +import { + Dialog, + DialogContent, + DialogFooter, + DialogHeader, + DialogTitle, +} from '@/components/ui/dialog'; +import { + Form, + FormControl, + FormField, + FormItem, + FormLabel, + FormMessage, +} from '@/components/ui/form'; +import { Input } from '@/components/ui/input'; +import { RAGFlowSelect, RAGFlowSelectOptionType } from '@/components/ui/select'; +import { Switch } from '@/components/ui/switch'; +import { IModalProps } from '@/interfaces/common'; +import { zodResolver } from '@hookform/resolvers/zod'; +import { isEmpty } from 'lodash'; +import { useEffect, useMemo } from 'react'; +import { useForm, useWatch } from 'react-hook-form'; +import { useTranslation } from 'react-i18next'; +import { z } from 'zod'; +import { BeginQueryType, BeginQueryTypeIconMap } from '../../constant'; +import { BeginQuery } from '../../interface'; +import { BeginDynamicOptions } from './begin-dynamic-options'; + +type ModalFormProps = { + initialValue: BeginQuery; + otherThanCurrentQuery: BeginQuery[]; + submit(values: any): void; +}; + +const FormId = 'BeginParameterForm'; + +function ParameterForm({ + initialValue, + otherThanCurrentQuery, + submit, +}: ModalFormProps) { + const FormSchema = z.object({ + type: z.string(), + key: z + .string() + .trim() + .min(1) + .refine( + (value) => + !value || !otherThanCurrentQuery.some((x) => x.key === value), + { message: 'The key cannot be repeated!' }, + ), + optional: z.boolean(), + name: z.string().trim().min(1), + options: z + .array(z.object({ value: z.string().or(z.boolean()).or(z.number()) })) + .optional(), + }); + + const form = useForm>({ + resolver: zodResolver(FormSchema), + mode: 'onChange', + defaultValues: { + type: BeginQueryType.Line, + optional: false, + key: '', + name: '', + options: [], + }, + }); + + const options = useMemo(() => { + return Object.values(BeginQueryType).reduce( + (pre, cur) => { + const Icon = BeginQueryTypeIconMap[cur]; + + return [ + ...pre, + { + label: ( +
    + + {cur} +
    + ), + value: cur, + }, + ]; + }, + [], + ); + }, []); + + const type = useWatch({ + control: form.control, + name: 'type', + }); + + useEffect(() => { + if (!isEmpty(initialValue)) { + form.reset({ + ...initialValue, + options: initialValue.options?.map((x) => ({ value: x })), + }); + } + }, [form, initialValue]); + + function onSubmit(data: z.infer) { + const values = { ...data, options: data.options?.map((x) => x.value) }; + console.log('🚀 ~ onSubmit ~ values:', values); + + submit(values); + } + + return ( +
    + + ( + + Type + + + + + + )} + /> + ( + + Key + + + + + + )} + /> + ( + + Name + + + + + + )} + /> + ( + + Optional + + + + + + )} + /> + {type === BeginQueryType.Options && ( + + )} + + + ); +} + +export function ParameterDialog({ + initialValue, + hideModal, + otherThanCurrentQuery, + submit, +}: ModalFormProps & IModalProps) { + const { t } = useTranslation(); + + return ( + + + + {t('flow.variableSettings')} + + + + + + + + ); +} diff --git a/web/src/pages/agent/form/begin-form/query-table.tsx b/web/src/pages/agent/form/begin-form/query-table.tsx index c7614e682b4..463a2fa645b 100644 --- a/web/src/pages/agent/form/begin-form/query-table.tsx +++ b/web/src/pages/agent/form/begin-form/query-table.tsx @@ -1,10 +1,38 @@ -import { DeleteOutlined, EditOutlined } from '@ant-design/icons'; -import type { TableProps } from 'antd'; -import { Collapse, Space, Table, Tooltip } from 'antd'; -import { BeginQuery } from '../../interface'; +'use client'; + +import { + ColumnDef, + ColumnFiltersState, + SortingState, + VisibilityState, + flexRender, + getCoreRowModel, + getFilteredRowModel, + getPaginationRowModel, + getSortedRowModel, + useReactTable, +} from '@tanstack/react-table'; +import { Pencil, Trash2 } from 'lucide-react'; +import * as React from 'react'; +import { TableEmpty } from '@/components/table-skeleton'; +import { Button } from '@/components/ui/button'; +import { + Table, + TableBody, + TableCell, + TableHead, + TableHeader, + TableRow, +} from '@/components/ui/table'; +import { + Tooltip, + TooltipContent, + TooltipTrigger, +} from '@/components/ui/tooltip'; +import { cn } from '@/lib/utils'; import { useTranslation } from 'react-i18next'; -import styles from './index.less'; +import { BeginQuery } from '../../interface'; interface IProps { data: BeginQuery[]; @@ -12,81 +40,150 @@ interface IProps { showModal(index: number, record: BeginQuery): void; } -const QueryTable = ({ data, deleteRecord, showModal }: IProps) => { +export function QueryTable({ data = [], deleteRecord, showModal }: IProps) { const { t } = useTranslation(); - const columns: TableProps['columns'] = [ + const [sorting, setSorting] = React.useState([]); + const [columnFilters, setColumnFilters] = React.useState( + [], + ); + const [columnVisibility, setColumnVisibility] = + React.useState({}); + + const columns: ColumnDef[] = [ { - title: 'Key', - dataIndex: 'key', - key: 'key', - ellipsis: { - showTitle: false, + accessorKey: 'key', + header: 'key', + meta: { cellClassName: 'max-w-16' }, + cell: ({ row }) => { + const key: string = row.getValue('key'); + return ( + + +
    {key}
    +
    + +

    {key}

    +
    +
    + ); }, - render: (key) => ( - - {key} - - ), }, { - title: t('flow.name'), - dataIndex: 'name', - key: 'name', - ellipsis: { - showTitle: false, + accessorKey: 'name', + header: t('flow.name'), + meta: { cellClassName: 'max-w-20' }, + cell: ({ row }) => { + const name: string = row.getValue('name'); + return ( + + +
    {name}
    +
    + +

    {name}

    +
    +
    + ); }, - render: (name) => ( - - {name} - - ), }, { - title: t('flow.type'), - dataIndex: 'type', - key: 'type', + accessorKey: 'type', + header: t('flow.type'), + cell: ({ row }) =>
    {row.getValue('type')}
    , }, { - title: t('flow.optional'), - dataIndex: 'optional', - key: 'optional', - render: (optional) => (optional ? 'Yes' : 'No'), + accessorKey: 'optional', + header: t('flow.optional'), + cell: ({ row }) =>
    {row.getValue('optional') ? 'Yes' : 'No'}
    , }, { - title: t('common.action'), - key: 'action', - render: (_, record, idx) => ( - - showModal(idx, record)} /> - deleteRecord(idx)} - /> - - ), + id: 'actions', + enableHiding: false, + header: t('common.action'), + cell: ({ row }) => { + const record = row.original; + const idx = row.index; + + return ( +
    + + +
    + ); + }, }, ]; + const table = useReactTable({ + data, + columns, + onSortingChange: setSorting, + onColumnFiltersChange: setColumnFilters, + getCoreRowModel: getCoreRowModel(), + getPaginationRowModel: getPaginationRowModel(), + getSortedRowModel: getSortedRowModel(), + getFilteredRowModel: getFilteredRowModel(), + onColumnVisibilityChange: setColumnVisibility, + state: { + sorting, + columnFilters, + columnVisibility, + }, + }); + return ( - {t('flow.input')}, - children: ( - - columns={columns} - dataSource={data} - pagination={false} - /> - ), - }, - ]} - /> +
    +
    + + + {table.getHeaderGroups().map((headerGroup) => ( + + {headerGroup.headers.map((header) => { + return ( + + {header.isPlaceholder + ? null + : flexRender( + header.column.columnDef.header, + header.getContext(), + )} + + ); + })} + + ))} + + + {table.getRowModel().rows?.length ? ( + table.getRowModel().rows.map((row) => ( + + {row.getVisibleCells().map((cell) => ( + + {flexRender( + cell.column.columnDef.cell, + cell.getContext(), + )} + + ))} + + )) + ) : ( + + )} + +
    +
    +
    ); -}; - -export default QueryTable; +} diff --git a/web/src/pages/agent/form/begin-form/hooks.ts b/web/src/pages/agent/form/begin-form/use-edit-query.ts similarity index 50% rename from web/src/pages/agent/form/begin-form/hooks.ts rename to web/src/pages/agent/form/begin-form/use-edit-query.ts index b045f5dc50d..a1bec8b3d10 100644 --- a/web/src/pages/agent/form/begin-form/hooks.ts +++ b/web/src/pages/agent/form/begin-form/use-edit-query.ts @@ -1,32 +1,34 @@ import { useSetModalState } from '@/hooks/common-hooks'; import { useSetSelectedRecord } from '@/hooks/logic-hooks'; import { useCallback, useMemo, useState } from 'react'; -import { BeginQuery, IOperatorForm } from '../../interface'; +import { BeginQuery, INextOperatorForm } from '../../interface'; -export const useEditQueryRecord = ({ form, onValuesChange }: IOperatorForm) => { +export const useEditQueryRecord = ({ form, node }: INextOperatorForm) => { const { setRecord, currentRecord } = useSetSelectedRecord(); const { visible, hideModal, showModal } = useSetModalState(); const [index, setIndex] = useState(-1); const otherThanCurrentQuery = useMemo(() => { - const query: BeginQuery[] = form?.getFieldValue('query') || []; - return query.filter((item, idx) => idx !== index); + const inputs: BeginQuery[] = form?.getValues('inputs') || []; + return inputs.filter((item, idx) => idx !== index); }, [form, index]); const handleEditRecord = useCallback( (record: BeginQuery) => { - const query: BeginQuery[] = form?.getFieldValue('query') || []; + const inputs: BeginQuery[] = form?.getValues('inputs') || []; + console.log('🚀 ~ useEditQueryRecord ~ inputs:', inputs); const nextQuery: BeginQuery[] = - index > -1 ? query.toSpliced(index, 1, record) : [...query, record]; + index > -1 ? inputs.toSpliced(index, 1, record) : [...inputs, record]; + + form.setValue('inputs', nextQuery, { + shouldDirty: true, + shouldTouch: true, + }); - onValuesChange?.( - { query: nextQuery }, - { query: nextQuery, prologue: form?.getFieldValue('prologue') }, - ); hideModal(); }, - [form, hideModal, index, onValuesChange], + [form, hideModal, index], ); const handleShowModal = useCallback( @@ -38,6 +40,18 @@ export const useEditQueryRecord = ({ form, onValuesChange }: IOperatorForm) => { [setRecord, showModal], ); + const handleDeleteRecord = useCallback( + (idx: number) => { + const inputs = form?.getValues('inputs') || []; + const nextQuery = inputs.filter( + (item: BeginQuery, index: number) => index !== idx, + ); + + form.setValue('inputs', nextQuery, { shouldDirty: true }); + }, + [form], + ); + return { ok: handleEditRecord, currentRecord, @@ -46,5 +60,6 @@ export const useEditQueryRecord = ({ form, onValuesChange }: IOperatorForm) => { hideModal, showModal: handleShowModal, otherThanCurrentQuery, + handleDeleteRecord, }; }; diff --git a/web/src/pages/agent/form/begin-form/use-values.ts b/web/src/pages/agent/form/begin-form/use-values.ts new file mode 100644 index 00000000000..10326bae83c --- /dev/null +++ b/web/src/pages/agent/form/begin-form/use-values.ts @@ -0,0 +1,34 @@ +import { RAGFlowNodeType } from '@/interfaces/database/flow'; +import { isEmpty } from 'lodash'; +import { useMemo } from 'react'; +import { useTranslation } from 'react-i18next'; +import { AgentDialogueMode } from '../../constant'; +import { buildBeginInputListFromObject } from './utils'; + +export function useValues(node?: RAGFlowNodeType) { + const { t } = useTranslation(); + + const defaultValues = useMemo( + () => ({ + enablePrologue: true, + prologue: t('chat.setAnOpenerInitial'), + mode: AgentDialogueMode.Conversational, + inputs: [], + }), + [t], + ); + + const values = useMemo(() => { + const formData = node?.data?.form; + + if (isEmpty(formData)) { + return defaultValues; + } + + const inputs = buildBeginInputListFromObject(formData?.inputs); + + return { ...(formData || {}), inputs }; + }, [defaultValues, node?.data?.form]); + + return values; +} diff --git a/web/src/pages/agent/form/begin-form/use-watch-change.ts b/web/src/pages/agent/form/begin-form/use-watch-change.ts new file mode 100644 index 00000000000..3dc45126589 --- /dev/null +++ b/web/src/pages/agent/form/begin-form/use-watch-change.ts @@ -0,0 +1,31 @@ +import { omit } from 'lodash'; +import { useEffect } from 'react'; +import { UseFormReturn, useWatch } from 'react-hook-form'; +import { BeginQuery } from '../../interface'; +import useGraphStore from '../../store'; + +function transferInputsArrayToObject(inputs: BeginQuery[] = []) { + return inputs.reduce>>((pre, cur) => { + pre[cur.key] = omit(cur, 'key'); + + return pre; + }, {}); +} + +export function useWatchFormChange(id?: string, form?: UseFormReturn) { + let values = useWatch({ control: form?.control }); + const updateNodeForm = useGraphStore((state) => state.updateNodeForm); + + useEffect(() => { + if (id && form?.formState.isDirty) { + values = form?.getValues(); + + const nextValues = { + ...values, + inputs: transferInputsArrayToObject(values.inputs), + }; + + updateNodeForm(id, nextValues); + } + }, [form?.formState.isDirty, id, updateNodeForm, values]); +} diff --git a/web/src/pages/agent/form/begin-form/utils.ts b/web/src/pages/agent/form/begin-form/utils.ts new file mode 100644 index 00000000000..36038c4f6d2 --- /dev/null +++ b/web/src/pages/agent/form/begin-form/utils.ts @@ -0,0 +1,14 @@ +import { BeginQuery } from '../../interface'; + +export function buildBeginInputListFromObject( + inputs: Record>, +) { + return Object.entries(inputs || {}).reduce( + (pre, [key, value]) => { + pre.push({ ...(value || {}), key }); + + return pre; + }, + [], + ); +} diff --git a/web/src/pages/agent/form/categorize-form/dynamic-categorize.tsx b/web/src/pages/agent/form/categorize-form/dynamic-categorize.tsx index 6302e032a22..7deb4f4ff1e 100644 --- a/web/src/pages/agent/form/categorize-form/dynamic-categorize.tsx +++ b/web/src/pages/agent/form/categorize-form/dynamic-categorize.tsx @@ -12,8 +12,7 @@ import { FormMessage, } from '@/components/ui/form'; import { Input } from '@/components/ui/input'; -import { RAGFlowSelect } from '@/components/ui/select'; -import { Textarea } from '@/components/ui/textarea'; +import { BlurTextarea } from '@/components/ui/textarea'; import { useTranslate } from '@/hooks/common-hooks'; import { PlusOutlined } from '@ant-design/icons'; import { useUpdateNodeInternals } from '@xyflow/react'; @@ -23,6 +22,7 @@ import { ChevronsUpDown, X } from 'lucide-react'; import { ChangeEventHandler, FocusEventHandler, + memo, useCallback, useEffect, useState, @@ -30,6 +30,7 @@ import { import { UseFormReturn, useFieldArray, useFormContext } from 'react-hook-form'; import { Operator } from '../../constant'; import { useBuildFormSelectOptions } from '../../form-hooks'; +import DynamicExample from './dynamic-example'; interface IProps { nodeId?: string; @@ -55,7 +56,7 @@ const getOtherFieldValues = ( x !== form.getValues(`${formListName}.${index}.${latestField}`), ); -const NameInput = ({ +const InnerNameInput = ({ value, onChange, otherNames, @@ -104,7 +105,9 @@ const NameInput = ({ ); }; -const FormSet = ({ nodeId, index }: IProps & { index: number }) => { +const NameInput = memo(InnerNameInput); + +const InnerFormSet = ({ nodeId, index }: IProps & { index: number }) => { const form = useFormContext(); const { t } = useTranslate('flow'); const buildCategorizeToOptions = useBuildFormSelectOptions( @@ -152,61 +155,19 @@ const FormSet = ({ nodeId, index }: IProps & { index: number }) => { {t('description')} - + + + )} + /> + {index === 0 ? ( + + ) : ( + + )} +
    + ))} +
    + +
    + ); +}; + +export default memo(DynamicExample); diff --git a/web/src/pages/agent/form/categorize-form/index.tsx b/web/src/pages/agent/form/categorize-form/index.tsx index 220c17deddc..a4faef25453 100644 --- a/web/src/pages/agent/form/categorize-form/index.tsx +++ b/web/src/pages/agent/form/categorize-form/index.tsx @@ -1,22 +1,67 @@ +import { FormContainer } from '@/components/form-container'; import { LargeModelFormField } from '@/components/large-model-form-field'; +import { LlmSettingSchema } from '@/components/llm-setting-items/next'; import { MessageHistoryWindowSizeFormField } from '@/components/message-history-window-size-item'; import { Form } from '@/components/ui/form'; +import { zodResolver } from '@hookform/resolvers/zod'; +import { useForm } from 'react-hook-form'; +import { useTranslation } from 'react-i18next'; +import { z } from 'zod'; import { INextOperatorForm } from '../../interface'; -import { DynamicInputVariable } from '../components/next-dynamic-input-variable'; +import { QueryVariable } from '../components/query-variable'; import DynamicCategorize from './dynamic-categorize'; +import { useValues } from './use-values'; +import { useWatchFormChange } from './use-watch-change'; + +const CategorizeForm = ({ node }: INextOperatorForm) => { + const { t } = useTranslation(); + + const values = useValues(node); + + const FormSchema = z.object({ + query: z.string().optional(), + parameter: z.string().optional(), + ...LlmSettingSchema, + message_history_window_size: z.coerce.number(), + items: z.array( + z + .object({ + name: z.string().min(1, t('flow.nameMessage')).trim(), + description: z.string().optional(), + examples: z + .array( + z.object({ + value: z.string(), + }), + ) + .optional(), + }) + .optional(), + ), + }); + + const form = useForm({ + defaultValues: values, + resolver: zodResolver(FormSchema), + }); + + useWatchFormChange(node?.id, form); -const CategorizeForm = ({ form, node }: INextOperatorForm) => { return (
    { e.preventDefault(); }} > - - - + + + + +
    diff --git a/web/src/pages/agent/form/categorize-form/use-values.ts b/web/src/pages/agent/form/categorize-form/use-values.ts new file mode 100644 index 00000000000..ef75575dcc6 --- /dev/null +++ b/web/src/pages/agent/form/categorize-form/use-values.ts @@ -0,0 +1,38 @@ +import { ModelVariableType } from '@/constants/knowledge'; +import { RAGFlowNodeType } from '@/interfaces/database/flow'; +import { get, isEmpty, isPlainObject, omit } from 'lodash'; +import { useMemo } from 'react'; +import { buildCategorizeListFromObject } from '../../utils'; + +const defaultValues = { + parameter: ModelVariableType.Precise, + message_history_window_size: 1, + temperatureEnabled: true, + topPEnabled: true, + presencePenaltyEnabled: true, + frequencyPenaltyEnabled: true, + maxTokensEnabled: true, + items: [], +}; + +export function useValues(node?: RAGFlowNodeType) { + const values = useMemo(() => { + const formData = node?.data?.form; + if (isEmpty(formData)) { + return defaultValues; + } + const items = buildCategorizeListFromObject( + get(node, 'data.form.category_description', {}), + ); + if (isPlainObject(formData)) { + const nextValues = { + ...omit(formData, 'category_description'), + items, + }; + + return nextValues; + } + }, [node]); + + return values; +} diff --git a/web/src/pages/agent/form/categorize-form/use-watch-change.ts b/web/src/pages/agent/form/categorize-form/use-watch-change.ts new file mode 100644 index 00000000000..6f01dc1a9d9 --- /dev/null +++ b/web/src/pages/agent/form/categorize-form/use-watch-change.ts @@ -0,0 +1,30 @@ +import { omit } from 'lodash'; +import { useEffect } from 'react'; +import { UseFormReturn, useWatch } from 'react-hook-form'; +import useGraphStore from '../../store'; +import { buildCategorizeObjectFromList } from '../../utils'; + +export function useWatchFormChange(id?: string, form?: UseFormReturn) { + let values = useWatch({ control: form?.control }); + const updateNodeForm = useGraphStore((state) => state.updateNodeForm); + + useEffect(() => { + // Manually triggered form updates are synchronized to the canvas + if (id && form?.formState.isDirty) { + values = form?.getValues(); + let nextValues: any = values; + + const categoryDescription = Array.isArray(values.items) + ? buildCategorizeObjectFromList(values.items) + : {}; + if (categoryDescription) { + nextValues = { + ...omit(values, 'items'), + category_description: categoryDescription, + }; + } + + updateNodeForm(id, nextValues); + } + }, [form?.formState.isDirty, id, updateNodeForm, values]); +} diff --git a/web/src/pages/agent/form/code-form/dynamic-input-variable.tsx b/web/src/pages/agent/form/code-form/dynamic-input-variable.tsx new file mode 100644 index 00000000000..c207d280436 --- /dev/null +++ b/web/src/pages/agent/form/code-form/dynamic-input-variable.tsx @@ -0,0 +1,59 @@ +import { BlockButton } from '@/components/ui/button'; +import { RAGFlowNodeType } from '@/interfaces/database/flow'; +import { MinusCircleOutlined } from '@ant-design/icons'; +import { Form, Input, Select } from 'antd'; +import { useTranslation } from 'react-i18next'; +import { useBuildVariableOptions } from '../../hooks/use-get-begin-query'; +import { FormCollapse } from '../components/dynamic-input-variable'; + +type DynamicInputVariableProps = { + name?: string; + node?: RAGFlowNodeType; +}; + +export const DynamicInputVariable = ({ + name = 'arguments', + node, +}: DynamicInputVariableProps) => { + const { t } = useTranslation(); + + const valueOptions = useBuildVariableOptions(node?.id, node?.parentId); + + return ( + + + {(fields, { add, remove }) => ( + <> + {fields.map(({ key, name, ...restField }) => ( +
    + + + + + + + remove(name)} /> +
    + ))} + + add()}> + {t('flow.addVariable')} + + + + )} +
    +
    + ); +}; diff --git a/web/src/pages/agent/form/code-form/index.tsx b/web/src/pages/agent/form/code-form/index.tsx new file mode 100644 index 00000000000..71e0f6f52cd --- /dev/null +++ b/web/src/pages/agent/form/code-form/index.tsx @@ -0,0 +1,174 @@ +import Editor, { loader } from '@monaco-editor/react'; +import { INextOperatorForm } from '../../interface'; + +import { FormContainer } from '@/components/form-container'; +import { + Form, + FormControl, + FormField, + FormItem, + FormLabel, + FormMessage, +} from '@/components/ui/form'; +import { Input } from '@/components/ui/input'; +import { RAGFlowSelect } from '@/components/ui/select'; +import { ProgrammingLanguage } from '@/constants/agent'; +import { ICodeForm } from '@/interfaces/database/flow'; +import { zodResolver } from '@hookform/resolvers/zod'; +import { useForm } from 'react-hook-form'; +import { useTranslation } from 'react-i18next'; +import { z } from 'zod'; +import { + DynamicInputVariable, + TypeOptions, + VariableTitle, +} from './next-variable'; +import { useValues } from './use-values'; +import { + useHandleLanguageChange, + useWatchFormChange, +} from './use-watch-change'; + +loader.config({ paths: { vs: '/vs' } }); + +const options = [ + ProgrammingLanguage.Python, + ProgrammingLanguage.Javascript, +].map((x) => ({ value: x, label: x })); + +const CodeForm = ({ node }: INextOperatorForm) => { + const formData = node?.data.form as ICodeForm; + const { t } = useTranslation(); + const values = useValues(node); + + const FormSchema = z.object({ + lang: z.string(), + script: z.string(), + arguments: z.array( + z.object({ name: z.string(), component_id: z.string() }), + ), + return: z.union([ + z + .array(z.object({ name: z.string(), component_id: z.string() })) + .optional(), + z.object({ name: z.string(), component_id: z.string() }), + ]), + }); + + const form = useForm({ + defaultValues: values, + resolver: zodResolver(FormSchema), + }); + + useWatchFormChange(node?.id, form); + + const handleLanguageChange = useHandleLanguageChange(node?.id, form); + + return ( +
    + { + e.preventDefault(); + }} + > + + ( + + + Code + ( + + + { + field.onChange(val); + handleLanguageChange(val); + }} + options={options} + /> + + + + )} + /> + + + + + + + )} + /> + + {formData.lang === ProgrammingLanguage.Python ? ( + + ) : ( +
    + + + ( + + Name + + + + + + )} + /> + ( + + Type + + + + + + )} + /> + +
    + )} + + + ); +}; + +export default CodeForm; diff --git a/web/src/pages/agent/form/code-form/next-variable.tsx b/web/src/pages/agent/form/code-form/next-variable.tsx new file mode 100644 index 00000000000..fe668f41b75 --- /dev/null +++ b/web/src/pages/agent/form/code-form/next-variable.tsx @@ -0,0 +1,118 @@ +'use client'; + +import { FormContainer } from '@/components/form-container'; +import { BlockButton, Button } from '@/components/ui/button'; +import { + FormControl, + FormField, + FormItem, + FormMessage, +} from '@/components/ui/form'; +import { BlurInput } from '@/components/ui/input'; +import { RAGFlowSelect } from '@/components/ui/select'; +import { Separator } from '@/components/ui/separator'; +import { RAGFlowNodeType } from '@/interfaces/database/flow'; +import { X } from 'lucide-react'; +import { ReactNode } from 'react'; +import { useFieldArray, useFormContext } from 'react-hook-form'; +import { useTranslation } from 'react-i18next'; +import { useBuildVariableOptions } from '../../hooks/use-get-begin-query'; + +interface IProps { + node?: RAGFlowNodeType; + name?: string; +} + +export const TypeOptions = [ + 'String', + 'Number', + 'Boolean', + 'Array[String]', + 'Array[Number]', + 'Object', +].map((x) => ({ label: x, value: x })); + +export function DynamicVariableForm({ node, name = 'arguments' }: IProps) { + const { t } = useTranslation(); + const form = useFormContext(); + + const { fields, remove, append } = useFieldArray({ + name: name, + control: form.control, + }); + + const valueOptions = useBuildVariableOptions(node?.id, node?.parentId); + + return ( +
    + {fields.map((field, index) => { + const typeField = `${name}.${index}.name`; + return ( +
    + ( + + + + + + + )} + /> + + ( + + + + + + + )} + /> + +
    + ); + })} + append({ name: '', component_id: undefined })} + > + {t('flow.addVariable')} + +
    + ); +} + +export function VariableTitle({ title }: { title: ReactNode }) { + return
    {title}
    ; +} + +export function DynamicInputVariable({ + node, + name, + title, +}: IProps & { title: ReactNode }) { + return ( +
    + + + + +
    + ); +} diff --git a/web/src/pages/agent/form/code-form/use-values.ts b/web/src/pages/agent/form/code-form/use-values.ts new file mode 100644 index 00000000000..9dc7d68d4c1 --- /dev/null +++ b/web/src/pages/agent/form/code-form/use-values.ts @@ -0,0 +1,27 @@ +import { CodeTemplateStrMap, ProgrammingLanguage } from '@/constants/agent'; +import { RAGFlowNodeType } from '@/interfaces/database/flow'; +import { isEmpty } from 'lodash'; +import { useMemo } from 'react'; + +export function useValues(node?: RAGFlowNodeType) { + const defaultValues = useMemo( + () => ({ + lang: ProgrammingLanguage.Python, + script: CodeTemplateStrMap[ProgrammingLanguage.Python], + arguments: [], + }), + [], + ); + + const values = useMemo(() => { + const formData = node?.data?.form; + + if (isEmpty(formData)) { + return defaultValues; + } + + return formData; + }, [defaultValues, node?.data?.form]); + + return values; +} diff --git a/web/src/pages/agent/form/code-form/use-watch-change.ts b/web/src/pages/agent/form/code-form/use-watch-change.ts new file mode 100644 index 00000000000..7c97595602e --- /dev/null +++ b/web/src/pages/agent/form/code-form/use-watch-change.ts @@ -0,0 +1,36 @@ +import { CodeTemplateStrMap, ProgrammingLanguage } from '@/constants/agent'; +import { useCallback, useEffect } from 'react'; +import { UseFormReturn, useWatch } from 'react-hook-form'; +import useGraphStore from '../../store'; + +export function useWatchFormChange(id?: string, form?: UseFormReturn) { + let values = useWatch({ control: form?.control }); + const updateNodeForm = useGraphStore((state) => state.updateNodeForm); + + useEffect(() => { + // Manually triggered form updates are synchronized to the canvas + if (id && form?.formState.isDirty) { + values = form?.getValues(); + let nextValues: any = values; + + updateNodeForm(id, nextValues); + } + }, [form?.formState.isDirty, id, updateNodeForm, values]); +} + +export function useHandleLanguageChange(id?: string, form?: UseFormReturn) { + const updateNodeForm = useGraphStore((state) => state.updateNodeForm); + + const handleLanguageChange = useCallback( + (lang: string) => { + if (id) { + const script = CodeTemplateStrMap[lang as ProgrammingLanguage]; + form?.setValue('script', script); + updateNodeForm(id, script, ['script']); + } + }, + [form, id, updateNodeForm], + ); + + return handleLanguageChange; +} diff --git a/web/src/pages/agent/form/components/dynamic-input-variable.tsx b/web/src/pages/agent/form/components/dynamic-input-variable.tsx index 82082e6b846..a5781fd16f9 100644 --- a/web/src/pages/agent/form/components/dynamic-input-variable.tsx +++ b/web/src/pages/agent/form/components/dynamic-input-variable.tsx @@ -3,7 +3,7 @@ import { MinusCircleOutlined, PlusOutlined } from '@ant-design/icons'; import { Button, Collapse, Flex, Form, Input, Select } from 'antd'; import { PropsWithChildren, useCallback } from 'react'; import { useTranslation } from 'react-i18next'; -import { useBuildComponentIdSelectOptions } from '../../hooks/use-get-begin-query'; +import { useBuildVariableOptions } from '../../hooks/use-get-begin-query'; import styles from './index.less'; @@ -21,10 +21,7 @@ const getVariableName = (type: string) => const DynamicVariableForm = ({ node }: IProps) => { const { t } = useTranslation(); - const valueOptions = useBuildComponentIdSelectOptions( - node?.id, - node?.parentId, - ); + const valueOptions = useBuildVariableOptions(node?.id, node?.parentId); const form = Form.useFormInstance(); const options = [ diff --git a/web/src/pages/agent/form/components/next-dynamic-input-variable.tsx b/web/src/pages/agent/form/components/next-dynamic-input-variable.tsx index d22bea406dc..341b49f371a 100644 --- a/web/src/pages/agent/form/components/next-dynamic-input-variable.tsx +++ b/web/src/pages/agent/form/components/next-dynamic-input-variable.tsx @@ -20,7 +20,7 @@ import { RAGFlowNodeType } from '@/interfaces/database/flow'; import { Plus, Trash2 } from 'lucide-react'; import { useFieldArray, useFormContext } from 'react-hook-form'; import { useTranslation } from 'react-i18next'; -import { useBuildComponentIdSelectOptions } from '../../hooks/use-get-begin-query'; +import { useBuildVariableOptions } from '../../hooks/use-get-begin-query'; interface IProps { node?: RAGFlowNodeType; @@ -42,10 +42,7 @@ export function DynamicVariableForm({ node }: IProps) { control: form.control, }); - const valueOptions = useBuildComponentIdSelectOptions( - node?.id, - node?.parentId, - ); + const valueOptions = useBuildVariableOptions(node?.id, node?.parentId); const options = [ { value: VariableType.Reference, label: t('flow.reference') }, diff --git a/web/src/pages/agent/form/components/output.tsx b/web/src/pages/agent/form/components/output.tsx new file mode 100644 index 00000000000..3f7e4ae27d3 --- /dev/null +++ b/web/src/pages/agent/form/components/output.tsx @@ -0,0 +1,26 @@ +export type OutputType = { + title: string; + type: string; +}; + +type OutputProps = { + list: Array; +}; + +export function Output({ list }: OutputProps) { + return ( +
    +
    Output
    +
      + {list.map((x, idx) => ( +
    • + {x.title}: {x.type} +
    • + ))} +
    +
    + ); +} diff --git a/web/src/pages/agent/form/components/prompt-editor/constant.ts b/web/src/pages/agent/form/components/prompt-editor/constant.ts new file mode 100644 index 00000000000..b6cf30ed9cd --- /dev/null +++ b/web/src/pages/agent/form/components/prompt-editor/constant.ts @@ -0,0 +1 @@ +export const ProgrammaticTag = 'programmatic'; diff --git a/web/src/pages/agent/form/components/prompt-editor/index.css b/web/src/pages/agent/form/components/prompt-editor/index.css new file mode 100644 index 00000000000..8f305064721 --- /dev/null +++ b/web/src/pages/agent/form/components/prompt-editor/index.css @@ -0,0 +1,76 @@ +.typeahead-popover { + background: #fff; + box-shadow: 0px 5px 10px rgba(0, 0, 0, 0.3); + border-radius: 8px; + position: fixed; + z-index: 1000; +} + +.typeahead-popover ul { + list-style: none; + margin: 0; + max-height: 200px; + overflow-y: scroll; +} + +.typeahead-popover ul::-webkit-scrollbar { + display: none; +} + +.typeahead-popover ul { + -ms-overflow-style: none; + scrollbar-width: none; +} + +.typeahead-popover ul li { + margin: 0; + min-width: 180px; + font-size: 14px; + outline: none; + cursor: pointer; + border-radius: 8px; +} + +.typeahead-popover ul li.selected { + background: #eee; +} + +.typeahead-popover li { + margin: 0 8px 0 8px; + color: #050505; + cursor: pointer; + line-height: 16px; + font-size: 15px; + display: flex; + align-content: center; + flex-direction: row; + flex-shrink: 0; + background-color: #fff; + border: 0; +} + +.typeahead-popover li.active { + display: flex; + width: 20px; + height: 20px; + background-size: contain; +} + +.typeahead-popover li .text { + display: flex; + line-height: 20px; + flex-grow: 1; + min-width: 150px; +} + +.typeahead-popover li .icon { + display: flex; + width: 20px; + height: 20px; + user-select: none; + margin-right: 8px; + line-height: 16px; + background-size: contain; + background-repeat: no-repeat; + background-position: center; +} diff --git a/web/src/pages/agent/form/components/prompt-editor/index.tsx b/web/src/pages/agent/form/components/prompt-editor/index.tsx new file mode 100644 index 00000000000..ffda1e8c39b --- /dev/null +++ b/web/src/pages/agent/form/components/prompt-editor/index.tsx @@ -0,0 +1,164 @@ +import { CodeHighlightNode, CodeNode } from '@lexical/code'; +import { + InitialConfigType, + LexicalComposer, +} from '@lexical/react/LexicalComposer'; +import { ContentEditable } from '@lexical/react/LexicalContentEditable'; +import { LexicalErrorBoundary } from '@lexical/react/LexicalErrorBoundary'; +import { RichTextPlugin } from '@lexical/react/LexicalRichTextPlugin'; +import { HeadingNode, QuoteNode } from '@lexical/rich-text'; +import { + $getRoot, + $getSelection, + $nodesOfType, + EditorState, + Klass, + LexicalNode, +} from 'lexical'; + +import { + Tooltip, + TooltipContent, + TooltipTrigger, +} from '@/components/ui/tooltip'; +import { cn } from '@/lib/utils'; +import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext'; +import { Variable } from 'lucide-react'; +import { ReactNode, useCallback, useState } from 'react'; +import { useTranslation } from 'react-i18next'; +import theme from './theme'; +import { VariableNode } from './variable-node'; +import { VariableOnChangePlugin } from './variable-on-change-plugin'; +import VariablePickerMenuPlugin from './variable-picker-plugin'; + +// Catch any errors that occur during Lexical updates and log them +// or throw them as needed. If you don't throw them, Lexical will +// try to recover gracefully without losing user data. +function onError(error: Error) { + console.error(error); +} + +const Nodes: Array> = [ + HeadingNode, + QuoteNode, + CodeHighlightNode, + CodeNode, + VariableNode, +]; + +type PromptContentProps = { showToolbar?: boolean }; + +type IProps = { + value?: string; + onChange?: (value?: string) => void; + placeholder?: ReactNode; +} & PromptContentProps; + +function PromptContent({ showToolbar = true }: PromptContentProps) { + const [editor] = useLexicalComposerContext(); + const [isBlur, setIsBlur] = useState(false); + const { t } = useTranslation(); + + const insertTextAtCursor = useCallback(() => { + editor.update(() => { + const selection = $getSelection(); + + if (selection !== null) { + selection.insertText(' /'); + } + }); + }, [editor]); + + const handleVariableIconClick = useCallback(() => { + insertTextAtCursor(); + }, [insertTextAtCursor]); + + const handleBlur = useCallback(() => { + setIsBlur(true); + }, []); + + const handleFocus = useCallback(() => { + setIsBlur(false); + }, []); + + return ( +
    + {showToolbar && ( +
    + + + + + + + +

    {t('flow.insertVariableTip')}

    +
    +
    +
    + )} + +
    + ); +} + +export function PromptEditor({ + value, + onChange, + placeholder, + showToolbar, +}: IProps) { + const { t } = useTranslation(); + const initialConfig: InitialConfigType = { + namespace: 'PromptEditor', + theme, + onError, + nodes: Nodes, + }; + + const onValueChange = useCallback( + (editorState: EditorState) => { + editorState?.read(() => { + const listNodes = $nodesOfType(VariableNode); // to be removed + // const allNodes = $dfs(); + console.log('🚀 ~ onChange ~ allNodes:', listNodes); + + const text = $getRoot().getTextContent(); + + onChange?.(text); + }); + }, + [onChange], + ); + + return ( +
    + + + } + placeholder={ +
    + {placeholder || t('common.pleaseInput')} +
    + } + ErrorBoundary={LexicalErrorBoundary} + /> + + +
    +
    + ); +} diff --git a/web/src/pages/agent/form/components/prompt-editor/theme.ts b/web/src/pages/agent/form/components/prompt-editor/theme.ts new file mode 100644 index 00000000000..1cc2bc15528 --- /dev/null +++ b/web/src/pages/agent/form/components/prompt-editor/theme.ts @@ -0,0 +1,43 @@ +/** + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + * + */ + +export default { + code: 'editor-code', + heading: { + h1: 'editor-heading-h1', + h2: 'editor-heading-h2', + h3: 'editor-heading-h3', + h4: 'editor-heading-h4', + h5: 'editor-heading-h5', + }, + image: 'editor-image', + link: 'editor-link', + list: { + listitem: 'editor-listitem', + nested: { + listitem: 'editor-nested-listitem', + }, + ol: 'editor-list-ol', + ul: 'editor-list-ul', + }, + ltr: 'ltr', + paragraph: 'editor-paragraph', + placeholder: 'editor-placeholder', + quote: 'editor-quote', + rtl: 'rtl', + text: { + bold: 'editor-text-bold', + code: 'editor-text-code', + hashtag: 'editor-text-hashtag', + italic: 'editor-text-italic', + overflowed: 'editor-text-overflowed', + strikethrough: 'editor-text-strikethrough', + underline: 'editor-text-underline', + underlineStrikethrough: 'editor-text-underlineStrikethrough', + }, +}; diff --git a/web/src/pages/agent/form/components/prompt-editor/variable-node.tsx b/web/src/pages/agent/form/components/prompt-editor/variable-node.tsx new file mode 100644 index 00000000000..e2a8cc29f93 --- /dev/null +++ b/web/src/pages/agent/form/components/prompt-editor/variable-node.tsx @@ -0,0 +1,70 @@ +import i18n from '@/locales/config'; +import { BeginId } from '@/pages/flow/constant'; +import { DecoratorNode, LexicalNode, NodeKey } from 'lexical'; +import { ReactNode } from 'react'; +const prefix = BeginId + '@'; + +export class VariableNode extends DecoratorNode { + __value: string; + __label: string; + + static getType(): string { + return 'variable'; + } + + static clone(node: VariableNode): VariableNode { + return new VariableNode(node.__value, node.__label, node.__key); + } + + constructor(value: string, label: string, key?: NodeKey) { + super(key); + this.__value = value; + this.__label = label; + } + + createDOM(): HTMLElement { + const dom = document.createElement('span'); + dom.className = 'mr-1'; + + return dom; + } + + updateDOM(): false { + return false; + } + + decorate(): ReactNode { + let content: ReactNode = ( + {this.__label} + ); + if (this.__value.startsWith(prefix)) { + content = ( +
    + {i18n.t(`flow.begin`)} / {content} +
    + ); + } + return ( +
    + {content} +
    + ); + } + + getTextContent(): string { + return `{${this.__value}}`; + } +} + +export function $createVariableNode( + value: string, + label: string, +): VariableNode { + return new VariableNode(value, label); +} + +export function $isVariableNode( + node: LexicalNode | null | undefined, +): node is VariableNode { + return node instanceof VariableNode; +} diff --git a/web/src/pages/agent/form/components/prompt-editor/variable-on-change-plugin.tsx b/web/src/pages/agent/form/components/prompt-editor/variable-on-change-plugin.tsx new file mode 100644 index 00000000000..86fa66db4f8 --- /dev/null +++ b/web/src/pages/agent/form/components/prompt-editor/variable-on-change-plugin.tsx @@ -0,0 +1,35 @@ +import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext'; +import { EditorState, LexicalEditor } from 'lexical'; +import { useEffect } from 'react'; +import { ProgrammaticTag } from './constant'; + +interface IProps { + onChange: ( + editorState: EditorState, + editor?: LexicalEditor, + tags?: Set, + ) => void; +} + +export function VariableOnChangePlugin({ onChange }: IProps) { + // Access the editor through the LexicalComposerContext + const [editor] = useLexicalComposerContext(); + // Wrap our listener in useEffect to handle the teardown and avoid stale references. + useEffect(() => { + // most listeners return a teardown function that can be called to clean them up. + return editor.registerUpdateListener( + ({ editorState, tags, dirtyElements }) => { + // Check if there is a "programmatic" tag + const isProgrammaticUpdate = tags.has(ProgrammaticTag); + + // The onchange event is only triggered when the data is manually updated + // Otherwise, the content will be displayed incorrectly. + if (dirtyElements.size > 0 && !isProgrammaticUpdate) { + onChange(editorState); + } + }, + ); + }, [editor, onChange]); + + return null; +} diff --git a/web/src/pages/agent/form/components/prompt-editor/variable-picker-plugin.tsx b/web/src/pages/agent/form/components/prompt-editor/variable-picker-plugin.tsx new file mode 100644 index 00000000000..afa54d47740 --- /dev/null +++ b/web/src/pages/agent/form/components/prompt-editor/variable-picker-plugin.tsx @@ -0,0 +1,269 @@ +/** + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + * + */ + +import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext'; +import { + LexicalTypeaheadMenuPlugin, + MenuOption, + useBasicTypeaheadTriggerMatch, +} from '@lexical/react/LexicalTypeaheadMenuPlugin'; +import { + $createParagraphNode, + $createTextNode, + $getRoot, + $getSelection, + $isRangeSelection, + TextNode, +} from 'lexical'; +import React, { ReactElement, useCallback, useEffect, useRef } from 'react'; +import * as ReactDOM from 'react-dom'; + +import { $createVariableNode } from './variable-node'; + +import { useBuildQueryVariableOptions } from '@/pages/agent/hooks/use-get-begin-query'; +import { ProgrammaticTag } from './constant'; +import './index.css'; +class VariableInnerOption extends MenuOption { + label: string; + value: string; + + constructor(label: string, value: string) { + super(value); + this.label = label; + this.value = value; + } +} + +class VariableOption extends MenuOption { + label: ReactElement | string; + title: string; + options: VariableInnerOption[]; + + constructor( + label: ReactElement | string, + title: string, + options: VariableInnerOption[], + ) { + super(title); + this.label = label; + this.title = title; + this.options = options; + } +} + +function VariablePickerMenuItem({ + index, + option, + selectOptionAndCleanUp, +}: { + index: number; + option: VariableOption; + selectOptionAndCleanUp: ( + option: VariableOption | VariableInnerOption, + ) => void; +}) { + console.info('xxxx'); + return ( +
  • +
    + {option.title} +
      + {option.options.map((x) => ( +
    • selectOptionAndCleanUp(x)} + className="hover:bg-slate-300 p-1" + > + {x.label} +
    • + ))} +
    +
    +
  • + ); +} + +export default function VariablePickerMenuPlugin({ + value, +}: { + value?: string; +}): JSX.Element { + const [editor] = useLexicalComposerContext(); + const isFirstRender = useRef(true); + + const checkForTriggerMatch = useBasicTypeaheadTriggerMatch('/', { + minLength: 0, + }); + + const [queryString, setQueryString] = React.useState(''); + + const options = useBuildQueryVariableOptions(); + + const buildNextOptions = useCallback(() => { + let filteredOptions = options; + + if (queryString) { + const lowerQuery = queryString.toLowerCase(); + filteredOptions = options + .map((x) => ({ + ...x, + options: x.options.filter( + (y) => + y.label.toLowerCase().includes(lowerQuery) || + y.value.toLowerCase().includes(lowerQuery), + ), + })) + .filter((x) => x.options.length > 0); + } + + const nextOptions: VariableOption[] = filteredOptions.map( + (x) => + new VariableOption( + x.label, + x.title, + x.options.map((y) => new VariableInnerOption(y.label, y.value)), + ), + ); + + return nextOptions; + }, [options, queryString]); + + const findLabelByValue = useCallback( + (value: string) => { + const children = options.reduce>( + (pre, cur) => { + return pre.concat(cur.options); + }, + [], + ); + + return children.find((x) => x.value === value)?.label; + }, + [options], + ); + + const onSelectOption = useCallback( + ( + selectedOption: VariableOption | VariableInnerOption, + nodeToRemove: TextNode | null, + closeMenu: () => void, + ) => { + editor.update(() => { + const selection = $getSelection(); + + if (!$isRangeSelection(selection) || selectedOption === null) { + return; + } + + if (nodeToRemove) { + nodeToRemove.remove(); + } + + selection.insertNodes([ + $createVariableNode( + (selectedOption as VariableInnerOption).value, + selectedOption.label as string, + ), + ]); + + closeMenu(); + }); + }, + [editor], + ); + + const parseTextToVariableNodes = useCallback( + (text: string) => { + const paragraph = $createParagraphNode(); + + // Regular expression to match content within {} + const regex = /{([^}]*)}/g; + let match; + let lastIndex = 0; + + while ((match = regex.exec(text)) !== null) { + const { 1: content, index, 0: template } = match; + + // Add the previous text part (if any) + if (index > lastIndex) { + const textNode = $createTextNode(text.slice(lastIndex, index)); + + paragraph.append(textNode); + } + + // Add variable node or text node + const label = findLabelByValue(content); + if (label) { + paragraph.append($createVariableNode(content, label)); + } else { + paragraph.append($createTextNode(template)); + } + + // Update index + lastIndex = regex.lastIndex; + } + + // Add the last part of text (if any) + if (lastIndex < text.length) { + const textNode = $createTextNode(text.slice(lastIndex)); + paragraph.append(textNode); + } + + $getRoot().clear().append(paragraph); + }, + [findLabelByValue], + ); + + useEffect(() => { + if (editor && value && isFirstRender.current) { + isFirstRender.current = false; + editor.update( + () => { + parseTextToVariableNodes(value); + }, + { tag: ProgrammaticTag }, + ); + } + }, [parseTextToVariableNodes, editor, value]); + + return ( + + onQueryChange={setQueryString} + onSelectOption={onSelectOption} + triggerFn={checkForTriggerMatch} + options={buildNextOptions()} + menuRenderFn={(anchorElementRef, { selectOptionAndCleanUp }) => { + const nextOptions = buildNextOptions(); + console.log('🚀 ~ nextOptions:', nextOptions); + return anchorElementRef.current && nextOptions.length + ? ReactDOM.createPortal( +
    +
      + {nextOptions.map((option, i: number) => ( + + ))} +
    +
    , + anchorElementRef.current, + ) + : null; + }} + /> + ); +} diff --git a/web/src/pages/agent/form/components/query-variable.tsx b/web/src/pages/agent/form/components/query-variable.tsx new file mode 100644 index 00000000000..8cb49b5cc24 --- /dev/null +++ b/web/src/pages/agent/form/components/query-variable.tsx @@ -0,0 +1,37 @@ +import { SelectWithSearch } from '@/components/originui/select-with-search'; +import { + FormControl, + FormField, + FormItem, + FormLabel, + FormMessage, +} from '@/components/ui/form'; +import { useFormContext } from 'react-hook-form'; +import { useTranslation } from 'react-i18next'; +import { useBuildQueryVariableOptions } from '../../hooks/use-get-begin-query'; + +export function QueryVariable() { + const { t } = useTranslation(); + const form = useFormContext(); + + const nextOptions = useBuildQueryVariableOptions(); + + return ( + ( + + {t('flow.query')} + + + + + + )} + /> + ); +} diff --git a/web/src/pages/agent/form/generate-form/index.tsx b/web/src/pages/agent/form/generate-form/index.tsx index d463e682c90..8ca644b876a 100644 --- a/web/src/pages/agent/form/generate-form/index.tsx +++ b/web/src/pages/agent/form/generate-form/index.tsx @@ -1,6 +1,5 @@ -import { NextLLMSelect } from '@/components/llm-select'; +import { NextLLMSelect } from '@/components/llm-select/next'; import { MessageHistoryWindowSizeFormField } from '@/components/message-history-window-size-item'; -import { PromptEditor } from '@/components/prompt-editor'; import { Form, FormControl, @@ -12,6 +11,7 @@ import { import { Switch } from '@/components/ui/switch'; import { useTranslation } from 'react-i18next'; import { INextOperatorForm } from '../../interface'; +import { PromptEditor } from '../components/prompt-editor'; const GenerateForm = ({ form }: INextOperatorForm) => { const { t } = useTranslation(); diff --git a/web/src/pages/agent/form/invoke-form/dynamic-variables.tsx b/web/src/pages/agent/form/invoke-form/dynamic-variables.tsx index 3538b8b728e..f98c04d63ce 100644 --- a/web/src/pages/agent/form/invoke-form/dynamic-variables.tsx +++ b/web/src/pages/agent/form/invoke-form/dynamic-variables.tsx @@ -3,10 +3,11 @@ import { useTranslate } from '@/hooks/common-hooks'; import { DeleteOutlined } from '@ant-design/icons'; import { Button, Collapse, Flex, Input, Select, Table, TableProps } from 'antd'; import { trim } from 'lodash'; -import { useBuildComponentIdSelectOptions } from '../../hooks/use-get-begin-query'; -import { IInvokeVariable, RAGFlowNodeType } from '../../interface'; +import { useBuildVariableOptions } from '../../hooks/use-get-begin-query'; +import { IInvokeVariable } from '../../interface'; import { useHandleOperateParameters } from './hooks'; +import { RAGFlowNodeType } from '@/interfaces/database/flow'; import styles from './index.less'; interface IProps { @@ -24,7 +25,7 @@ const DynamicVariablesForm = ({ node }: IProps) => { const nodeId = node?.id; const { t } = useTranslate('flow'); - const options = useBuildComponentIdSelectOptions(nodeId, node?.parentId); + const options = useBuildVariableOptions(nodeId, node?.parentId); const { dataSource, handleAdd, diff --git a/web/src/pages/agent/form/keyword-extract-form/index.tsx b/web/src/pages/agent/form/keyword-extract-form/index.tsx index 5ec092a35a3..bda5d44f510 100644 --- a/web/src/pages/agent/form/keyword-extract-form/index.tsx +++ b/web/src/pages/agent/form/keyword-extract-form/index.tsx @@ -1,4 +1,4 @@ -import { NextLLMSelect } from '@/components/llm-select'; +import { NextLLMSelect } from '@/components/llm-select/next'; import { TopNFormField } from '@/components/top-n-item'; import { Form, diff --git a/web/src/pages/agent/form/message-form/index.tsx b/web/src/pages/agent/form/message-form/index.tsx index 01e744a91a3..b6f67b5d897 100644 --- a/web/src/pages/agent/form/message-form/index.tsx +++ b/web/src/pages/agent/form/message-form/index.tsx @@ -1,4 +1,5 @@ -import { Button } from '@/components/ui/button'; +import { FormContainer } from '@/components/form-container'; +import { BlockButton, Button } from '@/components/ui/button'; import { Form, FormControl, @@ -7,73 +8,95 @@ import { FormLabel, FormMessage, } from '@/components/ui/form'; -import { Textarea } from '@/components/ui/textarea'; -import { PlusCircle, Trash2 } from 'lucide-react'; -import { useFieldArray } from 'react-hook-form'; +import { zodResolver } from '@hookform/resolvers/zod'; +import { X } from 'lucide-react'; +import { useFieldArray, useForm } from 'react-hook-form'; import { useTranslation } from 'react-i18next'; +import { z } from 'zod'; import { INextOperatorForm } from '../../interface'; +import { PromptEditor } from '../components/prompt-editor'; +import { useValues } from './use-values'; +import { useWatchFormChange } from './use-watch-change'; -const MessageForm = ({ form }: INextOperatorForm) => { +const MessageForm = ({ node }: INextOperatorForm) => { const { t } = useTranslation(); + + const values = useValues(node); + + const FormSchema = z.object({ + content: z + .array( + z.object({ + value: z.string(), + }), + ) + .optional(), + }); + + const form = useForm({ + defaultValues: values, + resolver: zodResolver(FormSchema), + }); + + useWatchFormChange(node?.id, form); + const { fields, append, remove } = useFieldArray({ - name: 'messages', + name: 'content', control: form.control, }); return (
    { e.preventDefault(); }} > - - {t('flow.msg')} -
    - {fields.map((field, index) => ( -
    - ( - - - */} + + + + )} + /> + {fields.length > 1 && ( + )} - /> - {fields.length > 1 && ( - - )} -
    - ))} +
    + ))} - - - -
    + append({ value: '' })} // "" will cause the inability to add, refer to: https://github.com/orgs/react-hook-form/discussions/8485#discussioncomment-2961861 + > + {t('flow.addMessage')} + + + + + ); diff --git a/web/src/pages/agent/form/message-form/use-values.ts b/web/src/pages/agent/form/message-form/use-values.ts new file mode 100644 index 00000000000..415314665ff --- /dev/null +++ b/web/src/pages/agent/form/message-form/use-values.ts @@ -0,0 +1,25 @@ +import { RAGFlowNodeType } from '@/interfaces/database/flow'; +import { isEmpty } from 'lodash'; +import { useMemo } from 'react'; +import { convertToObjectArray } from '../../utils'; + +const defaultValues = { + content: [], +}; + +export function useValues(node?: RAGFlowNodeType) { + const values = useMemo(() => { + const formData = node?.data?.form; + + if (isEmpty(formData)) { + return defaultValues; + } + + return { + ...formData, + content: convertToObjectArray(formData.content), + }; + }, [node]); + + return values; +} diff --git a/web/src/pages/agent/form/message-form/use-watch-change.ts b/web/src/pages/agent/form/message-form/use-watch-change.ts new file mode 100644 index 00000000000..10c35c653c1 --- /dev/null +++ b/web/src/pages/agent/form/message-form/use-watch-change.ts @@ -0,0 +1,24 @@ +import { useEffect } from 'react'; +import { UseFormReturn, useWatch } from 'react-hook-form'; +import useGraphStore from '../../store'; +import { convertToStringArray } from '../../utils'; + +export function useWatchFormChange(id?: string, form?: UseFormReturn) { + let values = useWatch({ control: form?.control }); + const updateNodeForm = useGraphStore((state) => state.updateNodeForm); + + useEffect(() => { + // Manually triggered form updates are synchronized to the canvas + if (id && form?.formState.isDirty) { + values = form?.getValues(); + let nextValues: any = values; + + nextValues = { + ...values, + content: convertToStringArray(values.content), + }; + + updateNodeForm(id, nextValues); + } + }, [form?.formState.isDirty, id, updateNodeForm, values]); +} diff --git a/web/src/pages/agent/form/retrieval-form/index.tsx b/web/src/pages/agent/form/retrieval-form/index.tsx deleted file mode 100644 index 4a92a7f94fc..00000000000 --- a/web/src/pages/agent/form/retrieval-form/index.tsx +++ /dev/null @@ -1,54 +0,0 @@ -import KnowledgeBaseItem from '@/components/knowledge-base-item'; -import Rerank from '@/components/rerank'; -import SimilaritySlider from '@/components/similarity-slider'; -import TopNItem from '@/components/top-n-item'; -import { useTranslate } from '@/hooks/common-hooks'; -import type { FormProps } from 'antd'; -import { Form, Input } from 'antd'; -import { IOperatorForm } from '../../interface'; -import DynamicInputVariable from '../components/dynamic-input-variable'; - -type FieldType = { - top_n?: number; -}; - -const onFinish: FormProps['onFinish'] = (values) => { - console.log('Success:', values); -}; - -const onFinishFailed: FormProps['onFinishFailed'] = (errorInfo) => { - console.log('Failed:', errorInfo); -}; - -const RetrievalForm = ({ onValuesChange, form, node }: IOperatorForm) => { - const { t } = useTranslate('flow'); - return ( -
    - - - - - - - - -
    - ); -}; - -export default RetrievalForm; diff --git a/web/src/pages/agent/form/retrieval-form/next.tsx b/web/src/pages/agent/form/retrieval-form/next.tsx index 368c95a7e56..f30faf92df5 100644 --- a/web/src/pages/agent/form/retrieval-form/next.tsx +++ b/web/src/pages/agent/form/retrieval-form/next.tsx @@ -1,3 +1,4 @@ +import { FormContainer } from '@/components/form-container'; import { KnowledgeBaseFormField } from '@/components/knowledge-base-item'; import { RerankFormFields } from '@/components/rerank'; import { SimilaritySliderFormField } from '@/components/similarity-slider'; @@ -11,46 +12,90 @@ import { FormMessage, } from '@/components/ui/form'; import { Textarea } from '@/components/ui/textarea'; +import { zodResolver } from '@hookform/resolvers/zod'; +import { useMemo } from 'react'; +import { useForm } from 'react-hook-form'; import { useTranslation } from 'react-i18next'; +import { z } from 'zod'; +import { initialRetrievalValues } from '../../constant'; +import { useWatchFormChange } from '../../hooks/use-watch-form-change'; import { INextOperatorForm } from '../../interface'; -import { DynamicInputVariable } from '../components/next-dynamic-input-variable'; +import { Output } from '../components/output'; +import { QueryVariable } from '../components/query-variable'; +import { useValues } from './use-values'; -const RetrievalForm = ({ form, node }: INextOperatorForm) => { +const FormSchema = z.object({ + query: z.string().optional(), + similarity_threshold: z.coerce.number(), + keywords_similarity_weight: z.coerce.number(), + top_n: z.coerce.number(), + top_k: z.coerce.number(), + kb_ids: z.array(z.string()), + rerank_id: z.string(), + empty_response: z.string(), +}); + +const RetrievalForm = ({ node }: INextOperatorForm) => { const { t } = useTranslation(); + + const outputList = useMemo(() => { + return [ + { + title: 'formalized_content', + type: initialRetrievalValues.outputs.formalized_content.type, + }, + ]; + }, []); + + const defaultValues = useValues(node); + + const form = useForm({ + defaultValues: defaultValues, + resolver: zodResolver(FormSchema), + }); + + useWatchFormChange(node?.id, form); + return (
    { e.preventDefault(); }} > - - - - - - ( - - {t('chat.emptyResponse')} - -