Skip to content

TPU VM trained weight release w/ PyTorch XLA

Compare
Choose a tag to compare
@rwightman rwightman released this 18 Mar 22:50
7c67d6a

A wide range of mid-large sized models trained in PyTorch XLA on TPU VM instances. Demonstrating viability of the TPU + PyTorch combo for excellent image model results. All models trained w/ the bits_and_tpu branch of this codebase.

A big thanks to the TPU Research Cloud (https://sites.research.google/trc/about/) for the compute used in these experiments.

This set includes several novel weights, including EvoNorm-S RegNetZ (C/D timm variants) and ResNet-V2 model experiments, as well as custom pre-activation model variants of RegNet-Y (called RegNet-V) and Xception (Xception-P) models.

Many if not all of the included RegNet weights surpass original paper results by a wide margin and remain above other known results (e.g. recent torchvision updates) in ImageNet-1k validation and especially OOD test set / robustness performance and scaling to higher resolutions.

RegNets

  • regnety_040 - 82.3 @ 224, 82.96 @ 288
  • regnety_064 - 83.0 @ 224, 83.65 @ 288
  • regnety_080 - 83.17 @ 224, 83.86 @ 288
  • regnetv_040 - 82.44 @ 224, 83.18 @ 288 (timm pre-act)
  • regnetv_064 - 83.1 @ 224, 83.71 @ 288 (timm pre-act)
  • regnetz_040 - 83.67 @ 256, 84.25 @ 320
  • regnetz_040h - 83.77 @ 256, 84.5 @ 320 (w/ extra fc in head)

Alternative norm layers (no BN!)

  • resnetv2_50d_gn - 80.8 @ 224, 81.96 @ 288 (pre-act GroupNorm)
  • resnetv2_50d_evos 80.77 @ 224, 82.04 @ 288 (pre-act EvoNormS)
  • regnetz_c16_evos - 81.9 @ 256, 82.64 @ 320 (EvoNormS)
  • regnetz_d8_evos - 83.42 @ 256, 84.04 @ 320 (EvoNormS)

Xception redux

  • xception41p - 82 @ 299 (timm pre-act)
  • xception65 - 83.17 @ 299
  • xception65p - 83.14 @ 299 (timm pre-act)

ResNets (w/ SE and/or NeXT)

  • resnext101_64x4d - 82.46 @ 224, 83.16 @ 288
  • seresnext101_32x8d - 83.57 @ 224, 84.27 @ 288
  • seresnext101d_32x8d - 83.69 @ 224, 84.35 @ 288
  • seresnextaa101d_32x8d - 83.85 @ 224, 84.57 @ 288
  • resnetrs200 - 83.85 @ 256, 84.44 @ 320

Vision transformer experiments -- relpos, residual-post-norm, layer-scale, fc-norm, and GAP

  • vit_relpos_base_patch32_plus_rpn_256 - 79.5 @ 256, 80.6 @ 320 -- rel pos + extended width + res-post-norm, no class token, avg pool
  • vit_relpos_small_patch16_224 - 81.5 @ 224, 82.5 @ 320 -- rel pos, layer scale, no class token, avg pool
  • vit_relpos_medium_patch16_rpn_224 - 82.3 @ 224, 83.1 @ 320 -- rel pos + res-post-norm, no class token, avg pool
  • vit_base_patch16_rpn_224 - 82.3 @ 224 -- rel pos + res-post-norm, no class token, avg pool
  • vit_relpos_medium_patch16_224 - 82.5 @ 224, 83.3 @ 320 -- rel pos, layer scale, no class token, avg pool
  • vit_relpos_base_patch16_224 - 82.5 @ 224, 83.6 @ 320 -- rel pos, layer scale, no class token, avg pool
  • vit_relpos_base_patch16_gapcls_224 - 82.8 @ 224, 83.9 @ 320 -- rel pos, layer scale, class token, avg pool (by mistake)