-
Notifications
You must be signed in to change notification settings - Fork 133
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(go): Add ollama embeddings support (#841)
- Loading branch information
1 parent
50cdf5c
commit b86533a
Showing
2 changed files
with
221 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,155 @@ | ||
// Copyright 2024 Google LLC | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
package ollama | ||
|
||
import ( | ||
"bytes" | ||
"context" | ||
"encoding/json" | ||
"fmt" | ||
"net/http" | ||
"strings" | ||
|
||
"github.com/firebase/genkit/go/ai" | ||
) | ||
|
||
type EmbedOptions struct { | ||
Model string `json:"model"` | ||
} | ||
|
||
type ollamaEmbedRequest struct { | ||
Model string `json:"model"` | ||
Input interface{} `json:"input"` // todo: using interface{} to handle both string and []string, figure out better solution | ||
Options map[string]interface{} `json:"options,omitempty"` | ||
} | ||
|
||
type ollamaEmbedResponse struct { | ||
Embeddings [][]float32 `json:"embeddings"` | ||
} | ||
|
||
func embed(ctx context.Context, serverAddress string, req *ai.EmbedRequest) (*ai.EmbedResponse, error) { | ||
options, ok := req.Options.(*EmbedOptions) | ||
if !ok && req.Options != nil { | ||
return nil, fmt.Errorf("invalid options type: expected *EmbedOptions") | ||
} | ||
if options == nil || options.Model == "" { | ||
return nil, fmt.Errorf("invalid embedding model: model must be specified") | ||
} | ||
|
||
if serverAddress == "" { | ||
return nil, fmt.Errorf("invalid server address: address cannot be empty") | ||
} | ||
|
||
ollamaReq := newOllamaEmbedRequest(options.Model, req.Documents) | ||
|
||
jsonData, err := json.Marshal(ollamaReq) | ||
if err != nil { | ||
return nil, fmt.Errorf("failed to marshal embed request: %w", err) | ||
} | ||
|
||
resp, err := sendEmbedRequest(ctx, serverAddress, jsonData) | ||
if err != nil { | ||
return nil, err | ||
} | ||
defer resp.Body.Close() | ||
|
||
if resp.StatusCode != http.StatusOK { | ||
return nil, fmt.Errorf("ollama embed request failed with status code %d", resp.StatusCode) | ||
} | ||
|
||
var ollamaResp ollamaEmbedResponse | ||
if err := json.NewDecoder(resp.Body).Decode(&ollamaResp); err != nil { | ||
return nil, fmt.Errorf("failed to decode embed response: %w", err) | ||
} | ||
|
||
return newEmbedResponse(ollamaResp.Embeddings), nil | ||
} | ||
|
||
func sendEmbedRequest(ctx context.Context, serverAddress string, jsonData []byte) (*http.Response, error) { | ||
client := &http.Client{} | ||
httpReq, err := http.NewRequestWithContext(ctx, "POST", serverAddress+"/api/embed", bytes.NewBuffer(jsonData)) | ||
if err != nil { | ||
return nil, fmt.Errorf("failed to create request: %w", err) | ||
} | ||
httpReq.Header.Set("Content-Type", "application/json") | ||
return client.Do(httpReq) | ||
} | ||
|
||
func newOllamaEmbedRequest(model string, documents []*ai.Document) ollamaEmbedRequest { | ||
var input interface{} | ||
if len(documents) == 1 { | ||
input = concatenateText(documents[0]) | ||
} else { | ||
texts := make([]string, len(documents)) | ||
for i, doc := range documents { | ||
texts[i] = concatenateText(doc) | ||
} | ||
input = texts | ||
} | ||
|
||
return ollamaEmbedRequest{ | ||
Model: model, | ||
Input: input, | ||
} | ||
} | ||
|
||
func newEmbedResponse(embeddings [][]float32) *ai.EmbedResponse { | ||
resp := &ai.EmbedResponse{ | ||
Embeddings: make([]*ai.DocumentEmbedding, len(embeddings)), | ||
} | ||
for i, embedding := range embeddings { | ||
resp.Embeddings[i] = &ai.DocumentEmbedding{Embedding: embedding} | ||
} | ||
return resp | ||
} | ||
|
||
func concatenateText(doc *ai.Document) string { | ||
var builder strings.Builder | ||
for _, part := range doc.Content { | ||
builder.WriteString(part.Text) | ||
} | ||
result := builder.String() | ||
return result | ||
} | ||
|
||
// DefineEmbedder defines an embedder with a given server address. | ||
func DefineEmbedder(serverAddress string, model string) ai.Embedder { | ||
state.mu.Lock() | ||
defer state.mu.Unlock() | ||
if !state.initted { | ||
panic("ollama.Init not called") | ||
} | ||
return ai.DefineEmbedder(provider, serverAddress, func(ctx context.Context, req *ai.EmbedRequest) (*ai.EmbedResponse, error) { | ||
if req.Options == nil { | ||
req.Options = &EmbedOptions{Model: model} | ||
} | ||
if req.Options.(*EmbedOptions).Model == "" { | ||
req.Options.(*EmbedOptions).Model = model | ||
} | ||
return embed(ctx, serverAddress, req) | ||
}) | ||
} | ||
|
||
// IsDefinedEmbedder reports whether the embedder with the given server address is defined by this plugin. | ||
func IsDefinedEmbedder(serverAddress string) bool { | ||
isDefined := ai.IsDefinedEmbedder(provider, serverAddress) | ||
return isDefined | ||
} | ||
|
||
// Embedder returns the [ai.Embedder] with the given server address. | ||
// It returns nil if the embedder was not defined. | ||
func Embedder(serverAddress string) ai.Embedder { | ||
return ai.LookupEmbedder(provider, serverAddress) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
// Copyright 2024 Google LLC | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
package ollama | ||
|
||
import ( | ||
"context" | ||
"encoding/json" | ||
"net/http" | ||
"net/http/httptest" | ||
"strings" | ||
"testing" | ||
|
||
"github.com/firebase/genkit/go/ai" | ||
) | ||
|
||
func TestEmbedValidRequest(t *testing.T) { | ||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
w.WriteHeader(http.StatusOK) | ||
json.NewEncoder(w).Encode(ollamaEmbedResponse{ | ||
Embeddings: [][]float32{{0.1, 0.2, 0.3}}, | ||
}) | ||
})) | ||
defer server.Close() | ||
|
||
req := &ai.EmbedRequest{ | ||
Documents: []*ai.Document{ | ||
ai.DocumentFromText("test", nil), | ||
}, | ||
Options: &EmbedOptions{Model: "all-minilm"}, | ||
} | ||
|
||
resp, err := embed(context.Background(), server.URL, req) | ||
if err != nil { | ||
t.Fatalf("expected no error, got %v", err) | ||
} | ||
|
||
if len(resp.Embeddings) != 1 { | ||
t.Fatalf("expected 1 embedding, got %d", len(resp.Embeddings)) | ||
} | ||
} | ||
|
||
func TestEmbedInvalidServerAddress(t *testing.T) { | ||
req := &ai.EmbedRequest{ | ||
Documents: []*ai.Document{ | ||
ai.DocumentFromText("test", nil), | ||
}, | ||
Options: &EmbedOptions{Model: "all-minilm"}, | ||
} | ||
|
||
_, err := embed(context.Background(), "", req) | ||
if err == nil || !strings.Contains(err.Error(), "invalid server address") { | ||
t.Fatalf("expected invalid server address error, got %v", err) | ||
} | ||
} |