Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
gaocegege committed Nov 3, 2021
1 parent 2951971 commit df2c1ed
Show file tree
Hide file tree
Showing 6 changed files with 76 additions and 44 deletions.
7 changes: 3 additions & 4 deletions examples/pytorch/elastic/imagenet/imagenet.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ metadata:
spec:
elasticPolicy:
rdzvBackend: c10d
minReplicas: 1
maxReplicas: 2
maxRestarts: 100
pytorchReplicaSpecs:
Worker:
Expand All @@ -23,9 +25,6 @@ spec:
- python
- -m
- torch.distributed.run
- --rdzv_backend=c10d
- --nnodes=1:2
- --nproc_per_node=1
- /workspace/examples/imagenet.py
- "--arch=resnet18"
- "--epochs=20"
Expand All @@ -39,4 +38,4 @@ spec:
ports:
- containerPort: 29400
protocol: TCP
name: elastic-port
name: pytorchjob-port
2 changes: 2 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ require (
github.com/gogo/protobuf v1.3.1 // indirect
github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7 // indirect
github.com/golang/protobuf v1.4.3 // indirect
github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db // indirect
github.com/google/go-cmp v0.5.4 // indirect
github.com/google/gofuzz v1.1.0 // indirect
github.com/google/uuid v1.1.1 // indirect
Expand All @@ -62,6 +63,7 @@ require (
github.com/prometheus/common v0.18.0 // indirect
github.com/prometheus/procfs v0.6.0 // indirect
github.com/spf13/pflag v1.0.5 // indirect
github.com/syndtr/goleveldb v1.0.0 // indirect
go.uber.org/atomic v1.6.0 // indirect
go.uber.org/multierr v1.5.0 // indirect
go.uber.org/zap v1.15.0 // indirect
Expand Down
3 changes: 3 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,7 @@ github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QD
github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI=
github.com/golang/protobuf v1.4.3 h1:JjCZWpVbqXDqFVmTfYWEVTMIYrL/NPdPSCHPJ0T/raM=
github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI=
github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db h1:woRePGFeVFfLKN/pOkfl+p/TAqKOfFu+7KPlMVpok/w=
github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ=
github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ=
Expand Down Expand Up @@ -499,6 +500,8 @@ github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81P
github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA=
github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0=
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/syndtr/goleveldb v1.0.0 h1:fBdIW9lB4Iz0n9khmH8w27SJ3QEJ7+IgjPEwGSZiFdE=
github.com/syndtr/goleveldb v1.0.0/go.mod h1:ZVVdQEZoIme9iO1Ch2Jdy24qqXrMMOU6lpPAyBWyWuQ=
github.com/tidwall/pretty v1.0.0/go.mod h1:XNkn88O1ChpSDQmQeStsy+sBenx6DDtFZJxhVysOjyk=
github.com/tmc/grpc-websocket-proxy v0.0.0-20170815181823-89b8d40f7ca8/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U=
github.com/tmc/grpc-websocket-proxy v0.0.0-20190109142713-0ad062ec5ee5/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U=
Expand Down
18 changes: 9 additions & 9 deletions pkg/controller.v1/pytorch/elastic.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,27 +51,27 @@ const (
// EnvNProcPerNode is the environment variable name for the number of processes per node.
EnvNProcPerNode = "PET_N_PROC_PER_NODE"
// EnvNNodes is the environment variable name for the number of nodes.
EnvNNodes = "PET_NODES"
EnvNNodes = "PET_NNODES"

defaultRDZVPort int32 = 29400
)

var (
instance *ElasticEnvVarGenerator
once sync.Once
elasticGenerator EnvVarGenerator
onceElastic sync.Once
)

type EnvVarGenerator interface {
Generate(job *pytorchv1.PyTorchJob) []corev1.EnvVar
Generate(job *pytorchv1.PyTorchJob) ([]corev1.EnvVar, error)
}

type ElasticEnvVarGenerator struct{}

func GetElasticEnvVarGenerator() *ElasticEnvVarGenerator {
once.Do(func() {
instance = &ElasticEnvVarGenerator{}
func GetElasticEnvVarGenerator() EnvVarGenerator {
onceElastic.Do(func() {
elasticGenerator = &ElasticEnvVarGenerator{}
})
return instance
return elasticGenerator
}

func (e ElasticEnvVarGenerator) Generate(
Expand Down Expand Up @@ -148,7 +148,7 @@ func (e ElasticEnvVarGenerator) generateEnvNNodes(job *pytorchv1.PyTorchJob) (*c
return &corev1.EnvVar{
Name: EnvNNodes,
Value: fmt.Sprintf("%d:%d",
job.Spec.ElasticPolicy.MinReplicas, job.Spec.ElasticPolicy.MaxReplicas),
*job.Spec.ElasticPolicy.MinReplicas, *job.Spec.ElasticPolicy.MaxReplicas),
}, nil
}

Expand Down
48 changes: 48 additions & 0 deletions pkg/controller.v1/pytorch/master.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package pytorch

import (
"strconv"
"strings"
"sync"

pytorchv1 "github.com/kubeflow/training-operator/pkg/apis/pytorch/v1"
corev1 "k8s.io/api/core/v1"
)

var (
masterGenerator EnvVarGenerator
once sync.Once
)

type MasterEnvVarGenerator struct {
}

func GetMasterEnvVarGenerator() EnvVarGenerator {
once.Do(func() {
masterGenerator = &MasterEnvVarGenerator{}
})
return masterGenerator
}

func (e MasterEnvVarGenerator) Generate(
job *pytorchv1.PyTorchJob) ([]corev1.EnvVar, error) {
envVars := []corev1.EnvVar{}
if job.Spec.PyTorchReplicaSpecs[pytorchv1.PyTorchReplicaTypeMaster] != nil {
masterPort, err := getPortFromPyTorchJob(job, pytorchv1.PyTorchReplicaTypeMaster)
if err != nil {
return nil, err
}

masterAddr := genGeneralName(job.Name, strings.ToLower(string(pytorchv1.PyTorchReplicaTypeMaster)), strconv.Itoa(0))

envVars = append(envVars, corev1.EnvVar{
Name: "MASTER_PORT",
Value: strconv.Itoa(int(masterPort)),
})
envVars = append(envVars, corev1.EnvVar{
Name: "MASTER_ADDR",
Value: masterAddr,
})
}
return envVars, nil
}
42 changes: 11 additions & 31 deletions pkg/controller.v1/pytorch/pytorch.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ package pytorch

import (
"fmt"
"strconv"
"strings"

commonv1 "github.com/kubeflow/common/pkg/apis/common/v1"
Expand All @@ -30,49 +29,30 @@ func SetPodEnv(obj interface{}, podTemplateSpec *corev1.PodTemplateSpec, rtype,
return fmt.Errorf("%+v is not a type of PyTorchJob", obj)
}

rank, err := strconv.Atoi(index)
if err != nil {
return err
}

masterPort, err := getPortFromPyTorchJob(pytorchjob, pytorchv1.PyTorchReplicaTypeMaster)
if err != nil {
return err
}

masterAddr := genGeneralName(pytorchjob.Name, strings.ToLower(string(pytorchv1.PyTorchReplicaTypeMaster)), strconv.Itoa(0))
if rtype == strings.ToLower(string(pytorchv1.PyTorchReplicaTypeMaster)) {
if rank != 0 {
return fmt.Errorf("invalid config: There should be only a single master with index=0")
}
masterAddr = "localhost"
}

for i := range podTemplateSpec.Spec.Containers {
if len(podTemplateSpec.Spec.Containers[i].Env) == 0 {
podTemplateSpec.Spec.Containers[i].Env = make([]corev1.EnvVar, 0)
}
podTemplateSpec.Spec.Containers[i].Env = append(podTemplateSpec.Spec.Containers[i].Env, corev1.EnvVar{
Name: "MASTER_PORT",
Value: strconv.Itoa(int(masterPort)),
})
podTemplateSpec.Spec.Containers[i].Env = append(podTemplateSpec.Spec.Containers[i].Env, corev1.EnvVar{
Name: "MASTER_ADDR",
Value: masterAddr,
})
podTemplateSpec.Spec.Containers[i].Env = append(podTemplateSpec.Spec.Containers[i].Env, corev1.EnvVar{
Name: "PYTHONUNBUFFERED",
Value: "0",
})

// TODO(gaocegege): claim MASTER_ADDRESS.
envVars, err := GetElasticEnvVarGenerator().Generate(pytorchjob)
envVars, err := GetMasterEnvVarGenerator().Generate(pytorchjob)
if err != nil {
return err
}
// Set elastic related environment variables.
podTemplateSpec.Spec.Containers[i].Env = append(
podTemplateSpec.Spec.Containers[i].Env, envVars...)

envVars, err = GetElasticEnvVarGenerator().Generate(pytorchjob)
if err != nil {
return err
}
// Set elastic related environment variables.
podTemplateSpec.Spec.Containers[i].Env = append(podTemplateSpec.Spec.Containers[i].Env,
envVars...)
podTemplateSpec.Spec.Containers[i].Env = append(
podTemplateSpec.Spec.Containers[i].Env, envVars...)
}

return nil
Expand Down

0 comments on commit df2c1ed

Please sign in to comment.