Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(grpc): backend SPI pluggable in embedding mode #1621

Merged
merged 7 commits into from
Jan 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion api/backend/embeddings.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c config.

var fn func() ([]float32, error)
switch model := inferenceModel.(type) {
case *grpc.Client:
case grpc.Backend:
fn = func() ([]float32, error) {
predictOptions := gRPCPredictOpts(c, loader.ModelPath)
if len(tokens) > 0 {
Expand Down
2 changes: 1 addition & 1 deletion api/backend/llm.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func ModelInference(ctx context.Context, s string, images []string, loader *mode

grpcOpts := gRPCModelOpts(c)

var inferenceModel *grpc.Client
var inferenceModel grpc.Backend
var err error

opts := modelOpts(c, o, []model.Option{
Expand Down
46 changes: 46 additions & 0 deletions pkg/grpc/backend.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package grpc

import (
"context"
"github.com/go-skynet/LocalAI/api/schema"
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
"google.golang.org/grpc"
)

var embeds = map[string]*embedBackend{}

func Provide(addr string, llm LLM) {
embeds[addr] = &embedBackend{s: &server{llm: llm}}
}

func NewClient(address string, parallel bool, wd WatchDog, enableWatchDog bool) Backend {
if bc, ok := embeds[address]; ok {
return bc
}
return NewGrpcClient(address, parallel, wd, enableWatchDog)
}

func NewGrpcClient(address string, parallel bool, wd WatchDog, enableWatchDog bool) Backend {
if !enableWatchDog {
wd = nil
}
return &Client{
address: address,
parallel: parallel,
wd: wd,
}
}

type Backend interface {
IsBusy() bool
HealthCheck(ctx context.Context) (bool, error)
Embeddings(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.EmbeddingResult, error)
Predict(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.Reply, error)
LoadModel(ctx context.Context, in *pb.ModelOptions, opts ...grpc.CallOption) (*pb.Result, error)
PredictStream(ctx context.Context, in *pb.PredictOptions, f func(s []byte), opts ...grpc.CallOption) error
GenerateImage(ctx context.Context, in *pb.GenerateImageRequest, opts ...grpc.CallOption) (*pb.Result, error)
TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOption) (*pb.Result, error)
AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*schema.Result, error)
TokenizeString(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.TokenizationResponse, error)
Status(ctx context.Context) (*pb.StatusResponse, error)
}
11 changes: 0 additions & 11 deletions pkg/grpc/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,6 @@ type WatchDog interface {
UnMark(address string)
}

func NewClient(address string, parallel bool, wd WatchDog, enableWatchDog bool) *Client {
if !enableWatchDog {
wd = nil
}
return &Client{
address: address,
parallel: parallel,
wd: wd,
}
}

func (c *Client) IsBusy() bool {
c.Lock()
defer c.Unlock()
Expand Down
121 changes: 121 additions & 0 deletions pkg/grpc/embed.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
package grpc

import (
"context"
"github.com/go-skynet/LocalAI/api/schema"
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
"time"
)

var _ Backend = new(embedBackend)
var _ pb.Backend_PredictStreamServer = new(embedBackendServerStream)

type embedBackend struct {
s *server
}

func (e *embedBackend) IsBusy() bool {
return e.s.llm.Busy()
}

func (e *embedBackend) HealthCheck(ctx context.Context) (bool, error) {
return true, nil
}

func (e *embedBackend) Embeddings(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.EmbeddingResult, error) {
return e.s.Embedding(ctx, in)
}

func (e *embedBackend) Predict(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.Reply, error) {
return e.s.Predict(ctx, in)
}

func (e *embedBackend) LoadModel(ctx context.Context, in *pb.ModelOptions, opts ...grpc.CallOption) (*pb.Result, error) {
return e.s.LoadModel(ctx, in)
}

func (e *embedBackend) PredictStream(ctx context.Context, in *pb.PredictOptions, f func(s []byte), opts ...grpc.CallOption) error {
bs := &embedBackendServerStream{
ctx: ctx,
fn: f,
}
return e.s.PredictStream(in, bs)
}

func (e *embedBackend) GenerateImage(ctx context.Context, in *pb.GenerateImageRequest, opts ...grpc.CallOption) (*pb.Result, error) {
return e.s.GenerateImage(ctx, in)
}

func (e *embedBackend) TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOption) (*pb.Result, error) {
return e.s.TTS(ctx, in)
}

func (e *embedBackend) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*schema.Result, error) {
r, err := e.s.AudioTranscription(ctx, in)
if err != nil {
return nil, err
}
tr := &schema.Result{}
for _, s := range r.Segments {
var tks []int
for _, t := range s.Tokens {
tks = append(tks, int(t))
}
tr.Segments = append(tr.Segments,
schema.Segment{
Text: s.Text,
Id: int(s.Id),
Start: time.Duration(s.Start),
End: time.Duration(s.End),
Tokens: tks,
})
}
tr.Text = r.Text
return tr, err
}

func (e *embedBackend) TokenizeString(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.TokenizationResponse, error) {
return e.s.TokenizeString(ctx, in)
}

func (e *embedBackend) Status(ctx context.Context) (*pb.StatusResponse, error) {
return e.s.Status(ctx, &pb.HealthMessage{})
}

type embedBackendServerStream struct {
ctx context.Context
fn func(s []byte)
}

func (e *embedBackendServerStream) Send(reply *pb.Reply) error {
e.fn(reply.GetMessage())
return nil
}

func (e *embedBackendServerStream) SetHeader(md metadata.MD) error {
return nil
}

func (e *embedBackendServerStream) SendHeader(md metadata.MD) error {
return nil
}

func (e *embedBackendServerStream) SetTrailer(md metadata.MD) {
}

func (e *embedBackendServerStream) Context() context.Context {
return e.ctx
}

func (e *embedBackendServerStream) SendMsg(m any) error {
if x, ok := m.(*pb.Reply); ok {
return e.Send(x)
}
return nil
}

func (e *embedBackendServerStream) RecvMsg(m any) error {
return nil
}
20 changes: 20 additions & 0 deletions pkg/grpc/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -181,3 +181,23 @@ func StartServer(address string, model LLM) error {

return nil
}

func RunServer(address string, model LLM) (func() error, error) {
lis, err := net.Listen("tcp", address)
if err != nil {
return nil, err
}
s := grpc.NewServer()
pb.RegisterBackendServer(s, &server{llm: model})
log.Printf("gRPC Server listening at %v", lis.Addr())
if err = s.Serve(lis); err != nil {
return func() error {
return lis.Close()
}, err
}

return func() error {
s.GracefulStop()
return nil
}, nil
}
6 changes: 3 additions & 3 deletions pkg/model/initializers.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string
}
}

func (ml *ModelLoader) resolveAddress(addr ModelAddress, parallel bool) (*grpc.Client, error) {
func (ml *ModelLoader) resolveAddress(addr ModelAddress, parallel bool) (grpc.Backend, error) {
if parallel {
return addr.GRPC(parallel, ml.wd), nil
}
Expand All @@ -177,7 +177,7 @@ func (ml *ModelLoader) resolveAddress(addr ModelAddress, parallel bool) (*grpc.C
return ml.grpcClients[string(addr)], nil
}

func (ml *ModelLoader) BackendLoader(opts ...Option) (client *grpc.Client, err error) {
func (ml *ModelLoader) BackendLoader(opts ...Option) (client grpc.Backend, err error) {
o := NewOptions(opts...)

if o.model != "" {
Expand Down Expand Up @@ -220,7 +220,7 @@ func (ml *ModelLoader) BackendLoader(opts ...Option) (client *grpc.Client, err e
return ml.resolveAddress(addr, o.parallelRequests)
}

func (ml *ModelLoader) GreedyLoader(opts ...Option) (*grpc.Client, error) {
func (ml *ModelLoader) GreedyLoader(opts ...Option) (grpc.Backend, error) {
o := NewOptions(opts...)

ml.mu.Lock()
Expand Down
8 changes: 4 additions & 4 deletions pkg/model/loader.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ type ModelLoader struct {
ModelPath string
mu sync.Mutex
// TODO: this needs generics
grpcClients map[string]*grpc.Client
grpcClients map[string]grpc.Backend
models map[string]ModelAddress
grpcProcesses map[string]*process.Process
templates map[TemplateType]map[string]*template.Template
Expand All @@ -68,7 +68,7 @@ type ModelLoader struct {

type ModelAddress string

func (m ModelAddress) GRPC(parallel bool, wd *WatchDog) *grpc.Client {
func (m ModelAddress) GRPC(parallel bool, wd *WatchDog) grpc.Backend {
enableWD := false
if wd != nil {
enableWD = true
Expand All @@ -79,7 +79,7 @@ func (m ModelAddress) GRPC(parallel bool, wd *WatchDog) *grpc.Client {
func NewModelLoader(modelPath string) *ModelLoader {
nml := &ModelLoader{
ModelPath: modelPath,
grpcClients: make(map[string]*grpc.Client),
grpcClients: make(map[string]grpc.Backend),
models: make(map[string]ModelAddress),
templates: make(map[TemplateType]map[string]*template.Template),
grpcProcesses: make(map[string]*process.Process),
Expand Down Expand Up @@ -163,7 +163,7 @@ func (ml *ModelLoader) StopModel(modelName string) error {
}

func (ml *ModelLoader) CheckIsLoaded(s string) ModelAddress {
var client *grpc.Client
var client grpc.Backend
if m, ok := ml.models[s]; ok {
log.Debug().Msgf("Model already loaded in memory: %s", s)
if c, ok := ml.grpcClients[s]; ok {
Expand Down