Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[python-package] use 2d collections for predictions, grads and hess in multiclass custom objective #4925

Merged
merged 7 commits into from
Feb 23, 2022
Merged
45 changes: 25 additions & 20 deletions python-package/lightgbm/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2946,22 +2946,21 @@ def update(self, train_set=None, fobj=None):
Should accept two parameters: preds, train_data,
and return (grad, hess).

preds : numpy 1-D array
preds : numpy 1-D array or numpy 2-D array (for multi-class task)
The predicted values.
Predicted values are returned before any transformation,
e.g. they are raw margin instead of probability of positive class for binary task.
train_data : Dataset
The training dataset.
grad : list, numpy 1-D array or pandas Series
grad : numpy 1-D array or numpy 2-D array (for multi-class task)
jmoralez marked this conversation as resolved.
Show resolved Hide resolved
The value of the first order derivative (gradient) of the loss
with respect to the elements of preds for each sample point.
hess : list, numpy 1-D array or pandas Series
hess : numpy 1-D array or numpy 2-D array (for multi-class task)
The value of the second order derivative (Hessian) of the loss
with respect to the elements of preds for each sample point.

For multi-class task, the preds is group by class_id first, then group by row_id.
If you want to get i-th row preds in j-th class, the access way is score[j * num_data + i]
and you should group grad and hess in this way as well.
For multi-class task, preds are a [n_samples, n_classes] numpy 2-D array,
and grad and hess should be returned in the same format.

Returns
-------
Expand Down Expand Up @@ -2999,6 +2998,9 @@ def update(self, train_set=None, fobj=None):
if not self.__set_objective_to_none:
self.reset_parameter({"objective": "none"}).__set_objective_to_none = True
grad, hess = fobj(self.__inner_predict(0), self.train_set)
if self.num_model_per_iteration() > 1:
Copy link
Collaborator Author

@jmoralez jmoralez Feb 23, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it safe to use _Booster__num_class here instead to avoid the lib call? I don't fully understand where __num_class gets converted to _Booster__num_class.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. It is safe since Booster.__num_class comes from the lib call. See

_safe_call(_LIB.LGBM_BoosterGetNumClasses(
self.handle,
ctypes.byref(out_num_class)))
self.__num_class = out_num_class.value

and
out_num_class = ctypes.c_int(0)
_safe_call(_LIB.LGBM_BoosterGetNumClasses(
self.handle,
ctypes.byref(out_num_class)))
self.__num_class = out_num_class.value

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean that the attribute changes name. I see it's used as self.__num_class in some places but if I add a breakpoint at that line it doesn't have that attribute but has self._Booster__num_class which is the part that confuses me. Do you think the performance impact of calling the lib on each iteration is noticeable and should be changed to use the attribute instead?

grad = grad.ravel(order='F')
hess = hess.ravel(order='F')
return self.__boost(grad, hess)

def __boost(self, grad, hess):
Expand All @@ -3008,16 +3010,15 @@ def __boost(self, grad, hess):

Score is returned before any transformation,
e.g. it is raw margin instead of probability of positive class for binary task.
For multi-class task, the score is group by class_id first, then group by row_id.
If you want to get i-th row score in j-th class, the access way is score[j * num_data + i]
and you should group grad and hess in this way as well.
For multi-class task, preds are a [n_samples, n_classes] numpy 2-D array,
and grad and hess should be returned in the same format.

Parameters
----------
grad : list, numpy 1-D array or pandas Series
grad : numpy 1-D array or numpy 2-D array (for multi-class task)
The value of the first order derivative (gradient) of the loss
with respect to the elements of score for each sample point.
hess : list, numpy 1-D array or pandas Series
hess : numpy 1-D array or numpy 2-D array (for multi-class task)
The value of the second order derivative (Hessian) of the loss
with respect to the elements of score for each sample point.

Expand Down Expand Up @@ -3159,8 +3160,8 @@ def eval(self, data, name, feval=None):
is_higher_better : bool
Is eval result higher better, e.g. AUC is ``is_higher_better``.

