|
| 1 | +// SPDX-License-Identifier: MIT |
| 2 | +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors |
| 3 | + |
| 4 | +package middleware |
| 5 | + |
| 6 | +import ( |
| 7 | + "bytes" |
| 8 | + "encoding/json" |
| 9 | + "errors" |
| 10 | + "fmt" |
| 11 | + "io" |
| 12 | + "net/http" |
| 13 | + "net/url" |
| 14 | + "strings" |
| 15 | + |
| 16 | + "github.com/labstack/echo/v5" |
| 17 | +) |
| 18 | + |
| 19 | +// MCPConfig defines the config for MCP middleware. The middleware mounts a |
| 20 | +// Model Context Protocol (https://modelcontextprotocol.io) JSON-RPC 2.0 endpoint |
| 21 | +// at MCPConfig.Path and exposes every route registered on the Echo instance as |
| 22 | +// an MCP "tool" that an AI client (Claude Desktop, Cursor, VS Code, ...) can |
| 23 | +// discover via tools/list and invoke via tools/call. |
| 24 | +type MCPConfig struct { |
| 25 | + // Skipper defines a function to skip middleware. |
| 26 | + Skipper Skipper |
| 27 | + |
| 28 | + // Name is the server name advertised in the initialize handshake. |
| 29 | + // Default value "echo-mcp". |
| 30 | + Name string |
| 31 | + |
| 32 | + // Version is the server version advertised in the initialize handshake. |
| 33 | + // Default value "0.0.0". |
| 34 | + Version string |
| 35 | + |
| 36 | + // Path is the URL path the MCP JSON-RPC endpoint is mounted on. Only POST |
| 37 | + // requests to this exact path are handled; every other request is passed |
| 38 | + // through to the next handler unchanged. |
| 39 | + // Default value "/mcp". |
| 40 | + Path string |
| 41 | +} |
| 42 | + |
| 43 | +// MCP returns a middleware that exposes registered Echo routes as MCP tools at |
| 44 | +// the path configured in MCPConfig.Path. |
| 45 | +func MCP(config MCPConfig) echo.MiddlewareFunc { |
| 46 | + return toMiddlewareOrPanic(config) |
| 47 | +} |
| 48 | + |
| 49 | +// ToMiddleware converts MCPConfig to middleware or returns an error for invalid configuration. |
| 50 | +func (config MCPConfig) ToMiddleware() (echo.MiddlewareFunc, error) { |
| 51 | + if config.Skipper == nil { |
| 52 | + config.Skipper = DefaultSkipper |
| 53 | + } |
| 54 | + if config.Path == "" { |
| 55 | + config.Path = "/mcp" |
| 56 | + } |
| 57 | + if config.Name == "" { |
| 58 | + config.Name = "echo-mcp" |
| 59 | + } |
| 60 | + if config.Version == "" { |
| 61 | + config.Version = "0.0.0" |
| 62 | + } |
| 63 | + |
| 64 | + return func(next echo.HandlerFunc) echo.HandlerFunc { |
| 65 | + return func(c *echo.Context) error { |
| 66 | + if config.Skipper(c) { |
| 67 | + return next(c) |
| 68 | + } |
| 69 | + if c.Request().URL.Path != config.Path { |
| 70 | + return next(c) |
| 71 | + } |
| 72 | + if c.Request().Method != http.MethodPost { |
| 73 | + return next(c) |
| 74 | + } |
| 75 | + return handleMCP(c, config) |
| 76 | + } |
| 77 | + }, nil |
| 78 | +} |
| 79 | + |
| 80 | +// --- JSON-RPC envelope types ------------------------------------------------ |
| 81 | + |
| 82 | +type rpcRequest struct { |
| 83 | + JSONRPC string `json:"jsonrpc"` |
| 84 | + ID json.RawMessage `json:"id,omitempty"` |
| 85 | + Method string `json:"method"` |
| 86 | + Params json.RawMessage `json:"params,omitempty"` |
| 87 | +} |
| 88 | + |
| 89 | +type rpcResponse struct { |
| 90 | + JSONRPC string `json:"jsonrpc"` |
| 91 | + ID json.RawMessage `json:"id,omitempty"` |
| 92 | + Result any `json:"result,omitempty"` |
| 93 | + Error *rpcError `json:"error,omitempty"` |
| 94 | +} |
| 95 | + |
| 96 | +type rpcError struct { |
| 97 | + Code int `json:"code"` |
| 98 | + Message string `json:"message"` |
| 99 | +} |
| 100 | + |
| 101 | +// --- request handling ------------------------------------------------------- |
| 102 | + |
| 103 | +func handleMCP(c *echo.Context, cfg MCPConfig) error { |
| 104 | + body, err := io.ReadAll(c.Request().Body) |
| 105 | + if err != nil { |
| 106 | + return c.JSON(http.StatusOK, rpcResponse{ |
| 107 | + JSONRPC: "2.0", |
| 108 | + Error: &rpcError{Code: -32700, Message: "parse error: " + err.Error()}, |
| 109 | + }) |
| 110 | + } |
| 111 | + |
| 112 | + var req rpcRequest |
| 113 | + if err := json.Unmarshal(body, &req); err != nil { |
| 114 | + return c.JSON(http.StatusOK, rpcResponse{ |
| 115 | + JSONRPC: "2.0", |
| 116 | + Error: &rpcError{Code: -32700, Message: "parse error: " + err.Error()}, |
| 117 | + }) |
| 118 | + } |
| 119 | + |
| 120 | + switch req.Method { |
| 121 | + case "initialize": |
| 122 | + return c.JSON(http.StatusOK, rpcResponse{ |
| 123 | + JSONRPC: "2.0", |
| 124 | + ID: req.ID, |
| 125 | + Result: map[string]any{ |
| 126 | + "protocolVersion": "2024-11-05", |
| 127 | + "serverInfo": map[string]string{ |
| 128 | + "name": cfg.Name, |
| 129 | + "version": cfg.Version, |
| 130 | + }, |
| 131 | + "capabilities": map[string]any{ |
| 132 | + "tools": map[string]any{}, |
| 133 | + }, |
| 134 | + }, |
| 135 | + }) |
| 136 | + |
| 137 | + case "notifications/initialized": |
| 138 | + // Notifications carry no ID and expect no response. |
| 139 | + return c.NoContent(http.StatusNoContent) |
| 140 | + |
| 141 | + case "tools/list": |
| 142 | + tools, _ := buildTools(c.Echo().Router().Routes(), cfg.Path) |
| 143 | + return c.JSON(http.StatusOK, rpcResponse{ |
| 144 | + JSONRPC: "2.0", |
| 145 | + ID: req.ID, |
| 146 | + Result: map[string]any{"tools": tools}, |
| 147 | + }) |
| 148 | + |
| 149 | + case "tools/call": |
| 150 | + result, err := callTool(c, cfg, req.Params) |
| 151 | + if err != nil { |
| 152 | + return c.JSON(http.StatusOK, rpcResponse{ |
| 153 | + JSONRPC: "2.0", |
| 154 | + ID: req.ID, |
| 155 | + Error: &rpcError{Code: -32603, Message: err.Error()}, |
| 156 | + }) |
| 157 | + } |
| 158 | + return c.JSON(http.StatusOK, rpcResponse{ |
| 159 | + JSONRPC: "2.0", |
| 160 | + ID: req.ID, |
| 161 | + Result: result, |
| 162 | + }) |
| 163 | + |
| 164 | + default: |
| 165 | + return c.JSON(http.StatusOK, rpcResponse{ |
| 166 | + JSONRPC: "2.0", |
| 167 | + ID: req.ID, |
| 168 | + Error: &rpcError{Code: -32601, Message: "method not found: " + req.Method}, |
| 169 | + }) |
| 170 | + } |
| 171 | +} |
| 172 | + |
| 173 | +// --- tool building ---------------------------------------------------------- |
| 174 | + |
| 175 | +// buildTools turns Echo's registered routes into MCP tool descriptors and |
| 176 | +// returns a parallel name->RouteInfo map used by callTool to look up the route |
| 177 | +// to dispatch to. |
| 178 | +func buildTools(routes echo.Routes, mcpPath string) ([]map[string]any, map[string]echo.RouteInfo) { |
| 179 | + tools := make([]map[string]any, 0, len(routes)) |
| 180 | + index := make(map[string]echo.RouteInfo, len(routes)) |
| 181 | + used := make(map[string]int, len(routes)) |
| 182 | + |
| 183 | + for _, ri := range routes { |
| 184 | + if ri.Path == mcpPath { |
| 185 | + continue // never expose the MCP endpoint itself |
| 186 | + } |
| 187 | + |
| 188 | + name := toolName(ri) |
| 189 | + if used[name] > 0 { |
| 190 | + name = fmt.Sprintf("%s_%d", name, used[name]+1) |
| 191 | + } |
| 192 | + used[toolName(ri)]++ |
| 193 | + index[name] = ri |
| 194 | + |
| 195 | + properties := map[string]any{} |
| 196 | + required := make([]string, 0, len(ri.Parameters)) |
| 197 | + for _, p := range ri.Parameters { |
| 198 | + properties[p] = map[string]any{ |
| 199 | + "type": "string", |
| 200 | + "description": "Path parameter :" + p, |
| 201 | + } |
| 202 | + required = append(required, p) |
| 203 | + } |
| 204 | + properties["query"] = map[string]any{ |
| 205 | + "type": "object", |
| 206 | + "description": "Optional query string parameters as a flat key/value object.", |
| 207 | + "additionalProperties": true, |
| 208 | + } |
| 209 | + if methodHasBody(ri.Method) { |
| 210 | + properties["body"] = map[string]any{ |
| 211 | + "type": "object", |
| 212 | + "description": "JSON request body.", |
| 213 | + "additionalProperties": true, |
| 214 | + } |
| 215 | + } |
| 216 | + |
| 217 | + schema := map[string]any{ |
| 218 | + "type": "object", |
| 219 | + "properties": properties, |
| 220 | + } |
| 221 | + if len(required) > 0 { |
| 222 | + schema["required"] = required |
| 223 | + } |
| 224 | + |
| 225 | + tools = append(tools, map[string]any{ |
| 226 | + "name": name, |
| 227 | + "description": fmt.Sprintf("%s %s", ri.Method, ri.Path), |
| 228 | + "inputSchema": schema, |
| 229 | + }) |
| 230 | + } |
| 231 | + |
| 232 | + return tools, index |
| 233 | +} |
| 234 | + |
| 235 | +func toolName(ri echo.RouteInfo) string { |
| 236 | + if ri.Name != "" && ri.Name != ri.Method+":"+ri.Path { |
| 237 | + return sanitize(ri.Name) |
| 238 | + } |
| 239 | + slug := strings.NewReplacer("/", "_", ":", "", "*", "wild").Replace(ri.Path) |
| 240 | + slug = strings.Trim(slug, "_") |
| 241 | + if slug == "" { |
| 242 | + slug = "root" |
| 243 | + } |
| 244 | + return ri.Method + "_" + slug |
| 245 | +} |
| 246 | + |
| 247 | +func sanitize(s string) string { |
| 248 | + var b strings.Builder |
| 249 | + for _, r := range s { |
| 250 | + switch { |
| 251 | + case r >= 'a' && r <= 'z', r >= 'A' && r <= 'Z', r >= '0' && r <= '9', r == '_', r == '-': |
| 252 | + b.WriteRune(r) |
| 253 | + default: |
| 254 | + b.WriteRune('_') |
| 255 | + } |
| 256 | + } |
| 257 | + return b.String() |
| 258 | +} |
| 259 | + |
| 260 | +func methodHasBody(method string) bool { |
| 261 | + switch method { |
| 262 | + case http.MethodPost, http.MethodPut, http.MethodPatch, http.MethodDelete: |
| 263 | + return true |
| 264 | + } |
| 265 | + return false |
| 266 | +} |
| 267 | + |
| 268 | +// --- tool invocation -------------------------------------------------------- |
| 269 | + |
| 270 | +type toolCallParams struct { |
| 271 | + Name string `json:"name"` |
| 272 | + Arguments map[string]any `json:"arguments"` |
| 273 | +} |
| 274 | + |
| 275 | +func callTool(c *echo.Context, cfg MCPConfig, raw json.RawMessage) (any, error) { |
| 276 | + var p toolCallParams |
| 277 | + if err := json.Unmarshal(raw, &p); err != nil { |
| 278 | + return nil, fmt.Errorf("invalid tools/call params: %w", err) |
| 279 | + } |
| 280 | + if p.Name == "" { |
| 281 | + return nil, errors.New("tools/call: missing tool name") |
| 282 | + } |
| 283 | + |
| 284 | + _, index := buildTools(c.Echo().Router().Routes(), cfg.Path) |
| 285 | + ri, ok := index[p.Name] |
| 286 | + if !ok { |
| 287 | + return nil, fmt.Errorf("tools/call: unknown tool %q", p.Name) |
| 288 | + } |
| 289 | + |
| 290 | + // Substitute path parameters in order using RouteInfo.Reverse. |
| 291 | + pathValues := make([]any, 0, len(ri.Parameters)) |
| 292 | + for _, name := range ri.Parameters { |
| 293 | + v, ok := p.Arguments[name] |
| 294 | + if !ok { |
| 295 | + return nil, fmt.Errorf("tools/call: missing required argument %q", name) |
| 296 | + } |
| 297 | + pathValues = append(pathValues, v) |
| 298 | + } |
| 299 | + target := ri.Reverse(pathValues...) |
| 300 | + |
| 301 | + // Build query string from arguments["query"] if present. |
| 302 | + if q, ok := p.Arguments["query"].(map[string]any); ok && len(q) > 0 { |
| 303 | + values := url.Values{} |
| 304 | + for k, v := range q { |
| 305 | + values.Set(k, fmt.Sprintf("%v", v)) |
| 306 | + } |
| 307 | + if strings.Contains(target, "?") { |
| 308 | + target += "&" + values.Encode() |
| 309 | + } else { |
| 310 | + target += "?" + values.Encode() |
| 311 | + } |
| 312 | + } |
| 313 | + |
| 314 | + // Build body from arguments["body"] if present. |
| 315 | + var bodyReader io.Reader |
| 316 | + if b, ok := p.Arguments["body"]; ok && b != nil { |
| 317 | + raw, err := json.Marshal(b) |
| 318 | + if err != nil { |
| 319 | + return nil, fmt.Errorf("tools/call: cannot marshal body: %w", err) |
| 320 | + } |
| 321 | + bodyReader = bytes.NewReader(raw) |
| 322 | + } |
| 323 | + |
| 324 | + innerReq, err := http.NewRequestWithContext(c.Request().Context(), ri.Method, target, bodyReader) |
| 325 | + if err != nil { |
| 326 | + return nil, fmt.Errorf("tools/call: cannot build request: %w", err) |
| 327 | + } |
| 328 | + // http.NewRequest leaves RequestURI empty (it is only set on server-side |
| 329 | + // requests), but Echo's request logger and other middlewares read it. |
| 330 | + // Populate it so the synthesized request looks like a real one downstream. |
| 331 | + innerReq.RequestURI = target |
| 332 | + if bodyReader != nil { |
| 333 | + innerReq.Header.Set("Content-Type", "application/json") |
| 334 | + } |
| 335 | + innerReq.Header.Set("Accept", "application/json") |
| 336 | + |
| 337 | + rw := &bufferRW{header: http.Header{}} |
| 338 | + c.Echo().ServeHTTP(rw, innerReq) |
| 339 | + |
| 340 | + text := rw.body.String() |
| 341 | + if text == "" { |
| 342 | + text = fmt.Sprintf("(status %d, empty body)", rw.statusOrDefault()) |
| 343 | + } |
| 344 | + |
| 345 | + return map[string]any{ |
| 346 | + "content": []map[string]any{ |
| 347 | + {"type": "text", "text": text}, |
| 348 | + }, |
| 349 | + "isError": rw.statusOrDefault() >= 400, |
| 350 | + }, nil |
| 351 | +} |
| 352 | + |
| 353 | +// --- in-memory http.ResponseWriter ----------------------------------------- |
| 354 | + |
| 355 | +type bufferRW struct { |
| 356 | + header http.Header |
| 357 | + status int |
| 358 | + body bytes.Buffer |
| 359 | +} |
| 360 | + |
| 361 | +func (w *bufferRW) Header() http.Header { return w.header } |
| 362 | + |
| 363 | +func (w *bufferRW) WriteHeader(code int) { |
| 364 | + if w.status == 0 { |
| 365 | + w.status = code |
| 366 | + } |
| 367 | +} |
| 368 | + |
| 369 | +func (w *bufferRW) Write(b []byte) (int, error) { |
| 370 | + if w.status == 0 { |
| 371 | + w.status = http.StatusOK |
| 372 | + } |
| 373 | + return w.body.Write(b) |
| 374 | +} |
| 375 | + |
| 376 | +func (w *bufferRW) statusOrDefault() int { |
| 377 | + if w.status == 0 { |
| 378 | + return http.StatusOK |
| 379 | + } |
| 380 | + return w.status |
| 381 | +} |
0 commit comments