Skip to content

Commit

Permalink
Add HSIC sensitivity analysis (#187)
Browse files Browse the repository at this point in the history
* Add HSIC computation

* Generate Sensitivity Analyser code (thrift)

* Add HSIC sensitivity analysis server

* Implement HSIC sensitivity analyser on ruby side

* Make HSIC sensitivity available in REST API

* Add var names in HSIC computation API

* Add HSIC sensitivity plot

* Relax test tolerance

* Linting
  • Loading branch information
relf authored Sep 25, 2023
1 parent 2164dbd commit 95c7a2b
Show file tree
Hide file tree
Showing 35 changed files with 2,062 additions and 157 deletions.
6 changes: 5 additions & 1 deletion app/controllers/api/v1/sensitivity_analyses_controller.rb
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,12 @@ def _get_sensitivity_analysis_infos(ope)
status, sa, err = analyser.run
return { statusOk: status, sensitivity: sa, error: err }
end
when Operation::CAT_DOE
analyser = WhatsOpt::HsicSensitivityAnalyser.new(ope)
status, sa, err = analyser.get_hsic_sensitivity
return { statusOk: status, sensitivity: sa, error: err }
end
{ statusOk: false, sensitivity: sa,
error: "Bad operation category: Should be #{Operation::CAT_SENSITIVITY} (got #{ope.category})" }
error: "Bad operation category: Should be #{Operation::CAT_SENSITIVITY} or #{Operation::CAT_DOE} (got #{ope.category})" }
end
end
79 changes: 79 additions & 0 deletions app/javascript/plotter/components/HsicScatterPlot.jsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import React from 'react';
import PropTypes from 'prop-types';
import createPlotlyComponent from 'react-plotly.js/factory';
import Plotly from './custom-plotly';

const Plot = createPlotlyComponent(Plotly);

class HsicScatterPlot extends React.PureComponent {
render() {
const { hsicData } = this.props;
const {
parameters_names: parameterNames, obj_name: objName, hsic: { r2, pvperm },
} = hsicData;
const traceR2 = {
name: 'R2',
x: parameterNames.map((_, i) => i + 0.9),
y: r2,
type: 'scatter',
mode: 'markers+text',
text: parameterNames,
textposition: 'left center',
marker: { size: 10 },
cliponaxis: false,
};

const tracePvperm = {
name: 'p-value by permutation',
x: parameterNames.map((_, i) => i + 1.1),
y: pvperm,
type: 'scatter',
mode: 'markers+text',
text: parameterNames,
textposition: 'right center',
marker: { size: 10 },
cliponaxis: false,
yaxis: 'y2',
};

const data = [traceR2, tracePvperm];
const layout = {
title: `${objName} optimization HSIC sensitivity`,
xaxis: {
rangemode: 'tozero',
title: { text: 'Design Variables' },
},
yaxis: {
rangemode: 'tozero',
title: { text: 'HSIC r2' },
titlefont: { color: '#1f77b4' },
tickfont: { color: '#1f77b4' },
},
yaxis2: {
rangemode: 'tozero',
title: 'HSIC p-value',
titlefont: { color: '#ff7f0e' },
tickfont: { color: '#ff7f0e' },
overlaying: 'y',
side: 'right',
},
};

return (<Plot data={data} layout={layout} />);
}
}

HsicScatterPlot.propTypes = {
hsicData: PropTypes.shape({
obj_name: PropTypes.string.isRequired,
hsic: PropTypes.shape({
indices: PropTypes.array.isRequired,
r2: PropTypes.array,
pvas: PropTypes.array.isRequired,
pvperm: PropTypes.array,
}),
parameters_names: PropTypes.array.isRequired,
}).isRequired,
};

export default HsicScatterPlot;
42 changes: 42 additions & 0 deletions app/javascript/plotter/index.jsx
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,37 @@ import MetaModelManager from 'plotter/components/MetaModelManager';
import AnalysisDatabase from '../utils/AnalysisDatabase';
import * as caseUtils from '../utils/cases';
import DistributionHistogram from './components/DistributionHistogram';
import HsicScatterPlot from './components/HsicScatterPlot';

const PLOTS_TAB = 'plots';
const VARIABLES_TAB = 'variables';
const METAMODEL_TAB = 'metamodel';