For multi-class task, the preds is group by class_id first, then group by row_id.
If you want to get i-th row preds in j-th class, the access way is preds[j * num_data + i].
For multi-class task, preds are a [n_samples, n_classes] numpy 2-D array,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please also check that customized evaluation function with multi class works correctly? I've read the code, and it seems that the customized evaluation function will finally take the output of __inner_predict as input, which is of the shape n_sample * n_class. This is inconsistent with the hint here.

feval_ret = eval_function(self.__inner_predict(data_idx), cur_data)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm you're right. I've only modified the portions required for fobj, I'll work on feval.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved the reshaping to __inner_predict in 5a56a30 so that it works in both places and added a test to check that we get the same result using the built-in log loss and computing it manually.

and grad and hess should be returned in the same format.

Returns
-------
Expand Down Expand Up @@ -3194,7 +3195,7 @@ def eval_train(self, feval=None):
Should accept two parameters: preds, train_data,
and return (eval_name, eval_result, is_higher_better) or list of such tuples.

preds : numpy 1-D array
preds : numpy 1-D array or numpy 2-D array (for multi-class task)
The predicted values.
If ``fobj`` is specified, predicted values are returned before any transformation,
e.g. they are raw margin instead of probability of positive class for binary task in this case.
Expand All @@ -3207,8 +3208,8 @@ def eval_train(self, feval=None):
is_higher_better : bool
Is eval result higher better, e.g. AUC is ``is_higher_better``.

For multi-class task, the preds is group by class_id first, then group by row_id.
If you want to get i-th row preds in j-th class, the access way is preds[j * num_data + i].
For multi-class task, preds are a [n_samples, n_classes] numpy 2-D array,
and grad and hess should be returned in the same format.

Returns
-------
Expand All @@ -3227,7 +3228,7 @@ def eval_valid(self, feval=None):
Should accept two parameters: preds, valid_data,
and return (eval_name, eval_result, is_higher_better) or list of such tuples.

preds : numpy 1-D array
preds : numpy 1-D array or numpy 2-D array (for multi-class task)
The predicted values.
If ``fobj`` is specified, predicted values are returned before any transformation,
e.g. they are raw margin instead of probability of positive class for binary task in this case.
Expand All @@ -3240,8 +3241,8 @@ def eval_valid(self, feval=None):
is_higher_better : bool
Is eval result higher better, e.g. AUC is ``is_higher_better``.

For multi-class task, the preds is group by class_id first, then group by row_id.
If you want to get i-th row preds in j-th class, the access way is preds[j * num_data + i].
For multi-class task, preds are a [n_samples, n_classes] numpy 2-D array,
and grad and hess should be returned in the same format.

Returns
-------
Expand Down Expand Up @@ -3866,7 +3867,11 @@ def __inner_predict(self, data_idx):
if tmp_out_len.value != len(self.__inner_predict_buffer[data_idx]):
raise ValueError(f"Wrong length of predict results for data {data_idx}")
self.__is_predicted_cur_iter[data_idx] = True
return self.__inner_predict_buffer[data_idx]
result = self.__inner_predict_buffer[data_idx]
if self.__num_class > 1:
num_data = result.size // self.__num_class
result = result.reshape(num_data, self.__num_class, order='F')
return result

def __get_eval_info(self):
"""Get inner evaluation count and names."""
Expand Down
63 changes: 27 additions & 36 deletions python-package/lightgbm/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def __init__(self, func: _LGBM_ScikitCustomObjectiveFunction):

y_true : numpy 1-D array of shape = [n_samples]
The target values.
y_pred : numpy 1-D array of shape = [n_samples] or shape = [n_samples * n_classes] (for multi-class task)
y_pred : numpy 1-D array of shape = [n_samples] or numpy 2-D array of shape = [n_samples, n_classes] (for multi-class task)
The predicted values.
Predicted values are returned before any transformation,
e.g. they are raw margin instead of probability of positive class for binary task.
Expand All @@ -69,18 +69,17 @@ def __init__(self, func: _LGBM_ScikitCustomObjectiveFunction):
sum(group) = n_samples.
For example, if you have a 100-document dataset with ``group = [10, 20, 40, 10, 10, 10]``, that means that you have 6 groups,
where the first 10 records are in the first group, records 11-30 are in the second group, records 31-70 are in the third group, etc.
grad : list, numpy 1-D array or pandas Series of shape = [n_samples] or shape = [n_samples * n_classes] (for multi-class task)
grad : numpy 1-D array of shape = [n_samples] or numpy 2-D array of shape [n_samples, n_classes] (for multi-class task)
The value of the first order derivative (gradient) of the loss
with respect to the elements of y_pred for each sample point.
hess : list, numpy 1-D array or pandas Series of shape = [n_samples] or shape = [n_samples * n_classes] (for multi-class task)
hess : numpy 1-D array of shape = [n_samples] or numpy 2-D array of shape [n_samples, n_classes] (for multi-class task)
The value of the second order derivative (Hessian) of the loss
with respect to the elements of y_pred for each sample point.

