Skip to content

Commit

Permalink
JAX Integration Enhancement Proposal
Browse files Browse the repository at this point in the history
Kubeflow Enhancement Proposal: Integrate JAX with Kubeflow Training Operator

Signed-off-by: Sandipan Panda <samparksandipan@gmail.com>
  • Loading branch information
sandipanpanda committed Jul 8, 2024
1 parent 2b39d3c commit 895620d
Showing 1 changed file with 281 additions and 0 deletions.
281 changes: 281 additions & 0 deletions docs/proposals/jax-integration.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,281 @@
# Kubeflow Enhancement Proposal: Integrate JAX with Kubeflow Training Operator for Distributed Training on Kubernetes

<!-- toc -->
## Table of Contents
- [Release Signoff Checklist](#release-signoff-checklist)
- [Summary](#summary)
- [Motivation](#motivation)
- [Goals](#goals)
- [Non-Goals](#non-goals)
- [Proposal](#proposal)
- [User Stories (Optional)](#user-stories-optional)
- [Story 1](#story-1)
- [Story 2](#story-2)
- [Story 3](#story-3)
- [Design Details](#design-details)
- [Alternatives](#alternatives)
- [Implementation History](#implementation-history)
<!-- /toc -->

## Release Signoff Checklist

Items marked with (R) are required *prior to targeting to a milestone / release*.

- [ ] (R) Enhancement issue in release milestone, which links to KEP dir in [kubeflow/training-operator/docs/proposals] (not the initial KEP PR)
- [ ] (R) KEP approvers have approved the KEP status as `implementable`
- [ ] (R) Design details are appropriately documented
- [ ] (R) Graduation criteria is in place
- [ ] (R) Production readiness review completed
- [ ] (R) Production readiness review approved
- [ ] "Implementation History" section is up-to-date for milestone
- [ ] User-facing documentation has been created in [kubeflow/website/content/en/docs/components/training], for publication to [kubeflow.org]
- [ ] Supporting documentation—e.g., additional design documents, links to mailing list discussions/SIG meetings, relevant PRs/issues, release notes

## Summary

This Kubeflow Enhancement Proposal (KEP) aims to integrate [JAX](http://jax.readthedocs.io/), a popular machine learning framework, into the Kubeflow Training Operator to enable distributed training and fine-tuning jobs on Kubernetes. This will involve creating a new Kubernetes Custom Resource Definition (CRD) for JAX (JAXJob) and updating the Training Operator controller to support it. The enhancement will also include integrating JAX with the Training Operator Python SDK to provide simple APIs for Data Scientists to create and manage JAXJob on Kubernetes clusters.

## Motivation

JAX has emerged as a popular machine learning framework for high-performance numerical computing and accelerated training on GPUs and TPUs. With its "multi-controller" programming model, JAX is particularly well-suited for distributed training using the Single Program, Multiple Data (SPMD) paradigm. However, running distributed JAX jobs on Kubernetes requires robust orchestration to handle the complexities of multi-process coordination.

Integrating JAX into the Kubeflow Training Operator will simplify distributed JAX training on Kubernetes, providing Data Scientists and ML Engineers with seamless APIs to deploy and manage JAX jobs. This proposal aims to create a new Kubernetes Custom Resource (CR) for JAX, update the Training Operator controller to support it, and provide an intuitive Python SDK for managing JAX jobs.

### Goals

- Introduce a new Custom Resource Definition (CRD) called `JAXJob` for managing JAX distributed training jobs on Kubernetes.
- Update the Kubeflow Training Operator to support the `JAXJob` CRD.
- Extend the Training Operator Python SDK to support JAXjob creation and management.
- Implement the solution to work in CPU environments using the Gloo backend, as GPU environments are not available.

### Non-Goals

- Support for GPU environments is not included in this proposal due to the current limitation of not having GPU resources available for testing.

## Proposal

### User Stories (Optional)

#### Story 1

As a Data Scientist, I want to use the Kubeflow Training Operator to run distributed JAX training jobs on a Kubernetes cluster so that I can leverage Kubernetes' scalability and resource management capabilities.

#### Story 2

As a Machine Learning Engineer, I want to use a simple Python SDK to define and launch JAX training jobs on Kubernetes, abstracting away the complexity of Kubernetes configurations.

#### Story 3

As a DevOps engineer, I want to manage JAX distributed training jobs using the Kubeflow Training Operator, so I can provide a consistent and scalable infrastructure for machine learning workloads.

## Design Details

- Create a new Custom Resource Definition (CRD) for JAX jobs (e.g., `JaxJob`).
- Update the Kubeflow Training Operator to manage `JaxJob` resources.
- Implement a mechanism to initialize and manage JAX distributed training processes using [`jax.distributed.initialize`](https://jax.readthedocs.io/en/latest/_autosummary/jax.distributed.initialize.html#jax.distributed.initialize).
- Extend the Training Operator Python SDK to simplify the creation and management of `JaxJob` resources.
- Configure JAX to use the Gloo backend for CPU-based distributed training.

#### API (CRD and resulting objects)

##### Custom Resource Definition

The JAXJob CRD will define the schema for JAX distributed training jobs, including specifications for the coordinator, worker processes, and necessary environment variables.

```yaml
apiVersion: kubeflow.org/v1
kind: JAXJob
metadata:
name: example-jaxjob
spec:
jaxReplicaSpecs:
Coordinator:
replicas: 1
restartPolicy: OnFailure
template:
spec:
containers:
- name: jax-coordinator
image: ghcr.io/kubeflow/jax:latest
command:
- python
args:
- "train.py"
env:
- name: JAX_COORDINATOR_ADDRESS
value: "127.0.0.1:6666"
- name: JAX_NUM_PROCESSES
value: 1
- name: JAX_PROCESS_ID
value: 0
Worker:
replicas: 1
restartPolicy: OnFailure
template:
spec:
containers:
- name: jax-worker
image: ghcr.io/kubeflow/jax:latest
command:
- python
args:
- "train.py"
env:
- name: JAX_COORDINATOR_ADDRESS
value: "127.0.0.1:6666"
- name: JAX_NUM_PROCESSES
value: 1
- name: JAX_PROCESS_ID
value: 1
```
Available image options:
1. Bitnami package for JAX: https://hub.docker.com/r/bitnami/jax
2. NVDIA JAX Toolbox: https://github.com/NVIDIA/JAX-Toolbox/pkgs/container/jax

`JAX API Definition`

```go
// Copyright 2024 The Kubeflow Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package v1

import (
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
)

const (
// JAXJobDefaultPortName is name of the port used to communicate between Coordinator and Workers.
JAXJobDefaultPortName = "jaxjob-port"
// JAXJobDefaultContainerName is the name of the JAXJob container.
JAXJobDefaultContainerName = "jax"
// JAXJobDefaultPort is default value of the port.
JAXJobDefaultPort = 6666
// JAXJobDefaultRestartPolicy is default RestartPolicy for JAXReplicaSpecs.
JAXJobDefaultRestartPolicy = RestartPolicyNever
// JAXJobKind is the kind name.
JAXJobKind = "JAXJob"
// JAXJobPlural is the JAXJobPlural for JAXJob.
JAXJobPlural = "jaxjobs"
// JAXJobSingular is the singular for JAXJob.
JAXJobSingular = "jaxjob"
// JAXJobFrameworkName is the name of the ML Framework
JAXJobFrameworkName = "jax"
// JAXJobReplicaTypeCoordinator is the type of Coordinator of distributed JAX
JAXJobReplicaTypeCoordinator ReplicaType = "Coordinator"
// JAXJobReplicaTypeWorker is the type for workers of distributed JAX.
JAXJobReplicaTypeWorker ReplicaType = "Worker"
)

// +genclient
// +k8s:deepcopy-gen:interfaces=k8s.io/apimachinery/pkg/runtime.Object
// +resource:path=jaxjob
//+kubebuilder:object:root=true
//+kubebuilder:subresource:status
//+kubebuilder:printcolumn:name="State",type=string,JSONPath=`.status.conditions[-1:].type`
//+kubebuilder:printcolumn:name="Age",type=date,JSONPath=`.metadata.creationTimestamp`
// +kubebuilder:subresource:scale:specpath=.spec.jaxReplicaSpecs.Worker.replicas,statuspath=.status.replicaStatuses.Worker.active,selectorpath=.status.replicaStatuses.Worker.selector

// JAXJob Represents a JAXJob resource.
type JAXJob struct {
// Standard Kubernetes type metadata.
metav1.TypeMeta `json:",inline"`

metav1.ObjectMeta `json:"metadata,omitempty"`

// Specification of the desired state of the JAXJob.
Spec JAXJobSpec `json:"spec,omitempty"`

// Most recently observed status of the JAXJob.
// Read-only (modified by the system).
Status JobStatus `json:"status,omitempty"`
}

// JAXJobSpec is a desired state description of the JAXJob.
type JAXJobSpec struct {
// RunPolicy encapsulates various runtime policies of the distributed training
// job, for example how to clean up resources and how long the job can stay
// active.
//+kubebuilder:validation:Optional
RunPolicy RunPolicy `json:"runPolicy"`

// A map of JAXReplicaType (type) to ReplicaSpec (value). Specifies the JAX cluster configuration.
// For example,
// {
// "Coordinator": JAXReplicaSpec,
// "Worker": JAXReplicaSpec,
// }
JAXReplicaSpecs map[ReplicaType]*ReplicaSpec `json:"jaxReplicaSpecs"`
}

// +k8s:deepcopy-gen:interfaces=k8s.io/apimachinery/pkg/runtime.Object
// +resource:path=jaxjobs
//+kubebuilder:object:root=true

// JAXJobList is a list of JAXJobs.
type JAXJobList struct {
// Standard type metadata.
metav1.TypeMeta `json:",inline"`

// Standard list metadata.
metav1.ListMeta `json:"metadata,omitempty"`

// List of JAXJobs.
Items []JAXJob `json:"items"`
}

func init() {
SchemeBuilder.Register(&JAXJob{}, &JAXJobList{})
SchemeBuilder.SchemeBuilder.Register(addJAXJobDefaultingFuncs)
}

```

##### Resulting Worker

Upon creating a `JaxJob`, the Training Operator will generate the necessary Kubernetes resources, such as Pods and Services, to facilitate distributed training. Each pod will be configured with environment variables required for JAX's `initialize` function.

- **Coordinator Pod:** The pod with `JAX_PROCESS_ID=0` will act as the coordinator.
- **Worker Pods:** The remaining pods will act as workers, connecting to the coordinator.

```yaml
---
apiVersion: v1
kind: Pod
metadata:
name: jaxjob-worker
spec:
containers:
- image: ghcr.io/kubeflow/jax:latest
name: worker
imagePullPolicy: IfNotPresent
env:
- name: JAX_COORDINATOR_ADDRESS
value: '127.0.0.1:6666'
- name: JAX_NUM_PROCESSES
value: 1
- name: JAX_PROCESS_ID
value: 0
restartPolicy: OnFailure

```

## Alternatives

- Integrate JAX to Training Operator with JobSet and `TrainJob` after `TrainJob` API is implemented in Training Operator.
- Using MPI instead of Gloo: While MPI is a mature solution for distributed computing, it adds additional complexity in terms of setup and management. Gloo, being simpler and more lightweight, is preferred for the initial implementation.

## Implementation History

- 2024-05-22: Initial KEP draft created.

0 comments on commit 895620d

Please sign in to comment.