Skip to content

feat: issue-187: support gpu model filter #196

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions api/v1/workloadprofile_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ type WorkloadProfileSpec struct {
// Schedule the workload to the same GPU server that runs vGPU worker for best performance, default to false
IsLocalGPU bool `json:"isLocalGPU,omitempty"`

// +optional
// GPUModel specifies the required GPU model (e.g., "A100", "H100")
GPUModel string `json:"gpuModel,omitempty"`

// +optional
// TODO, not implemented
// The number of GPUs to be used by the workload, default to 1
Expand Down
3 changes: 3 additions & 0 deletions internal/constants/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ const (
AutoScaleRequestsAnnotation = Domain + "/auto-requests"
AutoScaleReplicasAnnotation = Domain + "/auto-replicas"

// GPUModelAnnotation specifies the required GPU model (e.g., "A100", "H100")
GPUModelAnnotation = Domain + "/gpu-model"

GpuReleasedAnnotation = Domain + "/gpu-released"

TensorFusionPodCounterKeyAnnotation = Domain + "/pod-counter-key"
Expand Down
2 changes: 1 addition & 1 deletion internal/controller/tensorfusionworkload_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ func (r *TensorFusionWorkloadReconciler) scaleUpWorkers(ctx context.Context, wor
// Create worker pods
for range count {
// Schedule GPU for the worker
gpus, err := r.Allocator.Alloc(ctx, workload.Spec.PoolName, workload.Spec.Resources.Requests, 1)
gpus, err := r.Allocator.Alloc(ctx, workload.Spec.PoolName, workload.Spec.Resources.Requests, 1, workload.Spec.GPUModel)
if err != nil {
r.Recorder.Eventf(workload, corev1.EventTypeWarning, "ScheduleGPUFailed", "Failed to schedule GPU: %v", err)
return ctrl.Result{RequeueAfter: constants.PendingRequeueDuration}, nil
Expand Down
47 changes: 47 additions & 0 deletions internal/controller/tensorfusionworkload_controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,53 @@ var _ = Describe("TensorFusionWorkload Controller", func() {
})
})

Context("When specifying GPU model in workload", func() {
It("Should allocate GPUs of the specified model", func() {
pool := tfEnv.GetGPUPool(0)

// Create a workload requesting specific GPU model
workload := createTensorFusionWorkload(pool.Name, key, 1)
Eventually(func(g Gomega) {
// Get the latest version of the workload
g.Expect(k8sClient.Get(ctx, client.ObjectKeyFromObject(workload), workload)).To(Succeed())
// Set the GPU model
workload.Spec.GPUModel = "mock"
// Update the workload
g.Expect(k8sClient.Update(ctx, workload)).To(Succeed())
}, timeout, interval).Should(Succeed())

checkWorkerPodCount(workload)
checkWorkloadStatus(workload)

// Verify pods got GPUs of the correct model
podList := &corev1.PodList{}
// First make sure the pod exists
Eventually(func(g Gomega) {
g.Expect(k8sClient.List(ctx, podList,
client.InNamespace(key.Namespace),
client.MatchingLabels{constants.WorkloadKey: key.Name})).Should(Succeed())
g.Expect(podList.Items).Should(HaveLen(1))
}, timeout, interval).Should(Succeed())

// Now check if the pod has the correct GPU
Eventually(func(g Gomega) {
// Get the latest version of the pod
pod := &corev1.Pod{}
g.Expect(k8sClient.Get(ctx, client.ObjectKey{
Namespace: podList.Items[0].Namespace,
Name: podList.Items[0].Name,
}, pod)).Should(Succeed())
gpuName := pod.Labels[constants.GpuKey]
gpuList := tfEnv.GetPoolGpuList(0)
gpu, ok := lo.Find(gpuList.Items, func(gpu tfv1.GPU) bool {
return gpu.Name == gpuName
})
g.Expect(ok).To(BeTrue())
g.Expect(gpu.Status.GPUModel).To(Equal("mock"))
}, timeout, interval).Should(Succeed())
})
})

Context("When deleting workload directly", func() {
It("Should delete all pods and the workload itself", func() {
pool := tfEnv.GetGPUPool(0)
Expand Down
34 changes: 34 additions & 0 deletions internal/gpuallocator/filter/gpu_model_filter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package filter

import (
"context"

tfv1 "github.com/NexusGPU/tensor-fusion/api/v1"
)

// GPUModelFilter filters GPUs based on their model (e.g., A100, H100)
type GPUModelFilter struct {
requiredModel string
}

// NewGPUModelFilter creates a new filter that matches GPUs with the specified model
func NewGPUModelFilter(model string) *GPUModelFilter {
return &GPUModelFilter{
requiredModel: model,
}
}

// Filter implements GPUFilter interface
func (f *GPUModelFilter) Filter(ctx context.Context, gpus []tfv1.GPU) ([]tfv1.GPU, error) {
if f.requiredModel == "" {
return gpus, nil
}

var filtered []tfv1.GPU
for _, gpu := range gpus {
if gpu.Status.GPUModel == f.requiredModel {
filtered = append(filtered, gpu)
}
}
return filtered, nil
}
99 changes: 99 additions & 0 deletions internal/gpuallocator/filter/gpu_model_filter_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
package filter

import (
"context"
"testing"

tfv1 "github.com/NexusGPU/tensor-fusion/api/v1"
"github.com/stretchr/testify/assert"
"k8s.io/apimachinery/pkg/api/resource"
)

func TestGPUModelFilter(t *testing.T) {
tests := []struct {
name string
requiredModel string
gpus []tfv1.GPU
want int
wantErr bool
}{
{
name: "filter A100 GPUs",
requiredModel: "A100",
gpus: []tfv1.GPU{
{
Status: tfv1.GPUStatus{
GPUModel: "A100",
Available: &tfv1.Resource{
Tflops: resource.MustParse("100"),
Vram: resource.MustParse("40Gi"),
},
},
},
{
Status: tfv1.GPUStatus{
GPUModel: "H100",
Available: &tfv1.Resource{
Tflops: resource.MustParse("200"),
Vram: resource.MustParse("80Gi"),
},
},
},
},
want: 1,
wantErr: false,
},
{
name: "no model specified",
requiredModel: "",
gpus: []tfv1.GPU{
{
Status: tfv1.GPUStatus{
GPUModel: "A100",
Available: &tfv1.Resource{
Tflops: resource.MustParse("100"),
Vram: resource.MustParse("40Gi"),
},
},
},
},
want: 1,
wantErr: false,
},
{
name: "non-existent model",
requiredModel: "NonExistentModel",
gpus: []tfv1.GPU{
{
Status: tfv1.GPUStatus{
GPUModel: "A100",
Available: &tfv1.Resource{
Tflops: resource.MustParse("100"),
Vram: resource.MustParse("40Gi"),
},
},
},
},
want: 0,
wantErr: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
filter := NewGPUModelFilter(tt.requiredModel)
got, err := filter.Filter(context.Background(), tt.gpus)
if tt.wantErr {
assert.Error(t, err)
return
}
assert.NoError(t, err)
assert.Len(t, got, tt.want)
if tt.want > 0 && tt.requiredModel != "" {
for _, gpu := range got {
assert.Equal(t, tt.requiredModel, gpu.Status.GPUModel)
}
}
})
}
}
7 changes: 7 additions & 0 deletions internal/gpuallocator/gpuallocator.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,19 @@ func (s *GpuAllocator) Alloc(
poolName string,
request tfv1.Resource,
count uint,
gpuModel string,
) ([]*tfv1.GPU, error) {
// Get GPUs from the pool using the in-memory store
poolGPUs := s.listGPUsFromPool(poolName)

// Add SameNodeFilter if count > 1 to ensure GPUs are from the same node
filterRegistry := s.filterRegistry.With(filter.NewResourceFilter(request))

// Add GPU model filter if specified
if gpuModel != "" {
filterRegistry = filterRegistry.With(filter.NewGPUModelFilter(gpuModel))
}

if count > 1 {
filterRegistry = filterRegistry.With(filter.NewSameNodeFilter(count))
}
Expand Down
29 changes: 23 additions & 6 deletions internal/gpuallocator/gpuallocator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ var _ = Describe("GPU Allocator", func() {
Vram: resource.MustParse("8Gi"),
}

gpus, err := allocator.Alloc(ctx, "test-pool", request, 1)
gpus, err := allocator.Alloc(ctx, "test-pool", request, 1, "")
Expect(err).NotTo(HaveOccurred())
Expect(gpus).To(HaveLen(1))

Expand All @@ -78,7 +78,7 @@ var _ = Describe("GPU Allocator", func() {
Vram: resource.MustParse("4Gi"),
}

gpus, err := allocator.Alloc(ctx, "test-pool", request, 2)
gpus, err := allocator.Alloc(ctx, "test-pool", request, 2, "")
Expect(err).NotTo(HaveOccurred())
Expect(gpus).To(HaveLen(2))

Expand All @@ -95,7 +95,7 @@ var _ = Describe("GPU Allocator", func() {
Vram: resource.MustParse("2Gi"),
}

_, err := allocator.Alloc(ctx, "test-pool", request, 10)
_, err := allocator.Alloc(ctx, "test-pool", request, 10, "")
Expect(err).To(HaveOccurred())
})

Expand All @@ -105,7 +105,7 @@ var _ = Describe("GPU Allocator", func() {
Vram: resource.MustParse("64Gi"),
}

_, err := allocator.Alloc(ctx, "test-pool", request, 1)
_, err := allocator.Alloc(ctx, "test-pool", request, 1, "")
Expect(err).To(HaveOccurred())
})

Expand All @@ -115,7 +115,24 @@ var _ = Describe("GPU Allocator", func() {
Vram: resource.MustParse("2Gi"),
}

_, err := allocator.Alloc(ctx, "nonexistent-pool", request, 1)
_, err := allocator.Alloc(ctx, "nonexistent-pool", request, 1, "")
Expect(err).To(HaveOccurred())
})

It("should filter GPUs by model", func() {
request := tfv1.Resource{
Tflops: resource.MustParse("50"),
Vram: resource.MustParse("8Gi"),
}

// Try allocating with a specific GPU model
gpus, err := allocator.Alloc(ctx, "test-pool", request, 1, "NVIDIA A100")
Expect(err).NotTo(HaveOccurred())
Expect(gpus).To(HaveLen(1))
Expect(gpus[0].Status.GPUModel).To(Equal("NVIDIA A100"))

// Try allocating with a non-existent GPU model
_, err = allocator.Alloc(ctx, "test-pool", request, 1, "NonExistentModel")
Expect(err).To(HaveOccurred())
})
})
Expand All @@ -128,7 +145,7 @@ var _ = Describe("GPU Allocator", func() {
Vram: resource.MustParse("6Gi"),
}

gpus, err := allocator.Alloc(ctx, "test-pool", request, 1)
gpus, err := allocator.Alloc(ctx, "test-pool", request, 1, "")
Expect(err).NotTo(HaveOccurred())
Expect(gpus).To(HaveLen(1))

Expand Down
6 changes: 6 additions & 0 deletions internal/webhook/v1/tf_parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ type TFResource struct {
VramRequest resource.Quantity
TflopsLimit resource.Quantity
VramLimit resource.Quantity
GPUModel string // Required GPU model (e.g., A100, H100)
}

type TensorFusionInfo struct {
Expand Down Expand Up @@ -138,6 +139,11 @@ func ParseTensorFusionInfo(ctx context.Context, k8sClient client.Client, pod *co
return info, fmt.Errorf("inject container not found")
}

gpuModel, ok := pod.Annotations[constants.GPUModelAnnotation]
if ok {
workloadProfile.Spec.GPUModel = gpuModel
}

info.Profile = &workloadProfile.Spec
info.ContainerNames = containerNames
return info, nil
Expand Down