.. note::

For multi-class task, the y_pred is group by class_id first, then group by row_id.
If you want to get i-th row y_pred in j-th class, the access way is y_pred[j * num_data + i]
and you should group grad and hess in this way as well.
For multi-class task, preds are a [n_samples, n_classes] numpy 2-D array,
and grad and hess should be returned in the same format.
"""
self.func = func

Expand All @@ -89,17 +88,17 @@ def __call__(self, preds, dataset):

Parameters
----------
preds : numpy 1-D array of shape = [n_samples] or shape = [n_samples * n_classes] (for multi-class task)
preds : numpy 1-D array of shape = [n_samples] or numpy 2-D array of shape = [n_samples, n_classes] (for multi-class task)
The predicted values.
dataset : Dataset
The training dataset.

Returns
-------
grad : list, numpy 1-D array or pandas Series of shape = [n_samples] or shape = [n_samples * n_classes] (for multi-class task)
grad : numpy 1-D array of shape = [n_samples] or numpy 2-D array of shape [n_samples, n_classes] (for multi-class task)
The value of the first order derivative (gradient) of the loss
with respect to the elements of preds for each sample point.
hess : list, numpy 1-D array or pandas Series of shape = [n_samples] or shape = [n_samples * n_classes] (for multi-class task)
hess : numpy 1-D array of shape = [n_samples] or numpy 2-D array of shape [n_samples, n_classes] (for multi-class task)
The value of the second order derivative (Hessian) of the loss
with respect to the elements of preds for each sample point.
"""
Expand All @@ -114,20 +113,13 @@ def __call__(self, preds, dataset):
"""weighted for objective"""
weight = dataset.get_weight()
if weight is not None:
"""only one class"""
if len(weight) == len(grad):
grad = np.multiply(grad, weight)
hess = np.multiply(hess, weight)
else:
num_data = len(weight)
num_class = len(grad) // num_data
if num_class * num_data != len(grad):
raise ValueError("Length of grad and hess should equal to num_class * num_data")
for k in range(num_class):
for i in range(num_data):
idx = k * num_data + i
grad[idx] *= weight[i]
hess[idx] *= weight[i]
if grad.ndim == 2: # multi-class
num_data = grad.shape[0]
if weight.size != num_data:
raise ValueError("grad and hess should be of shape [n_samples, n_classes]")
weight = weight.reshape(num_data, 1)
grad *= weight
hess *= weight
Comment on lines +116 to +122
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The grad and hess are weighted in the sklearn interface but they're not in basic, should we weigh them there as well?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@guolinke Hey! Don't you remember the reason for doing this?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think with interfaces in basic.py, the weighting will be done in the C++ side finally. I'll double check why weighting is done here directly with sklearn interfaces.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@shiyu1994 You've merged this PR without resolving this conversation. Could you please share your findings about weighting derivatives here?

Copy link
Collaborator

@shiyu1994 shiyu1994 Feb 23, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry. I did not notice that what we discussed above is customized objective. I thought we are discussing native objectives of LightGBM. I suddenly noticed that weights with customized objective function is not handled correctly for Python API. See the code below.

import numpy as np
import lightgbm as lgb

def fobj(preds, train_data):
    labels = train_data.get_label()
    return preds - labels, np.ones_like(labels)

