diff --git a/lazyllm/engine/engine.py b/lazyllm/engine/engine.py index d5ab2e9de..582a99d99 100644 --- a/lazyllm/engine/engine.py +++ b/lazyllm/engine/engine.py @@ -1,7 +1,8 @@ -from typing import List, Dict, Type, Optional, Union, Any, overload +from typing import List, Dict, Type, Optional, Union, Any, overload,Callable import lazyllm from lazyllm import graph, switch, pipeline, package from lazyllm.tools import IntentClassifier, SqlManager +from lazyllm.tools.rag import SimpleDirectoryReader,DocNode from lazyllm.common import compile_func from .node import all_nodes, Node from .node_meta_hook import NodeMetaHook @@ -13,7 +14,9 @@ from datetime import datetime, timedelta import requests import json - +from fsspec import AbstractFileSystem +import paddleocr +import string # Each session will have a separate engine class Engine(ABC): __default_engine__ = None @@ -772,3 +775,48 @@ def __call__(self, *args, **kw) -> Union[str, List[str]]: @NodeConstructor.register('File') def make_file(id: str): return FileResource(id) + + + +class ReaderResource(object): + + def __call__(self, input_files: Union[str, List[str]] = "", + exclude: Optional[List] = None, exclude_hidden: bool = True, recursive: bool = False, + encoding: str = "utf-8", filename_as_id: bool = False, required_exts: Optional[List[str]] = None, + file_extractor: Optional[Dict[str, Callable]] = None, fs: Optional[AbstractFileSystem] = None, + metadata_genf: Optional[Callable[[str], Dict]] = None, num_files_limit: Optional[int] = None, + return_trace: bool = False, metadatas: Optional[Dict] = None): + if len(input_files) == 0: + return [] + if isinstance(input_files,str): + input_files = [input_files] + return SimpleDirectoryReader("",input_files,exclude,exclude_hidden,recursive, + encoding,filename_as_id,required_exts,file_extractor,fs, + metadata_genf,num_files_limit,return_trace,metadatas)._load_data() + +@NodeConstructor.register('Reader') +def make_simple_reader(): + return ReaderResource() + +punctuation = set(string.punctuation+ ",。!?;:“”‘’()【】《》…—~、") +def is_all_punctuation(s: str) -> bool: + return all(c in punctuation for c in s) +class OCR(lazyllm.Module): + def __init__(self): + super().__init__() + self._m = paddleocr.PaddleOCR() + def forward(self, input,metadatas: Optional[Dict] = None): + result = self._m.predict(input) + txt = [] + for res in result: + for sentence in res['rec_texts']: + t = sentence.strip() + if not is_all_punctuation(t) and len(t)>0 : + txt.append(DocNode(text=t,global_metadata=metadatas or {})) + return txt + + + +@NodeConstructor.register('OCR') +def make_ocr(): + return OCR() \ No newline at end of file diff --git a/requirements.full.txt b/requirements.full.txt index 82656464b..a71ce0c62 100644 --- a/requirements.full.txt +++ b/requirements.full.txt @@ -73,4 +73,4 @@ pymongo pymysql flagembedding mcp>=1.5.0 - +paddleocr diff --git a/requirements.txt b/requirements.txt index abf73e840..f0a96e4bc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -34,4 +34,4 @@ pymilvus>=2.4.11, <2.5.0 async-timeout httpx<0.28.0 rapidfuzz - +paddleocr diff --git a/tests/basic_tests/test_engine.py b/tests/basic_tests/test_engine.py index 6306c20b5..b98a66777 100644 --- a/tests/basic_tests/test_engine.py +++ b/tests/basic_tests/test_engine.py @@ -1,4 +1,4 @@ -from lazyllm.engine import LightEngine +from lazyllm.engine import LightEngine,engine import pytest import time from gradio_client import Client @@ -11,7 +11,8 @@ import socket import threading import requests - +import paddleocr +import string HOOK_PORT = 33733 HOOK_ROUTE = "mock_post" fastapi_code = """ @@ -665,7 +666,21 @@ def test_engine_status(self): engine.release_node(gid) assert '__start__' in engine._nodes and '__end__' in engine._nodes - + def test_engine_pdf_reader(self): + nodes = [dict(id='1', kind='Reader', name='m1', args=dict())] + edges = [dict(iid='__start__', oid='1'), dict(iid='1', oid='__end__')] + p = "D:\\Tutorial\\data\\data_txt\\6\\道德经.txt" + engine = LightEngine() + gid = engine.start(nodes, edges) + data = engine.run(gid, p) + entrys = lazyllm.tools.rag.SimpleDirectoryReader(input_files=[p])._load_data() + assert len(data) == len(entrys) + engine.stop(gid) + engine.reset() + nodes = [dict(id='1', kind='OCR', name='m1', args=dict())] + gid = engine.start(nodes, edges) + data = engine.run(gid, "C:\\wangtianxiong\\桌面\\vdbpdf.pdf") + print([t.get_text() for t in data]) class TestEngineRAG(object): def test_rag(self):