Skip to content
This repository has been archived by the owner on Oct 22, 2020. It is now read-only.

Commit

Permalink
Improved merge block reconstruction
Browse files Browse the repository at this point in the history
  • Loading branch information
MaikKlein committed Mar 3, 2018
1 parent 7fa3af1 commit 76c1d30
Show file tree
Hide file tree
Showing 6 changed files with 246 additions and 119 deletions.
16 changes: 10 additions & 6 deletions quad/shader/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ fn color_frag(
(false, true) => Vec3::new(1.0, 0.0, 0.0),
(false, false) => Vec3::new(1.0, 1.0, 1.0),
};
// let color: Vec3<f32> = match right {
// true => Vec3::new(0.0, 0.0, 1.0),
// false => Vec3::new(0.0, 1.0, 0.0),
// };

// let coord = uv.extend(uv.x)
// .add(offset)
Expand Down Expand Up @@ -52,13 +56,13 @@ fn color_frag(
// Output::new(coord)
//}

fn test(f: f32) -> f32 {
f + 1.0
}
// fn test(f: f32) -> f32 {
// f + 1.0
// }

fn test_mut(f: &mut f32) {
*f += 1.0;
}
// fn test_mut(f: &mut f32) {
// *f += 1.0;
// }
//fn test_add(f: &mut f32){
// *f += 0.5f32;
//}
Expand Down
25 changes: 11 additions & 14 deletions src/collector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,18 @@ use rustc::mir::mono::MonoItem;
use rustc::ty::{Instance, ParamEnv, TyCtxt};
use context::MirContext;
pub struct CollectCrateItems<'a, 'tcx: 'a> {
mtx: MirContext<'a, 'tcx>,
tcx: TyCtxt<'a, 'tcx, 'tcx>,
items: Vec<MonoItem<'tcx>>,
}
pub fn collect_crate_items<'a, 'tcx>(mtx: MirContext<'a, 'tcx>) -> Vec<MonoItem<'tcx>> {
pub fn collect_crate_items<'a, 'tcx>(
tcx: TyCtxt<'a, 'tcx, 'tcx>,
mir: &mir::Mir<'tcx>,
) -> Vec<MonoItem<'tcx>> {
let mut collector = CollectCrateItems {
mtx,
tcx,
items: Vec::new(),
};
collector.visit_mir(&mtx.mir);
collector.visit_mir(mir);
collector.items
}
impl<'a, 'tcx> rustc::mir::visit::Visitor<'tcx> for CollectCrateItems<'a, 'tcx> {
Expand All @@ -30,12 +33,12 @@ impl<'a, 'tcx> rustc::mir::visit::Visitor<'tcx> for CollectCrateItems<'a, 'tcx>
if let mir::Literal::Value { ref value } = constant.literal {
use rustc::middle::const_val::ConstVal;
if let ConstVal::Function(def_id, ref substs) = value.val {
let mono_substs = self.mtx.monomorphize(substs);
//let mono_substs = self.mtx.monomorphize(substs);
let instance = Instance::resolve(
self.mtx.tcx,
self.tcx,
ParamEnv::empty(rustc::traits::Reveal::All),
def_id,
&mono_substs,
substs,
).unwrap();
self.items.push(MonoItem::Fn(instance));
}
Expand All @@ -59,13 +62,7 @@ pub fn trans_all_items<'a, 'tcx>(
if let &MonoItem::Fn(ref instance) = item {
let mir = tcx.maybe_optimized_mir(instance.def_id());
if let Some(mir) = mir {
let mtx = MirContext {
tcx,
mir,
substs: instance.substs,
def_id: instance.def_id(),
};
let new_items = collect_crate_items(mtx);
let new_items = collect_crate_items(tcx, &mir);
if !new_items.is_empty() {
uncollected_items.push(new_items)
}
Expand Down
124 changes: 117 additions & 7 deletions src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,21 +50,21 @@ impl<'a, 'tcx> CodegenCx<'a, 'tcx> {
.map(|&id| FunctionCall::Intrinsic(id)))
}

pub fn constant_f32(&mut self, mcx: MirContext<'a, 'tcx>, value: f32) -> Value {
pub fn constant_f32(&mut self, value: f32) -> Value {
use std::convert::TryFrom;
let val = ConstValue::Float(ConstFloat::from_u128(
TryFrom::try_from(value.to_bits()).expect("Could not convert from f32 to u128"),
syntax::ast::FloatTy::F32,
));
self.constant(mcx, val)
self.constant(val)
}

pub fn constant_u32(&mut self, mcx: MirContext<'a, 'tcx>, value: u32) -> Value {
pub fn constant_u32(&mut self, value: u32) -> Value {
let val = ConstValue::Integer(ConstInt::U32(value));
self.constant(mcx, val)
self.constant(val)
}

pub fn constant(&mut self, mcx: MirContext<'a, 'tcx>, val: ConstValue) -> Value {
pub fn constant(&mut self, val: ConstValue) -> Value {
if let Some(val) = self.const_cache.get(&val) {
return *val;
}
Expand All @@ -76,7 +76,7 @@ impl<'a, 'tcx> CodegenCx<'a, 'tcx> {
let value = const_int.to_u128_unchecked() as u32;
self.builder.constant_u32(spirv_ty.word, value)
}
ConstValue::Bool(b) => self.constant_u32(mcx, b as u32).word,
ConstValue::Bool(b) => self.constant_u32(b as u32).word,
ConstValue::Float(const_float) => {
use rustc::infer::unify_key::ToType;
let value: f32 = unsafe { ::std::mem::transmute(const_float.bits as u32) };
Expand Down Expand Up @@ -465,14 +465,124 @@ impl<'a, 'tcx> CodegenCx<'a, 'tcx> {
}
}

#[derive(Copy, Clone)]
#[derive(Clone)]
pub struct SpirvMir<'a, 'tcx: 'a> {
pub def_id: hir::def_id::DefId,
pub mir: mir::Mir<'tcx>,
pub substs: &'tcx rustc::ty::subst::Substs<'tcx>,
pub merge_blocks: HashMap<mir::BasicBlock, mir::BasicBlock>,
pub tcx: TyCtxt<'a, 'tcx, 'tcx>,
}

impl<'a, 'tcx> SpirvMir<'a, 'tcx> {
pub fn mir(&self) -> &mir::Mir<'tcx> {
&self.mir
}
pub fn from_mir(mcx: &::MirContext<'a, 'tcx>) -> Self {
use mir::visit::Visitor;
struct FindMergeBlocks<'a, 'tcx: 'a> {
mir: &'a mir::Mir<'tcx>,
merge_blocks: HashMap<mir::BasicBlock, mir::BasicBlock>,
}

impl<'a, 'tcx> Visitor<'tcx> for FindMergeBlocks<'a, 'tcx> {
fn visit_terminator_kind(
&mut self,
block: mir::BasicBlock,
kind: &mir::TerminatorKind<'tcx>,
location: mir::Location,
) {
match kind {
&mir::TerminatorKind::SwitchInt {
ref discr,
switch_ty,
ref targets,
..
} => {
let merge_block =
::find_merge_block(self.mir, block, targets).expect("no merge block");
self.merge_blocks.insert(block, merge_block);
}
_ => (),
};
}
}

let mut visitor = FindMergeBlocks {
mir: mcx.mir,
merge_blocks: HashMap::new(),
};

visitor.visit_mir(mcx.mir);
let merge_blocks = visitor.merge_blocks;
let mut spirv_mir = mcx.mir.clone();
let mut fixed_merge_blocks = HashMap::new();
use syntax_pos::DUMMY_SP;
for (block, merge_block) in merge_blocks {
use rustc_data_structures::control_flow_graph::ControlFlowGraph;
use std::collections::HashSet;
let pred: HashSet<_> = ControlFlowGraph::predecessors(&spirv_mir, merge_block)
.into_iter()
.collect();
let suc: HashSet<_> = ControlFlowGraph::successors(&spirv_mir, block)
.into_iter()
.collect();
let previous_blocks: HashSet<_> = pred.intersection(&suc).collect();
if previous_blocks.is_empty() {
fixed_merge_blocks.insert(block, merge_block);
} else {
let terminator = mir::Terminator {
source_info: mir::SourceInfo {
span: DUMMY_SP,
scope: mir::ARGUMENT_VISIBILITY_SCOPE,
},
kind: mir::TerminatorKind::Goto {
target: merge_block,
},
};
let goto_data = mir::BasicBlockData::new(Some(terminator));
let goto_block = spirv_mir.basic_blocks_mut().push(goto_data);
fixed_merge_blocks.insert(block, goto_block);
for &previous_block in previous_blocks {
if let mir::TerminatorKind::Goto { ref mut target } = spirv_mir
.basic_blocks_mut()[previous_block]
.terminator_mut()
.kind
{
*target = goto_block;
} else {
panic!("Should be a goto");
}
}
}
}
SpirvMir {
mir: spirv_mir,
merge_blocks: fixed_merge_blocks,
substs: mcx.substs,
tcx: mcx.tcx,
def_id: mcx.def_id,
}
}
pub fn monomorphize<T>(&self, value: &T) -> T
where
T: rustc::infer::TransNormalize<'tcx>,
{
self.tcx.trans_apply_param_substs(self.substs, value)
}
}

#[derive(Clone)]
pub struct MirContext<'a, 'tcx: 'a> {
pub def_id: hir::def_id::DefId,
pub tcx: TyCtxt<'a, 'tcx, 'tcx>,
pub mir: &'a mir::Mir<'tcx>,
pub substs: &'tcx subst::Substs<'tcx>,
}
impl<'a, 'tcx> MirContext<'a, 'tcx> {
pub fn mir(&self) -> &'a mir::Mir<'tcx> {
self.mir
}
pub fn monomorphize<T>(&self, value: &T) -> T
where
T: rustc::infer::TransNormalize<'tcx>,
Expand Down
Loading

0 comments on commit 76c1d30

Please sign in to comment.