@@ -1360,6 +1360,7 @@ def resolve_model_path(self) -> str:
13601360 class ModelOptions (BaseModel ):
13611361 path : Optional [str ] = None
13621362 alias : Optional [str ] = None
1363+ chat_template : Optional [str ] = None
13631364 from_pretrained : Optional ["ConfigFile.FromPretrainedOptions" ] = None
13641365 n_gpu_layers : Optional [int ] = None
13651366 split_mode : Optional [int ] = None
@@ -1369,12 +1370,16 @@ class ModelOptions(BaseModel):
13691370 use_mmap : Optional [bool ] = None
13701371 use_mlock : Optional [bool ] = None
13711372 kv_overrides : Optional [Dict [str , Union [bool , int , float , str ]]] = None
1372- n_ctx : int = 1024
1373- n_batch : int = 256
1373+ n_ctx : Optional [ int ] = None
1374+ n_batch : Optional [ int ] = None
13741375 n_ubatch : Optional [int ] = None
1375- n_seq_max : int = 64
1376- threads : int = Field (default_factory = lambda : max (multiprocessing .cpu_count () // 2 , 1 ))
1377- threads_batch : int = Field (default_factory = lambda : max (multiprocessing .cpu_count (), 1 ))
1376+ n_seq_max : Optional [int ] = None
1377+ threads : Optional [int ] = Field (
1378+ default_factory = lambda : max (multiprocessing .cpu_count (), 1 )
1379+ )
1380+ threads_batch : Optional [int ] = Field (
1381+ default_factory = lambda : max (multiprocessing .cpu_count (), 1 )
1382+ )
13781383 rope_scaling_type : Optional [int ] = None
13791384 pooling_type : Optional [int ] = None
13801385 attention_type : Optional [int ] = None
@@ -1394,6 +1399,7 @@ class ModelOptions(BaseModel):
13941399 type_v : Optional [int ] = None
13951400 prompt_chunk_size : int = 32
13961401 max_seq_len : Optional [int ] = None
1402+ max_output_tokens : Optional [int ] = Field (default = None , ge = 0 )
13971403 kv_unified : bool = True
13981404 draft_model : Optional [Literal ["prompt-lookup-decoding" ]] = None
13991405 draft_model_num_pred_tokens : int = 10
@@ -1429,16 +1435,24 @@ def __init__(self, template: str, *, bos_token: str, eos_token: str) -> None:
14291435 self ._eos_token = eos_token
14301436 self ._bos_token = bos_token
14311437 self ._template_text = template
1432- self . _template = ImmutableSandboxedEnvironment (
1438+ environment = ImmutableSandboxedEnvironment (
14331439 loader = jinja2 .BaseLoader (),
14341440 trim_blocks = True ,
14351441 lstrip_blocks = True ,
1436- ).from_string (template )
1442+ )
1443+ environment .filters ["from_json" ] = self ._from_json
1444+ self ._template = environment .from_string (template )
14371445
14381446 @staticmethod
14391447 def _strftime_now (format_string : str ) -> str :
14401448 return datetime .now ().strftime (format_string )
14411449
1450+ @staticmethod
1451+ def _from_json (value : Any ) -> Any :
1452+ if isinstance (value , str ):
1453+ return json .loads (value )
1454+ return value
1455+
14421456 def format (
14431457 self ,
14441458 * ,
@@ -1651,6 +1665,8 @@ def from_prepared(
16511665 on_error : Optional [Callable [[BaseException ], None ]] = None ,
16521666 ) -> "CompletionRequest" :
16531667 ctx_limit = model .max_seq_len
1668+ if model .max_output_tokens is not None :
1669+ ctx_limit = min (ctx_limit , len (prompt_tokens ) + model .max_output_tokens )
16541670 if payload .max_tokens is None :
16551671 effective_max_len = ctx_limit
16561672 else :
@@ -1818,6 +1834,7 @@ class StreamState:
18181834 "_tools" ,
18191835 "_completion_id" ,
18201836 "_choice_index" ,
1837+ "_prompt_opens_leading_capture" ,
18211838 "_tool_schemas" ,
18221839 "_started" ,
18231840 "_text_parts" ,
@@ -1848,11 +1865,13 @@ def __init__(
18481865 tools : Optional [List [Dict [str , Any ]]] = None ,
18491866 completion_id : str = "" ,
18501867 choice_index : int = 0 ,
1868+ prompt_opens_leading_capture : bool = False ,
18511869 ) -> None :
18521870 self ._schema = schema
18531871 self ._tools = tools
18541872 self ._completion_id = completion_id
18551873 self ._choice_index = choice_index
1874+ self ._prompt_opens_leading_capture = prompt_opens_leading_capture
18561875 self ._tool_schemas = self ._cached_tool_schema_map (tools )
18571876 self ._started = False
18581877 self ._text_parts : List [str ] = []
@@ -1922,6 +1941,16 @@ def __init__(
19221941 if self ._direct .assistant_prefix
19231942 else self .DIRECT_MODE_PRELUDE
19241943 )
1944+ if (
1945+ self ._prompt_opens_leading_capture
1946+ and self ._direct .leading_capture_field is not None
1947+ and self ._direct .leading_capture_implicit
1948+ ):
1949+ self ._direct .mode = (
1950+ self .DIRECT_MODE_LEADING_CAPTURE
1951+ if not self ._direct .assistant_prefix
1952+ else self .DIRECT_MODE_ASSISTANT_PREFIX
1953+ )
19251954 self ._stream_state = None
19261955 else :
19271956 self ._stream_state = (
@@ -2963,6 +2992,7 @@ def _advance_direct_stream_state(self, text: str) -> Tuple[bool, List[Dict[str,
29632992 leading_capture_end = self ._direct .leading_capture_end
29642993 leading_capture_strip_after = self ._direct .leading_capture_strip_after
29652994 leading_capture_implicit = self ._direct .leading_capture_implicit
2995+ prompt_opens_leading_capture = self ._prompt_opens_leading_capture
29662996 iterator_start = self ._direct .iterator_start
29672997 iterator_end = self ._direct .iterator_end
29682998 content_end_markers = self ._direct .content_end_markers
@@ -2982,6 +3012,9 @@ def _advance_direct_stream_state(self, text: str) -> Tuple[bool, List[Dict[str,
29823012 continue
29833013 if mode == self .DIRECT_MODE_PRELUDE :
29843014 if leading_capture_field is not None :
3015+ if prompt_opens_leading_capture and leading_capture_implicit :
3016+ mode = self .DIRECT_MODE_LEADING_CAPTURE
3017+ continue
29853018 if buffer .startswith (leading_capture_start ):
29863019 buffer = buffer [len (leading_capture_start ) :]
29873020 mode = self .DIRECT_MODE_LEADING_CAPTURE
@@ -5007,6 +5040,12 @@ def _uses_reasoning_content(self) -> bool:
50075040 template = self ._chat_template_text ()
50085041 return "reasoning_content" in template or "<think>" in template
50095042
5043+ def _prompt_opens_leading_capture (self ) -> bool :
5044+ template = self ._chat_template_text ()
5045+ if "<think>" not in template :
5046+ return False
5047+ return "add_generation_prompt" in template
5048+
50105049 @staticmethod
50115050 def _chat_message (data : Dict [str , Any ]) -> ChatCompletionRequestMessage :
50125051 return ChatCompletionRequestMessage .model_validate (data )
@@ -5562,6 +5601,7 @@ def _response_parser(
55625601 tools = tools ,
55635602 completion_id = completion_id ,
55645603 choice_index = choice_index ,
5604+ prompt_opens_leading_capture = self ._prompt_opens_leading_capture (),
55655605 )
55665606
55675607 def parse_chat_response (
@@ -7306,6 +7346,7 @@ def __init__(
73067346 * ,
73077347 model_path : str ,
73087348 model_alias : Optional [str ] = None ,
7349+ chat_template : Optional [str ] = None ,
73097350 n_gpu_layers : Optional [int ] = None ,
73107351 split_mode : Optional [int ] = None ,
73117352 main_gpu : Optional [int ] = None ,
@@ -7314,12 +7355,12 @@ def __init__(
73147355 use_mmap : Optional [bool ] = None ,
73157356 use_mlock : Optional [bool ] = None ,
73167357 kv_overrides : Optional [Dict [str , Union [bool , int , float , str ]]] = None ,
7317- n_ctx : int ,
7318- n_batch : int ,
7358+ n_ctx : Optional [ int ] ,
7359+ n_batch : Optional [ int ] ,
73197360 n_ubatch : Optional [int ] = None ,
7320- n_seq_max : int ,
7321- n_threads : int ,
7322- n_threads_batch : int ,
7361+ n_seq_max : Optional [ int ] ,
7362+ n_threads : Optional [ int ] ,
7363+ n_threads_batch : Optional [ int ] ,
73237364 rope_scaling_type : Optional [int ] = None ,
73247365 pooling_type : Optional [int ] = None ,
73257366 attention_type : Optional [int ] = None ,
@@ -7340,6 +7381,7 @@ def __init__(
73407381 prompt_chunk_size : int ,
73417382 kv_unified : bool = True ,
73427383 max_seq_len : Optional [int ] = None ,
7384+ max_output_tokens : Optional [int ] = None ,
73437385 draft_model : Optional [str ] = None ,
73447386 draft_model_num_pred_tokens : int = 10 ,
73457387 draft_model_max_ngram_size : int = 2 ,
@@ -7350,9 +7392,11 @@ def __init__(
73507392 self .backend_initialized = True
73517393 self .model_path = model_path
73527394 self .model_alias = model_alias
7395+ self .chat_template_override = chat_template
73537396 self .prompt_chunk_size = prompt_chunk_size
73547397 self .response_schema = response_schema
73557398 self .store_logits = store_logits
7399+ self .max_output_tokens = max_output_tokens
73567400 model_params , self ._c_tensor_split , self ._kv_overrides_array = (
73577401 self .build_model_params (
73587402 n_gpu_layers = n_gpu_layers ,
@@ -7391,8 +7435,9 @@ def __init__(
73917435 raise RuntimeError (
73927436 "speculative decoding is only supported for attention models"
73937437 )
7438+ n_ctx_train = int (llama_cpp .llama_model_n_ctx_train (llama_model ))
73947439 context_params = self .build_context_params (
7395- n_ctx = n_ctx ,
7440+ n_ctx = n_ctx if n_ctx is not None else n_ctx_train ,
73967441 n_batch = n_batch ,
73977442 n_ubatch = n_ubatch ,
73987443 n_seq_max = n_seq_max ,
@@ -7426,7 +7471,7 @@ def __init__(
74267471 self .n_ctx_seq = int (llama_cpp .llama_n_ctx_seq (ctx ))
74277472 self .n_seq_max = int (llama_cpp .llama_n_seq_max (ctx ))
74287473 self .n_batch = int (llama_cpp .llama_n_batch (ctx ))
7429- self .n_ctx_train = int ( llama_cpp . llama_model_n_ctx_train ( llama_model ))
7474+ self .n_ctx_train = n_ctx_train
74307475 self .n_vocab = int (llama_cpp .llama_vocab_n_tokens (self .vocab ))
74317476 self .kv_unified = kv_unified
74327477 self .max_seq_len_limit = min (self .request_context_limit , self .n_ctx_train )
@@ -7548,12 +7593,12 @@ def build_model_params(
75487593 @staticmethod
75497594 def build_context_params (
75507595 * ,
7551- n_ctx : int ,
7552- n_batch : int ,
7596+ n_ctx : Optional [ int ] ,
7597+ n_batch : Optional [ int ] ,
75537598 n_ubatch : Optional [int ],
7554- n_seq_max : int ,
7555- n_threads : int ,
7556- n_threads_batch : int ,
7599+ n_seq_max : Optional [ int ] ,
7600+ n_threads : Optional [ int ] ,
7601+ n_threads_batch : Optional [ int ] ,
75577602 rope_scaling_type : Optional [int ],
75587603 pooling_type : Optional [int ],
75597604 attention_type : Optional [int ],
@@ -7574,13 +7619,18 @@ def build_context_params(
75747619 kv_unified : bool ,
75757620 ) -> Any :
75767621 context_params = llama_cpp .llama_context_default_params ()
7577- context_params .n_ctx = n_ctx
7578- context_params .n_batch = min (n_ctx , n_batch )
7622+ if n_ctx is not None :
7623+ context_params .n_ctx = n_ctx
7624+ if n_batch is not None :
7625+ context_params .n_batch = min (int (context_params .n_ctx ), n_batch )
75797626 if n_ubatch is not None :
7580- context_params .n_ubatch = min (context_params .n_batch , n_ubatch )
7581- context_params .n_seq_max = n_seq_max
7582- context_params .n_threads = n_threads
7583- context_params .n_threads_batch = n_threads_batch
7627+ context_params .n_ubatch = min (int (context_params .n_batch ), n_ubatch )
7628+ if n_seq_max is not None :
7629+ context_params .n_seq_max = n_seq_max
7630+ if n_threads is not None :
7631+ context_params .n_threads = n_threads
7632+ if n_threads_batch is not None :
7633+ context_params .n_threads_batch = n_threads_batch
75847634 if rope_scaling_type is not None :
75857635 context_params .rope_scaling_type = rope_scaling_type
75867636 if pooling_type is not None :
@@ -7668,8 +7718,12 @@ def _meta_value(self, key: str) -> Optional[str]:
76687718 capacity = count + 1
76697719
76707720 def _build_chat_formatter (self ) -> Optional [Jinja2ChatFormatter ]:
7671- template = llama_cpp .llama_model_chat_template (self .llama_model , None )
7672- if not template :
7721+ template_text = self .chat_template_override
7722+ if template_text is None :
7723+ template = llama_cpp .llama_model_chat_template (self .llama_model , None )
7724+ if template :
7725+ template_text = template .decode ("utf-8" , errors = "ignore" )
7726+ if not template_text :
76737727 return None
76747728 bos_token = ""
76757729 eos_token = ""
@@ -7680,7 +7734,7 @@ def _build_chat_formatter(self) -> Optional[Jinja2ChatFormatter]:
76807734 eos_text = llama_cpp .llama_vocab_get_text (self .vocab , self .eos_token )
76817735 eos_token = eos_text .decode ("utf-8" , errors = "ignore" ) if eos_text else ""
76827736 return Jinja2ChatFormatter (
7683- template = template . decode ( "utf-8" , errors = "ignore" ) ,
7737+ template = template_text ,
76847738 bos_token = bos_token ,
76857739 eos_token = eos_token ,
76867740 )
@@ -9598,6 +9652,7 @@ def main() -> None:
95989652 model = Model (
95999653 model_path = model_path ,
96009654 model_alias = config .model .alias ,
9655+ chat_template = config .model .chat_template ,
96019656 n_gpu_layers = config .model .n_gpu_layers ,
96029657 split_mode = config .model .split_mode ,
96039658 main_gpu = config .model .main_gpu ,
@@ -9632,6 +9687,7 @@ def main() -> None:
96329687 prompt_chunk_size = config .model .prompt_chunk_size ,
96339688 kv_unified = config .model .kv_unified ,
96349689 max_seq_len = config .model .max_seq_len ,
9690+ max_output_tokens = config .model .max_output_tokens ,
96359691 draft_model = config .model .draft_model ,
96369692 draft_model_num_pred_tokens = config .model .draft_model_num_pred_tokens ,
96379693 draft_model_max_ngram_size = config .model .draft_model_max_ngram_size ,
0 commit comments