class PlotPanel extends React.Component {
constructor(props) {
super(props);
this.state = {
sensitivity: null,
};
}

componentDidMount() {
const { db } = this.props;
console.log('component did mount');
const obj = db.getObjective();
if (obj) {
const { api, opeId } = this.props;
console.log('OCUCOU');
api.analyseSensitivity(
opeId,
(response) => {
console.log(response.data);
this.setState({ ...response.data });
},
);
}
}

shouldComponentUpdate(nextProps /* , nextState */) {
return nextProps.active;
}
Expand All @@ -25,6 +50,8 @@ class PlotPanel extends React.Component {
const {
db, optim, cases, success, title, uqMode,
} = this.props;
const { sensitivity } = this.state;
console.log(this.state);

let plotdist;
if (uqMode) {
Expand Down Expand Up @@ -98,18 +125,31 @@ class PlotPanel extends React.Component {
</div>
);
}
let plothsic;
if (sensitivity) {
plothsic = (
<HsicScatterPlot
hsicData={sensitivity}
title={title}
width={600}
/>
);
}

return (
<div className="tab-pane fade active show" id={PLOTS_TAB} role="tabpanel" aria-labelledby="plots-tab">
{plotdist}
{plotparall}
{plotoptim}
{plothsic}
</div>
);
}
}

PlotPanel.propTypes = {
api: PropTypes.object.isRequired,
opeId: PropTypes.number.isRequired,
active: PropTypes.bool.isRequired,
db: PropTypes.object.isRequired,
optim: PropTypes.bool.isRequired,
Expand Down Expand Up @@ -367,6 +407,8 @@ class Plotter extends React.Component {
</ul>
<div className="tab-content" id="myTabContent">
<PlotPanel
api={this.api}
opeId={ope.id}
db={this.db}
optim={isOptim}
uqMode={uqMode}
Expand Down
44 changes: 44 additions & 0 deletions app/lib/whats_opt/hsic_sensitivity_analyser.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# frozen_string_literal: true

require "matrix"

module WhatsOpt
class HsicSensitivityAnalyser
def initialize(ope_doe)
if ope_doe.doe?
@ope = ope_doe
@proxy = SensitivityAnalyserProxy.new
else
raise "Bad operation: doe operation required for hsic sensitivity analyzer (got #{ope_pce.base_operation.driver})"
end
end

def get_hsic_sensitivity(thresholding=Services::HsicThresholding::ZERO, quantile=0.2, g_threshold=0.0)
ok, err = true, ""

# xdoe from cases
xdata = @ope.input_cases.map {|c| c.values}
xdoe = Matrix[*xdata].t

# objective from cases
# XXX: suppose mono-objective
obj_vars = @ope.cases.with_role_case(WhatsOpt::Variable::MIN_OBJECTIVE_ROLE)
obj_vals = obj_vars.map {|c| c.values}
# cstrs from cases
# XXX: works only for negative constraints
cstrs_vars = @ope.cases.with_role_case(WhatsOpt::Variable::NEG_CONSTRAINT_ROLE)
cstrs_vals = cstrs_vars.map {|c| c.values}
ydoe = (Matrix[*obj_vals].vstack(Matrix[*cstrs_vals])).t

hsic = @proxy.compute_hsic(xdoe.to_a, ydoe.to_a, thresholding, quantile, g_threshold)
result = { obj_name: obj_vars.first.var_label, hsic: hsic, parameters_names: @ope.input_cases.map {|c| c.var_label} }
return ok, result, err
rescue StandardError => e
p e
return false, {}, e.to_s
end

end
end


34 changes: 34 additions & 0 deletions app/lib/whats_opt/sensitivity_analyser_proxy.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# frozen_string_literal: trueTypes

require "thrift"
require "sensitivity_analyser"

module WhatsOpt
class SensitivityAnalyserProxy < ServiceProxy
def _initialize
@analyser_protocol = Thrift::MultiplexedProtocol.new(
@protocol, "SensitivityAnalyserService"
)
@client = Services::SensitivityAnalyser::Client.new(@analyser_protocol)
end

def compute_hsic(xdoe, ydoe, thresholding, quantile, g_threshold)
hsic = nil
_send { hsic = @client.compute_hsic(xdoe, ydoe, thresholding, quantile, g_threshold) }
hsic
end

def _send
@transport.open()
yield
rescue Thrift::TransportException => e
# puts "#{e}"
Rails.logger.warn e
false
else
true
ensure
@transport.close()
end
end
end
98 changes: 98 additions & 0 deletions app/lib/whats_opt/services/sensitivity_analyser.rb

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 95c7a2b

Please sign in to comment.