-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathopenai.go
155 lines (129 loc) · 3.46 KB
/
openai.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
package main
import (
"context"
"errors"
gogpt "github.com/sashabaranov/go-gpt3"
"log"
"os"
"regexp"
"strconv"
)
// Number of times we'll retry generating a prompt thats unsafe
// before giving up
var MAX_COMPLETION_RETRIES int = 5
// Default response when we reach max retries
var DEFAULT_RESPONSE string = "*Yaaaawn*... eh, I dont really feel like it"
type CompletionRequest struct {
Prompt string
FilterRegex string
ResponseChan chan CompletionResponse
Model ModelEnum
Temperature float32
Tokens int
}
type CompletionResponse struct {
Response string
Err error
}
type ModelEnum struct{ *string }
func (e ModelEnum) String() string {
if e.string == nil {
return "<void>"
}
return *e.string
}
func (e ModelEnum) IsValid() bool {
for _, m := range []ModelEnum{Ada, Babbage, Curie, Davinci, DavinciInstruct, CurieInstruct} {
if m == e {
return true
}
}
return false
}
// Not a great way to do enums in golang
var (
es = []string{"ada", "babbage", "curie", "davinci", "davinci-instruct-beta", "curie-instruct-beta"}
Ada = ModelEnum{&es[0]}
Babbage = ModelEnum{&es[1]}
Curie = ModelEnum{&es[2]}
Davinci = ModelEnum{&es[3]}
DavinciInstruct = ModelEnum{&es[4]}
CurieInstruct = ModelEnum{&es[5]}
)
func runCompletions(buffer chan CompletionRequest) {
c := gogpt.NewClient(os.Getenv("OPENAI_API_KEY"))
ctx := context.Background()
for {
request, more := <-buffer
// If the buffer is closed, kill this goroutine
if !more {
log.Printf("Request buffer closed. Closing completion backend")
return
} else if !request.Model.IsValid() {
// Make sure we have a valid model requested
request.ResponseChan <- CompletionResponse{
Response: "",
Err: errors.New("Requested invalid model: " +
request.Model.String()),
}
} else {
req := gogpt.CompletionRequest{
MaxTokens: request.Tokens,
Prompt: request.Prompt,
Temperature: request.Temperature,
}
try := true
respText := ""
retries := 0
for try {
resp, err := c.CreateCompletion(ctx, request.Model.String(), req)
if err != nil {
return
}
respText = resp.Choices[0].Text
sensitivity, err := checkSensitivity(respText, ctx, c)
check(err)
// Safe is 0, sensitive is 1, unsafe is 2
if sensitivity < 2 {
try = false
} else if retries >= MAX_COMPLETION_RETRIES {
respText = DEFAULT_RESPONSE
log.Printf("Max retries reached for prompt: %v", req.Prompt)
break
} else {
retries++
}
}
filteredText := filterResponse(respText, request.FilterRegex)
request.ResponseChan <- CompletionResponse{
Response: filteredText,
Err: nil,
}
}
close(request.ResponseChan)
}
}
func filterResponse(text string, regex string) string {
// Regex to match the beginning of text we want to remove
// If the ai tries to provide the user's response to it's response,
// we'll remove it
re := regexp.MustCompile(regex)
indexes := re.FindStringIndex(text)
if indexes != nil {
return text[:indexes[0]]
}
return text
}
func checkSensitivity(text string, ctx context.Context, c *gogpt.Client) (int, error) {
req := gogpt.CompletionRequest{
MaxTokens: 1,
Prompt: "<|endoftext|>" + text + "\n--\nLabel:",
Temperature: 0.0,
TopP: 0,
}
resp, err := c.CreateCompletion(ctx, "content-filter-alpha-c4", req)
check(err)
sensitivity, err := strconv.Atoi(resp.Choices[0].Text)
check(err)
return sensitivity, nil
}