-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Problem with kmnist dataset #273
Comments
I would simply repeat the channels here: vision_transformer/vit_jax/input_pipeline.py Lines 195 to 216 in 297866a
something like import tensorflow_datasets as tfds
import tensorflow as tf
ds = tfds.load('mnist', split='train')
ds = ds.map(lambda d: {
'label': d['label'],
'image': tf.repeat(d['image'], 3, axis=2),
})
ds = ds.batch(2)
b = next(iter(ds))
assert b['image'].shape.as_list() == [2, 28, 28, 3] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hello
I am trying to use pretrained B_16 model on tfds kmnist dataset (which is similar to mnist in terms of 26x26 greyscale)
Problem is I got error
Which is probably due to only 1 color channel instead of 3.
I had no problem with running pretrained model on custom color dataset, is this method only available for 3 channel datasets, or mnist likes are also welcome?
The text was updated successfully, but these errors were encountered: