Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

int8 and uint8 support in Pad and other ops #387

Merged
merged 3 commits into from
Oct 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions src/ops/convert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,27 @@ fn cast(pool: &TensorPool, input: Input, dtype: DataType) -> Result<Output, OpEr
DataType::Int32 => match input {
Input::Int32Tensor(t) => Ok(t.map_in(pool, |x| *x).into()),
Input::FloatTensor(t) => Ok(t.map_in(pool, |x| *x as i32).into()),
_ => Err(OpError::UnsupportedType),
Input::Int8Tensor(t) => Ok(t.map_in(pool, |x| *x as i32).into()),
Input::UInt8Tensor(t) => Ok(t.map_in(pool, |x| *x as i32).into()),
},
DataType::Float => match input {
Input::FloatTensor(t) => Ok(t.map_in(pool, |x| *x).into()),
Input::Int32Tensor(t) => Ok(t.map_in(pool, |x| *x as f32).into()),
_ => Err(OpError::UnsupportedType),
Input::Int8Tensor(t) => Ok(t.map_in(pool, |x| *x as f32).into()),
Input::UInt8Tensor(t) => Ok(t.map_in(pool, |x| *x as f32).into()),
},
DataType::Int8 => match input {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that using as for casts is fine for all cases here. Just making some notes for myself:

Converting non-i8 types to i8 will truncate in various ways, and the same for u8. The Cast op specs are here. The main cases of interest are:

  • Float -> int when out of range: ONNX spec says undefined, so an as cast is fine
  • Int -> Int when out of range: ONNX spec says when OOR, discard higher bits and reinterpret (with respect to two’s complement representation for signed types). For example, 200 (int16) -> -56 (int8).. This matches the behavior of as in Rust

Input::FloatTensor(t) => Ok(t.map_in(pool, |x| *x as i8).into()),
Input::Int32Tensor(t) => Ok(t.map_in(pool, |x| *x as i8).into()),
Input::Int8Tensor(t) => Ok(t.map_in(pool, |x| *x).into()),
Input::UInt8Tensor(t) => Ok(t.map_in(pool, |x| *x as i8).into()),
},
DataType::UInt8 => match input {
Input::FloatTensor(t) => Ok(t.map_in(pool, |x| *x as u8).into()),
Input::Int32Tensor(t) => Ok(t.map_in(pool, |x| *x as u8).into()),
Input::Int8Tensor(t) => Ok(t.map_in(pool, |x| *x as u8).into()),
Input::UInt8Tensor(t) => Ok(t.map_in(pool, |x| *x).into()),
},
_ => Err(OpError::UnsupportedValue("Unsupported cast")),
}
}

Expand Down
30 changes: 27 additions & 3 deletions src/ops/gather.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ impl Operator for Gather {
Input::Int32Tensor(input) => gather(pool, input, self.axis, indices).into_op_result(),
Input::FloatTensor(input) => gather(pool, input, self.axis, indices).into_op_result(),
Input::UInt8Tensor(input) => gather(pool, input, self.axis, indices).into_op_result(),
_ => Err(OpError::UnsupportedType),
Input::Int8Tensor(input) => gather(pool, input, self.axis, indices).into_op_result(),
}
}
}
Expand Down Expand Up @@ -238,7 +238,12 @@ impl Operator for GatherElements {
Input::FloatTensor(input) => {
gather_elements(pool, input, indices, self.axis).into_op_result()
}
_ => Err(OpError::UnsupportedType),
Input::Int8Tensor(input) => {
gather_elements(pool, input, indices, self.axis).into_op_result()
}
Input::UInt8Tensor(input) => {
gather_elements(pool, input, indices, self.axis).into_op_result()
}
}
}
}
Expand Down Expand Up @@ -336,7 +341,12 @@ impl Operator for GatherND {
Input::FloatTensor(input) => {
gather_nd(pool, input, indices, self.batch_dims).into_op_result()
}
_ => Err(OpError::UnsupportedType),
Input::Int8Tensor(input) => {
gather_nd(pool, input, indices, self.batch_dims).into_op_result()
}
Input::UInt8Tensor(input) => {
gather_nd(pool, input, indices, self.batch_dims).into_op_result()
}
}
}
}
Expand Down Expand Up @@ -451,6 +461,14 @@ impl Operator for ScatterElements {
scatter_elements(pool, data, indices, updates, self.axis, self.reduction)
.into_op_result()
}
(Input::Int8Tensor(data), Input::Int8Tensor(updates)) => {
scatter_elements(pool, data, indices, updates, self.axis, self.reduction)
.into_op_result()
}
(Input::UInt8Tensor(data), Input::UInt8Tensor(updates)) => {
scatter_elements(pool, data, indices, updates, self.axis, self.reduction)
.into_op_result()
}
_ => Err(OpError::UnsupportedType),
}
}
Expand Down Expand Up @@ -547,6 +565,12 @@ impl Operator for ScatterND {
(Input::FloatTensor(data), Input::FloatTensor(updates)) => {
scatter_nd(pool, data, indices, updates, self.reduction).into_op_result()
}
(Input::Int8Tensor(data), Input::Int8Tensor(updates)) => {
scatter_nd(pool, data, indices, updates, self.reduction).into_op_result()
}
(Input::UInt8Tensor(data), Input::UInt8Tensor(updates)) => {
scatter_nd(pool, data, indices, updates, self.reduction).into_op_result()
}
_ => Err(OpError::UnsupportedType),
}
}
Expand Down
51 changes: 40 additions & 11 deletions src/ops/layout.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,8 @@ impl Operator for Expand {
match input {
Input::FloatTensor(input) => expand(pool, input, &shape).into_op_result(),
Input::Int32Tensor(input) => expand(pool, input, &shape).into_op_result(),
_ => Err(OpError::UnsupportedType),
Input::UInt8Tensor(input) => expand(pool, input, &shape).into_op_result(),
Input::Int8Tensor(input) => expand(pool, input, &shape).into_op_result(),
}
}

