Skip to content

Commit

Permalink
Update JAX integration proposal (#2165)
Browse files Browse the repository at this point in the history
Signed-off-by: Sandipan Panda <samparksandipan@gmail.com>
  • Loading branch information
sandipanpanda committed Jul 15, 2024
1 parent bcba864 commit ee736a7
Showing 1 changed file with 8 additions and 58 deletions.
66 changes: 8 additions & 58 deletions docs/proposals/jax-integration.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,14 +76,12 @@ As a DevOps engineer, I want to manage JAX distributed training jobs using the K

##### Key Validations

1. **Coordinator Role Validation**:
- Ensure exactly one Coordinator role with `processId` set to `0` and its `replicas` is set to `1`.
2. **Worker Role Validation**:
- Ensure at least one Worker replica.
- Ensure the `replicas` field for each role is greater than `0`.
3. **JAX Parameters Validation**:
1. **Worker Role Validation**:
- Ensure at least one Worker replica with `processId` set to `0` that will work as coordinator.
- Ensure the `replicas` field is greater than `0`.
2. **JAX Parameters Validation**:
- Ensure `coordinatorAddress`, `numProcesses`, and `processId` are set and valid across all roles.
4. **Pod Specification Validation**:
3. **Pod Specification Validation**:
- Ensure necessary container specifications and `restartPolicy` are correctly set.
- Validate `coordinatorAddress` follows the `host:port` format.

Expand All @@ -100,14 +98,6 @@ metadata:
name: example-jaxjob
spec:
jaxReplicaSpecs:
Coordinator:
replicas: 1
restartPolicy: OnFailure
template:
spec:
containers:
- name: jax-coordinator
image: ghcr.io/kubeflow/jax:latest
Worker:
replicas: 1
restartPolicy: OnFailure
Expand Down Expand Up @@ -158,8 +148,6 @@ const (
JAXJobSingular = "jaxjob"
// JAXJobFrameworkName is the name of the ML Framework
JAXJobFrameworkName = "jax"
// JAXJobReplicaTypeCoordinator is the type of Coordinator of distributed JAX
JAXJobReplicaTypeCoordinator ReplicaType = "Coordinator"
// JAXJobReplicaTypeWorker is the type for workers of distributed JAX.
JAXJobReplicaTypeWorker ReplicaType = "Worker"
)
Expand Down Expand Up @@ -199,7 +187,6 @@ type JAXJobSpec struct {
// A map of JAXReplicaType (type) to ReplicaSpec (value). Specifies the JAX cluster configuration.
// For example,
// {
// "Coordinator": JAXReplicaSpec,
// "Worker": JAXReplicaSpec,
// }
JAXReplicaSpecs map[ReplicaType]*ReplicaSpec `json:"jaxReplicaSpecs"`
Expand All @@ -223,48 +210,10 @@ type JAXJobList struct {

func init() {
SchemeBuilder.Register(&JAXJob{}, &JAXJobList{})
SchemeBuilder.SchemeBuilder.Register(addJAXJobDefaultingFuncs)
SchemeBuilder.SchemeBuilder.Register(addJAXDefaultingFuncs)
}

```
##### Resulting Coordinator
```yaml
apiVersion: v1
kind: Service
metadata:
name: jax-coordinator
spec:
selector:
app: jax-coordinator
ports:
- port: 6666
targetPort: 6666
```
```yaml
apiVersion: v1
kind: Pod
metadata:
name: jax-coordinator
labels:
app: jax-coordinator
spec:
containers:
- image: ghcr.io/kubeflow/jax:latest
imagePullPolicy: IfNotPresent
name: coordinator
env:
- name: JAX_COORDINATOR_ADDRESS
value: '127.0.0.1:6666'
- name: JAX_NUM_PROCESSES
value: 1
- name: JAX_PROCESS_ID
value: 0
# process 0 is coordinator
ports:
- name: coordinatorPort
containerPort: 6666
restartPolicy: OnFailure
```

##### Resulting Worker

Expand All @@ -290,7 +239,8 @@ spec:
- name: JAX_NUM_PROCESSES
value: 1
- name: JAX_PROCESS_ID
value: 1
value: 0
# process 0 is coordinator
restartPolicy: OnFailure

```
Expand Down

0 comments on commit ee736a7

Please sign in to comment.