Skip to content

Commit

Permalink
feat(components): Implement new output format of inference component
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 597621035
  • Loading branch information
Googler committed Jan 11, 2024
1 parent db6fe05 commit 4e1491a
Showing 1 changed file with 2 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ def model_inference_component_internal(
request_params: Dict[str, Any] = {},
max_request_per_minute: float = 3,
max_tokens_per_minute: float = 10000,
target_field_name: str = '',
query_field_name: str = '',
display_name: str = 'third-party-inference',
machine_type: str = 'e2-highmem-16',
Expand All @@ -68,11 +67,6 @@ def model_inference_component_internal(
max_request_per_minute: Maximum number of requests can be sent in a
minute.
max_tokens_per_minute: float = 10000,
target_field_name: The full name path of the features target field in the
predictions file. Formatted to be able to find nested columns, delimited
by `.`. Alternatively referred to as the ground truth (or
ground_truth_column) field. If not set, defaulted to
`inputs.ground_truth`.
query_field_name: The full name path of the features prompt field in the
request file. Formatted to be able to find nested columns, delimited by
`.`. Alternatively referred to as the ground truth (or
Expand Down Expand Up @@ -115,7 +109,7 @@ def model_inference_component_internal(
custom_job_payload=utils.build_custom_job_payload(
display_name=display_name,
machine_type=machine_type,
image_uri=version.LLM_EVAL_IMAGE_TAG,
image_uri=version.LLM_EVAL_IMAGE_TAG, # for local test and validation, use _IMAGE_URI.
args=[
f'--3p_model_inference={True}',
f'--project={project}',
Expand All @@ -127,7 +121,6 @@ def model_inference_component_internal(
f'--client_api_key_path={client_api_key_path}',
f'--max_request_per_minute={max_request_per_minute}',
f'--max_tokens_per_minute={max_tokens_per_minute}',
f'--target_field_name={target_field_name}',
f'--query_field_name={query_field_name}',
f'--gcs_output_path={gcs_output_path.path}',
'--executor_input={{$.json_escape[1]}}',
Expand All @@ -150,7 +143,6 @@ def model_inference_component(
inference_platform: str = 'openai_chat_completions',
model_id: str = 'gpt-3.5-turbo',
request_params: Dict[str, Any] = {},
target_field_name: str = '',
query_field_name: str = 'prompt',
max_request_per_minute: float = 3,
max_tokens_per_minute: float = 10000,
Expand All @@ -174,11 +166,6 @@ def model_inference_component(
inference_platform: Name of the inference platform.
model_id: Name of the model to send requests against.
request_params: Parameters to confirgure requests.
target_field_name: The full name path of the features target field in the
predictions file. Formatted to be able to find nested columns, delimited
by `.`. Alternatively referred to as the ground truth (or
ground_truth_column) field. If not set, defaulted to
`inputs.ground_truth`.
query_field_name: The full name path of the features prompt field in the
request file. Formatted to be able to find nested columns, delimited by
`.`. Alternatively referred to as the ground truth (or
Expand Down Expand Up @@ -234,7 +221,6 @@ def model_inference_component(
max_tokens_per_minute=max_tokens_per_minute,
display_name=display_name,
query_field_name=query_field_name,
target_field_name=target_field_name,
machine_type=machine_type,
service_account=service_account,
network=network,
Expand Down Expand Up @@ -341,7 +327,6 @@ def model_inference_and_evaluation_component(
max_request_per_minute=max_request_per_minute,
max_tokens_per_minute=max_tokens_per_minute,
query_field_name=query_field_name,
target_field_name=target_field_name,
display_name=display_name,
machine_type=machine_type,
service_account=service_account,
Expand All @@ -354,6 +339,7 @@ def model_inference_and_evaluation_component(
project=project,
location=location,
evaluation_task='text-generation',
target_field_name='.'.join(['instance', str(target_field_name)]),
predictions_format='jsonl',
joined_predictions_gcs_source=inference_task.outputs['gcs_output_path'],
machine_type=machine_type,
Expand Down

0 comments on commit 4e1491a

Please sign in to comment.