def test():
    np.random.seed(123)
    num_data = 10000
    num_feature = 100
    train_X = np.random.randn(num_data, num_feature)
    train_y = np.mean(train_X, axis=-1)
    valid_X = np.random.randn(num_data, num_feature)
    valid_y = np.mean(valid_X, axis=-1)
    weights = np.random.rand(num_data)
    train_data = lgb.Dataset(train_X, train_y, weight=weights)
    valid_data = lgb.Dataset(valid_X, valid_y)
    params = {
        "verbose": 2,
        "metric": "rmse",
        "learning_rate": 0.2,
        "num_trees": 20,
    }
    booster = lgb.train(train_set=train_data, valid_sets=[valid_data], valid_names=["valid"], params=params, fobj=fobj)

if __name__ == "__main__":
    test()

If we comment out the weights in the training dataset construction. The code will provide exactly the same output as below.

[LightGBM] [Warning] Using self-defined objective function
[LightGBM] [Debug] Dataset::GetMultiBinFromAllFeatures: sparse rate 0.000000
[LightGBM] [Debug] init for col-wise cost 0.000012 seconds, init for row-wise cost 0.001697 seconds
[LightGBM] [Warning] Auto-choosing col-wise multi-threading, the overhead of testing was 0.004134 seconds.
You can set `force_col_wise=true` to remove the overhead.
[LightGBM] [Info] Total Bins 25500
[LightGBM] [Info] Number of data points in the train set: 10000, number of used features: 100
[LightGBM] [Warning] Using self-defined objective function
[LightGBM] [Debug] Trained a tree with leaves = 31 and depth = 7
[1]	valid's rmse: 0.100043
[LightGBM] [Debug] Trained a tree with leaves = 31 and depth = 6
[2]	valid's rmse: 0.099099
[LightGBM] [Debug] Trained a tree with leaves = 31 and depth = 8
[3]	valid's rmse: 0.0982311
[LightGBM] [Debug] Trained a tree with leaves = 31 and depth = 7
[4]	valid's rmse: 0.0974867
[LightGBM] [Debug] Trained a tree with leaves = 31 and depth = 7
[5]	valid's rmse: 0.0965613
[LightGBM] [Debug] Trained a tree with leaves = 31 and depth = 7
[6]	valid's rmse: 0.0957191
[LightGBM] [Debug] Trained a tree with leaves = 31 and depth = 8
[7]	valid's rmse: 0.0949163
[LightGBM] [Debug] Trained a tree with leaves = 31 and depth = 6
[8]	valid's rmse: 0.0940159
[LightGBM] [Debug] Trained a tree with leaves = 31 and depth = 7
[9]	valid's rmse: 0.0932777
[LightGBM] [Debug] Trained a tree with leaves = 31 and depth = 8
[10]	valid's rmse: 0.0924858
[LightGBM] [Debug] Trained a tree with leaves = 31 and depth = 7
[11]	valid's rmse: 0.0917661
[LightGBM] [Debug] Trained a tree with leaves = 31 and depth = 8
[12]	valid's rmse: 0.0909356
[LightGBM] [Debug] Trained a tree with leaves = 31 and depth = 7
[13]	valid's rmse: 0.0901323
[LightGBM] [Debug] Trained a tree with leaves = 31 and depth = 8
[14]	valid's rmse: 0.0894671
[LightGBM] [Debug] Trained a tree with leaves = 31 and depth = 8
[15]	valid's rmse: 0.0888048
[LightGBM] [Debug] Trained a tree with leaves = 31 and depth = 8
[16]	valid's rmse: 0.0881257
[LightGBM] [Debug] Trained a tree with leaves = 31 and depth = 8
[17]	valid's rmse: 0.0874723
[LightGBM] [Debug] Trained a tree with leaves = 31 and depth = 7
[18]	valid's rmse: 0.0868133
[LightGBM] [Debug] Trained a tree with leaves = 31 and depth = 8
[19]	valid's rmse: 0.0862182
[LightGBM] [Debug] Trained a tree with leaves = 31 and depth = 7
[20]	valid's rmse: 0.0856057

We need a separate PR to fix this.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW, I find an additional issue. The latest master branch did not produce any evaluation results in the log as above. I get the log with version 3.3.2 instead. This is another issue we need to investigate.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suddenly noticed that weights with customized objective function is not handled correctly for Python API.

Yes, that's what I noticed when I saw that in the scikit-learn interface grad and hess are weighted before boosting. I don't know if it's because in basic you get a Dataset and have access to the weights and can weigh them in the objective function and in sklearn you can't but if that's the case it's worth mentioning in the docs.

