-
Notifications
You must be signed in to change notification settings - Fork 157
/
app.py
59 lines (51 loc) · 1.62 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from pydantic import BaseModel, Field
from transformers import pipeline
MODEL_NAME = "distilgpt2"
# Load model
nlp = pipeline("text-generation", model=MODEL_NAME, tokenizer=MODEL_NAME)
class TextGenerationInput(BaseModel):
text: str = Field(
...,
title="Text Input",
description="The input text to use as basis to generate text.",
max_length=1000,
)
temperature: float = Field(
1.0,
gt=0.0,
multiple_of=0.001,
description="The value used to module the next token probabilities.",
)
max_length: int = Field(
30,
ge=5,
le=100,
description="The maximum length of the sequence to be generated.",
)
repetition_penalty: float = Field(
1.0,
ge=0.0,
le=1.0,
description="The parameter for repetition penalty. 1.0 means no penalty.",
)
top_k: int = Field(
50,
ge=0,
description="The number of highest probability vocabulary tokens to keep for top-k-filtering.",
)
do_sample: bool = Field(
False,
description="Whether or not to use sampling ; use greedy decoding otherwise.",
)
class TextGenerationOutput(BaseModel):
generated_text: str = Field(...)
def generate_text(input: TextGenerationInput) -> TextGenerationOutput:
"""Generate text based on a given prompt."""
res = nlp(
input.text,
temperature=input.temperature,
max_length=input.max_length,
repetition_penalty=input.repetition_penalty,
top_k=input.top_k,
)
return TextGenerationOutput(generated_text=res[0]["generated_text"])