diff --git a/src/lib.rs b/src/lib.rs index ac67da49..e767790f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -895,13 +895,24 @@ pub enum ControlNodeKind { }, } +// FIXME(eddyb) consider interning this, perhaps in a similar vein to `DataInstForm`. #[derive(Clone)] pub enum SelectionKind { /// Two-case selection based on boolean condition, i.e. `if`-`else`, with /// the two cases being "then" and "else" (in that order). BoolCond, - SpvInst(spv::Inst), + /// `N+1`-case selection based on comparing an integer scrutinee against + /// `N` constants, i.e. `switch`, with the last case being the "default" + /// (making it the only case without a matching entry in `case_consts`). + Switch { + // FIXME(eddyb) avoid some of the `scalar::Const` overhead here, as there + // is only a single type and we shouldn't need to store more bits per case, + // than the actual width of the integer type. + // FIXME(eddyb) consider storing this more like sorted compressed keyset, + // as there can be no duplicates, and in many cases it may be contiguous. + case_consts: Vec, + }, } /// Entity handle for a [`DataInstDef`](crate::DataInstDef) (an SSA instruction). diff --git a/src/print/mod.rs b/src/print/mod.rs index cd64c5d2..59719995 100644 --- a/src/print/mod.rs +++ b/src/print/mod.rs @@ -2953,7 +2953,7 @@ impl Print for FuncAt<'_, ControlNode> { ( pretty::join_comma_sep( "(", - input_decls_and_uses.clone().zip(initial_inputs).map( + input_decls_and_uses.clone().zip_eq(initial_inputs).map( |((input_decl, input_use), initial)| { pretty::Fragment::new([ input_decl.print(printer).insert_name_before_def( @@ -3483,7 +3483,7 @@ impl SelectionKind { mut cases: impl ExactSizeIterator, ) -> pretty::Fragment { let kw = |kw| kw_style.apply(kw).into(); - match *self { + match self { SelectionKind::BoolCond => { assert_eq!(cases.len(), 2); let [then_case, else_case] = [cases.next().unwrap(), cases.next().unwrap()]; @@ -3500,27 +3500,36 @@ impl SelectionKind { "}".into(), ]) } - SelectionKind::SpvInst(spv::Inst { opcode, ref imms }) => { - let header = printer.pretty_spv_inst( - kw_style, - opcode, - imms, - [Some(scrutinee.print(printer))] - .into_iter() - .chain((0..cases.len()).map(|_| None)), - ); + SelectionKind::Switch { case_consts } => { + assert_eq!(cases.len(), case_consts.len() + 1); + + let case_patterns = case_consts + .iter() + .map(|&ct| { + let int_to_string = (ct.int_as_u128().map(|x| x.to_string())) + .or_else(|| ct.int_as_i128().map(|x| x.to_string())); + match int_to_string { + Some(v) => printer.numeric_literal_style().apply(v).into(), + None => { + let ct: Const = printer.cx.intern(ct); + ct.print(printer) + } + } + }) + .chain(["_".into()]); pretty::Fragment::new([ - header, + kw("switch"), + " ".into(), + scrutinee.print(printer), " {".into(), pretty::Node::IndentedBlock( - cases - .map(|case| { + case_patterns + .zip_eq(cases) + .map(|(case_pattern, case)| { pretty::Fragment::new([ pretty::Node::ForceLineSeparation.into(), - // FIXME(eddyb) this should pull information out - // of the instruction to be more precise. - kw("case"), + case_pattern, " => {".into(), pretty::Node::IndentedBlock(vec![case]).into(), "}".into(), diff --git a/src/spv/canonical.rs b/src/spv/canonical.rs index c170e047..e99d21dd 100644 --- a/src/spv/canonical.rs +++ b/src/spv/canonical.rs @@ -165,7 +165,11 @@ def_mappable_ops! { } impl scalar::Const { - fn try_decode_from_spv_imms(ty: scalar::Type, imms: &[spv::Imm]) -> Option { + // HACK(eddyb) this is not private so `spv::lower` can use it for `OpSwitch`. + pub(super) fn try_decode_from_spv_imms( + ty: scalar::Type, + imms: &[spv::Imm], + ) -> Option { // FIXME(eddyb) don't hardcode the 128-bit limitation, // but query `scalar::Const` somehow instead. if ty.bit_width() > 128 { @@ -198,7 +202,8 @@ impl scalar::Const { } } - fn encode_as_spv_imms(&self) -> impl Iterator { + // HACK(eddyb) this is not private so `spv::lift` can use it for `OpSwitch`. + pub(super) fn encode_as_spv_imms(&self) -> impl Iterator { let wk = &spec::Spec::get().well_known; let ty = self.ty(); diff --git a/src/spv/lift.rs b/src/spv/lift.rs index d84bf1cb..100a9f5c 100644 --- a/src/spv/lift.rs +++ b/src/spv/lift.rs @@ -1309,6 +1309,14 @@ impl LazyInst<'_, '_> { ids: [merge_label_id, continue_label_id].into_iter().collect(), }, Self::Terminator { parent_func, terminator } => { + let mut ids: SmallVec<[_; 4]> = terminator + .inputs + .iter() + .map(|&v| value_to_id(parent_func, v)) + .chain(terminator.targets.iter().map(|&target| parent_func.label_ids[&target])) + .collect(); + + // FIXME(eddyb) move some of this to `spv::canonical`. let inst = match &*terminator.kind { cfg::ControlInstKind::Unreachable => wk.OpUnreachable.into(), cfg::ControlInstKind::Return => { @@ -1327,23 +1335,21 @@ impl LazyInst<'_, '_> { cfg::ControlInstKind::SelectBranch(SelectionKind::BoolCond) => { wk.OpBranchConditional.into() } - cfg::ControlInstKind::SelectBranch(SelectionKind::SpvInst(inst)) => { - inst.clone() + cfg::ControlInstKind::SelectBranch(SelectionKind::Switch { case_consts }) => { + // HACK(eddyb) move the default case from last back to first. + let default_target = ids.pop().unwrap(); + ids.insert(1, default_target); + + spv::Inst { + opcode: wk.OpSwitch, + imms: case_consts + .iter() + .flat_map(|ct| ct.encode_as_spv_imms()) + .collect(), + } } }; - spv::InstWithIds { - without_ids: inst, - result_type_id: None, - result_id: None, - ids: terminator - .inputs - .iter() - .map(|&v| value_to_id(parent_func, v)) - .chain( - terminator.targets.iter().map(|&target| parent_func.label_ids[&target]), - ) - .collect(), - } + spv::InstWithIds { without_ids: inst, result_type_id: None, result_id: None, ids } } Self::OpFunctionEnd => spv::InstWithIds { without_ids: wk.OpFunctionEnd.into(), diff --git a/src/spv/lower.rs b/src/spv/lower.rs index 578fcaeb..d8b3d45e 100644 --- a/src/spv/lower.rs +++ b/src/spv/lower.rs @@ -3,11 +3,11 @@ use crate::spv::{self, spec}; // FIXME(eddyb) import more to avoid `crate::` everywhere. use crate::{ - cfg, print, AddrSpace, Attr, AttrSet, Const, ConstDef, ConstKind, Context, ControlNodeDef, - ControlNodeKind, ControlRegion, ControlRegionDef, ControlRegionInputDecl, DataInstDef, - DataInstFormDef, DataInstKind, DeclDef, Diag, EntityDefs, EntityList, ExportKey, Exportee, - Func, FuncDecl, FuncDefBody, FuncParam, FxIndexMap, GlobalVarDecl, GlobalVarDefBody, Import, - InternedStr, Module, SelectionKind, Type, TypeDef, TypeKind, TypeOrConst, Value, + cfg, print, scalar, AddrSpace, Attr, AttrSet, Const, ConstDef, ConstKind, Context, + ControlNodeDef, ControlNodeKind, ControlRegion, ControlRegionDef, ControlRegionInputDecl, + DataInstDef, DataInstFormDef, DataInstKind, DeclDef, Diag, EntityDefs, EntityList, ExportKey, + Exportee, Func, FuncDecl, FuncDefBody, FuncParam, FxIndexMap, GlobalVarDecl, GlobalVarDefBody, + Import, InternedStr, Module, SelectionKind, Type, TypeDef, TypeKind, TypeOrConst, Value, }; use rustc_hash::FxHashMap; use smallvec::SmallVec; @@ -85,6 +85,20 @@ fn invalid(reason: &str) -> io::Error { io::Error::new(io::ErrorKind::InvalidData, format!("malformed SPIR-V ({reason})")) } +fn invalid_factory_for_spv_inst( + inst: &spv::Inst, + result_id: Option, + ids: &[spv::Id], +) -> impl Fn(&str) -> io::Error { + let opcode = inst.opcode; + let first_id_operand = ids.first().copied(); + move |msg: &str| { + let result_prefix = result_id.map(|id| format!("%{id} = ")).unwrap_or_default(); + let operand_suffix = first_id_operand.map(|id| format!(" %{id} ...")).unwrap_or_default(); + invalid(&format!("in {result_prefix}{}{operand_suffix}: {msg}", opcode.name())) + } +} + // FIXME(eddyb) provide more information about any normalization that happened: // * stats about deduplication that occured through interning // * sets of unused global vars and functions (and types+consts only they use) @@ -195,7 +209,7 @@ impl Module { while let Some(mut inst) = spv_insts.next().transpose()? { let opcode = inst.opcode; - let invalid = |msg: &str| invalid(&format!("in {}: {}", opcode.name(), msg)); + let invalid = invalid_factory_for_spv_inst(&inst, inst.result_id, &inst.ids); // Handle line debuginfo early, as it doesn't have its own section, // but rather can go almost anywhere among globals and functions. @@ -861,7 +875,7 @@ impl Module { #[derive(Copy, Clone)] enum LocalIdDef { - Value(Value), + Value(Type, Value), BlockLabel(ControlRegion), } @@ -889,6 +903,7 @@ impl Module { let IntraFuncInst { without_ids: spv::Inst { opcode, ref imms }, result_id, + result_type, .. } = *raw_inst; @@ -903,10 +918,10 @@ impl Module { DeclDef::Present(def) => def.body, }; - LocalIdDef::Value(Value::ControlRegionInput { - region: body, - input_idx: idx, - }) + LocalIdDef::Value( + result_type.unwrap(), + Value::ControlRegionInput { region: body, input_idx: idx }, + ) } else { let is_entry_block = !has_blocks; has_blocks = true; @@ -957,10 +972,13 @@ impl Module { .push(value_id); } - LocalIdDef::Value(Value::ControlRegionInput { - region: current_block, - input_idx: phi_idx, - }) + LocalIdDef::Value( + result_type.unwrap(), + Value::ControlRegionInput { + region: current_block, + input_idx: phi_idx, + }, + ) } else { // HACK(eddyb) can't get a `DataInst` without // defining it (as a dummy) first. @@ -974,7 +992,7 @@ impl Module { } .into(), ); - LocalIdDef::Value(Value::DataInstOutput(inst)) + LocalIdDef::Value(result_type.unwrap(), Value::DataInstOutput(inst)) } }; local_id_defs.insert(id, local_id_def); @@ -1023,50 +1041,52 @@ impl Module { ref ids, } = *raw_inst; - let invalid = |msg: &str| invalid(&format!("in {}: {}", opcode.name(), msg)); + let invalid = invalid_factory_for_spv_inst(&raw_inst.without_ids, result_id, ids); // FIXME(eddyb) find a more compact name and/or make this a method. // FIXME(eddyb) this returns `LocalIdDef` even for global values. - let lookup_global_or_local_id_for_data_or_control_inst_input = - |id| match id_defs.get(&id) { - Some(&IdDef::Const(ct)) => Ok(LocalIdDef::Value(Value::Const(ct))), - Some(id_def @ IdDef::Type(_)) => Err(invalid(&format!( - "unsupported use of {} as an operand for \ + let lookup_global_or_local_id_for_data_or_control_inst_input = |id| match id_defs + .get(&id) + { + Some(&IdDef::Const(ct)) => Ok(LocalIdDef::Value(cx[ct].ty, Value::Const(ct))), + Some(id_def @ IdDef::Type(_)) => Err(invalid(&format!( + "unsupported use of {} as an operand for \ an instruction in a function", - id_def.descr(&cx), - ))), - Some(id_def @ IdDef::Func(_)) => Err(invalid(&format!( - "unsupported use of {} outside `OpFunctionCall`", - id_def.descr(&cx), - ))), - Some(id_def @ IdDef::SpvDebugString(s)) => { - if opcode == wk.OpExtInst { - // HACK(eddyb) intern `OpString`s as `Const`s on - // the fly, as it's a less likely usage than the - // `OpLine` one. - let ct = cx.intern(ConstDef { - attrs: AttrSet::default(), - ty: cx.intern(TypeKind::SpvStringLiteralForExtInst), - kind: ConstKind::SpvStringLiteralForExtInst(*s), - }); - Ok(LocalIdDef::Value(Value::Const(ct))) - } else { - Err(invalid(&format!( - "unsupported use of {} outside `OpSource`, \ + id_def.descr(&cx), + ))), + Some(id_def @ IdDef::Func(_)) => Err(invalid(&format!( + "unsupported use of {} outside `OpFunctionCall`", + id_def.descr(&cx), + ))), + Some(id_def @ IdDef::SpvDebugString(s)) => { + if opcode == wk.OpExtInst { + // HACK(eddyb) intern `OpString`s as `Const`s on + // the fly, as it's a less likely usage than the + // `OpLine` one. + let ty = cx.intern(TypeKind::SpvStringLiteralForExtInst); + let ct = cx.intern(ConstDef { + attrs: AttrSet::default(), + ty, + kind: ConstKind::SpvStringLiteralForExtInst(*s), + }); + Ok(LocalIdDef::Value(ty, Value::Const(ct))) + } else { + Err(invalid(&format!( + "unsupported use of {} outside `OpSource`, \ `OpLine`, or `OpExtInst`", - id_def.descr(&cx), - ))) - } + id_def.descr(&cx), + ))) } - Some(id_def @ IdDef::SpvExtInstImport(_)) => Err(invalid(&format!( - "unsupported use of {} outside `OpExtInst`", - id_def.descr(&cx), - ))), - None => local_id_defs - .get(&id) - .copied() - .ok_or_else(|| invalid(&format!("undefined ID %{id}",))), - }; + } + Some(id_def @ IdDef::SpvExtInstImport(_)) => Err(invalid(&format!( + "unsupported use of {} outside `OpExtInst`", + id_def.descr(&cx), + ))), + None => local_id_defs + .get(&id) + .copied() + .ok_or_else(|| invalid(&format!("undefined ID %{id}",))), + }; if opcode == wk.OpFunctionParameter { if current_block_control_region_and_details.is_some() { @@ -1104,7 +1124,7 @@ impl Module { // to be able to have an entry in `local_id_defs`. let control_region = match local_id_defs[&result_id.unwrap()] { LocalIdDef::BlockLabel(control_region) => control_region, - LocalIdDef::Value(_) => unreachable!(), + LocalIdDef::Value(..) => unreachable!(), }; let current_block_details = &block_details[&control_region]; assert_eq!(current_block_details.label_id, result_id.unwrap()); @@ -1140,7 +1160,7 @@ impl Module { }; let phi_value_id_to_value = |phi_key: &PhiKey, id| { match lookup_global_or_local_id_for_data_or_control_inst_input(id)? { - LocalIdDef::Value(v) => Ok(v), + LocalIdDef::Value(_, v) => Ok(v), LocalIdDef::BlockLabel { .. } => Err(invalid(&format!( "unsupported use of block label as the value for {}", descr_phi_case(phi_key) @@ -1190,10 +1210,11 @@ impl Module { // Split the operands into value inputs (e.g. a branch's // condition or an `OpSwitch`'s selector) and target blocks. let mut inputs = SmallVec::new(); + let mut input_types = SmallVec::<[_; 2]>::new(); let mut targets = SmallVec::new(); for &id in ids { match lookup_global_or_local_id_for_data_or_control_inst_input(id)? { - LocalIdDef::Value(v) => { + LocalIdDef::Value(ty, v) => { if !targets.is_empty() { return Err(invalid( "out of order: value operand \ @@ -1201,6 +1222,7 @@ impl Module { )); } inputs.push(v); + input_types.push(ty); } LocalIdDef::BlockLabel(target) => { record_cfg_edge(target)?; @@ -1209,6 +1231,7 @@ impl Module { } } + // FIXME(eddyb) move some of this to `spv::canonical`. let kind = if opcode == wk.OpUnreachable { assert!(targets.is_empty() && inputs.is_empty()); cfg::ControlInstKind::Unreachable @@ -1226,9 +1249,62 @@ impl Module { assert_eq!((targets.len(), inputs.len()), (2, 1)); cfg::ControlInstKind::SelectBranch(SelectionKind::BoolCond) } else if opcode == wk.OpSwitch { - cfg::ControlInstKind::SelectBranch(SelectionKind::SpvInst( - raw_inst.without_ids.clone(), - )) + assert_eq!(inputs.len(), 1); + + // HACK(eddyb) `spv::read` has to "redundantly" validate + // that such a type is `OpTypeInt`/`OpTypeFloat`, but + // there is still a limitation when it comes to `scalar::Const`. + // FIXME(eddyb) don't hardcode the 128-bit limitation, + // but query `scalar::Const` somehow instead. + let scrutinee_type = input_types[0]; + let scrutinee_type = scrutinee_type + .as_scalar(&cx) + .filter(|ty| { + matches!(ty, scalar::Type::UInt(_) | scalar::Type::SInt(_)) + && ty.bit_width() <= 128 + }) + .ok_or_else(|| { + invalid( + &print::Plan::for_root( + &cx, + &Diag::err([ + "unsupported `OpSwitch` scrutinee type `".into(), + scrutinee_type.into(), + "`".into(), + ]) + .message, + ) + .pretty_print() + .to_string(), + ) + })?; + + // FIXME(eddyb) move some of this to `spv::canonical`. + let imm_words_per_case = + usize::try_from(scrutinee_type.bit_width().div_ceil(32)).unwrap(); + + // NOTE(eddyb) these sanity-checks are redundant with `spv::read`. + assert_eq!(imms.len() % imm_words_per_case, 0); + assert_eq!(targets.len(), 1 + imms.len() / imm_words_per_case); + + let case_consts = imms + .chunks(imm_words_per_case) + .map(|case_imms| { + scalar::Const::try_decode_from_spv_imms(scrutinee_type, case_imms) + .ok_or_else(|| { + invalid(&format!( + "invalid {}-bit `OpSwitch` case constant", + scrutinee_type.bit_width() + )) + }) + }) + .collect::>()?; + + // HACK(eddyb) move the default case from first to last. + let default_target = targets.remove(0); + targets.push(default_target); + + cfg::ControlInstKind::SelectBranch(SelectionKind::Switch { case_consts }) } else { return Err(invalid("unsupported control-flow instruction")); }; @@ -1273,7 +1349,7 @@ impl Module { let loop_merge_target = match lookup_global_or_local_id_for_data_or_control_inst_input(ids[0])? { - LocalIdDef::Value(_) => return Err(invalid("expected label ID")), + LocalIdDef::Value(..) => return Err(invalid("expected label ID")), LocalIdDef::BlockLabel(target) => target, }; @@ -1372,7 +1448,7 @@ impl Module { .map(|&id| { match lookup_global_or_local_id_for_data_or_control_inst_input(id)? { - LocalIdDef::Value(v) => Ok(v), + LocalIdDef::Value(_, v) => Ok(v), LocalIdDef::BlockLabel { .. } => Err(invalid( "unsupported use of block label as a value, \ in non-terminator instruction", @@ -1383,7 +1459,7 @@ impl Module { }; let inst = match result_id { Some(id) => match local_id_defs[&id] { - LocalIdDef::Value(Value::DataInstOutput(inst)) => { + LocalIdDef::Value(_, Value::DataInstOutput(inst)) => { // A dummy was defined earlier, to be able to // have an entry in `local_id_defs`. func_def_body.data_insts[inst] = data_inst_def.into(); diff --git a/src/spv/read.rs b/src/spv/read.rs index 5acf2930..c532b804 100644 --- a/src/spv/read.rs +++ b/src/spv/read.rs @@ -173,11 +173,8 @@ impl InstParser<'_> { .and_then(|id| self.known_ids.get(&id)) .ok_or(Error::MissingContextSensitiveLiteralType)?; - let extra_word_count = match *contextual_type { - KnownIdDef::TypeIntOrFloat(width) => { - // HACK(eddyb) `(width + 31) / 32 - 1` but without overflow. - (width.get() - 1) / 32 - } + let word_count = match *contextual_type { + KnownIdDef::TypeIntOrFloat(width) => width.get().div_ceil(32), KnownIdDef::Uncategorized { opcode, .. } => { return Err(Error::UnsupportedContextSensitiveLiteralType { type_opcode: opcode, @@ -185,11 +182,11 @@ impl InstParser<'_> { } }; - if extra_word_count == 0 { + if word_count == 1 { self.inst.imms.push(spv::Imm::Short(kind, word)); } else { self.inst.imms.push(spv::Imm::LongStart(kind, word)); - for _ in 0..extra_word_count { + for _ in 1..word_count { let word = self.words.next().ok_or(Error::NotEnoughWords)?; self.inst.imms.push(spv::Imm::LongCont(kind, word)); } diff --git a/src/transform.rs b/src/transform.rs index 6cb697a2..eef7932a 100644 --- a/src/transform.rs +++ b/src/transform.rs @@ -640,7 +640,7 @@ impl InnerInPlaceTransform for FuncAtMut<'_, ControlNode> { } } ControlNodeKind::Select { - kind: SelectionKind::BoolCond | SelectionKind::SpvInst(_), + kind: SelectionKind::BoolCond | SelectionKind::Switch { case_consts: _ }, scrutinee, cases: _, } => { @@ -747,7 +747,7 @@ impl InnerInPlaceTransform for cfg::ControlInst { | cfg::ControlInstKind::ExitInvocation(cfg::ExitInvocationKind::SpvInst(_)) | cfg::ControlInstKind::Branch | cfg::ControlInstKind::SelectBranch( - SelectionKind::BoolCond | SelectionKind::SpvInst(_), + SelectionKind::BoolCond | SelectionKind::Switch { case_consts: _ }, ) => {} } for v in inputs { diff --git a/src/visit.rs b/src/visit.rs index 1b5d9718..19a7a48b 100644 --- a/src/visit.rs +++ b/src/visit.rs @@ -475,7 +475,7 @@ impl<'a> FuncAt<'a, ControlNode> { } } ControlNodeKind::Select { - kind: SelectionKind::BoolCond | SelectionKind::SpvInst(_), + kind: SelectionKind::BoolCond | SelectionKind::Switch { case_consts: _ }, scrutinee, cases, } => { @@ -556,7 +556,7 @@ impl InnerVisit for cfg::ControlInst { | cfg::ControlInstKind::ExitInvocation(cfg::ExitInvocationKind::SpvInst(_)) | cfg::ControlInstKind::Branch | cfg::ControlInstKind::SelectBranch( - SelectionKind::BoolCond | SelectionKind::SpvInst(_), + SelectionKind::BoolCond | SelectionKind::Switch { case_consts: _ }, ) => {} } for v in inputs {