Skip to content

Commit

Permalink
chore(refactor): do not expose internal backend Loader
Browse files Browse the repository at this point in the history
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
  • Loading branch information
mudler committed Nov 8, 2024
1 parent bae9939 commit 97060ae
Show file tree
Hide file tree
Showing 13 changed files with 29 additions and 55 deletions.
10 changes: 1 addition & 9 deletions core/backend/embeddings.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,9 @@ import (

func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (func() ([]float32, error), error) {

var inferenceModel interface{}
var err error

opts := ModelOptions(backendConfig, appConfig)

if backendConfig.Backend == "" {
inferenceModel, err = loader.GreedyLoader(opts...)
} else {
opts = append(opts, model.WithBackendString(backendConfig.Backend))
inferenceModel, err = loader.BackendLoader(opts...)
}
inferenceModel, err := loader.Load(opts...)
if err != nil {
return nil, err
}
Expand Down
3 changes: 1 addition & 2 deletions core/backend/image.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@ import (
func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negative_prompt, src, dst string, loader *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (func() error, error) {

opts := ModelOptions(backendConfig, appConfig)

inferenceModel, err := loader.BackendLoader(
inferenceModel, err := loader.Load(
opts...,
)
if err != nil {
Expand Down
18 changes: 2 additions & 16 deletions core/backend/llm.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ import (
"github.com/mudler/LocalAI/core/schema"

"github.com/mudler/LocalAI/core/gallery"
"github.com/mudler/LocalAI/pkg/grpc"
"github.com/mudler/LocalAI/pkg/grpc/proto"
model "github.com/mudler/LocalAI/pkg/model"
"github.com/mudler/LocalAI/pkg/utils"
Expand All @@ -35,15 +34,6 @@ type TokenUsage struct {
func ModelInference(ctx context.Context, s string, messages []schema.Message, images, videos, audios []string, loader *model.ModelLoader, c config.BackendConfig, o *config.ApplicationConfig, tokenCallback func(string, TokenUsage) bool) (func() (LLMResponse, error), error) {
modelFile := c.Model

var inferenceModel grpc.Backend
var err error

opts := ModelOptions(c, o)

if c.Backend != "" {
opts = append(opts, model.WithBackendString(c.Backend))
}

// Check if the modelFile exists, if it doesn't try to load it from the gallery
if o.AutoloadGalleries { // experimental
if _, err := os.Stat(modelFile); os.IsNotExist(err) {
Expand All @@ -56,12 +46,8 @@ func ModelInference(ctx context.Context, s string, messages []schema.Message, im
}
}

if c.Backend == "" {
inferenceModel, err = loader.GreedyLoader(opts...)
} else {
inferenceModel, err = loader.BackendLoader(opts...)
}

opts := ModelOptions(c, o)
inferenceModel, err := loader.Load(opts...)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion core/backend/rerank.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (
func Rerank(modelFile string, request *proto.RerankRequest, loader *model.ModelLoader, appConfig *config.ApplicationConfig, backendConfig config.BackendConfig) (*proto.RerankResult, error) {

opts := ModelOptions(backendConfig, appConfig, model.WithModel(modelFile))
rerankModel, err := loader.BackendLoader(opts...)
rerankModel, err := loader.Load(opts...)
if err != nil {
return nil, err
}
Expand Down
3 changes: 1 addition & 2 deletions core/backend/soundgeneration.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@ func SoundGeneration(
) (string, *proto.Result, error) {

opts := ModelOptions(backendConfig, appConfig, model.WithModel(modelFile))

soundGenModel, err := loader.BackendLoader(opts...)
soundGenModel, err := loader.Load(opts...)
if err != nil {
return "", nil, err
}
Expand Down
19 changes: 9 additions & 10 deletions core/backend/stores.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,15 @@ import (
)

func StoreBackend(sl *model.ModelLoader, appConfig *config.ApplicationConfig, storeName string) (grpc.Backend, error) {
if storeName == "" {
storeName = "default"
}
if storeName == "" {
storeName = "default"
}

sc := []model.Option{
model.WithBackendString(model.LocalStoreBackend),
model.WithAssetDir(appConfig.AssetsDestination),
model.WithModel(storeName),
}
sc := []model.Option{
model.WithBackendString(model.LocalStoreBackend),
model.WithAssetDir(appConfig.AssetsDestination),
model.WithModel(storeName),
}

return sl.BackendLoader(sc...)
return sl.Load(sc...)
}

2 changes: 1 addition & 1 deletion core/backend/token_metrics.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ func TokenMetrics(
backendConfig config.BackendConfig) (*proto.MetricsResponse, error) {

opts := ModelOptions(backendConfig, appConfig, model.WithModel(modelFile))
model, err := loader.BackendLoader(opts...)
model, err := loader.Load(opts...)
if err != nil {
return nil, err
}
Expand Down
4 changes: 2 additions & 2 deletions core/backend/tokenize.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@ func ModelTokenize(s string, loader *model.ModelLoader, backendConfig config.Bac
opts := ModelOptions(backendConfig, appConfig, model.WithModel(modelFile))

if backendConfig.Backend == "" {
inferenceModel, err = loader.GreedyLoader(opts...)
inferenceModel, err = loader.Load(opts...)
} else {
opts = append(opts, model.WithBackendString(backendConfig.Backend))
inferenceModel, err = loader.BackendLoader(opts...)
inferenceModel, err = loader.Load(opts...)
}
if err != nil {
return schema.TokenizeResponse{}, err
Expand Down
2 changes: 1 addition & 1 deletion core/backend/transcript.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ func ModelTranscription(audio, language string, translate bool, ml *model.ModelL

opts := ModelOptions(backendConfig, appConfig)

transcriptionModel, err := ml.BackendLoader(opts...)
transcriptionModel, err := ml.Load(opts...)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion core/backend/tts.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ func ModelTTS(
}

opts := ModelOptions(backendConfig, appConfig, model.WithBackendString(bb), model.WithModel(modelFile))
ttsModel, err := loader.BackendLoader(opts...)
ttsModel, err := loader.Load(opts...)
if err != nil {
return "", nil, err
}
Expand Down
7 changes: 1 addition & 6 deletions core/startup/startup.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,12 +163,7 @@ func Startup(opts ...config.AppOption) (*config.BackendConfigLoader, *model.Mode
o := backend.ModelOptions(*cfg, options)

var backendErr error
if cfg.Backend != "" {
o = append(o, model.WithBackendString(cfg.Backend))
_, backendErr = ml.BackendLoader(o...)
} else {
_, backendErr = ml.GreedyLoader(o...)
}
_, backendErr = ml.Load(o...)
if backendErr != nil {
return nil, nil, nil, err
}
Expand Down
10 changes: 7 additions & 3 deletions pkg/model/initializers.go
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,7 @@ func (ml *ModelLoader) ListAvailableBackends(assetdir string) ([]string, error)
return orderBackends(backends)
}

func (ml *ModelLoader) BackendLoader(opts ...Option) (client grpc.Backend, err error) {
func (ml *ModelLoader) backendLoader(opts ...Option) (client grpc.Backend, err error) {
o := NewOptions(opts...)

log.Info().Msgf("Loading model '%s' with backend %s", o.modelID, o.backendString)
Expand Down Expand Up @@ -500,7 +500,7 @@ func (ml *ModelLoader) stopActiveBackends(modelID string, singleActiveBackend bo
}
}

func (ml *ModelLoader) GreedyLoader(opts ...Option) (grpc.Backend, error) {
func (ml *ModelLoader) Load(opts ...Option) (grpc.Backend, error) {
o := NewOptions(opts...)

// Return earlier if we have a model already loaded
Expand All @@ -513,6 +513,10 @@ func (ml *ModelLoader) GreedyLoader(opts ...Option) (grpc.Backend, error) {

ml.stopActiveBackends(o.modelID, o.singleActiveBackend)

if o.backendString != "" {
return ml.backendLoader(opts...)
}

var err error

// get backends embedded in the binary
Expand All @@ -536,7 +540,7 @@ func (ml *ModelLoader) GreedyLoader(opts ...Option) (grpc.Backend, error) {
WithBackendString(key),
}...)

model, modelerr := ml.BackendLoader(options...)
model, modelerr := ml.backendLoader(options...)
if modelerr == nil && model != nil {
log.Info().Msgf("[%s] Loads OK", key)
return model, nil
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/stores_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ var _ = Describe("Integration tests for the stores backend(s) and internal APIs"
}

sl = model.NewModelLoader("")
sc, err = sl.BackendLoader(storeOpts...)
sc, err = sl.Load(storeOpts...)
Expect(err).ToNot(HaveOccurred())
Expect(sc).ToNot(BeNil())
})
Expand Down

0 comments on commit 97060ae

Please sign in to comment.