Skip to content

feat: enhance TensorFusionConnection lifecycle and derive worker pod #3

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions api/v1/tensorfusionconnection_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,9 @@ type TensorFusionConnectionPhase string

// These are the valid phases of a GpuConnection.
const (
TensorFusionConnectionPending TensorFusionConnectionPhase = "Pending"
TensorFusionConnectionRunning TensorFusionConnectionPhase = "Running"
TensorFusionConnectionPending TensorFusionConnectionPhase = "Pending"
TensorFusionConnectionStarting TensorFusionConnectionPhase = "Starting"
TensorFusionConnectionRunning TensorFusionConnectionPhase = "Running"
)

// TensorFusionConnectionStatus defines the observed state of TensorFusionConnection.
Expand Down
4 changes: 4 additions & 0 deletions cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ import (
"github.com/NexusGPU/tensor-fusion-operator/internal/server"
"github.com/NexusGPU/tensor-fusion-operator/internal/server/router"
webhookcorev1 "github.com/NexusGPU/tensor-fusion-operator/internal/webhook/v1"
"github.com/NexusGPU/tensor-fusion-operator/internal/worker"
// +kubebuilder:scaffold:imports
)

Expand Down Expand Up @@ -157,6 +158,9 @@ func main() {
Client: mgr.GetClient(),
Scheme: mgr.GetScheme(),
Scheduler: scheduler,
WorkerGenerator: &worker.WorkerGenerator{
PodTemplate: &config.WorkerTemplate,
},
}).SetupWithManager(mgr); err != nil {
setupLog.Error(err, "unable to create controller", "controller", "TensorFusionConnection")
os.Exit(1)
Expand Down
25 changes: 22 additions & 3 deletions internal/config/config.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
package config

import corev1 "k8s.io/api/core/v1"
import (
corev1 "k8s.io/api/core/v1"
"k8s.io/utils/ptr"
)

type Config struct {
PodMutator PodMutator `json:"podMutator"`
WorkerTemplate corev1.PodTemplate `json:"workerTemplate"`
PodMutator PodMutator `json:"podMutator"`
}

type PodMutator struct {
Expand All @@ -12,5 +16,20 @@ type PodMutator struct {
}

func NewDefaultConfig() Config {
return Config{}
return Config{
WorkerTemplate: corev1.PodTemplate{
Template: corev1.PodTemplateSpec{
Spec: corev1.PodSpec{
TerminationGracePeriodSeconds: ptr.To[int64](0),
Containers: []corev1.Container{
{
Name: "tensorfusion-worker",
Image: "busybox:stable-glibc",
Command: []string{"sleep", "infinity"},
},
},
},
},
},
}
}
47 changes: 43 additions & 4 deletions internal/controller/tensorfusionconnection_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@ package controller

import (
"context"
"fmt"

"k8s.io/apimachinery/pkg/api/errors"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/types"
"k8s.io/client-go/util/retry"
ctrl "sigs.k8s.io/controller-runtime"
"sigs.k8s.io/controller-runtime/pkg/client"
Expand All @@ -30,13 +32,15 @@ import (
"github.com/NexusGPU/tensor-fusion-operator/internal/constants"
scheduler "github.com/NexusGPU/tensor-fusion-operator/internal/scheduler"
"github.com/NexusGPU/tensor-fusion-operator/internal/worker"
corev1 "k8s.io/api/core/v1"
)

// TensorFusionConnectionReconciler reconciles a TensorFusionConnection object
type TensorFusionConnectionReconciler struct {
client.Client
Scheme *runtime.Scheme
Scheduler scheduler.Scheduler
Scheme *runtime.Scheme
Scheduler scheduler.Scheduler
WorkerGenerator *worker.WorkerGenerator
}

var (
Expand Down Expand Up @@ -102,25 +106,59 @@ func (r *TensorFusionConnectionReconciler) Reconcile(ctx context.Context, req ct
log.Info(err.Error())
connection.Status.Phase = tfv1.TensorFusionConnectionPending
} else if gpu != nil {
connection.Status.Phase = tfv1.TensorFusionConnectionRunning
connection.Status.ConnectionURL = worker.GenerateConnectionURL(gpu, connection)
connection.Status.Phase = tfv1.TensorFusionConnectionStarting
// Store the gpu name for cleanup
connection.Status.GPU = gpu.Name
} else {
// Init status
connection.Status.Phase = tfv1.TensorFusionConnectionPending
}
}

// Start worker job
phase, err := r.TryStartWorker(ctx, connection, types.NamespacedName{Name: connection.Name, Namespace: connection.Namespace})
if err != nil {
log.Error(err, "Failed to start worker pod")
return ctrl.Result{}, err
}

if phase == corev1.PodRunning {
connection.Status.Phase = tfv1.TensorFusionConnectionRunning
connection.Status.ConnectionURL = r.WorkerGenerator.GenerateConnectionURL(gpu, connection)
}
// TODO: Handle PodFailure

if err := r.MustUpdateStatus(ctx, connection, gpu); err != nil {
return ctrl.Result{}, err
}

if connection.Status.Phase == tfv1.TensorFusionConnectionPending {
// requeue
return ctrl.Result{RequeueAfter: constants.PendingRequeueDuration}, nil
}

return ctrl.Result{}, nil
}

func (r *TensorFusionConnectionReconciler) TryStartWorker(ctx context.Context, connection *tfv1.TensorFusionConnection, namespacedName types.NamespacedName) (corev1.PodPhase, error) {
// Try to get the Pod
pod := &corev1.Pod{}
if err := r.Get(ctx, namespacedName, pod); err != nil {
if errors.IsNotFound(err) {
// Pod doesn't exist, create a new one
pod = r.WorkerGenerator.GenerateWorkerPod(connection, namespacedName)
if err := ctrl.SetControllerReference(connection, pod, r.Scheme); err != nil {
return "", fmt.Errorf("set owner reference %w", err)
}
if err := r.Create(ctx, pod); err != nil {
return "", fmt.Errorf("create pod %w", err)
}
return corev1.PodPending, nil
}
}
return pod.Status.Phase, nil
}

// handleDeletion handles cleanup of external dependencies
func (r *TensorFusionConnectionReconciler) handleDeletion(ctx context.Context, connection *tfv1.TensorFusionConnection) error {
if connection.Status.GPU == "" {
Expand Down Expand Up @@ -209,6 +247,7 @@ func (r *TensorFusionConnectionReconciler) MustUpdateStatus(ctx context.Context,
func (r *TensorFusionConnectionReconciler) SetupWithManager(mgr ctrl.Manager) error {
return ctrl.NewControllerManagedBy(mgr).
For(&tfv1.TensorFusionConnection{}).
Owns(&corev1.Pod{}).
Named("tensorfusionconnection").
Complete(r)
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ import (
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"

tensorfusionaiv1 "github.com/NexusGPU/tensor-fusion-operator/api/v1"
"github.com/NexusGPU/tensor-fusion-operator/internal/config"
"github.com/NexusGPU/tensor-fusion-operator/internal/worker"
)

var _ = Describe("TensorFusionConnection Controller", func() {
Expand Down Expand Up @@ -68,11 +70,14 @@ var _ = Describe("TensorFusionConnection Controller", func() {
})
It("should successfully reconcile the resource", func() {
By("Reconciling the created resource")
config := config.NewDefaultConfig()
controllerReconciler := &TensorFusionConnectionReconciler{
Client: k8sClient,
Scheme: k8sClient.Scheme(),
WorkerGenerator: &worker.WorkerGenerator{
PodTemplate: &config.WorkerTemplate,
},
}

_, err := controllerReconciler.Reconcile(ctx, reconcile.Request{
NamespacedName: typeNamespacedName,
})
Expand Down
22 changes: 21 additions & 1 deletion internal/worker/worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,28 @@ package worker

import (
tfv1 "github.com/NexusGPU/tensor-fusion-operator/api/v1"
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/types"
)

func GenerateConnectionURL(_gpu *tfv1.GPU, _connection *tfv1.TensorFusionConnection) string {
type WorkerGenerator struct {
PodTemplate *corev1.PodTemplate
}

func (wg *WorkerGenerator) GenerateConnectionURL(_gpu *tfv1.GPU, _connection *tfv1.TensorFusionConnection) string {
return "TODO://"
}

func (wg *WorkerGenerator) GenerateWorkerPod(
connection *tfv1.TensorFusionConnection,
namespacedName types.NamespacedName,
) *corev1.Pod {
return &corev1.Pod{
ObjectMeta: metav1.ObjectMeta{
Name: namespacedName.Name,
Namespace: namespacedName.Namespace,
},
Spec: wg.PodTemplate.Template.Spec,
}
}
Loading