|
1 | 1 | package main |
2 | 2 |
|
3 | 3 | import ( |
| 4 | + "bytes" |
| 5 | + "context" |
| 6 | + "encoding/json" |
4 | 7 | "flag" |
5 | 8 | "fmt" |
| 9 | + "io" |
| 10 | + "net/http" |
6 | 11 | "strings" |
7 | 12 |
|
| 13 | + "github.com/sourcegraph/src-cli/internal/api" |
8 | 14 | "github.com/sourcegraph/src-cli/internal/mcp" |
| 15 | + |
| 16 | + "github.com/sourcegraph/sourcegraph/lib/errors" |
9 | 17 | ) |
10 | 18 |
|
| 19 | +const McpPath = ".api/mcp/v1" |
| 20 | + |
11 | 21 | func init() { |
12 | 22 | flagSet := flag.NewFlagSet("mcp", flag.ExitOnError) |
13 | 23 | commands = append(commands, &command{ |
@@ -44,37 +54,110 @@ func mcpMain(args []string) error { |
44 | 54 | if !ok { |
45 | 55 | return fmt.Errorf("tool definition for %q not found - run src mcp list-tools to see a list of available tools", subcmd) |
46 | 56 | } |
47 | | - return handleMcpTool(tool, args[1:]) |
48 | | -} |
49 | 57 |
|
50 | | -func handleMcpTool(tool *mcp.ToolDef, args []string) error { |
51 | | - fs, vars, err := mcp.BuildArgFlagSet(tool) |
| 58 | + flagArgs := args[1:] // skip subcommand name |
| 59 | + if len(args) > 1 && args[1] == "schema" { |
| 60 | + return printSchemas(tool) |
| 61 | + } |
| 62 | + |
| 63 | + flags, vars, err := mcp.BuildToolFlagSet(tool) |
52 | 64 | if err != nil { |
53 | 65 | return err |
54 | 66 | } |
| 67 | + if err := flags.Parse(flagArgs); err != nil { |
| 68 | + return err |
| 69 | + } |
| 70 | + sanitizeFlagValues(vars) |
| 71 | + |
| 72 | + if err := validateToolArgs(tool.InputSchema, args, vars); err != nil { |
| 73 | + return err |
| 74 | + } |
55 | 75 |
|
56 | | - if err := fs.Parse(args); err != nil { |
| 76 | + apiClient := cfg.apiClient(nil, flags.Output()) |
| 77 | + return handleMcpTool(context.Background(), apiClient, tool, vars) |
| 78 | +} |
| 79 | + |
| 80 | +func printSchemas(tool *mcp.ToolDef) error { |
| 81 | + input, err := json.MarshalIndent(tool.InputSchema, "", " ") |
| 82 | + if err != nil { |
| 83 | + return err |
| 84 | + } |
| 85 | + output, err := json.MarshalIndent(tool.OutputSchema, "", " ") |
| 86 | + if err != nil { |
57 | 87 | return err |
58 | 88 | } |
59 | 89 |
|
60 | | - inputSchema := tool.InputSchema |
| 90 | + fmt.Printf("Input:\n%v\nOutput:\n%v\n", string(input), string(output)) |
| 91 | + return nil |
| 92 | +} |
61 | 93 |
|
| 94 | +func validateToolArgs(inputSchema mcp.Schema, args []string, vars map[string]any) error { |
62 | 95 | for _, reqName := range inputSchema.Required { |
63 | 96 | if vars[reqName] == nil { |
64 | | - return fmt.Errorf("no value provided for required flag --%s", reqName) |
| 97 | + return errors.Newf("no value provided for required flag --%s", reqName) |
65 | 98 | } |
66 | 99 | } |
67 | 100 |
|
68 | 101 | if len(args) < len(inputSchema.Required) { |
69 | | - return fmt.Errorf("not enough arguments provided - the following flags are required:\n%s", strings.Join(inputSchema.Required, "\n")) |
| 102 | + return errors.Newf("not enough arguments provided - the following flags are required:\n%s", strings.Join(inputSchema.Required, "\n")) |
| 103 | + } |
| 104 | + |
| 105 | + return nil |
| 106 | +} |
| 107 | + |
| 108 | +func handleMcpTool(ctx context.Context, client api.Client, tool *mcp.ToolDef, vars map[string]any) error { |
| 109 | + jsonRPC := struct { |
| 110 | + Version string `json:"jsonrpc"` |
| 111 | + ID int `json:"id"` |
| 112 | + Method string `json:"method"` |
| 113 | + Params any `json:"params"` |
| 114 | + }{ |
| 115 | + Version: "2.0", |
| 116 | + ID: 1, |
| 117 | + Method: "tools/call", |
| 118 | + Params: struct { |
| 119 | + Name string `json:"name"` |
| 120 | + Arguments map[string]any `json:"arguments"` |
| 121 | + }{ |
| 122 | + Name: tool.RawName, |
| 123 | + Arguments: vars, |
| 124 | + }, |
| 125 | + } |
| 126 | + |
| 127 | + buf := bytes.NewBuffer(nil) |
| 128 | + data, err := json.Marshal(jsonRPC) |
| 129 | + if err != nil { |
| 130 | + return err |
| 131 | + } |
| 132 | + buf.Write(data) |
| 133 | + |
| 134 | + req, err := client.NewHTTPRequest(ctx, http.MethodPost, McpPath, buf) |
| 135 | + if err != nil { |
| 136 | + return err |
70 | 137 | } |
| 138 | + req.Header.Add("Content-Type", "application/json") |
| 139 | + req.Header.Add("Accept", "*/*") |
71 | 140 |
|
72 | | - mcp.DerefFlagValues(vars) |
| 141 | + resp, err := client.Do(req) |
| 142 | + if err != nil { |
| 143 | + return err |
| 144 | + } |
73 | 145 |
|
74 | | - fmt.Println("Flags") |
75 | | - for name, val := range vars { |
76 | | - fmt.Printf("--%s=%v\n", name, val) |
| 146 | + jsonData, err := parseSSEResponse(data) |
| 147 | + if err != nil { |
| 148 | + return err |
77 | 149 | } |
78 | 150 |
|
| 151 | + fmt.Println(string(jsonData)) |
79 | 152 | return nil |
80 | 153 | } |
| 154 | + |
| 155 | +func parseSSEResponse(data []byte) ([]byte, error) { |
| 156 | + lines := bytes.SplitSeq(data, []byte("\n")) |
| 157 | + for line := range lines { |
| 158 | + if jsonData, ok := bytes.CutPrefix(line, []byte("data: ")); ok { |
| 159 | + return jsonData, nil |
| 160 | + } |
| 161 | + } |
| 162 | + return nil, errors.New("no data found in SSE response") |
| 163 | +} |
0 commit comments