Skip to content

Commit a89c04f

Browse files
committed
Add basic aga controller e2e tests
1 parent 150fe5c commit a89c04f

14 files changed

+1611
-8
lines changed
Lines changed: 266 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,266 @@
1+
package globalaccelerator
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"time"
7+
8+
awssdk "github.com/aws/aws-sdk-go-v2/aws"
9+
"github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2"
10+
"github.com/aws/aws-sdk-go-v2/service/globalaccelerator"
11+
"github.com/aws/aws-sdk-go-v2/service/globalaccelerator/types"
12+
"k8s.io/apimachinery/pkg/util/wait"
13+
"sigs.k8s.io/aws-load-balancer-controller/test/framework"
14+
"sigs.k8s.io/aws-load-balancer-controller/test/framework/utils"
15+
)
16+
17+
type PortRangeExpectation struct {
18+
FromPort int32
19+
ToPort int32
20+
}
21+
22+
type PortOverrideExpectation struct {
23+
ListenerPort int32
24+
EndpointPort int32
25+
}
26+
27+
type EndpointGroupExpectation struct {
28+
TrafficDialPercentage int32
29+
PortOverrides []PortOverrideExpectation
30+
NumEndpoints int
31+
}
32+
33+
type ListenerExpectation struct {
34+
Protocol string
35+
PortRanges []PortRangeExpectation
36+
ClientAffinity string
37+
EndpointGroups []EndpointGroupExpectation
38+
}
39+
40+
type GlobalAcceleratorExpectation struct {
41+
Name string
42+
IPAddressType string
43+
Status string
44+
Listeners []ListenerExpectation
45+
}
46+
47+
func verifyGlobalAcceleratorConfiguration(ctx context.Context, f *framework.Framework, acceleratorARN string, expected GlobalAcceleratorExpectation) error {
48+
agaClient := f.Cloud.GlobalAccelerator()
49+
50+
describeAccelResp, err := agaClient.DescribeAcceleratorWithContext(ctx, &globalaccelerator.DescribeAcceleratorInput{
51+
AcceleratorArn: awssdk.String(acceleratorARN),
52+
})
53+
if err != nil {
54+
return err
55+
}
56+
57+
accelerator := describeAccelResp.Accelerator
58+
if expected.Name != "" && awssdk.ToString(accelerator.Name) != expected.Name {
59+
return fmt.Errorf("name mismatch: expected %s, got %s", expected.Name, awssdk.ToString(accelerator.Name))
60+
}
61+
if expected.IPAddressType != "" && string(accelerator.IpAddressType) != expected.IPAddressType {
62+
return fmt.Errorf("IP address type mismatch: expected %s, got %s", expected.IPAddressType, string(accelerator.IpAddressType))
63+
}
64+
if expected.Status != "" && string(accelerator.Status) != expected.Status {
65+
return fmt.Errorf("status mismatch: expected %s, got %s", expected.Status, string(accelerator.Status))
66+
}
67+
68+
if len(expected.Listeners) > 0 {
69+
listListenersResp, err := agaClient.ListListenersForAcceleratorWithContext(ctx, &globalaccelerator.ListListenersInput{
70+
AcceleratorArn: awssdk.String(acceleratorARN),
71+
})
72+
if err != nil {
73+
return err
74+
}
75+
if len(listListenersResp.Listeners) != len(expected.Listeners) {
76+
return fmt.Errorf("listener count mismatch: expected %d, got %d", len(expected.Listeners), len(listListenersResp.Listeners))
77+
}
78+
79+
for i, expectedListener := range expected.Listeners {
80+
listener := listListenersResp.Listeners[i]
81+
82+
if expectedListener.Protocol != "" && string(listener.Protocol) != expectedListener.Protocol {
83+
return fmt.Errorf("listener[%d] protocol mismatch: expected %s, got %s", i, expectedListener.Protocol, string(listener.Protocol))
84+
}
85+
if expectedListener.ClientAffinity != "" && string(listener.ClientAffinity) != expectedListener.ClientAffinity {
86+
return fmt.Errorf("listener[%d] client affinity mismatch: expected %s, got %s", i, expectedListener.ClientAffinity, string(listener.ClientAffinity))
87+
}
88+
89+
if len(expectedListener.PortRanges) > 0 {
90+
if len(listener.PortRanges) != len(expectedListener.PortRanges) {
91+
return fmt.Errorf("listener[%d] port range count mismatch: expected %d, got %d", i, len(expectedListener.PortRanges), len(listener.PortRanges))
92+
}
93+
for j, expectedPortRange := range expectedListener.PortRanges {
94+
if awssdk.ToInt32(listener.PortRanges[j].FromPort) != expectedPortRange.FromPort {
95+
return fmt.Errorf("listener[%d] port range[%d] from port mismatch: expected %d, got %d", i, j, expectedPortRange.FromPort, awssdk.ToInt32(listener.PortRanges[j].FromPort))
96+
}
97+
if awssdk.ToInt32(listener.PortRanges[j].ToPort) != expectedPortRange.ToPort {
98+
return fmt.Errorf("listener[%d] port range[%d] to port mismatch: expected %d, got %d", i, j, expectedPortRange.ToPort, awssdk.ToInt32(listener.PortRanges[j].ToPort))
99+
}
100+
}
101+
}
102+
103+
if len(expectedListener.EndpointGroups) > 0 {
104+
listEGResp, err := agaClient.ListEndpointGroupsAsList(ctx, &globalaccelerator.ListEndpointGroupsInput{
105+
ListenerArn: listener.ListenerArn,
106+
})
107+
if err != nil {
108+
return err
109+
}
110+
if len(listEGResp) != len(expectedListener.EndpointGroups) {
111+
return fmt.Errorf("listener[%d] endpoint group count mismatch: expected %d, got %d", i, len(expectedListener.EndpointGroups), len(listEGResp))
112+
}
113+
114+
for k, expectedEG := range expectedListener.EndpointGroups {
115+
eg := listEGResp[k]
116+
117+
if expectedEG.TrafficDialPercentage > 0 && awssdk.ToFloat32(eg.TrafficDialPercentage) != float32(expectedEG.TrafficDialPercentage) {
118+
return fmt.Errorf("listener[%d] endpoint group[%d] traffic dial percentage mismatch: expected %d, got %f", i, k, expectedEG.TrafficDialPercentage, awssdk.ToFloat32(eg.TrafficDialPercentage))
119+
}
120+
121+
if len(expectedEG.PortOverrides) > 0 {
122+
if len(eg.PortOverrides) != len(expectedEG.PortOverrides) {
123+
return fmt.Errorf("listener[%d] endpoint group[%d] port override count mismatch: expected %d, got %d", i, k, len(expectedEG.PortOverrides), len(eg.PortOverrides))
124+
}
125+
for l, expectedPO := range expectedEG.PortOverrides {
126+
if awssdk.ToInt32(eg.PortOverrides[l].ListenerPort) != expectedPO.ListenerPort {
127+
return fmt.Errorf("listener[%d] endpoint group[%d] port override[%d] listener port mismatch: expected %d, got %d", i, k, l, expectedPO.ListenerPort, awssdk.ToInt32(eg.PortOverrides[l].ListenerPort))
128+
}
129+
if awssdk.ToInt32(eg.PortOverrides[l].EndpointPort) != expectedPO.EndpointPort {
130+
return fmt.Errorf("listener[%d] endpoint group[%d] port override[%d] endpoint port mismatch: expected %d, got %d", i, k, l, expectedPO.EndpointPort, awssdk.ToInt32(eg.PortOverrides[l].EndpointPort))
131+
}
132+
}
133+
}
134+
135+
if expectedEG.NumEndpoints > 0 && len(eg.EndpointDescriptions) != expectedEG.NumEndpoints {
136+
return fmt.Errorf("listener[%d] endpoint group[%d] endpoint count mismatch: expected %d, got %d", i, k, expectedEG.NumEndpoints, len(eg.EndpointDescriptions))
137+
}
138+
}
139+
}
140+
}
141+
}
142+
143+
return nil
144+
}
145+
146+
func waitForEndpointsHealthy(ctx context.Context, f *framework.Framework, acceleratorARN string) error {
147+
agaClient := f.Cloud.GlobalAccelerator()
148+
timeoutCtx, cancel := context.WithTimeout(ctx, 5*time.Minute)
149+
defer cancel()
150+
151+
return wait.PollImmediateUntil(utils.PollIntervalMedium, func() (bool, error) {
152+
listListenersResp, err := agaClient.ListListenersForAcceleratorWithContext(ctx, &globalaccelerator.ListListenersInput{
153+
AcceleratorArn: awssdk.String(acceleratorARN),
154+
})
155+
if err != nil {
156+
return false, err
157+
}
158+
159+
hasEndpoints := false
160+
for _, listener := range listListenersResp.Listeners {
161+
listEGResp, err := agaClient.ListEndpointGroupsAsList(ctx, &globalaccelerator.ListEndpointGroupsInput{
162+
ListenerArn: listener.ListenerArn,
163+
})
164+
if err != nil {
165+
return false, err
166+
}
167+
168+
for _, eg := range listEGResp {
169+
if len(eg.EndpointDescriptions) == 0 {
170+
f.Logger.Info("waiting for endpoints to be added", "endpointGroupArn", awssdk.ToString(eg.EndpointGroupArn))
171+
return false, nil
172+
}
173+
hasEndpoints = true
174+
for _, endpoint := range eg.EndpointDescriptions {
175+
if endpoint.HealthState != types.HealthStateHealthy {
176+
f.Logger.Info("waiting for endpoint to be healthy",
177+
"endpointId", awssdk.ToString(endpoint.EndpointId),
178+
"healthState", string(endpoint.HealthState))
179+
return false, nil
180+
}
181+
}
182+
}
183+
}
184+
if !hasEndpoints {
185+
f.Logger.Info("no endpoints found in any endpoint group")
186+
return false, nil
187+
}
188+
return true, nil
189+
}, timeoutCtx.Done())
190+
}
191+
192+
func verifyLoadBalancerScheme(ctx context.Context, f *framework.Framework, lbHostname, expectedScheme string) error {
193+
elbClient := f.Cloud.ELBV2()
194+
lbs, err := elbClient.DescribeLoadBalancersAsList(ctx, &elasticloadbalancingv2.DescribeLoadBalancersInput{})
195+
if err != nil {
196+
return fmt.Errorf("failed to describe load balancers: %w", err)
197+
}
198+
199+
for _, lb := range lbs {
200+
if awssdk.ToString(lb.DNSName) == lbHostname {
201+
actualScheme := string(lb.Scheme)
202+
if actualScheme != expectedScheme {
203+
return fmt.Errorf("load balancer scheme mismatch: expected %s, got %s", expectedScheme, actualScheme)
204+
}
205+
f.Logger.Info("verified load balancer scheme", "hostname", lbHostname, "scheme", actualScheme)
206+
return nil
207+
}
208+
}
209+
return fmt.Errorf("load balancer with hostname %s not found", lbHostname)
210+
}
211+
212+
func verifyEndpointPointsToLoadBalancer(ctx context.Context, f *framework.Framework, acceleratorARN, expectedLBHostname string) error {
213+
agaClient := f.Cloud.GlobalAccelerator()
214+
elbClient := f.Cloud.ELBV2()
215+
216+
lbs, err := elbClient.DescribeLoadBalancersAsList(ctx, &elasticloadbalancingv2.DescribeLoadBalancersInput{})
217+
if err != nil {
218+
return fmt.Errorf("failed to describe load balancers: %w", err)
219+
}
220+
221+
var expectedLBARN string
222+
for _, lb := range lbs {
223+
if awssdk.ToString(lb.DNSName) == expectedLBHostname {
224+
expectedLBARN = awssdk.ToString(lb.LoadBalancerArn)
225+
break
226+
}
227+
}
228+
if expectedLBARN == "" {
229+
return fmt.Errorf("load balancer with hostname %s not found", expectedLBHostname)
230+
}
231+
232+
listListenersResp, err := agaClient.ListListenersForAcceleratorWithContext(ctx, &globalaccelerator.ListListenersInput{
233+
AcceleratorArn: awssdk.String(acceleratorARN),
234+
})
235+
if err != nil {
236+
return err
237+
}
238+
239+
for _, listener := range listListenersResp.Listeners {
240+
listEGResp, err := agaClient.ListEndpointGroupsAsList(ctx, &globalaccelerator.ListEndpointGroupsInput{
241+
ListenerArn: listener.ListenerArn,
242+
})
243+
if err != nil {
244+
return err
245+
}
246+
247+
for _, eg := range listEGResp {
248+
if len(eg.EndpointDescriptions) == 0 {
249+
return fmt.Errorf("no endpoints in endpoint group %s", awssdk.ToString(eg.EndpointGroupArn))
250+
}
251+
for _, endpoint := range eg.EndpointDescriptions {
252+
if endpoint.HealthState != types.HealthStateHealthy {
253+
return fmt.Errorf("endpoint %s not healthy: %s", awssdk.ToString(endpoint.EndpointId), string(endpoint.HealthState))
254+
}
255+
if awssdk.ToString(endpoint.EndpointId) != expectedLBARN {
256+
return fmt.Errorf("endpoint ARN mismatch: expected %s, got %s", expectedLBARN, awssdk.ToString(endpoint.EndpointId))
257+
}
258+
f.Logger.Info("verified endpoint points to correct load balancer",
259+
"endpointId", awssdk.ToString(endpoint.EndpointId),
260+
"expectedLBARN", expectedLBARN,
261+
"healthState", string(endpoint.HealthState))
262+
}
263+
}
264+
}
265+
return nil
266+
}
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
package globalaccelerator
2+
3+
import (
4+
"strings"
5+
"testing"
6+
7+
. "github.com/onsi/ginkgo/v2"
8+
. "github.com/onsi/gomega"
9+
"sigs.k8s.io/aws-load-balancer-controller/test/framework"
10+
)
11+
12+
var tf *framework.Framework
13+
14+
func TestGlobalAccelerator(t *testing.T) {
15+
RegisterFailHandler(Fail)
16+
RunSpecs(t, "GlobalAccelerator Suite")
17+
}
18+
19+
var _ = BeforeSuite(func() {
20+
var err error
21+
tf, err = framework.InitFramework()
22+
Expect(err).NotTo(HaveOccurred())
23+
24+
if !isCommercialPartition(tf.Options.AWSRegion) {
25+
Skip("GlobalAccelerator is only available in commercial AWS partition")
26+
}
27+
})
28+
29+
func isCommercialPartition(region string) bool {
30+
unsupportedPrefixes := []string{"cn-", "us-gov-", "us-iso", "eu-isoe-"}
31+
for _, prefix := range unsupportedPrefixes {
32+
if strings.HasPrefix(strings.ToLower(region), prefix) {
33+
return false
34+
}
35+
}
36+
return true
37+
}

0 commit comments

Comments
 (0)