Skip to content

Commit

Permalink
feat(pytorch): Support init container config
Browse files Browse the repository at this point in the history
Signed-off-by: Ce Gao <ce.gao@outlook.com>
  • Loading branch information
gaocegege committed Nov 30, 2021
1 parent b3c2b4c commit aefa0c0
Show file tree
Hide file tree
Showing 7 changed files with 194 additions and 20 deletions.
8 changes: 8 additions & 0 deletions cmd/training-operator.v1/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"fmt"
"os"

"github.com/kubeflow/training-operator/pkg/config"
controller_v1 "github.com/kubeflow/training-operator/pkg/controller.v1"

// Import all Kubernetes client auth plugins (e.g. Azure, GCP, OIDC, etc.)
Expand Down Expand Up @@ -70,6 +71,13 @@ func main() {
flag.Var(&enabledSchemes, "enable-scheme", "Enable scheme(s) as --enable-scheme=tfjob --enable-scheme=pytorchjob, case insensitive."+
" Now supporting TFJob, PyTorchJob, MXNetJob, XGBoostJob. By default, all supported schemes will be enabled.")
flag.BoolVar(&enableGangScheduling, "enable-gang-scheduling", false, "Set true to enable gang scheduling")

// PyTorch related flags
flag.StringVar(&config.Config.PyTorchInitContainerImage, "pytorch-init-container-image",
config.PyTorchInitContainerImageDefault, "The image for pytorch init container")
flag.StringVar(&config.Config.PyTorchInitContainerTemplateFile, "pytorch-init-container-template-file",
config.PyTorchInitContainerTemplateFileDefault, "The template file for pytorch init container")

opts := zap.Options{
Development: true,
}
Expand Down
16 changes: 16 additions & 0 deletions pkg/config/config.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package config

// Config is the global configuration for the training operator.
var Config struct {
PyTorchInitContainerTemplateFile string
PyTorchInitContainerImage string
}

const (
// PyTorchInitContainerImageDefault is the default image for the pytorch
// init container.
PyTorchInitContainerImageDefault = "alpine:3.10"
// PyTorchInitContainerTemplateFileDefault is the default template file for
// the pytorch init container.
PyTorchInitContainerTemplateFileDefault = "/etc/config/initContainer.yaml"
)
5 changes: 1 addition & 4 deletions pkg/controller.v1/pytorch/elastic.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,7 @@ var (
onceElastic sync.Once
)

type EnvVarGenerator interface {
Generate(job *pytorchv1.PyTorchJob) ([]corev1.EnvVar, error)
}

// ElasticEnvVarGenerator is the environment variable generator for Elastic related arguments.
type ElasticEnvVarGenerator struct{}

func GetElasticEnvVarGenerator() EnvVarGenerator {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,21 +24,31 @@ import (
corev1 "k8s.io/api/core/v1"
)

func SetPodEnv(obj interface{}, podTemplateSpec *corev1.PodTemplateSpec, rtype, index string) error {
// EnvVarGenerator is the environment variable generator interface.
type EnvVarGenerator interface {
Generate(job *pytorchv1.PyTorchJob) ([]corev1.EnvVar, error)
}

func setPodEnv(obj interface{}, podTemplateSpec *corev1.PodTemplateSpec, rtype, index string) error {
pytorchjob, ok := obj.(*pytorchv1.PyTorchJob)
if !ok {
return fmt.Errorf("%+v is not a type of PyTorchJob", obj)
}

for i := range podTemplateSpec.Spec.Containers {
// Initialize the environment variables.
if len(podTemplateSpec.Spec.Containers[i].Env) == 0 {
podTemplateSpec.Spec.Containers[i].Env = make([]corev1.EnvVar, 0)
}
podTemplateSpec.Spec.Containers[i].Env = append(podTemplateSpec.Spec.Containers[i].Env, corev1.EnvVar{
Name: "PYTHONUNBUFFERED",
Value: "0",
})
// Set PYTHONUNBUFFERED to true, to disable output buffering.
// Ref https://stackoverflow.com/questions/59812009/what-is-the-use-of-pythonunbuffered-in-docker-file.
podTemplateSpec.Spec.Containers[i].Env = append(
podTemplateSpec.Spec.Containers[i].Env, corev1.EnvVar{
Name: "PYTHONUNBUFFERED",
Value: "0",
})

// If the master is not null, then we need to set the MASTER_ADDR and RANK.
if pytorchjob.Spec.PyTorchReplicaSpecs[pytorchv1.PyTorchReplicaTypeMaster] != nil {
envVars, err := GetMasterEnvVarGenerator().Generate(pytorchjob)
if err != nil {
Expand Down Expand Up @@ -69,14 +79,16 @@ func SetPodEnv(obj interface{}, podTemplateSpec *corev1.PodTemplateSpec, rtype,
})
}

envVars, err := GetElasticEnvVarGenerator().Generate(pytorchjob)
if err != nil {
return err
// Set the elastic environment variables if the elasticPolicy is not null.
if pytorchjob.Spec.ElasticPolicy != nil {
envVars, err := GetElasticEnvVarGenerator().Generate(pytorchjob)
if err != nil {
return err
}
// Set elastic related environment variables.
podTemplateSpec.Spec.Containers[i].Env = append(
podTemplateSpec.Spec.Containers[i].Env, envVars...)
}
// Set elastic related environment variables.
podTemplateSpec.Spec.Containers[i].Env = append(
podTemplateSpec.Spec.Containers[i].Env, envVars...)

}

return nil
Expand Down
134 changes: 134 additions & 0 deletions pkg/controller.v1/pytorch/initcontainer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
// Copyright 2021 The Kubeflow Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License

package pytorch

import (
"bytes"
"fmt"
"html/template"
"io/ioutil"
"strconv"
"strings"
"sync"

"github.com/go-logr/logr"
corev1 "k8s.io/api/core/v1"
v1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/types"
"sigs.k8s.io/yaml"

pytorchv1 "github.com/kubeflow/training-operator/pkg/apis/pytorch/v1"
"github.com/kubeflow/training-operator/pkg/config"
)

var (
initContainerTemplate = `
- name: init-pytorch
image: {{.InitContainerImage}}
imagePullPolicy: IfNotPresent
resources:
limits:
cpu: 100m
memory: 20Mi
requests:
cpu: 50m
memory: 10Mi
command: ['sh', '-c', 'until nslookup {{.MasterAddr}}; do echo waiting for master; sleep 2; done;']`
onceInitContainer sync.Once
icGenerator *initContainerGenerator
)

type initContainerGenerator struct {
template string
image string
}

func getInitContainerGenerator() *initContainerGenerator {
onceInitContainer.Do(func() {
icGenerator = &initContainerGenerator{
template: getInitContainerTemplateOrDefault(config.Config.PyTorchInitContainerTemplateFile),
image: config.Config.PyTorchInitContainerImage,
}
})
return icGenerator
}

func (i *initContainerGenerator) GetInitContainer(masterAddr string) ([]v1.Container, error) {
var buf bytes.Buffer
tpl, err := template.New("container").Parse(i.template)
if err != nil {
return nil, err
}
if err := tpl.Execute(&buf, struct {
MasterAddr string
InitContainerImage string
}{
MasterAddr: masterAddr,
InitContainerImage: i.image,
}); err != nil {
return nil, err
}

var result []v1.Container
err = yaml.Unmarshal(buf.Bytes(), &result)
if err != nil {
return nil, err
}

return result, nil
}

// getInitContainerTemplateOrDefault returns the init container template file if
// it exists, or return initContainerTemplate by default.
func getInitContainerTemplateOrDefault(file string) string {
bytes, err := ioutil.ReadFile(file)
if err == nil {
return string(bytes)
}
return initContainerTemplate
}

func setInitContainer(obj interface{}, podTemplate *corev1.PodTemplateSpec,
rtype, index string, log logr.Logger) error {
pytorchjob, ok := obj.(*pytorchv1.PyTorchJob)
if !ok {
return fmt.Errorf("%+v is not a type of PyTorchJob", obj)
}
logger := log.WithValues(pytorchv1.Singular, types.NamespacedName{
Namespace: pytorchjob.Namespace,
Name: pytorchjob.Name,
})

// There is no need to set init container if no master is specified.
if pytorchjob.Spec.PyTorchReplicaSpecs[pytorchv1.PyTorchReplicaTypeMaster] == nil {
logger.V(1).Info("No master is specified, skip setting init container")
return nil
}

// Set the init container only if the master is specified and the current
// rtype is worker.
if rtype == strings.ToLower(string(pytorchv1.PyTorchReplicaTypeWorker)) {
g := getInitContainerGenerator()
initContainers, err := g.GetInitContainer(genGeneralName(pytorchjob.Name,
strings.ToLower(string(pytorchv1.PyTorchReplicaTypeMaster)), strconv.Itoa(0)))
if err != nil {
return err
}
podTemplate.Spec.InitContainers = append(podTemplate.Spec.InitContainers,
initContainers...)

}
return nil
}
5 changes: 3 additions & 2 deletions pkg/controller.v1/pytorch/master.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,17 @@ import (

var (
masterGenerator EnvVarGenerator
once sync.Once
onceMaster sync.Once
EnvMasterPort = "MASTER_PORT"
EnvMasterAddr = "MASTER_ADDR"
)

// MasterEnvVarGenerator is the environment variable generator for Master related arguments.
type MasterEnvVarGenerator struct {
}

func GetMasterEnvVarGenerator() EnvVarGenerator {
once.Do(func() {
onceMaster.Do(func() {
masterGenerator = &MasterEnvVarGenerator{}
})
return masterGenerator
Expand Down
10 changes: 8 additions & 2 deletions pkg/controller.v1/pytorch/pytorchjob_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -465,9 +465,15 @@ func (r *PyTorchJobReconciler) UpdateJobStatusInApiServer(job interface{}, jobSt
return nil
}

// SetClusterSpec sets the cluster spec for the pod
// SetClusterSpec sets the cluster spec and init container for the pod
func (r *PyTorchJobReconciler) SetClusterSpec(job interface{}, podTemplate *corev1.PodTemplateSpec, rtype, index string) error {
return SetPodEnv(job, podTemplate, rtype, index)
if err := setPodEnv(job, podTemplate, rtype, index); err != nil {
return err
}
if err := setInitContainer(job, podTemplate, rtype, index, r.Log); err != nil {
return err
}
return nil
}

func (r *PyTorchJobReconciler) GetDefaultContainerName() string {
Expand Down

0 comments on commit aefa0c0

Please sign in to comment.