@@ -825,26 +825,83 @@ bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString
825
825
m_ctx.n_predict = old_n_predict; // now we are ready for a response
826
826
}
827
827
828
- m_checkToolCall = !isToolCallResponse; // We can't handle recursive tool calls right now
828
+ // We can't handle recursive tool calls right now otherwise we always try to check if we have a
829
+ // tool call
830
+ m_checkToolCall = !isToolCallResponse;
831
+
829
832
m_llModelInfo.model ->prompt (prompt.toStdString (), promptTemplate.toStdString (), promptFunc, responseFunc, recalcFunc, m_ctx);
833
+
834
+ // After the response has been handled reset this state
830
835
m_checkToolCall = false ;
831
836
m_maybeToolCall = false ;
837
+
832
838
#if defined(DEBUG)
833
839
printf (" \n " );
834
840
fflush (stdout);
835
841
#endif
836
842
m_timer->stop ();
837
843
qint64 elapsed = totalTime.elapsed ();
838
844
std::string trimmed = trim_whitespace (m_response);
845
+
846
+ // If we found a tool call, then deal with it
839
847
if (m_foundToolCall) {
840
848
m_foundToolCall = false ;
849
+
850
+ const QString toolCall = QString::fromStdString (trimmed);
851
+ const QString toolTemplate = MySettings::globalInstance ()->modelToolTemplate (m_modelInfo);
852
+ if (toolTemplate.isEmpty ()) {
853
+ qWarning () << " ERROR: No valid tool template for this model" << toolCall;
854
+ return handleFailedToolCall (trimmed, elapsed);
855
+ }
856
+
857
+ QJsonParseError err;
858
+ const QJsonDocument toolCallDoc = QJsonDocument::fromJson (toolCall.toUtf8 (), &err);
859
+
860
+ if (toolCallDoc.isNull () || err.error != QJsonParseError::NoError || !toolCallDoc.isObject ()) {
861
+ qWarning () << " ERROR: The tool call had null or invalid json " << toolCall;
862
+ return handleFailedToolCall (trimmed, elapsed);
863
+ }
864
+
865
+ QJsonObject rootObject = toolCallDoc.object ();
866
+ if (!rootObject.contains (" name" ) || !rootObject.contains (" arguments" )) {
867
+ qWarning () << " ERROR: The tool call did not have required name and argument objects " << toolCall;
868
+ return handleFailedToolCall (trimmed, elapsed);
869
+ }
870
+
871
+ const QString tool = toolCallDoc[" name" ].toString ();
872
+ const QJsonObject args = toolCallDoc[" arguments" ].toObject ();
873
+
874
+ // FIXME: In the future this will try to match the tool call to a list of tools that are supported
875
+ // according to MySettings, but for now only brave search is supported
876
+ if (tool != " brave_search" || !args.contains (" query" )) {
877
+ qWarning () << " ERROR: Could not find the tool and correct arguments for " << toolCall;
878
+ return handleFailedToolCall (trimmed, elapsed);
879
+ }
880
+
881
+ const QString query = args[" query" ].toString ();
882
+
883
+ // FIXME: This has to handle errors of the tool call
884
+ emit toolCalled (tr (" searching web..." ));
885
+ const QString apiKey = MySettings::globalInstance ()->braveSearchAPIKey ();
886
+ Q_ASSERT (apiKey != " " );
887
+ BraveSearch brave;
888
+ const QPair<QString, QList<SourceExcerpt>> braveResponse = brave.search (apiKey, query, 2 /* topK*/ ,
889
+ 2000 /* msecs to timeout*/ );
890
+ emit sourceExcerptsChanged (braveResponse.second );
891
+
892
+ // Erase the context of the tool call
841
893
m_ctx.n_past = std::max (0 , m_ctx.n_past );
842
894
m_ctx.tokens .erase (m_ctx.tokens .end () - m_promptResponseTokens, m_ctx.tokens .end ());
843
895
m_promptResponseTokens = 0 ;
844
896
m_promptTokens = 0 ;
845
897
m_response = std::string ();
846
- return toolCallInternal (QString::fromStdString (trimmed), n_predict, top_k, top_p, min_p, temp,
847
- n_batch, repeat_penalty, repeat_penalty_tokens);
898
+
899
+ // This is a recursive call but isToolCallResponse is checked above to arrest infinite recursive
900
+ // tool calls
901
+ return promptInternal (QList<QString>()/* collectionList*/ , braveResponse.first , toolTemplate,
902
+ n_predict, top_k, top_p, min_p, temp, n_batch, repeat_penalty, repeat_penalty_tokens,
903
+ true /* isToolCallResponse*/ );
904
+
848
905
} else {
849
906
if (trimmed != m_response) {
850
907
m_response = trimmed;
@@ -856,65 +913,19 @@ bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString
856
913
generateQuestions (elapsed);
857
914
else
858
915
emit responseStopped (elapsed);
916
+ m_pristineLoadedState = false ;
917
+ return true ;
859
918
}
860
-
861
- m_pristineLoadedState = false ;
862
- return true ;
863
919
}
864
920
865
- bool ChatLLM::toolCallInternal (const QString &toolCall, int32_t n_predict, int32_t top_k, float top_p,
866
- float min_p, float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens)
921
+ bool ChatLLM::handleFailedToolCall (const std::string &response, qint64 elapsed)
867
922
{
868
- QString toolTemplate = MySettings::globalInstance ()->modelToolTemplate (m_modelInfo);
869
- if (toolTemplate.isEmpty ()) {
870
- // FIXME: Not sure what to do here. The model attempted a tool call, but there is no way for
871
- // us to process it. We should probably not even attempt further generation and just show an
872
- // error in the chat somehow?
873
- qWarning () << " WARNING: The model attempted a toolcall, but there is no valid tool template for this model" << toolCall;
874
- return promptInternal (QList<QString>()/* collectionList*/ , QString () /* prompt*/ ,
875
- MySettings::globalInstance ()->modelPromptTemplate (m_modelInfo),
876
- n_predict, top_k, top_p, min_p, temp, n_batch, repeat_penalty, repeat_penalty_tokens, true /* isToolCallResponse*/ );
877
- }
878
-
879
- QJsonParseError err;
880
- QJsonDocument toolCallDoc = QJsonDocument::fromJson (toolCall.toUtf8 (), &err);
881
-
882
- if (toolCallDoc.isNull () || err.error != QJsonParseError::NoError || !toolCallDoc.isObject ()) {
883
- qWarning () << " WARNING: The tool call had null or invalid json " << toolCall;
884
- return promptInternal (QList<QString>()/* collectionList*/ , QString () /* prompt*/ , toolTemplate,
885
- n_predict, top_k, top_p, min_p, temp, n_batch, repeat_penalty, repeat_penalty_tokens, true /* isToolCallResponse*/ );
886
- }
887
-
888
- QJsonObject rootObject = toolCallDoc.object ();
889
- if (!rootObject.contains (" name" ) || !rootObject.contains (" arguments" )) {
890
- qWarning () << " WARNING: The tool call did not have required name and argument objects " << toolCall;
891
- return promptInternal (QList<QString>()/* collectionList*/ , QString () /* prompt*/ , toolTemplate,
892
- n_predict, top_k, top_p, min_p, temp, n_batch, repeat_penalty, repeat_penalty_tokens, true /* isToolCallResponse*/ );
893
- }
894
-
895
- const QString tool = toolCallDoc[" name" ].toString ();
896
- const QJsonObject args = toolCallDoc[" arguments" ].toObject ();
897
-
898
- if (tool != " brave_search" || !args.contains (" query" )) {
899
- qWarning () << " WARNING: Could not find the tool and correct arguments for " << toolCall;
900
- return promptInternal (QList<QString>()/* collectionList*/ , QString () /* prompt*/ , toolTemplate,
901
- n_predict, top_k, top_p, min_p, temp, n_batch, repeat_penalty, repeat_penalty_tokens, true /* isToolCallResponse*/ );
902
- }
903
-
904
- const QString query = args[" query" ].toString ();
905
-
906
- emit toolCalled (tr (" searching web..." ));
907
-
908
- const QString apiKey = MySettings::globalInstance ()->braveSearchAPIKey ();
909
- Q_ASSERT (apiKey != " " );
910
-
911
- BraveSearch brave;
912
- const QPair<QString, QList<SourceExcerpt>> braveResponse = brave.search (apiKey, query, 2 /* topK*/ , 2000 /* msecs to timeout*/ );
913
-
914
- emit sourceExcerptsChanged (braveResponse.second );
915
-
916
- return promptInternal (QList<QString>()/* collectionList*/ , braveResponse.first , toolTemplate,
917
- n_predict, top_k, top_p, min_p, temp, n_batch, repeat_penalty, repeat_penalty_tokens, true /* isToolCallResponse*/ );
923
+ // Restore the strings that we excluded previously when detecting the tool call
924
+ m_response = " <tool_call>" + response + " </tool_call>" ;
925
+ emit responseChanged (QString::fromStdString (m_response));
926
+ emit responseStopped (elapsed);
927
+ m_pristineLoadedState = false ;
928
+ return true ;
918
929
}
919
930
920
931
void ChatLLM::setShouldBeLoaded (bool b)
0 commit comments