diff --git a/src/learner.cc b/src/learner.cc index 1269a6b2b6d7..d21f071471f4 100644 --- a/src/learner.cc +++ b/src/learner.cc @@ -485,10 +485,10 @@ class LearnerImpl : public Learner { this->PerformTreeMethodHeuristic(train); monitor_.Start("PredictRaw"); - this->PredictRaw(train, &preds_); + this->PredictRaw(train, &preds_[train]); monitor_.Stop("PredictRaw"); monitor_.Start("GetGradient"); - obj_->GetGradient(preds_, train->Info(), iter, &gpair_); + obj_->GetGradient(preds_[train], train->Info(), iter, &gpair_); monitor_.Stop("GetGradient"); gbm_->DoBoost(train, &gpair_, obj_.get()); monitor_.Stop("UpdateOneIter"); @@ -520,11 +520,12 @@ class LearnerImpl : public Learner { metrics_.back()->Configure(cfg_.begin(), cfg_.end()); } for (size_t i = 0; i < data_sets.size(); ++i) { - this->PredictRaw(data_sets[i], &preds_); - obj_->EvalTransform(&preds_); + DMatrix * dmat = data_sets[i]; + this->PredictRaw(data_sets[i], &preds_[dmat]); + obj_->EvalTransform(&preds_[dmat]); for (auto& ev : metrics_) { os << '\t' << data_names[i] << '-' << ev->Name() << ':' - << ev->Eval(preds_, data_sets[i]->Info(), + << ev->Eval(preds_[dmat], data_sets[i]->Info(), tparam_.dsplit == DataSplitMode::kRow); } } @@ -565,10 +566,10 @@ class LearnerImpl : public Learner { std::string metric) { if (metric == "auto") metric = obj_->DefaultEvalMetric(); std::unique_ptr ev(Metric::Create(metric.c_str())); - this->PredictRaw(data, &preds_); - obj_->EvalTransform(&preds_); + this->PredictRaw(data, &preds_[data]); + obj_->EvalTransform(&preds_[data]); return std::make_pair(metric, - ev->Eval(preds_, data->Info(), + ev->Eval(preds_[data], data->Info(), tparam_.dsplit == DataSplitMode::kRow)); } @@ -771,7 +772,7 @@ class LearnerImpl : public Learner { // name of objective function std::string name_obj_; // temporal storages for prediction - HostDeviceVector preds_; + std::map> preds_; // gradient pairs HostDeviceVector gpair_;