From 1071b751b22925fd23eb0e595d17a81aa532d0f5 Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Fri, 21 Jun 2024 19:19:30 +0200 Subject: [PATCH] Allow to install ollama models from CLI Signed-off-by: Ettore Di Giacinto --- core/cli/models.go | 15 +++++++++------ core/cli/util.go | 17 +++++++++++++++++ pkg/downloader/uri.go | 2 ++ pkg/startup/model_preload.go | 23 +++++++++++++++++++++++ 4 files changed, 51 insertions(+), 6 deletions(-) diff --git a/core/cli/models.go b/core/cli/models.go index a6ba39b9a091..b9e3d1b86e75 100644 --- a/core/cli/models.go +++ b/core/cli/models.go @@ -6,6 +6,7 @@ import ( cliContext "github.com/go-skynet/LocalAI/core/cli/context" + "github.com/go-skynet/LocalAI/pkg/downloader" "github.com/go-skynet/LocalAI/pkg/gallery" "github.com/go-skynet/LocalAI/pkg/startup" "github.com/rs/zerolog/log" @@ -79,13 +80,15 @@ func (mi *ModelsInstall) Run(ctx *cliContext.Context) error { return err } - model := gallery.FindModel(models, modelName, mi.ModelsPath) - if model == nil { - log.Error().Str("model", modelName).Msg("model not found") - return err - } + if !downloader.LooksLikeOCI(modelName) { + model := gallery.FindModel(models, modelName, mi.ModelsPath) + if model == nil { + log.Error().Str("model", modelName).Msg("model not found") + return err + } - log.Info().Str("model", modelName).Str("license", model.License).Msg("installing model") + log.Info().Str("model", modelName).Str("license", model.License).Msg("installing model") + } err = startup.InstallModels(galleries, "", mi.ModelsPath, progressCallback, modelName) if err != nil { return err diff --git a/core/cli/util.go b/core/cli/util.go index f0f78cf2cd53..7cff302c6dab 100644 --- a/core/cli/util.go +++ b/core/cli/util.go @@ -2,15 +2,19 @@ package cli import ( "fmt" + "path/filepath" "github.com/rs/zerolog/log" cliContext "github.com/go-skynet/LocalAI/core/cli/context" + "github.com/go-skynet/LocalAI/pkg/downloader" + "github.com/go-skynet/LocalAI/pkg/utils" gguf "github.com/thxcode/gguf-parser-go" ) type UtilCMD struct { GGUFInfo GGUFInfoCMD `cmd:"" name:"gguf-info" help:"Get information about a GGUF file"` + Download DownloadCMD `cmd:"" name:"download" help:"Download a file or a model from an OCI registry"` } type GGUFInfoCMD struct { @@ -53,3 +57,16 @@ func (u *GGUFInfoCMD) Run(ctx *cliContext.Context) error { return nil } + +type DownloadCMD struct { + Args []string `arg:"" optional:"" name:"args" help:"File URL and name to download"` + ModelsPath string `env:"LOCALAI_MODELS_PATH,MODELS_PATH" type:"path" default:"${basepath}/models" help:"Path containing models used for inferencing" group:"storage"` +} + +func (u *DownloadCMD) Run(ctx *cliContext.Context) error { + if len(u.Args) < 2 { + return fmt.Errorf("no URL or model name provided") + } + + return downloader.DownloadFile(u.Args[0], filepath.Join(u.ModelsPath, u.Args[1]), "", 1, 1, utils.DisplayDownloadFunction) +} diff --git a/pkg/downloader/uri.go b/pkg/downloader/uri.go index 10f057d29524..ceda08594700 100644 --- a/pkg/downloader/uri.go +++ b/pkg/downloader/uri.go @@ -173,9 +173,11 @@ func DownloadFile(url string, filePath, sha string, fileN, total int, downloadSt } if strings.HasPrefix(url, OllamaPrefix) { + url = strings.TrimPrefix(url, OllamaPrefix) return oci.OllamaFetchModel(url, filePath, progressStatus) } + url = strings.TrimPrefix(url, OCIPrefix) img, err := oci.GetImage(url, "", nil, nil) if err != nil { return fmt.Errorf("failed to get image %q: %v", url, err) diff --git a/pkg/startup/model_preload.go b/pkg/startup/model_preload.go index aa732ab0be8c..bb63ff53aee2 100644 --- a/pkg/startup/model_preload.go +++ b/pkg/startup/model_preload.go @@ -5,6 +5,7 @@ import ( "fmt" "os" "path/filepath" + "strings" "github.com/go-skynet/LocalAI/embedded" "github.com/go-skynet/LocalAI/pkg/downloader" @@ -52,6 +53,28 @@ func InstallModels(galleries []gallery.Gallery, modelLibraryURL string, modelPat log.Error().Err(e).Str("filepath", modelDefinitionFilePath).Msg("error writing model definition") err = errors.Join(err, e) } + case downloader.LooksLikeOCI(url): + log.Debug().Msgf("[startup] resolved OCI model to download: %s", url) + + // convert OCI image name to a file name. + ociName := strings.TrimPrefix(url, downloader.OCIPrefix) + ociName = strings.TrimPrefix(ociName, downloader.OllamaPrefix) + ociName = strings.ReplaceAll(ociName, "/", "__") + ociName = strings.ReplaceAll(ociName, ":", "__") + + // check if file exists + if _, e := os.Stat(filepath.Join(modelPath, ociName)); errors.Is(e, os.ErrNotExist) { + modelDefinitionFilePath := filepath.Join(modelPath, ociName) + e := downloader.DownloadFile(url, modelDefinitionFilePath, "", 0, 0, func(fileName, current, total string, percent float64) { + utils.DisplayDownloadFunction(fileName, current, total, percent) + }) + if e != nil { + log.Error().Err(e).Str("url", url).Str("filepath", modelDefinitionFilePath).Msg("error downloading model") + err = errors.Join(err, e) + } + } + + log.Info().Msgf("[startup] installed model from OCI repository: %s", ociName) case downloader.LooksLikeURL(url): log.Debug().Msgf("[startup] resolved model to download: %s", url)