Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion cmd/epp/runner/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,7 @@ func (r *Runner) parseConfigurationPhaseOne(ctx context.Context) (*configapi.End

loader.RegisterFeatureGate(datalayer.FeatureGate)
loader.RegisterFeatureGate(flowcontrol.FeatureGate)
loader.RegisterFeatureGate(datalayer.PrepareDataPluginsFeatureGate)

r.registerInTreePlugins()

Expand Down Expand Up @@ -504,8 +505,9 @@ func (r *Runner) parseConfigurationPhaseTwo(ctx context.Context, rawConfig *conf

// Add requestControl plugins
r.requestControlConfig.AddPlugins(handle.GetAllPlugins()...)

// Sort prepare data plugins in DAG order (topological sort). Also check prepare data plugins for cycles.
if r.requestControlConfig.PrepareDataPluginGraph() != nil {
if r.requestControlConfig.PrepareDataPluginGraph(r.featureGates[datalayer.PrepareDataPluginsFeatureGate]) != nil {
return nil, errors.New("failed to load the configuration - prepare data plugins have cyclic dependencies")
}

Expand Down
3 changes: 2 additions & 1 deletion pkg/epp/datalayer/factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ import (
)

const (
FeatureGate = "dataLayer"
FeatureGate = "dataLayer"
PrepareDataPluginsFeatureGate = "prepareDataPlugins"
)

// PoolInfo represents the DataStore information needed for endpoints.
Expand Down
45 changes: 45 additions & 0 deletions pkg/epp/datalayer/plugins/data_types.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*
Copyright 2025 The Kubernetes Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package plugins

import (
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer"
)

const (
PrefixCacheMatchInfoKey = "PrefixCacheMatchInfoKey"
)

type PrefixCacheMatchInfo struct {
matchPercentage float64
}

func NewPrefixCacheMatchInfo(matchPercentage float64) *PrefixCacheMatchInfo {
return &PrefixCacheMatchInfo{
matchPercentage: matchPercentage,
}
}

func (p *PrefixCacheMatchInfo) MatchPercentage() float64 {
return p.matchPercentage
}

func (p *PrefixCacheMatchInfo) Clone() datalayer.Cloneable {
return &PrefixCacheMatchInfo{
matchPercentage: p.matchPercentage,
}
}
6 changes: 5 additions & 1 deletion pkg/epp/requestcontrol/request_control_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,11 @@ func (c *Config) AddPlugins(pluginObjects ...plugins.Plugin) {

// PrepareDataPluginGraph creates data dependency graph and sorts the plugins in topological order.
// If a cycle is detected, it returns an error.
func (c *Config) PrepareDataPluginGraph() error {
func (c *Config) PrepareDataPluginGraph(enablePrepareDataPlugins bool) error {
if !enablePrepareDataPlugins {
c.prepareDataPlugins = []PrepareDataPlugin{}
return nil
}
dag := buildDAG(c.prepareDataPlugins)
plugins, err := sortPlugins(dag, c.prepareDataPlugins)
if err != nil {
Expand Down
35 changes: 35 additions & 0 deletions pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (
k8stypes "k8s.io/apimachinery/pkg/types"
"sigs.k8s.io/controller-runtime/pkg/log"

dplugins "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer/plugins"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/requestcontrol"
Expand Down Expand Up @@ -206,6 +207,40 @@ func (p *Plugin) WithName(name string) *Plugin {
return p
}

func (p *Plugin) Produces() map[string]any {
return map[string]any{dplugins.PrefixCacheMatchInfoKey: dplugins.PrefixCacheMatchInfo{}}
}

func (p *Plugin) Consumes() map[string]any {
return map[string]any{}
}

func (p *Plugin) PrepareRequestData(ctx context.Context, request *types.LLMRequest, pods []types.Pod) error {
// pre score step, hashing prompt and find longest prefix match.
hashes := hashPrompt(ctx, request, getBlockSize(pods, p.config), p.config.MaxPrefixBlocksToMatch)
state := &SchedulingContextState{
PrefixHashes: hashes,
PrefixCacheServers: p.matchLongestPrefix(ctx, hashes),
}
for server, matchLen := range state.PrefixCacheServers {
log.FromContext(ctx).V(logutil.TRACE).Info("prefix cached state", "server", server, "longest-prefix-match", matchLen)

}

total := len(state.PrefixHashes)
podScoreFunc := func(pod types.Pod) float64 {
if total == 0 {
return 0
}
matchLen := state.PrefixCacheServers[ServerID(pod.GetPod().NamespacedName)]
return float64(matchLen) / float64(total)
}
for _, pod := range pods {
pod.Put(dplugins.PrefixCacheMatchInfoKey, dplugins.NewPrefixCacheMatchInfo(podScoreFunc(pod)))
}
return nil
}

// Score returns the scoring result for the given list of pods based on context.
func (p *Plugin) Score(ctx context.Context, cycleState *types.CycleState, request *types.LLMRequest, pods []types.Pod) map[types.Pod]float64 {
// pre score step, hashing prompt and find longest prefix match.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,13 @@ import (
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend"
backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/requestcontrol"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
)

// static check to ensure Plugin implements the PrepareDataPlugin interface.
var _ requestcontrol.PrepareDataPlugin = &Plugin{}

func TestPrefixPluginCompletion(t *testing.T) {
config := Config{
BlockSize: 4,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/*
Copyright 2025 The Kubernetes Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package scorer

import (
"context"
"encoding/json"

k8stypes "k8s.io/apimachinery/pkg/types"
dplugins "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer/plugins"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
)

const (
PrefixCacheMatchScorerType = "prefix-cache-match-scorer"
)

type ServerID k8stypes.NamespacedName

// compile-time type assertion
var _ framework.Scorer = &PrefixCacheScorer{}

// PrefixCacheScorerFactory defines the factory function for PrefixCacheScorer.
func PrefixCacheScorerFactory(name string, _ json.RawMessage, _ plugins.Handle) (plugins.Plugin, error) {
return NewPrefixCacheScorer().WithName(name), nil
}

// NewPrefixCacheScorer initializes a new PrefixCacheScorer and returns its pointer.
func NewPrefixCacheScorer() *PrefixCacheScorer {
return &PrefixCacheScorer{
tn: plugins.TypedName{Type: PrefixCacheMatchScorerType, Name: PrefixCacheMatchScorerType},
}
}

// PrefixCacheScorer scores list of candidate pods based on Lora affinity and availability.
type PrefixCacheScorer struct {
tn plugins.TypedName
}

// TypedName returns the type and name tuple of this plugin instance.
func (s *PrefixCacheScorer) TypedName() plugins.TypedName {
return s.tn
}

// Consumes returns the list of data that is consumed by the plugin.
func (s *PrefixCacheScorer) Consumes() map[string]any {
return map[string]any{}
}

// WithName sets the name of the scorer.
func (s *PrefixCacheScorer) WithName(name string) *PrefixCacheScorer {
s.tn.Name = name
return s
}

func (s *PrefixCacheScorer) Score(_ context.Context, cycleState *types.CycleState, _ *types.LLMRequest, pods []types.Pod) map[types.Pod]float64 {
// calculate the scores of pods
scores := make(map[types.Pod]float64, len(pods))

for _, pod := range pods {
matchPercent, ok := pod.Get(dplugins.PrefixCacheMatchInfoKey)
if !ok {
scores[pod] = 0.0
continue
}
scores[pod] = matchPercent.(*dplugins.PrefixCacheMatchInfo).MatchPercentage()
}
return scores
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
/*
Copyright 2025 The Kubernetes Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package scorer

import (
"context"
"testing"

"github.com/stretchr/testify/assert"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend"
backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer"
dplugins "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer/plugins"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
)

// mockPod is a mock implementation of types.Pod for testing purposes.
type mockPod struct {
data map[string]datalayer.Cloneable
}

func newMockPod() *mockPod {
return &mockPod{
data: make(map[string]datalayer.Cloneable),
}
}

func (p *mockPod) Get(key string) (datalayer.Cloneable, bool) {
val, ok := p.data[key]
return val, ok
}

func (p *mockPod) Put(key string, value datalayer.Cloneable) {
p.data[key] = value
}

func (p *mockPod) GetPod() *backend.Pod {
return nil
}

func (p *mockPod) GetMetrics() *backendmetrics.MetricsState {
return nil
}

func (p *mockPod) String() string {
return ""
}

func (p *mockPod) Keys() []string {
keys := make([]string, 0, len(p.data))
for k := range p.data {
keys = append(keys, k)
}
return keys
}

func TestPrefixCacheScorer_Score(t *testing.T) {
pod1 := newMockPod()
pod1.Put(dplugins.PrefixCacheMatchInfoKey, dplugins.NewPrefixCacheMatchInfo(50.0))

pod2 := newMockPod()
pod2.Put(dplugins.PrefixCacheMatchInfoKey, dplugins.NewPrefixCacheMatchInfo(100.0))

pod3 := newMockPod()

testCases := []struct {
name string
pods []types.Pod
expected map[types.Pod]float64
}{
{
name: "pods with prefix cache match percent",
pods: []types.Pod{pod1, pod2},
expected: map[types.Pod]float64{
pod1: 50.0,
pod2: 100.0,
},
},
{
name: "pod without prefix cache match percent",
pods: []types.Pod{pod3},
expected: map[types.Pod]float64{
pod3: 0.0,
},
},
{
name: "mixed pods",
pods: []types.Pod{pod1, pod3},
expected: map[types.Pod]float64{
pod1: 50.0,
pod3: 0.0,
},
},
{
name: "empty pods list",
pods: []types.Pod{},
expected: map[types.Pod]float64{},
},
}

scorer := NewPrefixCacheScorer()

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
scores := scorer.Score(context.Background(), nil, nil, tc.pods)
assert.Equal(t, tc.expected, scores)
})
}
}

func TestNewPrefixCacheScorer(t *testing.T) {
scorer := NewPrefixCacheScorer()
assert.NotNil(t, scorer)
assert.Equal(t, PrefixCacheMatchScorerType, scorer.tn.Type)
assert.Equal(t, PrefixCacheMatchScorerType, scorer.tn.Name)
}

func TestPrefixCacheScorer_WithName(t *testing.T) {
scorer := NewPrefixCacheScorer()
customName := "custom-scorer"
scorer.WithName(customName)
assert.Equal(t, customName, scorer.TypedName().Name)
}

func TestPrefixCacheScorer_TypedName(t *testing.T) {
scorer := NewPrefixCacheScorer()
tn := scorer.TypedName()
assert.Equal(t, PrefixCacheMatchScorerType, tn.Type)
assert.Equal(t, PrefixCacheMatchScorerType, tn.Name)
}

func TestPrefixCacheScorer_Consumes(t *testing.T) {
scorer := NewPrefixCacheScorer()
consumes := scorer.Consumes()
assert.Empty(t, consumes)
}