Skip to content

Commit

Permalink
Format YAML in JAX integration proposal
Browse files Browse the repository at this point in the history
  • Loading branch information
sandipanpanda committed May 22, 2024
1 parent 4f3e223 commit c45d88e
Showing 1 changed file with 66 additions and 61 deletions.
127 changes: 66 additions & 61 deletions docs/proposals/jax-integration.md
Original file line number Diff line number Diff line change
Expand Up @@ -114,48 +114,47 @@ Define a new CRD for JaxJob that includes specifications for:
The JaxJob CRD will define the schema for JAX distributed training jobs, including specifications for the coordinator, worker processes, and necessary environment variables.

```yaml
apiVersion: "kubeflow.org/v1"
---
apiVersion: kubeflow.org/v1
kind: JaxJob
metadata:
name: example-jaxjob
namespace: default
spec:
replicaSpecs:
- replicas: 1
ReplicaType: MASTER
template:
spec:
containers:
- image: ghcr.io/nvidia/jax:upstream-pax-2024-03-12
name: master
imagePullPolicy: IfNotPresent
restartPolicy: OnFailure
ReplicaType: MASTER
template:
spec:
containers:
- image: 'ghcr.io/nvidia/jax:upstream-pax-2024-03-12'
name: master
imagePullPolicy: IfNotPresent
restartPolicy: OnFailure
- replicas: 2
ReplicaType: Worker
template:
metadata:
annotations:
sidecar.istio.io/inject: "false"
spec:
containers:
- name: jax-container
image: ghcr.io/nvidia/jax:upstream-pax-2024-03-12
command:
- "python"
- "train.py"
env:
# Environment variables for distributed training
- name: JAX_COORDINATOR_ADDRESS
value: "jax-coordinator.default.svc.cluster.local"
# Coordinator address
- name: JAX_NUM_PROCESSES
value: "2"
- name: JAX_PROCESS_ID
valueFrom:
fieldRef:
fieldPath: metadata.labels['statefulset.kubernetes.io/pod-name']
# Unique process ID for each replica
restartPolicy: OnFailure
ReplicaType: Worker
template:
metadata:
annotations: null
sidecar.istio.io/inject: 'false'
spec:
containers:
- name: jax-container
image: 'ghcr.io/nvidia/jax:upstream-pax-2024-03-12'
command:
- python
- train.py
env:
- name: JAX_COORDINATOR_ADDRESS
value: jax-coordinator.default.svc.cluster.local
- name: JAX_NUM_PROCESSES
value: '2'
- name: JAX_PROCESS_ID
valueFrom: null
fieldRef: null
fieldPath: 'metadata.labels[''statefulset.kubernetes.io/pod-name'']'
restartPolicy: OnFailure

```
Available image options:
1. Bitnami package for JAX: https://hub.docker.com/r/bitnami/jax
Expand All @@ -168,6 +167,7 @@ This YAML file defines a JaxJob with two worker replicas. Each worker runs a con
The Master component will act as the coordinator for the distributed JAX processes. It will be responsible for initializing the JAX distributed environment and managing communication between worker processes.

```yaml
---
kind: Service
apiVersion: v1
metadata:
Expand All @@ -180,6 +180,7 @@ spec:
targetPort: 23456
```
```yaml
---
apiVersion: v1
kind: Pod
metadata:
Expand All @@ -188,48 +189,52 @@ metadata:
app: jaxjobmaster-${job_id}
spec:
containers:
- image: ghcr.io/nvidia/jax:upstream-pax-2024-03-12
imagePullPolicy: IfNotPresent
name: master
env:
- name: MASTER_PORT
value: "23456"
- name: MASTER_ADDR
value: "localhost"
- name: WORLD_SIZE
value: "3" # Rank 0 is the master
- name: RANK
value: "0"
ports:
- name: masterPort
containerPort: 23456
- image: ghcr.io/nvidia/jax:upstream-pax-2024-03-12
imagePullPolicy: IfNotPresent
name: master
env:
- name: MASTER_PORT
value: "23456"
- name: MASTER_ADDR
value: "localhost"
- name: WORLD_SIZE
value: "3"
# Rank 0 is the master
- name: RANK
value: "0"
ports:
- name: masterPort
containerPort: 23456
restartPolicy: OnFailure

```

##### Resulting Worker

Worker components will execute the distributed training tasks. Each worker will be assigned a unique process ID and will connect to the coordinator using the specified environment variables.

```yaml
---
apiVersion: v1
kind: Pod
metadata:
name: jaxjob-worker-${job_id}
spec:
containers:
- image: ghcr.io/nvidia/jax:upstream-pax-2024-03-12
imagePullPolicy: IfNotPresent
name: worker
env:
- name: MASTER_PORT
value: "23456"
- name: MASTER_ADDR
value: jaxjob-master-${job_id}
- name: WORLD_SIZE
value: "3"
- name: RANK
value: "1"
- image: ghcr.io/nvidia/jax:upstream-pax-2024-03-12
imagePullPolicy: IfNotPresent
name: worker
env:
- name: MASTER_PORT
value: "23456"
- name: MASTER_ADDR
value: jaxjob-master-${job_id}
- name: WORLD_SIZE
value: "3"
- name: RANK
value: "1"
restartPolicy: OnFailure

```
### Controller Update

Expand Down

0 comments on commit c45d88e

Please sign in to comment.