Skip to content

Commit

Permalink
feat(go): Add ollama embeddings support (#841)
Browse files Browse the repository at this point in the history
  • Loading branch information
huangjeff5 authored Sep 16, 2024
1 parent 50cdf5c commit b86533a
Show file tree
Hide file tree
Showing 2 changed files with 221 additions and 0 deletions.
155 changes: 155 additions & 0 deletions go/plugins/ollama/embed.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package ollama

import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"strings"

"github.com/firebase/genkit/go/ai"
)

type EmbedOptions struct {
Model string `json:"model"`
}

type ollamaEmbedRequest struct {
Model string `json:"model"`
Input interface{} `json:"input"` // todo: using interface{} to handle both string and []string, figure out better solution
Options map[string]interface{} `json:"options,omitempty"`
}

type ollamaEmbedResponse struct {
Embeddings [][]float32 `json:"embeddings"`
}

func embed(ctx context.Context, serverAddress string, req *ai.EmbedRequest) (*ai.EmbedResponse, error) {
options, ok := req.Options.(*EmbedOptions)
if !ok && req.Options != nil {
return nil, fmt.Errorf("invalid options type: expected *EmbedOptions")
}
if options == nil || options.Model == "" {
return nil, fmt.Errorf("invalid embedding model: model must be specified")
}

if serverAddress == "" {
return nil, fmt.Errorf("invalid server address: address cannot be empty")
}

ollamaReq := newOllamaEmbedRequest(options.Model, req.Documents)

jsonData, err := json.Marshal(ollamaReq)
if err != nil {
return nil, fmt.Errorf("failed to marshal embed request: %w", err)
}

resp, err := sendEmbedRequest(ctx, serverAddress, jsonData)
if err != nil {
return nil, err
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("ollama embed request failed with status code %d", resp.StatusCode)
}

var ollamaResp ollamaEmbedResponse
if err := json.NewDecoder(resp.Body).Decode(&ollamaResp); err != nil {
return nil, fmt.Errorf("failed to decode embed response: %w", err)
}

return newEmbedResponse(ollamaResp.Embeddings), nil
}

func sendEmbedRequest(ctx context.Context, serverAddress string, jsonData []byte) (*http.Response, error) {
client := &http.Client{}
httpReq, err := http.NewRequestWithContext(ctx, "POST", serverAddress+"/api/embed", bytes.NewBuffer(jsonData))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
httpReq.Header.Set("Content-Type", "application/json")
return client.Do(httpReq)
}

func newOllamaEmbedRequest(model string, documents []*ai.Document) ollamaEmbedRequest {
var input interface{}
if len(documents) == 1 {
input = concatenateText(documents[0])
} else {
texts := make([]string, len(documents))
for i, doc := range documents {
texts[i] = concatenateText(doc)
}
input = texts
}

return ollamaEmbedRequest{
Model: model,
Input: input,
}
}

func newEmbedResponse(embeddings [][]float32) *ai.EmbedResponse {
resp := &ai.EmbedResponse{
Embeddings: make([]*ai.DocumentEmbedding, len(embeddings)),
}
for i, embedding := range embeddings {
resp.Embeddings[i] = &ai.DocumentEmbedding{Embedding: embedding}
}
return resp
}

func concatenateText(doc *ai.Document) string {
var builder strings.Builder
for _, part := range doc.Content {
builder.WriteString(part.Text)
}
result := builder.String()
return result
}

// DefineEmbedder defines an embedder with a given server address.
func DefineEmbedder(serverAddress string, model string) ai.Embedder {
state.mu.Lock()
defer state.mu.Unlock()
if !state.initted {
panic("ollama.Init not called")
}
return ai.DefineEmbedder(provider, serverAddress, func(ctx context.Context, req *ai.EmbedRequest) (*ai.EmbedResponse, error) {
if req.Options == nil {
req.Options = &EmbedOptions{Model: model}
}
if req.Options.(*EmbedOptions).Model == "" {
req.Options.(*EmbedOptions).Model = model
}
return embed(ctx, serverAddress, req)
})
}

// IsDefinedEmbedder reports whether the embedder with the given server address is defined by this plugin.
func IsDefinedEmbedder(serverAddress string) bool {
isDefined := ai.IsDefinedEmbedder(provider, serverAddress)
return isDefined
}

// Embedder returns the [ai.Embedder] with the given server address.
// It returns nil if the embedder was not defined.
func Embedder(serverAddress string) ai.Embedder {
return ai.LookupEmbedder(provider, serverAddress)
}
66 changes: 66 additions & 0 deletions go/plugins/ollama/embed_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package ollama

import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"

"github.com/firebase/genkit/go/ai"
)

func TestEmbedValidRequest(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(ollamaEmbedResponse{
Embeddings: [][]float32{{0.1, 0.2, 0.3}},
})
}))
defer server.Close()

req := &ai.EmbedRequest{
Documents: []*ai.Document{
ai.DocumentFromText("test", nil),
},
Options: &EmbedOptions{Model: "all-minilm"},
}

resp, err := embed(context.Background(), server.URL, req)
if err != nil {
t.Fatalf("expected no error, got %v", err)
}

if len(resp.Embeddings) != 1 {
t.Fatalf("expected 1 embedding, got %d", len(resp.Embeddings))
}
}

func TestEmbedInvalidServerAddress(t *testing.T) {
req := &ai.EmbedRequest{
Documents: []*ai.Document{
ai.DocumentFromText("test", nil),
},
Options: &EmbedOptions{Model: "all-minilm"},
}

_, err := embed(context.Background(), "", req)
if err == nil || !strings.Contains(err.Error(), "invalid server address") {
t.Fatalf("expected invalid server address error, got %v", err)
}
}

0 comments on commit b86533a

Please sign in to comment.