Skip to content

Commit 9dc9048

Browse files
committed
feat: add context injection for xtramcp tools
1 parent 99b1483 commit 9dc9048

2 files changed

Lines changed: 105 additions & 24 deletions

File tree

internal/services/toolkit/tools/xtramcp/loader_v2.go

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,17 +53,45 @@ func (loader *XtraMCPLoaderV2) LoadToolsFromBackend(toolRegistry *registry.ToolR
5353

5454
// Register each tool dynamically, passing the session ID
5555
for _, toolSchema := range toolSchemas {
56-
dynamicTool := NewDynamicToolV2(loader.db, loader.projectService, toolSchema, loader.baseURL, loader.sessionID)
56+
// some tools require secrutiy context injection e.g. user_id to authenticate
57+
requiresInjection := loader.requiresSecurityInjection(toolSchema)
58+
59+
dynamicTool := NewDynamicToolV2(
60+
loader.db,
61+
loader.projectService,
62+
toolSchema,
63+
loader.baseURL,
64+
loader.sessionID,
65+
requiresInjection,
66+
)
5767

5868
// Register the tool with the registry
5969
toolRegistry.Register(toolSchema.Name, dynamicTool.Description, dynamicTool.Call)
6070

61-
fmt.Printf("Registered dynamic tool: %s\n", toolSchema.Name)
71+
if requiresInjection {
72+
fmt.Printf("Registered dynamic tool with security injection: %s\n", toolSchema.Name)
73+
} else {
74+
fmt.Printf("Registered dynamic tool: %s\n", toolSchema.Name)
75+
}
6276
}
6377

6478
return nil
6579
}
6680

