Skip to content

Commit

Permalink
Removing tf-ranking as a dependency untill it supports tf 2.16 (#7725)
Browse files Browse the repository at this point in the history
  • Loading branch information
vkarampudi authored Dec 9, 2024
1 parent 4975229 commit a4f29a0
Show file tree
Hide file tree
Showing 4 changed files with 195 additions and 365 deletions.
35 changes: 18 additions & 17 deletions tfx/examples/ranking/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,36 +17,37 @@
These names will be shared between the transform and the model.
"""

import tensorflow as tf
from tfx.examples.ranking import struct2tensor_parsing_utils
# import tensorflow as tf
# This is due to TF Ranking not supporting TensorFlow 2.16, We should re-enable it when support is added.
# from tfx.examples.ranking import struct2tensor_parsing_utils

# Labels are expected to be dense. In case of a batch of ELWCs have different
# number of documents, the shape of the label is [N, D], where N is the batch
# size, D is the maximum number of documents in the batch. If an ELWC in the
# batch has D_0 < D documents, then the value of label at D0 <= d < D must be
# negative to indicate that the document is invalid.
LABEL_PADDING_VALUE = -1
#LABEL_PADDING_VALUE = -1

# Names of features in the ELWC.
QUERY_TOKENS = 'query_tokens'
DOCUMENT_TOKENS = 'document_tokens'
LABEL = 'relevance'
#QUERY_TOKENS = 'query_tokens'
#DOCUMENT_TOKENS = 'document_tokens'
#LABEL = 'relevance'

# This "feature" does not exist in the data but will be created on the fly.
LIST_SIZE_FEATURE_NAME = 'example_list_size'
# LIST_SIZE_FEATURE_NAME = 'example_list_size'


def get_features():
"""Defines the context features and example features spec for parsing."""
#def get_features():
# """Defines the context features and example features spec for parsing."""

context_features = [
struct2tensor_parsing_utils.Feature(QUERY_TOKENS, tf.string)
]
# context_features = [
# struct2tensor_parsing_utils.Feature(QUERY_TOKENS, tf.string)
# ]

example_features = [
struct2tensor_parsing_utils.Feature(DOCUMENT_TOKENS, tf.string)
]
# example_features = [
# struct2tensor_parsing_utils.Feature(DOCUMENT_TOKENS, tf.string)
# ]

label = struct2tensor_parsing_utils.Feature(LABEL, tf.int64)
# label = struct2tensor_parsing_utils.Feature(LABEL, tf.int64)

return context_features, example_features, label
# return context_features, example_features, label
47 changes: 25 additions & 22 deletions tfx/examples/ranking/ranking_pipeline_e2e_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,12 @@
import unittest

import tensorflow as tf
from tfx.examples.ranking import ranking_pipeline
from tfx.orchestration import metadata
from tfx.orchestration.beam.beam_dag_runner import BeamDagRunner
# from tfx.orchestration import metadata
# from tfx.orchestration.beam.beam_dag_runner import BeamDagRunner

# This is due to TF Ranking not supporting TensorFlow 2.16, We should re-enable it when support is added.
# from tfx.examples.ranking import ranking_pipeline


try:
import struct2tensor # pylint: disable=g-import-not-at-top
Expand Down Expand Up @@ -62,23 +65,23 @@ def assertExecutedOnce(self, component) -> None:
execution = tf.io.gfile.listdir(os.path.join(component_path, output))
self.assertEqual(1, len(execution))

def testPipeline(self):
BeamDagRunner().run(
ranking_pipeline._create_pipeline(
pipeline_name=self._pipeline_name,
pipeline_root=self._tfx_root,
data_root=self._data_root,
module_file=self._module_file,
serving_model_dir=self._serving_model_dir,
metadata_path=self._metadata_path,
beam_pipeline_args=['--direct_num_workers=1']))
self.assertTrue(tf.io.gfile.exists(self._serving_model_dir))
self.assertTrue(tf.io.gfile.exists(self._metadata_path))
#def testPipeline(self):
# BeamDagRunner().run(
# ranking_pipeline._create_pipeline(
# pipeline_name=self._pipeline_name,
# pipeline_root=self._tfx_root,
# data_root=self._data_root,
# module_file=self._module_file,
# serving_model_dir=self._serving_model_dir,
# metadata_path=self._metadata_path,
# beam_pipeline_args=['--direct_num_workers=1']))
# self.assertTrue(tf.io.gfile.exists(self._serving_model_dir))
# self.assertTrue(tf.io.gfile.exists(self._metadata_path))

metadata_config = metadata.sqlite_metadata_connection_config(
self._metadata_path)
with metadata.Metadata(metadata_config) as m:
artifact_count = len(m.store.get_artifacts())
execution_count = len(m.store.get_executions())
self.assertGreaterEqual(artifact_count, execution_count)
self.assertEqual(9, execution_count)
# metadata_config = metadata.sqlite_metadata_connection_config(
# self._metadata_path)
# with metadata.Metadata(metadata_config) as m:
# artifact_count = len(m.store.get_artifacts())
# execution_count = len(m.store.get_executions())
# self.assertGreaterEqual(artifact_count, execution_count)
# self.assertEqual(9, execution_count)
Loading

0 comments on commit a4f29a0

Please sign in to comment.