Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enforce interaction constraints with monotone_constraints_method = intermediate/advanced #4043

Merged
merged 6 commits into from
Apr 11, 2021
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions src/treelearner/serial_tree_learner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -677,7 +677,7 @@ void SerialTreeLearner::SplitInner(Tree* tree, int best_leaf, int* left_leaf,
best_split_per_leaf_);
// update leave outputs if needed
for (auto leaf : leaves_need_update) {
RecomputeBestSplitForLeaf(leaf, &best_split_per_leaf_[leaf]);
RecomputeBestSplitForLeaf(tree, leaf, &best_split_per_leaf_[leaf]);
}
}

Expand Down Expand Up @@ -768,7 +768,7 @@ double SerialTreeLearner::GetParentOutput(const Tree* tree, const LeafSplits* le
return parent_output;
}

void SerialTreeLearner::RecomputeBestSplitForLeaf(int leaf, SplitInfo* split) {
void SerialTreeLearner::RecomputeBestSplitForLeaf(Tree* tree, int leaf, SplitInfo* split) {
jameslamb marked this conversation as resolved.
Show resolved Hide resolved
FeatureHistogram* histogram_array_;
if (!histogram_pool_.Get(leaf, &histogram_array_)) {
Log::Warning(
Expand All @@ -795,6 +795,7 @@ void SerialTreeLearner::RecomputeBestSplitForLeaf(int leaf, SplitInfo* split) {

OMP_INIT_EX();
// find splits
std::vector<int8_t> node_used_features = col_sampler_.GetByNode(tree, leaf);
#pragma omp parallel for schedule(static) num_threads(share_state_->num_threads)
for (int feature_index = 0; feature_index < num_features_; ++feature_index) {
OMP_LOOP_EX_BEGIN();
Expand All @@ -804,7 +805,7 @@ void SerialTreeLearner::RecomputeBestSplitForLeaf(int leaf, SplitInfo* split) {
}
const int tid = omp_get_thread_num();
int real_fidx = train_data_->RealFeatureIndex(feature_index);
ComputeBestSplitForFeature(histogram_array_, feature_index, real_fidx, true,
ComputeBestSplitForFeature(histogram_array_, feature_index, real_fidx, node_used_features[feature_index],
num_data, &leaf_splits, &bests[tid], parent_output);

OMP_LOOP_EX_END();
Expand Down
2 changes: 1 addition & 1 deletion src/treelearner/serial_tree_learner.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ class SerialTreeLearner: public TreeLearner {

void GetShareStates(const Dataset* dataset, bool is_constant_hessian, bool is_first_time);

void RecomputeBestSplitForLeaf(int leaf, SplitInfo* split);
void RecomputeBestSplitForLeaf(Tree* tree, int leaf, SplitInfo* split);

/*!
* \brief Some initial works before training
Expand Down
69 changes: 57 additions & 12 deletions tests/python_package_test/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1252,7 +1252,8 @@ def generate_trainset_for_monotone_constraints_tests(x3_to_category=True):
return trainset


def test_monotone_constraints():
@pytest.mark.parametrize("test_with_interaction_constraints", [True, False])
def test_monotone_constraints(test_with_interaction_constraints):
def is_increasing(y):
return (np.diff(y) >= 0.0).all()

Expand All @@ -1273,28 +1274,72 @@ def is_correctly_constrained(learner, x3_to_category=True):
monotonically_increasing_y = learner.predict(monotonically_increasing_x)
monotonically_decreasing_x = np.column_stack((fixed_x, variable_x, fixed_x))
monotonically_decreasing_y = learner.predict(monotonically_decreasing_x)
non_monotone_x = np.column_stack((fixed_x,
fixed_x,
categorize(variable_x) if x3_to_category else variable_x))
non_monotone_x = np.column_stack(
(
fixed_x,
fixed_x,
categorize(variable_x) if x3_to_category else variable_x,
)
)
non_monotone_y = learner.predict(non_monotone_x)
if not (is_increasing(monotonically_increasing_y)
and is_decreasing(monotonically_decreasing_y)
and is_non_monotone(non_monotone_y)):
if not (
is_increasing(monotonically_increasing_y)
and is_decreasing(monotonically_decreasing_y)
and is_non_monotone(non_monotone_y)
):
return False
return True

def are_interactions_enforced(gbm, feature_sets):
def parse_tree_features(gbm):
# trees start at position 1.
tree_str = gbm.model_to_string().split("Tree")[1:]
feature_sets = []
for i, tree in enumerate(tree_str):
ChristophAymannsQC marked this conversation as resolved.
Show resolved Hide resolved
# split_features are in 4th line.
features = tree.splitlines()[3].split("=")[1].split(" ")
features = set([f"Column_{f}" for f in features])
ChristophAymannsQC marked this conversation as resolved.
Show resolved Hide resolved
feature_sets.append(features)
return np.array(feature_sets)

def has_interaction(treef):
n = 0
for fs in feature_sets:
if len(treef.intersection(fs)) > 0:
n += 1
if n > 1:
return True
else:
return False
ChristophAymannsQC marked this conversation as resolved.
Show resolved Hide resolved

tree_features = parse_tree_features(gbm)
has_interaction_flag = np.array(
[has_interaction(treef) for treef in tree_features]
)

return not has_interaction_flag.any()

for test_with_categorical_variable in [True, False]:
trainset = generate_trainset_for_monotone_constraints_tests(test_with_categorical_variable)
trainset = generate_trainset_for_monotone_constraints_tests(
test_with_categorical_variable
)
for monotone_constraints_method in ["basic", "intermediate", "advanced"]:
params = {
'min_data': 20,
'num_leaves': 20,
'monotone_constraints': [1, -1, 0],
"min_data": 20,
"num_leaves": 20,
"monotone_constraints": [1, -1, 0],
"monotone_constraints_method": monotone_constraints_method,
"use_missing": False,
}
if test_with_interaction_constraints:
params["interaction_constraints"] = [[0], [1], [2]]
constrained_model = lgb.train(params, trainset)
assert is_correctly_constrained(constrained_model, test_with_categorical_variable)
assert is_correctly_constrained(
constrained_model, test_with_categorical_variable
)
if test_with_interaction_constraints:
feature_sets = [["Column_0"], ["Column_1"], "Column_2"]
assert are_interactions_enforced(constrained_model, feature_sets)


def test_monotone_penalty():
Expand Down