Skip to content

Commit 95454bf

Browse files
committed
Improve batch server prompt and context config
1 parent 9e6711a commit 95454bf

1 file changed

Lines changed: 84 additions & 28 deletions

File tree

examples/batch-processing/server.py

Lines changed: 84 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1360,6 +1360,7 @@ def resolve_model_path(self) -> str:
13601360
class ModelOptions(BaseModel):
13611361
path: Optional[str] = None
13621362
alias: Optional[str] = None
1363+
chat_template: Optional[str] = None
13631364
from_pretrained: Optional["ConfigFile.FromPretrainedOptions"] = None
13641365
n_gpu_layers: Optional[int] = None
13651366
split_mode: Optional[int] = None
@@ -1369,12 +1370,16 @@ class ModelOptions(BaseModel):
13691370
use_mmap: Optional[bool] = None
13701371
use_mlock: Optional[bool] = None
13711372
kv_overrides: Optional[Dict[str, Union[bool, int, float, str]]] = None
1372-
n_ctx: int = 1024
1373-
n_batch: int = 256
1373+
n_ctx: Optional[int] = None
1374+
n_batch: Optional[int] = None
13741375
n_ubatch: Optional[int] = None
1375-
n_seq_max: int = 64
1376-
threads: int = Field(default_factory=lambda: max(multiprocessing.cpu_count() // 2, 1))
1377-
threads_batch: int = Field(default_factory=lambda: max(multiprocessing.cpu_count(), 1))
1376+
n_seq_max: Optional[int] = None
1377+
threads: Optional[int] = Field(
1378+
default_factory=lambda: max(multiprocessing.cpu_count(), 1)
1379+
)
1380+
threads_batch: Optional[int] = Field(
1381+
default_factory=lambda: max(multiprocessing.cpu_count(), 1)
1382+
)
13781383
rope_scaling_type: Optional[int] = None
13791384
pooling_type: Optional[int] = None
13801385
attention_type: Optional[int] = None
@@ -1394,6 +1399,7 @@ class ModelOptions(BaseModel):
13941399
type_v: Optional[int] = None
13951400
prompt_chunk_size: int = 32
13961401
max_seq_len: Optional[int] = None
1402+
max_output_tokens: Optional[int] = Field(default=None, ge=0)
13971403
kv_unified: bool = True
13981404
draft_model: Optional[Literal["prompt-lookup-decoding"]] = None
13991405
draft_model_num_pred_tokens: int = 10
@@ -1429,16 +1435,24 @@ def __init__(self, template: str, *, bos_token: str, eos_token: str) -> None:
14291435
self._eos_token = eos_token
14301436
self._bos_token = bos_token
14311437
self._template_text = template
1432-
self._template = ImmutableSandboxedEnvironment(
1438+
environment = ImmutableSandboxedEnvironment(
14331439
loader=jinja2.BaseLoader(),
14341440
trim_blocks=True,
14351441
lstrip_blocks=True,
1436-
).from_string(template)
1442+
)
1443+
environment.filters["from_json"] = self._from_json
1444+
self._template = environment.from_string(template)
14371445

14381446
@staticmethod
14391447
def _strftime_now(format_string: str) -> str:
14401448
return datetime.now().strftime(format_string)
14411449

1450+
@staticmethod
1451+
def _from_json(value: Any) -> Any:
1452+
if isinstance(value, str):
1453+
return json.loads(value)
1454+
return value
1455+
14421456
def format(
14431457
self,
14441458
*,
@@ -1651,6 +1665,8 @@ def from_prepared(
16511665
on_error: Optional[Callable[[BaseException], None]] = None,
16521666
) -> "CompletionRequest":
16531667
ctx_limit = model.max_seq_len
1668+
if model.max_output_tokens is not None:
1669+
ctx_limit = min(ctx_limit, len(prompt_tokens) + model.max_output_tokens)
16541670
if payload.max_tokens is None:
16551671
effective_max_len = ctx_limit
16561672
else:
@@ -1818,6 +1834,7 @@ class StreamState:
18181834
"_tools",
18191835
"_completion_id",
18201836
"_choice_index",
1837+
"_prompt_opens_leading_capture",
18211838
"_tool_schemas",
18221839
"_started",
18231840
"_text_parts",
@@ -1848,11 +1865,13 @@ def __init__(
18481865
tools: Optional[List[Dict[str, Any]]] = None,
18491866
completion_id: str = "",
18501867
choice_index: int = 0,
1868+
prompt_opens_leading_capture: bool = False,
18511869
) -> None:
18521870
self._schema = schema
18531871
self._tools = tools
18541872
self._completion_id = completion_id
18551873
self._choice_index = choice_index
1874+
self._prompt_opens_leading_capture = prompt_opens_leading_capture
18561875
self._tool_schemas = self._cached_tool_schema_map(tools)
18571876
self._started = False
18581877
self._text_parts: List[str] = []
@@ -1922,6 +1941,16 @@ def __init__(
19221941
if self._direct.assistant_prefix
19231942
else self.DIRECT_MODE_PRELUDE
19241943
)
1944+
if (
1945+
self._prompt_opens_leading_capture
1946+
and self._direct.leading_capture_field is not None
1947+
and self._direct.leading_capture_implicit
1948+
):
1949+
self._direct.mode = (
1950+
self.DIRECT_MODE_LEADING_CAPTURE
1951+
if not self._direct.assistant_prefix
1952+
else self.DIRECT_MODE_ASSISTANT_PREFIX
1953+
)
19251954
self._stream_state = None
19261955
else:
19271956
self._stream_state = (
@@ -2963,6 +2992,7 @@ def _advance_direct_stream_state(self, text: str) -> Tuple[bool, List[Dict[str,
29632992
leading_capture_end = self._direct.leading_capture_end
29642993
leading_capture_strip_after = self._direct.leading_capture_strip_after
29652994
leading_capture_implicit = self._direct.leading_capture_implicit
2995+
prompt_opens_leading_capture = self._prompt_opens_leading_capture
29662996
iterator_start = self._direct.iterator_start
29672997
iterator_end = self._direct.iterator_end
29682998
content_end_markers = self._direct.content_end_markers
@@ -2982,6 +3012,9 @@ def _advance_direct_stream_state(self, text: str) -> Tuple[bool, List[Dict[str,
29823012
continue
29833013
if mode == self.DIRECT_MODE_PRELUDE:
29843014
if leading_capture_field is not None:
3015+
if prompt_opens_leading_capture and leading_capture_implicit:
3016+
mode = self.DIRECT_MODE_LEADING_CAPTURE
3017+
continue
29853018
if buffer.startswith(leading_capture_start):
29863019
buffer = buffer[len(leading_capture_start) :]
29873020
mode = self.DIRECT_MODE_LEADING_CAPTURE
@@ -5007,6 +5040,12 @@ def _uses_reasoning_content(self) -> bool:
50075040
template = self._chat_template_text()
50085041
return "reasoning_content" in template or "<think>" in template
50095042

5043+
def _prompt_opens_leading_capture(self) -> bool:
5044+
template = self._chat_template_text()
5045+
if "<think>" not in template:
5046+
return False
5047+
return "add_generation_prompt" in template
5048+
50105049
@staticmethod
50115050
def _chat_message(data: Dict[str, Any]) -> ChatCompletionRequestMessage:
50125051
return ChatCompletionRequestMessage.model_validate(data)
@@ -5562,6 +5601,7 @@ def _response_parser(
55625601
tools=tools,
55635602
completion_id=completion_id,
55645603
choice_index=choice_index,
5604+
prompt_opens_leading_capture=self._prompt_opens_leading_capture(),
55655605
)
55665606

55675607
def parse_chat_response(
@@ -7306,6 +7346,7 @@ def __init__(
73067346
*,
73077347
model_path: str,
73087348
model_alias: Optional[str] = None,
7349+
chat_template: Optional[str] = None,
73097350
n_gpu_layers: Optional[int] = None,
73107351
split_mode: Optional[int] = None,
73117352
main_gpu: Optional[int] = None,
@@ -7314,12 +7355,12 @@ def __init__(
73147355
use_mmap: Optional[bool] = None,
73157356
use_mlock: Optional[bool] = None,
73167357
kv_overrides: Optional[Dict[str, Union[bool, int, float, str]]] = None,
7317-
n_ctx: int,
7318-
n_batch: int,
7358+
n_ctx: Optional[int],
7359+
n_batch: Optional[int],
73197360
n_ubatch: Optional[int] = None,
7320-
n_seq_max: int,
7321-
n_threads: int,
7322-
n_threads_batch: int,
7361+
n_seq_max: Optional[int],
7362+
n_threads: Optional[int],
7363+
n_threads_batch: Optional[int],
73237364
rope_scaling_type: Optional[int] = None,
73247365
pooling_type: Optional[int] = None,
73257366
attention_type: Optional[int] = None,
@@ -7340,6 +7381,7 @@ def __init__(
73407381
prompt_chunk_size: int,
73417382
kv_unified: bool = True,
73427383
max_seq_len: Optional[int] = None,
7384+
max_output_tokens: Optional[int] = None,
73437385
draft_model: Optional[str] = None,
73447386
draft_model_num_pred_tokens: int = 10,
73457387
draft_model_max_ngram_size: int = 2,
@@ -7350,9 +7392,11 @@ def __init__(
73507392
self.backend_initialized = True
73517393
self.model_path = model_path
73527394
self.model_alias = model_alias
7395+
self.chat_template_override = chat_template
73537396
self.prompt_chunk_size = prompt_chunk_size
73547397
self.response_schema = response_schema
73557398
self.store_logits = store_logits
7399+
self.max_output_tokens = max_output_tokens
73567400
model_params, self._c_tensor_split, self._kv_overrides_array = (
73577401
self.build_model_params(
73587402
n_gpu_layers=n_gpu_layers,
@@ -7391,8 +7435,9 @@ def __init__(
73917435
raise RuntimeError(
73927436
"speculative decoding is only supported for attention models"
73937437
)
7438+
n_ctx_train = int(llama_cpp.llama_model_n_ctx_train(llama_model))
73947439
context_params = self.build_context_params(
7395-
n_ctx=n_ctx,
7440+
n_ctx=n_ctx if n_ctx is not None else n_ctx_train,
73967441
n_batch=n_batch,
73977442
n_ubatch=n_ubatch,
73987443
n_seq_max=n_seq_max,
@@ -7426,7 +7471,7 @@ def __init__(
74267471
self.n_ctx_seq = int(llama_cpp.llama_n_ctx_seq(ctx))
74277472
self.n_seq_max = int(llama_cpp.llama_n_seq_max(ctx))
74287473
self.n_batch = int(llama_cpp.llama_n_batch(ctx))
7429-
self.n_ctx_train = int(llama_cpp.llama_model_n_ctx_train(llama_model))
7474+
self.n_ctx_train = n_ctx_train
74307475
self.n_vocab = int(llama_cpp.llama_vocab_n_tokens(self.vocab))
74317476
self.kv_unified = kv_unified
74327477
self.max_seq_len_limit = min(self.request_context_limit, self.n_ctx_train)
@@ -7548,12 +7593,12 @@ def build_model_params(
75487593
@staticmethod
75497594
def build_context_params(
75507595
*,
7551-
n_ctx: int,
7552-
n_batch: int,
7596+
n_ctx: Optional[int],
7597+
n_batch: Optional[int],
75537598
n_ubatch: Optional[int],
7554-
n_seq_max: int,
7555-
n_threads: int,
7556-
n_threads_batch: int,
7599+
n_seq_max: Optional[int],
7600+
n_threads: Optional[int],
7601+
n_threads_batch: Optional[int],
75577602
rope_scaling_type: Optional[int],
75587603
pooling_type: Optional[int],
75597604
attention_type: Optional[int],
@@ -7574,13 +7619,18 @@ def build_context_params(
75747619
kv_unified: bool,
75757620
) -> Any:
75767621
context_params = llama_cpp.llama_context_default_params()
7577-
context_params.n_ctx = n_ctx
7578-
context_params.n_batch = min(n_ctx, n_batch)
7622+
if n_ctx is not None:
7623+
context_params.n_ctx = n_ctx
7624+
if n_batch is not None:
7625+
context_params.n_batch = min(int(context_params.n_ctx), n_batch)
75797626
if n_ubatch is not None:
7580-
context_params.n_ubatch = min(context_params.n_batch, n_ubatch)
7581-
context_params.n_seq_max = n_seq_max
7582-
context_params.n_threads = n_threads
7583-
context_params.n_threads_batch = n_threads_batch
7627+
context_params.n_ubatch = min(int(context_params.n_batch), n_ubatch)
7628+
if n_seq_max is not None:
7629+
context_params.n_seq_max = n_seq_max
7630+
if n_threads is not None:
7631+
context_params.n_threads = n_threads
7632+
if n_threads_batch is not None:
7633+
context_params.n_threads_batch = n_threads_batch
75847634
if rope_scaling_type is not None:
75857635
context_params.rope_scaling_type = rope_scaling_type
75867636
if pooling_type is not None:
@@ -7668,8 +7718,12 @@ def _meta_value(self, key: str) -> Optional[str]:
76687718
capacity = count + 1
76697719

76707720
def _build_chat_formatter(self) -> Optional[Jinja2ChatFormatter]:
7671-
template = llama_cpp.llama_model_chat_template(self.llama_model, None)
7672-
if not template:
7721+
template_text = self.chat_template_override
7722+
if template_text is None:
7723+
template = llama_cpp.llama_model_chat_template(self.llama_model, None)
7724+
if template:
7725+
template_text = template.decode("utf-8", errors="ignore")
7726+
if not template_text:
76737727
return None
76747728
bos_token = ""
76757729
eos_token = ""
@@ -7680,7 +7734,7 @@ def _build_chat_formatter(self) -> Optional[Jinja2ChatFormatter]:
76807734
eos_text = llama_cpp.llama_vocab_get_text(self.vocab, self.eos_token)
76817735
eos_token = eos_text.decode("utf-8", errors="ignore") if eos_text else ""
76827736
return Jinja2ChatFormatter(
7683-
template=template.decode("utf-8", errors="ignore"),
7737+
template=template_text,
76847738
bos_token=bos_token,
76857739
eos_token=eos_token,
76867740
)
@@ -9598,6 +9652,7 @@ def main() -> None:
95989652
model = Model(
95999653
model_path=model_path,
96009654
model_alias=config.model.alias,
9655+
chat_template=config.model.chat_template,
96019656
n_gpu_layers=config.model.n_gpu_layers,
96029657
split_mode=config.model.split_mode,
96039658
main_gpu=config.model.main_gpu,
@@ -9632,6 +9687,7 @@ def main() -> None:
96329687
prompt_chunk_size=config.model.prompt_chunk_size,
96339688
kv_unified=config.model.kv_unified,
96349689
max_seq_len=config.model.max_seq_len,
9690+
max_output_tokens=config.model.max_output_tokens,
96359691
draft_model=config.model.draft_model,
96369692
draft_model_num_pred_tokens=config.model.draft_model_num_pred_tokens,
96379693
draft_model_max_ngram_size=config.model.draft_model_max_ngram_size,

0 commit comments

Comments
 (0)