Skip to content

Commit ca1a249

Browse files
authored
Run tests with two data layer implementations (#1930)
* Configurable data layer in tests Signed-off-by: irar2 <irar@il.ibm.com> * Configurable data layer in tests Signed-off-by: irar2 <irar@il.ibm.com> --------- Signed-off-by: irar2 <irar@il.ibm.com>
1 parent 779e85d commit ca1a249

File tree

11 files changed

+734
-626
lines changed

11 files changed

+734
-626
lines changed

pkg/epp/controller/inferencemodelrewrite_reconciler_test.go

Lines changed: 59 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ import (
3636
"sigs.k8s.io/gateway-api-inference-extension/apix/v1alpha2"
3737
"sigs.k8s.io/gateway-api-inference-extension/pkg/common"
3838
backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics"
39+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer"
3940
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datastore"
4041
poolutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/pool"
4142
utiltest "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/testing"
@@ -167,58 +168,64 @@ func TestInferenceModelRewriteReconciler(t *testing.T) {
167168
},
168169
}
169170
for _, test := range tests {
170-
t.Run(test.name, func(t *testing.T) {
171-
scheme := runtime.NewScheme()
172-
_ = clientgoscheme.AddToScheme(scheme)
173-
_ = v1alpha2.Install(scheme)
174-
_ = v1.Install(scheme)
175-
initObjs := []client.Object{}
176-
if test.rewrite != nil {
177-
initObjs = append(initObjs, test.rewrite)
178-
}
179-
for _, r := range test.rewritesInAPIServer {
180-
initObjs = append(initObjs, r)
181-
}
182-
fakeClient := fake.NewClientBuilder().
183-
WithScheme(scheme).
184-
WithObjects(initObjs...).
185-
Build()
186-
pmf := backendmetrics.NewPodMetricsFactory(&backendmetrics.FakePodMetricsClient{}, time.Second)
187-
ds := datastore.NewDatastore(t.Context(), pmf, 0)
188-
for _, r := range test.rewritesInStore {
189-
ds.ModelRewriteSet(r)
190-
}
191-
endpointPool := poolutil.InferencePoolToEndpointPool(poolForRewrite)
192-
_ = ds.PoolSet(context.Background(), fakeClient, endpointPool)
193-
reconciler := &InferenceModelRewriteReconciler{
194-
Reader: fakeClient,
195-
Datastore: ds,
196-
PoolGKNN: common.GKNN{
197-
NamespacedName: types.NamespacedName{Name: poolForRewrite.Name, Namespace: poolForRewrite.Namespace},
198-
GroupKind: schema.GroupKind{Group: poolForRewrite.GroupVersionKind().Group, Kind: poolForRewrite.GroupVersionKind().Kind},
199-
},
200-
}
201-
if test.incomingReq == nil {
202-
test.incomingReq = &types.NamespacedName{Name: test.rewrite.Name, Namespace: test.rewrite.Namespace}
203-
}
204-
205-
result, err := reconciler.Reconcile(context.Background(), ctrl.Request{NamespacedName: *test.incomingReq})
206-
if err != nil {
207-
t.Fatalf("expected no error, got %v", err)
208-
}
209-
210-
if diff := cmp.Diff(result, test.wantResult); diff != "" {
211-
t.Errorf("Unexpected result diff (+got/-want): %s", diff)
212-
}
213-
214-
if len(test.wantRewrites) != len(ds.ModelRewriteGetAll()) {
215-
t.Errorf("Unexpected number of rewrites; want: %d, got:%d", len(test.wantRewrites), len(ds.ModelRewriteGetAll()))
216-
}
217-
218-
if diff := diffStoreRewrites(ds, test.wantRewrites); diff != "" {
219-
t.Errorf("Unexpected diff (+got/-want): %s", diff)
220-
}
221-
})
171+
period := time.Second
172+
factories := []datalayer.EndpointFactory{
173+
backendmetrics.NewPodMetricsFactory(&backendmetrics.FakePodMetricsClient{}, period),
174+
datalayer.NewEndpointFactory([]datalayer.DataSource{&datalayer.FakeDataSource{}}, period),
175+
}
176+
for _, epf := range factories {
177+
t.Run(test.name, func(t *testing.T) {
178+
scheme := runtime.NewScheme()
179+
_ = clientgoscheme.AddToScheme(scheme)
180+
_ = v1alpha2.Install(scheme)
181+
_ = v1.Install(scheme)
182+
initObjs := []client.Object{}
183+
if test.rewrite != nil {
184+
initObjs = append(initObjs, test.rewrite)
185+
}
186+
for _, r := range test.rewritesInAPIServer {
187+
initObjs = append(initObjs, r)
188+
}
189+
fakeClient := fake.NewClientBuilder().
190+
WithScheme(scheme).
191+
WithObjects(initObjs...).
192+
Build()
193+
ds := datastore.NewDatastore(t.Context(), epf, 0)
194+
for _, r := range test.rewritesInStore {
195+
ds.ModelRewriteSet(r)
196+
}
197+
endpointPool := poolutil.InferencePoolToEndpointPool(poolForRewrite)
198+
_ = ds.PoolSet(context.Background(), fakeClient, endpointPool)
199+
reconciler := &InferenceModelRewriteReconciler{
200+
Reader: fakeClient,
201+
Datastore: ds,
202+
PoolGKNN: common.GKNN{
203+
NamespacedName: types.NamespacedName{Name: poolForRewrite.Name, Namespace: poolForRewrite.Namespace},
204+
GroupKind: schema.GroupKind{Group: poolForRewrite.GroupVersionKind().Group, Kind: poolForRewrite.GroupVersionKind().Kind},
205+
},
206+
}
207+
if test.incomingReq == nil {
208+
test.incomingReq = &types.NamespacedName{Name: test.rewrite.Name, Namespace: test.rewrite.Namespace}
209+
}
210+
211+
result, err := reconciler.Reconcile(context.Background(), ctrl.Request{NamespacedName: *test.incomingReq})
212+
if err != nil {
213+
t.Fatalf("expected no error, got %v", err)
214+
}
215+
216+
if diff := cmp.Diff(result, test.wantResult); diff != "" {
217+
t.Errorf("Unexpected result diff (+got/-want): %s", diff)
218+
}
219+
220+
if len(test.wantRewrites) != len(ds.ModelRewriteGetAll()) {
221+
t.Errorf("Unexpected number of rewrites; want: %d, got:%d", len(test.wantRewrites), len(ds.ModelRewriteGetAll()))
222+
}
223+
224+
if diff := diffStoreRewrites(ds, test.wantRewrites); diff != "" {
225+
t.Errorf("Unexpected diff (+got/-want): %s", diff)
226+
}
227+
})
228+
}
222229
}
223230
}
224231

