Skip to content

Commit

Permalink
feat: extend model configuration for llama.cpp (#536)
Browse files Browse the repository at this point in the history
  • Loading branch information
mudler authored Jun 7, 2023
1 parent 694dd4a commit 5abbb13
Show file tree
Hide file tree
Showing 6 changed files with 69 additions and 150 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ GOTEST=$(GOCMD) test
GOVET=$(GOCMD) vet
BINARY_NAME=local-ai

GOLLAMA_VERSION?=37ef81d01ae0848575e416e48b41d112ef0d520e
GOLLAMA_VERSION?=351aa714672fb09aa84396868d08934e8e477f25
GPT4ALL_REPO?=https://github.com/go-skynet/gpt4all
GPT4ALL_VERSION?=f7498c9
GOGGMLTRANSFORMERS_VERSION?=bd765bb6f3b38a63f915f3725e488aad492eedd4
Expand Down
60 changes: 35 additions & 25 deletions api/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,27 +16,33 @@ import (
)

type Config struct {
OpenAIRequest `yaml:"parameters"`
Name string `yaml:"name"`
StopWords []string `yaml:"stopwords"`
Cutstrings []string `yaml:"cutstrings"`
TrimSpace []string `yaml:"trimspace"`
ContextSize int `yaml:"context_size"`
F16 bool `yaml:"f16"`
Threads int `yaml:"threads"`
Debug bool `yaml:"debug"`
Roles map[string]string `yaml:"roles"`
Embeddings bool `yaml:"embeddings"`
Backend string `yaml:"backend"`
TemplateConfig TemplateConfig `yaml:"template"`
MirostatETA float64 `yaml:"mirostat_eta"`
MirostatTAU float64 `yaml:"mirostat_tau"`
Mirostat int `yaml:"mirostat"`
NGPULayers int `yaml:"gpu_layers"`
ImageGenerationAssets string `yaml:"asset_dir"`
OpenAIRequest `yaml:"parameters"`
Name string `yaml:"name"`
StopWords []string `yaml:"stopwords"`
Cutstrings []string `yaml:"cutstrings"`
TrimSpace []string `yaml:"trimspace"`
ContextSize int `yaml:"context_size"`
F16 bool `yaml:"f16"`
Threads int `yaml:"threads"`
Debug bool `yaml:"debug"`
Roles map[string]string `yaml:"roles"`
Embeddings bool `yaml:"embeddings"`
Backend string `yaml:"backend"`
TemplateConfig TemplateConfig `yaml:"template"`
MirostatETA float64 `yaml:"mirostat_eta"`
MirostatTAU float64 `yaml:"mirostat_tau"`
Mirostat int `yaml:"mirostat"`
NGPULayers int `yaml:"gpu_layers"`
MMap bool `yaml:"mmap"`
MMlock bool `yaml:"mmlock"`

TensorSplit string `yaml:"tensor_split"`
MainGPU string `yaml:"main_gpu"`
ImageGenerationAssets string `yaml:"asset_dir"`

PromptCachePath string `yaml:"prompt_cache_path"`
PromptCacheAll bool `yaml:"prompt_cache_all"`
PromptCacheRO bool `yaml:"prompt_cache_ro"`

PromptStrings, InputStrings []string
InputToken [][]int
Expand All @@ -53,6 +59,12 @@ type ConfigMerger struct {
sync.Mutex
}

func defaultConfig(modelFile string) *Config {
return &Config{
OpenAIRequest: defaultRequest(modelFile),
}
}

func NewConfigMerger() *ConfigMerger {
return &ConfigMerger{
configs: make(map[string]Config),
Expand Down Expand Up @@ -308,13 +320,11 @@ func readConfig(modelFile string, input *OpenAIRequest, cm *ConfigMerger, loader
var config *Config
cfg, exists := cm.GetConfig(modelFile)
if !exists {
config = &Config{
OpenAIRequest: defaultRequest(modelFile),
ContextSize: ctx,
Threads: threads,
F16: f16,
Debug: debug,
}
config = defaultConfig(modelFile)
config.ContextSize = ctx
config.Threads = threads
config.F16 = f16
config.Debug = debug
} else {
config = &cfg
}
Expand Down
7 changes: 5 additions & 2 deletions api/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ import (
"bufio"
"bytes"
"encoding/base64"
"errors"
"encoding/json"
"errors"
"fmt"
"io"
"io/ioutil"
Expand Down Expand Up @@ -125,6 +125,9 @@ type OpenAIRequest struct {
MirostatTAU float64 `json:"mirostat_tau" yaml:"mirostat_tau"`
Mirostat int `json:"mirostat" yaml:"mirostat"`

FrequencyPenalty float64 `json:"frequency_penalty" yaml:"frequency_penalty"`
TFZ float64 `json:"tfz" yaml:"tfz"`

Seed int `json:"seed" yaml:"seed"`

// Image (not supported by OpenAI)
Expand Down Expand Up @@ -191,7 +194,7 @@ func completionEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error {
}

if input.Stream {
if (len(config.PromptStrings) > 1) {
if len(config.PromptStrings) > 1 {
return errors.New("cannot handle more than 1 `PromptStrings` when `Stream`ing")
}

Expand Down
20 changes: 20 additions & 0 deletions api/prediction.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,13 @@ func defaultLLamaOpts(c Config) []llama.ModelOption {
llamaOpts = append(llamaOpts, llama.SetGPULayers(c.NGPULayers))
}

llamaOpts = append(llamaOpts, llama.SetMMap(c.MMap))
llamaOpts = append(llamaOpts, llama.SetMainGPU(c.MainGPU))
llamaOpts = append(llamaOpts, llama.SetTensorSplit(c.TensorSplit))
if c.Batch != 0 {
llamaOpts = append(llamaOpts, llama.SetNBatch(c.Batch))
}

return llamaOpts
}

Expand Down Expand Up @@ -168,6 +175,10 @@ func buildLLamaPredictOptions(c Config, modelPath string) []llama.PredictOption
predictOptions = append(predictOptions, llama.EnablePromptCacheAll)
}

if c.PromptCacheRO {
predictOptions = append(predictOptions, llama.EnablePromptCacheRO)
}

if c.PromptCachePath != "" {
// Create parent directory
p := filepath.Join(modelPath, c.PromptCachePath)
Expand Down Expand Up @@ -217,6 +228,15 @@ func buildLLamaPredictOptions(c Config, modelPath string) []llama.PredictOption
predictOptions = append(predictOptions, llama.SetSeed(c.Seed))
}

//predictOptions = append(predictOptions, llama.SetLogitBias(c.Seed))

predictOptions = append(predictOptions, llama.SetFrequencyPenalty(c.FrequencyPenalty))
predictOptions = append(predictOptions, llama.SetMlock(c.MMlock))
predictOptions = append(predictOptions, llama.SetMemoryMap(c.MMap))
predictOptions = append(predictOptions, llama.SetPredictionMainGPU(c.MainGPU))
predictOptions = append(predictOptions, llama.SetPredictionTensorSplit(c.TensorSplit))
predictOptions = append(predictOptions, llama.SetTailFreeSamplingZ(c.TFZ))

return predictOptions
}

Expand Down
5 changes: 2 additions & 3 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@ require (
github.com/go-skynet/bloomz.cpp v0.0.0-20230529155654-1834e77b83fa
github.com/go-skynet/go-bert.cpp v0.0.0-20230531070950-0548994371f7
github.com/go-skynet/go-ggml-transformers.cpp v0.0.0-20230606131358-bd765bb6f3b3
github.com/go-skynet/go-llama.cpp v0.0.0-20230606152241-37ef81d01ae0
github.com/go-skynet/go-llama.cpp v0.0.0-20230607123950-351aa714672f
github.com/gofiber/fiber/v2 v2.46.0
github.com/google/uuid v1.3.0
github.com/hashicorp/go-multierror v1.1.1
github.com/imdario/mergo v0.3.16
github.com/mudler/go-stable-diffusion v0.0.0-20230605122230-d89260f598af
github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230605194130-266f13aee9d8
github.com/onsi/ginkgo/v2 v2.9.7
github.com/onsi/gomega v1.27.7
github.com/onsi/gomega v1.27.8
github.com/otiai10/openaigo v1.1.0
github.com/rs/zerolog v1.29.1
github.com/sashabaranov/go-openai v1.10.0
Expand All @@ -42,7 +42,6 @@ require (
github.com/go-openapi/jsonreference v0.19.6 // indirect
github.com/go-openapi/spec v0.20.4 // indirect
github.com/go-openapi/swag v0.19.15 // indirect
github.com/go-skynet/go-gpt2.cpp v0.0.0-20230523153133-3eb3a32c0874 // indirect
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect
github.com/google/go-cmp v0.5.9 // indirect
github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 // indirect
Expand Down
Loading

0 comments on commit 5abbb13

Please sign in to comment.