Skip to content

Commit

Permalink
Fixing tests (#6)
Browse files Browse the repository at this point in the history
* Fixing tests

* Fix src/bpe_model_test.cc (#7)

Co-authored-by: Kuba Podgórski <kuba--@users.noreply.github.com>

---------

Co-authored-by: rbehjati <razieh.behjati@gmail.com>
  • Loading branch information
kuba-- and rbehjati authored Oct 28, 2023
1 parent d1a4bb4 commit 5cec47d
Show file tree
Hide file tree
Showing 15 changed files with 132 additions and 71 deletions.
23 changes: 14 additions & 9 deletions src/bpe_model_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ TEST(BPEModelTest, EncodeTest) {
model_proto.mutable_pieces(12)->set_type( // r
ModelProto::SentencePiece::USER_DEFINED);

const Model model(model_proto);
const Model model(std::make_unique<const ModelProto>(model_proto));

EncodeResult result;

Expand Down Expand Up @@ -149,7 +149,7 @@ TEST(BPEModelTest, EncodeAmbiguousTest) {
AddPiece(&model_proto, "a", -0.4);
AddPiece(&model_proto, "b", -0.5);

const Model model(model_proto);
const Model model(std::make_unique<const ModelProto>(model_proto));

EncodeResult result;

Expand Down Expand Up @@ -188,7 +188,8 @@ TEST(BPEModelTest, EncodeAmbiguousTest) {

TEST(BPEModelTest, NotSupportedTest) {
ModelProto model_proto = MakeBaseModelProto();
const Model model(model_proto);

const Model model(std::make_unique<const ModelProto>(model_proto));
EXPECT_EQ(NBestEncodeResult(), model.NBestEncode("test", 10));
}

Expand All @@ -206,15 +207,16 @@ TEST(BPEModelTest, EncodeWithUnusedTest) {

// No unused.
{
const Model model(model_proto);
const Model model(std::make_unique<const ModelProto>(model_proto));
const auto result = model.Encode("abcd");
EXPECT_EQ(1, result.size());
EXPECT_EQ("abcd", result[0].first);
}

{
model_proto.mutable_pieces(3)->set_type(ModelProto::SentencePiece::UNUSED);
const Model model(model_proto);

const Model model(std::make_unique<const ModelProto>(model_proto));
const auto result = model.Encode("abcd");
EXPECT_EQ(2, result.size());
EXPECT_EQ("abc", result[0].first);
Expand All @@ -225,7 +227,8 @@ TEST(BPEModelTest, EncodeWithUnusedTest) {
// The parent rule "abc" is still alive even if the child "ab" is unused.
model_proto.mutable_pieces(3)->set_type(ModelProto::SentencePiece::UNUSED);
model_proto.mutable_pieces(5)->set_type(ModelProto::SentencePiece::UNUSED);
const Model model(model_proto);

const Model model(std::make_unique<const ModelProto>(model_proto));
const auto result = model.Encode("abcd");
EXPECT_EQ(2, result.size());
EXPECT_EQ("abc", result[0].first);
Expand All @@ -240,7 +243,8 @@ TEST(BPEModelTest, EncodeWithUnusedTest) {
model_proto.mutable_pieces(3)->set_type(ModelProto::SentencePiece::UNUSED);
model_proto.mutable_pieces(4)->set_type(ModelProto::SentencePiece::UNUSED);
model_proto.mutable_pieces(5)->set_type(ModelProto::SentencePiece::NORMAL);
const Model model(model_proto);

const Model model(std::make_unique<const ModelProto>(model_proto));
const auto result = model.Encode("abcd");
EXPECT_EQ(3, result.size());
EXPECT_EQ("ab", result[0].first);
Expand All @@ -259,7 +263,7 @@ TEST(SampleModelTest, EncodeTest) {

// No regularization
{
const Model model(model_proto);
const Model model(std::make_unique<const ModelProto>(model_proto));
const auto result = model.Encode("abcd");
EXPECT_EQ(1, result.size());
EXPECT_EQ("abcd", result[0].first);
Expand All @@ -275,7 +279,8 @@ TEST(SampleModelTest, EncodeTest) {
return out;
};

const Model model(model_proto);
const Model model(std::make_unique<const ModelProto>(model_proto));

const std::vector<double> kAlpha = {0.0, 0.1, 0.5, 0.7, 0.9};
for (const auto alpha : kAlpha) {
constexpr int kTrial = 100000;
Expand Down
6 changes: 3 additions & 3 deletions src/bpe_model_trainer_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,12 @@ std::string RunTrainer(
SentencePieceProcessor processor;
EXPECT_TRUE(processor.Load(model_prefix + ".model").ok());

const auto &model = processor.model_proto();
const auto model = processor.model_proto();
std::vector<std::string> pieces;

// remove <unk>, <s>, </s>
for (int i = 3; i < model.pieces_size(); ++i) {
pieces.emplace_back(model.pieces(i).piece());
for (int i = 3; i < model->pieces_size(); ++i) {
pieces.emplace_back(model->pieces(i).piece());
}

return absl::StrJoin(pieces, " ");
Expand Down
6 changes: 3 additions & 3 deletions src/char_model_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ void AddPiece(ModelProto *model_proto, const std::string &piece,
}

TEST(ModelTest, EncodeTest) {
ModelProto model_proto = MakeBaseModelProto();

auto model_proto = MakeBaseModelProto();
AddPiece(&model_proto, WS, 0.0);
AddPiece(&model_proto, "a", 0.1);
AddPiece(&model_proto, "b", 0.2);
Expand All @@ -60,7 +60,7 @@ TEST(ModelTest, EncodeTest) {
model_proto.mutable_pieces(8)->set_type(
ModelProto::SentencePiece::USER_DEFINED);

const Model model(model_proto);
const Model model(std::make_unique<const ModelProto>(model_proto));

EncodeResult result;

Expand Down Expand Up @@ -108,7 +108,7 @@ TEST(ModelTest, EncodeTest) {

TEST(CharModelTest, NotSupportedTest) {
ModelProto model_proto = MakeBaseModelProto();
const Model model(model_proto);
const Model model(std::make_unique<const ModelProto>(model_proto));
EXPECT_EQ(NBestEncodeResult(), model.NBestEncode("test", 10));
EXPECT_EQ(EncodeResult(), model.SampleEncode("test", 0.1));
}
Expand Down
6 changes: 3 additions & 3 deletions src/char_model_trainer_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,12 @@ std::string RunTrainer(const std::vector<std::string> &input, int size) {
SentencePieceProcessor processor;
EXPECT_TRUE(processor.Load(model_prefix + ".model").ok());

const auto &model = processor.model_proto();
const auto model = processor.model_proto();
std::vector<std::string> pieces;

// remove <unk>, <s>, </s>
for (int i = 3; i < model.pieces_size(); ++i) {
pieces.emplace_back(model.pieces(i).piece());
for (int i = 3; i < model->pieces_size(); ++i) {
pieces.emplace_back(model->pieces(i).piece());
}

return absl::StrJoin(pieces, " ");
Expand Down
8 changes: 4 additions & 4 deletions src/model_factory_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,22 +37,22 @@ TEST(ModelFactoryTest, BasicTest) {

{
model_proto.mutable_trainer_spec()->set_model_type(TrainerSpec::UNIGRAM);
auto m = ModelFactory::Create(model_proto);
auto m = ModelFactory::Create(std::make_unique<const ModelProto>(model_proto));
}

{
model_proto.mutable_trainer_spec()->set_model_type(TrainerSpec::BPE);
auto m = ModelFactory::Create(model_proto);
auto m = ModelFactory::Create(std::make_unique<const ModelProto>(model_proto));
}

{
model_proto.mutable_trainer_spec()->set_model_type(TrainerSpec::WORD);
auto m = ModelFactory::Create(model_proto);
auto m = ModelFactory::Create(std::make_unique<const ModelProto>(model_proto));
}

{
model_proto.mutable_trainer_spec()->set_model_type(TrainerSpec::CHAR);
auto m = ModelFactory::Create(model_proto);
auto m = ModelFactory::Create(std::make_unique<const ModelProto>(model_proto));
}
}
} // namespace sentencepiece
51 changes: 37 additions & 14 deletions src/model_interface_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,9 @@ TEST(ModelInterfaceTest, GetDefaultPieceTest) {
{
ModelProto model_proto = MakeBaseModelProto(TrainerSpec::UNIGRAM);
AddPiece(&model_proto, "a");
auto model = ModelFactory::Create(model_proto);

auto model = ModelFactory::Create(std::make_unique<const ModelProto>(model_proto));

EXPECT_EQ("<unk>", model->unk_piece());
EXPECT_EQ("<s>", model->bos_piece());
EXPECT_EQ("</s>", model->eos_piece());
Expand All @@ -86,7 +88,9 @@ TEST(ModelInterfaceTest, GetDefaultPieceTest) {
model_proto.mutable_trainer_spec()->clear_bos_piece();
model_proto.mutable_trainer_spec()->clear_eos_piece();
model_proto.mutable_trainer_spec()->clear_pad_piece();
auto model = ModelFactory::Create(model_proto);

auto model = ModelFactory::Create(std::make_unique<const ModelProto>(model_proto));

EXPECT_EQ("<unk>", model->unk_piece());
EXPECT_EQ("<s>", model->bos_piece());
EXPECT_EQ("</s>", model->eos_piece());
Expand All @@ -100,7 +104,9 @@ TEST(ModelInterfaceTest, GetDefaultPieceTest) {
model_proto.mutable_trainer_spec()->set_bos_piece("BOS");
model_proto.mutable_trainer_spec()->set_eos_piece("EOS");
model_proto.mutable_trainer_spec()->set_pad_piece("PAD");
auto model = ModelFactory::Create(model_proto);

auto model = ModelFactory::Create(std::make_unique<const ModelProto>(model_proto));

EXPECT_EQ("UNK", model->unk_piece());
EXPECT_EQ("BOS", model->bos_piece());
EXPECT_EQ("EOS", model->eos_piece());
Expand All @@ -116,7 +122,8 @@ TEST(ModelInterfaceTest, SetModelInterfaceTest) {
AddPiece(&model_proto, "c");
AddPiece(&model_proto, "d");

auto model = ModelFactory::Create(model_proto);
auto model = ModelFactory::Create(std::make_unique<const ModelProto>(model_proto));

EXPECT_EQ(model_proto.SerializeAsString(),
model->model_proto().SerializeAsString());
}
Expand All @@ -135,7 +142,7 @@ TEST(ModelInterfaceTest, PieceToIdTest) {
model_proto.mutable_pieces(7)->set_type(
ModelProto::SentencePiece::USER_DEFINED);

auto model = ModelFactory::Create(model_proto);
auto model = ModelFactory::Create(std::make_unique<const ModelProto>(model_proto));

EXPECT_EQ(model_proto.SerializeAsString(),
model->model_proto().SerializeAsString());
Expand Down Expand Up @@ -212,7 +219,9 @@ TEST(ModelInterfaceTest, InvalidModelTest) {
{
ModelProto model_proto = MakeBaseModelProto(TrainerSpec::UNIGRAM);
AddPiece(&model_proto, "");
auto model = ModelFactory::Create(model_proto);

auto model = ModelFactory::Create(std::make_unique<const ModelProto>(model_proto));

EXPECT_FALSE(model->status().ok());
}

Expand All @@ -221,23 +230,29 @@ TEST(ModelInterfaceTest, InvalidModelTest) {
ModelProto model_proto = MakeBaseModelProto(TrainerSpec::UNIGRAM);
AddPiece(&model_proto, "a");
AddPiece(&model_proto, "a");
auto model = ModelFactory::Create(model_proto);

auto model = ModelFactory::Create(std::make_unique<const ModelProto>(model_proto));

EXPECT_FALSE(model->status().ok());
}

// Multiple unknowns.
{
ModelProto model_proto = MakeBaseModelProto(TrainerSpec::UNIGRAM);
model_proto.mutable_pieces(1)->set_type(ModelProto::SentencePiece::UNKNOWN);
auto model = ModelFactory::Create(model_proto);

auto model = ModelFactory::Create(std::make_unique<const ModelProto>(model_proto));

EXPECT_FALSE(model->status().ok());
}

// No unknown.
{
ModelProto model_proto = MakeBaseModelProto(TrainerSpec::UNIGRAM);
model_proto.mutable_pieces(0)->set_type(ModelProto::SentencePiece::CONTROL);
auto model = ModelFactory::Create(model_proto);

auto model = ModelFactory::Create(std::make_unique<const ModelProto>(model_proto));

EXPECT_FALSE(model->status().ok());
}
}
Expand All @@ -249,7 +264,9 @@ TEST(ModelInterfaceTest, ByteFallbackModelTest) {
AddBytePiece(&model_proto, i);
}
AddPiece(&model_proto, "a");
auto model = ModelFactory::Create(model_proto);

auto model = ModelFactory::Create(std::make_unique<const ModelProto>(model_proto));

EXPECT_TRUE(model->status().ok());
}

Expand All @@ -260,7 +277,9 @@ TEST(ModelInterfaceTest, ByteFallbackModelTest) {
AddBytePiece(&model_proto, i);
}
AddPiece(&model_proto, "a");
auto model = ModelFactory::Create(model_proto);

auto model = ModelFactory::Create(std::make_unique<const ModelProto>(model_proto));

EXPECT_FALSE(model->status().ok());
}

Expand All @@ -271,7 +290,9 @@ TEST(ModelInterfaceTest, ByteFallbackModelTest) {
AddBytePiece(&model_proto, i);
}
AddPiece(&model_proto, "a");
auto model = ModelFactory::Create(model_proto);

auto model = ModelFactory::Create(std::make_unique<const ModelProto>(model_proto));

EXPECT_FALSE(model->status().ok());
}
}
Expand Down Expand Up @@ -307,7 +328,8 @@ TEST(ModelInterfaceTest, PieceToIdStressTest) {
AddPiece(&model_proto, piece);
}

auto model = ModelFactory::Create(model_proto);
auto model = ModelFactory::Create(std::make_unique<const ModelProto>(model_proto));

for (const auto &it : expected_p2i) {
EXPECT_EQ(it.second, model->PieceToId(it.first));
}
Expand Down Expand Up @@ -505,7 +527,8 @@ TEST(ModelInterfaceTest, VerifyOutputsEquivalent) {
ModelProto model_proto = MakeBaseModelProto(type);
AddPiece(&model_proto, "a", 1.0);
AddPiece(&model_proto, "b", 2.0);
auto model = ModelFactory::Create(model_proto);

auto model = ModelFactory::Create(std::make_unique<const ModelProto>(model_proto));

// Equivalent outputs.
EXPECT_TRUE(model->VerifyOutputsEquivalent("", ""));
Expand Down
8 changes: 4 additions & 4 deletions src/sentencepiece_processor_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -969,7 +969,7 @@ TEST(SentencePieceProcessorTest, LoadSerializedProtoTest) {
EXPECT_FALSE(sp.LoadFromSerializedProto("__NOT_A_PROTO__").ok());
EXPECT_TRUE(sp.LoadFromSerializedProto(model_proto.SerializeAsString()).ok());
EXPECT_EQ(model_proto.SerializeAsString(),
sp.model_proto().SerializeAsString());
sp.model_proto()->SerializeAsString());
}

TEST(SentencePieceProcessorTest, EndToEndTest) {
Expand Down Expand Up @@ -1004,7 +1004,7 @@ TEST(SentencePieceProcessorTest, EndToEndTest) {
sp.Load(util::JoinPath(absl::GetFlag(FLAGS_test_tmpdir), "model")).ok());

EXPECT_EQ(model_proto.SerializeAsString(),
sp.model_proto().SerializeAsString());
sp.model_proto()->SerializeAsString());

EXPECT_EQ(8, sp.GetPieceSize());
EXPECT_EQ(0, sp.PieceToId("<unk>"));
Expand Down Expand Up @@ -1273,7 +1273,7 @@ TEST(SentencePieceProcessorTest, EndToEndTest) {

auto RunTest = [&model_proto](const SentencePieceProcessor &sp) {
EXPECT_EQ(model_proto.SerializeAsString(),
sp.model_proto().SerializeAsString());
sp.model_proto()->SerializeAsString());

EXPECT_EQ(8, sp.GetPieceSize());
EXPECT_EQ(0, sp.PieceToId("<unk>"));
Expand Down Expand Up @@ -1351,7 +1351,7 @@ TEST(SentencePieceProcessorTest, EndToEndTest) {
const ModelProto *moved_ptr = moved.get();
*moved = model_proto;
EXPECT_TRUE(sp.Load(std::move(moved)).ok());
EXPECT_EQ(moved_ptr, &sp.model_proto());
EXPECT_EQ(moved_ptr, sp.model_proto());
RunTest(sp);
}

Expand Down
10 changes: 5 additions & 5 deletions src/sentencepiece_trainer_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,17 @@ static constexpr char kIdsDenormTsv[] = "ids_denorm.tsv";
void CheckVocab(absl::string_view filename, int expected_vocab_size) {
SentencePieceProcessor sp;
ASSERT_TRUE(sp.Load(filename.data()).ok());
EXPECT_EQ(expected_vocab_size, sp.model_proto().trainer_spec().vocab_size());
EXPECT_EQ(sp.model_proto().pieces_size(),
sp.model_proto().trainer_spec().vocab_size());
EXPECT_EQ(expected_vocab_size, sp.model_proto()->trainer_spec().vocab_size());
EXPECT_EQ(sp.model_proto()->pieces_size(),
sp.model_proto()->trainer_spec().vocab_size());
}

void CheckNormalizer(absl::string_view filename, bool expected_has_normalizer,
bool expected_has_denormalizer) {
SentencePieceProcessor sp;
ASSERT_TRUE(sp.Load(filename.data()).ok());
const auto &normalizer_spec = sp.model_proto().normalizer_spec();
const auto &denormalizer_spec = sp.model_proto().denormalizer_spec();
const auto &normalizer_spec = sp.model_proto()->normalizer_spec();
const auto &denormalizer_spec = sp.model_proto()->denormalizer_spec();
EXPECT_EQ(!normalizer_spec.precompiled_charsmap().empty(),
expected_has_normalizer);
EXPECT_EQ(!denormalizer_spec.precompiled_charsmap().empty(),
Expand Down
2 changes: 1 addition & 1 deletion src/testharness.h
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ std::vector<T> ValuesIn(const std::vector<T> &v) {
} \
ParamType param_; \
void SetParam(const ParamType &param) { param_ = param; } \
const ParamType GetParam() { return param_; } \
ParamType GetParam() const { return param_; } \
void _Run(); \
static void _RunIt() { \
TCONCAT(base, _Test_p_, name) t; \
Expand Down
Loading

0 comments on commit 5cec47d

Please sign in to comment.