From f4fb3df437f32737f6975d218d1f2deabe4e37ee Mon Sep 17 00:00:00 2001 From: Adam Treat Date: Thu, 25 Jul 2024 10:04:47 -0400 Subject: [PATCH 01/30] Brave search tool calling. Signed-off-by: Adam Treat --- gpt4all-chat/CMakeLists.txt | 3 + gpt4all-chat/bravesearch.cpp | 221 ++++++++++++++++++++++++++++++ gpt4all-chat/bravesearch.h | 51 +++++++ gpt4all-chat/chat.cpp | 29 +++- gpt4all-chat/chat.h | 15 +- gpt4all-chat/chatlistmodel.cpp | 2 +- gpt4all-chat/chatllm.cpp | 109 +++++++++++++-- gpt4all-chat/chatllm.h | 9 +- gpt4all-chat/chatmodel.h | 38 +++-- gpt4all-chat/database.cpp | 4 +- gpt4all-chat/database.h | 63 +-------- gpt4all-chat/mysettings.cpp | 2 + gpt4all-chat/mysettings.h | 6 + gpt4all-chat/qml/ChatView.qml | 27 ++-- gpt4all-chat/qml/SettingsView.qml | 9 ++ gpt4all-chat/qml/ToolSettings.qml | 71 ++++++++++ gpt4all-chat/server.cpp | 28 +--- gpt4all-chat/server.h | 6 +- gpt4all-chat/sourceexcerpt.h | 95 +++++++++++++ 19 files changed, 650 insertions(+), 138 deletions(-) create mode 100644 gpt4all-chat/bravesearch.cpp create mode 100644 gpt4all-chat/bravesearch.h create mode 100644 gpt4all-chat/qml/ToolSettings.qml create mode 100644 gpt4all-chat/sourceexcerpt.h diff --git a/gpt4all-chat/CMakeLists.txt b/gpt4all-chat/CMakeLists.txt index e0f6b6f80914..994ecb936d8b 100644 --- a/gpt4all-chat/CMakeLists.txt +++ b/gpt4all-chat/CMakeLists.txt @@ -107,6 +107,7 @@ endif() qt_add_executable(chat main.cpp + bravesearch.h bravesearch.cpp chat.h chat.cpp chatllm.h chatllm.cpp chatmodel.h chatlistmodel.h chatlistmodel.cpp @@ -120,6 +121,7 @@ qt_add_executable(chat modellist.h modellist.cpp mysettings.h mysettings.cpp network.h network.cpp + sourceexcerpt.h server.h server.cpp logger.h logger.cpp ${APP_ICON_RESOURCE} @@ -153,6 +155,7 @@ qt_add_qml_module(chat qml/ThumbsDownDialog.qml qml/Toast.qml qml/ToastManager.qml + qml/ToolSettings.qml qml/MyBusyIndicator.qml qml/MyButton.qml qml/MyCheckBox.qml diff --git a/gpt4all-chat/bravesearch.cpp b/gpt4all-chat/bravesearch.cpp new file mode 100644 index 000000000000..a9c0df47ab9b --- /dev/null +++ b/gpt4all-chat/bravesearch.cpp @@ -0,0 +1,221 @@ +#include "bravesearch.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace Qt::Literals::StringLiterals; + +QPair> BraveSearch::search(const QString &apiKey, const QString &query, int topK, unsigned long timeout) +{ + QThread workerThread; + BraveAPIWorker worker; + worker.moveToThread(&workerThread); + connect(&worker, &BraveAPIWorker::finished, &workerThread, &QThread::quit, Qt::DirectConnection); + connect(this, &BraveSearch::request, &worker, &BraveAPIWorker::request, Qt::QueuedConnection); + workerThread.start(); + emit request(apiKey, query, topK); + workerThread.wait(timeout); + workerThread.quit(); + workerThread.wait(); + return worker.response(); +} + +void BraveAPIWorker::request(const QString &apiKey, const QString &query, int topK) +{ + m_topK = topK; + QUrl jsonUrl("https://api.search.brave.com/res/v1/web/search"); + QUrlQuery urlQuery; + urlQuery.addQueryItem("q", query); + jsonUrl.setQuery(urlQuery); + QNetworkRequest request(jsonUrl); + QSslConfiguration conf = request.sslConfiguration(); + conf.setPeerVerifyMode(QSslSocket::VerifyNone); + request.setSslConfiguration(conf); + + request.setRawHeader("X-Subscription-Token", apiKey.toUtf8()); +// request.setRawHeader("Accept-Encoding", "gzip"); + request.setRawHeader("Accept", "application/json"); + + m_networkManager = new QNetworkAccessManager(this); + QNetworkReply *reply = m_networkManager->get(request); + connect(qGuiApp, &QCoreApplication::aboutToQuit, reply, &QNetworkReply::abort); + connect(reply, &QNetworkReply::finished, this, &BraveAPIWorker::handleFinished); + connect(reply, &QNetworkReply::errorOccurred, this, &BraveAPIWorker::handleErrorOccurred); +} + +static QPair> cleanBraveResponse(const QByteArray& jsonResponse, qsizetype topK = 1) +{ + QJsonParseError err; + QJsonDocument document = QJsonDocument::fromJson(jsonResponse, &err); + if (err.error != QJsonParseError::NoError) { + qWarning() << "ERROR: Couldn't parse: " << jsonResponse << err.errorString(); + return QPair>(); + } + + QJsonObject searchResponse = document.object(); + QJsonObject cleanResponse; + QString query; + QJsonArray cleanArray; + + QList infos; + + if (searchResponse.contains("query")) { + QJsonObject queryObj = searchResponse["query"].toObject(); + if (queryObj.contains("original")) { + query = queryObj["original"].toString(); + } + } + + if (searchResponse.contains("mixed")) { + QJsonObject mixedResults = searchResponse["mixed"].toObject(); + QJsonArray mainResults = mixedResults["main"].toArray(); + + for (int i = 0; i < std::min(mainResults.size(), topK); ++i) { + QJsonObject m = mainResults[i].toObject(); + QString r_type = m["type"].toString(); + int idx = m["index"].toInt(); + QJsonObject resultsObject = searchResponse[r_type].toObject(); + QJsonArray resultsArray = resultsObject["results"].toArray(); + + QJsonValue cleaned; + SourceExcerpt info; + if (r_type == "web") { + // For web data - add a single output from the search + QJsonObject resultObj = resultsArray[idx].toObject(); + QStringList selectedKeys = {"type", "title", "url", "description", "date", "extra_snippets"}; + QJsonObject cleanedObj; + for (const auto& key : selectedKeys) { + if (resultObj.contains(key)) { + cleanedObj.insert(key, resultObj[key]); + } + } + + info.date = resultObj["date"].toString(); + info.text = resultObj["description"].toString(); // fixme + info.url = resultObj["url"].toString(); + QJsonObject meta_url = resultObj["meta_url"].toObject(); + info.favicon = meta_url["favicon"].toString(); + info.title = resultObj["title"].toString(); + + cleaned = cleanedObj; + } else if (r_type == "faq") { + // For faq data - take a list of all the questions & answers + QStringList selectedKeys = {"type", "question", "answer", "title", "url"}; + QJsonArray cleanedArray; + for (const auto& q : resultsArray) { + QJsonObject qObj = q.toObject(); + QJsonObject cleanedObj; + for (const auto& key : selectedKeys) { + if (qObj.contains(key)) { + cleanedObj.insert(key, qObj[key]); + } + } + cleanedArray.append(cleanedObj); + } + cleaned = cleanedArray; + } else if (r_type == "infobox") { + QJsonObject resultObj = resultsArray[idx].toObject(); + QStringList selectedKeys = {"type", "title", "url", "description", "long_desc"}; + QJsonObject cleanedObj; + for (const auto& key : selectedKeys) { + if (resultObj.contains(key)) { + cleanedObj.insert(key, resultObj[key]); + } + } + cleaned = cleanedObj; + } else if (r_type == "videos") { + QStringList selectedKeys = {"type", "url", "title", "description", "date"}; + QJsonArray cleanedArray; + for (const auto& q : resultsArray) { + QJsonObject qObj = q.toObject(); + QJsonObject cleanedObj; + for (const auto& key : selectedKeys) { + if (qObj.contains(key)) { + cleanedObj.insert(key, qObj[key]); + } + } + cleanedArray.append(cleanedObj); + } + cleaned = cleanedArray; + } else if (r_type == "locations") { + QStringList selectedKeys = {"type", "title", "url", "description", "coordinates", "postal_address", "contact", "rating", "distance", "zoom_level"}; + QJsonArray cleanedArray; + for (const auto& q : resultsArray) { + QJsonObject qObj = q.toObject(); + QJsonObject cleanedObj; + for (const auto& key : selectedKeys) { + if (qObj.contains(key)) { + cleanedObj.insert(key, qObj[key]); + } + } + cleanedArray.append(cleanedObj); + } + cleaned = cleanedArray; + } else if (r_type == "news") { + QStringList selectedKeys = {"type", "title", "url", "description"}; + QJsonArray cleanedArray; + for (const auto& q : resultsArray) { + QJsonObject qObj = q.toObject(); + QJsonObject cleanedObj; + for (const auto& key : selectedKeys) { + if (qObj.contains(key)) { + cleanedObj.insert(key, qObj[key]); + } + } + cleanedArray.append(cleanedObj); + } + cleaned = cleanedArray; + } else { + cleaned = QJsonValue(); + } + + infos.append(info); + cleanArray.append(cleaned); + } + } + + cleanResponse.insert("query", query); + cleanResponse.insert("top_k", cleanArray); + QJsonDocument cleanedDoc(cleanResponse); + +// qDebug().noquote() << document.toJson(QJsonDocument::Indented); +// qDebug().noquote() << cleanedDoc.toJson(QJsonDocument::Indented); + + return qMakePair(cleanedDoc.toJson(QJsonDocument::Indented), infos); +} + +void BraveAPIWorker::handleFinished() +{ + QNetworkReply *jsonReply = qobject_cast(sender()); + Q_ASSERT(jsonReply); + + if (jsonReply->error() == QNetworkReply::NoError && jsonReply->isFinished()) { + QByteArray jsonData = jsonReply->readAll(); + jsonReply->deleteLater(); + m_response = cleanBraveResponse(jsonData, m_topK); + } else { + QByteArray jsonData = jsonReply->readAll(); + qWarning() << "ERROR: Could not search brave" << jsonReply->error() << jsonReply->errorString() << jsonData; + jsonReply->deleteLater(); + } +} + +void BraveAPIWorker::handleErrorOccurred(QNetworkReply::NetworkError code) +{ + QNetworkReply *reply = qobject_cast(sender()); + Q_ASSERT(reply); + qWarning().noquote() << "ERROR: BraveAPIWorker::handleErrorOccurred got HTTP Error" << code << "response:" + << reply->errorString(); + emit finished(); +} diff --git a/gpt4all-chat/bravesearch.h b/gpt4all-chat/bravesearch.h new file mode 100644 index 000000000000..482b29a6199d --- /dev/null +++ b/gpt4all-chat/bravesearch.h @@ -0,0 +1,51 @@ +#ifndef BRAVESEARCH_H +#define BRAVESEARCH_H + +#include "sourceexcerpt.h" + +#include +#include +#include +#include + +class BraveAPIWorker : public QObject { + Q_OBJECT +public: + BraveAPIWorker() + : QObject(nullptr) + , m_networkManager(nullptr) + , m_topK(1) {} + virtual ~BraveAPIWorker() {} + + QPair> response() const { return m_response; } + +public Q_SLOTS: + void request(const QString &apiKey, const QString &query, int topK); + +Q_SIGNALS: + void finished(); + +private Q_SLOTS: + void handleFinished(); + void handleErrorOccurred(QNetworkReply::NetworkError code); + +private: + QNetworkAccessManager *m_networkManager; + QPair> m_response; + int m_topK; +}; + +class BraveSearch : public QObject { + Q_OBJECT +public: + BraveSearch() + : QObject(nullptr) {} + virtual ~BraveSearch() {} + + QPair> search(const QString &apiKey, const QString &query, int topK, unsigned long timeout = 2000); + +Q_SIGNALS: + void request(const QString &apiKey, const QString &query, int topK); +}; + +#endif // BRAVESEARCH_H diff --git a/gpt4all-chat/chat.cpp b/gpt4all-chat/chat.cpp index a44022c0bc24..7da8274cfd7d 100644 --- a/gpt4all-chat/chat.cpp +++ b/gpt4all-chat/chat.cpp @@ -59,6 +59,7 @@ void Chat::connectLLM() connect(m_llmodel, &ChatLLM::responseChanged, this, &Chat::handleResponseChanged, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::promptProcessing, this, &Chat::promptProcessing, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::generatingQuestions, this, &Chat::generatingQuestions, Qt::QueuedConnection); + connect(m_llmodel, &ChatLLM::toolCalled, this, &Chat::toolCalled, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::responseStopped, this, &Chat::responseStopped, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::modelLoadingError, this, &Chat::handleModelLoadingError, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::modelLoadingWarning, this, &Chat::modelLoadingWarning, Qt::QueuedConnection); @@ -67,7 +68,7 @@ void Chat::connectLLM() connect(m_llmodel, &ChatLLM::generatedQuestionFinished, this, &Chat::generatedQuestionFinished, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::reportSpeed, this, &Chat::handleTokenSpeedChanged, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::loadedModelInfoChanged, this, &Chat::loadedModelInfoChanged, Qt::QueuedConnection); - connect(m_llmodel, &ChatLLM::databaseResultsChanged, this, &Chat::handleDatabaseResultsChanged, Qt::QueuedConnection); + connect(m_llmodel, &ChatLLM::sourceExcerptsChanged, this, &Chat::handleSourceExcerptsChanged, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::modelInfoChanged, this, &Chat::handleModelInfoChanged, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::trySwitchContextOfLoadedModelCompleted, this, &Chat::handleTrySwitchContextOfLoadedModelCompleted, Qt::QueuedConnection); @@ -121,6 +122,7 @@ void Chat::resetResponseState() emit tokenSpeedChanged(); m_responseInProgress = true; m_responseState = m_collections.empty() ? Chat::PromptProcessing : Chat::LocalDocsRetrieval; + m_toolDescription = QString(); emit responseInProgressChanged(); emit responseStateChanged(); } @@ -134,7 +136,7 @@ void Chat::prompt(const QString &prompt) void Chat::regenerateResponse() { const int index = m_chatModel->count() - 1; - m_chatModel->updateSources(index, QList()); + m_chatModel->updateSources(index, QList()); emit regenerateResponseRequested(); } @@ -189,8 +191,13 @@ void Chat::handleModelLoadingPercentageChanged(float loadingPercentage) void Chat::promptProcessing() { - m_responseState = !databaseResults().isEmpty() ? Chat::LocalDocsProcessing : Chat::PromptProcessing; - emit responseStateChanged(); + if (sourceExcerpts().isEmpty()) + m_responseState = Chat::PromptProcessing; + else if (m_responseState == Chat::ToolCalled) + m_responseState = Chat::ToolProcessing; + else + m_responseState = Chat::LocalDocsProcessing; + emit responseStateChanged(); } void Chat::generatingQuestions() @@ -199,6 +206,14 @@ void Chat::generatingQuestions() emit responseStateChanged(); } +void Chat::toolCalled(const QString &description) +{ + m_responseState = Chat::ToolCalled; + m_toolDescription = description; + emit toolDescriptionChanged(); + emit responseStateChanged(); +} + void Chat::responseStopped(qint64 promptResponseMs) { m_tokenSpeed = QString(); @@ -357,11 +372,11 @@ QString Chat::fallbackReason() const return m_llmodel->fallbackReason(); } -void Chat::handleDatabaseResultsChanged(const QList &results) +void Chat::handleSourceExcerptsChanged(const QList &sourceExcerpts) { - m_databaseResults = results; + m_sourceExcerpts = sourceExcerpts; const int index = m_chatModel->count() - 1; - m_chatModel->updateSources(index, m_databaseResults); + m_chatModel->updateSources(index, m_sourceExcerpts); } void Chat::handleModelInfoChanged(const ModelInfo &modelInfo) diff --git a/gpt4all-chat/chat.h b/gpt4all-chat/chat.h index 065c624eef31..8f2b7410ec4b 100644 --- a/gpt4all-chat/chat.h +++ b/gpt4all-chat/chat.h @@ -40,6 +40,7 @@ class Chat : public QObject // 0=no, 1=waiting, 2=working Q_PROPERTY(int trySwitchContextInProgress READ trySwitchContextInProgress NOTIFY trySwitchContextInProgressChanged) Q_PROPERTY(QList generatedQuestions READ generatedQuestions NOTIFY generatedQuestionsChanged) + Q_PROPERTY(QString toolDescription READ toolDescription NOTIFY toolDescriptionChanged) QML_ELEMENT QML_UNCREATABLE("Only creatable from c++!") @@ -50,7 +51,9 @@ class Chat : public QObject LocalDocsProcessing, PromptProcessing, GeneratingQuestions, - ResponseGeneration + ResponseGeneration, + ToolCalled, + ToolProcessing }; Q_ENUM(ResponseState) @@ -81,9 +84,10 @@ class Chat : public QObject Q_INVOKABLE void stopGenerating(); Q_INVOKABLE void newPromptResponsePair(const QString &prompt); - QList databaseResults() const { return m_databaseResults; } + QList sourceExcerpts() const { return m_sourceExcerpts; } QString response() const; + QString toolDescription() const { return m_toolDescription; } bool responseInProgress() const { return m_responseInProgress; } ResponseState responseState() const; ModelInfo modelInfo() const; @@ -158,19 +162,21 @@ public Q_SLOTS: void trySwitchContextInProgressChanged(); void loadedModelInfoChanged(); void generatedQuestionsChanged(); + void toolDescriptionChanged(); private Q_SLOTS: void handleResponseChanged(const QString &response); void handleModelLoadingPercentageChanged(float); void promptProcessing(); void generatingQuestions(); + void toolCalled(const QString &description); void responseStopped(qint64 promptResponseMs); void generatedNameChanged(const QString &name); void generatedQuestionFinished(const QString &question); void handleRestoringFromText(); void handleModelLoadingError(const QString &error); void handleTokenSpeedChanged(const QString &tokenSpeed); - void handleDatabaseResultsChanged(const QList &results); + void handleSourceExcerptsChanged(const QList &sourceExcerpts); void handleModelInfoChanged(const ModelInfo &modelInfo); void handleTrySwitchContextOfLoadedModelCompleted(int value); @@ -185,6 +191,7 @@ private Q_SLOTS: QString m_device; QString m_fallbackReason; QString m_response; + QString m_toolDescription; QList m_collections; QList m_generatedQuestions; ChatModel *m_chatModel; @@ -192,7 +199,7 @@ private Q_SLOTS: ResponseState m_responseState; qint64 m_creationDate; ChatLLM *m_llmodel; - QList m_databaseResults; + QList m_sourceExcerpts; bool m_isServer = false; bool m_shouldDeleteLater = false; float m_modelLoadingPercentage = 0.0f; diff --git a/gpt4all-chat/chatlistmodel.cpp b/gpt4all-chat/chatlistmodel.cpp index b4afb39f6ce0..c5be4338ff3e 100644 --- a/gpt4all-chat/chatlistmodel.cpp +++ b/gpt4all-chat/chatlistmodel.cpp @@ -19,7 +19,7 @@ #include #define CHAT_FORMAT_MAGIC 0xF5D553CC -#define CHAT_FORMAT_VERSION 9 +#define CHAT_FORMAT_VERSION 10 class MyChatListModel: public ChatListModel { }; Q_GLOBAL_STATIC(MyChatListModel, chatListModelInstance) diff --git a/gpt4all-chat/chatllm.cpp b/gpt4all-chat/chatllm.cpp index e9fb7f3132f9..a6dea3959dbc 100644 --- a/gpt4all-chat/chatllm.cpp +++ b/gpt4all-chat/chatllm.cpp @@ -1,5 +1,6 @@ #include "chatllm.h" +#include "bravesearch.h" #include "chat.h" #include "chatapi.h" #include "localdocs.h" @@ -10,6 +11,7 @@ #include #include #include +#include #include #include #include @@ -113,6 +115,7 @@ ChatLLM::ChatLLM(Chat *parent, bool isServer) , m_reloadingToChangeVariant(false) , m_processedSystemPrompt(false) , m_restoreStateFromText(false) + , m_maybeToolCall(false) { moveToThread(&m_llmThread); connect(this, &ChatLLM::shouldBeLoadedChanged, this, &ChatLLM::handleShouldBeLoadedChanged, @@ -702,13 +705,44 @@ bool ChatLLM::handleResponse(int32_t token, const std::string &response) return false; } + // Only valid for llama 3.1 instruct + if (m_modelInfo.filename().startsWith("Meta-Llama-3.1-8B-Instruct")) { + // Based on https://llama.meta.com/docs/model-cards-and-prompt-formats/llama3_1/#built-in-python-based-tool-calling + // For brave_search and wolfram_alpha ipython is always used + + // <|python_tag|> + // brave_search.call(query="...") + // <|eom_id|> + const int eom_id = 128008; + const int python_tag = 128010; + + // If we have a built-in tool call, then it should be the first token + const bool isFirstResponseToken = m_promptResponseTokens == m_promptTokens; + Q_ASSERT(token != python_tag || isFirstResponseToken); + if (isFirstResponseToken && token == python_tag) { + m_maybeToolCall = true; + ++m_promptResponseTokens; + return !m_stopGenerating; + } + + // Check for end of built-in tool call + Q_ASSERT(token != eom_id || !m_maybeToolCall); + if (token == eom_id) { + ++m_promptResponseTokens; + return false; + } + } + // m_promptResponseTokens is related to last prompt/response not // the entire context window which we can reset on regenerate prompt ++m_promptResponseTokens; m_timer->inc(); Q_ASSERT(!response.empty()); m_response.append(response); - emit responseChanged(QString::fromStdString(remove_leading_whitespace(m_response))); + + if (!m_maybeToolCall) + emit responseChanged(QString::fromStdString(remove_leading_whitespace(m_response))); + return !m_stopGenerating; } @@ -735,24 +769,24 @@ bool ChatLLM::prompt(const QList &collectionList, const QString &prompt } bool ChatLLM::promptInternal(const QList &collectionList, const QString &prompt, const QString &promptTemplate, - int32_t n_predict, int32_t top_k, float top_p, float min_p, float temp, int32_t n_batch, float repeat_penalty, - int32_t repeat_penalty_tokens) + int32_t n_predict, int32_t top_k, float top_p, float min_p, float temp, int32_t n_batch, float repeat_penalty, + int32_t repeat_penalty_tokens) { if (!isModelLoaded()) return false; - QList databaseResults; + QList databaseResults; const int retrievalSize = MySettings::globalInstance()->localDocsRetrievalSize(); if (!collectionList.isEmpty()) { emit requestRetrieveFromDB(collectionList, prompt, retrievalSize, &databaseResults); // blocks - emit databaseResultsChanged(databaseResults); + emit sourceExcerptsChanged(databaseResults); } // Augment the prompt template with the results if any QString docsContext; if (!databaseResults.isEmpty()) { QStringList results; - for (const ResultInfo &info : databaseResults) + for (const SourceExcerpt &info : databaseResults) results << u"Collection: %1\nPath: %2\nExcerpt: %3"_s.arg(info.collection, info.path, info.text); // FIXME(jared): use a Jinja prompt template instead of hardcoded Alpaca-style localdocs template @@ -797,21 +831,66 @@ bool ChatLLM::promptInternal(const QList &collectionList, const QString m_timer->stop(); qint64 elapsed = totalTime.elapsed(); std::string trimmed = trim_whitespace(m_response); - if (trimmed != m_response) { - m_response = trimmed; - emit responseChanged(QString::fromStdString(m_response)); - } + if (m_maybeToolCall) { + m_maybeToolCall = false; + m_ctx.n_past = std::max(0, m_ctx.n_past); + m_ctx.tokens.erase(m_ctx.tokens.end() - m_promptResponseTokens, m_ctx.tokens.end()); + m_promptResponseTokens = 0; + m_promptTokens = 0; + m_response = std::string(); + return toolCallInternal(QString::fromStdString(trimmed), n_predict, top_k, top_p, min_p, temp, + n_batch, repeat_penalty, repeat_penalty_tokens); + } else { + if (trimmed != m_response) { + m_response = trimmed; + emit responseChanged(QString::fromStdString(m_response)); + } - SuggestionMode mode = MySettings::globalInstance()->suggestionMode(); - if (mode == SuggestionMode::On || (!databaseResults.isEmpty() && mode == SuggestionMode::LocalDocsOnly)) - generateQuestions(elapsed); - else - emit responseStopped(elapsed); + SuggestionMode mode = MySettings::globalInstance()->suggestionMode(); + if (mode == SuggestionMode::On || (!databaseResults.isEmpty() && mode == SuggestionMode::LocalDocsOnly)) + generateQuestions(elapsed); + else + emit responseStopped(elapsed); + } m_pristineLoadedState = false; return true; } +bool ChatLLM::toolCallInternal(const QString &toolCall, int32_t n_predict, int32_t top_k, float top_p, + float min_p, float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens) +{ + Q_ASSERT(m_modelInfo.filename().startsWith("Meta-Llama-3.1-8B-Instruct")); + emit toolCalled(tr("searching web...")); + + // Based on https://llama.meta.com/docs/model-cards-and-prompt-formats/llama3_1/#built-in-python-based-tool-calling + // For brave_search and wolfram_alpha ipython is always used + + static QRegularExpression re(R"(brave_search\.call\(query=\"([^\"]+)\"\))"); + QRegularExpressionMatch match = re.match(toolCall); + + QString prompt("<|start_header_id|>ipython<|end_header_id|>\n\n%1<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n%2"); + QString query; + if (match.hasMatch()) { + query = match.captured(1); + } else { + qWarning() << "WARNING: Could not find the tool for " << toolCall; + return promptInternal(QList()/*collectionList*/, prompt.arg(QString()), QString("%1") /*promptTemplate*/, + n_predict, top_k, top_p, min_p, temp, n_batch, repeat_penalty, repeat_penalty_tokens); + } + + const QString apiKey = MySettings::globalInstance()->braveSearchAPIKey(); + Q_ASSERT(apiKey != ""); + + BraveSearch brave; + const QPair> braveResponse = brave.search(apiKey, query, 2 /*topK*/, 2000 /*msecs to timeout*/); + + emit sourceExcerptsChanged(braveResponse.second); + + return promptInternal(QList()/*collectionList*/, prompt.arg(braveResponse.first), QString("%1") /*promptTemplate*/, + n_predict, top_k, top_p, min_p, temp, n_batch, repeat_penalty, repeat_penalty_tokens); +} + void ChatLLM::setShouldBeLoaded(bool b) { #if defined(DEBUG_MODEL_LOADING) diff --git a/gpt4all-chat/chatllm.h b/gpt4all-chat/chatllm.h index d123358ad58e..31e99c716dea 100644 --- a/gpt4all-chat/chatllm.h +++ b/gpt4all-chat/chatllm.h @@ -180,6 +180,7 @@ public Q_SLOTS: void responseChanged(const QString &response); void promptProcessing(); void generatingQuestions(); + void toolCalled(const QString &description); void responseStopped(qint64 promptResponseMs); void generatedNameChanged(const QString &name); void generatedQuestionFinished(const QString &generatedQuestion); @@ -188,17 +189,19 @@ public Q_SLOTS: void shouldBeLoadedChanged(); void trySwitchContextRequested(const ModelInfo &modelInfo); void trySwitchContextOfLoadedModelCompleted(int value); - void requestRetrieveFromDB(const QList &collections, const QString &text, int retrievalSize, QList *results); + void requestRetrieveFromDB(const QList &collections, const QString &text, int retrievalSize, QList *results); void reportSpeed(const QString &speed); void reportDevice(const QString &device); void reportFallbackReason(const QString &fallbackReason); - void databaseResultsChanged(const QList&); + void sourceExcerptsChanged(const QList&); void modelInfoChanged(const ModelInfo &modelInfo); protected: bool promptInternal(const QList &collectionList, const QString &prompt, const QString &promptTemplate, int32_t n_predict, int32_t top_k, float top_p, float min_p, float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens); + bool toolCallInternal(const QString &toolcall, int32_t n_predict, int32_t top_k, float top_p, float min_p, float temp, int32_t n_batch, float repeat_penalty, + int32_t repeat_penalty_tokens); bool handlePrompt(int32_t token); bool handleResponse(int32_t token, const std::string &response); bool handleNamePrompt(int32_t token); @@ -239,11 +242,13 @@ public Q_SLOTS: bool m_reloadingToChangeVariant; bool m_processedSystemPrompt; bool m_restoreStateFromText; + bool m_maybeToolCall; // m_pristineLoadedState is set if saveSate is unnecessary, either because: // - an unload was queued during LLModel::restoreState() // - the chat will be restored from text and hasn't been interacted with yet bool m_pristineLoadedState = false; QVector> m_stateFromText; + QNetworkAccessManager m_networkManager; // FIXME REMOVE }; #endif // CHATLLM_H diff --git a/gpt4all-chat/chatmodel.h b/gpt4all-chat/chatmodel.h index f061ccf7dc0c..7982d0146d0a 100644 --- a/gpt4all-chat/chatmodel.h +++ b/gpt4all-chat/chatmodel.h @@ -28,8 +28,8 @@ struct ChatItem Q_PROPERTY(bool stopped MEMBER stopped) Q_PROPERTY(bool thumbsUpState MEMBER thumbsUpState) Q_PROPERTY(bool thumbsDownState MEMBER thumbsDownState) - Q_PROPERTY(QList sources MEMBER sources) - Q_PROPERTY(QList consolidatedSources MEMBER consolidatedSources) + Q_PROPERTY(QList sources MEMBER sources) + Q_PROPERTY(QList consolidatedSources MEMBER consolidatedSources) public: // TODO: Maybe we should include the model name here as well as timestamp? @@ -38,8 +38,8 @@ struct ChatItem QString value; QString prompt; QString newResponse; - QList sources; - QList consolidatedSources; + QList sources; + QList consolidatedSources; bool currentResponse = false; bool stopped = false; bool thumbsUpState = false; @@ -200,20 +200,20 @@ class ChatModel : public QAbstractListModel } } - QList consolidateSources(const QList &sources) { - QMap groupedData; - for (const ResultInfo &info : sources) { + QList consolidateSources(const QList &sources) { + QMap groupedData; + for (const SourceExcerpt &info : sources) { if (groupedData.contains(info.file)) { groupedData[info.file].text += "\n---\n" + info.text; } else { groupedData[info.file] = info; } } - QList consolidatedSources = groupedData.values(); + QList consolidatedSources = groupedData.values(); return consolidatedSources; } - Q_INVOKABLE void updateSources(int index, const QList &sources) + Q_INVOKABLE void updateSources(int index, const QList &sources) { if (index < 0 || index >= m_chatItems.size()) return; @@ -274,7 +274,7 @@ class ChatModel : public QAbstractListModel stream << c.thumbsDownState; if (version > 7) { stream << c.sources.size(); - for (const ResultInfo &info : c.sources) { + for (const SourceExcerpt &info : c.sources) { Q_ASSERT(!info.file.isEmpty()); stream << info.collection; stream << info.path; @@ -286,12 +286,16 @@ class ChatModel : public QAbstractListModel stream << info.page; stream << info.from; stream << info.to; + if (version > 9) { + stream << info.url; + stream << info.favicon; + } } } else if (version > 2) { QList references; QList referencesContext; int validReferenceNumber = 1; - for (const ResultInfo &info : c.sources) { + for (const SourceExcerpt &info : c.sources) { if (info.file.isEmpty()) continue; @@ -345,9 +349,9 @@ class ChatModel : public QAbstractListModel if (version > 7) { qsizetype count; stream >> count; - QList sources; + QList sources; for (int i = 0; i < count; ++i) { - ResultInfo info; + SourceExcerpt info; stream >> info.collection; stream >> info.path; stream >> info.file; @@ -358,6 +362,10 @@ class ChatModel : public QAbstractListModel stream >> info.page; stream >> info.from; stream >> info.to; + if (version > 9) { + stream >> info.url; + stream >> info.favicon; + } sources.append(info); } c.sources = sources; @@ -369,7 +377,7 @@ class ChatModel : public QAbstractListModel stream >> referencesContext; if (!references.isEmpty()) { - QList sources; + QList sources; QList referenceList = references.split("\n"); // Ignore empty lines and those that begin with "---" which is no longer used @@ -384,7 +392,7 @@ class ChatModel : public QAbstractListModel for (int j = 0; j < referenceList.size(); ++j) { QString reference = referenceList[j]; QString context = referencesContext[j]; - ResultInfo info; + SourceExcerpt info; QTextStream refStream(&reference); QString dummy; int validReferenceNumber; diff --git a/gpt4all-chat/database.cpp b/gpt4all-chat/database.cpp index 02ab5364f6d4..f2fc50941253 100644 --- a/gpt4all-chat/database.cpp +++ b/gpt4all-chat/database.cpp @@ -1938,7 +1938,7 @@ QList Database::searchEmbeddings(const std::vector &query, const QLi } void Database::retrieveFromDB(const QList &collections, const QString &text, int retrievalSize, - QList *results) + QList *results) { #if defined(DEBUG) qDebug() << "retrieveFromDB" << collections << text << retrievalSize; @@ -1974,7 +1974,7 @@ void Database::retrieveFromDB(const QList &collections, const QString & const int from = q.value(8).toInt(); const int to = q.value(9).toInt(); const QString collectionName = q.value(10).toString(); - ResultInfo info; + SourceExcerpt info; info.collection = collectionName; info.path = document_path; info.file = file; diff --git a/gpt4all-chat/database.h b/gpt4all-chat/database.h index 9031229079aa..fd0a78d9d5a7 100644 --- a/gpt4all-chat/database.h +++ b/gpt4all-chat/database.h @@ -2,6 +2,7 @@ #define DATABASE_H #include "embllm.h" // IWYU pragma: keep +#include "sourceexcerpt.h" #include #include @@ -49,64 +50,6 @@ struct DocumentInfo } }; -struct ResultInfo { - Q_GADGET - Q_PROPERTY(QString collection MEMBER collection) - Q_PROPERTY(QString path MEMBER path) - Q_PROPERTY(QString file MEMBER file) - Q_PROPERTY(QString title MEMBER title) - Q_PROPERTY(QString author MEMBER author) - Q_PROPERTY(QString date MEMBER date) - Q_PROPERTY(QString text MEMBER text) - Q_PROPERTY(int page MEMBER page) - Q_PROPERTY(int from MEMBER from) - Q_PROPERTY(int to MEMBER to) - Q_PROPERTY(QString fileUri READ fileUri STORED false) - -public: - QString collection; // [Required] The name of the collection - QString path; // [Required] The full path - QString file; // [Required] The name of the file, but not the full path - QString title; // [Optional] The title of the document - QString author; // [Optional] The author of the document - QString date; // [Required] The creation or the last modification date whichever is latest - QString text; // [Required] The text actually used in the augmented context - int page = -1; // [Optional] The page where the text was found - int from = -1; // [Optional] The line number where the text begins - int to = -1; // [Optional] The line number where the text ends - - QString fileUri() const { - // QUrl reserved chars that are not UNSAFE_PATH according to glib/gconvert.c - static const QByteArray s_exclude = "!$&'()*+,/:=@~"_ba; - - Q_ASSERT(!QFileInfo(path).isRelative()); -#ifdef Q_OS_WINDOWS - Q_ASSERT(!path.contains('\\')); // Qt normally uses forward slash as path separator -#endif - - auto escaped = QString::fromUtf8(QUrl::toPercentEncoding(path, s_exclude)); - if (escaped.front() != '/') - escaped = '/' + escaped; - return u"file://"_s + escaped; - } - - bool operator==(const ResultInfo &other) const { - return file == other.file && - title == other.title && - author == other.author && - date == other.date && - text == other.text && - page == other.page && - from == other.from && - to == other.to; - } - bool operator!=(const ResultInfo &other) const { - return !(*this == other); - } -}; - -Q_DECLARE_METATYPE(ResultInfo) - struct CollectionItem { // -- Fields persisted to database -- @@ -158,7 +101,7 @@ public Q_SLOTS: void forceRebuildFolder(const QString &path); bool addFolder(const QString &collection, const QString &path, const QString &embedding_model); void removeFolder(const QString &collection, const QString &path); - void retrieveFromDB(const QList &collections, const QString &text, int retrievalSize, QList *results); + void retrieveFromDB(const QList &collections, const QString &text, int retrievalSize, QList *results); void changeChunkSize(int chunkSize); void changeFileExtensions(const QStringList &extensions); @@ -225,7 +168,7 @@ private Q_SLOTS: QStringList m_scannedFileExtensions; QTimer *m_scanTimer; QMap> m_docsToScan; - QList m_retrieve; + QList m_retrieve; QThread m_dbThread; QFileSystemWatcher *m_watcher; QSet m_watchedPaths; diff --git a/gpt4all-chat/mysettings.cpp b/gpt4all-chat/mysettings.cpp index 525ccc1e9dc7..57f94ca21575 100644 --- a/gpt4all-chat/mysettings.cpp +++ b/gpt4all-chat/mysettings.cpp @@ -456,6 +456,7 @@ bool MySettings::localDocsUseRemoteEmbed() const { return getBasicSetting QString MySettings::localDocsNomicAPIKey() const { return getBasicSetting("localdocs/nomicAPIKey" ).toString(); } QString MySettings::localDocsEmbedDevice() const { return getBasicSetting("localdocs/embedDevice" ).toString(); } QString MySettings::networkAttribution() const { return getBasicSetting("network/attribution" ).toString(); } +QString MySettings::braveSearchAPIKey() const { return getBasicSetting("bravesearch/APIKey" ).toString(); } ChatTheme MySettings::chatTheme() const { return ChatTheme (getEnumSetting("chatTheme", chatThemeNames)); } FontSize MySettings::fontSize() const { return FontSize (getEnumSetting("fontSize", fontSizeNames)); } @@ -474,6 +475,7 @@ void MySettings::setLocalDocsUseRemoteEmbed(bool value) { setBasic void MySettings::setLocalDocsNomicAPIKey(const QString &value) { setBasicSetting("localdocs/nomicAPIKey", value, "localDocsNomicAPIKey"); } void MySettings::setLocalDocsEmbedDevice(const QString &value) { setBasicSetting("localdocs/embedDevice", value, "localDocsEmbedDevice"); } void MySettings::setNetworkAttribution(const QString &value) { setBasicSetting("network/attribution", value, "networkAttribution"); } +void MySettings::setBraveSearchAPIKey(const QString &value) { setBasicSetting("bravesearch/APIKey", value, "braveSearchAPIKey"); } void MySettings::setChatTheme(ChatTheme value) { setBasicSetting("chatTheme", chatThemeNames .value(int(value))); } void MySettings::setFontSize(FontSize value) { setBasicSetting("fontSize", fontSizeNames .value(int(value))); } diff --git a/gpt4all-chat/mysettings.h b/gpt4all-chat/mysettings.h index 3db8b2345c1e..d4aa4c5a3410 100644 --- a/gpt4all-chat/mysettings.h +++ b/gpt4all-chat/mysettings.h @@ -72,6 +72,7 @@ class MySettings : public QObject Q_PROPERTY(int networkPort READ networkPort WRITE setNetworkPort NOTIFY networkPortChanged) Q_PROPERTY(SuggestionMode suggestionMode READ suggestionMode WRITE setSuggestionMode NOTIFY suggestionModeChanged) Q_PROPERTY(QStringList uiLanguages MEMBER m_uiLanguages CONSTANT) + Q_PROPERTY(QString braveSearchAPIKey READ braveSearchAPIKey WRITE setBraveSearchAPIKey NOTIFY braveSearchAPIKeyChanged) public: static MySettings *globalInstance(); @@ -185,6 +186,10 @@ class MySettings : public QObject QString localDocsEmbedDevice() const; void setLocalDocsEmbedDevice(const QString &value); + // Tool settings + QString braveSearchAPIKey() const; + void setBraveSearchAPIKey(const QString &value); + // Network settings QString networkAttribution() const; void setNetworkAttribution(const QString &value); @@ -239,6 +244,7 @@ class MySettings : public QObject void deviceChanged(); void suggestionModeChanged(); void languageAndLocaleChanged(); + void braveSearchAPIKeyChanged(); private: QSettings m_settings; diff --git a/gpt4all-chat/qml/ChatView.qml b/gpt4all-chat/qml/ChatView.qml index 920e2759982d..75a6fc146732 100644 --- a/gpt4all-chat/qml/ChatView.qml +++ b/gpt4all-chat/qml/ChatView.qml @@ -881,6 +881,8 @@ Rectangle { case Chat.PromptProcessing: return qsTr("processing ...") case Chat.ResponseGeneration: return qsTr("generating response ..."); case Chat.GeneratingQuestions: return qsTr("generating questions ..."); + case Chat.ToolCalled: return currentChat.toolDescription; + case Chat.ToolProcessing: return qsTr("processing web results ..."); // FIXME should not be hardcoded! default: return ""; // handle unexpected values } } @@ -1131,7 +1133,7 @@ Rectangle { sourceSize.width: 24 sourceSize.height: 24 mipmap: true - source: "qrc:/gpt4all/icons/db.svg" + source: consolidatedSources[0].url === "" ? "qrc:/gpt4all/icons/db.svg" : "qrc:/gpt4all/icons/globe.svg" } ColorOverlay { @@ -1243,11 +1245,15 @@ Rectangle { MouseArea { id: ma - enabled: modelData.path !== "" + enabled: modelData.path !== "" || modelData.url !== "" anchors.fill: parent hoverEnabled: true onClicked: function() { - Qt.openUrlExternally(modelData.fileUri) + if (modelData.url !== "") { + console.log("opening url") + Qt.openUrlExternally(modelData.url) + } else + Qt.openUrlExternally(modelData.fileUri) } } @@ -1287,22 +1293,27 @@ Rectangle { Image { id: fileIcon anchors.fill: parent - visible: false + visible: modelData.favicon !== "" sourceSize.width: 24 sourceSize.height: 24 mipmap: true source: { - if (modelData.file.toLowerCase().endsWith(".txt")) + if (modelData.favicon !== "") + return modelData.favicon; + else if (modelData.file.toLowerCase().endsWith(".txt")) return "qrc:/gpt4all/icons/file-txt.svg" else if (modelData.file.toLowerCase().endsWith(".pdf")) return "qrc:/gpt4all/icons/file-pdf.svg" else if (modelData.file.toLowerCase().endsWith(".md")) return "qrc:/gpt4all/icons/file-md.svg" - else + else if (modelData.file !== "") return "qrc:/gpt4all/icons/file.svg" + else + return "qrc:/gpt4all/icons/globe.svg" } } ColorOverlay { + visible: !fileIcon.visible anchors.fill: fileIcon source: fileIcon color: theme.textColor @@ -1310,7 +1321,7 @@ Rectangle { } Text { Layout.maximumWidth: 156 - text: modelData.collection !== "" ? modelData.collection : qsTr("LocalDocs") + text: modelData.collection !== "" ? modelData.collection : modelData.title font.pixelSize: theme.fontSizeLarge font.bold: true color: theme.styledTextColor @@ -1326,7 +1337,7 @@ Rectangle { Layout.fillHeight: true Layout.maximumWidth: 180 Layout.maximumHeight: 55 - title.height - text: modelData.file + text: modelData.file !== "" ? modelData.file : modelData.url color: theme.textColor font.pixelSize: theme.fontSizeSmall elide: Qt.ElideRight diff --git a/gpt4all-chat/qml/SettingsView.qml b/gpt4all-chat/qml/SettingsView.qml index 176d04188c6e..d421dad2f9ed 100644 --- a/gpt4all-chat/qml/SettingsView.qml +++ b/gpt4all-chat/qml/SettingsView.qml @@ -34,6 +34,9 @@ Rectangle { ListElement { title: qsTr("LocalDocs") } + ListElement { + title: qsTr("Tools") + } } ColumnLayout { @@ -152,6 +155,12 @@ Rectangle { Component { LocalDocsSettings { } } ] } + + MySettingsStack { + tabs: [ + Component { ToolSettings { } } + ] + } } } } diff --git a/gpt4all-chat/qml/ToolSettings.qml b/gpt4all-chat/qml/ToolSettings.qml new file mode 100644 index 000000000000..2fc1cd3210da --- /dev/null +++ b/gpt4all-chat/qml/ToolSettings.qml @@ -0,0 +1,71 @@ +import QtCore +import QtQuick +import QtQuick.Controls +import QtQuick.Controls.Basic +import QtQuick.Layouts +import QtQuick.Dialogs +import localdocs +import modellist +import mysettings +import network + +MySettingsTab { + onRestoreDefaultsClicked: { + MySettings.restoreLocalDocsDefaults(); + } + + showRestoreDefaultsButton: true + + title: qsTr("Tools") + contentItem: ColumnLayout { + id: root + spacing: 30 + + ColumnLayout { + spacing: 10 + Label { + color: theme.grayRed900 + font.pixelSize: theme.fontSizeLarge + font.bold: true + text: qsTr("Brave Search") + } + + Rectangle { + Layout.fillWidth: true + height: 1 + color: theme.grayRed500 + } + } + + RowLayout { + MySettingsLabel { + id: apiKeyLabel + text: qsTr("Brave AI API key") + helpText: qsTr('The API key to use for Brave Web Search. Get one from the Brave for free API keys page.') + onLinkActivated: function(link) { Qt.openUrlExternally(link) } + } + + MyTextField { + id: apiKeyField + text: MySettings.braveSearchAPIKey + color: theme.textColor + font.pixelSize: theme.fontSizeLarge + Layout.alignment: Qt.AlignRight + Layout.minimumWidth: 200 + onEditingFinished: { + MySettings.braveSearchAPIKey = apiKeyField.text; + } + Accessible.role: Accessible.EditableText + Accessible.name: apiKeyLabel.text + Accessible.description: apiKeyLabel.helpText + } + } + + Rectangle { + Layout.topMargin: 15 + Layout.fillWidth: true + height: 1 + color: theme.settingsDivider + } + } +} diff --git a/gpt4all-chat/server.cpp b/gpt4all-chat/server.cpp index c8485d93e11d..e655bf9feff9 100644 --- a/gpt4all-chat/server.cpp +++ b/gpt4all-chat/server.cpp @@ -56,27 +56,13 @@ static inline QJsonObject modelToJson(const ModelInfo &info) return model; } -static inline QJsonObject resultToJson(const ResultInfo &info) -{ - QJsonObject result; - result.insert("file", info.file); - result.insert("title", info.title); - result.insert("author", info.author); - result.insert("date", info.date); - result.insert("text", info.text); - result.insert("page", info.page); - result.insert("from", info.from); - result.insert("to", info.to); - return result; -} - Server::Server(Chat *chat) : ChatLLM(chat, true /*isServer*/) , m_chat(chat) , m_server(nullptr) { connect(this, &Server::threadStarted, this, &Server::start); - connect(this, &Server::databaseResultsChanged, this, &Server::handleDatabaseResultsChanged); + connect(this, &Server::sourceExcerptsChanged, this, &Server::handleSourceExcerptsChanged); connect(chat, &Chat::collectionListChanged, this, &Server::handleCollectionListChanged, Qt::QueuedConnection); } @@ -373,7 +359,7 @@ QHttpServerResponse Server::handleCompletionRequest(const QHttpServerRequest &re int promptTokens = 0; int responseTokens = 0; - QList>> responses; + QList>> responses; for (int i = 0; i < n; ++i) { if (!promptInternal( m_collections, @@ -394,7 +380,7 @@ QHttpServerResponse Server::handleCompletionRequest(const QHttpServerRequest &re QString echoedPrompt = actualPrompt; if (!echoedPrompt.endsWith("\n")) echoedPrompt += "\n"; - responses.append(qMakePair((echo ? u"%1\n"_s.arg(actualPrompt) : QString()) + response(), m_databaseResults)); + responses.append(qMakePair((echo ? u"%1\n"_s.arg(actualPrompt) : QString()) + response(), m_sourceExcerpts)); if (!promptTokens) promptTokens += m_promptTokens; responseTokens += m_promptResponseTokens - m_promptTokens; @@ -414,7 +400,7 @@ QHttpServerResponse Server::handleCompletionRequest(const QHttpServerRequest &re int index = 0; for (const auto &r : responses) { QString result = r.first; - QList infos = r.second; + QList infos = r.second; QJsonObject choice; choice.insert("index", index++); choice.insert("finish_reason", responseTokens == max_tokens ? "length" : "stop"); @@ -425,7 +411,7 @@ QHttpServerResponse Server::handleCompletionRequest(const QHttpServerRequest &re if (MySettings::globalInstance()->localDocsShowReferences()) { QJsonArray references; for (const auto &ref : infos) - references.append(resultToJson(ref)); + references.append(ref.toJson()); choice.insert("references", references); } choices.append(choice); @@ -434,7 +420,7 @@ QHttpServerResponse Server::handleCompletionRequest(const QHttpServerRequest &re int index = 0; for (const auto &r : responses) { QString result = r.first; - QList infos = r.second; + QList infos = r.second; QJsonObject choice; choice.insert("text", result); choice.insert("index", index++); @@ -443,7 +429,7 @@ QHttpServerResponse Server::handleCompletionRequest(const QHttpServerRequest &re if (MySettings::globalInstance()->localDocsShowReferences()) { QJsonArray references; for (const auto &ref : infos) - references.append(resultToJson(ref)); + references.append(ref.toJson()); choice.insert("references", references); } choices.append(choice); diff --git a/gpt4all-chat/server.h b/gpt4all-chat/server.h index 689f0b6061e3..0c6f6eb6e195 100644 --- a/gpt4all-chat/server.h +++ b/gpt4all-chat/server.h @@ -2,7 +2,7 @@ #define SERVER_H #include "chatllm.h" -#include "database.h" +#include "sourceexcerpt.h" #include #include @@ -29,13 +29,13 @@ public Q_SLOTS: private Q_SLOTS: QHttpServerResponse handleCompletionRequest(const QHttpServerRequest &request, bool isChat); - void handleDatabaseResultsChanged(const QList &results) { m_databaseResults = results; } + void handleSourceExcerptsChanged(const QList &sourceExcerpts) { m_sourceExcerpts = sourceExcerpts; } void handleCollectionListChanged(const QList &collectionList) { m_collections = collectionList; } private: Chat *m_chat; QHttpServer *m_server; - QList m_databaseResults; + QList m_sourceExcerpts; QList m_collections; }; diff --git a/gpt4all-chat/sourceexcerpt.h b/gpt4all-chat/sourceexcerpt.h new file mode 100644 index 000000000000..91497e9daf17 --- /dev/null +++ b/gpt4all-chat/sourceexcerpt.h @@ -0,0 +1,95 @@ +#ifndef SOURCEEXCERT_H +#define SOURCEEXCERT_H + +#include +#include +#include +#include + +using namespace Qt::Literals::StringLiterals; + +struct SourceExcerpt { + Q_GADGET + Q_PROPERTY(QString date MEMBER date) + Q_PROPERTY(QString text MEMBER text) + Q_PROPERTY(QString collection MEMBER collection) + Q_PROPERTY(QString path MEMBER path) + Q_PROPERTY(QString file MEMBER file) + Q_PROPERTY(QString url MEMBER url) + Q_PROPERTY(QString favicon MEMBER favicon) + Q_PROPERTY(QString title MEMBER title) + Q_PROPERTY(QString author MEMBER author) + Q_PROPERTY(int page MEMBER page) + Q_PROPERTY(int from MEMBER from) + Q_PROPERTY(int to MEMBER to) + Q_PROPERTY(QString fileUri READ fileUri STORED false) + +public: + QString date; // [Required] The creation or the last modification date whichever is latest + QString text; // [Required] The text actually used in the augmented context + QString collection; // [Optional] The name of the collection + QString path; // [Optional] The full path + QString file; // [Optional] The name of the file, but not the full path + QString url; // [Optional] The name of the remote url + QString favicon; // [Optional] The favicon + QString title; // [Optional] The title of the document + QString author; // [Optional] The author of the document + int page = -1; // [Optional] The page where the text was found + int from = -1; // [Optional] The line number where the text begins + int to = -1; // [Optional] The line number where the text ends + + QString fileUri() const { + // QUrl reserved chars that are not UNSAFE_PATH according to glib/gconvert.c + static const QByteArray s_exclude = "!$&'()*+,/:=@~"_ba; + + Q_ASSERT(!QFileInfo(path).isRelative()); +#ifdef Q_OS_WINDOWS + Q_ASSERT(!path.contains('\\')); // Qt normally uses forward slash as path separator +#endif + + auto escaped = QString::fromUtf8(QUrl::toPercentEncoding(path, s_exclude)); + if (escaped.front() != '/') + escaped = '/' + escaped; + return u"file://"_s + escaped; + } + + QJsonObject toJson() const + { + QJsonObject result; + result.insert("date", date); + result.insert("text", text); + result.insert("collection", collection); + result.insert("path", path); + result.insert("file", file); + result.insert("url", url); + result.insert("favicon", favicon); + result.insert("title", title); + result.insert("author", author); + result.insert("page", page); + result.insert("from", from); + result.insert("to", to); + return result; + } + + bool operator==(const SourceExcerpt &other) const { + return date == other.date && + text == other.text && + collection == other.collection && + path == other.path && + file == other.file && + url == other.url && + favicon == other.favicon && + title == other.title && + author == other.author && + page == other.page && + from == other.from && + to == other.to; + } + bool operator!=(const SourceExcerpt &other) const { + return !(*this == other); + } +}; + +Q_DECLARE_METATYPE(SourceExcerpt) + +#endif // SOURCEEXCERT_H From a71db10124360afd467f3656498991d6649d2f72 Mon Sep 17 00:00:00 2001 From: Adam Treat Date: Thu, 25 Jul 2024 15:44:40 -0400 Subject: [PATCH 02/30] Change the name to announce the beta. Signed-off-by: Adam Treat --- gpt4all-chat/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gpt4all-chat/CMakeLists.txt b/gpt4all-chat/CMakeLists.txt index 994ecb936d8b..a3e06386a5c6 100644 --- a/gpt4all-chat/CMakeLists.txt +++ b/gpt4all-chat/CMakeLists.txt @@ -20,7 +20,7 @@ set(APP_VERSION_MAJOR 3) set(APP_VERSION_MINOR 2) set(APP_VERSION_PATCH 2) set(APP_VERSION_BASE "${APP_VERSION_MAJOR}.${APP_VERSION_MINOR}.${APP_VERSION_PATCH}") -set(APP_VERSION "${APP_VERSION_BASE}-dev0") +set(APP_VERSION "${APP_VERSION_BASE}-web_search_beta") list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_LIST_DIR}/cmake/Modules") From 1c9911a4ac461935644fa43b6a0b3937f88618d1 Mon Sep 17 00:00:00 2001 From: Adam Treat Date: Thu, 25 Jul 2024 19:14:57 -0400 Subject: [PATCH 03/30] Fix problem with only displaying one source for tool call excerpts. Signed-off-by: Adam Treat --- gpt4all-chat/chatmodel.h | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/gpt4all-chat/chatmodel.h b/gpt4all-chat/chatmodel.h index 7982d0146d0a..1031ded5a266 100644 --- a/gpt4all-chat/chatmodel.h +++ b/gpt4all-chat/chatmodel.h @@ -203,10 +203,11 @@ class ChatModel : public QAbstractListModel QList consolidateSources(const QList &sources) { QMap groupedData; for (const SourceExcerpt &info : sources) { - if (groupedData.contains(info.file)) { - groupedData[info.file].text += "\n---\n" + info.text; + QString key = !info.file.isEmpty() ? info.file : info.url; + if (groupedData.contains(key)) { + groupedData[key].text += "\n---\n" + info.text; } else { - groupedData[info.file] = info; + groupedData[key] = info; } } QList consolidatedSources = groupedData.values(); From c78c95ab425263af31bd5454824013a8e3ab5682 Mon Sep 17 00:00:00 2001 From: Adam Treat Date: Thu, 25 Jul 2024 19:40:18 -0400 Subject: [PATCH 04/30] Add the extra snippets to the source excerpts. Signed-off-by: Adam Treat --- gpt4all-chat/bravesearch.cpp | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/gpt4all-chat/bravesearch.cpp b/gpt4all-chat/bravesearch.cpp index a9c0df47ab9b..5c05c24f8a42 100644 --- a/gpt4all-chat/bravesearch.cpp +++ b/gpt4all-chat/bravesearch.cpp @@ -101,8 +101,17 @@ static QPair> cleanBraveResponse(const QByteArray& } } + QStringList textKeys = {"description", "extra_snippets"}; + QJsonObject textObj; + for (const auto& key : textKeys) { + if (resultObj.contains(key)) { + textObj.insert(key, resultObj[key]); + } + } + + QJsonDocument textObjDoc(textObj); info.date = resultObj["date"].toString(); - info.text = resultObj["description"].toString(); // fixme + info.text = textObjDoc.toJson(QJsonDocument::Indented); info.url = resultObj["url"].toString(); QJsonObject meta_url = resultObj["meta_url"].toObject(); info.favicon = meta_url["favicon"].toString(); From dda59a97a64e5290171ebad0275d68c1518c4478 Mon Sep 17 00:00:00 2001 From: Adam Treat Date: Sat, 27 Jul 2024 11:11:41 -0400 Subject: [PATCH 05/30] Fix the way we're injecting the context back into the model for web search. Signed-off-by: Adam Treat --- gpt4all-chat/chat.cpp | 2 ++ gpt4all-chat/chatllm.cpp | 6 +++--- gpt4all-chat/chatmodel.h | 9 +++++++-- 3 files changed, 12 insertions(+), 5 deletions(-) diff --git a/gpt4all-chat/chat.cpp b/gpt4all-chat/chat.cpp index 7da8274cfd7d..2730eaf3b8c2 100644 --- a/gpt4all-chat/chat.cpp +++ b/gpt4all-chat/chat.cpp @@ -116,6 +116,7 @@ void Chat::resetResponseState() if (m_responseInProgress && m_responseState == Chat::LocalDocsRetrieval) return; + m_sourceExcerpts = QList(); m_generatedQuestions = QList(); emit generatedQuestionsChanged(); m_tokenSpeed = QString(); @@ -136,6 +137,7 @@ void Chat::prompt(const QString &prompt) void Chat::regenerateResponse() { const int index = m_chatModel->count() - 1; + m_sourceExcerpts = QList(); m_chatModel->updateSources(index, QList()); emit regenerateResponseRequested(); } diff --git a/gpt4all-chat/chatllm.cpp b/gpt4all-chat/chatllm.cpp index a6dea3959dbc..073a1d7ae484 100644 --- a/gpt4all-chat/chatllm.cpp +++ b/gpt4all-chat/chatllm.cpp @@ -869,13 +869,13 @@ bool ChatLLM::toolCallInternal(const QString &toolCall, int32_t n_predict, int32 static QRegularExpression re(R"(brave_search\.call\(query=\"([^\"]+)\"\))"); QRegularExpressionMatch match = re.match(toolCall); - QString prompt("<|start_header_id|>ipython<|end_header_id|>\n\n%1<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n%2"); + QString promptTemplate("<|start_header_id|>ipython<|end_header_id|>\n\n%1<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n%2"); QString query; if (match.hasMatch()) { query = match.captured(1); } else { qWarning() << "WARNING: Could not find the tool for " << toolCall; - return promptInternal(QList()/*collectionList*/, prompt.arg(QString()), QString("%1") /*promptTemplate*/, + return promptInternal(QList()/*collectionList*/, QString() /*prompt*/, promptTemplate, n_predict, top_k, top_p, min_p, temp, n_batch, repeat_penalty, repeat_penalty_tokens); } @@ -887,7 +887,7 @@ bool ChatLLM::toolCallInternal(const QString &toolCall, int32_t n_predict, int32 emit sourceExcerptsChanged(braveResponse.second); - return promptInternal(QList()/*collectionList*/, prompt.arg(braveResponse.first), QString("%1") /*promptTemplate*/, + return promptInternal(QList()/*collectionList*/, braveResponse.first, promptTemplate, n_predict, top_k, top_p, min_p, temp, n_batch, repeat_penalty, repeat_penalty_tokens); } diff --git a/gpt4all-chat/chatmodel.h b/gpt4all-chat/chatmodel.h index 1031ded5a266..97b812750b4e 100644 --- a/gpt4all-chat/chatmodel.h +++ b/gpt4all-chat/chatmodel.h @@ -219,8 +219,13 @@ class ChatModel : public QAbstractListModel if (index < 0 || index >= m_chatItems.size()) return; ChatItem &item = m_chatItems[index]; - item.sources = sources; - item.consolidatedSources = consolidateSources(sources); + if (sources.isEmpty()) { + item.sources.clear(); + item.consolidatedSources.clear(); + } else { + item.sources << sources; + item.consolidatedSources << consolidateSources(sources); + } emit dataChanged(createIndex(index, 0), createIndex(index, 0), {SourcesRole}); emit dataChanged(createIndex(index, 0), createIndex(index, 0), {ConsolidatedSourcesRole}); } From d2ee235388ac0550abc1b37e15fca79b6bbdde08 Mon Sep 17 00:00:00 2001 From: Adam Treat Date: Sat, 27 Jul 2024 11:43:32 -0400 Subject: [PATCH 06/30] Change the suggestion mode to turn on for tool calls by default. Signed-off-by: Adam Treat --- gpt4all-chat/chatllm.cpp | 10 +++++----- gpt4all-chat/chatllm.h | 2 +- gpt4all-chat/mysettings.cpp | 2 +- gpt4all-chat/mysettings.h | 6 +++--- gpt4all-chat/qml/ApplicationSettings.qml | 2 +- 5 files changed, 11 insertions(+), 11 deletions(-) diff --git a/gpt4all-chat/chatllm.cpp b/gpt4all-chat/chatllm.cpp index 073a1d7ae484..b54a86962a2d 100644 --- a/gpt4all-chat/chatllm.cpp +++ b/gpt4all-chat/chatllm.cpp @@ -770,14 +770,14 @@ bool ChatLLM::prompt(const QList &collectionList, const QString &prompt bool ChatLLM::promptInternal(const QList &collectionList, const QString &prompt, const QString &promptTemplate, int32_t n_predict, int32_t top_k, float top_p, float min_p, float temp, int32_t n_batch, float repeat_penalty, - int32_t repeat_penalty_tokens) + int32_t repeat_penalty_tokens, bool isToolCallResponse) { if (!isModelLoaded()) return false; QList databaseResults; const int retrievalSize = MySettings::globalInstance()->localDocsRetrievalSize(); - if (!collectionList.isEmpty()) { + if (!collectionList.isEmpty() && !isToolCallResponse) { emit requestRetrieveFromDB(collectionList, prompt, retrievalSize, &databaseResults); // blocks emit sourceExcerptsChanged(databaseResults); } @@ -847,7 +847,7 @@ bool ChatLLM::promptInternal(const QList &collectionList, const QString } SuggestionMode mode = MySettings::globalInstance()->suggestionMode(); - if (mode == SuggestionMode::On || (!databaseResults.isEmpty() && mode == SuggestionMode::LocalDocsOnly)) + if (mode == SuggestionMode::On || (mode == SuggestionMode::SourceExcerptsOnly && (!databaseResults.isEmpty() || isToolCallResponse))) generateQuestions(elapsed); else emit responseStopped(elapsed); @@ -876,7 +876,7 @@ bool ChatLLM::toolCallInternal(const QString &toolCall, int32_t n_predict, int32 } else { qWarning() << "WARNING: Could not find the tool for " << toolCall; return promptInternal(QList()/*collectionList*/, QString() /*prompt*/, promptTemplate, - n_predict, top_k, top_p, min_p, temp, n_batch, repeat_penalty, repeat_penalty_tokens); + n_predict, top_k, top_p, min_p, temp, n_batch, repeat_penalty, repeat_penalty_tokens, true /*isToolCallResponse*/); } const QString apiKey = MySettings::globalInstance()->braveSearchAPIKey(); @@ -888,7 +888,7 @@ bool ChatLLM::toolCallInternal(const QString &toolCall, int32_t n_predict, int32 emit sourceExcerptsChanged(braveResponse.second); return promptInternal(QList()/*collectionList*/, braveResponse.first, promptTemplate, - n_predict, top_k, top_p, min_p, temp, n_batch, repeat_penalty, repeat_penalty_tokens); + n_predict, top_k, top_p, min_p, temp, n_batch, repeat_penalty, repeat_penalty_tokens, true /*isToolCallResponse*/); } void ChatLLM::setShouldBeLoaded(bool b) diff --git a/gpt4all-chat/chatllm.h b/gpt4all-chat/chatllm.h index 31e99c716dea..b4874b702dfa 100644 --- a/gpt4all-chat/chatllm.h +++ b/gpt4all-chat/chatllm.h @@ -199,7 +199,7 @@ public Q_SLOTS: protected: bool promptInternal(const QList &collectionList, const QString &prompt, const QString &promptTemplate, int32_t n_predict, int32_t top_k, float top_p, float min_p, float temp, int32_t n_batch, float repeat_penalty, - int32_t repeat_penalty_tokens); + int32_t repeat_penalty_tokens, bool isToolCallResponse = false); bool toolCallInternal(const QString &toolcall, int32_t n_predict, int32_t top_k, float top_p, float min_p, float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens); bool handlePrompt(int32_t token); diff --git a/gpt4all-chat/mysettings.cpp b/gpt4all-chat/mysettings.cpp index 57f94ca21575..e1530a8539c6 100644 --- a/gpt4all-chat/mysettings.cpp +++ b/gpt4all-chat/mysettings.cpp @@ -51,7 +51,7 @@ static const QVariantMap basicDefaults { { "saveChatsContext", false }, { "serverChat", false }, { "userDefaultModel", "Application default" }, - { "suggestionMode", QVariant::fromValue(SuggestionMode::LocalDocsOnly) }, + { "suggestionMode", QVariant::fromValue(SuggestionMode::SourceExcerptsOnly) }, { "localdocs/chunkSize", 512 }, { "localdocs/retrievalSize", 3 }, { "localdocs/showReferences", true }, diff --git a/gpt4all-chat/mysettings.h b/gpt4all-chat/mysettings.h index d4aa4c5a3410..59301cf2fd93 100644 --- a/gpt4all-chat/mysettings.h +++ b/gpt4all-chat/mysettings.h @@ -21,9 +21,9 @@ namespace MySettingsEnums { * ApplicationSettings.qml, as well as the corresponding name lists in mysettings.cpp */ enum class SuggestionMode { - LocalDocsOnly = 0, - On = 1, - Off = 2, + SourceExcerptsOnly = 0, + On = 1, + Off = 2, }; Q_ENUM_NS(SuggestionMode) diff --git a/gpt4all-chat/qml/ApplicationSettings.qml b/gpt4all-chat/qml/ApplicationSettings.qml index 1e459e99c462..5fe4bea5e1cf 100644 --- a/gpt4all-chat/qml/ApplicationSettings.qml +++ b/gpt4all-chat/qml/ApplicationSettings.qml @@ -350,7 +350,7 @@ MySettingsTab { Layout.alignment: Qt.AlignRight // NOTE: indices match values of SuggestionMode enum, keep them in sync model: ListModel { - ListElement { name: qsTr("When chatting with LocalDocs") } + ListElement { name: qsTr("When source excerpts are cited") } ListElement { name: qsTr("Whenever possible") } ListElement { name: qsTr("Never") } } From b0578e28b9f405d5932ff857848bb607472c52ba Mon Sep 17 00:00:00 2001 From: Adam Treat Date: Sat, 27 Jul 2024 14:38:42 -0400 Subject: [PATCH 07/30] Change the name to inc the beta. Signed-off-by: Adam Treat --- gpt4all-chat/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gpt4all-chat/CMakeLists.txt b/gpt4all-chat/CMakeLists.txt index a3e06386a5c6..600d4c80d391 100644 --- a/gpt4all-chat/CMakeLists.txt +++ b/gpt4all-chat/CMakeLists.txt @@ -20,7 +20,7 @@ set(APP_VERSION_MAJOR 3) set(APP_VERSION_MINOR 2) set(APP_VERSION_PATCH 2) set(APP_VERSION_BASE "${APP_VERSION_MAJOR}.${APP_VERSION_MINOR}.${APP_VERSION_PATCH}") -set(APP_VERSION "${APP_VERSION_BASE}-web_search_beta") +set(APP_VERSION "${APP_VERSION_BASE}-web_search_beta_2") list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_LIST_DIR}/cmake/Modules") From b9684ff74199a3a125ffb7b91dcfb2e1510fc79d Mon Sep 17 00:00:00 2001 From: Adam Treat Date: Sun, 28 Jul 2024 11:48:53 -0400 Subject: [PATCH 08/30] Inc again for new beta. Signed-off-by: Adam Treat --- gpt4all-chat/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gpt4all-chat/CMakeLists.txt b/gpt4all-chat/CMakeLists.txt index 600d4c80d391..074859ad24a2 100644 --- a/gpt4all-chat/CMakeLists.txt +++ b/gpt4all-chat/CMakeLists.txt @@ -20,7 +20,7 @@ set(APP_VERSION_MAJOR 3) set(APP_VERSION_MINOR 2) set(APP_VERSION_PATCH 2) set(APP_VERSION_BASE "${APP_VERSION_MAJOR}.${APP_VERSION_MINOR}.${APP_VERSION_PATCH}") -set(APP_VERSION "${APP_VERSION_BASE}-web_search_beta_2") +set(APP_VERSION "${APP_VERSION_BASE}-web_search_beta_3") list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_LIST_DIR}/cmake/Modules") From 7c7558eed3db18a6c8c25307b9cd8d7d01b36a13 Mon Sep 17 00:00:00 2001 From: Adam Treat Date: Tue, 30 Jul 2024 10:30:56 -0400 Subject: [PATCH 09/30] Stop hardcoding the tool call checking and rely upon the format advocated by ollama for tool calling. Signed-off-by: Adam Treat --- gpt4all-chat/chatllm.cpp | 107 +++++++++++++++++------------ gpt4all-chat/chatllm.h | 2 + gpt4all-chat/modellist.cpp | 24 +++++++ gpt4all-chat/modellist.h | 6 ++ gpt4all-chat/mysettings.cpp | 7 ++ gpt4all-chat/mysettings.h | 3 + gpt4all-chat/qml/ModelSettings.qml | 70 +++++++++++++++---- 7 files changed, 162 insertions(+), 57 deletions(-) diff --git a/gpt4all-chat/chatllm.cpp b/gpt4all-chat/chatllm.cpp index b54a86962a2d..06c9223e2e19 100644 --- a/gpt4all-chat/chatllm.cpp +++ b/gpt4all-chat/chatllm.cpp @@ -115,7 +115,9 @@ ChatLLM::ChatLLM(Chat *parent, bool isServer) , m_reloadingToChangeVariant(false) , m_processedSystemPrompt(false) , m_restoreStateFromText(false) + , m_checkToolCall(false) , m_maybeToolCall(false) + , m_foundToolCall(false) { moveToThread(&m_llmThread); connect(this, &ChatLLM::shouldBeLoadedChanged, this, &ChatLLM::handleShouldBeLoadedChanged, @@ -705,34 +707,6 @@ bool ChatLLM::handleResponse(int32_t token, const std::string &response) return false; } - // Only valid for llama 3.1 instruct - if (m_modelInfo.filename().startsWith("Meta-Llama-3.1-8B-Instruct")) { - // Based on https://llama.meta.com/docs/model-cards-and-prompt-formats/llama3_1/#built-in-python-based-tool-calling - // For brave_search and wolfram_alpha ipython is always used - - // <|python_tag|> - // brave_search.call(query="...") - // <|eom_id|> - const int eom_id = 128008; - const int python_tag = 128010; - - // If we have a built-in tool call, then it should be the first token - const bool isFirstResponseToken = m_promptResponseTokens == m_promptTokens; - Q_ASSERT(token != python_tag || isFirstResponseToken); - if (isFirstResponseToken && token == python_tag) { - m_maybeToolCall = true; - ++m_promptResponseTokens; - return !m_stopGenerating; - } - - // Check for end of built-in tool call - Q_ASSERT(token != eom_id || !m_maybeToolCall); - if (token == eom_id) { - ++m_promptResponseTokens; - return false; - } - } - // m_promptResponseTokens is related to last prompt/response not // the entire context window which we can reset on regenerate prompt ++m_promptResponseTokens; @@ -740,7 +714,25 @@ bool ChatLLM::handleResponse(int32_t token, const std::string &response) Q_ASSERT(!response.empty()); m_response.append(response); - if (!m_maybeToolCall) + // If we're checking for a tool called and the response is equal or exceeds 11 chars + // then we check + if (m_checkToolCall && m_response.size() >= 11) { + if (m_response.starts_with("")) { + m_maybeToolCall = true; + m_response.erase(0, 11); + } + m_checkToolCall = false; + } + + // Check if we're at the end of tool call and erase the end tag + if (m_maybeToolCall && m_response.ends_with("")) { + m_foundToolCall = true; + m_response.erase(m_response.length() - 12); + return false; + } + + // If we're not checking for tool call and haven't detected one, then send along the response + if (!m_checkToolCall && !m_maybeToolCall) emit responseChanged(QString::fromStdString(remove_leading_whitespace(m_response))); return !m_stopGenerating; @@ -822,8 +814,12 @@ bool ChatLLM::promptInternal(const QList &collectionList, const QString /*allowContextShift*/ true, m_ctx); m_ctx.n_predict = old_n_predict; // now we are ready for a response } + + m_checkToolCall = !isToolCallResponse; // We can't handle recursive tool calls right now m_llModelInfo.model->prompt(prompt.toStdString(), promptTemplate.toStdString(), promptFunc, responseFunc, /*allowContextShift*/ true, m_ctx); + m_checkToolCall = false; + m_maybeToolCall = false; #if defined(DEBUG) printf("\n"); fflush(stdout); @@ -831,8 +827,8 @@ bool ChatLLM::promptInternal(const QList &collectionList, const QString m_timer->stop(); qint64 elapsed = totalTime.elapsed(); std::string trimmed = trim_whitespace(m_response); - if (m_maybeToolCall) { - m_maybeToolCall = false; + if (m_foundToolCall) { + m_foundToolCall = false; m_ctx.n_past = std::max(0, m_ctx.n_past); m_ctx.tokens.erase(m_ctx.tokens.end() - m_promptResponseTokens, m_ctx.tokens.end()); m_promptResponseTokens = 0; @@ -860,25 +856,46 @@ bool ChatLLM::promptInternal(const QList &collectionList, const QString bool ChatLLM::toolCallInternal(const QString &toolCall, int32_t n_predict, int32_t top_k, float top_p, float min_p, float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens) { - Q_ASSERT(m_modelInfo.filename().startsWith("Meta-Llama-3.1-8B-Instruct")); - emit toolCalled(tr("searching web...")); + QString toolTemplate = MySettings::globalInstance()->modelToolTemplate(m_modelInfo); + if (toolTemplate.isEmpty()) { + // FIXME: Not sure what to do here. The model attempted a tool call, but there is no way for + // us to process it. We should probably not even attempt further generation and just show an + // error in the chat somehow? + qWarning() << "WARNING: The model attempted a toolcall, but there is no valid tool template for this model" << toolCall; + return promptInternal(QList()/*collectionList*/, QString() /*prompt*/, + MySettings::globalInstance()->modelPromptTemplate(m_modelInfo), + n_predict, top_k, top_p, min_p, temp, n_batch, repeat_penalty, repeat_penalty_tokens, true /*isToolCallResponse*/); + } - // Based on https://llama.meta.com/docs/model-cards-and-prompt-formats/llama3_1/#built-in-python-based-tool-calling - // For brave_search and wolfram_alpha ipython is always used + QJsonParseError err; + QJsonDocument toolCallDoc = QJsonDocument::fromJson(toolCall.toUtf8(), &err); - static QRegularExpression re(R"(brave_search\.call\(query=\"([^\"]+)\"\))"); - QRegularExpressionMatch match = re.match(toolCall); + if (toolCallDoc.isNull() || err.error != QJsonParseError::NoError || !toolCallDoc.isObject()) { + qWarning() << "WARNING: The tool call had null or invalid json " << toolCall; + return promptInternal(QList()/*collectionList*/, QString() /*prompt*/, toolTemplate, + n_predict, top_k, top_p, min_p, temp, n_batch, repeat_penalty, repeat_penalty_tokens, true /*isToolCallResponse*/); + } - QString promptTemplate("<|start_header_id|>ipython<|end_header_id|>\n\n%1<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n%2"); - QString query; - if (match.hasMatch()) { - query = match.captured(1); - } else { - qWarning() << "WARNING: Could not find the tool for " << toolCall; - return promptInternal(QList()/*collectionList*/, QString() /*prompt*/, promptTemplate, + QJsonObject rootObject = toolCallDoc.object(); + if (!rootObject.contains("name") || !rootObject.contains("arguments")) { + qWarning() << "WARNING: The tool call did not have required name and argument objects " << toolCall; + return promptInternal(QList()/*collectionList*/, QString() /*prompt*/, toolTemplate, n_predict, top_k, top_p, min_p, temp, n_batch, repeat_penalty, repeat_penalty_tokens, true /*isToolCallResponse*/); } + const QString tool = toolCallDoc["name"].toString(); + const QJsonObject args = toolCallDoc["arguments"].toObject(); + + if (tool != "brave_search" || !args.contains("query")) { + qWarning() << "WARNING: Could not find the tool and correct arguments for " << toolCall; + return promptInternal(QList()/*collectionList*/, QString() /*prompt*/, toolTemplate, + n_predict, top_k, top_p, min_p, temp, n_batch, repeat_penalty, repeat_penalty_tokens, true /*isToolCallResponse*/); + } + + const QString query = args["query"].toString(); + + emit toolCalled(tr("searching web...")); + const QString apiKey = MySettings::globalInstance()->braveSearchAPIKey(); Q_ASSERT(apiKey != ""); @@ -887,7 +904,7 @@ bool ChatLLM::toolCallInternal(const QString &toolCall, int32_t n_predict, int32 emit sourceExcerptsChanged(braveResponse.second); - return promptInternal(QList()/*collectionList*/, braveResponse.first, promptTemplate, + return promptInternal(QList()/*collectionList*/, braveResponse.first, toolTemplate, n_predict, top_k, top_p, min_p, temp, n_batch, repeat_penalty, repeat_penalty_tokens, true /*isToolCallResponse*/); } diff --git a/gpt4all-chat/chatllm.h b/gpt4all-chat/chatllm.h index b4874b702dfa..d9d47ae94c8a 100644 --- a/gpt4all-chat/chatllm.h +++ b/gpt4all-chat/chatllm.h @@ -242,7 +242,9 @@ public Q_SLOTS: bool m_reloadingToChangeVariant; bool m_processedSystemPrompt; bool m_restoreStateFromText; + bool m_checkToolCall; bool m_maybeToolCall; + bool m_foundToolCall; // m_pristineLoadedState is set if saveSate is unnecessary, either because: // - an unload was queued during LLModel::restoreState() // - the chat will be restored from text and hasn't been interacted with yet diff --git a/gpt4all-chat/modellist.cpp b/gpt4all-chat/modellist.cpp index 580b615ff4e6..6ed911508e92 100644 --- a/gpt4all-chat/modellist.cpp +++ b/gpt4all-chat/modellist.cpp @@ -323,6 +323,17 @@ void ModelInfo::setPromptTemplate(const QString &t) m_promptTemplate = t; } +QString ModelInfo::toolTemplate() const +{ + return MySettings::globalInstance()->modelToolTemplate(*this); +} + +void ModelInfo::setToolTemplate(const QString &t) +{ + if (shouldSaveMetadata()) MySettings::globalInstance()->setModelToolTemplate(*this, t, true /*force*/); + m_toolTemplate = t; +} + QString ModelInfo::systemPrompt() const { return MySettings::globalInstance()->modelSystemPrompt(*this); @@ -385,6 +396,7 @@ QVariantMap ModelInfo::getFields() const { "repeatPenalty", m_repeatPenalty }, { "repeatPenaltyTokens", m_repeatPenaltyTokens }, { "promptTemplate", m_promptTemplate }, + { "toolTemplate", m_toolTemplate }, { "systemPrompt", m_systemPrompt }, { "chatNamePrompt", m_chatNamePrompt }, { "suggestedFollowUpPrompt", m_suggestedFollowUpPrompt }, @@ -504,6 +516,7 @@ ModelList::ModelList() connect(MySettings::globalInstance(), &MySettings::repeatPenaltyChanged, this, &ModelList::updateDataForSettings); connect(MySettings::globalInstance(), &MySettings::repeatPenaltyTokensChanged, this, &ModelList::updateDataForSettings);; connect(MySettings::globalInstance(), &MySettings::promptTemplateChanged, this, &ModelList::updateDataForSettings); + connect(MySettings::globalInstance(), &MySettings::toolTemplateChanged, this, &ModelList::updateDataForSettings); connect(MySettings::globalInstance(), &MySettings::systemPromptChanged, this, &ModelList::updateDataForSettings); connect(&m_networkManager, &QNetworkAccessManager::sslErrors, this, &ModelList::handleSslErrors); @@ -776,6 +789,8 @@ QVariant ModelList::dataInternal(const ModelInfo *info, int role) const return info->repeatPenaltyTokens(); case PromptTemplateRole: return info->promptTemplate(); + case ToolTemplateRole: + return info->toolTemplate(); case SystemPromptRole: return info->systemPrompt(); case ChatNamePromptRole: @@ -952,6 +967,8 @@ void ModelList::updateData(const QString &id, const QVector info->setRepeatPenaltyTokens(value.toInt()); break; case PromptTemplateRole: info->setPromptTemplate(value.toString()); break; + case ToolTemplateRole: + info->setToolTemplate(value.toString()); break; case SystemPromptRole: info->setSystemPrompt(value.toString()); break; case ChatNamePromptRole: @@ -1107,6 +1124,7 @@ QString ModelList::clone(const ModelInfo &model) { ModelList::RepeatPenaltyRole, model.repeatPenalty() }, { ModelList::RepeatPenaltyTokensRole, model.repeatPenaltyTokens() }, { ModelList::PromptTemplateRole, model.promptTemplate() }, + { ModelList::ToolTemplateRole, model.toolTemplate() }, { ModelList::SystemPromptRole, model.systemPrompt() }, { ModelList::ChatNamePromptRole, model.chatNamePrompt() }, { ModelList::SuggestedFollowUpPromptRole, model.suggestedFollowUpPrompt() }, @@ -1551,6 +1569,8 @@ void ModelList::parseModelsJsonFile(const QByteArray &jsonData, bool save) data.append({ ModelList::RepeatPenaltyTokensRole, obj["repeatPenaltyTokens"].toInt() }); if (obj.contains("promptTemplate")) data.append({ ModelList::PromptTemplateRole, obj["promptTemplate"].toString() }); + if (obj.contains("toolTemplate")) + data.append({ ModelList::ToolTemplateRole, obj["toolTemplate"].toString() }); if (obj.contains("systemPrompt")) data.append({ ModelList::SystemPromptRole, obj["systemPrompt"].toString() }); updateData(id, data); @@ -1852,6 +1872,10 @@ void ModelList::updateModelsFromSettings() const QString promptTemplate = settings.value(g + "/promptTemplate").toString(); data.append({ ModelList::PromptTemplateRole, promptTemplate }); } + if (settings.contains(g + "/toolTemplate")) { + const QString toolTemplate = settings.value(g + "/toolTemplate").toString(); + data.append({ ModelList::ToolTemplateRole, toolTemplate }); + } if (settings.contains(g + "/systemPrompt")) { const QString systemPrompt = settings.value(g + "/systemPrompt").toString(); data.append({ ModelList::SystemPromptRole, systemPrompt }); diff --git a/gpt4all-chat/modellist.h b/gpt4all-chat/modellist.h index 7c13da8ef4fd..9e9f088f76cb 100644 --- a/gpt4all-chat/modellist.h +++ b/gpt4all-chat/modellist.h @@ -68,6 +68,7 @@ struct ModelInfo { Q_PROPERTY(double repeatPenalty READ repeatPenalty WRITE setRepeatPenalty) Q_PROPERTY(int repeatPenaltyTokens READ repeatPenaltyTokens WRITE setRepeatPenaltyTokens) Q_PROPERTY(QString promptTemplate READ promptTemplate WRITE setPromptTemplate) + Q_PROPERTY(QString toolTemplate READ toolTemplate WRITE setToolTemplate) Q_PROPERTY(QString systemPrompt READ systemPrompt WRITE setSystemPrompt) Q_PROPERTY(QString chatNamePrompt READ chatNamePrompt WRITE setChatNamePrompt) Q_PROPERTY(QString suggestedFollowUpPrompt READ suggestedFollowUpPrompt WRITE setSuggestedFollowUpPrompt) @@ -178,6 +179,8 @@ struct ModelInfo { void setRepeatPenaltyTokens(int t); QString promptTemplate() const; void setPromptTemplate(const QString &t); + QString toolTemplate() const; + void setToolTemplate(const QString &t); QString systemPrompt() const; void setSystemPrompt(const QString &p); QString chatNamePrompt() const; @@ -215,6 +218,7 @@ struct ModelInfo { double m_repeatPenalty = 1.18; int m_repeatPenaltyTokens = 64; QString m_promptTemplate = "### Human:\n%1\n\n### Assistant:\n"; + QString m_toolTemplate = ""; QString m_systemPrompt = "### System:\nYou are an AI assistant who gives a quality response to whatever humans ask of you.\n\n"; QString m_chatNamePrompt = "Describe the above conversation in seven words or less."; QString m_suggestedFollowUpPrompt = "Suggest three very short factual follow-up questions that have not been answered yet or cannot be found inspired by the previous conversation and excerpts."; @@ -339,6 +343,7 @@ class ModelList : public QAbstractListModel RepeatPenaltyRole, RepeatPenaltyTokensRole, PromptTemplateRole, + ToolTemplateRole, SystemPromptRole, ChatNamePromptRole, SuggestedFollowUpPromptRole, @@ -393,6 +398,7 @@ class ModelList : public QAbstractListModel roles[RepeatPenaltyRole] = "repeatPenalty"; roles[RepeatPenaltyTokensRole] = "repeatPenaltyTokens"; roles[PromptTemplateRole] = "promptTemplate"; + roles[ToolTemplateRole] = "toolTemplate"; roles[SystemPromptRole] = "systemPrompt"; roles[ChatNamePromptRole] = "chatNamePrompt"; roles[SuggestedFollowUpPromptRole] = "suggestedFollowUpPrompt"; diff --git a/gpt4all-chat/mysettings.cpp b/gpt4all-chat/mysettings.cpp index e1530a8539c6..0b94989edfb7 100644 --- a/gpt4all-chat/mysettings.cpp +++ b/gpt4all-chat/mysettings.cpp @@ -194,6 +194,7 @@ void MySettings::restoreModelDefaults(const ModelInfo &info) setModelRepeatPenalty(info, info.m_repeatPenalty); setModelRepeatPenaltyTokens(info, info.m_repeatPenaltyTokens); setModelPromptTemplate(info, info.m_promptTemplate); + setModelToolTemplate(info, info.m_toolTemplate); setModelSystemPrompt(info, info.m_systemPrompt); setModelChatNamePrompt(info, info.m_chatNamePrompt); setModelSuggestedFollowUpPrompt(info, info.m_suggestedFollowUpPrompt); @@ -296,6 +297,7 @@ int MySettings::modelGpuLayers (const ModelInfo &info) const double MySettings::modelRepeatPenalty (const ModelInfo &info) const { return getModelSetting("repeatPenalty", info).toDouble(); } int MySettings::modelRepeatPenaltyTokens (const ModelInfo &info) const { return getModelSetting("repeatPenaltyTokens", info).toInt(); } QString MySettings::modelPromptTemplate (const ModelInfo &info) const { return getModelSetting("promptTemplate", info).toString(); } +QString MySettings::modelToolTemplate (const ModelInfo &info) const { return getModelSetting("toolTemplate", info).toString(); } QString MySettings::modelSystemPrompt (const ModelInfo &info) const { return getModelSetting("systemPrompt", info).toString(); } QString MySettings::modelChatNamePrompt (const ModelInfo &info) const { return getModelSetting("chatNamePrompt", info).toString(); } QString MySettings::modelSuggestedFollowUpPrompt(const ModelInfo &info) const { return getModelSetting("suggestedFollowUpPrompt", info).toString(); } @@ -405,6 +407,11 @@ void MySettings::setModelPromptTemplate(const ModelInfo &info, const QString &va setModelSetting("promptTemplate", info, value, force, true); } +void MySettings::setModelToolTemplate(const ModelInfo &info, const QString &value, bool force) +{ + setModelSetting("toolTemplate", info, value, force, true); +} + void MySettings::setModelSystemPrompt(const ModelInfo &info, const QString &value, bool force) { setModelSetting("systemPrompt", info, value, force, true); diff --git a/gpt4all-chat/mysettings.h b/gpt4all-chat/mysettings.h index 59301cf2fd93..205c21d03192 100644 --- a/gpt4all-chat/mysettings.h +++ b/gpt4all-chat/mysettings.h @@ -126,6 +126,8 @@ class MySettings : public QObject Q_INVOKABLE void setModelRepeatPenaltyTokens(const ModelInfo &info, int value, bool force = false); QString modelPromptTemplate(const ModelInfo &info) const; Q_INVOKABLE void setModelPromptTemplate(const ModelInfo &info, const QString &value, bool force = false); + QString modelToolTemplate(const ModelInfo &info) const; + Q_INVOKABLE void setModelToolTemplate(const ModelInfo &info, const QString &value, bool force = false); QString modelSystemPrompt(const ModelInfo &info) const; Q_INVOKABLE void setModelSystemPrompt(const ModelInfo &info, const QString &value, bool force = false); int modelContextLength(const ModelInfo &info) const; @@ -217,6 +219,7 @@ class MySettings : public QObject void repeatPenaltyChanged(const ModelInfo &info); void repeatPenaltyTokensChanged(const ModelInfo &info); void promptTemplateChanged(const ModelInfo &info); + void toolTemplateChanged(const ModelInfo &info); void systemPromptChanged(const ModelInfo &info); void chatNamePromptChanged(const ModelInfo &info); void suggestedFollowUpPromptChanged(const ModelInfo &info); diff --git a/gpt4all-chat/qml/ModelSettings.qml b/gpt4all-chat/qml/ModelSettings.qml index 5e896eb17495..9948934988f8 100644 --- a/gpt4all-chat/qml/ModelSettings.qml +++ b/gpt4all-chat/qml/ModelSettings.qml @@ -209,7 +209,7 @@ MySettingsTab { id: promptTemplateLabelHelp text: qsTr("Must contain the string \"%1\" to be replaced with the user's input.") color: theme.textErrorColor - visible: templateTextArea.text.indexOf("%1") === -1 + visible: promptTemplateTextArea.text.indexOf("%1") === -1 wrapMode: TextArea.Wrap } } @@ -220,27 +220,27 @@ MySettingsTab { Layout.column: 0 Layout.columnSpan: 2 Layout.fillWidth: true - Layout.minimumHeight: Math.max(100, templateTextArea.contentHeight + 20) + Layout.minimumHeight: Math.max(100, promptTemplateTextArea.contentHeight + 20) color: "transparent" clip: true MyTextArea { - id: templateTextArea + id: promptTemplateTextArea anchors.fill: parent text: root.currentModelInfo.promptTemplate Connections { target: MySettings function onPromptTemplateChanged() { - templateTextArea.text = root.currentModelInfo.promptTemplate; + promptTemplateTextArea.text = root.currentModelInfo.promptTemplate; } } Connections { target: root function onCurrentModelInfoChanged() { - templateTextArea.text = root.currentModelInfo.promptTemplate; + promptTemplateTextArea.text = root.currentModelInfo.promptTemplate; } } onTextChanged: { - if (templateTextArea.text.indexOf("%1") !== -1) { + if (promptTemplateTextArea.text.indexOf("%1") !== -1) { MySettings.setModelPromptTemplate(root.currentModelInfo, text) } } @@ -250,18 +250,64 @@ MySettingsTab { } } + MySettingsLabel { + Layout.row: 11 + Layout.column: 0 + Layout.columnSpan: 2 + Layout.topMargin: 15 + id: toolTemplateLabel + text: qsTr("Tool Template") + helpText: qsTr("The template that allows tool calls to inject information into the context.") + } + + Rectangle { + id: toolTemplate + Layout.row: 12 + Layout.column: 0 + Layout.columnSpan: 2 + Layout.fillWidth: true + Layout.minimumHeight: Math.max(100, toolTemplateTextArea.contentHeight + 20) + color: "transparent" + clip: true + MyTextArea { + id: toolTemplateTextArea + anchors.fill: parent + text: root.currentModelInfo.toolTemplate + Connections { + target: MySettings + function onToolTemplateChanged() { + toolTemplateTextArea.text = root.currentModelInfo.toolTemplate; + } + } + Connections { + target: root + function onCurrentModelInfoChanged() { + toolTemplateTextArea.text = root.currentModelInfo.toolTemplate; + } + } + onTextChanged: { + if (toolTemplateTextArea.text.indexOf("%1") !== -1) { + MySettings.setModelToolTemplate(root.currentModelInfo, text) + } + } + Accessible.role: Accessible.EditableText + Accessible.name: toolTemplateLabel.text + Accessible.description: toolTemplateLabel.text + } + } + MySettingsLabel { id: chatNamePromptLabel text: qsTr("Chat Name Prompt") helpText: qsTr("Prompt used to automatically generate chat names.") - Layout.row: 11 + Layout.row: 13 Layout.column: 0 Layout.topMargin: 15 } Rectangle { id: chatNamePrompt - Layout.row: 12 + Layout.row: 14 Layout.column: 0 Layout.columnSpan: 2 Layout.fillWidth: true @@ -297,14 +343,14 @@ MySettingsTab { id: suggestedFollowUpPromptLabel text: qsTr("Suggested FollowUp Prompt") helpText: qsTr("Prompt used to generate suggested follow-up questions.") - Layout.row: 13 + Layout.row: 15 Layout.column: 0 Layout.topMargin: 15 } Rectangle { id: suggestedFollowUpPrompt - Layout.row: 14 + Layout.row: 16 Layout.column: 0 Layout.columnSpan: 2 Layout.fillWidth: true @@ -337,7 +383,7 @@ MySettingsTab { } GridLayout { - Layout.row: 15 + Layout.row: 17 Layout.column: 0 Layout.columnSpan: 2 Layout.topMargin: 15 @@ -833,7 +879,7 @@ MySettingsTab { } Rectangle { - Layout.row: 16 + Layout.row: 18 Layout.column: 0 Layout.columnSpan: 2 Layout.topMargin: 15 From fffd9f341a3a4c60e43b98db3e08d94950942bd4 Mon Sep 17 00:00:00 2001 From: Adam Treat Date: Tue, 30 Jul 2024 14:51:32 -0400 Subject: [PATCH 10/30] Refactor to handle errors in tool calling better and add source comments. Signed-off-by: Adam Treat --- gpt4all-chat/chatllm.cpp | 127 +++++++++++++++++++++------------------ gpt4all-chat/chatllm.h | 3 +- 2 files changed, 70 insertions(+), 60 deletions(-) diff --git a/gpt4all-chat/chatllm.cpp b/gpt4all-chat/chatllm.cpp index 06c9223e2e19..b02592d9d5f4 100644 --- a/gpt4all-chat/chatllm.cpp +++ b/gpt4all-chat/chatllm.cpp @@ -815,11 +815,17 @@ bool ChatLLM::promptInternal(const QList &collectionList, const QString m_ctx.n_predict = old_n_predict; // now we are ready for a response } - m_checkToolCall = !isToolCallResponse; // We can't handle recursive tool calls right now + // We can't handle recursive tool calls right now otherwise we always try to check if we have a + // tool call + m_checkToolCall = !isToolCallResponse; + m_llModelInfo.model->prompt(prompt.toStdString(), promptTemplate.toStdString(), promptFunc, responseFunc, /*allowContextShift*/ true, m_ctx); + + // After the response has been handled reset this state m_checkToolCall = false; m_maybeToolCall = false; + #if defined(DEBUG) printf("\n"); fflush(stdout); @@ -827,15 +833,66 @@ bool ChatLLM::promptInternal(const QList &collectionList, const QString m_timer->stop(); qint64 elapsed = totalTime.elapsed(); std::string trimmed = trim_whitespace(m_response); + + // If we found a tool call, then deal with it if (m_foundToolCall) { m_foundToolCall = false; + + const QString toolCall = QString::fromStdString(trimmed); + const QString toolTemplate = MySettings::globalInstance()->modelToolTemplate(m_modelInfo); + if (toolTemplate.isEmpty()) { + qWarning() << "ERROR: No valid tool template for this model" << toolCall; + return handleFailedToolCall(trimmed, elapsed); + } + + QJsonParseError err; + const QJsonDocument toolCallDoc = QJsonDocument::fromJson(toolCall.toUtf8(), &err); + + if (toolCallDoc.isNull() || err.error != QJsonParseError::NoError || !toolCallDoc.isObject()) { + qWarning() << "ERROR: The tool call had null or invalid json " << toolCall; + return handleFailedToolCall(trimmed, elapsed); + } + + QJsonObject rootObject = toolCallDoc.object(); + if (!rootObject.contains("name") || !rootObject.contains("arguments")) { + qWarning() << "ERROR: The tool call did not have required name and argument objects " << toolCall; + return handleFailedToolCall(trimmed, elapsed); + } + + const QString tool = toolCallDoc["name"].toString(); + const QJsonObject args = toolCallDoc["arguments"].toObject(); + + // FIXME: In the future this will try to match the tool call to a list of tools that are supported + // according to MySettings, but for now only brave search is supported + if (tool != "brave_search" || !args.contains("query")) { + qWarning() << "ERROR: Could not find the tool and correct arguments for " << toolCall; + return handleFailedToolCall(trimmed, elapsed); + } + + const QString query = args["query"].toString(); + + // FIXME: This has to handle errors of the tool call + emit toolCalled(tr("searching web...")); + const QString apiKey = MySettings::globalInstance()->braveSearchAPIKey(); + Q_ASSERT(apiKey != ""); + BraveSearch brave; + const QPair> braveResponse = brave.search(apiKey, query, 2 /*topK*/, + 2000 /*msecs to timeout*/); + emit sourceExcerptsChanged(braveResponse.second); + + // Erase the context of the tool call m_ctx.n_past = std::max(0, m_ctx.n_past); m_ctx.tokens.erase(m_ctx.tokens.end() - m_promptResponseTokens, m_ctx.tokens.end()); m_promptResponseTokens = 0; m_promptTokens = 0; m_response = std::string(); - return toolCallInternal(QString::fromStdString(trimmed), n_predict, top_k, top_p, min_p, temp, - n_batch, repeat_penalty, repeat_penalty_tokens); + + // This is a recursive call but isToolCallResponse is checked above to arrest infinite recursive + // tool calls + return promptInternal(QList()/*collectionList*/, braveResponse.first, toolTemplate, + n_predict, top_k, top_p, min_p, temp, n_batch, repeat_penalty, repeat_penalty_tokens, + true /*isToolCallResponse*/); + } else { if (trimmed != m_response) { m_response = trimmed; @@ -847,65 +904,19 @@ bool ChatLLM::promptInternal(const QList &collectionList, const QString generateQuestions(elapsed); else emit responseStopped(elapsed); + m_pristineLoadedState = false; + return true; } - - m_pristineLoadedState = false; - return true; } -bool ChatLLM::toolCallInternal(const QString &toolCall, int32_t n_predict, int32_t top_k, float top_p, - float min_p, float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens) +bool ChatLLM::handleFailedToolCall(const std::string &response, qint64 elapsed) { - QString toolTemplate = MySettings::globalInstance()->modelToolTemplate(m_modelInfo); - if (toolTemplate.isEmpty()) { - // FIXME: Not sure what to do here. The model attempted a tool call, but there is no way for - // us to process it. We should probably not even attempt further generation and just show an - // error in the chat somehow? - qWarning() << "WARNING: The model attempted a toolcall, but there is no valid tool template for this model" << toolCall; - return promptInternal(QList()/*collectionList*/, QString() /*prompt*/, - MySettings::globalInstance()->modelPromptTemplate(m_modelInfo), - n_predict, top_k, top_p, min_p, temp, n_batch, repeat_penalty, repeat_penalty_tokens, true /*isToolCallResponse*/); - } - - QJsonParseError err; - QJsonDocument toolCallDoc = QJsonDocument::fromJson(toolCall.toUtf8(), &err); - - if (toolCallDoc.isNull() || err.error != QJsonParseError::NoError || !toolCallDoc.isObject()) { - qWarning() << "WARNING: The tool call had null or invalid json " << toolCall; - return promptInternal(QList()/*collectionList*/, QString() /*prompt*/, toolTemplate, - n_predict, top_k, top_p, min_p, temp, n_batch, repeat_penalty, repeat_penalty_tokens, true /*isToolCallResponse*/); - } - - QJsonObject rootObject = toolCallDoc.object(); - if (!rootObject.contains("name") || !rootObject.contains("arguments")) { - qWarning() << "WARNING: The tool call did not have required name and argument objects " << toolCall; - return promptInternal(QList()/*collectionList*/, QString() /*prompt*/, toolTemplate, - n_predict, top_k, top_p, min_p, temp, n_batch, repeat_penalty, repeat_penalty_tokens, true /*isToolCallResponse*/); - } - - const QString tool = toolCallDoc["name"].toString(); - const QJsonObject args = toolCallDoc["arguments"].toObject(); - - if (tool != "brave_search" || !args.contains("query")) { - qWarning() << "WARNING: Could not find the tool and correct arguments for " << toolCall; - return promptInternal(QList()/*collectionList*/, QString() /*prompt*/, toolTemplate, - n_predict, top_k, top_p, min_p, temp, n_batch, repeat_penalty, repeat_penalty_tokens, true /*isToolCallResponse*/); - } - - const QString query = args["query"].toString(); - - emit toolCalled(tr("searching web...")); - - const QString apiKey = MySettings::globalInstance()->braveSearchAPIKey(); - Q_ASSERT(apiKey != ""); - - BraveSearch brave; - const QPair> braveResponse = brave.search(apiKey, query, 2 /*topK*/, 2000 /*msecs to timeout*/); - - emit sourceExcerptsChanged(braveResponse.second); - - return promptInternal(QList()/*collectionList*/, braveResponse.first, toolTemplate, - n_predict, top_k, top_p, min_p, temp, n_batch, repeat_penalty, repeat_penalty_tokens, true /*isToolCallResponse*/); + // Restore the strings that we excluded previously when detecting the tool call + m_response = "" + response + ""; + emit responseChanged(QString::fromStdString(m_response)); + emit responseStopped(elapsed); + m_pristineLoadedState = false; + return true; } void ChatLLM::setShouldBeLoaded(bool b) diff --git a/gpt4all-chat/chatllm.h b/gpt4all-chat/chatllm.h index d9d47ae94c8a..050c0b7db6bc 100644 --- a/gpt4all-chat/chatllm.h +++ b/gpt4all-chat/chatllm.h @@ -200,8 +200,7 @@ public Q_SLOTS: bool promptInternal(const QList &collectionList, const QString &prompt, const QString &promptTemplate, int32_t n_predict, int32_t top_k, float top_p, float min_p, float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens, bool isToolCallResponse = false); - bool toolCallInternal(const QString &toolcall, int32_t n_predict, int32_t top_k, float top_p, float min_p, float temp, int32_t n_batch, float repeat_penalty, - int32_t repeat_penalty_tokens); + bool handleFailedToolCall(const std::string &toolCall, qint64 elapsed); bool handlePrompt(int32_t token); bool handleResponse(int32_t token, const std::string &response); bool handleNamePrompt(int32_t token); From dfe3e951d4103831fd07405a2dddeda136ab3588 Mon Sep 17 00:00:00 2001 From: Adam Treat Date: Wed, 31 Jul 2024 14:54:38 -0400 Subject: [PATCH 11/30] Refactor the brave search and introduce an abstraction for tool calls. Signed-off-by: Adam Treat --- gpt4all-chat/CMakeLists.txt | 3 +- gpt4all-chat/bravesearch.cpp | 175 ++++++++++----------------------- gpt4all-chat/bravesearch.h | 15 ++- gpt4all-chat/chatllm.cpp | 22 ++++- gpt4all-chat/qml/ChatView.qml | 9 +- gpt4all-chat/sourceexcerpt.cpp | 93 ++++++++++++++++++ gpt4all-chat/sourceexcerpt.h | 6 ++ gpt4all-chat/tool.cpp | 1 + gpt4all-chat/tool.h | 87 ++++++++++++++++ gpt4all-chat/toolinfo.h | 95 ++++++++++++++++++ 10 files changed, 366 insertions(+), 140 deletions(-) create mode 100644 gpt4all-chat/sourceexcerpt.cpp create mode 100644 gpt4all-chat/tool.cpp create mode 100644 gpt4all-chat/tool.h create mode 100644 gpt4all-chat/toolinfo.h diff --git a/gpt4all-chat/CMakeLists.txt b/gpt4all-chat/CMakeLists.txt index 074859ad24a2..7f0c51436213 100644 --- a/gpt4all-chat/CMakeLists.txt +++ b/gpt4all-chat/CMakeLists.txt @@ -121,9 +121,10 @@ qt_add_executable(chat modellist.h modellist.cpp mysettings.h mysettings.cpp network.h network.cpp - sourceexcerpt.h + sourceexcerpt.h sourceexcerpt.cpp server.h server.cpp logger.h logger.cpp + tool.h tool.cpp ${APP_ICON_RESOURCE} ${CHAT_EXE_RESOURCES} ) diff --git a/gpt4all-chat/bravesearch.cpp b/gpt4all-chat/bravesearch.cpp index 5c05c24f8a42..85e8497190ab 100644 --- a/gpt4all-chat/bravesearch.cpp +++ b/gpt4all-chat/bravesearch.cpp @@ -16,15 +16,19 @@ using namespace Qt::Literals::StringLiterals; -QPair> BraveSearch::search(const QString &apiKey, const QString &query, int topK, unsigned long timeout) +QString BraveSearch::run(const QJsonObject ¶meters, qint64 timeout) { + const QString apiKey = parameters["apiKey"].toString(); + const QString query = parameters["query"].toString(); + const int count = parameters["count"].toInt(); QThread workerThread; BraveAPIWorker worker; worker.moveToThread(&workerThread); connect(&worker, &BraveAPIWorker::finished, &workerThread, &QThread::quit, Qt::DirectConnection); - connect(this, &BraveSearch::request, &worker, &BraveAPIWorker::request, Qt::QueuedConnection); + connect(&workerThread, &QThread::started, [&worker, apiKey, query, count]() { + worker.request(apiKey, query, count); + }); workerThread.start(); - emit request(apiKey, query, topK); workerThread.wait(timeout); workerThread.quit(); workerThread.wait(); @@ -34,19 +38,25 @@ QPair> BraveSearch::search(const QString &apiKey, void BraveAPIWorker::request(const QString &apiKey, const QString &query, int topK) { m_topK = topK; + + // Documentation on the brave web search: + // https://api.search.brave.com/app/documentation/web-search/get-started QUrl jsonUrl("https://api.search.brave.com/res/v1/web/search"); + + // Documentation on the query options: + //https://api.search.brave.com/app/documentation/web-search/query QUrlQuery urlQuery; urlQuery.addQueryItem("q", query); + urlQuery.addQueryItem("count", QString::number(topK)); + urlQuery.addQueryItem("result_filter", "web"); + urlQuery.addQueryItem("extra_snippets", "true"); jsonUrl.setQuery(urlQuery); QNetworkRequest request(jsonUrl); QSslConfiguration conf = request.sslConfiguration(); conf.setPeerVerifyMode(QSslSocket::VerifyNone); request.setSslConfiguration(conf); - request.setRawHeader("X-Subscription-Token", apiKey.toUtf8()); -// request.setRawHeader("Accept-Encoding", "gzip"); request.setRawHeader("Accept", "application/json"); - m_networkManager = new QNetworkAccessManager(this); QNetworkReply *reply = m_networkManager->get(request); connect(qGuiApp, &QCoreApplication::aboutToQuit, reply, &QNetworkReply::abort); @@ -54,154 +64,71 @@ void BraveAPIWorker::request(const QString &apiKey, const QString &query, int to connect(reply, &QNetworkReply::errorOccurred, this, &BraveAPIWorker::handleErrorOccurred); } -static QPair> cleanBraveResponse(const QByteArray& jsonResponse, qsizetype topK = 1) +static QString cleanBraveResponse(const QByteArray& jsonResponse, qsizetype topK = 1) { + // This parses the response from brave and formats it in json that conforms to the de facto + // standard in SourceExcerpts::fromJson(...) QJsonParseError err; QJsonDocument document = QJsonDocument::fromJson(jsonResponse, &err); if (err.error != QJsonParseError::NoError) { - qWarning() << "ERROR: Couldn't parse: " << jsonResponse << err.errorString(); - return QPair>(); + qWarning() << "ERROR: Couldn't parse brave response: " << jsonResponse << err.errorString(); + return QString(); } + QString query; QJsonObject searchResponse = document.object(); QJsonObject cleanResponse; - QString query; QJsonArray cleanArray; - QList infos; - if (searchResponse.contains("query")) { QJsonObject queryObj = searchResponse["query"].toObject(); - if (queryObj.contains("original")) { + if (queryObj.contains("original")) query = queryObj["original"].toString(); - } } if (searchResponse.contains("mixed")) { QJsonObject mixedResults = searchResponse["mixed"].toObject(); QJsonArray mainResults = mixedResults["main"].toArray(); + QJsonObject resultsObject = searchResponse["web"].toObject(); + QJsonArray resultsArray = resultsObject["results"].toArray(); - for (int i = 0; i < std::min(mainResults.size(), topK); ++i) { + for (int i = 0; i < std::min(mainResults.size(), resultsArray.size()); ++i) { QJsonObject m = mainResults[i].toObject(); QString r_type = m["type"].toString(); - int idx = m["index"].toInt(); - QJsonObject resultsObject = searchResponse[r_type].toObject(); - QJsonArray resultsArray = resultsObject["results"].toArray(); - - QJsonValue cleaned; - SourceExcerpt info; - if (r_type == "web") { - // For web data - add a single output from the search - QJsonObject resultObj = resultsArray[idx].toObject(); - QStringList selectedKeys = {"type", "title", "url", "description", "date", "extra_snippets"}; - QJsonObject cleanedObj; - for (const auto& key : selectedKeys) { - if (resultObj.contains(key)) { - cleanedObj.insert(key, resultObj[key]); - } - } - - QStringList textKeys = {"description", "extra_snippets"}; - QJsonObject textObj; - for (const auto& key : textKeys) { - if (resultObj.contains(key)) { - textObj.insert(key, resultObj[key]); - } + Q_ASSERT(r_type == "web"); + const int idx = m["index"].toInt(); + + QJsonObject resultObj = resultsArray[idx].toObject(); + QStringList selectedKeys = {"type", "title", "url", "description"}; + QJsonObject result; + for (const auto& key : selectedKeys) + if (resultObj.contains(key)) + result.insert(key, resultObj[key]); + + if (resultObj.contains("page_age")) + result.insert("date", resultObj["page_age"]); + + QJsonArray excerpts; + if (resultObj.contains("extra_snippets")) { + QJsonArray snippets = resultObj["extra_snippets"].toArray(); + for (int i = 0; i < snippets.size(); ++i) { + QString snippet = snippets[i].toString(); + QJsonObject excerpt; + excerpt.insert("text", snippet); + excerpts.append(excerpt); } - - QJsonDocument textObjDoc(textObj); - info.date = resultObj["date"].toString(); - info.text = textObjDoc.toJson(QJsonDocument::Indented); - info.url = resultObj["url"].toString(); - QJsonObject meta_url = resultObj["meta_url"].toObject(); - info.favicon = meta_url["favicon"].toString(); - info.title = resultObj["title"].toString(); - - cleaned = cleanedObj; - } else if (r_type == "faq") { - // For faq data - take a list of all the questions & answers - QStringList selectedKeys = {"type", "question", "answer", "title", "url"}; - QJsonArray cleanedArray; - for (const auto& q : resultsArray) { - QJsonObject qObj = q.toObject(); - QJsonObject cleanedObj; - for (const auto& key : selectedKeys) { - if (qObj.contains(key)) { - cleanedObj.insert(key, qObj[key]); - } - } - cleanedArray.append(cleanedObj); - } - cleaned = cleanedArray; - } else if (r_type == "infobox") { - QJsonObject resultObj = resultsArray[idx].toObject(); - QStringList selectedKeys = {"type", "title", "url", "description", "long_desc"}; - QJsonObject cleanedObj; - for (const auto& key : selectedKeys) { - if (resultObj.contains(key)) { - cleanedObj.insert(key, resultObj[key]); - } - } - cleaned = cleanedObj; - } else if (r_type == "videos") { - QStringList selectedKeys = {"type", "url", "title", "description", "date"}; - QJsonArray cleanedArray; - for (const auto& q : resultsArray) { - QJsonObject qObj = q.toObject(); - QJsonObject cleanedObj; - for (const auto& key : selectedKeys) { - if (qObj.contains(key)) { - cleanedObj.insert(key, qObj[key]); - } - } - cleanedArray.append(cleanedObj); - } - cleaned = cleanedArray; - } else if (r_type == "locations") { - QStringList selectedKeys = {"type", "title", "url", "description", "coordinates", "postal_address", "contact", "rating", "distance", "zoom_level"}; - QJsonArray cleanedArray; - for (const auto& q : resultsArray) { - QJsonObject qObj = q.toObject(); - QJsonObject cleanedObj; - for (const auto& key : selectedKeys) { - if (qObj.contains(key)) { - cleanedObj.insert(key, qObj[key]); - } - } - cleanedArray.append(cleanedObj); - } - cleaned = cleanedArray; - } else if (r_type == "news") { - QStringList selectedKeys = {"type", "title", "url", "description"}; - QJsonArray cleanedArray; - for (const auto& q : resultsArray) { - QJsonObject qObj = q.toObject(); - QJsonObject cleanedObj; - for (const auto& key : selectedKeys) { - if (qObj.contains(key)) { - cleanedObj.insert(key, qObj[key]); - } - } - cleanedArray.append(cleanedObj); - } - cleaned = cleanedArray; - } else { - cleaned = QJsonValue(); } - - infos.append(info); - cleanArray.append(cleaned); + result.insert("excerpts", excerpts); + cleanArray.append(QJsonValue(result)); } } cleanResponse.insert("query", query); - cleanResponse.insert("top_k", cleanArray); + cleanResponse.insert("results", cleanArray); QJsonDocument cleanedDoc(cleanResponse); - // qDebug().noquote() << document.toJson(QJsonDocument::Indented); // qDebug().noquote() << cleanedDoc.toJson(QJsonDocument::Indented); - - return qMakePair(cleanedDoc.toJson(QJsonDocument::Indented), infos); + return cleanedDoc.toJson(QJsonDocument::Compact); } void BraveAPIWorker::handleFinished() diff --git a/gpt4all-chat/bravesearch.h b/gpt4all-chat/bravesearch.h index 482b29a6199d..6f617c4ae3f9 100644 --- a/gpt4all-chat/bravesearch.h +++ b/gpt4all-chat/bravesearch.h @@ -2,6 +2,7 @@ #define BRAVESEARCH_H #include "sourceexcerpt.h" +#include "tool.h" #include #include @@ -17,7 +18,7 @@ class BraveAPIWorker : public QObject { , m_topK(1) {} virtual ~BraveAPIWorker() {} - QPair> response() const { return m_response; } + QString response() const { return m_response; } public Q_SLOTS: void request(const QString &apiKey, const QString &query, int topK); @@ -31,21 +32,17 @@ private Q_SLOTS: private: QNetworkAccessManager *m_networkManager; - QPair> m_response; + QString m_response; int m_topK; }; -class BraveSearch : public QObject { +class BraveSearch : public Tool { Q_OBJECT public: - BraveSearch() - : QObject(nullptr) {} + BraveSearch() : Tool() {} virtual ~BraveSearch() {} - QPair> search(const QString &apiKey, const QString &query, int topK, unsigned long timeout = 2000); - -Q_SIGNALS: - void request(const QString &apiKey, const QString &query, int topK); + QString run(const QJsonObject ¶meters, qint64 timeout = 2000) override; }; #endif // BRAVESEARCH_H diff --git a/gpt4all-chat/chatllm.cpp b/gpt4all-chat/chatllm.cpp index b02592d9d5f4..5c4f168c31b5 100644 --- a/gpt4all-chat/chatllm.cpp +++ b/gpt4all-chat/chatllm.cpp @@ -871,14 +871,26 @@ bool ChatLLM::promptInternal(const QList &collectionList, const QString const QString query = args["query"].toString(); - // FIXME: This has to handle errors of the tool call emit toolCalled(tr("searching web...")); const QString apiKey = MySettings::globalInstance()->braveSearchAPIKey(); Q_ASSERT(apiKey != ""); BraveSearch brave; - const QPair> braveResponse = brave.search(apiKey, query, 2 /*topK*/, - 2000 /*msecs to timeout*/); - emit sourceExcerptsChanged(braveResponse.second); + + QJsonObject parameters; + parameters.insert("apiKey", apiKey); + parameters.insert("query", query); + parameters.insert("count", 2); + + // FIXME: This has to handle errors of the tool call + const QString braveResponse = brave.run(parameters, 2000 /*msecs to timeout*/); + + QString parseError; + QList sourceExcerpts = SourceExcerpt::fromJson(braveResponse, parseError); + if (!parseError.isEmpty()) { + qWarning() << "ERROR: Could not parse source excerpts for brave response" << parseError; + } else if (!sourceExcerpts.isEmpty()) { + emit sourceExcerptsChanged(sourceExcerpts); + } // Erase the context of the tool call m_ctx.n_past = std::max(0, m_ctx.n_past); @@ -889,7 +901,7 @@ bool ChatLLM::promptInternal(const QList &collectionList, const QString // This is a recursive call but isToolCallResponse is checked above to arrest infinite recursive // tool calls - return promptInternal(QList()/*collectionList*/, braveResponse.first, toolTemplate, + return promptInternal(QList()/*collectionList*/, braveResponse, toolTemplate, n_predict, top_k, top_p, min_p, temp, n_batch, repeat_penalty, repeat_penalty_tokens, true /*isToolCallResponse*/); diff --git a/gpt4all-chat/qml/ChatView.qml b/gpt4all-chat/qml/ChatView.qml index 75a6fc146732..c35843cf50c4 100644 --- a/gpt4all-chat/qml/ChatView.qml +++ b/gpt4all-chat/qml/ChatView.qml @@ -1133,7 +1133,14 @@ Rectangle { sourceSize.width: 24 sourceSize.height: 24 mipmap: true - source: consolidatedSources[0].url === "" ? "qrc:/gpt4all/icons/db.svg" : "qrc:/gpt4all/icons/globe.svg" + source: { + if (typeof consolidatedSources === 'undefined' + || typeof consolidatedSources[0] === 'undefined' + || consolidatedSources[0].url === "") + return "qrc:/gpt4all/icons/db.svg"; + else + return "qrc:/gpt4all/icons/globe.svg"; + } } ColorOverlay { diff --git a/gpt4all-chat/sourceexcerpt.cpp b/gpt4all-chat/sourceexcerpt.cpp new file mode 100644 index 000000000000..811d65f050d0 --- /dev/null +++ b/gpt4all-chat/sourceexcerpt.cpp @@ -0,0 +1,93 @@ +#include "sourceexcerpt.h" + +#include +#include +#include +#include + +QList SourceExcerpt::fromJson(const QString &json, QString &errorString) +{ + QJsonParseError err; + QJsonDocument document = QJsonDocument::fromJson(json.toUtf8(), &err); + if (err.error != QJsonParseError::NoError) { + errorString = err.errorString(); + return QList(); + } + + QJsonObject jsonObject = document.object(); + Q_ASSERT(jsonObject.contains("results")); + if (!jsonObject.contains("results")) { + errorString = "json does not contain results array"; + return QList(); + } + + QList excerpts; + QJsonArray results = jsonObject["results"].toArray(); + for (int i = 0; i < results.size(); ++i) { + QJsonObject result = results[i].toObject(); + if (!result.contains("date")) { + errorString = "result does not contain required date field"; + return QList(); + } + + if (!result.contains("excerpts") || !result["excerpts"].isArray()) { + errorString = "result does not contain required excerpts array"; + return QList(); + } + + QJsonArray textExcerpts = result["excerpts"].toArray(); + if (textExcerpts.isEmpty()) { + errorString = "result excerpts array is empty"; + return QList(); + } + + SourceExcerpt source; + source.date = result["date"].toString(); + if (result.contains("collection")) + source.collection = result["text"].toString(); + if (result.contains("path")) + source.path = result["path"].toString(); + if (result.contains("file")) + source.file = result["file"].toString(); + if (result.contains("url")) + source.url = result["url"].toString(); + if (result.contains("favicon")) + source.favicon = result["favicon"].toString(); + if (result.contains("title")) + source.title = result["title"].toString(); + if (result.contains("author")) + source.author = result["author"].toString(); + if (result.contains("description")) + source.author = result["description"].toString(); + + for (int i = 0; i < textExcerpts.size(); ++i) { + SourceExcerpt excerpt; + excerpt.date = source.date; + excerpt.collection = source.collection; + excerpt.path = source.path; + excerpt.file = source.file; + excerpt.url = source.url; + excerpt.favicon = source.favicon; + excerpt.title = source.title; + excerpt.author = source.author; + if (!textExcerpts[i].isObject()) { + errorString = "result excerpt is not an object"; + return QList(); + } + QJsonObject excerptObj = textExcerpts[i].toObject(); + if (!excerptObj.contains("text")) { + errorString = "result excerpt is does not have text field"; + return QList(); + } + excerpt.text = excerptObj["text"].toString(); + if (excerptObj.contains("page")) + excerpt.page = excerptObj["page"].toInt(); + if (excerptObj.contains("from")) + excerpt.from = excerptObj["from"].toInt(); + if (excerptObj.contains("to")) + excerpt.to = excerptObj["to"].toInt(); + excerpts.append(excerpt); + } + } + return excerpts; +} diff --git a/gpt4all-chat/sourceexcerpt.h b/gpt4all-chat/sourceexcerpt.h index 91497e9daf17..c66007f05aca 100644 --- a/gpt4all-chat/sourceexcerpt.h +++ b/gpt4all-chat/sourceexcerpt.h @@ -19,6 +19,7 @@ struct SourceExcerpt { Q_PROPERTY(QString favicon MEMBER favicon) Q_PROPERTY(QString title MEMBER title) Q_PROPERTY(QString author MEMBER author) + Q_PROPERTY(QString description MEMBER description) Q_PROPERTY(int page MEMBER page) Q_PROPERTY(int from MEMBER from) Q_PROPERTY(int to MEMBER to) @@ -34,6 +35,7 @@ struct SourceExcerpt { QString favicon; // [Optional] The favicon QString title; // [Optional] The title of the document QString author; // [Optional] The author of the document + QString description;// [Optional] The description of the source int page = -1; // [Optional] The page where the text was found int from = -1; // [Optional] The line number where the text begins int to = -1; // [Optional] The line number where the text ends @@ -65,12 +67,15 @@ struct SourceExcerpt { result.insert("favicon", favicon); result.insert("title", title); result.insert("author", author); + result.insert("description", description); result.insert("page", page); result.insert("from", from); result.insert("to", to); return result; } + static QList fromJson(const QString &json, QString &errorString); + bool operator==(const SourceExcerpt &other) const { return date == other.date && text == other.text && @@ -81,6 +86,7 @@ struct SourceExcerpt { favicon == other.favicon && title == other.title && author == other.author && + description == other.description && page == other.page && from == other.from && to == other.to; diff --git a/gpt4all-chat/tool.cpp b/gpt4all-chat/tool.cpp new file mode 100644 index 000000000000..7261f42bf170 --- /dev/null +++ b/gpt4all-chat/tool.cpp @@ -0,0 +1 @@ +#include "tool.h" diff --git a/gpt4all-chat/tool.h b/gpt4all-chat/tool.h new file mode 100644 index 000000000000..240d2aa25d5b --- /dev/null +++ b/gpt4all-chat/tool.h @@ -0,0 +1,87 @@ +#ifndef TOOL_H +#define TOOL_H + +#include "sourceexcerpt.h" + +#include +#include + +using namespace Qt::Literals::StringLiterals; + +namespace ToolEnums { + Q_NAMESPACE + enum class ConnectionType { + Builtin = 0, // A built-in tool with bespoke connection type + Local = 1, // Starts a local process and communicates via stdin/stdout/stderr + LocalServer = 2, // Connects to an existing local process and communicates via stdin/stdout/stderr + Remote = 3, // Starts a remote process and communicates via some networking protocol TBD + RemoteServer = 4 // Connects to an existing remote process and communicates via some networking protocol TBD + }; + Q_ENUM_NS(ConnectionType) +} +using namespace ToolEnums; + +struct ToolInfo { + Q_GADGET + Q_PROPERTY(QString name MEMBER name) + Q_PROPERTY(QString description MEMBER description) + Q_PROPERTY(QJsonObject parameters MEMBER parameters) + Q_PROPERTY(bool isEnabled MEMBER isEnabled) + Q_PROPERTY(ConnectionType connectionType MEMBER connectionType) + +public: + QString name; + QString description; + QJsonObject parameters; + bool isEnabled; + ConnectionType connectionType; + + // FIXME: Should we go with essentially the OpenAI/ollama consensus for these tool + // info files? If you install a tool in GPT4All should it need to meet the spec for these: + // https://platform.openai.com/docs/api-reference/runs/createRun#runs-createrun-tools + // https://github.com/ollama/ollama/blob/main/docs/api.md#chat-request-with-tools + QJsonObject toJson() const + { + QJsonObject result; + result.insert("name", name); + result.insert("description", description); + result.insert("parameters", parameters); + return result; + } + + static ToolInfo fromJson(const QString &json); + + bool operator==(const ToolInfo &other) const { + return name == other.name; + } + bool operator!=(const ToolInfo &other) const { + return !(*this == other); + } +}; +Q_DECLARE_METATYPE(ToolInfo) + +class Tool : public QObject { + Q_OBJECT +public: + Tool() : QObject(nullptr) {} + virtual ~Tool() {} + + // FIXME: How to handle errors? + virtual QString run(const QJsonObject ¶meters, qint64 timeout = 2000) = 0; +}; + +//class BuiltinTool : public Tool { +// Q_OBJECT +//public: +// BuiltinTool() : Tool() {} +// virtual QString run(const QJsonObject ¶meters, qint64 timeout = 2000); +//}; + +//class LocalTool : public Tool { +// Q_OBJECT +//public: +// LocalTool() : Tool() {} +// virtual QString run(const QJsonObject ¶meters, qint64 timeout = 2000); +//}; + +#endif // TOOL_H diff --git a/gpt4all-chat/toolinfo.h b/gpt4all-chat/toolinfo.h new file mode 100644 index 000000000000..91497e9daf17 --- /dev/null +++ b/gpt4all-chat/toolinfo.h @@ -0,0 +1,95 @@ +#ifndef SOURCEEXCERT_H +#define SOURCEEXCERT_H + +#include +#include +#include +#include + +using namespace Qt::Literals::StringLiterals; + +struct SourceExcerpt { + Q_GADGET + Q_PROPERTY(QString date MEMBER date) + Q_PROPERTY(QString text MEMBER text) + Q_PROPERTY(QString collection MEMBER collection) + Q_PROPERTY(QString path MEMBER path) + Q_PROPERTY(QString file MEMBER file) + Q_PROPERTY(QString url MEMBER url) + Q_PROPERTY(QString favicon MEMBER favicon) + Q_PROPERTY(QString title MEMBER title) + Q_PROPERTY(QString author MEMBER author) + Q_PROPERTY(int page MEMBER page) + Q_PROPERTY(int from MEMBER from) + Q_PROPERTY(int to MEMBER to) + Q_PROPERTY(QString fileUri READ fileUri STORED false) + +public: + QString date; // [Required] The creation or the last modification date whichever is latest + QString text; // [Required] The text actually used in the augmented context + QString collection; // [Optional] The name of the collection + QString path; // [Optional] The full path + QString file; // [Optional] The name of the file, but not the full path + QString url; // [Optional] The name of the remote url + QString favicon; // [Optional] The favicon + QString title; // [Optional] The title of the document + QString author; // [Optional] The author of the document + int page = -1; // [Optional] The page where the text was found + int from = -1; // [Optional] The line number where the text begins + int to = -1; // [Optional] The line number where the text ends + + QString fileUri() const { + // QUrl reserved chars that are not UNSAFE_PATH according to glib/gconvert.c + static const QByteArray s_exclude = "!$&'()*+,/:=@~"_ba; + + Q_ASSERT(!QFileInfo(path).isRelative()); +#ifdef Q_OS_WINDOWS + Q_ASSERT(!path.contains('\\')); // Qt normally uses forward slash as path separator +#endif + + auto escaped = QString::fromUtf8(QUrl::toPercentEncoding(path, s_exclude)); + if (escaped.front() != '/') + escaped = '/' + escaped; + return u"file://"_s + escaped; + } + + QJsonObject toJson() const + { + QJsonObject result; + result.insert("date", date); + result.insert("text", text); + result.insert("collection", collection); + result.insert("path", path); + result.insert("file", file); + result.insert("url", url); + result.insert("favicon", favicon); + result.insert("title", title); + result.insert("author", author); + result.insert("page", page); + result.insert("from", from); + result.insert("to", to); + return result; + } + + bool operator==(const SourceExcerpt &other) const { + return date == other.date && + text == other.text && + collection == other.collection && + path == other.path && + file == other.file && + url == other.url && + favicon == other.favicon && + title == other.title && + author == other.author && + page == other.page && + from == other.from && + to == other.to; + } + bool operator!=(const SourceExcerpt &other) const { + return !(*this == other); + } +}; + +Q_DECLARE_METATYPE(SourceExcerpt) + +#endif // SOURCEEXCERT_H From 01f67c74ea16150ad11ade1354bfbedbb9defe71 Mon Sep 17 00:00:00 2001 From: Adam Treat Date: Wed, 31 Jul 2024 22:46:30 -0400 Subject: [PATCH 12/30] Begin converting the localdocs to a tool. Signed-off-by: Adam Treat --- gpt4all-chat/CMakeLists.txt | 2 +- gpt4all-chat/bravesearch.cpp | 22 ++++++++----- gpt4all-chat/bravesearch.h | 7 ++-- gpt4all-chat/chatllm.cpp | 40 ++++++++++++++--------- gpt4all-chat/chatllm.h | 1 - gpt4all-chat/database.cpp | 56 +++++++++++++++++++------------- gpt4all-chat/database.h | 3 +- gpt4all-chat/localdocssearch.cpp | 50 ++++++++++++++++++++++++++++ gpt4all-chat/localdocssearch.h | 36 ++++++++++++++++++++ gpt4all-chat/sourceexcerpt.h | 14 +------- gpt4all-chat/tool.h | 16 --------- 11 files changed, 163 insertions(+), 84 deletions(-) create mode 100644 gpt4all-chat/localdocssearch.cpp create mode 100644 gpt4all-chat/localdocssearch.h diff --git a/gpt4all-chat/CMakeLists.txt b/gpt4all-chat/CMakeLists.txt index 7f0c51436213..1ab4db20f89c 100644 --- a/gpt4all-chat/CMakeLists.txt +++ b/gpt4all-chat/CMakeLists.txt @@ -116,7 +116,7 @@ qt_add_executable(chat database.h database.cpp download.h download.cpp embllm.cpp embllm.h - localdocs.h localdocs.cpp localdocsmodel.h localdocsmodel.cpp + localdocs.h localdocs.cpp localdocsmodel.h localdocsmodel.cpp localdocssearch.h localdocssearch.cpp llm.h llm.cpp modellist.h modellist.cpp mysettings.h mysettings.cpp diff --git a/gpt4all-chat/bravesearch.cpp b/gpt4all-chat/bravesearch.cpp index 85e8497190ab..e691810bb567 100644 --- a/gpt4all-chat/bravesearch.cpp +++ b/gpt4all-chat/bravesearch.cpp @@ -35,10 +35,8 @@ QString BraveSearch::run(const QJsonObject ¶meters, qint64 timeout) return worker.response(); } -void BraveAPIWorker::request(const QString &apiKey, const QString &query, int topK) +void BraveAPIWorker::request(const QString &apiKey, const QString &query, int count) { - m_topK = topK; - // Documentation on the brave web search: // https://api.search.brave.com/app/documentation/web-search/get-started QUrl jsonUrl("https://api.search.brave.com/res/v1/web/search"); @@ -47,7 +45,7 @@ void BraveAPIWorker::request(const QString &apiKey, const QString &query, int to //https://api.search.brave.com/app/documentation/web-search/query QUrlQuery urlQuery; urlQuery.addQueryItem("q", query); - urlQuery.addQueryItem("count", QString::number(topK)); + urlQuery.addQueryItem("count", QString::number(count)); urlQuery.addQueryItem("result_filter", "web"); urlQuery.addQueryItem("extra_snippets", "true"); jsonUrl.setQuery(urlQuery); @@ -64,7 +62,7 @@ void BraveAPIWorker::request(const QString &apiKey, const QString &query, int to connect(reply, &QNetworkReply::errorOccurred, this, &BraveAPIWorker::handleErrorOccurred); } -static QString cleanBraveResponse(const QByteArray& jsonResponse, qsizetype topK = 1) +static QString cleanBraveResponse(const QByteArray& jsonResponse) { // This parses the response from brave and formats it in json that conforms to the de facto // standard in SourceExcerpts::fromJson(...) @@ -77,7 +75,6 @@ static QString cleanBraveResponse(const QByteArray& jsonResponse, qsizetype topK QString query; QJsonObject searchResponse = document.object(); - QJsonObject cleanResponse; QJsonArray cleanArray; if (searchResponse.contains("query")) { @@ -99,7 +96,7 @@ static QString cleanBraveResponse(const QByteArray& jsonResponse, qsizetype topK const int idx = m["index"].toInt(); QJsonObject resultObj = resultsArray[idx].toObject(); - QStringList selectedKeys = {"type", "title", "url", "description"}; + QStringList selectedKeys = {"type", "title", "url"}; QJsonObject result; for (const auto& key : selectedKeys) if (resultObj.contains(key)) @@ -107,6 +104,8 @@ static QString cleanBraveResponse(const QByteArray& jsonResponse, qsizetype topK if (resultObj.contains("page_age")) result.insert("date", resultObj["page_age"]); + else + result.insert("date", QDate::currentDate().toString()); QJsonArray excerpts; if (resultObj.contains("extra_snippets")) { @@ -117,12 +116,18 @@ static QString cleanBraveResponse(const QByteArray& jsonResponse, qsizetype topK excerpt.insert("text", snippet); excerpts.append(excerpt); } + if (resultObj.contains("description")) + result.insert("description", resultObj["description"]); + } else { + QJsonObject excerpt; + excerpt.insert("text", resultObj["description"]); } result.insert("excerpts", excerpts); cleanArray.append(QJsonValue(result)); } } + QJsonObject cleanResponse; cleanResponse.insert("query", query); cleanResponse.insert("results", cleanArray); QJsonDocument cleanedDoc(cleanResponse); @@ -139,12 +144,13 @@ void BraveAPIWorker::handleFinished() if (jsonReply->error() == QNetworkReply::NoError && jsonReply->isFinished()) { QByteArray jsonData = jsonReply->readAll(); jsonReply->deleteLater(); - m_response = cleanBraveResponse(jsonData, m_topK); + m_response = cleanBraveResponse(jsonData); } else { QByteArray jsonData = jsonReply->readAll(); qWarning() << "ERROR: Could not search brave" << jsonReply->error() << jsonReply->errorString() << jsonData; jsonReply->deleteLater(); } + emit finished(); } void BraveAPIWorker::handleErrorOccurred(QNetworkReply::NetworkError code) diff --git a/gpt4all-chat/bravesearch.h b/gpt4all-chat/bravesearch.h index 6f617c4ae3f9..28f84b159e01 100644 --- a/gpt4all-chat/bravesearch.h +++ b/gpt4all-chat/bravesearch.h @@ -1,7 +1,6 @@ #ifndef BRAVESEARCH_H #define BRAVESEARCH_H -#include "sourceexcerpt.h" #include "tool.h" #include @@ -14,14 +13,13 @@ class BraveAPIWorker : public QObject { public: BraveAPIWorker() : QObject(nullptr) - , m_networkManager(nullptr) - , m_topK(1) {} + , m_networkManager(nullptr) {} virtual ~BraveAPIWorker() {} QString response() const { return m_response; } public Q_SLOTS: - void request(const QString &apiKey, const QString &query, int topK); + void request(const QString &apiKey, const QString &query, int count); Q_SIGNALS: void finished(); @@ -33,7 +31,6 @@ private Q_SLOTS: private: QNetworkAccessManager *m_networkManager; QString m_response; - int m_topK; }; class BraveSearch : public Tool { diff --git a/gpt4all-chat/chatllm.cpp b/gpt4all-chat/chatllm.cpp index 5c4f168c31b5..31b6e92f2680 100644 --- a/gpt4all-chat/chatllm.cpp +++ b/gpt4all-chat/chatllm.cpp @@ -3,7 +3,7 @@ #include "bravesearch.h" #include "chat.h" #include "chatapi.h" -#include "localdocs.h" +#include "localdocssearch.h" #include "mysettings.h" #include "network.h" @@ -13,6 +13,7 @@ #include #include #include +#include #include #include #include @@ -128,11 +129,6 @@ ChatLLM::ChatLLM(Chat *parent, bool isServer) connect(&m_llmThread, &QThread::started, this, &ChatLLM::handleThreadStarted); connect(MySettings::globalInstance(), &MySettings::forceMetalChanged, this, &ChatLLM::handleForceMetalChanged); connect(MySettings::globalInstance(), &MySettings::deviceChanged, this, &ChatLLM::handleDeviceChanged); - - // The following are blocking operations and will block the llm thread - connect(this, &ChatLLM::requestRetrieveFromDB, LocalDocs::globalInstance()->database(), &Database::retrieveFromDB, - Qt::BlockingQueuedConnection); - m_llmThread.setObjectName(parent->id()); m_llmThread.start(); } @@ -767,21 +763,33 @@ bool ChatLLM::promptInternal(const QList &collectionList, const QString if (!isModelLoaded()) return false; - QList databaseResults; - const int retrievalSize = MySettings::globalInstance()->localDocsRetrievalSize(); + QList localDocsExcerpts; if (!collectionList.isEmpty() && !isToolCallResponse) { - emit requestRetrieveFromDB(collectionList, prompt, retrievalSize, &databaseResults); // blocks - emit sourceExcerptsChanged(databaseResults); + LocalDocsSearch localdocs; + QJsonObject parameters; + parameters.insert("text", prompt); + parameters.insert("count", MySettings::globalInstance()->localDocsRetrievalSize()); + parameters.insert("collections", QJsonArray::fromStringList(collectionList)); + + // FIXME: This has to handle errors of the tool call + const QString localDocsResponse = localdocs.run(parameters, 2000 /*msecs to timeout*/); + + QString parseError; + localDocsExcerpts = SourceExcerpt::fromJson(localDocsResponse, parseError); + if (!parseError.isEmpty()) { + qWarning() << "ERROR: Could not parse source excerpts for localdocs response:" << parseError; + } else if (!localDocsExcerpts.isEmpty()) { + emit sourceExcerptsChanged(localDocsExcerpts); + } } // Augment the prompt template with the results if any QString docsContext; - if (!databaseResults.isEmpty()) { + if (!localDocsExcerpts.isEmpty()) { + // FIXME(adam): we should be using the new tool template if available otherwise this I guess QStringList results; - for (const SourceExcerpt &info : databaseResults) + for (const SourceExcerpt &info : localDocsExcerpts) results << u"Collection: %1\nPath: %2\nExcerpt: %3"_s.arg(info.collection, info.path, info.text); - - // FIXME(jared): use a Jinja prompt template instead of hardcoded Alpaca-style localdocs template docsContext = u"### Context:\n%1\n\n"_s.arg(results.join("\n\n")); } @@ -887,7 +895,7 @@ bool ChatLLM::promptInternal(const QList &collectionList, const QString QString parseError; QList sourceExcerpts = SourceExcerpt::fromJson(braveResponse, parseError); if (!parseError.isEmpty()) { - qWarning() << "ERROR: Could not parse source excerpts for brave response" << parseError; + qWarning() << "ERROR: Could not parse source excerpts for brave response:" << parseError; } else if (!sourceExcerpts.isEmpty()) { emit sourceExcerptsChanged(sourceExcerpts); } @@ -912,7 +920,7 @@ bool ChatLLM::promptInternal(const QList &collectionList, const QString } SuggestionMode mode = MySettings::globalInstance()->suggestionMode(); - if (mode == SuggestionMode::On || (mode == SuggestionMode::SourceExcerptsOnly && (!databaseResults.isEmpty() || isToolCallResponse))) + if (mode == SuggestionMode::On || (mode == SuggestionMode::SourceExcerptsOnly && (!localDocsExcerpts.isEmpty() || isToolCallResponse))) generateQuestions(elapsed); else emit responseStopped(elapsed); diff --git a/gpt4all-chat/chatllm.h b/gpt4all-chat/chatllm.h index 050c0b7db6bc..feacd744f228 100644 --- a/gpt4all-chat/chatllm.h +++ b/gpt4all-chat/chatllm.h @@ -189,7 +189,6 @@ public Q_SLOTS: void shouldBeLoadedChanged(); void trySwitchContextRequested(const ModelInfo &modelInfo); void trySwitchContextOfLoadedModelCompleted(int value); - void requestRetrieveFromDB(const QList &collections, const QString &text, int retrievalSize, QList *results); void reportSpeed(const QString &speed); void reportDevice(const QString &device); void reportFallbackReason(const QString &fallbackReason); diff --git a/gpt4all-chat/database.cpp b/gpt4all-chat/database.cpp index f2fc50941253..0713fc9f34be 100644 --- a/gpt4all-chat/database.cpp +++ b/gpt4all-chat/database.cpp @@ -1938,7 +1938,7 @@ QList Database::searchEmbeddings(const std::vector &query, const QLi } void Database::retrieveFromDB(const QList &collections, const QString &text, int retrievalSize, - QList *results) + QString &jsonResult) { #if defined(DEBUG) qDebug() << "retrieveFromDB" << collections << text << retrievalSize; @@ -1960,37 +1960,49 @@ void Database::retrieveFromDB(const QList &collections, const QString & return; } + QMap results; while (q.next()) { #if defined(DEBUG) const int rowid = q.value(0).toInt(); #endif - const QString document_path = q.value(2).toString(); - const QString chunk_text = q.value(3).toString(); - const QString date = QDateTime::fromMSecsSinceEpoch(q.value(1).toLongLong()).toString("yyyy, MMMM dd"); const QString file = q.value(4).toString(); - const QString title = q.value(5).toString(); - const QString author = q.value(6).toString(); - const int page = q.value(7).toInt(); - const int from = q.value(8).toInt(); - const int to = q.value(9).toInt(); - const QString collectionName = q.value(10).toString(); - SourceExcerpt info; - info.collection = collectionName; - info.path = document_path; - info.file = file; - info.title = title; - info.author = author; - info.date = date; - info.text = chunk_text; - info.page = page; - info.from = from; - info.to = to; - results->append(info); + QJsonObject resultObject = results.value(file); + resultObject.insert("file", file); + resultObject.insert("path", q.value(2).toString()); + resultObject.insert("date", QDateTime::fromMSecsSinceEpoch(q.value(1).toLongLong()).toString("yyyy, MMMM dd")); + resultObject.insert("title", q.value(5).toString()); + resultObject.insert("author", q.value(6).toString()); + resultObject.insert("collection", q.value(10).toString()); + + QJsonArray excerpts; + if (resultObject.contains("excerpts")) + excerpts = resultObject["excerpts"].toArray(); + + QJsonObject excerptObject; + excerptObject.insert("text", q.value(3).toString()); + excerptObject.insert("page", q.value(7).toInt()); + excerptObject.insert("from", q.value(8).toInt()); + excerptObject.insert("to", q.value(9).toInt()); + excerpts.append(excerptObject); + resultObject.insert("excerpts", excerpts); + results.insert(file, resultObject); + #if defined(DEBUG) qDebug() << "retrieve rowid:" << rowid << "chunk_text:" << chunk_text; #endif } + + QJsonArray resultsArray; + QList resultsList = results.values(); + for (const QJsonObject &result : resultsList) + resultsArray.append(QJsonValue(result)); + + QJsonObject response; + response.insert("results", resultsArray); + QJsonDocument document(response); +// qDebug().noquote() << document.toJson(QJsonDocument::Indented); + jsonResult = document.toJson(QJsonDocument::Compact); } // FIXME This is very slow and non-interruptible and when we close the application and we're diff --git a/gpt4all-chat/database.h b/gpt4all-chat/database.h index fd0a78d9d5a7..a2e7cf227dc3 100644 --- a/gpt4all-chat/database.h +++ b/gpt4all-chat/database.h @@ -101,7 +101,7 @@ public Q_SLOTS: void forceRebuildFolder(const QString &path); bool addFolder(const QString &collection, const QString &path, const QString &embedding_model); void removeFolder(const QString &collection, const QString &path); - void retrieveFromDB(const QList &collections, const QString &text, int retrievalSize, QList *results); + void retrieveFromDB(const QList &collections, const QString &text, int retrievalSize, QString &jsonResult); void changeChunkSize(int chunkSize); void changeFileExtensions(const QStringList &extensions); @@ -168,7 +168,6 @@ private Q_SLOTS: QStringList m_scannedFileExtensions; QTimer *m_scanTimer; QMap> m_docsToScan; - QList m_retrieve; QThread m_dbThread; QFileSystemWatcher *m_watcher; QSet m_watchedPaths; diff --git a/gpt4all-chat/localdocssearch.cpp b/gpt4all-chat/localdocssearch.cpp new file mode 100644 index 000000000000..77becf4ec9b7 --- /dev/null +++ b/gpt4all-chat/localdocssearch.cpp @@ -0,0 +1,50 @@ +#include "localdocssearch.h" +#include "database.h" +#include "localdocs.h" + +#include +#include +#include +#include +#include +#include + +using namespace Qt::Literals::StringLiterals; + +QString LocalDocsSearch::run(const QJsonObject ¶meters, qint64 timeout) +{ + QList collections; + QJsonArray collectionsArray = parameters["collections"].toArray(); + for (int i = 0; i < collectionsArray.size(); ++i) + collections.append(collectionsArray[i].toString()); + const QString text = parameters["text"].toString(); + const int count = parameters["count"].toInt(); + QThread workerThread; + LocalDocsWorker worker; + worker.moveToThread(&workerThread); + connect(&worker, &LocalDocsWorker::finished, &workerThread, &QThread::quit, Qt::DirectConnection); + connect(&workerThread, &QThread::started, [&worker, collections, text, count]() { + worker.request(collections, text, count); + }); + workerThread.start(); + workerThread.wait(timeout); + workerThread.quit(); + workerThread.wait(); + return worker.response(); +} + +LocalDocsWorker::LocalDocsWorker() + : QObject(nullptr) +{ + // The following are blocking operations and will block the calling thread + connect(this, &LocalDocsWorker::requestRetrieveFromDB, LocalDocs::globalInstance()->database(), + &Database::retrieveFromDB, Qt::BlockingQueuedConnection); +} + +void LocalDocsWorker::request(const QList &collections, const QString &text, int count) +{ + QString jsonResult; + emit requestRetrieveFromDB(collections, text, count, jsonResult); // blocks + m_response = jsonResult; + emit finished(); +} diff --git a/gpt4all-chat/localdocssearch.h b/gpt4all-chat/localdocssearch.h new file mode 100644 index 000000000000..4f757a521801 --- /dev/null +++ b/gpt4all-chat/localdocssearch.h @@ -0,0 +1,36 @@ +#ifndef LOCALDOCSSEARCH_H +#define LOCALDOCSSEARCH_H + +#include "tool.h" + +#include +#include + +class LocalDocsWorker : public QObject { + Q_OBJECT +public: + LocalDocsWorker(); + virtual ~LocalDocsWorker() {} + + QString response() const { return m_response; } + + void request(const QList &collections, const QString &text, int count); + +Q_SIGNALS: + void requestRetrieveFromDB(const QList &collections, const QString &text, int count, QString &jsonResponse); + void finished(); + +private: + QString m_response; +}; + +class LocalDocsSearch : public Tool { + Q_OBJECT +public: + LocalDocsSearch() : Tool() {} + virtual ~LocalDocsSearch() {} + + QString run(const QJsonObject ¶meters, qint64 timeout = 2000) override; +}; + +#endif // LOCALDOCSSEARCH_H diff --git a/gpt4all-chat/sourceexcerpt.h b/gpt4all-chat/sourceexcerpt.h index c66007f05aca..8276923940d1 100644 --- a/gpt4all-chat/sourceexcerpt.h +++ b/gpt4all-chat/sourceexcerpt.h @@ -77,19 +77,7 @@ struct SourceExcerpt { static QList fromJson(const QString &json, QString &errorString); bool operator==(const SourceExcerpt &other) const { - return date == other.date && - text == other.text && - collection == other.collection && - path == other.path && - file == other.file && - url == other.url && - favicon == other.favicon && - title == other.title && - author == other.author && - description == other.description && - page == other.page && - from == other.from && - to == other.to; + return file == other.file || url == other.url; } bool operator!=(const SourceExcerpt &other) const { return !(*this == other); diff --git a/gpt4all-chat/tool.h b/gpt4all-chat/tool.h index 240d2aa25d5b..3dac88e7b3a5 100644 --- a/gpt4all-chat/tool.h +++ b/gpt4all-chat/tool.h @@ -1,8 +1,6 @@ #ifndef TOOL_H #define TOOL_H -#include "sourceexcerpt.h" - #include #include @@ -70,18 +68,4 @@ class Tool : public QObject { virtual QString run(const QJsonObject ¶meters, qint64 timeout = 2000) = 0; }; -//class BuiltinTool : public Tool { -// Q_OBJECT -//public: -// BuiltinTool() : Tool() {} -// virtual QString run(const QJsonObject ¶meters, qint64 timeout = 2000); -//}; - -//class LocalTool : public Tool { -// Q_OBJECT -//public: -// LocalTool() : Tool() {} -// virtual QString run(const QJsonObject ¶meters, qint64 timeout = 2000); -//}; - #endif // TOOL_H From 27b86dae2143bf660be1e156f04ea754f728c478 Mon Sep 17 00:00:00 2001 From: Adam Treat Date: Thu, 1 Aug 2024 11:18:50 -0400 Subject: [PATCH 13/30] Serialize the source excerpts from and to pure json Signed-off-by: Adam Treat --- gpt4all-chat/chatllm.cpp | 9 +- gpt4all-chat/chatmodel.h | 149 +++++++++------------------------ gpt4all-chat/qml/ChatView.qml | 18 ++-- gpt4all-chat/server.cpp | 16 +--- gpt4all-chat/sourceexcerpt.cpp | 60 ++++++++++--- gpt4all-chat/sourceexcerpt.h | 82 ++++++++++-------- 6 files changed, 149 insertions(+), 185 deletions(-) diff --git a/gpt4all-chat/chatllm.cpp b/gpt4all-chat/chatllm.cpp index 31b6e92f2680..ed209ed03aaa 100644 --- a/gpt4all-chat/chatllm.cpp +++ b/gpt4all-chat/chatllm.cpp @@ -787,10 +787,8 @@ bool ChatLLM::promptInternal(const QList &collectionList, const QString QString docsContext; if (!localDocsExcerpts.isEmpty()) { // FIXME(adam): we should be using the new tool template if available otherwise this I guess - QStringList results; - for (const SourceExcerpt &info : localDocsExcerpts) - results << u"Collection: %1\nPath: %2\nExcerpt: %3"_s.arg(info.collection, info.path, info.text); - docsContext = u"### Context:\n%1\n\n"_s.arg(results.join("\n\n")); + QString json = SourceExcerpt::toJson(localDocsExcerpts); + docsContext = u"### Context:\n%1\n\n"_s.arg(json); } int n_threads = MySettings::globalInstance()->threadCount(); @@ -900,9 +898,6 @@ bool ChatLLM::promptInternal(const QList &collectionList, const QString emit sourceExcerptsChanged(sourceExcerpts); } - // Erase the context of the tool call - m_ctx.n_past = std::max(0, m_ctx.n_past); - m_ctx.tokens.erase(m_ctx.tokens.end() - m_promptResponseTokens, m_ctx.tokens.end()); m_promptResponseTokens = 0; m_promptTokens = 0; m_response = std::string(); diff --git a/gpt4all-chat/chatmodel.h b/gpt4all-chat/chatmodel.h index 97b812750b4e..19be1cf31216 100644 --- a/gpt4all-chat/chatmodel.h +++ b/gpt4all-chat/chatmodel.h @@ -29,7 +29,6 @@ struct ChatItem Q_PROPERTY(bool thumbsUpState MEMBER thumbsUpState) Q_PROPERTY(bool thumbsDownState MEMBER thumbsDownState) Q_PROPERTY(QList sources MEMBER sources) - Q_PROPERTY(QList consolidatedSources MEMBER consolidatedSources) public: // TODO: Maybe we should include the model name here as well as timestamp? @@ -39,7 +38,6 @@ struct ChatItem QString prompt; QString newResponse; QList sources; - QList consolidatedSources; bool currentResponse = false; bool stopped = false; bool thumbsUpState = false; @@ -65,8 +63,7 @@ class ChatModel : public QAbstractListModel StoppedRole, ThumbsUpStateRole, ThumbsDownStateRole, - SourcesRole, - ConsolidatedSourcesRole + SourcesRole }; int rowCount(const QModelIndex &parent = QModelIndex()) const override @@ -102,8 +99,6 @@ class ChatModel : public QAbstractListModel return item.thumbsDownState; case SourcesRole: return QVariant::fromValue(item.sources); - case ConsolidatedSourcesRole: - return QVariant::fromValue(item.consolidatedSources); } return QVariant(); @@ -122,7 +117,6 @@ class ChatModel : public QAbstractListModel roles[ThumbsUpStateRole] = "thumbsUpState"; roles[ThumbsDownStateRole] = "thumbsDownState"; roles[SourcesRole] = "sources"; - roles[ConsolidatedSourcesRole] = "consolidatedSources"; return roles; } @@ -200,20 +194,6 @@ class ChatModel : public QAbstractListModel } } - QList consolidateSources(const QList &sources) { - QMap groupedData; - for (const SourceExcerpt &info : sources) { - QString key = !info.file.isEmpty() ? info.file : info.url; - if (groupedData.contains(key)) { - groupedData[key].text += "\n---\n" + info.text; - } else { - groupedData[key] = info; - } - } - QList consolidatedSources = groupedData.values(); - return consolidatedSources; - } - Q_INVOKABLE void updateSources(int index, const QList &sources) { if (index < 0 || index >= m_chatItems.size()) return; @@ -221,13 +201,10 @@ class ChatModel : public QAbstractListModel ChatItem &item = m_chatItems[index]; if (sources.isEmpty()) { item.sources.clear(); - item.consolidatedSources.clear(); } else { item.sources << sources; - item.consolidatedSources << consolidateSources(sources); } emit dataChanged(createIndex(index, 0), createIndex(index, 0), {SourcesRole}); - emit dataChanged(createIndex(index, 0), createIndex(index, 0), {ConsolidatedSourcesRole}); } Q_INVOKABLE void updateThumbsUpState(int index, bool b) @@ -278,61 +255,7 @@ class ChatModel : public QAbstractListModel stream << c.stopped; stream << c.thumbsUpState; stream << c.thumbsDownState; - if (version > 7) { - stream << c.sources.size(); - for (const SourceExcerpt &info : c.sources) { - Q_ASSERT(!info.file.isEmpty()); - stream << info.collection; - stream << info.path; - stream << info.file; - stream << info.title; - stream << info.author; - stream << info.date; - stream << info.text; - stream << info.page; - stream << info.from; - stream << info.to; - if (version > 9) { - stream << info.url; - stream << info.favicon; - } - } - } else if (version > 2) { - QList references; - QList referencesContext; - int validReferenceNumber = 1; - for (const SourceExcerpt &info : c.sources) { - if (info.file.isEmpty()) - continue; - - QString reference; - { - QTextStream stream(&reference); - stream << (validReferenceNumber++) << ". "; - if (!info.title.isEmpty()) - stream << "\"" << info.title << "\". "; - if (!info.author.isEmpty()) - stream << "By " << info.author << ". "; - if (!info.date.isEmpty()) - stream << "Date: " << info.date << ". "; - stream << "In " << info.file << ". "; - if (info.page != -1) - stream << "Page " << info.page << ". "; - if (info.from != -1) { - stream << "Lines " << info.from; - if (info.to != -1) - stream << "-" << info.to; - stream << ". "; - } - stream << "[Context](context://" << validReferenceNumber - 1 << ")"; - } - references.append(reference); - referencesContext.append(info.text); - } - - stream << references.join("\n"); - stream << referencesContext; - } + stream << SourceExcerpt::toJson(c.sources); } return stream.status() == QDataStream::Ok; } @@ -352,31 +275,36 @@ class ChatModel : public QAbstractListModel stream >> c.stopped; stream >> c.thumbsUpState; stream >> c.thumbsDownState; - if (version > 7) { + if (version > 9) { + QList sources; + QString json; + stream >> json; + QString errorString; + sources = SourceExcerpt::fromJson(json, errorString); + Q_ASSERT(errorString.isEmpty()); + c.sources = sources; + } else if (version > 7) { qsizetype count; stream >> count; QList sources; for (int i = 0; i < count; ++i) { - SourceExcerpt info; - stream >> info.collection; - stream >> info.path; - stream >> info.file; - stream >> info.title; - stream >> info.author; - stream >> info.date; - stream >> info.text; - stream >> info.page; - stream >> info.from; - stream >> info.to; - if (version > 9) { - stream >> info.url; - stream >> info.favicon; - } - sources.append(info); + SourceExcerpt source; + stream >> source.collection; + stream >> source.path; + stream >> source.file; + stream >> source.title; + stream >> source.author; + stream >> source.date; + Excerpt excerpt; + stream >> excerpt.text; + stream >> excerpt.page; + stream >> excerpt.from; + stream >> excerpt.to; + source.excerpts = QList{ excerpt }; + sources.append(source); } c.sources = sources; - c.consolidatedSources = consolidateSources(sources); - }else if (version > 2) { + } else if (version > 2) { QString references; QList referencesContext; stream >> references; @@ -398,7 +326,8 @@ class ChatModel : public QAbstractListModel for (int j = 0; j < referenceList.size(); ++j) { QString reference = referenceList[j]; QString context = referencesContext[j]; - SourceExcerpt info; + SourceExcerpt source; + Excerpt excerpt; QTextStream refStream(&reference); QString dummy; int validReferenceNumber; @@ -407,28 +336,28 @@ class ChatModel : public QAbstractListModel if (reference.contains("\"")) { int startIndex = reference.indexOf('"') + 1; int endIndex = reference.indexOf('"', startIndex); - info.title = reference.mid(startIndex, endIndex - startIndex); + source.title = reference.mid(startIndex, endIndex - startIndex); } // Extract author (after "By " and before the next period) if (reference.contains("By ")) { int startIndex = reference.indexOf("By ") + 3; int endIndex = reference.indexOf('.', startIndex); - info.author = reference.mid(startIndex, endIndex - startIndex).trimmed(); + source.author = reference.mid(startIndex, endIndex - startIndex).trimmed(); } // Extract date (after "Date: " and before the next period) if (reference.contains("Date: ")) { int startIndex = reference.indexOf("Date: ") + 6; int endIndex = reference.indexOf('.', startIndex); - info.date = reference.mid(startIndex, endIndex - startIndex).trimmed(); + source.date = reference.mid(startIndex, endIndex - startIndex).trimmed(); } // Extract file name (after "In " and before the "[Context]") if (reference.contains("In ") && reference.contains(". [Context]")) { int startIndex = reference.indexOf("In ") + 3; int endIndex = reference.indexOf(". [Context]", startIndex); - info.file = reference.mid(startIndex, endIndex - startIndex).trimmed(); + source.file = reference.mid(startIndex, endIndex - startIndex).trimmed(); } // Extract page number (after "Page " and before the next space) @@ -436,7 +365,7 @@ class ChatModel : public QAbstractListModel int startIndex = reference.indexOf("Page ") + 5; int endIndex = reference.indexOf(' ', startIndex); if (endIndex == -1) endIndex = reference.length(); - info.page = reference.mid(startIndex, endIndex - startIndex).toInt(); + excerpt.page = reference.mid(startIndex, endIndex - startIndex).toInt(); } // Extract lines (after "Lines " and before the next space or hyphen) @@ -446,18 +375,18 @@ class ChatModel : public QAbstractListModel if (endIndex == -1) endIndex = reference.length(); int hyphenIndex = reference.indexOf('-', startIndex); if (hyphenIndex != -1 && hyphenIndex < endIndex) { - info.from = reference.mid(startIndex, hyphenIndex - startIndex).toInt(); - info.to = reference.mid(hyphenIndex + 1, endIndex - hyphenIndex - 1).toInt(); + excerpt.from = reference.mid(startIndex, hyphenIndex - startIndex).toInt(); + excerpt.to = reference.mid(hyphenIndex + 1, endIndex - hyphenIndex - 1).toInt(); } else { - info.from = reference.mid(startIndex, endIndex - startIndex).toInt(); + excerpt.from = reference.mid(startIndex, endIndex - startIndex).toInt(); } } - info.text = context; - sources.append(info); + excerpt.text = context; + source.excerpts = QList{ excerpt }; + sources.append(source); } c.sources = sources; - c.consolidatedSources = consolidateSources(sources); } } beginInsertRows(QModelIndex(), m_chatItems.size(), m_chatItems.size()); diff --git a/gpt4all-chat/qml/ChatView.qml b/gpt4all-chat/qml/ChatView.qml index c35843cf50c4..1417f42b0971 100644 --- a/gpt4all-chat/qml/ChatView.qml +++ b/gpt4all-chat/qml/ChatView.qml @@ -1106,7 +1106,7 @@ Rectangle { Layout.preferredWidth: childrenRect.width Layout.preferredHeight: childrenRect.height visible: { - if (consolidatedSources.length === 0) + if (sources.length === 0) return false if (!MySettings.localDocsShowReferences) return false @@ -1134,9 +1134,9 @@ Rectangle { sourceSize.height: 24 mipmap: true source: { - if (typeof consolidatedSources === 'undefined' - || typeof consolidatedSources[0] === 'undefined' - || consolidatedSources[0].url === "") + if (typeof sources === 'undefined' + || typeof sources[0] === 'undefined' + || sources[0].url === "") return "qrc:/gpt4all/icons/db.svg"; else return "qrc:/gpt4all/icons/globe.svg"; @@ -1151,7 +1151,7 @@ Rectangle { } Text { - text: qsTr("%1 Sources").arg(consolidatedSources.length) + text: qsTr("%1 Sources").arg(sources.length) padding: 0 font.pixelSize: theme.fontSizeLarge font.bold: true @@ -1199,7 +1199,7 @@ Rectangle { Layout.column: 1 Layout.topMargin: 5 visible: { - if (consolidatedSources.length === 0) + if (sources.length === 0) return false if (!MySettings.localDocsShowReferences) return false @@ -1240,9 +1240,9 @@ Rectangle { id: flow Layout.fillWidth: true spacing: 10 - visible: consolidatedSources.length !== 0 + visible: sources.length !== 0 Repeater { - model: consolidatedSources + model: sources delegate: Rectangle { radius: 10 @@ -1361,7 +1361,7 @@ Rectangle { return false; if (MySettings.suggestionMode === 2) // Off return false; - if (MySettings.suggestionMode === 0 && consolidatedSources.length === 0) // LocalDocs only + if (MySettings.suggestionMode === 0 && sources.length === 0) // LocalDocs only return false; return currentChat.responseState === Chat.GeneratingQuestions || currentChat.generatedQuestions.length !== 0; } diff --git a/gpt4all-chat/server.cpp b/gpt4all-chat/server.cpp index e655bf9feff9..af266afbfe86 100644 --- a/gpt4all-chat/server.cpp +++ b/gpt4all-chat/server.cpp @@ -408,12 +408,8 @@ QHttpServerResponse Server::handleCompletionRequest(const QHttpServerRequest &re message.insert("role", "assistant"); message.insert("content", result); choice.insert("message", message); - if (MySettings::globalInstance()->localDocsShowReferences()) { - QJsonArray references; - for (const auto &ref : infos) - references.append(ref.toJson()); - choice.insert("references", references); - } + if (MySettings::globalInstance()->localDocsShowReferences()) + choice.insert("references", SourceExcerpt::toJson(infos)); choices.append(choice); } } else { @@ -426,12 +422,8 @@ QHttpServerResponse Server::handleCompletionRequest(const QHttpServerRequest &re choice.insert("index", index++); choice.insert("logprobs", QJsonValue::Null); // We don't support choice.insert("finish_reason", responseTokens == max_tokens ? "length" : "stop"); - if (MySettings::globalInstance()->localDocsShowReferences()) { - QJsonArray references; - for (const auto &ref : infos) - references.append(ref.toJson()); - choice.insert("references", references); - } + if (MySettings::globalInstance()->localDocsShowReferences()) + choice.insert("references", SourceExcerpt::toJson(infos)); choices.append(choice); } } diff --git a/gpt4all-chat/sourceexcerpt.cpp b/gpt4all-chat/sourceexcerpt.cpp index 811d65f050d0..0c702ce9ff1b 100644 --- a/gpt4all-chat/sourceexcerpt.cpp +++ b/gpt4all-chat/sourceexcerpt.cpp @@ -5,8 +5,53 @@ #include #include +QString SourceExcerpt::toJson(const QList &sources) +{ + if (sources.isEmpty()) + return QString(); + + QJsonArray resultsArray; + for (const auto &source : sources) { + QJsonObject sourceObj; + sourceObj["date"] = source.date; + sourceObj["collection"] = source.collection; + sourceObj["path"] = source.path; + sourceObj["file"] = source.file; + sourceObj["url"] = source.url; + sourceObj["favicon"] = source.favicon; + sourceObj["title"] = source.title; + sourceObj["author"] = source.author; + sourceObj["description"] = source.description; + + QJsonArray excerptsArray; + for (const auto &excerpt : source.excerpts) { + QJsonObject excerptObj; + excerptObj["text"] = excerpt.text; + if (excerpt.page != -1) + excerptObj["page"] = excerpt.page; + if (excerpt.from != -1) + excerptObj["from"] = excerpt.from; + if (excerpt.to != -1) + excerptObj["to"] = excerpt.to; + excerptsArray.append(excerptObj); + } + sourceObj["excerpts"] = excerptsArray; + + resultsArray.append(sourceObj); + } + + QJsonObject jsonObj; + jsonObj["results"] = resultsArray; + + QJsonDocument doc(jsonObj); + return doc.toJson(QJsonDocument::Compact); +} + QList SourceExcerpt::fromJson(const QString &json, QString &errorString) { + if (json.isEmpty()) + return QList(); + QJsonParseError err; QJsonDocument document = QJsonDocument::fromJson(json.toUtf8(), &err); if (err.error != QJsonParseError::NoError) { @@ -44,7 +89,7 @@ QList SourceExcerpt::fromJson(const QString &json, QString &error SourceExcerpt source; source.date = result["date"].toString(); if (result.contains("collection")) - source.collection = result["text"].toString(); + source.collection = result["collection"].toString(); if (result.contains("path")) source.path = result["path"].toString(); if (result.contains("file")) @@ -61,15 +106,6 @@ QList SourceExcerpt::fromJson(const QString &json, QString &error source.author = result["description"].toString(); for (int i = 0; i < textExcerpts.size(); ++i) { - SourceExcerpt excerpt; - excerpt.date = source.date; - excerpt.collection = source.collection; - excerpt.path = source.path; - excerpt.file = source.file; - excerpt.url = source.url; - excerpt.favicon = source.favicon; - excerpt.title = source.title; - excerpt.author = source.author; if (!textExcerpts[i].isObject()) { errorString = "result excerpt is not an object"; return QList(); @@ -79,6 +115,7 @@ QList SourceExcerpt::fromJson(const QString &json, QString &error errorString = "result excerpt is does not have text field"; return QList(); } + Excerpt excerpt; excerpt.text = excerptObj["text"].toString(); if (excerptObj.contains("page")) excerpt.page = excerptObj["page"].toInt(); @@ -86,8 +123,9 @@ QList SourceExcerpt::fromJson(const QString &json, QString &error excerpt.from = excerptObj["from"].toInt(); if (excerptObj.contains("to")) excerpt.to = excerptObj["to"].toInt(); - excerpts.append(excerpt); + source.excerpts.append(excerpt); } + excerpts.append(source); } return excerpts; } diff --git a/gpt4all-chat/sourceexcerpt.h b/gpt4all-chat/sourceexcerpt.h index 8276923940d1..3f02457ca4dc 100644 --- a/gpt4all-chat/sourceexcerpt.h +++ b/gpt4all-chat/sourceexcerpt.h @@ -8,10 +8,23 @@ using namespace Qt::Literals::StringLiterals; +struct Excerpt { + QString text; // [Required] The text actually used in the augmented context + int page = -1; // [Optional] The page where the text was found + int from = -1; // [Optional] The line number where the text begins + int to = -1; // [Optional] The line number where the text ends + bool operator==(const Excerpt &other) const { + return text == other.text && page == other.page && from == other.from && to == other.to; + } + bool operator!=(const Excerpt &other) const { + return !(*this == other); + } +}; +Q_DECLARE_METATYPE(Excerpt) + struct SourceExcerpt { Q_GADGET Q_PROPERTY(QString date MEMBER date) - Q_PROPERTY(QString text MEMBER text) Q_PROPERTY(QString collection MEMBER collection) Q_PROPERTY(QString path MEMBER path) Q_PROPERTY(QString file MEMBER file) @@ -20,25 +33,40 @@ struct SourceExcerpt { Q_PROPERTY(QString title MEMBER title) Q_PROPERTY(QString author MEMBER author) Q_PROPERTY(QString description MEMBER description) - Q_PROPERTY(int page MEMBER page) - Q_PROPERTY(int from MEMBER from) - Q_PROPERTY(int to MEMBER to) Q_PROPERTY(QString fileUri READ fileUri STORED false) + Q_PROPERTY(QString text READ text STORED false) + Q_PROPERTY(QList excerpts MEMBER excerpts) public: - QString date; // [Required] The creation or the last modification date whichever is latest - QString text; // [Required] The text actually used in the augmented context - QString collection; // [Optional] The name of the collection - QString path; // [Optional] The full path - QString file; // [Optional] The name of the file, but not the full path - QString url; // [Optional] The name of the remote url - QString favicon; // [Optional] The favicon - QString title; // [Optional] The title of the document - QString author; // [Optional] The author of the document - QString description;// [Optional] The description of the source - int page = -1; // [Optional] The page where the text was found - int from = -1; // [Optional] The line number where the text begins - int to = -1; // [Optional] The line number where the text ends + QString date; // [Required] The creation or the last modification date whichever is latest + QString collection; // [Optional] The name of the collection + QString path; // [Optional] The full path + QString file; // [Optional] The name of the file, but not the full path + QString url; // [Optional] The name of the remote url + QString favicon; // [Optional] The favicon + QString title; // [Optional] The title of the document + QString author; // [Optional] The author of the document + QString description; // [Optional] The description of the source + QList excerpts;// [Required] The list of excerpts + + // Returns a human readable string containing all the excerpts + QString text() const { + QStringList formattedExcerpts; + for (const auto& excerpt : excerpts) { + QString formattedExcerpt = excerpt.text; + if (excerpt.page != -1) { + formattedExcerpt += QStringLiteral(" (Page: %1").arg(excerpt.page); + if (excerpt.from != -1 && excerpt.to != -1) { + formattedExcerpt += QStringLiteral(", Lines: %1-%2").arg(excerpt.from).arg(excerpt.to); + } + formattedExcerpt += QStringLiteral(")"); + } else if (excerpt.from != -1 && excerpt.to != -1) { + formattedExcerpt += QStringLiteral(" (Lines: %1-%2)").arg(excerpt.from).arg(excerpt.to); + } + formattedExcerpts.append(formattedExcerpt); + } + return formattedExcerpts.join(QStringLiteral("\n---\n")); + } QString fileUri() const { // QUrl reserved chars that are not UNSAFE_PATH according to glib/gconvert.c @@ -55,25 +83,7 @@ struct SourceExcerpt { return u"file://"_s + escaped; } - QJsonObject toJson() const - { - QJsonObject result; - result.insert("date", date); - result.insert("text", text); - result.insert("collection", collection); - result.insert("path", path); - result.insert("file", file); - result.insert("url", url); - result.insert("favicon", favicon); - result.insert("title", title); - result.insert("author", author); - result.insert("description", description); - result.insert("page", page); - result.insert("from", from); - result.insert("to", to); - return result; - } - + static QString toJson(const QList &sources); static QList fromJson(const QString &json, QString &errorString); bool operator==(const SourceExcerpt &other) const { From 5fc2ff8e69f7895974d8d821633fa44a846df820 Mon Sep 17 00:00:00 2001 From: Adam Treat Date: Thu, 1 Aug 2024 18:08:38 -0400 Subject: [PATCH 14/30] Use parameters which is in keeping with other standard practices. Signed-off-by: Adam Treat --- gpt4all-chat/chatllm.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/gpt4all-chat/chatllm.cpp b/gpt4all-chat/chatllm.cpp index ed209ed03aaa..bea454eebeea 100644 --- a/gpt4all-chat/chatllm.cpp +++ b/gpt4all-chat/chatllm.cpp @@ -860,18 +860,18 @@ bool ChatLLM::promptInternal(const QList &collectionList, const QString } QJsonObject rootObject = toolCallDoc.object(); - if (!rootObject.contains("name") || !rootObject.contains("arguments")) { + if (!rootObject.contains("name") || !rootObject.contains("parameters")) { qWarning() << "ERROR: The tool call did not have required name and argument objects " << toolCall; return handleFailedToolCall(trimmed, elapsed); } const QString tool = toolCallDoc["name"].toString(); - const QJsonObject args = toolCallDoc["arguments"].toObject(); + const QJsonObject args = toolCallDoc["parameters"].toObject(); // FIXME: In the future this will try to match the tool call to a list of tools that are supported // according to MySettings, but for now only brave search is supported if (tool != "brave_search" || !args.contains("query")) { - qWarning() << "ERROR: Could not find the tool and correct arguments for " << toolCall; + qWarning() << "ERROR: Could not find the tool and correct parameters for " << toolCall; return handleFailedToolCall(trimmed, elapsed); } From 244b82622c79c4e7a9fa90f74d992b56f72b83ce Mon Sep 17 00:00:00 2001 From: Adam Treat Date: Wed, 7 Aug 2024 10:44:52 -0400 Subject: [PATCH 15/30] Implement error handling for tool calls. Signed-off-by: Adam Treat --- gpt4all-chat/bravesearch.cpp | 19 +++++++++++++------ gpt4all-chat/bravesearch.h | 13 ++++++++++++- gpt4all-chat/chatllm.cpp | 3 ++- gpt4all-chat/localdocssearch.cpp | 6 ++++++ gpt4all-chat/localdocssearch.h | 8 +++++++- gpt4all-chat/tool.h | 24 +++++++++++++++--------- 6 files changed, 55 insertions(+), 18 deletions(-) diff --git a/gpt4all-chat/bravesearch.cpp b/gpt4all-chat/bravesearch.cpp index e691810bb567..a9c950b0b80f 100644 --- a/gpt4all-chat/bravesearch.cpp +++ b/gpt4all-chat/bravesearch.cpp @@ -29,7 +29,14 @@ QString BraveSearch::run(const QJsonObject ¶meters, qint64 timeout) worker.request(apiKey, query, count); }); workerThread.start(); - workerThread.wait(timeout); + bool timedOut = !workerThread.wait(timeout); + if (timedOut) { + m_error = ToolEnums::Error::TimeoutError; + m_errorString = tr("ERROR: brave search timeout"); + } else { + m_error = worker.error(); + m_errorString = worker.errorString(); + } workerThread.quit(); workerThread.wait(); return worker.response(); @@ -62,14 +69,15 @@ void BraveAPIWorker::request(const QString &apiKey, const QString &query, int co connect(reply, &QNetworkReply::errorOccurred, this, &BraveAPIWorker::handleErrorOccurred); } -static QString cleanBraveResponse(const QByteArray& jsonResponse) +QString BraveAPIWorker::cleanBraveResponse(const QByteArray& jsonResponse) { // This parses the response from brave and formats it in json that conforms to the de facto // standard in SourceExcerpts::fromJson(...) QJsonParseError err; QJsonDocument document = QJsonDocument::fromJson(jsonResponse, &err); if (err.error != QJsonParseError::NoError) { - qWarning() << "ERROR: Couldn't parse brave response: " << jsonResponse << err.errorString(); + m_error = ToolEnums::Error::UnknownError; + m_errorString = QString(tr("ERROR: brave search could not parse json response: %1")).arg(jsonResponse); return QString(); } @@ -147,7 +155,6 @@ void BraveAPIWorker::handleFinished() m_response = cleanBraveResponse(jsonData); } else { QByteArray jsonData = jsonReply->readAll(); - qWarning() << "ERROR: Could not search brave" << jsonReply->error() << jsonReply->errorString() << jsonData; jsonReply->deleteLater(); } emit finished(); @@ -157,7 +164,7 @@ void BraveAPIWorker::handleErrorOccurred(QNetworkReply::NetworkError code) { QNetworkReply *reply = qobject_cast(sender()); Q_ASSERT(reply); - qWarning().noquote() << "ERROR: BraveAPIWorker::handleErrorOccurred got HTTP Error" << code << "response:" - << reply->errorString(); + m_error = ToolEnums::Error::UnknownError; + m_errorString = QString(tr("ERROR: brave search code: %1 response: %2")).arg(code).arg(reply->errorString()); emit finished(); } diff --git a/gpt4all-chat/bravesearch.h b/gpt4all-chat/bravesearch.h index 28f84b159e01..dc0cc0423688 100644 --- a/gpt4all-chat/bravesearch.h +++ b/gpt4all-chat/bravesearch.h @@ -17,6 +17,8 @@ class BraveAPIWorker : public QObject { virtual ~BraveAPIWorker() {} QString response() const { return m_response; } + ToolEnums::Error error() const { return m_error; } + QString errorString() const { return m_errorString; } public Q_SLOTS: void request(const QString &apiKey, const QString &query, int count); @@ -25,21 +27,30 @@ public Q_SLOTS: void finished(); private Q_SLOTS: + QString cleanBraveResponse(const QByteArray& jsonResponse); void handleFinished(); void handleErrorOccurred(QNetworkReply::NetworkError code); private: QNetworkAccessManager *m_networkManager; QString m_response; + ToolEnums::Error m_error; + QString m_errorString; }; class BraveSearch : public Tool { Q_OBJECT public: - BraveSearch() : Tool() {} + BraveSearch() : Tool(), m_error(ToolEnums::Error::NoError) {} virtual ~BraveSearch() {} QString run(const QJsonObject ¶meters, qint64 timeout = 2000) override; + ToolEnums::Error error() const override { return m_error; } + QString errorString() const override { return m_errorString; } + +private: + ToolEnums::Error m_error; + QString m_errorString; }; #endif // BRAVESEARCH_H diff --git a/gpt4all-chat/chatllm.cpp b/gpt4all-chat/chatllm.cpp index bea454eebeea..381514961fd6 100644 --- a/gpt4all-chat/chatllm.cpp +++ b/gpt4all-chat/chatllm.cpp @@ -871,6 +871,7 @@ bool ChatLLM::promptInternal(const QList &collectionList, const QString // FIXME: In the future this will try to match the tool call to a list of tools that are supported // according to MySettings, but for now only brave search is supported if (tool != "brave_search" || !args.contains("query")) { + // FIXME: Need to surface errors to the UI qWarning() << "ERROR: Could not find the tool and correct parameters for " << toolCall; return handleFailedToolCall(trimmed, elapsed); } @@ -887,7 +888,7 @@ bool ChatLLM::promptInternal(const QList &collectionList, const QString parameters.insert("query", query); parameters.insert("count", 2); - // FIXME: This has to handle errors of the tool call + // FIXME: Need to surface errors to the UI const QString braveResponse = brave.run(parameters, 2000 /*msecs to timeout*/); QString parseError; diff --git a/gpt4all-chat/localdocssearch.cpp b/gpt4all-chat/localdocssearch.cpp index 77becf4ec9b7..2fc9dbe217bd 100644 --- a/gpt4all-chat/localdocssearch.cpp +++ b/gpt4all-chat/localdocssearch.cpp @@ -27,6 +27,12 @@ QString LocalDocsSearch::run(const QJsonObject ¶meters, qint64 timeout) worker.request(collections, text, count); }); workerThread.start(); + bool timedOut = !workerThread.wait(timeout); + if (timedOut) { + m_error = ToolEnums::Error::TimeoutError; + m_errorString = tr("ERROR: localdocs timeout"); + } + workerThread.wait(timeout); workerThread.quit(); workerThread.wait(); diff --git a/gpt4all-chat/localdocssearch.h b/gpt4all-chat/localdocssearch.h index 4f757a521801..61e56b737cb3 100644 --- a/gpt4all-chat/localdocssearch.h +++ b/gpt4all-chat/localdocssearch.h @@ -27,10 +27,16 @@ class LocalDocsWorker : public QObject { class LocalDocsSearch : public Tool { Q_OBJECT public: - LocalDocsSearch() : Tool() {} + LocalDocsSearch() : Tool(), m_error(ToolEnums::Error::NoError) {} virtual ~LocalDocsSearch() {} QString run(const QJsonObject ¶meters, qint64 timeout = 2000) override; + ToolEnums::Error error() const override { return m_error; } + QString errorString() const override { return m_errorString; } + +private: + ToolEnums::Error m_error; + QString m_errorString; }; #endif // LOCALDOCSSEARCH_H diff --git a/gpt4all-chat/tool.h b/gpt4all-chat/tool.h index 3dac88e7b3a5..aa8ea3b28800 100644 --- a/gpt4all-chat/tool.h +++ b/gpt4all-chat/tool.h @@ -9,15 +9,20 @@ using namespace Qt::Literals::StringLiterals; namespace ToolEnums { Q_NAMESPACE enum class ConnectionType { - Builtin = 0, // A built-in tool with bespoke connection type - Local = 1, // Starts a local process and communicates via stdin/stdout/stderr - LocalServer = 2, // Connects to an existing local process and communicates via stdin/stdout/stderr - Remote = 3, // Starts a remote process and communicates via some networking protocol TBD - RemoteServer = 4 // Connects to an existing remote process and communicates via some networking protocol TBD + BuiltinConnection = 0, // A built-in tool with bespoke connection type + LocalConnection = 1, // Starts a local process and communicates via stdin/stdout/stderr + LocalServerConnection = 2, // Connects to an existing local process and communicates via stdin/stdout/stderr + RemoteConnection = 3, // Starts a remote process and communicates via some networking protocol TBD + RemoteServerConnection = 4 // Connects to an existing remote process and communicates via some networking protocol TBD }; Q_ENUM_NS(ConnectionType) + + enum class Error { + NoError = 0, + TimeoutError = 2, + UnknownError = 499, + }; } -using namespace ToolEnums; struct ToolInfo { Q_GADGET @@ -25,14 +30,14 @@ struct ToolInfo { Q_PROPERTY(QString description MEMBER description) Q_PROPERTY(QJsonObject parameters MEMBER parameters) Q_PROPERTY(bool isEnabled MEMBER isEnabled) - Q_PROPERTY(ConnectionType connectionType MEMBER connectionType) + Q_PROPERTY(ToolEnums::ConnectionType connectionType MEMBER connectionType) public: QString name; QString description; QJsonObject parameters; bool isEnabled; - ConnectionType connectionType; + ToolEnums::ConnectionType connectionType; // FIXME: Should we go with essentially the OpenAI/ollama consensus for these tool // info files? If you install a tool in GPT4All should it need to meet the spec for these: @@ -64,8 +69,9 @@ class Tool : public QObject { Tool() : QObject(nullptr) {} virtual ~Tool() {} - // FIXME: How to handle errors? virtual QString run(const QJsonObject ¶meters, qint64 timeout = 2000) = 0; + virtual ToolEnums::Error error() const { return ToolEnums::Error::NoError; } + virtual QString errorString() const { return QString(); } }; #endif // TOOL_H From cedba6cd10042bcdbc749501e8dbe066e137d3a5 Mon Sep 17 00:00:00 2001 From: Adam Treat Date: Thu, 8 Aug 2024 10:37:53 -0400 Subject: [PATCH 16/30] Tool model. Signed-off-by: Adam Treat --- gpt4all-chat/CMakeLists.txt | 2 +- gpt4all-chat/chatlistmodel.cpp | 2 +- gpt4all-chat/qml/ToolSettings.qml | 2 +- gpt4all-chat/tool.h | 72 ++++++++----------- gpt4all-chat/toolmodel.cpp | 103 +++++++++++++++++++++++++++ gpt4all-chat/toolmodel.h | 112 ++++++++++++++++++++++++++++++ 6 files changed, 247 insertions(+), 46 deletions(-) create mode 100644 gpt4all-chat/toolmodel.cpp create mode 100644 gpt4all-chat/toolmodel.h diff --git a/gpt4all-chat/CMakeLists.txt b/gpt4all-chat/CMakeLists.txt index 1ab4db20f89c..08030af646e9 100644 --- a/gpt4all-chat/CMakeLists.txt +++ b/gpt4all-chat/CMakeLists.txt @@ -124,7 +124,7 @@ qt_add_executable(chat sourceexcerpt.h sourceexcerpt.cpp server.h server.cpp logger.h logger.cpp - tool.h tool.cpp + tool.h tool.cpp toolmodel.h toolmodel.cpp ${APP_ICON_RESOURCE} ${CHAT_EXE_RESOURCES} ) diff --git a/gpt4all-chat/chatlistmodel.cpp b/gpt4all-chat/chatlistmodel.cpp index c5be4338ff3e..d5547e3b63ab 100644 --- a/gpt4all-chat/chatlistmodel.cpp +++ b/gpt4all-chat/chatlistmodel.cpp @@ -31,7 +31,7 @@ ChatListModel *ChatListModel::globalInstance() ChatListModel::ChatListModel() : QAbstractListModel(nullptr) { - QCoreApplication::instance()->installEventFilter(this); + QCoreApplication::instance()->installEventFilter(this); } bool ChatListModel::eventFilter(QObject *obj, QEvent *ev) diff --git a/gpt4all-chat/qml/ToolSettings.qml b/gpt4all-chat/qml/ToolSettings.qml index 2fc1cd3210da..f9b5e7277662 100644 --- a/gpt4all-chat/qml/ToolSettings.qml +++ b/gpt4all-chat/qml/ToolSettings.qml @@ -67,5 +67,5 @@ MySettingsTab { height: 1 color: theme.settingsDivider } - } + } } diff --git a/gpt4all-chat/tool.h b/gpt4all-chat/tool.h index aa8ea3b28800..7598af6885a0 100644 --- a/gpt4all-chat/tool.h +++ b/gpt4all-chat/tool.h @@ -8,70 +8,56 @@ using namespace Qt::Literals::StringLiterals; namespace ToolEnums { Q_NAMESPACE - enum class ConnectionType { - BuiltinConnection = 0, // A built-in tool with bespoke connection type - LocalConnection = 1, // Starts a local process and communicates via stdin/stdout/stderr - LocalServerConnection = 2, // Connects to an existing local process and communicates via stdin/stdout/stderr - RemoteConnection = 3, // Starts a remote process and communicates via some networking protocol TBD - RemoteServerConnection = 4 // Connects to an existing remote process and communicates via some networking protocol TBD - }; - Q_ENUM_NS(ConnectionType) - enum class Error { NoError = 0, TimeoutError = 2, UnknownError = 499, }; + Q_ENUM_NS(Error) } -struct ToolInfo { - Q_GADGET +class Tool : public QObject { + Q_OBJECT Q_PROPERTY(QString name MEMBER name) Q_PROPERTY(QString description MEMBER description) - Q_PROPERTY(QJsonObject parameters MEMBER parameters) + Q_PROPERTY(QString function MEMBER function) + Q_PROPERTY(QJsonObject paramSchema MEMBER paramSchema) + Q_PROPERTY(QUrl url MEMBER url) Q_PROPERTY(bool isEnabled MEMBER isEnabled) - Q_PROPERTY(ToolEnums::ConnectionType connectionType MEMBER connectionType) + Q_PROPERTY(bool isBuiltin MEMBER isBuiltin) + Q_PROPERTY(bool forceUsage MEMBER forceUsage) + Q_PROPERTY(bool excerpts MEMBER excerpts) public: - QString name; - QString description; - QJsonObject parameters; - bool isEnabled; - ToolEnums::ConnectionType connectionType; + Tool() : QObject(nullptr) {} + virtual ~Tool() {} + + virtual QString run(const QJsonObject ¶meters, qint64 timeout = 2000) = 0; + virtual ToolEnums::Error error() const { return ToolEnums::Error::NoError; } + virtual QString errorString() const { return QString(); } + + QString name; // [Required] Human readable name of the tool. + QString description; // [Required] Human readable description of the tool. + QString function; // [Required] Must be unique. Name of the function to invoke. Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length of 64. + QJsonObject paramSchema; // [Optional] Json schema describing the tool's parameters. An empty object specifies no parameters. + // https://json-schema.org/understanding-json-schema/ + QUrl url; // [Optional] The local file or remote resource use to invoke the tool. + bool isEnabled = false; // [Optional] Whether the tool is currently enabled + bool isBuiltin = false; // [Optional] Whether the tool is built-in + bool forceUsage = false; // [Optional] Whether we should attempt to force usage of the tool rather than let the LLM decide. NOTE: Not always possible. + bool excerpts = false; // [Optional] Whether json result produces source excerpts. // FIXME: Should we go with essentially the OpenAI/ollama consensus for these tool // info files? If you install a tool in GPT4All should it need to meet the spec for these: // https://platform.openai.com/docs/api-reference/runs/createRun#runs-createrun-tools // https://github.com/ollama/ollama/blob/main/docs/api.md#chat-request-with-tools - QJsonObject toJson() const - { - QJsonObject result; - result.insert("name", name); - result.insert("description", description); - result.insert("parameters", parameters); - return result; - } - static ToolInfo fromJson(const QString &json); - - bool operator==(const ToolInfo &other) const { - return name == other.name; + bool operator==(const Tool &other) const { + return function == other.function; } - bool operator!=(const ToolInfo &other) const { + bool operator!=(const Tool &other) const { return !(*this == other); } }; -Q_DECLARE_METATYPE(ToolInfo) - -class Tool : public QObject { - Q_OBJECT -public: - Tool() : QObject(nullptr) {} - virtual ~Tool() {} - - virtual QString run(const QJsonObject ¶meters, qint64 timeout = 2000) = 0; - virtual ToolEnums::Error error() const { return ToolEnums::Error::NoError; } - virtual QString errorString() const { return QString(); } -}; #endif // TOOL_H diff --git a/gpt4all-chat/toolmodel.cpp b/gpt4all-chat/toolmodel.cpp new file mode 100644 index 000000000000..4a6d61c32f06 --- /dev/null +++ b/gpt4all-chat/toolmodel.cpp @@ -0,0 +1,103 @@ +#include "toolmodel.h" + +#include +#include +#include + +#include "bravesearch.h" +#include "localdocssearch.h" + +class MyToolModel: public ToolModel { }; +Q_GLOBAL_STATIC(MyToolModel, toolModelInstance) +ToolModel *ToolModel::globalInstance() +{ + return toolModelInstance(); +} + +ToolModel::ToolModel() + : QAbstractListModel(nullptr) { + + QCoreApplication::instance()->installEventFilter(this); + + Tool* localDocsSearch = new LocalDocsSearch; + localDocsSearch->name = tr("LocalDocs search"); + localDocsSearch->description = tr("Search the local docs"); + localDocsSearch->function = "localdocs_search"; + localDocsSearch->isBuiltin = true; + localDocsSearch->excerpts = true; + localDocsSearch->forceUsage = true; // FIXME: persistent setting + localDocsSearch->isEnabled = true; // FIXME: persistent setting + + QString localParamSchema = R"({ + "collections": { + "type": "array", + "items": { + "type": "string" + }, + "description": "The collections to search", + "required": true, + "modelGenerated": false, + "userConfigured": false + }, + "query": { + "type": "string", + "description": "The query to search", + "required": true + }, + "count": { + "type": "integer", + "description": "The number of excerpts to return", + "required": true, + "modelGenerated": false + } + })"; + + QJsonDocument localJsonDoc = QJsonDocument::fromJson(localParamSchema.toUtf8()); + Q_ASSERT(!localJsonDoc.isNull() && localJsonDoc.isObject()); + localDocsSearch->paramSchema = localJsonDoc.object(); + m_tools.append(localDocsSearch); + m_toolMap.insert(localDocsSearch->function, localDocsSearch); + + Tool *braveSearch = new BraveSearch; + braveSearch->name = tr("Brave web search"); + braveSearch->description = tr("Search the web using brave.com"); + braveSearch->function = "brave_search"; + braveSearch->isBuiltin = true; + braveSearch->excerpts = true; + braveSearch->forceUsage = false; // FIXME: persistent setting + braveSearch->isEnabled = false; // FIXME: persistent setting + + QString braveParamSchema = R"({ + "apiKey": { + "type": "string", + "description": "The api key to use", + "required": true, + "modelGenerated": false, + "userConfigured": true + }, + "query": { + "type": "string", + "description": "The query to search", + "required": true + }, + "count": { + "type": "integer", + "description": "The number of excerpts to return", + "required": true, + "modelGenerated": false + } + })"; + + QJsonDocument braveJsonDoc = QJsonDocument::fromJson(braveParamSchema.toUtf8()); + Q_ASSERT(!braveJsonDoc.isNull() && braveJsonDoc.isObject()); + braveSearch->paramSchema = braveJsonDoc.object(); + m_tools.append(braveSearch); + m_toolMap.insert(braveSearch->function, braveSearch); +} + +bool ToolModel::eventFilter(QObject *obj, QEvent *ev) +{ + if (obj == QCoreApplication::instance() && ev->type() == QEvent::LanguageChange) + emit dataChanged(index(0, 0), index(m_tools.size() - 1, 0)); + return false; +} diff --git a/gpt4all-chat/toolmodel.h b/gpt4all-chat/toolmodel.h new file mode 100644 index 000000000000..3da557936a06 --- /dev/null +++ b/gpt4all-chat/toolmodel.h @@ -0,0 +1,112 @@ +#ifndef TOOLMODEL_H +#define TOOLMODEL_H + +#include "tool.h" + +#include + +class ToolModel : public QAbstractListModel +{ + Q_OBJECT + Q_PROPERTY(int count READ count NOTIFY countChanged) + +public: + static ToolModel *globalInstance(); + + enum Roles { + NameRole = Qt::UserRole + 1, + DescriptionRole, + FunctionRole, + ParametersRole, + UrlRole, + ApiKeyRole, + KeyRequiredRole, + IsEnabledRole, + IsBuiltinRole, + ForceUsageRole, + ExcerptsRole, + }; + + int rowCount(const QModelIndex &parent = QModelIndex()) const override + { + Q_UNUSED(parent) + return m_tools.size(); + } + + QVariant data(const QModelIndex &index, int role = Qt::DisplayRole) const override + { + if (!index.isValid() || index.row() < 0 || index.row() >= m_tools.size()) + return QVariant(); + + const Tool *item = m_tools.at(index.row()); + switch (role) { + case NameRole: + return item->name; + case DescriptionRole: + return item->description; + case FunctionRole: + return item->function; + case ParametersRole: + return item->paramSchema; + case UrlRole: + return item->url; + case IsEnabledRole: + return item->isEnabled; + case IsBuiltinRole: + return item->isBuiltin; + case ForceUsageRole: + return item->forceUsage; + case ExcerptsRole: + return item->excerpts; + } + + return QVariant(); + } + + QHash roleNames() const override + { + QHash roles; + roles[NameRole] = "name"; + roles[DescriptionRole] = "description"; + roles[FunctionRole] = "function"; + roles[ParametersRole] = "parameters"; + roles[UrlRole] = "url"; + roles[ApiKeyRole] = "apiKey"; + roles[KeyRequiredRole] = "keyRequired"; + roles[IsEnabledRole] = "isEnabled"; + roles[IsBuiltinRole] = "isBuiltin"; + roles[ForceUsageRole] = "forceUsage"; + roles[ExcerptsRole] = "excerpts"; + return roles; + } + + Q_INVOKABLE Tool* get(int index) const + { + if (index < 0 || index >= m_tools.size()) return nullptr; + return m_tools.at(index); + } + + Q_INVOKABLE Tool *get(const QString &id) const + { + if (!m_toolMap.contains(id)) return nullptr; + return m_toolMap.value(id); + } + + int count() const { return m_tools.size(); } + +Q_SIGNALS: + void countChanged(); + void valueChanged(int index, const QString &value); + +protected: + bool eventFilter(QObject *obj, QEvent *ev) override; + +private: + explicit ToolModel(); + ~ToolModel() {} + friend class MyToolModel; + QList m_tools; + QHash m_toolMap; +}; + +#endif // TOOLMODEL_H From f93b76438e7857420883e3b72534dcabcc30d95c Mon Sep 17 00:00:00 2001 From: Adam Treat Date: Sat, 10 Aug 2024 09:48:54 -0400 Subject: [PATCH 17/30] Move the usearch submodule to third_party dir. Signed-off-by: Adam Treat --- .gitmodules | 2 +- gpt4all-chat/CMakeLists.txt | 4 ++-- gpt4all-chat/{ => third_party}/usearch | 0 3 files changed, 3 insertions(+), 3 deletions(-) rename gpt4all-chat/{ => third_party}/usearch (100%) diff --git a/.gitmodules b/.gitmodules index 98c9a2142a21..0bb4ae91fe11 100644 --- a/.gitmodules +++ b/.gitmodules @@ -3,5 +3,5 @@ url = https://github.com/nomic-ai/llama.cpp.git branch = master [submodule "gpt4all-chat/usearch"] - path = gpt4all-chat/usearch + path = gpt4all-chat/third_party/usearch url = https://github.com/nomic-ai/usearch.git diff --git a/gpt4all-chat/CMakeLists.txt b/gpt4all-chat/CMakeLists.txt index 08030af646e9..ba1f09774aa0 100644 --- a/gpt4all-chat/CMakeLists.txt +++ b/gpt4all-chat/CMakeLists.txt @@ -293,8 +293,8 @@ target_compile_definitions(chat # usearch uses the identifier 'slots' which conflicts with Qt's 'slots' keyword target_compile_definitions(chat PRIVATE QT_NO_SIGNALS_SLOTS_KEYWORDS) -target_include_directories(chat PRIVATE usearch/include - usearch/fp16/include) +target_include_directories(chat PRIVATE third_party/usearch/include + third_party/usearch/fp16/include) if(LINUX) target_link_libraries(chat diff --git a/gpt4all-chat/usearch b/gpt4all-chat/third_party/usearch similarity index 100% rename from gpt4all-chat/usearch rename to gpt4all-chat/third_party/usearch From 00ecbb75b4fa0d939147e4fac6700e75116ec2fa Mon Sep 17 00:00:00 2001 From: Adam Treat Date: Sat, 10 Aug 2024 10:01:35 -0400 Subject: [PATCH 18/30] Add jinja third party dependency. Signed-off-by: Adam Treat --- .gitmodules | 3 +++ gpt4all-chat/CMakeLists.txt | 2 ++ gpt4all-chat/third_party/jinja2cpp | 1 + 3 files changed, 6 insertions(+) create mode 160000 gpt4all-chat/third_party/jinja2cpp diff --git a/.gitmodules b/.gitmodules index 0bb4ae91fe11..b1ac1f5301a8 100644 --- a/.gitmodules +++ b/.gitmodules @@ -5,3 +5,6 @@ [submodule "gpt4all-chat/usearch"] path = gpt4all-chat/third_party/usearch url = https://github.com/nomic-ai/usearch.git +[submodule "gpt4all-chat/third_party/jinja2cpp"] + path = gpt4all-chat/third_party/jinja2cpp + url = https://github.com/nomic-ai/jinja2cpp.git diff --git a/gpt4all-chat/CMakeLists.txt b/gpt4all-chat/CMakeLists.txt index ba1f09774aa0..25c5817f1d82 100644 --- a/gpt4all-chat/CMakeLists.txt +++ b/gpt4all-chat/CMakeLists.txt @@ -296,6 +296,8 @@ target_compile_definitions(chat PRIVATE QT_NO_SIGNALS_SLOTS_KEYWORDS) target_include_directories(chat PRIVATE third_party/usearch/include third_party/usearch/fp16/include) +add_subdirectory(third_party/jinja2cpp ${CMAKE_BINARY_DIR}/jinja2cpp) + if(LINUX) target_link_libraries(chat PRIVATE Qt6::Quick Qt6::Svg Qt6::HttpServer Qt6::Sql Qt6::Pdf Qt6::WaylandCompositor) diff --git a/gpt4all-chat/third_party/jinja2cpp b/gpt4all-chat/third_party/jinja2cpp new file mode 160000 index 000000000000..e97a54e51336 --- /dev/null +++ b/gpt4all-chat/third_party/jinja2cpp @@ -0,0 +1 @@ +Subproject commit e97a54e51336938470eacb4ce261bde903e22e54 From c3cfaff80385d20a0ec08af93aa1ba0376790981 Mon Sep 17 00:00:00 2001 From: Adam Treat Date: Fri, 9 Aug 2024 11:02:10 -0400 Subject: [PATCH 19/30] Refactor and make use of jinja templates. Signed-off-by: Adam Treat --- gpt4all-chat/CMakeLists.txt | 2 +- gpt4all-chat/bravesearch.cpp | 50 +++++++++++++++++++++ gpt4all-chat/bravesearch.h | 10 +++++ gpt4all-chat/chatllm.cpp | 33 +++++++++++++- gpt4all-chat/localdocssearch.cpp | 32 ++++++++++++++ gpt4all-chat/localdocssearch.h | 9 ++++ gpt4all-chat/modellist.cpp | 18 ++++---- gpt4all-chat/modellist.h | 9 ++-- gpt4all-chat/mysettings.cpp | 6 +-- gpt4all-chat/mysettings.h | 4 +- gpt4all-chat/tool.cpp | 30 +++++++++++++ gpt4all-chat/tool.h | 74 ++++++++++++++++++++++---------- gpt4all-chat/toolmodel.cpp | 71 +----------------------------- gpt4all-chat/toolmodel.h | 18 ++++---- 14 files changed, 245 insertions(+), 121 deletions(-) diff --git a/gpt4all-chat/CMakeLists.txt b/gpt4all-chat/CMakeLists.txt index 25c5817f1d82..2b2fb4dbd1c2 100644 --- a/gpt4all-chat/CMakeLists.txt +++ b/gpt4all-chat/CMakeLists.txt @@ -306,7 +306,7 @@ else() PRIVATE Qt6::Quick Qt6::Svg Qt6::HttpServer Qt6::Sql Qt6::Pdf) endif() target_link_libraries(chat - PRIVATE llmodel) + PRIVATE llmodel jinja2cpp) # -- install -- diff --git a/gpt4all-chat/bravesearch.cpp b/gpt4all-chat/bravesearch.cpp index a9c950b0b80f..2b8c20ce149d 100644 --- a/gpt4all-chat/bravesearch.cpp +++ b/gpt4all-chat/bravesearch.cpp @@ -42,6 +42,56 @@ QString BraveSearch::run(const QJsonObject ¶meters, qint64 timeout) return worker.response(); } +QJsonObject BraveSearch::paramSchema() const +{ + static const QString braveParamSchema = R"({ + "apiKey": { + "type": "string", + "description": "The api key to use", + "required": true, + "modelGenerated": false, + "userConfigured": true + }, + "query": { + "type": "string", + "description": "The query to search", + "required": true + }, + "count": { + "type": "integer", + "description": "The number of excerpts to return", + "required": true, + "modelGenerated": false + } + })"; + + static const QJsonDocument braveJsonDoc = QJsonDocument::fromJson(braveParamSchema.toUtf8()); + Q_ASSERT(!braveJsonDoc.isNull() && braveJsonDoc.isObject()); + return braveJsonDoc.object(); +} + +QJsonObject BraveSearch::exampleParams() const +{ + static const QString example = R"({ + "query": "the 44th president of the United States" + })"; + static const QJsonDocument exampleDoc = QJsonDocument::fromJson(example.toUtf8()); + Q_ASSERT(!exampleDoc.isNull() && exampleDoc.isObject()); + return exampleDoc.object(); +} + +bool BraveSearch::isEnabled() const +{ + // FIXME: Refer to mysettings + return true; +} + +bool BraveSearch::forceUsage() const +{ + // FIXME: Refer to mysettings + return false; +} + void BraveAPIWorker::request(const QString &apiKey, const QString &query, int count) { // Documentation on the brave web search: diff --git a/gpt4all-chat/bravesearch.h b/gpt4all-chat/bravesearch.h index dc0cc0423688..a1817412c7f3 100644 --- a/gpt4all-chat/bravesearch.h +++ b/gpt4all-chat/bravesearch.h @@ -48,6 +48,16 @@ class BraveSearch : public Tool { ToolEnums::Error error() const override { return m_error; } QString errorString() const override { return m_errorString; } + QString name() const override { return tr("Brave web search"); } + QString description() const override { return tr("Search the web using brave"); } + QString function() const override { return "brave_search"; } + QJsonObject paramSchema() const override; + QJsonObject exampleParams() const override; + bool isEnabled() const override; + bool isBuiltin() const override { return true; } + bool forceUsage() const override; + bool excerpts() const override { return true; } + private: ToolEnums::Error m_error; QString m_errorString; diff --git a/gpt4all-chat/chatllm.cpp b/gpt4all-chat/chatllm.cpp index 381514961fd6..7d42dbf57652 100644 --- a/gpt4all-chat/chatllm.cpp +++ b/gpt4all-chat/chatllm.cpp @@ -6,6 +6,8 @@ #include "localdocssearch.h" #include "mysettings.h" #include "network.h" +#include "tool.h" +#include "toolmodel.h" #include #include @@ -29,6 +31,7 @@ #include #include #include +#include #include #include #include @@ -1332,7 +1335,35 @@ void ChatLLM::processSystemPrompt() if (!isModelLoaded() || m_processedSystemPrompt || m_restoreStateFromText || m_isServer) return; - const std::string systemPrompt = MySettings::globalInstance()->modelSystemPrompt(m_modelInfo).toStdString(); + const std::string systemPromptTemplate = MySettings::globalInstance()->modelSystemPromptTemplate(m_modelInfo).toStdString(); + + // FIXME: This needs to be moved to settings probably and the same code used for validation + jinja2::ValuesMap params; + params.insert({"currentDate", QDate::currentDate().toString().toStdString()}); + + jinja2::ValuesList toolList; + int c = ToolModel::globalInstance()->count(); + for (int i = 0; i < c; ++i) { + Tool *t = ToolModel::globalInstance()->get(i); + if (t->isEnabled() && !t->forceUsage()) + toolList.push_back(t->jinjaValue()); + } + params.insert({"toolList", toolList}); + + std::string systemPrompt; + + jinja2::Template t; + t.Load(systemPromptTemplate); + const auto renderResult = t.RenderAsString(params); + + // The GUI should not allow setting an improper template, but it is always possible someone hand + // edits the settings file to produce an improper one. + Q_ASSERT(renderResult); + if (renderResult) + systemPrompt = renderResult.value(); + else + qWarning() << "ERROR: Could not parse system prompt template:" << renderResult.error().ToString(); + if (QString::fromStdString(systemPrompt).trimmed().isEmpty()) { m_processedSystemPrompt = true; return; diff --git a/gpt4all-chat/localdocssearch.cpp b/gpt4all-chat/localdocssearch.cpp index 2fc9dbe217bd..424c99a1cb7f 100644 --- a/gpt4all-chat/localdocssearch.cpp +++ b/gpt4all-chat/localdocssearch.cpp @@ -6,6 +6,7 @@ #include #include #include +#include #include #include @@ -39,6 +40,37 @@ QString LocalDocsSearch::run(const QJsonObject ¶meters, qint64 timeout) return worker.response(); } +QJsonObject LocalDocsSearch::paramSchema() const +{ + static const QString localParamSchema = R"({ + "collections": { + "type": "array", + "items": { + "type": "string" + }, + "description": "The collections to search", + "required": true, + "modelGenerated": false, + "userConfigured": false + }, + "query": { + "type": "string", + "description": "The query to search", + "required": true + }, + "count": { + "type": "integer", + "description": "The number of excerpts to return", + "required": true, + "modelGenerated": false + } + })"; + + static const QJsonDocument localJsonDoc = QJsonDocument::fromJson(localParamSchema.toUtf8()); + Q_ASSERT(!localJsonDoc.isNull() && localJsonDoc.isObject()); + return localJsonDoc.object(); +} + LocalDocsWorker::LocalDocsWorker() : QObject(nullptr) { diff --git a/gpt4all-chat/localdocssearch.h b/gpt4all-chat/localdocssearch.h index 61e56b737cb3..23f64215401d 100644 --- a/gpt4all-chat/localdocssearch.h +++ b/gpt4all-chat/localdocssearch.h @@ -34,6 +34,15 @@ class LocalDocsSearch : public Tool { ToolEnums::Error error() const override { return m_error; } QString errorString() const override { return m_errorString; } + QString name() const override { return tr("LocalDocs search"); } + QString description() const override { return tr("Search the local docs"); } + QString function() const override { return "localdocs_search"; } + QJsonObject paramSchema() const override; + bool isEnabled() const override { return true; } + bool isBuiltin() const override { return true; } + bool forceUsage() const override { return true; } + bool excerpts() const override { return true; } + private: ToolEnums::Error m_error; QString m_errorString; diff --git a/gpt4all-chat/modellist.cpp b/gpt4all-chat/modellist.cpp index 6ed911508e92..5d3530c1c650 100644 --- a/gpt4all-chat/modellist.cpp +++ b/gpt4all-chat/modellist.cpp @@ -334,15 +334,15 @@ void ModelInfo::setToolTemplate(const QString &t) m_toolTemplate = t; } -QString ModelInfo::systemPrompt() const +QString ModelInfo::systemPromptTemplate() const { - return MySettings::globalInstance()->modelSystemPrompt(*this); + return MySettings::globalInstance()->modelSystemPromptTemplate(*this); } -void ModelInfo::setSystemPrompt(const QString &p) +void ModelInfo::setSystemPromptTemplate(const QString &p) { - if (shouldSaveMetadata()) MySettings::globalInstance()->setModelSystemPrompt(*this, p, true /*force*/); - m_systemPrompt = p; + if (shouldSaveMetadata()) MySettings::globalInstance()->setModelSystemPromptTemplate(*this, p, true /*force*/); + m_systemPromptTemplate = p; } QString ModelInfo::chatNamePrompt() const @@ -397,7 +397,7 @@ QVariantMap ModelInfo::getFields() const { "repeatPenaltyTokens", m_repeatPenaltyTokens }, { "promptTemplate", m_promptTemplate }, { "toolTemplate", m_toolTemplate }, - { "systemPrompt", m_systemPrompt }, + { "systemPromptTemplate",m_systemPromptTemplate }, { "chatNamePrompt", m_chatNamePrompt }, { "suggestedFollowUpPrompt", m_suggestedFollowUpPrompt }, }; @@ -792,7 +792,7 @@ QVariant ModelList::dataInternal(const ModelInfo *info, int role) const case ToolTemplateRole: return info->toolTemplate(); case SystemPromptRole: - return info->systemPrompt(); + return info->systemPromptTemplate(); case ChatNamePromptRole: return info->chatNamePrompt(); case SuggestedFollowUpPromptRole: @@ -970,7 +970,7 @@ void ModelList::updateData(const QString &id, const QVector case ToolTemplateRole: info->setToolTemplate(value.toString()); break; case SystemPromptRole: - info->setSystemPrompt(value.toString()); break; + info->setSystemPromptTemplate(value.toString()); break; case ChatNamePromptRole: info->setChatNamePrompt(value.toString()); break; case SuggestedFollowUpPromptRole: @@ -1125,7 +1125,7 @@ QString ModelList::clone(const ModelInfo &model) { ModelList::RepeatPenaltyTokensRole, model.repeatPenaltyTokens() }, { ModelList::PromptTemplateRole, model.promptTemplate() }, { ModelList::ToolTemplateRole, model.toolTemplate() }, - { ModelList::SystemPromptRole, model.systemPrompt() }, + { ModelList::SystemPromptRole, model.systemPromptTemplate() }, { ModelList::ChatNamePromptRole, model.chatNamePrompt() }, { ModelList::SuggestedFollowUpPromptRole, model.suggestedFollowUpPrompt() }, }; diff --git a/gpt4all-chat/modellist.h b/gpt4all-chat/modellist.h index 9e9f088f76cb..766807461da5 100644 --- a/gpt4all-chat/modellist.h +++ b/gpt4all-chat/modellist.h @@ -69,7 +69,7 @@ struct ModelInfo { Q_PROPERTY(int repeatPenaltyTokens READ repeatPenaltyTokens WRITE setRepeatPenaltyTokens) Q_PROPERTY(QString promptTemplate READ promptTemplate WRITE setPromptTemplate) Q_PROPERTY(QString toolTemplate READ toolTemplate WRITE setToolTemplate) - Q_PROPERTY(QString systemPrompt READ systemPrompt WRITE setSystemPrompt) + Q_PROPERTY(QString systemPromptTemplate READ systemPromptTemplate WRITE setSystemPromptTemplate) Q_PROPERTY(QString chatNamePrompt READ chatNamePrompt WRITE setChatNamePrompt) Q_PROPERTY(QString suggestedFollowUpPrompt READ suggestedFollowUpPrompt WRITE setSuggestedFollowUpPrompt) Q_PROPERTY(int likes READ likes WRITE setLikes) @@ -181,8 +181,9 @@ struct ModelInfo { void setPromptTemplate(const QString &t); QString toolTemplate() const; void setToolTemplate(const QString &t); - QString systemPrompt() const; - void setSystemPrompt(const QString &p); + QString systemPromptTemplate() const; + void setSystemPromptTemplate(const QString &p); + // FIXME (adam): The chatname and suggested follow-up should also be templates I guess? QString chatNamePrompt() const; void setChatNamePrompt(const QString &p); QString suggestedFollowUpPrompt() const; @@ -219,7 +220,7 @@ struct ModelInfo { int m_repeatPenaltyTokens = 64; QString m_promptTemplate = "### Human:\n%1\n\n### Assistant:\n"; QString m_toolTemplate = ""; - QString m_systemPrompt = "### System:\nYou are an AI assistant who gives a quality response to whatever humans ask of you.\n\n"; + QString m_systemPromptTemplate = "### System:\nYou are an AI assistant who gives a quality response to whatever humans ask of you.\n\n"; QString m_chatNamePrompt = "Describe the above conversation in seven words or less."; QString m_suggestedFollowUpPrompt = "Suggest three very short factual follow-up questions that have not been answered yet or cannot be found inspired by the previous conversation and excerpts."; friend class MySettings; diff --git a/gpt4all-chat/mysettings.cpp b/gpt4all-chat/mysettings.cpp index 0b94989edfb7..b9bb76b546c2 100644 --- a/gpt4all-chat/mysettings.cpp +++ b/gpt4all-chat/mysettings.cpp @@ -195,7 +195,7 @@ void MySettings::restoreModelDefaults(const ModelInfo &info) setModelRepeatPenaltyTokens(info, info.m_repeatPenaltyTokens); setModelPromptTemplate(info, info.m_promptTemplate); setModelToolTemplate(info, info.m_toolTemplate); - setModelSystemPrompt(info, info.m_systemPrompt); + setModelSystemPromptTemplate(info, info.m_systemPromptTemplate); setModelChatNamePrompt(info, info.m_chatNamePrompt); setModelSuggestedFollowUpPrompt(info, info.m_suggestedFollowUpPrompt); } @@ -298,7 +298,7 @@ double MySettings::modelRepeatPenalty (const ModelInfo &info) const int MySettings::modelRepeatPenaltyTokens (const ModelInfo &info) const { return getModelSetting("repeatPenaltyTokens", info).toInt(); } QString MySettings::modelPromptTemplate (const ModelInfo &info) const { return getModelSetting("promptTemplate", info).toString(); } QString MySettings::modelToolTemplate (const ModelInfo &info) const { return getModelSetting("toolTemplate", info).toString(); } -QString MySettings::modelSystemPrompt (const ModelInfo &info) const { return getModelSetting("systemPrompt", info).toString(); } +QString MySettings::modelSystemPromptTemplate (const ModelInfo &info) const { return getModelSetting("systemPrompt", info).toString(); } QString MySettings::modelChatNamePrompt (const ModelInfo &info) const { return getModelSetting("chatNamePrompt", info).toString(); } QString MySettings::modelSuggestedFollowUpPrompt(const ModelInfo &info) const { return getModelSetting("suggestedFollowUpPrompt", info).toString(); } @@ -412,7 +412,7 @@ void MySettings::setModelToolTemplate(const ModelInfo &info, const QString &valu setModelSetting("toolTemplate", info, value, force, true); } -void MySettings::setModelSystemPrompt(const ModelInfo &info, const QString &value, bool force) +void MySettings::setModelSystemPromptTemplate(const ModelInfo &info, const QString &value, bool force) { setModelSetting("systemPrompt", info, value, force, true); } diff --git a/gpt4all-chat/mysettings.h b/gpt4all-chat/mysettings.h index 205c21d03192..09da6681da03 100644 --- a/gpt4all-chat/mysettings.h +++ b/gpt4all-chat/mysettings.h @@ -128,8 +128,8 @@ class MySettings : public QObject Q_INVOKABLE void setModelPromptTemplate(const ModelInfo &info, const QString &value, bool force = false); QString modelToolTemplate(const ModelInfo &info) const; Q_INVOKABLE void setModelToolTemplate(const ModelInfo &info, const QString &value, bool force = false); - QString modelSystemPrompt(const ModelInfo &info) const; - Q_INVOKABLE void setModelSystemPrompt(const ModelInfo &info, const QString &value, bool force = false); + QString modelSystemPromptTemplate(const ModelInfo &info) const; + Q_INVOKABLE void setModelSystemPromptTemplate(const ModelInfo &info, const QString &value, bool force = false); int modelContextLength(const ModelInfo &info) const; Q_INVOKABLE void setModelContextLength(const ModelInfo &info, int value, bool force = false); int modelGpuLayers(const ModelInfo &info) const; diff --git a/gpt4all-chat/tool.cpp b/gpt4all-chat/tool.cpp index 7261f42bf170..e2d13dda9987 100644 --- a/gpt4all-chat/tool.cpp +++ b/gpt4all-chat/tool.cpp @@ -1 +1,31 @@ #include "tool.h" + +#include + +QJsonObject filterModelGeneratedProperties(const QJsonObject &inputObject) { + QJsonObject filteredObject; + for (const QString &key : inputObject.keys()) { + QJsonObject propertyObject = inputObject.value(key).toObject(); + if (!propertyObject.contains("modelGenerated") || propertyObject["modelGenerated"].toBool()) + filteredObject.insert(key, propertyObject); + } + return filteredObject; +} + +jinja2::Value Tool::jinjaValue() const +{ + QJsonDocument doc(filterModelGeneratedProperties(paramSchema())); + QString p(doc.toJson(QJsonDocument::Compact)); + + QJsonDocument exampleDoc(exampleParams()); + QString e(exampleDoc.toJson(QJsonDocument::Compact)); + + jinja2::ValuesMap params { + { "name", name().toStdString() }, + { "description", description().toStdString() }, + { "function", function().toStdString() }, + { "paramSchema", p.toStdString() }, + { "exampleParams", e.toStdString() } + }; + return params; +} diff --git a/gpt4all-chat/tool.h b/gpt4all-chat/tool.h index 7598af6885a0..6728f58762e3 100644 --- a/gpt4all-chat/tool.h +++ b/gpt4all-chat/tool.h @@ -3,6 +3,7 @@ #include #include +#include using namespace Qt::Literals::StringLiterals; @@ -18,15 +19,16 @@ namespace ToolEnums { class Tool : public QObject { Q_OBJECT - Q_PROPERTY(QString name MEMBER name) - Q_PROPERTY(QString description MEMBER description) - Q_PROPERTY(QString function MEMBER function) - Q_PROPERTY(QJsonObject paramSchema MEMBER paramSchema) - Q_PROPERTY(QUrl url MEMBER url) - Q_PROPERTY(bool isEnabled MEMBER isEnabled) - Q_PROPERTY(bool isBuiltin MEMBER isBuiltin) - Q_PROPERTY(bool forceUsage MEMBER forceUsage) - Q_PROPERTY(bool excerpts MEMBER excerpts) + Q_PROPERTY(QString name READ name CONSTANT) + Q_PROPERTY(QString description READ description CONSTANT) + Q_PROPERTY(QString function READ function CONSTANT) + Q_PROPERTY(QJsonObject paramSchema READ paramSchema CONSTANT) + Q_PROPERTY(QJsonObject exampleParams READ exampleParams CONSTANT) + Q_PROPERTY(QUrl url READ url CONSTANT) + Q_PROPERTY(bool isEnabled READ isEnabled NOTIFY isEnabledChanged) + Q_PROPERTY(bool isBuiltin READ isBuiltin CONSTANT) + Q_PROPERTY(bool forceUsage READ forceUsage NOTIFY forceUsageChanged) + Q_PROPERTY(bool excerpts READ excerpts CONSTANT) public: Tool() : QObject(nullptr) {} @@ -36,28 +38,54 @@ class Tool : public QObject { virtual ToolEnums::Error error() const { return ToolEnums::Error::NoError; } virtual QString errorString() const { return QString(); } - QString name; // [Required] Human readable name of the tool. - QString description; // [Required] Human readable description of the tool. - QString function; // [Required] Must be unique. Name of the function to invoke. Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length of 64. - QJsonObject paramSchema; // [Optional] Json schema describing the tool's parameters. An empty object specifies no parameters. - // https://json-schema.org/understanding-json-schema/ - QUrl url; // [Optional] The local file or remote resource use to invoke the tool. - bool isEnabled = false; // [Optional] Whether the tool is currently enabled - bool isBuiltin = false; // [Optional] Whether the tool is built-in - bool forceUsage = false; // [Optional] Whether we should attempt to force usage of the tool rather than let the LLM decide. NOTE: Not always possible. - bool excerpts = false; // [Optional] Whether json result produces source excerpts. - - // FIXME: Should we go with essentially the OpenAI/ollama consensus for these tool - // info files? If you install a tool in GPT4All should it need to meet the spec for these: + // [Required] Human readable name of the tool. + virtual QString name() const = 0; + + // [Required] Human readable description of the tool. + virtual QString description() const = 0; + + // [Required] Must be unique. Name of the function to invoke. Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length of 64. + virtual QString function() const = 0; + + // [Optional] Json schema describing the tool's parameters. An empty object specifies no parameters. + // https://json-schema.org/understanding-json-schema/ // https://platform.openai.com/docs/api-reference/runs/createRun#runs-createrun-tools // https://github.com/ollama/ollama/blob/main/docs/api.md#chat-request-with-tools + // FIXME: This should be validated against json schema + virtual QJsonObject paramSchema() const { return QJsonObject(); } + + // [Optional] An example of the parameters for this tool call. NOTE: This should only include parameters + // that the model is responsible for generating. + virtual QJsonObject exampleParams() const { return QJsonObject(); } + + // [Optional] The local file or remote resource use to invoke the tool. + virtual QUrl url() const { return QUrl(); } + + // [Optional] Whether the tool is currently enabled + virtual bool isEnabled() const { return false; } + + // [Optional] Whether the tool is built-in + virtual bool isBuiltin() const { return false; } + + // [Optional] Whether we should attempt to force usage of the tool rather than let the LLM decide. NOTE: Not always possible. + virtual bool forceUsage() const { return false; } + + // [Optional] Whether json result produces source excerpts. + virtual bool excerpts() const { return false; } bool operator==(const Tool &other) const { - return function == other.function; + return function() == other.function(); } bool operator!=(const Tool &other) const { return !(*this == other); } + + jinja2::Value jinjaValue() const; + +Q_SIGNALS: + void isEnabledChanged(); + void forceUsageChanged(); + }; #endif // TOOL_H diff --git a/gpt4all-chat/toolmodel.cpp b/gpt4all-chat/toolmodel.cpp index 4a6d61c32f06..9f8547eaa2f0 100644 --- a/gpt4all-chat/toolmodel.cpp +++ b/gpt4all-chat/toolmodel.cpp @@ -20,79 +20,12 @@ ToolModel::ToolModel() QCoreApplication::instance()->installEventFilter(this); Tool* localDocsSearch = new LocalDocsSearch; - localDocsSearch->name = tr("LocalDocs search"); - localDocsSearch->description = tr("Search the local docs"); - localDocsSearch->function = "localdocs_search"; - localDocsSearch->isBuiltin = true; - localDocsSearch->excerpts = true; - localDocsSearch->forceUsage = true; // FIXME: persistent setting - localDocsSearch->isEnabled = true; // FIXME: persistent setting - - QString localParamSchema = R"({ - "collections": { - "type": "array", - "items": { - "type": "string" - }, - "description": "The collections to search", - "required": true, - "modelGenerated": false, - "userConfigured": false - }, - "query": { - "type": "string", - "description": "The query to search", - "required": true - }, - "count": { - "type": "integer", - "description": "The number of excerpts to return", - "required": true, - "modelGenerated": false - } - })"; - - QJsonDocument localJsonDoc = QJsonDocument::fromJson(localParamSchema.toUtf8()); - Q_ASSERT(!localJsonDoc.isNull() && localJsonDoc.isObject()); - localDocsSearch->paramSchema = localJsonDoc.object(); m_tools.append(localDocsSearch); - m_toolMap.insert(localDocsSearch->function, localDocsSearch); + m_toolMap.insert(localDocsSearch->function(), localDocsSearch); Tool *braveSearch = new BraveSearch; - braveSearch->name = tr("Brave web search"); - braveSearch->description = tr("Search the web using brave.com"); - braveSearch->function = "brave_search"; - braveSearch->isBuiltin = true; - braveSearch->excerpts = true; - braveSearch->forceUsage = false; // FIXME: persistent setting - braveSearch->isEnabled = false; // FIXME: persistent setting - - QString braveParamSchema = R"({ - "apiKey": { - "type": "string", - "description": "The api key to use", - "required": true, - "modelGenerated": false, - "userConfigured": true - }, - "query": { - "type": "string", - "description": "The query to search", - "required": true - }, - "count": { - "type": "integer", - "description": "The number of excerpts to return", - "required": true, - "modelGenerated": false - } - })"; - - QJsonDocument braveJsonDoc = QJsonDocument::fromJson(braveParamSchema.toUtf8()); - Q_ASSERT(!braveJsonDoc.isNull() && braveJsonDoc.isObject()); - braveSearch->paramSchema = braveJsonDoc.object(); m_tools.append(braveSearch); - m_toolMap.insert(braveSearch->function, braveSearch); + m_toolMap.insert(braveSearch->function(), braveSearch); } bool ToolModel::eventFilter(QObject *obj, QEvent *ev) diff --git a/gpt4all-chat/toolmodel.h b/gpt4all-chat/toolmodel.h index 3da557936a06..d31bb1bc7fbe 100644 --- a/gpt4all-chat/toolmodel.h +++ b/gpt4all-chat/toolmodel.h @@ -41,23 +41,23 @@ class ToolModel : public QAbstractListModel const Tool *item = m_tools.at(index.row()); switch (role) { case NameRole: - return item->name; + return item->name(); case DescriptionRole: - return item->description; + return item->description(); case FunctionRole: - return item->function; + return item->function(); case ParametersRole: - return item->paramSchema; + return item->paramSchema(); case UrlRole: - return item->url; + return item->url(); case IsEnabledRole: - return item->isEnabled; + return item->isEnabled(); case IsBuiltinRole: - return item->isBuiltin; + return item->isBuiltin(); case ForceUsageRole: - return item->forceUsage; + return item->forceUsage(); case ExcerptsRole: - return item->excerpts; + return item->excerpts(); } return QVariant(); From 587dd55b7377fee6697b27e4271f417a960c82d8 Mon Sep 17 00:00:00 2001 From: Adam Treat Date: Mon, 12 Aug 2024 21:48:51 -0400 Subject: [PATCH 20/30] Get rid of the name change now that 3.2.0 has been released. Signed-off-by: Adam Treat --- gpt4all-chat/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gpt4all-chat/CMakeLists.txt b/gpt4all-chat/CMakeLists.txt index 2b2fb4dbd1c2..de1cf1fb2802 100644 --- a/gpt4all-chat/CMakeLists.txt +++ b/gpt4all-chat/CMakeLists.txt @@ -20,7 +20,7 @@ set(APP_VERSION_MAJOR 3) set(APP_VERSION_MINOR 2) set(APP_VERSION_PATCH 2) set(APP_VERSION_BASE "${APP_VERSION_MAJOR}.${APP_VERSION_MINOR}.${APP_VERSION_PATCH}") -set(APP_VERSION "${APP_VERSION_BASE}-web_search_beta_3") +set(APP_VERSION "${APP_VERSION_BASE}-dev0") list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_LIST_DIR}/cmake/Modules") From 227dbfd18b74372600db3f7caf2ccbe1a880d50e Mon Sep 17 00:00:00 2001 From: Adam Treat Date: Tue, 13 Aug 2024 12:27:59 -0400 Subject: [PATCH 21/30] Use an enum for tool usage mode. Signed-off-by: Adam Treat --- gpt4all-chat/bravesearch.cpp | 12 +++--------- gpt4all-chat/bravesearch.h | 3 +-- gpt4all-chat/chatllm.cpp | 4 +++- gpt4all-chat/localdocssearch.h | 3 +-- gpt4all-chat/qml/ModelSettings.qml | 8 ++++---- gpt4all-chat/tool.h | 22 ++++++++++++---------- gpt4all-chat/toolmodel.h | 12 ++++-------- 7 files changed, 28 insertions(+), 36 deletions(-) diff --git a/gpt4all-chat/bravesearch.cpp b/gpt4all-chat/bravesearch.cpp index 2b8c20ce149d..197ed383af5c 100644 --- a/gpt4all-chat/bravesearch.cpp +++ b/gpt4all-chat/bravesearch.cpp @@ -80,16 +80,10 @@ QJsonObject BraveSearch::exampleParams() const return exampleDoc.object(); } -bool BraveSearch::isEnabled() const +ToolEnums::UsageMode BraveSearch::usageMode() const { - // FIXME: Refer to mysettings - return true; -} - -bool BraveSearch::forceUsage() const -{ - // FIXME: Refer to mysettings - return false; + // FIXME: This needs to be a setting + return ToolEnums::UsageMode::Enabled; } void BraveAPIWorker::request(const QString &apiKey, const QString &query, int count) diff --git a/gpt4all-chat/bravesearch.h b/gpt4all-chat/bravesearch.h index a1817412c7f3..45cd0a6e9333 100644 --- a/gpt4all-chat/bravesearch.h +++ b/gpt4all-chat/bravesearch.h @@ -53,9 +53,8 @@ class BraveSearch : public Tool { QString function() const override { return "brave_search"; } QJsonObject paramSchema() const override; QJsonObject exampleParams() const override; - bool isEnabled() const override; bool isBuiltin() const override { return true; } - bool forceUsage() const override; + ToolEnums::UsageMode usageMode() const override; bool excerpts() const override { return true; } private: diff --git a/gpt4all-chat/chatllm.cpp b/gpt4all-chat/chatllm.cpp index 7d42dbf57652..5a0e03ffb082 100644 --- a/gpt4all-chat/chatllm.cpp +++ b/gpt4all-chat/chatllm.cpp @@ -766,6 +766,8 @@ bool ChatLLM::promptInternal(const QList &collectionList, const QString if (!isModelLoaded()) return false; + // FIXME: This should be made agnostic to localdocs and rely upon the force usage usage mode + // and also we have to honor the ask before running mode. QList localDocsExcerpts; if (!collectionList.isEmpty() && !isToolCallResponse) { LocalDocsSearch localdocs; @@ -1345,7 +1347,7 @@ void ChatLLM::processSystemPrompt() int c = ToolModel::globalInstance()->count(); for (int i = 0; i < c; ++i) { Tool *t = ToolModel::globalInstance()->get(i); - if (t->isEnabled() && !t->forceUsage()) + if (t->usageMode() == ToolEnums::UsageMode::Enabled) toolList.push_back(t->jinjaValue()); } params.insert({"toolList", toolList}); diff --git a/gpt4all-chat/localdocssearch.h b/gpt4all-chat/localdocssearch.h index 23f64215401d..66c63414521c 100644 --- a/gpt4all-chat/localdocssearch.h +++ b/gpt4all-chat/localdocssearch.h @@ -38,9 +38,8 @@ class LocalDocsSearch : public Tool { QString description() const override { return tr("Search the local docs"); } QString function() const override { return "localdocs_search"; } QJsonObject paramSchema() const override; - bool isEnabled() const override { return true; } bool isBuiltin() const override { return true; } - bool forceUsage() const override { return true; } + ToolEnums::UsageMode usageMode() const override { return ToolEnums::UsageMode::ForceUsage; } bool excerpts() const override { return true; } private: diff --git a/gpt4all-chat/qml/ModelSettings.qml b/gpt4all-chat/qml/ModelSettings.qml index 9948934988f8..d2f47976f6ba 100644 --- a/gpt4all-chat/qml/ModelSettings.qml +++ b/gpt4all-chat/qml/ModelSettings.qml @@ -174,21 +174,21 @@ MySettingsTab { MyTextArea { id: systemPromptArea anchors.fill: parent - text: root.currentModelInfo.systemPrompt + text: root.currentModelInfo.systemPromptTemplate Connections { target: MySettings function onSystemPromptChanged() { - systemPromptArea.text = root.currentModelInfo.systemPrompt; + systemPromptArea.text = root.currentModelInfo.systemPromptTemplate; } } Connections { target: root function onCurrentModelInfoChanged() { - systemPromptArea.text = root.currentModelInfo.systemPrompt; + systemPromptArea.text = root.currentModelInfo.systemPromptTemplate; } } onTextChanged: { - MySettings.setModelSystemPrompt(root.currentModelInfo, text) + MySettings.setModelSystemPromptTemplate(root.currentModelInfo, text) } Accessible.role: Accessible.EditableText } diff --git a/gpt4all-chat/tool.h b/gpt4all-chat/tool.h index 6728f58762e3..92a3b758efbc 100644 --- a/gpt4all-chat/tool.h +++ b/gpt4all-chat/tool.h @@ -15,6 +15,14 @@ namespace ToolEnums { UnknownError = 499, }; Q_ENUM_NS(Error) + + enum class UsageMode { + Disabled, // Completely disabled + Enabled, // Enabled and the model decides whether to run + AskBeforeRunning, // Enabled and model decides but the user is queried whether they want the tool to run in every instance + ForceUsage, // Attempt to force usage of the tool rather than let the LLM decide. NOTE: Not always possible. + }; + Q_ENUM_NS(UsageMode) } class Tool : public QObject { @@ -25,9 +33,8 @@ class Tool : public QObject { Q_PROPERTY(QJsonObject paramSchema READ paramSchema CONSTANT) Q_PROPERTY(QJsonObject exampleParams READ exampleParams CONSTANT) Q_PROPERTY(QUrl url READ url CONSTANT) - Q_PROPERTY(bool isEnabled READ isEnabled NOTIFY isEnabledChanged) Q_PROPERTY(bool isBuiltin READ isBuiltin CONSTANT) - Q_PROPERTY(bool forceUsage READ forceUsage NOTIFY forceUsageChanged) + Q_PROPERTY(ToolEnums::UsageMode usageMode READ usageMode NOTIFY usageModeChanged) Q_PROPERTY(bool excerpts READ excerpts CONSTANT) public: @@ -61,14 +68,11 @@ class Tool : public QObject { // [Optional] The local file or remote resource use to invoke the tool. virtual QUrl url() const { return QUrl(); } - // [Optional] Whether the tool is currently enabled - virtual bool isEnabled() const { return false; } - // [Optional] Whether the tool is built-in virtual bool isBuiltin() const { return false; } - // [Optional] Whether we should attempt to force usage of the tool rather than let the LLM decide. NOTE: Not always possible. - virtual bool forceUsage() const { return false; } + // [Optional] The current usage mode + virtual ToolEnums::UsageMode usageMode() const { return ToolEnums::UsageMode::Disabled; } // [Optional] Whether json result produces source excerpts. virtual bool excerpts() const { return false; } @@ -83,9 +87,7 @@ class Tool : public QObject { jinja2::Value jinjaValue() const; Q_SIGNALS: - void isEnabledChanged(); - void forceUsageChanged(); - + void usageModeChanged(); }; #endif // TOOL_H diff --git a/gpt4all-chat/toolmodel.h b/gpt4all-chat/toolmodel.h index d31bb1bc7fbe..f1b599229fef 100644 --- a/gpt4all-chat/toolmodel.h +++ b/gpt4all-chat/toolmodel.h @@ -21,9 +21,8 @@ class ToolModel : public QAbstractListModel UrlRole, ApiKeyRole, KeyRequiredRole, - IsEnabledRole, IsBuiltinRole, - ForceUsageRole, + UsageModeRole, ExcerptsRole, }; @@ -50,12 +49,10 @@ class ToolModel : public QAbstractListModel return item->paramSchema(); case UrlRole: return item->url(); - case IsEnabledRole: - return item->isEnabled(); case IsBuiltinRole: return item->isBuiltin(); - case ForceUsageRole: - return item->forceUsage(); + case UsageModeRole: + return QVariant::fromValue(item->usageMode()); case ExcerptsRole: return item->excerpts(); } @@ -73,9 +70,8 @@ class ToolModel : public QAbstractListModel roles[UrlRole] = "url"; roles[ApiKeyRole] = "apiKey"; roles[KeyRequiredRole] = "keyRequired"; - roles[IsEnabledRole] = "isEnabled"; roles[IsBuiltinRole] = "isBuiltin"; - roles[ForceUsageRole] = "forceUsage"; + roles[UsageModeRole] = "usageMode"; roles[ExcerptsRole] = "excerpts"; return roles; } From 48117cda46ceb75c4ef0f1871b180eef8797829a Mon Sep 17 00:00:00 2001 From: Adam Treat Date: Tue, 13 Aug 2024 13:23:12 -0400 Subject: [PATCH 22/30] Move the jinja processing to mysettings and validation. Signed-off-by: Adam Treat --- gpt4all-chat/chatllm.cpp | 31 ++-------- gpt4all-chat/mysettings.cpp | 45 ++++++++++++++ gpt4all-chat/mysettings.h | 6 ++ gpt4all-chat/qml/ModelSettings.qml | 35 +++++++++-- gpt4all-chat/toolinfo.h | 95 ------------------------------ 5 files changed, 86 insertions(+), 126 deletions(-) delete mode 100644 gpt4all-chat/toolinfo.h diff --git a/gpt4all-chat/chatllm.cpp b/gpt4all-chat/chatllm.cpp index 5a0e03ffb082..00bf86a715f4 100644 --- a/gpt4all-chat/chatllm.cpp +++ b/gpt4all-chat/chatllm.cpp @@ -31,7 +31,6 @@ #include #include #include -#include #include #include #include @@ -1337,34 +1336,14 @@ void ChatLLM::processSystemPrompt() if (!isModelLoaded() || m_processedSystemPrompt || m_restoreStateFromText || m_isServer) return; - const std::string systemPromptTemplate = MySettings::globalInstance()->modelSystemPromptTemplate(m_modelInfo).toStdString(); - - // FIXME: This needs to be moved to settings probably and the same code used for validation - jinja2::ValuesMap params; - params.insert({"currentDate", QDate::currentDate().toString().toStdString()}); - - jinja2::ValuesList toolList; - int c = ToolModel::globalInstance()->count(); - for (int i = 0; i < c; ++i) { - Tool *t = ToolModel::globalInstance()->get(i); - if (t->usageMode() == ToolEnums::UsageMode::Enabled) - toolList.push_back(t->jinjaValue()); - } - params.insert({"toolList", toolList}); - - std::string systemPrompt; - - jinja2::Template t; - t.Load(systemPromptTemplate); - const auto renderResult = t.RenderAsString(params); + QString error; + const std::string systemPrompt = MySettings::globalInstance()->modelSystemPrompt(m_modelInfo, error).toStdString(); // The GUI should not allow setting an improper template, but it is always possible someone hand // edits the settings file to produce an improper one. - Q_ASSERT(renderResult); - if (renderResult) - systemPrompt = renderResult.value(); - else - qWarning() << "ERROR: Could not parse system prompt template:" << renderResult.error().ToString(); + Q_ASSERT(error.isEmpty()); + if (!error.isEmpty()) + qWarning() << "ERROR: Could not parse system prompt template:" << error; if (QString::fromStdString(systemPrompt).trimmed().isEmpty()) { m_processedSystemPrompt = true; diff --git a/gpt4all-chat/mysettings.cpp b/gpt4all-chat/mysettings.cpp index b9bb76b546c2..0d6429792f4e 100644 --- a/gpt4all-chat/mysettings.cpp +++ b/gpt4all-chat/mysettings.cpp @@ -1,6 +1,8 @@ #include "mysettings.h" #include "../gpt4all-backend/llmodel.h" +#include "tool.h" +#include "toolmodel.h" #include #include @@ -18,6 +20,7 @@ #include #include +#include #include #include #include @@ -676,3 +679,45 @@ void MySettings::setLanguageAndLocale(const QString &bcp47Name) QLocale::setDefault(locale); emit languageAndLocaleChanged(); } + +QString MySettings::validateModelSystemPromptTemplate(const QString &proposedTemplate) +{ + QString error; + systemPromptInternal(proposedTemplate, error); + return error; +} + +QString MySettings::modelSystemPrompt(const ModelInfo &info, QString &error) +{ + return systemPromptInternal(modelSystemPromptTemplate(info), error); +} + +QString MySettings::systemPromptInternal(const QString &proposedTemplate, QString &error) +{ + jinja2::ValuesMap params; + params.insert({"currentDate", QDate::currentDate().toString().toStdString()}); + + jinja2::ValuesList toolList; + int c = ToolModel::globalInstance()->count(); + for (int i = 0; i < c; ++i) { + Tool *t = ToolModel::globalInstance()->get(i); + if (t->usageMode() == ToolEnums::UsageMode::Enabled) + toolList.push_back(t->jinjaValue()); + } + params.insert({"toolList", toolList}); + + QString systemPrompt; + jinja2::Template t; + const auto loadResult = t.Load(proposedTemplate.toStdString(), "systemPromptTemplate" /*Used in error messages*/); + if (!loadResult) { + error = QString::fromStdString(loadResult.error().ToString()); + return systemPrompt; + } + + const auto renderResult = t.RenderAsString(params); + if (renderResult) + systemPrompt = QString::fromStdString(renderResult.value()); + else + error = QString::fromStdString(renderResult.error().ToString()); + return systemPrompt; +} diff --git a/gpt4all-chat/mysettings.h b/gpt4all-chat/mysettings.h index 09da6681da03..6e1e8056b06a 100644 --- a/gpt4all-chat/mysettings.h +++ b/gpt4all-chat/mysettings.h @@ -204,6 +204,10 @@ class MySettings : public QObject int networkPort() const; void setNetworkPort(int value); + // Jinja aware methods for validating and parsing/rendering the system prompt + Q_INVOKABLE QString validateModelSystemPromptTemplate(const QString &proposedTemplate); + QString modelSystemPrompt(const ModelInfo &info, QString &error); + Q_SIGNALS: void nameChanged(const ModelInfo &info); void filenameChanged(const ModelInfo &info); @@ -269,6 +273,8 @@ class MySettings : public QObject void setModelSetting(const QString &name, const ModelInfo &info, const QVariant &value, bool force, bool signal = false); QString filePathForLocale(const QLocale &locale); + QString systemPromptInternal(const QString &proposedTemplate, QString &error); + }; #endif // MYSETTINGS_H diff --git a/gpt4all-chat/qml/ModelSettings.qml b/gpt4all-chat/qml/ModelSettings.qml index d2f47976f6ba..1f96807d397f 100644 --- a/gpt4all-chat/qml/ModelSettings.qml +++ b/gpt4all-chat/qml/ModelSettings.qml @@ -153,13 +153,30 @@ MySettingsTab { Layout.fillWidth: true } - MySettingsLabel { - visible: !root.currentModelInfo.isOnline - text: qsTr("System Prompt") - helpText: qsTr("Prefixed at the beginning of every conversation. Must contain the appropriate framing tokens.") + RowLayout { Layout.row: 7 Layout.column: 0 + Layout.columnSpan: 2 Layout.topMargin: 15 + spacing: 10 + MySettingsLabel { + text: qsTr("System Prompt Template") + helpText: qsTr("Prefixed at the beginning of every conversation. Must contain the appropriate framing tokens.") + } + MySettingsLabel { + id: systemPromptTemplateError + color: theme.textErrorColor + wrapMode: TextArea.Wrap + Timer { + id: errorTimer + interval: 500 // 500 ms delay + repeat: false + property string text: "" + onTriggered: { + systemPromptTemplateError.text = errorTimer.text; + } + } + } } Rectangle { @@ -188,7 +205,15 @@ MySettingsTab { } } onTextChanged: { - MySettings.setModelSystemPromptTemplate(root.currentModelInfo, text) + var errorString = MySettings.validateModelSystemPromptTemplate(text); + if (errorString === "") { + errorTimer.stop(); + systemPromptTemplateError.text = ""; // Clear any previous error + MySettings.setModelSystemPromptTemplate(root.currentModelInfo, text); + } else { + errorTimer.text = errorString; + errorTimer.restart(); + } } Accessible.role: Accessible.EditableText } diff --git a/gpt4all-chat/toolinfo.h b/gpt4all-chat/toolinfo.h deleted file mode 100644 index 91497e9daf17..000000000000 --- a/gpt4all-chat/toolinfo.h +++ /dev/null @@ -1,95 +0,0 @@ -#ifndef SOURCEEXCERT_H -#define SOURCEEXCERT_H - -#include -#include -#include -#include - -using namespace Qt::Literals::StringLiterals; - -struct SourceExcerpt { - Q_GADGET - Q_PROPERTY(QString date MEMBER date) - Q_PROPERTY(QString text MEMBER text) - Q_PROPERTY(QString collection MEMBER collection) - Q_PROPERTY(QString path MEMBER path) - Q_PROPERTY(QString file MEMBER file) - Q_PROPERTY(QString url MEMBER url) - Q_PROPERTY(QString favicon MEMBER favicon) - Q_PROPERTY(QString title MEMBER title) - Q_PROPERTY(QString author MEMBER author) - Q_PROPERTY(int page MEMBER page) - Q_PROPERTY(int from MEMBER from) - Q_PROPERTY(int to MEMBER to) - Q_PROPERTY(QString fileUri READ fileUri STORED false) - -public: - QString date; // [Required] The creation or the last modification date whichever is latest - QString text; // [Required] The text actually used in the augmented context - QString collection; // [Optional] The name of the collection - QString path; // [Optional] The full path - QString file; // [Optional] The name of the file, but not the full path - QString url; // [Optional] The name of the remote url - QString favicon; // [Optional] The favicon - QString title; // [Optional] The title of the document - QString author; // [Optional] The author of the document - int page = -1; // [Optional] The page where the text was found - int from = -1; // [Optional] The line number where the text begins - int to = -1; // [Optional] The line number where the text ends - - QString fileUri() const { - // QUrl reserved chars that are not UNSAFE_PATH according to glib/gconvert.c - static const QByteArray s_exclude = "!$&'()*+,/:=@~"_ba; - - Q_ASSERT(!QFileInfo(path).isRelative()); -#ifdef Q_OS_WINDOWS - Q_ASSERT(!path.contains('\\')); // Qt normally uses forward slash as path separator -#endif - - auto escaped = QString::fromUtf8(QUrl::toPercentEncoding(path, s_exclude)); - if (escaped.front() != '/') - escaped = '/' + escaped; - return u"file://"_s + escaped; - } - - QJsonObject toJson() const - { - QJsonObject result; - result.insert("date", date); - result.insert("text", text); - result.insert("collection", collection); - result.insert("path", path); - result.insert("file", file); - result.insert("url", url); - result.insert("favicon", favicon); - result.insert("title", title); - result.insert("author", author); - result.insert("page", page); - result.insert("from", from); - result.insert("to", to); - return result; - } - - bool operator==(const SourceExcerpt &other) const { - return date == other.date && - text == other.text && - collection == other.collection && - path == other.path && - file == other.file && - url == other.url && - favicon == other.favicon && - title == other.title && - author == other.author && - page == other.page && - from == other.from && - to == other.to; - } - bool operator!=(const SourceExcerpt &other) const { - return !(*this == other); - } -}; - -Q_DECLARE_METATYPE(SourceExcerpt) - -#endif // SOURCEEXCERT_H From a6730879dde3f4d84d4913b767ed0b22829c0d3b Mon Sep 17 00:00:00 2001 From: Adam Treat Date: Tue, 13 Aug 2024 14:48:34 -0400 Subject: [PATCH 23/30] Move to a brave search specific settings page. Signed-off-by: Adam Treat --- gpt4all-chat/CMakeLists.txt | 2 +- gpt4all-chat/bravesearch.h | 2 +- gpt4all-chat/localdocssearch.h | 2 +- ...olSettings.qml => BraveSearchSettings.qml} | 34 +++++++++++++++++-- gpt4all-chat/qml/SettingsView.qml | 4 +-- 5 files changed, 36 insertions(+), 8 deletions(-) rename gpt4all-chat/qml/{ToolSettings.qml => BraveSearchSettings.qml} (59%) diff --git a/gpt4all-chat/CMakeLists.txt b/gpt4all-chat/CMakeLists.txt index de1cf1fb2802..f41afbd737bd 100644 --- a/gpt4all-chat/CMakeLists.txt +++ b/gpt4all-chat/CMakeLists.txt @@ -138,6 +138,7 @@ qt_add_qml_module(chat qml/AddCollectionView.qml qml/AddModelView.qml qml/ApplicationSettings.qml + qml/BraveSearchSettings.qml qml/ChatDrawer.qml qml/ChatView.qml qml/CollectionsDrawer.qml @@ -156,7 +157,6 @@ qt_add_qml_module(chat qml/ThumbsDownDialog.qml qml/Toast.qml qml/ToastManager.qml - qml/ToolSettings.qml qml/MyBusyIndicator.qml qml/MyButton.qml qml/MyCheckBox.qml diff --git a/gpt4all-chat/bravesearch.h b/gpt4all-chat/bravesearch.h index 45cd0a6e9333..7450b7c8d6f1 100644 --- a/gpt4all-chat/bravesearch.h +++ b/gpt4all-chat/bravesearch.h @@ -48,7 +48,7 @@ class BraveSearch : public Tool { ToolEnums::Error error() const override { return m_error; } QString errorString() const override { return m_errorString; } - QString name() const override { return tr("Brave web search"); } + QString name() const override { return tr("Brave Web Search"); } QString description() const override { return tr("Search the web using brave"); } QString function() const override { return "brave_search"; } QJsonObject paramSchema() const override; diff --git a/gpt4all-chat/localdocssearch.h b/gpt4all-chat/localdocssearch.h index 66c63414521c..b4ea3e94b0a0 100644 --- a/gpt4all-chat/localdocssearch.h +++ b/gpt4all-chat/localdocssearch.h @@ -34,7 +34,7 @@ class LocalDocsSearch : public Tool { ToolEnums::Error error() const override { return m_error; } QString errorString() const override { return m_errorString; } - QString name() const override { return tr("LocalDocs search"); } + QString name() const override { return tr("LocalDocs Search"); } QString description() const override { return tr("Search the local docs"); } QString function() const override { return "localdocs_search"; } QJsonObject paramSchema() const override; diff --git a/gpt4all-chat/qml/ToolSettings.qml b/gpt4all-chat/qml/BraveSearchSettings.qml similarity index 59% rename from gpt4all-chat/qml/ToolSettings.qml rename to gpt4all-chat/qml/BraveSearchSettings.qml index f9b5e7277662..bc189248ea38 100644 --- a/gpt4all-chat/qml/ToolSettings.qml +++ b/gpt4all-chat/qml/BraveSearchSettings.qml @@ -16,7 +16,7 @@ MySettingsTab { showRestoreDefaultsButton: true - title: qsTr("Tools") + title: qsTr("Brave Web Search") contentItem: ColumnLayout { id: root spacing: 30 @@ -27,7 +27,7 @@ MySettingsTab { color: theme.grayRed900 font.pixelSize: theme.fontSizeLarge font.bold: true - text: qsTr("Brave Search") + text: qsTr("Brave Web Search") } Rectangle { @@ -37,6 +37,33 @@ MySettingsTab { } } + RowLayout { + MySettingsLabel { + id: usageModeLabel + text: qsTr("Usage Mode") + helpText: qsTr("When and how the brave search tool is executed.") + } + MyComboBox { + id: usageModeBox + Layout.minimumWidth: 400 + Layout.maximumWidth: 400 + Layout.alignment: Qt.AlignRight + // NOTE: indices match values of UsageMode enum, keep them in sync + model: ListModel { + ListElement { name: qsTr("Never") } + ListElement { name: qsTr("Model decides") } + ListElement { name: qsTr("Ask for confirmation before executing") } + ListElement { name: qsTr("Force usage for every response when possible") } + } + Accessible.name: usageModeLabel.text + Accessible.description: usageModeLabel.helpText + onActivated: { + } + Component.onCompleted: { + } + } + } + RowLayout { MySettingsLabel { id: apiKeyLabel @@ -51,7 +78,8 @@ MySettingsTab { color: theme.textColor font.pixelSize: theme.fontSizeLarge Layout.alignment: Qt.AlignRight - Layout.minimumWidth: 200 + Layout.minimumWidth: 400 + Layout.maximumWidth: 400 onEditingFinished: { MySettings.braveSearchAPIKey = apiKeyField.text; } diff --git a/gpt4all-chat/qml/SettingsView.qml b/gpt4all-chat/qml/SettingsView.qml index d421dad2f9ed..587a31b2901c 100644 --- a/gpt4all-chat/qml/SettingsView.qml +++ b/gpt4all-chat/qml/SettingsView.qml @@ -35,7 +35,7 @@ Rectangle { title: qsTr("LocalDocs") } ListElement { - title: qsTr("Tools") + title: qsTr("Brave Web Search") } } @@ -158,7 +158,7 @@ Rectangle { MySettingsStack { tabs: [ - Component { ToolSettings { } } + Component { BraveSearchSettings { } } ] } } From f1187207171f3034454c7fef818ee2fbd03722be Mon Sep 17 00:00:00 2001 From: Adam Treat Date: Wed, 14 Aug 2024 07:58:31 -0400 Subject: [PATCH 24/30] Don't advertise brave. Signed-off-by: Adam Treat --- gpt4all-chat/CMakeLists.txt | 2 +- gpt4all-chat/bravesearch.h | 6 +++--- gpt4all-chat/chatllm.cpp | 2 +- gpt4all-chat/qml/SettingsView.qml | 4 ++-- .../qml/{BraveSearchSettings.qml => WebSearchSettings.qml} | 4 ++-- 5 files changed, 9 insertions(+), 9 deletions(-) rename gpt4all-chat/qml/{BraveSearchSettings.qml => WebSearchSettings.qml} (97%) diff --git a/gpt4all-chat/CMakeLists.txt b/gpt4all-chat/CMakeLists.txt index f41afbd737bd..817497a5cf16 100644 --- a/gpt4all-chat/CMakeLists.txt +++ b/gpt4all-chat/CMakeLists.txt @@ -138,7 +138,6 @@ qt_add_qml_module(chat qml/AddCollectionView.qml qml/AddModelView.qml qml/ApplicationSettings.qml - qml/BraveSearchSettings.qml qml/ChatDrawer.qml qml/ChatView.qml qml/CollectionsDrawer.qml @@ -178,6 +177,7 @@ qt_add_qml_module(chat qml/MyTextField.qml qml/MyToolButton.qml qml/MyWelcomeButton.qml + qml/WebSearchSettings.qml RESOURCES icons/antenna_1.svg icons/antenna_2.svg diff --git a/gpt4all-chat/bravesearch.h b/gpt4all-chat/bravesearch.h index 7450b7c8d6f1..872e230a95e3 100644 --- a/gpt4all-chat/bravesearch.h +++ b/gpt4all-chat/bravesearch.h @@ -48,9 +48,9 @@ class BraveSearch : public Tool { ToolEnums::Error error() const override { return m_error; } QString errorString() const override { return m_errorString; } - QString name() const override { return tr("Brave Web Search"); } - QString description() const override { return tr("Search the web using brave"); } - QString function() const override { return "brave_search"; } + QString name() const override { return tr("Web Search"); } + QString description() const override { return tr("Search the web"); } + QString function() const override { return "web_search"; } QJsonObject paramSchema() const override; QJsonObject exampleParams() const override; bool isBuiltin() const override { return true; } diff --git a/gpt4all-chat/chatllm.cpp b/gpt4all-chat/chatllm.cpp index 00bf86a715f4..3500c858669e 100644 --- a/gpt4all-chat/chatllm.cpp +++ b/gpt4all-chat/chatllm.cpp @@ -874,7 +874,7 @@ bool ChatLLM::promptInternal(const QList &collectionList, const QString // FIXME: In the future this will try to match the tool call to a list of tools that are supported // according to MySettings, but for now only brave search is supported - if (tool != "brave_search" || !args.contains("query")) { + if (tool != "web_search" || !args.contains("query")) { // FIXME: Need to surface errors to the UI qWarning() << "ERROR: Could not find the tool and correct parameters for " << toolCall; return handleFailedToolCall(trimmed, elapsed); diff --git a/gpt4all-chat/qml/SettingsView.qml b/gpt4all-chat/qml/SettingsView.qml index 587a31b2901c..9a5093134c9e 100644 --- a/gpt4all-chat/qml/SettingsView.qml +++ b/gpt4all-chat/qml/SettingsView.qml @@ -35,7 +35,7 @@ Rectangle { title: qsTr("LocalDocs") } ListElement { - title: qsTr("Brave Web Search") + title: qsTr("Web Search") } } @@ -158,7 +158,7 @@ Rectangle { MySettingsStack { tabs: [ - Component { BraveSearchSettings { } } + Component { WebSearchSettings { } } ] } } diff --git a/gpt4all-chat/qml/BraveSearchSettings.qml b/gpt4all-chat/qml/WebSearchSettings.qml similarity index 97% rename from gpt4all-chat/qml/BraveSearchSettings.qml rename to gpt4all-chat/qml/WebSearchSettings.qml index bc189248ea38..222f9e30ee1d 100644 --- a/gpt4all-chat/qml/BraveSearchSettings.qml +++ b/gpt4all-chat/qml/WebSearchSettings.qml @@ -16,7 +16,7 @@ MySettingsTab { showRestoreDefaultsButton: true - title: qsTr("Brave Web Search") + title: qsTr("Web Search") contentItem: ColumnLayout { id: root spacing: 30 @@ -27,7 +27,7 @@ MySettingsTab { color: theme.grayRed900 font.pixelSize: theme.fontSizeLarge font.bold: true - text: qsTr("Brave Web Search") + text: qsTr("Web Search") } Rectangle { From 75dbf9de7d84dd8ee480844eee676cf181c85b51 Mon Sep 17 00:00:00 2001 From: Adam Treat Date: Wed, 14 Aug 2024 09:38:53 -0400 Subject: [PATCH 25/30] Handle the forced usage of tool calls outside of the recursive prompt method. Signed-off-by: Adam Treat --- gpt4all-chat/chatllm.cpp | 79 ++++++++++++++++++++++++---------------- gpt4all-chat/chatllm.h | 6 ++- 2 files changed, 52 insertions(+), 33 deletions(-) diff --git a/gpt4all-chat/chatllm.cpp b/gpt4all-chat/chatllm.cpp index 3500c858669e..7450ab1fc402 100644 --- a/gpt4all-chat/chatllm.cpp +++ b/gpt4all-chat/chatllm.cpp @@ -759,16 +759,13 @@ bool ChatLLM::prompt(const QList &collectionList, const QString &prompt } bool ChatLLM::promptInternal(const QList &collectionList, const QString &prompt, const QString &promptTemplate, - int32_t n_predict, int32_t top_k, float top_p, float min_p, float temp, int32_t n_batch, float repeat_penalty, - int32_t repeat_penalty_tokens, bool isToolCallResponse) + int32_t n_predict, int32_t top_k, float top_p, float min_p, float temp, int32_t n_batch, float repeat_penalty, + int32_t repeat_penalty_tokens) { - if (!isModelLoaded()) - return false; - - // FIXME: This should be made agnostic to localdocs and rely upon the force usage usage mode - // and also we have to honor the ask before running mode. + // FIXME: The only localdocs specific thing here should be the injection of the parameters + // FIXME: Get the list of tools ... if force usage is set, then we *try* and force usage here. QList localDocsExcerpts; - if (!collectionList.isEmpty() && !isToolCallResponse) { + if (!collectionList.isEmpty()) { LocalDocsSearch localdocs; QJsonObject parameters; parameters.insert("text", prompt); @@ -795,6 +792,27 @@ bool ChatLLM::promptInternal(const QList &collectionList, const QString docsContext = u"### Context:\n%1\n\n"_s.arg(json); } + qint64 totalTime = 0; + bool producedSourceExcerpts; + bool success = promptRecursive({ docsContext }, prompt, promptTemplate, n_predict, top_k, top_p, + min_p, temp, n_batch, repeat_penalty, repeat_penalty_tokens, totalTime, producedSourceExcerpts); + + SuggestionMode mode = MySettings::globalInstance()->suggestionMode(); + if (mode == SuggestionMode::On || (mode == SuggestionMode::SourceExcerptsOnly && (!localDocsExcerpts.isEmpty() || producedSourceExcerpts))) + generateQuestions(totalTime); + else + emit responseStopped(totalTime); + + return success; +} + +bool ChatLLM::promptRecursive(const QList &toolContexts, const QString &prompt, + const QString &promptTemplate, int32_t n_predict, int32_t top_k, float top_p, float min_p, float temp, + int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens, qint64 &totalTime, bool &producedSourceExcerpts, bool isRecursiveCall) +{ + if (!isModelLoaded()) + return false; + int n_threads = MySettings::globalInstance()->threadCount(); m_stopGenerating = false; @@ -815,19 +833,22 @@ bool ChatLLM::promptInternal(const QList &collectionList, const QString printf("%s", qPrintable(prompt)); fflush(stdout); #endif - QElapsedTimer totalTime; - totalTime.start(); + + QElapsedTimer elapsedTimer; + elapsedTimer.start(); m_timer->start(); - if (!docsContext.isEmpty()) { - auto old_n_predict = std::exchange(m_ctx.n_predict, 0); // decode localdocs context without a response - m_llModelInfo.model->prompt(docsContext.toStdString(), "%1", promptFunc, responseFunc, + + // The list of possible additional contexts that come from previous usage of tool calls + for (const QString &context : toolContexts) { + auto old_n_predict = std::exchange(m_ctx.n_predict, 0); // decode context without a response + m_llModelInfo.model->prompt(context.toStdString(), "%1", promptFunc, responseFunc, /*allowContextShift*/ true, m_ctx); m_ctx.n_predict = old_n_predict; // now we are ready for a response } - // We can't handle recursive tool calls right now otherwise we always try to check if we have a - // tool call - m_checkToolCall = !isToolCallResponse; + // We can't handle recursive tool calls right now due to the possibility of the model causing + // infinite recursion through repeated tool calls + m_checkToolCall = !isRecursiveCall; m_llModelInfo.model->prompt(prompt.toStdString(), promptTemplate.toStdString(), promptFunc, responseFunc, /*allowContextShift*/ true, m_ctx); @@ -841,7 +862,7 @@ bool ChatLLM::promptInternal(const QList &collectionList, const QString fflush(stdout); #endif m_timer->stop(); - qint64 elapsed = totalTime.elapsed(); + totalTime = elapsedTimer.elapsed(); std::string trimmed = trim_whitespace(m_response); // If we found a tool call, then deal with it @@ -852,7 +873,7 @@ bool ChatLLM::promptInternal(const QList &collectionList, const QString const QString toolTemplate = MySettings::globalInstance()->modelToolTemplate(m_modelInfo); if (toolTemplate.isEmpty()) { qWarning() << "ERROR: No valid tool template for this model" << toolCall; - return handleFailedToolCall(trimmed, elapsed); + return handleFailedToolCall(trimmed, totalTime); } QJsonParseError err; @@ -860,13 +881,13 @@ bool ChatLLM::promptInternal(const QList &collectionList, const QString if (toolCallDoc.isNull() || err.error != QJsonParseError::NoError || !toolCallDoc.isObject()) { qWarning() << "ERROR: The tool call had null or invalid json " << toolCall; - return handleFailedToolCall(trimmed, elapsed); + return handleFailedToolCall(trimmed, totalTime); } QJsonObject rootObject = toolCallDoc.object(); if (!rootObject.contains("name") || !rootObject.contains("parameters")) { qWarning() << "ERROR: The tool call did not have required name and argument objects " << toolCall; - return handleFailedToolCall(trimmed, elapsed); + return handleFailedToolCall(trimmed, totalTime); } const QString tool = toolCallDoc["name"].toString(); @@ -877,7 +898,7 @@ bool ChatLLM::promptInternal(const QList &collectionList, const QString if (tool != "web_search" || !args.contains("query")) { // FIXME: Need to surface errors to the UI qWarning() << "ERROR: Could not find the tool and correct parameters for " << toolCall; - return handleFailedToolCall(trimmed, elapsed); + return handleFailedToolCall(trimmed, totalTime); } const QString query = args["query"].toString(); @@ -900,6 +921,7 @@ bool ChatLLM::promptInternal(const QList &collectionList, const QString if (!parseError.isEmpty()) { qWarning() << "ERROR: Could not parse source excerpts for brave response:" << parseError; } else if (!sourceExcerpts.isEmpty()) { + producedSourceExcerpts = true; emit sourceExcerptsChanged(sourceExcerpts); } @@ -907,23 +929,16 @@ bool ChatLLM::promptInternal(const QList &collectionList, const QString m_promptTokens = 0; m_response = std::string(); - // This is a recursive call but isToolCallResponse is checked above to arrest infinite recursive + // This is a recursive call but isRecursiveCall is checked above to arrest infinite recursive // tool calls - return promptInternal(QList()/*collectionList*/, braveResponse, toolTemplate, - n_predict, top_k, top_p, min_p, temp, n_batch, repeat_penalty, repeat_penalty_tokens, - true /*isToolCallResponse*/); - + return promptRecursive(QList()/*collectionList*/, braveResponse, toolTemplate, + n_predict, top_k, top_p, min_p, temp, n_batch, repeat_penalty, repeat_penalty_tokens, totalTime, + producedSourceExcerpts, true /*isRecursiveCall*/); } else { if (trimmed != m_response) { m_response = trimmed; emit responseChanged(QString::fromStdString(m_response)); } - - SuggestionMode mode = MySettings::globalInstance()->suggestionMode(); - if (mode == SuggestionMode::On || (mode == SuggestionMode::SourceExcerptsOnly && (!localDocsExcerpts.isEmpty() || isToolCallResponse))) - generateQuestions(elapsed); - else - emit responseStopped(elapsed); m_pristineLoadedState = false; return true; } diff --git a/gpt4all-chat/chatllm.h b/gpt4all-chat/chatllm.h index feacd744f228..7622c3661446 100644 --- a/gpt4all-chat/chatllm.h +++ b/gpt4all-chat/chatllm.h @@ -196,9 +196,10 @@ public Q_SLOTS: void modelInfoChanged(const ModelInfo &modelInfo); protected: + // FIXME: This is only available because of server which sucks bool promptInternal(const QList &collectionList, const QString &prompt, const QString &promptTemplate, int32_t n_predict, int32_t top_k, float top_p, float min_p, float temp, int32_t n_batch, float repeat_penalty, - int32_t repeat_penalty_tokens, bool isToolCallResponse = false); + int32_t repeat_penalty_tokens); bool handleFailedToolCall(const std::string &toolCall, qint64 elapsed); bool handlePrompt(int32_t token); bool handleResponse(int32_t token, const std::string &response); @@ -219,6 +220,9 @@ public Q_SLOTS: quint32 m_promptResponseTokens; private: + bool promptRecursive(const QList &toolContexts, const QString &prompt, const QString &promptTemplate, + int32_t n_predict, int32_t top_k, float top_p, float min_p, float temp, int32_t n_batch, float repeat_penalty, + int32_t repeat_penalty_tokens, qint64 &totalTime, bool &producedSourceExcerpts, bool isRecursiveCall = false); bool loadNewModel(const ModelInfo &modelInfo, QVariantMap &modelLoadProps); std::string m_response; From 991afc6ef248c1ea52081799e77a90e2364d6a74 Mon Sep 17 00:00:00 2001 From: Adam Treat Date: Wed, 14 Aug 2024 10:24:06 -0400 Subject: [PATCH 26/30] Abstract the built-in web search completely away from ChatLLM. Signed-off-by: Adam Treat --- gpt4all-chat/bravesearch.cpp | 5 ++-- gpt4all-chat/bravesearch.h | 2 +- gpt4all-chat/chatllm.cpp | 49 ++++++++++++++++------------------- gpt4all-chat/qml/ChatView.qml | 4 +-- 4 files changed, 28 insertions(+), 32 deletions(-) diff --git a/gpt4all-chat/bravesearch.cpp b/gpt4all-chat/bravesearch.cpp index 197ed383af5c..0a0ef8f9f138 100644 --- a/gpt4all-chat/bravesearch.cpp +++ b/gpt4all-chat/bravesearch.cpp @@ -1,4 +1,5 @@ #include "bravesearch.h" +#include "mysettings.h" #include #include @@ -18,9 +19,9 @@ using namespace Qt::Literals::StringLiterals; QString BraveSearch::run(const QJsonObject ¶meters, qint64 timeout) { - const QString apiKey = parameters["apiKey"].toString(); + const QString apiKey = MySettings::globalInstance()->braveSearchAPIKey(); const QString query = parameters["query"].toString(); - const int count = parameters["count"].toInt(); + const int count = 2; // FIXME: This should be a setting QThread workerThread; BraveAPIWorker worker; worker.moveToThread(&workerThread); diff --git a/gpt4all-chat/bravesearch.h b/gpt4all-chat/bravesearch.h index 872e230a95e3..0cf1118b1496 100644 --- a/gpt4all-chat/bravesearch.h +++ b/gpt4all-chat/bravesearch.h @@ -34,7 +34,7 @@ private Q_SLOTS: private: QNetworkAccessManager *m_networkManager; QString m_response; - ToolEnums::Error m_error; + ToolEnums::Error m_error = ToolEnums::Error::NoError; QString m_errorString; }; diff --git a/gpt4all-chat/chatllm.cpp b/gpt4all-chat/chatllm.cpp index 7450ab1fc402..d86de0d01c6a 100644 --- a/gpt4all-chat/chatllm.cpp +++ b/gpt4all-chat/chatllm.cpp @@ -1,6 +1,5 @@ #include "chatllm.h" -#include "bravesearch.h" #include "chat.h" #include "chatapi.h" #include "localdocssearch.h" @@ -893,36 +892,31 @@ bool ChatLLM::promptRecursive(const QList &toolContexts, const QString const QString tool = toolCallDoc["name"].toString(); const QJsonObject args = toolCallDoc["parameters"].toObject(); - // FIXME: In the future this will try to match the tool call to a list of tools that are supported - // according to MySettings, but for now only brave search is supported - if (tool != "web_search" || !args.contains("query")) { - // FIXME: Need to surface errors to the UI - qWarning() << "ERROR: Could not find the tool and correct parameters for " << toolCall; + Tool *toolInstance = ToolModel::globalInstance()->get(tool); + if (!toolInstance) { + qWarning() << "ERROR: Could not find the tool for " << toolCall; return handleFailedToolCall(trimmed, totalTime); } - const QString query = args["query"].toString(); + // Inform the chat that we're executing a tool call + emit toolCalled(toolInstance->name().toLower()); - emit toolCalled(tr("searching web...")); - const QString apiKey = MySettings::globalInstance()->braveSearchAPIKey(); - Q_ASSERT(apiKey != ""); - BraveSearch brave; - - QJsonObject parameters; - parameters.insert("apiKey", apiKey); - parameters.insert("query", query); - parameters.insert("count", 2); - - // FIXME: Need to surface errors to the UI - const QString braveResponse = brave.run(parameters, 2000 /*msecs to timeout*/); + const QString response = toolInstance->run(args, 2000 /*msecs to timeout*/); + if (toolInstance->error() != ToolEnums::Error::NoError) { + qWarning() << "ERROR: Tool call produced error:" << toolInstance->errorString(); + return handleFailedToolCall(trimmed, totalTime); + } - QString parseError; - QList sourceExcerpts = SourceExcerpt::fromJson(braveResponse, parseError); - if (!parseError.isEmpty()) { - qWarning() << "ERROR: Could not parse source excerpts for brave response:" << parseError; - } else if (!sourceExcerpts.isEmpty()) { - producedSourceExcerpts = true; - emit sourceExcerptsChanged(sourceExcerpts); + // If the tool supports excerpts then try to parse them here + if (toolInstance->excerpts()) { + QString parseError; + QList sourceExcerpts = SourceExcerpt::fromJson(response, parseError); + if (!parseError.isEmpty()) { + qWarning() << "ERROR: Could not parse source excerpts for response:" << parseError; + } else if (!sourceExcerpts.isEmpty()) { + producedSourceExcerpts = true; + emit sourceExcerptsChanged(sourceExcerpts); + } } m_promptResponseTokens = 0; @@ -931,7 +925,7 @@ bool ChatLLM::promptRecursive(const QList &toolContexts, const QString // This is a recursive call but isRecursiveCall is checked above to arrest infinite recursive // tool calls - return promptRecursive(QList()/*collectionList*/, braveResponse, toolTemplate, + return promptRecursive(QList()/*tool context*/, response, toolTemplate, n_predict, top_k, top_p, min_p, temp, n_batch, repeat_penalty, repeat_penalty_tokens, totalTime, producedSourceExcerpts, true /*isRecursiveCall*/); } else { @@ -946,6 +940,7 @@ bool ChatLLM::promptRecursive(const QList &toolContexts, const QString bool ChatLLM::handleFailedToolCall(const std::string &response, qint64 elapsed) { + // FIXME: Need to surface errors to the UI // Restore the strings that we excluded previously when detecting the tool call m_response = "" + response + ""; emit responseChanged(QString::fromStdString(m_response)); diff --git a/gpt4all-chat/qml/ChatView.qml b/gpt4all-chat/qml/ChatView.qml index 1417f42b0971..58926b7e398e 100644 --- a/gpt4all-chat/qml/ChatView.qml +++ b/gpt4all-chat/qml/ChatView.qml @@ -881,8 +881,8 @@ Rectangle { case Chat.PromptProcessing: return qsTr("processing ...") case Chat.ResponseGeneration: return qsTr("generating response ..."); case Chat.GeneratingQuestions: return qsTr("generating questions ..."); - case Chat.ToolCalled: return currentChat.toolDescription; - case Chat.ToolProcessing: return qsTr("processing web results ..."); // FIXME should not be hardcoded! + case Chat.ToolCalled: return qsTr("executing %1 ...").arg(currentChat.toolDescription); + case Chat.ToolProcessing: return qsTr("processing %1 results ...").arg(currentChat.toolDescription); default: return ""; // handle unexpected values } } From 4cb95694ffd515575268f202872bd16dc0519fef Mon Sep 17 00:00:00 2001 From: Adam Treat Date: Wed, 14 Aug 2024 11:20:53 -0400 Subject: [PATCH 27/30] Breakout the ask before running which can be thought of as a security feature. Signed-off-by: Adam Treat --- gpt4all-chat/chatllm.cpp | 2 ++ gpt4all-chat/tool.h | 6 +++++- gpt4all-chat/toolmodel.h | 4 ++++ 3 files changed, 11 insertions(+), 1 deletion(-) diff --git a/gpt4all-chat/chatllm.cpp b/gpt4all-chat/chatllm.cpp index d86de0d01c6a..09a8928797f0 100644 --- a/gpt4all-chat/chatllm.cpp +++ b/gpt4all-chat/chatllm.cpp @@ -761,6 +761,7 @@ bool ChatLLM::promptInternal(const QList &collectionList, const QString int32_t n_predict, int32_t top_k, float top_p, float min_p, float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens) { + // FIXME: Honor the ask before running feature // FIXME: The only localdocs specific thing here should be the injection of the parameters // FIXME: Get the list of tools ... if force usage is set, then we *try* and force usage here. QList localDocsExcerpts; @@ -898,6 +899,7 @@ bool ChatLLM::promptRecursive(const QList &toolContexts, const QString return handleFailedToolCall(trimmed, totalTime); } + // FIXME: Honor the ask before running feature // Inform the chat that we're executing a tool call emit toolCalled(toolInstance->name().toLower()); diff --git a/gpt4all-chat/tool.h b/gpt4all-chat/tool.h index 92a3b758efbc..86f95904fa25 100644 --- a/gpt4all-chat/tool.h +++ b/gpt4all-chat/tool.h @@ -19,7 +19,6 @@ namespace ToolEnums { enum class UsageMode { Disabled, // Completely disabled Enabled, // Enabled and the model decides whether to run - AskBeforeRunning, // Enabled and model decides but the user is queried whether they want the tool to run in every instance ForceUsage, // Attempt to force usage of the tool rather than let the LLM decide. NOTE: Not always possible. }; Q_ENUM_NS(UsageMode) @@ -35,6 +34,7 @@ class Tool : public QObject { Q_PROPERTY(QUrl url READ url CONSTANT) Q_PROPERTY(bool isBuiltin READ isBuiltin CONSTANT) Q_PROPERTY(ToolEnums::UsageMode usageMode READ usageMode NOTIFY usageModeChanged) + Q_PROPERTY(bool askBeforeRunning READ askBeforeRunning NOTIFY askBeforeRunningChanged) Q_PROPERTY(bool excerpts READ excerpts CONSTANT) public: @@ -74,6 +74,9 @@ class Tool : public QObject { // [Optional] The current usage mode virtual ToolEnums::UsageMode usageMode() const { return ToolEnums::UsageMode::Disabled; } + // [Optional] The user is queried whether they want the tool to run in every instance + virtual bool askBeforeRunning() const { return false; } + // [Optional] Whether json result produces source excerpts. virtual bool excerpts() const { return false; } @@ -88,6 +91,7 @@ class Tool : public QObject { Q_SIGNALS: void usageModeChanged(); + void askBeforeRunningChanged(); }; #endif // TOOL_H diff --git a/gpt4all-chat/toolmodel.h b/gpt4all-chat/toolmodel.h index f1b599229fef..315b73604fa9 100644 --- a/gpt4all-chat/toolmodel.h +++ b/gpt4all-chat/toolmodel.h @@ -23,6 +23,7 @@ class ToolModel : public QAbstractListModel KeyRequiredRole, IsBuiltinRole, UsageModeRole, + AskBeforeRole, ExcerptsRole, }; @@ -53,6 +54,8 @@ class ToolModel : public QAbstractListModel return item->isBuiltin(); case UsageModeRole: return QVariant::fromValue(item->usageMode()); + case AskBeforeRole: + return item->askBeforeRunning(); case ExcerptsRole: return item->excerpts(); } @@ -72,6 +75,7 @@ class ToolModel : public QAbstractListModel roles[KeyRequiredRole] = "keyRequired"; roles[IsBuiltinRole] = "isBuiltin"; roles[UsageModeRole] = "usageMode"; + roles[AskBeforeRole] = "askBeforeRunning"; roles[ExcerptsRole] = "excerpts"; return roles; } From 054ca43d52437ee086fca59db7d1d2e143f3d729 Mon Sep 17 00:00:00 2001 From: Adam Treat Date: Wed, 14 Aug 2024 12:01:48 -0400 Subject: [PATCH 28/30] Display the antenna by introducing notion of privacy scopes to tools. Signed-off-by: Adam Treat --- gpt4all-chat/bravesearch.h | 1 + gpt4all-chat/localdocssearch.h | 1 + gpt4all-chat/main.cpp | 3 +++ gpt4all-chat/main.qml | 12 ++++++++++-- gpt4all-chat/tool.h | 12 ++++++++++++ gpt4all-chat/toolmodel.cpp | 11 +++++++++++ gpt4all-chat/toolmodel.h | 9 +++++++++ 7 files changed, 47 insertions(+), 2 deletions(-) diff --git a/gpt4all-chat/bravesearch.h b/gpt4all-chat/bravesearch.h index 0cf1118b1496..ef7c14e15776 100644 --- a/gpt4all-chat/bravesearch.h +++ b/gpt4all-chat/bravesearch.h @@ -51,6 +51,7 @@ class BraveSearch : public Tool { QString name() const override { return tr("Web Search"); } QString description() const override { return tr("Search the web"); } QString function() const override { return "web_search"; } + ToolEnums::PrivacyScope privacyScope() const override { return ToolEnums::PrivacyScope::None; } QJsonObject paramSchema() const override; QJsonObject exampleParams() const override; bool isBuiltin() const override { return true; } diff --git a/gpt4all-chat/localdocssearch.h b/gpt4all-chat/localdocssearch.h index b4ea3e94b0a0..be9ed04f3b37 100644 --- a/gpt4all-chat/localdocssearch.h +++ b/gpt4all-chat/localdocssearch.h @@ -37,6 +37,7 @@ class LocalDocsSearch : public Tool { QString name() const override { return tr("LocalDocs Search"); } QString description() const override { return tr("Search the local docs"); } QString function() const override { return "localdocs_search"; } + ToolEnums::PrivacyScope privacyScope() const override { return ToolEnums::PrivacyScope::Local; } QJsonObject paramSchema() const override; bool isBuiltin() const override { return true; } ToolEnums::UsageMode usageMode() const override { return ToolEnums::UsageMode::ForceUsage; } diff --git a/gpt4all-chat/main.cpp b/gpt4all-chat/main.cpp index 4546a95bcf32..b3cef897a4ca 100644 --- a/gpt4all-chat/main.cpp +++ b/gpt4all-chat/main.cpp @@ -7,6 +7,7 @@ #include "modellist.h" #include "mysettings.h" #include "network.h" +#include "toolmodel.h" #include "../gpt4all-backend/llmodel.h" @@ -67,6 +68,8 @@ int main(int argc, char *argv[]) qmlRegisterSingletonInstance("download", 1, 0, "Download", Download::globalInstance()); qmlRegisterSingletonInstance("network", 1, 0, "Network", Network::globalInstance()); qmlRegisterSingletonInstance("localdocs", 1, 0, "LocalDocs", LocalDocs::globalInstance()); + qmlRegisterSingletonInstance("toollist", 1, 0, "ToolList", ToolModel::globalInstance()); + qmlRegisterUncreatableMetaObject(ToolEnums::staticMetaObject, "toolenums", 1, 0, "ToolEnums", "Error: only enums"); qmlRegisterUncreatableMetaObject(MySettingsEnums::staticMetaObject, "mysettingsenums", 1, 0, "MySettingsEnums", "Error: only enums"); const QUrl url(u"qrc:/gpt4all/main.qml"_qs); diff --git a/gpt4all-chat/main.qml b/gpt4all-chat/main.qml index 16b5b65cf502..ad1f0295a0f1 100644 --- a/gpt4all-chat/main.qml +++ b/gpt4all-chat/main.qml @@ -12,6 +12,8 @@ import network import gpt4all import localdocs import mysettings +import toollist +import toolenums Window { id: window @@ -413,7 +415,11 @@ Window { ColorOverlay { id: antennaColored - visible: ModelList.selectableModels.count !== 0 && (currentChat.isServer || currentChat.modelInfo.isOnline || MySettings.networkIsActive) + visible: ModelList.selectableModels.count !== 0 + && (MySettings.networkIsActive + || currentChat.modelInfo.isOnline + || currentChat.isServer + || ToolList.privacyScope === ToolEnums.PrivacyScope.None) anchors.fill: antennaImage source: antennaImage color: theme.styledTextColor @@ -422,8 +428,10 @@ Window { return qsTr("The datalake is enabled") else if (currentChat.modelInfo.isOnline) return qsTr("Using a network model") - else if (currentChat.modelInfo.isOnline) + else if (currentChat.isServer) return qsTr("Server mode is enabled") + else if (ToolList.privacyScope === ToolEnums.PrivacyScope.None) + return qsTr("One or more enabled tools is not private") return "" } ToolTip.visible: maAntenna.containsMouse diff --git a/gpt4all-chat/tool.h b/gpt4all-chat/tool.h index 86f95904fa25..88a3dbc9de87 100644 --- a/gpt4all-chat/tool.h +++ b/gpt4all-chat/tool.h @@ -22,6 +22,14 @@ namespace ToolEnums { ForceUsage, // Attempt to force usage of the tool rather than let the LLM decide. NOTE: Not always possible. }; Q_ENUM_NS(UsageMode) + + // Ordered in increasing levels of privacy + enum class PrivacyScope { + None = 0, // Tool call data does not have any privacy scope + LocalOrg = 1, // Tool call data does not leave the local organization + Local = 2 // Tool call data does not leave the machine + }; + Q_ENUM_NS(PrivacyScope) } class Tool : public QObject { @@ -29,6 +37,7 @@ class Tool : public QObject { Q_PROPERTY(QString name READ name CONSTANT) Q_PROPERTY(QString description READ description CONSTANT) Q_PROPERTY(QString function READ function CONSTANT) + Q_PROPERTY(ToolEnums::PrivacyScope privacyScope READ privacyScope CONSTANT) Q_PROPERTY(QJsonObject paramSchema READ paramSchema CONSTANT) Q_PROPERTY(QJsonObject exampleParams READ exampleParams CONSTANT) Q_PROPERTY(QUrl url READ url CONSTANT) @@ -54,6 +63,9 @@ class Tool : public QObject { // [Required] Must be unique. Name of the function to invoke. Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length of 64. virtual QString function() const = 0; + // [Required] The privacy scope + virtual ToolEnums::PrivacyScope privacyScope() const = 0; + // [Optional] Json schema describing the tool's parameters. An empty object specifies no parameters. // https://json-schema.org/understanding-json-schema/ // https://platform.openai.com/docs/api-reference/runs/createRun#runs-createrun-tools diff --git a/gpt4all-chat/toolmodel.cpp b/gpt4all-chat/toolmodel.cpp index 9f8547eaa2f0..a2fdbdc020da 100644 --- a/gpt4all-chat/toolmodel.cpp +++ b/gpt4all-chat/toolmodel.cpp @@ -22,10 +22,12 @@ ToolModel::ToolModel() Tool* localDocsSearch = new LocalDocsSearch; m_tools.append(localDocsSearch); m_toolMap.insert(localDocsSearch->function(), localDocsSearch); + connect(localDocsSearch, &Tool::usageModeChanged, this, &ToolModel::privacyScopeChanged); Tool *braveSearch = new BraveSearch; m_tools.append(braveSearch); m_toolMap.insert(braveSearch->function(), braveSearch); + connect(braveSearch, &Tool::usageModeChanged, this, &ToolModel::privacyScopeChanged); } bool ToolModel::eventFilter(QObject *obj, QEvent *ev) @@ -34,3 +36,12 @@ bool ToolModel::eventFilter(QObject *obj, QEvent *ev) emit dataChanged(index(0, 0), index(m_tools.size() - 1, 0)); return false; } + +ToolEnums::PrivacyScope ToolModel::privacyScope() const +{ + ToolEnums::PrivacyScope scope = ToolEnums::PrivacyScope::Local; // highest scope + for (const Tool *t : m_tools) + if (t->usageMode() != ToolEnums::UsageMode::Disabled) + scope = std::min(scope, t->privacyScope()); + return scope; +} diff --git a/gpt4all-chat/toolmodel.h b/gpt4all-chat/toolmodel.h index 315b73604fa9..4135771e8f9a 100644 --- a/gpt4all-chat/toolmodel.h +++ b/gpt4all-chat/toolmodel.h @@ -9,6 +9,7 @@ class ToolModel : public QAbstractListModel { Q_OBJECT Q_PROPERTY(int count READ count NOTIFY countChanged) + Q_PROPERTY(ToolEnums::PrivacyScope privacyScope READ privacyScope NOTIFY privacyScopeChanged) public: static ToolModel *globalInstance(); @@ -17,6 +18,7 @@ class ToolModel : public QAbstractListModel NameRole = Qt::UserRole + 1, DescriptionRole, FunctionRole, + PrivacyScopeRole, ParametersRole, UrlRole, ApiKeyRole, @@ -46,6 +48,8 @@ class ToolModel : public QAbstractListModel return item->description(); case FunctionRole: return item->function(); + case PrivacyScopeRole: + return QVariant::fromValue(item->privacyScope()); case ParametersRole: return item->paramSchema(); case UrlRole: @@ -69,6 +73,7 @@ class ToolModel : public QAbstractListModel roles[NameRole] = "name"; roles[DescriptionRole] = "description"; roles[FunctionRole] = "function"; + roles[PrivacyScopeRole] = "privacyScope"; roles[ParametersRole] = "parameters"; roles[UrlRole] = "url"; roles[ApiKeyRole] = "apiKey"; @@ -94,8 +99,12 @@ class ToolModel : public QAbstractListModel int count() const { return m_tools.size(); } + // Returns the least private scope of all enabled tools + ToolEnums::PrivacyScope privacyScope() const; + Q_SIGNALS: void countChanged(); + void privacyScopeChanged(); void valueChanged(int index, const QString &value); protected: From 3a564688b1fa23ab4240737c7f1cda37ed4b462c Mon Sep 17 00:00:00 2001 From: Adam Treat Date: Wed, 14 Aug 2024 16:19:28 -0400 Subject: [PATCH 29/30] Implement all the settings for the web search. Signed-off-by: Adam Treat --- gpt4all-chat/bravesearch.cpp | 12 +++-- gpt4all-chat/bravesearch.h | 2 +- gpt4all-chat/mysettings.cpp | 53 +++++++++++++++------- gpt4all-chat/mysettings.h | 17 ++++++- gpt4all-chat/qml/ApplicationSettings.qml | 11 ++++- gpt4all-chat/qml/LocalDocsSettings.qml | 10 +++-- gpt4all-chat/qml/WebSearchSettings.qml | 57 ++++++++++++++++++++++-- gpt4all-chat/tool.h | 6 +-- 8 files changed, 134 insertions(+), 34 deletions(-) diff --git a/gpt4all-chat/bravesearch.cpp b/gpt4all-chat/bravesearch.cpp index 0a0ef8f9f138..e5505437f362 100644 --- a/gpt4all-chat/bravesearch.cpp +++ b/gpt4all-chat/bravesearch.cpp @@ -17,11 +17,18 @@ using namespace Qt::Literals::StringLiterals; +BraveSearch::BraveSearch() + : Tool(), m_error(ToolEnums::Error::NoError) +{ + connect(MySettings::globalInstance(), &MySettings::webSearchUsageModeChanged, + this, &Tool::usageModeChanged); +} + QString BraveSearch::run(const QJsonObject ¶meters, qint64 timeout) { const QString apiKey = MySettings::globalInstance()->braveSearchAPIKey(); const QString query = parameters["query"].toString(); - const int count = 2; // FIXME: This should be a setting + const int count = MySettings::globalInstance()->webSearchRetrievalSize(); QThread workerThread; BraveAPIWorker worker; worker.moveToThread(&workerThread); @@ -83,8 +90,7 @@ QJsonObject BraveSearch::exampleParams() const ToolEnums::UsageMode BraveSearch::usageMode() const { - // FIXME: This needs to be a setting - return ToolEnums::UsageMode::Enabled; + return MySettings::globalInstance()->webSearchUsageMode(); } void BraveAPIWorker::request(const QString &apiKey, const QString &query, int count) diff --git a/gpt4all-chat/bravesearch.h b/gpt4all-chat/bravesearch.h index ef7c14e15776..78a73b60a893 100644 --- a/gpt4all-chat/bravesearch.h +++ b/gpt4all-chat/bravesearch.h @@ -41,7 +41,7 @@ private Q_SLOTS: class BraveSearch : public Tool { Q_OBJECT public: - BraveSearch() : Tool(), m_error(ToolEnums::Error::NoError) {} + BraveSearch(); virtual ~BraveSearch() {} QString run(const QJsonObject ¶meters, qint64 timeout = 2000) override; diff --git a/gpt4all-chat/mysettings.cpp b/gpt4all-chat/mysettings.cpp index 0d6429792f4e..999efd6b321f 100644 --- a/gpt4all-chat/mysettings.cpp +++ b/gpt4all-chat/mysettings.cpp @@ -26,6 +26,7 @@ #include using namespace Qt::Literals::StringLiterals; +using namespace ToolEnums; // used only for settings serialization, do not translate static const QStringList suggestionModeNames { "LocalDocsOnly", "On", "Off" }; @@ -47,22 +48,26 @@ static const QString languageAndLocale = "System Locale"; } // namespace defaults static const QVariantMap basicDefaults { - { "chatTheme", QVariant::fromValue(ChatTheme::Light) }, - { "fontSize", QVariant::fromValue(FontSize::Small) }, - { "lastVersionStarted", "" }, - { "networkPort", 4891, }, - { "saveChatsContext", false }, - { "serverChat", false }, - { "userDefaultModel", "Application default" }, - { "suggestionMode", QVariant::fromValue(SuggestionMode::SourceExcerptsOnly) }, - { "localdocs/chunkSize", 512 }, - { "localdocs/retrievalSize", 3 }, - { "localdocs/showReferences", true }, - { "localdocs/fileExtensions", QStringList { "txt", "pdf", "md", "rst" } }, - { "localdocs/useRemoteEmbed", false }, - { "localdocs/nomicAPIKey", "" }, - { "localdocs/embedDevice", "Auto" }, - { "network/attribution", "" }, + { "chatTheme", QVariant::fromValue(ChatTheme::Light) }, + { "fontSize", QVariant::fromValue(FontSize::Small) }, + { "lastVersionStarted", "" }, + { "networkPort", 4891, }, + { "saveChatsContext", false }, + { "serverChat", false }, + { "userDefaultModel", "Application default" }, + { "suggestionMode", QVariant::fromValue(SuggestionMode::SourceExcerptsOnly) }, + { "localdocs/chunkSize", 512 }, + { "localdocs/retrievalSize", 3 }, + { "localdocs/showReferences", true }, + { "localdocs/fileExtensions", QStringList { "txt", "pdf", "md", "rst" } }, + { "localdocs/useRemoteEmbed", false }, + { "localdocs/nomicAPIKey", "" }, + { "localdocs/embedDevice", "Auto" }, + { "network/attribution", "" }, + { "websearch/usageMode", QVariant::fromValue(UsageMode::Disabled) }, + { "websearch/retrievalSize", 2 }, + { "websearch/askBeforeRunning", false }, + { "bravesearch/APIKey", "" }, }; static QString defaultLocalModelsPath() @@ -230,6 +235,14 @@ void MySettings::restoreLocalDocsDefaults() setLocalDocsEmbedDevice(basicDefaults.value("localdocs/embedDevice").toString()); } +void MySettings::restoreWebSearchDefaults() +{ + setWebSearchUsageMode(basicDefaults.value("websearch/usageMode").value()); + setWebSearchRetrievalSize(basicDefaults.value("websearch/retrievalSize").toInt()); + setWebSearchAskBeforeRunning(basicDefaults.value("websearch/askBeforeRunning").toBool()); + setBraveSearchAPIKey(basicDefaults.value("bravesearch/APIKey").toString()); +} + void MySettings::eraseModel(const ModelInfo &info) { m_settings.remove(u"model-%1"_s.arg(info.id())); @@ -467,6 +480,9 @@ QString MySettings::localDocsNomicAPIKey() const { return getBasicSetting QString MySettings::localDocsEmbedDevice() const { return getBasicSetting("localdocs/embedDevice" ).toString(); } QString MySettings::networkAttribution() const { return getBasicSetting("network/attribution" ).toString(); } QString MySettings::braveSearchAPIKey() const { return getBasicSetting("bravesearch/APIKey" ).toString(); } +int MySettings::webSearchRetrievalSize() const { return getBasicSetting("websearch/retrievalSize").toInt(); } +bool MySettings::webSearchAskBeforeRunning() const { return getBasicSetting("websearch/askBeforeRunning").toBool(); } +UsageMode MySettings::webSearchUsageMode() const { return getBasicSetting("websearch/usageMode").value(); } ChatTheme MySettings::chatTheme() const { return ChatTheme (getEnumSetting("chatTheme", chatThemeNames)); } FontSize MySettings::fontSize() const { return FontSize (getEnumSetting("fontSize", fontSizeNames)); } @@ -486,6 +502,9 @@ void MySettings::setLocalDocsNomicAPIKey(const QString &value) { setBasic void MySettings::setLocalDocsEmbedDevice(const QString &value) { setBasicSetting("localdocs/embedDevice", value, "localDocsEmbedDevice"); } void MySettings::setNetworkAttribution(const QString &value) { setBasicSetting("network/attribution", value, "networkAttribution"); } void MySettings::setBraveSearchAPIKey(const QString &value) { setBasicSetting("bravesearch/APIKey", value, "braveSearchAPIKey"); } +void MySettings::setWebSearchUsageMode(ToolEnums::UsageMode value) { setBasicSetting("websearch/usageMode", int(value), "webSearchUsageMode"); } +void MySettings::setWebSearchRetrievalSize(int value) { setBasicSetting("websearch/retrievalSize", value, "webSearchRetrievalSize"); } +void MySettings::setWebSearchAskBeforeRunning(bool value) { setBasicSetting("websearch/askBeforeRunning", value, "webSearchAskBeforeRunning"); } void MySettings::setChatTheme(ChatTheme value) { setBasicSetting("chatTheme", chatThemeNames .value(int(value))); } void MySettings::setFontSize(FontSize value) { setBasicSetting("fontSize", fontSizeNames .value(int(value))); } @@ -701,7 +720,7 @@ QString MySettings::systemPromptInternal(const QString &proposedTemplate, QStrin int c = ToolModel::globalInstance()->count(); for (int i = 0; i < c; ++i) { Tool *t = ToolModel::globalInstance()->get(i); - if (t->usageMode() == ToolEnums::UsageMode::Enabled) + if (t->usageMode() == UsageMode::Enabled) toolList.push_back(t->jinjaValue()); } params.insert({"toolList", toolList}); diff --git a/gpt4all-chat/mysettings.h b/gpt4all-chat/mysettings.h index 6e1e8056b06a..1518412b01f5 100644 --- a/gpt4all-chat/mysettings.h +++ b/gpt4all-chat/mysettings.h @@ -2,6 +2,7 @@ #define MYSETTINGS_H #include "modellist.h" // IWYU pragma: keep +#include "tool.h" #include #include @@ -72,6 +73,9 @@ class MySettings : public QObject Q_PROPERTY(int networkPort READ networkPort WRITE setNetworkPort NOTIFY networkPortChanged) Q_PROPERTY(SuggestionMode suggestionMode READ suggestionMode WRITE setSuggestionMode NOTIFY suggestionModeChanged) Q_PROPERTY(QStringList uiLanguages MEMBER m_uiLanguages CONSTANT) + Q_PROPERTY(ToolEnums::UsageMode webSearchUsageMode READ webSearchUsageMode WRITE setWebSearchUsageMode NOTIFY webSearchUsageModeChanged) + Q_PROPERTY(int webSearchRetrievalSize READ webSearchRetrievalSize WRITE setWebSearchRetrievalSize NOTIFY webSearchRetrievalSizeChanged) + Q_PROPERTY(bool webSearchAskBeforeRunning READ webSearchAskBeforeRunning WRITE setWebSearchAskBeforeRunning NOTIFY webSearchAskBeforeRunningChanged) Q_PROPERTY(QString braveSearchAPIKey READ braveSearchAPIKey WRITE setBraveSearchAPIKey NOTIFY braveSearchAPIKeyChanged) public: @@ -81,6 +85,7 @@ class MySettings : public QObject Q_INVOKABLE void restoreModelDefaults(const ModelInfo &info); Q_INVOKABLE void restoreApplicationDefaults(); Q_INVOKABLE void restoreLocalDocsDefaults(); + Q_INVOKABLE void restoreWebSearchDefaults(); // Model/Character settings void eraseModel(const ModelInfo &info); @@ -188,7 +193,13 @@ class MySettings : public QObject QString localDocsEmbedDevice() const; void setLocalDocsEmbedDevice(const QString &value); - // Tool settings + // Web search settings + ToolEnums::UsageMode webSearchUsageMode() const; + void setWebSearchUsageMode(ToolEnums::UsageMode value); + int webSearchRetrievalSize() const; + void setWebSearchRetrievalSize(int value); + bool webSearchAskBeforeRunning() const; + void setWebSearchAskBeforeRunning(bool value); QString braveSearchAPIKey() const; void setBraveSearchAPIKey(const QString &value); @@ -251,6 +262,9 @@ class MySettings : public QObject void deviceChanged(); void suggestionModeChanged(); void languageAndLocaleChanged(); + void webSearchUsageModeChanged(); + void webSearchRetrievalSizeChanged() const; + void webSearchAskBeforeRunningChanged() const; void braveSearchAPIKeyChanged(); private: @@ -274,7 +288,6 @@ class MySettings : public QObject bool signal = false); QString filePathForLocale(const QLocale &locale); QString systemPromptInternal(const QString &proposedTemplate, QString &error); - }; #endif // MYSETTINGS_H diff --git a/gpt4all-chat/qml/ApplicationSettings.qml b/gpt4all-chat/qml/ApplicationSettings.qml index 5fe4bea5e1cf..57290606adee 100644 --- a/gpt4all-chat/qml/ApplicationSettings.qml +++ b/gpt4all-chat/qml/ApplicationSettings.qml @@ -354,13 +354,22 @@ MySettingsTab { ListElement { name: qsTr("Whenever possible") } ListElement { name: qsTr("Never") } } + function updateModel() { + suggestionModeBox.currentIndex = MySettings.suggestionMode; + } Accessible.name: suggestionModeLabel.text Accessible.description: suggestionModeLabel.helpText onActivated: { MySettings.suggestionMode = suggestionModeBox.currentIndex; } Component.onCompleted: { - suggestionModeBox.currentIndex = MySettings.suggestionMode; + suggestionModeBox.updateModel(); + } + Connections { + target: MySettings + function onSuggestionModeChanged() { + suggestionModeBox.updateModel(); + } } } MySettingsLabel { diff --git a/gpt4all-chat/qml/LocalDocsSettings.qml b/gpt4all-chat/qml/LocalDocsSettings.qml index 47f4409058f3..c3c98e7c6b1a 100644 --- a/gpt4all-chat/qml/LocalDocsSettings.qml +++ b/gpt4all-chat/qml/LocalDocsSettings.qml @@ -255,13 +255,14 @@ MySettingsTab { MySettingsLabel { id: chunkLabel Layout.fillWidth: true - text: qsTr("Document snippet size (characters)") - helpText: qsTr("Number of characters per document snippet. Larger numbers increase likelihood of factual responses, but also result in slower generation.") + text: qsTr("Document excerpt size (characters)") + helpText: qsTr("Number of characters per document excerpt. Larger numbers increase likelihood of factual responses, but also result in slower generation.") } MyTextField { id: chunkSizeTextField text: MySettings.localDocsChunkSize + font.pixelSize: theme.fontSizeLarge validator: IntValidator { bottom: 1 } @@ -281,13 +282,14 @@ MySettingsTab { Layout.topMargin: 15 MySettingsLabel { id: contextItemsPerPrompt - text: qsTr("Max document snippets per prompt") - helpText: qsTr("Max best N matches of retrieved document snippets to add to the context for prompt. Larger numbers increase likelihood of factual responses, but also result in slower generation.") + text: qsTr("Max source excerpts per prompt") + helpText: qsTr("Max best N matches of retrieved source excerpts to add to the context for prompt. Larger numbers increase likelihood of factual responses, but also result in slower generation.") } MyTextField { text: MySettings.localDocsRetrievalSize + font.pixelSize: theme.fontSizeLarge validator: IntValidator { bottom: 1 } diff --git a/gpt4all-chat/qml/WebSearchSettings.qml b/gpt4all-chat/qml/WebSearchSettings.qml index 222f9e30ee1d..901883cded59 100644 --- a/gpt4all-chat/qml/WebSearchSettings.qml +++ b/gpt4all-chat/qml/WebSearchSettings.qml @@ -11,7 +11,7 @@ import network MySettingsTab { onRestoreDefaultsClicked: { - MySettings.restoreLocalDocsDefaults(); + MySettings.restoreWebSearchDefaults(); } showRestoreDefaultsButton: true @@ -52,14 +52,24 @@ MySettingsTab { model: ListModel { ListElement { name: qsTr("Never") } ListElement { name: qsTr("Model decides") } - ListElement { name: qsTr("Ask for confirmation before executing") } - ListElement { name: qsTr("Force usage for every response when possible") } + ListElement { name: qsTr("Force usage for every response where possible") } + } + function updateModel() { + usageModeBox.currentIndex = MySettings.webSearchUsageMode; } Accessible.name: usageModeLabel.text Accessible.description: usageModeLabel.helpText onActivated: { + MySettings.webSearchUsageMode = usageModeBox.currentIndex; } Component.onCompleted: { + usageModeBox.updateModel(); + } + Connections { + target: MySettings + function onWebSearchUsageModeChanged() { + usageModeBox.updateModel(); + } } } } @@ -74,6 +84,7 @@ MySettingsTab { MyTextField { id: apiKeyField + enabled: usageModeBox.currentIndex !== 0 text: MySettings.braveSearchAPIKey color: theme.textColor font.pixelSize: theme.fontSizeLarge @@ -89,6 +100,46 @@ MySettingsTab { } } + RowLayout { + MySettingsLabel { + id: contextItemsPerPrompt + text: qsTr("Max source excerpts per prompt") + helpText: qsTr("Max best N matches of retrieved source excerpts to add to the context for prompt. Larger numbers increase likelihood of factual responses, but also result in slower generation.") + } + + MyTextField { + text: MySettings.webSearchRetrievalSize + font.pixelSize: theme.fontSizeLarge + validator: IntValidator { + bottom: 1 + } + onEditingFinished: { + var val = parseInt(text) + if (!isNaN(val)) { + MySettings.webSearchRetrievalSize = val + focus = false + } else { + text = MySettings.webSearchRetrievalSize + } + } + } + } + + RowLayout { + MySettingsLabel { + id: askBeforeRunningLabel + text: qsTr("Ask before running") + helpText: qsTr("The user is queried whether they want the tool to run in every instance") + } + MyCheckBox { + id: askBeforeRunningBox + checked: MySettings.webSearchAskBeforeRunning + onClicked: { + MySettings.webSearchAskBeforeRunning = !MySettings.webSearchAskBeforeRunning + } + } + } + Rectangle { Layout.topMargin: 15 Layout.fillWidth: true diff --git a/gpt4all-chat/tool.h b/gpt4all-chat/tool.h index 88a3dbc9de87..23c4c152a3a1 100644 --- a/gpt4all-chat/tool.h +++ b/gpt4all-chat/tool.h @@ -17,9 +17,9 @@ namespace ToolEnums { Q_ENUM_NS(Error) enum class UsageMode { - Disabled, // Completely disabled - Enabled, // Enabled and the model decides whether to run - ForceUsage, // Attempt to force usage of the tool rather than let the LLM decide. NOTE: Not always possible. + Disabled = 0, // Completely disabled + Enabled = 1, // Enabled and the model decides whether to run + ForceUsage = 2, // Attempt to force usage of the tool rather than let the LLM decide. NOTE: Not always possible. }; Q_ENUM_NS(UsageMode) From 4ae6acdedc98a31175f590de60924b77bccafbd5 Mon Sep 17 00:00:00 2001 From: Adam Treat Date: Thu, 15 Aug 2024 16:25:26 -0400 Subject: [PATCH 30/30] Force tool usage and refactor. Signed-off-by: Adam Treat --- gpt4all-chat/bravesearch.cpp | 17 +- gpt4all-chat/bravesearch.h | 1 + gpt4all-chat/chatllm.cpp | 314 +++++++++++++++++-------- gpt4all-chat/chatllm.h | 5 +- gpt4all-chat/localdocssearch.cpp | 19 +- gpt4all-chat/localdocssearch.h | 1 + gpt4all-chat/modellist.cpp | 24 +- gpt4all-chat/modellist.h | 8 +- gpt4all-chat/mysettings.cpp | 31 ++- gpt4all-chat/mysettings.h | 20 +- gpt4all-chat/qml/ModelSettings.qml | 48 +++- gpt4all-chat/qml/WebSearchSettings.qml | 29 +-- gpt4all-chat/tool.h | 21 +- gpt4all-chat/toolmodel.h | 8 +- 14 files changed, 385 insertions(+), 161 deletions(-) diff --git a/gpt4all-chat/bravesearch.cpp b/gpt4all-chat/bravesearch.cpp index e5505437f362..4eefb936e421 100644 --- a/gpt4all-chat/bravesearch.cpp +++ b/gpt4all-chat/bravesearch.cpp @@ -22,10 +22,16 @@ BraveSearch::BraveSearch() { connect(MySettings::globalInstance(), &MySettings::webSearchUsageModeChanged, this, &Tool::usageModeChanged); + connect(MySettings::globalInstance(), &MySettings::webSearchConfirmationModeChanged, + this, &Tool::confirmationModeChanged); } QString BraveSearch::run(const QJsonObject ¶meters, qint64 timeout) { + // Reset the error state + m_error = ToolEnums::Error::NoError; + m_errorString = QString(); + const QString apiKey = MySettings::globalInstance()->braveSearchAPIKey(); const QString query = parameters["query"].toString(); const int count = MySettings::globalInstance()->webSearchRetrievalSize(); @@ -93,6 +99,11 @@ ToolEnums::UsageMode BraveSearch::usageMode() const return MySettings::globalInstance()->webSearchUsageMode(); } +ToolEnums::ConfirmationMode BraveSearch::confirmationMode() const +{ + return MySettings::globalInstance()->webSearchConfirmationMode(); +} + void BraveAPIWorker::request(const QString &apiKey, const QString &query, int count) { // Documentation on the brave web search: @@ -181,8 +192,10 @@ QString BraveAPIWorker::cleanBraveResponse(const QByteArray& jsonResponse) QJsonObject excerpt; excerpt.insert("text", resultObj["description"]); } - result.insert("excerpts", excerpts); - cleanArray.append(QJsonValue(result)); + if (!excerpts.isEmpty()) { + result.insert("excerpts", excerpts); + cleanArray.append(QJsonValue(result)); + } } } diff --git a/gpt4all-chat/bravesearch.h b/gpt4all-chat/bravesearch.h index 78a73b60a893..12975681f5a9 100644 --- a/gpt4all-chat/bravesearch.h +++ b/gpt4all-chat/bravesearch.h @@ -56,6 +56,7 @@ class BraveSearch : public Tool { QJsonObject exampleParams() const override; bool isBuiltin() const override { return true; } ToolEnums::UsageMode usageMode() const override; + ToolEnums::ConfirmationMode confirmationMode() const override; bool excerpts() const override { return true; } private: diff --git a/gpt4all-chat/chatllm.cpp b/gpt4all-chat/chatllm.cpp index 09a8928797f0..26f9289943b3 100644 --- a/gpt4all-chat/chatllm.cpp +++ b/gpt4all-chat/chatllm.cpp @@ -37,6 +37,7 @@ #include using namespace Qt::Literals::StringLiterals; +using namespace ToolEnums; //#define DEBUG //#define DEBUG_MODEL_LOADING @@ -761,52 +762,218 @@ bool ChatLLM::promptInternal(const QList &collectionList, const QString int32_t n_predict, int32_t top_k, float top_p, float min_p, float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens) { - // FIXME: Honor the ask before running feature - // FIXME: The only localdocs specific thing here should be the injection of the parameters - // FIXME: Get the list of tools ... if force usage is set, then we *try* and force usage here. - QList localDocsExcerpts; - if (!collectionList.isEmpty()) { - LocalDocsSearch localdocs; - QJsonObject parameters; - parameters.insert("text", prompt); - parameters.insert("count", MySettings::globalInstance()->localDocsRetrievalSize()); - parameters.insert("collections", QJsonArray::fromStringList(collectionList)); - - // FIXME: This has to handle errors of the tool call - const QString localDocsResponse = localdocs.run(parameters, 2000 /*msecs to timeout*/); + QString toolCallingTemplate = MySettings::globalInstance()->modelToolTemplate(m_modelInfo); + Q_ASSERT(toolCallingTemplate.isEmpty() || toolCallingTemplate.contains("%1")); + if (toolCallingTemplate.isEmpty() || !toolCallingTemplate.contains("%1")) + toolCallingTemplate = u"### Context:\n%1\n\n"_s; - QString parseError; - localDocsExcerpts = SourceExcerpt::fromJson(localDocsResponse, parseError); - if (!parseError.isEmpty()) { - qWarning() << "ERROR: Could not parse source excerpts for localdocs response:" << parseError; - } else if (!localDocsExcerpts.isEmpty()) { - emit sourceExcerptsChanged(localDocsExcerpts); + const bool isToolCallingModel = MySettings::globalInstance()->modelIsToolCalling(m_modelInfo); + + // Iterate over the list of tools and if force usage is set, then we *try* and force usage here + QList toolResponses; + qint64 totalTime = 0; + bool producedSourceExcerpts = false; + const int toolCount = ToolModel::globalInstance()->count(); + for (int i = 0; i < toolCount; ++i) { + Tool *t = ToolModel::globalInstance()->get(i); + if (t->usageMode() != UsageMode::ForceUsage) + continue; + + // Local docs search is unique. It is the _only_ tool where we try and force usage even if + // the model does not support tool calling. + if (!isToolCallingModel && t->function() != "localdocs_search") + continue; + + // If this is the localdocs tool call, then we perform the search with the entire prompt as + // the query + if (t->function() == "localdocs_search") { + if (collectionList.isEmpty()) + continue; + + QJsonObject parameters; + parameters.insert("collections", QJsonArray::fromStringList(collectionList)); + parameters.insert("query", prompt); + parameters.insert("count", MySettings::globalInstance()->localDocsRetrievalSize()); + + // FIXME: Honor the confirmation mode feature + const QString response = t->run(parameters, 2000 /*msecs to timeout*/); + if (t->error() != Error::NoError) { + qWarning() << "ERROR: LocalDocs call produced error:" << t->errorString(); + continue; + } + + QString parseError; + QList localDocsExcerpts = SourceExcerpt::fromJson(response, parseError); + if (!parseError.isEmpty()) { + qWarning() << "ERROR: Could not parse source excerpts for localdocs response:" << parseError; + } else { + producedSourceExcerpts = true; + emit sourceExcerptsChanged(localDocsExcerpts); + } + toolResponses << QString(toolCallingTemplate).arg(response); + continue; } - } - // Augment the prompt template with the results if any - QString docsContext; - if (!localDocsExcerpts.isEmpty()) { - // FIXME(adam): we should be using the new tool template if available otherwise this I guess - QString json = SourceExcerpt::toJson(localDocsExcerpts); - docsContext = u"### Context:\n%1\n\n"_s.arg(json); + // For all other cases we should have a tool calling model + Q_ASSERT(isToolCallingModel); + + // Create the tool calling response as if the model has chosen this particular tool + const QString toolCallingResponse = QString("{\"name\": \"%1\", \"parameters\": {\"").arg(t->function()); + + // Mimic that the model has already responded like this to trigger our tool calling detection + // code and then rely upon it to complete the parameters correctly + m_response = toolCallingResponse.toStdString(); + + // Insert this response as the tool prompt + const QString toolPrompt = QString(promptTemplate).arg(prompt, toolCallingResponse); + + const QString toolCall = completeToolCall(toolPrompt, n_predict, top_k, top_p, + min_p, temp, n_batch, repeat_penalty, repeat_penalty_tokens, totalTime); + + // If the tool call is empty, then we failed in our attempt to force usage + if (toolCall.isEmpty()) { + qWarning() << "WARNING: Attempt to force usage of toolcall" << t->function() << "failed:" + << "model could not complete parameters for" << toolPrompt; + continue; + } + + QString errorString; + const QString response = executeToolCall(toolCall, producedSourceExcerpts, errorString); + if (response.isEmpty()) { + qWarning() << "WARNING: Attempt to force usage of toolcall" << t->function() << "failed:" << errorString; + continue; + } + + toolResponses << QString(toolCallingTemplate).arg(response); } - qint64 totalTime = 0; - bool producedSourceExcerpts; - bool success = promptRecursive({ docsContext }, prompt, promptTemplate, n_predict, top_k, top_p, + bool success = promptRecursive({ toolResponses }, prompt, promptTemplate, n_predict, top_k, top_p, min_p, temp, n_batch, repeat_penalty, repeat_penalty_tokens, totalTime, producedSourceExcerpts); - + Q_ASSERT(success); SuggestionMode mode = MySettings::globalInstance()->suggestionMode(); - if (mode == SuggestionMode::On || (mode == SuggestionMode::SourceExcerptsOnly && (!localDocsExcerpts.isEmpty() || producedSourceExcerpts))) + if (mode == SuggestionMode::On || (mode == SuggestionMode::SourceExcerptsOnly && producedSourceExcerpts)) generateQuestions(totalTime); else emit responseStopped(totalTime); - return success; } -bool ChatLLM::promptRecursive(const QList &toolContexts, const QString &prompt, +QString ChatLLM::completeToolCall(const QString &prompt, int32_t n_predict, int32_t top_k, float top_p, + float min_p, float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens, + qint64 &totalTime) +{ + if (!isModelLoaded()) + return QString(); + + int n_threads = MySettings::globalInstance()->threadCount(); + + m_stopGenerating = false; + auto promptFunc = std::bind(&ChatLLM::handlePrompt, this, std::placeholders::_1); + auto responseFunc = std::bind(&ChatLLM::handleResponse, this, std::placeholders::_1, + std::placeholders::_2); + emit promptProcessing(); + m_ctx.n_predict = n_predict; + m_ctx.top_k = top_k; + m_ctx.top_p = top_p; + m_ctx.min_p = min_p; + m_ctx.temp = temp; + m_ctx.n_batch = n_batch; + m_ctx.repeat_penalty = repeat_penalty; + m_ctx.repeat_last_n = repeat_penalty_tokens; + m_llModelInfo.model->setThreadCount(n_threads); +#if defined(DEBUG) + printf("%s", qPrintable(prompt)); + fflush(stdout); +#endif + + QElapsedTimer elapsedTimer; + elapsedTimer.start(); + m_timer->start(); + + m_checkToolCall = true; + + // We pass in the prompt as the completed template as we're mimicking that the respone has already + // started + LLModel::PromptContext ctx = m_ctx; + m_llModelInfo.model->prompt(prompt.toStdString(), "%1", promptFunc, responseFunc, + /*allowContextShift*/ false, ctx); + + // After the response has been handled reset this state + m_checkToolCall = false; + m_maybeToolCall = false; + + m_timer->stop(); + totalTime = elapsedTimer.elapsed(); + + const QString toolCall = QString::fromStdString(trim_whitespace(m_response)); + m_promptResponseTokens = 0; + m_promptTokens = 0; + m_response = std::string(); + + if (!m_foundToolCall) + return QString(); + + m_foundToolCall = false; + return toolCall; +} + +QString ChatLLM::executeToolCall(const QString &toolCall, bool &producedSourceExcerpts, QString &errorString) +{ + const QString toolTemplate = MySettings::globalInstance()->modelToolTemplate(m_modelInfo); + if (toolTemplate.isEmpty()) { + errorString = QString("ERROR: No valid tool template for this model %1").arg(toolCall); + return QString(); + } + + QJsonParseError err; + const QJsonDocument toolCallDoc = QJsonDocument::fromJson(toolCall.toUtf8(), &err); + + if (toolCallDoc.isNull() || err.error != QJsonParseError::NoError || !toolCallDoc.isObject()) { + errorString = QString("ERROR: The tool call had null or invalid json %1").arg(toolCall); + return QString(); + } + + QJsonObject rootObject = toolCallDoc.object(); + if (!rootObject.contains("name") || !rootObject.contains("parameters")) { + errorString = QString("ERROR: The tool call did not have required name and argument objects %1").arg(toolCall); + return QString(); + } + + const QString tool = toolCallDoc["name"].toString(); + const QJsonObject args = toolCallDoc["parameters"].toObject(); + + Tool *toolInstance = ToolModel::globalInstance()->get(tool); + if (!toolInstance) { + errorString = QString("ERROR: Could not find the tool for %1").arg(toolCall); + return QString(); + } + + // FIXME: Honor the confirmation mode feature + // Inform the chat that we're executing a tool call + emit toolCalled(toolInstance->name().toLower()); + + const QString response = toolInstance->run(args, 2000 /*msecs to timeout*/); + if (toolInstance->error() != Error::NoError) { + errorString = QString("ERROR: Tool call produced error: %1").arg(toolInstance->errorString()); + return QString(); + } + + // If the tool supports excerpts then try to parse them here, but it isn't strictly an error + // but rather a warning + if (toolInstance->excerpts()) { + QString parseError; + QList sourceExcerpts = SourceExcerpt::fromJson(response, parseError); + if (!parseError.isEmpty()) { + qWarning() << "WARNING: Could not parse source excerpts for response:" << parseError; + } else if (!sourceExcerpts.isEmpty()) { + producedSourceExcerpts = true; + emit sourceExcerptsChanged(sourceExcerpts); + } + } + return response; +} + +bool ChatLLM::promptRecursive(const QList &toolResponses, const QString &prompt, const QString &promptTemplate, int32_t n_predict, int32_t top_k, float top_p, float min_p, float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens, qint64 &totalTime, bool &producedSourceExcerpts, bool isRecursiveCall) { @@ -838,8 +1005,8 @@ bool ChatLLM::promptRecursive(const QList &toolContexts, const QString elapsedTimer.start(); m_timer->start(); - // The list of possible additional contexts that come from previous usage of tool calls - for (const QString &context : toolContexts) { + // The list of possible additional responses that come from previous usage of tool calls + for (const QString &context : toolResponses) { auto old_n_predict = std::exchange(m_ctx.n_predict, 0); // decode context without a response m_llModelInfo.model->prompt(context.toStdString(), "%1", promptFunc, responseFunc, /*allowContextShift*/ true, m_ctx); @@ -869,65 +1036,27 @@ bool ChatLLM::promptRecursive(const QList &toolContexts, const QString if (m_foundToolCall) { m_foundToolCall = false; + QString errorString; const QString toolCall = QString::fromStdString(trimmed); - const QString toolTemplate = MySettings::globalInstance()->modelToolTemplate(m_modelInfo); - if (toolTemplate.isEmpty()) { - qWarning() << "ERROR: No valid tool template for this model" << toolCall; - return handleFailedToolCall(trimmed, totalTime); - } - - QJsonParseError err; - const QJsonDocument toolCallDoc = QJsonDocument::fromJson(toolCall.toUtf8(), &err); - - if (toolCallDoc.isNull() || err.error != QJsonParseError::NoError || !toolCallDoc.isObject()) { - qWarning() << "ERROR: The tool call had null or invalid json " << toolCall; - return handleFailedToolCall(trimmed, totalTime); - } - - QJsonObject rootObject = toolCallDoc.object(); - if (!rootObject.contains("name") || !rootObject.contains("parameters")) { - qWarning() << "ERROR: The tool call did not have required name and argument objects " << toolCall; - return handleFailedToolCall(trimmed, totalTime); - } - - const QString tool = toolCallDoc["name"].toString(); - const QJsonObject args = toolCallDoc["parameters"].toObject(); - - Tool *toolInstance = ToolModel::globalInstance()->get(tool); - if (!toolInstance) { - qWarning() << "ERROR: Could not find the tool for " << toolCall; - return handleFailedToolCall(trimmed, totalTime); - } - - // FIXME: Honor the ask before running feature - // Inform the chat that we're executing a tool call - emit toolCalled(toolInstance->name().toLower()); - - const QString response = toolInstance->run(args, 2000 /*msecs to timeout*/); - if (toolInstance->error() != ToolEnums::Error::NoError) { - qWarning() << "ERROR: Tool call produced error:" << toolInstance->errorString(); - return handleFailedToolCall(trimmed, totalTime); - } - - // If the tool supports excerpts then try to parse them here - if (toolInstance->excerpts()) { - QString parseError; - QList sourceExcerpts = SourceExcerpt::fromJson(response, parseError); - if (!parseError.isEmpty()) { - qWarning() << "ERROR: Could not parse source excerpts for response:" << parseError; - } else if (!sourceExcerpts.isEmpty()) { - producedSourceExcerpts = true; - emit sourceExcerptsChanged(sourceExcerpts); - } + const QString toolResponse = executeToolCall(toolCall, producedSourceExcerpts, errorString); + if (toolResponse.isEmpty()) { + // FIXME: Need to surface errors to the UI + // Restore the strings that we excluded previously when detecting the tool call + qWarning() << errorString; + m_response = "" + toolCall.toStdString() + ""; + emit responseChanged(QString::fromStdString(m_response)); + emit responseStopped(totalTime); + m_pristineLoadedState = false; + return false; } + // Reset the state now that we've had a successful tool call response m_promptResponseTokens = 0; m_promptTokens = 0; m_response = std::string(); - // This is a recursive call but isRecursiveCall is checked above to arrest infinite recursive - // tool calls - return promptRecursive(QList()/*tool context*/, response, toolTemplate, + // This is a recursive call but flag is checked above to arrest infinite recursive tool calls + return promptRecursive({ toolResponse }, prompt, promptTemplate, n_predict, top_k, top_p, min_p, temp, n_batch, repeat_penalty, repeat_penalty_tokens, totalTime, producedSourceExcerpts, true /*isRecursiveCall*/); } else { @@ -940,17 +1069,6 @@ bool ChatLLM::promptRecursive(const QList &toolContexts, const QString } } -bool ChatLLM::handleFailedToolCall(const std::string &response, qint64 elapsed) -{ - // FIXME: Need to surface errors to the UI - // Restore the strings that we excluded previously when detecting the tool call - m_response = "" + response + ""; - emit responseChanged(QString::fromStdString(m_response)); - emit responseStopped(elapsed); - m_pristineLoadedState = false; - return true; -} - void ChatLLM::setShouldBeLoaded(bool b) { #if defined(DEBUG_MODEL_LOADING) diff --git a/gpt4all-chat/chatllm.h b/gpt4all-chat/chatllm.h index 7622c3661446..a4b8bfe8b9f7 100644 --- a/gpt4all-chat/chatllm.h +++ b/gpt4all-chat/chatllm.h @@ -200,7 +200,6 @@ public Q_SLOTS: bool promptInternal(const QList &collectionList, const QString &prompt, const QString &promptTemplate, int32_t n_predict, int32_t top_k, float top_p, float min_p, float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens); - bool handleFailedToolCall(const std::string &toolCall, qint64 elapsed); bool handlePrompt(int32_t token); bool handleResponse(int32_t token, const std::string &response); bool handleNamePrompt(int32_t token); @@ -220,6 +219,10 @@ public Q_SLOTS: quint32 m_promptResponseTokens; private: + QString completeToolCall(const QString &promptTemplate, int32_t n_predict, int32_t top_k, float top_p, + float min_p, float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens, + qint64 &totalTime); + QString executeToolCall(const QString &toolCall, bool &producedSourceExcerpts, QString &errorString); bool promptRecursive(const QList &toolContexts, const QString &prompt, const QString &promptTemplate, int32_t n_predict, int32_t top_k, float top_p, float min_p, float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens, qint64 &totalTime, bool &producedSourceExcerpts, bool isRecursiveCall = false); diff --git a/gpt4all-chat/localdocssearch.cpp b/gpt4all-chat/localdocssearch.cpp index 424c99a1cb7f..d36a50eadfc0 100644 --- a/gpt4all-chat/localdocssearch.cpp +++ b/gpt4all-chat/localdocssearch.cpp @@ -1,6 +1,7 @@ #include "localdocssearch.h" #include "database.h" #include "localdocs.h" +#include "mysettings.h" #include #include @@ -14,12 +15,16 @@ using namespace Qt::Literals::StringLiterals; QString LocalDocsSearch::run(const QJsonObject ¶meters, qint64 timeout) { + // Reset the error state + m_error = ToolEnums::Error::NoError; + m_errorString = QString(); + QList collections; QJsonArray collectionsArray = parameters["collections"].toArray(); for (int i = 0; i < collectionsArray.size(); ++i) collections.append(collectionsArray[i].toString()); - const QString text = parameters["text"].toString(); - const int count = parameters["count"].toInt(); + const QString text = parameters["query"].toString(); + const int count = MySettings::globalInstance()->localDocsRetrievalSize(); QThread workerThread; LocalDocsWorker worker; worker.moveToThread(&workerThread); @@ -71,6 +76,16 @@ QJsonObject LocalDocsSearch::paramSchema() const return localJsonDoc.object(); } +QJsonObject LocalDocsSearch::exampleParams() const +{ + static const QString example = R"({ + "query": "the 44th president of the United States" + })"; + static const QJsonDocument exampleDoc = QJsonDocument::fromJson(example.toUtf8()); + Q_ASSERT(!exampleDoc.isNull() && exampleDoc.isObject()); + return exampleDoc.object(); +} + LocalDocsWorker::LocalDocsWorker() : QObject(nullptr) { diff --git a/gpt4all-chat/localdocssearch.h b/gpt4all-chat/localdocssearch.h index be9ed04f3b37..91e8e943938a 100644 --- a/gpt4all-chat/localdocssearch.h +++ b/gpt4all-chat/localdocssearch.h @@ -38,6 +38,7 @@ class LocalDocsSearch : public Tool { QString description() const override { return tr("Search the local docs"); } QString function() const override { return "localdocs_search"; } ToolEnums::PrivacyScope privacyScope() const override { return ToolEnums::PrivacyScope::Local; } + QJsonObject exampleParams() const override; QJsonObject paramSchema() const override; bool isBuiltin() const override { return true; } ToolEnums::UsageMode usageMode() const override { return ToolEnums::UsageMode::ForceUsage; } diff --git a/gpt4all-chat/modellist.cpp b/gpt4all-chat/modellist.cpp index 5d3530c1c650..a239a47c2152 100644 --- a/gpt4all-chat/modellist.cpp +++ b/gpt4all-chat/modellist.cpp @@ -367,6 +367,17 @@ void ModelInfo::setSuggestedFollowUpPrompt(const QString &p) m_suggestedFollowUpPrompt = p; } +bool ModelInfo::isToolCalling() const +{ + return MySettings::globalInstance()->modelIsToolCalling(*this); +} + +void ModelInfo::setIsToolCalling(bool b) +{ + if (shouldSaveMetadata()) MySettings::globalInstance()->setModelIsToolCalling(*this, b, true /*force*/); + m_isToolCalling = b; +} + bool ModelInfo::shouldSaveMetadata() const { return installed && (isClone() || isDiscovered() || description() == "" /*indicates sideloaded*/); @@ -400,6 +411,7 @@ QVariantMap ModelInfo::getFields() const { "systemPromptTemplate",m_systemPromptTemplate }, { "chatNamePrompt", m_chatNamePrompt }, { "suggestedFollowUpPrompt", m_suggestedFollowUpPrompt }, + { "isToolCalling", m_isToolCalling }, }; } @@ -518,6 +530,7 @@ ModelList::ModelList() connect(MySettings::globalInstance(), &MySettings::promptTemplateChanged, this, &ModelList::updateDataForSettings); connect(MySettings::globalInstance(), &MySettings::toolTemplateChanged, this, &ModelList::updateDataForSettings); connect(MySettings::globalInstance(), &MySettings::systemPromptChanged, this, &ModelList::updateDataForSettings); + connect(MySettings::globalInstance(), &MySettings::isToolCallingChanged, this, &ModelList::updateDataForSettings); connect(&m_networkManager, &QNetworkAccessManager::sslErrors, this, &ModelList::handleSslErrors); updateModelsFromJson(); @@ -803,7 +816,8 @@ QVariant ModelList::dataInternal(const ModelInfo *info, int role) const return info->downloads(); case RecencyRole: return info->recency(); - + case IsToolCallingRole: + return info->isToolCalling(); } return QVariant(); @@ -999,6 +1013,8 @@ void ModelList::updateData(const QString &id, const QVector } break; } + case IsToolCallingRole: + info->setIsToolCalling(value.toBool()); break; } } @@ -1573,6 +1589,8 @@ void ModelList::parseModelsJsonFile(const QByteArray &jsonData, bool save) data.append({ ModelList::ToolTemplateRole, obj["toolTemplate"].toString() }); if (obj.contains("systemPrompt")) data.append({ ModelList::SystemPromptRole, obj["systemPrompt"].toString() }); + if (obj.contains("isToolCalling")) + data.append({ ModelList::IsToolCallingRole, obj["isToolCalling"].toBool() }); updateData(id, data); } @@ -1888,6 +1906,10 @@ void ModelList::updateModelsFromSettings() const QString suggestedFollowUpPrompt = settings.value(g + "/suggestedFollowUpPrompt").toString(); data.append({ ModelList::SuggestedFollowUpPromptRole, suggestedFollowUpPrompt }); } + if (settings.contains(g + "/isToolCalling")) { + const bool isToolCalling = settings.value(g + "/isToolCalling").toBool(); + data.append({ ModelList::IsToolCallingRole, isToolCalling }); + } updateData(id, data); } } diff --git a/gpt4all-chat/modellist.h b/gpt4all-chat/modellist.h index 766807461da5..9cc217bf1e15 100644 --- a/gpt4all-chat/modellist.h +++ b/gpt4all-chat/modellist.h @@ -75,6 +75,7 @@ struct ModelInfo { Q_PROPERTY(int likes READ likes WRITE setLikes) Q_PROPERTY(int downloads READ downloads WRITE setDownloads) Q_PROPERTY(QDateTime recency READ recency WRITE setRecency) + Q_PROPERTY(bool isToolCalling READ isToolCalling WRITE setIsToolCalling) public: enum HashAlgorithm { @@ -118,6 +119,9 @@ struct ModelInfo { QDateTime recency() const; void setRecency(const QDateTime &r); + bool isToolCalling() const; + void setIsToolCalling(bool b); + QString dirpath; QString filesize; QByteArray hash; @@ -223,6 +227,7 @@ struct ModelInfo { QString m_systemPromptTemplate = "### System:\nYou are an AI assistant who gives a quality response to whatever humans ask of you.\n\n"; QString m_chatNamePrompt = "Describe the above conversation in seven words or less."; QString m_suggestedFollowUpPrompt = "Suggest three very short factual follow-up questions that have not been answered yet or cannot be found inspired by the previous conversation and excerpts."; + bool m_isToolCalling = false; friend class MySettings; }; Q_DECLARE_METATYPE(ModelInfo) @@ -351,7 +356,8 @@ class ModelList : public QAbstractListModel MinPRole, LikesRole, DownloadsRole, - RecencyRole + RecencyRole, + IsToolCallingRole }; QHash roleNames() const override diff --git a/gpt4all-chat/mysettings.cpp b/gpt4all-chat/mysettings.cpp index 999efd6b321f..b70500328e1b 100644 --- a/gpt4all-chat/mysettings.cpp +++ b/gpt4all-chat/mysettings.cpp @@ -64,9 +64,9 @@ static const QVariantMap basicDefaults { { "localdocs/nomicAPIKey", "" }, { "localdocs/embedDevice", "Auto" }, { "network/attribution", "" }, - { "websearch/usageMode", QVariant::fromValue(UsageMode::Disabled) }, { "websearch/retrievalSize", 2 }, - { "websearch/askBeforeRunning", false }, + { "websearch/usageMode", QVariant::fromValue(UsageMode::Disabled) }, + { "websearch/confirmationMode", QVariant::fromValue(ConfirmationMode::NoConfirmation) }, { "bravesearch/APIKey", "" }, }; @@ -203,6 +203,7 @@ void MySettings::restoreModelDefaults(const ModelInfo &info) setModelRepeatPenaltyTokens(info, info.m_repeatPenaltyTokens); setModelPromptTemplate(info, info.m_promptTemplate); setModelToolTemplate(info, info.m_toolTemplate); + setModelIsToolCalling(info, info.m_isToolCalling); setModelSystemPromptTemplate(info, info.m_systemPromptTemplate); setModelChatNamePrompt(info, info.m_chatNamePrompt); setModelSuggestedFollowUpPrompt(info, info.m_suggestedFollowUpPrompt); @@ -239,7 +240,7 @@ void MySettings::restoreWebSearchDefaults() { setWebSearchUsageMode(basicDefaults.value("websearch/usageMode").value()); setWebSearchRetrievalSize(basicDefaults.value("websearch/retrievalSize").toInt()); - setWebSearchAskBeforeRunning(basicDefaults.value("websearch/askBeforeRunning").toBool()); + setWebSearchConfirmationMode(basicDefaults.value("websearch/confirmationMode").value()); setBraveSearchAPIKey(basicDefaults.value("bravesearch/APIKey").toString()); } @@ -314,6 +315,7 @@ double MySettings::modelRepeatPenalty (const ModelInfo &info) const int MySettings::modelRepeatPenaltyTokens (const ModelInfo &info) const { return getModelSetting("repeatPenaltyTokens", info).toInt(); } QString MySettings::modelPromptTemplate (const ModelInfo &info) const { return getModelSetting("promptTemplate", info).toString(); } QString MySettings::modelToolTemplate (const ModelInfo &info) const { return getModelSetting("toolTemplate", info).toString(); } +bool MySettings::modelIsToolCalling (const ModelInfo &info) const { return getModelSetting("isToolCalling", info).toBool(); } QString MySettings::modelSystemPromptTemplate (const ModelInfo &info) const { return getModelSetting("systemPrompt", info).toString(); } QString MySettings::modelChatNamePrompt (const ModelInfo &info) const { return getModelSetting("chatNamePrompt", info).toString(); } QString MySettings::modelSuggestedFollowUpPrompt(const ModelInfo &info) const { return getModelSetting("suggestedFollowUpPrompt", info).toString(); } @@ -428,6 +430,11 @@ void MySettings::setModelToolTemplate(const ModelInfo &info, const QString &valu setModelSetting("toolTemplate", info, value, force, true); } +void MySettings::setModelIsToolCalling(const ModelInfo &info, bool value, bool force) +{ + setModelSetting("isToolCalling", info, value, force, true); +} + void MySettings::setModelSystemPromptTemplate(const ModelInfo &info, const QString &value, bool force) { setModelSetting("systemPrompt", info, value, force, true); @@ -481,8 +488,8 @@ QString MySettings::localDocsEmbedDevice() const { return getBasicSetting QString MySettings::networkAttribution() const { return getBasicSetting("network/attribution" ).toString(); } QString MySettings::braveSearchAPIKey() const { return getBasicSetting("bravesearch/APIKey" ).toString(); } int MySettings::webSearchRetrievalSize() const { return getBasicSetting("websearch/retrievalSize").toInt(); } -bool MySettings::webSearchAskBeforeRunning() const { return getBasicSetting("websearch/askBeforeRunning").toBool(); } -UsageMode MySettings::webSearchUsageMode() const { return getBasicSetting("websearch/usageMode").value(); } +UsageMode MySettings::webSearchUsageMode() const { return getBasicSetting("websearch/usageMode").value(); } +ConfirmationMode MySettings::webSearchConfirmationMode() const { return getBasicSetting("websearch/confirmationMode").value(); } ChatTheme MySettings::chatTheme() const { return ChatTheme (getEnumSetting("chatTheme", chatThemeNames)); } FontSize MySettings::fontSize() const { return FontSize (getEnumSetting("fontSize", fontSizeNames)); } @@ -502,9 +509,9 @@ void MySettings::setLocalDocsNomicAPIKey(const QString &value) { setBasic void MySettings::setLocalDocsEmbedDevice(const QString &value) { setBasicSetting("localdocs/embedDevice", value, "localDocsEmbedDevice"); } void MySettings::setNetworkAttribution(const QString &value) { setBasicSetting("network/attribution", value, "networkAttribution"); } void MySettings::setBraveSearchAPIKey(const QString &value) { setBasicSetting("bravesearch/APIKey", value, "braveSearchAPIKey"); } -void MySettings::setWebSearchUsageMode(ToolEnums::UsageMode value) { setBasicSetting("websearch/usageMode", int(value), "webSearchUsageMode"); } void MySettings::setWebSearchRetrievalSize(int value) { setBasicSetting("websearch/retrievalSize", value, "webSearchRetrievalSize"); } -void MySettings::setWebSearchAskBeforeRunning(bool value) { setBasicSetting("websearch/askBeforeRunning", value, "webSearchAskBeforeRunning"); } +void MySettings::setWebSearchUsageMode(ToolEnums::UsageMode value) { setBasicSetting("websearch/usageMode", int(value), "webSearchUsageMode"); } +void MySettings::setWebSearchConfirmationMode(ToolEnums::ConfirmationMode value) { setBasicSetting("websearch/confirmationMode", int(value), "webSearchConfirmationMode"); } void MySettings::setChatTheme(ChatTheme value) { setBasicSetting("chatTheme", chatThemeNames .value(int(value))); } void MySettings::setFontSize(FontSize value) { setBasicSetting("fontSize", fontSizeNames .value(int(value))); } @@ -717,10 +724,14 @@ QString MySettings::systemPromptInternal(const QString &proposedTemplate, QStrin params.insert({"currentDate", QDate::currentDate().toString().toStdString()}); jinja2::ValuesList toolList; - int c = ToolModel::globalInstance()->count(); - for (int i = 0; i < c; ++i) { + const int toolCount = ToolModel::globalInstance()->count(); + for (int i = 0; i < toolCount; ++i) { Tool *t = ToolModel::globalInstance()->get(i); - if (t->usageMode() == UsageMode::Enabled) + // FIXME: For now we don't tell the model about the localdocs search in the system prompt because + // it will try to call the localdocs search even if no collection is selected. Ideally, we need + // away to update model to whether a tool is enabled/disabled either via reprocessing the system + // prompt or sending a system message as it happens + if (t->usageMode() != UsageMode::Disabled && t->function() != "localdocs_search") toolList.push_back(t->jinjaValue()); } params.insert({"toolList", toolList}); diff --git a/gpt4all-chat/mysettings.h b/gpt4all-chat/mysettings.h index 1518412b01f5..20f8913e6f57 100644 --- a/gpt4all-chat/mysettings.h +++ b/gpt4all-chat/mysettings.h @@ -73,9 +73,9 @@ class MySettings : public QObject Q_PROPERTY(int networkPort READ networkPort WRITE setNetworkPort NOTIFY networkPortChanged) Q_PROPERTY(SuggestionMode suggestionMode READ suggestionMode WRITE setSuggestionMode NOTIFY suggestionModeChanged) Q_PROPERTY(QStringList uiLanguages MEMBER m_uiLanguages CONSTANT) - Q_PROPERTY(ToolEnums::UsageMode webSearchUsageMode READ webSearchUsageMode WRITE setWebSearchUsageMode NOTIFY webSearchUsageModeChanged) Q_PROPERTY(int webSearchRetrievalSize READ webSearchRetrievalSize WRITE setWebSearchRetrievalSize NOTIFY webSearchRetrievalSizeChanged) - Q_PROPERTY(bool webSearchAskBeforeRunning READ webSearchAskBeforeRunning WRITE setWebSearchAskBeforeRunning NOTIFY webSearchAskBeforeRunningChanged) + Q_PROPERTY(ToolEnums::UsageMode webSearchUsageMode READ webSearchUsageMode WRITE setWebSearchUsageMode NOTIFY webSearchUsageModeChanged) + Q_PROPERTY(ToolEnums::ConfirmationMode webSearchConfirmationMode READ webSearchConfirmationMode WRITE setWebSearchConfirmationMode NOTIFY webSearchConfirmationModeChanged) Q_PROPERTY(QString braveSearchAPIKey READ braveSearchAPIKey WRITE setBraveSearchAPIKey NOTIFY braveSearchAPIKeyChanged) public: @@ -133,6 +133,8 @@ class MySettings : public QObject Q_INVOKABLE void setModelPromptTemplate(const ModelInfo &info, const QString &value, bool force = false); QString modelToolTemplate(const ModelInfo &info) const; Q_INVOKABLE void setModelToolTemplate(const ModelInfo &info, const QString &value, bool force = false); + bool modelIsToolCalling(const ModelInfo &info) const; + Q_INVOKABLE void setModelIsToolCalling(const ModelInfo &info, bool value, bool force = false); QString modelSystemPromptTemplate(const ModelInfo &info) const; Q_INVOKABLE void setModelSystemPromptTemplate(const ModelInfo &info, const QString &value, bool force = false); int modelContextLength(const ModelInfo &info) const; @@ -194,12 +196,12 @@ class MySettings : public QObject void setLocalDocsEmbedDevice(const QString &value); // Web search settings - ToolEnums::UsageMode webSearchUsageMode() const; - void setWebSearchUsageMode(ToolEnums::UsageMode value); int webSearchRetrievalSize() const; void setWebSearchRetrievalSize(int value); - bool webSearchAskBeforeRunning() const; - void setWebSearchAskBeforeRunning(bool value); + ToolEnums::UsageMode webSearchUsageMode() const; + void setWebSearchUsageMode(ToolEnums::UsageMode value); + ToolEnums::ConfirmationMode webSearchConfirmationMode() const; + void setWebSearchConfirmationMode(ToolEnums::ConfirmationMode value); QString braveSearchAPIKey() const; void setBraveSearchAPIKey(const QString &value); @@ -238,6 +240,7 @@ class MySettings : public QObject void systemPromptChanged(const ModelInfo &info); void chatNamePromptChanged(const ModelInfo &info); void suggestedFollowUpPromptChanged(const ModelInfo &info); + void isToolCallingChanged(const ModelInfo &info); void threadCountChanged(); void saveChatsContextChanged(); void serverChatChanged(); @@ -262,9 +265,10 @@ class MySettings : public QObject void deviceChanged(); void suggestionModeChanged(); void languageAndLocaleChanged(); + void webSearchRetrievalSizeChanged(); + // FIXME: These are never emitted along with a lot of the signals above probably with all kinds of bugs!! void webSearchUsageModeChanged(); - void webSearchRetrievalSizeChanged() const; - void webSearchAskBeforeRunningChanged() const; + void webSearchConfirmationModeChanged(); void braveSearchAPIKeyChanged(); private: diff --git a/gpt4all-chat/qml/ModelSettings.qml b/gpt4all-chat/qml/ModelSettings.qml index 1f96807d397f..49beb5c9c09a 100644 --- a/gpt4all-chat/qml/ModelSettings.qml +++ b/gpt4all-chat/qml/ModelSettings.qml @@ -153,9 +153,30 @@ MySettingsTab { Layout.fillWidth: true } - RowLayout { + MySettingsLabel { Layout.row: 7 Layout.column: 0 + Layout.columnSpan: 1 + Layout.topMargin: 15 + id: isToolCallingLabel + text: qsTr("Is Tool Calling Model") + helpText: qsTr("Whether the model is capable of tool calling and has tool calling instructions in system prompt.") + } + + MyCheckBox { + Layout.row: 7 + Layout.column: 1 + Layout.topMargin: 15 + id: isToolCallingBox + checked: root.currentModelInfo.isToolCalling + onClicked: { + MySettings.setModelIsToolCalling(root.currentModelInfo, isToolCallingBox.checked); + } + } + + RowLayout { + Layout.row: 8 + Layout.column: 0 Layout.columnSpan: 2 Layout.topMargin: 15 spacing: 10 @@ -182,7 +203,7 @@ MySettingsTab { Rectangle { id: systemPrompt visible: !root.currentModelInfo.isOnline - Layout.row: 8 + Layout.row: 9 Layout.column: 0 Layout.columnSpan: 2 Layout.fillWidth: true @@ -220,7 +241,7 @@ MySettingsTab { } RowLayout { - Layout.row: 9 + Layout.row: 10 Layout.column: 0 Layout.columnSpan: 2 Layout.topMargin: 15 @@ -241,7 +262,7 @@ MySettingsTab { Rectangle { id: promptTemplate - Layout.row: 10 + Layout.row: 11 Layout.column: 0 Layout.columnSpan: 2 Layout.fillWidth: true @@ -276,18 +297,19 @@ MySettingsTab { } MySettingsLabel { - Layout.row: 11 + Layout.row: 12 Layout.column: 0 Layout.columnSpan: 2 Layout.topMargin: 15 id: toolTemplateLabel text: qsTr("Tool Template") - helpText: qsTr("The template that allows tool calls to inject information into the context.") + helpText: qsTr("The template that allows tool calls to inject information into the context. Only enabled for tool calling models.") } Rectangle { id: toolTemplate - Layout.row: 12 + enabled: root.currentModelInfo.isToolCalling + Layout.row: 13 Layout.column: 0 Layout.columnSpan: 2 Layout.fillWidth: true @@ -325,14 +347,14 @@ MySettingsTab { id: chatNamePromptLabel text: qsTr("Chat Name Prompt") helpText: qsTr("Prompt used to automatically generate chat names.") - Layout.row: 13 + Layout.row: 14 Layout.column: 0 Layout.topMargin: 15 } Rectangle { id: chatNamePrompt - Layout.row: 14 + Layout.row: 15 Layout.column: 0 Layout.columnSpan: 2 Layout.fillWidth: true @@ -368,14 +390,14 @@ MySettingsTab { id: suggestedFollowUpPromptLabel text: qsTr("Suggested FollowUp Prompt") helpText: qsTr("Prompt used to generate suggested follow-up questions.") - Layout.row: 15 + Layout.row: 16 Layout.column: 0 Layout.topMargin: 15 } Rectangle { id: suggestedFollowUpPrompt - Layout.row: 16 + Layout.row: 17 Layout.column: 0 Layout.columnSpan: 2 Layout.fillWidth: true @@ -408,7 +430,7 @@ MySettingsTab { } GridLayout { - Layout.row: 17 + Layout.row: 18 Layout.column: 0 Layout.columnSpan: 2 Layout.topMargin: 15 @@ -904,7 +926,7 @@ MySettingsTab { } Rectangle { - Layout.row: 18 + Layout.row: 19 Layout.column: 0 Layout.columnSpan: 2 Layout.topMargin: 15 diff --git a/gpt4all-chat/qml/WebSearchSettings.qml b/gpt4all-chat/qml/WebSearchSettings.qml index 901883cded59..b24891148287 100644 --- a/gpt4all-chat/qml/WebSearchSettings.qml +++ b/gpt4all-chat/qml/WebSearchSettings.qml @@ -125,20 +125,21 @@ MySettingsTab { } } - RowLayout { - MySettingsLabel { - id: askBeforeRunningLabel - text: qsTr("Ask before running") - helpText: qsTr("The user is queried whether they want the tool to run in every instance") - } - MyCheckBox { - id: askBeforeRunningBox - checked: MySettings.webSearchAskBeforeRunning - onClicked: { - MySettings.webSearchAskBeforeRunning = !MySettings.webSearchAskBeforeRunning - } - } - } +// FIXME: +// RowLayout { +// MySettingsLabel { +// id: askBeforeRunningLabel +// text: qsTr("Ask before running") +// helpText: qsTr("The user is queried whether they want the tool to run in every instance.") +// } +// MyCheckBox { +// id: askBeforeRunningBox +// checked: MySettings.webSearchConfirmationMode +// onClicked: { +// MySettings.webSearchConfirmationMode = !MySettings.webSearchAskBeforeRunning +// } +// } +// } Rectangle { Layout.topMargin: 15 diff --git a/gpt4all-chat/tool.h b/gpt4all-chat/tool.h index 23c4c152a3a1..19ae75f23359 100644 --- a/gpt4all-chat/tool.h +++ b/gpt4all-chat/tool.h @@ -23,6 +23,13 @@ namespace ToolEnums { }; Q_ENUM_NS(UsageMode) + enum class ConfirmationMode { + NoConfirmation = 0, // No confirmation required + AskBeforeRunning = 1, // User is queried on every execution + AskBeforeRunningRecursive = 2, // User is queried if the tool is invoked in a recursive tool call + }; + Q_ENUM_NS(ConfirmationMode) + // Ordered in increasing levels of privacy enum class PrivacyScope { None = 0, // Tool call data does not have any privacy scope @@ -43,7 +50,7 @@ class Tool : public QObject { Q_PROPERTY(QUrl url READ url CONSTANT) Q_PROPERTY(bool isBuiltin READ isBuiltin CONSTANT) Q_PROPERTY(ToolEnums::UsageMode usageMode READ usageMode NOTIFY usageModeChanged) - Q_PROPERTY(bool askBeforeRunning READ askBeforeRunning NOTIFY askBeforeRunningChanged) + Q_PROPERTY(ToolEnums::ConfirmationMode confirmationMode READ confirmationMode NOTIFY confirmationModeChanged) Q_PROPERTY(bool excerpts READ excerpts CONSTANT) public: @@ -63,7 +70,7 @@ class Tool : public QObject { // [Required] Must be unique. Name of the function to invoke. Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length of 64. virtual QString function() const = 0; - // [Required] The privacy scope + // [Required] The privacy scope. virtual ToolEnums::PrivacyScope privacyScope() const = 0; // [Optional] Json schema describing the tool's parameters. An empty object specifies no parameters. @@ -80,14 +87,14 @@ class Tool : public QObject { // [Optional] The local file or remote resource use to invoke the tool. virtual QUrl url() const { return QUrl(); } - // [Optional] Whether the tool is built-in + // [Optional] Whether the tool is built-in. virtual bool isBuiltin() const { return false; } - // [Optional] The current usage mode + // [Optional] The usage mode. virtual ToolEnums::UsageMode usageMode() const { return ToolEnums::UsageMode::Disabled; } - // [Optional] The user is queried whether they want the tool to run in every instance - virtual bool askBeforeRunning() const { return false; } + // [Optional] The confirmation mode. + virtual ToolEnums::ConfirmationMode confirmationMode() const { return ToolEnums::ConfirmationMode::NoConfirmation; } // [Optional] Whether json result produces source excerpts. virtual bool excerpts() const { return false; } @@ -103,7 +110,7 @@ class Tool : public QObject { Q_SIGNALS: void usageModeChanged(); - void askBeforeRunningChanged(); + void confirmationModeChanged(); }; #endif // TOOL_H diff --git a/gpt4all-chat/toolmodel.h b/gpt4all-chat/toolmodel.h index 4135771e8f9a..bb61b25d9962 100644 --- a/gpt4all-chat/toolmodel.h +++ b/gpt4all-chat/toolmodel.h @@ -25,7 +25,7 @@ class ToolModel : public QAbstractListModel KeyRequiredRole, IsBuiltinRole, UsageModeRole, - AskBeforeRole, + ConfirmationModeRole, ExcerptsRole, }; @@ -58,8 +58,8 @@ class ToolModel : public QAbstractListModel return item->isBuiltin(); case UsageModeRole: return QVariant::fromValue(item->usageMode()); - case AskBeforeRole: - return item->askBeforeRunning(); + case ConfirmationModeRole: + return QVariant::fromValue(item->confirmationMode()); case ExcerptsRole: return item->excerpts(); } @@ -80,7 +80,7 @@ class ToolModel : public QAbstractListModel roles[KeyRequiredRole] = "keyRequired"; roles[IsBuiltinRole] = "isBuiltin"; roles[UsageModeRole] = "usageMode"; - roles[AskBeforeRole] = "askBeforeRunning"; + roles[ConfirmationModeRole] = "confirmationMode"; roles[ExcerptsRole] = "excerpts"; return roles; }