diff --git a/cmd/src/mcp.go b/cmd/src/mcp.go index 458ff0ce05..89ab4b2dce 100644 --- a/cmd/src/mcp.go +++ b/cmd/src/mcp.go @@ -1,19 +1,27 @@ package main import ( + "context" + "encoding/json" "flag" "fmt" + "os" "strings" + "github.com/sourcegraph/src-cli/internal/api" "github.com/sourcegraph/src-cli/internal/mcp" + + "github.com/sourcegraph/sourcegraph/lib/errors" ) func init() { - flagSet := flag.NewFlagSet("mcp", flag.ExitOnError) - commands = append(commands, &command{ - flagSet: flagSet, - handler: mcpMain, - }) + if os.Getenv("SRC_EXPERIMENT_MCP") == "true" { + flagSet := flag.NewFlagSet("mcp", flag.ExitOnError) + commands = append(commands, &command{ + flagSet: flagSet, + handler: mcpMain, + }) + } } func mcpMain(args []string) error { fmt.Println("NOTE: This command is still experimental") @@ -44,37 +52,73 @@ func mcpMain(args []string) error { if !ok { return fmt.Errorf("tool definition for %q not found - run src mcp list-tools to see a list of available tools", subcmd) } - return handleMcpTool(tool, args[1:]) -} -func handleMcpTool(tool *mcp.ToolDef, args []string) error { - fs, vars, err := mcp.BuildArgFlagSet(tool) + flagArgs := args[1:] // skip subcommand name + if len(args) > 1 && args[1] == "schema" { + return printSchemas(tool) + } + + flags, vars, err := mcp.BuildArgFlagSet(tool) if err != nil { return err } + if err := flags.Parse(flagArgs); err != nil { + return err + } + mcp.DerefFlagValues(vars) - if err := fs.Parse(args); err != nil { + if err := validateToolArgs(tool.InputSchema, args, vars); err != nil { return err } - inputSchema := tool.InputSchema + apiClient := cfg.apiClient(nil, flags.Output()) + return handleMcpTool(context.Background(), apiClient, tool, vars) +} + +func printSchemas(tool *mcp.ToolDef) error { + input, err := json.MarshalIndent(tool.InputSchema, "", " ") + if err != nil { + return err + } + output, err := json.MarshalIndent(tool.OutputSchema, "", " ") + if err != nil { + return err + } + fmt.Printf("Input:\n%v\nOutput:\n%v\n", string(input), string(output)) + return nil +} + +func validateToolArgs(inputSchema mcp.SchemaObject, args []string, vars map[string]any) error { for _, reqName := range inputSchema.Required { if vars[reqName] == nil { - return fmt.Errorf("no value provided for required flag --%s", reqName) + return errors.Newf("no value provided for required flag --%s", reqName) } } if len(args) < len(inputSchema.Required) { - return fmt.Errorf("not enough arguments provided - the following flags are required:\n%s", strings.Join(inputSchema.Required, "\n")) + return errors.Newf("not enough arguments provided - the following flags are required:\n%s", strings.Join(inputSchema.Required, "\n")) } - mcp.DerefFlagValues(vars) + return nil +} - fmt.Println("Flags") - for name, val := range vars { - fmt.Printf("--%s=%v\n", name, val) +func handleMcpTool(ctx context.Context, client api.Client, tool *mcp.ToolDef, vars map[string]any) error { + resp, err := mcp.DoToolRequest(ctx, client, tool, vars) + if err != nil { + return err } + result, err := mcp.DecodeToolResponse(resp) + if err != nil { + return err + } + defer resp.Body.Close() + + output, err := json.MarshalIndent(result, "", " ") + if err != nil { + return err + } + fmt.Println(string(output)) return nil } diff --git a/internal/mcp/mcp_args.go b/internal/mcp/mcp_args.go index fe2ed00337..09efb2a371 100644 --- a/internal/mcp/mcp_args.go +++ b/internal/mcp/mcp_args.go @@ -32,11 +32,28 @@ func DerefFlagValues(vars map[string]any) { if slice, ok := vv.(strSliceFlag); ok { vv = slice.vals } - vars[k] = vv + if isNil(vv) { + delete(vars, k) + } else { + vars[k] = vv + } } } } +func isNil(v any) bool { + if v == nil { + return true + } + rv := reflect.ValueOf(v) + switch rv.Kind() { + case reflect.Slice, reflect.Map, reflect.Pointer, reflect.Interface: + return rv.IsNil() + default: + return false + } +} + func BuildArgFlagSet(tool *ToolDef) (*flag.FlagSet, map[string]any, error) { if tool == nil { return nil, nil, errors.New("cannot build flagset on nil Tool Definition") diff --git a/internal/mcp/mcp_request.go b/internal/mcp/mcp_request.go new file mode 100644 index 0000000000..dbcb0ed97b --- /dev/null +++ b/internal/mcp/mcp_request.go @@ -0,0 +1,93 @@ +package mcp + +import ( + "bytes" + "context" + "encoding/json" + "io" + "net/http" + + "github.com/sourcegraph/src-cli/internal/api" + + "github.com/sourcegraph/sourcegraph/lib/errors" +) + +const McpURLPath = ".api/mcp/v1" + +func DoToolRequest(ctx context.Context, client api.Client, tool *ToolDef, vars map[string]any) (*http.Response, error) { + jsonRPC := struct { + Version string `json:"jsonrpc"` + ID int `json:"id"` + Method string `json:"method"` + Params any `json:"params"` + }{ + Version: "2.0", + ID: 1, + Method: "tools/call", + Params: struct { + Name string `json:"name"` + Arguments map[string]any `json:"arguments"` + }{ + Name: tool.RawName, + Arguments: vars, + }, + } + + buf := bytes.NewBuffer(nil) + data, err := json.Marshal(jsonRPC) + if err != nil { + return nil, err + } + buf.Write(data) + + req, err := client.NewHTTPRequest(ctx, http.MethodPost, McpURLPath, buf) + if err != nil { + return nil, err + } + req.Header.Add("Content-Type", "application/json") + req.Header.Add("Accept", "*/*") + + return client.Do(req) +} + +func DecodeToolResponse(resp *http.Response) (map[string]json.RawMessage, error) { + data, err := readSSEResponseData(resp) + if err != nil { + return nil, err + } + + if data == nil { + return map[string]json.RawMessage{}, nil + } + + jsonRPCResp := struct { + Version string `json:"jsonrpc"` + ID int `json:"id"` + Result struct { + Content []json.RawMessage `json:"content"` + StructuredContent map[string]json.RawMessage `json:"structuredContent"` + } `json:"result"` + }{} + if err := json.Unmarshal(data, &jsonRPCResp); err != nil { + return nil, errors.Wrapf(err, "failed to unmarshal MCP JSON-RPC response") + } + + return jsonRPCResp.Result.StructuredContent, nil +} +func readSSEResponseData(resp *http.Response) ([]byte, error) { + data, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + // The response is an SSE reponse + // event: + // data: + lines := bytes.SplitSeq(data, []byte("\n")) + for line := range lines { + if jsonData, ok := bytes.CutPrefix(line, []byte("data: ")); ok { + return jsonData, nil + } + } + return nil, errors.New("no data found in SSE response") + +}