Expand All @@ -122,7 +123,8 @@ impl Operator for Expand {
let output: Output = match input {
Output::FloatTensor(input) => expand_to(pool, input.view(), &out_shape).into(),
Output::Int32Tensor(input) => expand_to(pool, input.view(), &out_shape).into(),
_ => return Err(OpError::UnsupportedType),
Output::Int8Tensor(input) => expand_to(pool, input.view(), &out_shape).into(),
Output::UInt8Tensor(input) => expand_to(pool, input.view(), &out_shape).into(),
};
Ok(output)
}
Expand Down Expand Up @@ -172,7 +174,8 @@ impl Operator for Flatten {
match input {
Input::FloatTensor(input) => flatten(pool, input, self.axis).into_op_result(),
Input::Int32Tensor(input) => flatten(pool, input, self.axis).into_op_result(),
_ => Err(OpError::UnsupportedType),
Input::Int8Tensor(input) => flatten(pool, input, self.axis).into_op_result(),
Input::UInt8Tensor(input) => flatten(pool, input, self.axis).into_op_result(),
}
}

Expand All @@ -195,7 +198,14 @@ impl Operator for Flatten {
flatten_in_place(pool, &mut output, self.axis)?;
Ok(output.into())
}
_ => Err(OpError::UnsupportedType),
Output::Int8Tensor(mut output) => {
flatten_in_place(pool, &mut output, self.axis)?;
Ok(output.into())
}
Output::UInt8Tensor(mut output) => {
flatten_in_place(pool, &mut output, self.axis)?;
Ok(output.into())
}
}
}
}
Expand Down Expand Up @@ -314,7 +324,8 @@ impl Operator for Reshape {
match input {
Input::Int32Tensor(t) => reshape(pool, t, &shape, self.allow_zero).into_op_result(),
Input::FloatTensor(t) => reshape(pool, t, &shape, self.allow_zero).into_op_result(),
_ => Err(OpError::UnsupportedType),
Input::Int8Tensor(t) => reshape(pool, t, &shape, self.allow_zero).into_op_result(),
Input::UInt8Tensor(t) => reshape(pool, t, &shape, self.allow_zero).into_op_result(),
}
}

Expand All @@ -340,7 +351,14 @@ impl Operator for Reshape {
reshape_in_place(pool, &mut output, &shape, self.allow_zero)?;
Ok(output.into())
}
_ => Err(OpError::UnsupportedType),
Output::Int8Tensor(mut output) => {
reshape_in_place(pool, &mut output, &shape, self.allow_zero)?;
Ok(output.into())
}
Output::UInt8Tensor(mut output) => {
reshape_in_place(pool, &mut output, &shape, self.allow_zero)?;
Ok(output.into())
}
}
}
}
Expand Down Expand Up @@ -449,7 +467,8 @@ impl Operator for Squeeze {
match input {
Input::FloatTensor(t) => squeeze(pool, t, axes).into_op_result(),
Input::Int32Tensor(t) => squeeze(pool, t, axes).into_op_result(),
_ => Err(OpError::UnsupportedType),
Input::Int8Tensor(t) => squeeze(pool, t, axes).into_op_result(),
Input::UInt8Tensor(t) => squeeze(pool, t, axes).into_op_result(),
}
}

