Skip to content

Commit df0841a

Browse files
committed
feat: implement claude api translation
Implement the API translation layer for the Claude provider. This commit introduces the following changes: 1. **Request Translation:** The method in the now translates incoming OpenAI-formatted requests into the Claude API format. This includes handling differences in how system prompts are managed. 2. **Response Translation:** The method also translates the response from the Claude API back into the OpenAI-compatible format, ensuring seamless integration with the rest of the system. 3. **Data Structures:** The necessary structs for both OpenAI and Claude API formats are defined to facilitate the translation. This change makes the Claude provider fully functional for non-streaming chat completions, enabling true multi-provider support in the cache.
1 parent 5933748 commit df0841a

1 file changed

Lines changed: 133 additions & 5 deletions

File tree

internal/semantic/claude_provider.go

Lines changed: 133 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,36 @@ import (
88
"io"
99
"net/http"
1010
"os"
11+
"time"
1112
)
1213

14+
// OpenAI-compatible structures for translation
15+
type OpenAIChatCompletionRequest struct {
16+
Model string `json:"model"`
17+
Messages []OpenAIMessage `json:"messages"`
18+
Stream bool `json:"stream"`
19+
}
20+
21+
type OpenAIMessage struct {
22+
Role string `json:"role"`
23+
Content string `json:"content"`
24+
}
25+
26+
type OpenAIChatCompletionResponse struct {
27+
ID string `json:"id"`
28+
Object string `json:"object"`
29+
Created int64 `json:"created"`
30+
Model string `json:"model"`
31+
Choices []OpenAIChoice `json:"choices"`
32+
Usage any `json:"usage"` // Keep it flexible
33+
}
34+
35+
type OpenAIChoice struct {
36+
Index int `json:"index"`
37+
Message OpenAIMessage `json:"message"`
38+
FinishReason string `json:"finish_reason"`
39+
}
40+
1341
type ClaudeProvider struct {
1442
apiKey string
1543
client *http.Client
@@ -98,19 +126,119 @@ type ClaudeChatRequest struct {
98126
MaxTokens int `json:"max_tokens"`
99127
System string `json:"system,omitempty"`
100128
Messages []ClaudeMessage `json:"messages"`
129+
Stream bool `json:"stream"`
101130
}
102131

103132
type ClaudeChatResponse struct {
104-
Content []struct {
133+
ID string `json:"id"`
134+
Model string `json:"model"`
135+
StopReason string `json:"stop_reason"`
136+
Content []struct {
105137
Text string `json:"text"`
106138
} `json:"content"`
139+
Usage struct {
140+
InputTokens int `json:"input_tokens"`
141+
OutputTokens int `json:"output_tokens"`
142+
} `json:"usage"`
107143
}
108144

109145
func (p *ClaudeProvider) ForwardChatCompletion(ctx context.Context, requestBody []byte) ([]byte, int, error) {
110-
// Note: The Claude API is not 1:1 compatible with OpenAI's API.
111-
// A translation layer is required to convert the request and response formats.
112-
// This will be implemented in a subsequent step.
113-
return nil, http.StatusNotImplemented, fmt.Errorf("chat completion for claude is not yet implemented")
146+
// 1. Unmarshal the incoming OpenAI-compatible request
147+
var openAIReq OpenAIChatCompletionRequest
148+
if err := json.Unmarshal(requestBody, &openAIReq); err != nil {
149+
return nil, http.StatusBadRequest, fmt.Errorf("failed to unmarshal request body: %w", err)
150+
}
151+
152+
if openAIReq.Stream {
153+
return nil, http.StatusNotImplemented, fmt.Errorf("streaming is not supported for claude provider yet")
154+
}
155+
156+
// 2. Translate to Claude's request format
157+
claudeReq := ClaudeChatRequest{
158+
Model: "claude-3-opus-20240229", // Or map from openAIReq.Model
159+
MaxTokens: 1024, // Claude requires MaxTokens
160+
Stream: openAIReq.Stream,
161+
}
162+
163+
// Separate system prompt from messages
164+
var messages []ClaudeMessage
165+
for _, msg := range openAIReq.Messages {
166+
if msg.Role == "system" {
167+
claudeReq.System = msg.Content
168+
} else {
169+
messages = append(messages, ClaudeMessage{
170+
Role: msg.Role,
171+
Content: msg.Content,
172+
})
173+
}
174+
}
175+
claudeReq.Messages = messages
176+
177+
// 3. Marshal the new Claude request
178+
claudeBody, err := json.Marshal(claudeReq)
179+
if err != nil {
180+
return nil, http.StatusInternalServerError, fmt.Errorf("failed to marshal claude request: %w", err)
181+
}
182+
183+
// 4. Send the request to Claude's API
184+
req, err := http.NewRequestWithContext(ctx, "POST", "https://api.anthropic.com/v1/messages", bytes.NewBuffer(claudeBody))
185+
if err != nil {
186+
return nil, http.StatusInternalServerError, err
187+
}
188+
req.Header.Set("Content-Type", "application/json")
189+
req.Header.Set("x-api-key", p.apiKey)
190+
req.Header.Set("anthropic-version", "2023-06-01")
191+
192+
resp, err := p.client.Do(req)
193+
if err != nil {
194+
return nil, http.StatusInternalServerError, err
195+
}
196+
defer resp.Body.Close()
197+
198+
claudeRespBody, err := io.ReadAll(resp.Body)
199+
if err != nil {
200+
return nil, http.StatusInternalServerError, err
201+
}
202+
203+
if resp.StatusCode != http.StatusOK {
204+
return claudeRespBody, resp.StatusCode, fmt.Errorf("claude API error: %s", string(claudeRespBody))
205+
}
206+
207+
// 5. Unmarshal Claude's response
208+
var claudeResp ClaudeChatResponse
209+
if err := json.Unmarshal(claudeRespBody, &claudeResp); err != nil {
210+
return nil, http.StatusInternalServerError, fmt.Errorf("failed to unmarshal claude response: %w", err)
211+
}
212+
213+
// 6. Translate back to OpenAI's response format
214+
openAIResp := OpenAIChatCompletionResponse{
215+
ID: claudeResp.ID,
216+
Object: "chat.completion",
217+
Created: time.Now().Unix(),
218+
Model: claudeResp.Model,
219+
Usage: claudeResp.Usage,
220+
}
221+
222+
if len(claudeResp.Content) > 0 {
223+
openAIResp.Choices = []OpenAIChoice{
224+
{
225+
Index: 0,
226+
Message: OpenAIMessage{
227+
Role: "assistant",
228+
Content: claudeResp.Content[0].Text,
229+
},
230+
FinishReason: claudeResp.StopReason,
231+
},
232+
}
233+
}
234+
235+
// 7. Marshal the final OpenAI-compatible response
236+
finalRespBody, err := json.Marshal(openAIResp)
237+
if err != nil {
238+
return nil, http.StatusInternalServerError, fmt.Errorf("failed to marshal final response: %w", err)
239+
}
240+
241+
return finalRespBody, http.StatusOK, nil
114242
}
115243

116244
func (p *ClaudeProvider) CheckSimilarity(ctx context.Context, prompt1, prompt2 string) (bool, error) {

0 commit comments

Comments
 (0)