Skip to content

feat: Allow clients to request multiple GPUs #204

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
May 31, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions api/v1/workloadprofile_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ type WorkloadProfileSpec struct {
PoolName string `json:"poolName,omitempty"`

// +optional
Resources Resources `json:"resources,omitempty"`

Resources Resources `json:"resources,omitempty"`
// +optional
// Qos defines the quality of service level for the client.
Qos QoSLevel `json:"qos,omitempty"`
Expand All @@ -53,10 +53,8 @@ type WorkloadProfileSpec struct {
// GPUModel specifies the required GPU model (e.g., "A100", "H100")
GPUModel string `json:"gpuModel,omitempty"`

// +optional
// TODO, not implemented
// The number of GPUs to be used by the workload, default to 1
GPUCount int `json:"gpuCount,omitempty"`
GPUCount uint `json:"gpuCount,omitempty"`

// +optional
// TODO, not implemented
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,10 @@ spec:
description: The number of GPUs to be used by the workload, default
to 1
type: integer
gpuModel:
description: GPUModel specifies the required GPU model (e.g., "A100",
"H100")
type: string
isLocalGPU:
description: Schedule the workload to the same GPU server that runs
vGPU worker for best performance, default to false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,10 @@ spec:
description: The number of GPUs to be used by the workload, default
to 1
type: integer
gpuModel:
description: GPUModel specifies the required GPU model (e.g., "A100",
"H100")
type: string
isLocalGPU:
description: Schedule the workload to the same GPU server that runs
vGPU worker for best performance, default to false
Expand Down
4 changes: 4 additions & 0 deletions config/crd/bases/tensor-fusion.ai_tensorfusionworkloads.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,10 @@ spec:
description: The number of GPUs to be used by the workload, default
to 1
type: integer
gpuModel:
description: GPUModel specifies the required GPU model (e.g., "A100",
"H100")
type: string
isLocalGPU:
description: Schedule the workload to the same GPU server that runs
vGPU worker for best performance, default to false
Expand Down
4 changes: 4 additions & 0 deletions config/crd/bases/tensor-fusion.ai_workloadprofiles.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,10 @@ spec:
description: The number of GPUs to be used by the workload, default
to 1
type: integer
gpuModel:
description: GPUModel specifies the required GPU model (e.g., "A100",
"H100")
type: string
isLocalGPU:
description: Schedule the workload to the same GPU server that runs
vGPU worker for best performance, default to false
Expand Down
68 changes: 36 additions & 32 deletions internal/controller/tensorfusionworkload_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@ import (
"context"
"fmt"
"sort"
"strings"

corev1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/equality"
"k8s.io/apimachinery/pkg/api/errors"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/types"
"k8s.io/client-go/tools/record"
ctrl "sigs.k8s.io/controller-runtime"
"sigs.k8s.io/controller-runtime/pkg/client"
Expand Down Expand Up @@ -217,12 +219,12 @@ func (r *TensorFusionWorkloadReconciler) Reconcile(ctx context.Context, req ctrl
func (r *TensorFusionWorkloadReconciler) tryStartWorker(
ctx context.Context,
workerGenerator *worker.WorkerGenerator,
gpu *tfv1.GPU,
gpus []*tfv1.GPU,
workload *tfv1.TensorFusionWorkload,
hash string,
) (*corev1.Pod, error) {
port := workerGenerator.AllocPort()
pod, hash, err := workerGenerator.GenerateWorkerPod(gpu, fmt.Sprintf("%s-tf-worker-", workload.Name), workload.Namespace, port, workload.Spec.Resources.Limits, hash)
pod, hash, err := workerGenerator.GenerateWorkerPod(gpus, fmt.Sprintf("%s-tf-worker-", workload.Name), workload.Namespace, port, workload.Spec.Resources.Limits, hash)
if err != nil {
return nil, fmt.Errorf("generate worker pod %w", err)
}
Expand All @@ -231,9 +233,18 @@ func (r *TensorFusionWorkloadReconciler) tryStartWorker(
if pod.Labels == nil {
pod.Labels = make(map[string]string)
}

if pod.Annotations == nil {
pod.Annotations = make(map[string]string)
}

gpuNames := lo.Map(gpus, func(gpu *tfv1.GPU, _ int) string {
return gpu.Name
})

pod.Labels[constants.WorkloadKey] = workload.Name
pod.Labels[constants.GpuKey] = gpu.Name
pod.Labels[constants.LabelKeyPodTemplateHash] = hash
pod.Annotations[constants.GpuKey] = strings.Join(gpuNames, ",")

// Add finalizer for GPU resource cleanup
pod.Finalizers = append(pod.Finalizers, constants.Finalizer)
Expand Down Expand Up @@ -269,6 +280,7 @@ func (r *TensorFusionWorkloadReconciler) scaleDownWorkers(ctx context.Context, w
metrics.GpuTflopsLimit.Delete(labels)
metrics.VramBytesRequest.Delete(labels)
metrics.VramBytesLimit.Delete(labels)
metrics.GpuCount.Delete(labels)
}
return nil
}
Expand All @@ -279,26 +291,24 @@ func (r *TensorFusionWorkloadReconciler) handlePodGPUCleanup(ctx context.Context

log.Info("Processing pod with GPU resource cleanup finalizer", "pod", pod.Name)

// Get GPU name from pod label
gpuName, ok := pod.Labels[constants.GpuKey]
// read the GPU names from the pod annotations
gpuNamesStr, ok := pod.Annotations[constants.GpuKey]
if !ok {
log.Info("Pod has finalizer but no GPU label", "pod", pod.Name)
return true, nil
}

// Get the GPU
gpu := &tfv1.GPU{}
if err := r.Get(ctx, client.ObjectKey{Name: gpuName}, gpu); err != nil {
if errors.IsNotFound(err) {
// GPU not found, just continue
log.Info("GPU not found", "gpu", gpuName, "pod", pod.Name)
return true, nil
}
// Error getting GPU, retry later
log.Error(err, "Failed to get GPU", "gpu", gpuName, "pod", pod.Name)
// Split GPU names by comma
gpuNames := strings.Split(gpuNamesStr, ",")
gpus := lo.Map(gpuNames, func(gpuName string, _ int) types.NamespacedName {
return types.NamespacedName{Name: gpuName}
})
// Release GPU resources
if err := r.Allocator.Dealloc(ctx, workload.Spec.Resources.Requests, gpus); err != nil {
log.Error(err, "Failed to release GPU resources, will retry", "gpus", gpus, "pod", pod.Name)
return false, err
}

log.Info("Released GPU resources via finalizer", "gpus", gpus, "pod", pod.Name)
if pod.Annotations == nil {
pod.Annotations = make(map[string]string)
}
Expand All @@ -310,17 +320,10 @@ func (r *TensorFusionWorkloadReconciler) handlePodGPUCleanup(ctx context.Context
// not yet reflecting the finalizer's removal), Then this r.Update pod will fail.
// Will not cause duplicate releases
if err := r.Update(ctx, pod); err != nil {
log.Error(err, "Failed to mark that GPU cleanup of pod", "gpu", gpuName, "pod", pod.Name)
return false, err
}

// Release GPU resources
if err := r.Allocator.Dealloc(ctx, workload.Spec.Resources.Requests, gpu); err != nil {
log.Error(err, "Failed to release GPU resources, will retry", "gpu", gpuName, "pod", pod.Name)
log.Error(err, "Failed to mark that GPU cleanup of pod")
return false, err
}

log.Info("Released GPU resources via finalizer", "gpu", gpuName, "pod", pod.Name)
return true, nil
}

Expand All @@ -344,21 +347,21 @@ func (r *TensorFusionWorkloadReconciler) scaleUpWorkers(ctx context.Context, wor
// Create worker pods
for range count {
// Schedule GPU for the worker
gpus, err := r.Allocator.Alloc(ctx, workload.Spec.PoolName, workload.Spec.Resources.Requests, 1, workload.Spec.GPUModel)
gpus, err := r.Allocator.Alloc(ctx, workload.Spec.PoolName, workload.Spec.Resources.Requests, workload.Spec.GPUCount, workload.Spec.GPUModel)
if err != nil {
r.Recorder.Eventf(workload, corev1.EventTypeWarning, "ScheduleGPUFailed", "Failed to schedule GPU: %v", err)
return ctrl.Result{RequeueAfter: constants.PendingRequeueDuration}, nil
}

// Use the first GPU from the allocated array
gpu := gpus[0]

pod, err := r.tryStartWorker(ctx, workerGenerator, gpu, workload, hash)
pod, err := r.tryStartWorker(ctx, workerGenerator, gpus, workload, hash)
if err != nil {
// Try to release the GPU resource if pod creation fails
releaseErr := r.Allocator.Dealloc(ctx, workload.Spec.Resources.Requests, gpu)
// Try to release all allocated GPUs if pod creation fails
gpus := lo.Map(gpus, func(gpu *tfv1.GPU, _ int) types.NamespacedName {
return client.ObjectKeyFromObject(gpu)
})
releaseErr := r.Allocator.Dealloc(ctx, workload.Spec.Resources.Requests, gpus)
if releaseErr != nil {
log.Error(releaseErr, "Failed to release GPU after pod creation failure")
log.Error(releaseErr, "Failed to release GPU after pod creation failure", "gpus", gpus)
}
return ctrl.Result{}, fmt.Errorf("create worker pod: %w", err)
}
Expand All @@ -372,6 +375,7 @@ func (r *TensorFusionWorkloadReconciler) scaleUpWorkers(ctx context.Context, wor
metrics.GpuTflopsLimit.With(labels).Set(workload.Spec.Resources.Limits.Tflops.AsApproximateFloat64())
metrics.VramBytesRequest.With(labels).Set(workload.Spec.Resources.Requests.Vram.AsApproximateFloat64())
metrics.VramBytesLimit.With(labels).Set(workload.Spec.Resources.Limits.Vram.AsApproximateFloat64())
metrics.GpuCount.With(labels).Set(float64(workload.Spec.GPUCount))
}

return ctrl.Result{}, nil
Expand Down
54 changes: 52 additions & 2 deletions internal/controller/tensorfusionworkload_controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
package controller

import (
"strings"
"time"

"github.com/aws/smithy-go/ptr"
Expand Down Expand Up @@ -57,6 +58,51 @@ var _ = Describe("TensorFusionWorkload Controller", func() {
checkWorkerPodCount(workload)
checkWorkloadStatus(workload)
})

It("Should allocate multiple GPUs per workload when GPUCount > 1", func() {
pool := tfEnv.GetGPUPool(0)
By("creating a workload that requests 2 GPUs")
workload := &tfv1.TensorFusionWorkload{
ObjectMeta: metav1.ObjectMeta{
Name: key.Name,
Namespace: key.Namespace,
Labels: map[string]string{
constants.LabelKeyOwner: pool.Name,
},
},
Spec: tfv1.WorkloadProfileSpec{
Replicas: ptr.Int32(1),
PoolName: pool.Name,
GPUCount: 2,
Resources: tfv1.Resources{
Requests: tfv1.Resource{
Tflops: resource.MustParse("10"),
Vram: resource.MustParse("8Gi"),
},
Limits: tfv1.Resource{
Tflops: resource.MustParse("20"),
Vram: resource.MustParse("16Gi"),
},
},
},
}

Expect(k8sClient.Create(ctx, workload)).To(Succeed())

// Check that pod is created with 2 GPUs
podList := &corev1.PodList{}
Eventually(func(g Gomega) {
g.Expect(k8sClient.List(ctx, podList,
client.InNamespace(key.Namespace),
client.MatchingLabels{constants.WorkloadKey: key.Name})).Should(Succeed())
g.Expect(podList.Items).Should(HaveLen(1))

gpuNames := strings.Split(podList.Items[0].Annotations[constants.GpuKey], ",")
g.Expect(gpuNames).Should(HaveLen(2))
}, timeout, interval).Should(Succeed())

checkWorkloadStatus(workload)
})
})

Context("When scaling up a workload", func() {
Expand Down Expand Up @@ -214,6 +260,10 @@ var _ = Describe("TensorFusionWorkload Controller", func() {
g.Expect(k8sClient.List(ctx, podList,
client.InNamespace(key.Namespace),
client.MatchingLabels{constants.WorkloadKey: key.Name})).Should(Succeed())
// Filter out pods that are being deleted
podList.Items = lo.Filter(podList.Items, func(pod corev1.Pod, _ int) bool {
return pod.DeletionTimestamp == nil
})
g.Expect(podList.Items).Should(HaveLen(1))
}, timeout, interval).Should(Succeed())

Expand All @@ -225,10 +275,10 @@ var _ = Describe("TensorFusionWorkload Controller", func() {
Namespace: podList.Items[0].Namespace,
Name: podList.Items[0].Name,
}, pod)).Should(Succeed())
gpuName := pod.Labels[constants.GpuKey]
gpuNames := strings.Split(pod.Annotations[constants.GpuKey], ",")
gpuList := tfEnv.GetPoolGpuList(0)
gpu, ok := lo.Find(gpuList.Items, func(gpu tfv1.GPU) bool {
return gpu.Name == gpuName
return gpu.Name == gpuNames[0]
})
g.Expect(ok).To(BeTrue())
g.Expect(gpu.Status.GPUModel).To(Equal("mock"))
Expand Down
27 changes: 14 additions & 13 deletions internal/gpuallocator/gpuallocator.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,25 +128,26 @@ func (s *GpuAllocator) Alloc(
return result, nil
}

// Dealloc deallocates a request from a gpu.
func (s *GpuAllocator) Dealloc(ctx context.Context, request tfv1.Resource, gpu *tfv1.GPU) error {
// Dealloc deallocates a request from one or multiple gpus.
func (s *GpuAllocator) Dealloc(ctx context.Context, request tfv1.Resource, gpus []types.NamespacedName) error {
log := log.FromContext(ctx)
s.storeMutex.Lock()
defer s.storeMutex.Unlock()

// Get the GPU from the store
key := types.NamespacedName{Name: gpu.Name, Namespace: gpu.Namespace}
storeGPU, exists := s.gpuStore[key]
if !exists {
log.Info("GPU not found in store during deallocation", "name", key.String())
return fmt.Errorf("GPU %s not found in store", key.String())
}
for _, gpu := range gpus {
// Get the GPU from the store
storeGPU, exists := s.gpuStore[gpu]
if !exists {
log.Error(fmt.Errorf("GPU not found in store"), "Failed to deallocate GPU", "name", gpu.String())
continue
}

// Add resources back to the GPU
storeGPU.Status.Available.Tflops.Add(request.Tflops)
storeGPU.Status.Available.Vram.Add(request.Vram)
// Add resources back to the GPU
storeGPU.Status.Available.Tflops.Add(request.Tflops)
storeGPU.Status.Available.Vram.Add(request.Vram)

s.markGPUDirty(key)
s.markGPUDirty(gpu)
}

return nil
}
Expand Down
Loading