|
1 | 1 | package main |
2 | 2 |
|
3 | 3 | import ( |
| 4 | + "bytes" |
| 5 | + "context" |
| 6 | + "encoding/json" |
4 | 7 | "flag" |
5 | 8 | "fmt" |
| 9 | + "io" |
| 10 | + "net/http" |
6 | 11 | "strings" |
7 | 12 |
|
| 13 | + "github.com/sourcegraph/src-cli/internal/api" |
8 | 14 | "github.com/sourcegraph/src-cli/internal/mcp" |
| 15 | + |
| 16 | + "github.com/sourcegraph/sourcegraph/lib/errors" |
9 | 17 | ) |
10 | 18 |
|
| 19 | +const McpPath = ".api/mcp/v1" |
| 20 | + |
11 | 21 | func init() { |
12 | 22 | flagSet := flag.NewFlagSet("mcp", flag.ExitOnError) |
13 | 23 | commands = append(commands, &command{ |
@@ -35,37 +45,110 @@ func mcpMain(args []string) error { |
35 | 45 | if !ok { |
36 | 46 | return fmt.Errorf("tool definition for %q not found - run src mcp list-tools to see a list of available tools", subcmd) |
37 | 47 | } |
38 | | - return handleMcpTool(tool, args[1:]) |
39 | | -} |
40 | 48 |
|
41 | | -func handleMcpTool(tool *mcp.ToolDef, args []string) error { |
42 | | - fs, vars, err := mcp.BuildArgFlagSet(tool) |
| 49 | + flagArgs := args[1:] // skip subcommand name |
| 50 | + if len(args) > 1 && args[1] == "schema" { |
| 51 | + return printSchemas(tool) |
| 52 | + } |
| 53 | + |
| 54 | + flags, vars, err := mcp.BuildToolFlagSet(tool) |
43 | 55 | if err != nil { |
44 | 56 | return err |
45 | 57 | } |
| 58 | + if err := flags.Parse(flagArgs); err != nil { |
| 59 | + return err |
| 60 | + } |
| 61 | + sanitizeFlagValues(vars) |
| 62 | + |
| 63 | + if err := validateToolArgs(tool.InputSchema, args, vars); err != nil { |
| 64 | + return err |
| 65 | + } |
46 | 66 |
|
47 | | - if err := fs.Parse(args); err != nil { |
| 67 | + apiClient := cfg.apiClient(nil, flags.Output()) |
| 68 | + return handleMcpTool(context.Background(), apiClient, tool, vars) |
| 69 | +} |
| 70 | + |
| 71 | +func printSchemas(tool *mcp.ToolDef) error { |
| 72 | + input, err := json.MarshalIndent(tool.InputSchema, "", " ") |
| 73 | + if err != nil { |
| 74 | + return err |
| 75 | + } |
| 76 | + output, err := json.MarshalIndent(tool.OutputSchema, "", " ") |
| 77 | + if err != nil { |
48 | 78 | return err |
49 | 79 | } |
50 | 80 |
|
51 | | - inputSchema := tool.InputSchema |
| 81 | + fmt.Printf("Input:\n%v\nOutput:\n%v\n", string(input), string(output)) |
| 82 | + return nil |
| 83 | +} |
52 | 84 |
|
| 85 | +func validateToolArgs(inputSchema mcp.Schema, args []string, vars map[string]any) error { |
53 | 86 | for _, reqName := range inputSchema.Required { |
54 | 87 | if vars[reqName] == nil { |
55 | | - return fmt.Errorf("no value provided for required flag --%s", reqName) |
| 88 | + return errors.Newf("no value provided for required flag --%s", reqName) |
56 | 89 | } |
57 | 90 | } |
58 | 91 |
|
59 | 92 | if len(args) < len(inputSchema.Required) { |
60 | | - return fmt.Errorf("not enough arguments provided - the following flags are required:\n%s", strings.Join(inputSchema.Required, "\n")) |
| 93 | + return errors.Newf("not enough arguments provided - the following flags are required:\n%s", strings.Join(inputSchema.Required, "\n")) |
| 94 | + } |
| 95 | + |
| 96 | + return nil |
| 97 | +} |
| 98 | + |
| 99 | +func handleMcpTool(ctx context.Context, client api.Client, tool *mcp.ToolDef, vars map[string]any) error { |
| 100 | + jsonRPC := struct { |
| 101 | + Version string `json:"jsonrpc"` |
| 102 | + ID int `json:"id"` |
| 103 | + Method string `json:"method"` |
| 104 | + Params any `json:"params"` |
| 105 | + }{ |
| 106 | + Version: "2.0", |
| 107 | + ID: 1, |
| 108 | + Method: "tools/call", |
| 109 | + Params: struct { |
| 110 | + Name string `json:"name"` |
| 111 | + Arguments map[string]any `json:"arguments"` |
| 112 | + }{ |
| 113 | + Name: tool.RawName, |
| 114 | + Arguments: vars, |
| 115 | + }, |
| 116 | + } |
| 117 | + |
| 118 | + buf := bytes.NewBuffer(nil) |
| 119 | + data, err := json.Marshal(jsonRPC) |
| 120 | + if err != nil { |
| 121 | + return err |
| 122 | + } |
| 123 | + buf.Write(data) |
| 124 | + |
| 125 | + req, err := client.NewHTTPRequest(ctx, http.MethodPost, McpPath, buf) |
| 126 | + if err != nil { |
| 127 | + return err |
61 | 128 | } |
| 129 | + req.Header.Add("Content-Type", "application/json") |
| 130 | + req.Header.Add("Accept", "*/*") |
62 | 131 |
|
63 | | - mcp.DerefFlagValues(vars) |
| 132 | + resp, err := client.Do(req) |
| 133 | + if err != nil { |
| 134 | + return err |
| 135 | + } |
64 | 136 |
|
65 | | - fmt.Println("Flags") |
66 | | - for name, val := range vars { |
67 | | - fmt.Printf("--%s=%v\n", name, val) |
| 137 | + jsonData, err := parseSSEResponse(data) |
| 138 | + if err != nil { |
| 139 | + return err |
68 | 140 | } |
69 | 141 |
|
| 142 | + fmt.Println(string(jsonData)) |
70 | 143 | return nil |
71 | 144 | } |
| 145 | + |
| 146 | +func parseSSEResponse(data []byte) ([]byte, error) { |
| 147 | + lines := bytes.SplitSeq(data, []byte("\n")) |
| 148 | + for line := range lines { |
| 149 | + if jsonData, ok := bytes.CutPrefix(line, []byte("data: ")); ok { |
| 150 | + return jsonData, nil |
| 151 | + } |
| 152 | + } |
| 153 | + return nil, errors.New("no data found in SSE response") |
| 154 | +} |
0 commit comments