diff --git a/cflearn/misc/toolkit.py b/cflearn/misc/toolkit.py index 59f09089c..cfa26418f 100644 --- a/cflearn/misc/toolkit.py +++ b/cflearn/misc/toolkit.py @@ -194,7 +194,12 @@ def flatten(d: Dict[str, Any], previous_keys: Tuple[str, ...]) -> None: # This is a modified version of https://github.com/sksq96/pytorch-summary # So it can summary `carefree-learn` model structures better -def summary(model: nn.Module, sample_batch: tensor_dict_type) -> None: +def summary( + model: nn.Module, + sample_batch: tensor_dict_type, + *, + return_only: bool = False, +) -> str: def register_hook(module: nn.Module) -> None: def hook(module_: nn.Module, inp: Any, output: Any) -> None: m_name = module_names.get(module_) @@ -331,12 +336,12 @@ def _inject_summary(current_hierarchy: Any, previous_keys: List[str]) -> None: _inject_summary(hierarchy, []) line_length = 120 - print("=" * line_length) + messages = ["=" * line_length] line_format = "{:30} {:>20} {:>40} {:>20}" headers = "Layer (type)", "Input Shape", "Output Shape", "Trainable Param #" line_new = line_format.format(*headers) - print(line_new) - print("-" * line_length) + messages.append(line_new) + messages.append("-" * line_length) total_output = 0 for layer in summary_dict: # name, input_shape, output_shape, num_trainable_params @@ -352,7 +357,7 @@ def _inject_summary(current_hierarchy: Any, previous_keys: List[str]) -> None: output_shape = [output_shape] for shape in output_shape: total_output += prod(shape) - print(line_new) + messages.append(line_new) # assume 4 bytes/number (float on cuda). x_batch = sample_batch["x_batch"] @@ -362,16 +367,21 @@ def _inject_summary(current_hierarchy: Any, previous_keys: List[str]) -> None: total_params_size = abs(total_params * 4.0 / (1024 ** 2.0)) total_size = total_params_size + total_output_size + total_input_size - print("=" * line_length) - print("Total params: {0:,}".format(total_params)) - print("Trainable params: {0:,}".format(trainable_params)) - print("Non-trainable params: {0:,}".format(total_params - trainable_params)) - print("-" * line_length) - print("Input size (MB): %0.2f" % total_input_size) - print("Forward/backward pass size (MB): %0.2f" % total_output_size) - print("Params size (MB): %0.2f" % total_params_size) - print("Estimated Total Size (MB): %0.2f" % total_size) - print("-" * line_length) + non_trainable_params = total_params - trainable_params + messages.append("=" * line_length) + messages.append("Total params: {0:,}".format(total_params)) + messages.append("Trainable params: {0:,}".format(trainable_params)) + messages.append("Non-trainable params: {0:,}".format(non_trainable_params)) + messages.append("-" * line_length) + messages.append("Input size (MB): %0.2f" % total_input_size) + messages.append("Forward/backward pass size (MB): %0.2f" % total_output_size) + messages.append("Params size (MB): %0.2f" % total_params_size) + messages.append("Estimated Total Size (MB): %0.2f" % total_size) + messages.append("-" * line_length) + msg = "\n".join(messages) + if not return_only: + print(msg) + return msg class LoggingMixinWithRank(LoggingMixin): diff --git a/cflearn/trainer.py b/cflearn/trainer.py index e11ac0852..0cc75f2c4 100644 --- a/cflearn/trainer.py +++ b/cflearn/trainer.py @@ -986,7 +986,7 @@ def fit( show_summary = self.show_summary if show_summary is None: show_summary = not self.tqdm_settings.in_distributed - if show_summary and not self.is_loading: + if self.is_rank_0 and not self.is_loading: next_item = next(iter(self.tr_loader_copy)) if self.tr_loader_copy.return_indices: assert isinstance(next_item, tuple) @@ -994,7 +994,14 @@ def fit( else: assert isinstance(next_item, dict) sample_batch = next_item - summary(self.model, sample_batch) + summary_msg = summary( + self.model, + sample_batch, + return_only=not show_summary, + ) + logging_folder = self.environment.logging_folder + with open(os.path.join(logging_folder, "__summary__.txt"), "w") as f: + f.write(summary_msg) self._prepare_log() step_tqdm = None self._epoch_tqdm: Optional[tqdm] = None