Skip to content

Commit

Permalink
sagemaker fixes and improvements (#708)
Browse files Browse the repository at this point in the history
* adding aws sagemaker examples to examples readme

* refactoring and correcting documentation
  • Loading branch information
pacman100 authored Sep 22, 2022
1 parent 82a7afd commit 6a39d01
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 2 deletions.
21 changes: 20 additions & 1 deletion docs/source/usage_guides/sagemaker.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,26 @@ You can find your model data at: s3://your-bucket/accelerate-sagemaker-1-2021-04

### Distributed Training: Data Parallelism

*currently in development, will be supported soon.*
Set up the accelerate config by running `accelerate config` and answer the SageMaker questions and set it up.
To use SageMaker DDP, select it when asked
`What is the distributed mode? ([0] No distributed training, [1] data parallelism):`.
Example config below:
```yaml
base_job_name: accelerate-sagemaker-1
compute_environment: AMAZON_SAGEMAKER
distributed_type: DATA_PARALLEL
ec2_instance_type: ml.p3.16xlarge
iam_role_name: xxxxx
image_uri: null
mixed_precision: fp16
num_machines: 1
profile: xxxxx
py_version: py38
pytorch_version: 1.10.2
region: us-east-1
transformers_version: 4.17.0
use_cpu: false
```
### Distributed Training: Model Parallelism
Expand Down
3 changes: 3 additions & 0 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,9 @@ To run it in each of these various modes, use the following commands:
### Simple vision example (GANs)

- [huggan project](https://github.com/huggingface/community-events/tree/main/huggan)

### Using AWS SageMaker integration
- [Examples showcasing AWS SageMaker integration of 🤗 Accelerate.](https://github.com/pacman100/accelerate-aws-sagemaker)

## Finer Examples

Expand Down
2 changes: 1 addition & 1 deletion src/accelerate/commands/config/sagemaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def get_sagemaker_input():
)

distributed_type = _ask_field(
"Which type of machine are you using? ([0] No distributed training, [1] data parallelism): ",
"What is the distributed mode? ([0] No distributed training, [1] data parallelism): ",
_convert_sagemaker_distributed_mode,
error_message="Please enter 0 or 1",
)
Expand Down

0 comments on commit 6a39d01

Please sign in to comment.