Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add general convolution TTIR op. Add pass that legalizes it to Conv2D #1085

Closed
wants to merge 1 commit into from

Conversation

LPanosTT
Copy link
Contributor

  • Added ConvolutionOp, a very flexible convolution operation that can represent a convolution (and conv-transpose) of any dimensionality
    • Adding because stablehlo only has one convolution operation exactly like this, and we don't want to legalize to specific convolutions in dialect conversion, rather transform it within the same dialect.
    • Created a TTIR_ConvLayoutAttr to store the dimension numbers of the input, weight, and output.
      • I.e which index is the batch dimension, which index is the channels dimension, which indices are the spatial dimensions
  • Added a pass that recognizes when a ConvolutionOp is a legal Conv2d (or can be converted to one by permuting the input, weights, and/or output) and performs the transformation.
  • In the future we can add passes that legalize ConvolutionOp to any ConvNd or ConvNdTranspose.

@LPanosTT LPanosTT force-pushed the lpanos/new_conv branch 2 times, most recently from b01134c to 42a9bca Compare October 28, 2024 15:45
MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); }

static ArrayRef<bool> getDefaultWindowReversal(const ConvolutionLayoutAttr &convolutionLayout) {
static bool boolArray[1000];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I missed this last time around, let's just make this a SmallVector<bool> boolArray(convolutionLayout.getInputSpatialDimensions().size()).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was able to use SmallVector instead of std::vector for all the attributes actually. Also was able to get rid of the helper because of that. Take a look :)

Copy link
Contributor

@nsmithtt nsmithtt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great!

…o convolution to it

Add pass to transform eligible ConvolutionOp --> Conv2dOp in TTIR
@sdjordjevicTT
Copy link
Contributor

This converts broad TTIR Conv to a more specific and narrow one. Should we leverage Jovan's work to consolidate the decomposition into a narrower set of TTIR operations?
#969

@LPanosTT
Copy link
Contributor Author

This converts broad TTIR Conv to a more specific and narrow one. Should we leverage Jovan's work to consolidate the decomposition into a narrower set of TTIR operations? #969

Sure I suppose it can go in that pipeline. I'll put it in there if his PR gets merged first, otherwise I'll do it after.

@sdjordjevicTT
Copy link
Contributor

I belive @jserbedzijaTT will merge this soon, can we wait a day or two to consolidate things from the start? :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants