Skip to content

Commit

Permalink
Fixed issue #4272 and added tests for partition (#4280)
Browse files Browse the repository at this point in the history
  • Loading branch information
kruda authored Jun 18, 2021
1 parent f126db6 commit c7134fa
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 1 deletion.
4 changes: 3 additions & 1 deletion include/LightGBM/utils/array_args.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,9 @@ class ArrayArgs {
int j = end - 1;
int p = i;
int q = j;
if (start >= end) {
if (start >= end - 1) {
*l = start - 1;
*r = end;
return;
}
std::vector<VAL_T>& ref = *arr;
Expand Down
52 changes: 52 additions & 0 deletions tests/cpp_tests/test_array_args.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
/*!
* Copyright (c) 2021 Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See LICENSE file in the project root for license information.
*/

#include <gtest/gtest.h>
#include <LightGBM/meta.h>
#include <LightGBM/utils/array_args.h>

#include <random>

using LightGBM::data_size_t;
using LightGBM::score_t;
using LightGBM::ArrayArgs;


TEST(Partition, JustWorks) {
std::vector<score_t> gradients({0.5f, 5.0f, 1.0f, 2.0f, 2.0f});
data_size_t middle_begin, middle_end;

ArrayArgs<score_t>::Partition(&gradients, 0, gradients.size(), &middle_begin, &middle_end);

EXPECT_EQ(gradients[middle_begin + 1], gradients[middle_end - 1]);
EXPECT_GT(gradients[0], gradients[middle_begin + 1]);
EXPECT_GT(gradients[middle_begin + 1], gradients.back());
}

TEST(Partition, PartitionOneElement) {
std::vector<score_t> gradients({0.5f});
data_size_t middle_begin, middle_end;
ArrayArgs<score_t>::Partition(&gradients, 0, gradients.size(), &middle_begin, &middle_end);
EXPECT_EQ(gradients[middle_begin + 1], gradients[middle_end - 1]);
}

TEST(Partition, Empty) {
std::vector<score_t> gradients;
data_size_t middle_begin, middle_end;
ArrayArgs<score_t>::Partition(&gradients, 0, gradients.size(), &middle_begin, &middle_end);

EXPECT_EQ(middle_begin, -1);
EXPECT_EQ(middle_end, 0);
}

TEST(Partition, AllEqual) {
std::vector<score_t> gradients({0.5f, 0.5f, 0.5f});
data_size_t middle_begin, middle_end;
ArrayArgs<score_t>::Partition(&gradients, 0, gradients.size(), &middle_begin, &middle_end);

EXPECT_EQ(gradients[middle_begin + 1], gradients[middle_end - 1]);
EXPECT_EQ(middle_begin, -1);
EXPECT_EQ(middle_end, 3);
}

0 comments on commit c7134fa

Please sign in to comment.