Skip to content

Commit 6236af8

Browse files
committed
feat: add special argument needed to make Granite-Docling useful
1 parent fdcb01f commit 6236af8

3 files changed

Lines changed: 54 additions & 15 deletions

File tree

examples/granite_docling/main.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
}
2424
],
2525
stream=True,
26+
special=True,
2627
)
2728

2829
for chunk in response:

llama_cpp/llama.py

Lines changed: 50 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1199,6 +1199,7 @@ def _create_completion(
11991199
logits_processor: Optional[LogitsProcessorList] = None,
12001200
grammar: Optional[LlamaGrammar] = None,
12011201
logit_bias: Optional[Dict[int, float]] = None,
1202+
special: bool = False,
12021203
) -> Union[
12031204
Iterator[CreateCompletionResponse], Iterator[CreateCompletionStreamResponse]
12041205
]:
@@ -1390,13 +1391,17 @@ def logit_bias_processor(
13901391
grammar=grammar,
13911392
):
13921393
if llama_cpp.llama_token_is_eog(self._model.vocab, token):
1393-
text = self.detokenize(completion_tokens, prev_tokens=prompt_tokens)
1394+
text = self.detokenize(
1395+
completion_tokens, prev_tokens=prompt_tokens, special=special
1396+
)
13941397
finish_reason = "stop"
13951398
break
13961399

13971400
completion_tokens.append(token)
13981401

1399-
all_text = self.detokenize(completion_tokens, prev_tokens=prompt_tokens)
1402+
all_text = self.detokenize(
1403+
completion_tokens, prev_tokens=prompt_tokens, special=special
1404+
)
14001405

