diff --git a/sdks/community/go/pkg/client/sse/client.go b/sdks/community/go/pkg/client/sse/client.go index bf50ead22..d8c8707e3 100644 --- a/sdks/community/go/pkg/client/sse/client.go +++ b/sdks/community/go/pkg/client/sse/client.go @@ -11,6 +11,7 @@ import ( "strings" "time" + "github.com/ag-ui-protocol/ag-ui/sdks/community/go/pkg/core/types" "github.com/sirupsen/logrus" ) @@ -38,7 +39,7 @@ type Frame struct { type StreamOptions struct { Context context.Context - Payload interface{} + Payload types.RunAgentInput Headers map[string]string } diff --git a/sdks/community/go/pkg/client/sse/client_stream_test.go b/sdks/community/go/pkg/client/sse/client_stream_test.go index 5ddcb5660..7c8dc0f42 100644 --- a/sdks/community/go/pkg/client/sse/client_stream_test.go +++ b/sdks/community/go/pkg/client/sse/client_stream_test.go @@ -11,45 +11,54 @@ import ( "testing" "time" + "github.com/ag-ui-protocol/ag-ui/sdks/community/go/pkg/core/types" "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) +// testPayload returns a simple RunAgentInput for testing +func testPayload() types.RunAgentInput { + return types.RunAgentInput{ + ThreadId: "test-thread", + RunId: "test-run", + } +} + func TestStream(t *testing.T) { t.Run("successful stream", func(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { assert.Equal(t, "application/json", r.Header.Get("Content-Type")) assert.Equal(t, "text/event-stream", r.Header.Get("Accept")) - + w.Header().Set("Content-Type", "text/event-stream") w.WriteHeader(http.StatusOK) - + flusher, ok := w.(http.Flusher) require.True(t, ok) - + fmt.Fprintf(w, "data: first message\n\n") flusher.Flush() - + fmt.Fprintf(w, "data: second message\n\n") flusher.Flush() - + fmt.Fprintf(w, "data: {\"type\":\"json\",\"value\":123}\n\n") flusher.Flush() })) defer server.Close() - + client := NewClient(Config{ Endpoint: server.URL, BufferSize: 10, }) - + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() - + frames, errors, err := client.Stream(StreamOptions{ Context: ctx, - Payload: map[string]string{"test": "data"}, + Payload: testPayload(), }) require.NoError(t, err) @@ -106,10 +115,10 @@ func TestStream(t *testing.T) { frames, _, err := client.Stream(StreamOptions{ Context: ctx, - Payload: struct{}{}, + Payload: testPayload(), }) require.NoError(t, err) - + select { case frame := <-frames: assert.Equal(t, "line1\nline2\nline3", string(frame.Data)) @@ -170,13 +179,13 @@ func TestStream(t *testing.T) { _, _, err := client.Stream(StreamOptions{ Context: ctx, - Payload: struct{}{}, + Payload: testPayload(), }) require.NoError(t, err) }) } }) - + t.Run("custom headers", func(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { assert.Equal(t, "custom-value", r.Header.Get("X-Custom-Header")) @@ -195,7 +204,7 @@ func TestStream(t *testing.T) { _, _, err := client.Stream(StreamOptions{ Context: ctx, - Payload: struct{}{}, + Payload: testPayload(), Headers: map[string]string{ "X-Custom-Header": "custom-value", "X-Another-Header": "another-value", @@ -203,7 +212,7 @@ func TestStream(t *testing.T) { }) require.NoError(t, err) }) - + t.Run("error responses", func(t *testing.T) { tests := []struct { name string @@ -250,16 +259,16 @@ func TestStream(t *testing.T) { client := NewClient(Config{ Endpoint: server.URL, }) - + _, _, err := client.Stream(StreamOptions{ - Payload: struct{}{}, + Payload: testPayload(), }) require.Error(t, err) assert.Contains(t, err.Error(), tt.expectedErr) }) } }) - + t.Run("context cancellation", func(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/event-stream") @@ -283,13 +292,13 @@ func TestStream(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) defer cancel() - + frames, errors, err := client.Stream(StreamOptions{ Context: ctx, - Payload: struct{}{}, + Payload: testPayload(), }) require.NoError(t, err) - + messageCount := 0 for { select { @@ -309,32 +318,17 @@ func TestStream(t *testing.T) { } }) - t.Run("invalid payload marshaling", func(t *testing.T) { - client := NewClient(Config{ - Endpoint: "http://localhost", - }) - - // Create an unmarshalable payload - invalidPayload := make(chan int) - - _, _, err := client.Stream(StreamOptions{ - Payload: invalidPayload, - }) - require.Error(t, err) - assert.Contains(t, err.Error(), "failed to marshal payload") - }) - t.Run("invalid endpoint", func(t *testing.T) { client := NewClient(Config{ Endpoint: "http://[::1]:namedport", // Invalid URL }) - + _, _, err := client.Stream(StreamOptions{ - Payload: struct{}{}, + Payload: testPayload(), }) require.Error(t, err) }) - + t.Run("concurrent reads", func(t *testing.T) { messageCount := 50 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -358,13 +352,13 @@ func TestStream(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() - + frames, _, err := client.Stream(StreamOptions{ Context: ctx, - Payload: struct{}{}, + Payload: testPayload(), }) require.NoError(t, err) - + var wg sync.WaitGroup received := make(map[string]bool) mu := sync.Mutex{} @@ -410,13 +404,13 @@ func TestStream(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() - + frames, errors, err := client.Stream(StreamOptions{ Context: ctx, - Payload: struct{}{}, + Payload: testPayload(), }) require.NoError(t, err) - + // Should receive initial message select { case frame := <-frames: @@ -463,13 +457,13 @@ func TestStream(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) defer cancel() - + frames, _, err := client.Stream(StreamOptions{ Context: ctx, - Payload: struct{}{}, + Payload: testPayload(), }) require.NoError(t, err) - + // Consume all frames go func() { for range frames { @@ -691,12 +685,12 @@ func BenchmarkStream(b *testing.B) { frames, _, err := client.Stream(StreamOptions{ Context: ctx, - Payload: struct{}{}, + Payload: testPayload(), }) if err != nil { b.Fatal(err) } - + count := 0 for range frames { count++ diff --git a/sdks/community/go/pkg/core/types/types.go b/sdks/community/go/pkg/core/types/types.go new file mode 100644 index 000000000..ab62917cd --- /dev/null +++ b/sdks/community/go/pkg/core/types/types.go @@ -0,0 +1,48 @@ +package types + +// Context represents additional context provided to the agent +type Context struct { + Description string `json:"description"` + Value string `json:"value"` +} + +// Tool represents a tool available to the agent +type Tool struct { + Name string `json:"name"` + Description string `json:"description"` + Parameters any `json:"parameters"` // JSON Schema for the tool parameters +} + +// RunAgentInput represents the input payload for running an agent +type RunAgentInput struct { + ThreadId string `json:"threadId"` + RunId string `json:"runId"` + State any `json:"state,omitempty"` + Messages []Message `json:"messages"` + Tools []Tool `json:"tools,omitempty"` + Context []Context `json:"context,omitempty"` + ForwardedProps any `json:"forwardedProps,omitempty"` +} + +// Message represents a message in the conversation +type Message struct { + ID string `json:"id"` + Role string `json:"role"` + Content *string `json:"content,omitempty"` + Name *string `json:"name,omitempty"` + ToolCalls []ToolCall `json:"toolCalls,omitempty"` + ToolCallID *string `json:"toolCallId,omitempty"` +} + +// ToolCall represents a tool call within a message +type ToolCall struct { + ID string `json:"id"` + Type string `json:"type"` + Function Function `json:"function"` +} + +// Function represents a function call +type Function struct { + Name string `json:"name"` + Arguments string `json:"arguments"` +}