pkg/epp/controller/inferenceobjective_reconciler_test.go

Lines changed: 57 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ import (
3535
"sigs.k8s.io/gateway-api-inference-extension/apix/v1alpha2"
3636
"sigs.k8s.io/gateway-api-inference-extension/pkg/common"
3737
backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics"
38+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer"
3839
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datastore"
3940
poolutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/pool"
4041
utiltest "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/testing"
@@ -143,59 +144,65 @@ func TestInferenceObjectiveReconciler(t *testing.T) {
143144
},
144145
}
145146
for _, test := range tests {
146-
t.Run(test.name, func(t *testing.T) {
147-
// Create a fake client with no InferenceObjective objects.
148-
scheme := runtime.NewScheme()
149-
_ = clientgoscheme.AddToScheme(scheme)
150-
_ = v1alpha2.Install(scheme)
151-
_ = v1.Install(scheme)
152-
initObjs := []client.Object{}
153-
if test.objective != nil {
154-
initObjs = append(initObjs, test.objective)
155-
}
156-
for _, m := range test.objectivesInAPIServer {
157-
initObjs = append(initObjs, m)
158-
}
159-
fakeClient := fake.NewClientBuilder().
160-
WithScheme(scheme).
161-
WithObjects(initObjs...).
162-
Build()
163-
pmf := backendmetrics.NewPodMetricsFactory(&backendmetrics.FakePodMetricsClient{}, time.Second)
164-
ds := datastore.NewDatastore(t.Context(), pmf, 0)
165-
for _, m := range test.objectivessInStore {
166-
ds.ObjectiveSet(m)
167-
}
168-
endpointPool := poolutil.InferencePoolToEndpointPool(inferencePool)
169-
_ = ds.PoolSet(context.Background(), fakeClient, endpointPool)
170-
reconciler := &InferenceObjectiveReconciler{
171-
Reader: fakeClient,
172-
Datastore: ds,
173-
PoolGKNN: common.GKNN{
174-
NamespacedName: types.NamespacedName{Name: inferencePool.Name, Namespace: inferencePool.Namespace},
175-
GroupKind: schema.GroupKind{Group: inferencePool.GroupVersionKind().Group, Kind: inferencePool.GroupVersionKind().Kind},
176-
},
177-
}
178-
if test.incomingReq == nil {
179-
test.incomingReq = &types.NamespacedName{Name: test.objective.Name, Namespace: test.objective.Namespace}
180-
}
147+
period := time.Second
148+
factories := []datalayer.EndpointFactory{
149+
backendmetrics.NewPodMetricsFactory(&backendmetrics.FakePodMetricsClient{}, period),
150+
datalayer.NewEndpointFactory([]datalayer.DataSource{&datalayer.FakeDataSource{}}, period),
151+
}
152+
for _, epf := range factories {
153+
t.Run(test.name, func(t *testing.T) {
154+
// Create a fake client with no InferenceObjective objects.
155+
scheme := runtime.NewScheme()
156+
_ = clientgoscheme.AddToScheme(scheme)
157+
_ = v1alpha2.Install(scheme)
158+
_ = v1.Install(scheme)
159+
initObjs := []client.Object{}
160+
if test.objective != nil {
161+
initObjs = append(initObjs, test.objective)
162+
}
163+
for _, m := range test.objectivesInAPIServer {
164+
initObjs = append(initObjs, m)
165+
}
166+
fakeClient := fake.NewClientBuilder().
167+
WithScheme(scheme).
168+
WithObjects(initObjs...).
169+
Build()
170+
ds := datastore.NewDatastore(t.Context(), epf, 0)
171+
for _, m := range test.objectivessInStore {
172+
ds.ObjectiveSet(m)
173+
}
174+
endpointPool := poolutil.InferencePoolToEndpointPool(inferencePool)
175+
_ = ds.PoolSet(context.Background(), fakeClient, endpointPool)
176+
reconciler := &InferenceObjectiveReconciler{
177+
Reader: fakeClient,
178+
Datastore: ds,
179+
PoolGKNN: common.GKNN{
180+
NamespacedName: types.NamespacedName{Name: inferencePool.Name, Namespace: inferencePool.Namespace},
181+
GroupKind: schema.GroupKind{Group: inferencePool.GroupVersionKind().Group, Kind: inferencePool.GroupVersionKind().Kind},
182+
},
183+
}
184+
if test.incomingReq == nil {
185+
test.incomingReq = &types.NamespacedName{Name: test.objective.Name, Namespace: test.objective.Namespace}
186+
}
181187

182-
// Call Reconcile.
183-
result, err := reconciler.Reconcile(context.Background(), ctrl.Request{NamespacedName: *test.incomingReq})
184-
if err != nil {
185-
t.Fatalf("expected no error when resource is not found, got %v", err)
186-
}
188+
// Call Reconcile.
189+
result, err := reconciler.Reconcile(context.Background(), ctrl.Request{NamespacedName: *test.incomingReq})
190+
if err != nil {
191+
t.Fatalf("expected no error when resource is not found, got %v", err)
192+
}
187193

188-
if diff := cmp.Diff(result, test.wantResult); diff != "" {
189-
t.Errorf("Unexpected result diff (+got/-want): %s", diff)
190-
}
194+
if diff := cmp.Diff(result, test.wantResult); diff != "" {
195+
t.Errorf("Unexpected result diff (+got/-want): %s", diff)
196+
}
191197

192-
if len(test.wantObjectives) != len(ds.ObjectiveGetAll()) {
193-
t.Errorf("Unexpected; want: %d, got:%d", len(test.wantObjectives), len(ds.ObjectiveGetAll()))
194-
}
195-
if diff := diffStore(ds, diffStoreParams{wantPool: endpointPool, wantObjectives: test.wantObjectives}); diff != "" {
196-
t.Errorf("Unexpected diff (+got/-want): %s", diff)
197-
}
198+
if len(test.wantObjectives) != len(ds.ObjectiveGetAll()) {
199+
t.Errorf("Unexpected; want: %d, got:%d", len(test.wantObjectives), len(ds.ObjectiveGetAll()))
200+
}
201+
if diff := diffStore(ds, diffStoreParams{wantPool: endpointPool, wantObjectives: test.wantObjectives}); diff != "" {
202+
t.Errorf("Unexpected diff (+got/-want): %s", diff)
203+
}
198204

199-
})
205+
})
206+
}
200207
}
201208
}

0 commit comments

Comments
 (0)