Skip to content

Commit be8ffd1

Browse files
authored
[llm] Add generate_from_pos API to LLM runner (#11570)
As titled, this API allows us to support multi-turn conversation by passing in a `start_pos` argument to `generate_from_pos`. This pull request introduces a new feature to support text generation from a specific starting position (`generate_from_pos`) and includes updates to ensure proper error handling and functionality when `max_new_tokens` is negative. The changes primarily focus on extending the `TextLLMRunner` class and its associated methods to accommodate this new feature while maintaining backward compatibility. ### New Feature: Text Generation from a Specific Starting Position * **Added `generate_from_pos` Method**: Introduced a new method `generate_from_pos` in `TextLLMRunner` to allow text generation starting from a specified position in the KV cache. This includes updates to the method signature, logic, and error handling. (`extension/llm/runner/text_llm_runner.cpp` [[1]](diffhunk://#diff-9b3bd38c0b1ad81b18afab15784634e2b394fda448f5e2dae03de58870751440L76-R78) [[2]](diffhunk://#diff-9b3bd38c0b1ad81b18afab15784634e2b394fda448f5e2dae03de58870751440R129-R156) [[3]](diffhunk://#diff-9b3bd38c0b1ad81b18afab15784634e2b394fda448f5e2dae03de58870751440L150-R165) [[4]](diffhunk://#diff-9b3bd38c0b1ad81b18afab15784634e2b394fda448f5e2dae03de58870751440R219-R225); `extension/llm/runner/text_llm_runner.h` [[5]](diffhunk://#diff-d1aa44a87ea9b7ec51250c2002466cb9bd57db153c1c8b58ffdf73e8f231a89bR98-R122) * **Updated Documentation**: Enhanced method documentation in `TextLLMRunner` to describe the new functionality, including parameters like `start_pos` and the expected behavior. (`extension/llm/runner/text_llm_runner.h` [[1]](diffhunk://#diff-d1aa44a87ea9b7ec51250c2002466cb9bd57db153c1c8b58ffdf73e8f231a89bL81-R83) [[2]](diffhunk://#diff-d1aa44a87ea9b7ec51250c2002466cb9bd57db153c1c8b58ffdf73e8f231a89bR98-R122) ### Error Handling Improvements * **Validation for `max_new_tokens`**: Added checks to ensure `max_new_tokens` is positive. If it is not, an `InvalidArgument` error is returned. This prevents invalid configurations during text generation. (`extension/llm/runner/text_llm_runner.cpp` [extension/llm/runner/text_llm_runner.cppR129-R156](diffhunk://#diff-9b3bd38c0b1ad81b18afab15784634e2b394fda448f5e2dae03de58870751440R129-R156)) * **Unit Test for Negative `max_new_tokens`**: Created a new test case (`GenerateFromPosErrorsWithNegativeMaxNewTokens`) to verify that the `generate_from_pos` method correctly handles scenarios where `max_new_tokens` is negative. (`extension/llm/runner/test/test_text_llm_runner.cpp` [extension/llm/runner/test/test_text_llm_runner.cppR325-R379](diffhunk://#diff-0a1e69b4182878ccad887c4f4ba3929ef24082a26623e26a871d73f4e6cea503R325-R379))
1 parent 1309849 commit be8ffd1

File tree

4 files changed

+134
-14
lines changed

4 files changed

+134
-14
lines changed

extension/llm/runner/irunner.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,23 @@ class ET_EXPERIMENTAL IRunner {
121121
std::function<void(const std::string&)> token_callback,
122122
std::function<void(const Stats&)> stats_callback) = 0;
123123

124+
/**
125+
* Generate text based on the provided prompt and generation config, from a
126+
* given position in KV cache.
127+
*
128+
* @param prompt The input prompt to generate from
129+
* @param start_pos The starting position in KV cache of the input
130+
* @param config Generation configuration parameters
131+
* @param token_callback Callback function called for each generated token
132+
* @param stats_callback Callback function for generation statistics
133+
* @return Error::Ok if successful, an error otherwise
134+
*/
135+
virtual runtime::Error generate_from_pos(
136+
const std::string& prompt,
137+
int64_t start_pos,
138+
const GenerationConfig& config,
139+
std::function<void(const std::string&)> token_callback,
140+
std::function<void(const Stats&)> stats_callback) = 0;
124141
/**
125142
* Stop the generation process.
126143
*/

extension/llm/runner/test/test_text_llm_runner.cpp

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,3 +322,58 @@ TEST_F(RunnerTest, IsLoadedReturnsTrueWhenComponentsInitialized) {
322322
// Verify is_loaded returns true
323323
EXPECT_TRUE(runner.is_loaded());
324324
}
325+
326+
// Test that generate_from_pos() errors out when max_new_tokens is negative
327+
TEST_F(RunnerTest, GenerateFromPosErrorsWithNegativeMaxNewTokens) {
328+
// Create mock instances using helper functions
329+
auto tokenizer = createMockTokenizer();
330+
auto text_decoder_runner = createMockTextDecoderRunner();
331+
auto text_prefiller = createMockTextPrefiller(text_decoder_runner.get());
332+
333+
// Set up expectations for the tokenizer encode method
334+
EXPECT_CALL(*tokenizer, encode(_, _, _))
335+
.WillOnce(Return(::tokenizers::Result<std::vector<uint64_t>>(
336+
std::vector<uint64_t>{1, 2, 3})));
337+
338+
// Set up expectations for load methods
339+
EXPECT_CALL(*text_prefiller, is_loaded()).WillRepeatedly(Return(true));
340+
341+
std::unique_ptr<executorch::llm::Stats> stats =
342+
std::make_unique<executorch::llm::Stats>();
343+
// Create a real TextTokenGenerator
344+
auto text_token_generator = createTextTokenGenerator(
345+
tokenizer.get(), text_decoder_runner.get(), stats.get());
346+
347+
// Create a Runner with our mocked components
348+
TextLLMRunner runner(
349+
{
350+
{"enable_dynamic_shape", false},
351+
{"get_max_seq_len", 10},
352+
{"get_max_context_len", 10},
353+
{"use_kv_cache", true},
354+
},
355+
std::unique_ptr<::tokenizers::Tokenizer>(tokenizer.release()),
356+
std::make_unique<MockModule>(),
357+
std::move(text_decoder_runner),
358+
std::unique_ptr<::executorch::extension::llm::TextPrefiller>(
359+
text_prefiller.release()),
360+
std::move(text_token_generator),
361+
std::move(stats));
362+
363+
// Load
364+
runner.load();
365+
366+
// Set up the generation config with a negative max_new_tokens value
367+
GenerationConfig config;
368+
config.max_new_tokens = 5;
369+
config.echo = false;
370+
371+
// num_prompt_tokens = 3
372+
// max_context_len = 10
373+
// start_pos = 8, this should fail because 10 - 8 > 3, even though
374+
// config.max_new_tokens = 5 > 3, it's still a failure.
375+
Error err = runner.generate_from_pos("test prompt", 8, config);
376+
377+
// Verify that an InvalidArgument error is returned
378+
EXPECT_EQ(err, Error::InvalidArgument);
379+
}

extension/llm/runner/text_llm_runner.cpp

Lines changed: 34 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,9 @@ Error TextLLMRunner::load() {
7373
ET_LOG(Info, format, __VA_ARGS__); \
7474
}
7575

76-
Error TextLLMRunner::generate(
76+
Error TextLLMRunner::generate_from_pos(
7777
const std::string& prompt,
78+
int64_t start_pos,
7879
const GenerationConfig& config,
7980
std::function<void(const std::string&)> token_callback,
8081
std::function<void(const Stats&)> stats_callback) {
@@ -125,20 +126,34 @@ Error TextLLMRunner::generate(
125126
std::vector<uint64_t> prompt_tokens = encode_res.get();
126127
int num_prompt_tokens = prompt_tokens.size();
127128

129+
// Reduce max_context_len by start_pos
130+
int64_t max_context_len = metadata_.at(kMaxContextLen) - start_pos;
128131
ET_CHECK_MSG(num_prompt_tokens >= 1, "Expected at least 1 prompt token");
129132
ET_CHECK_MSG(
130-
num_prompt_tokens < metadata_.at(kMaxContextLen),
131-
"num_prompt_tokens %d >= max_seq_len_ %" PRId64
133+
num_prompt_tokens < max_context_len,
134+
"num_prompt_tokens %d >= max_context_len %" PRId64
132135
", Max seq length exceeded - please increase max seq len value in your export script",
133136
num_prompt_tokens,
134-
metadata_.at(kMaxContextLen));
135-
136-
// Determine max_new_tokens using the GenerationConfig's resolve method
137-
int max_new_tokens = config.resolve_max_new_tokens(
138-
metadata_.at(kMaxContextLen), num_prompt_tokens);
139-
140-
ET_LOG(Info, "Max new tokens resolved: %d", max_new_tokens);
141-
137+
max_context_len);
138+
139+
// Determine max_new_tokens using the GenerationConfig's resolve method,
140+
// then subtract start_pos for max_new_tokens.
141+
int max_new_tokens =
142+
config.resolve_max_new_tokens(max_context_len, num_prompt_tokens);
143+
144+
ET_LOG(
145+
Info,
146+
"Max new tokens resolved: %d, given start_pos %" PRId64
147+
", num_prompt_tokens %zu, max_context_len %" PRId64,
148+
max_new_tokens,
149+
start_pos,
150+
prompt_tokens.size(),
151+
max_context_len);
152+
ET_CHECK_OR_RETURN_ERROR(
153+
max_new_tokens > 0,
154+
InvalidArgument,
155+
"Max new tokens %d is less than or equal to 0",
156+
max_new_tokens);
142157
// Prefill first
143158
// Here feed all tokens to the model and get the next predicted token
144159
// after the prompt. After that we will enter generate loop.
@@ -147,7 +162,7 @@ Error TextLLMRunner::generate(
147162
if (config.echo) {
148163
wrapped_callback(prompt);
149164
}
150-
int64_t pos = 0;
165+
int64_t pos = start_pos;
151166
auto prefill_res = text_prefiller_->prefill(prompt_tokens, pos);
152167
ET_CHECK_OK_OR_RETURN_ERROR(prefill_res.error());
153168
uint64_t cur_token = prefill_res.get();
@@ -201,6 +216,13 @@ Error TextLLMRunner::generate(
201216

202217
return Error::Ok;
203218
}
219+
Error TextLLMRunner::generate(
220+
const std::string& prompt,
221+
const GenerationConfig& config,
222+
std::function<void(const std::string&)> token_callback,
223+
std::function<void(const Stats&)> stats_callback) {
224+
return generate_from_pos(prompt, 0, config, token_callback, stats_callback);
225+
}
204226

205227
Error TextLLMRunner::warmup(const std::string& prompt, int32_t max_new_tokens) {
206228
// Create a GenerationConfig for warmup

extension/llm/runner/text_llm_runner.h

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,9 @@ class ET_EXPERIMENTAL TextLLMRunner : public IRunner {
7878
* @brief Generates text based on the provided prompt
7979
*
8080
* This method performs text generation using the loaded model. It processes
81-
* the input prompt, runs the model in prefill and decode phases, and returns
82-
* generated text through callbacks.
81+
* the input prompt, runs the model in prefill and decode phases until max
82+
* tokens to generate is reached or eos token is generated, then returns
83+
* generated text and perf stats through callbacks.
8384
*
8485
* @param prompt The input text to generate from
8586
* @param config Configuration parameters for text generation (e.g.,
@@ -94,6 +95,31 @@ class ET_EXPERIMENTAL TextLLMRunner : public IRunner {
9495
const GenerationConfig& config,
9596
std::function<void(const std::string&)> token_callback = {},
9697
std::function<void(const Stats&)> stats_callback = {}) override;
98+
99+
/**
100+
* @brief Generates text based on the provided prompt and start position
101+
*
102+
* This method performs text generation using the loaded model. It processes
103+
* the input prompt, runs the model in prefill and decode phases using the
104+
* start position until max tokens to generate is reached or eos token is
105+
* generated, then returns generated text and perf stats through callbacks.
106+
*
107+
* @param prompt The input text to generate from
108+
* @param start_pos The starting position in KV cache of the input
109+
* @param config Configuration parameters for text generation (e.g.,
110+
* max_new_tokens, temperature)
111+
* @param token_callback Function called for each generated token with the
112+
* decoded text
113+
* @param stats_callback Function called with performance statistics
114+
* @return ::executorch::runtime::Error Success or error status
115+
*/
116+
::executorch::runtime::Error generate_from_pos(
117+
const std::string& prompt,
118+
int64_t start_pos,
119+
const GenerationConfig& config,
120+
std::function<void(const std::string&)> token_callback = {},
121+
std::function<void(const Stats&)> stats_callback = {}) override;
122+
97123
/**
98124
* @brief Warms up the model with a sample prompt
99125
*

0 commit comments

Comments
 (0)