Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 61 additions & 17 deletions cmd/src/mcp.go
Original file line number Diff line number Diff line change
@@ -1,19 +1,27 @@
package main

import (
"context"
"encoding/json"
"flag"
"fmt"
"os"
"strings"

"github.com/sourcegraph/src-cli/internal/api"
"github.com/sourcegraph/src-cli/internal/mcp"

"github.com/sourcegraph/sourcegraph/lib/errors"
)

func init() {
flagSet := flag.NewFlagSet("mcp", flag.ExitOnError)
commands = append(commands, &command{
flagSet: flagSet,
handler: mcpMain,
})
if os.Getenv("SRC_EXPERIMENT_MCP") == "true" {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mentioned in slack just not advertising this command. This also works, but would be nice to more easily experiment without the envvar.

flagSet := flag.NewFlagSet("mcp", flag.ExitOnError)
commands = append(commands, &command{
flagSet: flagSet,
handler: mcpMain,
})
}
}
func mcpMain(args []string) error {
fmt.Println("NOTE: This command is still experimental")
Expand Down Expand Up @@ -44,37 +52,73 @@ func mcpMain(args []string) error {
if !ok {
return fmt.Errorf("tool definition for %q not found - run src mcp list-tools to see a list of available tools", subcmd)
}
return handleMcpTool(tool, args[1:])
}

func handleMcpTool(tool *mcp.ToolDef, args []string) error {
fs, vars, err := mcp.BuildArgFlagSet(tool)
flagArgs := args[1:] // skip subcommand name
if len(args) > 1 && args[1] == "schema" {
return printSchemas(tool)
}

flags, vars, err := mcp.BuildArgFlagSet(tool)
if err != nil {
return err
}
if err := flags.Parse(flagArgs); err != nil {
return err
}
mcp.DerefFlagValues(vars)

if err := fs.Parse(args); err != nil {
if err := validateToolArgs(tool.InputSchema, args, vars); err != nil {
return err
}

inputSchema := tool.InputSchema
apiClient := cfg.apiClient(nil, flags.Output())
return handleMcpTool(context.Background(), apiClient, tool, vars)
}

func printSchemas(tool *mcp.ToolDef) error {
input, err := json.MarshalIndent(tool.InputSchema, "", " ")
if err != nil {
return err
}
output, err := json.MarshalIndent(tool.OutputSchema, "", " ")
if err != nil {
return err
}

fmt.Printf("Input:\n%v\nOutput:\n%v\n", string(input), string(output))
return nil
}

func validateToolArgs(inputSchema mcp.SchemaObject, args []string, vars map[string]any) error {
for _, reqName := range inputSchema.Required {
if vars[reqName] == nil {
return fmt.Errorf("no value provided for required flag --%s", reqName)
return errors.Newf("no value provided for required flag --%s", reqName)
}
}

if len(args) < len(inputSchema.Required) {
return fmt.Errorf("not enough arguments provided - the following flags are required:\n%s", strings.Join(inputSchema.Required, "\n"))
return errors.Newf("not enough arguments provided - the following flags are required:\n%s", strings.Join(inputSchema.Required, "\n"))
}

mcp.DerefFlagValues(vars)
return nil
}

fmt.Println("Flags")
for name, val := range vars {
fmt.Printf("--%s=%v\n", name, val)
func handleMcpTool(ctx context.Context, client api.Client, tool *mcp.ToolDef, vars map[string]any) error {
resp, err := mcp.DoToolRequest(ctx, client, tool, vars)
if err != nil {
return err
}

result, err := mcp.DecodeToolResponse(resp)
if err != nil {
return err
}
defer resp.Body.Close()

output, err := json.MarshalIndent(result, "", " ")
if err != nil {
return err
}
fmt.Println(string(output))
return nil
}
19 changes: 18 additions & 1 deletion internal/mcp/mcp_args.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,28 @@ func DerefFlagValues(vars map[string]any) {
if slice, ok := vv.(strSliceFlag); ok {
vv = slice.vals
}
vars[k] = vv
if isNil(vv) {
delete(vars, k)
} else {
vars[k] = vv
}
}
}
}

func isNil(v any) bool {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fun times

if v == nil {
return true
}
rv := reflect.ValueOf(v)
switch rv.Kind() {
case reflect.Slice, reflect.Map, reflect.Pointer, reflect.Interface:
return rv.IsNil()
default:
return false
}
}

func BuildArgFlagSet(tool *ToolDef) (*flag.FlagSet, map[string]any, error) {
if tool == nil {
return nil, nil, errors.New("cannot build flagset on nil Tool Definition")
Expand Down
93 changes: 93 additions & 0 deletions internal/mcp/mcp_request.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
package mcp

import (
"bytes"
"context"
"encoding/json"
"io"
"net/http"

"github.com/sourcegraph/src-cli/internal/api"

"github.com/sourcegraph/sourcegraph/lib/errors"
)

const McpURLPath = ".api/mcp/v1"

func DoToolRequest(ctx context.Context, client api.Client, tool *ToolDef, vars map[string]any) (*http.Response, error) {
jsonRPC := struct {
Version string `json:"jsonrpc"`
ID int `json:"id"`
Method string `json:"method"`
Params any `json:"params"`
}{
Version: "2.0",
ID: 1,
Method: "tools/call",
Params: struct {
Name string `json:"name"`
Arguments map[string]any `json:"arguments"`
}{
Name: tool.RawName,
Arguments: vars,
},
}

buf := bytes.NewBuffer(nil)
data, err := json.Marshal(jsonRPC)
if err != nil {
return nil, err
}
buf.Write(data)

req, err := client.NewHTTPRequest(ctx, http.MethodPost, McpURLPath, buf)
if err != nil {
return nil, err
}
req.Header.Add("Content-Type", "application/json")
req.Header.Add("Accept", "*/*")

return client.Do(req)
}

func DecodeToolResponse(resp *http.Response) (map[string]json.RawMessage, error) {
data, err := readSSEResponseData(resp)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought our server was just http not sse. But then again we use some framework that I guess decides for us?

if err != nil {
return nil, err
}

if data == nil {
return map[string]json.RawMessage{}, nil
}

jsonRPCResp := struct {
Version string `json:"jsonrpc"`
ID int `json:"id"`
Result struct {
Content []json.RawMessage `json:"content"`
StructuredContent map[string]json.RawMessage `json:"structuredContent"`
} `json:"result"`
}{}
if err := json.Unmarshal(data, &jsonRPCResp); err != nil {
return nil, errors.Wrapf(err, "failed to unmarshal MCP JSON-RPC response")
}

return jsonRPCResp.Result.StructuredContent, nil
}
func readSSEResponseData(resp *http.Response) ([]byte, error) {
data, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
// The response is an SSE reponse
// event:
// data:
lines := bytes.SplitSeq(data, []byte("\n"))
for line := range lines {
if jsonData, ok := bytes.CutPrefix(line, []byte("data: ")); ok {
return jsonData, nil
}
}
return nil, errors.New("no data found in SSE response")

}
Loading