|
1 | 1 | package chat |
2 | | - |
3 | | -import ( |
4 | | - "context" |
5 | | - |
6 | | - "paperdebugger/internal/libs/contextutil" |
7 | | - "paperdebugger/internal/libs/shared" |
8 | | - "paperdebugger/internal/models" |
9 | | - chatv1 "paperdebugger/pkg/gen/api/chat/v1" |
10 | | - |
11 | | - "github.com/google/uuid" |
12 | | - "github.com/openai/openai-go/v2/responses" |
13 | | - "go.mongodb.org/mongo-driver/v2/bson" |
14 | | - "go.mongodb.org/mongo-driver/v2/mongo" |
15 | | - "google.golang.org/protobuf/encoding/protojson" |
16 | | -) |
17 | | - |
18 | | -// 设计理念: |
19 | | -// 发送给 GPT 之前,消息列表已经构造进 Conversation 对象中(也保存在数据库里) |
20 | | -// 我们发送给 GPT 的就是从数据库里拿到的 Conversation 对象里面的内容(InputItemList) |
21 | | - |
22 | | -// buildUserMessage constructs both the user-facing message and the OpenAI input message |
23 | | -func (s *ChatServer) buildUserMessage(ctx context.Context, userMessage, userSelectedText string, conversationType chatv1.ConversationType) (*chatv1.Message, *responses.ResponseInputItemUnionParam, error) { |
24 | | - userPrompt, err := s.chatService.GetPrompt(ctx, userMessage, userSelectedText, conversationType) |
25 | | - if err != nil { |
26 | | - return nil, nil, err |
27 | | - } |
28 | | - |
29 | | - var inappMessage *chatv1.Message |
30 | | - switch conversationType { |
31 | | - case chatv1.ConversationType_CONVERSATION_TYPE_DEBUG: |
32 | | - inappMessage = &chatv1.Message{ |
33 | | - MessageId: "pd_msg_user_" + uuid.New().String(), |
34 | | - Payload: &chatv1.MessagePayload{ |
35 | | - MessageType: &chatv1.MessagePayload_User{ |
36 | | - User: &chatv1.MessageTypeUser{ |
37 | | - Content: userPrompt, |
38 | | - }, |
39 | | - }, |
40 | | - }, |
41 | | - } |
42 | | - default: |
43 | | - inappMessage = &chatv1.Message{ |
44 | | - MessageId: "pd_msg_user_" + uuid.New().String(), |
45 | | - Payload: &chatv1.MessagePayload{ |
46 | | - MessageType: &chatv1.MessagePayload_User{ |
47 | | - User: &chatv1.MessageTypeUser{ |
48 | | - Content: userMessage, |
49 | | - SelectedText: &userSelectedText, |
50 | | - }, |
51 | | - }, |
52 | | - }, |
53 | | - } |
54 | | - } |
55 | | - |
56 | | - openaiMessage := &responses.ResponseInputItemUnionParam{ |
57 | | - OfInputMessage: &responses.ResponseInputItemMessageParam{ |
58 | | - Role: "user", |
59 | | - Content: responses.ResponseInputMessageContentListParam{ |
60 | | - responses.ResponseInputContentParamOfInputText(userPrompt), |
61 | | - }, |
62 | | - }, |
63 | | - } |
64 | | - |
65 | | - return inappMessage, openaiMessage, nil |
66 | | -} |
67 | | - |
68 | | -// buildSystemMessage constructs both the user-facing system message and the OpenAI input message |
69 | | -func (s *ChatServer) buildSystemMessage(systemPrompt string) (*chatv1.Message, *responses.ResponseInputItemUnionParam) { |
70 | | - inappMessage := &chatv1.Message{ |
71 | | - MessageId: "pd_msg_system_" + uuid.New().String(), |
72 | | - Payload: &chatv1.MessagePayload{ |
73 | | - MessageType: &chatv1.MessagePayload_System{ |
74 | | - System: &chatv1.MessageTypeSystem{ |
75 | | - Content: systemPrompt, |
76 | | - }, |
77 | | - }, |
78 | | - }, |
79 | | - } |
80 | | - |
81 | | - openaiMessage := &responses.ResponseInputItemUnionParam{ |
82 | | - OfInputMessage: &responses.ResponseInputItemMessageParam{ |
83 | | - Role: "system", |
84 | | - Content: responses.ResponseInputMessageContentListParam{ |
85 | | - responses.ResponseInputContentParamOfInputText(systemPrompt), |
86 | | - }, |
87 | | - }, |
88 | | - } |
89 | | - |
90 | | - return inappMessage, openaiMessage |
91 | | -} |
92 | | - |
93 | | -// convertToBSON converts a protobuf message to BSON |
94 | | -func convertToBSON(msg *chatv1.Message) (bson.M, error) { |
95 | | - jsonBytes, err := protojson.Marshal(msg) |
96 | | - if err != nil { |
97 | | - return nil, err |
98 | | - } |
99 | | - var bsonMsg bson.M |
100 | | - if err := bson.UnmarshalExtJSON(jsonBytes, true, &bsonMsg); err != nil { |
101 | | - return nil, err |
102 | | - } |
103 | | - return bsonMsg, nil |
104 | | -} |
105 | | - |
106 | | -// 创建对话并写入数据库 |
107 | | -// 返回 Conversation 对象 |
108 | | -func (s *ChatServer) createConversation( |
109 | | - ctx context.Context, |
110 | | - userId bson.ObjectID, |
111 | | - projectId string, |
112 | | - latexFullSource string, |
113 | | - projectInstructions string, |
114 | | - userInstructions string, |
115 | | - userMessage string, |
116 | | - userSelectedText string, |
117 | | - modelSlug string, |
118 | | - conversationType chatv1.ConversationType, |
119 | | -) (*models.Conversation, error) { |
120 | | - systemPrompt, err := s.chatService.GetSystemPrompt(ctx, latexFullSource, projectInstructions, userInstructions, conversationType) |
121 | | - if err != nil { |
122 | | - return nil, err |
123 | | - } |
124 | | - |
125 | | - _, openaiSystemMsg := s.buildSystemMessage(systemPrompt) |
126 | | - inappUserMsg, openaiUserMsg, err := s.buildUserMessage(ctx, userMessage, userSelectedText, conversationType) |
127 | | - if err != nil { |
128 | | - return nil, err |
129 | | - } |
130 | | - |
131 | | - messages := []*chatv1.Message{inappUserMsg} |
132 | | - oaiHistory := responses.ResponseNewParamsInputUnion{ |
133 | | - OfInputItemList: responses.ResponseInputParam{*openaiSystemMsg, *openaiUserMsg}, |
134 | | - } |
135 | | - |
136 | | - return s.chatService.InsertConversationToDB( |
137 | | - ctx, userId, projectId, modelSlug, messages, oaiHistory.OfInputItemList, |
138 | | - ) |
139 | | -} |
140 | | - |
141 | | -// 追加消息到对话并写入数据库 |
142 | | -// 返回 Conversation 对象 |
143 | | -func (s *ChatServer) appendConversationMessage( |
144 | | - ctx context.Context, |
145 | | - userId bson.ObjectID, |
146 | | - conversationId string, |
147 | | - userMessage string, |
148 | | - userSelectedText string, |
149 | | - conversationType chatv1.ConversationType, |
150 | | -) (*models.Conversation, error) { |
151 | | - objectID, err := bson.ObjectIDFromHex(conversationId) |
152 | | - if err != nil { |
153 | | - return nil, err |
154 | | - } |
155 | | - |
156 | | - conversation, err := s.chatService.GetConversation(ctx, userId, objectID) |
157 | | - if err != nil { |
158 | | - return nil, err |
159 | | - } |
160 | | - |
161 | | - userMsg, userOaiMsg, err := s.buildUserMessage(ctx, userMessage, userSelectedText, conversationType) |
162 | | - if err != nil { |
163 | | - return nil, err |
164 | | - } |
165 | | - |
166 | | - bsonMsg, err := convertToBSON(userMsg) |
167 | | - if err != nil { |
168 | | - return nil, err |
169 | | - } |
170 | | - conversation.InappChatHistory = append(conversation.InappChatHistory, bsonMsg) |
171 | | - conversation.OpenaiChatHistory = append(conversation.OpenaiChatHistory, *userOaiMsg) |
172 | | - |
173 | | - if err := s.chatService.UpdateConversation(conversation); err != nil { |
174 | | - return nil, err |
175 | | - } |
176 | | - |
177 | | - return conversation, nil |
178 | | -} |
179 | | - |
180 | | -// 如果 conversationId 是 "", 就创建新对话,否则就追加消息到对话 |
181 | | -// conversationType 可以在一次 conversation 中多次切换 |
182 | | -func (s *ChatServer) prepare(ctx context.Context, projectId string, conversationId string, userMessage string, userSelectedText string, modelSlug string, conversationType chatv1.ConversationType) (context.Context, *models.Conversation, *models.Settings, error) { |
183 | | - actor, err := contextutil.GetActor(ctx) |
184 | | - if err != nil { |
185 | | - return ctx, nil, nil, err |
186 | | - } |
187 | | - |
188 | | - project, err := s.projectService.GetProject(ctx, actor.ID, projectId) |
189 | | - if err != nil && err != mongo.ErrNoDocuments { |
190 | | - return ctx, nil, nil, err |
191 | | - } |
192 | | - |
193 | | - userInstructions, err := s.userService.GetUserInstructions(ctx, actor.ID) |
194 | | - if err != nil { |
195 | | - return ctx, nil, nil, err |
196 | | - } |
197 | | - |
198 | | - var latexFullSource string |
199 | | - switch conversationType { |
200 | | - case chatv1.ConversationType_CONVERSATION_TYPE_DEBUG: |
201 | | - latexFullSource = "latex_full_source is not available in debug mode" |
202 | | - default: |
203 | | - if project == nil || project.IsOutOfDate() { |
204 | | - return ctx, nil, nil, shared.ErrProjectOutOfDate("project is out of date") |
205 | | - } |
206 | | - |
207 | | - latexFullSource, err = project.GetFullContent() |
208 | | - if err != nil { |
209 | | - return ctx, nil, nil, err |
210 | | - } |
211 | | - } |
212 | | - |
213 | | - var conversation *models.Conversation |
214 | | - |
215 | | - if conversationId == "" { |
216 | | - conversation, err = s.createConversation( |
217 | | - ctx, |
218 | | - actor.ID, |
219 | | - projectId, |
220 | | - latexFullSource, |
221 | | - project.Instructions, |
222 | | - userInstructions, |
223 | | - userMessage, |
224 | | - userSelectedText, |
225 | | - modelSlug, |
226 | | - conversationType, |
227 | | - ) |
228 | | - } else { |
229 | | - conversation, err = s.appendConversationMessage( |
230 | | - ctx, |
231 | | - actor.ID, |
232 | | - conversationId, |
233 | | - userMessage, |
234 | | - userSelectedText, |
235 | | - conversationType, |
236 | | - ) |
237 | | - } |
238 | | - |
239 | | - if err != nil { |
240 | | - return ctx, nil, nil, err |
241 | | - } |
242 | | - |
243 | | - ctx = contextutil.SetProjectID(ctx, conversation.ProjectID) |
244 | | - ctx = contextutil.SetConversationID(ctx, conversation.ID.Hex()) |
245 | | - |
246 | | - settings, err := s.userService.GetUserSettings(ctx, actor.ID) |
247 | | - if err != nil { |
248 | | - return ctx, conversation, nil, err |
249 | | - } |
250 | | - |
251 | | - return ctx, conversation, settings, nil |
252 | | -} |
0 commit comments