Skip to content

Commit 3f8ee0e

Browse files
committed
Refactor to handle errors in tool calling better and add source comments.
Signed-off-by: Adam Treat <treat.adam@gmail.com>
1 parent ffd2bcc commit 3f8ee0e

File tree

2 files changed

+70
-60
lines changed

2 files changed

+70
-60
lines changed

gpt4all-chat/chatllm.cpp

Lines changed: 69 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -825,26 +825,83 @@ bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString
825825
m_ctx.n_predict = old_n_predict; // now we are ready for a response
826826
}
827827

828-
m_checkToolCall = !isToolCallResponse; // We can't handle recursive tool calls right now
828+
// We can't handle recursive tool calls right now otherwise we always try to check if we have a
829+
// tool call
830+
m_checkToolCall = !isToolCallResponse;
831+
829832
m_llModelInfo.model->prompt(prompt.toStdString(), promptTemplate.toStdString(), promptFunc, responseFunc, recalcFunc, m_ctx);
833+
834+
// After the response has been handled reset this state
830835
m_checkToolCall = false;
831836
m_maybeToolCall = false;
837+
832838
#if defined(DEBUG)
833839
printf("\n");
834840
fflush(stdout);
835841
#endif
836842
m_timer->stop();
837843
qint64 elapsed = totalTime.elapsed();
838844
std::string trimmed = trim_whitespace(m_response);
845+
846+
// If we found a tool call, then deal with it
839847
if (m_foundToolCall) {
840848
m_foundToolCall = false;
849+
850+
const QString toolCall = QString::fromStdString(trimmed);
851+
const QString toolTemplate = MySettings::globalInstance()->modelToolTemplate(m_modelInfo);
852+
if (toolTemplate.isEmpty()) {
853+
qWarning() << "ERROR: No valid tool template for this model" << toolCall;
854+
return handleFailedToolCall(trimmed, elapsed);
855+
}
856+
857+
QJsonParseError err;
858+
const QJsonDocument toolCallDoc = QJsonDocument::fromJson(toolCall.toUtf8(), &err);
859+
860+
if (toolCallDoc.isNull() || err.error != QJsonParseError::NoError || !toolCallDoc.isObject()) {
861+
qWarning() << "ERROR: The tool call had null or invalid json " << toolCall;
862+
return handleFailedToolCall(trimmed, elapsed);
863+
}
864+
865+
QJsonObject rootObject = toolCallDoc.object();
866+
if (!rootObject.contains("name") || !rootObject.contains("arguments")) {
867+
qWarning() << "ERROR: The tool call did not have required name and argument objects " << toolCall;
868+
return handleFailedToolCall(trimmed, elapsed);
869+
}
870+
871+
const QString tool = toolCallDoc["name"].toString();
872+
const QJsonObject args = toolCallDoc["arguments"].toObject();
873+
874+
// FIXME: In the future this will try to match the tool call to a list of tools that are supported
875+
// according to MySettings, but for now only brave search is supported
876+
if (tool != "brave_search" || !args.contains("query")) {
877+
qWarning() << "ERROR: Could not find the tool and correct arguments for " << toolCall;
878+
return handleFailedToolCall(trimmed, elapsed);
879+
}
880+
881+
const QString query = args["query"].toString();
882+
883+
// FIXME: This has to handle errors of the tool call
884+
emit toolCalled(tr("searching web..."));
885+
const QString apiKey = MySettings::globalInstance()->braveSearchAPIKey();
886+
Q_ASSERT(apiKey != "");
887+
BraveSearch brave;
888+
const QPair<QString, QList<SourceExcerpt>> braveResponse = brave.search(apiKey, query, 2 /*topK*/,
889+
2000 /*msecs to timeout*/);
890+
emit sourceExcerptsChanged(braveResponse.second);
891+
892+
// Erase the context of the tool call
841893
m_ctx.n_past = std::max(0, m_ctx.n_past);
842894
m_ctx.tokens.erase(m_ctx.tokens.end() - m_promptResponseTokens, m_ctx.tokens.end());
843895
m_promptResponseTokens = 0;
844896
m_promptTokens = 0;
845897
m_response = std::string();
846-
return toolCallInternal(QString::fromStdString(trimmed), n_predict, top_k, top_p, min_p, temp,
847-
n_batch, repeat_penalty, repeat_penalty_tokens);
898+
899+
// This is a recursive call but isToolCallResponse is checked above to arrest infinite recursive
900+
// tool calls
901+
return promptInternal(QList<QString>()/*collectionList*/, braveResponse.first, toolTemplate,
902+
n_predict, top_k, top_p, min_p, temp, n_batch, repeat_penalty, repeat_penalty_tokens,
903+
true /*isToolCallResponse*/);
904+
848905
} else {
849906
if (trimmed != m_response) {
850907
m_response = trimmed;
@@ -856,65 +913,19 @@ bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString
856913
generateQuestions(elapsed);
857914
else
858915
emit responseStopped(elapsed);
916+
m_pristineLoadedState = false;
917+
return true;
859918
}
860-
861-
m_pristineLoadedState = false;
862-
return true;
863919
}
864920

