Skip to content

Commit da0ad31

Browse files
committed
feat: enhance TensorFusionConnection lifecycle and worker configuration
- Add "Starting" phase to TensorFusionConnection states - Implement worker pod template configuration in config package - Update controller to use WorkerGenerator for pod management - Enhance main.go with worker template initialization
1 parent da891eb commit da0ad31

File tree

6 files changed

+95
-11
lines changed

6 files changed

+95
-11
lines changed

api/v1/tensorfusionconnection_types.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,9 @@ type TensorFusionConnectionPhase string
4040

4141
// These are the valid phases of a GpuConnection.
4242
const (
43-
TensorFusionConnectionPending TensorFusionConnectionPhase = "Pending"
44-
TensorFusionConnectionRunning TensorFusionConnectionPhase = "Running"
43+
TensorFusionConnectionPending TensorFusionConnectionPhase = "Pending"
44+
TensorFusionConnectionStarting TensorFusionConnectionPhase = "Starting"
45+
TensorFusionConnectionRunning TensorFusionConnectionPhase = "Running"
4546
)
4647

4748
// TensorFusionConnectionStatus defines the observed state of TensorFusionConnection.

cmd/main.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ import (
4444
"github.com/NexusGPU/tensor-fusion-operator/internal/server"
4545
"github.com/NexusGPU/tensor-fusion-operator/internal/server/router"
4646
webhookcorev1 "github.com/NexusGPU/tensor-fusion-operator/internal/webhook/v1"
47+
"github.com/NexusGPU/tensor-fusion-operator/internal/worker"
4748
// +kubebuilder:scaffold:imports
4849
)
4950

