diff --git a/quaterion/main.py b/quaterion/main.py index b0ecdae9..7a882761 100644 --- a/quaterion/main.py +++ b/quaterion/main.py @@ -158,8 +158,13 @@ def trainer_defaults( # If the cache is enabled and there are no # trainable encoders - checkpointing on each epoch might become a bottleneck cache_config = trainable_model.configure_caches() - all_encoders_frozen = all(not encoder.trainable for encoder in trainable_model.model.encoders.values()) - cache_configured = cache_config is not None and cache_config.cache_type != CacheType.NONE + all_encoders_frozen = all( + not encoder.trainable + for encoder in trainable_model.model.encoders.values() + ) + cache_configured = ( + cache_config is not None and cache_config.cache_type != CacheType.NONE + ) disable_checkpoints = all_encoders_frozen and cache_configured if disable_checkpoints: