Fine-Tuning CNN models #3182
Answered
by
andsteing
IMvision12
asked this question in
General
-
I have a flax model : b = mlpmixer_b16(num_classes=10) And pre-trained weights (ImageNet) (Image size: 224x224) with open("imagenet21k_Mixer-B_16.msgpack", "rb") as f:
content = f.read()
restored_params = flax.serialization.msgpack_restore(content) So, I want to fine-tune this model with restored_params on a dataset having images of size 128x128 dummy_inputs = jnp.ones((1, 128, 128, 3), dtype=jnp.float32)
rng = jax.random.PRNGKey(0)
x = b.apply({"params": restored_params}, dummy_inputs)
ScopeParamShapeError: Initializer expected to generate shape (196, 384) but got shape (64, 384) instead for parameter "kernel" in "/MixerBlock_0/token_mixing/Dense_0". (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.ScopeParamShapeError) If I change the shape to 224x224 it works fine: jnp.ones((1, 224, 224, 3), dtype=jnp.float32) How to properly finetune a model using flax? |
Beta Was this translation helpful? Give feedback.
Answered by
andsteing
Jul 6, 2023
Replies: 1 comment 3 replies
-
Why don't you try finetunning it with inputs of shape 224x224? Take a look at our Transfer Learning guide for some general recommendations. |
Beta Was this translation helpful? Give feedback.
3 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Answer copied from google-research/vision_transformer#274: