-
Notifications
You must be signed in to change notification settings - Fork 3.8k
/
Copy pathdask.py
1588 lines (1379 loc) · 63.5 KB
/
dask.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# coding: utf-8
"""Distributed training with LightGBM and dask.distributed.
This module enables you to perform distributed training with LightGBM on
dask.Array and dask.DataFrame collections.
It is based on dask-lightgbm, which was based on dask-xgboost.
"""
import socket
from collections import defaultdict, namedtuple
from copy import deepcopy
from enum import Enum, auto
from functools import partial
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, Union
from urllib.parse import urlparse
import numpy as np
import scipy.sparse as ss
from .basic import _LIB, LightGBMError, _choose_param_value, _ConfigAliases, _log_info, _log_warning, _safe_call
from .compat import (DASK_INSTALLED, PANDAS_INSTALLED, SKLEARN_INSTALLED, Client, LGBMNotFittedError, concat,
dask_Array, dask_array_from_delayed, dask_bag_from_delayed, dask_DataFrame, dask_Series,
default_client, delayed, pd_DataFrame, pd_Series, wait)
from .sklearn import (LGBMClassifier, LGBMModel, LGBMRanker, LGBMRegressor, _lgbmmodel_doc_custom_eval_note,
_lgbmmodel_doc_fit, _lgbmmodel_doc_predict)
_DaskCollection = Union[dask_Array, dask_DataFrame, dask_Series]
_DaskMatrixLike = Union[dask_Array, dask_DataFrame]
_DaskVectorLike = Union[dask_Array, dask_Series]
_DaskPart = Union[np.ndarray, pd_DataFrame, pd_Series, ss.spmatrix]
_PredictionDtype = Union[Type[np.float32], Type[np.float64], Type[np.int32], Type[np.int64]]
_HostWorkers = namedtuple('_HostWorkers', ['default', 'all'])
class _DatasetNames(Enum):
"""Placeholder names used by lightgbm.dask internals to say 'also evaluate the training data'.
Avoid duplicating the training data when the validation set refers to elements of training data.
"""
TRAINSET = auto()
SAMPLE_WEIGHT = auto()
INIT_SCORE = auto()
GROUP = auto()
def _get_dask_client(client: Optional[Client]) -> Client:
"""Choose a Dask client to use.
Parameters
----------
client : dask.distributed.Client or None
Dask client.
Returns
-------
client : dask.distributed.Client
A Dask client.
"""
if client is None:
return default_client()
else:
return client
def _find_n_open_ports(n: int) -> List[int]:
"""Find n random open ports on localhost.
Returns
-------
ports : list of int
n random open ports on localhost.
"""
sockets = []
for _ in range(n):
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.bind(('', 0))
sockets.append(s)
ports = []
for s in sockets:
ports.append(s.getsockname()[1])
s.close()
return ports
def _group_workers_by_host(worker_addresses: Iterable[str]) -> Dict[str, _HostWorkers]:
"""Group all worker addresses by hostname.
Returns
-------
host_to_workers : dict
mapping from hostname to all its workers.
"""
host_to_workers: Dict[str, _HostWorkers] = {}
for address in worker_addresses:
hostname = urlparse(address).hostname
if hostname not in host_to_workers:
host_to_workers[hostname] = _HostWorkers(default=address, all=[address])
else:
host_to_workers[hostname].all.append(address)
return host_to_workers
def _assign_open_ports_to_workers(
client: Client,
host_to_workers: Dict[str, _HostWorkers]
) -> Dict[str, int]:
"""Assign an open port to each worker.
Returns
-------
worker_to_port: dict
mapping from worker address to an open port.
"""
host_ports_futures = {}
for hostname, workers in host_to_workers.items():
n_workers_in_host = len(workers.all)
host_ports_futures[hostname] = client.submit(
_find_n_open_ports,
n=n_workers_in_host,
workers=[workers.default],
pure=False,
allow_other_workers=False,
)
found_ports = client.gather(host_ports_futures)
worker_to_port = {}
for hostname, workers in host_to_workers.items():
for worker, port in zip(workers.all, found_ports[hostname]):
worker_to_port[worker] = port
return worker_to_port
def _concat(seq: List[_DaskPart]) -> _DaskPart:
if isinstance(seq[0], np.ndarray):
return np.concatenate(seq, axis=0)
elif isinstance(seq[0], (pd_DataFrame, pd_Series)):
return concat(seq, axis=0)
elif isinstance(seq[0], ss.spmatrix):
return ss.vstack(seq, format='csr')
else:
raise TypeError(f'Data must be one of: numpy arrays, pandas dataframes, sparse matrices (from scipy). Got {type(seq[0]).__name__}.')
def _remove_list_padding(*args: Any) -> List[List[Any]]:
return [[z for z in arg if z is not None] for arg in args]
def _pad_eval_names(lgbm_model: LGBMModel, required_names: List[str]) -> LGBMModel:
"""Append missing (key, value) pairs to a LightGBM model's evals_result_ and best_score_ OrderedDict attrs based on a set of required eval_set names.
Allows users to rely on expected eval_set names being present when fitting DaskLGBM estimators with ``eval_set``.
"""
not_evaluated = 'not evaluated'
for eval_name in required_names:
if eval_name not in lgbm_model.evals_result_:
lgbm_model.evals_result_[eval_name] = not_evaluated
if eval_name not in lgbm_model.best_score_:
lgbm_model.best_score_[eval_name] = not_evaluated
return lgbm_model
def _train_part(
params: Dict[str, Any],
model_factory: Type[LGBMModel],
list_of_parts: List[Dict[str, _DaskPart]],
machines: str,
local_listen_port: int,
num_machines: int,
return_model: bool,
time_out: int = 120,
**kwargs: Any
) -> Optional[LGBMModel]:
network_params = {
'machines': machines,
'local_listen_port': local_listen_port,
'time_out': time_out,
'num_machines': num_machines
}
params.update(network_params)
is_ranker = issubclass(model_factory, LGBMRanker)
# Concatenate many parts into one
data = _concat([x['data'] for x in list_of_parts])
label = _concat([x['label'] for x in list_of_parts])
if 'weight' in list_of_parts[0]:
weight = _concat([x['weight'] for x in list_of_parts])
else:
weight = None
if 'group' in list_of_parts[0]:
group = _concat([x['group'] for x in list_of_parts])
else:
group = None
if 'init_score' in list_of_parts[0]:
init_score = _concat([x['init_score'] for x in list_of_parts])
else:
init_score = None
# construct local eval_set data.
n_evals = max(len(x.get('eval_set', [])) for x in list_of_parts)
eval_names = kwargs.pop('eval_names', None)
eval_class_weight = kwargs.get('eval_class_weight')
local_eval_set = None
local_eval_names = None
local_eval_sample_weight = None
local_eval_init_score = None
local_eval_group = None
if n_evals:
has_eval_sample_weight = any(x.get('eval_sample_weight') is not None for x in list_of_parts)
has_eval_init_score = any(x.get('eval_init_score') is not None for x in list_of_parts)
local_eval_set = []
evals_result_names = []
if has_eval_sample_weight:
local_eval_sample_weight = []
if has_eval_init_score:
local_eval_init_score = []
if is_ranker:
local_eval_group = []
# store indices of eval_set components that were not contained within local parts.
missing_eval_component_idx = []
# consolidate parts of each individual eval component.
for i in range(n_evals):
x_e = []
y_e = []
w_e = []
init_score_e = []
g_e = []
for part in list_of_parts:
if not part.get('eval_set'):
continue
# require that eval_name exists in evaluated result data in case dropped due to padding.
# in distributed training the 'training' eval_set is not detected, will have name 'valid_<index>'.
if eval_names:
evals_result_name = eval_names[i]
else:
evals_result_name = f'valid_{i}'
eval_set = part['eval_set'][i]
if eval_set is _DatasetNames.TRAINSET:
x_e.append(part['data'])
y_e.append(part['label'])
else:
x_e.extend(eval_set[0])
y_e.extend(eval_set[1])
if evals_result_name not in evals_result_names:
evals_result_names.append(evals_result_name)
eval_weight = part.get('eval_sample_weight')
if eval_weight:
if eval_weight[i] is _DatasetNames.SAMPLE_WEIGHT:
w_e.append(part['weight'])
else:
w_e.extend(eval_weight[i])
eval_init_score = part.get('eval_init_score')
if eval_init_score:
if eval_init_score[i] is _DatasetNames.INIT_SCORE:
init_score_e.append(part['init_score'])
else:
init_score_e.extend(eval_init_score[i])
eval_group = part.get('eval_group')
if eval_group:
if eval_group[i] is _DatasetNames.GROUP:
g_e.append(part['group'])
else:
g_e.extend(eval_group[i])
# filter padding from eval parts then _concat each eval_set component.
x_e, y_e, w_e, init_score_e, g_e = _remove_list_padding(x_e, y_e, w_e, init_score_e, g_e)
if x_e:
local_eval_set.append((_concat(x_e), _concat(y_e)))
else:
missing_eval_component_idx.append(i)
continue
if w_e:
local_eval_sample_weight.append(_concat(w_e))
if init_score_e:
local_eval_init_score.append(_concat(init_score_e))
if g_e:
local_eval_group.append(_concat(g_e))
# reconstruct eval_set fit args/kwargs depending on which components of eval_set are on worker.
eval_component_idx = [i for i in range(n_evals) if i not in missing_eval_component_idx]
if eval_names:
local_eval_names = [eval_names[i] for i in eval_component_idx]
if eval_class_weight:
kwargs['eval_class_weight'] = [eval_class_weight[i] for i in eval_component_idx]
try:
model = model_factory(**params)
if is_ranker:
model.fit(
data,
label,
sample_weight=weight,
init_score=init_score,
group=group,
eval_set=local_eval_set,
eval_sample_weight=local_eval_sample_weight,
eval_init_score=local_eval_init_score,
eval_group=local_eval_group,
eval_names=local_eval_names,
**kwargs
)
else:
model.fit(
data,
label,
sample_weight=weight,
init_score=init_score,
eval_set=local_eval_set,
eval_sample_weight=local_eval_sample_weight,
eval_init_score=local_eval_init_score,
eval_names=local_eval_names,
**kwargs
)
finally:
_safe_call(_LIB.LGBM_NetworkFree())
if n_evals:
# ensure that expected keys for evals_result_ and best_score_ exist regardless of padding.
model = _pad_eval_names(model, required_names=evals_result_names)
return model if return_model else None
def _split_to_parts(data: _DaskCollection, is_matrix: bool) -> List[_DaskPart]:
parts = data.to_delayed()
if isinstance(parts, np.ndarray):
if is_matrix:
assert parts.shape[1] == 1
else:
assert parts.ndim == 1 or parts.shape[1] == 1
parts = parts.flatten().tolist()
return parts
def _machines_to_worker_map(machines: str, worker_addresses: List[str]) -> Dict[str, int]:
"""Create a worker_map from machines list.
Given ``machines`` and a list of Dask worker addresses, return a mapping where the keys are
``worker_addresses`` and the values are ports from ``machines``.
Parameters
----------
machines : str
A comma-delimited list of workers, of the form ``ip1:port,ip2:port``.
worker_addresses : list of str
A list of Dask worker addresses, of the form ``{protocol}{hostname}:{port}``, where ``port`` is the port Dask's scheduler uses to talk to that worker.
Returns
-------
result : Dict[str, int]
Dictionary where keys are work addresses in the form expected by Dask and values are a port for LightGBM to use.
"""
machine_addresses = machines.split(",")
if len(set(machine_addresses)) != len(machine_addresses):
raise ValueError(f"Found duplicates in 'machines' ({machines}). Each entry in 'machines' must be a unique IP-port combination.")
machine_to_port = defaultdict(set)
for address in machine_addresses:
host, port = address.split(":")
machine_to_port[host].add(int(port))
out = {}
for address in worker_addresses:
worker_host = urlparse(address).hostname
out[address] = machine_to_port[worker_host].pop()
return out
def _train(
client: Client,
data: _DaskMatrixLike,
label: _DaskCollection,
params: Dict[str, Any],
model_factory: Type[LGBMModel],
sample_weight: Optional[_DaskVectorLike] = None,
init_score: Optional[_DaskCollection] = None,
group: Optional[_DaskVectorLike] = None,
eval_set: Optional[List[Tuple[_DaskMatrixLike, _DaskCollection]]] = None,
eval_names: Optional[List[str]] = None,
eval_sample_weight: Optional[List[_DaskVectorLike]] = None,
eval_class_weight: Optional[List[Union[dict, str]]] = None,
eval_init_score: Optional[List[_DaskCollection]] = None,
eval_group: Optional[List[_DaskVectorLike]] = None,
eval_metric: Optional[Union[Callable, str, List[Union[Callable, str]]]] = None,
eval_at: Optional[Iterable[int]] = None,
**kwargs: Any
) -> LGBMModel:
"""Inner train routine.
Parameters
----------
client : dask.distributed.Client
Dask client.
data : Dask Array or Dask DataFrame of shape = [n_samples, n_features]
Input feature matrix.
label : Dask Array, Dask DataFrame or Dask Series of shape = [n_samples]
The target values (class labels in classification, real numbers in regression).
params : dict
Parameters passed to constructor of the local underlying model.
model_factory : lightgbm.LGBMClassifier, lightgbm.LGBMRegressor, or lightgbm.LGBMRanker class
Class of the local underlying model.
sample_weight : Dask Array or Dask Series of shape = [n_samples] or None, optional (default=None)
Weights of training data.
init_score : Dask Array or Dask Series of shape = [n_samples] or shape = [n_samples * n_classes] (for multi-class task), or Dask Array or Dask DataFrame of shape = [n_samples, n_classes] (for multi-class task), or None, optional (default=None)
Init score of training data.
group : Dask Array or Dask Series or None, optional (default=None)
Group/query data.
Only used in the learning-to-rank task.
sum(group) = n_samples.
For example, if you have a 100-document dataset with ``group = [10, 20, 40, 10, 10, 10]``, that means that you have 6 groups,
where the first 10 records are in the first group, records 11-30 are in the second group, records 31-70 are in the third group, etc.
eval_set : list of (X, y) tuples of Dask data collections, or None, optional (default=None)
List of (X, y) tuple pairs to use as validation sets.
Note, that not all workers may receive chunks of every eval set within ``eval_set``. When the returned
lightgbm estimator is not trained using any chunks of a particular eval set, its corresponding component
of evals_result_ and best_score_ will be 'not_evaluated'.
eval_names : list of str, or None, optional (default=None)
Names of eval_set.
eval_sample_weight : list of Dask Array or Dask Series, or None, optional (default=None)
Weights for each validation set in eval_set.
eval_class_weight : list of dict or str, or None, optional (default=None)
Class weights, one dict or str for each validation set in eval_set.
eval_init_score : list of Dask Array, Dask Series or Dask DataFrame (for multi-class task), or None, optional (default=None)
Initial model score for each validation set in eval_set.
eval_group : list of Dask Array or Dask Series, or None, optional (default=None)
Group/query for each validation set in eval_set.
eval_metric : str, callable, list or None, optional (default=None)
If str, it should be a built-in evaluation metric to use.
If callable, it should be a custom evaluation metric, see note below for more details.
If list, it can be a list of built-in metrics, a list of custom evaluation metrics, or a mix of both.
In either case, the ``metric`` from the Dask model parameters (or inferred from the objective) will be evaluated and used as well.
Default: 'l2' for DaskLGBMRegressor, 'binary(multi)_logloss' for DaskLGBMClassifier, 'ndcg' for DaskLGBMRanker.
eval_at : iterable of int, optional (default=None)
The evaluation positions of the specified ranking metric.
**kwargs
Other parameters passed to ``fit`` method of the local underlying model.
Returns
-------
model : lightgbm.LGBMClassifier, lightgbm.LGBMRegressor, or lightgbm.LGBMRanker class
Returns fitted underlying model.
Note
----
This method handles setting up the following network parameters based on information
about the Dask cluster referenced by ``client``.
* ``local_listen_port``: port that each LightGBM worker opens a listening socket on,
to accept connections from other workers. This can differ from LightGBM worker
to LightGBM worker, but does not have to.
* ``machines``: a comma-delimited list of all workers in the cluster, in the
form ``ip:port,ip:port``. If running multiple Dask workers on the same host, use different
ports for each worker. For example, for ``LocalCluster(n_workers=3)``, you might
pass ``"127.0.0.1:12400,127.0.0.1:12401,127.0.0.1:12402"``.
* ``num_machines``: number of LightGBM workers.
* ``timeout``: time in minutes to wait before closing unused sockets.
The default behavior of this function is to generate ``machines`` from the list of
Dask workers which hold some piece of the training data, and to search for an open
port on each worker to be used as ``local_listen_port``.
If ``machines`` is provided explicitly in ``params``, this function uses the hosts
and ports in that list directly, and does not do any searching. This means that if
any of the Dask workers are missing from the list or any of those ports are not free
when training starts, training will fail.
If ``local_listen_port`` is provided in ``params`` and ``machines`` is not, this function
constructs ``machines`` from the list of Dask workers which hold some piece of the
training data, assuming that each one will use the same ``local_listen_port``.
"""
params = deepcopy(params)
# capture whether local_listen_port or its aliases were provided
listen_port_in_params = any(
alias in params for alias in _ConfigAliases.get("local_listen_port")
)
# capture whether machines or its aliases were provided
machines_in_params = any(
alias in params for alias in _ConfigAliases.get("machines")
)
params = _choose_param_value(
main_param_name="tree_learner",
params=params,
default_value="data"
)
allowed_tree_learners = {
'data',
'data_parallel',
'feature',
'feature_parallel',
'voting',
'voting_parallel'
}
if params["tree_learner"] not in allowed_tree_learners:
_log_warning(f'Parameter tree_learner set to {params["tree_learner"]}, which is not allowed. Using "data" as default')
params['tree_learner'] = 'data'
# Some passed-in parameters can be removed:
# * 'num_machines': set automatically from Dask worker list
# * 'num_threads': overridden to match nthreads on each Dask process
for param_alias in _ConfigAliases.get('num_machines', 'num_threads'):
if param_alias in params:
_log_warning(f"Parameter {param_alias} will be ignored.")
params.pop(param_alias)
# Split arrays/dataframes into parts. Arrange parts into dicts to enforce co-locality
data_parts = _split_to_parts(data=data, is_matrix=True)
label_parts = _split_to_parts(data=label, is_matrix=False)
parts = [{'data': x, 'label': y} for (x, y) in zip(data_parts, label_parts)]
n_parts = len(parts)
if sample_weight is not None:
weight_parts = _split_to_parts(data=sample_weight, is_matrix=False)
for i in range(n_parts):
parts[i]['weight'] = weight_parts[i]
if group is not None:
group_parts = _split_to_parts(data=group, is_matrix=False)
for i in range(n_parts):
parts[i]['group'] = group_parts[i]
if init_score is not None:
init_score_parts = _split_to_parts(data=init_score, is_matrix=False)
for i in range(n_parts):
parts[i]['init_score'] = init_score_parts[i]
# evals_set will to be re-constructed into smaller lists of (X, y) tuples, where
# X and y are each delayed sub-lists of original eval dask Collections.
if eval_set:
# find maximum number of parts in an individual eval set so that we can
# pad eval sets when they come in different sizes.
n_largest_eval_parts = max(x[0].npartitions for x in eval_set)
eval_sets = defaultdict(list)
if eval_sample_weight:
eval_sample_weights = defaultdict(list)
if eval_group:
eval_groups = defaultdict(list)
if eval_init_score:
eval_init_scores = defaultdict(list)
for i, (X_eval, y_eval) in enumerate(eval_set):
n_this_eval_parts = X_eval.npartitions
# when individual eval set is equivalent to training data, skip recomputing parts.
if X_eval is data and y_eval is label:
for parts_idx in range(n_parts):
eval_sets[parts_idx].append(_DatasetNames.TRAINSET)
else:
eval_x_parts = _split_to_parts(data=X_eval, is_matrix=True)
eval_y_parts = _split_to_parts(data=y_eval, is_matrix=False)
for j in range(n_largest_eval_parts):
parts_idx = j % n_parts
# add None-padding for individual eval_set member if it is smaller than the largest member.
if j < n_this_eval_parts:
x_e = eval_x_parts[j]
y_e = eval_y_parts[j]
else:
x_e = None
y_e = None
if j < n_parts:
# first time a chunk of this eval set is added to this part.
eval_sets[parts_idx].append(([x_e], [y_e]))
else:
# append additional chunks of this eval set to this part.
eval_sets[parts_idx][-1][0].append(x_e)
eval_sets[parts_idx][-1][1].append(y_e)
if eval_sample_weight:
if eval_sample_weight[i] is sample_weight:
for parts_idx in range(n_parts):
eval_sample_weights[parts_idx].append(_DatasetNames.SAMPLE_WEIGHT)
else:
eval_w_parts = _split_to_parts(data=eval_sample_weight[i], is_matrix=False)
# ensure that all evaluation parts map uniquely to one part.
for j in range(n_largest_eval_parts):
if j < n_this_eval_parts:
w_e = eval_w_parts[j]
else:
w_e = None
parts_idx = j % n_parts
if j < n_parts:
eval_sample_weights[parts_idx].append([w_e])
else:
eval_sample_weights[parts_idx][-1].append(w_e)
if eval_init_score:
if eval_init_score[i] is init_score:
for parts_idx in range(n_parts):
eval_init_scores[parts_idx].append(_DatasetNames.INIT_SCORE)
else:
eval_init_score_parts = _split_to_parts(data=eval_init_score[i], is_matrix=False)
for j in range(n_largest_eval_parts):
if j < n_this_eval_parts:
init_score_e = eval_init_score_parts[j]
else:
init_score_e = None
parts_idx = j % n_parts
if j < n_parts:
eval_init_scores[parts_idx].append([init_score_e])
else:
eval_init_scores[parts_idx][-1].append(init_score_e)
if eval_group:
if eval_group[i] is group:
for parts_idx in range(n_parts):
eval_groups[parts_idx].append(_DatasetNames.GROUP)
else:
eval_g_parts = _split_to_parts(data=eval_group[i], is_matrix=False)
for j in range(n_largest_eval_parts):
if j < n_this_eval_parts:
g_e = eval_g_parts[j]
else:
g_e = None
parts_idx = j % n_parts
if j < n_parts:
eval_groups[parts_idx].append([g_e])
else:
eval_groups[parts_idx][-1].append(g_e)
# assign sub-eval_set components to worker parts.
for parts_idx, e_set in eval_sets.items():
parts[parts_idx]['eval_set'] = e_set
if eval_sample_weight:
parts[parts_idx]['eval_sample_weight'] = eval_sample_weights[parts_idx]
if eval_init_score:
parts[parts_idx]['eval_init_score'] = eval_init_scores[parts_idx]
if eval_group:
parts[parts_idx]['eval_group'] = eval_groups[parts_idx]
# Start computation in the background
parts = list(map(delayed, parts))
parts = client.compute(parts)
wait(parts)
for part in parts:
if part.status == 'error': # type: ignore
return part # trigger error locally
# Find locations of all parts and map them to particular Dask workers
key_to_part_dict = {part.key: part for part in parts} # type: ignore
who_has = client.who_has(parts)
worker_map = defaultdict(list)
for key, workers in who_has.items():
worker_map[next(iter(workers))].append(key_to_part_dict[key])
# Check that all workers were provided some of eval_set. Otherwise warn user that validation
# data artifacts may not be populated depending on worker returning final estimator.
if eval_set:
for worker in worker_map:
has_eval_set = False
for part in worker_map[worker]:
if 'eval_set' in part.result():
has_eval_set = True
break
if not has_eval_set:
_log_warning(
f"Worker {worker} was not allocated eval_set data. Therefore evals_result_ and best_score_ data may be unreliable. "
"Try rebalancing data across workers."
)
# assign general validation set settings to fit kwargs.
if eval_names:
kwargs['eval_names'] = eval_names
if eval_class_weight:
kwargs['eval_class_weight'] = eval_class_weight
if eval_metric:
kwargs['eval_metric'] = eval_metric
if eval_at:
kwargs['eval_at'] = eval_at
master_worker = next(iter(worker_map))
worker_ncores = client.ncores()
# resolve aliases for network parameters and pop the result off params.
# these values are added back in calls to `_train_part()`
params = _choose_param_value(
main_param_name="local_listen_port",
params=params,
default_value=12400
)
local_listen_port = params.pop("local_listen_port")
params = _choose_param_value(
main_param_name="machines",
params=params,
default_value=None
)
machines = params.pop("machines")
# figure out network params
worker_addresses = worker_map.keys()
if machines is not None:
_log_info("Using passed-in 'machines' parameter")
worker_address_to_port = _machines_to_worker_map(
machines=machines,
worker_addresses=worker_addresses
)
else:
if listen_port_in_params:
_log_info("Using passed-in 'local_listen_port' for all workers")
unique_hosts = set(urlparse(a).hostname for a in worker_addresses)
if len(unique_hosts) < len(worker_addresses):
msg = (
"'local_listen_port' was provided in Dask training parameters, but at least one "
"machine in the cluster has multiple Dask worker processes running on it. Please omit "
"'local_listen_port' or pass 'machines'."
)
raise LightGBMError(msg)
worker_address_to_port = {
address: local_listen_port
for address in worker_addresses
}
else:
_log_info("Finding random open ports for workers")
host_to_workers = _group_workers_by_host(worker_map.keys())
worker_address_to_port = _assign_open_ports_to_workers(client, host_to_workers)
machines = ','.join([
f'{urlparse(worker_address).hostname}:{port}'
for worker_address, port
in worker_address_to_port.items()
])
num_machines = len(worker_address_to_port)
# Tell each worker to train on the parts that it has locally
#
# This code treats ``_train_part()`` calls as not "pure" because:
# 1. there is randomness in the training process unless parameters ``seed``
# and ``deterministic`` are set
# 2. even with those parameters set, the output of one ``_train_part()`` call
# relies on global state (it and all the other LightGBM training processes
# coordinate with each other)
futures_classifiers = [
client.submit(
_train_part,
model_factory=model_factory,
params={**params, 'num_threads': worker_ncores[worker]},
list_of_parts=list_of_parts,
machines=machines,
local_listen_port=worker_address_to_port[worker],
num_machines=num_machines,
time_out=params.get('time_out', 120),
return_model=(worker == master_worker),
workers=[worker],
allow_other_workers=False,
pure=False,
**kwargs
)
for worker, list_of_parts in worker_map.items()
]
results = client.gather(futures_classifiers)
results = [v for v in results if v]
model = results[0]
# if network parameters were changed during training, remove them from the
# returned model so that they're generated dynamically on every run based
# on the Dask cluster you're connected to and which workers have pieces of
# the training data
if not listen_port_in_params:
for param in _ConfigAliases.get('local_listen_port'):
model._other_params.pop(param, None)
if not machines_in_params:
for param in _ConfigAliases.get('machines'):
model._other_params.pop(param, None)
for param in _ConfigAliases.get('num_machines', 'timeout'):
model._other_params.pop(param, None)
return model
def _predict_part(
part: _DaskPart,
model: LGBMModel,
raw_score: bool,
pred_proba: bool,
pred_leaf: bool,
pred_contrib: bool,
**kwargs: Any
) -> _DaskPart:
if part.shape[0] == 0:
result = np.array([])
elif pred_proba:
result = model.predict_proba(
part,
raw_score=raw_score,
pred_leaf=pred_leaf,
pred_contrib=pred_contrib,
**kwargs
)
else:
result = model.predict(
part,
raw_score=raw_score,
pred_leaf=pred_leaf,
pred_contrib=pred_contrib,
**kwargs
)
# dask.DataFrame.map_partitions() expects each call to return a pandas DataFrame or Series
if isinstance(part, pd_DataFrame):
if len(result.shape) == 2:
result = pd_DataFrame(result, index=part.index)
else:
result = pd_Series(result, index=part.index, name='predictions')
return result
def _predict(
model: LGBMModel,
data: _DaskMatrixLike,
client: Client,
raw_score: bool = False,
pred_proba: bool = False,
pred_leaf: bool = False,
pred_contrib: bool = False,
dtype: _PredictionDtype = np.float32,
**kwargs: Any
) -> Union[dask_Array, List[dask_Array]]:
"""Inner predict routine.
Parameters
----------
model : lightgbm.LGBMClassifier, lightgbm.LGBMRegressor, or lightgbm.LGBMRanker class
Fitted underlying model.
data : Dask Array or Dask DataFrame of shape = [n_samples, n_features]
Input feature matrix.
raw_score : bool, optional (default=False)
Whether to predict raw scores.
pred_proba : bool, optional (default=False)
Should method return results of ``predict_proba`` (``pred_proba=True``) or ``predict`` (``pred_proba=False``).
pred_leaf : bool, optional (default=False)
Whether to predict leaf index.
pred_contrib : bool, optional (default=False)
Whether to predict feature contributions.
dtype : np.dtype, optional (default=np.float32)
Dtype of the output.
**kwargs
Other parameters passed to ``predict`` or ``predict_proba`` method.
Returns
-------
predicted_result : Dask Array of shape = [n_samples] or shape = [n_samples, n_classes]
The predicted values.
X_leaves : Dask Array of shape = [n_samples, n_trees] or shape = [n_samples, n_trees * n_classes]
If ``pred_leaf=True``, the predicted leaf of every tree for each sample.
X_SHAP_values : Dask Array of shape = [n_samples, n_features + 1] or shape = [n_samples, (n_features + 1) * n_classes] or (if multi-class and using sparse inputs) a list of ``n_classes`` Dask Arrays of shape = [n_samples, n_features + 1]
If ``pred_contrib=True``, the feature contributions for each sample.
"""
if not all((DASK_INSTALLED, PANDAS_INSTALLED, SKLEARN_INSTALLED)):
raise LightGBMError('dask, pandas and scikit-learn are required for lightgbm.dask')
if isinstance(data, dask_DataFrame):
return data.map_partitions(
_predict_part,
model=model,
raw_score=raw_score,
pred_proba=pred_proba,
pred_leaf=pred_leaf,
pred_contrib=pred_contrib,
**kwargs
).values
elif isinstance(data, dask_Array):
# for multi-class classification with sparse matrices, pred_contrib predictions
# are returned as a list of sparse matrices (one per class)
num_classes = model._n_classes or -1
if (
num_classes > 2
and pred_contrib
and isinstance(data._meta, ss.spmatrix)
):
predict_function = partial(
_predict_part,
model=model,
raw_score=False,
pred_proba=pred_proba,
pred_leaf=False,
pred_contrib=True,
**kwargs
)
delayed_chunks = data.to_delayed()
bag = dask_bag_from_delayed(delayed_chunks[:, 0])
@delayed
def _extract(items: List[Any], i: int) -> Any:
return items[i]
preds = bag.map_partitions(predict_function)
# pred_contrib output will have one column per feature,
# plus one more for the base value
num_cols = model.n_features_ + 1
nrows_per_chunk = data.chunks[0]
out = [[] for _ in range(num_classes)]
# need to tell Dask the expected type and shape of individual preds
pred_meta = data._meta
for j, partition in enumerate(preds.to_delayed()):
for i in range(num_classes):
part = dask_array_from_delayed(
value=_extract(partition, i),
shape=(nrows_per_chunk[j], num_cols),
meta=pred_meta
)
out[i].append(part)
# by default, dask.array.concatenate() concatenates sparse arrays into a COO matrix
# the code below is used instead to ensure that the sparse type is preserved during concatentation
if isinstance(pred_meta, ss.csr_matrix):
concat_fn = partial(ss.vstack, format='csr')
elif isinstance(pred_meta, ss.csc_matrix):
concat_fn = partial(ss.vstack, format='csc')
else:
concat_fn = ss.vstack
# At this point, `out` is a list of lists of delayeds (each of which points to a matrix).
# Concatenate them to return a list of Dask Arrays.
for i in range(num_classes):
out[i] = dask_array_from_delayed(
value=delayed(concat_fn)(out[i]),
shape=(data.shape[0], num_cols),
meta=pred_meta
)
return out
data_row = client.compute(data[[0]]).result()
predict_fn = partial(
_predict_part,
model=model,
raw_score=raw_score,
pred_proba=pred_proba,
pred_leaf=pred_leaf,
pred_contrib=pred_contrib,
**kwargs,
)
pred_row = predict_fn(data_row)
chunks = (data.chunks[0],)
map_blocks_kwargs = {}
if len(pred_row.shape) > 1:
chunks += (pred_row.shape[1],)
else:
map_blocks_kwargs['drop_axis'] = 1
return data.map_blocks(
predict_fn,
chunks=chunks,
meta=pred_row,
dtype=dtype,
**map_blocks_kwargs,
)
else:
raise TypeError(f'Data must be either Dask Array or Dask DataFrame. Got {type(data).__name__}.')
class _DaskLGBMModel:
@property
def client_(self) -> Client:
""":obj:`dask.distributed.Client`: Dask client.