From bae9939da0c54adbf6612038d5dbf91a985755a9 Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Thu, 7 Nov 2024 10:11:39 +0100 Subject: [PATCH 1/2] chore: simplify passing options to ModelOptions Signed-off-by: Ettore Di Giacinto --- core/backend/embeddings.go | 2 +- core/backend/image.go | 2 +- core/backend/llm.go | 2 +- core/backend/options.go | 2 +- core/backend/rerank.go | 2 +- core/backend/soundgeneration.go | 2 +- core/backend/token_metrics.go | 4 +--- core/backend/tokenize.go | 4 +--- core/backend/transcript.go | 2 +- core/backend/tts.go | 6 +----- core/startup/startup.go | 2 +- 11 files changed, 11 insertions(+), 19 deletions(-) diff --git a/core/backend/embeddings.go b/core/backend/embeddings.go index 264d947b9913..5bf8eff90781 100644 --- a/core/backend/embeddings.go +++ b/core/backend/embeddings.go @@ -14,7 +14,7 @@ func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, backendCo var inferenceModel interface{} var err error - opts := ModelOptions(backendConfig, appConfig, []model.Option{}) + opts := ModelOptions(backendConfig, appConfig) if backendConfig.Backend == "" { inferenceModel, err = loader.GreedyLoader(opts...) diff --git a/core/backend/image.go b/core/backend/image.go index 72c0007c5842..d21feb39eac0 100644 --- a/core/backend/image.go +++ b/core/backend/image.go @@ -9,7 +9,7 @@ import ( func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negative_prompt, src, dst string, loader *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (func() error, error) { - opts := ModelOptions(backendConfig, appConfig, []model.Option{}) + opts := ModelOptions(backendConfig, appConfig) inferenceModel, err := loader.BackendLoader( opts..., diff --git a/core/backend/llm.go b/core/backend/llm.go index 199a62338c84..26c1e4a0a2c0 100644 --- a/core/backend/llm.go +++ b/core/backend/llm.go @@ -38,7 +38,7 @@ func ModelInference(ctx context.Context, s string, messages []schema.Message, im var inferenceModel grpc.Backend var err error - opts := ModelOptions(c, o, []model.Option{}) + opts := ModelOptions(c, o) if c.Backend != "" { opts = append(opts, model.WithBackendString(c.Backend)) diff --git a/core/backend/options.go b/core/backend/options.go index 6586eccf13fd..c65912222a58 100644 --- a/core/backend/options.go +++ b/core/backend/options.go @@ -11,7 +11,7 @@ import ( "github.com/rs/zerolog/log" ) -func ModelOptions(c config.BackendConfig, so *config.ApplicationConfig, opts []model.Option) []model.Option { +func ModelOptions(c config.BackendConfig, so *config.ApplicationConfig, opts ...model.Option) []model.Option { name := c.Name if name == "" { name = c.Model diff --git a/core/backend/rerank.go b/core/backend/rerank.go index f600e2e6eaff..fae97a81dbf6 100644 --- a/core/backend/rerank.go +++ b/core/backend/rerank.go @@ -11,7 +11,7 @@ import ( func Rerank(modelFile string, request *proto.RerankRequest, loader *model.ModelLoader, appConfig *config.ApplicationConfig, backendConfig config.BackendConfig) (*proto.RerankResult, error) { - opts := ModelOptions(backendConfig, appConfig, []model.Option{model.WithModel(modelFile)}) + opts := ModelOptions(backendConfig, appConfig, model.WithModel(modelFile)) rerankModel, err := loader.BackendLoader(opts...) if err != nil { return nil, err diff --git a/core/backend/soundgeneration.go b/core/backend/soundgeneration.go index b1b458b447ab..f79d271581ed 100644 --- a/core/backend/soundgeneration.go +++ b/core/backend/soundgeneration.go @@ -25,7 +25,7 @@ func SoundGeneration( backendConfig config.BackendConfig, ) (string, *proto.Result, error) { - opts := ModelOptions(backendConfig, appConfig, []model.Option{model.WithModel(modelFile)}) + opts := ModelOptions(backendConfig, appConfig, model.WithModel(modelFile)) soundGenModel, err := loader.BackendLoader(opts...) if err != nil { diff --git a/core/backend/token_metrics.go b/core/backend/token_metrics.go index acd256634a0a..19c30e29b9b0 100644 --- a/core/backend/token_metrics.go +++ b/core/backend/token_metrics.go @@ -15,9 +15,7 @@ func TokenMetrics( appConfig *config.ApplicationConfig, backendConfig config.BackendConfig) (*proto.MetricsResponse, error) { - opts := ModelOptions(backendConfig, appConfig, []model.Option{ - model.WithModel(modelFile), - }) + opts := ModelOptions(backendConfig, appConfig, model.WithModel(modelFile)) model, err := loader.BackendLoader(opts...) if err != nil { return nil, err diff --git a/core/backend/tokenize.go b/core/backend/tokenize.go index c8ec8d1cb260..ac63d85a55c4 100644 --- a/core/backend/tokenize.go +++ b/core/backend/tokenize.go @@ -14,9 +14,7 @@ func ModelTokenize(s string, loader *model.ModelLoader, backendConfig config.Bac var inferenceModel grpc.Backend var err error - opts := ModelOptions(backendConfig, appConfig, []model.Option{ - model.WithModel(modelFile), - }) + opts := ModelOptions(backendConfig, appConfig, model.WithModel(modelFile)) if backendConfig.Backend == "" { inferenceModel, err = loader.GreedyLoader(opts...) diff --git a/core/backend/transcript.go b/core/backend/transcript.go index c6ad9b597795..8406d2ef7a1b 100644 --- a/core/backend/transcript.go +++ b/core/backend/transcript.go @@ -18,7 +18,7 @@ func ModelTranscription(audio, language string, translate bool, ml *model.ModelL backendConfig.Backend = model.WhisperBackend } - opts := ModelOptions(backendConfig, appConfig, []model.Option{}) + opts := ModelOptions(backendConfig, appConfig) transcriptionModel, err := ml.BackendLoader(opts...) if err != nil { diff --git a/core/backend/tts.go b/core/backend/tts.go index 20aa358e7257..a9f9612cfb65 100644 --- a/core/backend/tts.go +++ b/core/backend/tts.go @@ -28,11 +28,7 @@ func ModelTTS( bb = model.PiperBackend } - opts := ModelOptions(backendConfig, appConfig, []model.Option{ - model.WithBackendString(bb), - model.WithModel(modelFile), - }) - + opts := ModelOptions(backendConfig, appConfig, model.WithBackendString(bb), model.WithModel(modelFile)) ttsModel, err := loader.BackendLoader(opts...) if err != nil { return "", nil, err diff --git a/core/startup/startup.go b/core/startup/startup.go index 17e54bc0603b..941b73c3e9d5 100644 --- a/core/startup/startup.go +++ b/core/startup/startup.go @@ -160,7 +160,7 @@ func Startup(opts ...config.AppOption) (*config.BackendConfigLoader, *model.Mode log.Debug().Msgf("Auto loading model %s into memory from file: %s", m, cfg.Model) - o := backend.ModelOptions(*cfg, options, []model.Option{}) + o := backend.ModelOptions(*cfg, options) var backendErr error if cfg.Backend != "" { From 97060ae9a5338973eedf716740aea625e4921796 Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Fri, 8 Nov 2024 19:12:13 +0100 Subject: [PATCH 2/2] chore(refactor): do not expose internal backend Loader Signed-off-by: Ettore Di Giacinto --- core/backend/embeddings.go | 10 +--------- core/backend/image.go | 3 +-- core/backend/llm.go | 18 ++---------------- core/backend/rerank.go | 2 +- core/backend/soundgeneration.go | 3 +-- core/backend/stores.go | 19 +++++++++---------- core/backend/token_metrics.go | 2 +- core/backend/tokenize.go | 4 ++-- core/backend/transcript.go | 2 +- core/backend/tts.go | 2 +- core/startup/startup.go | 7 +------ pkg/model/initializers.go | 10 +++++++--- tests/integration/stores_test.go | 2 +- 13 files changed, 29 insertions(+), 55 deletions(-) diff --git a/core/backend/embeddings.go b/core/backend/embeddings.go index 5bf8eff90781..a96e9829af16 100644 --- a/core/backend/embeddings.go +++ b/core/backend/embeddings.go @@ -11,17 +11,9 @@ import ( func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (func() ([]float32, error), error) { - var inferenceModel interface{} - var err error - opts := ModelOptions(backendConfig, appConfig) - if backendConfig.Backend == "" { - inferenceModel, err = loader.GreedyLoader(opts...) - } else { - opts = append(opts, model.WithBackendString(backendConfig.Backend)) - inferenceModel, err = loader.BackendLoader(opts...) - } + inferenceModel, err := loader.Load(opts...) if err != nil { return nil, err } diff --git a/core/backend/image.go b/core/backend/image.go index d21feb39eac0..38ca43570fe8 100644 --- a/core/backend/image.go +++ b/core/backend/image.go @@ -10,8 +10,7 @@ import ( func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negative_prompt, src, dst string, loader *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (func() error, error) { opts := ModelOptions(backendConfig, appConfig) - - inferenceModel, err := loader.BackendLoader( + inferenceModel, err := loader.Load( opts..., ) if err != nil { diff --git a/core/backend/llm.go b/core/backend/llm.go index 26c1e4a0a2c0..4491a191eeb4 100644 --- a/core/backend/llm.go +++ b/core/backend/llm.go @@ -16,7 +16,6 @@ import ( "github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/gallery" - "github.com/mudler/LocalAI/pkg/grpc" "github.com/mudler/LocalAI/pkg/grpc/proto" model "github.com/mudler/LocalAI/pkg/model" "github.com/mudler/LocalAI/pkg/utils" @@ -35,15 +34,6 @@ type TokenUsage struct { func ModelInference(ctx context.Context, s string, messages []schema.Message, images, videos, audios []string, loader *model.ModelLoader, c config.BackendConfig, o *config.ApplicationConfig, tokenCallback func(string, TokenUsage) bool) (func() (LLMResponse, error), error) { modelFile := c.Model - var inferenceModel grpc.Backend - var err error - - opts := ModelOptions(c, o) - - if c.Backend != "" { - opts = append(opts, model.WithBackendString(c.Backend)) - } - // Check if the modelFile exists, if it doesn't try to load it from the gallery if o.AutoloadGalleries { // experimental if _, err := os.Stat(modelFile); os.IsNotExist(err) { @@ -56,12 +46,8 @@ func ModelInference(ctx context.Context, s string, messages []schema.Message, im } } - if c.Backend == "" { - inferenceModel, err = loader.GreedyLoader(opts...) - } else { - inferenceModel, err = loader.BackendLoader(opts...) - } - + opts := ModelOptions(c, o) + inferenceModel, err := loader.Load(opts...) if err != nil { return nil, err } diff --git a/core/backend/rerank.go b/core/backend/rerank.go index fae97a81dbf6..8152ef7fc357 100644 --- a/core/backend/rerank.go +++ b/core/backend/rerank.go @@ -12,7 +12,7 @@ import ( func Rerank(modelFile string, request *proto.RerankRequest, loader *model.ModelLoader, appConfig *config.ApplicationConfig, backendConfig config.BackendConfig) (*proto.RerankResult, error) { opts := ModelOptions(backendConfig, appConfig, model.WithModel(modelFile)) - rerankModel, err := loader.BackendLoader(opts...) + rerankModel, err := loader.Load(opts...) if err != nil { return nil, err } diff --git a/core/backend/soundgeneration.go b/core/backend/soundgeneration.go index f79d271581ed..a8d46478c7cd 100644 --- a/core/backend/soundgeneration.go +++ b/core/backend/soundgeneration.go @@ -26,8 +26,7 @@ func SoundGeneration( ) (string, *proto.Result, error) { opts := ModelOptions(backendConfig, appConfig, model.WithModel(modelFile)) - - soundGenModel, err := loader.BackendLoader(opts...) + soundGenModel, err := loader.Load(opts...) if err != nil { return "", nil, err } diff --git a/core/backend/stores.go b/core/backend/stores.go index 1b514584cbeb..f5ee9166df8b 100644 --- a/core/backend/stores.go +++ b/core/backend/stores.go @@ -8,16 +8,15 @@ import ( ) func StoreBackend(sl *model.ModelLoader, appConfig *config.ApplicationConfig, storeName string) (grpc.Backend, error) { - if storeName == "" { - storeName = "default" - } + if storeName == "" { + storeName = "default" + } - sc := []model.Option{ - model.WithBackendString(model.LocalStoreBackend), - model.WithAssetDir(appConfig.AssetsDestination), - model.WithModel(storeName), - } + sc := []model.Option{ + model.WithBackendString(model.LocalStoreBackend), + model.WithAssetDir(appConfig.AssetsDestination), + model.WithModel(storeName), + } - return sl.BackendLoader(sc...) + return sl.Load(sc...) } - diff --git a/core/backend/token_metrics.go b/core/backend/token_metrics.go index 19c30e29b9b0..cc71c8681e54 100644 --- a/core/backend/token_metrics.go +++ b/core/backend/token_metrics.go @@ -16,7 +16,7 @@ func TokenMetrics( backendConfig config.BackendConfig) (*proto.MetricsResponse, error) { opts := ModelOptions(backendConfig, appConfig, model.WithModel(modelFile)) - model, err := loader.BackendLoader(opts...) + model, err := loader.Load(opts...) if err != nil { return nil, err } diff --git a/core/backend/tokenize.go b/core/backend/tokenize.go index ac63d85a55c4..2f813e18736b 100644 --- a/core/backend/tokenize.go +++ b/core/backend/tokenize.go @@ -17,10 +17,10 @@ func ModelTokenize(s string, loader *model.ModelLoader, backendConfig config.Bac opts := ModelOptions(backendConfig, appConfig, model.WithModel(modelFile)) if backendConfig.Backend == "" { - inferenceModel, err = loader.GreedyLoader(opts...) + inferenceModel, err = loader.Load(opts...) } else { opts = append(opts, model.WithBackendString(backendConfig.Backend)) - inferenceModel, err = loader.BackendLoader(opts...) + inferenceModel, err = loader.Load(opts...) } if err != nil { return schema.TokenizeResponse{}, err diff --git a/core/backend/transcript.go b/core/backend/transcript.go index 8406d2ef7a1b..372f6984237c 100644 --- a/core/backend/transcript.go +++ b/core/backend/transcript.go @@ -20,7 +20,7 @@ func ModelTranscription(audio, language string, translate bool, ml *model.ModelL opts := ModelOptions(backendConfig, appConfig) - transcriptionModel, err := ml.BackendLoader(opts...) + transcriptionModel, err := ml.Load(opts...) if err != nil { return nil, err } diff --git a/core/backend/tts.go b/core/backend/tts.go index a9f9612cfb65..f9be6955bcd6 100644 --- a/core/backend/tts.go +++ b/core/backend/tts.go @@ -29,7 +29,7 @@ func ModelTTS( } opts := ModelOptions(backendConfig, appConfig, model.WithBackendString(bb), model.WithModel(modelFile)) - ttsModel, err := loader.BackendLoader(opts...) + ttsModel, err := loader.Load(opts...) if err != nil { return "", nil, err } diff --git a/core/startup/startup.go b/core/startup/startup.go index 941b73c3e9d5..0eb5fa585585 100644 --- a/core/startup/startup.go +++ b/core/startup/startup.go @@ -163,12 +163,7 @@ func Startup(opts ...config.AppOption) (*config.BackendConfigLoader, *model.Mode o := backend.ModelOptions(*cfg, options) var backendErr error - if cfg.Backend != "" { - o = append(o, model.WithBackendString(cfg.Backend)) - _, backendErr = ml.BackendLoader(o...) - } else { - _, backendErr = ml.GreedyLoader(o...) - } + _, backendErr = ml.Load(o...) if backendErr != nil { return nil, nil, nil, err } diff --git a/pkg/model/initializers.go b/pkg/model/initializers.go index 5723e3e41db2..a5bedf79a7a6 100644 --- a/pkg/model/initializers.go +++ b/pkg/model/initializers.go @@ -455,7 +455,7 @@ func (ml *ModelLoader) ListAvailableBackends(assetdir string) ([]string, error) return orderBackends(backends) } -func (ml *ModelLoader) BackendLoader(opts ...Option) (client grpc.Backend, err error) { +func (ml *ModelLoader) backendLoader(opts ...Option) (client grpc.Backend, err error) { o := NewOptions(opts...) log.Info().Msgf("Loading model '%s' with backend %s", o.modelID, o.backendString) @@ -500,7 +500,7 @@ func (ml *ModelLoader) stopActiveBackends(modelID string, singleActiveBackend bo } } -func (ml *ModelLoader) GreedyLoader(opts ...Option) (grpc.Backend, error) { +func (ml *ModelLoader) Load(opts ...Option) (grpc.Backend, error) { o := NewOptions(opts...) // Return earlier if we have a model already loaded @@ -513,6 +513,10 @@ func (ml *ModelLoader) GreedyLoader(opts ...Option) (grpc.Backend, error) { ml.stopActiveBackends(o.modelID, o.singleActiveBackend) + if o.backendString != "" { + return ml.backendLoader(opts...) + } + var err error // get backends embedded in the binary @@ -536,7 +540,7 @@ func (ml *ModelLoader) GreedyLoader(opts ...Option) (grpc.Backend, error) { WithBackendString(key), }...) - model, modelerr := ml.BackendLoader(options...) + model, modelerr := ml.backendLoader(options...) if modelerr == nil && model != nil { log.Info().Msgf("[%s] Loads OK", key) return model, nil diff --git a/tests/integration/stores_test.go b/tests/integration/stores_test.go index 4244d817fc04..5ed46b19649f 100644 --- a/tests/integration/stores_test.go +++ b/tests/integration/stores_test.go @@ -57,7 +57,7 @@ var _ = Describe("Integration tests for the stores backend(s) and internal APIs" } sl = model.NewModelLoader("") - sc, err = sl.BackendLoader(storeOpts...) + sc, err = sl.Load(storeOpts...) Expect(err).ToNot(HaveOccurred()) Expect(sc).ToNot(BeNil()) })