From 7ca6f7e50da5ddecc32ff970d8616e5aab4732b5 Mon Sep 17 00:00:00 2001 From: Sandipan Panda Date: Wed, 22 May 2024 03:25:29 +0530 Subject: [PATCH] JAX Integration Enhancement Proposal Kubeflow Enhancement Proposal: Integrate JAX with Kubeflow Training Operator Signed-off-by: Sandipan Panda --- docs/proposals/jax-integration.md | 414 ++++++++++++++++++++++++++++++ 1 file changed, 414 insertions(+) create mode 100644 docs/proposals/jax-integration.md diff --git a/docs/proposals/jax-integration.md b/docs/proposals/jax-integration.md new file mode 100644 index 0000000000..1ea91035ca --- /dev/null +++ b/docs/proposals/jax-integration.md @@ -0,0 +1,414 @@ +# Kubeflow Enhancement Proposal: Integrate JAX with Kubeflow Training Operator for Distributed Training on Kubernetes + + +## Table of Contents +- [Release Signoff Checklist](#release-signoff-checklist) +- [Summary](#summary) +- [Motivation](#motivation) + - [Goals](#goals) + - [Non-Goals](#non-goals) +- [Proposal](#proposal) + - [User Stories (Optional)](#user-stories-optional) + - [Story 1](#story-1) + - [Story 2](#story-2) + - [Story 3](#story-3) + - [Notes/Constraints/Caveats (Optional)](#notesconstraintscaveats-optional) + - [Risks and Mitigations](#risks-and-mitigations) +- [Design Details](#design-details) + - [Test Plan](#test-plan) + - [Graduation Criteria](#graduation-criteria) + - [Upgrade / Downgrade Strategy](#upgrade--downgrade-strategy) + - [Version Skew Strategy](#version-skew-strategy) +- [Production Readiness Review Questionnaire](#production-readiness-review-questionnaire) + - [Feature Enablement and Rollback](#feature-enablement-and-rollback) + - [Rollout, Upgrade and Rollback Planning](#rollout-upgrade-and-rollback-planning) + - [Monitoring Requirements](#monitoring-requirements) + - [Dependencies](#dependencies) + - [Scalability](#scalability) + - [Troubleshooting](#troubleshooting) +- [Implementation History](#implementation-history) +- [Drawbacks](#drawbacks) +- [Alternatives](#alternatives) +- [Infrastructure Needed (Optional)](#infrastructure-needed-optional) + + +## Release Signoff Checklist + +Items marked with (R) are required *prior to targeting to a milestone / release*. + +- [ ] (R) Enhancement issue in release milestone, which links to KEP dir in [kubeflow/training-operator/docs/proposals] (not the initial KEP PR) +- [ ] (R) KEP approvers have approved the KEP status as `implementable` +- [ ] (R) Design details are appropriately documented +- [ ] (R) Graduation criteria is in place +- [ ] (R) Production readiness review completed +- [ ] (R) Production readiness review approved +- [ ] "Implementation History" section is up-to-date for milestone +- [ ] User-facing documentation has been created in [kubeflow/website/content/en/docs/components/training], for publication to [kubeflow.org] +- [ ] Supporting documentation—e.g., additional design documents, links to mailing list discussions/SIG meetings, relevant PRs/issues, release notes + +## Summary + +This Kubeflow Enhancement Proposal (KEP) aims to integrate [JAX](http://jax.readthedocs.io/), a popular machine learning framework, into the Kubeflow Training Operator to enable distributed training and fine-tuning jobs on Kubernetes. This will involve creating a new Kubernetes Custom Resource Definition (CRD) for JAX (JaxJob) and updating the Training Operator controller to support it. The enhancement will also include integrating JAX with the Training Operator Python SDK to provide simple APIs for Data Scientists to create and manage JaxJob on Kubernetes clusters. + +## Motivation + +JAX has emerged as a popular machine learning framework for high-performance numerical computing and accelerated training on GPUs and TPUs. With its "multi-controller" programming model, JAX is particularly well-suited for distributed training using the Single Program, Multiple Data (SPMD) paradigm. However, running distributed JAX jobs on Kubernetes requires robust orchestration to handle the complexities of multi-process coordination. + +Integrating JAX into the Kubeflow Training Operator will simplify distributed JAX training on Kubernetes, providing Data Scientists and ML Engineers with seamless APIs to deploy and manage JAX jobs. This proposal aims to create a new Kubernetes Custom Resource (CR) for JAX, update the Training Operator controller to support it, and provide an intuitive Python SDK for managing JAX jobs. + +### Goals + +- Develop a new Kubernetes CRD named `JaxJob` for managing JAX distributed training jobs. +- Update the Training Operator to manage `JaxJob` resources. +- Extend the Training Operator Python SDK to support JAX job creation and management. +- Ensure seamless integration with JAX's distributed training API. + +### Non-Goals + +- Support for non-distributed JAX training jobs (single-process training can continue to use existing mechanisms). +- General-purpose distributed computing support outside of JAX. + +## Proposal + +### User Stories (Optional) + +#### Story 1 + +As a Data Scientist, I want to use the Kubeflow Training Operator to run distributed JAX training jobs on a Kubernetes cluster so that I can leverage Kubernetes' scalability and resource management capabilities. + +#### Story 2 + +As a Machine Learning Engineer, I want to use a simple Python SDK to define and launch JAX training jobs on Kubernetes, abstracting away the complexity of Kubernetes configurations. + +#### Story 3 + +As a DevOps engineer, I want to manage JAX distributed training jobs using the Kubeflow Training Operator, so I can provide a consistent and scalable infrastructure for machine learning workloads. + +### Notes/Constraints/Caveats (Optional) + +- Ensuring compatibility with different versions of JAX and Kubernetes. +- Adequate documentation must be provided for users to understand how to configure and run `JaxJob` resources. + +### Risks and Mitigations + +- **Risk**: Compatibility issues with JAX updates. + - **Mitigation**: Regularly test the integration with new JAX releases and provide timely updates. +- **Risk**: Resource contention in large clusters. + - **Mitigation**: Implement robust scheduling and resource management policies in the Training Operator. + +## Design Details + +### Custom Resource Definition (CRD) for JaxJob + +Define a new CRD for JaxJob that includes specifications for: +- Number of processes +- Coordinator address +- Environment variables for JAX distributed training +- Container image and resource requirements +- Job scheduling and retries + +#### API (CRD and resulting objects) + +##### Custom Resource Definition + +The JaxJob CRD will define the schema for JAX distributed training jobs, including specifications for the coordinator, worker processes, and necessary environment variables. + +```yaml +--- +apiVersion: kubeflow.org/v1 +kind: JaxJob +metadata: + name: example-jaxjob + namespace: default +spec: + replicaSpecs: + - replicas: 1 + ReplicaType: MASTER + template: + spec: + containers: + - image: 'ghcr.io/nvidia/jax:upstream-pax-2024-03-12' + name: master + imagePullPolicy: IfNotPresent + restartPolicy: OnFailure + - replicas: 2 + ReplicaType: Worker + template: + metadata: + annotations: null + sidecar.istio.io/inject: 'false' + spec: + containers: + - name: jax-container + image: 'ghcr.io/nvidia/jax:upstream-pax-2024-03-12' + command: + - python + - train.py + env: + - name: JAX_COORDINATOR_ADDRESS + value: jax-coordinator.default.svc.cluster.local + - name: JAX_NUM_PROCESSES + value: '2' + - name: JAX_PROCESS_ID + valueFrom: null + fieldRef: null + fieldPath: 'metadata.labels[''statefulset.kubernetes.io/pod-name'']' + restartPolicy: OnFailure + +``` +Available image options: +1. Bitnami package for JAX: https://hub.docker.com/r/bitnami/jax +2. NVDIA JAX Toolbox: https://github.com/NVIDIA/JAX-Toolbox/pkgs/container/jax + +This YAML file defines a JaxJob with two worker replicas. Each worker runs a container based on a JAX image and executes a Python script (e.g. train.py). The environment variables necessary for JAX's distributed training are set within each container. + +##### Resulting Master + +The Master component will act as the coordinator for the distributed JAX processes. It will be responsible for initializing the JAX distributed environment and managing communication between worker processes. + +```yaml +--- +kind: Service +apiVersion: v1 +metadata: + name: jaxjob-master-${job_id} +spec: + selector: + app: jaxjob-master-${job_id} + ports: + - port: 23456 + targetPort: 23456 +``` +```yaml +--- +apiVersion: v1 +kind: Pod +metadata: + name: jaxjob-master-${job_id} + labels: + app: jaxjobmaster-${job_id} +spec: + containers: + - image: ghcr.io/nvidia/jax:upstream-pax-2024-03-12 + imagePullPolicy: IfNotPresent + name: master + env: + - name: MASTER_PORT + value: "23456" + - name: MASTER_ADDR + value: "localhost" + - name: WORLD_SIZE + value: "3" + # Rank 0 is the master + - name: RANK + value: "0" + ports: + - name: masterPort + containerPort: 23456 +restartPolicy: OnFailure + +``` + +##### Resulting Worker + +Worker components will execute the distributed training tasks. Each worker will be assigned a unique process ID and will connect to the coordinator using the specified environment variables. + +```yaml +--- +apiVersion: v1 +kind: Pod +metadata: + name: jaxjob-worker-${job_id} +spec: + containers: + - image: ghcr.io/nvidia/jax:upstream-pax-2024-03-12 + imagePullPolicy: IfNotPresent + name: worker + env: + - name: MASTER_PORT + value: "23456" + - name: MASTER_ADDR + value: jaxjob-master-${job_id} + - name: WORLD_SIZE + value: "3" + - name: RANK + value: "1" +restartPolicy: OnFailure + +``` +### Controller Update + +Update the Training Operator controller to handle JaxJob resources: +- Watch for JaxJob resources and manage their lifecycle. +- Set up the required environment variables for each container in the JaxJob. +- Ensure that JAX distributed training initialization is handled correctly using `jax.distributed.initialize`. + +### Python SDK Integration + +Enhance the Training Operator Python SDK to include: +- APIs for defining and launching JaxJob resources. +- Utility functions for setting up distributed training environments. +- Example scripts and templates for common JAX training scenarios. + +### Test Plan + +- Develop unit tests for the new JaxJob controller functionalities. +- Create integration tests to validate end-to-end JaxJob execution on a Kubernetes cluster. +- Implement e2e tests to ensure JaxJob resources are correctly managed and distributed training works as expected. + +#### Prerequisite testing updates + +- Ensure existing tests for other training operators remain unaffected. + +#### Unit tests + +- Test JaxJob creation, update, and deletion. +- Validate environment variable setup for JAX distributed training. + +#### Integration tests + +- Implement integration tests to validate the end-to-end workflow of submitting and running a JaxJob on a Kubernetes cluster. + +#### e2e tests + +- Develop end-to-end tests to simulate real-world scenarios of running distributed JAX training jobs, including failure recovery and scaling. + +### Graduation Criteria + +#### Alpha + +- Initial implementation of JaxJob CRD and controller. +- Basic integration with the Python SDK. +- Initial set of unit and integration tests. + +#### Beta + +- Comprehensive e2e tests and feedback from early adopters. +- Performance benchmarks and optimizations. +- Detailed documentation and examples. + +#### GA + +- Proven stability and reliability in production environments. +- Wide adoption and positive feedback from the community. +- Complete set of monitoring and troubleshooting tools. + +### Upgrade / Downgrade Strategy + +- Provide clear instructions for upgrading and downgrading the Training Operator with JaxJob support. +- Ensure backward compatibility with existing Training Operator functionalities. + +### Version Skew Strategy + +- Ensure the `JaxJob` controller is compatible with the existing Kubernetes versions supported by Kubeflow. +- Handle version mismatches gracefully, providing clear error messages and guidance for resolution. + +## Production Readiness Review Questionnaire + +### Feature Enablement and Rollback + +- **How can this feature be enabled / disabled in a live cluster?** + - This feature can be enabled by deploying the updated Training Operator with support for the JaxJob CRD. It can be disabled by removing the JaxJob CRD and related controller code. + +- **Does enabling the feature change any default behavior?** + - No, it introduces new functionality without changing existing behaviors. + +- **Can the feature be disabled once it has been enabled (i.e. can we roll back the enablement)?** + - Yes, by removing the JaxJob CRD and associated controller logic. + +- **What happens if we reenable the feature if it was previously rolled back?** + - The JaxJob resources will be recognized and managed by the Training Operator as intended. + +- **Are there any tests for feature enablement/disablement?** + - Yes, integration tests will cover scenarios for enabling and disabling the feature. + +### Rollout, Upgrade and Rollback Planning + +- **How can a rollout or rollback fail? Can it impact already running workloads?** + - Rollout or rollback could fail if there are issues with the CRD definition or controller logic. This might impact running JaxJobs but not other workloads managed by the Training Operator. + +- **What specific metrics should inform a rollback?** + - Failure rates of JaxJob creations and completions, error logs from the Training Operator. + +- **Were upgrade and rollback tested? Was the upgrade->downgrade->upgrade path tested?** + - Manual testing will be performed initially, with plans to automate these tests. + +- **Is the rollout accompanied by any deprecations and/or removals of features, APIs, fields of API types, flags, etc.?** + - No, this is an additive feature. + +### Monitoring Requirements + +- **How can an operator determine if the feature is in use by workloads?** + - By monitoring the creation and status of JaxJob resources via Kubernetes API and custom metrics. + +- **How can someone using this feature know that it is working for their instance?** + - Successful creation and completion of JaxJob resources, logs, and metrics. + +- **What are the reasonable SLOs (Service Level Objectives) for the enhancement?** + - High availability and low failure rates for JaxJob submissions and executions. + +- **What are the SLIs (Service Level Indicators) an operator can use to determine the health of the service?** + - Metrics on job creation success rates, job completion times, and resource utilization. + +- **Are there any missing metrics that would be useful to have to improve observability of this feature?** + - Custom metrics for detailed tracking of JaxJob lifecycle events. + +### Dependencies + +- **Does this feature depend on any specific services running in the cluster?** + - Dependence on Kubernetes API server and etcd for resource management. + +### Scalability + +- **Will enabling / using this feature result in any new API calls?** + - Yes, new API calls related to the creation, update, and deletion of JaxJob resources. + +- **Will enabling / using this feature result in introducing new API types?** + - Yes, the JaxJob CRD. + +- **Will enabling / using this feature result in any new calls to the cloud provider?** + - No direct calls, but increased Kubernetes resource usage might indirectly lead to more cloud provider interactions. + +- **Will enabling / using this feature result in increasing size or count of the existing API objects?** + - Yes, additional JaxJob objects. + +- **Will enabling / using this feature result in increasing time taken by any operations covered by existing SLIs/SLOs?** + - Slight increase in controller processing time due to additional resource management. + +- **Will enabling / using this feature result in non-negligible increase of resource usage (CPU, RAM, disk, IO, ...) in any components?** + - Yes, additional resource usage by the Training Operator controller. + +- **Can enabling / using this feature result in resource exhaustion of some node resources (PIDs, sockets, inodes, etc.)?** + - Unlikely, but proper resource requests and limits should be set for JaxJob containers. + +### Troubleshooting + +- **How does this feature react if the API server and/or etcd is unavailable?** + - JaxJob creation and management will fail until the API server and etcd are available again. + +- **What are other known failure modes?** + - Failure mode: JaxJob creation errors. + - Detection: Monitor error logs from the Training Operator. + - Mitigations: Validate CRD and controller logic. + - Diagnostics: Error logs from the Training Operator controller. + - Testing: Integration tests to cover failure scenarios. + +- **What steps should be taken if SLOs are not being met to determine the problem?** + - Check the status of the Training Operator controller, inspect logs for errors, and verify the configuration of JaxJob resources. + +## Implementation History + +- 2024-05-22: Initial KEP draft created. + +## Drawbacks + +- Adds complexity to the Training Operator. +- Potential for increased maintenance burden with ongoing updates and compatibility checks. + +## Alternatives + +- Manual Job Management: Require users to manually manage JAX distributed training jobs using Kubernetes Job resources. This approach lacks the automation and ease-of-use provided by integrating with the Training Operator. +- Using Kubernetes JobSet + +## Infrastructure Needed (Optional) +- GPU and TPU cloud credits for testing JAX distributed training jobs.