Expand All @@ -475,7 +494,14 @@ impl Operator for Squeeze {
squeeze_in_place(&mut t, axes)?;
Ok(t.into())
}
_ => Err(OpError::UnsupportedType),
Output::UInt8Tensor(mut t) => {
squeeze_in_place(&mut t, axes)?;
Ok(t.into())
}
Output::Int8Tensor(mut t) => {
squeeze_in_place(&mut t, axes)?;
Ok(t.into())
}
}
}
}
Expand Down Expand Up @@ -519,7 +545,8 @@ impl Operator for Transpose {
match input {
Input::FloatTensor(input) => transpose(pool, input, perm_slice).into_op_result(),
Input::Int32Tensor(input) => transpose(pool, input, perm_slice).into_op_result(),
_ => Err(OpError::UnsupportedType),
Input::Int8Tensor(input) => transpose(pool, input, perm_slice).into_op_result(),
Input::UInt8Tensor(input) => transpose(pool, input, perm_slice).into_op_result(),
}
}
}
Expand Down Expand Up @@ -577,7 +604,8 @@ impl Operator for Unsqueeze {
match input {
Input::FloatTensor(input) => unsqueeze(pool, input, &axes).into_op_result(),
Input::Int32Tensor(input) => unsqueeze(pool, input, &axes).into_op_result(),
_ => Err(OpError::UnsupportedType),
Input::Int8Tensor(input) => unsqueeze(pool, input, &axes).into_op_result(),
Input::UInt8Tensor(input) => unsqueeze(pool, input, &axes).into_op_result(),
}
}

Expand All @@ -597,7 +625,8 @@ impl Operator for Unsqueeze {
match output {
Output::FloatTensor(t) => unsqueeze_in_place(t, &axes).map(Output::FloatTensor),
Output::Int32Tensor(t) => unsqueeze_in_place(t, &axes).map(Output::Int32Tensor),
_ => Err(OpError::UnsupportedType),
Output::Int8Tensor(t) => unsqueeze_in_place(t, &axes).map(Output::Int8Tensor),
Output::UInt8Tensor(t) => unsqueeze_in_place(t, &axes).map(Output::UInt8Tensor),
}
}
}
Expand Down
9 changes: 8 additions & 1 deletion src/ops/pad.rs
Original file line number Diff line number Diff line change
Expand Up @@ -192,11 +192,18 @@ impl Operator for Pad {
let const_val = inputs.get_as_scalar::<i32>(2)?.unwrap_or(0);
pad(pool, t, &pads, self.mode, const_val).into_op_result()
}
Input::Int8Tensor(t) => {
let const_val = inputs.get_as_scalar::<i8>(2)?.unwrap_or(0);
pad(pool, t, &pads, self.mode, const_val).into_op_result()
}
Input::UInt8Tensor(t) => {
let const_val = inputs.get_as_scalar::<u8>(2)?.unwrap_or(0);
pad(pool, t, &pads, self.mode, const_val).into_op_result()
}
Input::FloatTensor(t) => {
let const_val = inputs.get_as_scalar::<f32>(2)?.unwrap_or(0.);
pad(pool, t, &pads, self.mode, const_val).into_op_result()
}
_ => Err(OpError::UnsupportedType),
}
}
}
Expand Down
16 changes: 14 additions & 2 deletions src/ops/slice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,12 @@ impl Operator for Slice {
Input::Int32Tensor(input) => {
slice(pool, input, &starts, &ends, axes.as_ref(), steps.as_ref()).map(|t| t.into())
}
_ => Err(OpError::UnsupportedType),
Input::Int8Tensor(input) => {
slice(pool, input, &starts, &ends, axes.as_ref(), steps.as_ref()).map(|t| t.into())
}
Input::UInt8Tensor(input) => {
slice(pool, input, &starts, &ends, axes.as_ref(), steps.as_ref()).map(|t| t.into())
}
};
result.into_op_result()
}
Expand Down Expand Up @@ -168,7 +173,14 @@ impl Operator for Slice {
slice_in_place(&mut output, &starts, &ends, axes.as_ref())?;
Ok(output.into())
}
_ => Err(OpError::UnsupportedType),
Output::Int8Tensor(mut output) => {
slice_in_place(&mut output, &starts, &ends, axes.as_ref())?;
Ok(output.into())
}
Output::UInt8Tensor(mut output) => {
slice_in_place(&mut output, &starts, &ends, axes.as_ref())?;
Ok(output.into())
}
}
}
}
Expand Down
Loading