diff --git a/go/plugins/ollama/embed.go b/go/plugins/ollama/embed.go new file mode 100644 index 000000000..051514420 --- /dev/null +++ b/go/plugins/ollama/embed.go @@ -0,0 +1,155 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ollama + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "net/http" + "strings" + + "github.com/firebase/genkit/go/ai" +) + +type EmbedOptions struct { + Model string `json:"model"` +} + +type ollamaEmbedRequest struct { + Model string `json:"model"` + Input interface{} `json:"input"` // todo: using interface{} to handle both string and []string, figure out better solution + Options map[string]interface{} `json:"options,omitempty"` +} + +type ollamaEmbedResponse struct { + Embeddings [][]float32 `json:"embeddings"` +} + +func embed(ctx context.Context, serverAddress string, req *ai.EmbedRequest) (*ai.EmbedResponse, error) { + options, ok := req.Options.(*EmbedOptions) + if !ok && req.Options != nil { + return nil, fmt.Errorf("invalid options type: expected *EmbedOptions") + } + if options == nil || options.Model == "" { + return nil, fmt.Errorf("invalid embedding model: model must be specified") + } + + if serverAddress == "" { + return nil, fmt.Errorf("invalid server address: address cannot be empty") + } + + ollamaReq := newOllamaEmbedRequest(options.Model, req.Documents) + + jsonData, err := json.Marshal(ollamaReq) + if err != nil { + return nil, fmt.Errorf("failed to marshal embed request: %w", err) + } + + resp, err := sendEmbedRequest(ctx, serverAddress, jsonData) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("ollama embed request failed with status code %d", resp.StatusCode) + } + + var ollamaResp ollamaEmbedResponse + if err := json.NewDecoder(resp.Body).Decode(&ollamaResp); err != nil { + return nil, fmt.Errorf("failed to decode embed response: %w", err) + } + + return newEmbedResponse(ollamaResp.Embeddings), nil +} + +func sendEmbedRequest(ctx context.Context, serverAddress string, jsonData []byte) (*http.Response, error) { + client := &http.Client{} + httpReq, err := http.NewRequestWithContext(ctx, "POST", serverAddress+"/api/embed", bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + httpReq.Header.Set("Content-Type", "application/json") + return client.Do(httpReq) +} + +func newOllamaEmbedRequest(model string, documents []*ai.Document) ollamaEmbedRequest { + var input interface{} + if len(documents) == 1 { + input = concatenateText(documents[0]) + } else { + texts := make([]string, len(documents)) + for i, doc := range documents { + texts[i] = concatenateText(doc) + } + input = texts + } + + return ollamaEmbedRequest{ + Model: model, + Input: input, + } +} + +func newEmbedResponse(embeddings [][]float32) *ai.EmbedResponse { + resp := &ai.EmbedResponse{ + Embeddings: make([]*ai.DocumentEmbedding, len(embeddings)), + } + for i, embedding := range embeddings { + resp.Embeddings[i] = &ai.DocumentEmbedding{Embedding: embedding} + } + return resp +} + +func concatenateText(doc *ai.Document) string { + var builder strings.Builder + for _, part := range doc.Content { + builder.WriteString(part.Text) + } + result := builder.String() + return result +} + +// DefineEmbedder defines an embedder with a given server address. +func DefineEmbedder(serverAddress string, model string) ai.Embedder { + state.mu.Lock() + defer state.mu.Unlock() + if !state.initted { + panic("ollama.Init not called") + } + return ai.DefineEmbedder(provider, serverAddress, func(ctx context.Context, req *ai.EmbedRequest) (*ai.EmbedResponse, error) { + if req.Options == nil { + req.Options = &EmbedOptions{Model: model} + } + if req.Options.(*EmbedOptions).Model == "" { + req.Options.(*EmbedOptions).Model = model + } + return embed(ctx, serverAddress, req) + }) +} + +// IsDefinedEmbedder reports whether the embedder with the given server address is defined by this plugin. +func IsDefinedEmbedder(serverAddress string) bool { + isDefined := ai.IsDefinedEmbedder(provider, serverAddress) + return isDefined +} + +// Embedder returns the [ai.Embedder] with the given server address. +// It returns nil if the embedder was not defined. +func Embedder(serverAddress string) ai.Embedder { + return ai.LookupEmbedder(provider, serverAddress) +} diff --git a/go/plugins/ollama/embed_test.go b/go/plugins/ollama/embed_test.go new file mode 100644 index 000000000..d5c0190a4 --- /dev/null +++ b/go/plugins/ollama/embed_test.go @@ -0,0 +1,66 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ollama + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/firebase/genkit/go/ai" +) + +func TestEmbedValidRequest(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(ollamaEmbedResponse{ + Embeddings: [][]float32{{0.1, 0.2, 0.3}}, + }) + })) + defer server.Close() + + req := &ai.EmbedRequest{ + Documents: []*ai.Document{ + ai.DocumentFromText("test", nil), + }, + Options: &EmbedOptions{Model: "all-minilm"}, + } + + resp, err := embed(context.Background(), server.URL, req) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + + if len(resp.Embeddings) != 1 { + t.Fatalf("expected 1 embedding, got %d", len(resp.Embeddings)) + } +} + +func TestEmbedInvalidServerAddress(t *testing.T) { + req := &ai.EmbedRequest{ + Documents: []*ai.Document{ + ai.DocumentFromText("test", nil), + }, + Options: &EmbedOptions{Model: "all-minilm"}, + } + + _, err := embed(context.Background(), "", req) + if err == nil || !strings.Contains(err.Error(), "invalid server address") { + t.Fatalf("expected invalid server address error, got %v", err) + } +}