@@ -28,7 +28,7 @@ DEFINE_string(tokenizer_path, "", "HuggingFace tokenizer.json path.");
2828DEFINE_string (prompt, " Hello" , " Prompt text." );
2929DEFINE_double (temperature, 0.8 , " Sampling temperature (0 = greedy)." );
3030DEFINE_int32 (max_new_tokens, 128 , " Maximum tokens to generate." );
31- DEFINE_bool (decode_only, false , " Use decode method for everything (no prefill). " );
31+
3232
3333namespace llm = ::executorch::extension::llm;
3434using ::executorch::extension::from_blob;
@@ -120,60 +120,36 @@ int main(int argc, char** argv) {
120120 uint64_t cur_token = 0 ;
121121 auto prefill_start = std::chrono::steady_clock::now ();
122122
123- if (FLAGS_decode_only) {
124- // Token-by-token using decode method
125- for (int64_t i = 0 ; i < num_prompt_tokens; i++) {
126- std::vector<int64_t > tok_data = {static_cast <int64_t >(prompt_tokens[i])};
127- std::vector<int64_t > pos_data = {i};
128- auto tok_t = from_blob (tok_data.data (), {1 , 1 }, executorch::aten::ScalarType::Long);
129- auto pos_t = from_blob (pos_data.data (), {1 }, executorch::aten::ScalarType::Long);
130- std::vector<EValue> inputs;
131- inputs.push_back (tok_t );
132- inputs.push_back (pos_t );
133- auto result = module ->execute (" decode" , inputs);
134- if (result.error () != Error::Ok) {
135- ET_LOG (Error, " Decode prefill step %ld failed" , i);
136- return 1 ;
137- }
138- if (i == num_prompt_tokens - 1 ) {
139- auto & outputs = result.get ();
140- auto logits = outputs[0 ].toTensor ();
141- auto logits_ptr = std::make_shared<executorch::aten::Tensor>(std::move (logits));
142- cur_token = llm::logits_to_token (*logits_ptr, FLAGS_temperature);
143- }
144- }
145- } else {
146- // Chunked prefill
147- std::vector<int64_t > pos_data (num_prompt_tokens);
148- for (int64_t i = 0 ; i < num_prompt_tokens; i++) {
149- pos_data[i] = i;
150- }
151- std::vector<int64_t > token_data (prompt_tokens.begin (), prompt_tokens.end ());
152- auto tokens_tensor = from_blob (
153- token_data.data (),
154- {1 , S (num_prompt_tokens)},
155- executorch::aten::ScalarType::Long);
156- auto pos_tensor = from_blob (
157- pos_data.data (),
158- {S (num_prompt_tokens)},
159- executorch::aten::ScalarType::Long);
160-
161- std::vector<EValue> prefill_inputs;
162- prefill_inputs.push_back (tokens_tensor);
163- prefill_inputs.push_back (pos_tensor);
164-
165- auto prefill_result = module ->execute (" prefill" , prefill_inputs);
166- if (prefill_result.error () != Error::Ok) {
167- ET_LOG (Error, " Prefill failed" );
168- return 1 ;
169- }
170- auto & prefill_outputs = prefill_result.get ();
171-
172- auto logits_tensor = prefill_outputs[0 ].toTensor ();
173- auto logits_ptr =
174- std::make_shared<executorch::aten::Tensor>(std::move (logits_tensor));
175- cur_token = llm::logits_to_token (*logits_ptr, FLAGS_temperature);
123+ // Chunked prefill
124+ std::vector<int64_t > pos_data (num_prompt_tokens);
125+ for (int64_t i = 0 ; i < num_prompt_tokens; i++) {
126+ pos_data[i] = i;
127+ }
128+ std::vector<int64_t > token_data (prompt_tokens.begin (), prompt_tokens.end ());
129+ auto tokens_tensor = from_blob (
130+ token_data.data (),
131+ {1 , S (num_prompt_tokens)},
132+ executorch::aten::ScalarType::Long);
133+ auto pos_tensor = from_blob (
134+ pos_data.data (),
135+ {S (num_prompt_tokens)},
136+ executorch::aten::ScalarType::Long);
137+
138+ std::vector<EValue> prefill_inputs;
139+ prefill_inputs.push_back (tokens_tensor);
140+ prefill_inputs.push_back (pos_tensor);
141+
142+ auto prefill_result = module ->execute (" prefill" , prefill_inputs);
143+ if (prefill_result.error () != Error::Ok) {
144+ ET_LOG (Error, " Prefill failed" );
145+ return 1 ;
176146 }
147+ auto & prefill_outputs = prefill_result.get ();
148+
149+ auto logits_tensor = prefill_outputs[0 ].toTensor ();
150+ auto logits_ptr =
151+ std::make_shared<executorch::aten::Tensor>(std::move (logits_tensor));
152+ cur_token = llm::logits_to_token (*logits_ptr, FLAGS_temperature);
177153
178154 auto prefill_end = std::chrono::steady_clock::now ();
179155 double prefill_ms =
0 commit comments