-
Notifications
You must be signed in to change notification settings - Fork 0
/
mlpcr_full.m
165 lines (155 loc) · 6.85 KB
/
mlpcr_full.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
% [weights, Intercept, lme, hp] = mlpcr_full(kfold,grps,bayesopt_opts,mlpcr_options)
%
% Performs MLPCR with hyperparameters estimated using bayesian
% optimization of mean squared error (MSE). See mlpcr help for additional
% details on usage and what to pass as mlpcr_options arguments.
%
% Hyperparameters can be specified explicitly or as type
% optimizableVariable. If specified as optimizableVariable they will be
% optimized and the best result will be returned. Explicitly specified
% hyperparameters will simply be passed on to mlpcr as is.
%
% Note: mlpcr_full() relies on cross validation (CV) to estimate
% hyperparameters. Do not use on data where CV estimates are invalid, for
% instance on resampled datasets.
%
% Input ::
%
% kfolds - number of folds to use when evaluating objective
% function.
%
% grps - Optimization relies on kfold cross validation. Grps
% is an n x 1 vector of labels (e.g. subjects) which
% are not fragmented across cv folds. Elements with
% the same label will either be in a test fold or a
% training fold, but not both. Thus, in a
% multisubject study providing subject labels will
% result in hyperparameter optimization for
% out-of-subject performance.
%
% bayesopts_opts - Options to pass to bayesopt() call. "help bayesopts"
% for details. If no acquisition function is
% specified expected-improvement-plus is used by
% default. Pass empty cell array to use defaults.
% Note: if running with {'UseParallel',true} invoke
% parpool manually and then run multithreadWorkers()
% before invoking mlpcr_full for best performance.
%
% mlpcr_options - Options to pass to mlpcr. Substitute
% optimizableVariable objects for any hyperparameter
% you want optmized.
% Note: optimizableVariable objects must each have a
% unique 'name' field value or you will encounter
% obscure errors ('Variable index exceeds table
% dimensions').
%
% Output ::
%
% weights - Cell array of weights returned by mlpcr (using
% optimal hyperparameters). See "help mlpcr" for
% details.
%
% Intercept - Cell array of intercepts returned by mlpcr (using
% optimal hyperparameters. See "help mlpcr" for
% details.
%
% lme - LinearMixedModel object returned by fitlme() when
% fitting PCA components to data after optimizing
% hyperparameters.
%
% hp - BayesianOptimization object returned by bayesopt().
% Useful for inspecting optimization results,
% resuming optimization for additional iterations,
% or getting estimate of objective function
% at optimal hyperparameters.
% Note: MSE estimate for final model is stored in
% hp.MinObjective.
%
% Dependencies:
% R2016b machine learning toolbox (required by this script for bayesopt)
% mlpcr_out_of_id_mse (required by this script)
% mlpcr_cv_pred (required by mlpcr_out_of_id_mse)
% mlpcr (required by mlpcr_cv_pred)
% mlpca (required by mlpcr)
% get_nested_var_comps (required by mlpcr_cv_pred)
% get_cntrng_mats (required by mlpca and get_nested_var_comps)
%
% Writen by Bogdan Petre (Feb 20, 2018)
function [weights, Intercept, lme, hp] = mlpcr_full(kfold,grps,bayesopt_opts,X,Y,varargin)
weights = {};
Intercept = {};
lme = {};
optVars = [];
mlpcr_arg = {X,Y};
for i = 1:length(varargin)
[new_arg, new_optVars] = extractOptVars(varargin{i});
mlpcr_arg = [mlpcr_arg, new_arg];
optVars = [optVars, new_optVars];
end
% construct obj fxn invocation
execstr = construct_obj_fxn(1, 'mlpcr_arg', mlpcr_arg{:});
eval(['objfxn = @(x1)(mlpcr_out_of_id_mse(kfold,grps,', execstr, '));']);
% default acquisition function
AcFxn = 'expected-improvement-plus';
% If acquisition function was specified, use that instead, and remove
% from bayesopt_opt (will be passed explicitly)
for i = 1:length(bayesopt_opts)
if ischar(bayesopt_opts)
switch bayesopt_opts{i}
case 'AcquisitionFunctionName'
AcFxn = bayesopt_opts{i+1};
bayesopt_opts = bayesopt_opts(1:i-1,i+3:end);
end
end
end
hp = bayesopt(objfxn,optVars,'AcquisitionFunctionName',AcFxn,bayesopt_opts{:});
execstr = strrep(execstr,'(x1(','(hp.XAtMinEstimatedObjective(');
eval(['[weights,Intercept,lme] = mlpcr(', execstr, ');']);
end
% Parses arguments from mlpcr_full
% fixedArgs will be identical to varargin, except that optimizableVariable
% types will be replaced with nans, and the optimizableVariable will be
% copied to optVars.
function [fixedArgs, optVars] = extractOptVars(varargin)
optVars = [];
fixedArgs = {};
for i = 1:length(varargin)
% nans are used as flags for optimizableVariables internally, so make sure there aren't any preexisting.
if isnumeric(varargin{i})
if isnan(varargin{i})
error('Found ''nan'' in argument list. This is not supported.');
end
end
switch class(varargin{i})
case 'optimizableVariable'
optVars = [optVars, varargin{i}];
fixedArgs = [fixedArgs, {nan}];
case 'cell'
[subFixedArgs, subOptVars] = extractOptVars(varargin{i}{:});
optVars = [optVars, subOptVars];
fixedArgs = [fixedArgs, {subFixedArgs}];
otherwise
fixedArgs = [fixedArgs, varargin(i)];
end
end
end
% ov_idx - optimization variable index
% fv_idx - fixed variable index
function [execstr, ov_idx] = construct_obj_fxn(ov_idx, name, varargin)
execstr = [];
for i = 1:length(varargin)
if ~isa(varargin{i},'cell')
if isnan(varargin{i})
execstr = [execstr, 'table2array(x1(1,' int2str(ov_idx) ')),'];
ov_idx = ov_idx + 1;
else
execstr = [execstr, name, '{', int2str(i), '},'];
end
else
% Note: ov_idx's carry through
[new_execstr, ov_idx] = construct_obj_fxn(ov_idx, [name, '{' int2str(i) '}'], varargin{i}{:});
execstr = [execstr, '{', new_execstr, '},'];
end
end
execstr = execstr(1:end-1); % drop trailing comma
end