Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Made Embeddings work #1

Merged
merged 8 commits into from
Mar 24, 2024
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
# Ignore Byebug command history file.
.byebug_history

# Ignore the RustRover project file
.idea

## Specific to RubyMotion:
.dat*
.repl_history
Expand Down Expand Up @@ -64,4 +67,4 @@ target
*.o
*.lock

lib.py.rs
lib.py.rs
45 changes: 45 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,50 @@ x = x.reshape([3, 2])
# Tensor[[3, 2], f32]
```

```ruby
require 'candle'
model = Candle::Model.new
embedding = model.embedding("Hi there!")
```

## A note on memory usage
The `Candle::Model` defaults to the `jinaai/jina-embeddings-v2-base-en` model with the `sentence-transformers/all-MiniLM-L6-v2` tokenizer (both from [HuggingFace](https://huggingface.co)). With this configuration the model takes a little more than 3GB of memory running on my Mac. The memory stays with the instantiated `Candle::Model` class, if you instantiate more than one, you'll use more memory. Likewise, if you let it go out of scope and call the garbage collector, you'll free the memory. For example:

```ruby
> require 'candle'
# Ruby memory = 25.9 MB
> model = Candle::Model.new
# Ruby memory = 3.50 GB
> model2 = Candle::Model.new
# Ruby memory = 7.04 GB
> model2 = nil
> GC.start
# Ruby memory = 3.56 GB
> model = nil
> GC.start
# Ruby memory = 55.2 MB
```

## A note on returned embeddings

The code should match the same embeddings when generated from the python `transformers` library. For instance, locally I was able to generate the same embedding for the text "Hi there!" using the python code:

```python
from transformers import AutoModel
model = AutoModel.from_pretrained('jinaai/jina-embeddings-v2-base-en', trust_remote_code=True)
sentence = ['Hi there!']
embedding = model.encode(sentence)
print(embedding)
```

And the following ruby:

```ruby
require 'candle'
model = Candle::Model.new
embedding = model.embedding("Hi there!")
```

## Development

FORK IT!
Expand All @@ -29,6 +73,7 @@ bundle
bundle exec rake compile
```


