-
Notifications
You must be signed in to change notification settings - Fork 3.8k
/
engine.py
611 lines (550 loc) · 29 KB
/
engine.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
# coding: utf-8
"""Library with training routines of LightGBM."""
import collections
import copy
from operator import attrgetter
import numpy as np
from . import callback
from .basic import Booster, Dataset, LightGBMError, _ConfigAliases, _InnerPredictor, _log_warning
from .compat import SKLEARN_INSTALLED, _LGBMGroupKFold, _LGBMStratifiedKFold
def train(params, train_set, num_boost_round=100,
valid_sets=None, valid_names=None,
fobj=None, feval=None, init_model=None,
feature_name='auto', categorical_feature='auto',
early_stopping_rounds=None, evals_result=None,
verbose_eval=True, learning_rates=None,
keep_training_booster=False, callbacks=None):
"""Perform the training with given parameters.
Parameters
----------
params : dict
Parameters for training.
train_set : Dataset
Data to be trained on.
num_boost_round : int, optional (default=100)
Number of boosting iterations.
valid_sets : list of Datasets or None, optional (default=None)
List of data to be evaluated on during training.
valid_names : list of strings or None, optional (default=None)
Names of ``valid_sets``.
fobj : callable or None, optional (default=None)
Customized objective function.
Should accept two parameters: preds, train_data,
and return (grad, hess).
preds : list or numpy 1-D array
The predicted values.
train_data : Dataset
The training dataset.
grad : list or numpy 1-D array
The value of the first order derivative (gradient) for each sample point.
hess : list or numpy 1-D array
The value of the second order derivative (Hessian) for each sample point.
For binary task, the preds is margin.
For multi-class task, the preds is group by class_id first, then group by row_id.
If you want to get i-th row preds in j-th class, the access way is score[j * num_data + i]
and you should group grad and hess in this way as well.
feval : callable, list of callable functions or None, optional (default=None)
Customized evaluation function.
Each evaluation function should accept two parameters: preds, train_data,
and return (eval_name, eval_result, is_higher_better) or list of such tuples.
preds : list or numpy 1-D array
The predicted values.
train_data : Dataset
The training dataset.
eval_name : string
The name of evaluation function (without whitespaces).
eval_result : float
The eval result.
is_higher_better : bool
Is eval result higher better, e.g. AUC is ``is_higher_better``.
For binary task, the preds is probability of positive class (or margin in case of specified ``fobj``).
For multi-class task, the preds is group by class_id first, then group by row_id.
If you want to get i-th row preds in j-th class, the access way is preds[j * num_data + i].
To ignore the default metric corresponding to the used objective,
set the ``metric`` parameter to the string ``"None"`` in ``params``.
init_model : string, Booster or None, optional (default=None)
Filename of LightGBM model or Booster instance used for continue training.
feature_name : list of strings or 'auto', optional (default="auto")
Feature names.
If 'auto' and data is pandas DataFrame, data columns names are used.
categorical_feature : list of strings or int, or 'auto', optional (default="auto")
Categorical features.
If list of int, interpreted as indices.
If list of strings, interpreted as feature names (need to specify ``feature_name`` as well).
If 'auto' and data is pandas DataFrame, pandas unordered categorical columns are used.
All values in categorical features should be less than int32 max value (2147483647).
Large values could be memory consuming. Consider using consecutive integers starting from zero.
All negative values in categorical features will be treated as missing values.
The output cannot be monotonically constrained with respect to a categorical feature.
early_stopping_rounds : int or None, optional (default=None)
Activates early stopping. The model will train until the validation score stops improving.
Validation score needs to improve at least every ``early_stopping_rounds`` round(s)
to continue training.
Requires at least one validation data and one metric.
If there's more than one, will check all of them. But the training data is ignored anyway.
To check only the first metric, set the ``first_metric_only`` parameter to ``True`` in ``params``.
The index of iteration that has the best performance will be saved in the ``best_iteration`` field
if early stopping logic is enabled by setting ``early_stopping_rounds``.
evals_result: dict or None, optional (default=None)
This dictionary used to store all evaluation results of all the items in ``valid_sets``.
.. rubric:: Example
With a ``valid_sets`` = [valid_set, train_set],
``valid_names`` = ['eval', 'train']
and a ``params`` = {'metric': 'logloss'}
returns {'train': {'logloss': ['0.48253', '0.35953', ...]},
'eval': {'logloss': ['0.480385', '0.357756', ...]}}.
verbose_eval : bool or int, optional (default=True)
Requires at least one validation data.
If True, the eval metric on the valid set is printed at each boosting stage.
If int, the eval metric on the valid set is printed at every ``verbose_eval`` boosting stage.
The last boosting stage or the boosting stage found by using ``early_stopping_rounds`` is also printed.
.. rubric:: Example
With ``verbose_eval`` = 4 and at least one item in ``valid_sets``,
an evaluation metric is printed every 4 (instead of 1) boosting stages.
learning_rates : list, callable or None, optional (default=None)
List of learning rates for each boosting round
or a customized function that calculates ``learning_rate``
in terms of current number of round (e.g. yields learning rate decay).
keep_training_booster : bool, optional (default=False)
Whether the returned Booster will be used to keep training.
If False, the returned value will be converted into _InnerPredictor before returning.
When your model is very large and cause the memory error,
you can try to set this param to ``True`` to avoid the model conversion performed during the internal call of ``model_to_string``.
You can still use _InnerPredictor as ``init_model`` for future continue training.
callbacks : list of callables or None, optional (default=None)
List of callback functions that are applied at each iteration.
See Callbacks in Python API for more information.
Returns
-------
booster : Booster
The trained Booster model.
"""
# create predictor first
params = copy.deepcopy(params)
if fobj is not None:
for obj_alias in _ConfigAliases.get("objective"):
params.pop(obj_alias, None)
params['objective'] = 'none'
for alias in _ConfigAliases.get("num_iterations"):
if alias in params:
num_boost_round = params.pop(alias)
_log_warning("Found `{}` in params. Will use it instead of argument".format(alias))
params["num_iterations"] = num_boost_round
for alias in _ConfigAliases.get("early_stopping_round"):
if alias in params:
early_stopping_rounds = params.pop(alias)
_log_warning("Found `{}` in params. Will use it instead of argument".format(alias))
params["early_stopping_round"] = early_stopping_rounds
first_metric_only = params.get('first_metric_only', False)
if num_boost_round <= 0:
raise ValueError("num_boost_round should be greater than zero.")
if isinstance(init_model, str):
predictor = _InnerPredictor(model_file=init_model, pred_parameter=params)
elif isinstance(init_model, Booster):
predictor = init_model._to_predictor(dict(init_model.params, **params))
else:
predictor = None
init_iteration = predictor.num_total_iteration if predictor is not None else 0
# check dataset
if not isinstance(train_set, Dataset):
raise TypeError("Training only accepts Dataset object")
train_set._update_params(params) \
._set_predictor(predictor) \
.set_feature_name(feature_name) \
.set_categorical_feature(categorical_feature)
is_valid_contain_train = False
train_data_name = "training"
reduced_valid_sets = []
name_valid_sets = []
if valid_sets is not None:
if isinstance(valid_sets, Dataset):
valid_sets = [valid_sets]
if isinstance(valid_names, str):
valid_names = [valid_names]
for i, valid_data in enumerate(valid_sets):
# reduce cost for prediction training data
if valid_data is train_set:
is_valid_contain_train = True
if valid_names is not None:
train_data_name = valid_names[i]
continue
if not isinstance(valid_data, Dataset):
raise TypeError("Training only accepts Dataset object")
reduced_valid_sets.append(valid_data._update_params(params).set_reference(train_set))
if valid_names is not None and len(valid_names) > i:
name_valid_sets.append(valid_names[i])
else:
name_valid_sets.append('valid_' + str(i))
# process callbacks
if callbacks is None:
callbacks = set()
else:
for i, cb in enumerate(callbacks):
cb.__dict__.setdefault('order', i - len(callbacks))
callbacks = set(callbacks)
# Most of legacy advanced options becomes callbacks
if verbose_eval is True:
callbacks.add(callback.print_evaluation())
elif isinstance(verbose_eval, int):
callbacks.add(callback.print_evaluation(verbose_eval))
if early_stopping_rounds is not None and early_stopping_rounds > 0:
callbacks.add(callback.early_stopping(early_stopping_rounds, first_metric_only, verbose=bool(verbose_eval)))
if learning_rates is not None:
callbacks.add(callback.reset_parameter(learning_rate=learning_rates))
if evals_result is not None:
callbacks.add(callback.record_evaluation(evals_result))
callbacks_before_iter = {cb for cb in callbacks if getattr(cb, 'before_iteration', False)}
callbacks_after_iter = callbacks - callbacks_before_iter
callbacks_before_iter = sorted(callbacks_before_iter, key=attrgetter('order'))
callbacks_after_iter = sorted(callbacks_after_iter, key=attrgetter('order'))
# construct booster
try:
booster = Booster(params=params, train_set=train_set)
if is_valid_contain_train:
booster.set_train_data_name(train_data_name)
for valid_set, name_valid_set in zip(reduced_valid_sets, name_valid_sets):
booster.add_valid(valid_set, name_valid_set)
finally:
train_set._reverse_update_params()
for valid_set in reduced_valid_sets:
valid_set._reverse_update_params()
booster.best_iteration = 0
# start training
for i in range(init_iteration, init_iteration + num_boost_round):
for cb in callbacks_before_iter:
cb(callback.CallbackEnv(model=booster,
params=params,
iteration=i,
begin_iteration=init_iteration,
end_iteration=init_iteration + num_boost_round,
evaluation_result_list=None))
booster.update(fobj=fobj)
evaluation_result_list = []
# check evaluation result.
if valid_sets is not None:
if is_valid_contain_train:
evaluation_result_list.extend(booster.eval_train(feval))
evaluation_result_list.extend(booster.eval_valid(feval))
try:
for cb in callbacks_after_iter:
cb(callback.CallbackEnv(model=booster,
params=params,
iteration=i,
begin_iteration=init_iteration,
end_iteration=init_iteration + num_boost_round,
evaluation_result_list=evaluation_result_list))
except callback.EarlyStopException as earlyStopException:
booster.best_iteration = earlyStopException.best_iteration + 1
evaluation_result_list = earlyStopException.best_score
break
booster.best_score = collections.defaultdict(collections.OrderedDict)
for dataset_name, eval_name, score, _ in evaluation_result_list:
booster.best_score[dataset_name][eval_name] = score
if not keep_training_booster:
booster.model_from_string(booster.model_to_string(), False).free_dataset()
return booster
class CVBooster:
"""CVBooster in LightGBM.
Auxiliary data structure to hold and redirect all boosters of ``cv`` function.
This class has the same methods as Booster class.
All method calls are actually performed for underlying Boosters and then all returned results are returned in a list.
Attributes
----------
boosters : list of Booster
The list of underlying fitted models.
best_iteration : int
The best iteration of fitted model.
"""
def __init__(self):
"""Initialize the CVBooster.
Generally, no need to instantiate manually.
"""
self.boosters = []
self.best_iteration = -1
def _append(self, booster):
"""Add a booster to CVBooster."""
self.boosters.append(booster)
def __getattr__(self, name):
"""Redirect methods call of CVBooster."""
def handler_function(*args, **kwargs):
"""Call methods with each booster, and concatenate their results."""
ret = []
for booster in self.boosters:
ret.append(getattr(booster, name)(*args, **kwargs))
return ret
return handler_function
def _make_n_folds(full_data, folds, nfold, params, seed, fpreproc=None, stratified=True,
shuffle=True, eval_train_metric=False):
"""Make a n-fold list of Booster from random indices."""
full_data = full_data.construct()
num_data = full_data.num_data()
if folds is not None:
if not hasattr(folds, '__iter__') and not hasattr(folds, 'split'):
raise AttributeError("folds should be a generator or iterator of (train_idx, test_idx) tuples "
"or scikit-learn splitter object with split method")
if hasattr(folds, 'split'):
group_info = full_data.get_group()
if group_info is not None:
group_info = np.array(group_info, dtype=np.int32, copy=False)
flatted_group = np.repeat(range(len(group_info)), repeats=group_info)
else:
flatted_group = np.zeros(num_data, dtype=np.int32)
folds = folds.split(X=np.zeros(num_data), y=full_data.get_label(), groups=flatted_group)
else:
if any(params.get(obj_alias, "") in {"lambdarank", "rank_xendcg", "xendcg",
"xe_ndcg", "xe_ndcg_mart", "xendcg_mart"}
for obj_alias in _ConfigAliases.get("objective")):
if not SKLEARN_INSTALLED:
raise LightGBMError('scikit-learn is required for ranking cv')
# ranking task, split according to groups
group_info = np.array(full_data.get_group(), dtype=np.int32, copy=False)
flatted_group = np.repeat(range(len(group_info)), repeats=group_info)
group_kfold = _LGBMGroupKFold(n_splits=nfold)
folds = group_kfold.split(X=np.zeros(num_data), groups=flatted_group)
elif stratified:
if not SKLEARN_INSTALLED:
raise LightGBMError('scikit-learn is required for stratified cv')
skf = _LGBMStratifiedKFold(n_splits=nfold, shuffle=shuffle, random_state=seed)
folds = skf.split(X=np.zeros(num_data), y=full_data.get_label())
else:
if shuffle:
randidx = np.random.RandomState(seed).permutation(num_data)
else:
randidx = np.arange(num_data)
kstep = int(num_data / nfold)
test_id = [randidx[i: i + kstep] for i in range(0, num_data, kstep)]
train_id = [np.concatenate([test_id[i] for i in range(nfold) if k != i]) for k in range(nfold)]
folds = zip(train_id, test_id)
ret = CVBooster()
for train_idx, test_idx in folds:
train_set = full_data.subset(sorted(train_idx))
valid_set = full_data.subset(sorted(test_idx))
# run preprocessing on the data set if needed
if fpreproc is not None:
train_set, valid_set, tparam = fpreproc(train_set, valid_set, params.copy())
else:
tparam = params
cvbooster = Booster(tparam, train_set)
if eval_train_metric:
cvbooster.add_valid(train_set, 'train')
cvbooster.add_valid(valid_set, 'valid')
ret._append(cvbooster)
return ret
def _agg_cv_result(raw_results, eval_train_metric=False):
"""Aggregate cross-validation results."""
cvmap = collections.OrderedDict()
metric_type = {}
for one_result in raw_results:
for one_line in one_result:
if eval_train_metric:
key = "{} {}".format(one_line[0], one_line[1])
else:
key = one_line[1]
metric_type[key] = one_line[3]
cvmap.setdefault(key, [])
cvmap[key].append(one_line[2])
return [('cv_agg', k, np.mean(v), metric_type[k], np.std(v)) for k, v in cvmap.items()]
def cv(params, train_set, num_boost_round=100,
folds=None, nfold=5, stratified=True, shuffle=True,
metrics=None, fobj=None, feval=None, init_model=None,
feature_name='auto', categorical_feature='auto',
early_stopping_rounds=None, fpreproc=None,
verbose_eval=None, show_stdv=True, seed=0,
callbacks=None, eval_train_metric=False,
return_cvbooster=False):
"""Perform the cross-validation with given paramaters.
Parameters
----------
params : dict
Parameters for Booster.
train_set : Dataset
Data to be trained on.
num_boost_round : int, optional (default=100)
Number of boosting iterations.
folds : generator or iterator of (train_idx, test_idx) tuples, scikit-learn splitter object or None, optional (default=None)
If generator or iterator, it should yield the train and test indices for each fold.
If object, it should be one of the scikit-learn splitter classes
(https://scikit-learn.org/stable/modules/classes.html#splitter-classes)
and have ``split`` method.
This argument has highest priority over other data split arguments.
nfold : int, optional (default=5)
Number of folds in CV.
stratified : bool, optional (default=True)
Whether to perform stratified sampling.
shuffle : bool, optional (default=True)
Whether to shuffle before splitting data.
metrics : string, list of strings or None, optional (default=None)
Evaluation metrics to be monitored while CV.
If not None, the metric in ``params`` will be overridden.
fobj : callable or None, optional (default=None)
Customized objective function.
Should accept two parameters: preds, train_data,
and return (grad, hess).
preds : list or numpy 1-D array
The predicted values.
train_data : Dataset
The training dataset.
grad : list or numpy 1-D array
The value of the first order derivative (gradient) for each sample point.
hess : list or numpy 1-D array
The value of the second order derivative (Hessian) for each sample point.
For binary task, the preds is margin.
For multi-class task, the preds is group by class_id first, then group by row_id.
If you want to get i-th row preds in j-th class, the access way is score[j * num_data + i]
and you should group grad and hess in this way as well.
feval : callable, list of callable functions or None, optional (default=None)
Customized evaluation function.
Each evaluation function should accept two parameters: preds, train_data,
and return (eval_name, eval_result, is_higher_better) or list of such tuples.
preds : list or numpy 1-D array
The predicted values.
train_data : Dataset
The training dataset.
eval_name : string
The name of evaluation function (without whitespaces).
eval_result : float
The eval result.
is_higher_better : bool
Is eval result higher better, e.g. AUC is ``is_higher_better``.
For binary task, the preds is probability of positive class (or margin in case of specified ``fobj``).
For multi-class task, the preds is group by class_id first, then group by row_id.
If you want to get i-th row preds in j-th class, the access way is preds[j * num_data + i].
To ignore the default metric corresponding to the used objective,
set ``metrics`` to the string ``"None"``.
init_model : string, Booster or None, optional (default=None)
Filename of LightGBM model or Booster instance used for continue training.
feature_name : list of strings or 'auto', optional (default="auto")
Feature names.
If 'auto' and data is pandas DataFrame, data columns names are used.
categorical_feature : list of strings or int, or 'auto', optional (default="auto")
Categorical features.
If list of int, interpreted as indices.
If list of strings, interpreted as feature names (need to specify ``feature_name`` as well).
If 'auto' and data is pandas DataFrame, pandas unordered categorical columns are used.
All values in categorical features should be less than int32 max value (2147483647).
Large values could be memory consuming. Consider using consecutive integers starting from zero.
All negative values in categorical features will be treated as missing values.
The output cannot be monotonically constrained with respect to a categorical feature.
early_stopping_rounds : int or None, optional (default=None)
Activates early stopping.
CV score needs to improve at least every ``early_stopping_rounds`` round(s)
to continue.
Requires at least one metric. If there's more than one, will check all of them.
To check only the first metric, set the ``first_metric_only`` parameter to ``True`` in ``params``.
Last entry in evaluation history is the one from the best iteration.
fpreproc : callable or None, optional (default=None)
Preprocessing function that takes (dtrain, dtest, params)
and returns transformed versions of those.
verbose_eval : bool, int, or None, optional (default=None)
Whether to display the progress.
If None, progress will be displayed when np.ndarray is returned.
If True, progress will be displayed at every boosting stage.
If int, progress will be displayed at every given ``verbose_eval`` boosting stage.
show_stdv : bool, optional (default=True)
Whether to display the standard deviation in progress.
Results are not affected by this parameter, and always contain std.
seed : int, optional (default=0)
Seed used to generate the folds (passed to numpy.random.seed).
callbacks : list of callables or None, optional (default=None)
List of callback functions that are applied at each iteration.
See Callbacks in Python API for more information.
eval_train_metric : bool, optional (default=False)
Whether to display the train metric in progress.
The score of the metric is calculated again after each training step, so there is some impact on performance.
return_cvbooster : bool, optional (default=False)
Whether to return Booster models trained on each fold through ``CVBooster``.
Returns
-------
eval_hist : dict
Evaluation history.
The dictionary has the following format:
{'metric1-mean': [values], 'metric1-stdv': [values],
'metric2-mean': [values], 'metric2-stdv': [values],
...}.
If ``return_cvbooster=True``, also returns trained boosters via ``cvbooster`` key.
"""
if not isinstance(train_set, Dataset):
raise TypeError("Training only accepts Dataset object")
params = copy.deepcopy(params)
if fobj is not None:
for obj_alias in _ConfigAliases.get("objective"):
params.pop(obj_alias, None)
params['objective'] = 'none'
for alias in _ConfigAliases.get("num_iterations"):
if alias in params:
_log_warning("Found `{}` in params. Will use it instead of argument".format(alias))
num_boost_round = params.pop(alias)
params["num_iterations"] = num_boost_round
for alias in _ConfigAliases.get("early_stopping_round"):
if alias in params:
_log_warning("Found `{}` in params. Will use it instead of argument".format(alias))
early_stopping_rounds = params.pop(alias)
params["early_stopping_round"] = early_stopping_rounds
first_metric_only = params.get('first_metric_only', False)
if num_boost_round <= 0:
raise ValueError("num_boost_round should be greater than zero.")
if isinstance(init_model, str):
predictor = _InnerPredictor(model_file=init_model, pred_parameter=params)
elif isinstance(init_model, Booster):
predictor = init_model._to_predictor(dict(init_model.params, **params))
else:
predictor = None
if metrics is not None:
for metric_alias in _ConfigAliases.get("metric"):
params.pop(metric_alias, None)
params['metric'] = metrics
train_set._update_params(params) \
._set_predictor(predictor) \
.set_feature_name(feature_name) \
.set_categorical_feature(categorical_feature)
results = collections.defaultdict(list)
cvfolds = _make_n_folds(train_set, folds=folds, nfold=nfold,
params=params, seed=seed, fpreproc=fpreproc,
stratified=stratified, shuffle=shuffle,
eval_train_metric=eval_train_metric)
# setup callbacks
if callbacks is None:
callbacks = set()
else:
for i, cb in enumerate(callbacks):
cb.__dict__.setdefault('order', i - len(callbacks))
callbacks = set(callbacks)
if early_stopping_rounds is not None and early_stopping_rounds > 0:
callbacks.add(callback.early_stopping(early_stopping_rounds, first_metric_only, verbose=False))
if verbose_eval is True:
callbacks.add(callback.print_evaluation(show_stdv=show_stdv))
elif isinstance(verbose_eval, int):
callbacks.add(callback.print_evaluation(verbose_eval, show_stdv=show_stdv))
callbacks_before_iter = {cb for cb in callbacks if getattr(cb, 'before_iteration', False)}
callbacks_after_iter = callbacks - callbacks_before_iter
callbacks_before_iter = sorted(callbacks_before_iter, key=attrgetter('order'))
callbacks_after_iter = sorted(callbacks_after_iter, key=attrgetter('order'))
for i in range(num_boost_round):
for cb in callbacks_before_iter:
cb(callback.CallbackEnv(model=cvfolds,
params=params,
iteration=i,
begin_iteration=0,
end_iteration=num_boost_round,
evaluation_result_list=None))
cvfolds.update(fobj=fobj)
res = _agg_cv_result(cvfolds.eval_valid(feval), eval_train_metric)
for _, key, mean, _, std in res:
results[key + '-mean'].append(mean)
results[key + '-stdv'].append(std)
try:
for cb in callbacks_after_iter:
cb(callback.CallbackEnv(model=cvfolds,
params=params,
iteration=i,
begin_iteration=0,
end_iteration=num_boost_round,
evaluation_result_list=res))
except callback.EarlyStopException as earlyStopException:
cvfolds.best_iteration = earlyStopException.best_iteration + 1
for k in results:
results[k] = results[k][:cvfolds.best_iteration]
break
if return_cvbooster:
results['cvbooster'] = cvfolds
return dict(results)