Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GlobalAveragePooling1D data_format Question #20627

Open
cptang2007 opened this issue Dec 11, 2024 · 0 comments · May be fixed by keras-team/keras-io#2006
Open

GlobalAveragePooling1D data_format Question #20627

cptang2007 opened this issue Dec 11, 2024 · 0 comments · May be fixed by keras-team/keras-io#2006
Assignees
Labels

Comments

@cptang2007
Copy link

My rig

  • Ubuntu 24.04 VM , RTX3060Ti with driver nvidia 535
  • tensorflow-2.14-gpu/tensorflow-2.18 , both pull from docker
  • Nvidia Container Toolkit if running in gpu version

About this example

The transformer blocks of this example contain 2 Conv1D layer, and therefore we have to reshape the input matrix to add the channel dimension at the end.
There is a GlobalAveragePooling1D layer after the transformer blocks:
x = layers.GlobalAveragePooling1D(data_format="channels_last")(x)

which should be correct since our channel is added at the last.

However, if running these example, the summary at the last third line will not have 64,128 Params
dense (Dense) │ (None, 128) │ 64,128 │ global_average_pool…

Instead it will just have 256 parameters and making the total params way less, the model will also have an accuracy of ~50% only
Screenshot from 2024-12-11 13-32-38

this happen no matter i am running tensorflow-2.14-gpu, or just using the CPU version tensorflow-2.18

However, if changing the data_format="channels_first" everything become fine. The number of params in the GlobalAveragePooling1D layer become 64,128. The total params also match. The training accuracy also more than 90%.

I discover that as i find a very similar model here.
The only difference is the data_format

But isn't data_format="channels_last" is the right choice ?

So whats wrong ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants