@@ -1199,6 +1199,7 @@ def _create_completion(
11991199 logits_processor : Optional [LogitsProcessorList ] = None ,
12001200 grammar : Optional [LlamaGrammar ] = None ,
12011201 logit_bias : Optional [Dict [int , float ]] = None ,
1202+ special : bool = False ,
12021203 ) -> Union [
12031204 Iterator [CreateCompletionResponse ], Iterator [CreateCompletionStreamResponse ]
12041205 ]:
@@ -1390,13 +1391,17 @@ def logit_bias_processor(
13901391 grammar = grammar ,
13911392 ):
13921393 if llama_cpp .llama_token_is_eog (self ._model .vocab , token ):
1393- text = self .detokenize (completion_tokens , prev_tokens = prompt_tokens )
1394+ text = self .detokenize (
1395+ completion_tokens , prev_tokens = prompt_tokens , special = special
1396+ )
13941397 finish_reason = "stop"
13951398 break
13961399
13971400 completion_tokens .append (token )
13981401
1399- all_text = self .detokenize (completion_tokens , prev_tokens = prompt_tokens )
1402+ all_text = self .detokenize (
1403+ completion_tokens , prev_tokens = prompt_tokens , special = special
1404+ )
14001405
14011406 # Contains multi-byte UTF8
14021407 for k , char in enumerate (all_text [- 3 :]):
@@ -1423,6 +1428,7 @@ def logit_bias_processor(
14231428 remaining_text = self .detokenize (
14241429 remaining_tokens ,
14251430 prev_tokens = prompt_tokens + completion_tokens [:returned_tokens ],
1431+ special = special ,
14261432 )
14271433 remaining_length = len (remaining_text )
14281434
@@ -1450,6 +1456,7 @@ def logit_bias_processor(
14501456 [token ],
14511457 prev_tokens = prompt_tokens
14521458 + completion_tokens [:returned_tokens ],
1459+ special = special ,
14531460 )
14541461 )
14551462 # Check if stop sequence is in the token
@@ -1461,12 +1468,14 @@ def logit_bias_processor(
14611468 [token ],
14621469 prev_tokens = prompt_tokens
14631470 + completion_tokens [:returned_tokens ],
1471+ special = special ,
14641472 ).decode ("utf-8" , errors = "ignore" )
14651473 text_offset = len (prompt ) + len (
14661474 self .detokenize (
14671475 completion_tokens [:returned_tokens ],
14681476 prev_tokens = prompt_tokens
14691477 + completion_tokens [:returned_tokens ],
1478+ special = special ,
14701479 ).decode ("utf-8" , errors = "ignore" )
14711480 )
14721481 token_offset = len (prompt_tokens ) + returned_tokens
@@ -1479,7 +1488,7 @@ def logit_bias_processor(
14791488 )
14801489 )
14811490 top_logprob = {
1482- self .detokenize ([i ]).decode (
1491+ self .detokenize ([i ], special = special ).decode (
14831492 "utf-8" , errors = "ignore"
14841493 ): logprob
14851494 for logprob , i in sorted_logprobs [:logprobs ]
@@ -1491,6 +1500,7 @@ def logit_bias_processor(
14911500 [token ],
14921501 prev_tokens = prompt_tokens
14931502 + completion_tokens [:returned_tokens ],
1503+ special = special ,
14941504 ).decode ("utf-8" , errors = "ignore" )
14951505 ],
14961506 "text_offset" : [text_offset ],
@@ -1509,6 +1519,7 @@ def logit_bias_processor(
15091519 [token ],
15101520 prev_tokens = prompt_tokens
15111521 + completion_tokens [:returned_tokens ],
1522+ special = special ,
15121523 ).decode ("utf-8" , errors = "ignore" ),
15131524 "index" : 0 ,
15141525 "logprobs" : logprobs_or_none ,
@@ -1525,6 +1536,7 @@ def logit_bias_processor(
15251536 remaining_tokens [:i ],
15261537 prev_tokens = prompt_tokens
15271538 + completion_tokens [:returned_tokens ],
1539+ special = special ,
15281540 )
15291541 ts = bs .decode ("utf-8" )
15301542 decode_success = True
@@ -1560,14 +1572,18 @@ def logit_bias_processor(
15601572 }
15611573
15621574 if len (completion_tokens ) >= max_tokens :
1563- text = self .detokenize (completion_tokens , prev_tokens = prompt_tokens )
1575+ text = self .detokenize (
1576+ completion_tokens , prev_tokens = prompt_tokens , special = special
1577+ )
15641578 finish_reason = "length"
15651579 break
15661580
15671581 if stopping_criteria is not None and stopping_criteria (
15681582 self ._input_ids , self ._scores [- 1 , :]
15691583 ):
1570- text = self .detokenize (completion_tokens , prev_tokens = prompt_tokens )
1584+ text = self .detokenize (
1585+ completion_tokens , prev_tokens = prompt_tokens , special = special
1586+ )
15711587 finish_reason = "stop"
15721588
15731589 if self .verbose :
@@ -1578,6 +1594,7 @@ def logit_bias_processor(
15781594 remaining_text = self .detokenize (
15791595 remaining_tokens ,
15801596 prev_tokens = prompt_tokens + completion_tokens [:returned_tokens ],
1597+ special = special ,
15811598 )
15821599 any_stop = [s for s in stop_sequences if s in remaining_text ]
15831600 if len (any_stop ) > 0 :
@@ -1591,6 +1608,7 @@ def logit_bias_processor(
15911608 self .detokenize (
15921609 [token ],
15931610 prev_tokens = prompt_tokens + completion_tokens [:returned_tokens ],
1611+ special = special ,
15941612 )
15951613 )
15961614
@@ -1599,13 +1617,16 @@ def logit_bias_processor(
15991617 if token == bos_token_id :
16001618 continue
16011619 token_str = self .detokenize ([token ]).decode (
1602- "utf-8" , errors = "ignore"
1620+ "utf-8" ,
1621+ errors = "ignore" ,
1622+ special = special ,
16031623 )
16041624 text_offset = len (prompt ) + len (
16051625 self .detokenize (
16061626 completion_tokens [:returned_tokens ],
16071627 prev_tokens = prompt_tokens
16081628 + completion_tokens [:returned_tokens ],
1629+ special = special ,
16091630 )
16101631 )
16111632 token_offset = len (prompt_tokens ) + returned_tokens - 1
@@ -1618,21 +1639,26 @@ def logit_bias_processor(
16181639 )
16191640 )
16201641 top_logprob = {
1621- self .detokenize ([i ]).decode ("utf-8" , errors = "ignore" ): logprob
1642+ self .detokenize ([i ]).decode (
1643+ "utf-8" , errors = "ignore" , special = special
1644+ ): logprob
16221645 for logprob , i in sorted_logprobs [:logprobs ]
16231646 }
16241647 top_logprob .update ({token_str : current_logprobs [int (token )]})
16251648 logprobs_or_none = {
16261649 "tokens" : [
1627- self .detokenize ([token ]).decode ("utf-8" , errors = "ignore" )
1650+ self .detokenize (
1651+ [token ],
1652+ special = special ,
1653+ ).decode ("utf-8" , errors = "ignore" )
16281654 ],
16291655 "text_offset" : [text_offset ],
16301656 "token_logprobs" : [current_logprobs [int (token )]],
16311657 "top_logprobs" : [top_logprob ],
16321658 }
16331659
16341660 if token_end_position >= end :
1635- last_text = self .detokenize ([token ])
1661+ last_text = self .detokenize ([token ], special = special )
16361662 if token_end_position == end - 1 :
16371663 break
16381664 returned_tokens += 1
@@ -1661,7 +1687,7 @@ def logit_bias_processor(
16611687 "model" : model_name ,
16621688 "choices" : [
16631689 {
1664- "text" : self .detokenize ([token ]).decode (
1690+ "text" : self .detokenize ([token ], special = special ).decode (
16651691 "utf-8" , errors = "ignore"
16661692 ),
16671693 "index" : 0 ,
@@ -1725,7 +1751,7 @@ def logit_bias_processor(
17251751
17261752 all_token_strs = [
17271753 self .detokenize ([token ], prev_tokens = all_tokens [:i ]).decode (
1728- "utf-8" , errors = "ignore"
1754+ "utf-8" , errors = "ignore" , special = special
17291755 )
17301756 for i , token in enumerate (all_tokens )
17311757 ]
@@ -1740,7 +1766,7 @@ def logit_bias_processor(
17401766 text_offset
17411767 + len (
17421768 self .detokenize (all_tokens [:idx ]).decode (
1743- "utf-8" , errors = "ignore"
1769+ "utf-8" , errors = "ignore" , special = special
17441770 )
17451771 )
17461772 )
@@ -1752,9 +1778,9 @@ def logit_bias_processor(
17521778 )
17531779 token_logprobs .append (logprobs_token [int (token )])
17541780 top_logprob : Optional [Dict [str , float ]] = {
1755- self .detokenize ([ i ], prev_tokens = all_tokens [: idx ]). decode (
1756- "utf-8" , errors = "ignore"
1757- ): logprob
1781+ self .detokenize (
1782+ [ i ], prev_tokens = all_tokens [: idx ], special = special
1783+ ). decode ( "utf-8" , errors = "ignore" ) : logprob
17581784 for logprob , i in sorted_logprobs [:logprobs ]
17591785 }
17601786 top_logprob .update ({token_str : logprobs_token [int (token )]})
@@ -1819,6 +1845,7 @@ def create_completion(
18191845 logits_processor : Optional [LogitsProcessorList ] = None ,
18201846 grammar : Optional [LlamaGrammar ] = None ,
18211847 logit_bias : Optional [Dict [int , float ]] = None ,
1848+ special : bool = False ,
18221849 ) -> Union [CreateCompletionResponse , Iterator [CreateCompletionStreamResponse ]]:
18231850 """Generate text from a prompt.
18241851
@@ -1848,6 +1875,7 @@ def create_completion(
18481875 logits_processor: A list of logits processors to use.
18491876 grammar: A grammar to use for constrained sampling.
18501877 logit_bias: A logit bias to use.
1878+ special: Include special tokens in output.
18511879
18521880 Raises:
18531881 ValueError: If the requested tokens exceed the context window.
@@ -1882,6 +1910,7 @@ def create_completion(
18821910 logits_processor = logits_processor ,
18831911 grammar = grammar ,
18841912 logit_bias = logit_bias ,
1913+ special = special ,
18851914 )
18861915 if stream :
18871916 chunks : Iterator [CreateCompletionStreamResponse ] = completion_or_chunks
@@ -1916,6 +1945,7 @@ def __call__(
19161945 logits_processor : Optional [LogitsProcessorList ] = None ,
19171946 grammar : Optional [LlamaGrammar ] = None ,
19181947 logit_bias : Optional [Dict [int , float ]] = None ,
1948+ special : bool = False ,
19191949 ) -> Union [CreateCompletionResponse , Iterator [CreateCompletionStreamResponse ]]:
19201950 """Generate text from a prompt.
19211951
@@ -1945,6 +1975,7 @@ def __call__(
19451975 logits_processor: A list of logits processors to use.
19461976 grammar: A grammar to use for constrained sampling.
19471977 logit_bias: A logit bias to use.
1978+ special: Include special tokens in output.
19481979
19491980 Raises:
19501981 ValueError: If the requested tokens exceed the context window.
@@ -1979,6 +2010,7 @@ def __call__(
19792010 logits_processor = logits_processor ,
19802011 grammar = grammar ,
19812012 logit_bias = logit_bias ,
2013+ special = special ,
19822014 )
19832015
19842016 def create_chat_completion (
@@ -2011,6 +2043,7 @@ def create_chat_completion(
20112043 logit_bias : Optional [Dict [int , float ]] = None ,
20122044 logprobs : Optional [bool ] = None ,
20132045 top_logprobs : Optional [int ] = None ,
2046+ special : bool = False ,
20142047 ) -> Union [
20152048 CreateChatCompletionResponse , Iterator [CreateChatCompletionStreamResponse ]
20162049 ]:
@@ -2043,6 +2076,7 @@ def create_chat_completion(
20432076 logits_processor: A list of logits processors to use.
20442077 grammar: A grammar to use.
20452078 logit_bias: A logit bias to use.
2079+ special: Include special tokens in output.
20462080
20472081 Returns:
20482082 Generated chat completion or a stream of chat completion chunks.
@@ -2082,6 +2116,7 @@ def create_chat_completion(
20822116 logits_processor = logits_processor ,
20832117 grammar = grammar ,
20842118 logit_bias = logit_bias ,
2119+ special = special ,
20852120 )
20862121
20872122 def create_chat_completion_openai_v1 (
0 commit comments