Skip to content

Commit 1a0b483

Browse files
committed
Begin converting the localdocs to a tool.
Signed-off-by: Adam Treat <treat.adam@gmail.com>
1 parent 0c3577b commit 1a0b483

File tree

11 files changed

+163
-84
lines changed

11 files changed

+163
-84
lines changed

gpt4all-chat/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ qt_add_executable(chat
118118
database.h database.cpp
119119
download.h download.cpp
120120
embllm.cpp embllm.h
121-
localdocs.h localdocs.cpp localdocsmodel.h localdocsmodel.cpp
121+
localdocs.h localdocs.cpp localdocsmodel.h localdocsmodel.cpp localdocssearch.h localdocssearch.cpp
122122
llm.h llm.cpp
123123
modellist.h modellist.cpp
124124
mysettings.h mysettings.cpp

gpt4all-chat/bravesearch.cpp

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,8 @@ QString BraveSearch::run(const QJsonObject &parameters, qint64 timeout)
3535
return worker.response();
3636
}
3737

38-
void BraveAPIWorker::request(const QString &apiKey, const QString &query, int topK)
38+
void BraveAPIWorker::request(const QString &apiKey, const QString &query, int count)
3939
{
40-
m_topK = topK;
41-
4240
// Documentation on the brave web search:
4341
// https://api.search.brave.com/app/documentation/web-search/get-started
4442
QUrl jsonUrl("https://api.search.brave.com/res/v1/web/search");
@@ -47,7 +45,7 @@ void BraveAPIWorker::request(const QString &apiKey, const QString &query, int to
4745
//https://api.search.brave.com/app/documentation/web-search/query
4846
QUrlQuery urlQuery;
4947
urlQuery.addQueryItem("q", query);
50-
urlQuery.addQueryItem("count", QString::number(topK));
48+
urlQuery.addQueryItem("count", QString::number(count));
5149
urlQuery.addQueryItem("result_filter", "web");
5250
urlQuery.addQueryItem("extra_snippets", "true");
5351
jsonUrl.setQuery(urlQuery);
@@ -64,7 +62,7 @@ void BraveAPIWorker::request(const QString &apiKey, const QString &query, int to
6462
connect(reply, &QNetworkReply::errorOccurred, this, &BraveAPIWorker::handleErrorOccurred);
6563
}
6664

67-
static QString cleanBraveResponse(const QByteArray& jsonResponse, qsizetype topK = 1)
65+
static QString cleanBraveResponse(const QByteArray& jsonResponse)
6866
{
6967
// This parses the response from brave and formats it in json that conforms to the de facto
7068
// standard in SourceExcerpts::fromJson(...)
@@ -77,7 +75,6 @@ static QString cleanBraveResponse(const QByteArray& jsonResponse, qsizetype topK
7775

7876
QString query;
7977
QJsonObject searchResponse = document.object();
80-
QJsonObject cleanResponse;
8178
QJsonArray cleanArray;
8279

8380
if (searchResponse.contains("query")) {
@@ -99,14 +96,16 @@ static QString cleanBraveResponse(const QByteArray& jsonResponse, qsizetype topK
9996
const int idx = m["index"].toInt();
10097

10198
QJsonObject resultObj = resultsArray[idx].toObject();
102-
QStringList selectedKeys = {"type", "title", "url", "description"};
99+
QStringList selectedKeys = {"type", "title", "url"};
103100
QJsonObject result;
104101
for (const auto& key : selectedKeys)
105102
if (resultObj.contains(key))
106103
result.insert(key, resultObj[key]);
107104

108105
if (resultObj.contains("page_age"))
109106
result.insert("date", resultObj["page_age"]);
107+
else
108+
result.insert("date", QDate::currentDate().toString());
110109

111110
QJsonArray excerpts;
112111
if (resultObj.contains("extra_snippets")) {
@@ -117,12 +116,18 @@ static QString cleanBraveResponse(const QByteArray& jsonResponse, qsizetype topK
117116
excerpt.insert("text", snippet);
118117
excerpts.append(excerpt);
119118
}
119+
if (resultObj.contains("description"))
120+
result.insert("description", resultObj["description"]);
121+
} else {
122+
QJsonObject excerpt;
123+
excerpt.insert("text", resultObj["description"]);
120124
}
121125
result.insert("excerpts", excerpts);
122126
cleanArray.append(QJsonValue(result));
123127
}
124128
}
125129

130+
QJsonObject cleanResponse;
126131
cleanResponse.insert("query", query);
127132
cleanResponse.insert("results", cleanArray);
128133
QJsonDocument cleanedDoc(cleanResponse);
@@ -139,12 +144,13 @@ void BraveAPIWorker::handleFinished()
139144
if (jsonReply->error() == QNetworkReply::NoError && jsonReply->isFinished()) {
140145
QByteArray jsonData = jsonReply->readAll();
141146
jsonReply->deleteLater();
142-
m_response = cleanBraveResponse(jsonData, m_topK);
147+
m_response = cleanBraveResponse(jsonData);
143148
} else {
144149
QByteArray jsonData = jsonReply->readAll();
145150
qWarning() << "ERROR: Could not search brave" << jsonReply->error() << jsonReply->errorString() << jsonData;
146151
jsonReply->deleteLater();
147152
}
153+
emit finished();
148154
}
149155

150156
void BraveAPIWorker::handleErrorOccurred(QNetworkReply::NetworkError code)

gpt4all-chat/bravesearch.h

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
#ifndef BRAVESEARCH_H
22
#define BRAVESEARCH_H
33

4-
#include "sourceexcerpt.h"
54
#include "tool.h"
65

76
#include <QObject>
@@ -14,14 +13,13 @@ class BraveAPIWorker : public QObject {
1413
public:
1514
BraveAPIWorker()
1615
: QObject(nullptr)
17-
, m_networkManager(nullptr)
18-
, m_topK(1) {}
16+
, m_networkManager(nullptr) {}
1917
virtual ~BraveAPIWorker() {}
2018

2119
QString response() const { return m_response; }
2220

2321
public Q_SLOTS:
24-
void request(const QString &apiKey, const QString &query, int topK);
22+
void request(const QString &apiKey, const QString &query, int count);
2523

2624
Q_SIGNALS:
2725
void finished();
@@ -33,7 +31,6 @@ private Q_SLOTS:
3331
private:
3432
QNetworkAccessManager *m_networkManager;
3533
QString m_response;
36-
int m_topK;
3734
};
3835

3936
class BraveSearch : public Tool {

gpt4all-chat/chatllm.cpp

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
#include "bravesearch.h"
44
#include "chat.h"
55
#include "chatapi.h"
6-
#include "localdocs.h"
6+
#include "localdocssearch.h"
77
#include "mysettings.h"
88
#include "network.h"
99

@@ -13,6 +13,7 @@
1313
#include <QGlobalStatic>
1414
#include <QGuiApplication>
1515
#include <QIODevice>
16+
#include <QJsonArray>
1617
#include <QJsonDocument>
1718
#include <QJsonObject>
1819
#include <QMutex>
@@ -128,11 +129,6 @@ ChatLLM::ChatLLM(Chat *parent, bool isServer)
128129
connect(&m_llmThread, &QThread::started, this, &ChatLLM::handleThreadStarted);
129130
connect(MySettings::globalInstance(), &MySettings::forceMetalChanged, this, &ChatLLM::handleForceMetalChanged);
130131
connect(MySettings::globalInstance(), &MySettings::deviceChanged, this, &ChatLLM::handleDeviceChanged);
131-
132-
// The following are blocking operations and will block the llm thread
133-
connect(this, &ChatLLM::requestRetrieveFromDB, LocalDocs::globalInstance()->database(), &Database::retrieveFromDB,
134-
Qt::BlockingQueuedConnection);
135-
136132
m_llmThread.setObjectName(parent->id());
137133
m_llmThread.start();
138134
}
@@ -778,21 +774,33 @@ bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString
778774
if (!isModelLoaded())
779775
return false;
780776

781-
QList<SourceExcerpt> databaseResults;
782-
const int retrievalSize = MySettings::globalInstance()->localDocsRetrievalSize();
777+
QList<SourceExcerpt> localDocsExcerpts;
783778
if (!collectionList.isEmpty() && !isToolCallResponse) {
784-
emit requestRetrieveFromDB(collectionList, prompt, retrievalSize, &databaseResults); // blocks
785-
emit sourceExcerptsChanged(databaseResults);
779+
LocalDocsSearch localdocs;
780+
QJsonObject parameters;
781+
parameters.insert("text", prompt);
782+
parameters.insert("count", MySettings::globalInstance()->localDocsRetrievalSize());
783+
parameters.insert("collections", QJsonArray::fromStringList(collectionList));
784+
785+
// FIXME: This has to handle errors of the tool call
786+
const QString localDocsResponse = localdocs.run(parameters, 2000 /*msecs to timeout*/);
787+
788+
QString parseError;
789+
localDocsExcerpts = SourceExcerpt::fromJson(localDocsResponse, parseError);
790+
if (!parseError.isEmpty()) {
791+
qWarning() << "ERROR: Could not parse source excerpts for localdocs response:" << parseError;
792+
} else if (!localDocsExcerpts.isEmpty()) {
793+
emit sourceExcerptsChanged(localDocsExcerpts);
794+
}
786795
}
787796

788797
// Augment the prompt template with the results if any
789798
QString docsContext;
790-
if (!databaseResults.isEmpty()) {
799+
if (!localDocsExcerpts.isEmpty()) {
800+
// FIXME(adam): we should be using the new tool template if available otherwise this I guess
791801
QStringList results;
792-
for (const SourceExcerpt &info : databaseResults)
802+
for (const SourceExcerpt &info : localDocsExcerpts)
793803
results << u"Collection: %1\nPath: %2\nExcerpt: %3"_s.arg(info.collection, info.path, info.text);
794-
795-
// FIXME(jared): use a Jinja prompt template instead of hardcoded Alpaca-style localdocs template
796804
docsContext = u"### Context:\n%1\n\n"_s.arg(results.join("\n\n"));
797805
}
798806

@@ -897,7 +905,7 @@ bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString
897905
QString parseError;
898906
QList<SourceExcerpt> sourceExcerpts = SourceExcerpt::fromJson(braveResponse, parseError);
899907
if (!parseError.isEmpty()) {
900-
qWarning() << "ERROR: Could not parse source excerpts for brave response" << parseError;
908+
qWarning() << "ERROR: Could not parse source excerpts for brave response:" << parseError;
901909
} else if (!sourceExcerpts.isEmpty()) {
902910
emit sourceExcerptsChanged(sourceExcerpts);
903911
}
@@ -922,7 +930,7 @@ bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString
922930
}
923931

924932
SuggestionMode mode = MySettings::globalInstance()->suggestionMode();
925-
if (mode == SuggestionMode::On || (mode == SuggestionMode::SourceExcerptsOnly && (!databaseResults.isEmpty() || isToolCallResponse)))
933+
if (mode == SuggestionMode::On || (mode == SuggestionMode::SourceExcerptsOnly && (!localDocsExcerpts.isEmpty() || isToolCallResponse)))
926934
generateQuestions(elapsed);
927935
else
928936
emit responseStopped(elapsed);

gpt4all-chat/chatllm.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,6 @@ public Q_SLOTS:
189189
void shouldBeLoadedChanged();
190190
void trySwitchContextRequested(const ModelInfo &modelInfo);
191191
void trySwitchContextOfLoadedModelCompleted(int value);
192-
void requestRetrieveFromDB(const QList<QString> &collections, const QString &text, int retrievalSize, QList<SourceExcerpt> *results);
193192
void reportSpeed(const QString &speed);
194193
void reportDevice(const QString &device);
195194
void reportFallbackReason(const QString &fallbackReason);

gpt4all-chat/database.cpp

Lines changed: 34 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1938,7 +1938,7 @@ QList<int> Database::searchEmbeddings(const std::vector<float> &query, const QLi
19381938
}
19391939

19401940
void Database::retrieveFromDB(const QList<QString> &collections, const QString &text, int retrievalSize,
1941-
QList<SourceExcerpt> *results)
1941+
QString &jsonResult)
19421942
{
19431943
#if defined(DEBUG)
19441944
qDebug() << "retrieveFromDB" << collections << text << retrievalSize;
@@ -1960,37 +1960,49 @@ void Database::retrieveFromDB(const QList<QString> &collections, const QString &
19601960
return;
19611961
}
19621962

1963+
QMap<QString, QJsonObject> results;
19631964
while (q.next()) {
19641965
#if defined(DEBUG)
19651966
const int rowid = q.value(0).toInt();
19661967
#endif
1967-
const QString document_path = q.value(2).toString();
1968-
const QString chunk_text = q.value(3).toString();
1969-
const QString date = QDateTime::fromMSecsSinceEpoch(q.value(1).toLongLong()).toString("yyyy, MMMM dd");
19701968
const QString file = q.value(4).toString();
1971-
const QString title = q.value(5).toString();
1972-
const QString author = q.value(6).toString();
1973-
const int page = q.value(7).toInt();
1974-
const int from = q.value(8).toInt();
1975-
const int to = q.value(9).toInt();
1976-
const QString collectionName = q.value(10).toString();
1977-
SourceExcerpt info;
1978-
info.collection = collectionName;
1979-
info.path = document_path;
1980-
info.file = file;
1981-
info.title = title;
1982-
info.author = author;
1983-
info.date = date;
1984-
info.text = chunk_text;
1985-
info.page = page;
1986-
info.from = from;
1987-
info.to = to;
1988-
results->append(info);
1969+
QJsonObject resultObject = results.value(file);
1970+
resultObject.insert("file", file);
1971+
resultObject.insert("path", q.value(2).toString());
1972+
resultObject.insert("date", QDateTime::fromMSecsSinceEpoch(q.value(1).toLongLong()).toString("yyyy, MMMM dd"));
1973+
resultObject.insert("title", q.value(5).toString());
1974+
resultObject.insert("author", q.value(6).toString());
1975+
resultObject.insert("collection", q.value(10).toString());
1976+
1977+
QJsonArray excerpts;
1978+
if (resultObject.contains("excerpts"))
1979+
excerpts = resultObject["excerpts"].toArray();
1980+
1981+
QJsonObject excerptObject;
1982+
excerptObject.insert("text", q.value(3).toString());
1983+
excerptObject.insert("page", q.value(7).toInt());
1984+
excerptObject.insert("from", q.value(8).toInt());
1985+
excerptObject.insert("to", q.value(9).toInt());
1986+
excerpts.append(excerptObject);
1987+
resultObject.insert("excerpts", excerpts);
1988+
results.insert(file, resultObject);
1989+
19891990
#if defined(DEBUG)
19901991
qDebug() << "retrieve rowid:" << rowid
19911992
<< "chunk_text:" << chunk_text;
19921993
#endif
19931994
}
1995+
1996+
QJsonArray resultsArray;
1997+
QList<QJsonObject> resultsList = results.values();
1998+
for (const QJsonObject &result : resultsList)
1999+
resultsArray.append(QJsonValue(result));
2000+
2001+
QJsonObject response;
2002+
response.insert("results", resultsArray);
2003+
QJsonDocument document(response);
2004+
// qDebug().noquote() << document.toJson(QJsonDocument::Indented);
2005+
jsonResult = document.toJson(QJsonDocument::Compact);
19942006
}
19952007

19962008
// FIXME This is very slow and non-interruptible and when we close the application and we're

gpt4all-chat/database.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ public Q_SLOTS:
101101
void forceRebuildFolder(const QString &path);
102102
bool addFolder(const QString &collection, const QString &path, const QString &embedding_model);
103103
void removeFolder(const QString &collection, const QString &path);
104-
void retrieveFromDB(const QList<QString> &collections, const QString &text, int retrievalSize, QList<SourceExcerpt> *results);
104+
void retrieveFromDB(const QList<QString> &collections, const QString &text, int retrievalSize, QString &jsonResult);
105105
void changeChunkSize(int chunkSize);
106106
void changeFileExtensions(const QStringList &extensions);
107107

@@ -168,7 +168,6 @@ private Q_SLOTS:
168168
QStringList m_scannedFileExtensions;
169169
QTimer *m_scanTimer;
170170
QMap<int, QQueue<DocumentInfo>> m_docsToScan;
171-
QList<SourceExcerpt> m_retrieve;
172171
QThread m_dbThread;
173172
QFileSystemWatcher *m_watcher;
174173
QSet<QString> m_watchedPaths;

gpt4all-chat/localdocssearch.cpp

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
#include "localdocssearch.h"
2+
#include "database.h"
3+
#include "localdocs.h"
4+
5+
#include <QCoreApplication>
6+
#include <QDebug>
7+
#include <QGuiApplication>
8+
#include <QJsonArray>
9+
#include <QJsonObject>
10+
#include <QThread>
11+
12+
using namespace Qt::Literals::StringLiterals;
13+
14+
QString LocalDocsSearch::run(const QJsonObject &parameters, qint64 timeout)
15+
{
16+
QList<QString> collections;
17+
QJsonArray collectionsArray = parameters["collections"].toArray();
18+
for (int i = 0; i < collectionsArray.size(); ++i)
19+
collections.append(collectionsArray[i].toString());
20+
const QString text = parameters["text"].toString();
21+
const int count = parameters["count"].toInt();
22+
QThread workerThread;
23+
LocalDocsWorker worker;
24+
worker.moveToThread(&workerThread);
25+
connect(&worker, &LocalDocsWorker::finished, &workerThread, &QThread::quit, Qt::DirectConnection);
26+
connect(&workerThread, &QThread::started, [&worker, collections, text, count]() {
27+
worker.request(collections, text, count);
28+
});
29+
workerThread.start();
30+
workerThread.wait(timeout);
31+
workerThread.quit();
32+
workerThread.wait();
33+
return worker.response();
34+
}
35+
36+
LocalDocsWorker::LocalDocsWorker()
37+
: QObject(nullptr)
38+
{
39+
// The following are blocking operations and will block the calling thread
40+
connect(this, &LocalDocsWorker::requestRetrieveFromDB, LocalDocs::globalInstance()->database(),
41+
&Database::retrieveFromDB, Qt::BlockingQueuedConnection);
42+
}
43+
44+
void LocalDocsWorker::request(const QList<QString> &collections, const QString &text, int count)
45+
{
46+
QString jsonResult;
47+
emit requestRetrieveFromDB(collections, text, count, jsonResult); // blocks
48+
m_response = jsonResult;
49+
emit finished();
50+
}

0 commit comments

Comments
 (0)