Implemented with [Magnus](https://github.com/matsadler/magnus), with reference to [Polars Ruby](https://github.com/ankane/polars-ruby)

Policies
Expand Down
11 changes: 6 additions & 5 deletions ext/candle/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use magnus::{function, method, prelude::*, Ruby};

use crate::model::{candle_utils, ModelConfig, RbDType, RbDevice, RbQTensor, RbResult, RbTensor};
use crate::model::{candle_utils, RbModel, RbDType, RbDevice, RbQTensor, RbResult, RbTensor};

pub mod model;

Expand All @@ -22,6 +22,7 @@ fn init(ruby: &Ruby) -> RbResult<()> {
rb_tensor.define_method("dtype", method!(RbTensor::dtype, 0))?;
rb_tensor.define_method("device", method!(RbTensor::device, 0))?;
rb_tensor.define_method("rank", method!(RbTensor::rank, 0))?;
rb_tensor.define_method("elem_count", method!(RbTensor::elem_count, 0))?;
rb_tensor.define_method("sin", method!(RbTensor::sin, 0))?;
rb_tensor.define_method("cos", method!(RbTensor::cos, 0))?;
rb_tensor.define_method("log", method!(RbTensor::log, 0))?;
Expand Down Expand Up @@ -93,10 +94,10 @@ fn init(ruby: &Ruby) -> RbResult<()> {
rb_qtensor.define_method("dequantize", method!(RbQTensor::dequantize, 0))?;

let rb_model = rb_candle.define_class("Model", Ruby::class_object(ruby))?;
rb_model.define_singleton_method("new", function!(ModelConfig::new, 0))?;
rb_model.define_method("embedding", method!(ModelConfig::embedding, 1))?;
rb_model.define_method("to_s", method!(ModelConfig::__str__, 0))?;
rb_model.define_method("inspect", method!(ModelConfig::__repr__, 0))?;
rb_model.define_singleton_method("new", function!(RbModel::new, 0))?;
rb_model.define_method("embedding", method!(RbModel::embedding, 1))?;
rb_model.define_method("to_s", method!(RbModel::__str__, 0))?;
rb_model.define_method("inspect", method!(RbModel::__repr__, 0))?;

Ok(())
}
4 changes: 2 additions & 2 deletions ext/candle/src/model/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
mod config;
pub use config::*;
mod rb_model;
pub use rb_model::*;

mod errors;
pub use errors::*;
Expand Down
50 changes: 32 additions & 18 deletions ext/candle/src/model/config.rs → ext/candle/src/model/rb_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,24 +16,24 @@ use crate::model::RbResult;
use tokenizers::Tokenizer;

#[magnus::wrap(class = "Candle::Model", free_immediately, size)]
pub struct ModelConfig(pub ModelConfigInner);
pub struct RbModel(pub RbModelInner);

pub struct ModelConfigInner {
pub struct RbModelInner {
device: Device,
tokenizer_path: Option<String>,
model_path: Option<String>,
model: Option<BertModel>,
tokenizer: Option<Tokenizer>,
}

impl ModelConfig {
impl RbModel {
pub fn new() -> RbResult<Self> {
Self::new2(Some("jinaai/jina-embeddings-v2-base-en".to_string()), Some("sentence-transformers/all-MiniLM-L6-v2".to_string()), Some(Device::Cpu))
}

pub fn new2(model_path: Option<String>, tokenizer_path: Option<String>, device: Option<Device>) -> RbResult<Self> {
let device = device.unwrap_or(Device::Cpu);
Ok(ModelConfig(ModelConfigInner {
Ok(RbModel(RbModelInner {
device: device.clone(),
model_path: model_path.clone(),
tokenizer_path: tokenizer_path.clone(),
Expand Down Expand Up @@ -92,10 +92,14 @@ impl ModelConfig {
))
.get("tokenizer.json")
.map_err(wrap_hf_err)?;
let tokenizer = tokenizers::Tokenizer::from_file(tokenizer_path)
// .with_padding(None)
// .with_truncation(None)
let mut tokenizer = tokenizers::Tokenizer::from_file(tokenizer_path)
.map_err(wrap_std_err)?;
let pp = tokenizers::PaddingParams {
strategy: tokenizers::PaddingStrategy::BatchLongest,
..Default::default()
};
tokenizer.with_padding(Some(pp));

Ok(tokenizer)
}

Expand All @@ -105,9 +109,6 @@ impl ModelConfig {
model: &BertModel,
tokenizer: &Tokenizer,
) -> Result<Tensor, Error> {
let start: std::time::Instant = std::time::Instant::now();
// let tokenizer_impl = tokenizer
// .map_err(wrap_std_err)?;
let tokens = tokenizer
.encode(prompt, true)
.map_err(wrap_std_err)?
Expand All @@ -117,16 +118,29 @@ impl ModelConfig {
.map_err(wrap_candle_err)?
.unsqueeze(0)
.map_err(wrap_candle_err)?;
println!("Loaded and encoded {:?}", start.elapsed());
let start: std::time::Instant = std::time::Instant::now();

// let start: std::time::Instant = std::time::Instant::now();
let result = model.forward(&token_ids).map_err(wrap_candle_err)?;
// println!("{result}");
println!("Took {:?}", start.elapsed());
Ok(result)

// Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
let (_n_sentence, n_tokens, _hidden_size) = result.dims3()
.map_err(wrap_candle_err)?;
let sum = result.sum(1)
.map_err(wrap_candle_err)?;
let embeddings = (sum / (n_tokens as f64))
.map_err(wrap_candle_err)?;
// let embeddings = Self::normalize_l2(&embeddings).map_err(wrap_candle_err)?;

Ok(embeddings)
}

#[allow(dead_code)]
fn normalize_l2(v: &Tensor) -> Result<Tensor, candle_core::Error> {
v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)
}

pub fn __repr__(&self) -> String {
format!("Candle::Model(path={})", self.0.model_path.as_deref().unwrap_or("None"))
format!("#<Candle::Model model_path: {} tokenizer_path: {})", self.0.model_path.as_deref().unwrap_or("nil"), self.0.tokenizer_path.as_deref().unwrap_or("nil"))
}

pub fn __str__(&self) -> String {
Expand All @@ -143,14 +157,14 @@ impl ModelConfig {

// #[test]
// fn test_build_model_and_tokenizer() {
// let config = super::ModelConfig::build();
// let config = super::RbModel::build();
// let (_model, tokenizer) = config.build_model_and_tokenizer().unwrap();
// assert_eq!(tokenizer.get_vocab_size(true), 30522);
// }

// #[test]
// fn test_embedding() {
// let config = super::ModelConfig::build();
// let config = super::RbModel::build();
// // let (_model, tokenizer) = config.build_model_and_tokenizer().unwrap();
// // assert_eq!(config.embedding("Scientist.com is a marketplace for pharmaceutical services.")?, None);
// }
Expand Down
7 changes: 6 additions & 1 deletion ext/candle/src/model/rb_tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ impl RbTensor {
))
}

// FXIME: Do not use `to_f64` here.
pub fn values(&self) -> RbResult<Vec<f64>> {
let values = self
.0
Expand Down Expand Up @@ -83,6 +82,12 @@ impl RbTensor {
self.0.rank()
}

/// The number of elements stored in this tensor.
/// &RETURNS&: int
pub fn elem_count(&self) -> usize {
self.0.elem_count()
}

pub fn __repr__(&self) -> String {
format!("{}", self.0)
}
Expand Down
2 changes: 2 additions & 0 deletions ext/candle/src/model/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ pub fn candle_utils(rb_candle: magnus::RModule) -> Result<(), Error> {

/// Applies the Softmax function to a given tensor.#
/// &RETURNS&: Tensor
#[allow(dead_code)]
fn softmax(tensor: RbTensor, dim: i64) -> RbResult<RbTensor> {
let dim = actual_dim(&tensor, dim).map_err(wrap_candle_err)?;
let sm = candle_nn::ops::softmax(&tensor.0, dim).map_err(wrap_candle_err)?;
Expand All @@ -79,6 +80,7 @@ fn softmax(tensor: RbTensor, dim: i64) -> RbResult<RbTensor> {

/// Applies the Sigmoid Linear Unit (SiLU) function to a given tensor.
/// &RETURNS&: Tensor
#[allow(dead_code)]
fn silu(tensor: RbTensor) -> RbResult<RbTensor> {
let s = candle_nn::ops::silu(&tensor.0).map_err(wrap_candle_err)?;
Ok(RbTensor(s))
Expand Down
1 change: 1 addition & 0 deletions lib/candle.rb
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
require_relative 'candle/candle'
require_relative 'candle/tensor'
17 changes: 17 additions & 0 deletions lib/candle/tensor.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
module Candle
class Tensor
include Enumerable

def each
if self.rank == 1
self.values.each do |value|
yield value
end
else
shape.first.times do |i|
yield self[i]
end
end
end
end
end
Loading