Skip to content

Commit

Permalink
Fast text fixes
Browse files Browse the repository at this point in the history
Signed-off-by: Dusan Malusev <dusan@dusanmalusev.dev>
  • Loading branch information
CodeLieutenant committed Jul 7, 2023
1 parent 9017dfb commit 1c5e5d4
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 12 deletions.
13 changes: 6 additions & 7 deletions cbits.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,7 @@ FastText_Predict_t FastText_Predict(const FastText_Handle_t handle, FastText_Str
std::istream in(&sbuf);

auto predictions = new std::vector<std::pair<fasttext::real, std::string>>();
model->predictLine(in, reinterpret_cast<std::vector<std::pair<fasttext::real, std::string>> &>(predictions), k,
threshold);
model->predictLine(in, *predictions, k, threshold);

free(query.data);
query.data = nullptr;
Expand Down Expand Up @@ -112,7 +111,7 @@ FastText_FloatVector_t FastText_Wordvec(const FastText_Handle_t handle, FastText
int64_t dimensions = model->getDimension();

auto vec = new fasttext::Vector(dimensions);
model->getWordVector(reinterpret_cast<fasttext::Vector &>(vec), std::string(word.data, word.size));
model->getWordVector(*vec, std::string(word.data, word.size));

free(word.data);
word.data = nullptr;
Expand All @@ -133,7 +132,7 @@ FastText_FloatVector_t FastText_Sentencevec(const FastText_Handle_t handle, Fast
std::istream in(&sbuf);

auto vec = new fasttext::Vector(model->getDimension());
model->getSentenceVector(in, reinterpret_cast<fasttext::Vector &>(vec));
model->getSentenceVector(in, *vec);
free(sentence.data);
sentence.data = nullptr;
sentence.size = 0;
Expand Down Expand Up @@ -163,12 +162,12 @@ FastText_PredictItem_t FastText_PredictItemAt(FastText_Predict_t predict, size_t
const auto &data = vec->at(idx);

auto str = FastText_String_t{
data.second.size(),
(char *)data.second.c_str(),
data.second.size() - sizeof("__label__") + 1,
(char *)(data.second.c_str() + sizeof("__label__") - 1),
};

return FastText_PredictItem_t{
std::exp(data.first),
data.first,
str,
};
}
5 changes: 1 addition & 4 deletions fasttext.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,8 @@ func (handle Model) MultiLinePredict(query string, k int32, threshoad float32) [
}

func (handle Model) PredictOne(query string, threshoad float32) Prediction {
cquery := C.CString(query)
defer C.free(unsafe.Pointer(cquery))

r := C.FastText_PredictOne(handle.p, C.FastText_String_t{
data: cquery,
data: C.CString(query),
size: C.size_t(len(query)),
}, C.float(threshoad))
defer C.FastText_FreePredict(r)
Expand Down
16 changes: 15 additions & 1 deletion fasttext_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,24 @@ import (

func TestOpen(t *testing.T) {
t.Parallel()

assert := require.New(t)

_, err := fasttext.Open("testdata/lid.176.ftz")

assert.NoError(err)
}


func TestPredictOne(t *testing.T) {

assert := require.New(t)

model, err := fasttext.Open("testdata/lid.176.ftz")

assert.NoError(err)

prediction := model.PredictOne("hello world from my dear C++", 0.7)

assert.Equal("en", prediction.Label)
assert.Greater(prediction.Probability, float32(0.7))
}

0 comments on commit 1c5e5d4

Please sign in to comment.