Skip to content

Commit

Permalink
✨summary will now be written to disk
Browse files Browse the repository at this point in the history
  • Loading branch information
carefree0910 committed Mar 18, 2021
1 parent 40fa867 commit d5435e9
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 17 deletions.
40 changes: 25 additions & 15 deletions cflearn/misc/toolkit.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,12 @@ def flatten(d: Dict[str, Any], previous_keys: Tuple[str, ...]) -> None:

# This is a modified version of https://github.com/sksq96/pytorch-summary
# So it can summary `carefree-learn` model structures better
def summary(model: nn.Module, sample_batch: tensor_dict_type) -> None:
def summary(
model: nn.Module,
sample_batch: tensor_dict_type,
*,
return_only: bool = False,
) -> str:
def register_hook(module: nn.Module) -> None:
def hook(module_: nn.Module, inp: Any, output: Any) -> None:
m_name = module_names.get(module_)
Expand Down Expand Up @@ -331,12 +336,12 @@ def _inject_summary(current_hierarchy: Any, previous_keys: List[str]) -> None:
_inject_summary(hierarchy, [])

line_length = 120
print("=" * line_length)
messages = ["=" * line_length]
line_format = "{:30} {:>20} {:>40} {:>20}"
headers = "Layer (type)", "Input Shape", "Output Shape", "Trainable Param #"
line_new = line_format.format(*headers)
print(line_new)
print("-" * line_length)
messages.append(line_new)
messages.append("-" * line_length)
total_output = 0
for layer in summary_dict:
# name, input_shape, output_shape, num_trainable_params
Expand All @@ -352,7 +357,7 @@ def _inject_summary(current_hierarchy: Any, previous_keys: List[str]) -> None:
output_shape = [output_shape]
for shape in output_shape:
total_output += prod(shape)
print(line_new)
messages.append(line_new)

# assume 4 bytes/number (float on cuda).
x_batch = sample_batch["x_batch"]
Expand All @@ -362,16 +367,21 @@ def _inject_summary(current_hierarchy: Any, previous_keys: List[str]) -> None:
total_params_size = abs(total_params * 4.0 / (1024 ** 2.0))
total_size = total_params_size + total_output_size + total_input_size

print("=" * line_length)
print("Total params: {0:,}".format(total_params))
print("Trainable params: {0:,}".format(trainable_params))
print("Non-trainable params: {0:,}".format(total_params - trainable_params))
print("-" * line_length)
print("Input size (MB): %0.2f" % total_input_size)
print("Forward/backward pass size (MB): %0.2f" % total_output_size)
print("Params size (MB): %0.2f" % total_params_size)
print("Estimated Total Size (MB): %0.2f" % total_size)
print("-" * line_length)
non_trainable_params = total_params - trainable_params
messages.append("=" * line_length)
messages.append("Total params: {0:,}".format(total_params))
messages.append("Trainable params: {0:,}".format(trainable_params))
messages.append("Non-trainable params: {0:,}".format(non_trainable_params))
messages.append("-" * line_length)
messages.append("Input size (MB): %0.2f" % total_input_size)
messages.append("Forward/backward pass size (MB): %0.2f" % total_output_size)
messages.append("Params size (MB): %0.2f" % total_params_size)
messages.append("Estimated Total Size (MB): %0.2f" % total_size)
messages.append("-" * line_length)
msg = "\n".join(messages)
if not return_only:
print(msg)
return msg


class LoggingMixinWithRank(LoggingMixin):
Expand Down
11 changes: 9 additions & 2 deletions cflearn/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -986,15 +986,22 @@ def fit(
show_summary = self.show_summary
if show_summary is None:
show_summary = not self.tqdm_settings.in_distributed
if show_summary and not self.is_loading:
if self.is_rank_0 and not self.is_loading:
next_item = next(iter(self.tr_loader_copy))
if self.tr_loader_copy.return_indices:
assert isinstance(next_item, tuple)
sample_batch = next_item[0]
else:
assert isinstance(next_item, dict)
sample_batch = next_item
summary(self.model, sample_batch)
summary_msg = summary(
self.model,
sample_batch,
return_only=not show_summary,
)
logging_folder = self.environment.logging_folder
with open(os.path.join(logging_folder, "__summary__.txt"), "w") as f:
f.write(summary_msg)
self._prepare_log()
step_tqdm = None
self._epoch_tqdm: Optional[tqdm] = None
Expand Down

0 comments on commit d5435e9

Please sign in to comment.