865-
bool ChatLLM::toolCallInternal(const QString &toolCall, int32_t n_predict, int32_t top_k, float top_p,
866-
float min_p, float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens)
921+
bool ChatLLM::handleFailedToolCall(const std::string &response, qint64 elapsed)
867922
{
868-
QString toolTemplate = MySettings::globalInstance()->modelToolTemplate(m_modelInfo);
869-
if (toolTemplate.isEmpty()) {
870-
// FIXME: Not sure what to do here. The model attempted a tool call, but there is no way for
871-
// us to process it. We should probably not even attempt further generation and just show an
872-
// error in the chat somehow?
873-
qWarning() << "WARNING: The model attempted a toolcall, but there is no valid tool template for this model" << toolCall;
874-
return promptInternal(QList<QString>()/*collectionList*/, QString() /*prompt*/,
875-
MySettings::globalInstance()->modelPromptTemplate(m_modelInfo),
876-
n_predict, top_k, top_p, min_p, temp, n_batch, repeat_penalty, repeat_penalty_tokens, true /*isToolCallResponse*/);
877-
}
878-
879-
QJsonParseError err;
880-
QJsonDocument toolCallDoc = QJsonDocument::fromJson(toolCall.toUtf8(), &err);
881-
882-
if (toolCallDoc.isNull() || err.error != QJsonParseError::NoError || !toolCallDoc.isObject()) {
883-
qWarning() << "WARNING: The tool call had null or invalid json " << toolCall;
884-
return promptInternal(QList<QString>()/*collectionList*/, QString() /*prompt*/, toolTemplate,
885-
n_predict, top_k, top_p, min_p, temp, n_batch, repeat_penalty, repeat_penalty_tokens, true /*isToolCallResponse*/);
886-
}
887-
888-
QJsonObject rootObject = toolCallDoc.object();
889-
if (!rootObject.contains("name") || !rootObject.contains("arguments")) {
890-
qWarning() << "WARNING: The tool call did not have required name and argument objects " << toolCall;
891-
return promptInternal(QList<QString>()/*collectionList*/, QString() /*prompt*/, toolTemplate,
892-
n_predict, top_k, top_p, min_p, temp, n_batch, repeat_penalty, repeat_penalty_tokens, true /*isToolCallResponse*/);
893-
}
894-
895-
const QString tool = toolCallDoc["name"].toString();
896-
const QJsonObject args = toolCallDoc["arguments"].toObject();
897-
898-
if (tool != "brave_search" || !args.contains("query")) {
899-
qWarning() << "WARNING: Could not find the tool and correct arguments for " << toolCall;
900-
return promptInternal(QList<QString>()/*collectionList*/, QString() /*prompt*/, toolTemplate,
901-
n_predict, top_k, top_p, min_p, temp, n_batch, repeat_penalty, repeat_penalty_tokens, true /*isToolCallResponse*/);
902-
}
903-
904-
const QString query = args["query"].toString();
905-
906-
emit toolCalled(tr("searching web..."));
907-
908-
const QString apiKey = MySettings::globalInstance()->braveSearchAPIKey();
909-
Q_ASSERT(apiKey != "");
910-
911-
BraveSearch brave;
912-
const QPair<QString, QList<SourceExcerpt>> braveResponse = brave.search(apiKey, query, 2 /*topK*/, 2000 /*msecs to timeout*/);
913-
914-
emit sourceExcerptsChanged(braveResponse.second);
915-
916-
return promptInternal(QList<QString>()/*collectionList*/, braveResponse.first, toolTemplate,
917-
n_predict, top_k, top_p, min_p, temp, n_batch, repeat_penalty, repeat_penalty_tokens, true /*isToolCallResponse*/);
923+
// Restore the strings that we excluded previously when detecting the tool call
924+
m_response = "<tool_call>" + response + "</tool_call>";
925+
emit responseChanged(QString::fromStdString(m_response));
926+
emit responseStopped(elapsed);
927+
m_pristineLoadedState = false;
928+
return true;
918929
}
919930

920931
void ChatLLM::setShouldBeLoaded(bool b)

gpt4all-chat/chatllm.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -200,8 +200,7 @@ public Q_SLOTS:
200200
bool promptInternal(const QList<QString> &collectionList, const QString &prompt, const QString &promptTemplate,
201201
int32_t n_predict, int32_t top_k, float top_p, float min_p, float temp, int32_t n_batch, float repeat_penalty,
202202
int32_t repeat_penalty_tokens, bool isToolCallResponse = false);
203-
bool toolCallInternal(const QString &toolcall, int32_t n_predict, int32_t top_k, float top_p, float min_p, float temp, int32_t n_batch, float repeat_penalty,
204-
int32_t repeat_penalty_tokens);
203+
bool handleFailedToolCall(const std::string &toolCall, qint64 elapsed);
205204
bool handlePrompt(int32_t token);
206205
bool handleResponse(int32_t token, const std::string &response);
207206
bool handleRecalculate(bool isRecalc);

0 commit comments

Comments
 (0)