81+
// checks if a tool schema contains parameters that should be inejected instead of LLM-generated
82+
func (loader *XtraMCPLoaderV2) requiresSecurityInjection(schema ToolSchemaV2) bool {
83+
properties, ok := schema.InputSchema["properties"].(map[string]interface{})
84+
if !ok {
85+
return false
86+
}
87+
88+
// injected parameters
89+
_, hasUserId := properties["user_id"]
90+
_, hasProjectId := properties["project_id"]
91+
92+
return hasUserId || hasProjectId
93+
}
94+
6795
// InitializeMCP performs the full MCP initialization handshake, stores session ID, and returns it
6896
func (loader *XtraMCPLoaderV2) InitializeMCP() (string, error) {
6997
// Step 1: Initialize

internal/services/toolkit/tools/xtramcp/tool_v2.go

Lines changed: 75 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,13 @@ import (
99
"net/http"
1010
"paperdebugger/internal/libs/db"
1111
"paperdebugger/internal/services"
12+
"paperdebugger/internal/services/toolkit"
1213
toolCallRecordDB "paperdebugger/internal/services/toolkit/db"
1314
"time"
1415

1516
"github.com/openai/openai-go/v3"
1617
"github.com/openai/openai-go/v3/packages/param"
18+
"go.mongodb.org/mongo-driver/v2/mongo"
1719
)
1820

1921
// ToolSchema represents the schema from your backend
@@ -40,41 +42,49 @@ type MCPParamsV2 struct {
4042

4143
// DynamicTool represents a generic tool that can handle any schema
4244
type DynamicToolV2 struct {
43-
Name string
44-
Description openai.ChatCompletionToolUnionParam
45-
toolCallRecordDB *toolCallRecordDB.ToolCallRecordDB
46-
projectService *services.ProjectService
47-
coolDownTime time.Duration
48-
baseURL string
49-
client *http.Client
50-
schema map[string]interface{}
51-
sessionID string // Reuse the session ID from initialization
45+
Name string
46+
Description openai.ChatCompletionToolUnionParam
47+
toolCallRecordDB *toolCallRecordDB.ToolCallRecordDB
48+
projectService *services.ProjectService
49+
coolDownTime time.Duration
50+
baseURL string
51+
client *http.Client
52+
schema map[string]interface{}
53+
sessionID string // Reuse the session ID from initialization
54+
requiresInjection bool // Indicates if this tool needs user/project injection
5255
}
5356

5457
// NewDynamicTool creates a new dynamic tool from a schema
55-
func NewDynamicToolV2(db *db.DB, projectService *services.ProjectService, toolSchema ToolSchemaV2, baseURL string, sessionID string) *DynamicToolV2 {
56-
// Create tool description with the schema
58+
func NewDynamicToolV2(db *db.DB, projectService *services.ProjectService, toolSchema ToolSchemaV2, baseURL string, sessionID string, requiresInjection bool) *DynamicToolV2 {
59+
// filter schema if injection is required (hide security context like user_id/project_id from LLM)
60+
schemaForLLM := toolSchema.InputSchema
61+
if requiresInjection {
62+
schemaForLLM = filterSecurityParameters(toolSchema.InputSchema)
63+
}
64+
5765
description := openai.ChatCompletionToolUnionParam{
5866
OfFunction: &openai.ChatCompletionFunctionToolParam{
5967
Function: openai.FunctionDefinitionParam{
6068
Name: toolSchema.Name,
6169
Description: param.NewOpt(toolSchema.Description),
62-
Parameters: openai.FunctionParameters(toolSchema.InputSchema),
70+
Parameters: openai.FunctionParameters(schemaForLLM), // Use filtered schema
6371
},
6472
},
6573
}
6674

6775
toolCallRecordDB := toolCallRecordDB.NewToolCallRecordDB(db)
76+
//TODO: consider letting llm client know of output schema too
6877
return &DynamicToolV2{
69-
Name: toolSchema.Name,
70-
Description: description,
71-
toolCallRecordDB: toolCallRecordDB,
72-
projectService: projectService,
73-
coolDownTime: 5 * time.Minute,
74-
baseURL: baseURL,
75-
client: &http.Client{},
76-
schema: toolSchema.InputSchema,
77-
sessionID: sessionID, // Store the session ID for reuse
78+
Name: toolSchema.Name,
79+
Description: description,
80+
toolCallRecordDB: toolCallRecordDB,
81+
projectService: projectService,
82+
coolDownTime: 5 * time.Minute,
83+
baseURL: baseURL,
84+
client: &http.Client{},
85+
schema: toolSchema.InputSchema, // Store original schema for validation
86+
sessionID: sessionID, // Store the session ID for reuse
87+
requiresInjection: requiresInjection,
7888
}
7989
}
8090

@@ -87,7 +97,14 @@ func (t *DynamicToolV2) Call(ctx context.Context, toolCallId string, args json.R
8797
return "", "", err
8898
}
8999

90-
// Create function call record
100+
// inject user/project context if required
101+
if t.requiresInjection {
102+
err := t.injectSecurityContext(ctx, argsMap)
103+
if err != nil {
104+
return "", "", fmt.Errorf("security context injection failed: %w", err)
105+
}
106+
}
107+
91108
record, err := t.toolCallRecordDB.Create(ctx, toolCallId, t.Name, argsMap)
92109
if err != nil {
93110
return "", "", err
@@ -112,6 +129,42 @@ func (t *DynamicToolV2) Call(ctx context.Context, toolCallId string, args json.R
112129
return respStr, "", nil
113130
}
114131

132+
// extracts user/project from context and injects into arguments
133+
func (t *DynamicToolV2) injectSecurityContext(ctx context.Context, argsMap map[string]interface{}) error {
134+
// 1. Extract from context
135+
actor, projectId, _ := toolkit.GetActorProjectConversationID(ctx)
136+
if actor == nil || projectId == "" {
137+
return fmt.Errorf("authentication required: user context not found")
138+
}
139+
140+
// 2. Validate user owns the project
141+
_, err := t.projectService.GetProject(ctx, actor.ID, projectId)
142+
if err != nil {
143+
if err == mongo.ErrNoDocuments {
144+
return fmt.Errorf("authorization failed: project not found or access denied")
145+
}
146+
return fmt.Errorf("authorization check failed: %w", err)
147+
}
148+
149+
// 3. Check if tool schema expects these parameters
150+
properties, ok := t.schema["properties"].(map[string]interface{})
151+
if !ok {
152+
return fmt.Errorf("invalid tool schema: properties not found")
153+
}
154+
155+
// 4. Inject user_id if expected by tool
156+
if _, hasUserId := properties["user_id"]; hasUserId {
157+
argsMap["user_id"] = actor.ID.Hex()
158+
}
159+
160+
// 5. Inject project_id if expected by tool
161+
if _, hasProjectId := properties["project_id"]; hasProjectId {
162+
argsMap["project_id"] = projectId
163+
}
164+
165+
return nil
166+
}
167+
115168
// executeTool makes the MCP request (generic for any tool)
116169
func (t *DynamicToolV2) executeTool(args map[string]interface{}) (string, error) {
117170

0 commit comments

Comments
 (0)