Skip to content

Commit

Permalink
Add DataInstKind::Scalar for pure scalars->scalars ops.
Browse files Browse the repository at this point in the history
  • Loading branch information
eddyb committed Jan 12, 2024
1 parent 3061f59 commit 8af71b9
Show file tree
Hide file tree
Showing 12 changed files with 509 additions and 39 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -139,10 +139,10 @@ global_var GV0 in spv.StorageClass.Output: s32

func F0() -> spv.OpTypeVoid {
loop(v0: s32 <- 1s32, v1: s32 <- 1s32) {
v2 = spv.OpSLessThan(v1, 10s32): bool
v2 = s.lt(v1, 10s32): bool
(v3: s32, v4: s32) = if v2 {
v5 = spv.OpIMul(v0, v1): s32
v6 = spv.OpIAdd(v1, 1s32): s32
v5 = i.mul(v0, v1): s32
v6 = i.add(v1, 1s32): s32
(v5, v6)
} else {
(undef: s32, undef: s32)
Expand Down
6 changes: 6 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -938,6 +938,12 @@ pub struct DataInstFormDef {

#[derive(Clone, PartialEq, Eq, Hash, derive_more::From)]
pub enum DataInstKind {
/// Scalar (`bool`, integer, and floating-point) pure operations.
///
/// See also the [`scalar`] module for more documentation and definitions.
#[from]
Scalar(scalar::Op),

// FIXME(eddyb) try to split this into recursive and non-recursive calls,
// to avoid needing special handling for recursion where it's impossible.
FuncCall(Func),
Expand Down
13 changes: 13 additions & 0 deletions src/print/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3044,6 +3044,19 @@ impl Print for FuncAt<'_, DataInst> {
let mut output_type_to_print = *output_type;

let def_without_type = match kind {
&DataInstKind::Scalar(op) => {
let name = op.name();
let (namespace_prefix, name) = name.split_at(name.find('.').unwrap() + 1);
pretty::Fragment::new([
printer
.demote_style_for_namespace_prefix(printer.declarative_keyword_style())
.apply(namespace_prefix)
.into(),
printer.declarative_keyword_style().apply(name).into(),
pretty::join_comma_sep("(", inputs.iter().map(|v| v.print(printer)), ")"),
])
}

&DataInstKind::FuncCall(func) => pretty::Fragment::new([
printer.declarative_keyword_style().apply("call").into(),
" ".into(),
Expand Down
2 changes: 2 additions & 0 deletions src/qptr/analyze.rs
Original file line number Diff line number Diff line change
Expand Up @@ -906,6 +906,8 @@ impl<'a> InferUsage<'a> {
});
};
match &data_inst_form_def.kind {
DataInstKind::Scalar(_) => {}

&DataInstKind::FuncCall(callee) => {
match self.infer_usage_in_func(module, callee) {
FuncInferUsageState::Complete(callee_results) => {
Expand Down
2 changes: 2 additions & 0 deletions src/qptr/lift.rs
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,8 @@ impl LiftToSpvPtrInstsInFunc<'_> {
Ok((addr_space, self.lifter.layout_of(pointee_type)?))
};
let replacement_data_inst_def = match &data_inst_form_def.kind {
DataInstKind::Scalar(_) => return Ok(Transformed::Unchanged),

&DataInstKind::FuncCall(_callee) => {
for &v in &data_inst_def.inputs {
if self.lifter.as_spv_ptr_type(type_of_val(v)).is_some() {
Expand Down
2 changes: 1 addition & 1 deletion src/qptr/lower.rs
Original file line number Diff line number Diff line change
Expand Up @@ -616,7 +616,7 @@ impl LowerFromSpvPtrInstsInFunc<'_> {

match data_inst_form_def.kind {
// Known semantics, no need to preserve SPIR-V pointer information.
DataInstKind::FuncCall(_) | DataInstKind::QPtr(_) => return,
DataInstKind::Scalar(_) | DataInstKind::FuncCall(_) | DataInstKind::QPtr(_) => return,

DataInstKind::SpvInst(_) | DataInstKind::SpvExtInst { .. } => {}
}
Expand Down
271 changes: 271 additions & 0 deletions src/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -196,3 +196,274 @@ impl Const {
self.int_as_u128()?.try_into().ok()
}
}

/// Pure operations with scalar inputs and outputs.
//
// FIXME(eddyb) these are not some "perfect" grouping, but allow for more
// flexibility in users of this `enum` (and its component `enum`s).
#[derive(Copy, Clone, PartialEq, Eq, Hash, derive_more::From)]
pub enum Op {
BoolUnary(BoolUnOp),
BoolBinary(BoolBinOp),

IntUnary(IntUnOp),
IntBinary(IntBinOp),

FloatUnary(FloatUnOp),
FloatBinary(FloatBinOp),
}

#[derive(Copy, Clone, PartialEq, Eq, Hash)]
pub enum BoolUnOp {
Not,
}

#[derive(Copy, Clone, PartialEq, Eq, Hash)]
pub enum BoolBinOp {
Eq,
// FIXME(eddyb) should this be `Xor` instead?
Ne,
Or,
And,
}

#[derive(Copy, Clone, PartialEq, Eq, Hash)]
pub enum IntUnOp {
Neg,
Not,
CountOnes,

// FIXME(eddyb) ideally `Trunc` should be separated and common.
TruncOrZeroExtend,
TruncOrSignExtend,
}

#[derive(Copy, Clone, PartialEq, Eq, Hash)]
pub enum IntBinOp {
// I×I→I
Add,
Sub,
Mul,
DivU,
DivS,
ModU,
RemS,
ModS,
ShrU,
ShrS,
Shl,
Or,
Xor,
And,

// I×I→I×I
CarryingAdd,
BorrowingSub,
WideningMulU,
WideningMulS,

// I×I→B
Eq,
Ne,
// FIXME(eddyb) deduplicate between signed and unsigned.
GtU,
GtS,
GeU,
GeS,
LtU,
LtS,
LeU,
LeS,
}

#[derive(Copy, Clone, PartialEq, Eq, Hash)]
pub enum FloatUnOp {
// F→F
Neg,

// F→B
IsNan,
IsInf,

// FIXME(eddyb) these are a complicated mix of signatures.
FromUInt,
FromSInt,
ToUInt,
ToSInt,
Convert,
QuantizeAsF16,
}

#[derive(Copy, Clone, PartialEq, Eq, Hash)]
pub enum FloatBinOp {
// F×F→F
Add,
Sub,
Mul,
Div,
Rem,
Mod,

// F×F→B
Cmp(FloatCmp),
// FIXME(eddyb) this doesn't properly convey that this is effectively the
// boolean flip of the opposite comparison, e.g. `CmpOrUnord(Ge)` is really
// a fused version of `Not(Cmp(Lt))`, because `x < y` is never `true` for
// unordered `x` and `y` (i.e. `PartialOrd::partial_cmp(x, y) == None`),
// but that maps to `!(x < y)` always being `true` for unordered `x` and `y`,
// and thus `x >= y` is only equivalent for the ordered cases.
CmpOrUnord(FloatCmp),
}

#[derive(Copy, Clone, PartialEq, Eq, Hash)]
pub enum FloatCmp {
Eq,
Ne,
Lt,
Gt,
Le,
Ge,
}

impl Op {
pub fn output_count(self) -> usize {
match self {
Op::IntBinary(op) => op.output_count(),
_ => 1,
}
}

pub fn name(self) -> &'static str {
match self {
Op::BoolUnary(op) => op.name(),
Op::BoolBinary(op) => op.name(),

Op::IntUnary(op) => op.name(),
Op::IntBinary(op) => op.name(),

Op::FloatUnary(op) => op.name(),
Op::FloatBinary(op) => op.name(),
}
}
}

impl BoolUnOp {
pub fn name(self) -> &'static str {
match self {
BoolUnOp::Not => "bool.not",
}
}
}

impl BoolBinOp {
pub fn name(self) -> &'static str {
match self {
BoolBinOp::Eq => "bool.eq",
BoolBinOp::Ne => "bool.ne",
BoolBinOp::Or => "bool.or",
BoolBinOp::And => "bool.and",
}
}
}

impl IntUnOp {
pub fn name(self) -> &'static str {
match self {
IntUnOp::Neg => "i.neg",
IntUnOp::Not => "i.not",
IntUnOp::CountOnes => "i.count_ones",

IntUnOp::TruncOrZeroExtend => "u.trunc_or_zext",
IntUnOp::TruncOrSignExtend => "s.trunc_or_sext",
}
}
}

impl IntBinOp {
pub fn output_count(self) -> usize {
// FIXME(eddyb) should these 4 go into a different `enum`?
match self {
IntBinOp::CarryingAdd
| IntBinOp::BorrowingSub
| IntBinOp::WideningMulU
| IntBinOp::WideningMulS => 2,
_ => 1,
}
}

pub fn name(self) -> &'static str {
match self {
IntBinOp::Add => "i.add",
IntBinOp::Sub => "i.sub",
IntBinOp::Mul => "i.mul",
IntBinOp::DivU => "u.div",
IntBinOp::DivS => "s.div",
IntBinOp::ModU => "u.mod",
IntBinOp::RemS => "s.rem",
IntBinOp::ModS => "s.mod",
IntBinOp::ShrU => "u.shr",
IntBinOp::ShrS => "s.shr",
IntBinOp::Shl => "i.shl",
IntBinOp::Or => "i.or",
IntBinOp::Xor => "i.xor",
IntBinOp::And => "i.and",
IntBinOp::CarryingAdd => "i.carrying_add",
IntBinOp::BorrowingSub => "i.borrowing_sub",
IntBinOp::WideningMulU => "u.widening_mul",
IntBinOp::WideningMulS => "s.widening_mul",
IntBinOp::Eq => "i.eq",
IntBinOp::Ne => "i.ne",
IntBinOp::GtU => "u.gt",
IntBinOp::GtS => "s.gt",
IntBinOp::GeU => "u.ge",
IntBinOp::GeS => "s.ge",
IntBinOp::LtU => "u.lt",
IntBinOp::LtS => "s.lt",
IntBinOp::LeU => "u.le",
IntBinOp::LeS => "s.le",
}
}
}

impl FloatUnOp {
pub fn name(self) -> &'static str {
match self {
FloatUnOp::Neg => "f.neg",

FloatUnOp::IsNan => "f.is_nan",
FloatUnOp::IsInf => "f.is_inf",

FloatUnOp::FromUInt => "f.from_uint",
FloatUnOp::FromSInt => "f.from_sint",
FloatUnOp::ToUInt => "f.to_uint",
FloatUnOp::ToSInt => "f.to_sint",
FloatUnOp::Convert => "f.convert",
FloatUnOp::QuantizeAsF16 => "f.quantize_as_f16",
}
}
}

impl FloatBinOp {
pub fn name(self) -> &'static str {
match self {
FloatBinOp::Add => "f.add",
FloatBinOp::Sub => "f.sub",
FloatBinOp::Mul => "f.mul",
FloatBinOp::Div => "f.div",
FloatBinOp::Rem => "f.rem",
FloatBinOp::Mod => "f.mod",
FloatBinOp::Cmp(FloatCmp::Eq) => "f.eq",
FloatBinOp::Cmp(FloatCmp::Ne) => "f.ne",
FloatBinOp::Cmp(FloatCmp::Lt) => "f.lt",
FloatBinOp::Cmp(FloatCmp::Gt) => "f.gt",
FloatBinOp::Cmp(FloatCmp::Le) => "f.le",
FloatBinOp::Cmp(FloatCmp::Ge) => "f.ge",
FloatBinOp::CmpOrUnord(FloatCmp::Eq) => "f.eq_or_unord",
FloatBinOp::CmpOrUnord(FloatCmp::Ne) => "f.ne_or_unord",
FloatBinOp::CmpOrUnord(FloatCmp::Lt) => "f.lt_or_unord",
FloatBinOp::CmpOrUnord(FloatCmp::Gt) => "f.gt_or_unord",
FloatBinOp::CmpOrUnord(FloatCmp::Le) => "f.le_or_unord",
FloatBinOp::CmpOrUnord(FloatCmp::Ge) => "f.ge_or_unord",
}
}
}
Loading

0 comments on commit 8af71b9

Please sign in to comment.