Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update JAX integration proposal #2165

Merged
merged 1 commit into from
Jul 15, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 8 additions & 58 deletions docs/proposals/jax-integration.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,14 +76,12 @@ As a DevOps engineer, I want to manage JAX distributed training jobs using the K

##### Key Validations

1. **Coordinator Role Validation**:
- Ensure exactly one Coordinator role with `processId` set to `0` and its `replicas` is set to `1`.
2. **Worker Role Validation**:
- Ensure at least one Worker replica.
- Ensure the `replicas` field for each role is greater than `0`.
3. **JAX Parameters Validation**:
1. **Worker Role Validation**:
- Ensure at least one Worker replica with `processId` set to `0` that will work as coordinator.
- Ensure the `replicas` field is greater than `0`.
2. **JAX Parameters Validation**:
- Ensure `coordinatorAddress`, `numProcesses`, and `processId` are set and valid across all roles.
4. **Pod Specification Validation**:
3. **Pod Specification Validation**:
- Ensure necessary container specifications and `restartPolicy` are correctly set.
- Validate `coordinatorAddress` follows the `host:port` format.

Expand All @@ -100,14 +98,6 @@ metadata:
name: example-jaxjob
spec:
jaxReplicaSpecs:
Coordinator:
replicas: 1
restartPolicy: OnFailure
template:
spec:
containers:
- name: jax-coordinator
image: ghcr.io/kubeflow/jax:latest
Worker:
replicas: 1
restartPolicy: OnFailure
Expand Down Expand Up @@ -158,8 +148,6 @@ const (
JAXJobSingular = "jaxjob"
// JAXJobFrameworkName is the name of the ML Framework
JAXJobFrameworkName = "jax"
// JAXJobReplicaTypeCoordinator is the type of Coordinator of distributed JAX
JAXJobReplicaTypeCoordinator ReplicaType = "Coordinator"
// JAXJobReplicaTypeWorker is the type for workers of distributed JAX.
JAXJobReplicaTypeWorker ReplicaType = "Worker"
)
Expand Down Expand Up @@ -199,7 +187,6 @@ type JAXJobSpec struct {
// A map of JAXReplicaType (type) to ReplicaSpec (value). Specifies the JAX cluster configuration.
// For example,
// {
// "Coordinator": JAXReplicaSpec,
// "Worker": JAXReplicaSpec,
// }
JAXReplicaSpecs map[ReplicaType]*ReplicaSpec `json:"jaxReplicaSpecs"`
Expand All @@ -223,48 +210,10 @@ type JAXJobList struct {

func init() {
SchemeBuilder.Register(&JAXJob{}, &JAXJobList{})
SchemeBuilder.SchemeBuilder.Register(addJAXJobDefaultingFuncs)
SchemeBuilder.SchemeBuilder.Register(addJAXDefaultingFuncs)
}

```
##### Resulting Coordinator
```yaml
apiVersion: v1
kind: Service
metadata:
name: jax-coordinator
spec:
selector:
app: jax-coordinator
ports:
- port: 6666
targetPort: 6666
```
```yaml
apiVersion: v1
kind: Pod
metadata:
name: jax-coordinator
labels:
app: jax-coordinator
spec:
containers:
- image: ghcr.io/kubeflow/jax:latest
imagePullPolicy: IfNotPresent
name: coordinator
env:
- name: JAX_COORDINATOR_ADDRESS
value: '127.0.0.1:6666'
- name: JAX_NUM_PROCESSES
value: 1
- name: JAX_PROCESS_ID
value: 0
# process 0 is coordinator
ports:
- name: coordinatorPort
containerPort: 6666
restartPolicy: OnFailure
```

##### Resulting Worker

Expand All @@ -290,7 +239,8 @@ spec:
- name: JAX_NUM_PROCESSES
value: 1
- name: JAX_PROCESS_ID
value: 1
value: 0
# process 0 is coordinator
restartPolicy: OnFailure

```
Expand Down
Loading