@@ -157,6 +158,9 @@ func main() {
157158
Client: mgr.GetClient(),
158159
Scheme: mgr.GetScheme(),
159160
Scheduler: scheduler,
161+
WorkerGenerator: &worker.WorkerGenerator{
162+
PodTemplate: &config.WorkerTemplate,
163+
},
160164
}).SetupWithManager(mgr); err != nil {
161165
setupLog.Error(err, "unable to create controller", "controller", "TensorFusionConnection")
162166
os.Exit(1)

internal/config/config.go

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
package config
22

3-
import corev1 "k8s.io/api/core/v1"
3+
import (
4+
corev1 "k8s.io/api/core/v1"
5+
)
46

57
type Config struct {
6-
PodMutator PodMutator `json:"podMutator"`
8+
WorkerTemplate corev1.PodTemplate `json:"workerTemplate"`
9+
PodMutator PodMutator `json:"podMutator"`
710
}
811

912
type PodMutator struct {
@@ -12,5 +15,19 @@ type PodMutator struct {
1215
}
1316

1417
func NewDefaultConfig() Config {
15-
return Config{}
18+
return Config{
19+
WorkerTemplate: corev1.PodTemplate{
20+
Template: corev1.PodTemplateSpec{
21+
Spec: corev1.PodSpec{
22+
Containers: []corev1.Container{
23+
{
24+
Name: "tensorfusion-worker",
25+
Image: "busybox:stable-glibc",
26+
Command: []string{"sleep", "infinity"},
27+
},
28+
},
29+
},
30+
},
31+
},
32+
}
1633
}

internal/controller/tensorfusionconnection_controller.go

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import (
2121

2222
"k8s.io/apimachinery/pkg/api/errors"
2323
"k8s.io/apimachinery/pkg/runtime"
24+
"k8s.io/apimachinery/pkg/types"
2425
"k8s.io/client-go/util/retry"
2526
ctrl "sigs.k8s.io/controller-runtime"
2627
"sigs.k8s.io/controller-runtime/pkg/client"
@@ -30,13 +31,15 @@ import (
3031
"github.com/NexusGPU/tensor-fusion-operator/internal/constants"
3132
scheduler "github.com/NexusGPU/tensor-fusion-operator/internal/scheduler"
3233
"github.com/NexusGPU/tensor-fusion-operator/internal/worker"
34+
corev1 "k8s.io/api/core/v1"
3335
)
3436

3537
// TensorFusionConnectionReconciler reconciles a TensorFusionConnection object
3638
type TensorFusionConnectionReconciler struct {
3739
client.Client
38-
Scheme *runtime.Scheme
39-
Scheduler scheduler.Scheduler
40+
Scheme *runtime.Scheme
41+
Scheduler scheduler.Scheduler
42+
WorkerGenerator *worker.WorkerGenerator
4043
}
4144

4245
var (
@@ -102,25 +105,58 @@ func (r *TensorFusionConnectionReconciler) Reconcile(ctx context.Context, req ct
102105
log.Info(err.Error())
103106
connection.Status.Phase = tfv1.TensorFusionConnectionPending
104107
} else if gpu != nil {
105-
connection.Status.Phase = tfv1.TensorFusionConnectionRunning
106-
connection.Status.ConnectionURL = worker.GenerateConnectionURL(gpu, connection)
108+
connection.Status.Phase = tfv1.TensorFusionConnectionStarting
107109
// Store the gpu name for cleanup
108110
connection.Status.GPU = gpu.Name
109111
} else {
112+
// Init status
110113
connection.Status.Phase = tfv1.TensorFusionConnectionPending
111114
}
112115
}
113116

117+
if connection.Status.Phase == tfv1.TensorFusionConnectionStarting {
118+
// Start worker job
119+
phase, err := r.StartWorker(ctx, connection, types.NamespacedName{Name: gpu.Name, Namespace: gpu.Namespace})
120+
if err != nil {
121+
log.Error(err, "Failed to start worker pod")
122+
return ctrl.Result{}, err
123+
}
124+
125+
if phase == corev1.PodRunning {
126+
connection.Status.Phase = tfv1.TensorFusionConnectionRunning
127+
connection.Status.ConnectionURL = r.WorkerGenerator.GenerateConnectionURL(gpu, connection)
128+
}
129+
// TODO: Handle PodFailure
130+
}
131+
114132
if err := r.MustUpdateStatus(ctx, connection, gpu); err != nil {
115133
return ctrl.Result{}, err
116134
}
117135

118136
if connection.Status.Phase == tfv1.TensorFusionConnectionPending {
137+
// requeue
119138
return ctrl.Result{RequeueAfter: constants.PendingRequeueDuration}, nil
120139
}
140+
121141
return ctrl.Result{}, nil
122142
}
123143

144+
func (r *TensorFusionConnectionReconciler) StartWorker(ctx context.Context, connection *tfv1.TensorFusionConnection, namespacedName types.NamespacedName) (corev1.PodPhase, error) {
145+
// Try to get the Pod
146+
pod := &corev1.Pod{}
147+
if err := r.Get(ctx, namespacedName, pod); err != nil {
148+
if errors.IsNotFound(err) {
149+
// Pod doesn't exist, create a new one
150+
pod = r.WorkerGenerator.GenerateWorkerPod(connection, namespacedName)
151+
if err := r.Create(ctx, pod); err != nil {
152+
return "", err
153+
}
154+
return corev1.PodPending, nil
155+
}
156+
}
157+
return pod.Status.Phase, nil
158+
}
159+
124160
// handleDeletion handles cleanup of external dependencies
125161
func (r *TensorFusionConnectionReconciler) handleDeletion(ctx context.Context, connection *tfv1.TensorFusionConnection) error {
126162
if connection.Status.GPU == "" {
@@ -209,6 +245,7 @@ func (r *TensorFusionConnectionReconciler) MustUpdateStatus(ctx context.Context,
209245
func (r *TensorFusionConnectionReconciler) SetupWithManager(mgr ctrl.Manager) error {
210246
return ctrl.NewControllerManagedBy(mgr).
211247
For(&tfv1.TensorFusionConnection{}).
248+
Owns(&corev1.Pod{}).
212249
Named("tensorfusionconnection").
213250
Complete(r)
214251
}

internal/controller/tensorfusionconnection_controller_test.go

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ import (
2828
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
2929

3030
tensorfusionaiv1 "github.com/NexusGPU/tensor-fusion-operator/api/v1"
31+
"github.com/NexusGPU/tensor-fusion-operator/internal/config"
32+
"github.com/NexusGPU/tensor-fusion-operator/internal/worker"
3133
)
3234

3335
var _ = Describe("TensorFusionConnection Controller", func() {
@@ -68,11 +70,14 @@ var _ = Describe("TensorFusionConnection Controller", func() {
6870
})
6971
It("should successfully reconcile the resource", func() {
7072
By("Reconciling the created resource")
73+
config := config.NewDefaultConfig()
7174
controllerReconciler := &TensorFusionConnectionReconciler{
7275
Client: k8sClient,
7376
Scheme: k8sClient.Scheme(),
77+
WorkerGenerator: &worker.WorkerGenerator{
78+
PodTemplate: &config.WorkerTemplate,
79+
},
7480
}
75-
7681
_, err := controllerReconciler.Reconcile(ctx, reconcile.Request{
7782
NamespacedName: typeNamespacedName,
7883
})

internal/worker/worker.go

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,28 @@ package worker
22

33
import (
44
tfv1 "github.com/NexusGPU/tensor-fusion-operator/api/v1"
5+
corev1 "k8s.io/api/core/v1"
6+
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
7+
"k8s.io/apimachinery/pkg/types"
58
)
69

7-
func GenerateConnectionURL(_gpu *tfv1.GPU, _connection *tfv1.TensorFusionConnection) string {
10+
type WorkerGenerator struct {
11+
PodTemplate *corev1.PodTemplate
12+
}
13+
14+
func (wg *WorkerGenerator) GenerateConnectionURL(_gpu *tfv1.GPU, _connection *tfv1.TensorFusionConnection) string {
815
return "TODO://"
916
}
17+
18+
func (wg *WorkerGenerator) GenerateWorkerPod(
19+
connection *tfv1.TensorFusionConnection,
20+
namespacedName types.NamespacedName,
21+
) *corev1.Pod {
22+
return &corev1.Pod{
23+
ObjectMeta: metav1.ObjectMeta{
24+
Name: namespacedName.Name,
25+
Namespace: namespacedName.Namespace,
26+
},
27+
Spec: wg.PodTemplate.Template.Spec,
28+
}
29+
}

0 commit comments

Comments
 (0)