14011406
# Contains multi-byte UTF8
14021407
for k, char in enumerate(all_text[-3:]):
@@ -1423,6 +1428,7 @@ def logit_bias_processor(
14231428
remaining_text = self.detokenize(
14241429
remaining_tokens,
14251430
prev_tokens=prompt_tokens + completion_tokens[:returned_tokens],
1431+
special=special,
14261432
)
14271433
remaining_length = len(remaining_text)
14281434

@@ -1450,6 +1456,7 @@ def logit_bias_processor(
14501456
[token],
14511457
prev_tokens=prompt_tokens
14521458
+ completion_tokens[:returned_tokens],
1459+
special=special,
14531460
)
14541461
)
14551462
# Check if stop sequence is in the token
@@ -1461,12 +1468,14 @@ def logit_bias_processor(
14611468
[token],
14621469
prev_tokens=prompt_tokens
14631470
+ completion_tokens[:returned_tokens],
1471+
special=special,
14641472
).decode("utf-8", errors="ignore")
14651473
text_offset = len(prompt) + len(
14661474
self.detokenize(
14671475
completion_tokens[:returned_tokens],
14681476
prev_tokens=prompt_tokens
14691477
+ completion_tokens[:returned_tokens],
1478+
special=special,
14701479
).decode("utf-8", errors="ignore")
14711480
)
14721481
token_offset = len(prompt_tokens) + returned_tokens
@@ -1479,7 +1488,7 @@ def logit_bias_processor(
14791488
)
14801489
)
14811490
top_logprob = {
1482-
self.detokenize([i]).decode(
1491+
self.detokenize([i], special=special).decode(
14831492
"utf-8", errors="ignore"
14841493
): logprob
14851494
for logprob, i in sorted_logprobs[:logprobs]
@@ -1491,6 +1500,7 @@ def logit_bias_processor(
14911500
[token],
14921501
prev_tokens=prompt_tokens
14931502
+ completion_tokens[:returned_tokens],
1503+
special=special,
14941504
).decode("utf-8", errors="ignore")
14951505
],
14961506
"text_offset": [text_offset],
@@ -1509,6 +1519,7 @@ def logit_bias_processor(
15091519
[token],
15101520
prev_tokens=prompt_tokens
15111521
+ completion_tokens[:returned_tokens],
1522+
special=special,
15121523
).decode("utf-8", errors="ignore"),
15131524
"index": 0,
15141525
"logprobs": logprobs_or_none,
@@ -1525,6 +1536,7 @@ def logit_bias_processor(
15251536
remaining_tokens[:i],
15261537
prev_tokens=prompt_tokens
15271538
+ completion_tokens[:returned_tokens],
1539+
special=special,
15281540
)
15291541
ts = bs.decode("utf-8")
15301542
decode_success = True
@@ -1560,14 +1572,18 @@ def logit_bias_processor(
15601572
}
15611573

15621574
if len(completion_tokens) >= max_tokens:
1563-
text = self.detokenize(completion_tokens, prev_tokens=prompt_tokens)
1575+
text = self.detokenize(
1576+
completion_tokens, prev_tokens=prompt_tokens, special=special
1577+
)
15641578
finish_reason = "length"
15651579
break
15661580

15671581
if stopping_criteria is not None and stopping_criteria(
15681582
self._input_ids, self._scores[-1, :]
15691583
):
1570-
text = self.detokenize(completion_tokens, prev_tokens=prompt_tokens)
1584+
text = self.detokenize(
1585+
completion_tokens, prev_tokens=prompt_tokens, special=special
1586+
)
15711587
finish_reason = "stop"
15721588

15731589
if self.verbose:
@@ -1578,6 +1594,7 @@ def logit_bias_processor(
15781594
remaining_text = self.detokenize(
15791595
remaining_tokens,
15801596
prev_tokens=prompt_tokens + completion_tokens[:returned_tokens],
1597+
special=special,
15811598
)
15821599
any_stop = [s for s in stop_sequences if s in remaining_text]
15831600
if len(any_stop) > 0:
@@ -1591,6 +1608,7 @@ def logit_bias_processor(
15911608
self.detokenize(
15921609
[token],
15931610
prev_tokens=prompt_tokens + completion_tokens[:returned_tokens],
1611+
special=special,
15941612
)
15951613
)
15961614

@@ -1599,13 +1617,16 @@ def logit_bias_processor(
15991617
if token == bos_token_id:
16001618
continue
16011619
token_str = self.detokenize([token]).decode(
1602-
"utf-8", errors="ignore"
1620+
"utf-8",
1621+
errors="ignore",
1622+
special=special,
16031623
)
16041624
text_offset = len(prompt) + len(
16051625
self.detokenize(
16061626
completion_tokens[:returned_tokens],
16071627
prev_tokens=prompt_tokens
16081628
+ completion_tokens[:returned_tokens],
1629+
special=special,
16091630
)
16101631
)
16111632
token_offset = len(prompt_tokens) + returned_tokens - 1
@@ -1618,21 +1639,26 @@ def logit_bias_processor(
16181639
)
16191640
)
16201641
top_logprob = {
1621-
self.detokenize([i]).decode("utf-8", errors="ignore"): logprob
1642+
self.detokenize([i]).decode(
1643+
"utf-8", errors="ignore", special=special
1644+
): logprob
16221645
for logprob, i in sorted_logprobs[:logprobs]
16231646
}
16241647
top_logprob.update({token_str: current_logprobs[int(token)]})
16251648
logprobs_or_none = {
16261649
"tokens": [
1627-
self.detokenize([token]).decode("utf-8", errors="ignore")
1650+
self.detokenize(
1651+
[token],
1652+
special=special,
1653+
).decode("utf-8", errors="ignore")
16281654
],
16291655
"text_offset": [text_offset],
16301656
"token_logprobs": [current_logprobs[int(token)]],
16311657
"top_logprobs": [top_logprob],
16321658
}
16331659

16341660
if token_end_position >= end:
1635-
last_text = self.detokenize([token])
1661+
last_text = self.detokenize([token], special=special)
16361662
if token_end_position == end - 1:
16371663
break
16381664
returned_tokens += 1
@@ -1661,7 +1687,7 @@ def logit_bias_processor(
16611687
"model": model_name,
16621688
"choices": [
16631689
{
1664-
"text": self.detokenize([token]).decode(
1690+
"text": self.detokenize([token], special=special).decode(
16651691
"utf-8", errors="ignore"
16661692
),
16671693
"index": 0,
@@ -1725,7 +1751,7 @@ def logit_bias_processor(
17251751

17261752
all_token_strs = [
17271753
self.detokenize([token], prev_tokens=all_tokens[:i]).decode(
1728-
"utf-8", errors="ignore"
1754+
"utf-8", errors="ignore", special=special
17291755
)
17301756
for i, token in enumerate(all_tokens)
17311757
]
@@ -1740,7 +1766,7 @@ def logit_bias_processor(
17401766
text_offset
17411767
+ len(
17421768
self.detokenize(all_tokens[:idx]).decode(
1743-
"utf-8", errors="ignore"
1769+
"utf-8", errors="ignore", special=special
17441770
)
17451771
)
17461772
)
@@ -1752,9 +1778,9 @@ def logit_bias_processor(
17521778
)
17531779
token_logprobs.append(logprobs_token[int(token)])
17541780
top_logprob: Optional[Dict[str, float]] = {
1755-
self.detokenize([i], prev_tokens=all_tokens[:idx]).decode(
1756-
"utf-8", errors="ignore"
1757-
): logprob
1781+
self.detokenize(
1782+
[i], prev_tokens=all_tokens[:idx], special=special
1783+
).decode("utf-8", errors="ignore"): logprob
17581784
for logprob, i in sorted_logprobs[:logprobs]
17591785
}
17601786
top_logprob.update({token_str: logprobs_token[int(token)]})
@@ -1819,6 +1845,7 @@ def create_completion(
18191845
logits_processor: Optional[LogitsProcessorList] = None,
18201846
grammar: Optional[LlamaGrammar] = None,
18211847
logit_bias: Optional[Dict[int, float]] = None,
1848+
special: bool = False,
18221849
) -> Union[CreateCompletionResponse, Iterator[CreateCompletionStreamResponse]]:
18231850
"""Generate text from a prompt.
18241851
@@ -1848,6 +1875,7 @@ def create_completion(
18481875
logits_processor: A list of logits processors to use.
18491876
grammar: A grammar to use for constrained sampling.
18501877
logit_bias: A logit bias to use.
1878+
special: Include special tokens in output.
18511879
18521880
Raises:
18531881
ValueError: If the requested tokens exceed the context window.
@@ -1882,6 +1910,7 @@ def create_completion(
18821910
logits_processor=logits_processor,
18831911
grammar=grammar,
18841912
logit_bias=logit_bias,
1913+
special=special,
18851914
)
18861915
if stream:
18871916
chunks: Iterator[CreateCompletionStreamResponse] = completion_or_chunks
@@ -1916,6 +1945,7 @@ def __call__(
19161945
logits_processor: Optional[LogitsProcessorList] = None,
19171946
grammar: Optional[LlamaGrammar] = None,
19181947
logit_bias: Optional[Dict[int, float]] = None,
1948+
special: bool = False,
19191949
) -> Union[CreateCompletionResponse, Iterator[CreateCompletionStreamResponse]]:
19201950
"""Generate text from a prompt.
19211951
@@ -1945,6 +1975,7 @@ def __call__(
19451975
logits_processor: A list of logits processors to use.
19461976
grammar: A grammar to use for constrained sampling.
19471977
logit_bias: A logit bias to use.
1978+
special: Include special tokens in output.
19481979
19491980
Raises:
19501981
ValueError: If the requested tokens exceed the context window.
@@ -1979,6 +2010,7 @@ def __call__(
19792010
logits_processor=logits_processor,
19802011
grammar=grammar,
19812012
logit_bias=logit_bias,
2013+
special=special,
19822014
)
19832015

19842016
def create_chat_completion(
@@ -2011,6 +2043,7 @@ def create_chat_completion(
20112043
logit_bias: Optional[Dict[int, float]] = None,
20122044
logprobs: Optional[bool] = None,
20132045
top_logprobs: Optional[int] = None,
2046+
special: bool = False,
20142047
) -> Union[
20152048
CreateChatCompletionResponse, Iterator[CreateChatCompletionStreamResponse]
20162049
]:
@@ -2043,6 +2076,7 @@ def create_chat_completion(
20432076
logits_processor: A list of logits processors to use.
20442077
grammar: A grammar to use.
20452078
logit_bias: A logit bias to use.
2079+
special: Include special tokens in output.
20462080
20472081
Returns:
20482082
Generated chat completion or a stream of chat completion chunks.
@@ -2082,6 +2116,7 @@ def create_chat_completion(
20822116
logits_processor=logits_processor,
20832117
grammar=grammar,
20842118
logit_bias=logit_bias,
2119+
special=special,
20852120
)
20862121

20872122
def create_chat_completion_openai_v1(

llama_cpp/llama_chat_format.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ def __call__(
102102
grammar: Optional[llama.LlamaGrammar] = None,
103103
logprobs: Optional[bool] = None,
104104
top_logprobs: Optional[int] = None,
105+
special: bool = False,
105106
**kwargs, # type: ignore
106107
) -> Union[
107108
llama_types.CreateChatCompletionResponse,
@@ -2798,6 +2799,7 @@ def __call__(
27982799
logit_bias: Optional[Dict[str, float]] = None,
27992800
logprobs: Optional[bool] = None,
28002801
top_logprobs: Optional[int] = None,
2802+
special: bool = False,
28012803
**kwargs, # type: ignore
28022804
) -> Union[
28032805
llama_types.CreateChatCompletionResponse,
@@ -3018,6 +3020,7 @@ def __call__(
30183020
logits_processor=logits_processor,
30193021
grammar=grammar,
30203022
logit_bias=logit_bias,
3023+
special=special,
30213024
)
30223025

30233026
if tool is not None:

0 commit comments

Comments
 (0)