-
Notifications
You must be signed in to change notification settings - Fork 458
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: bitwise-ops-for-tensors #2498
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great start! Some minor comments but here is my preliminary review:
I think the bitwise ops should only be added as an int tensor operation, it doesn't make sense for floats.
We could eventually extend the ops to boolean tensors with their logical counterpart (would be applied on a single bit represented by the bool), but this can be left for another PR.
We can leave the candle ops as unimplemented, but for the JIT backends we should wait to merge once it's implemented with cubecl.
@@ -376,4 +376,31 @@ impl<F: FloatCandleElement, I: IntCandleElement> IntTensorOps<Self> for Candle<F | |||
fn int_sign(tensor: IntTensor<Self>) -> IntTensor<Self> { | |||
sign(tensor) | |||
} | |||
fn bitwise_and(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> { | |||
todo!("bitwise_and is not implemented for Candle IntTensor"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tiny detail, but I would mark those with unimplemented!(...)
instead
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, will address
@@ -795,6 +873,40 @@ where | |||
let output = B::int_powf(lhs, rhs); | |||
handles.register_int_tensor::<B>(&desc.out.id, output); | |||
} | |||
NumericOperationDescription::BitwiseAnd(desc) => { | |||
// binary_int_ops!(handles, desc, B::int_bitwise_and(lhs, rhs)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The macro should work here, no? 🤔
TchTensor::binary_ops_tensor( | ||
lhs, | ||
rhs, | ||
|lhs, rhs| lhs.f_bitwise_and_tensor(rhs).unwrap(), | ||
|lhs, rhs| rhs.f_bitwise_and_tensor(lhs).unwrap(), | ||
|lhs, rhs| lhs.f_bitwise_and_tensor(rhs).unwrap(), | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
binary_ops_tensor
has
FLMut: Fn(&mut tch::Tensor, &tch::Tensor) -> tch::Tensor,
FRMut: Fn(&tch::Tensor, &mut tch::Tensor) -> tch::Tensor,
FRef: Fn(&tch::Tensor, &tch::Tensor) -> tch::Tensor,
so I think the definitions should use the inplace version as well, similar to the other ops defined. Also, you can use the methods without the f_
prefix and you won't have to unwrap. It should look something like this:
TchTensor::binary_ops_tensor(
lhs,
rhs,
|lhs, rhs| lhs.bitwise_and_tensor_(rhs),
|lhs, rhs| rhs.bitwise_and_tensor_(lhs),
|lhs, rhs| lhs.bitwise_and_tensor(rhs),
)
The same applies to all other bitwise ops below.
/// Operation corresponding to: | ||
/// | ||
/// Float => [add](crate::ops::FloatTensorOps::float_add). | ||
/// Int => [add](crate::ops::IntTensorOps::int_add). | ||
BitwiseAnd(BinaryOperationDescription), | ||
/// Operation corresponding to: | ||
/// | ||
/// Float => [div scalar](crate::ops::FloatTensorOps::float_div_scalar). | ||
/// Int => [div scalar](crate::ops::IntTensorOps::int_div_scalar). | ||
BitwiseAndScalar(ScalarOperationDescription<E>), | ||
/// Operation corresponding to: | ||
/// | ||
/// Float => [add](crate::ops::FloatTensorOps::float_add). | ||
/// Int => [add](crate::ops::IntTensorOps::int_add). | ||
BitwiseOr(BinaryOperationDescription), | ||
/// Operation corresponding to: | ||
/// | ||
/// Float => [div scalar](crate::ops::FloatTensorOps::float_div_scalar). | ||
/// Int => [div scalar](crate::ops::IntTensorOps::int_div_scalar). | ||
BitwiseOrScalar(ScalarOperationDescription<E>), | ||
/// Operation corresponding to: | ||
/// | ||
/// Float => [add](crate::ops::FloatTensorOps::float_add). | ||
/// Int => [add](crate::ops::IntTensorOps::int_add). | ||
BitwiseXor(BinaryOperationDescription), | ||
/// Operation corresponding to: | ||
/// | ||
/// Float => [div scalar](crate::ops::FloatTensorOps::float_div_scalar). | ||
/// Int => [div scalar](crate::ops::IntTensorOps::int_div_scalar). | ||
BitwiseXorScalar(ScalarOperationDescription<E>), | ||
/// Operation corresponding to: | ||
/// | ||
/// Float => [powf](crate::ops::FloatTensorOps::float_powf). | ||
/// Int => [powf](crate::ops::IntTensorOps::int_powf). | ||
BitwiseNot(UnaryOperationDescription), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Copy/pasta docstrings should be fixed.
Also, not sure if this should be a numeric op since it won't be implemented for float (and doesn't really make sense either). I think it should got in the int operations only (and possible add the logical operations that are equivalent for bool).
@@ -856,6 +856,41 @@ where | |||
// Check if the sum is NaN by comparing it to itself | |||
Tensor::new(K::not_equal(sum.clone(), sum)) | |||
} | |||
|
|||
/// Applies the element wise logical and operation with another tensor. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think a more accurate description for integers is that it applies the bitwise logical and (i.e., for each bit that represents the integer).
mod bitwise_and; | ||
mod bitwise_and_scalar; | ||
mod bitwise_not; | ||
mod bitwise_or; | ||
mod bitwise_or_scalar; | ||
mod bitwise_xor; | ||
mod bitwise_xor_scalar; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think tests for bitwise operators could be grouped into a single file 🙂
The idea behind the whole float implementation was to convert to int before computing and back to float after, you're right, doesn't really make sense, will make the changes. Raised an issue in the cubecl repo for adding the bitwise op support, will look at that first before implementing here |
This PR has been marked as stale because it has not been updated for over a month |
Pull Request Template
Checklist
run-checks all
script has been executed.Related Issues/PRs
#2234
Blocked by CubeCL
Changes
Bitwise Operations for Tensors
Testing
The corresponding tests for the ops were included under the
burn_tensor/tensor/tests
directory.Candle seems to not have bitwise operations so as it stands the implementation for the candle backend
is replaced with the todo macro.