Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: consolidate usage of GetURI #674

Merged
merged 2 commits into from
Jun 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,13 +80,13 @@ func App(opts ...AppOption) (*fiber.App, error) {
app.Use(recover.New())

if options.preloadJSONModels != "" {
if err := ApplyGalleryFromString(options.loader.ModelPath, options.preloadJSONModels, cm); err != nil {
if err := ApplyGalleryFromString(options.loader.ModelPath, options.preloadJSONModels, cm, options.galleries); err != nil {
return nil, err
}
}

if options.preloadModelsFromPath != "" {
if err := ApplyGalleryFromFile(options.loader.ModelPath, options.preloadModelsFromPath, cm); err != nil {
if err := ApplyGalleryFromFile(options.loader.ModelPath, options.preloadModelsFromPath, cm, options.galleries); err != nil {
return nil, err
}
}
Expand Down
29 changes: 5 additions & 24 deletions api/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
. "github.com/go-skynet/LocalAI/api"
"github.com/go-skynet/LocalAI/pkg/gallery"
"github.com/go-skynet/LocalAI/pkg/model"
"github.com/go-skynet/LocalAI/pkg/utils"
"github.com/gofiber/fiber/v2"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
Expand Down Expand Up @@ -56,30 +57,10 @@ func getModelStatus(url string) (response map[string]interface{}) {
}

func getModels(url string) (response []gallery.GalleryModel) {

//url := "http://localhost:AI/models/apply"

// Create the request payload

// Create the HTTP request
resp, err := http.Get(url)
if err != nil {
return nil
}
defer resp.Body.Close()

body, err := ioutil.ReadAll(resp.Body)
if err != nil {
fmt.Println("Error reading response body:", err)
return
}

// Unmarshal the response into a map[string]interface{}
err = json.Unmarshal(body, &response)
if err != nil {
fmt.Println("Error unmarshaling JSON response:", err)
return
}
utils.GetURI(url, func(url string, i []byte) error {
// Unmarshal YAML data into a struct
return json.Unmarshal(i, &response)
})
return
}

Expand Down
36 changes: 15 additions & 21 deletions api/gallery.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,8 @@ func newGalleryApplier(modelPath string) *galleryApplier {

// prepareModel applies a
func prepareModel(modelPath string, req gallery.GalleryModel, cm *ConfigMerger, downloadStatus func(string, string, string, float64)) error {
var config gallery.Config

err := req.Get(&config)
config, err := gallery.GetGalleryConfigFromURL(req.URL)
if err != nil {
return err
}
Expand Down Expand Up @@ -144,40 +143,35 @@ func displayDownload(fileName string, current string, total string, percentage f
}
}

func ApplyGalleryFromFile(modelPath, s string, cm *ConfigMerger) error {
type galleryModel struct {
gallery.GalleryModel
ID string `json:"id"`
}

func ApplyGalleryFromFile(modelPath, s string, cm *ConfigMerger, galleries []gallery.Gallery) error {
dat, err := os.ReadFile(s)
if err != nil {
return err
}
var requests []gallery.GalleryModel
err = json.Unmarshal(dat, &requests)
if err != nil {
return err
}

for _, r := range requests {
if err := prepareModel(modelPath, r, cm, displayDownload); err != nil {
return err
}
}

return nil
return ApplyGalleryFromString(modelPath, string(dat), cm, galleries)
}

func ApplyGalleryFromString(modelPath, s string, cm *ConfigMerger) error {
var requests []gallery.GalleryModel
func ApplyGalleryFromString(modelPath, s string, cm *ConfigMerger, galleries []gallery.Gallery) error {
var requests []galleryModel
err := json.Unmarshal([]byte(s), &requests)
if err != nil {
return err
}

for _, r := range requests {
if err := prepareModel(modelPath, r, cm, displayDownload); err != nil {
return err
if r.ID == "" {
err = prepareModel(modelPath, r.GalleryModel, cm, displayDownload)
} else {
err = gallery.InstallModelFromGallery(galleries, r.ID, modelPath, r.GalleryModel, displayDownload)
}
}

return nil
return err
}

func getOpStatus(g *galleryApplier) func(c *fiber.Ctx) error {
Expand Down
6 changes: 2 additions & 4 deletions pkg/gallery/gallery.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,7 @@ func InstallModelFromGallery(galleries []Gallery, name string, basePath string,
}

applyModel := func(model *GalleryModel) error {
var config Config

err := model.Get(&config)
config, err := GetGalleryConfigFromURL(model.URL)
if err != nil {
return err
}
Expand Down Expand Up @@ -79,7 +77,7 @@ func AvailableGalleryModels(galleries []Gallery, basePath string) ([]*GalleryMod
func getGalleryModels(gallery Gallery, basePath string) ([]*GalleryModel, error) {
var models []*GalleryModel = []*GalleryModel{}

err := utils.GetURI(gallery.URL, func(d []byte) error {
err := utils.GetURI(gallery.URL, func(url string, d []byte) error {
return yaml.Unmarshal(d, &models)
})
if err != nil {
Expand Down
11 changes: 11 additions & 0 deletions pkg/gallery/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,17 @@ type PromptTemplate struct {
Content string `yaml:"content"`
}

func GetGalleryConfigFromURL(url string) (Config, error) {
var config Config
err := utils.GetURI(url, func(url string, d []byte) error {
return yaml.Unmarshal(d, &config)
})
if err != nil {
return config, err
}
return config, nil
}

func ReadConfigFile(filePath string) (*Config, error) {
// Read the YAML file
yamlFile, err := os.ReadFile(filePath)
Expand Down
58 changes: 0 additions & 58 deletions pkg/gallery/request.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,5 @@
package gallery

import (
"fmt"
"net/url"
"strings"

"github.com/go-skynet/LocalAI/pkg/utils"
"gopkg.in/yaml.v2"
)

// GalleryModel is the struct used to represent a model in the gallery returned by the endpoint.
// It is used to install the model by resolving the URL and downloading the files.
// The other fields are used to override the configuration of the model.
Expand All @@ -34,52 +25,3 @@ type GalleryModel struct {
const (
githubURI = "github:"
)

func (request GalleryModel) DecodeURL() (string, error) {
input := request.URL
var rawURL string

if strings.HasPrefix(input, githubURI) {
parts := strings.Split(input, ":")
repoParts := strings.Split(parts[1], "@")
branch := "main"

if len(repoParts) > 1 {
branch = repoParts[1]
}

repoPath := strings.Split(repoParts[0], "/")
org := repoPath[0]
project := repoPath[1]
projectPath := strings.Join(repoPath[2:], "/")

rawURL = fmt.Sprintf("https://raw.githubusercontent.com/%s/%s/%s/%s", org, project, branch, projectPath)
} else if strings.HasPrefix(input, "http://") || strings.HasPrefix(input, "https://") {
// Handle regular URLs
u, err := url.Parse(input)
if err != nil {
return "", fmt.Errorf("invalid URL: %w", err)
}
rawURL = u.String()
// check if it's a file path
} else if strings.HasPrefix(input, "file://") {
return input, nil
} else {

return "", fmt.Errorf("invalid URL format: %s", input)
}

return rawURL, nil
}

// Get fetches a model from a URL and unmarshals it into a struct
func (request GalleryModel) Get(i interface{}) error {
url, err := request.DecodeURL()
if err != nil {
return err
}

return utils.GetURI(url, func(d []byte) error {
return yaml.Unmarshal(d, i)
})
}
26 changes: 1 addition & 25 deletions pkg/gallery/request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,37 +6,13 @@ import (
. "github.com/onsi/gomega"
)

type example struct {
Name string `yaml:"name"`
}

var _ = Describe("Gallery API tests", func() {

Context("requests", func() {
It("parses github with a branch", func() {
req := GalleryModel{URL: "github:go-skynet/model-gallery/gpt4all-j.yaml@main"}
var e example
err := req.Get(&e)
e, err := GetGalleryConfigFromURL(req.URL)
Expect(err).ToNot(HaveOccurred())
Expect(e.Name).To(Equal("gpt4all-j"))
})
It("parses github without a branch", func() {
req := GalleryModel{URL: "github:go-skynet/model-gallery/gpt4all-j.yaml@main"}
str, err := req.DecodeURL()
Expect(err).ToNot(HaveOccurred())
Expect(str).To(Equal("https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml"))
})
It("parses github without a branch", func() {
req := GalleryModel{URL: "github:go-skynet/model-gallery/gpt4all-j.yaml"}
str, err := req.DecodeURL()
Expect(err).ToNot(HaveOccurred())
Expect(str).To(Equal("https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml"))
})
It("parses URLS", func() {
req := GalleryModel{URL: "https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml"}
str, err := req.DecodeURL()
Expect(err).ToNot(HaveOccurred())
Expect(str).To(Equal("https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml"))
})
})
})
28 changes: 25 additions & 3 deletions pkg/utils/uri.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,34 @@
package utils

import (
"fmt"
"io/ioutil"
"net/http"
"strings"
)

func GetURI(url string, f func(i []byte) error) error {
const (
githubURI = "github:"
)

func GetURI(url string, f func(url string, i []byte) error) error {
if strings.HasPrefix(url, githubURI) {
parts := strings.Split(url, ":")
repoParts := strings.Split(parts[1], "@")
branch := "main"

if len(repoParts) > 1 {
branch = repoParts[1]
}

repoPath := strings.Split(repoParts[0], "/")
org := repoPath[0]
project := repoPath[1]
projectPath := strings.Join(repoPath[2:], "/")

url = fmt.Sprintf("https://raw.githubusercontent.com/%s/%s/%s/%s", org, project, branch, projectPath)
}

if strings.HasPrefix(url, "file://") {
rawURL := strings.TrimPrefix(url, "file://")
// Read the response body
Expand All @@ -16,7 +38,7 @@ func GetURI(url string, f func(i []byte) error) error {
}

// Unmarshal YAML data into a struct
return f(body)
return f(url, body)
}

// Send a GET request to the URL
Expand All @@ -33,5 +55,5 @@ func GetURI(url string, f func(i []byte) error) error {
}

// Unmarshal YAML data into a struct
return f(body)
return f(url, body)
}
36 changes: 36 additions & 0 deletions pkg/utils/uri_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package utils_test

import (
. "github.com/go-skynet/LocalAI/pkg/utils"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)

var _ = Describe("Gallery API tests", func() {
Context("URI", func() {
It("parses github with a branch", func() {
Expect(
GetURI("github:go-skynet/model-gallery/gpt4all-j.yaml", func(url string, i []byte) error {
Expect(url).To(Equal("https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml"))
return nil
}),
).ToNot(HaveOccurred())
})
It("parses github without a branch", func() {
Expect(
GetURI("github:go-skynet/model-gallery/gpt4all-j.yaml@main", func(url string, i []byte) error {
Expect(url).To(Equal("https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml"))
return nil
}),
).ToNot(HaveOccurred())
})
It("parses github with urls", func() {
Expect(
GetURI("https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml", func(url string, i []byte) error {
Expect(url).To(Equal("https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml"))
return nil
}),
).ToNot(HaveOccurred())
})
})
})
13 changes: 13 additions & 0 deletions pkg/utils/utils_suite_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package utils_test

import (
"testing"

. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)

func TestUtils(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "Utils test suite")
}