Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch image added #659

Merged

Conversation

asingh9530
Copy link
Contributor

@asingh9530 asingh9530 commented Aug 2, 2023

Hi @fchollet ,

I was adding this in my previous PR but since you have already ported code from keras, I am only adding torch functionalities.

This PR contains following

  • torch and tensorflow functions added in backend.
  • added relevant torch and tf imports.
  • backend based image pre-processing.

Also need your thoughts on following questions, as I am not sure regarding how do we want to support backend.

  • most functions in main branch for paths_and_labels_to_dataset functions are still returning and accepting in tf only and processing also is done in tf , my idea was to make it backend specific.

  • since for this PR I am heavily using backend methods and since these are not present for jax and numpy most test using this function will fail but since the main branch is currently only using tf for everything so not sure how exactly it fits condition of supporting multiple backends🤔 . My idea was to make sure we do processing backend specific but only making API interface similar.

  • Since Image_utils contains many functions but still some of them depends on backend specific function i.e backend.function() which again are not present for all. 😅

Copy link
Contributor

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! Overall it seems to me that the main contribution here is the addition of the data_format. Please focus the PR on that. Note that it should default to None, which gets converted to config.image_data_format() (the global default), so that users don't have to pass the right value around everywhere.

@@ -50,6 +52,16 @@ def resize(
return resized


def decode_image(img, channels, expand_animations=False):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is not a priori need for backend specific functions for these ops

@@ -63,6 +75,130 @@ def resize(
)


def smart_resize(x, size, interpolation="bilinear"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We already have a backend agnostic implementation of smart_resize in image_utils (I suppose we could make it a public instead but the need has not arisen).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fchollet how about if we enable it in this PR by using inside load_image.

@@ -319,13 +325,12 @@ def image_dataset_from_directory(
if shuffle:
dataset = dataset.shuffle(buffer_size=1024, seed=seed)

dataset = dataset.prefetch(tf.data.AUTOTUNE)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure that prefetching should be done before shuffling

Copy link
Contributor Author

@asingh9530 asingh9530 Aug 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure that prefetching should be done before shuffling

so previously keras used to do local shuffle not global hence I kept'd same.

@@ -29,6 +46,7 @@ def image_dataset_from_directory(
interpolation="bilinear",
follow_links=False,
crop_to_aspect_ratio=False,
data_format="channels_last",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding this argument seems useful!

@asingh9530 asingh9530 requested a review from fchollet August 5, 2023 15:09
@asingh9530
Copy link
Contributor Author

@fchollet can we merge this ?

@fchollet
Copy link
Contributor

fchollet commented Aug 6, 2023

@asingh9530 the PR seems to contain a lot of extraneous changes, in particular the smart_resize function. Please fix.

@asingh9530
Copy link
Contributor Author

@asingh9530 the PR seems to contain a lot of extraneous changes, in particular the smart_resize function. Please fix.

Fixed extra changes, I forgot to remove them. I believe now it should be fine. 🤔

# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Keras image dataset loading utilities."""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are still a lot of extraneous changes here. Check out the diff.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The header was part of keras API that's why I added it here, but I have removed it now.

@@ -180,24 +185,6 @@ def image_dataset_from_directory(
f"Received: color_mode={color_mode}"
)

interpolation = interpolation.lower()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This block should not be removed.

Copy link
Contributor Author

@asingh9530 asingh9530 Aug 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, Added it back again but since not all interpolation methods are supported in all backend types and if user uses smart_resize() then backend.resize() is called with current backend then already ValueError have been raised specific to backend, That's why I removed it.

)

train_dataset = train_dataset.prefetch(tf.data.AUTOTUNE)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Prefetching should not be moved.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shifted back to after shuffle. But I have a question so why local shuffle is used in tf.keras and we are using global shuffle in keras_core

@@ -288,6 +289,8 @@ def test_image_dataset_from_directory_no_images(self):
_ = image_dataset_utils.image_dataset_from_directory(directory)

def test_image_dataset_from_directory_crop_to_aspect_ratio(self):
if backend_config() == "torch":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test should not be skipped.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

Copy link
Contributor

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the update! The change looks good.

@fchollet fchollet merged commit cb95f39 into keras-team:main Aug 8, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants