-
Notifications
You must be signed in to change notification settings - Fork 118
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
torch image added #659
torch image added #659
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR! Overall it seems to me that the main contribution here is the addition of the data_format
. Please focus the PR on that. Note that it should default to None
, which gets converted to config.image_data_format()
(the global default), so that users don't have to pass the right value around everywhere.
@@ -50,6 +52,16 @@ def resize( | |||
return resized | |||
|
|||
|
|||
def decode_image(img, channels, expand_animations=False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is not a priori need for backend specific functions for these ops
@@ -63,6 +75,130 @@ def resize( | |||
) | |||
|
|||
|
|||
def smart_resize(x, size, interpolation="bilinear"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We already have a backend agnostic implementation of smart_resize
in image_utils
(I suppose we could make it a public instead but the need has not arisen).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@fchollet how about if we enable it in this PR by using inside load_image
.
@@ -319,13 +325,12 @@ def image_dataset_from_directory( | |||
if shuffle: | |||
dataset = dataset.shuffle(buffer_size=1024, seed=seed) | |||
|
|||
dataset = dataset.prefetch(tf.data.AUTOTUNE) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure that prefetching should be done before shuffling
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure that prefetching should be done before shuffling
so previously keras used to do local shuffle not global hence I kept'd same.
@@ -29,6 +46,7 @@ def image_dataset_from_directory( | |||
interpolation="bilinear", | |||
follow_links=False, | |||
crop_to_aspect_ratio=False, | |||
data_format="channels_last", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adding this argument seems useful!
@fchollet can we merge this ? |
@asingh9530 the PR seems to contain a lot of extraneous changes, in particular the |
Fixed extra changes, I forgot to remove them. I believe now it should be fine. 🤔 |
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""Keras image dataset loading utilities.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are still a lot of extraneous changes here. Check out the diff.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The header was part of keras API that's why I added it here, but I have removed it now.
@@ -180,24 +185,6 @@ def image_dataset_from_directory( | |||
f"Received: color_mode={color_mode}" | |||
) | |||
|
|||
interpolation = interpolation.lower() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This block should not be removed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, Added it back again but since not all interpolation methods are supported in all backend types and if user uses smart_resize()
then backend.resize()
is called with current backend then already ValueError have been raised specific to backend, That's why I removed it.
) | ||
|
||
train_dataset = train_dataset.prefetch(tf.data.AUTOTUNE) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Prefetching should not be moved.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shifted back to after shuffle. But I have a question so why local shuffle is used in tf.keras
and we are using global shuffle in keras_core
@@ -288,6 +289,8 @@ def test_image_dataset_from_directory_no_images(self): | |||
_ = image_dataset_utils.image_dataset_from_directory(directory) | |||
|
|||
def test_image_dataset_from_directory_crop_to_aspect_ratio(self): | |||
if backend_config() == "torch": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test should not be skipped.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the update! The change looks good.
Hi @fchollet ,
I was adding this in my previous PR but since you have already ported code from keras, I am only adding torch functionalities.
This PR contains following
Also need your thoughts on following questions, as I am not sure regarding how do we want to support backend.
most functions in main branch for
paths_and_labels_to_dataset
functions are still returning and accepting intf
only and processing also is done in tf , my idea was to make it backend specific.since for this PR I am heavily using
backend
methods and since these are not present forjax
andnumpy
most test using this function will fail but since the main branch is currently only using tf for everything so not sure how exactly it fits condition of supporting multiple backends🤔 . My idea was to make sure we do processing backend specific but only making API interface similar.Since
Image_utils
contains many functions but still some of them depends on backend specific function i.ebackend.function()
which again are not present for all. 😅