Skip to content

Commit 7ac8e72

Browse files
author
yuanhang zhao
authored
feature: add execution timeout for remote UDFs (#827)
* feature: add execution timeout for remote UDFs * small adjustments * change the way to store the props and add more unit tests * Fix: disallow duplicate key in remote function settings * add some comments on timeout * add some logics about dealing with illegal queries
1 parent 76a96ce commit 7ac8e72

13 files changed

+392
-59
lines changed

src/Common/ErrorCodes.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -694,6 +694,7 @@
694694
M(2611, UNKNOWN_FORMAT_SCHEMA) \
695695
M(2612, AMBIGUOUS_FORMAT_SCHEMA) \
696696
M(2613, UNKNOWN_FORMAT_SCHEMA_TYPE) \
697+
M(2631, DUPLICATE_KEY) \
697698
/* See END */
698699

699700
namespace DB

src/Common/sendRequest.cpp

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,10 @@ std::pair<String, Int32> sendRequest(
3333
const String & password,
3434
const String & payload,
3535
const std::vector<std::pair<String, String>> & headers,
36+
/// Timeout second for connect/send/receive
37+
ConnectionTimeouts timeouts,
3638
Poco::Logger * log)
3739
{
38-
/// One second for connect/send/receive
39-
ConnectionTimeouts timeouts({2, 0}, {5, 0}, {10, 0});
4040

4141
PooledHTTPSessionPtr session;
4242
try
@@ -109,6 +109,18 @@ std::pair<String, Int32> sendRequest(
109109
{
110110
session->attachSessionData(e.message());
111111
}
112+
if (e.code() == 1000){
113+
LOG_ERROR(
114+
log,
115+
"Execution timeout from uri={} method={} payload={} query_id={} error={} exception={}",
116+
uri.toString(),
117+
method,
118+
payload,
119+
query_id,
120+
e.message(),
121+
getCurrentExceptionMessage(true, true));
122+
return {"Execution timeout", toHTTPCode(e)};
123+
}
112124

113125
LOG_ERROR(
114126
log,

src/Common/sendRequest.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
#include <base/types.h>
44

5+
#include <IO/ConnectionTimeouts.h>
6+
57
#include <Poco/URI.h>
68

79
#include <utility>
@@ -23,5 +25,6 @@ std::pair<String, Int32> sendRequest(
2325
const String & password,
2426
const String & payload,
2527
const std::vector<std::pair<String, String>> & headers,
28+
ConnectionTimeouts timeouts,
2629
Poco::Logger * log);
2730
}

src/Coordination/MetaStoreConnection.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ std::pair<String, Int32> MetaStoreConnection::forwardRequest(
215215
for (auto it = active_servers.begin(); it != active_servers.end();)
216216
{
217217
Poco::URI uri{fmt::format({METASTORE_URL}, it->getHost(), it->getPort(), uri_parameter)};
218-
auto [response, http_status] = sendRequest(uri, method, query_id, user, password, body, {}, log);
218+
auto [response, http_status] = sendRequest(uri, method, query_id, user, password, body, {}, ConnectionTimeouts({2, 0}/* connect timeout */, {5, 0}/* send timeout */, {10, 0}/* receive timeout */) , log);
219219

220220
if (http_status != Poco::Net::HTTPResponse::HTTP_INTERNAL_SERVER_ERROR)
221221
{

src/Functions/UserDefined/RemoteUserDefinedFunction.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include <Functions/UserDefined/UserDefinedFunctionBase.h>
44
#include <IO/ReadBufferFromString.h>
5+
#include <IO/ConnectionTimeouts.h>
56
#include <Processors/Formats/IOutputFormat.h>
67
#include <Processors/Formats/IRowInputFormat.h>
78
#include <Common/sendRequest.h>
@@ -53,6 +54,7 @@ class RemoteUserDefinedFunction final : public UserDefinedFunctionBase
5354
"",
5455
out,
5556
{{config.auth_context.key_name, config.auth_context.key_value}, {"", context->getCurrentQueryId()}},
57+
ConnectionTimeouts({2, 0}/* connect timeout */, {5, 0} /* send timeout */, {static_cast<long>(config.command_execution_timeout_milliseconds / 1000), static_cast<long>((config.command_execution_timeout_milliseconds % 1000u) * 1000u)/* receive timeout */}), /// timeout and limit for connect/send/receive ...
5658
&Poco::Logger::get("UserDefinedFunction"));
5759

5860
if (http_status != Poco::Net::HTTPResponse::HTTP_OK)

src/Functions/UserDefined/UDFHelper.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,7 @@ createUserDefinedExecutableFunction(ContextPtr context, const std::string & name
255255
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Invalid url for remote UDF, msg: {}", e.message());
256256
}
257257
cfg->command_read_timeout_milliseconds = config.getUInt64(key_in_config + ".command_read_timeout", 10000);
258+
cfg->command_execution_timeout_milliseconds = config.getUInt64(key_in_config + ".command_execution_timeout", 10000);
258259
cfg->auth_method = std::move(auth_method);
259260
cfg->auth_context = std::move(auth_ctx);
260261
};

src/Functions/UserDefined/UserDefinedFunctionConfiguration.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,9 @@ struct RemoteUserDefinedFunctionConfiguration : public UserDefinedFunctionConfig
8080
/// Timeout for reading data from input format
8181
size_t command_read_timeout_milliseconds = 10000;
8282

83+
/// Timeout for receiving response from remote endpoint
84+
size_t command_execution_timeout_milliseconds = 10000;
85+
8386
/// url of remote endpoint, only available when 'type' is 'remote'
8487
Poco::URI url;
8588
enum AuthMethod

src/Parsers/ASTCreateFunctionQuery.cpp

Lines changed: 42 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
#include <Parsers/ASTFunction.h>
66

77
/// proton: starts
8+
#include <optional>
9+
#include <Parsers/ASTFunctionWithKeyValueArguments.h>
810
#include <Parsers/ASTLiteral.h>
911
#include <Parsers/ASTNameTypePair.h>
1012
#include <Parsers/formatAST.h>
@@ -72,15 +74,8 @@ void ASTCreateFunctionQuery::formatImpl(const IAST::FormatSettings & settings, I
7274
/// proton: starts
7375
if (is_remote)
7476
{
75-
settings.ostr << fmt::format("\nURL '{}'\n", function_core->as<ASTLiteral>()->value.safeGet<String>());
76-
auto auth_method
77-
= !function_core->children.empty() ? function_core->children[0]->as<ASTLiteral>()->value.safeGet<String>() : "none";
78-
settings.ostr << fmt::format("AUTH_METHOD '{}'\n", auth_method);
79-
if (auth_method != "none")
80-
{
81-
settings.ostr << fmt::format("AUTH_HEADER '{}'\n", function_core->children[1]->as<ASTLiteral>()->value.safeGet<String>());
82-
settings.ostr << fmt::format("AUTH_KEY '{}'\n", function_core->children[2]->as<ASTLiteral>()->value.safeGet<String>());
83-
}
77+
settings.ostr << ' ';
78+
function_core->formatImpl(settings, state, frame);
8479
return;
8580
}
8681
/// proton: ends
@@ -153,21 +148,50 @@ Poco::JSON::Object::Ptr ASTCreateFunctionQuery::toJSON() const
153148
/// remote function
154149
if (is_remote)
155150
{
156-
assert(function_core != nullptr);
157-
inner_func->set("url", function_core->as<ASTLiteral>()->value.safeGet<String>());
158-
// auth
159-
if (!function_core->children.empty())
151+
assert(function_core != nullptr && function_core->as<ASTExpressionList>());
152+
auto keyvalue_list = function_core->as<ASTExpressionList>();
153+
std::optional<String> url;
154+
std::optional<String> auth_method;
155+
std::optional<String> auth_header;
156+
std::optional<String> auth_key;
157+
std::optional<UInt64> execution_timeout;
158+
for (ASTPtr child : keyvalue_list->children)
160159
{
161-
auto auth_method = function_core->children[0]->as<ASTLiteral>()->value.safeGet<String>();
162-
inner_func->set("auth_method", auth_method);
163-
if (auth_method == "auth_header")
160+
auto pair = child->as<ASTPair>();
161+
if (pair != nullptr){
162+
if (pair->first == "url")
163+
url = pair->second->as<ASTLiteral>()->value.safeGet<String>();
164+
else if (pair->first == "auth_method")
165+
auth_method = pair->second->as<ASTLiteral>()->value.safeGet<String>();
166+
else if (pair->first == "auth_header")
167+
auth_header = pair->second->as<ASTLiteral>()->value.safeGet<String>();
168+
else if (pair->first == "auth_key")
169+
auth_key = pair->second->as<ASTLiteral>()->value.safeGet<String>();
170+
else if (pair->first == "execution_timeout")
171+
execution_timeout = pair->second->as<ASTLiteral>()->value.safeGet<UInt64>();
172+
}
173+
}
174+
175+
inner_func->set("url", url.value());
176+
if (auth_method.has_value())
177+
{
178+
inner_func->set("auth_method", auth_method.value());
179+
if (auth_method.value() == "auth_header")
164180
{
165181
Poco::JSON::Object::Ptr auth_context = new Poco::JSON::Object();
166-
auth_context->set("key_name", function_core->children[1]->as<ASTLiteral>()->value.safeGet<String>());
167-
auth_context->set("key_value", function_core->children[2]->as<ASTLiteral>()->value.safeGet<String>());
182+
auth_context->set("key_name", auth_header.value_or(""));
183+
auth_context->set("key_value", auth_key.value_or(""));
168184
inner_func->set("auth_context", auth_context);
169185
}
170186
}
187+
else
188+
{
189+
inner_func->set("auth_method", "none");
190+
}
191+
if (execution_timeout.has_value())
192+
{
193+
inner_func->set("command_execution_timeout", execution_timeout.value());
194+
}
171195
func->set("function", inner_func);
172196
/// Remote function don't have source, return early.
173197
return func;

src/Parsers/ParserCreateFunctionQuery.cpp

Lines changed: 65 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include <Parsers/ASTLiteral.h>
1212
#include <Parsers/ParserCreateQuery.h>
1313
#include <Parsers/Streaming/ParserArguments.h>
14+
#include <Parsers/ParserKeyValuePairsSet.h>
1415

1516
#include <Poco/JSON/Object.h>
1617
/// proton: ends
@@ -40,11 +41,14 @@ bool ParserCreateFunctionQuery::parseImpl(IParser::Pos & pos, ASTPtr & node, Exp
4041
ParserKeyword s_auth_method("AUTH_METHOD");
4142
ParserKeyword s_auth_header("AUTH_HEADER");
4243
ParserKeyword s_auth_key("AUTH_KEY");
44+
ParserKeyword s_execution_timeout("EXECUTION_TIMEOUT");
4345
ParserLiteral value;
46+
ASTPtr kv_list;
4447
ASTPtr url;
4548
ASTPtr auth_method;
4649
ASTPtr auth_header;
4750
ASTPtr auth_key;
51+
ASTPtr execution_timeout;
4852
ParserArguments arguments_p;
4953
ParserDataType return_p;
5054
ParserStringLiteral js_src_p;
@@ -139,35 +143,74 @@ bool ParserCreateFunctionQuery::parseImpl(IParser::Pos & pos, ASTPtr & node, Exp
139143
{
140144
throw Exception("Remote udf can not be an aggregate function", ErrorCodes::AGGREGATE_FUNCTION_NOT_APPLICABLE);
141145
}
142-
if (!s_url.ignore(pos, expected))
146+
ParserKeyValuePairsSet kv_pairs_list;
147+
if (!kv_pairs_list.parse(pos, kv_list, expected))
143148
return false;
144-
if (!value.parse(pos, url, expected))
145-
return false;
146-
if (s_auth_method.ignore(pos, expected))
149+
150+
/// check if the parameters are valid and no unsupported or unknown parameters.
151+
std::optional<String> ast_url;
152+
std::optional<String> ast_auth_method;
153+
std::optional<String> ast_auth_header;
154+
std::optional<String> ast_auth_key;
155+
std::optional<UInt64> ast_execution_timeout;
156+
for (const auto & kv : kv_list->children)
157+
{
158+
auto * kv_pair = kv->as<ASTPair>();
159+
auto key = kv_pair->first;
160+
auto pair_value = kv_pair->second->as<ASTLiteral>()->value;
161+
if (!kv_pair)
162+
throw Exception("Key-value pair expected", ErrorCodes::UNKNOWN_FUNCTION);
163+
164+
if (key == "url")
165+
{
166+
ast_url = pair_value.safeGet<String>();
167+
}
168+
else if (key == "auth_method")
169+
{
170+
ast_auth_method = pair_value.safeGet<String>();
171+
if (ast_auth_method.value() != "none" && ast_auth_method.value() != "auth_header")
172+
throw Exception("Unknown auth method", ErrorCodes::UNKNOWN_FUNCTION);
173+
}
174+
else if (key == "auth_header")
175+
{
176+
ast_auth_header = pair_value.safeGet<String>();
177+
}
178+
else if (key == "auth_key")
179+
{
180+
ast_auth_key = pair_value.safeGet<String>();
181+
}
182+
else if (key == "execution_timeout")
183+
{
184+
ast_execution_timeout = pair_value.safeGet<UInt64>();
185+
}
186+
}
187+
/// check if URL is set
188+
if (!ast_url)
189+
throw Exception("URL is required for remote function", ErrorCodes::UNKNOWN_FUNCTION);
190+
/// check if auth_method is "auth_header" or "none"
191+
if (ast_auth_method)
147192
{
148-
if (!value.parse(pos, auth_method, expected))
149-
return false;
150-
auto method_str = auth_method->as<ASTLiteral>()->value.safeGet<String>();
151-
url->children.push_back(std::move(auth_method));
152-
if (method_str == "auth_header")
193+
if (ast_auth_method.value() == "auth_header")
153194
{
154-
if (!s_auth_header.ignore(pos, expected))
155-
return false;
156-
if (!value.parse(pos, auth_header, expected))
157-
return false;
158-
if (!s_auth_key.ignore(pos, expected))
159-
return false;
160-
if (!value.parse(pos, auth_key, expected))
161-
return false;
162-
url->children.push_back(std::move(auth_header));
163-
url->children.push_back(std::move(auth_key));
195+
if (!ast_auth_header || !ast_auth_key)
196+
throw Exception("Auth header and auth key are required for auth_header auth method", ErrorCodes::UNKNOWN_FUNCTION);
164197
}
165-
else if (method_str != "none")
198+
else if (ast_auth_method.value() == "none")
166199
{
167-
throw Exception("AUTH_METHOD must be 'none' or 'auth_header'", ErrorCodes::UNKNOWN_FUNCTION);
200+
if (ast_auth_header || ast_auth_key)
201+
throw Exception("Auth method is 'none', but auth header or auth key is set.", ErrorCodes::UNKNOWN_FUNCTION);
168202
}
203+
else
204+
{
205+
throw Exception("Unknown auth method " + ast_auth_method.value(), ErrorCodes::UNKNOWN_FUNCTION);
206+
}
207+
}
208+
else
209+
{
210+
if (ast_auth_header || ast_auth_key)
211+
throw Exception("Auth method is 'none', but auth header or auth key is set.", ErrorCodes::UNKNOWN_FUNCTION);
169212
}
170-
function_core = std::move(url);
213+
function_core = std::move(kv_list);
171214
}
172215
/// proton: ends
173216

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
#include <Parsers/ParserKeyValuePairsSet.h>
2+
#include <Parsers/ASTExpressionList.h>
3+
#include <Common/ErrorCodes.h>
4+
#include <unordered_map>
5+
6+
7+
8+
namespace DB
9+
{
10+
11+
namespace ErrorCodes
12+
{
13+
extern const int DUPLICATE_KEY;
14+
}
15+
16+
bool ParserKeyValuePairsSet::parseImpl(Pos & pos, ASTPtr & node, Expected & expected, [[ maybe_unused ]] bool hint)
17+
{
18+
ASTs elements;
19+
std::unordered_map<String, size_t> exists_keys;
20+
21+
auto parse_element = [&]
22+
{
23+
ASTPtr element;
24+
if (!elem_parser->parse(pos, element, expected))
25+
return false;
26+
auto key_value = element->as<ASTPair>();
27+
if (exists_keys.find(key_value->first) == exists_keys.end())
28+
{
29+
exists_keys[key_value->first] = elements.size();
30+
elements.push_back(element);
31+
}
32+
else
33+
{
34+
if (allow_duplicate)
35+
{
36+
size_t index = exists_keys[key_value->first];
37+
elements.erase(elements.begin() + index);
38+
exists_keys[key_value->first] = elements.size();
39+
elements.push_back(element);
40+
}
41+
else
42+
{
43+
throw Exception("Duplicate key \"" + key_value->first + "\" has existed previously", ErrorCodes::DUPLICATE_KEY);
44+
}
45+
}
46+
47+
return true;
48+
};
49+
50+
if (!ParserList::parseUtil(pos, expected, parse_element, *separator_parser, allow_empty))
51+
return false;
52+
53+
auto list = std::make_shared<ASTExpressionList>(result_separator);
54+
list->children = std::move(elements);
55+
node = list;
56+
57+
return true;
58+
}
59+
60+
}
61+

src/Parsers/ParserKeyValuePairsSet.h

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
#pragma once
2+
3+
#include <Parsers/ExpressionElementParsers.h>
4+
#include <Parsers/ExpressionListParsers.h>
5+
#include <Parsers/ASTFunctionWithKeyValueArguments.h>
6+
7+
#ifdef __clang__
8+
#pragma clang diagnostic push
9+
#pragma clang diagnostic ignored "-Wc99-extensions"
10+
#endif
11+
12+
namespace DB
13+
{
14+
15+
/// Parser for list of key-value pairs.
16+
class ParserKeyValuePairsSet : public IParserBase
17+
{
18+
protected:
19+
bool allow_duplicate = false;
20+
21+
ParserPtr elem_parser = std::make_unique<ParserKeyValuePair>();
22+
ParserPtr separator_parser = std::make_unique<ParserNothing>();
23+
bool allow_empty = true;
24+
char result_separator = '\0';
25+
const char * getName() const override { return "set of pairs"; }
26+
bool parseImpl(Pos & pos, ASTPtr & node, Expected & expected, [[ maybe_unused ]] bool hint) override;
27+
};
28+
29+
}
30+

0 commit comments

Comments
 (0)