Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion sdks/community/go/pkg/client/sse/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"strings"
"time"

"github.com/ag-ui-protocol/ag-ui/sdks/community/go/pkg/core/types"
"github.com/sirupsen/logrus"
)

Expand Down Expand Up @@ -38,7 +39,7 @@ type Frame struct {

type StreamOptions struct {
Context context.Context
Payload interface{}
Payload types.RunAgentInput
Headers map[string]string
}

Expand Down
94 changes: 44 additions & 50 deletions sdks/community/go/pkg/client/sse/client_stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,45 +11,54 @@ import (
"testing"
"time"

"github.com/ag-ui-protocol/ag-ui/sdks/community/go/pkg/core/types"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

// testPayload returns a simple RunAgentInput for testing
func testPayload() types.RunAgentInput {
return types.RunAgentInput{
ThreadId: "test-thread",
RunId: "test-run",
}
}

func TestStream(t *testing.T) {
t.Run("successful stream", func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "application/json", r.Header.Get("Content-Type"))
assert.Equal(t, "text/event-stream", r.Header.Get("Accept"))

w.Header().Set("Content-Type", "text/event-stream")
w.WriteHeader(http.StatusOK)

flusher, ok := w.(http.Flusher)
require.True(t, ok)

fmt.Fprintf(w, "data: first message\n\n")
flusher.Flush()

fmt.Fprintf(w, "data: second message\n\n")
flusher.Flush()

fmt.Fprintf(w, "data: {\"type\":\"json\",\"value\":123}\n\n")
flusher.Flush()
}))
defer server.Close()

client := NewClient(Config{
Endpoint: server.URL,
BufferSize: 10,
})

ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()

frames, errors, err := client.Stream(StreamOptions{
Context: ctx,
Payload: map[string]string{"test": "data"},
Payload: testPayload(),
})
require.NoError(t, err)

Expand Down Expand Up @@ -106,10 +115,10 @@ func TestStream(t *testing.T) {

frames, _, err := client.Stream(StreamOptions{
Context: ctx,
Payload: struct{}{},
Payload: testPayload(),
})
require.NoError(t, err)

select {
case frame := <-frames:
assert.Equal(t, "line1\nline2\nline3", string(frame.Data))
Expand Down Expand Up @@ -170,13 +179,13 @@ func TestStream(t *testing.T) {

_, _, err := client.Stream(StreamOptions{
Context: ctx,
Payload: struct{}{},
Payload: testPayload(),
})
require.NoError(t, err)
})
}
})

t.Run("custom headers", func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "custom-value", r.Header.Get("X-Custom-Header"))
Expand All @@ -195,15 +204,15 @@ func TestStream(t *testing.T) {

_, _, err := client.Stream(StreamOptions{
Context: ctx,
Payload: struct{}{},
Payload: testPayload(),
Headers: map[string]string{
"X-Custom-Header": "custom-value",
"X-Another-Header": "another-value",
},
})
require.NoError(t, err)
})

t.Run("error responses", func(t *testing.T) {
tests := []struct {
name string
Expand Down Expand Up @@ -250,16 +259,16 @@ func TestStream(t *testing.T) {
client := NewClient(Config{
Endpoint: server.URL,
})

_, _, err := client.Stream(StreamOptions{
Payload: struct{}{},
Payload: testPayload(),
})
require.Error(t, err)
assert.Contains(t, err.Error(), tt.expectedErr)
})
}
})

t.Run("context cancellation", func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
Expand All @@ -283,13 +292,13 @@ func TestStream(t *testing.T) {

ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()

frames, errors, err := client.Stream(StreamOptions{
Context: ctx,
Payload: struct{}{},
Payload: testPayload(),
})
require.NoError(t, err)

messageCount := 0
for {
select {
Expand All @@ -309,32 +318,17 @@ func TestStream(t *testing.T) {
}
})

t.Run("invalid payload marshaling", func(t *testing.T) {
client := NewClient(Config{
Endpoint: "http://localhost",
})

// Create an unmarshalable payload
invalidPayload := make(chan int)

_, _, err := client.Stream(StreamOptions{
Payload: invalidPayload,
})
require.Error(t, err)
assert.Contains(t, err.Error(), "failed to marshal payload")
})

t.Run("invalid endpoint", func(t *testing.T) {
client := NewClient(Config{
Endpoint: "http://[::1]:namedport", // Invalid URL
})

_, _, err := client.Stream(StreamOptions{
Payload: struct{}{},
Payload: testPayload(),
})
require.Error(t, err)
})

t.Run("concurrent reads", func(t *testing.T) {
messageCount := 50
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Expand All @@ -358,13 +352,13 @@ func TestStream(t *testing.T) {

ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()

frames, _, err := client.Stream(StreamOptions{
Context: ctx,
Payload: struct{}{},
Payload: testPayload(),
})
require.NoError(t, err)

var wg sync.WaitGroup
received := make(map[string]bool)
mu := sync.Mutex{}
Expand Down Expand Up @@ -410,13 +404,13 @@ func TestStream(t *testing.T) {

ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()

frames, errors, err := client.Stream(StreamOptions{
Context: ctx,
Payload: struct{}{},
Payload: testPayload(),
})
require.NoError(t, err)

// Should receive initial message
select {
case frame := <-frames:
Expand Down Expand Up @@ -463,13 +457,13 @@ func TestStream(t *testing.T) {

ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()

frames, _, err := client.Stream(StreamOptions{
Context: ctx,
Payload: struct{}{},
Payload: testPayload(),
})
require.NoError(t, err)

// Consume all frames
go func() {
for range frames {
Expand Down Expand Up @@ -691,12 +685,12 @@ func BenchmarkStream(b *testing.B) {

frames, _, err := client.Stream(StreamOptions{
Context: ctx,
Payload: struct{}{},
Payload: testPayload(),
})
if err != nil {
b.Fatal(err)
}

count := 0
for range frames {
count++
Expand Down
48 changes: 48 additions & 0 deletions sdks/community/go/pkg/core/types/types.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package types

// Context represents additional context provided to the agent
type Context struct {
Description string `json:"description"`
Value string `json:"value"`
}

// Tool represents a tool available to the agent
type Tool struct {
Name string `json:"name"`
Description string `json:"description"`
Parameters any `json:"parameters"` // JSON Schema for the tool parameters
}

// RunAgentInput represents the input payload for running an agent
type RunAgentInput struct {
ThreadId string `json:"threadId"`
RunId string `json:"runId"`
State any `json:"state,omitempty"`
Messages []Message `json:"messages"`
Tools []Tool `json:"tools,omitempty"`
Context []Context `json:"context,omitempty"`
ForwardedProps any `json:"forwardedProps,omitempty"`
}

// Message represents a message in the conversation
type Message struct {
ID string `json:"id"`
Role string `json:"role"`
Content *string `json:"content,omitempty"`
Name *string `json:"name,omitempty"`
ToolCalls []ToolCall `json:"toolCalls,omitempty"`
ToolCallID *string `json:"toolCallId,omitempty"`
}

// ToolCall represents a tool call within a message
type ToolCall struct {
ID string `json:"id"`
Type string `json:"type"`
Function Function `json:"function"`
}

// Function represents a function call
type Function struct {
Name string `json:"name"`
Arguments string `json:"arguments"`
}
Loading