The latest master branch did not produce any evaluation results in the log as above.

I believe this is because callbacks are now preferred (#4878), to log the evaluation you have to specify callbacks=[lgb.log_evaluation(1)]

return grad, hess


Expand All @@ -152,7 +144,7 @@ def __init__(self, func: _LGBM_ScikitCustomEvalFunction):

y_true : numpy 1-D array of shape = [n_samples]
The target values.
y_pred : numpy 1-D array of shape = [n_samples] or shape = [n_samples * n_classes] (for multi-class task)
y_pred : numpy 1-D array of shape = [n_samples] or numpy 2-D array shape = [n_samples, n_classes] (for multi-class task)
The predicted values.
In case of custom ``objective``, predicted values are returned before any transformation,
e.g. they are raw margin instead of probability of positive class for binary task in this case.
Expand All @@ -173,8 +165,8 @@ def __init__(self, func: _LGBM_ScikitCustomEvalFunction):

.. note::

For multi-class task, the y_pred is group by class_id first, then group by row_id.
If you want to get i-th row y_pred in j-th class, the access way is y_pred[j * num_data + i].
For multi-class task, preds are a [n_samples, n_classes] numpy 2-D array,
and grad and hess should be returned in the same format.
"""
self.func = func

Expand All @@ -183,7 +175,7 @@ def __call__(self, preds, dataset):

Parameters
----------
preds : numpy 1-D array of shape = [n_samples] or shape = [n_samples * n_classes] (for multi-class task)
preds : numpy 1-D array of shape = [n_samples] or numpy 2-D array of shape = [n_samples, n_classes] (for multi-class task)
The predicted values.
dataset : Dataset
The training dataset.
Expand Down Expand Up @@ -286,7 +278,7 @@ def __call__(self, preds, dataset):

y_true : numpy 1-D array of shape = [n_samples]
The target values.
y_pred : numpy 1-D array of shape = [n_samples] or shape = [n_samples * n_classes] (for multi-class task)
y_pred : numpy 1-D array of shape = [n_samples] or numpy 2-D array of shape = [n_samples, n_classes] (for multi-class task)
The predicted values.
In case of custom ``objective``, predicted values are returned before any transformation,
e.g. they are raw margin instead of probability of positive class for binary task in this case.
Expand All @@ -305,8 +297,8 @@ def __call__(self, preds, dataset):
is_higher_better : bool
Is eval result higher better, e.g. AUC is ``is_higher_better``.

For multi-class task, the y_pred is group by class_id first, then group by row_id.
If you want to get i-th row y_pred in j-th class, the access way is y_pred[j * num_data + i].
For multi-class task, preds are a [n_samples, n_classes] numpy 2-D array,
and grad and hess should be returned in the same format.
"""

_lgbmmodel_doc_predict = (
Expand Down Expand Up @@ -463,7 +455,7 @@ def __init__(

y_true : numpy 1-D array of shape = [n_samples]
The target values.
y_pred : numpy 1-D array of shape = [n_samples] or shape = [n_samples * n_classes] (for multi-class task)
y_pred : numpy 1-D array of shape = [n_samples] or numpy 2-D array of shape = [n_samples, n_classes] (for multi-class task)
The predicted values.
Predicted values are returned before any transformation,
e.g. they are raw margin instead of probability of positive class for binary task.
Expand All @@ -473,16 +465,15 @@ def __init__(
sum(group) = n_samples.
For example, if you have a 100-document dataset with ``group = [10, 20, 40, 10, 10, 10]``, that means that you have 6 groups,
where the first 10 records are in the first group, records 11-30 are in the second group, records 31-70 are in the third group, etc.
grad : list, numpy 1-D array or pandas Series of shape = [n_samples] or shape = [n_samples * n_classes] (for multi-class task)
grad : numpy 1-D array of shape = [n_samples] or numpy 2-D array of shape = [n_samples, n_classes] (for multi-class task)
The value of the first order derivative (gradient) of the loss
with respect to the elements of y_pred for each sample point.
hess : list, numpy 1-D array or pandas Series of shape = [n_samples] or shape = [n_samples * n_classes] (for multi-class task)
hess : numpy 1-D array of shape = [n_samples] or numpy 2-D array of shape = [n_samples, n_classes] (for multi-class task)
The value of the second order derivative (Hessian) of the loss
with respect to the elements of y_pred for each sample point.

For multi-class task, the y_pred is group by class_id first, then group by row_id.
If you want to get i-th row y_pred in j-th class, the access way is y_pred[j * num_data + i]
and you should group grad and hess in this way as well.
For multi-class task, preds are a [n_samples, n_classes] numpy 2-D array,
and grad and hess should be returned in the same format.
"""
if not SKLEARN_INSTALLED:
raise LightGBMError('scikit-learn is required for lightgbm.sklearn. '
Expand Down
55 changes: 52 additions & 3 deletions tests/python_package_test/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@
import numpy as np
import pytest
from scipy import sparse
from sklearn.datasets import dump_svmlight_file, load_svmlight_file
from sklearn.datasets import dump_svmlight_file, load_svmlight_file, make_blobs
from sklearn.metrics import log_loss
from sklearn.model_selection import train_test_split

import lightgbm as lgb
from lightgbm.compat import PANDAS_INSTALLED, pd_DataFrame, pd_Series

from .utils import load_breast_cancer
from .utils import load_breast_cancer, sklearn_multiclass_custom_objective, softmax


def test_basic(tmp_path):
Expand Down Expand Up @@ -587,7 +588,7 @@ def _bad_gradients(preds, _):


def _good_gradients(preds, _):
return np.random.randn(len(preds)), np.random.rand(len(preds))
return np.random.randn(*preds.shape), np.random.rand(*preds.shape)


def test_custom_objective_safety():
Expand All @@ -609,3 +610,51 @@ def test_custom_objective_safety():
good_bst_multi.update(fobj=_good_gradients)
with pytest.raises(ValueError, match=re.escape(f"number of models per one iteration ({nclass})")):
bad_bst_multi.update(fobj=_bad_gradients)


def test_multiclass_custom_objective():
def custom_obj(y_pred, ds):
y_true = ds.get_label()
return sklearn_multiclass_custom_objective(y_true, y_pred)

centers = [[-4, -4], [4, 4], [-4, 4]]
X, y = make_blobs(n_samples=1_000, centers=centers, random_state=42)
ds = lgb.Dataset(X, y)
params = {'objective': 'multiclass', 'num_class': 3, 'num_leaves': 7}
builtin_obj_bst = lgb.train(params, ds, num_boost_round=10)
builtin_obj_preds = builtin_obj_bst.predict(X)

custom_obj_bst = lgb.train(params, ds, num_boost_round=10, fobj=custom_obj)
custom_obj_preds = softmax(custom_obj_bst.predict(X))

np.testing.assert_allclose(builtin_obj_preds, custom_obj_preds, rtol=0.01)


def test_multiclass_custom_eval():
def custom_eval(y_pred, ds):
y_true = ds.get_label()
return 'custom_logloss', log_loss(y_true, y_pred), False

centers = [[-4, -4], [4, 4], [-4, 4]]
X, y = make_blobs(n_samples=1_000, centers=centers, random_state=42)
X_train, X_valid, y_train, y_valid = train_test_split(X, y, test_size=0.2, random_state=0)
train_ds = lgb.Dataset(X_train, y_train)
valid_ds = lgb.Dataset(X_valid, y_valid, reference=train_ds)
params = {'objective': 'multiclass', 'num_class': 3, 'num_leaves': 7}
eval_result = {}
bst = lgb.train(
params,
train_ds,
num_boost_round=10,
valid_sets=[train_ds, valid_ds],
valid_names=['train', 'valid'],
feval=custom_eval,
callbacks=[lgb.record_evaluation(eval_result)],
keep_training_booster=True,
)

for key, ds in zip(['train', 'valid'], [train_ds, valid_ds]):
np.testing.assert_allclose(eval_result[key]['multi_logloss'], eval_result[key]['custom_logloss'])
_, metric, value, _ = bst.eval(ds, key, feval=custom_eval)[1] # first element is multi_logloss
assert metric == 'custom_logloss'
np.testing.assert_allclose(value, eval_result[key][metric][-1])
Loading