Skip to content

Commit

Permalink
Even more validation. (huggingface#20762)
Browse files Browse the repository at this point in the history
* Even more validation.

* Fixing order.
  • Loading branch information
Narsil authored and amyeroberts committed Jan 4, 2023
1 parent 3ed1666 commit 0be5bfc
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 0 deletions.
4 changes: 4 additions & 0 deletions src/transformers/pipelines/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,12 @@ def _sanitize_parameters(
if return_full_text is not None and return_type is None:
if return_text is not None:
raise ValueError("`return_text` is mutually exclusive with `return_full_text`")
if return_tensors is not None:
raise ValueError("`return_full_text` is mutually exclusive with `return_tensors`")
return_type = ReturnType.FULL_TEXT if return_full_text else ReturnType.NEW_TEXT
if return_tensors is not None and return_type is None:
if return_text is not None:
raise ValueError("`return_text` is mutually exclusive with `return_tensors`")
return_type = ReturnType.TENSORS
if return_type is not None:
postprocess_params["return_type"] = return_type
Expand Down
4 changes: 4 additions & 0 deletions tests/pipelines/test_pipelines_text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,10 @@ def run_pipeline_test(self, text_generator, _):

with self.assertRaises(ValueError):
outputs = text_generator("test", return_full_text=True, return_text=True)
with self.assertRaises(ValueError):
outputs = text_generator("test", return_full_text=True, return_tensors=True)
with self.assertRaises(ValueError):
outputs = text_generator("test", return_text=True, return_tensors=True)

# Empty prompt is slighly special
# it requires BOS token to exist.
Expand Down

0 comments on commit 0be5bfc

Please sign in to comment.