Skip to content
This repository was archived by the owner on Sep 30, 2024. It is now read-only.

Commit 4613452

Browse files
authored
refactor(cody): Reshape the CompletionsClient interface (#63358)
This PR refactors the `CompletionsClient` interface, and all the corresponding call sites. There is no functional change, beyond bundling several function parameters into a new type. See `internal/completions/types/types.go`. But the gist is this putting 3x parameters into a single `CompletionRequest` type. ```diff Complete( context.Context, log.Logger - CompletionsFeature, - CompletionsVersion, - CompletionRequestParameters + CompletionRequest ) (*CompletionResponse, error) ``` ## Why? As part of reworking the codepath between receiving a completion request, dispatching it to the right `CompletionsClient` implementation, and serving the request, I need some "hooks" to inject new information. In a future PR I plan on adding a `*ServerSideModelConfig` as another field to the `CompletionRequest`, so that when the `CompletionClient`'s implementation is trying to serve that request it has any additional data it needs. (For example, the AWS Bedrock provisioned capacity ARN, etc.) ## Test plan Updated existing tests, relying on CI/CD for any other issues. ## Changelog NA, just some under the hood refactoring that shouldn't impact any functionality.
1 parent b717fd5 commit 4613452

File tree

18 files changed

+243
-105
lines changed

18 files changed

+243
-105
lines changed

cmd/cody-gateway/internal/httpapi/embeddings/metadata.go

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55

66
"github.com/sourcegraph/conc/iter"
77
"github.com/sourcegraph/log"
8+
89
"github.com/sourcegraph/sourcegraph/internal/codygateway"
910
"github.com/sourcegraph/sourcegraph/internal/completions/client/fireworks"
1011
"github.com/sourcegraph/sourcegraph/internal/completions/types"
@@ -35,8 +36,10 @@ Return your response in text format. Each entry name should be followed by a new
3536
Respond with nothing else, only the entry names and the documentation. Code: ` +
3637
"```\n" + *input + "\n```"
3738

38-
resp, err := c.completionsClient.Complete(c.ctx, types.CompletionsFeatureChat, types.CompletionsVersionLegacy,
39-
types.CompletionRequestParameters{
39+
compRequest := types.CompletionRequest{
40+
Feature: types.CompletionsFeatureChat,
41+
Version: types.CompletionsVersionLegacy,
42+
Parameters: types.CompletionRequestParameters{
4043
Messages: []types.Message{{
4144
Speaker: "user",
4245
Text: promptText,
@@ -45,7 +48,9 @@ Respond with nothing else, only the entry names and the documentation. Code: ` +
4548
Temperature: 0,
4649
TopP: 1,
4750
Model: fireworks.Llama38bInstruct,
48-
}, c.logger)
51+
},
52+
}
53+
resp, err := c.completionsClient.Complete(c.ctx, c.logger, compRequest)
4954

