Skip to content

Commit

Permalink
Fully consolidate tfjob-operator to training-operator (#1850)
Browse files Browse the repository at this point in the history
Signed-off-by: Yuki Iwai <yuki.iwai.tz@gmail.com>
  • Loading branch information
tenzen-y authored Jul 5, 2023
1 parent 59cc98c commit fcdf9a3
Show file tree
Hide file tree
Showing 8 changed files with 69 additions and 240 deletions.
3 changes: 3 additions & 0 deletions pkg/common/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,4 +88,7 @@ type ControllerInterface interface {
// It will requeue the job in case of an error while creating/deleting services.
// Common implementation will be provided and User can still override this to implement their own reconcile logic
ReconcileServices(job metav1.Object, services []*v1.Service, rtype apiv1.ReplicaType, spec *apiv1.ReplicaSpec) error

// GetFrameworkName returns framework name (e.g., tensorflow).
GetFrameworkName() string
}
9 changes: 9 additions & 0 deletions pkg/controller.v1/common/pod.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"strings"

apiv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1"
trainingoperatorcommon "github.com/kubeflow/training-operator/pkg/common"
"github.com/kubeflow/training-operator/pkg/controller.v1/control"
"github.com/kubeflow/training-operator/pkg/controller.v1/expectation"
"github.com/kubeflow/training-operator/pkg/core"
Expand Down Expand Up @@ -356,6 +357,14 @@ func (jc *JobController) ReconcilePods(
// Deletion is expected
jc.Expectations.RaiseExpectations(expectationPodsKey, 0, 1)

msg := fmt.Sprintf("job %s is restarting because %s replica(s) failed.",
metaObject.GetName(), rType)
jc.Recorder.Event(runtimeObject, v1.EventTypeWarning, "JobRestarting", msg)
if err := commonutil.UpdateJobConditions(jobStatus, apiv1.JobRestarting, "JobRestarting", msg); err != nil {
commonutil.LoggerForJob(metaObject).Infof("Append job condition error: %v", err)
return err
}
trainingoperatorcommon.RestartedJobsCounterInc(metaObject.GetNamespace(), jc.Controller.GetFrameworkName())
}

updateJobReplicaStatuses(jobStatus, rType, pod)
Expand Down
14 changes: 9 additions & 5 deletions pkg/controller.v1/mpi/mpijob_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,10 @@ func (jc *MPIJobReconciler) GetGroupNameLabelValue() string {
return kubeflowv1.GroupVersion.Group
}

func (jc *MPIJobReconciler) GetFrameworkName() string {
return kubeflowv1.MPIJobFrameworkName
}

// SetClusterSpec is overridden because no cluster spec is needed for MPIJob
func (jc *MPIJobReconciler) SetClusterSpec(job interface{}, podTemplate *corev1.PodTemplateSpec, rtype, index string) error {
return nil
Expand Down Expand Up @@ -314,7 +318,7 @@ func (jc *MPIJobReconciler) onOwnerCreateFunc() func(event.CreateEvent) bool {
jc.Scheme.Default(mpiJob)
msg := fmt.Sprintf("MPIJob %s/%s is created.", mpiJob.Namespace, e.Object.GetName())
logrus.Info(msg)
trainingoperatorcommon.CreatedJobsCounterInc(mpiJob.Namespace, kubeflowv1.MPIJobFrameworkName)
trainingoperatorcommon.CreatedJobsCounterInc(mpiJob.Namespace, jc.GetFrameworkName())
if err := commonutil.UpdateJobConditions(&mpiJob.Status, kubeflowv1.JobCreated, mpiJobCreatedReason, msg); err != nil {
log.Log.Error(err, "append job condition error")
return false
Expand Down Expand Up @@ -546,7 +550,7 @@ func (jc *MPIJobReconciler) DeleteJob(job interface{}) error {

jc.Recorder.Eventf(mpiJob, corev1.EventTypeNormal, SuccessfulDeleteJobReason, "Deleted job: %v", mpiJob.Name)
log.Infof("job %s/%s has been deleted", mpiJob.Namespace, mpiJob.Name)
trainingoperatorcommon.DeletedJobsCounterInc(mpiJob.Namespace, kubeflowv1.MPIJobFrameworkName)
trainingoperatorcommon.DeletedJobsCounterInc(mpiJob.Namespace, jc.GetFrameworkName())
return nil
}

Expand Down Expand Up @@ -597,7 +601,7 @@ func (jc *MPIJobReconciler) UpdateJobStatus(job interface{}, replicas map[kubefl
commonutil.LoggerForJob(mpiJob).Infof("Append job condition error: %v", err)
return err
}
trainingoperatorcommon.SuccessfulJobsCounterInc(mpiJob.Namespace, kubeflowv1.MPIJobFrameworkName)
trainingoperatorcommon.SuccessfulJobsCounterInc(mpiJob.Namespace, jc.GetFrameworkName())
return nil
}
}
Expand All @@ -610,7 +614,7 @@ func (jc *MPIJobReconciler) UpdateJobStatus(job interface{}, replicas map[kubefl
commonutil.LoggerForJob(mpiJob).Infof("Append job condition error: %v", err)
return err
}
trainingoperatorcommon.RestartedJobsCounterInc(mpiJob.Namespace, kubeflowv1.MPIJobFrameworkName)
trainingoperatorcommon.RestartedJobsCounterInc(mpiJob.Namespace, jc.GetFrameworkName())
} else {
msg := fmt.Sprintf("MPIJob %s is failed because %d %s replica(s) failed.", mpiJob.Name, failed, rtype)
jc.Recorder.Event(mpiJob, corev1.EventTypeNormal, commonutil.JobFailedReason, msg)
Expand All @@ -623,7 +627,7 @@ func (jc *MPIJobReconciler) UpdateJobStatus(job interface{}, replicas map[kubefl
commonutil.LoggerForJob(mpiJob).Infof("Append job condition error: %v", err)
return err
}
trainingoperatorcommon.FailedJobsCounterInc(mpiJob.Namespace, kubeflowv1.MPIJobFrameworkName)
trainingoperatorcommon.FailedJobsCounterInc(mpiJob.Namespace, jc.GetFrameworkName())
}
}
}
Expand Down
14 changes: 9 additions & 5 deletions pkg/controller.v1/mxnet/mxjob_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,10 @@ func (r *MXJobReconciler) GetGroupNameLabelValue() string {
return kubeflowv1.GroupVersion.Group
}

func (r *MXJobReconciler) GetFrameworkName() string {
return kubeflowv1.MXJobFrameworkName
}

func (r *MXJobReconciler) GetJobFromInformerCache(namespace, name string) (metav1.Object, error) {
job := &kubeflowv1.MXJob{}
err := r.Get(context.Background(), types.NamespacedName{Namespace: namespace, Name: name}, job)
Expand Down Expand Up @@ -331,7 +335,7 @@ func (r *MXJobReconciler) DeleteJob(job interface{}) error {
}
r.Recorder.Eventf(mxjob, corev1.EventTypeNormal, control.SuccessfulDeletePodReason, "Deleted job: %v", mxjob.Name)
logrus.Info("job deleted", "namespace", mxjob.Namespace, "name", mxjob.Name)
trainingoperatorcommon.DeletedJobsCounterInc(mxjob.Namespace, kubeflowv1.MXJobFrameworkName)
trainingoperatorcommon.DeletedJobsCounterInc(mxjob.Namespace, r.GetFrameworkName())
return nil
}

Expand Down Expand Up @@ -394,7 +398,7 @@ func (r *MXJobReconciler) UpdateJobStatus(job interface{}, replicas map[kubeflow
logrus.Infof("Append mxjob condition error: %v", err)
return err
}
trainingoperatorcommon.SuccessfulJobsCounterInc(mxjob.Namespace, kubeflowv1.MXJobFrameworkName)
trainingoperatorcommon.SuccessfulJobsCounterInc(mxjob.Namespace, r.GetFrameworkName())
return nil
}
}
Expand All @@ -407,7 +411,7 @@ func (r *MXJobReconciler) UpdateJobStatus(job interface{}, replicas map[kubeflow
logrus.Infof("Append job condition error: %v", err)
return err
}
trainingoperatorcommon.RestartedJobsCounterInc(mxjob.Namespace, kubeflowv1.MXJobFrameworkName)
trainingoperatorcommon.RestartedJobsCounterInc(mxjob.Namespace, r.GetFrameworkName())
} else {
msg := fmt.Sprintf("mxjob %s is failed because %d %s replica(s) failed.", mxjob.Name, failed, rtype)
r.Recorder.Event(mxjob, corev1.EventTypeNormal, mxJobFailedReason, msg)
Expand All @@ -420,7 +424,7 @@ func (r *MXJobReconciler) UpdateJobStatus(job interface{}, replicas map[kubeflow
logrus.Infof("Append job condition error: %v", err)
return err
}
trainingoperatorcommon.FailedJobsCounterInc(mxjob.Namespace, kubeflowv1.MXJobFrameworkName)
trainingoperatorcommon.FailedJobsCounterInc(mxjob.Namespace, r.GetFrameworkName())
}
}
}
Expand Down Expand Up @@ -481,7 +485,7 @@ func (r *MXJobReconciler) onOwnerCreateFunc() func(event.CreateEvent) bool {
r.Scheme.Default(mxJob)
msg := fmt.Sprintf("MXJob %s is created.", e.Object.GetName())
logrus.Info(msg)
trainingoperatorcommon.CreatedJobsCounterInc(mxJob.Namespace, kubeflowv1.MXJobFrameworkName)
trainingoperatorcommon.CreatedJobsCounterInc(mxJob.Namespace, r.GetFrameworkName())
if err := commonutil.UpdateJobConditions(&mxJob.Status, kubeflowv1.JobCreated, "MXJobCreated", msg); err != nil {
logrus.Error(err, "append job condition error")
return false
Expand Down
16 changes: 10 additions & 6 deletions pkg/controller.v1/paddlepaddle/paddlepaddle_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,10 @@ func (r *PaddleJobReconciler) GetGroupNameLabelValue() string {
return kubeflowv1.GroupVersion.Group
}

func (r *PaddleJobReconciler) GetFrameworkName() string {
return kubeflowv1.PaddleJobFrameworkName
}

func (r *PaddleJobReconciler) GetJobFromInformerCache(namespace, name string) (metav1.Object, error) {
job := &kubeflowv1.PaddleJob{}
err := r.Get(context.Background(), types.NamespacedName{Namespace: namespace, Name: name}, job)
Expand Down Expand Up @@ -328,7 +332,7 @@ func (r *PaddleJobReconciler) DeleteJob(job interface{}) error {
}
r.recorder.Eventf(paddlejob, corev1.EventTypeNormal, control.SuccessfulDeletePodReason, "Deleted job: %v", paddlejob.Name)
logrus.Info("job deleted", "namespace", paddlejob.Namespace, "name", paddlejob.Name)
trainingoperatorcommon.DeletedJobsCounterInc(paddlejob.Namespace, kubeflowv1.PaddleJobFrameworkName)
trainingoperatorcommon.DeletedJobsCounterInc(paddlejob.Namespace, r.GetFrameworkName())
return nil
}

Expand Down Expand Up @@ -408,7 +412,7 @@ func (r *PaddleJobReconciler) UpdateJobStatus(job interface{},
commonutil.LoggerForJob(paddlejob).Infof("Append job condition error: %v", err)
return err
}
trainingoperatorcommon.SuccessfulJobsCounterInc(paddlejob.Namespace, kubeflowv1.PaddleJobFrameworkName)
trainingoperatorcommon.SuccessfulJobsCounterInc(paddlejob.Namespace, r.GetFrameworkName())
return nil
}
}
Expand All @@ -429,7 +433,7 @@ func (r *PaddleJobReconciler) UpdateJobStatus(job interface{},
commonutil.LoggerForJob(paddlejob).Infof("Append paddlejob condition error: %v", err)
return err
}
trainingoperatorcommon.SuccessfulJobsCounterInc(paddlejob.Namespace, kubeflowv1.PaddleJobFrameworkName)
trainingoperatorcommon.SuccessfulJobsCounterInc(paddlejob.Namespace, r.GetFrameworkName())
} else if running > 0 {
// Some workers are still running, leave a running condition.
msg := fmt.Sprintf("PaddleJob %s/%s is running.",
Expand All @@ -452,7 +456,7 @@ func (r *PaddleJobReconciler) UpdateJobStatus(job interface{},
commonutil.LoggerForJob(paddlejob).Infof("Append job condition error: %v", err)
return err
}
trainingoperatorcommon.RestartedJobsCounterInc(paddlejob.Namespace, kubeflowv1.PaddleJobFrameworkName)
trainingoperatorcommon.RestartedJobsCounterInc(paddlejob.Namespace, r.GetFrameworkName())
} else {
msg := fmt.Sprintf("PaddleJob %s is failed because %d %s replica(s) failed.", paddlejob.Name, failed, rtype)
r.Recorder.Event(paddlejob, corev1.EventTypeNormal, commonutil.JobFailedReason, msg)
Expand All @@ -465,7 +469,7 @@ func (r *PaddleJobReconciler) UpdateJobStatus(job interface{},
commonutil.LoggerForJob(paddlejob).Infof("Append job condition error: %v", err)
return err
}
trainingoperatorcommon.FailedJobsCounterInc(paddlejob.Namespace, kubeflowv1.PaddleJobFrameworkName)
trainingoperatorcommon.FailedJobsCounterInc(paddlejob.Namespace, r.GetFrameworkName())
}
}
}
Expand Down Expand Up @@ -544,7 +548,7 @@ func (r *PaddleJobReconciler) onOwnerCreateFunc() func(event.CreateEvent) bool {
r.Scheme.Default(paddlejob)
msg := fmt.Sprintf("PaddleJob %s is created.", e.Object.GetName())
logrus.Info(msg)
trainingoperatorcommon.CreatedJobsCounterInc(paddlejob.Namespace, kubeflowv1.PaddleJobFrameworkName)
trainingoperatorcommon.CreatedJobsCounterInc(paddlejob.Namespace, r.GetFrameworkName())
if err := commonutil.UpdateJobConditions(&paddlejob.Status, kubeflowv1.JobCreated, "PaddleJobCreated", msg); err != nil {
logrus.Error(err, "append job condition error")
return false
Expand Down
16 changes: 10 additions & 6 deletions pkg/controller.v1/pytorch/pytorchjob_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,10 @@ func (r *PyTorchJobReconciler) GetGroupNameLabelValue() string {
return kubeflowv1.GroupVersion.Group
}

func (r *PyTorchJobReconciler) GetFrameworkName() string {
return kubeflowv1.PytorchJobFrameworkName
}

func (r *PyTorchJobReconciler) GetJobFromInformerCache(namespace, name string) (metav1.Object, error) {
job := &kubeflowv1.PyTorchJob{}
err := r.Get(context.Background(), types.NamespacedName{Namespace: namespace, Name: name}, job)
Expand Down Expand Up @@ -328,7 +332,7 @@ func (r *PyTorchJobReconciler) DeleteJob(job interface{}) error {
}
r.recorder.Eventf(pytorchjob, corev1.EventTypeNormal, control.SuccessfulDeletePodReason, "Deleted job: %v", pytorchjob.Name)
logrus.Info("job deleted", "namespace", pytorchjob.Namespace, "name", pytorchjob.Name)
trainingoperatorcommon.DeletedJobsCounterInc(pytorchjob.Namespace, kubeflowv1.PytorchJobFrameworkName)
trainingoperatorcommon.DeletedJobsCounterInc(pytorchjob.Namespace, r.GetFrameworkName())
return nil
}

Expand Down Expand Up @@ -407,7 +411,7 @@ func (r *PyTorchJobReconciler) UpdateJobStatus(job interface{},
commonutil.LoggerForJob(pytorchjob).Infof("Append job condition error: %v", err)
return err
}
trainingoperatorcommon.SuccessfulJobsCounterInc(pytorchjob.Namespace, kubeflowv1.PytorchJobFrameworkName)
trainingoperatorcommon.SuccessfulJobsCounterInc(pytorchjob.Namespace, r.GetFrameworkName())
return nil
}
}
Expand All @@ -431,7 +435,7 @@ func (r *PyTorchJobReconciler) UpdateJobStatus(job interface{},
commonutil.LoggerForJob(pytorchjob).Infof("Append pytorchjob condition error: %v", err)
return err
}
trainingoperatorcommon.SuccessfulJobsCounterInc(pytorchjob.Namespace, kubeflowv1.PytorchJobFrameworkName)
trainingoperatorcommon.SuccessfulJobsCounterInc(pytorchjob.Namespace, r.GetFrameworkName())
} else if running > 0 {
// Some workers are still running, leave a running condition.
msg := fmt.Sprintf("PyTorchJob %s/%s is running.",
Expand All @@ -454,7 +458,7 @@ func (r *PyTorchJobReconciler) UpdateJobStatus(job interface{},
commonutil.LoggerForJob(pytorchjob).Infof("Append job condition error: %v", err)
return err
}
trainingoperatorcommon.RestartedJobsCounterInc(pytorchjob.Namespace, kubeflowv1.PytorchJobFrameworkName)
trainingoperatorcommon.RestartedJobsCounterInc(pytorchjob.Namespace, r.GetFrameworkName())
} else {
msg := fmt.Sprintf("PyTorchJob %s is failed because %d %s replica(s) failed.", pytorchjob.Name, failed, rtype)
r.Recorder.Event(pytorchjob, corev1.EventTypeNormal, commonutil.JobFailedReason, msg)
Expand All @@ -467,7 +471,7 @@ func (r *PyTorchJobReconciler) UpdateJobStatus(job interface{},
commonutil.LoggerForJob(pytorchjob).Infof("Append job condition error: %v", err)
return err
}
trainingoperatorcommon.FailedJobsCounterInc(pytorchjob.Namespace, kubeflowv1.PytorchJobFrameworkName)
trainingoperatorcommon.FailedJobsCounterInc(pytorchjob.Namespace, r.GetFrameworkName())
}
}
}
Expand Down Expand Up @@ -547,7 +551,7 @@ func (r *PyTorchJobReconciler) onOwnerCreateFunc() func(event.CreateEvent) bool
r.Scheme.Default(pytorchjob)
msg := fmt.Sprintf("PyTorchJob %s is created.", e.Object.GetName())
logrus.Info(msg)
trainingoperatorcommon.CreatedJobsCounterInc(pytorchjob.Namespace, kubeflowv1.PytorchJobFrameworkName)
trainingoperatorcommon.CreatedJobsCounterInc(pytorchjob.Namespace, r.GetFrameworkName())
if err := commonutil.UpdateJobConditions(&pytorchjob.Status, kubeflowv1.JobCreated, "PyTorchJobCreated", msg); err != nil {
logrus.Error(err, "append job condition error")
return false
Expand Down
Loading

0 comments on commit fcdf9a3

Please sign in to comment.