Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: queue up requests if not running parallel requests #1296

Merged
merged 1 commit into from
Nov 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions api/localai/backend_monitor.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,13 +123,12 @@ func BackendMonitorEndpoint(bm BackendMonitor) func(c *fiber.Ctx) error {
return err
}

client := bm.options.Loader.CheckIsLoaded(backendId)

if client == "" {
model := bm.options.Loader.CheckIsLoaded(backendId)
if model == "" {
return fmt.Errorf("backend %s is not currently loaded", backendId)
}

status, rpcErr := client.GRPC().Status(context.TODO())
status, rpcErr := model.GRPC(false).Status(context.TODO())
if rpcErr != nil {
log.Warn().Msgf("backend %s experienced an error retrieving status info: %s", backendId, rpcErr.Error())
val, slbErr := bm.SampleLocalBackendProcess(backendId)
Expand Down
51 changes: 47 additions & 4 deletions pkg/grpc/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,17 @@ import (
)

type Client struct {
address string
busy bool
address string
busy bool
parallel bool
sync.Mutex
opMutex sync.Mutex
}

func NewClient(address string) *Client {
func NewClient(address string, parallel bool) *Client {
return &Client{
address: address,
address: address,
parallel: parallel,
}
}

Expand All @@ -38,6 +41,10 @@ func (c *Client) setBusy(v bool) {
}

func (c *Client) HealthCheck(ctx context.Context) bool {
if !c.parallel {
c.opMutex.Lock()
defer c.opMutex.Unlock()
}
c.setBusy(true)
defer c.setBusy(false)
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
Expand Down Expand Up @@ -66,6 +73,10 @@ func (c *Client) HealthCheck(ctx context.Context) bool {
}

func (c *Client) Embeddings(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.EmbeddingResult, error) {
if !c.parallel {
c.opMutex.Lock()
defer c.opMutex.Unlock()
}
c.setBusy(true)
defer c.setBusy(false)
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
Expand All @@ -79,6 +90,10 @@ func (c *Client) Embeddings(ctx context.Context, in *pb.PredictOptions, opts ...
}

func (c *Client) Predict(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.Reply, error) {
if !c.parallel {
c.opMutex.Lock()
defer c.opMutex.Unlock()
}
c.setBusy(true)
defer c.setBusy(false)
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
Expand All @@ -92,6 +107,10 @@ func (c *Client) Predict(ctx context.Context, in *pb.PredictOptions, opts ...grp
}

func (c *Client) LoadModel(ctx context.Context, in *pb.ModelOptions, opts ...grpc.CallOption) (*pb.Result, error) {
if !c.parallel {
c.opMutex.Lock()
defer c.opMutex.Unlock()
}
c.setBusy(true)
defer c.setBusy(false)
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
Expand All @@ -104,6 +123,10 @@ func (c *Client) LoadModel(ctx context.Context, in *pb.ModelOptions, opts ...grp
}

func (c *Client) PredictStream(ctx context.Context, in *pb.PredictOptions, f func(s []byte), opts ...grpc.CallOption) error {
if !c.parallel {
c.opMutex.Lock()
defer c.opMutex.Unlock()
}
c.setBusy(true)
defer c.setBusy(false)
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
Expand Down Expand Up @@ -135,6 +158,10 @@ func (c *Client) PredictStream(ctx context.Context, in *pb.PredictOptions, f fun
}

func (c *Client) GenerateImage(ctx context.Context, in *pb.GenerateImageRequest, opts ...grpc.CallOption) (*pb.Result, error) {
if !c.parallel {
c.opMutex.Lock()
defer c.opMutex.Unlock()
}
c.setBusy(true)
defer c.setBusy(false)
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
Expand All @@ -147,6 +174,10 @@ func (c *Client) GenerateImage(ctx context.Context, in *pb.GenerateImageRequest,
}

func (c *Client) TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOption) (*pb.Result, error) {
if !c.parallel {
c.opMutex.Lock()
defer c.opMutex.Unlock()
}
c.setBusy(true)
defer c.setBusy(false)
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
Expand All @@ -159,6 +190,10 @@ func (c *Client) TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOp
}

func (c *Client) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*schema.Result, error) {
if !c.parallel {
c.opMutex.Lock()
defer c.opMutex.Unlock()
}
c.setBusy(true)
defer c.setBusy(false)
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
Expand Down Expand Up @@ -191,6 +226,10 @@ func (c *Client) AudioTranscription(ctx context.Context, in *pb.TranscriptReques
}

func (c *Client) TokenizeString(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.TokenizationResponse, error) {
if !c.parallel {
c.opMutex.Lock()
defer c.opMutex.Unlock()
}
c.setBusy(true)
defer c.setBusy(false)
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
Expand All @@ -209,6 +248,10 @@ func (c *Client) TokenizeString(ctx context.Context, in *pb.PredictOptions, opts
}

func (c *Client) Status(ctx context.Context) (*pb.StatusResponse, error) {
if !c.parallel {
c.opMutex.Lock()
defer c.opMutex.Unlock()
}
c.setBusy(true)
defer c.setBusy(false)
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
Expand Down
8 changes: 4 additions & 4 deletions pkg/model/initializers.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string
// Wait for the service to start up
ready := false
for i := 0; i < o.grpcAttempts; i++ {
if client.GRPC().HealthCheck(context.Background()) {
if client.GRPC(o.parallelRequests).HealthCheck(context.Background()) {
log.Debug().Msgf("GRPC Service Ready")
ready = true
break
Expand All @@ -140,7 +140,7 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string

log.Debug().Msgf("GRPC: Loading model with options: %+v", options)

res, err := client.GRPC().LoadModel(o.context, &options)
res, err := client.GRPC(o.parallelRequests).LoadModel(o.context, &options)
if err != nil {
return "", fmt.Errorf("could not load model: %w", err)
}
Expand All @@ -154,11 +154,11 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string

func (ml *ModelLoader) resolveAddress(addr ModelAddress, parallel bool) (*grpc.Client, error) {
if parallel {
return addr.GRPC(), nil
return addr.GRPC(parallel), nil
}

if _, ok := ml.grpcClients[string(addr)]; !ok {
ml.grpcClients[string(addr)] = addr.GRPC()
ml.grpcClients[string(addr)] = addr.GRPC(parallel)
}
return ml.grpcClients[string(addr)], nil
}
Expand Down
12 changes: 9 additions & 3 deletions pkg/model/loader.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ type ModelLoader struct {

type ModelAddress string

func (m ModelAddress) GRPC() *grpc.Client {
return grpc.NewClient(string(m))
func (m ModelAddress) GRPC(parallel bool) *grpc.Client {
return grpc.NewClient(string(m), parallel)
}

func NewModelLoader(modelPath string) *ModelLoader {
Expand Down Expand Up @@ -147,10 +147,16 @@ func (ml *ModelLoader) ShutdownModel(modelName string) error {
}

func (ml *ModelLoader) CheckIsLoaded(s string) ModelAddress {
var client *grpc.Client
if m, ok := ml.models[s]; ok {
log.Debug().Msgf("Model already loaded in memory: %s", s)
if c, ok := ml.grpcClients[s]; ok {
client = c
} else {
client = m.GRPC(false)
}

if !m.GRPC().HealthCheck(context.Background()) {
if !client.HealthCheck(context.Background()) {
log.Debug().Msgf("GRPC Model not responding: %s", s)
if !ml.grpcProcesses[s].IsAlive() {
log.Debug().Msgf("GRPC Process is not responding: %s", s)
Expand Down
2 changes: 1 addition & 1 deletion pkg/model/process.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import (
func (ml *ModelLoader) StopAllExcept(s string) {
ml.StopGRPC(func(id string, p *process.Process) bool {
if id != s {
for ml.models[id].GRPC().IsBusy() {
for ml.models[id].GRPC(false).IsBusy() {
log.Debug().Msgf("%s busy. Waiting.", id)
time.Sleep(2 * time.Second)
}
Expand Down