-
Notifications
You must be signed in to change notification settings - Fork 4
/
sessions.go
120 lines (100 loc) · 2.61 KB
/
sessions.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
package main
import (
"log"
"sync"
"time"
ai "github.com/sashabaranov/go-openai"
)
type Session interface {
GetHistory() []ai.ChatCompletionMessage
AddMessage(ai.ChatCompletionMessage)
Clear()
}
type SessionStore interface {
Get(string) Session
Range(func(key, value interface{}) bool)
Expire()
}
type SyncMapSessionStore struct {
sync.Map
config *Configuration
}
type LocalSession struct {
history []ai.ChatCompletionMessage
config *Configuration
last time.Time
mu sync.RWMutex
name string
}
var _ SessionStore = (*SyncMapSessionStore)(nil)
var _ Session = (*LocalSession)(nil)
func NewSessionStore(config *Configuration) SessionStore {
log.Printf("sessionstore: %s", "syncmap")
store := &SyncMapSessionStore{
config: config,
}
// start expiry goroutine
go func() {
for {
time.Sleep(config.Session.TTL)
store.Expire()
}
}()
return store
}
func (sessions *SyncMapSessionStore) Expire() {
sessions.Range(func(key, value interface{}) bool {
session := value.(*LocalSession)
if time.Since(session.last) > sessions.config.Session.TTL {
log.Printf("syncmapsessionstore: %s expired after %f seconds", key, sessions.config.Session.TTL.Seconds())
sessions.Delete(key)
}
return true
})
}
func (sessions *SyncMapSessionStore) Get(id string) Session {
if value, ok := sessions.Load(id); ok {
return value.(*LocalSession)
}
session := &LocalSession{
name: id,
last: time.Now(),
config: sessions.config,
}
session.Clear()
sessions.Store(id, session)
return session
}
func (s *LocalSession) GetHistory() []ai.ChatCompletionMessage {
s.mu.RLock()
defer s.mu.RUnlock()
history := make([]ai.ChatCompletionMessage, len(s.history))
copy(history, s.history)
return history
}
func (s *LocalSession) AddMessage(msg ai.ChatCompletionMessage) {
s.mu.Lock()
defer s.mu.Unlock()
s.history = append(s.history, msg)
s.last = time.Now()
s.trimHistory()
}
func (s *LocalSession) trimHistory() {
if len(s.history) <= s.config.Session.MaxHistory {
return
}
s.history = append(s.history[:1], s.history[len(s.history)-s.config.Session.MaxHistory:]...)
// "messages with role 'tool' must be a response to a preceding message with 'tool_calls'."
// if the second oldest message is a tool, remove it
// (the first message is the system message)
if s.history[1].Role == ai.ChatMessageRoleTool {
s.history = append(s.history[:1], s.history[2:]...)
}
}
func (s *LocalSession) Clear() {
s.mu.Lock()
defer s.mu.Unlock()
s.history = s.history[:0]
s.history = append(s.history, ai.ChatCompletionMessage{Role: ai.ChatMessageRoleSystem, Content: s.config.Bot.Prompt})
s.last = time.Now()
}