Skip to content

Commit

Permalink
Merge pull request #935 from bghira/main
Browse files Browse the repository at this point in the history
documentation updates, deepspeed config reference error fix
  • Loading branch information
bghira authored Sep 3, 2024
2 parents 91e3f59 + f56e6f4 commit 9595e6e
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 8 deletions.
2 changes: 2 additions & 0 deletions OPTIONS.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ The JSON filename expected is `config.json` and the key names are the same as th

The script `configure.py` in the project root can be used via `python configure.py` to set up a `config.json` file with mostly-ideal default settings.

> ⚠️ For users located in countries where Hugging Face Hub is not readily accessible, you should add `HF_ENDPOINT=https://hf-mirror.com` to your `~/.bashrc` or `~/.zshrc` depending on which `$SHELL` your system uses.
---

## 🌟 Core Model Configuration
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ LoRA and full-rank tuning are tested to work on an M3 Max with 128G memory, taki
- A100-80G (Full tune with DeepSpeed)
- A100-40G (LoRA, LoKr)
- 3090 24G (LoRA, LoKr)
- 3080 16G (int4, LoRA, LoKr)
- 4060 Ti, 3080 16G (int8, LoRA, LoKr)

Flux prefers being trained with multiple large GPUs but a single 16G card should be able to do it with quantisation of the transformer and text encoders.

Expand Down
6 changes: 3 additions & 3 deletions configure.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,8 +640,8 @@ def configure_env():

quantization = (
prompt_user(
f"Would you like to enable model quantization? {'NOTE: Currently, a bug prevents multi-GPU training with LoRA' if use_lora else ''}. (y/n)",
"n",
f"Would you like to enable model quantization? {'NOTE: Currently, a bug prevents multi-GPU training with LoRA' if use_lora else ''}. ([y]/n)",
"y",
).lower()
== "y"
)
Expand All @@ -656,7 +656,7 @@ def configure_env():
if quantization_type:
print(f"Invalid quantization type: {quantization_type}")
quantization_type = prompt_user(
f"Choose quantization type (Options: {'/'.join(quantised_precision_levels)})",
f"Choose quantization type. int4 may only work on A100, H100, or Apple systems. (Options: {'/'.join(quantised_precision_levels)})",
"int8-quanto",
)
env_contents["--base_model_precision"] = quantization_type
Expand Down
8 changes: 4 additions & 4 deletions helpers/training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,9 @@ def _misc_init(self):
self.grad_norm = None
self.extra_lr_scheduler_kwargs = {}
StateTracker.set_global_step(self.state["global_step"])
self.config.use_deepspeed_optimizer, self.config.use_deepspeed_scheduler = (
prepare_model_for_deepspeed(self.accelerator, self.config)
)

def set_model_family(self, model_family: str = None):
model_family = getattr(self.config, "model_family", model_family)
Expand Down Expand Up @@ -900,9 +903,6 @@ def _recalculate_training_steps(self):
def init_optimizer(self):
logger.info(f"Learning rate: {self.config.learning_rate}")
extra_optimizer_args = {"lr": self.config.learning_rate}
self.config.use_deepspeed_optimizer, self.config.use_deepspeed_scheduler = (
prepare_model_for_deepspeed(self.accelerator, self.config)
)
# Initialize the optimizer
optimizer_args_from_config, optimizer_class = (
determine_optimizer_class_with_config(
Expand Down Expand Up @@ -1203,7 +1203,7 @@ def init_benchmark_base_model(self):
# on deepspeed, every process has to enter. otherwise, only the main process does.
return
logger.info(
f"Benchmarking base model for comparison. Set DISABLE_BENCHMARK=true to disable this behaviour."
f"Benchmarking base model for comparison. Supply `--disable_benchmark: true` to disable this behaviour."
)
if is_lr_scheduler_disabled(self.config.optimizer):
self.optimizer.eval()
Expand Down
9 changes: 9 additions & 0 deletions notebook.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -675,6 +675,15 @@
"trainer.init_ema_model()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"trainer.move_models(destination=\"accelerator\")"
]
},
{
"cell_type": "code",
"execution_count": 21,
Expand Down

0 comments on commit 9595e6e

Please sign in to comment.