diff --git a/gpt_researcher/llm_provider/generic/base.py b/gpt_researcher/llm_provider/generic/base.py index 8be95eca0..ff1fceee1 100644 --- a/gpt_researcher/llm_provider/generic/base.py +++ b/gpt_researcher/llm_provider/generic/base.py @@ -3,6 +3,7 @@ from colorama import Fore, Style, init import os + class GenericLLMProvider: def __init__(self, llm): @@ -48,7 +49,7 @@ def from_provider(cls, provider: str, **kwargs: Any): elif provider == "ollama": _check_pkg("langchain_community") from langchain_community.chat_models import ChatOllama - + llm = ChatOllama(base_url=os.environ["OLLAMA_BASE_URL"], **kwargs) elif provider == "together": _check_pkg("langchain_together") @@ -81,6 +82,27 @@ def from_provider(cls, provider: str, **kwargs: Any): model_id = kwargs.pop("model", None) or kwargs.pop("model_name", None) kwargs = {"model_id": model_id, **kwargs} llm = ChatBedrock(**kwargs) + elif provider == "watsonxai": + """ + Valid parameters are: ['decoding_method', 'length_penalty', 'temperature', 'top_p', 'top_k', 'random_seed', 'repetition_penalty', 'min_new_tokens', 'max_new_tokens', 'stop_sequences', ' time_limit', 'truncate_input_tokens', 'return_options', 'prompt_variables'] + """ + _check_pkg("langchain_ibm") + from langchain_ibm import WatsonxLLM + + if "max_tokens" in kwargs: + kwargs["max_new_tokens"] = kwargs.pop("max_tokens") + + if "min_tokens" in kwargs: + kwargs["min_new_tokens"] = kwargs.pop("min_tokens") + + if "model" in kwargs or "model_name" in kwargs: + model_id = kwargs.pop("model", None) or kwargs.pop("model_name", None) + kwargs = {"model_id": model_id, "params": kwargs} + + wx_url = os.environ.get("WATSONX_URL") + wx_project_id = os.environ.get("WATSONX_PROJECT_ID") + + llm = WatsonxLLM(url=wx_url, project_id=wx_project_id, **kwargs) else: supported = ", ".join(_SUPPORTED_PROVIDERS) raise ValueError( @@ -89,12 +111,14 @@ def from_provider(cls, provider: str, **kwargs: Any): ) return cls(llm) - async def get_chat_response(self, messages, stream, websocket=None): if not stream: # Getting output from the model chain using ainvoke for asynchronous invoking output = await self.llm.ainvoke(messages) + if type(output) is str: + return output + return output.content else: @@ -106,7 +130,11 @@ async def stream_response(self, messages, websocket=None): # Streaming the response using the chain astream method from langchain async for chunk in self.llm.astream(messages): - content = chunk.content + + if type(chunk) is str: + content = chunk + else: + content = chunk.content if content is not None: response += content paragraph += content @@ -126,7 +154,6 @@ async def _send_output(self, content, websocket=None): print(f"{Fore.GREEN}{content}{Style.RESET_ALL}") - _SUPPORTED_PROVIDERS = { "openai", "anthropic", @@ -141,8 +168,10 @@ async def _send_output(self, content, websocket=None): "huggingface", "groq", "bedrock", + "watsonxai", } + def _check_pkg(pkg: str) -> None: if not importlib.util.find_spec(pkg): pkg_kebab = pkg.replace("_", "-") diff --git a/gpt_researcher/memory/embeddings.py b/gpt_researcher/memory/embeddings.py index bc6b4be0d..611fc3982 100644 --- a/gpt_researcher/memory/embeddings.py +++ b/gpt_researcher/memory/embeddings.py @@ -1,7 +1,9 @@ from langchain_community.vectorstores import FAISS import os -OPENAI_EMBEDDING_MODEL = os.environ.get("OPENAI_EMBEDDING_MODEL","text-embedding-3-small") +OPENAI_EMBEDDING_MODEL = os.environ.get( + "OPENAI_EMBEDDING_MODEL", "text-embedding-3-small" +) class Memory: @@ -35,7 +37,7 @@ def __init__(self, embedding_provider, headers=None, **kwargs): _embeddings = OpenAIEmbeddings( openai_api_key=headers.get("openai_api_key") or os.environ.get("OPENAI_API_KEY"), - model=OPENAI_EMBEDDING_MODEL + model=OPENAI_EMBEDDING_MODEL, ) case "azure_openai": from langchain_openai import AzureOpenAIEmbeddings @@ -51,6 +53,16 @@ def __init__(self, embedding_provider, headers=None, **kwargs): model_name="sentence-transformers/all-MiniLM-L6-v2" ) + case "watsonxai": + from langchain_ibm import WatsonxEmbeddings + + _embeddings = WatsonxEmbeddings( + url=os.environ["WATSONX_URL"], + apikey=os.environ["WATSONX_APIKEY"], + project_id=os.environ["WATSONX_PROJECT_ID"], + model_id=os.environ["WATSONX_EMBEDDING_MODEL"], + ) + case _: raise Exception("Embedding provider not found.") diff --git a/requirements.txt b/requirements.txt index b8962d7a0..07b6f23c6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -29,6 +29,7 @@ unstructured json_repair json5 loguru +langchain-ibm==0.1.12 # uncomment for testing # pytest