Skip to content

Commit

Permalink
match program.fbs definition with tt-metal
Browse files Browse the repository at this point in the history
  • Loading branch information
mmanzoorTT committed Jan 21, 2025
1 parent 6e51177 commit 81b036b
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 9 deletions.
2 changes: 1 addition & 1 deletion include/ttmlir/Target/TTNN/program.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ table ReductionProdOp {
in: tt.target.TensorRef;
out: tt.target.TensorRef;
all_dimensions: bool;
dim_arg: [int32];
dim_arg: int64;
keep_dim: bool;
}

Expand Down
12 changes: 8 additions & 4 deletions lib/Target/TTNN/TTNNToFlatbuffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -799,12 +799,16 @@ createReductionProdOp(FlatbufferObjectCache &cache, ReductionOp op) {
cache.at<::tt::target::TensorRef>(getOperandThroughDPSOps(op.getInput()));
auto output = cache.getOrCreate(op.getResult(), tensorValueToFlatbuffer,
kHostAllocatedAddress, kHostAllocatedSize);
auto dimArg =
arrayAttrToFlatbuffer<mlir::IntegerAttr, int>(cache, op.getDimArg());
bool allDimensions = op.getDimArg() ? false : true;
auto dimArg = op.getDimArg();
// arrayAttrToFlatbuffer<mlir::IntegerAttr, int>(cache, op.getDimArg());
bool allDimensions = dimArg ? false : true;
int64_t dimension =
dimArg ? (mlir::cast<mlir::IntegerAttr>(dimArg->getValue()[0])).getInt()
: 0;
llvm::errs() << "dimension: " << dimension << '\t' << allDimensions << '\n';

return ::tt::target::ttnn::CreateReductionProdOp(
*cache.fbb, in, output, allDimensions, dimArg, op.getKeepDim());
*cache.fbb, in, output, allDimensions, dimension, op.getKeepDim());
}

::flatbuffers::Offset<::tt::target::ttnn::TransposeOp>
Expand Down
5 changes: 1 addition & 4 deletions runtime/lib/ttnn/operations/reduction/prod.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,8 @@ static void runReductionProdOp(::tt::target::ttnn::ReductionProdOp const *op,
const ::ttnn::Tensor &in = tensorPool.at(op->in()->global_id());
DEBUG_ASSERT(in.is_allocated());

const auto *fbDimArg = op->dim_arg();
int dim = fbDimArg ? static_cast<int>(*fbDimArg->begin()) : 0;

::ttnn::Tensor out =
::ttnn::prod(in, op->all_dimensions(), dim, op->keep_dim(),
::ttnn::prod(in, op->all_dimensions(), op->dim_arg(), op->keep_dim(),
outputMemoryConfig /* memory_config_arg */);

tensorPool.insert_or_assign(op->out()->global_id(), out);
Expand Down

0 comments on commit 81b036b

Please sign in to comment.