Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(prepare): allow to specify additional files to download #1526

Merged
merged 1 commit into from
Jan 1, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 28 additions & 3 deletions api/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,14 @@ type Config struct {
// CUDA
// Explicitly enable CUDA or not (some backends might need it)
CUDA bool `yaml:"cuda"`

DownloadFiles []File `yaml:"download_files"`
}

type File struct {
Filename string `yaml:"filename" json:"filename"`
SHA256 string `yaml:"sha256" json:"sha256"`
URI string `yaml:"uri" json:"uri"`
}

type VallE struct {
Expand Down Expand Up @@ -272,10 +280,29 @@ func (cm *ConfigLoader) Preload(modelPath string) error {
cm.Lock()
defer cm.Unlock()

status := func(fileName, current, total string, percent float64) {
utils.DisplayDownloadFunction(fileName, current, total, percent)
}

log.Info().Msgf("Preloading models from %s", modelPath)

for i, config := range cm.configs {

// Download files and verify their SHA
for _, file := range config.DownloadFiles {
log.Debug().Msgf("Checking %q exists and matches SHA", file.Filename)

if err := utils.VerifyPath(file.Filename, modelPath); err != nil {
return err
}
// Create file path
filePath := filepath.Join(modelPath, file.Filename)

if err := utils.DownloadFile(file.URI, filePath, file.SHA256, status); err != nil {
return err
}
}

modelURL := config.PredictionOptions.Model
modelURL = utils.ConvertURL(modelURL)

Expand All @@ -285,9 +312,7 @@ func (cm *ConfigLoader) Preload(modelPath string) error {

// check if file exists
if _, err := os.Stat(filepath.Join(modelPath, md5Name)); errors.Is(err, os.ErrNotExist) {
err := utils.DownloadFile(modelURL, filepath.Join(modelPath, md5Name), "", func(fileName, current, total string, percent float64) {
utils.DisplayDownloadFunction(fileName, current, total, percent)
})
err := utils.DownloadFile(modelURL, filepath.Join(modelPath, md5Name), "", status)
if err != nil {
return err
}
Expand Down