Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(go): Add ollama embeddings support #841

Merged
merged 9 commits into from
Sep 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 155 additions & 0 deletions go/plugins/ollama/embed.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package ollama

import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"strings"

"github.com/firebase/genkit/go/ai"
)

type EmbedOptions struct {
Model string `json:"model"`
}

type ollamaEmbedRequest struct {
Model string `json:"model"`
Input interface{} `json:"input"` // todo: using interface{} to handle both string and []string, figure out better solution
Options map[string]interface{} `json:"options,omitempty"`
}

type ollamaEmbedResponse struct {
Embeddings [][]float32 `json:"embeddings"`
}

func embed(ctx context.Context, serverAddress string, req *ai.EmbedRequest) (*ai.EmbedResponse, error) {
options, ok := req.Options.(*EmbedOptions)
if !ok && req.Options != nil {
return nil, fmt.Errorf("invalid options type: expected *EmbedOptions")
}
if options == nil || options.Model == "" {
return nil, fmt.Errorf("invalid embedding model: model must be specified")
}

if serverAddress == "" {
return nil, fmt.Errorf("invalid server address: address cannot be empty")
}

ollamaReq := newOllamaEmbedRequest(options.Model, req.Documents)

jsonData, err := json.Marshal(ollamaReq)
if err != nil {
return nil, fmt.Errorf("failed to marshal embed request: %w", err)
}

resp, err := sendEmbedRequest(ctx, serverAddress, jsonData)
if err != nil {
return nil, err
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("ollama embed request failed with status code %d", resp.StatusCode)
}

var ollamaResp ollamaEmbedResponse
if err := json.NewDecoder(resp.Body).Decode(&ollamaResp); err != nil {
return nil, fmt.Errorf("failed to decode embed response: %w", err)
}

return newEmbedResponse(ollamaResp.Embeddings), nil
}

func sendEmbedRequest(ctx context.Context, serverAddress string, jsonData []byte) (*http.Response, error) {
client := &http.Client{}
httpReq, err := http.NewRequestWithContext(ctx, "POST", serverAddress+"/api/embed", bytes.NewBuffer(jsonData))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
httpReq.Header.Set("Content-Type", "application/json")
return client.Do(httpReq)
}

func newOllamaEmbedRequest(model string, documents []*ai.Document) ollamaEmbedRequest {
var input interface{}
if len(documents) == 1 {
input = concatenateText(documents[0])
} else {
texts := make([]string, len(documents))
for i, doc := range documents {
texts[i] = concatenateText(doc)
}
input = texts
}

return ollamaEmbedRequest{
Model: model,
Input: input,
}
}

func newEmbedResponse(embeddings [][]float32) *ai.EmbedResponse {
resp := &ai.EmbedResponse{
Embeddings: make([]*ai.DocumentEmbedding, len(embeddings)),
}
for i, embedding := range embeddings {
resp.Embeddings[i] = &ai.DocumentEmbedding{Embedding: embedding}
}
return resp
}

func concatenateText(doc *ai.Document) string {
var builder strings.Builder
for _, part := range doc.Content {
builder.WriteString(part.Text)
}
result := builder.String()
return result
}

// DefineEmbedder defines an embedder with a given server address.
func DefineEmbedder(serverAddress string, model string) ai.Embedder {
state.mu.Lock()
defer state.mu.Unlock()
if !state.initted {
panic("ollama.Init not called")
}
return ai.DefineEmbedder(provider, serverAddress, func(ctx context.Context, req *ai.EmbedRequest) (*ai.EmbedResponse, error) {
if req.Options == nil {
huangjeff5 marked this conversation as resolved.
Show resolved Hide resolved
req.Options = &EmbedOptions{Model: model}
}
if req.Options.(*EmbedOptions).Model == "" {
req.Options.(*EmbedOptions).Model = model
}
return embed(ctx, serverAddress, req)
})
}

// IsDefinedEmbedder reports whether the embedder with the given server address is defined by this plugin.
func IsDefinedEmbedder(serverAddress string) bool {
isDefined := ai.IsDefinedEmbedder(provider, serverAddress)
return isDefined
}

// Embedder returns the [ai.Embedder] with the given server address.
// It returns nil if the embedder was not defined.
func Embedder(serverAddress string) ai.Embedder {
return ai.LookupEmbedder(provider, serverAddress)
}
66 changes: 66 additions & 0 deletions go/plugins/ollama/embed_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package ollama

import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"

"github.com/firebase/genkit/go/ai"
)

func TestEmbedValidRequest(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(ollamaEmbedResponse{
Embeddings: [][]float32{{0.1, 0.2, 0.3}},
})
}))
defer server.Close()

req := &ai.EmbedRequest{
Documents: []*ai.Document{
ai.DocumentFromText("test", nil),
},
Options: &EmbedOptions{Model: "all-minilm"},
}

resp, err := embed(context.Background(), server.URL, req)
if err != nil {
t.Fatalf("expected no error, got %v", err)
}

if len(resp.Embeddings) != 1 {
t.Fatalf("expected 1 embedding, got %d", len(resp.Embeddings))
}
}

func TestEmbedInvalidServerAddress(t *testing.T) {
req := &ai.EmbedRequest{
Documents: []*ai.Document{
ai.DocumentFromText("test", nil),
},
Options: &EmbedOptions{Model: "all-minilm"},
}

_, err := embed(context.Background(), "", req)
if err == nil || !strings.Contains(err.Error(), "invalid server address") {
t.Fatalf("expected invalid server address error, got %v", err)
}
}
Loading