diff --git a/apis/v1alpha1/upstreamsettingspolicy_types.go b/apis/v1alpha1/upstreamsettingspolicy_types.go index 8504c9717d..4b8131fad1 100644 --- a/apis/v1alpha1/upstreamsettingspolicy_types.go +++ b/apis/v1alpha1/upstreamsettingspolicy_types.go @@ -36,6 +36,9 @@ type UpstreamSettingsPolicyList struct { } // UpstreamSettingsPolicySpec defines the desired state of the UpstreamSettingsPolicy. +// +kubebuilder:validation:XValidation:rule="!(has(self.loadBalancingMethod) && (self.loadBalancingMethod == 'hash' || self.loadBalancingMethod == 'hash consistent')) || has(self.hashMethodKey)",message="hashMethodKey is required when loadBalancingMethod is 'hash' or 'hash consistent'" +// +//nolint:lll type UpstreamSettingsPolicySpec struct { // ZoneSize is the size of the shared memory zone used by the upstream. This memory zone is used to share // the upstream configuration between nginx worker processes. The more servers that an upstream has, @@ -51,6 +54,19 @@ type UpstreamSettingsPolicySpec struct { // +optional KeepAlive *UpstreamKeepAlive `json:"keepAlive,omitempty"` + // LoadBalancingMethod specifies the load balancing algorithm to be used for the upstream. + // If not specified, NGINX Gateway Fabric defaults to `random two least_conn`, + // which differs from the standard NGINX default `round-robin`. + // + // +optional + LoadBalancingMethod *LoadBalancingType `json:"loadBalancingMethod,omitempty"` + + // HashMethodKey defines the key used for hash-based load balancing methods. + // This field is required when `LoadBalancingMethod` is set to `hash` or `hash consistent`. + // + // +optional + HashMethodKey *HashMethodKey `json:"hashMethodKey,omitempty"` + // TargetRefs identifies API object(s) to apply the policy to. // Objects must be in the same namespace as the policy. // Support: Service @@ -98,3 +114,99 @@ type UpstreamKeepAlive struct { // +optional Timeout *Duration `json:"timeout,omitempty"` } + +// LoadBalancingType defines the supported load balancing methods. +// +// +kubebuilder:validation:Enum=round_robin;least_conn;ip_hash;hash;hash consistent;random;random two;random two least_conn;random two least_time=header;random two least_time=last_byte;least_time header;least_time last_byte;least_time header inflight;least_time last_byte inflight +// +//nolint:lll +type LoadBalancingType string + +const ( + // Combination of NGINX directive + // - https://nginx.org/en/docs/http/ngx_http_upstream_module.html#random + // - https://nginx.org/en/docs/http/ngx_http_upstream_module.html#least_conn + // - https://nginx.org/en/docs/http/ngx_http_upstream_module.html#least_time + // - https://nginx.org/en/docs/http/ngx_http_upstream_module.html#upstream + // - https://nginx.org/en/docs/http/ngx_http_upstream_module.html#ip_hash + // - https://nginx.org/en/docs/http/ngx_http_upstream_module.html#hash + + // LoadBalancingMethods supported by NGINX OSS and NGINX Plus. + + // LoadBalancingTypeRoundRobin enables round-robin load balancing, + // distributing requests evenly across all upstream servers. + LoadBalancingTypeRoundRobin LoadBalancingType = "round_robin" + + // LoadBalancingTypeLeastConnection enables least-connections load balancing, + // routing requests to the upstream server with the fewest active connections. + LoadBalancingTypeLeastConnection LoadBalancingType = "least_conn" + + // LoadBalancingTypeIPHash enables IP hash-based load balancing, + // ensuring requests from the same client IP are routed to the same upstream server. + LoadBalancingTypeIPHash LoadBalancingType = "ip_hash" + + // LoadBalancingTypeHash enables generic hash-based load balancing, + // routing requests to upstream servers based on a hash of a specified key + // HashMethodKey field must be set when this method is selected. + // Example configuration: hash $binary_remote_addr;. + LoadBalancingTypeHash LoadBalancingType = "hash" + + // LoadBalancingTypeHashConsistent enables consistent hash-based load balancing, + // which minimizes the number of keys remapped when a server is added or removed. + // HashMethodKey field must be set when this method is selected. + // Example configuration: hash $binary_remote_addr consistent;. + LoadBalancingTypeHashConsistent LoadBalancingType = "hash consistent" + + // LoadBalancingTypeRandom enables random load balancing, + // routing requests to upstream servers in a random manner. + LoadBalancingTypeRandom LoadBalancingType = "random" + + // LoadBalancingTypeRandomTwo enables a variation of random load balancing + // that randomly selects two servers and forwards traffic to one of them. + // The default method is least_conn which passes a request to a server with the least number of active connections. + LoadBalancingTypeRandomTwo LoadBalancingType = "random two" + + // LoadBalancingTypeRandomTwoLeastConnection enables a variation of least-connections + // balancing that randomly selects two servers and forwards traffic to the one with + // fewer active connections. + LoadBalancingTypeRandomTwoLeastConnection LoadBalancingType = "random two least_conn" + + // LoadBalancingMethods supported by NGINX Plus. + + // LoadBalancingTypeRandomTwoLeastTimeHeader enables a variation of least-time load balancing + // that randomly selects two servers and forwards traffic to the one with the least + // time to receive the response header. + LoadBalancingTypeRandomTwoLeastTimeHeader LoadBalancingType = "random two least_time=header" + + // LoadBalancingTypeRandomTwoLeastTimeLastByte enables a variation of least-time load balancing + // that randomly selects two servers and forwards traffic to the one with the least time + // to receive the full response. + LoadBalancingTypeRandomTwoLeastTimeLastByte LoadBalancingType = "random two least_time=last_byte" + + // LoadBalancingTypeLeastTimeHeader enables least-time load balancing, + // routing requests to the upstream server with the least time to receive the response header. + LoadBalancingTypeLeastTimeHeader LoadBalancingType = "least_time header" + + // LoadBalancingTypeLeastTimeLastByte enables least-time load balancing, + // routing requests to the upstream server with the least time to receive the full response. + LoadBalancingTypeLeastTimeLastByte LoadBalancingType = "least_time last_byte" + + // LoadBalancingTypeLeastTimeHeaderInflight enables least-time load balancing, + // routing requests to the upstream server with the least time to receive the response header, + // considering the incomplete requests. + LoadBalancingTypeLeastTimeHeaderInflight LoadBalancingType = "least_time header inflight" + + // LoadBalancingTypeLeastTimeLastByteInflight enables least-time load balancing, + // routing requests to the upstream server with the least time to receive the full response, + // considering the incomplete requests. + LoadBalancingTypeLeastTimeLastByteInflight LoadBalancingType = "least_time last_byte inflight" +) + +// HashMethodKey defines the key used for hash-based load balancing methods. +// The key must be a valid NGINX variable name starting with '$' followed by lowercase +// letters and underscores only. +// For a full list of NGINX variables, +// refer to: https://nginx.org/en/docs/http/ngx_http_upstream_module.html#variables +// +// +kubebuilder:validation:Pattern=`^\$[a-z_]+$` +type HashMethodKey string diff --git a/apis/v1alpha1/zz_generated.deepcopy.go b/apis/v1alpha1/zz_generated.deepcopy.go index d07825de2d..164bef0cba 100644 --- a/apis/v1alpha1/zz_generated.deepcopy.go +++ b/apis/v1alpha1/zz_generated.deepcopy.go @@ -556,6 +556,16 @@ func (in *UpstreamSettingsPolicySpec) DeepCopyInto(out *UpstreamSettingsPolicySp *out = new(UpstreamKeepAlive) (*in).DeepCopyInto(*out) } + if in.LoadBalancingMethod != nil { + in, out := &in.LoadBalancingMethod, &out.LoadBalancingMethod + *out = new(LoadBalancingType) + **out = **in + } + if in.HashMethodKey != nil { + in, out := &in.HashMethodKey, &out.HashMethodKey + *out = new(HashMethodKey) + **out = **in + } if in.TargetRefs != nil { in, out := &in.TargetRefs, &out.TargetRefs *out = make([]apisv1.LocalPolicyTargetReference, len(*in)) diff --git a/config/crd/bases/gateway.nginx.org_upstreamsettingspolicies.yaml b/config/crd/bases/gateway.nginx.org_upstreamsettingspolicies.yaml index ce6b9603ed..c8cda0c218 100644 --- a/config/crd/bases/gateway.nginx.org_upstreamsettingspolicies.yaml +++ b/config/crd/bases/gateway.nginx.org_upstreamsettingspolicies.yaml @@ -51,6 +51,12 @@ spec: spec: description: Spec defines the desired state of the UpstreamSettingsPolicy. properties: + hashMethodKey: + description: |- + HashMethodKey defines the key used for hash-based load balancing methods. + This field is required when `LoadBalancingMethod` is set to `hash` or `hash consistent`. + pattern: ^\$[a-z_]+$ + type: string keepAlive: description: KeepAlive defines the keep-alive settings. properties: @@ -85,6 +91,27 @@ spec: pattern: ^[0-9]{1,4}(ms|s|m|h)?$ type: string type: object + loadBalancingMethod: + description: |- + LoadBalancingMethod specifies the load balancing algorithm to be used for the upstream. + If not specified, NGINX Gateway Fabric defaults to `random two least_conn`, + which differs from the standard NGINX default `round-robin`. + enum: + - round_robin + - least_conn + - ip_hash + - hash + - hash consistent + - random + - random two + - random two least_conn + - random two least_time=header + - random two least_time=last_byte + - least_time header + - least_time last_byte + - least_time header inflight + - least_time last_byte inflight + type: string targetRefs: description: |- TargetRefs identifies API object(s) to apply the policy to. @@ -143,6 +170,12 @@ spec: required: - targetRefs type: object + x-kubernetes-validations: + - message: hashMethodKey is required when loadBalancingMethod is 'hash' + or 'hash consistent' + rule: '!(has(self.loadBalancingMethod) && (self.loadBalancingMethod + == ''hash'' || self.loadBalancingMethod == ''hash consistent'')) || + has(self.hashMethodKey)' status: description: Status defines the state of the UpstreamSettingsPolicy. properties: diff --git a/deploy/crds.yaml b/deploy/crds.yaml index 2a526a961f..6cf54dd1f5 100644 --- a/deploy/crds.yaml +++ b/deploy/crds.yaml @@ -9578,6 +9578,12 @@ spec: spec: description: Spec defines the desired state of the UpstreamSettingsPolicy. properties: + hashMethodKey: + description: |- + HashMethodKey defines the key used for hash-based load balancing methods. + This field is required when `LoadBalancingMethod` is set to `hash` or `hash consistent`. + pattern: ^\$[a-z_]+$ + type: string keepAlive: description: KeepAlive defines the keep-alive settings. properties: @@ -9612,6 +9618,27 @@ spec: pattern: ^[0-9]{1,4}(ms|s|m|h)?$ type: string type: object + loadBalancingMethod: + description: |- + LoadBalancingMethod specifies the load balancing algorithm to be used for the upstream. + If not specified, NGINX Gateway Fabric defaults to `random two least_conn`, + which differs from the standard NGINX default `round-robin`. + enum: + - round_robin + - least_conn + - ip_hash + - hash + - hash consistent + - random + - random two + - random two least_conn + - random two least_time=header + - random two least_time=last_byte + - least_time header + - least_time last_byte + - least_time header inflight + - least_time last_byte inflight + type: string targetRefs: description: |- TargetRefs identifies API object(s) to apply the policy to. @@ -9670,6 +9697,12 @@ spec: required: - targetRefs type: object + x-kubernetes-validations: + - message: hashMethodKey is required when loadBalancingMethod is 'hash' + or 'hash consistent' + rule: '!(has(self.loadBalancingMethod) && (self.loadBalancingMethod + == ''hash'' || self.loadBalancingMethod == ''hash consistent'')) || + has(self.hashMethodKey)' status: description: Status defines the state of the UpstreamSettingsPolicy. properties: diff --git a/internal/controller/manager.go b/internal/controller/manager.go index d4e2114e8a..8b8c00b607 100644 --- a/internal/controller/manager.go +++ b/internal/controller/manager.go @@ -124,7 +124,7 @@ func StartManager(cfg config.Config) error { mustExtractGVK := kinds.NewMustExtractGKV(scheme) genericValidator := ngxvalidation.GenericValidator{} - policyManager := createPolicyManager(mustExtractGVK, genericValidator) + policyManager := createPolicyManager(mustExtractGVK, genericValidator, cfg.Plus) plusSecrets, err := createPlusSecretMetadata(cfg, mgr.GetAPIReader()) if err != nil { @@ -140,10 +140,13 @@ func StartManager(cfg config.Config) error { GenericValidator: genericValidator, PolicyValidator: policyManager, }, - EventRecorder: recorder, - MustExtractGVK: mustExtractGVK, - PlusSecrets: plusSecrets, - ExperimentalFeatures: cfg.ExperimentalFeatures, + EventRecorder: recorder, + MustExtractGVK: mustExtractGVK, + PlusSecrets: plusSecrets, + FeatureFlags: graph.FeatureFlags{ + Plus: cfg.Plus, + Experimental: cfg.ExperimentalFeatures, + }, }) var handlerCollector handlerMetricsCollector = collectors.NewControllerNoopCollector() @@ -323,6 +326,7 @@ func StartManager(cfg config.Config) error { func createPolicyManager( mustExtractGVK kinds.MustExtractGVK, validator validation.GenericValidator, + plusEnabled bool, ) *policies.CompositeValidator { cfgs := []policies.ManagerConfig{ { @@ -335,7 +339,7 @@ func createPolicyManager( }, { GVK: mustExtractGVK(&ngfAPIv1alpha1.UpstreamSettingsPolicy{}), - Validator: upstreamsettings.NewValidator(validator), + Validator: upstreamsettings.NewValidator(validator, plusEnabled), }, } diff --git a/internal/controller/nginx/config/http/config.go b/internal/controller/nginx/config/http/config.go index 355e45fe1b..a8bec2f1e0 100644 --- a/internal/controller/nginx/config/http/config.go +++ b/internal/controller/nginx/config/http/config.go @@ -1,6 +1,7 @@ package http //nolint:revive // ignoring conflicting package name import ( + ngfAPI "github.com/nginx/nginx-gateway-fabric/v2/apis/v1alpha1" "github.com/nginx/nginx-gateway-fabric/v2/internal/controller/nginx/config/shared" ) @@ -119,11 +120,22 @@ const ( // Upstream holds all configuration for an HTTP upstream. type Upstream struct { - Name string - ZoneSize string // format: 512k, 1m - StateFile string - KeepAlive UpstreamKeepAlive - Servers []UpstreamServer + SessionPersistence UpstreamSessionPersistence + Name string + ZoneSize string // format: 512k, 1m + StateFile string + LoadBalancingMethod string + HashMethodKey string + KeepAlive UpstreamKeepAlive + Servers []UpstreamServer +} + +// UpstreamSessionPersistence holds the session persistence configuration for an upstream. +type UpstreamSessionPersistence struct { + Name string + Expiry string + Path string + SessionType string } // UpstreamKeepAlive holds the keepalive configuration for an HTTP upstream. @@ -166,3 +178,33 @@ type ServerConfig struct { Plus bool DisableSNIHostValidation bool } + +var ( + OSSAllowedLBMethods = map[ngfAPI.LoadBalancingType]struct{}{ + ngfAPI.LoadBalancingTypeRoundRobin: {}, + ngfAPI.LoadBalancingTypeLeastConnection: {}, + ngfAPI.LoadBalancingTypeIPHash: {}, + ngfAPI.LoadBalancingTypeRandom: {}, + ngfAPI.LoadBalancingTypeHash: {}, + ngfAPI.LoadBalancingTypeHashConsistent: {}, + ngfAPI.LoadBalancingTypeRandomTwo: {}, + ngfAPI.LoadBalancingTypeRandomTwoLeastConnection: {}, + } + + PlusAllowedLBMethods = map[ngfAPI.LoadBalancingType]struct{}{ + ngfAPI.LoadBalancingTypeRoundRobin: {}, + ngfAPI.LoadBalancingTypeLeastConnection: {}, + ngfAPI.LoadBalancingTypeIPHash: {}, + ngfAPI.LoadBalancingTypeRandom: {}, + ngfAPI.LoadBalancingTypeHash: {}, + ngfAPI.LoadBalancingTypeHashConsistent: {}, + ngfAPI.LoadBalancingTypeRandomTwo: {}, + ngfAPI.LoadBalancingTypeRandomTwoLeastConnection: {}, + ngfAPI.LoadBalancingTypeLeastTimeHeader: {}, + ngfAPI.LoadBalancingTypeLeastTimeLastByte: {}, + ngfAPI.LoadBalancingTypeLeastTimeHeaderInflight: {}, + ngfAPI.LoadBalancingTypeLeastTimeLastByteInflight: {}, + ngfAPI.LoadBalancingTypeRandomTwoLeastTimeHeader: {}, + ngfAPI.LoadBalancingTypeRandomTwoLeastTimeLastByte: {}, + } +) diff --git a/internal/controller/nginx/config/policies/upstreamsettings/processor.go b/internal/controller/nginx/config/policies/upstreamsettings/processor.go index 1a646a45e1..9b4f23c7a7 100644 --- a/internal/controller/nginx/config/policies/upstreamsettings/processor.go +++ b/internal/controller/nginx/config/policies/upstreamsettings/processor.go @@ -13,6 +13,10 @@ type Processor struct{} type UpstreamSettings struct { // ZoneSize is the zone size setting. ZoneSize string + // LoadBalancingMethod is the load balancing method setting. + LoadBalancingMethod string + // HashMethodKey is the key to be used for hash-based load balancing methods. + HashMethodKey string // KeepAlive contains the keepalive settings. KeepAlive http.UpstreamKeepAlive } @@ -61,6 +65,14 @@ func processPolicies(pols []policies.Policy) UpstreamSettings { upstreamSettings.KeepAlive.Timeout = string(*usp.Spec.KeepAlive.Timeout) } } + + if usp.Spec.LoadBalancingMethod != nil { + upstreamSettings.LoadBalancingMethod = string(*usp.Spec.LoadBalancingMethod) + } + + if usp.Spec.HashMethodKey != nil { + upstreamSettings.HashMethodKey = string(*usp.Spec.HashMethodKey) + } } return upstreamSettings diff --git a/internal/controller/nginx/config/policies/upstreamsettings/processor_test.go b/internal/controller/nginx/config/policies/upstreamsettings/processor_test.go index a67c1df186..8473e59f40 100644 --- a/internal/controller/nginx/config/policies/upstreamsettings/processor_test.go +++ b/internal/controller/nginx/config/policies/upstreamsettings/processor_test.go @@ -37,6 +37,8 @@ func TestProcess(t *testing.T) { Time: helpers.GetPointer[ngfAPIv1alpha1.Duration]("5s"), Timeout: helpers.GetPointer[ngfAPIv1alpha1.Duration]("10s"), }), + LoadBalancingMethod: helpers.GetPointer(ngfAPIv1alpha1.LoadBalancingTypeIPHash), + HashMethodKey: helpers.GetPointer[ngfAPIv1alpha1.HashMethodKey]("$upstream_addr"), }, }, }, @@ -48,6 +50,44 @@ func TestProcess(t *testing.T) { Time: "5s", Timeout: "10s", }, + LoadBalancingMethod: string(ngfAPIv1alpha1.LoadBalancingTypeIPHash), + HashMethodKey: "$upstream_addr", + }, + }, + { + name: "load balancing method set", + policies: []policies.Policy{ + &ngfAPIv1alpha1.UpstreamSettingsPolicy{ + ObjectMeta: metav1.ObjectMeta{ + Name: "usp", + Namespace: "test", + }, + Spec: ngfAPIv1alpha1.UpstreamSettingsPolicySpec{ + LoadBalancingMethod: helpers.GetPointer(ngfAPIv1alpha1.LoadBalancingTypeRandomTwoLeastConnection), + }, + }, + }, + expUpstreamSettings: UpstreamSettings{ + LoadBalancingMethod: string(ngfAPIv1alpha1.LoadBalancingTypeRandomTwoLeastConnection), + }, + }, + { + name: "load balancing method set with hash key", + policies: []policies.Policy{ + &ngfAPIv1alpha1.UpstreamSettingsPolicy{ + ObjectMeta: metav1.ObjectMeta{ + Name: "usp", + Namespace: "test", + }, + Spec: ngfAPIv1alpha1.UpstreamSettingsPolicySpec{ + LoadBalancingMethod: helpers.GetPointer(ngfAPIv1alpha1.LoadBalancingTypeHashConsistent), + HashMethodKey: helpers.GetPointer[ngfAPIv1alpha1.HashMethodKey]("$request_time"), + }, + }, + }, + expUpstreamSettings: UpstreamSettings{ + LoadBalancingMethod: string(ngfAPIv1alpha1.LoadBalancingTypeHashConsistent), + HashMethodKey: "$request_time", }, }, { @@ -220,6 +260,16 @@ func TestProcess(t *testing.T) { }), }, }, + &ngfAPIv1alpha1.UpstreamSettingsPolicy{ + ObjectMeta: metav1.ObjectMeta{ + Name: "usp-loadBalancingMethod", + Namespace: "test", + }, + Spec: ngfAPIv1alpha1.UpstreamSettingsPolicySpec{ + LoadBalancingMethod: helpers.GetPointer(ngfAPIv1alpha1.LoadBalancingTypeHashConsistent), + HashMethodKey: helpers.GetPointer[ngfAPIv1alpha1.HashMethodKey]("$upstream_addr"), + }, + }, }, expUpstreamSettings: UpstreamSettings{ ZoneSize: "2m", @@ -229,6 +279,8 @@ func TestProcess(t *testing.T) { Time: "5s", Timeout: "10s", }, + LoadBalancingMethod: string(ngfAPIv1alpha1.LoadBalancingTypeHashConsistent), + HashMethodKey: "$upstream_addr", }, }, { @@ -310,6 +362,16 @@ func TestProcess(t *testing.T) { }, }, }, + &ngfAPIv1alpha1.UpstreamSettingsPolicy{ + ObjectMeta: metav1.ObjectMeta{ + Name: "usp-lb-method", + Namespace: "test", + }, + Spec: ngfAPIv1alpha1.UpstreamSettingsPolicySpec{ + LoadBalancingMethod: helpers.GetPointer(ngfAPIv1alpha1.LoadBalancingTypeHash), + HashMethodKey: helpers.GetPointer[ngfAPIv1alpha1.HashMethodKey]("$remote_addr"), + }, + }, }, expUpstreamSettings: UpstreamSettings{ ZoneSize: "2m", @@ -319,6 +381,8 @@ func TestProcess(t *testing.T) { Time: "5s", Timeout: "10s", }, + LoadBalancingMethod: string(ngfAPIv1alpha1.LoadBalancingTypeHash), + HashMethodKey: "$remote_addr", }, }, } diff --git a/internal/controller/nginx/config/policies/upstreamsettings/validator.go b/internal/controller/nginx/config/policies/upstreamsettings/validator.go index 9b54fb48e2..e56857f752 100644 --- a/internal/controller/nginx/config/policies/upstreamsettings/validator.go +++ b/internal/controller/nginx/config/policies/upstreamsettings/validator.go @@ -1,10 +1,14 @@ package upstreamsettings import ( + "fmt" + "strings" + "k8s.io/apimachinery/pkg/util/validation/field" gatewayv1 "sigs.k8s.io/gateway-api/apis/v1" ngfAPI "github.com/nginx/nginx-gateway-fabric/v2/apis/v1alpha1" + httpConfig "github.com/nginx/nginx-gateway-fabric/v2/internal/controller/nginx/config/http" "github.com/nginx/nginx-gateway-fabric/v2/internal/controller/nginx/config/policies" "github.com/nginx/nginx-gateway-fabric/v2/internal/controller/state/conditions" "github.com/nginx/nginx-gateway-fabric/v2/internal/controller/state/validation" @@ -16,11 +20,15 @@ import ( // Implements policies.Validator interface. type Validator struct { genericValidator validation.GenericValidator + plusEnabled bool } // NewValidator returns a new Validator. -func NewValidator(genericValidator validation.GenericValidator) Validator { - return Validator{genericValidator: genericValidator} +func NewValidator(genericValidator validation.GenericValidator, plusEnabled bool) Validator { + return Validator{ + genericValidator: genericValidator, + plusEnabled: plusEnabled, + } } // Validate validates the spec of an UpstreamsSettingsPolicy. @@ -83,6 +91,22 @@ func conflicts(a, b ngfAPI.UpstreamSettingsPolicySpec) bool { } } + if checkConflictsForLoadBalancingFields(a, b) { + return true + } + + return false +} + +func checkConflictsForLoadBalancingFields(a, b ngfAPI.UpstreamSettingsPolicySpec) bool { + if a.LoadBalancingMethod != nil && b.LoadBalancingMethod != nil { + return true + } + + if a.HashMethodKey != nil && b.HashMethodKey != nil { + return true + } + return false } @@ -103,6 +127,8 @@ func (v Validator) validateSettings(spec ngfAPI.UpstreamSettingsPolicySpec) erro allErrs = append(allErrs, v.validateUpstreamKeepAlive(*spec.KeepAlive, fieldPath.Child("keepAlive"))...) } + allErrs = append(allErrs, v.validateLoadBalancingMethod(spec)...) + return allErrs.ToAggregate() } @@ -130,3 +156,51 @@ func (v Validator) validateUpstreamKeepAlive( return allErrs } + +// ValidateLoadBalancingMethod validates the load balancing method for upstream servers. +func (v Validator) validateLoadBalancingMethod(spec ngfAPI.UpstreamSettingsPolicySpec) field.ErrorList { + if spec.LoadBalancingMethod == nil { + return nil + } + + var allErrs field.ErrorList + path := field.NewPath("spec") + lbPath := path.Child("loadBalancingMethod") + + allowedMethods := httpConfig.OSSAllowedLBMethods + nginxType := "NGINX OSS" + if v.plusEnabled { + allowedMethods = httpConfig.PlusAllowedLBMethods + nginxType = "NGINX Plus" + } + + if _, ok := allowedMethods[*spec.LoadBalancingMethod]; !ok { + allErrs = append(allErrs, field.Invalid( + lbPath, + *spec.LoadBalancingMethod, + fmt.Sprintf( + "%s supports the following load balancing methods: %s", + nginxType, + getLoadBalancingMethodList(allowedMethods), + ), + )) + } + + if spec.HashMethodKey != nil { + hashMethodKey := *spec.HashMethodKey + if err := v.genericValidator.ValidateNginxVariableName(string(hashMethodKey)); err != nil { + path := path.Child("hashMethodKey") + allErrs = append(allErrs, field.Invalid(path, hashMethodKey, err.Error())) + } + } + + return allErrs +} + +func getLoadBalancingMethodList(lbMethods map[ngfAPI.LoadBalancingType]struct{}) string { + methods := make([]string, 0, len(lbMethods)) + for method := range lbMethods { + methods = append(methods, string(method)) + } + return strings.Join(methods, ", ") +} diff --git a/internal/controller/nginx/config/policies/upstreamsettings/validator_test.go b/internal/controller/nginx/config/policies/upstreamsettings/validator_test.go index df4c4e0770..d7c81ffa17 100644 --- a/internal/controller/nginx/config/policies/upstreamsettings/validator_test.go +++ b/internal/controller/nginx/config/policies/upstreamsettings/validator_test.go @@ -16,6 +16,8 @@ import ( "github.com/nginx/nginx-gateway-fabric/v2/internal/framework/kinds" ) +const plusDisabled = false + type policyModFunc func(policy *ngfAPI.UpstreamSettingsPolicy) *ngfAPI.UpstreamSettingsPolicy func createValidPolicy() *ngfAPI.UpstreamSettingsPolicy { @@ -38,6 +40,8 @@ func createValidPolicy() *ngfAPI.UpstreamSettingsPolicy { Timeout: helpers.GetPointer[ngfAPI.Duration]("30s"), Connections: helpers.GetPointer[int32](100), }, + LoadBalancingMethod: helpers.GetPointer(ngfAPI.LoadBalancingTypeRandomTwoLeastConnection), + HashMethodKey: helpers.GetPointer[ngfAPI.HashMethodKey]("$upstream_addr"), }, Status: v1.PolicyStatus{}, } @@ -124,7 +128,7 @@ func TestValidator_Validate(t *testing.T) { }, } - v := upstreamsettings.NewValidator(validation.GenericValidator{}) + v := upstreamsettings.NewValidator(validation.GenericValidator{}, plusDisabled) for _, test := range tests { t.Run(test.name, func(t *testing.T) { @@ -139,7 +143,7 @@ func TestValidator_Validate(t *testing.T) { func TestValidator_ValidatePanics(t *testing.T) { t.Parallel() - v := upstreamsettings.NewValidator(nil) + v := upstreamsettings.NewValidator(nil, plusDisabled) validate := func() { _ = v.Validate(&policiesfakes.FakePolicy{}) @@ -154,7 +158,7 @@ func TestValidator_ValidateGlobalSettings(t *testing.T) { t.Parallel() g := NewWithT(t) - v := upstreamsettings.NewValidator(validation.GenericValidator{}) + v := upstreamsettings.NewValidator(validation.GenericValidator{}, plusDisabled) g.Expect(v.ValidateGlobalSettings(nil, nil)).To(BeNil()) } @@ -176,6 +180,7 @@ func TestValidator_Conflicts(t *testing.T) { Requests: helpers.GetPointer[int32](900), Time: helpers.GetPointer[ngfAPI.Duration]("50s"), }, + LoadBalancingMethod: helpers.GetPointer(ngfAPI.LoadBalancingTypeRandomTwoLeastConnection), }, }, polB: &ngfAPI.UpstreamSettingsPolicy{ @@ -246,9 +251,29 @@ func TestValidator_Conflicts(t *testing.T) { }, conflicts: true, }, + { + name: "load balancing method conflicts", + polA: createValidPolicy(), + polB: &ngfAPI.UpstreamSettingsPolicy{ + Spec: ngfAPI.UpstreamSettingsPolicySpec{ + LoadBalancingMethod: helpers.GetPointer(ngfAPI.LoadBalancingTypeIPHash), + }, + }, + conflicts: true, + }, + { + name: "hash key conflicts", + polA: createValidPolicy(), + polB: &ngfAPI.UpstreamSettingsPolicy{ + Spec: ngfAPI.UpstreamSettingsPolicySpec{ + HashMethodKey: helpers.GetPointer[ngfAPI.HashMethodKey]("$upstream_addr"), + }, + }, + conflicts: true, + }, } - v := upstreamsettings.NewValidator(nil) + v := upstreamsettings.NewValidator(nil, plusDisabled) for _, test := range tests { t.Run(test.name, func(t *testing.T) { @@ -262,7 +287,7 @@ func TestValidator_Conflicts(t *testing.T) { func TestValidator_ConflictsPanics(t *testing.T) { t.Parallel() - v := upstreamsettings.NewValidator(nil) + v := upstreamsettings.NewValidator(nil, plusDisabled) conflicts := func() { _ = v.Conflicts(&policiesfakes.FakePolicy{}, &policiesfakes.FakePolicy{}) @@ -272,3 +297,95 @@ func TestValidator_ConflictsPanics(t *testing.T) { g.Expect(conflicts).To(Panic()) } + +func TestValidate_ValidateLoadBalancingMethod(t *testing.T) { + t.Parallel() + + tests := []struct { + policy *ngfAPI.UpstreamSettingsPolicy + name string + expConditions []conditions.Condition + plusEnabled bool + }{ + { + name: "oss method random with Plus disabled", + policy: &ngfAPI.UpstreamSettingsPolicy{ + Spec: ngfAPI.UpstreamSettingsPolicySpec{ + LoadBalancingMethod: helpers.GetPointer(ngfAPI.LoadBalancingTypeRandom), + }, + }, + expConditions: nil, + }, + { + name: "oss method hash consistent with Plus disabled", + policy: &ngfAPI.UpstreamSettingsPolicy{ + Spec: ngfAPI.UpstreamSettingsPolicySpec{ + LoadBalancingMethod: helpers.GetPointer(ngfAPI.LoadBalancingTypeHashConsistent), + }, + }, + expConditions: nil, + }, + { + name: "plus load balancing method least_time last_byte not allowed with Plus disabled", + policy: &ngfAPI.UpstreamSettingsPolicy{ + Spec: ngfAPI.UpstreamSettingsPolicySpec{ + LoadBalancingMethod: helpers.GetPointer(ngfAPI.LoadBalancingTypeLeastTimeLastByte), + }, + }, + expConditions: []conditions.Condition{ + conditions.NewPolicyInvalid("spec.loadBalancingMethod: Invalid value: \"least_time last_byte\": " + + "NGINX OSS supports the following load balancing methods: "), + }, + }, + { + name: "plus load balancing method least_time header allowed with Plus enabled", + policy: &ngfAPI.UpstreamSettingsPolicy{ + Spec: ngfAPI.UpstreamSettingsPolicySpec{ + LoadBalancingMethod: helpers.GetPointer(ngfAPI.LoadBalancingTypeLeastTimeHeader), + }, + }, + plusEnabled: true, + expConditions: nil, + }, + { + name: "invalid load balancing method for NGINX OSS", + policy: &ngfAPI.UpstreamSettingsPolicy{ + Spec: ngfAPI.UpstreamSettingsPolicySpec{ + LoadBalancingMethod: helpers.GetPointer(ngfAPI.LoadBalancingType("invalid-method")), + }, + }, + expConditions: []conditions.Condition{ + conditions.NewPolicyInvalid("spec.loadBalancingMethod: Invalid value: \"invalid-method\": " + + "NGINX OSS supports the following load balancing methods: "), + }, + }, + { + name: "invalid load balancing method for NGINX Plus", + policy: &ngfAPI.UpstreamSettingsPolicy{ + Spec: ngfAPI.UpstreamSettingsPolicySpec{ + LoadBalancingMethod: helpers.GetPointer(ngfAPI.LoadBalancingType("invalid-method")), + }, + }, + expConditions: []conditions.Condition{ + conditions.NewPolicyInvalid("spec.loadBalancingMethod: Invalid value: \"invalid-method\": " + + "NGINX Plus supports the following load balancing methods: "), + }, + plusEnabled: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + t.Parallel() + g := NewWithT(t) + + v := upstreamsettings.NewValidator(validation.GenericValidator{}, test.plusEnabled) + conds := v.Validate(test.policy) + + if test.expConditions != nil { + g.Expect(conds).To(HaveLen(1)) + g.Expect(conds[0].Message).To(ContainSubstring(test.expConditions[0].Message)) + } + }) + } +} diff --git a/internal/controller/nginx/config/upstreams.go b/internal/controller/nginx/config/upstreams.go index 4b7e6bd16f..ef0c1ab1d4 100644 --- a/internal/controller/nginx/config/upstreams.go +++ b/internal/controller/nginx/config/upstreams.go @@ -4,6 +4,7 @@ import ( "fmt" gotemplate "text/template" + ngfAPI "github.com/nginx/nginx-gateway-fabric/v2/apis/v1alpha1" "github.com/nginx/nginx-gateway-fabric/v2/internal/controller/nginx/config/http" "github.com/nginx/nginx-gateway-fabric/v2/internal/controller/nginx/config/policies/upstreamsettings" "github.com/nginx/nginx-gateway-fabric/v2/internal/controller/nginx/config/stream" @@ -32,6 +33,8 @@ const ( plusZoneSizeStream = "1m" // stateDir is the directory for storing state files. stateDir = "/var/lib/nginx/state" + // default load balancing method. + defaultLBMethod = "random two least_conn" ) // keepAliveChecker takes an upstream name and returns if it has keep alive settings enabled. @@ -144,6 +147,7 @@ func (g GeneratorImpl) createUpstream( processor upstreamsettings.Processor, ) http.Upstream { var stateFile string + var sp http.UpstreamSessionPersistence upstreamPolicySettings := processor.Process(up.Policies) zoneSize := ossZoneSize @@ -152,14 +156,36 @@ func (g GeneratorImpl) createUpstream( // Only set state file if the upstream doesn't have resolve servers // Upstreams with resolve servers can't be managed via NGINX Plus API if !upstreamHasResolveServers(up) { - stateFile = fmt.Sprintf("%s/%s.conf", stateDir, up.Name) + base := up.StateFileKey + if base == "" { + base = up.Name + } + stateFile = fmt.Sprintf("%s/%s.conf", stateDir, base) } + + sp = getSessionPersistenceConfiguration(up.SessionPersistence) } if upstreamPolicySettings.ZoneSize != "" { zoneSize = upstreamPolicySettings.ZoneSize } + chosenLBMethod := defaultLBMethod + if upstreamPolicySettings.LoadBalancingMethod != "" { + lbMethod := upstreamPolicySettings.LoadBalancingMethod + + if lbMethod == string(ngfAPI.LoadBalancingTypeHash) { + lbMethod = fmt.Sprintf("hash %s", upstreamPolicySettings.HashMethodKey) + } + if lbMethod == string(ngfAPI.LoadBalancingTypeHashConsistent) { + lbMethod = fmt.Sprintf("hash %s consistent", upstreamPolicySettings.HashMethodKey) + } + if lbMethod == string(ngfAPI.LoadBalancingTypeRoundRobin) { + lbMethod = "" + } + chosenLBMethod = lbMethod + } + if len(up.Endpoints) == 0 { return http.Upstream{ Name: up.Name, @@ -170,6 +196,7 @@ func (g GeneratorImpl) createUpstream( Address: types.Nginx503Server, }, }, + LoadBalancingMethod: chosenLBMethod, } } @@ -186,11 +213,13 @@ func (g GeneratorImpl) createUpstream( } return http.Upstream{ - Name: up.Name, - ZoneSize: zoneSize, - StateFile: stateFile, - Servers: upstreamServers, - KeepAlive: upstreamPolicySettings.KeepAlive, + Name: up.Name, + ZoneSize: zoneSize, + StateFile: stateFile, + Servers: upstreamServers, + KeepAlive: upstreamPolicySettings.KeepAlive, + LoadBalancingMethod: chosenLBMethod, + SessionPersistence: sp, } } @@ -215,3 +244,17 @@ func upstreamHasResolveServers(upstream dataplane.Upstream) bool { } return false } + +// getSessionPersistenceConfiguration gets the session persistence configuration for an upstream. +// Supported only for NGINX Plus and cookie-based type. +func getSessionPersistenceConfiguration(sp dataplane.SessionPersistenceConfig) http.UpstreamSessionPersistence { + if sp.Name == "" { + return http.UpstreamSessionPersistence{} + } + return http.UpstreamSessionPersistence{ + Name: sp.Name, + Expiry: sp.Expiry, + Path: sp.Path, + SessionType: string(sp.SessionType), + } +} diff --git a/internal/controller/nginx/config/upstreams_template.go b/internal/controller/nginx/config/upstreams_template.go index 15e9b0c1fc..9f42702422 100644 --- a/internal/controller/nginx/config/upstreams_template.go +++ b/internal/controller/nginx/config/upstreams_template.go @@ -10,11 +10,19 @@ package config const upstreamsTemplateText = ` {{ range $u := . }} upstream {{ $u.Name }} { - random two least_conn; + {{ if $u.LoadBalancingMethod -}} + {{ $u.LoadBalancingMethod }}; + {{- end }} {{ if $u.ZoneSize -}} zone {{ $u.Name }} {{ $u.ZoneSize }}; {{ end -}} + {{ if $u.SessionPersistence.Name -}} + sticky {{ $u.SessionPersistence.SessionType }} {{ $u.SessionPersistence.Name }} + {{- if $u.SessionPersistence.Expiry }} expires={{ $u.SessionPersistence.Expiry }}{{- end }} + {{- if $u.SessionPersistence.Path }} path={{ $u.SessionPersistence.Path }}{{- end }}; + {{ end -}} + {{- if $u.StateFile }} state {{ $u.StateFile }}; {{- else }} diff --git a/internal/controller/nginx/config/upstreams_test.go b/internal/controller/nginx/config/upstreams_test.go index c13904d25d..f2a33ce296 100644 --- a/internal/controller/nginx/config/upstreams_test.go +++ b/internal/controller/nginx/config/upstreams_test.go @@ -1,6 +1,8 @@ package config import ( + "fmt" + "strings" "testing" . "github.com/onsi/gomega" @@ -17,9 +19,11 @@ import ( "github.com/nginx/nginx-gateway-fabric/v2/internal/framework/helpers" ) -func TestExecuteUpstreams(t *testing.T) { +func TestExecuteUpstreams_NginxOSS(t *testing.T) { t.Parallel() - gen := GeneratorImpl{} + gen := GeneratorImpl{ + plus: false, + } stateUpstreams := []dataplane.Upstream{ { Name: "up1", @@ -75,31 +79,40 @@ func TestExecuteUpstreams(t *testing.T) { Time: helpers.GetPointer[ngfAPI.Duration]("5s"), Timeout: helpers.GetPointer[ngfAPI.Duration]("10s"), }), + LoadBalancingMethod: helpers.GetPointer(ngfAPI.LoadBalancingTypeIPHash), }, }, }, }, } - expectedSubStrings := []string{ - "upstream up1", - "upstream up2", - "upstream up3", - "upstream up4-ipv6", - "upstream up5-usp", - "upstream invalid-backend-ref", + expectedSubStrings := map[string]int{ + "upstream up1": 1, + "upstream up2": 1, + "upstream up3": 1, + "upstream up4-ipv6": 1, + "upstream up5-usp": 1, + "upstream invalid-backend-ref": 1, - "server 10.0.0.0:80;", - "server 11.0.0.0:80;", - "server [2001:db8::1]:80", - "server 12.0.0.0:80;", - "server unix:/var/run/nginx/nginx-503-server.sock;", + "server 10.0.0.0:80;": 1, + "server 11.0.0.0:80;": 1, + "server [2001:db8::1]:80": 1, + "server 12.0.0.0:80;": 1, + "server unix:/var/run/nginx/nginx-503-server.sock;": 1, + + "keepalive 1;": 1, + "keepalive_requests 1;": 1, + "keepalive_time 5s;": 1, + "keepalive_timeout 10s;": 1, + "ip_hash;": 1, - "keepalive 1;", - "keepalive_requests 1;", - "keepalive_time 5s;", - "keepalive_timeout 10s;", - "zone up5-usp 2m;", + "zone up1 512k;": 1, + "zone up2 512k;": 1, + "zone up3 512k;": 1, + "zone up4-ipv6 512k;": 1, + "zone up5-usp 2m;": 1, + + "random two least_conn;": 4, } upstreams := gen.createUpstreams(stateUpstreams, upstreamsettings.NewProcessor()) @@ -107,11 +120,211 @@ func TestExecuteUpstreams(t *testing.T) { upstreamResults := executeUpstreams(upstreams) g := NewWithT(t) g.Expect(upstreamResults).To(HaveLen(1)) + g.Expect(upstreamResults[0].dest).To(Equal(httpConfigFile)) + nginxUpstreams := string(upstreamResults[0].data) + for expSubString, expectedCount := range expectedSubStrings { + actualCount := strings.Count(nginxUpstreams, expSubString) + g.Expect(actualCount).To( + Equal(expectedCount), + fmt.Sprintf("substring %q expected %d occurrence(s), got %d", expSubString, expectedCount, actualCount), + ) + } +} + +func TestExecuteUpstreams_NginxPlus(t *testing.T) { + t.Parallel() + gen := GeneratorImpl{ + plus: true, + } + stateUpstreams := []dataplane.Upstream{ + { + Name: "up1", + Endpoints: []resolver.Endpoint{ + { + Address: "10.0.0.0", + Port: 80, + }, + }, + }, + { + Name: "up2", + Endpoints: []resolver.Endpoint{ + { + Address: "11.0.0.0", + Port: 80, + }, + { + Address: "11.0.0.1", + Port: 80, + }, + { + Address: "11.0.0.2", + Port: 80, + }, + }, + }, + { + Name: "up3-ipv6", + Endpoints: []resolver.Endpoint{ + { + Address: "2001:db8::1", + Port: 80, + IPv6: true, + }, + }, + }, + { + Name: "up4-ipv6", + Endpoints: []resolver.Endpoint{ + { + Address: "2001:db8::2", + Port: 80, + IPv6: true, + }, + { + Address: "2001:db8::3", + Port: 80, + IPv6: true, + }, + }, + }, + { + Name: "up5", + Endpoints: []resolver.Endpoint{}, + }, + { + Name: "up6-usp-with-sp", + StateFileKey: "up6-usp-with-sp", + Endpoints: []resolver.Endpoint{ + { + Address: "12.0.0.1", + Port: 80, + }, + }, + Policies: []policies.Policy{ + &ngfAPI.UpstreamSettingsPolicy{ + ObjectMeta: metav1.ObjectMeta{ + Name: "usp", + Namespace: "test", + }, + Spec: ngfAPI.UpstreamSettingsPolicySpec{ + ZoneSize: helpers.GetPointer[ngfAPI.Size]("2m"), + KeepAlive: helpers.GetPointer(ngfAPI.UpstreamKeepAlive{ + Connections: helpers.GetPointer(int32(1)), + Requests: helpers.GetPointer(int32(1)), + Time: helpers.GetPointer[ngfAPI.Duration]("5s"), + Timeout: helpers.GetPointer[ngfAPI.Duration]("10s"), + }), + LoadBalancingMethod: helpers.GetPointer(ngfAPI.LoadBalancingTypeIPHash), + }, + }, + }, + SessionPersistence: dataplane.SessionPersistenceConfig{ + Name: "session-persistence", + Expiry: "30m", + Path: "/session", + SessionType: dataplane.CookieBasedSessionPersistence, + }, + }, + { + Name: "up6-with-same-state-file-key", + StateFileKey: "up6-usp-with-sp", + Endpoints: []resolver.Endpoint{ + { + Address: "12.0.0.1", + Port: 80, + }, + }, + }, + { + Name: "up7-with-sp", + Endpoints: []resolver.Endpoint{ + { + Address: "12.0.0.2", + Port: 80, + }, + }, + SessionPersistence: dataplane.SessionPersistenceConfig{ + Name: "session-persistence", + Expiry: "100h", + Path: "/v1/users", + SessionType: dataplane.CookieBasedSessionPersistence, + }, + }, + { + Name: "up8-with-sp-expiry-and-path-empty", + Endpoints: []resolver.Endpoint{ + { + Address: "12.0.0.3", + Port: 80, + }, + }, + SessionPersistence: dataplane.SessionPersistenceConfig{ + Name: "session-persistence", + SessionType: dataplane.CookieBasedSessionPersistence, + }, + }, + } + + expectedSubStrings := map[string]int{ + "upstream up1": 1, + "upstream up2": 1, + "upstream up3-ipv6": 1, + "upstream up4-ipv6": 1, + "upstream up5": 1, + "upstream up6-usp-with-sp": 1, + "upstream up7-with-sp": 1, + "upstream up8-with-sp-expiry-and-path-empty": 1, + "upstream invalid-backend-ref": 1, + + "random two least_conn;": 8, + "ip_hash;": 1, + + "zone up1 1m;": 1, + "zone up2 1m;": 1, + "zone up3-ipv6 1m;": 1, + "zone up4-ipv6 1m;": 1, + "zone up5 1m;": 1, + "zone up6-usp-with-sp 2m;": 1, + "zone up7-with-sp 1m;": 1, + "zone up8-with-sp-expiry-and-path-empty 1m;": 1, + + "sticky cookie session-persistence expires=30m path=/session;": 1, + "sticky cookie session-persistence expires=100h path=/v1/users;": 1, + "sticky cookie session-persistence;": 1, + "keepalive 1;": 1, + "keepalive_requests 1;": 1, + "keepalive_time 5s;": 1, + "keepalive_timeout 10s;": 1, + + "state /var/lib/nginx/state/up1.conf;": 1, + "state /var/lib/nginx/state/up2.conf;": 1, + "state /var/lib/nginx/state/up3-ipv6.conf;": 1, + "state /var/lib/nginx/state/up4-ipv6.conf;": 1, + "state /var/lib/nginx/state/up5.conf;": 1, + + "state /var/lib/nginx/state/up6-usp-with-sp.conf": 2, + "state /var/lib/nginx/state/up7-with-sp.conf;": 1, + "state /var/lib/nginx/state/up8-with-sp-expiry-and-path-empty.conf;": 1, + "server unix:/var/run/nginx/nginx-500-server.sock;": 1, + } + + upstreams := gen.createUpstreams(stateUpstreams, upstreamsettings.NewProcessor()) + + upstreamResults := executeUpstreams(upstreams) + g := NewWithT(t) + g.Expect(upstreamResults).To(HaveLen(1)) g.Expect(upstreamResults[0].dest).To(Equal(httpConfigFile)) - for _, expSubString := range expectedSubStrings { - g.Expect(nginxUpstreams).To(ContainSubstring(expSubString)) + + nginxUpstreams := string(upstreamResults[0].data) + for expSubString, expectedCount := range expectedSubStrings { + actualCount := strings.Count(nginxUpstreams, expSubString) + g.Expect(actualCount).To( + Equal(expectedCount), + fmt.Sprintf("substring %q expected %d occurrence(s), got %d", expSubString, expectedCount, actualCount), + ) } } @@ -181,6 +394,7 @@ func TestCreateUpstreams(t *testing.T) { Time: helpers.GetPointer[ngfAPI.Duration]("5s"), Timeout: helpers.GetPointer[ngfAPI.Duration]("10s"), }), + LoadBalancingMethod: helpers.GetPointer((ngfAPI.LoadBalancingTypeIPHash)), }, }, }, @@ -202,6 +416,7 @@ func TestCreateUpstreams(t *testing.T) { Address: "10.0.0.2:80", }, }, + LoadBalancingMethod: defaultLBMethod, }, { Name: "up2", @@ -211,6 +426,7 @@ func TestCreateUpstreams(t *testing.T) { Address: "11.0.0.0:80", }, }, + LoadBalancingMethod: defaultLBMethod, }, { Name: "up3", @@ -220,6 +436,7 @@ func TestCreateUpstreams(t *testing.T) { Address: types.Nginx503Server, }, }, + LoadBalancingMethod: defaultLBMethod, }, { Name: "up4-ipv6", @@ -229,6 +446,7 @@ func TestCreateUpstreams(t *testing.T) { Address: "[fd00:10:244:1::7]:80", }, }, + LoadBalancingMethod: defaultLBMethod, }, { Name: "up5-usp", @@ -244,6 +462,7 @@ func TestCreateUpstreams(t *testing.T) { Time: "5s", Timeout: "10s", }, + LoadBalancingMethod: string(ngfAPI.LoadBalancingTypeIPHash), }, { Name: invalidBackendRef, @@ -281,6 +500,7 @@ func TestCreateUpstream(t *testing.T) { Address: types.Nginx503Server, }, }, + LoadBalancingMethod: defaultLBMethod, }, msg: "nil endpoints", }, @@ -297,6 +517,7 @@ func TestCreateUpstream(t *testing.T) { Address: types.Nginx503Server, }, }, + LoadBalancingMethod: defaultLBMethod, }, msg: "no endpoints", }, @@ -332,6 +553,7 @@ func TestCreateUpstream(t *testing.T) { Address: "10.0.0.3:80", }, }, + LoadBalancingMethod: defaultLBMethod, }, msg: "multiple endpoints", }, @@ -354,6 +576,7 @@ func TestCreateUpstream(t *testing.T) { Address: "[fd00:10:244:1::7]:80", }, }, + LoadBalancingMethod: defaultLBMethod, }, msg: "endpoint ipv6", }, @@ -380,6 +603,7 @@ func TestCreateUpstream(t *testing.T) { Time: helpers.GetPointer[ngfAPI.Duration]("5s"), Timeout: helpers.GetPointer[ngfAPI.Duration]("10s"), }), + LoadBalancingMethod: helpers.GetPointer(ngfAPI.LoadBalancingTypeIPHash), }, }, }, @@ -398,6 +622,7 @@ func TestCreateUpstream(t *testing.T) { Time: "5s", Timeout: "10s", }, + LoadBalancingMethod: string(ngfAPI.LoadBalancingTypeIPHash), }, msg: "single upstreamSettingsPolicy", }, @@ -422,6 +647,7 @@ func TestCreateUpstream(t *testing.T) { Time: helpers.GetPointer[ngfAPI.Duration]("5s"), Timeout: helpers.GetPointer[ngfAPI.Duration]("10s"), }), + LoadBalancingMethod: helpers.GetPointer(ngfAPI.LoadBalancingTypeRandomTwoLeastConnection), }, }, &ngfAPI.UpstreamSettingsPolicy{ @@ -452,6 +678,7 @@ func TestCreateUpstream(t *testing.T) { Time: "5s", Timeout: "10s", }, + LoadBalancingMethod: string(ngfAPI.LoadBalancingTypeRandomTwoLeastConnection), }, msg: "multiple upstreamSettingsPolicies", }, @@ -481,6 +708,7 @@ func TestCreateUpstream(t *testing.T) { Address: "10.0.0.1:80", }, }, + LoadBalancingMethod: defaultLBMethod, }, msg: "empty upstreamSettingsPolicies", }, @@ -524,9 +752,43 @@ func TestCreateUpstream(t *testing.T) { Time: "5s", Timeout: "10s", }, + LoadBalancingMethod: defaultLBMethod, }, msg: "upstreamSettingsPolicy with only keep alive settings", }, + { + stateUpstream: dataplane.Upstream{ + Name: "upstreamSettingsPolicy with only load balancing settings", + Endpoints: []resolver.Endpoint{ + { + Address: "11.0.20.9", + Port: 80, + }, + }, + Policies: []policies.Policy{ + &ngfAPI.UpstreamSettingsPolicy{ + ObjectMeta: metav1.ObjectMeta{ + Name: "usp1", + Namespace: "test", + }, + Spec: ngfAPI.UpstreamSettingsPolicySpec{ + LoadBalancingMethod: helpers.GetPointer(ngfAPI.LoadBalancingTypeIPHash), + }, + }, + }, + }, + expectedUpstream: http.Upstream{ + Name: "upstreamSettingsPolicy with only load balancing settings", + ZoneSize: ossZoneSize, + Servers: []http.UpstreamServer{ + { + Address: "11.0.20.9:80", + }, + }, + LoadBalancingMethod: string(ngfAPI.LoadBalancingTypeIPHash), + }, + msg: "upstreamSettingsPolicy with only load balancing settings", + }, { stateUpstream: dataplane.Upstream{ Name: "external-name-service", @@ -547,6 +809,7 @@ func TestCreateUpstream(t *testing.T) { Resolve: true, }, }, + LoadBalancingMethod: defaultLBMethod, }, msg: "ExternalName service with DNS name", }, @@ -585,6 +848,7 @@ func TestCreateUpstream(t *testing.T) { Address: "[fd00:10:244:1::7]:80", }, }, + LoadBalancingMethod: defaultLBMethod, }, msg: "mixed IP addresses and DNS names", }, @@ -605,14 +869,15 @@ func TestCreateUpstreamPlus(t *testing.T) { gen := GeneratorImpl{plus: true} tests := []struct { + expectedUpstream http.Upstream msg string stateUpstream dataplane.Upstream - expectedUpstream http.Upstream }{ { msg: "with endpoints", stateUpstream: dataplane.Upstream{ - Name: "endpoints", + Name: "endpoints", + StateFileKey: "endpoints", Endpoints: []resolver.Endpoint{ { Address: "10.0.0.1", @@ -629,13 +894,15 @@ func TestCreateUpstreamPlus(t *testing.T) { Address: "10.0.0.1:80", }, }, + LoadBalancingMethod: defaultLBMethod, }, }, { msg: "no endpoints", stateUpstream: dataplane.Upstream{ - Name: "no-endpoints", - Endpoints: []resolver.Endpoint{}, + Name: "no-endpoints", + StateFileKey: "no-endpoints", + Endpoints: []resolver.Endpoint{}, }, expectedUpstream: http.Upstream{ Name: "no-endpoints", @@ -646,6 +913,43 @@ func TestCreateUpstreamPlus(t *testing.T) { Address: types.Nginx503Server, }, }, + LoadBalancingMethod: defaultLBMethod, + }, + }, + { + msg: "session persistence config with endpoints", + stateUpstream: dataplane.Upstream{ + Name: "sp-with-endpoints", + StateFileKey: "sp-with-endpoints", + Endpoints: []resolver.Endpoint{ + { + Address: "10.0.0.2", + Port: 80, + }, + }, + SessionPersistence: dataplane.SessionPersistenceConfig{ + Name: "session-persistence", + Expiry: "45m", + SessionType: dataplane.CookieBasedSessionPersistence, + Path: "/app", + }, + }, + expectedUpstream: http.Upstream{ + Name: "sp-with-endpoints", + ZoneSize: plusZoneSize, + StateFile: stateDir + "/sp-with-endpoints.conf", + Servers: []http.UpstreamServer{ + { + Address: "10.0.0.2:80", + }, + }, + LoadBalancingMethod: defaultLBMethod, + SessionPersistence: http.UpstreamSessionPersistence{ + Name: "session-persistence", + Expiry: "45m", + SessionType: string(dataplane.CookieBasedSessionPersistence), + Path: "/app", + }, }, }, } @@ -1139,3 +1443,206 @@ func TestKeepAliveChecker(t *testing.T) { }) } } + +func TestExecuteUpstreams_LoadBalancingMethod(t *testing.T) { + t.Parallel() + + tests := []struct { + expectedSubStrings map[string]int + name string + lbType ngfAPI.LoadBalancingType + HashMethodKey ngfAPI.HashMethodKey + }{ + { + name: "default load balancing method", + expectedSubStrings: map[string]int{ + "upstream up1-usp-ipv4": 1, + "upstream up2-usp-ipv6": 1, + "random two least_conn;": 2, + }, + }, + { + name: "round_robin load balancing method", + expectedSubStrings: map[string]int{ + "upstream up1-usp-ipv4": 1, + "upstream up2-usp-ipv6": 1, + }, + }, + { + name: "least_conn load balancing method", + lbType: ngfAPI.LoadBalancingTypeLeastConnection, + expectedSubStrings: map[string]int{ + "upstream up1-usp-ipv4": 1, + "upstream up2-usp-ipv6": 1, + "least_conn;": 2, + }, + }, + { + name: "ip_hash load balancing method", + lbType: ngfAPI.LoadBalancingTypeIPHash, + expectedSubStrings: map[string]int{ + "upstream up1-usp-ipv4": 1, + "upstream up2-usp-ipv6": 1, + "ip_hash;": 2, + }, + }, + { + name: "hash load balancing method with specific hash key", + lbType: ngfAPI.LoadBalancingTypeHash, + HashMethodKey: ngfAPI.HashMethodKey("$request_uri"), + expectedSubStrings: map[string]int{ + "upstream up1-usp-ipv4": 1, + "upstream up2-usp-ipv6": 1, + "hash $request_uri;": 2, + }, + }, + { + name: "hash consistent load balancing method with specific hash key", + lbType: ngfAPI.LoadBalancingTypeHashConsistent, + HashMethodKey: ngfAPI.HashMethodKey("$remote_addr"), + expectedSubStrings: map[string]int{ + "upstream up1-usp-ipv4": 1, + "upstream up2-usp-ipv6": 1, + "hash $remote_addr consistent;": 2, + }, + }, + { + name: "random load balancing method", + lbType: ngfAPI.LoadBalancingTypeRandom, + expectedSubStrings: map[string]int{ + "upstream up1-usp-ipv4": 1, + "upstream up2-usp-ipv6": 1, + "random;": 2, + }, + }, + { + name: "random two load balancing method", + lbType: ngfAPI.LoadBalancingTypeRandomTwo, + expectedSubStrings: map[string]int{ + "upstream up1-usp-ipv4": 1, + "upstream up2-usp-ipv6": 1, + "random two;": 2, + }, + }, + { + name: "random two least_time=header load balancing method", + lbType: ngfAPI.LoadBalancingTypeRandomTwoLeastTimeHeader, + expectedSubStrings: map[string]int{ + "upstream up1-usp-ipv4": 1, + "upstream up2-usp-ipv6": 1, + "random two least_time=header;": 2, + }, + }, + { + name: "random two least_time=last_byte load balancing method", + lbType: ngfAPI.LoadBalancingTypeRandomTwoLeastTimeLastByte, + expectedSubStrings: map[string]int{ + "upstream up1-usp-ipv4": 1, + "upstream up2-usp-ipv6": 1, + "random two least_time=last_byte;": 2, + }, + }, + { + name: "least_time header load balancing method", + lbType: ngfAPI.LoadBalancingTypeLeastTimeHeader, + expectedSubStrings: map[string]int{ + "upstream up1-usp-ipv4": 1, + "upstream up2-usp-ipv6": 1, + "least_time header;": 2, + }, + }, + { + name: "least_time last_byte load balancing method", + lbType: ngfAPI.LoadBalancingTypeLeastTimeLastByte, + expectedSubStrings: map[string]int{ + "upstream up1-usp-ipv4": 1, + "upstream up2-usp-ipv6": 1, + "least_time last_byte;": 2, + }, + }, + { + name: "least_time header inflight load balancing method", + lbType: ngfAPI.LoadBalancingTypeLeastTimeHeaderInflight, + expectedSubStrings: map[string]int{ + "upstream up1-usp-ipv4": 1, + "upstream up2-usp-ipv6": 1, + "least_time header inflight;": 2, + }, + }, + { + name: "least_time last_byte inflight load balancing method", + lbType: ngfAPI.LoadBalancingTypeLeastTimeLastByteInflight, + expectedSubStrings: map[string]int{ + "upstream up1-usp-ipv4": 1, + "upstream up2-usp-ipv6": 1, + "least_time last_byte inflight;": 2, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + g := NewWithT(t) + gen := GeneratorImpl{} + stateUpstreams := []dataplane.Upstream{ + { + Name: "up1-usp-ipv4", + Endpoints: []resolver.Endpoint{ + { + Address: "12.0.0.0", + Port: 80, + }, + }, + Policies: []policies.Policy{ + &ngfAPI.UpstreamSettingsPolicy{ + ObjectMeta: metav1.ObjectMeta{ + Name: "usp-ipv4", + Namespace: "test", + }, + Spec: ngfAPI.UpstreamSettingsPolicySpec{ + LoadBalancingMethod: helpers.GetPointer(tt.lbType), + HashMethodKey: helpers.GetPointer(tt.HashMethodKey), + }, + }, + }, + }, + { + Name: "up2-usp-ipv6", + Endpoints: []resolver.Endpoint{ + { + Address: "2001:db8::1", + Port: 80, + }, + }, + Policies: []policies.Policy{ + &ngfAPI.UpstreamSettingsPolicy{ + ObjectMeta: metav1.ObjectMeta{ + Name: "usp-ipv6", + Namespace: "test", + }, + Spec: ngfAPI.UpstreamSettingsPolicySpec{ + LoadBalancingMethod: helpers.GetPointer(tt.lbType), + HashMethodKey: helpers.GetPointer(tt.HashMethodKey), + }, + }, + }, + }, + } + + upstreams := gen.createUpstreams(stateUpstreams, upstreamsettings.NewProcessor()) + upstreamResults := executeUpstreams(upstreams) + + g.Expect(upstreamResults).To(HaveLen(1)) + nginxUpstreams := string(upstreamResults[0].data) + + for expSubString, expectedCount := range tt.expectedSubStrings { + actualCount := strings.Count(nginxUpstreams, expSubString) + g.Expect(actualCount).To( + Equal(expectedCount), + fmt.Sprintf("substring %q expected %d occurrence(s), got %d", expSubString, expectedCount, actualCount), + ) + } + }) + } +} diff --git a/internal/controller/nginx/config/validation/common.go b/internal/controller/nginx/config/validation/common.go index d5a9294fde..71eeab51fc 100644 --- a/internal/controller/nginx/config/validation/common.go +++ b/internal/controller/nginx/config/validation/common.go @@ -5,6 +5,7 @@ import ( "fmt" "regexp" "strings" + "time" "github.com/dlclark/regexp2" k8svalidation "k8s.io/apimachinery/pkg/util/validation" @@ -148,3 +149,68 @@ func validatePathInRegexMatch(path string) error { return nil } + +type HTTPDurationValidator struct{} + +func (d HTTPDurationValidator) ValidateDuration(duration string) (string, error) { + return d.validateDurationCanBeConvertedToNginxFormat(duration) +} + +// validateDurationCanBeConvertedToNginxFormat parses a Gateway API duration and returns a single-unit, +// NGINX-friendly duration that matches `^[0-9]{1,4}(ms|s|m|h)?$` +// The conversion rules are: +// - duration must be > 0 +// - ceil to the next millisecond +// - choose the smallest unit (ms→s→m→h) whose ceil value fits in 1–4 digits +// - always include a unit suffix +func (d HTTPDurationValidator) validateDurationCanBeConvertedToNginxFormat(in string) (string, error) { + // if the input already matches the NGINX format, return it as is + if durationStringFmtRegexp.MatchString(in) { + return in, nil + } + + td, err := time.ParseDuration(in) + if err != nil { + return "", fmt.Errorf("invalid duration: %w", err) + } + if td <= 0 { + return "", errors.New("duration must be > 0") + } + + ns := td.Nanoseconds() + ceilDivision := func(a, b int64) int64 { + return (a + b - 1) / b + } + + totalMS := ceilDivision(ns, int64(time.Millisecond)) + + type unit struct { + suffix string + step int64 + } + + units := []unit{ + {"ms", 1}, + {"s", 1000}, + {"m", 60 * 1000}, + {"h", 60 * 60 * 1000}, + } + + const maxValue = 9999 + var out string + for _, u := range units { + v := ceilDivision(totalMS, u.step) + if v >= 1 && v <= maxValue { + out = fmt.Sprintf("%d%s", v, u.suffix) + break + } + } + if out == "" { + return "", fmt.Errorf("duration is too large for NGINX format (exceeds %dh)", maxValue) + } + + if !durationStringFmtRegexp.MatchString(out) { + return "", fmt.Errorf("computed duration %q does not match NGINX format", out) + } + return out, nil +} diff --git a/internal/controller/nginx/config/validation/common_test.go b/internal/controller/nginx/config/validation/common_test.go index 6b944c3b95..d7f432535c 100644 --- a/internal/controller/nginx/config/validation/common_test.go +++ b/internal/controller/nginx/config/validation/common_test.go @@ -161,3 +161,31 @@ func TestValidatePathInRegexMatch(t *testing.T) { `^(([a-z])+)+$`, // nested quantifiers not allowed ) } + +func TestValidateDurationCanBeConvertedToNginxFormat(t *testing.T) { + t.Parallel() + validator := HTTPDurationValidator{} + + testValidValuesForResultValidator[string, string]( + t, + validator.validateDurationCanBeConvertedToNginxFormat, + resultTestCase[string, string]{input: "24h", expected: "24h"}, + resultTestCase[string, string]{input: "1ms", expected: "1ms"}, + resultTestCase[string, string]{input: "1.1ms", expected: "2ms"}, + resultTestCase[string, string]{input: "100s", expected: "100s"}, + resultTestCase[string, string]{input: "1m", expected: "1m"}, + resultTestCase[string, string]{input: "1h", expected: "1h"}, + resultTestCase[string, string]{input: "9999s", expected: "9999s"}, + resultTestCase[string, string]{input: "10000s", expected: "167m"}, + ) + + testInvalidValuesForResultValidator[string, string]( + t, + validator.validateDurationCanBeConvertedToNginxFormat, + "", + "foo", + "-1s", + "1000000h", // too large + "9999h1s", // just over max + ) +} diff --git a/internal/controller/nginx/config/validation/framework_test.go b/internal/controller/nginx/config/validation/framework_test.go index 7e0ff5d0ce..a9ceaaf117 100644 --- a/internal/controller/nginx/config/validation/framework_test.go +++ b/internal/controller/nginx/config/validation/framework_test.go @@ -149,3 +149,45 @@ func TestGetSortedKeysAsString(t *testing.T) { result := getSortedKeysAsString(values) g.Expect(result).To(Equal(expected)) } + +type resultValidatorFunc[T configValue, R any] func(v T) (R, error) + +type resultTestCase[T configValue, R any] struct { + input T + expected R +} + +func testInvalidValuesForResultValidator[T configValue, R any]( + t *testing.T, + f resultValidatorFunc[T, R], + values ...T, +) { + t.Helper() + runValidatorTests( + t, + func(g *WithT, v T) { + _, err := f(v) + g.Expect(err).To(HaveOccurred(), createFailureMessage(v)) + }, + "invalid_value", + values..., + ) +} + +func testValidValuesForResultValidator[T configValue, R any]( + t *testing.T, + f resultValidatorFunc[T, R], + cases ...resultTestCase[T, R], +) { + t.Helper() + for i, tc := range cases { + name := fmt.Sprintf("test-case=%d", i) + + t.Run(name, func(t *testing.T) { + g := NewWithT(t) + got, err := f(tc.input) + g.Expect(err).ToNot(HaveOccurred(), createFailureMessage(tc.input)) + g.Expect(got).To(Equal(tc.expected), createFailureMessage(tc.input)) + }) + } +} diff --git a/internal/controller/nginx/config/validation/generic.go b/internal/controller/nginx/config/validation/generic.go index 8342ab4134..f63073955b 100644 --- a/internal/controller/nginx/config/validation/generic.go +++ b/internal/controller/nginx/config/validation/generic.go @@ -106,3 +106,24 @@ func (GenericValidator) ValidateEndpoint(endpoint string) error { return nil } + +const ( + variableNameFmt = `\$[a-z_]+` + variableNameErrMsg = "must start with '$' followed by lowercase letters and underscores only" +) + +var variableNameRegexp = regexp.MustCompile("^" + variableNameFmt + "$") + +// ValidateNginxVariableName validates an nginx variable name. +func (GenericValidator) ValidateNginxVariableName(name string) error { + if !variableNameRegexp.MatchString(name) { + examples := []string{ + "$upstream_addr", + "$remote_addr", + } + + return errors.New(k8svalidation.RegexError(variableNameFmt, variableNameErrMsg, examples...)) + } + + return nil +} diff --git a/internal/controller/nginx/config/validation/generic_test.go b/internal/controller/nginx/config/validation/generic_test.go index 5f57b51c56..73be3f10cb 100644 --- a/internal/controller/nginx/config/validation/generic_test.go +++ b/internal/controller/nginx/config/validation/generic_test.go @@ -112,3 +112,25 @@ func TestValidateEndpoint(t *testing.T) { `my$endpoint`, ) } + +func TestValidateNginxVariableName(t *testing.T) { + t.Parallel() + validator := GenericValidator{} + + testValidValuesForSimpleValidator( + t, + validator.ValidateNginxVariableName, + `$upstream_bytes_sent`, + `$upstream_last_server_name`, + `$remote_addr`, + ) + + testInvalidValuesForSimpleValidator( + t, + validator.ValidateNginxVariableName, + `1varname`, + `var-name`, + `var name`, + `var$name`, + ) +} diff --git a/internal/controller/nginx/config/validation/http_validator.go b/internal/controller/nginx/config/validation/http_validator.go index 4837113458..779682bf4e 100644 --- a/internal/controller/nginx/config/validation/http_validator.go +++ b/internal/controller/nginx/config/validation/http_validator.go @@ -13,6 +13,7 @@ type HTTPValidator struct { HTTPURLRewriteValidator HTTPHeaderValidator HTTPPathValidator + HTTPDurationValidator } func (HTTPValidator) SkipValidation() bool { return false } diff --git a/internal/controller/state/change_processor.go b/internal/controller/state/change_processor.go index d661903b8c..74bef98779 100644 --- a/internal/controller/state/change_processor.go +++ b/internal/controller/state/change_processor.go @@ -64,8 +64,8 @@ type ChangeProcessorConfig struct { GatewayCtlrName string // GatewayClassName is the name of the GatewayClass resource. GatewayClassName string - // ExperimentalFeatures indicates if experimental features are enabled. - ExperimentalFeatures bool + // FeaturesFlags holds the feature flags for building the Graph. + FeatureFlags graph.FeatureFlags } // ChangeProcessorImpl is an implementation of ChangeProcessor. @@ -278,7 +278,7 @@ func (c *ChangeProcessorImpl) Process() *graph.Graph { c.cfg.PlusSecrets, c.cfg.Validators, c.cfg.Logger, - c.cfg.ExperimentalFeatures, + c.cfg.FeatureFlags, ) return c.latestGraph diff --git a/internal/controller/state/dataplane/configuration.go b/internal/controller/state/dataplane/configuration.go index def60605b6..7d9b78d5cb 100644 --- a/internal/controller/state/dataplane/configuration.go +++ b/internal/controller/state/dataplane/configuration.go @@ -791,6 +791,7 @@ func buildUpstreams( referencedServices, uniqueUpstreams, allowedAddressType, + br.SessionPersistence, ); upstream != nil { uniqueUpstreams[upstream.Name] = *upstream } @@ -826,6 +827,7 @@ func buildUpstream( referencedServices map[types.NamespacedName]*graph.ReferencedService, uniqueUpstreams map[string]Upstream, allowedAddressType []discoveryV1.AddressType, + sessionPersistence *graph.SessionPersistenceConfig, ) *Upstream { if !br.Valid { return nil @@ -844,7 +846,6 @@ func buildUpstream( } var errMsg string - eps, err := resolveUpstreamEndpoints( ctx, logger, @@ -865,11 +866,23 @@ func buildUpstream( upstreamPolicies = buildPolicies(gateway, graphSvc.Policies) } + var sp SessionPersistenceConfig + if sessionPersistence != nil { + sp = SessionPersistenceConfig{ + Name: sessionPersistence.Name, + Expiry: sessionPersistence.Expiry, + Path: sessionPersistence.Path, + SessionType: CookieBasedSessionPersistence, + } + } + return &Upstream{ - Name: upstreamName, - Endpoints: eps, - ErrorMsg: errMsg, - Policies: upstreamPolicies, + Name: upstreamName, + Endpoints: eps, + ErrorMsg: errMsg, + Policies: upstreamPolicies, + SessionPersistence: sp, + StateFileKey: br.BaseServicePortKey(), } } diff --git a/internal/controller/state/dataplane/configuration_test.go b/internal/controller/state/dataplane/configuration_test.go index b9c3bc27e8..f99b0df700 100644 --- a/internal/controller/state/dataplane/configuration_test.go +++ b/internal/controller/state/dataplane/configuration_test.go @@ -55,8 +55,9 @@ var ( } fooUpstream = Upstream{ - Name: fooUpstreamName, - Endpoints: fooEndpoints, + Name: fooUpstreamName, + Endpoints: fooEndpoints, + StateFileKey: fooUpstreamName, } // routes. @@ -172,7 +173,7 @@ func getNormalBackendRef() graph.BackendRef { } } -func getExpectedConfiguration() Configuration { +func getExpectedSPConfiguration() Configuration { return Configuration{ BaseHTTPConfig: defaultBaseHTTPConfig, HTTPServers: []VirtualServer{ @@ -237,7 +238,7 @@ func getModifiedGraph(mod func(g *graph.Graph) *graph.Graph) *graph.Graph { } func getModifiedExpectedConfiguration(mod func(conf Configuration) Configuration) Configuration { - return mod(getExpectedConfiguration()) + return mod(getExpectedSPConfiguration()) } func createFakePolicy(name string, kind string) policies.Policy { @@ -2885,10 +2886,10 @@ func TestGetListenerHostname(t *testing.T) { } } -func refsToValidRules(refs ...[]graph.BackendRef) []graph.RouteRule { - rules := make([]graph.RouteRule, 0, len(refs)) +func refsToValidRules(backendRefs ...[]graph.BackendRef) []graph.RouteRule { + rules := make([]graph.RouteRule, 0, len(backendRefs)) - for _, ref := range refs { + for _, ref := range backendRefs { rules = append(rules, graph.RouteRule{ ValidMatches: true, Filters: graph.RouteRuleFilters{Valid: true}, @@ -2993,44 +2994,66 @@ func TestBuildUpstreams(t *testing.T) { }, } - createBackendRefs := func(serviceNames ...string) []graph.BackendRef { + createBackendRefs := func(sp *graph.SessionPersistenceConfig, serviceNames ...string) []graph.BackendRef { var backends []graph.BackendRef for _, name := range serviceNames { backends = append(backends, graph.BackendRef{ - SvcNsName: types.NamespacedName{Namespace: "test", Name: name}, - ServicePort: apiv1.ServicePort{Port: 80}, - Valid: name != "", + SvcNsName: types.NamespacedName{Namespace: "test", Name: name}, + ServicePort: apiv1.ServicePort{Port: 80}, + Valid: name != "", + SessionPersistence: sp, }) } return backends } - hr1Refs0 := createBackendRefs("foo", "bar") + createSPConfig := func(idx string) *graph.SessionPersistenceConfig { + return &graph.SessionPersistenceConfig{ + Name: "session-persistence", + SessionType: v1.CookieBasedSessionPersistence, + Expiry: "24h", + Path: "/", + Valid: true, + Idx: idx, + } + } + + hr1Refs0 := createBackendRefs(createSPConfig("foo-bar-sp"), "foo", "bar") - hr1Refs1 := createBackendRefs("baz", "", "") // empty service names should be ignored + hr1Refs1 := createBackendRefs(nil, "baz", "", "") // empty service names should be ignored - hr1Refs2 := createBackendRefs("invalid-for-gateway") + hr1Refs2 := createBackendRefs(nil, "invalid-for-gateway") hr1Refs2[0].InvalidForGateways = map[types.NamespacedName]conditions.Condition{ {Namespace: "test", Name: "gateway"}: {}, } - hr2Refs0 := createBackendRefs("foo", "baz") // shouldn't duplicate foo and baz upstream + // should duplicate foo upstream because it has a different SP config + hr2Refs0 := createBackendRefs(createSPConfig("foo-baz-sp"), "foo", "baz") - hr2Refs1 := createBackendRefs("nil-endpoints") + hr2Refs1 := createBackendRefs(nil, "nil-endpoints") - hr3Refs0 := createBackendRefs("baz") // shouldn't duplicate baz upstream + hr3Refs0 := createBackendRefs(nil, "baz") // shouldn't duplicate baz upstream - hr4Refs0 := createBackendRefs("empty-endpoints", "") + hr4Refs0 := createBackendRefs(nil, "empty-endpoints", "") - hr4Refs1 := createBackendRefs("baz2") + hr4Refs1 := createBackendRefs(nil, "baz2") - hr5Refs0 := createBackendRefs("ipv6-endpoints") + hr5Refs0 := createBackendRefs(nil, "ipv6-endpoints") - nonExistingRefs := createBackendRefs("non-existing") + nonExistingRefs := createBackendRefs(nil, "non-existing") - invalidHRRefs := createBackendRefs("abc") + invalidHRRefs := createBackendRefs(nil, "abc") - refsWithPolicies := createBackendRefs("policies") + refsWithPolicies := createBackendRefs(createSPConfig("policies-sp"), "policies") + + getExpectedSPConfig := func() SessionPersistenceConfig { + return SessionPersistenceConfig{ + Name: "session-persistence", + SessionType: CookieBasedSessionPersistence, + Expiry: "24h", + Path: "/", + } + } routes := map[graph.RouteKey]*graph.L7Route{ {NamespacedName: types.NamespacedName{Name: "hr1", Namespace: "test"}}: { @@ -3175,39 +3198,67 @@ func TestBuildUpstreams(t *testing.T) { expUpstreams := []Upstream{ { - Name: "test_bar_80", - Endpoints: barEndpoints, + Name: "test_bar_80_foo-bar-sp", + Endpoints: barEndpoints, + SessionPersistence: getExpectedSPConfig(), + StateFileKey: "test_bar_80", }, { - Name: "test_baz2_80", - Endpoints: baz2Endpoints, + Name: "test_baz2_80", + Endpoints: baz2Endpoints, + SessionPersistence: SessionPersistenceConfig{}, + StateFileKey: "test_baz2_80", }, { - Name: "test_baz_80", - Endpoints: bazEndpoints, + Name: "test_baz_80", + Endpoints: bazEndpoints, + SessionPersistence: SessionPersistenceConfig{}, + StateFileKey: "test_baz_80", }, { - Name: "test_empty-endpoints_80", - Endpoints: []resolver.Endpoint{}, - ErrorMsg: emptyEndpointsErrMsg, + Name: "test_baz_80_foo-baz-sp", + Endpoints: bazEndpoints, + SessionPersistence: getExpectedSPConfig(), + StateFileKey: "test_baz_80", }, { - Name: "test_foo_80", - Endpoints: fooEndpoints, + Name: "test_empty-endpoints_80", + Endpoints: []resolver.Endpoint{}, + ErrorMsg: emptyEndpointsErrMsg, + SessionPersistence: SessionPersistenceConfig{}, + StateFileKey: "test_empty-endpoints_80", }, { - Name: "test_nil-endpoints_80", - Endpoints: nil, - ErrorMsg: nilEndpointsErrMsg, + Name: "test_foo_80_foo-bar-sp", + Endpoints: fooEndpoints, + SessionPersistence: getExpectedSPConfig(), + StateFileKey: "test_foo_80", }, { - Name: "test_ipv6-endpoints_80", - Endpoints: ipv6Endpoints, + Name: "test_foo_80_foo-baz-sp", + Endpoints: fooEndpoints, + SessionPersistence: getExpectedSPConfig(), + StateFileKey: "test_foo_80", }, { - Name: "test_policies_80", - Endpoints: policyEndpoints, - Policies: []policies.Policy{validPolicy1, validPolicy2}, + Name: "test_ipv6-endpoints_80", + Endpoints: ipv6Endpoints, + SessionPersistence: SessionPersistenceConfig{}, + StateFileKey: "test_ipv6-endpoints_80", + }, + { + Name: "test_nil-endpoints_80", + Endpoints: nil, + ErrorMsg: nilEndpointsErrMsg, + StateFileKey: "test_nil-endpoints_80", + }, + + { + Name: "test_policies_80_policies-sp", + Endpoints: policyEndpoints, + Policies: []policies.Policy{validPolicy1, validPolicy2}, + SessionPersistence: getExpectedSPConfig(), + StateFileKey: "test_policies_80", }, } diff --git a/internal/controller/state/dataplane/types.go b/internal/controller/state/dataplane/types.go index 9e1beeaa0f..4b7ff0927e 100644 --- a/internal/controller/state/dataplane/types.go +++ b/internal/controller/state/dataplane/types.go @@ -112,16 +112,40 @@ type Layer4VirtualServer struct { // Upstream is a pool of endpoints to be load balanced. type Upstream struct { + // SessionPersistence holds the session persistence configuration for the upstream. + SessionPersistence SessionPersistenceConfig // Name is the name of the Upstream. Will be unique for each service/port combination. Name string // ErrorMsg contains the error message if the Upstream is invalid. ErrorMsg string + // StateFileKey is the key for naming the state file for NGINX Plus upstreams. + StateFileKey string // Endpoints are the endpoints of the Upstream. Endpoints []resolver.Endpoint // Policies holds all the valid policies that apply to the Upstream. Policies []policies.Policy } +// SessionPersistenceConfig holds the session persistence configuration for an upstream. +type SessionPersistenceConfig struct { + // SessionType is the type of session persistence. + SessionType SessionPersistenceType + // Name is the name of the session. + Name string + // Expiry is the expiration time of the session. + Expiry string + // Path is the path for which session is applied. + Path string +} + +// SessionPersistenceType is the type of session persistence. +type SessionPersistenceType string + +const ( + // CookieBasedSessionPersistence indicates cookie-based session persistence. + CookieBasedSessionPersistence SessionPersistenceType = "cookie" +) + // SSL is the SSL configuration for a server. type SSL struct { // KeyPairID is the ID of the corresponding SSLKeyPair for the server. diff --git a/internal/controller/state/graph/backend_refs.go b/internal/controller/state/graph/backend_refs.go index fbfbdec550..4e8622a476 100644 --- a/internal/controller/state/graph/backend_refs.go +++ b/internal/controller/state/graph/backend_refs.go @@ -34,6 +34,8 @@ type BackendRef struct { // condition. Certain NginxProxy configurations may result in a backend not being valid for some Gateways, // but not others. InvalidForGateways map[types.NamespacedName]conditions.Condition + // SessionPersistence is the SessionPersistenceConfig of the backendRef. + SessionPersistence *SessionPersistenceConfig // SvcNsName is the NamespacedName of the Service referenced by the backendRef. SvcNsName types.NamespacedName // ServicePort is the ServicePort of the Service which is referenced by the backendRef. @@ -49,12 +51,24 @@ type BackendRef struct { IsInferencePool bool } -// ServicePortReference returns a string representation for the service and port that is referenced by the BackendRef. +// BaseServicePortKey returns a base unique string key for the Service port of the BackendRef. +func (b BackendRef) BaseServicePortKey() string { + return fmt.Sprintf("%s_%s_%d", b.SvcNsName.Namespace, b.SvcNsName.Name, b.ServicePort.Port) +} + +// ServicePortReference returns a unique string reference for the Service port of the BackendRef including +// session persistence index if applicable. func (b BackendRef) ServicePortReference() string { if !b.Valid { return "" } - return fmt.Sprintf("%s_%s_%d", b.SvcNsName.Namespace, b.SvcNsName.Name, b.ServicePort.Port) + + base := b.BaseServicePortKey() + if b.SessionPersistence == nil || b.SessionPersistence.Name == "" { + return base + } + + return fmt.Sprintf("%s_%s", base, b.SessionPersistence.Idx) } func addBackendRefsToRouteRules( @@ -313,6 +327,7 @@ func createBackendRef( IsInferencePool: ref.IsInferencePool, InvalidForGateways: invalidForGateways, EndpointPickerConfig: ref.EndpointPickerConfig, + SessionPersistence: ref.SessionPersistence, } return backendRef, conds diff --git a/internal/controller/state/graph/backend_refs_test.go b/internal/controller/state/graph/backend_refs_test.go index ab9da6c39f..c5f2461458 100644 --- a/internal/controller/state/graph/backend_refs_test.go +++ b/internal/controller/state/graph/backend_refs_test.go @@ -1413,6 +1413,14 @@ func TestCreateBackend(t *testing.T) { }, } + expectedSPConfig := SessionPersistenceConfig{ + Name: "test-persistence", + Idx: "test-persistence-idx", + Valid: true, + Expiry: "10m", + Path: "/test-path", + } + tests := []struct { nginxProxySpec *EffectiveNginxProxy name string @@ -1431,8 +1439,9 @@ func TestCreateBackend(t *testing.T) { Weight: 5, Valid: true, InvalidForGateways: map[types.NamespacedName]conditions.Condition{}, + SessionPersistence: &expectedSPConfig, }, - expectedServicePortReference: "test_service1_80", + expectedServicePortReference: "test_service1_80_test-persistence-idx", expectedConditions: nil, name: "normal case", }, @@ -1449,8 +1458,9 @@ func TestCreateBackend(t *testing.T) { Weight: 1, Valid: true, InvalidForGateways: map[types.NamespacedName]conditions.Condition{}, + SessionPersistence: &expectedSPConfig, }, - expectedServicePortReference: "test_service1_80", + expectedServicePortReference: "test_service1_80_test-persistence-idx", expectedConditions: nil, name: "normal with nil weight", }, @@ -1534,8 +1544,9 @@ func TestCreateBackend(t *testing.T) { `The Service configured with IPv4 family but NginxProxy is configured with IPv6`, ), }, + SessionPersistence: &expectedSPConfig, }, - expectedServicePortReference: "test_service1_80", + expectedServicePortReference: "test_service1_80_test-persistence-idx", nginxProxySpec: &EffectiveNginxProxy{IPFamily: helpers.GetPointer(ngfAPIv1alpha2.IPv6)}, expectedConditions: nil, name: "service IPFamily doesn't match NginxProxy IPFamily", @@ -1554,8 +1565,9 @@ func TestCreateBackend(t *testing.T) { Valid: true, BackendTLSPolicy: &btp, InvalidForGateways: map[types.NamespacedName]conditions.Condition{}, + SessionPersistence: &expectedSPConfig, }, - expectedServicePortReference: "test_service2_80", + expectedServicePortReference: "test_service2_80_test-persistence-idx", expectedConditions: nil, name: "normal case with policy", }, @@ -1598,8 +1610,9 @@ func TestCreateBackend(t *testing.T) { "ExternalName service requires DNS resolver configuration in Gateway's NginxProxy", ), }, + SessionPersistence: &expectedSPConfig, }, - expectedServicePortReference: "test_external-service_80", + expectedServicePortReference: "test_external-service_80_test-persistence-idx", expectedConditions: nil, name: "ExternalName service without DNS resolver", }, @@ -1623,8 +1636,9 @@ func TestCreateBackend(t *testing.T) { Weight: 5, Valid: true, InvalidForGateways: map[types.NamespacedName]conditions.Condition{}, + SessionPersistence: &expectedSPConfig, }, - expectedServicePortReference: "test_external-service_80", + expectedServicePortReference: "test_external-service_80_test-persistence-idx", expectedConditions: nil, name: "ExternalName service with DNS resolver", }, @@ -1645,8 +1659,9 @@ func TestCreateBackend(t *testing.T) { "ExternalName service requires DNS resolver configuration in Gateway's NginxProxy", ), }, + SessionPersistence: &expectedSPConfig, }, - expectedServicePortReference: "test_external-service_80", + expectedServicePortReference: "test_external-service_80_test-persistence-idx", expectedConditions: nil, name: "ExternalName service with multiple gateways - mixed DNS resolver config", }, @@ -1731,6 +1746,13 @@ func TestCreateBackend(t *testing.T) { IsInferencePool: false, BackendRef: test.ref.BackendRef, Filters: []any{}, + SessionPersistence: &SessionPersistenceConfig{ + Name: "test-persistence", + Idx: "test-persistence-idx", + Valid: true, + Expiry: "10m", + Path: "/test-path", + }, } route := &L7Route{ RouteType: RouteTypeHTTP, diff --git a/internal/controller/state/graph/graph.go b/internal/controller/state/graph/graph.go index ffd86ac6e0..c670ccc487 100644 --- a/internal/controller/state/graph/graph.go +++ b/internal/controller/state/graph/graph.go @@ -92,6 +92,14 @@ type NginxReloadResult struct { // ProtectedPorts are the ports that may not be configured by a listener with a descriptive name of each port. type ProtectedPorts map[int32]string +// FeatureFlags hold the feature flags for building the Graph. +type FeatureFlags struct { + // Plus indicates whether NGINX Plus features are enabled. + Plus bool + // Experimental indicates whether experimental features are enabled. + Experimental bool +} + // IsReferenced returns true if the Graph references the resource. func (g *Graph) IsReferenced(resourceType ngftypes.ObjectType, nsname types.NamespacedName) bool { switch obj := resourceType.(type) { @@ -208,7 +216,7 @@ func BuildGraph( plusSecrets map[types.NamespacedName][]PlusSecretFile, validators validation.Validators, logger logr.Logger, - experimentalEnabled bool, + featureFlags FeatureFlags, ) *Graph { processedGwClasses, gcExists := processGatewayClasses(state.GatewayClasses, gcName, controllerName) if gcExists && processedGwClasses.Winner == nil { @@ -228,7 +236,7 @@ func BuildGraph( processedGwClasses.Winner, processedNginxProxies, state.CRDMetadata, - experimentalEnabled, + featureFlags.Experimental, ) secretResolver := newSecretResolver(state.Secrets) @@ -242,7 +250,7 @@ func BuildGraph( gc, refGrantResolver, processedNginxProxies, - experimentalEnabled, + featureFlags.Experimental, ) processedBackendTLSPolicies := processBackendTLSPolicies( @@ -261,6 +269,7 @@ func BuildGraph( gws, processedSnippetsFilters, state.InferencePools, + featureFlags, ) referencedInferencePools := buildReferencedInferencePools(routes, gws, state.InferencePools, state.Services) diff --git a/internal/controller/state/graph/graph_test.go b/internal/controller/state/graph/graph_test.go index 60e11790d3..5f2d4e4369 100644 --- a/internal/controller/state/graph/graph_test.go +++ b/internal/controller/state/graph/graph_test.go @@ -165,7 +165,10 @@ func TestBuildGraph(t *testing.T) { }, } - createValidRuleWithBackendRefs := func(matches []gatewayv1.HTTPRouteMatch) RouteRule { + createValidRuleWithBackendRefs := func( + matches []gatewayv1.HTTPRouteMatch, + sessionPersistence *SessionPersistenceConfig, + ) RouteRule { refs := []BackendRef{ { SvcNsName: types.NamespacedName{Namespace: "service", Name: "foo"}, @@ -174,11 +177,13 @@ func TestBuildGraph(t *testing.T) { Weight: 1, BackendTLSPolicy: &btp, InvalidForGateways: map[types.NamespacedName]conditions.Condition{}, + SessionPersistence: sessionPersistence, }, } rbrs := []RouteBackendRef{ { - BackendRef: commonGWBackendRef, + BackendRef: commonGWBackendRef, + SessionPersistence: sessionPersistence, }, } return RouteRule{ @@ -196,8 +201,9 @@ func TestBuildGraph(t *testing.T) { createValidRuleWithBackendRefsAndFilters := func( matches []gatewayv1.HTTPRouteMatch, routeType RouteType, + sessionPersistence *SessionPersistenceConfig, ) RouteRule { - rule := createValidRuleWithBackendRefs(matches) + rule := createValidRuleWithBackendRefs(matches, sessionPersistence) rule.Filters = RouteRuleFilters{ Filters: []Filter{ { @@ -334,14 +340,23 @@ func TestBuildGraph(t *testing.T) { } } + spConfig := &gatewayv1.SessionPersistence{ + SessionName: helpers.GetPointer("session-persistence-httproute"), + Type: helpers.GetPointer(gatewayv1.CookieBasedSessionPersistence), + AbsoluteTimeout: helpers.GetPointer(gatewayv1.Duration("30m")), + CookieConfig: &gatewayv1.CookieConfig{ + LifetimeType: helpers.GetPointer(gatewayv1.PermanentCookieLifetimeType), + }, + } hr1 := createRoute("hr-1", "gateway-1", "listener-80-1") - addFilterToPath( + addElementsToPath( hr1, "/", gatewayv1.HTTPRouteFilter{ Type: gatewayv1.HTTPRouteFilterExtensionRef, ExtensionRef: refSnippetsFilterExtensionRef, }, + spConfig, ) hr2 := createRoute("hr-2", "wrong-gateway", "listener-80-1") @@ -382,6 +397,14 @@ func TestBuildGraph(t *testing.T) { ExtensionRef: refSnippetsFilterExtensionRef, }, }, + SessionPersistence: &gatewayv1.SessionPersistence{ + SessionName: helpers.GetPointer("session-persistence-grpcroute"), + Type: helpers.GetPointer(gatewayv1.CookieBasedSessionPersistence), + AbsoluteTimeout: helpers.GetPointer(gatewayv1.Duration("30m")), + CookieConfig: &gatewayv1.CookieConfig{ + LifetimeType: helpers.GetPointer(gatewayv1.PermanentCookieLifetimeType), + }, + }, }, }, }, @@ -859,6 +882,15 @@ func TestBuildGraph(t *testing.T) { } } + getExpectedSPConfig := &SessionPersistenceConfig{ + Name: "session-persistence-httproute", + SessionType: gatewayv1.CookieBasedSessionPersistence, + Expiry: "30m", + Valid: true, + Path: "/", + Idx: "hr-1_test_0", + } + routeHR1 := &L7Route{ RouteType: RouteTypeHTTP, Valid: true, @@ -886,7 +918,7 @@ func TestBuildGraph(t *testing.T) { }, Spec: L7RouteSpec{ Hostnames: hr1.Spec.Hostnames, - Rules: []RouteRule{createValidRuleWithBackendRefsAndFilters(routeMatches, RouteTypeHTTP)}, + Rules: []RouteRule{createValidRuleWithBackendRefsAndFilters(routeMatches, RouteTypeHTTP, getExpectedSPConfig)}, }, Policies: []*Policy{processedRoutePolicy}, Conditions: []conditions.Condition{ @@ -1050,6 +1082,13 @@ func TestBuildGraph(t *testing.T) { }, } + expectedSPgr := &SessionPersistenceConfig{ + Name: "session-persistence-grpcroute", + SessionType: gatewayv1.CookieBasedSessionPersistence, + Expiry: "30m", + Valid: true, + Idx: "gr_test_0", + } routeGR := &L7Route{ RouteType: RouteTypeGRPC, Valid: true, @@ -1078,7 +1117,7 @@ func TestBuildGraph(t *testing.T) { Spec: L7RouteSpec{ Hostnames: gr.Spec.Hostnames, Rules: []RouteRule{ - createValidRuleWithBackendRefsAndFilters(routeMatches, RouteTypeGRPC), + createValidRuleWithBackendRefsAndFilters(routeMatches, RouteTypeGRPC, expectedSPgr), }, }, } @@ -1110,7 +1149,7 @@ func TestBuildGraph(t *testing.T) { }, Spec: L7RouteSpec{ Hostnames: hr3.Spec.Hostnames, - Rules: []RouteRule{createValidRuleWithBackendRefs(routeMatches)}, + Rules: []RouteRule{createValidRuleWithBackendRefs(routeMatches, nil)}, }, } @@ -1449,15 +1488,16 @@ func TestBuildGraph(t *testing.T) { } tests := []struct { - store ClusterState - expected *Graph - name string - experimentalEnabled bool + store ClusterState + expected *Graph + name string + plus, experimentalEnabled bool }{ { store: createStateWithGatewayClass(normalGC), expected: createExpectedGraphWithGatewayClass(normalGC), experimentalEnabled: true, + plus: true, name: "normal case", }, { @@ -1476,6 +1516,12 @@ func TestBuildGraph(t *testing.T) { fakePolicyValidator := &validationfakes.FakePolicyValidator{} + createAllValidValidator := func() *validationfakes.FakeHTTPFieldsValidator { + v := &validationfakes.FakeHTTPFieldsValidator{} + v.ValidateDurationReturns("30m", nil) + return v + } + result := BuildGraph( test.store, controllerName, @@ -1489,12 +1535,15 @@ func TestBuildGraph(t *testing.T) { }, }, validation.Validators{ - HTTPFieldsValidator: &validationfakes.FakeHTTPFieldsValidator{}, + HTTPFieldsValidator: createAllValidValidator(), GenericValidator: &validationfakes.FakeGenericValidator{}, PolicyValidator: fakePolicyValidator, }, logr.Discard(), - test.experimentalEnabled, + FeatureFlags{ + Experimental: test.experimentalEnabled, + Plus: test.plus, + }, ) g.Expect(helpers.Diff(test.expected, result)).To(BeEmpty()) diff --git a/internal/controller/state/graph/grpcroute.go b/internal/controller/state/graph/grpcroute.go index 5d6e7489e3..4bfe3b9fdc 100644 --- a/internal/controller/state/graph/grpcroute.go +++ b/internal/controller/state/graph/grpcroute.go @@ -21,6 +21,7 @@ func buildGRPCRoute( ghr *v1.GRPCRoute, gws map[types.NamespacedName]*Gateway, snippetsFilters map[types.NamespacedName]*SnippetsFilter, + featureFlags FeatureFlags, ) *L7Route { r := &L7Route{ Source: ghr, @@ -52,10 +53,16 @@ func buildGRPCRoute( r.Spec.Hostnames = ghr.Spec.Hostnames r.Attachable = true + grpcRouteNsName := types.NamespacedName{ + Namespace: ghr.GetNamespace(), + Name: ghr.GetName(), + } rules, valid, conds := processGRPCRouteRules( ghr.Spec.Rules, validator, getSnippetsFilterResolverForNamespace(snippetsFilters, r.Source.GetNamespace()), + grpcRouteNsName, + featureFlags, ) r.Spec.Rules = rules @@ -71,6 +78,7 @@ func buildGRPCMirrorRoutes( route *v1.GRPCRoute, gateways map[types.NamespacedName]*Gateway, snippetsFilters map[types.NamespacedName]*SnippetsFilter, + featureFlags FeatureFlags, ) { for idx, rule := range l7route.Spec.Rules { if rule.Filters.Valid { @@ -107,6 +115,7 @@ func buildGRPCMirrorRoutes( tmpMirrorRoute, gateways, snippetsFilters, + featureFlags, ) if mirrorRoute != nil { @@ -157,15 +166,17 @@ func removeGRPCMirrorFilters(filters []v1.GRPCRouteFilter) []v1.GRPCRouteFilter func processGRPCRouteRule( specRule v1.GRPCRouteRule, - rulePath *field.Path, + ruleIdx int, validator validation.HTTPFieldsValidator, resolveExtRefFunc resolveExtRefFilter, + grpcRouteNsName types.NamespacedName, + featureFlags FeatureFlags, ) (RouteRule, routeRuleErrors) { - var errors routeRuleErrors - + rulePath := field.NewPath("spec").Child("rules").Index(ruleIdx) validMatches := true - unsupportedFieldsErrors := checkForUnsupportedGRPCFields(specRule, rulePath) + var errors routeRuleErrors + unsupportedFieldsErrors := checkForUnsupportedGRPCFields(specRule, rulePath, featureFlags) if len(unsupportedFieldsErrors) > 0 { errors.warn = append(errors.warn, unsupportedFieldsErrors...) } @@ -189,6 +200,26 @@ func processGRPCRouteRule( errors = errors.append(filterErrors) + var sp *SessionPersistenceConfig + if specRule.SessionPersistence != nil { + spConfig, spErrors := processSessionPersistenceConfig( + specRule.SessionPersistence, + specRule.Matches, + rulePath.Child("sessionPersistence"), + validator, + ) + errors = errors.append(spErrors) + + if spConfig != nil && spConfig.Valid { + spKey := getSessionPersistenceKey(ruleIdx, grpcRouteNsName) + spConfig.Idx = spKey + if spConfig.Name == "" { + spConfig.Name = fmt.Sprintf("sp_%s", spKey) + } + sp = spConfig + } + } + backendRefs := make([]RouteBackendRef, 0, len(specRule.BackendRefs)) // rule.BackendRefs are validated separately because of their special requirements @@ -201,8 +232,9 @@ func processGRPCRouteRule( } } rbr := RouteBackendRef{ - BackendRef: b.BackendRef, - Filters: interfaceFilters, + BackendRef: b.BackendRef, + Filters: interfaceFilters, + SessionPersistence: sp, } backendRefs = append(backendRefs, rbr) } @@ -235,6 +267,8 @@ func processGRPCRouteRules( specRules []v1.GRPCRouteRule, validator validation.HTTPFieldsValidator, resolveExtRefFunc resolveExtRefFilter, + grpcRouteNsName types.NamespacedName, + featureFlags FeatureFlags, ) (rules []RouteRule, valid bool, conds []conditions.Condition) { rules = make([]RouteRule, len(specRules)) @@ -243,14 +277,14 @@ func processGRPCRouteRules( atLeastOneValid bool ) - for i, rule := range specRules { - rulePath := field.NewPath("spec").Child("rules").Index(i) - + for ruleIdx, rule := range specRules { rr, errors := processGRPCRouteRule( rule, - rulePath, + ruleIdx, validator, resolveExtRefFunc, + grpcRouteNsName, + featureFlags, ) if rr.ValidMatches && rr.Filters.Valid { @@ -259,7 +293,7 @@ func processGRPCRouteRules( allRulesErrors = allRulesErrors.append(errors) - rules[i] = rr + rules[ruleIdx] = rr } conds = make([]conditions.Condition, 0, 2) @@ -455,7 +489,11 @@ func validateGRPCHeaderMatch( return allErrs } -func checkForUnsupportedGRPCFields(rule v1.GRPCRouteRule, rulePath *field.Path) field.ErrorList { +func checkForUnsupportedGRPCFields( + rule v1.GRPCRouteRule, + rulePath *field.Path, + featureFlags FeatureFlags, +) field.ErrorList { var ruleErrors field.ErrorList if rule.Name != nil { @@ -464,10 +502,21 @@ func checkForUnsupportedGRPCFields(rule v1.GRPCRouteRule, rulePath *field.Path) "Name", )) } - if rule.SessionPersistence != nil { + + if !featureFlags.Plus && rule.SessionPersistence != nil { + ruleErrors = append(ruleErrors, field.Forbidden( + rulePath.Child("sessionPersistence"), + fmt.Sprintf( + "%s OSS users can use `ip_hash` load balancing method via the UpstreamSettingsPolicy for session affinity.", + spErrMsg, + ), + )) + } + + if !featureFlags.Experimental && rule.SessionPersistence != nil { ruleErrors = append(ruleErrors, field.Forbidden( rulePath.Child("sessionPersistence"), - "SessionPersistence", + spErrMsg, )) } diff --git a/internal/controller/state/graph/grpcroute_test.go b/internal/controller/state/graph/grpcroute_test.go index 032b6a891f..2d0e493551 100644 --- a/internal/controller/state/graph/grpcroute_test.go +++ b/internal/controller/state/graph/grpcroute_test.go @@ -2,6 +2,7 @@ package graph import ( "errors" + "fmt" "testing" . "github.com/onsi/gomega" @@ -20,38 +21,56 @@ import ( "github.com/nginx/nginx-gateway-fabric/v2/internal/framework/kinds" ) -func createGRPCMethodMatch(serviceName, methodName, methodType string) v1.GRPCRouteRule { +func createGRPCMethodMatch( + serviceName, + methodName, + methodType string, + sp *v1.SessionPersistence, + backendRef []v1.GRPCBackendRef, +) v1.GRPCRouteRule { var mt *v1.GRPCMethodMatchType if methodType != "nilType" { mt = (*v1.GRPCMethodMatchType)(&methodType) } - return v1.GRPCRouteRule{ - Matches: []v1.GRPCRouteMatch{ - { - Method: &v1.GRPCMethodMatch{ - Type: mt, - Service: &serviceName, - Method: &methodName, - }, + matches := []v1.GRPCRouteMatch{ + { + Method: &v1.GRPCMethodMatch{ + Type: mt, + Service: &serviceName, + Method: &methodName, }, }, } + + if sp != nil { + return v1.GRPCRouteRule{ + Matches: matches, + SessionPersistence: sp, + BackendRefs: backendRef, + } + } + + return v1.GRPCRouteRule{ + Matches: matches, + } } func createGRPCHeadersMatch(headerType, headerName, headerValue string) v1.GRPCRouteRule { - return v1.GRPCRouteRule{ - Matches: []v1.GRPCRouteMatch{ - { - Headers: []v1.GRPCHeaderMatch{ - { - Type: (*v1.GRPCHeaderMatchType)(&headerType), - Name: v1.GRPCHeaderName(headerName), - Value: headerValue, - }, + matches := []v1.GRPCRouteMatch{ + { + Headers: []v1.GRPCHeaderMatch{ + { + Type: (*v1.GRPCHeaderMatchType)(&headerType), + Name: v1.GRPCHeaderName(headerName), + Value: headerValue, }, }, }, } + + return v1.GRPCRouteRule{ + Matches: matches, + } } func createGRPCRoute( @@ -114,11 +133,32 @@ func TestBuildGRPCRoutes(t *testing.T) { RequestHeaderModifier: &v1.HTTPHeaderFilter{}, } - grRuleWithFilters := v1.GRPCRouteRule{ - Filters: []v1.GRPCRouteFilter{snippetsFilterRef, requestHeaderFilter}, + unNamedSPConfig := v1.SessionPersistence{ + AbsoluteTimeout: helpers.GetPointer(v1.Duration("10m")), + Type: helpers.GetPointer(v1.CookieBasedSessionPersistence), + CookieConfig: &v1.CookieConfig{ + LifetimeType: helpers.GetPointer((v1.PermanentCookieLifetimeType)), + }, } - gr := createGRPCRoute("gr-1", gwNsName.Name, "example.com", []v1.GRPCRouteRule{grRuleWithFilters}) + grpcBackendRef := v1.GRPCBackendRef{ + BackendRef: v1.BackendRef{ + BackendObjectReference: v1.BackendObjectReference{ + Kind: helpers.GetPointer[v1.Kind]("Service"), + Name: "grpc-service", + Namespace: helpers.GetPointer[v1.Namespace]("test"), + Port: helpers.GetPointer[v1.PortNumber](80), + }, + }, + } + + grRuleWithFiltersAndSessionPersistence := v1.GRPCRouteRule{ + Filters: []v1.GRPCRouteFilter{snippetsFilterRef, requestHeaderFilter}, + SessionPersistence: &unNamedSPConfig, + BackendRefs: []v1.GRPCBackendRef{grpcBackendRef}, + } + + gr := createGRPCRoute("gr-1", gwNsName.Name, "example.com", []v1.GRPCRouteRule{grRuleWithFiltersAndSessionPersistence}) grWrongGateway := createGRPCRoute("gr-2", "some-gateway", "example.com", []v1.GRPCRouteRule{}) @@ -193,8 +233,21 @@ func TestBuildGRPCRoutes(t *testing.T) { }, }, }, - ValidMatches: true, - RouteBackendRefs: []RouteBackendRef{}, + ValidMatches: true, + RouteBackendRefs: []RouteBackendRef{ + { + BackendRef: v1.BackendRef{ + BackendObjectReference: grpcBackendRef.BackendObjectReference, + }, + SessionPersistence: &SessionPersistenceConfig{ + Valid: true, + Name: "sp_gr-1_test_0", + SessionType: *unNamedSPConfig.Type, + Expiry: "10m", + Idx: "gr-1_test_0", + }, + }, + }, }, }, }, @@ -209,7 +262,11 @@ func TestBuildGRPCRoutes(t *testing.T) { }, } - validator := &validationfakes.FakeHTTPFieldsValidator{} + createAllValidValidator := func() *validationfakes.FakeHTTPFieldsValidator { + v := &validationfakes.FakeHTTPFieldsValidator{} + v.ValidateDurationReturns("10m", nil) + return v + } for _, test := range tests { t.Run(test.name, func(t *testing.T) { @@ -225,14 +282,17 @@ func TestBuildGRPCRoutes(t *testing.T) { }, }, } - routes := buildRoutesForGateways( - validator, + createAllValidValidator(), map[types.NamespacedName]*v1.HTTPRoute{}, grRoutes, test.gateways, snippetsFilters, nil, + FeatureFlags{ + Plus: true, + Experimental: true, + }, ) g.Expect(helpers.Diff(test.expected, routes)).To(BeEmpty()) }) @@ -255,13 +315,40 @@ func TestBuildGRPCRoute(t *testing.T) { }, } gatewayNsName := client.ObjectKeyFromObject(gw.Source) + backendRef := v1.BackendRef{ + BackendObjectReference: v1.BackendObjectReference{ + Kind: helpers.GetPointer[v1.Kind]("Service"), + Name: "service1", + Namespace: helpers.GetPointer[v1.Namespace]("test"), + Port: helpers.GetPointer[v1.PortNumber](80), + }, + } + + grpcBackendRef := v1.GRPCBackendRef{ + BackendRef: backendRef, + } - methodMatchRule := createGRPCMethodMatch("myService", "myMethod", "Exact") + spMethod := &v1.SessionPersistence{ + SessionName: helpers.GetPointer("grpc-method-session"), + AbsoluteTimeout: helpers.GetPointer(v1.Duration("10h")), + Type: helpers.GetPointer(v1.CookieBasedSessionPersistence), + CookieConfig: &v1.CookieConfig{ + LifetimeType: helpers.GetPointer((v1.PermanentCookieLifetimeType)), + }, + } + + methodMatchRule := createGRPCMethodMatch( + "myService", + "myMethod", + "Exact", + spMethod, + []v1.GRPCBackendRef{grpcBackendRef}, + ) headersMatchRule := createGRPCHeadersMatch("Exact", "MyHeader", "SomeValue") - methodMatchEmptyFields := createGRPCMethodMatch("", "", "") - methodMatchInvalidFields := createGRPCMethodMatch("service{}", "method{}", "Exact") - methodMatchNilType := createGRPCMethodMatch("myService", "myMethod", "nilType") + methodMatchEmptyFields := createGRPCMethodMatch("", "", "", nil, nil) + methodMatchInvalidFields := createGRPCMethodMatch("service{}", "method{}", "Exact", nil, nil) + methodMatchNilType := createGRPCMethodMatch("myService", "myMethod", "nilType", nil, nil) headersMatchInvalid := createGRPCHeadersMatch("", "MyHeader", "SomeValue") headersMatchEmptyType := v1.GRPCRouteRule{ @@ -284,19 +371,6 @@ func TestBuildGRPCRoute(t *testing.T) { []v1.GRPCRouteRule{methodMatchRule, headersMatchRule}, ) - backendRef := v1.BackendRef{ - BackendObjectReference: v1.BackendObjectReference{ - Kind: helpers.GetPointer[v1.Kind]("Service"), - Name: "service1", - Namespace: helpers.GetPointer[v1.Namespace]("test"), - Port: helpers.GetPointer[v1.PortNumber](80), - }, - } - - grpcBackendRef := v1.GRPCBackendRef{ - BackendRef: backendRef, - } - grEmptyMatch := createGRPCRoute( "gr-1", gatewayNsName.Name, @@ -396,7 +470,7 @@ func TestBuildGRPCRoute(t *testing.T) { grDuplicateSectionName.Spec.ParentRefs[0], ) - grInvalidFilterRule := createGRPCMethodMatch("myService", "myMethod", "Exact") + grInvalidFilterRule := createGRPCMethodMatch("myService", "myMethod", "Exact", nil, nil) grInvalidFilterRule.Filters = []v1.GRPCRouteFilter{ { @@ -411,7 +485,7 @@ func TestBuildGRPCRoute(t *testing.T) { []v1.GRPCRouteRule{grInvalidFilterRule}, ) - grValidFilterRule := createGRPCMethodMatch("myService", "myMethod", "Exact") + grValidFilterRule := createGRPCMethodMatch("myService", "myMethod", "Exact", nil, nil) grValidHeaderMatch := createGRPCHeadersMatch("RegularExpression", "MyHeader", "headers-[a-z]+") validSnippetsFilterRef := &v1.LocalObjectReference{ Group: ngfAPIv1alpha1.GroupName, @@ -451,7 +525,7 @@ func TestBuildGRPCRoute(t *testing.T) { ) // route with invalid snippets filter extension ref - grInvalidSnippetsFilterRule := createGRPCMethodMatch("myService", "myMethod", "Exact") + grInvalidSnippetsFilterRule := createGRPCMethodMatch("myService", "myMethod", "Exact", nil, nil) grInvalidSnippetsFilterRule.Filters = []v1.GRPCRouteFilter{ { Type: v1.GRPCRouteFilterExtensionRef, @@ -470,7 +544,7 @@ func TestBuildGRPCRoute(t *testing.T) { ) // route with unresolvable snippets filter extension ref - grUnresolvableSnippetsFilterRule := createGRPCMethodMatch("myService", "myMethod", "Exact") + grUnresolvableSnippetsFilterRule := createGRPCMethodMatch("myService", "myMethod", "Exact", nil, nil) grUnresolvableSnippetsFilterRule.Filters = []v1.GRPCRouteFilter{ { Type: v1.GRPCRouteFilterExtensionRef, @@ -489,7 +563,7 @@ func TestBuildGRPCRoute(t *testing.T) { ) // route with two invalid snippets filter extensions refs: (1) invalid group (2) unresolvable - grInvalidAndUnresolvableSnippetsFilterRule := createGRPCMethodMatch("myService", "myMethod", "Exact") + grInvalidAndUnresolvableSnippetsFilterRule := createGRPCMethodMatch("myService", "myMethod", "Exact", nil, nil) grInvalidAndUnresolvableSnippetsFilterRule.Filters = []v1.GRPCRouteFilter{ { Type: v1.GRPCRouteFilterExtensionRef, @@ -521,6 +595,17 @@ func TestBuildGRPCRoute(t *testing.T) { return v } + createDurationValidator := func(duration *v1.Duration) *validationfakes.FakeHTTPFieldsValidator { + v := &validationfakes.FakeHTTPFieldsValidator{} + + if duration == nil { + v.ValidateDurationReturns("", nil) + } else { + v.ValidateDurationReturns(string(*duration), nil) + } + return v + } + routeFilters := []Filter{ { RouteType: RouteTypeGRPC, @@ -552,14 +637,16 @@ func TestBuildGRPCRoute(t *testing.T) { }, } + durationSP := v1.Duration("10h") tests := []struct { - validator *validationfakes.FakeHTTPFieldsValidator - gr *v1.GRPCRoute - expected *L7Route - name string + validator *validationfakes.FakeHTTPFieldsValidator + gr *v1.GRPCRoute + expected *L7Route + name string + plus, experimental bool }{ { - validator: createAllValidValidator(), + validator: createDurationValidator(&durationSP), gr: grBoth, expected: &L7Route{ RouteType: RouteTypeGRPC, @@ -582,8 +669,19 @@ func TestBuildGRPCRoute(t *testing.T) { Valid: true, Filters: []Filter{}, }, - Matches: ConvertGRPCMatches(grBoth.Spec.Rules[0].Matches), - RouteBackendRefs: []RouteBackendRef{}, + Matches: ConvertGRPCMatches(grBoth.Spec.Rules[0].Matches), + RouteBackendRefs: []RouteBackendRef{ + { + BackendRef: grpcBackendRef.BackendRef, + SessionPersistence: &SessionPersistenceConfig{ + Valid: true, + Name: "grpc-method-session", + SessionType: v1.CookieBasedSessionPersistence, + Expiry: "10h", + Idx: "gr-1_test_0", + }, + }, + }, }, { ValidMatches: true, @@ -597,7 +695,9 @@ func TestBuildGRPCRoute(t *testing.T) { }, }, }, - name: "normal case with both", + plus: true, + experimental: true, + name: "normal case with both", }, { validator: createAllValidValidator(), @@ -765,7 +865,7 @@ func TestBuildGRPCRoute(t *testing.T) { name: "invalid route with duplicate sectionName", }, { - validator: createAllValidValidator(), + validator: createDurationValidator(&durationSP), gr: grOneInvalid, expected: &L7Route{ Source: grOneInvalid, @@ -793,8 +893,19 @@ func TestBuildGRPCRoute(t *testing.T) { Valid: true, Filters: []Filter{}, }, - Matches: ConvertGRPCMatches(grOneInvalid.Spec.Rules[0].Matches), - RouteBackendRefs: []RouteBackendRef{}, + Matches: ConvertGRPCMatches(grOneInvalid.Spec.Rules[0].Matches), + RouteBackendRefs: []RouteBackendRef{ + { + BackendRef: grpcBackendRef.BackendRef, + SessionPersistence: &SessionPersistenceConfig{ + Valid: true, + Name: "grpc-method-session", + SessionType: v1.CookieBasedSessionPersistence, + Expiry: "10h", + Idx: "gr-1_test_0", + }, + }, + }, }, { ValidMatches: false, @@ -808,7 +919,9 @@ func TestBuildGRPCRoute(t *testing.T) { }, }, }, - name: "invalid headers and valid method", + plus: true, + experimental: true, + name: "invalid headers and valid method", }, { validator: createAllValidValidator(), @@ -1199,7 +1312,16 @@ func TestBuildGRPCRoute(t *testing.T) { snippetsFilters := map[types.NamespacedName]*SnippetsFilter{ {Namespace: "test", Name: "sf"}: {Valid: true}, } - route := buildGRPCRoute(test.validator, test.gr, gws, snippetsFilters) + route := buildGRPCRoute( + test.validator, + test.gr, + gws, + snippetsFilters, + FeatureFlags{ + Plus: test.plus, + Experimental: test.experimental, + }, + ) g.Expect(helpers.Diff(test.expected, route)).To(BeEmpty()) }) } @@ -1351,11 +1473,21 @@ func TestBuildGRPCRouteWithMirrorRoutes(t *testing.T) { g := NewWithT(t) + featureFlags := FeatureFlags{ + Plus: false, + Experimental: false, + } + routes := map[RouteKey]*L7Route{} - l7route := buildGRPCRoute(validator, gr, gateways, snippetsFilters) + l7route := buildGRPCRoute( + validator, + gr, + gateways, + snippetsFilters, + featureFlags, + ) g.Expect(l7route).NotTo(BeNil()) - - buildGRPCMirrorRoutes(routes, l7route, gr, gateways, snippetsFilters) + buildGRPCMirrorRoutes(routes, l7route, gr, gateways, snippetsFilters, featureFlags) obj, ok := expectedMirrorRoute.Source.(*v1.GRPCRoute) g.Expect(ok).To(BeTrue()) @@ -1366,7 +1498,7 @@ func TestBuildGRPCRouteWithMirrorRoutes(t *testing.T) { func TestConvertGRPCMatches(t *testing.T) { t.Parallel() - methodMatch := createGRPCMethodMatch("myService", "myMethod", "Exact").Matches + methodMatch := createGRPCMethodMatch("myService", "myMethod", "Exact", nil, nil).Matches headersMatch := createGRPCHeadersMatch("Exact", "MyHeader", "SomeValue").Matches @@ -1523,7 +1655,7 @@ func TestProcessGRPCRouteRule_UnsupportedFields(t *testing.T) { Type: helpers.GetPointer(v1.SessionPersistenceType("unsupported-session-persistence")), }), }, - expectedErrors: 2, + expectedErrors: 3, }, } @@ -1536,7 +1668,14 @@ func TestProcessGRPCRouteRule_UnsupportedFields(t *testing.T) { var errors routeRuleErrors // Wrap the rule in GRPCRouteRuleWrapper - unsupportedFieldsErrors := checkForUnsupportedGRPCFields(test.specRule, rulePath) + unsupportedFieldsErrors := checkForUnsupportedGRPCFields( + test.specRule, + rulePath, + FeatureFlags{ + Plus: false, + Experimental: false, + }, + ) if len(unsupportedFieldsErrors) > 0 { errors.warn = append(errors.warn, unsupportedFieldsErrors...) } @@ -1555,6 +1694,8 @@ func TestProcessGRPCRouteRules_UnsupportedFields(t *testing.T) { expectedConds []conditions.Condition expectedWarns int expectedValid bool + plusEnabled bool + experimental bool }{ { name: "No unsupported fields", @@ -1582,17 +1723,58 @@ func TestProcessGRPCRouteRules_UnsupportedFields(t *testing.T) { { Name: helpers.GetPointer[v1.SectionName]("unsupported-name"), SessionPersistence: helpers.GetPointer(v1.SessionPersistence{ - Type: helpers.GetPointer(v1.SessionPersistenceType("unsupported-session-persistence")), + Type: helpers.GetPointer(v1.CookieBasedSessionPersistence), + SessionName: helpers.GetPointer("session_id"), }), }, }, expectedValid: true, expectedConds: []conditions.Condition{ - conditions.NewRouteAcceptedUnsupportedField("[spec.rules[0].name: Forbidden: Name, " + - "spec.rules[0].sessionPersistence: Forbidden: SessionPersistence]"), + conditions.NewRouteAcceptedUnsupportedField(fmt.Sprintf("[spec.rules[0].name: Forbidden: Name, "+ + "spec.rules[0].sessionPersistence: Forbidden: "+ + "%s"+ + " OSS users can use `ip_hash` load balancing method via the UpstreamSettingsPolicy for session affinity.]", + spErrMsg, + )), }, + experimental: true, + plusEnabled: false, expectedWarns: 2, }, + { + name: "Session persistence unsupported with experimental disabled", + specRules: []v1.GRPCRouteRule{ + { + SessionPersistence: helpers.GetPointer(v1.SessionPersistence{ + Type: helpers.GetPointer(v1.CookieBasedSessionPersistence), + SessionName: helpers.GetPointer("session_id"), + }), + }, + }, + expectedValid: true, + expectedConds: []conditions.Condition{ + conditions.NewRouteAcceptedUnsupportedField(fmt.Sprintf("spec.rules[0].sessionPersistence: Forbidden: "+ + "%s", spErrMsg)), + }, + expectedWarns: 1, + plusEnabled: true, + experimental: false, + }, + { + name: "Session Persistence supported with Plus enabled and experimental enabled", + specRules: []v1.GRPCRouteRule{ + { + SessionPersistence: helpers.GetPointer(v1.SessionPersistence{ + Type: helpers.GetPointer(v1.CookieBasedSessionPersistence), + SessionName: helpers.GetPointer("session_id"), + }), + }, + }, + expectedValid: true, + plusEnabled: true, + experimental: true, + expectedWarns: 0, + }, } for _, test := range tests { @@ -1600,10 +1782,20 @@ func TestProcessGRPCRouteRules_UnsupportedFields(t *testing.T) { t.Parallel() g := NewWithT(t) + grpcRouteNsName := types.NamespacedName{ + Namespace: "test", + Name: "grpc-route", + } + _, valid, conds := processGRPCRouteRules( test.specRules, validation.SkipValidator{}, nil, + grpcRouteNsName, + FeatureFlags{ + Plus: test.plusEnabled, + Experimental: test.experimental, + }, ) g.Expect(valid).To(Equal(test.expectedValid)) diff --git a/internal/controller/state/graph/httproute.go b/internal/controller/state/graph/httproute.go index 1a4749b984..29a044d17d 100644 --- a/internal/controller/state/graph/httproute.go +++ b/internal/controller/state/graph/httproute.go @@ -31,6 +31,7 @@ func buildHTTPRoute( gws map[types.NamespacedName]*Gateway, snippetsFilters map[types.NamespacedName]*SnippetsFilter, inferencePools map[types.NamespacedName]*inference.InferencePool, + featureFlags FeatureFlags, ) *L7Route { r := &L7Route{ Source: ghr, @@ -63,12 +64,17 @@ func buildHTTPRoute( r.Spec.Hostnames = ghr.Spec.Hostnames r.Attachable = true + nsName := types.NamespacedName{ + Name: ghr.GetName(), + Namespace: ghr.GetNamespace(), + } rules, valid, conds := processHTTPRouteRules( ghr.Spec.Rules, validator, getSnippetsFilterResolverForNamespace(snippetsFilters, r.Source.GetNamespace()), inferencePools, - r.Source.GetNamespace(), + nsName, + featureFlags, ) r.Spec.Rules = rules @@ -84,6 +90,7 @@ func buildHTTPMirrorRoutes( route *v1.HTTPRoute, gateways map[types.NamespacedName]*Gateway, snippetsFilters map[types.NamespacedName]*SnippetsFilter, + featureFlags FeatureFlags, ) { for idx, rule := range l7route.Spec.Rules { if rule.Filters.Valid { @@ -121,6 +128,7 @@ func buildHTTPMirrorRoutes( gateways, snippetsFilters, nil, + featureFlags, ) if mirrorRoute != nil { @@ -171,15 +179,17 @@ func removeHTTPMirrorFilters(filters []v1.HTTPRouteFilter) []v1.HTTPRouteFilter func processHTTPRouteRule( specRule v1.HTTPRouteRule, - rulePath *field.Path, + ruleIdx int, validator validation.HTTPFieldsValidator, resolveExtRefFunc resolveExtRefFilter, inferencePools map[types.NamespacedName]*inference.InferencePool, - routeNamespace string, + routeNsName types.NamespacedName, + featureFlags FeatureFlags, ) (RouteRule, routeRuleErrors) { - var errors routeRuleErrors + rulePath := field.NewPath("spec").Child("rules").Index(ruleIdx) - unsupportedFieldsErrors := checkForUnsupportedHTTPFields(specRule, rulePath) + var errors routeRuleErrors + unsupportedFieldsErrors := checkForUnsupportedHTTPFields(specRule, rulePath, featureFlags) if len(unsupportedFieldsErrors) > 0 { errors.warn = append(errors.warn, unsupportedFieldsErrors...) } @@ -202,28 +212,77 @@ func processHTTPRouteRule( validator, resolveExtRefFunc, ) - errors = errors.append(filterErrors) - backendRefs := make([]RouteBackendRef, 0, len(specRule.BackendRefs)) + var sp *SessionPersistenceConfig + if specRule.SessionPersistence != nil { + spConfig, spErrors := processSessionPersistenceConfig( + specRule.SessionPersistence, + specRule.Matches, + rulePath.Child("sessionPersistence"), + validator, + ) + errors = errors.append(spErrors) + + if spConfig != nil && spConfig.Valid { + spKey := getSessionPersistenceKey(ruleIdx, routeNsName) + spConfig.Idx = spKey + if spConfig.Name == "" { + spConfig.Name = fmt.Sprintf("sp_%s", spKey) + } + sp = spConfig + } + } + + backendRefs, backendRefErrors := getBackendRefs(specRule, routeNsName.Namespace, inferencePools, rulePath, sp) + errors = errors.append(backendRefErrors) + + if routeFilters.Valid { + for i, filter := range routeFilters.Filters { + if filter.RequestMirror == nil { + continue + } + + rbr := RouteBackendRef{ + BackendRef: v1.BackendRef{ + BackendObjectReference: filter.RequestMirror.BackendRef, + }, + MirrorBackendIdx: helpers.GetPointer(i), + } + backendRefs = append(backendRefs, rbr) + } + } + + return RouteRule{ + ValidMatches: validMatches, + Matches: specRule.Matches, + Filters: routeFilters, + RouteBackendRefs: backendRefs, + }, errors +} + +func getBackendRefs( + routeRule v1.HTTPRouteRule, + routeNamespace string, + inferencePools map[types.NamespacedName]*inference.InferencePool, + rulePath *field.Path, + sp *SessionPersistenceConfig, +) ([]RouteBackendRef, routeRuleErrors) { + var errors routeRuleErrors + backendRefs := make([]RouteBackendRef, 0, len(routeRule.BackendRefs)) - if checkForMixedBackendTypes(specRule, routeNamespace, inferencePools) { + if checkForMixedBackendTypes(routeRule, routeNamespace, inferencePools) { err := field.Forbidden( rulePath.Child("backendRefs"), "mixing InferencePool and non-InferencePool backends in a rule is not supported", ) errors.invalid = append(errors.invalid, err) - return RouteRule{ - ValidMatches: validMatches, - Matches: specRule.Matches, - Filters: routeFilters, - RouteBackendRefs: backendRefs, - }, errors + return backendRefs, errors } // rule.BackendRefs are validated separately because of their special requirements - for _, b := range specRule.BackendRefs { + for _, b := range routeRule.BackendRefs { var interfaceFilters []any if len(b.Filters) > 0 { interfaceFilters = make([]any, 0, len(b.Filters)) @@ -233,7 +292,8 @@ func processHTTPRouteRule( } rbr := RouteBackendRef{ - BackendRef: b.BackendRef, + BackendRef: b.BackendRef, + SessionPersistence: sp, } // If route specifies an InferencePool backend, we need to convert it to its associated @@ -260,28 +320,7 @@ func processHTTPRouteRule( backendRefs = append(backendRefs, rbr) } - if routeFilters.Valid { - for i, filter := range routeFilters.Filters { - if filter.RequestMirror == nil { - continue - } - - rbr := RouteBackendRef{ - BackendRef: v1.BackendRef{ - BackendObjectReference: filter.RequestMirror.BackendRef, - }, - MirrorBackendIdx: helpers.GetPointer(i), - } - backendRefs = append(backendRefs, rbr) - } - } - - return RouteRule{ - ValidMatches: validMatches, - Matches: specRule.Matches, - Filters: routeFilters, - RouteBackendRefs: backendRefs, - }, errors + return backendRefs, errors } func processHTTPRouteRules( @@ -289,7 +328,8 @@ func processHTTPRouteRules( validator validation.HTTPFieldsValidator, resolveExtRefFunc resolveExtRefFilter, inferencePools map[types.NamespacedName]*inference.InferencePool, - routeNamespace string, + routeNsName types.NamespacedName, + featureFlags FeatureFlags, ) (rules []RouteRule, valid bool, conds []conditions.Condition) { rules = make([]RouteRule, len(specRules)) @@ -298,16 +338,15 @@ func processHTTPRouteRules( atLeastOneValid bool ) - for i, rule := range specRules { - rulePath := field.NewPath("spec").Child("rules").Index(i) - + for ruleIdx, rule := range specRules { rr, errors := processHTTPRouteRule( rule, - rulePath, + ruleIdx, validator, resolveExtRefFunc, inferencePools, - routeNamespace, + routeNsName, + featureFlags, ) if rr.ValidMatches && rr.Filters.Valid { @@ -316,7 +355,7 @@ func processHTTPRouteRules( allRulesErrors = allRulesErrors.append(errors) - rules[i] = rr + rules[ruleIdx] = rr } conds = make([]conditions.Condition, 0, 2) @@ -612,7 +651,11 @@ func validateFilterRewrite( return allErrs } -func checkForUnsupportedHTTPFields(rule v1.HTTPRouteRule, rulePath *field.Path) field.ErrorList { +func checkForUnsupportedHTTPFields( + rule v1.HTTPRouteRule, + rulePath *field.Path, + featureFlags FeatureFlags, +) field.ErrorList { var ruleErrors field.ErrorList if rule.Name != nil { @@ -633,10 +676,21 @@ func checkForUnsupportedHTTPFields(rule v1.HTTPRouteRule, rulePath *field.Path) "Retry", )) } - if rule.SessionPersistence != nil { + + if !featureFlags.Plus && rule.SessionPersistence != nil { + ruleErrors = append(ruleErrors, field.Forbidden( + rulePath.Child("sessionPersistence"), + fmt.Sprintf( + "%s OSS users can use `ip_hash` load balancing method via the UpstreamSettingsPolicy for session affinity.", + spErrMsg, + ), + )) + } + + if !featureFlags.Experimental && rule.SessionPersistence != nil { ruleErrors = append(ruleErrors, field.Forbidden( rulePath.Child("sessionPersistence"), - "SessionPersistence", + spErrMsg, )) } diff --git a/internal/controller/state/graph/httproute_test.go b/internal/controller/state/graph/httproute_test.go index ff59e84eba..dfc7797ae9 100644 --- a/internal/controller/state/graph/httproute_test.go +++ b/internal/controller/state/graph/httproute_test.go @@ -2,6 +2,7 @@ package graph import ( "errors" + "fmt" "testing" . "github.com/onsi/gomega" @@ -59,6 +60,7 @@ func createHTTPRoute( BackendObjectReference: gatewayv1.BackendObjectReference{ Kind: helpers.GetPointer[gatewayv1.Kind](kinds.Service), Name: "backend", + Port: helpers.GetPointer[gatewayv1.PortNumber](80), }, }, Filters: []gatewayv1.HTTPRouteFilter{ @@ -92,7 +94,12 @@ func createHTTPRoute( } } -func addFilterToPath(hr *gatewayv1.HTTPRoute, path string, filter gatewayv1.HTTPRouteFilter) { +func addElementsToPath( + hr *gatewayv1.HTTPRoute, + path string, + filter gatewayv1.HTTPRouteFilter, + sp *gatewayv1.SessionPersistence, +) { for i := range hr.Spec.Rules { for _, match := range hr.Spec.Rules[i].Matches { if match.Path == nil { @@ -100,6 +107,10 @@ func addFilterToPath(hr *gatewayv1.HTTPRoute, path string, filter gatewayv1.HTTP } if *match.Path.Value == path { hr.Spec.Rules[i].Filters = append(hr.Spec.Rules[i].Filters, filter) + + if sp != nil { + hr.Spec.Rules[i].SessionPersistence = sp + } } } } @@ -110,6 +121,7 @@ var expRouteBackendRef = RouteBackendRef{ BackendObjectReference: gatewayv1.BackendObjectReference{ Kind: helpers.GetPointer[gatewayv1.Kind](kinds.Service), Name: "backend", + Port: helpers.GetPointer[gatewayv1.PortNumber](80), }, }, Filters: []any{ @@ -130,6 +142,38 @@ func createInferencePoolBackend(name, namespace string) gatewayv1.BackendRef { } } +func getExpRouteBackendRefForPath(path string, spIdx string, sessionName string) RouteBackendRef { + var spName string + if sessionName == "" { + spName = fmt.Sprintf("sp_%s", spIdx) + } else { + spName = sessionName + } + + return RouteBackendRef{ + BackendRef: gatewayv1.BackendRef{ + BackendObjectReference: gatewayv1.BackendObjectReference{ + Kind: helpers.GetPointer[gatewayv1.Kind](kinds.Service), + Name: "backend", + Port: helpers.GetPointer[gatewayv1.PortNumber](80), + }, + }, + Filters: []any{ + gatewayv1.HTTPRouteFilter{ + Type: gatewayv1.HTTPRouteFilterExtensionRef, + }, + }, + SessionPersistence: &SessionPersistenceConfig{ + Valid: true, + Name: spName, + SessionType: gatewayv1.CookieBasedSessionPersistence, + Expiry: "1h", + Path: path, + Idx: spIdx, + }, + } +} + func TestBuildHTTPRoutes(t *testing.T) { t.Parallel() @@ -161,8 +205,16 @@ func TestBuildHTTPRoutes(t *testing.T) { RequestRedirect: &gatewayv1.HTTPRequestRedirectFilter{}, } - addFilterToPath(hr, "/", snippetsFilterRef) - addFilterToPath(hr, "/", requestRedirectFilter) + unNamedSPConfig := &gatewayv1.SessionPersistence{ + AbsoluteTimeout: helpers.GetPointer(gatewayv1.Duration("1h")), + Type: helpers.GetPointer(gatewayv1.CookieBasedSessionPersistence), + CookieConfig: &gatewayv1.CookieConfig{ + LifetimeType: helpers.GetPointer((gatewayv1.PermanentCookieLifetimeType)), + }, + } + + addElementsToPath(hr, "/", snippetsFilterRef, unNamedSPConfig) + addElementsToPath(hr, "/", requestRedirectFilter, nil) hrWrongGateway := createHTTPRoute("hr-2", "some-gateway", "example.com", "/") @@ -238,7 +290,7 @@ func TestBuildHTTPRoutes(t *testing.T) { }, }, Matches: hr.Spec.Rules[0].Matches, - RouteBackendRefs: []RouteBackendRef{expRouteBackendRef}, + RouteBackendRefs: []RouteBackendRef{getExpRouteBackendRefForPath("/", "hr-1_test_0", "")}, }, }, }, @@ -253,7 +305,11 @@ func TestBuildHTTPRoutes(t *testing.T) { }, } - validator := &validationfakes.FakeHTTPFieldsValidator{} + createAllValidValidator := func() *validationfakes.FakeHTTPFieldsValidator { + v := &validationfakes.FakeHTTPFieldsValidator{} + v.ValidateDurationReturns("1h", nil) + return v + } for _, test := range tests { t.Run(test.name, func(t *testing.T) { @@ -271,12 +327,16 @@ func TestBuildHTTPRoutes(t *testing.T) { } routes := buildRoutesForGateways( - validator, + createAllValidValidator(), hrRoutes, map[types.NamespacedName]*gatewayv1.GRPCRoute{}, test.gateways, snippetsFilters, nil, + FeatureFlags{ + Plus: true, + Experimental: true, + }, ) g.Expect(helpers.Diff(test.expected, routes)).To(BeEmpty()) }) @@ -305,13 +365,21 @@ func TestBuildHTTPRoute(t *testing.T) { hrValidWithUnsupportedField := createHTTPRoute("hr-valid-unsupported", gatewayNsName.Name, "example.com", "/") hrValidWithUnsupportedField.Spec.Rules[0].Name = helpers.GetPointer[gatewayv1.SectionName]("unsupported-name") + sp := &gatewayv1.SessionPersistence{ + SessionName: helpers.GetPointer("http-route-session"), + AbsoluteTimeout: helpers.GetPointer(gatewayv1.Duration("1h")), + Type: helpers.GetPointer(gatewayv1.CookieBasedSessionPersistence), + CookieConfig: &gatewayv1.CookieConfig{ + LifetimeType: helpers.GetPointer((gatewayv1.PermanentCookieLifetimeType)), + }, + } // route with valid filter validFilter := gatewayv1.HTTPRouteFilter{ Type: gatewayv1.HTTPRouteFilterRequestRedirect, RequestRedirect: &gatewayv1.HTTPRequestRedirectFilter{}, } hr := createHTTPRoute("hr", gatewayNsName.Name, "example.com", "/", "/filter") - addFilterToPath(hr, "/filter", validFilter) + addElementsToPath(hr, "/filter", validFilter, sp) // invalid routes without filters hrInvalidHostname := createHTTPRoute("hr", gatewayNsName.Name, "", "/") @@ -329,7 +397,7 @@ func TestBuildHTTPRoute(t *testing.T) { }, } hrInvalidFilters := createHTTPRoute("hr", gatewayNsName.Name, "example.com", "/filter") - addFilterToPath(hrInvalidFilters, "/filter", invalidFilter) + addElementsToPath(hrInvalidFilters, "/filter", invalidFilter, nil) // route with invalid matches and filters hrDroppedInvalidMatchesAndInvalidFilters := createHTTPRoute( @@ -340,12 +408,12 @@ func TestBuildHTTPRoute(t *testing.T) { "/filter", "/", ) - addFilterToPath(hrDroppedInvalidMatchesAndInvalidFilters, "/filter", invalidFilter) + addElementsToPath(hrDroppedInvalidMatchesAndInvalidFilters, "/filter", invalidFilter, sp) // route with both invalid and valid filters in the same rule hrDroppedInvalidFilters := createHTTPRoute("hr", gatewayNsName.Name, "example.com", "/filter", "/") - addFilterToPath(hrDroppedInvalidFilters, "/filter", validFilter) - addFilterToPath(hrDroppedInvalidFilters, "/", invalidFilter) + addElementsToPath(hrDroppedInvalidFilters, "/filter", validFilter, sp) + addElementsToPath(hrDroppedInvalidFilters, "/", invalidFilter, sp) // route with duplicate section names hrDuplicateSectionName := createHTTPRoute("hr", gatewayNsName.Name, "example.com", "/") @@ -364,7 +432,7 @@ func TestBuildHTTPRoute(t *testing.T) { Name: "sf", }, } - addFilterToPath(hrValidSnippetsFilter, "/filter", validSnippetsFilterExtRef) + addElementsToPath(hrValidSnippetsFilter, "/filter", validSnippetsFilterExtRef, sp) // route with invalid snippets filter extension ref hrInvalidSnippetsFilter := createHTTPRoute("hr", gatewayNsName.Name, "example.com", "/filter") @@ -376,7 +444,7 @@ func TestBuildHTTPRoute(t *testing.T) { Name: "sf", }, } - addFilterToPath(hrInvalidSnippetsFilter, "/filter", invalidSnippetsFilterExtRef) + addElementsToPath(hrInvalidSnippetsFilter, "/filter", invalidSnippetsFilterExtRef, nil) // route with unresolvable snippets filter extension ref hrUnresolvableSnippetsFilter := createHTTPRoute("hr", gatewayNsName.Name, "example.com", "/filter") @@ -388,12 +456,12 @@ func TestBuildHTTPRoute(t *testing.T) { Name: "does-not-exist", }, } - addFilterToPath(hrUnresolvableSnippetsFilter, "/filter", unresolvableSnippetsFilterExtRef) + addElementsToPath(hrUnresolvableSnippetsFilter, "/filter", unresolvableSnippetsFilterExtRef, nil) // route with two invalid snippets filter extensions refs: (1) invalid group (2) unresolvable hrInvalidAndUnresolvableSnippetsFilter := createHTTPRoute("hr", gatewayNsName.Name, "example.com", "/filter") - addFilterToPath(hrInvalidAndUnresolvableSnippetsFilter, "/filter", invalidSnippetsFilterExtRef) - addFilterToPath(hrInvalidAndUnresolvableSnippetsFilter, "/filter", unresolvableSnippetsFilterExtRef) + addElementsToPath(hrInvalidAndUnresolvableSnippetsFilter, "/filter", invalidSnippetsFilterExtRef, nil) + addElementsToPath(hrInvalidAndUnresolvableSnippetsFilter, "/filter", unresolvableSnippetsFilterExtRef, nil) // routes with an inference pool backend hrInferencePool := createHTTPRoute("hr", gatewayNsName.Name, "example.com", "/") @@ -402,6 +470,9 @@ func TestBuildHTTPRoute(t *testing.T) { BackendRef: createInferencePoolBackend("ipool", gatewayNsName.Namespace), }, } + + // session persistence should not be added for inference pool backends + hrInferencePool.Spec.Rules[0].SessionPersistence = sp // route with an inference pool backend that does not exist hrInferencePoolDoesNotExist := createHTTPRoute("hr", gatewayNsName.Name, "example.com", "/") hrInferencePoolDoesNotExist.Spec.Rules[0].BackendRefs = []gatewayv1.HTTPBackendRef{ @@ -423,16 +494,30 @@ func TestBuildHTTPRoute(t *testing.T) { } return nil }, + ValidateDurationStub: func(_ string) (string, error) { + return "1h", nil + }, + } + + createHTTPValidValidator := func(duration *gatewayv1.Duration) *validationfakes.FakeHTTPFieldsValidator { + v := &validationfakes.FakeHTTPFieldsValidator{} + if duration == nil { + v.ValidateDurationReturns("", nil) + } else { + v.ValidateDurationReturns(string(*duration), nil) + } + return v } tests := []struct { - validator *validationfakes.FakeHTTPFieldsValidator - hr *gatewayv1.HTTPRoute - expected *L7Route - name string + validator *validationfakes.FakeHTTPFieldsValidator + hr *gatewayv1.HTTPRoute + expected *L7Route + name string + plus, experimental bool }{ { - validator: &validationfakes.FakeHTTPFieldsValidator{}, + validator: createHTTPValidValidator(sp.AbsoluteTimeout), hr: hr, expected: &L7Route{ RouteType: RouteTypeHTTP, @@ -465,12 +550,14 @@ func TestBuildHTTPRoute(t *testing.T) { Filters: convertHTTPRouteFilters(hr.Spec.Rules[1].Filters), }, Matches: hr.Spec.Rules[1].Matches, - RouteBackendRefs: []RouteBackendRef{expRouteBackendRef}, + RouteBackendRefs: []RouteBackendRef{getExpRouteBackendRefForPath("/filter", "hr_test_1", "http-route-session")}, }, }, }, }, - name: "normal case", + plus: true, + experimental: true, + name: "normal case", }, { validator: &validationfakes.FakeHTTPFieldsValidator{}, @@ -705,7 +792,6 @@ func TestBuildHTTPRoute(t *testing.T) { }, name: "dropped invalid rule with invalid matches", }, - { validator: validatorInvalidFieldsInRule, hr: hrDroppedInvalidMatchesAndInvalidFilters, @@ -749,7 +835,7 @@ func TestBuildHTTPRoute(t *testing.T) { hrDroppedInvalidMatchesAndInvalidFilters.Spec.Rules[1].Filters, ), }, - RouteBackendRefs: []RouteBackendRef{expRouteBackendRef}, + RouteBackendRefs: []RouteBackendRef{getExpRouteBackendRefForPath("/filter", "hr_test_1", "http-route-session")}, }, { ValidMatches: true, @@ -763,7 +849,9 @@ func TestBuildHTTPRoute(t *testing.T) { }, }, }, - name: "dropped invalid rule with invalid filters and invalid rule with invalid matches", + plus: true, + experimental: true, + name: "dropped invalid rule with invalid filters and invalid rule with invalid matches", }, { validator: validatorInvalidFieldsInRule, @@ -796,7 +884,7 @@ func TestBuildHTTPRoute(t *testing.T) { Filters: convertHTTPRouteFilters(hrDroppedInvalidFilters.Spec.Rules[0].Filters), Valid: true, }, - RouteBackendRefs: []RouteBackendRef{expRouteBackendRef}, + RouteBackendRefs: []RouteBackendRef{getExpRouteBackendRefForPath("/filter", "hr_test_0", "http-route-session")}, }, { ValidMatches: true, @@ -805,12 +893,14 @@ func TestBuildHTTPRoute(t *testing.T) { Filters: convertHTTPRouteFilters(hrDroppedInvalidFilters.Spec.Rules[1].Filters), Valid: false, }, - RouteBackendRefs: []RouteBackendRef{expRouteBackendRef}, + RouteBackendRefs: []RouteBackendRef{getExpRouteBackendRefForPath("/", "hr_test_1", "http-route-session")}, }, }, }, }, - name: "dropped invalid rule with invalid filters", + plus: true, + experimental: true, + name: "dropped invalid rule with invalid filters", }, { validator: validatorInvalidFieldsInRule, @@ -847,12 +937,14 @@ func TestBuildHTTPRoute(t *testing.T) { }, Valid: true, }, - RouteBackendRefs: []RouteBackendRef{expRouteBackendRef}, + RouteBackendRefs: []RouteBackendRef{getExpRouteBackendRefForPath("/filter", "hr_test_0", "http-route-session")}, }, }, }, }, - name: "rule with valid snippets filter extension ref filter", + plus: true, + experimental: true, + name: "rule with valid snippets filter extension ref filter", }, { validator: validatorInvalidFieldsInRule, @@ -1054,7 +1146,9 @@ func TestBuildHTTPRoute(t *testing.T) { }, }, }, - name: "route with an inference pool backend gets converted to service", + plus: true, + experimental: true, + name: "route with an inference pool backend gets converted to service", }, { validator: &validationfakes.FakeHTTPFieldsValidator{}, @@ -1109,8 +1203,17 @@ func TestBuildHTTPRoute(t *testing.T) { inferencePools := map[types.NamespacedName]*inference.InferencePool{ {Namespace: "test", Name: "ipool"}: {}, } - - route := buildHTTPRoute(test.validator, test.hr, gws, snippetsFilters, inferencePools) + route := buildHTTPRoute( + test.validator, + test.hr, + gws, + snippetsFilters, + inferencePools, + FeatureFlags{ + Plus: test.plus, + Experimental: test.experimental, + }, + ) g.Expect(helpers.Diff(test.expected, route)).To(BeEmpty()) }) } @@ -1152,8 +1255,8 @@ func TestBuildHTTPRouteWithMirrorRoutes(t *testing.T) { }, } hr := createHTTPRoute("hr", gatewayNsName.Name, "example.com", "/mirror") - addFilterToPath(hr, "/mirror", mirrorFilter) - addFilterToPath(hr, "/mirror", urlRewriteFilter) + addElementsToPath(hr, "/mirror", mirrorFilter, nil) + addElementsToPath(hr, "/mirror", urlRewriteFilter, nil) // Expected mirror route expectedMirrorRoute := &L7Route{ @@ -1241,11 +1344,23 @@ func TestBuildHTTPRouteWithMirrorRoutes(t *testing.T) { g := NewWithT(t) + featureFlags := FeatureFlags{ + Plus: false, + Experimental: false, + } + routes := map[RouteKey]*L7Route{} - l7route := buildHTTPRoute(validator, hr, gateways, snippetsFilters, nil) + l7route := buildHTTPRoute( + validator, + hr, + gateways, + snippetsFilters, + nil, + featureFlags, + ) g.Expect(l7route).NotTo(BeNil()) - buildHTTPMirrorRoutes(routes, l7route, hr, gateways, snippetsFilters) + buildHTTPMirrorRoutes(routes, l7route, hr, gateways, snippetsFilters, featureFlags) obj, ok := expectedMirrorRoute.Source.(*gatewayv1.HTTPRoute) g.Expect(ok).To(BeTrue()) @@ -1260,10 +1375,10 @@ func TestProcessHTTPRouteRule_InferencePoolWithMultipleBackendRefs(t *testing.T) validator := &validationfakes.FakeHTTPFieldsValidator{} inferencePoolName1 := "primary-pool" inferencePoolName2 := "secondary-pool" - routeNamespace := "test" + routeNsName := types.NamespacedName{Namespace: "test", Name: "hr"} inferencePools := map[types.NamespacedName]*inference.InferencePool{ - {Namespace: routeNamespace, Name: inferencePoolName1}: {}, - {Namespace: routeNamespace, Name: inferencePoolName2}: {}, + {Namespace: routeNsName.Namespace, Name: inferencePoolName1}: {}, + {Namespace: routeNsName.Namespace, Name: inferencePoolName2}: {}, } tests := []struct { @@ -1291,7 +1406,7 @@ func TestProcessHTTPRouteRule_InferencePoolWithMultipleBackendRefs(t *testing.T) Group: helpers.GetPointer[gatewayv1.Group](inferenceAPIGroup), Kind: helpers.GetPointer[gatewayv1.Kind](kinds.InferencePool), Name: gatewayv1.ObjectName(inferencePoolName1), - Namespace: helpers.GetPointer(gatewayv1.Namespace(routeNamespace)), + Namespace: helpers.GetPointer(gatewayv1.Namespace(routeNsName.Namespace)), }, Weight: helpers.GetPointer(int32(70)), }, @@ -1302,7 +1417,7 @@ func TestProcessHTTPRouteRule_InferencePoolWithMultipleBackendRefs(t *testing.T) Group: helpers.GetPointer[gatewayv1.Group](inferenceAPIGroup), Kind: helpers.GetPointer[gatewayv1.Kind](kinds.InferencePool), Name: gatewayv1.ObjectName(inferencePoolName2), - Namespace: helpers.GetPointer(gatewayv1.Namespace(routeNamespace)), + Namespace: helpers.GetPointer(gatewayv1.Namespace(routeNsName.Namespace)), }, Weight: helpers.GetPointer(int32(30)), }, @@ -1330,7 +1445,7 @@ func TestProcessHTTPRouteRule_InferencePoolWithMultipleBackendRefs(t *testing.T) Group: helpers.GetPointer[gatewayv1.Group](inferenceAPIGroup), Kind: helpers.GetPointer[gatewayv1.Kind](kinds.InferencePool), Name: gatewayv1.ObjectName(inferencePoolName1), - Namespace: helpers.GetPointer(gatewayv1.Namespace(routeNamespace)), + Namespace: helpers.GetPointer(gatewayv1.Namespace(routeNsName.Namespace)), }, }, }, @@ -1355,15 +1470,18 @@ func TestProcessHTTPRouteRule_InferencePoolWithMultipleBackendRefs(t *testing.T) t.Parallel() g := NewWithT(t) - rulePath := field.NewPath("spec").Child("rules").Index(0) - + ruleIdx := 0 routeRule, errs := processHTTPRouteRule( tc.specRule, - rulePath, + ruleIdx, validator, nil, inferencePools, - routeNamespace, + routeNsName, + FeatureFlags{ + Plus: false, + Experimental: false, + }, ) if tc.expectValid { @@ -2002,7 +2120,7 @@ func TestUnsupportedFieldsErrors(t *testing.T) { Type: helpers.GetPointer(gatewayv1.SessionPersistenceType("unsupported-session-persistence")), }), }, - expectedErrors: 4, + expectedErrors: 5, }, } @@ -2014,7 +2132,14 @@ func TestUnsupportedFieldsErrors(t *testing.T) { rulePath := field.NewPath("spec").Child("rules") var errors routeRuleErrors - unsupportedFieldsErrors := checkForUnsupportedHTTPFields(test.specRule, rulePath) + unsupportedFieldsErrors := checkForUnsupportedHTTPFields( + test.specRule, + rulePath, + FeatureFlags{ + Plus: false, + Experimental: false, + }, + ) if len(unsupportedFieldsErrors) > 0 { errors.warn = append(errors.warn, unsupportedFieldsErrors...) } @@ -2026,7 +2151,10 @@ func TestUnsupportedFieldsErrors(t *testing.T) { func TestProcessHTTPRouteRules_UnsupportedFields(t *testing.T) { t.Parallel() - routeNamespace := "test" + routeNsName := types.NamespacedName{ + Namespace: "test", + Name: "route", + } tests := []struct { name string @@ -2034,6 +2162,8 @@ func TestProcessHTTPRouteRules_UnsupportedFields(t *testing.T) { expectedConds []conditions.Condition expectedWarns int expectedValid bool + plusEnabled bool + experimental bool }{ { name: "No unsupported fields", @@ -2065,18 +2195,60 @@ func TestProcessHTTPRouteRules_UnsupportedFields(t *testing.T) { }), Retry: helpers.GetPointer(gatewayv1.HTTPRouteRetry{Attempts: helpers.GetPointer(3)}), SessionPersistence: helpers.GetPointer(gatewayv1.SessionPersistence{ - Type: helpers.GetPointer(gatewayv1.SessionPersistenceType("unsupported-session-persistence")), + Type: helpers.GetPointer(gatewayv1.CookieBasedSessionPersistence), + SessionName: helpers.GetPointer("session_id"), }), }, }, expectedValid: true, expectedConds: []conditions.Condition{ - conditions.NewRouteAcceptedUnsupportedField("[spec.rules[0].name: Forbidden: Name, spec.rules[0].timeouts: " + - "Forbidden: Timeouts, spec.rules[0].retry: Forbidden: Retry, " + - "spec.rules[0].sessionPersistence: Forbidden: SessionPersistence]"), - }, + conditions.NewRouteAcceptedUnsupportedField( + fmt.Sprintf("[spec.rules[0].name: Forbidden: Name, spec.rules[0].timeouts: "+ + "Forbidden: Timeouts, spec.rules[0].retry: Forbidden: Retry, "+ + "spec.rules[0].sessionPersistence: Forbidden: "+ + "%s OSS users can use `ip_hash` load balancing method via the UpstreamSettingsPolicy for session affinity.]", + spErrMsg, + )), + }, + experimental: true, + plusEnabled: false, expectedWarns: 4, }, + { + name: "Session persistence unsupported with experimental disabled", + specRules: []gatewayv1.HTTPRouteRule{ + { + SessionPersistence: helpers.GetPointer(gatewayv1.SessionPersistence{ + Type: helpers.GetPointer(gatewayv1.CookieBasedSessionPersistence), + SessionName: helpers.GetPointer("session_id"), + }), + }, + }, + expectedValid: true, + expectedConds: []conditions.Condition{ + conditions.NewRouteAcceptedUnsupportedField(fmt.Sprintf("spec.rules[0].sessionPersistence: Forbidden: "+ + "%s", spErrMsg)), + }, + expectedWarns: 1, + plusEnabled: true, + experimental: false, + }, + { + name: "SessionPersistence field with Plus enabled and experimental enabled", + specRules: []gatewayv1.HTTPRouteRule{ + { + SessionPersistence: helpers.GetPointer(gatewayv1.SessionPersistence{ + Type: helpers.GetPointer(gatewayv1.CookieBasedSessionPersistence), + SessionName: helpers.GetPointer("session_id"), + }), + }, + }, + expectedValid: true, + expectedConds: nil, + expectedWarns: 0, + plusEnabled: true, + experimental: true, + }, } for _, test := range tests { @@ -2089,7 +2261,11 @@ func TestProcessHTTPRouteRules_UnsupportedFields(t *testing.T) { validation.SkipValidator{}, nil, nil, - routeNamespace, + routeNsName, + FeatureFlags{ + Plus: test.plusEnabled, + Experimental: test.experimental, + }, ) g.Expect(valid).To(Equal(test.expectedValid)) diff --git a/internal/controller/state/graph/multiple_gateways_test.go b/internal/controller/state/graph/multiple_gateways_test.go index c20fd98516..3d027661df 100644 --- a/internal/controller/state/graph/multiple_gateways_test.go +++ b/internal/controller/state/graph/multiple_gateways_test.go @@ -409,7 +409,9 @@ func Test_MultipleGateways_WithNginxProxy(t *testing.T) { PolicyValidator: fakePolicyValidator, }, logr.Discard(), - experimentalFeaturesEnabled, + FeatureFlags{ + Experimental: experimentalFeaturesEnabled, + }, ) g.Expect(helpers.Diff(test.expGraph, result)).To(BeEmpty()) @@ -899,7 +901,9 @@ func Test_MultipleGateways_WithListeners(t *testing.T) { PolicyValidator: fakePolicyValidator, }, logr.Discard(), - experimentalFeaturesEnabled, + FeatureFlags{ + Experimental: experimentalFeaturesEnabled, + }, ) g.Expect(helpers.Diff(test.expGraph, result)).To(BeEmpty()) diff --git a/internal/controller/state/graph/policies_test.go b/internal/controller/state/graph/policies_test.go index 251ab5aeb8..e4532f5a54 100644 --- a/internal/controller/state/graph/policies_test.go +++ b/internal/controller/state/graph/policies_test.go @@ -242,7 +242,7 @@ func TestAttachPolicies(t *testing.T) { NGFPolicies: test.ngfPolicies, } - graph.attachPolicies(nil, "nginx-gateway", logr.Discard()) + graph.attachPolicies(&policiesfakes.FakeValidator{}, "nginx-gateway", logr.Discard()) for _, expect := range test.expects { expect(g, graph) } diff --git a/internal/controller/state/graph/route_common.go b/internal/controller/state/graph/route_common.go index 531b516eb1..499cd97171 100644 --- a/internal/controller/state/graph/route_common.go +++ b/internal/controller/state/graph/route_common.go @@ -25,6 +25,9 @@ const ( inferenceAPIGroup = "inference.networking.k8s.io" ) +var spErrMsg = "SessionPersistence is only supported with NGINX Plus " + + "and when experimental features are enabled. This configuration will be ignored." + // ParentRef describes a reference to a parent in a Route. type ParentRef struct { // Attachment is the attachment status of the ParentRef. It could be nil. In that case, NGF didn't attempt to @@ -172,12 +175,30 @@ type RouteBackendRef struct { // InferencePoolName is the name of the InferencePool, if this backendRef is for an InferencePool. InferencePoolName string + // SessionPersistence holds the session persistence configuration for the route rule. + SessionPersistence *SessionPersistenceConfig + Filters []any // IsInferencePool indicates if this backend is an InferencePool disguised as a Service. IsInferencePool bool } +type SessionPersistenceConfig struct { + // Name is the name of the session. + Name string + // Expiry determines the expiry time of the session. + Expiry string + // SessionType is the type of session persistence. + SessionType v1.SessionPersistenceType + // Path is the path for which the session persistence is allowed. + Path string + // Idx is the unique identifier for this configuration in the route rule. + Idx string + // Valid indicates if the session persistence configuration is valid. + Valid bool +} + // CreateRouteKey takes a client.Object and creates a RouteKey. func CreateRouteKey(obj client.Object) RouteKey { nsName := types.NamespacedName{ @@ -259,6 +280,7 @@ func buildRoutesForGateways( gateways map[types.NamespacedName]*Gateway, snippetsFilters map[types.NamespacedName]*SnippetsFilter, inferencePools map[types.NamespacedName]*inference.InferencePool, + featureFlags FeatureFlags, ) map[RouteKey]*L7Route { if len(gateways) == 0 { return nil @@ -267,7 +289,7 @@ func buildRoutesForGateways( routes := make(map[RouteKey]*L7Route) for _, route := range httpRoutes { - r := buildHTTPRoute(validator, route, gateways, snippetsFilters, inferencePools) + r := buildHTTPRoute(validator, route, gateways, snippetsFilters, inferencePools, featureFlags) if r == nil { continue } @@ -275,11 +297,11 @@ func buildRoutesForGateways( routes[CreateRouteKey(route)] = r // if this route has a RequestMirror filter, build a duplicate route for the mirror - buildHTTPMirrorRoutes(routes, r, route, gateways, snippetsFilters) + buildHTTPMirrorRoutes(routes, r, route, gateways, snippetsFilters, featureFlags) } for _, route := range grpcRoutes { - r := buildGRPCRoute(validator, route, gateways, snippetsFilters) + r := buildGRPCRoute(validator, route, gateways, snippetsFilters, featureFlags) if r == nil { continue } @@ -287,7 +309,7 @@ func buildRoutesForGateways( routes[CreateRouteKey(route)] = r // if this route has a RequestMirror filter, build a duplicate route for the mirror - buildGRPCMirrorRoutes(routes, r, route, gateways, snippetsFilters) + buildGRPCMirrorRoutes(routes, r, route, gateways, snippetsFilters, featureFlags) } return routes @@ -1153,3 +1175,164 @@ func routeKeyForKind(kind v1.Kind, nsname types.NamespacedName) RouteKey { return key } + +func getSessionPersistenceKey(ruleIdx int, routeNsName types.NamespacedName) string { + return fmt.Sprintf("%s_%s_%d", routeNsName.Name, routeNsName.Namespace, ruleIdx) +} + +// processSessionPersistenceConfig processes the session persistence configuration. +func processSessionPersistenceConfig[T any]( + sp *v1.SessionPersistence, + routeMatches []T, + rulePath *field.Path, + validator validation.HTTPFieldsValidator, +) (*SessionPersistenceConfig, routeRuleErrors) { + var spConfig SessionPersistenceConfig + expiry, errors := validateSessionPersistenceConfig(sp, rulePath, validator) + + if len(errors.warn) > 0 { + errors.warn = append(errors.warn, field.Invalid( + rulePath, + rulePath.String(), + "session persistence is ignored because there are errors in the configuration", + )) + spConfig.Valid = false + return &spConfig, errors + } + + var sessionName string + if sp.SessionName != nil { + sessionName = *sp.SessionName + } + + var cookieLifetimeType v1.CookieLifetimeType + if sp.CookieConfig != nil { + cookieLifetimeType = *sp.CookieConfig.LifetimeType + } + + if sp.AbsoluteTimeout != nil && cookieLifetimeType == v1.SessionCookieLifetimeType { + expiry = "" + } + + var path string + switch rm := any(routeMatches).(type) { + case []v1.HTTPRouteMatch: + path = deriveCookiePathForHTTPMatches(rm) + case []v1.GRPCRouteMatch: + path = "" + default: + panic("unsupported route match type") + } + + spConfig = SessionPersistenceConfig{ + Valid: true, + Name: sessionName, + SessionType: *sp.Type, + Path: path, + Expiry: expiry, + } + + return &spConfig, errors +} + +// validateSessionPersistenceConfig validates the session persistence configuration. +// Returns warnings for any invalid session persistence configuration. +// but that does not make the route associated with it invalid. +func validateSessionPersistenceConfig( + sp *v1.SessionPersistence, + path *field.Path, + validator validation.HTTPFieldsValidator, +) (string, routeRuleErrors) { + if sp == nil { + return "", routeRuleErrors{} + } + + var errors routeRuleErrors + + if sp.Type != nil && *sp.Type != v1.CookieBasedSessionPersistence { + errors.warn = append(errors.warn, field.NotSupported( + path.Child("type"), + sp.Type, + []string{string(v1.CookieBasedSessionPersistence)}, + )) + } + + if sp.IdleTimeout != nil { + errors.warn = append(errors.warn, field.Forbidden( + path.Child("idleTimeout"), + "IdleTimeout", + )) + } + + var timeout string + if sp.AbsoluteTimeout != nil { + if absoluteTimeout, err := validator.ValidateDuration(string(*sp.AbsoluteTimeout)); err != nil { + errors.warn = append(errors.warn, field.Invalid( + path.Child("absoluteTimeout"), + sp.AbsoluteTimeout, + err.Error(), + )) + } else { + timeout = absoluteTimeout + } + } + + return timeout, errors +} + +func deriveCookiePathForHTTPMatches(matches []v1.HTTPRouteMatch) string { + paths := make([]string, 0, len(matches)) + for _, match := range matches { + paths = append(paths, getCookiePath(match)) + } + + return longestCommonPathPrefix(paths) +} + +// longestCommonPathPrefix returns the longest common path prefix of the given +// paths. +// Examples: +// +// ["/foo/bar", "/foo/baz"] -> "/foo" +// ["/foo/bar", "/foo/bar/b"] -> "/foo/bar" +// ["/foo", "/bar"] -> "" +// [] -> "" +func longestCommonPathPrefix(paths []string) string { + if len(paths) == 0 { + return "" + } + if len(paths) == 1 { + return paths[0] + } + + commonSegs := strings.Split(paths[0], "/") + for _, p := range paths[1:] { + segs := strings.Split(p, "/") + i := 0 + limit := len(commonSegs) + if len(segs) < limit { + limit = len(segs) + } + for i < limit && commonSegs[i] == segs[i] { + i++ + } + // truncate commonSegs to the common prefix + commonSegs = commonSegs[:i] + if len(commonSegs) == 0 { + return "" + } + } + + return strings.Join(commonSegs, "/") +} + +func getCookiePath(match v1.HTTPRouteMatch) string { + pathType := *match.Path.Type + + switch pathType { + case v1.PathMatchExact, v1.PathMatchPathPrefix: + return *match.Path.Value + default: + return "" + } +} diff --git a/internal/controller/state/graph/route_common_test.go b/internal/controller/state/graph/route_common_test.go index 52ff2ca2af..f827e5794e 100644 --- a/internal/controller/state/graph/route_common_test.go +++ b/internal/controller/state/graph/route_common_test.go @@ -15,6 +15,7 @@ import ( "sigs.k8s.io/gateway-api/apis/v1alpha2" "github.com/nginx/nginx-gateway-fabric/v2/internal/controller/state/conditions" + "github.com/nginx/nginx-gateway-fabric/v2/internal/controller/state/validation/validationfakes" "github.com/nginx/nginx-gateway-fabric/v2/internal/framework/helpers" "github.com/nginx/nginx-gateway-fabric/v2/internal/framework/kinds" ) @@ -3867,3 +3868,437 @@ func TestFindAttachableListenersWithPort(t *testing.T) { }) } } + +func TestProcessSessionPersistenceConfiguration(t *testing.T) { + t.Parallel() + + createDurationValidator := func(duration *gatewayv1.Duration) *validationfakes.FakeHTTPFieldsValidator { + v := &validationfakes.FakeHTTPFieldsValidator{} + if duration == nil { + v.ValidateDurationReturns("", nil) + } else { + v.ValidateDurationReturns(string(*duration), nil) + } + return v + } + + sessionPersistencePath := field.NewPath("sessionPersistence") + tests := []struct { + name string + sessionPersistence *gatewayv1.SessionPersistence + expectedResult SessionPersistenceConfig + expectedErrors routeRuleErrors + httpRouteMatches []gatewayv1.HTTPRouteMatch + grpcRouteMatches []gatewayv1.GRPCRouteMatch + }{ + { + name: "session persistence has errors in configuration", + sessionPersistence: &gatewayv1.SessionPersistence{ + Type: helpers.GetPointer(gatewayv1.HeaderBasedSessionPersistence), + }, + expectedErrors: routeRuleErrors{ + warn: field.ErrorList{ + field.NotSupported( + sessionPersistencePath.Child("type"), + helpers.GetPointer(gatewayv1.HeaderBasedSessionPersistence), + []string{string(gatewayv1.CookieBasedSessionPersistence)}, + ), + field.Invalid( + sessionPersistencePath, + sessionPersistencePath.String(), + "session persistence is ignored because there are errors in the configuration", + ), + }, + }, + expectedResult: SessionPersistenceConfig{ + Valid: false, + }, + httpRouteMatches: []gatewayv1.HTTPRouteMatch{ + { + Path: &gatewayv1.HTTPPathMatch{ + Type: helpers.GetPointer(gatewayv1.PathMatchExact), + Value: helpers.GetPointer("/coffee"), + }, + }, + }, + }, + { + name: "when lifetime type of a cookie is Session, timeout is set to 0", + sessionPersistence: &gatewayv1.SessionPersistence{ + SessionName: helpers.GetPointer("session-persistence"), + Type: helpers.GetPointer(gatewayv1.CookieBasedSessionPersistence), + AbsoluteTimeout: helpers.GetPointer(gatewayv1.Duration("20m")), + CookieConfig: &gatewayv1.CookieConfig{ + LifetimeType: helpers.GetPointer(gatewayv1.SessionCookieLifetimeType), + }, + }, + expectedResult: SessionPersistenceConfig{ + Valid: true, + SessionType: gatewayv1.CookieBasedSessionPersistence, + Name: "session-persistence", + Path: "/tea", + }, + httpRouteMatches: []gatewayv1.HTTPRouteMatch{ + { + Path: &gatewayv1.HTTPPathMatch{ + Type: helpers.GetPointer(gatewayv1.PathMatchExact), + Value: helpers.GetPointer("/tea"), + }, + }, + }, + }, + { + name: "valid session persistence configuration for HTTPRoute", + sessionPersistence: &gatewayv1.SessionPersistence{ + SessionName: helpers.GetPointer("session-persistence-http"), + Type: helpers.GetPointer(gatewayv1.CookieBasedSessionPersistence), + AbsoluteTimeout: helpers.GetPointer(gatewayv1.Duration("1h")), + CookieConfig: &gatewayv1.CookieConfig{ + LifetimeType: helpers.GetPointer(gatewayv1.PermanentCookieLifetimeType), + }, + }, + expectedResult: SessionPersistenceConfig{ + Valid: true, + SessionType: gatewayv1.CookieBasedSessionPersistence, + Name: "session-persistence-http", + Expiry: "1h", + Path: "/app/v1", + }, + httpRouteMatches: []gatewayv1.HTTPRouteMatch{ + { + Path: &gatewayv1.HTTPPathMatch{ + Type: helpers.GetPointer(gatewayv1.PathMatchExact), + Value: helpers.GetPointer("/app/v1/users/"), + }, + }, + { + Path: &gatewayv1.HTTPPathMatch{ + Type: helpers.GetPointer(gatewayv1.PathMatchPathPrefix), + Value: helpers.GetPointer("/app/v1/latte/"), + }, + }, + { + Path: &gatewayv1.HTTPPathMatch{ + Type: helpers.GetPointer(gatewayv1.PathMatchExact), + Value: helpers.GetPointer("/app/v1/tea"), + }, + }, + { + Path: &gatewayv1.HTTPPathMatch{ + Type: helpers.GetPointer(gatewayv1.PathMatchPathPrefix), + Value: helpers.GetPointer("/app/v1/coffee"), + }, + }, + }, + }, + { + name: "valid session persistence configuration for GRPCRoute", + sessionPersistence: &gatewayv1.SessionPersistence{ + SessionName: helpers.GetPointer("session-persistence-grpc"), + Type: helpers.GetPointer(gatewayv1.CookieBasedSessionPersistence), + AbsoluteTimeout: helpers.GetPointer(gatewayv1.Duration("30m")), + CookieConfig: &gatewayv1.CookieConfig{ + LifetimeType: helpers.GetPointer(gatewayv1.PermanentCookieLifetimeType), + }, + }, + expectedResult: SessionPersistenceConfig{ + Valid: true, + SessionType: gatewayv1.CookieBasedSessionPersistence, + Name: "session-persistence-grpc", + Expiry: "30m", + }, + grpcRouteMatches: []gatewayv1.GRPCRouteMatch{ + { + Method: &gatewayv1.GRPCMethodMatch{ + Type: helpers.GetPointer(gatewayv1.GRPCMethodMatchExact), + Service: helpers.GetPointer("mymethod.user"), + }, + }, + { + Method: &gatewayv1.GRPCMethodMatch{ + Type: helpers.GetPointer(gatewayv1.GRPCMethodMatchExact), + Service: helpers.GetPointer("mymethod.coffee"), + }, + }, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + t.Parallel() + g := NewWithT(t) + + var result *SessionPersistenceConfig + var errors routeRuleErrors + if test.httpRouteMatches != nil { + result, errors = processSessionPersistenceConfig( + test.sessionPersistence, + test.httpRouteMatches, + sessionPersistencePath, + createDurationValidator(test.sessionPersistence.AbsoluteTimeout), + ) + } + + if test.grpcRouteMatches != nil { + result, errors = processSessionPersistenceConfig( + test.sessionPersistence, + test.grpcRouteMatches, + sessionPersistencePath, + createDurationValidator(test.sessionPersistence.AbsoluteTimeout), + ) + } + + g.Expect(result).To(HaveValue(Equal(test.expectedResult))) + g.Expect(errors).To(Equal(test.expectedErrors)) + }) + } +} + +func TestValidateSessionPersistence(t *testing.T) { + t.Parallel() + + createDurationValidator := func() *validationfakes.FakeHTTPFieldsValidator { + v := &validationfakes.FakeHTTPFieldsValidator{} + v.ValidateDurationReturns("", nil) + return v + } + + createInvalidDurationValidator := func() *validationfakes.FakeHTTPFieldsValidator { + v := &validationfakes.FakeHTTPFieldsValidator{} + v.ValidateDurationReturns("", errors.New("invalid duration format")) + return v + } + + sessionPersistencePath := field.NewPath("sessionPersistence") + tests := []struct { + sessionPersistence *gatewayv1.SessionPersistence + validator *validationfakes.FakeHTTPFieldsValidator + name string + expectedErrors routeRuleErrors + }{ + { + name: "session persistence returns error for invalid type", + sessionPersistence: &gatewayv1.SessionPersistence{ + Type: helpers.GetPointer(gatewayv1.HeaderBasedSessionPersistence), + }, + expectedErrors: routeRuleErrors{ + warn: field.ErrorList{ + field.NotSupported( + sessionPersistencePath.Child("type"), + helpers.GetPointer(gatewayv1.HeaderBasedSessionPersistence), + []string{string(gatewayv1.CookieBasedSessionPersistence)}, + ), + }, + }, + validator: createDurationValidator(), + }, + { + name: "session persistence returns error when idleTimeout is specified", + sessionPersistence: &gatewayv1.SessionPersistence{ + Type: helpers.GetPointer(gatewayv1.CookieBasedSessionPersistence), + IdleTimeout: helpers.GetPointer(gatewayv1.Duration("10m")), + }, + expectedErrors: routeRuleErrors{ + warn: field.ErrorList{ + field.Forbidden( + sessionPersistencePath.Child("idleTimeout"), + "IdleTimeout", + ), + }, + }, + validator: createDurationValidator(), + }, + { + name: "session persistence returns error when absoluteTimeout is invalid", + sessionPersistence: &gatewayv1.SessionPersistence{ + Type: helpers.GetPointer(gatewayv1.CookieBasedSessionPersistence), + AbsoluteTimeout: helpers.GetPointer(gatewayv1.Duration("invalid-duration")), + }, + expectedErrors: routeRuleErrors{ + warn: field.ErrorList{ + field.Invalid( + sessionPersistencePath.Child("absoluteTimeout"), + helpers.GetPointer(gatewayv1.Duration("invalid-duration")), + "invalid duration format", + ), + }, + }, + validator: createInvalidDurationValidator(), + }, + { + name: "valid session persistence returns no errors", + sessionPersistence: &gatewayv1.SessionPersistence{ + SessionName: helpers.GetPointer("session-persistence"), + Type: helpers.GetPointer(gatewayv1.CookieBasedSessionPersistence), + AbsoluteTimeout: helpers.GetPointer(gatewayv1.Duration("30m")), + CookieConfig: &gatewayv1.CookieConfig{ + LifetimeType: helpers.GetPointer(gatewayv1.PermanentCookieLifetimeType), + }, + }, + validator: createDurationValidator(), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + t.Parallel() + g := NewWithT(t) + + _, errors := validateSessionPersistenceConfig( + test.sessionPersistence, + field.NewPath("sessionPersistence"), + test.validator, + ) + g.Expect(errors).To(Equal(test.expectedErrors)) + }) + } +} + +func TestGetCookiePath(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + expectedPath string + matches []gatewayv1.HTTPRouteMatch + }{ + { + name: "no matches returns empty path", + matches: []gatewayv1.HTTPRouteMatch{}, + expectedPath: "", + }, + { + name: "single match with type Exact returns that path", + matches: []gatewayv1.HTTPRouteMatch{ + { + Path: &gatewayv1.HTTPPathMatch{ + Type: helpers.GetPointer(gatewayv1.PathMatchExact), + Value: helpers.GetPointer("/app/users"), + }, + }, + }, + expectedPath: "/app/users", + }, + { + name: "single match with type Prefix returns that path", + matches: []gatewayv1.HTTPRouteMatch{ + { + Path: &gatewayv1.HTTPPathMatch{ + Type: helpers.GetPointer(gatewayv1.PathMatchPathPrefix), + Value: helpers.GetPointer("/app/orders"), + }, + }, + }, + expectedPath: "/app/orders", + }, + { + name: "single match with type Regular Expression returns empty path", + matches: []gatewayv1.HTTPRouteMatch{ + { + Path: &gatewayv1.HTTPPathMatch{ + Type: helpers.GetPointer(gatewayv1.PathMatchRegularExpression), + Value: helpers.GetPointer("/app/[a-z]+/orders"), + }, + }, + }, + expectedPath: "", + }, + { + name: "multiple matches with all three types of matches returns empty path", + matches: []gatewayv1.HTTPRouteMatch{ + { + Path: &gatewayv1.HTTPPathMatch{ + Type: helpers.GetPointer(gatewayv1.PathMatchRegularExpression), + Value: helpers.GetPointer("/app/[a-z]+/orders"), + }, + }, + { + Path: &gatewayv1.HTTPPathMatch{ + Type: helpers.GetPointer(gatewayv1.PathMatchPathPrefix), + Value: helpers.GetPointer("/app/users/login"), + }, + }, + { + Path: &gatewayv1.HTTPPathMatch{ + Type: helpers.GetPointer(gatewayv1.PathMatchExact), + Value: helpers.GetPointer("/app/users/"), + }, + }, + }, + expectedPath: "", + }, + { + name: "multiple matches with all predefined path types Exact and PathPrefix " + + "returns longest common prefix", + matches: []gatewayv1.HTTPRouteMatch{ + { + Path: &gatewayv1.HTTPPathMatch{ + Type: helpers.GetPointer(gatewayv1.PathMatchPathPrefix), + Value: helpers.GetPointer("/app/users/profile/"), + }, + }, + { + Path: &gatewayv1.HTTPPathMatch{ + Type: helpers.GetPointer(gatewayv1.PathMatchPathPrefix), + Value: helpers.GetPointer("/app/users/login"), + }, + }, + { + Path: &gatewayv1.HTTPPathMatch{ + Type: helpers.GetPointer(gatewayv1.PathMatchExact), + Value: helpers.GetPointer("/app/users/orders/"), + }, + }, + { + Path: &gatewayv1.HTTPPathMatch{ + Type: helpers.GetPointer(gatewayv1.PathMatchPathPrefix), + Value: helpers.GetPointer("/app/users/history"), + }, + }, + }, + expectedPath: "/app/users", + }, + { + name: "multiple matches with all predefined path types Exact and PathPrefix " + + "returns empty path when there is no common prefix", + matches: []gatewayv1.HTTPRouteMatch{ + { + Path: &gatewayv1.HTTPPathMatch{ + Type: helpers.GetPointer(gatewayv1.PathMatchExact), + Value: helpers.GetPointer("/app/v1"), + }, + }, + { + Path: &gatewayv1.HTTPPathMatch{ + Type: helpers.GetPointer(gatewayv1.PathMatchPathPrefix), + Value: helpers.GetPointer("/coffee/latte/"), + }, + }, + { + Path: &gatewayv1.HTTPPathMatch{ + Type: helpers.GetPointer(gatewayv1.PathMatchExact), + Value: helpers.GetPointer("/coffee/espresso"), + }, + }, + { + Path: &gatewayv1.HTTPPathMatch{ + Type: helpers.GetPointer(gatewayv1.PathMatchPathPrefix), + Value: helpers.GetPointer("/tea/green"), + }, + }, + }, + expectedPath: "", + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + t.Parallel() + g := NewWithT(t) + + result := deriveCookiePathForHTTPMatches(test.matches) + g.Expect(result).To(Equal(test.expectedPath)) + }) + } +} diff --git a/internal/controller/state/validation/validationfakes/fake_generic_validator.go b/internal/controller/state/validation/validationfakes/fake_generic_validator.go index 8c83a4ff9a..cd162c7359 100644 --- a/internal/controller/state/validation/validationfakes/fake_generic_validator.go +++ b/internal/controller/state/validation/validationfakes/fake_generic_validator.go @@ -52,6 +52,17 @@ type FakeGenericValidator struct { validateNginxSizeReturnsOnCall map[int]struct { result1 error } + ValidateNginxVariableNameStub func(string) error + validateNginxVariableNameMutex sync.RWMutex + validateNginxVariableNameArgsForCall []struct { + arg1 string + } + validateNginxVariableNameReturns struct { + result1 error + } + validateNginxVariableNameReturnsOnCall map[int]struct { + result1 error + } ValidateServiceNameStub func(string) error validateServiceNameMutex sync.RWMutex validateServiceNameArgsForCall []struct { @@ -311,6 +322,67 @@ func (fake *FakeGenericValidator) ValidateNginxSizeReturnsOnCall(i int, result1 }{result1} } +func (fake *FakeGenericValidator) ValidateNginxVariableName(arg1 string) error { + fake.validateNginxVariableNameMutex.Lock() + ret, specificReturn := fake.validateNginxVariableNameReturnsOnCall[len(fake.validateNginxVariableNameArgsForCall)] + fake.validateNginxVariableNameArgsForCall = append(fake.validateNginxVariableNameArgsForCall, struct { + arg1 string + }{arg1}) + stub := fake.ValidateNginxVariableNameStub + fakeReturns := fake.validateNginxVariableNameReturns + fake.recordInvocation("ValidateNginxVariableName", []interface{}{arg1}) + fake.validateNginxVariableNameMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeGenericValidator) ValidateNginxVariableNameCallCount() int { + fake.validateNginxVariableNameMutex.RLock() + defer fake.validateNginxVariableNameMutex.RUnlock() + return len(fake.validateNginxVariableNameArgsForCall) +} + +func (fake *FakeGenericValidator) ValidateNginxVariableNameCalls(stub func(string) error) { + fake.validateNginxVariableNameMutex.Lock() + defer fake.validateNginxVariableNameMutex.Unlock() + fake.ValidateNginxVariableNameStub = stub +} + +func (fake *FakeGenericValidator) ValidateNginxVariableNameArgsForCall(i int) string { + fake.validateNginxVariableNameMutex.RLock() + defer fake.validateNginxVariableNameMutex.RUnlock() + argsForCall := fake.validateNginxVariableNameArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeGenericValidator) ValidateNginxVariableNameReturns(result1 error) { + fake.validateNginxVariableNameMutex.Lock() + defer fake.validateNginxVariableNameMutex.Unlock() + fake.ValidateNginxVariableNameStub = nil + fake.validateNginxVariableNameReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeGenericValidator) ValidateNginxVariableNameReturnsOnCall(i int, result1 error) { + fake.validateNginxVariableNameMutex.Lock() + defer fake.validateNginxVariableNameMutex.Unlock() + fake.ValidateNginxVariableNameStub = nil + if fake.validateNginxVariableNameReturnsOnCall == nil { + fake.validateNginxVariableNameReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.validateNginxVariableNameReturnsOnCall[i] = struct { + result1 error + }{result1} +} + func (fake *FakeGenericValidator) ValidateServiceName(arg1 string) error { fake.validateServiceNameMutex.Lock() ret, specificReturn := fake.validateServiceNameReturnsOnCall[len(fake.validateServiceNameArgsForCall)] diff --git a/internal/controller/state/validation/validationfakes/fake_httpfields_validator.go b/internal/controller/state/validation/validationfakes/fake_httpfields_validator.go index cd5ff2d8f7..f82c40610d 100644 --- a/internal/controller/state/validation/validationfakes/fake_httpfields_validator.go +++ b/internal/controller/state/validation/validationfakes/fake_httpfields_validator.go @@ -18,6 +18,19 @@ type FakeHTTPFieldsValidator struct { skipValidationReturnsOnCall map[int]struct { result1 bool } + ValidateDurationStub func(string) (string, error) + validateDurationMutex sync.RWMutex + validateDurationArgsForCall []struct { + arg1 string + } + validateDurationReturns struct { + result1 string + result2 error + } + validateDurationReturnsOnCall map[int]struct { + result1 string + result2 error + } ValidateFilterHeaderNameStub func(string) error validateFilterHeaderNameMutex sync.RWMutex validateFilterHeaderNameArgsForCall []struct { @@ -235,6 +248,70 @@ func (fake *FakeHTTPFieldsValidator) SkipValidationReturnsOnCall(i int, result1 }{result1} } +func (fake *FakeHTTPFieldsValidator) ValidateDuration(arg1 string) (string, error) { + fake.validateDurationMutex.Lock() + ret, specificReturn := fake.validateDurationReturnsOnCall[len(fake.validateDurationArgsForCall)] + fake.validateDurationArgsForCall = append(fake.validateDurationArgsForCall, struct { + arg1 string + }{arg1}) + stub := fake.ValidateDurationStub + fakeReturns := fake.validateDurationReturns + fake.recordInvocation("ValidateDuration", []interface{}{arg1}) + fake.validateDurationMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1, ret.result2 + } + return fakeReturns.result1, fakeReturns.result2 +} + +func (fake *FakeHTTPFieldsValidator) ValidateDurationCallCount() int { + fake.validateDurationMutex.RLock() + defer fake.validateDurationMutex.RUnlock() + return len(fake.validateDurationArgsForCall) +} + +func (fake *FakeHTTPFieldsValidator) ValidateDurationCalls(stub func(string) (string, error)) { + fake.validateDurationMutex.Lock() + defer fake.validateDurationMutex.Unlock() + fake.ValidateDurationStub = stub +} + +func (fake *FakeHTTPFieldsValidator) ValidateDurationArgsForCall(i int) string { + fake.validateDurationMutex.RLock() + defer fake.validateDurationMutex.RUnlock() + argsForCall := fake.validateDurationArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeHTTPFieldsValidator) ValidateDurationReturns(result1 string, result2 error) { + fake.validateDurationMutex.Lock() + defer fake.validateDurationMutex.Unlock() + fake.ValidateDurationStub = nil + fake.validateDurationReturns = struct { + result1 string + result2 error + }{result1, result2} +} + +func (fake *FakeHTTPFieldsValidator) ValidateDurationReturnsOnCall(i int, result1 string, result2 error) { + fake.validateDurationMutex.Lock() + defer fake.validateDurationMutex.Unlock() + fake.ValidateDurationStub = nil + if fake.validateDurationReturnsOnCall == nil { + fake.validateDurationReturnsOnCall = make(map[int]struct { + result1 string + result2 error + }) + } + fake.validateDurationReturnsOnCall[i] = struct { + result1 string + result2 error + }{result1, result2} +} + func (fake *FakeHTTPFieldsValidator) ValidateFilterHeaderName(arg1 string) error { fake.validateFilterHeaderNameMutex.Lock() ret, specificReturn := fake.validateFilterHeaderNameReturnsOnCall[len(fake.validateFilterHeaderNameArgsForCall)] diff --git a/internal/controller/state/validation/validator.go b/internal/controller/state/validation/validator.go index 10dc1fe8c3..44f2f268e9 100644 --- a/internal/controller/state/validation/validator.go +++ b/internal/controller/state/validation/validator.go @@ -37,6 +37,7 @@ type HTTPFieldsValidator interface { ValidateFilterHeaderName(name string) error ValidateFilterHeaderValue(value string) error ValidatePath(path string) error + ValidateDuration(duration string) (string, error) } // GenericValidator validates any generic values from NGF API resources from the perspective of a data-plane. @@ -49,6 +50,7 @@ type GenericValidator interface { ValidateNginxDuration(duration string) error ValidateNginxSize(size string) error ValidateEndpoint(endpoint string) error + ValidateNginxVariableName(name string) error } // PolicyValidator validates an NGF Policy. @@ -82,3 +84,4 @@ func (SkipValidator) ValidateHostname(string) error { return n func (SkipValidator) ValidateFilterHeaderName(string) error { return nil } func (SkipValidator) ValidateFilterHeaderValue(string) error { return nil } func (SkipValidator) ValidatePath(string) error { return nil } +func (SkipValidator) ValidateDuration(string) (string, error) { return "", nil } diff --git a/tests/cel/common.go b/tests/cel/common.go index 6208c1a75f..067888a2d3 100644 --- a/tests/cel/common.go +++ b/tests/cel/common.go @@ -56,9 +56,11 @@ const ( // UpstreamSettingsPolicy validation errors. const ( - expectedTargetRefKindServiceError = `TargetRefs Kind must be: Service` - expectedTargetRefGroupCoreError = `TargetRefs Group must be core` - expectedTargetRefNameUniqueError = `TargetRef Name must be unique` + expectedTargetRefKindServiceError = `TargetRefs Kind must be: Service` + expectedTargetRefGroupCoreError = `TargetRefs Group must be core` + expectedTargetRefNameUniqueError = `TargetRef Name must be unique` + expectedHashKeyLoadBalancingTypeError = `hashMethodKey is required when loadBalancingMethod ` + + `is 'hash' or 'hash consistent'` ) // SnippetsFilter validation errors. diff --git a/tests/cel/upstreamsettingspolicy_test.go b/tests/cel/upstreamsettingspolicy_test.go index 35a4fbd364..5b0267350f 100644 --- a/tests/cel/upstreamsettingspolicy_test.go +++ b/tests/cel/upstreamsettingspolicy_test.go @@ -7,6 +7,7 @@ import ( gatewayv1 "sigs.k8s.io/gateway-api/apis/v1" ngfAPIv1alpha1 "github.com/nginx/nginx-gateway-fabric/v2/apis/v1alpha1" + "github.com/nginx/nginx-gateway-fabric/v2/internal/framework/helpers" ) func TestUpstreamSettingsPolicyTargetRefKind(t *testing.T) { @@ -372,3 +373,86 @@ func TestUpstreamSettingsPolicyTargetRefNameUniqueness(t *testing.T) { }) } } + +func TestUpstreamSettingsPolicy_LoadBalancing(t *testing.T) { + t.Parallel() + k8sClient := getKubernetesClient(t) + + tests := []struct { + spec ngfAPIv1alpha1.UpstreamSettingsPolicySpec + name string + wantErrors []string + }{ + { + name: "when load balancing method is hash, hash key is required, error expected", + spec: ngfAPIv1alpha1.UpstreamSettingsPolicySpec{ + TargetRefs: []gatewayv1.LocalPolicyTargetReference{ + { + Kind: serviceKind, + Group: coreGroup, + }, + }, + LoadBalancingMethod: helpers.GetPointer(ngfAPIv1alpha1.LoadBalancingTypeHash), + }, + wantErrors: []string{expectedHashKeyLoadBalancingTypeError}, + }, + { + name: "when load balancing method is hash consistent, hash key is required, error expected", + spec: ngfAPIv1alpha1.UpstreamSettingsPolicySpec{ + TargetRefs: []gatewayv1.LocalPolicyTargetReference{ + { + Kind: serviceKind, + Group: coreGroup, + }, + }, + LoadBalancingMethod: helpers.GetPointer(ngfAPIv1alpha1.LoadBalancingTypeHashConsistent), + }, + wantErrors: []string{expectedHashKeyLoadBalancingTypeError}, + }, + { + name: "specify load balancing method as hash and set the hash key, no error expected", + spec: ngfAPIv1alpha1.UpstreamSettingsPolicySpec{ + TargetRefs: []gatewayv1.LocalPolicyTargetReference{ + { + Kind: serviceKind, + Group: coreGroup, + }, + }, + LoadBalancingMethod: helpers.GetPointer(ngfAPIv1alpha1.LoadBalancingTypeHash), + HashMethodKey: helpers.GetPointer(ngfAPIv1alpha1.HashMethodKey("$upstream_connect_time")), + }, + }, + { + name: "specify load balancing method as hash consistent and set the hash key, no error expected", + spec: ngfAPIv1alpha1.UpstreamSettingsPolicySpec{ + TargetRefs: []gatewayv1.LocalPolicyTargetReference{ + { + Kind: serviceKind, + Group: coreGroup, + }, + }, + LoadBalancingMethod: helpers.GetPointer(ngfAPIv1alpha1.LoadBalancingTypeHashConsistent), + HashMethodKey: helpers.GetPointer(ngfAPIv1alpha1.HashMethodKey("$upstream_bytes_sent")), + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + for i := range tt.spec.TargetRefs { + tt.spec.TargetRefs[i].Name = gatewayv1.ObjectName(uniqueResourceName(testTargetRefName)) + } + + upstreamSettingsPolicy := &ngfAPIv1alpha1.UpstreamSettingsPolicy{ + ObjectMeta: controllerruntime.ObjectMeta{ + Name: uniqueResourceName(testResourceName), + Namespace: defaultNamespace, + }, + Spec: tt.spec, + } + validateCrd(t, tt.wantErrors, upstreamSettingsPolicy, k8sClient) + }) + } +} diff --git a/tests/framework/logging.go b/tests/framework/logging.go index ddc02672a3..306fc4b78d 100644 --- a/tests/framework/logging.go +++ b/tests/framework/logging.go @@ -13,7 +13,9 @@ func WithLoggingDisabled() Option { } func LogOptions(opts ...Option) *Options { - options := &Options{logEnabled: true} + options := &Options{ + logEnabled: true, + } for _, opt := range opts { opt(options) } diff --git a/tests/framework/ngf.go b/tests/framework/ngf.go index 0bb98d1956..e51e32dd49 100644 --- a/tests/framework/ngf.go +++ b/tests/framework/ngf.go @@ -39,12 +39,12 @@ type InstallationConfig struct { // InstallGatewayAPI installs the specified version of the Gateway API resources. func InstallGatewayAPI(apiVersion string) ([]byte, error) { - apiPath := fmt.Sprintf("%s/v%s/standard-install.yaml", gwInstallBasePath, apiVersion) - GinkgoWriter.Printf("Installing Gateway API version %q at API path %q\n", apiVersion, apiPath) + apiPath := fmt.Sprintf("%s/v%s/experimental-install.yaml", gwInstallBasePath, apiVersion) + GinkgoWriter.Printf("Installing Gateway API CRDs from experimental channel %q", apiVersion, apiPath) cmd := exec.CommandContext( context.Background(), - "kubectl", "apply", "-f", apiPath, + "kubectl", "apply", "--server-side", "--force-conflicts", "-f", apiPath, ) output, err := cmd.CombinedOutput() if err != nil { @@ -59,8 +59,8 @@ func InstallGatewayAPI(apiVersion string) ([]byte, error) { // UninstallGatewayAPI uninstalls the specified version of the Gateway API resources. func UninstallGatewayAPI(apiVersion string) ([]byte, error) { - apiPath := fmt.Sprintf("%s/v%s/standard-install.yaml", gwInstallBasePath, apiVersion) - GinkgoWriter.Printf("Uninstalling Gateway API version %q at API path %q\n", apiVersion, apiPath) + apiPath := fmt.Sprintf("%s/v%s/experimental-install.yaml", gwInstallBasePath, apiVersion) + GinkgoWriter.Printf("Uninstalling Gateway API CRDs from experimental channel for version %q\n", apiVersion) output, err := exec.CommandContext(context.Background(), "kubectl", "delete", "-f", apiPath).CombinedOutput() if err != nil && !strings.Contains(string(output), "not found") { @@ -84,6 +84,7 @@ func InstallNGF(cfg InstallationConfig, extraArgs ...string) ([]byte, error) { "--namespace", cfg.Namespace, "--wait", "--set", "nginxGateway.snippetsFilters.enable=true", + "--set", "nginxGateway.gwAPIExperimentalFeatures.enable=true", } if cfg.ChartVersion != "" { args = append(args, "--version", cfg.ChartVersion) diff --git a/tests/framework/prometheus.go b/tests/framework/prometheus.go index 37d72e1c49..f175dec899 100644 --- a/tests/framework/prometheus.go +++ b/tests/framework/prometheus.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "log/slog" + "net/http" "os" "os/exec" "time" @@ -542,7 +543,12 @@ func CreateResponseChecker(url, address string, requestTimeout time.Duration, op } return func() error { - status, _, err := Get(url, address, requestTimeout, nil, nil, opts...) + request := Request{ + URL: url, + Address: address, + Timeout: requestTimeout, + } + resp, err := Get(request, opts...) if err != nil { badReqErr := fmt.Errorf("bad response: %w", err) if options.logEnabled { @@ -552,8 +558,8 @@ func CreateResponseChecker(url, address string, requestTimeout time.Duration, op return badReqErr } - if status != 200 { - statusErr := fmt.Errorf("unexpected status code: %d", status) + if resp.StatusCode != http.StatusOK { + statusErr := fmt.Errorf("unexpected status code: %d", resp.StatusCode) if options.logEnabled { GinkgoWriter.Printf("ERROR during creating response checker: %v\n", statusErr) } diff --git a/tests/framework/request.go b/tests/framework/request.go index bd8663b8a3..27308e8e23 100644 --- a/tests/framework/request.go +++ b/tests/framework/request.go @@ -15,18 +15,28 @@ import ( . "github.com/onsi/ginkgo/v2" ) +type Response struct { + Headers http.Header + Body string + StatusCode int +} + +type Request struct { + Body io.Reader + Headers map[string]string + QueryParams map[string]string + URL string + Address string + Timeout time.Duration +} + // Get sends a GET request to the specified url. // It resolves to the specified address instead of using DNS. -// The status and body of the response is returned, or an error. -func Get( - url, address string, - timeout time.Duration, - headers, queryParams map[string]string, - opts ...Option, -) (int, string, error) { +// It returns the response body, headers, and status code. +func Get(request Request, opts ...Option) (Response, error) { options := LogOptions(opts...) - resp, err := makeRequest(http.MethodGet, url, address, nil, timeout, headers, queryParams, opts...) + resp, err := makeRequest(http.MethodGet, request, opts...) if err != nil { if options.logEnabled { GinkgoWriter.Printf( @@ -35,7 +45,7 @@ func Get( ) } - return 0, "", err + return Response{StatusCode: 0}, err } defer resp.Body.Close() @@ -43,24 +53,23 @@ func Get( _, err = body.ReadFrom(resp.Body) if err != nil { GinkgoWriter.Printf("ERROR in Body content: %v returning body: ''\n", err) - return resp.StatusCode, "", err + return Response{StatusCode: resp.StatusCode}, err } if options.logEnabled { GinkgoWriter.Printf("Successfully received response and parsed body: %s\n", body.String()) } - return resp.StatusCode, body.String(), nil + return Response{ + Body: body.String(), + Headers: resp.Header, + StatusCode: resp.StatusCode, + }, nil } // Post sends a POST request to the specified url with the body as the payload. // It resolves to the specified address instead of using DNS. -func Post( - url, address string, - body io.Reader, - timeout time.Duration, - headers, queryParams map[string]string, -) (*http.Response, error) { - response, err := makeRequest(http.MethodPost, url, address, body, timeout, headers, queryParams) +func Post(request Request) (*http.Response, error) { + response, err := makeRequest(http.MethodPost, request) if err != nil { GinkgoWriter.Printf("ERROR occurred during getting response, error: %s\n", err) } @@ -68,13 +77,7 @@ func Post( return response, err } -func makeRequest( - method, url, address string, - body io.Reader, - timeout time.Duration, - headers, queryParams map[string]string, - opts ...Option, -) (*http.Response, error) { +func makeRequest(method string, request Request, opts ...Option) (*http.Response, error) { dialer := &net.Dialer{} transport, ok := http.DefaultTransport.(*http.Transport) @@ -90,10 +93,10 @@ func makeRequest( ) (net.Conn, error) { split := strings.Split(addr, ":") port := split[len(split)-1] - return dialer.DialContext(ctx, network, fmt.Sprintf("%s:%s", address, port)) + return dialer.DialContext(ctx, network, fmt.Sprintf("%s:%s", request.Address, port)) } - ctx, cancel := context.WithTimeout(context.Background(), timeout) + ctx, cancel := context.WithTimeout(context.Background(), request.Timeout) defer cancel() options := LogOptions(opts...) @@ -101,39 +104,41 @@ func makeRequest( requestDetails := fmt.Sprintf( "Method: %s, URL: %s, Address: %s, Headers: %v, QueryParams: %v\n", strings.ToUpper(method), - url, - address, - headers, - queryParams, + request.URL, + request.Address, + request.Headers, + request.QueryParams, ) GinkgoWriter.Printf("Sending request: %s", requestDetails) } - req, err := http.NewRequestWithContext(ctx, method, url, body) + req, err := http.NewRequestWithContext(ctx, method, request.URL, request.Body) if err != nil { return nil, err } - for key, value := range headers { + for key, value := range request.Headers { req.Header.Add(key, value) } - if queryParams != nil { + if request.QueryParams != nil { q := req.URL.Query() - for key, value := range queryParams { + for key, value := range request.QueryParams { q.Add(key, value) } req.URL.RawQuery = q.Encode() } var resp *http.Response - if strings.HasPrefix(url, "https") { + if strings.HasPrefix(request.URL, "https") { // similar to how in our examples with https requests we run our curl command // we turn off verification of the certificate, we do the same here customTransport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} //nolint:gosec // for https test traffic } - client := &http.Client{Transport: customTransport} + client := &http.Client{ + Transport: customTransport, + } resp, err = client.Do(req) if err != nil { return nil, err diff --git a/tests/suite/advanced_routing_test.go b/tests/suite/advanced_routing_test.go index 944a9f0832..9be8ea7af1 100644 --- a/tests/suite/advanced_routing_test.go +++ b/tests/suite/advanced_routing_test.go @@ -120,19 +120,28 @@ func expectRequestToRespondFromExpectedServer( headers, queryParams map[string]string, ) error { GinkgoWriter.Printf("Expecting request to respond from the server %q\n", expServerName) - status, body, err := framework.Get(appURL, address, timeoutConfig.RequestTimeout, headers, queryParams) + + request := framework.Request{ + URL: appURL, + Address: address, + Timeout: timeoutConfig.RequestTimeout, + Headers: headers, + QueryParams: queryParams, + } + + resp, err := framework.Get(request) if err != nil { return err } - if status != http.StatusOK { + if resp.StatusCode != http.StatusOK { statusErr := errors.New("http status was not 200") GinkgoWriter.Printf("ERROR: %v\n", statusErr) return statusErr } - actualServerName, err := extractServerName(body) + actualServerName, err := extractServerName(resp.Body) if err != nil { GinkgoWriter.Printf("ERROR extracting server name from response body: %v\n", err) diff --git a/tests/suite/client_settings_test.go b/tests/suite/client_settings_test.go index 2422364cf6..5b3e74612b 100644 --- a/tests/suite/client_settings_test.go +++ b/tests/suite/client_settings_test.go @@ -254,7 +254,13 @@ var _ = Describe("ClientSettingsPolicy", Ordered, Label("functional", "cspolicy" _, err := rand.Read(payload) Expect(err).ToNot(HaveOccurred()) - resp, err := framework.Post(url, address, bytes.NewReader(payload), timeoutConfig.RequestTimeout, nil, nil) + request := framework.Request{ + URL: url, + Address: address, + Body: bytes.NewReader(payload), + Timeout: timeoutConfig.RequestTimeout, + } + resp, err := framework.Post(request) Expect(err).ToNot(HaveOccurred()) Expect(resp).To(HaveHTTPStatus(expStatus)) diff --git a/tests/suite/graceful_recovery_test.go b/tests/suite/graceful_recovery_test.go index a5d3348958..0b013f7cba 100644 --- a/tests/suite/graceful_recovery_test.go +++ b/tests/suite/graceful_recovery_test.go @@ -555,14 +555,19 @@ var _ = Describe("Graceful Recovery test", Ordered, FlakeAttempts(2), Label("gra }) func expectRequestToSucceed(appURL, address string, responseBodyMessage string) error { - status, body, err := framework.Get(appURL, address, timeoutConfig.RequestTimeout, nil, nil) + request := framework.Request{ + URL: appURL, + Address: address, + Timeout: timeoutConfig.RequestTimeout, + } + resp, err := framework.Get(request) - if status != http.StatusOK { - return fmt.Errorf("http status was not 200, got %d: %w", status, err) + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("http status was not 200, got %d: %w", resp.StatusCode, err) } - if !strings.Contains(body, responseBodyMessage) { - return fmt.Errorf("expected response body to contain correct body message, got: %s", body) + if !strings.Contains(resp.Body, responseBodyMessage) { + return fmt.Errorf("expected response body to contain correct body message, got: %s", resp.Body) } return err @@ -577,13 +582,18 @@ func expectRequestToSucceed(appURL, address string, responseBodyMessage string) // We only want an error returned from this particular function if it does not appear that NGINX has // stopped serving traffic. func expectRequestToFail(appURL, address string) error { - status, body, err := framework.Get(appURL, address, timeoutConfig.RequestTimeout, nil, nil) - if status != 0 { + request := framework.Request{ + URL: appURL, + Address: address, + Timeout: timeoutConfig.RequestTimeout, + } + resp, err := framework.Get(request) + if resp.StatusCode != 0 { return errors.New("expected http status to be 0") } - if body != "" { - return fmt.Errorf("expected response body to be empty, instead received: %s", body) + if resp.Body != "" { + return fmt.Errorf("expected response body to be empty, instead received: %s", resp.Body) } if err == nil { diff --git a/tests/suite/manifests/session-persistence/cafe.yaml b/tests/suite/manifests/session-persistence/cafe.yaml new file mode 100644 index 0000000000..9c1a83548a --- /dev/null +++ b/tests/suite/manifests/session-persistence/cafe.yaml @@ -0,0 +1,65 @@ +apiVersion: apps/v1 +kind: Deployment +metadata: + name: coffee +spec: + replicas: 3 + selector: + matchLabels: + app: coffee + template: + metadata: + labels: + app: coffee + spec: + containers: + - name: coffee + image: nginxdemos/nginx-hello:plain-text + ports: + - containerPort: 8080 +--- +apiVersion: v1 +kind: Service +metadata: + name: coffee +spec: + ports: + - port: 80 + targetPort: 8080 + protocol: TCP + name: http + selector: + app: coffee +--- +apiVersion: apps/v1 +kind: Deployment +metadata: + name: tea +spec: + replicas: 3 + selector: + matchLabels: + app: tea + template: + metadata: + labels: + app: tea + spec: + containers: + - name: tea + image: nginxdemos/nginx-hello:plain-text + ports: + - containerPort: 8080 +--- +apiVersion: v1 +kind: Service +metadata: + name: tea +spec: + ports: + - port: 80 + targetPort: 8080 + protocol: TCP + name: http + selector: + app: tea diff --git a/tests/suite/manifests/session-persistence/gateway.yaml b/tests/suite/manifests/session-persistence/gateway.yaml new file mode 100644 index 0000000000..e6507f613b --- /dev/null +++ b/tests/suite/manifests/session-persistence/gateway.yaml @@ -0,0 +1,11 @@ +apiVersion: gateway.networking.k8s.io/v1 +kind: Gateway +metadata: + name: gateway +spec: + gatewayClassName: nginx + listeners: + - name: http + port: 80 + protocol: HTTP + hostname: "*.example.com" diff --git a/tests/suite/manifests/session-persistence/grpc-backends.yaml b/tests/suite/manifests/session-persistence/grpc-backends.yaml new file mode 100644 index 0000000000..fc5011f92b --- /dev/null +++ b/tests/suite/manifests/session-persistence/grpc-backends.yaml @@ -0,0 +1,44 @@ +apiVersion: v1 +kind: Service +metadata: + name: grpc-backend +spec: + selector: + app: grpc-backend + ports: + - protocol: TCP + port: 8080 + targetPort: 50051 +--- +apiVersion: apps/v1 +kind: Deployment +metadata: + name: grpc-backend + labels: + app: grpc-backend +spec: + replicas: 3 + selector: + matchLabels: + app: grpc-backend + template: + metadata: + labels: + app: grpc-backend + spec: + containers: + - name: grpc-backend + image: ghcr.io/nginx/kic-test-grpc-server:0.2.6 + ports: + - containerPort: 50051 + env: + - name: POD_NAME + valueFrom: + fieldRef: + fieldPath: metadata.name + readinessProbe: + tcpSocket: + port: 50051 + resources: + requests: + cpu: 10m diff --git a/tests/suite/manifests/session-persistence/route-invalid-sp-config.yaml b/tests/suite/manifests/session-persistence/route-invalid-sp-config.yaml new file mode 100644 index 0000000000..0ff8d6481b --- /dev/null +++ b/tests/suite/manifests/session-persistence/route-invalid-sp-config.yaml @@ -0,0 +1,49 @@ +apiVersion: gateway.networking.k8s.io/v1 +kind: HTTPRoute +metadata: + name: route-invalid-sp +spec: + parentRefs: + - name: gateway + sectionName: http + hostnames: + - "cafe.example.com" + rules: + - matches: + - path: + type: Exact + value: / + backendRefs: + - name: tea + port: 80 + sessionPersistence: + sessionName: invalid-cookie + type: Header + idleTimeout: 30m + absoluteTimeout: 10000h # duration too long for NGINX + cookieConfig: + lifetimeType: Session +--- +apiVersion: gateway.networking.k8s.io/v1 +kind: GRPCRoute +metadata: + name: grpc-route-invalid-sp +spec: + parentRefs: + - name: gateway + sectionName: http + rules: + - matches: + - method: + service: helloworld.Greeter + method: SayHello + backendRefs: + - name: grpc-backend + port: 8080 + sessionPersistence: + sessionName: invalid-cookie + type: Header + idleTimeout: 30m + absoluteTimeout: 10000h # duration too long for NGINX + cookieConfig: + lifetimeType: Session diff --git a/tests/suite/manifests/session-persistence/routes-oss.yaml b/tests/suite/manifests/session-persistence/routes-oss.yaml new file mode 100644 index 0000000000..361542cfe1 --- /dev/null +++ b/tests/suite/manifests/session-persistence/routes-oss.yaml @@ -0,0 +1,35 @@ +apiVersion: gateway.networking.k8s.io/v1 +kind: HTTPRoute +metadata: + name: coffee +spec: + parentRefs: + - name: gateway + sectionName: http + hostnames: + - "cafe.example.com" + rules: + - matches: + - path: + type: PathPrefix + value: /coffee + backendRefs: + - name: coffee + port: 80 +--- +apiVersion: gateway.networking.k8s.io/v1 +kind: GRPCRoute +metadata: + name: grpc-route +spec: + parentRefs: + - name: gateway + sectionName: http + rules: + - matches: + - method: + service: helloworld.Greeter + method: SayHello + backendRefs: + - name: grpc-backend + port: 8080 diff --git a/tests/suite/manifests/session-persistence/routes-plus.yaml b/tests/suite/manifests/session-persistence/routes-plus.yaml new file mode 100644 index 0000000000..d1c370d02f --- /dev/null +++ b/tests/suite/manifests/session-persistence/routes-plus.yaml @@ -0,0 +1,81 @@ +# GRPC Route with method match and unnamed session persistence configuration +apiVersion: gateway.networking.k8s.io/v1 +kind: GRPCRoute +metadata: + name: grpc-route +spec: + parentRefs: + - name: gateway + sectionName: http + rules: + - matches: + - method: + service: helloworld.Greeter + method: SayHello + backendRefs: + - name: grpc-backend + port: 8080 + sessionPersistence: + type: Cookie + absoluteTimeout: 24h + cookieConfig: + lifetimeType: Permanent +--- +# Route with multiple path matches(common prefix /shop) and unnamed session persistence configuration +apiVersion: gateway.networking.k8s.io/v1 +kind: HTTPRoute +metadata: + name: coffee +spec: + parentRefs: + - name: gateway + hostnames: + - cafe.example.com + rules: + - matches: + - path: + type: PathPrefix + value: /coffee + - path: + type: PathPrefix + value: /coffee/snacks + - path: + type: PathPrefix + value: /coffee/orders/checkout + - path: + type: PathPrefix + value: /coffee/desserts + backendRefs: + - name: coffee + port: 80 + sessionPersistence: + type: Cookie + absoluteTimeout: 48h + cookieConfig: + lifetimeType: Permanent +--- +# Route with regex path match and named session persistence configuration +apiVersion: gateway.networking.k8s.io/v1 +kind: HTTPRoute +metadata: + name: tea +spec: + parentRefs: + - name: gateway + sectionName: http + hostnames: + - "cafe.example.com" + rules: + - matches: + - path: + type: RegularExpression + value: /tea/[a-z]+/flavors + backendRefs: + - name: tea + port: 80 + sessionPersistence: + sessionName: tea-cookie + type: Cookie + absoluteTimeout: 48h + cookieConfig: + lifetimeType: Session diff --git a/tests/suite/manifests/session-persistence/usp.yaml b/tests/suite/manifests/session-persistence/usp.yaml new file mode 100644 index 0000000000..4e7b3a19ae --- /dev/null +++ b/tests/suite/manifests/session-persistence/usp.yaml @@ -0,0 +1,13 @@ +apiVersion: gateway.nginx.org/v1alpha1 +kind: UpstreamSettingsPolicy +metadata: + name: usp-ip-hash +spec: + targetRefs: + - group: core + kind: Service + name: coffee + - group: core + kind: Service + name: grpc-backend + loadBalancingMethod: "ip_hash" diff --git a/tests/suite/sample_test.go b/tests/suite/sample_test.go index f133cc6acd..c5f13ebfff 100644 --- a/tests/suite/sample_test.go +++ b/tests/suite/sample_test.go @@ -64,16 +64,21 @@ var _ = Describe("Basic test example", Label("functional"), func() { Eventually( func() error { - status, body, err := framework.Get(url, address, timeoutConfig.RequestTimeout, nil, nil) + request := framework.Request{ + URL: url, + Address: address, + Timeout: timeoutConfig.RequestTimeout, + } + resp, err := framework.Get(request) if err != nil { return err } - if status != http.StatusOK { - return fmt.Errorf("status not 200; got %d", status) + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("status not 200; got %d", resp.StatusCode) } expBody := "URI: /hello" - if !strings.Contains(body, expBody) { - return fmt.Errorf("bad body: got %s; expected %s", body, expBody) + if !strings.Contains(resp.Body, expBody) { + return fmt.Errorf("bad body: got %s; expected %s", resp.Body, expBody) } return nil }). diff --git a/tests/suite/session_persistence_test.go b/tests/suite/session_persistence_test.go new file mode 100644 index 0000000000..4e0433b1d6 --- /dev/null +++ b/tests/suite/session_persistence_test.go @@ -0,0 +1,608 @@ +package main + +import ( + "context" + "fmt" + "net/http" + "strings" + "time" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + core "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/types" + "k8s.io/apimachinery/pkg/util/wait" + "sigs.k8s.io/controller-runtime/pkg/client" + gatewayv1 "sigs.k8s.io/gateway-api/apis/v1" + + "github.com/nginx/nginx-gateway-fabric/v2/tests/framework" +) + +var invalidSPErrMsgs = "[spec.rules[0].sessionPersistence.type: Unsupported value: \"Header\": " + + "supported values: \"Cookie\", spec.rules[0].sessionPersistence.idleTimeout: " + + "Forbidden: IdleTimeout, spec.rules[0].sessionPersistence.absoluteTimeout: " + + "Invalid value: \"10000h\": duration is too large for NGINX format (exceeds 9999h), " + + "spec.rules[0].sessionPersistence: Invalid value: \"spec.rules[0].sessionPersistence\":" + + " session persistence is ignored because there are errors in the configuration]" + +var _ = Describe("SessionPersistence OSS", Ordered, Label("functional", "session-persistence-oss"), func() { + var ( + files = []string{ + "session-persistence/cafe.yaml", + "session-persistence/grpc-backends.yaml", + "session-persistence/gateway.yaml", + "session-persistence/routes-oss.yaml", + } + + namespace = "session-persistence-oss" + gatewayName = "gateway" + + nginxPodName string + ) + + BeforeAll(func() { + ns := &core.Namespace{ + ObjectMeta: metav1.ObjectMeta{ + Name: namespace, + }, + } + + Expect(resourceManager.Apply([]client.Object{ns})).To(Succeed()) + Expect(resourceManager.ApplyFromFiles(files, namespace)).To(Succeed()) + Expect(resourceManager.WaitForAppsToBeReady(namespace)).To(Succeed()) + + nginxPodNames, err := resourceManager.GetReadyNginxPodNames( + namespace, + timeoutConfig.GetStatusTimeout, + ) + Expect(err).ToNot(HaveOccurred()) + Expect(nginxPodNames).To(HaveLen(1)) + + nginxPodName = nginxPodNames[0] + + setUpPortForward(nginxPodName, namespace) + }) + + AfterAll(func() { + framework.AddNginxLogsAndEventsToReport(resourceManager, namespace) + cleanUpPortForward() + + Expect(resourceManager.DeleteNamespace(namespace)).To(Succeed()) + }) + + When("LoadBalancingMethod `ip-hash` is used for session affinity", func() { + uspFiles := []string{ + "session-persistence/usp.yaml", + } + + BeforeAll(func() { + Expect(resourceManager.ApplyFromFiles(uspFiles, namespace)).To(Succeed()) + }) + + AfterAll(func() { + Expect(resourceManager.DeleteFromFiles(uspFiles, namespace)).To(Succeed()) + }) + + Specify("upstreamSettingsPolicies are accepted", func() { + usPolicy := "usp-ip-hash" + + uspolicyNsName := types.NamespacedName{Name: usPolicy, Namespace: namespace} + + err := waitForUSPolicyStatus( + uspolicyNsName, + gatewayName, + metav1.ConditionTrue, + gatewayv1.PolicyReasonAccepted, + ) + Expect(err).ToNot(HaveOccurred(), fmt.Sprintf("%s was not accepted", usPolicy)) + }) + + Context("verify working traffic", func() { + It("should return 200 response for HTTPRoute `coffee` from the same backend", func() { + port := 80 + if portFwdPort != 0 { + port = portFwdPort + } + baseCoffeeURL := fmt.Sprintf("http://cafe.example.com:%d%s", port, "/coffee") + + Eventually( + func() error { + return expectRequestToSucceedAndRespondFromTheSameBackend(baseCoffeeURL, address, "URI: /coffee", 11) + }). + WithTimeout(timeoutConfig.RequestTimeout). + WithPolling(500 * time.Millisecond). + Should(Succeed()) + }) + }) + + Context("nginx directives", func() { + var conf *framework.Payload + + BeforeAll(func() { + var err error + conf, err = resourceManager.GetNginxConfig(nginxPodName, namespace, "") + Expect(err).ToNot(HaveOccurred()) + }) + + DescribeTable("are set properly for", + func(expCfgs []framework.ExpectedNginxField) { + for _, expCfg := range expCfgs { + Expect(framework.ValidateNginxFieldExists(conf, expCfg)).To(Succeed()) + } + }, + Entry("HTTP upstream", []framework.ExpectedNginxField{ + { + Directive: "upstream", + Value: "session-persistence-oss_coffee_80", + File: "http.conf", + }, + { + Directive: "ip_hash", + Upstream: "session-persistence-oss_coffee_80", + File: "http.conf", + }, + }), + Entry("GRPC upstream", []framework.ExpectedNginxField{ + { + Directive: "upstream", + Value: "session-persistence-oss_grpc-backend_8080", + File: "http.conf", + }, + { + Directive: "ip_hash", + Upstream: "session-persistence-oss_grpc-backend_8080", + File: "http.conf", + }, + }), + ) + }) + }) +}) + +var _ = Describe("SessionPersistence Plus", Ordered, Label("functional", "session-persistence-plus"), func() { + var ( + files = []string{ + "session-persistence/cafe.yaml", + "session-persistence/grpc-backends.yaml", + "session-persistence/gateway.yaml", + "session-persistence/routes-plus.yaml", + } + + namespace = "session-persistence-plus" + + nginxPodName string + ) + + BeforeAll(func() { + ns := &core.Namespace{ + ObjectMeta: metav1.ObjectMeta{ + Name: namespace, + }, + } + + Expect(resourceManager.Apply([]client.Object{ns})).To(Succeed()) + Expect(resourceManager.ApplyFromFiles(files, namespace)).To(Succeed()) + Expect(resourceManager.WaitForAppsToBeReady(namespace)).To(Succeed()) + + nginxPodNames, err := resourceManager.GetReadyNginxPodNames( + namespace, + timeoutConfig.GetStatusTimeout, + ) + Expect(err).ToNot(HaveOccurred()) + Expect(nginxPodNames).To(HaveLen(1)) + + nginxPodName = nginxPodNames[0] + setUpPortForward(nginxPodName, namespace) + }) + + AfterAll(func() { + framework.AddNginxLogsAndEventsToReport(resourceManager, namespace) + cleanUpPortForward() + + Expect(resourceManager.DeleteNamespace(namespace)).To(Succeed()) + }) + + When("sticky cookies are used for session persistence in NGINX Plus", func() { + var baseCoffeeURL, baseTeaURL string + + BeforeAll(func() { + port := 80 + if portFwdPort != 0 { + port = portFwdPort + } + + baseCoffeeURL = fmt.Sprintf("http://cafe.example.com:%d%s", port, "/coffee") + baseTeaURL = fmt.Sprintf("http://cafe.example.com:%d%s", port, "/tea/location/flavors") + }) + + Context("verify working traffic", func() { + It("should return 200 responses from the same backend for HTTPRoutes `coffee` and `tea`", func() { + if !*plusEnabled { + Skip("Skipping Session Persistence Plus tests on NGINX OSS deployment") + } + Eventually( + func() error { + return expectRequestToSucceedAndReuseCookie(baseCoffeeURL, address, "URI: /coffee", 11) + }). + WithTimeout(timeoutConfig.RequestTimeout). + WithPolling(500 * time.Millisecond). + Should(Succeed()) + + Eventually( + func() error { + return expectRequestToSucceedAndReuseCookie(baseTeaURL, address, "URI: /tea/location/flavors", 11) + }). + WithTimeout(timeoutConfig.RequestTimeout). + WithPolling(500 * time.Millisecond). + Should(Succeed()) + }) + }) + + Context("nginx directives", func() { + var conf *framework.Payload + + BeforeAll(func() { + var err error + conf, err = resourceManager.GetNginxConfig(nginxPodName, namespace, "") + Expect(err).ToNot(HaveOccurred()) + }) + + DescribeTable("are set properly for", + func(expCfgs []framework.ExpectedNginxField) { + if !*plusEnabled { + Skip("Skipping Session Persistence Plus tests on NGINX OSS deployment") + } + for _, expCfg := range expCfgs { + Expect(framework.ValidateNginxFieldExists(conf, expCfg)).To(Succeed()) + } + }, + Entry("HTTP upstreams", []framework.ExpectedNginxField{ + { + Directive: "upstream", + Value: "session-persistence-plus_coffee_80_coffee_session-persistence-plus_0", + File: "http.conf", + }, + { + Directive: "sticky", + Value: "cookie sp_coffee_session-persistence-plus_0 expires=48h path=/coffee", + Upstream: "session-persistence-plus_coffee_80_coffee_session-persistence-plus_0", + File: "http.conf", + }, + { + Directive: "state", + Value: "/var/lib/nginx/state/session-persistence-plus_coffee_80.conf", + Upstream: "session-persistence-plus_coffee_80_coffee_session-persistence-plus_0", + File: "http.conf", + }, + { + Directive: "upstream", + Value: "session-persistence-plus_tea_80_tea_session-persistence-plus_0", + File: "http.conf", + }, + { + Directive: "sticky", + Value: "cookie tea-cookie", + Upstream: "session-persistence-plus_tea_80_tea_session-persistence-plus_0", + File: "http.conf", + }, + { + Directive: "state", + Value: "/var/lib/nginx/state/session-persistence-plus_tea_80.conf", + Upstream: "session-persistence-plus_tea_80_tea_session-persistence-plus_0", + File: "http.conf", + }, + }), + Entry("GRPC upstream", []framework.ExpectedNginxField{ + { + Directive: "upstream", + Value: "session-persistence-plus_grpc-backend_8080_grpc-route_session-persistence-plus_0", + File: "http.conf", + }, + { + Directive: "sticky", + Value: "cookie sp_grpc-route_session-persistence-plus_0 expires=24h", + Upstream: "session-persistence-plus_grpc-backend_8080_grpc-route_session-persistence-plus_0", + File: "http.conf", + }, + { + Directive: "state", + Upstream: "session-persistence-plus_grpc-backend_8080_grpc-route_session-persistence-plus_0", + Value: "/var/lib/nginx/state/session-persistence-plus_grpc-backend_8080.conf", + File: "http.conf", + }, + }), + ) + }) + }) + + When("Routes have an invalid session persistence configuration", func() { + BeforeAll(func() { + routeFile := "session-persistence/route-invalid-sp-config.yaml" + Expect(resourceManager.ApplyFromFiles([]string{routeFile}, namespace)).To(Succeed()) + }) + + It("updates the HTTPRoute status with all relevant validation errors", func() { + if !*plusEnabled { + Skip("Skipping Session Persistence Plus tests on NGINX OSS deployment") + } + routeNsName := types.NamespacedName{Name: "route-invalid-sp", Namespace: namespace} + err := waitForHTTPRouteToHaveErrorMessage(routeNsName) + Expect(err).ToNot(HaveOccurred(), "expected route to report invalid session persistence configuration") + }) + + It("updates the HTTPRoute status with all relevant validation errors", func() { + if !*plusEnabled { + Skip("Skipping Session Persistence Plus tests on NGINX OSS deployment") + } + routeNsName := types.NamespacedName{Name: "grpc-route-invalid-sp", Namespace: namespace} + err := waitForGRPCRouteToHaveErrorMessage(routeNsName) + Expect(err).ToNot(HaveOccurred(), "expected route to report invalid session persistence configuration") + }) + }) +}) + +func waitForHTTPRouteToHaveErrorMessage(routeNsName types.NamespacedName) error { + ctx, cancel := context.WithTimeout(context.Background(), timeoutConfig.GetStatusTimeout) + defer cancel() + + GinkgoWriter.Printf( + "Waiting for %q to have the condition Accepted/True/Accepted with the right error message\n", + routeNsName, + ) + + return wait.PollUntilContextCancel( + ctx, + 500*time.Millisecond, + true, /* poll immediately */ + func(ctx context.Context) (bool, error) { + var route gatewayv1.HTTPRoute + if err := resourceManager.Get(ctx, routeNsName, &route); err != nil { + return false, err + } + + return checkRouteStatus( + route.Status.RouteStatus, + gatewayv1.RouteConditionAccepted, + metav1.ConditionTrue, + invalidSPErrMsgs, + ) + }, + ) +} + +func waitForGRPCRouteToHaveErrorMessage(routeNsName types.NamespacedName) error { + ctx, cancel := context.WithTimeout(context.Background(), timeoutConfig.GetStatusTimeout) + defer cancel() + + GinkgoWriter.Printf( + "Waiting for %q to have the condition Accepted/True/Accepted with the right error message\n", + routeNsName, + ) + + return wait.PollUntilContextCancel( + ctx, + 500*time.Millisecond, + true, /* poll immediately */ + func(ctx context.Context) (bool, error) { + var route gatewayv1.GRPCRoute + if err := resourceManager.Get(ctx, routeNsName, &route); err != nil { + return false, err + } + + return checkRouteStatus( + route.Status.RouteStatus, + gatewayv1.RouteConditionAccepted, + metav1.ConditionTrue, + invalidSPErrMsgs) + }, + ) +} + +func checkRouteStatus( + rs gatewayv1.RouteStatus, + conditionType gatewayv1.RouteConditionType, + condStatus metav1.ConditionStatus, + expectedReasonSubstring string, +) (bool, error) { + var err error + if len(rs.Parents) == 0 { + GinkgoWriter.Printf("route does not have a status yet\n") + return false, nil + } + if len(rs.Parents) != 1 { + err := fmt.Errorf("route has %d parents, expected 1", len(rs.Parents)) + GinkgoWriter.Printf("ERROR: %v\n", err) + return false, err + } + + parent := rs.Parents[0] + if parent.Conditions == nil { + err := fmt.Errorf("route has no conditions in its status") + GinkgoWriter.Printf("ERROR: %v\n", err) + return false, err + } + if len(parent.Conditions) != 2 { + err := fmt.Errorf("expected route to have only two conditions, instead has %d", len(parent.Conditions)) + GinkgoWriter.Printf("ERROR: %v\n", err) + return false, err + } + + cond := parent.Conditions[1] + if cond.Type != string(conditionType) && + cond.Status != condStatus && + !strings.Contains(cond.Reason, expectedReasonSubstring) { + err := fmt.Errorf( + "expected route condition to be Type=%s, Status=%s, "+ + "Reason contains=%s; instead got Type=%s, Status=%s, Reason=%s", + conditionType, condStatus, expectedReasonSubstring, cond.Type, cond.Status, cond.Reason, + ) + GinkgoWriter.Printf("ERROR: %v\n", err) + return false, err + } + + return err == nil, nil +} + +func expectRequestToSucceedAndRespondFromTheSameBackend( + appURL, + address, + responseBodyMessage string, + totalRequests int, +) error { + var firstServerName string + + for i := range totalRequests { + request := framework.Request{ + URL: appURL, + Address: address, + Timeout: timeoutConfig.RequestTimeout, + } + resp, err := framework.Get(request) + if err != nil { + return fmt.Errorf("request %d to %s failed: %w", i+1, appURL, err) + } + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("request %d: http status was not 200, got %d", i+1, resp.StatusCode) + } + + if !strings.Contains(resp.Body, responseBodyMessage) { + return fmt.Errorf("request %d: expected response body to contain %q, got: %s", i+1, responseBodyMessage, resp.Body) + } + + serverName, err := extractServerName(resp.Body) + if err != nil { + return fmt.Errorf("request %d: failed to extract server name: %w; body: %s", i+1, err, resp.Body) + } + + if i == 0 { + firstServerName = serverName + continue + } + + // subsequent replies must come from the same backend. + if serverName != firstServerName { + return fmt.Errorf( + "request %d: expected server name %q, got %q resulting in `ip-hash` stickiness failure", + i+1, firstServerName, serverName, + ) + } + } + + return nil +} + +func expectRequestToSucceedAndReuseCookie( + appURL, + address, + responseBodyMessage string, + totalRequests int, +) error { + var firstServerName string + cookieAttr := make(map[string]string, 0) + + for i := range totalRequests { + headers := make(map[string]string, 0) + + // send cookie token after first response + if i > 0 { + if cookieAttr == nil { + return fmt.Errorf("request %d: cookie attributes are nil after first response", i+1) + } + + headers["Cookie"] = fmt.Sprintf("%s=%s", cookieAttr["name"], cookieAttr["value"]) + } + + request := framework.Request{ + URL: appURL, + Address: address, + Timeout: timeoutConfig.RequestTimeout, + Headers: headers, + } + + resp, err := framework.Get(request) + if err != nil { + return fmt.Errorf("request %d to %s failed: %w", i+1, appURL, err) + } + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("request %d: http status was not 200, got %d", i+1, resp.StatusCode) + } + + if !strings.Contains(resp.Body, responseBodyMessage) { + return fmt.Errorf( + "request %d: expected response body to contain %q, got: %s", + i+1, responseBodyMessage, resp.Body, + ) + } + + serverName, err := extractServerName(resp.Body) + if err != nil { + return fmt.Errorf( + "request %d: failed to extract server name: %w; body: %s", + i+1, err, resp.Body, + ) + } + + // get the cookie token from the first response + if i == 0 { + cookieAttr, err = extractCookieInformationFromResponseHeaders(resp.Headers) + if err != nil { + return fmt.Errorf( + "request %d: failed to extract cookie from response headers: %w; body: %s", + i+1, err, resp.Body, + ) + } + + firstServerName = serverName + continue + } + + if serverName != firstServerName { + return fmt.Errorf( + "request %d: expected server name %q, got %q (session persistence failed)", + i+1, firstServerName, serverName, + ) + } + } + + return nil +} + +func extractCookieInformationFromResponseHeaders(h http.Header) (map[string]string, error) { + values := h.Values("Set-Cookie") + if len(values) == 0 { + return nil, fmt.Errorf("no Set-Cookie header found in response") + } + + raw := strings.TrimSpace(values[0]) + if raw == "" { + return nil, fmt.Errorf("empty Set-Cookie header") + } + + parts := strings.Split(raw, ";") + if len(parts) == 0 { + return nil, fmt.Errorf("malformed Set-Cookie header: %q", raw) + } + + // first part is cookie-name=value + pair := strings.TrimSpace(parts[0]) + nv := strings.SplitN(pair, "=", 2) + if len(nv) != 2 { + return nil, fmt.Errorf("malformed Set-Cookie header (no name=value): %q", raw) + } + + name := strings.TrimSpace(nv[0]) + value := strings.TrimSpace(nv[1]) + if name == "" || value == "" { + return nil, fmt.Errorf("malformed Set-Cookie header (empty name or value): %q", raw) + } + + result := map[string]string{ + "name": name, + "value": value, + } + + return result, nil +} diff --git a/tests/suite/tracing_test.go b/tests/suite/tracing_test.go index 1df06598df..b52012fc17 100644 --- a/tests/suite/tracing_test.go +++ b/tests/suite/tracing_test.go @@ -133,19 +133,17 @@ var _ = Describe("Tracing", FlakeAttempts(2), Ordered, Label("functional", "trac for range count { Eventually( func() error { - status, _, err := framework.Get( - url, - address, - timeoutConfig.RequestTimeout, - nil, - nil, - framework.WithLoggingDisabled(), - ) + request := framework.Request{ + URL: url, + Address: address, + Timeout: timeoutConfig.RequestTimeout, + } + resp, err := framework.Get(request, framework.WithLoggingDisabled()) if err != nil { return err } - if status != http.StatusOK { - return fmt.Errorf("status not 200; got %d", status) + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("status not 200; got %d", resp.StatusCode) } return nil }).