Skip to content

Commit dcae304

Browse files
authored
Merge pull request #153 from PaperDebugger/feat/byok
feat: BYOK enhancements
2 parents 56d815e + 549b5fb commit dcae304

11 files changed

Lines changed: 404 additions & 126 deletions

File tree

internal/api/chat/create_conversation_message_stream_v2.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,7 @@ func (s *ChatServerV2) CreateConversationMessageStream(
321321
}
322322
}
323323

324-
openaiChatHistory, inappChatHistory, err := s.aiClientV2.ChatCompletionStreamV2(ctx, stream, conversation.ID.Hex(), modelSlug, conversation.OpenaiChatHistoryCompletion, llmProvider)
324+
openaiChatHistory, inappChatHistory, err := s.aiClientV2.ChatCompletionStreamV2(ctx, stream, conversation.ID.Hex(), modelSlug, conversation.OpenaiChatHistoryCompletion, llmProvider, customModel)
325325
if err != nil {
326326
return s.sendStreamError(stream, err)
327327
}
@@ -347,7 +347,7 @@ func (s *ChatServerV2) CreateConversationMessageStream(
347347
for i, bsonMsg := range conversation.InappChatHistory {
348348
protoMessages[i] = mapper.BSONToChatMessageV2(bsonMsg)
349349
}
350-
title, err := s.aiClientV2.GetConversationTitleV2(ctx, protoMessages, llmProvider, modelSlug)
350+
title, err := s.aiClientV2.GetConversationTitleV2(ctx, protoMessages, llmProvider, modelSlug, customModel)
351351
if err != nil {
352352
s.logger.Error("Failed to get conversation title", "error", err, "conversationID", conversation.ID.Hex())
353353
return

internal/api/mapper/user.go

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,18 @@ func MapProtoSettingsToModel(settings *userv1.Settings) *models.Settings {
1919
}
2020

2121
customModels[i] = models.CustomModel{
22-
Id: id,
23-
Slug: m.Slug,
24-
Name: m.Name,
25-
BaseUrl: m.BaseUrl,
26-
APIKey: m.ApiKey,
27-
ContextWindow: m.ContextWindow,
28-
MaxOutput: m.MaxOutput,
29-
InputPrice: m.InputPrice,
30-
OutputPrice: m.OutputPrice,
22+
Id: id,
23+
Slug: m.Slug,
24+
Name: m.Name,
25+
BaseUrl: m.BaseUrl,
26+
APIKey: m.ApiKey,
27+
ContextWindow: m.ContextWindow,
28+
MaxOutput: m.MaxOutput,
29+
InputPrice: m.InputPrice,
30+
OutputPrice: m.OutputPrice,
31+
Temperature: m.Temperature,
32+
ParallelToolCalls: m.ParallelToolCalls,
33+
Store: m.Store,
3134
}
3235
}
3336

@@ -47,15 +50,18 @@ func MapModelSettingsToProto(settings *models.Settings) *userv1.Settings {
4750
customModels := make([]*userv1.CustomModel, len(settings.CustomModels))
4851
for i, m := range settings.CustomModels {
4952
customModels[i] = &userv1.CustomModel{
50-
Id: m.Id.Hex(),
51-
Slug: m.Slug,
52-
Name: m.Name,
53-
BaseUrl: m.BaseUrl,
54-
ApiKey: m.APIKey,
55-
ContextWindow: m.ContextWindow,
56-
MaxOutput: m.MaxOutput,
57-
InputPrice: m.InputPrice,
58-
OutputPrice: m.OutputPrice,
53+
Id: m.Id.Hex(),
54+
Slug: m.Slug,
55+
Name: m.Name,
56+
BaseUrl: m.BaseUrl,
57+
ApiKey: m.APIKey,
58+
ContextWindow: m.ContextWindow,
59+
MaxOutput: m.MaxOutput,
60+
InputPrice: m.InputPrice,
61+
OutputPrice: m.OutputPrice,
62+
Temperature: m.Temperature,
63+
ParallelToolCalls: m.ParallelToolCalls,
64+
Store: m.Store,
5965
}
6066
}
6167

internal/models/user.go

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,18 @@ package models
33
import "go.mongodb.org/mongo-driver/v2/bson"
44

55
type CustomModel struct {
6-
Id bson.ObjectID `bson:"_id"`
7-
Slug string `bson:"slug"`
8-
Name string `bson:"name"`
9-
BaseUrl string `bson:"base_url"`
10-
APIKey string `bson:"api_key"`
11-
ContextWindow int32 `bson:"context_window"`
12-
MaxOutput int32 `bson:"max_output"`
13-
InputPrice int32 `bson:"input_price"`
14-
OutputPrice int32 `bson:"output_price"`
6+
Id bson.ObjectID `bson:"_id"`
7+
Slug string `bson:"slug"`
8+
Name string `bson:"name"`
9+
BaseUrl string `bson:"base_url"`
10+
APIKey string `bson:"api_key"`
11+
ContextWindow int32 `bson:"context_window"`
12+
MaxOutput int32 `bson:"max_output"`
13+
InputPrice int32 `bson:"input_price"`
14+
OutputPrice int32 `bson:"output_price"`
15+
Temperature float32 `bson:"temperature"`
16+
ParallelToolCalls bool `bson:"parallel_tool_calls"`
17+
Store bool `bson:"store"`
1518
}
1619

1720
type Settings struct {

internal/services/toolkit/client/completion_v2.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ import (
2525
// 1. The full chat history sent to the language model (including any tool call results).
2626
// 2. The incremental chat history visible to the user (including tool call results and assistant responses).
2727
// 3. An error, if any occurred during the process.
28-
func (a *AIClientV2) ChatCompletionV2(ctx context.Context, modelSlug string, messages OpenAIChatHistory, llmProvider *models.LLMProviderConfig) (OpenAIChatHistory, AppChatHistory, error) {
29-
openaiChatHistory, inappChatHistory, err := a.ChatCompletionStreamV2(ctx, nil, "", modelSlug, messages, llmProvider)
28+
func (a *AIClientV2) ChatCompletionV2(ctx context.Context, modelSlug string, messages OpenAIChatHistory, llmProvider *models.LLMProviderConfig, customModel *models.CustomModel) (OpenAIChatHistory, AppChatHistory, error) {
29+
openaiChatHistory, inappChatHistory, err := a.ChatCompletionStreamV2(ctx, nil, "", modelSlug, messages, llmProvider, customModel)
3030
if err != nil {
3131
return nil, nil, err
3232
}
@@ -54,7 +54,7 @@ func (a *AIClientV2) ChatCompletionV2(ctx context.Context, modelSlug string, mes
5454
// - If tool calls are required, it handles them and appends the results to the chat history, then continues the loop.
5555
// - If no tool calls are needed, it appends the assistant's response and exits the loop.
5656
// - Finally, it returns the updated chat histories and any error encountered.
57-
func (a *AIClientV2) ChatCompletionStreamV2(ctx context.Context, callbackStream chatv2.ChatService_CreateConversationMessageStreamServer, conversationId string, modelSlug string, messages OpenAIChatHistory, llmProvider *models.LLMProviderConfig) (OpenAIChatHistory, AppChatHistory, error) {
57+
func (a *AIClientV2) ChatCompletionStreamV2(ctx context.Context, callbackStream chatv2.ChatService_CreateConversationMessageStreamServer, conversationId string, modelSlug string, messages OpenAIChatHistory, llmProvider *models.LLMProviderConfig, customModel *models.CustomModel) (OpenAIChatHistory, AppChatHistory, error) {
5858
openaiChatHistory := messages
5959
inappChatHistory := AppChatHistory{}
6060

@@ -66,7 +66,7 @@ func (a *AIClientV2) ChatCompletionStreamV2(ctx context.Context, callbackStream
6666
}()
6767

6868
oaiClient := a.GetOpenAIClient(llmProvider)
69-
params := getDefaultParamsV2(modelSlug, a.toolCallHandler.Registry, llmProvider.IsCustomModel)
69+
params := getDefaultParamsV2(modelSlug, a.toolCallHandler.Registry, customModel)
7070

7171
for {
7272
params.Messages = openaiChatHistory

internal/services/toolkit/client/get_citation_keys.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ func (a *AIClientV2) GetCitationKeys(ctx context.Context, sentence string, userI
244244
_, resp, err := a.ChatCompletionV2(ctx, "gpt-5.2", OpenAIChatHistory{
245245
openai.SystemMessage("You are a helpful assistant that suggests relevant citation keys."),
246246
openai.UserMessage(message),
247-
}, llmProvider)
247+
}, llmProvider, nil)
248248

249249
if err != nil {
250250
return nil, err

internal/services/toolkit/client/get_conversation_title_v2.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ import (
1313
"github.com/samber/lo"
1414
)
1515

16-
func (a *AIClientV2) GetConversationTitleV2(ctx context.Context, inappChatHistory []*chatv2.Message, llmProvider *models.LLMProviderConfig, modelSlug string) (string, error) {
16+
func (a *AIClientV2) GetConversationTitleV2(ctx context.Context, inappChatHistory []*chatv2.Message, llmProvider *models.LLMProviderConfig, modelSlug string, customModel *models.CustomModel) (string, error) {
1717
messages := lo.Map(inappChatHistory, func(message *chatv2.Message, _ int) string {
1818
if _, ok := message.Payload.MessageType.(*chatv2.MessagePayload_Assistant); ok {
1919
return fmt.Sprintf("Assistant: %s", message.Payload.GetAssistant().GetContent())
@@ -38,7 +38,7 @@ func (a *AIClientV2) GetConversationTitleV2(ctx context.Context, inappChatHistor
3838
_, resp, err := a.ChatCompletionV2(ctx, modelToUse, OpenAIChatHistory{
3939
openai.SystemMessage("You are a helpful assistant that generates a title for a conversation."),
4040
openai.UserMessage(message),
41-
}, llmProvider)
41+
}, llmProvider, customModel)
4242
if err != nil {
4343
return "", err
4444
}

internal/services/toolkit/client/utils_v2.go

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"paperdebugger/internal/libs/cfg"
1111
"paperdebugger/internal/libs/db"
1212
"paperdebugger/internal/libs/logger"
13+
"paperdebugger/internal/models"
1314
"paperdebugger/internal/services"
1415
"paperdebugger/internal/services/toolkit/registry"
1516
filetools "paperdebugger/internal/services/toolkit/tools/files"
@@ -53,7 +54,7 @@ func appendAssistantTextResponseV2(openaiChatHistory *OpenAIChatHistory, inappCh
5354
})
5455
}
5556

56-
func getDefaultParamsV2(modelSlug string, toolRegistry *registry.ToolRegistryV2, isCustomModel bool) openaiv3.ChatCompletionNewParams {
57+
func getDefaultParamsV2(modelSlug string, toolRegistry *registry.ToolRegistryV2, customModel *models.CustomModel) openaiv3.ChatCompletionNewParams {
5758
var reasoningModels = []string{
5859
"gpt-5",
5960
"gpt-5-mini",
@@ -67,15 +68,22 @@ func getDefaultParamsV2(modelSlug string, toolRegistry *registry.ToolRegistryV2,
6768
"codex-mini-latest",
6869
}
6970

70-
// Other model providers generally do not support the Store param
71-
if isCustomModel {
72-
return openaiv3.ChatCompletionNewParams{
73-
Model: modelSlug,
74-
Temperature: openaiv3.Float(0.7),
75-
MaxCompletionTokens: openaiv3.Int(4000),
71+
if customModel != nil {
72+
params := openaiv3.ChatCompletionNewParams{
73+
Model: customModel.Slug,
74+
Temperature: openaiv3.Float(float64(customModel.Temperature)),
75+
MaxCompletionTokens: openaiv3.Int(int64(customModel.MaxOutput)),
7676
Tools: toolRegistry.GetTools(),
77-
ParallelToolCalls: openaiv3.Bool(true),
77+
ParallelToolCalls: openaiv3.Bool(customModel.ParallelToolCalls),
7878
}
79+
80+
// Store param should only be included if it is true
81+
// Some providers like Gemini might not support the param at all even if false
82+
if customModel.Store {
83+
params.Store = openaiv3.Bool(customModel.Store)
84+
}
85+
86+
return params
7987
}
8088

8189
for _, model := range reasoningModels {

pkg/gen/api/user/v1/user.pb.go

Lines changed: 42 additions & 14 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

proto/user/v1/user.proto

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,9 @@ message CustomModel {
124124
int32 max_output = 7;
125125
int32 input_price = 8;
126126
int32 output_price = 9;
127+
float temperature = 10;
128+
bool parallel_tool_calls = 11;
129+
bool store = 12;
127130
}
128131

129132
message Settings {

0 commit comments

Comments
 (0)