5055
if err != nil {
5156
return "", err

cmd/frontend/internal/completions/resolvers/resolver.go

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -74,13 +74,17 @@ func (c *completionsResolver) Completions(ctx context.Context, args graphqlbacke
7474
return "", err
7575
}
7676

77-
// GraphQL API is considered a legacy API
78-
version := types.CompletionsVersionLegacy
79-
8077
params := convertParams(args)
8178
// No way to configure the model through the request, we hard code to chat.
8279
params.Model = chatModel
83-
resp, err := client.Complete(ctx, types.CompletionsFeatureChat, version, params, c.logger)
80+
81+
request := types.CompletionRequest{
82+
Feature: types.CompletionsFeatureChat,
83+
// GraphQL API is considered a legacy API.
84+
Version: types.CompletionsVersionLegacy,
85+
Parameters: params,
86+
}
87+
resp, err := client.Complete(ctx, c.logger, request)
8488
if err != nil {
8589
return "", errors.Wrap(err, "client.Complete")
8690
}

cmd/frontend/internal/httpapi/completions/handler.go

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -316,14 +316,21 @@ func newStreamingResponseHandler(logger log.Logger, db database.DB, feature type
316316
f = ff
317317
}
318318
}
319-
err := cc.Stream(ctx, feature, version, requestParams,
320-
func(event types.CompletionResponse) error {
321-
if !firstEventObserved {
322-
firstEventObserved = true
323-
timeToFirstEventMetrics.Observe(time.Since(start).Seconds(), 1, nil, requestParams.Model)
324-
}
325-
return f.Send(ctx, event)
326-
}, logger)
319+
320+
// Build and send the completions request.
321+
compReq := types.CompletionRequest{
322+
Feature: feature,
323+
Version: version,
324+
Parameters: requestParams,
325+
}
326+
sendEventFn := func(event types.CompletionResponse) error {
327+
if !firstEventObserved {
328+
firstEventObserved = true
329+
timeToFirstEventMetrics.Observe(time.Since(start).Seconds(), 1, nil, requestParams.Model)
330+
}
331+
return f.Send(ctx, event)
332+
}
333+
err := cc.Stream(ctx, logger, compReq, sendEventFn)
327334
if err != nil {
328335
l := trace.Logger(ctx, logger)
329336

@@ -394,7 +401,12 @@ func newStreamingResponseHandler(logger log.Logger, db database.DB, feature type
394401
// to the client.
395402
func newNonStreamingResponseHandler(logger log.Logger, db database.DB, feature types.CompletionsFeature) func(ctx context.Context, requestParams types.CompletionRequestParameters, version types.CompletionsVersion, cc types.CompletionsClient, w http.ResponseWriter, userStore database.UserStore) {
396403
return func(ctx context.Context, requestParams types.CompletionRequestParameters, version types.CompletionsVersion, cc types.CompletionsClient, w http.ResponseWriter, userStore database.UserStore) {
397-
completion, err := cc.Complete(ctx, feature, version, requestParams, logger)
404+
compRequest := types.CompletionRequest{
405+
Feature: feature,
406+
Version: version,
407+
Parameters: requestParams,
408+
}
409+
completion, err := cc.Complete(ctx, logger, compRequest)
398410
if err != nil {
399411
logFields := []log.Field{log.Error(err)}
400412

internal/completions/client/anthropic/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ go_test(
3333
deps = [
3434
"//internal/completions/tokenusage",
3535
"//internal/completions/types",
36+
"//lib/pointers",
3637
"@com_github_hexops_autogold_v2//:autogold",
3738
"@com_github_sourcegraph_log//:log",
3839
"@com_github_stretchr_testify//assert",

internal/completions/client/anthropic/anthropic.go

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,13 @@ type anthropicClient struct {
4343

4444
func (a *anthropicClient) Complete(
4545
ctx context.Context,
46-
feature types.CompletionsFeature,
47-
version types.CompletionsVersion,
48-
requestParams types.CompletionRequestParameters,
4946
logger log.Logger,
50-
) (*types.CompletionResponse, error) {
47+
request types.CompletionRequest) (*types.CompletionResponse, error) {
48+
49+
feature := request.Feature
50+
version := request.Version
51+
requestParams := request.Parameters
52+
5153
resp, err := a.makeRequest(ctx, requestParams, version, false)
5254
if err != nil {
5355
return nil, err
@@ -78,12 +80,14 @@ func (a *anthropicClient) Complete(
7880

7981
func (a *anthropicClient) Stream(
8082
ctx context.Context,
81-
feature types.CompletionsFeature,
82-
version types.CompletionsVersion,
83-
requestParams types.CompletionRequestParameters,
84-
sendEvent types.SendCompletionEvent,
8583
logger log.Logger,
86-
) error {
84+
request types.CompletionRequest,
85+
sendEvent types.SendCompletionEvent) error {
86+
87+
feature := request.Feature
88+
version := request.Version
89+
requestParams := request.Parameters
90+
8791
resp, err := a.makeRequest(ctx, requestParams, version, true)
8892
if err != nil {
8993
return err

internal/completions/client/anthropic/anthropic_test.go

Lines changed: 54 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import (
1515

1616
"github.com/sourcegraph/sourcegraph/internal/completions/tokenusage"
1717
"github.com/sourcegraph/sourcegraph/internal/completions/types"
18+
"github.com/sourcegraph/sourcegraph/lib/pointers"
1819
)
1920

2021
type mockDoer struct {
@@ -68,13 +69,21 @@ func TestValidAnthropicMessagesStream(t *testing.T) {
6869

6970
mockClient := getMockClient(linesToResponse(mockAnthropicMessagesResponseLines, "\n\n"))
7071
events := []types.CompletionResponse{}
71-
stream := true
72-
err := mockClient.Stream(context.Background(), types.CompletionsFeatureChat, types.CompletionsVersionLegacy, types.CompletionRequestParameters{
73-
Stream: &stream,
74-
}, func(event types.CompletionResponse) error {
72+
73+
sendEventFn := func(event types.CompletionResponse) error {
7574
events = append(events, event)
7675
return nil
77-
}, logger)
76+
}
77+
78+
compRequest := types.CompletionRequest{
79+
Feature: types.CompletionsFeatureChat,
80+
Version: types.CompletionsVersionLegacy,
81+
Parameters: types.CompletionRequestParameters{
82+
Stream: pointers.Ptr(true),
83+
},
84+
}
85+
86+
err := mockClient.Stream(context.Background(), logger, compRequest, sendEventFn)
7887
if err != nil {
7988
t.Fatal(err)
8089
}
@@ -86,7 +95,15 @@ func TestInvalidAnthropicMessagesStream(t *testing.T) {
8695
logger := log.Scoped("completions")
8796

8897
mockClient := getMockClient(linesToResponse(mockAnthropicInvalidResponseLines, "\r\n\r\n"))
89-
err := mockClient.Stream(context.Background(), types.CompletionsFeatureChat, types.CompletionsVersionLegacy, types.CompletionRequestParameters{}, func(event types.CompletionResponse) error { return nil }, logger)
98+
99+
compRequest := types.CompletionRequest{
100+
Feature: types.CompletionsFeatureChat,
101+
Version: types.CompletionsVersionLegacy,
102+
Parameters: types.CompletionRequestParameters{},
103+
}
104+
sendEventFn := func(event types.CompletionResponse) error { return nil }
105+
106+
err := mockClient.Stream(context.Background(), logger, compRequest, sendEventFn)
90107
if err == nil {
91108
t.Fatal("expected error, got nil")
92109
}
@@ -104,9 +121,15 @@ func TestErrStatusNotOK(t *testing.T) {
104121
},
105122
}, "", "", false, *tokenManager)
106123

124+
compRequest := types.CompletionRequest{
125+
Feature: types.CompletionsFeatureChat,
126+
Version: types.CompletionsVersionLegacy,
127+
Parameters: types.CompletionRequestParameters{},
128+
}
129+
107130
t.Run("Complete", func(t *testing.T) {
108131
logger := log.Scoped("completions")
109-
resp, err := mockClient.Complete(context.Background(), types.CompletionsFeatureChat, types.CompletionsVersionLegacy, types.CompletionRequestParameters{}, logger)
132+
resp, err := mockClient.Complete(context.Background(), logger, compRequest)
110133
require.Error(t, err)
111134
assert.Nil(t, resp)
112135

@@ -117,7 +140,11 @@ func TestErrStatusNotOK(t *testing.T) {
117140

118141
t.Run("Stream", func(t *testing.T) {
119142
logger := log.Scoped("completions")
120-
err := mockClient.Stream(context.Background(), types.CompletionsFeatureChat, types.CompletionsVersionLegacy, types.CompletionRequestParameters{}, func(event types.CompletionResponse) error { return nil }, logger)
143+
sendEventFn := func(event types.CompletionResponse) error {
144+
return nil
145+
}
146+
147+
err := mockClient.Stream(context.Background(), logger, compRequest, sendEventFn)
121148
require.Error(t, err)
122149

123150
autogold.Expect("Anthropic: unexpected status code 429: oh no, please slow down!").Equal(t, err.Error())
@@ -149,7 +176,15 @@ func TestCompleteApiToMessages(t *testing.T) {
149176

150177
t.Run("Complete", func(t *testing.T) {
151178
logger := log.Scoped("completions")
152-
resp, err := mockClient.Complete(context.Background(), types.CompletionsFeatureChat, types.CompletionsVersionLegacy, types.CompletionRequestParameters{Messages: messages}, logger)
179+
compRequest := types.CompletionRequest{
180+
Feature: types.CompletionsFeatureChat,
181+
Version: types.CompletionsVersionLegacy,
182+
Parameters: types.CompletionRequestParameters{
183+
Messages: messages,
184+
},
185+
}
186+
187+
resp, err := mockClient.Complete(context.Background(), logger, compRequest)
153188
require.Error(t, err)
154189
assert.Nil(t, resp)
155190

@@ -162,8 +197,16 @@ func TestCompleteApiToMessages(t *testing.T) {
162197

163198
t.Run("Stream", func(t *testing.T) {
164199
logger := log.Scoped("completions")
165-
stream := true
166-
err := mockClient.Stream(context.Background(), types.CompletionsFeatureChat, types.CompletionsVersionLegacy, types.CompletionRequestParameters{Messages: messages, Stream: &stream}, func(event types.CompletionResponse) error { return nil }, logger)
200+
compRequest := types.CompletionRequest{
201+
Feature: types.CompletionsFeatureChat,
202+
Version: types.CompletionsVersionLegacy,
203+
Parameters: types.CompletionRequestParameters{
204+
Messages: messages,
205+
Stream: pointers.Ptr(true),
206+
},
207+
}
208+
sendEventFn := func(event types.CompletionResponse) error { return nil }
209+
err := mockClient.Stream(context.Background(), logger, compRequest, sendEventFn)
167210
require.Error(t, err)
168211

169212
assert.NotNil(t, response)

internal/completions/client/awsbedrock/bedrock.go

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -49,11 +49,13 @@ type awsBedrockAnthropicCompletionStreamClient struct {
4949

5050
func (c *awsBedrockAnthropicCompletionStreamClient) Complete(
5151
ctx context.Context,
52-
feature types.CompletionsFeature,
53-
version types.CompletionsVersion,
54-
requestParams types.CompletionRequestParameters,
5552
logger log.Logger,
56-
) (*types.CompletionResponse, error) {
53+
request types.CompletionRequest) (*types.CompletionResponse, error) {
54+
55+
feature := request.Feature
56+
version := request.Version
57+
requestParams := request.Parameters
58+
5759
resp, err := c.makeRequest(ctx, requestParams, version, false)
5860
if err != nil {
5961
return nil, errors.Wrap(err, "making request")
@@ -82,12 +84,14 @@ func (c *awsBedrockAnthropicCompletionStreamClient) Complete(
8284

8385
func (a *awsBedrockAnthropicCompletionStreamClient) Stream(
8486
ctx context.Context,
85-
feature types.CompletionsFeature,
86-
version types.CompletionsVersion,
87-
requestParams types.CompletionRequestParameters,
88-
sendEvent types.SendCompletionEvent,
8987
logger log.Logger,
90-
) error {
88+
request types.CompletionRequest,
89+
sendEvent types.SendCompletionEvent) error {
90+
91+
feature := request.Feature
92+
version := request.Version
93+
requestParams := request.Parameters
94+
9195
resp, err := a.makeRequest(ctx, requestParams, version, true)
9296
if err != nil {
9397
return errors.Wrap(err, "making request")

internal/completions/client/azureopenai/openai.go

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -129,11 +129,10 @@ type azureCompletionClient struct {
129129

130130
func (c *azureCompletionClient) Complete(
131131
ctx context.Context,
132-
feature types.CompletionsFeature,
133-
_ types.CompletionsVersion,
134-
requestParams types.CompletionRequestParameters,
135132
log log.Logger,
136-
) (*types.CompletionResponse, error) {
133+
request types.CompletionRequest) (*types.CompletionResponse, error) {
134+
feature := request.Feature
135+
requestParams := request.Parameters
137136

138137
switch feature {
139138
case types.CompletionsFeatureCode:
@@ -189,12 +188,13 @@ func completeChat(
189188

190189
func (c *azureCompletionClient) Stream(
191190
ctx context.Context,
192-
feature types.CompletionsFeature,
193-
_ types.CompletionsVersion,
194-
requestParams types.CompletionRequestParameters,
195-
sendEvent types.SendCompletionEvent,
196191
log log.Logger,
192+
request types.CompletionRequest,
193+
sendEvent types.SendCompletionEvent,
197194
) error {
195+
feature := request.Feature
196+
requestParams := request.Parameters
197+
198198
switch feature {
199199
case types.CompletionsFeatureCode:
200200
return streamAutocomplete(ctx, c.client, requestParams, sendEvent, log)

internal/completions/client/azureopenai/openai_test.go

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,9 +79,16 @@ func TestErrStatusNotOK(t *testing.T) {
7979
})
8080
tokenManager := tokenusage.NewManager()
8181
mockClient, _ := NewClient(getAzureAPIClient, "", "", *tokenManager)
82+
83+
compRequest := types.CompletionRequest{
84+
Feature: types.CompletionsFeatureChat,
85+
Version: types.CompletionsVersionLegacy,
86+
Parameters: types.CompletionRequestParameters{},
87+
}
88+
8289
t.Run("Complete", func(t *testing.T) {
8390
logger := log.Scoped("completions")
84-
resp, err := mockClient.Complete(context.Background(), types.CompletionsFeatureChat, types.CompletionsVersionLegacy, types.CompletionRequestParameters{}, logger)
91+
resp, err := mockClient.Complete(context.Background(), logger, compRequest)
8592
require.Error(t, err)
8693
assert.Nil(t, resp)
8794

@@ -92,7 +99,8 @@ func TestErrStatusNotOK(t *testing.T) {
9299

93100
t.Run("Stream", func(t *testing.T) {
94101
logger := log.Scoped("completions")
95-
err := mockClient.Stream(context.Background(), types.CompletionsFeatureChat, types.CompletionsVersionLegacy, types.CompletionRequestParameters{}, func(event types.CompletionResponse) error { return nil }, logger)
102+
sendEventFn := func(event types.CompletionResponse) error { return nil }
103+
err := mockClient.Stream(context.Background(), logger, compRequest, sendEventFn)
96104
require.Error(t, err)
97105

98106
autogold.Expect("AzureOpenAI: unexpected status code 429: too many requests").Equal(t, err.Error())
@@ -118,9 +126,16 @@ func TestGenericErr(t *testing.T) {
118126
})
119127
tokenManager := tokenusage.NewManager()
120128
mockClient, _ := NewClient(getAzureAPIClient, "", "", *tokenManager)
129+
130+
compRequest := types.CompletionRequest{
131+
Feature: types.CompletionsFeatureChat,
132+
Version: types.CompletionsVersionLegacy,
133+
Parameters: types.CompletionRequestParameters{},
134+
}
135+
121136
t.Run("Complete", func(t *testing.T) {
122137
logger := log.Scoped("completions")
123-
resp, err := mockClient.Complete(context.Background(), types.CompletionsFeatureChat, types.CompletionsVersionLegacy, types.CompletionRequestParameters{}, logger)
138+
resp, err := mockClient.Complete(context.Background(), logger, compRequest)
124139
require.Error(t, err)
125140
assert.Nil(t, resp)
126141

@@ -131,7 +146,8 @@ func TestGenericErr(t *testing.T) {
131146

132147
t.Run("Stream", func(t *testing.T) {
133148
logger := log.Scoped("completions")
134-
err := mockClient.Stream(context.Background(), types.CompletionsFeatureChat, types.CompletionsVersionLegacy, types.CompletionRequestParameters{}, func(event types.CompletionResponse) error { return nil }, logger)
149+
sendEventFn := func(event types.CompletionResponse) error { return nil }
150+
err := mockClient.Stream(context.Background(), logger, compRequest, sendEventFn)
135151
require.Error(t, err)
136152

137153
autogold.Expect("error").Equal(t, err.Error())

0 commit comments

Comments
 (0)