diff --git a/include/ttmlir/Target/TTNN/program.fbs b/include/ttmlir/Target/TTNN/program.fbs index e797ecf924..84200bcd90 100644 --- a/include/ttmlir/Target/TTNN/program.fbs +++ b/include/ttmlir/Target/TTNN/program.fbs @@ -7,12 +7,14 @@ table GetDeviceOp { mesh: Dim2d; chip_ids: [uint32]; out: tt.target.DeviceRef; + debug_info: string; } table ToMemoryConfigOp { in0: tt.target.TensorRef; memcfg: MemoryConfigDesc; out: tt.target.TensorRef; + debug_info: string; } table ToLayoutOp { @@ -22,12 +24,14 @@ table ToLayoutOp { memcfg: tt.target.MemoryConfigDesc; device: tt.target.DeviceRef; out: tt.target.TensorRef; + debug_info: string; } table TypecastOp { in: tt.target.TensorRef; dtype: tt.target.DataType; out: tt.target.TensorRef; + debug_info: string; } table ToDeviceOp { @@ -35,11 +39,13 @@ table ToDeviceOp { device: tt.target.DeviceRef; memcfg: tt.target.MemoryConfigDesc; out: tt.target.TensorRef; + debug_info: string; } table FromDeviceOp { in: tt.target.TensorRef; out: tt.target.TensorRef; + debug_info: string; } table EmptyOp { @@ -51,6 +57,7 @@ table EmptyOp { memcfg: tt.target.MemoryConfigDesc; // optional strategy: tt.target.DistributionStrategy; out: tt.target.TensorRef; + debug_info: string; } table FullOp { @@ -59,6 +66,7 @@ table FullOp { num_shards: uint32; strategy: tt.target.DistributionStrategy; out: tt.target.TensorRef; + debug_info: string; } enum EltwiseOpType: uint32 { @@ -108,6 +116,7 @@ table EltwiseOp { ins: [tt.target.TensorRef]; out: tt.target.TensorRef; params: EltwiseOpParams; + debug_info: string; } enum ReductionOpType: uint32 { @@ -122,18 +131,21 @@ table ReductionOp { out: tt.target.TensorRef; dim_arg: [int32]; keep_dim: bool; + debug_info: string; } table EmbeddingOp { input: tt.target.TensorRef; weight: tt.target.TensorRef; out: tt.target.TensorRef; + debug_info: string; } table SoftmaxOp { in: tt.target.TensorRef; out: tt.target.TensorRef; dimension: int32; + debug_info: string; } table TransposeOp { @@ -141,18 +153,21 @@ table TransposeOp { out: tt.target.TensorRef; dim0: int32; dim1: int32; + debug_info: string; } table ConcatOp { inputs: [tt.target.TensorRef]; out: tt.target.TensorRef; dim: int32; + debug_info: string; } table ReshapeOp { in: tt.target.TensorRef; out: tt.target.TensorRef; shape: [int32]; + debug_info: string; } table SliceOp { @@ -161,6 +176,7 @@ table SliceOp { begins: [int64]; ends: [int64]; step: [int64]; + debug_info: string; } // ANCHOR: adding_an_op_matmul_fbs @@ -168,6 +184,7 @@ table MatmulOp { in0: tt.target.TensorRef; in1: tt.target.TensorRef; out: tt.target.TensorRef; + debug_info: string; } // ANCHOR_END: adding_an_op_matmul_fbs @@ -191,6 +208,7 @@ table Conv2dOp { dilation_height: uint32; dilation_width: uint32; groups: uint32; + debug_info: string; } table MaxPool2dOp { @@ -210,10 +228,12 @@ table MaxPool2dOp { ceil_mode: bool; padding_height: uint32; padding_width: uint32; + debug_info: string; } table DeallocOp { in: tt.target.TensorRef; + debug_info: string; } table AllGatherOp { @@ -221,6 +241,7 @@ table AllGatherOp { out: tt.target.TensorRef; dim: uint32; num_links: uint32; + debug_info: string; } union OpType { diff --git a/lib/Target/TTNN/TTNNToFlatbuffer.cpp b/lib/Target/TTNN/TTNNToFlatbuffer.cpp index a3262a680d..ba22a057c3 100644 --- a/lib/Target/TTNN/TTNNToFlatbuffer.cpp +++ b/lib/Target/TTNN/TTNNToFlatbuffer.cpp @@ -95,7 +95,8 @@ createOperation(FlatbufferObjectCache &cache, ::flatbuffers::Offset op, } ::flatbuffers::Offset<::tt::target::ttnn::GetDeviceOp> -createOp(FlatbufferObjectCache &cache, GetDeviceOp op) { +createOp(FlatbufferObjectCache &cache, GetDeviceOp op, + std::string const &debugString) { auto result = op.getResult(); auto resultType = mlir::cast(result.getType()); auto meshShape = resultType.getDesc().getMeshShape(); @@ -109,11 +110,13 @@ createOp(FlatbufferObjectCache &cache, GetDeviceOp op) { ::tt::target::Dim2d mesh(1, meshVolume); auto chipIds = toFlatbuffer(cache, resultType.getDesc().getChipIds()); auto out = cache.getOrCreate(result, createDeviceRef); - return ::tt::target::ttnn::CreateGetDeviceOp(*cache.fbb, &mesh, chipIds, out); + return ::tt::target::ttnn::CreateGetDeviceOp( + *cache.fbb, &mesh, chipIds, out, cache.fbb->CreateString(debugString)); } ::flatbuffers::Offset<::tt::target::ttnn::ToMemoryConfigOp> -createOp(FlatbufferObjectCache &cache, ToMemoryConfigOp op) { +createOp(FlatbufferObjectCache &cache, ToMemoryConfigOp op, + std::string const &debugString) { auto input = cache.at<::tt::target::TensorRef>(getOperandThroughDPSOps(op.getInput())); @@ -122,12 +125,14 @@ createOp(FlatbufferObjectCache &cache, ToMemoryConfigOp op) { auto output = cache.getOrCreate(op.getResult(), tensorValueToFlatbuffer, kHostAllocatedAddress, kHostAllocatedSize); - return ::tt::target::ttnn::CreateToMemoryConfigOp(*cache.fbb, input, - memoryConfigDesc, output); + return ::tt::target::ttnn::CreateToMemoryConfigOp( + *cache.fbb, input, memoryConfigDesc, output, + cache.fbb->CreateString(debugString)); } ::flatbuffers::Offset<::tt::target::ttnn::ToLayoutOp> -createOp(FlatbufferObjectCache &cache, ToLayoutOp op) { +createOp(FlatbufferObjectCache &cache, ToLayoutOp op, + std::string const &debugString) { auto input = cache.at<::tt::target::TensorRef>(getOperandThroughDPSOps(op.getInput())); ::tt::target::TensorLayout layout = @@ -151,11 +156,13 @@ createOp(FlatbufferObjectCache &cache, ToLayoutOp op) { memoryConfig.has_value() ? cache.getOrCreate(memoryConfig.value(), memoryConfigToFlatbuffer) : 0, - device ? cache.at<::tt::target::DeviceRef>(device) : 0, output); + device ? cache.at<::tt::target::DeviceRef>(device) : 0, output, + cache.fbb->CreateString(debugString)); } ::flatbuffers::Offset<::tt::target::ttnn::TypecastOp> -createOp(FlatbufferObjectCache &cache, TypecastOp op) { +createOp(FlatbufferObjectCache &cache, TypecastOp op, + std::string const &debugString) { auto input = cache.at<::tt::target::TensorRef>(getOperandThroughDPSOps(op.getInput())); ::tt::target::DataType dtype = @@ -163,11 +170,13 @@ createOp(FlatbufferObjectCache &cache, TypecastOp op) { auto output = cache.getOrCreate(op.getResult(), tensorValueToFlatbuffer, kHostAllocatedAddress, kHostAllocatedSize); - return ::tt::target::ttnn::CreateTypecastOp(*cache.fbb, input, dtype, output); + return ::tt::target::ttnn::CreateTypecastOp( + *cache.fbb, input, dtype, output, cache.fbb->CreateString(debugString)); } ::flatbuffers::Offset<::tt::target::ttnn::ToDeviceOp> -createOp(FlatbufferObjectCache &cache, ToDeviceOp op) { +createOp(FlatbufferObjectCache &cache, ToDeviceOp op, + std::string const &debugString) { auto input = cache.at<::tt::target::TensorRef>(getOperandThroughDPSOps(op.getInput())); auto device = getOperandThroughDPSOps(op.getDevice()); @@ -178,7 +187,7 @@ createOp(FlatbufferObjectCache &cache, ToDeviceOp op) { if (!op.getMemoryConfig()) { return ::tt::target::ttnn::CreateToDeviceOp( *cache.fbb, input, cache.at<::tt::target::DeviceRef>(device), - /* memoryConfigDesc */ 0, output); + /* memoryConfigDesc */ 0, output, cache.fbb->CreateString(debugString)); } auto memoryConfigDesc = @@ -190,18 +199,21 @@ createOp(FlatbufferObjectCache &cache, ToDeviceOp op) { } ::flatbuffers::Offset<::tt::target::ttnn::FromDeviceOp> -createOp(FlatbufferObjectCache &cache, FromDeviceOp op) { +createOp(FlatbufferObjectCache &cache, FromDeviceOp op, + std::string const &debugString) { auto input = cache.at<::tt::target::TensorRef>(getOperandThroughDPSOps(op.getInput())); auto output = cache.getOrCreate(op.getResult(), tensorValueToFlatbuffer, kHostAllocatedAddress, kHostAllocatedSize); - return ::tt::target::ttnn::CreateFromDeviceOp(*cache.fbb, input, output); + return ::tt::target::ttnn::CreateFromDeviceOp( + *cache.fbb, input, output, cache.fbb->CreateString(debugString)); } ::flatbuffers::Offset<::tt::target::ttnn::EmptyOp> -createOp(FlatbufferObjectCache &cache, EmptyOp op) { +createOp(FlatbufferObjectCache &cache, EmptyOp op, + std::string const &debugString) { ::llvm::ArrayRef shape = op.getShape().getShape(); ::tt::target::DataType dtype = ::tt::mlir::ttnn::utils::toTargetDataType(op.getDtype().value()); @@ -237,11 +249,13 @@ createOp(FlatbufferObjectCache &cache, EmptyOp op) { numShards, cache.at<::tt::target::DeviceRef>(device), memoryConfigDesc, strategy, cache.getOrCreate(output, tensorValueToFlatbuffer, kHostAllocatedAddress, - kHostAllocatedSize)); + kHostAllocatedSize), + cache.fbb->CreateString(debugString)); } ::flatbuffers::Offset<::tt::target::ttnn::FullOp> -createOp(FlatbufferObjectCache &cache, FullOp op) { +createOp(FlatbufferObjectCache &cache, FullOp op, + std::string const &debugString) { auto device = getOperandThroughDPSOps(op.getDevice()); auto fillValue = op.getFillValue().convertToFloat(); auto output = getOperandThroughDPSOps(op.getResult()); @@ -256,24 +270,28 @@ createOp(FlatbufferObjectCache &cache, FullOp op) { *cache.fbb, cache.at<::tt::target::DeviceRef>(device), fillValue, numShards, strategy, cache.getOrCreate(output, tensorValueToFlatbuffer, kHostAllocatedAddress, - kHostAllocatedSize)); + kHostAllocatedSize), + cache.fbb->CreateString(debugString)); } // ANCHOR: adding_an_op_matmul_serialize_to_binary ::flatbuffers::Offset<::tt::target::ttnn::MatmulOp> -createOp(FlatbufferObjectCache &cache, MatmulOp op) { +createOp(FlatbufferObjectCache &cache, MatmulOp op, + std::string const &debugString) { auto in0 = cache.at<::tt::target::TensorRef>(getOperandThroughDPSOps(op.getA())); auto in1 = cache.at<::tt::target::TensorRef>(getOperandThroughDPSOps(op.getB())); auto output = cache.at<::tt::target::TensorRef>( getOperandThroughDPSOps(op.getResult())); - return ::tt::target::ttnn::CreateMatmulOp(*cache.fbb, in0, in1, output); + return ::tt::target::ttnn::CreateMatmulOp( + *cache.fbb, in0, in1, output, cache.fbb->CreateString(debugString)); } // ANCHOR_END: adding_an_op_matmul_serialize_to_binary ::flatbuffers::Offset<::tt::target::ttnn::Conv2dOp> -createOp(FlatbufferObjectCache &cache, Conv2dOp op) { +createOp(FlatbufferObjectCache &cache, Conv2dOp op, + std::string const &debugString) { auto in0 = cache.at<::tt::target::TensorRef>(getOperandThroughDPSOps(op.getInput())); auto in1 = cache.at<::tt::target::TensorRef>( @@ -293,22 +311,25 @@ createOp(FlatbufferObjectCache &cache, Conv2dOp op) { op.getInputWidth(), op.getKernelHeight(), op.getKernelWidth(), op.getStrideHeight(), op.getStrideWidth(), op.getPaddingHeight(), op.getPaddingWidth(), op.getDilationHeight(), op.getDilationWidth(), - op.getGroups()); + op.getGroups(), cache.fbb->CreateString(debugString)); } ::flatbuffers::Offset<::tt::target::ttnn::AllGatherOp> -createOp(FlatbufferObjectCache &cache, AllGatherOp op) { +createOp(FlatbufferObjectCache &cache, AllGatherOp op, + std::string const &debugString) { auto input = cache.at<::tt::target::TensorRef>(getOperandThroughDPSOps(op.getInput())); auto output = cache.getOrCreate(op.getResult(), tensorValueToFlatbuffer, kHostAllocatedAddress, kHostAllocatedSize); - return ::tt::target::ttnn::CreateAllGatherOp(*cache.fbb, input, output, - op.getDim(), op.getNumLinks()); + return ::tt::target::ttnn::CreateAllGatherOp( + *cache.fbb, input, output, op.getDim(), op.getNumLinks(), + cache.fbb->CreateString(debugString)); } template ::flatbuffers::Offset<::tt::target::ttnn::EltwiseOp> -createEltwiseOp(FlatbufferObjectCache &cache, EltwiseOp op) { +createEltwiseOp(FlatbufferObjectCache &cache, EltwiseOp op, + std::string const &debugString) { ::tt::target::ttnn::EltwiseOpType type; ::tt::target::ttnn::EltwiseOpParams paramsType = ::tt::target::ttnn::EltwiseOpParams::NONE; @@ -394,12 +415,13 @@ createEltwiseOp(FlatbufferObjectCache &cache, EltwiseOp op) { *cache.fbb, type, &ins, cache.at<::tt::target::TensorRef>( getOperandThroughDPSOps(op.getOutputs().front())), - paramsType, params); + paramsType, params, debugString.c_str()); } template ::flatbuffers::Offset<::tt::target::ttnn::ReductionOp> -createReductionOp(FlatbufferObjectCache &cache, ReductionOp op) { +createReductionOp(FlatbufferObjectCache &cache, ReductionOp op, + std::string const &debugString) { ::tt::target::ttnn::ReductionOpType type; if constexpr (std::is_same_v) { type = ::tt::target::ttnn::ReductionOpType::Sum; @@ -418,13 +440,15 @@ createReductionOp(FlatbufferObjectCache &cache, ReductionOp op) { auto dim_arg = arrayAttrToFlatbuffer(cache, op.getDimArg()); - return ::tt::target::ttnn::CreateReductionOp(*cache.fbb, type, in, output, - dim_arg, op.getKeepDim()); + return ::tt::target::ttnn::CreateReductionOp( + *cache.fbb, type, in, output, dim_arg, op.getKeepDim(), + cache.fbb->CreateString(debugString)); } template ::flatbuffers::Offset<::tt::target::ttnn::TransposeOp> -createTransposeOp(FlatbufferObjectCache &cache, TransposeOp op) { +createTransposeOp(FlatbufferObjectCache &cache, TransposeOp op, + std::string const &debugString) { auto in = cache.at<::tt::target::TensorRef>(getOperandThroughDPSOps(op.getInput())); auto out = cache.getOrCreate(op.getResult(), tensorValueToFlatbuffer, @@ -432,12 +456,14 @@ createTransposeOp(FlatbufferObjectCache &cache, TransposeOp op) { int32_t dim0 = op.getDim0(); int32_t dim1 = op.getDim1(); - return ::tt::target::ttnn::CreateTransposeOp(*cache.fbb, in, out, dim0, dim1); + return ::tt::target::ttnn::CreateTransposeOp( + *cache.fbb, in, out, dim0, dim1, cache.fbb->CreateString(debugString)); } template ::flatbuffers::Offset<::tt::target::ttnn::ConcatOp> -createConcatOp(FlatbufferObjectCache &cache, ConcatOp op) { +createConcatOp(FlatbufferObjectCache &cache, ConcatOp op, + std::string const &debugString) { std::vector<::flatbuffers::Offset<::tt::target::TensorRef>> ins; for (auto input : op.getInputs()) { ins.push_back( @@ -447,24 +473,28 @@ createConcatOp(FlatbufferObjectCache &cache, ConcatOp op) { getOperandThroughDPSOps(op.getResult())); int32_t dim = op.getDim(); - return ::tt::target::ttnn::CreateConcatOpDirect(*cache.fbb, &ins, out, dim); + return ::tt::target::ttnn::CreateConcatOpDirect(*cache.fbb, &ins, out, dim, + debugString.c_str()); } template ::flatbuffers::Offset<::tt::target::ttnn::EmbeddingOp> -createEmbeddingOp(FlatbufferObjectCache &cache, EmbeddingOp op) { +createEmbeddingOp(FlatbufferObjectCache &cache, EmbeddingOp op, + std::string const &debugString) { auto in0 = cache.at<::tt::target::TensorRef>(getOperandThroughDPSOps(op.getInput())); auto in1 = cache.at<::tt::target::TensorRef>( getOperandThroughDPSOps(op.getWeight())); auto output = cache.getOrCreate(op.getResult(), tensorValueToFlatbuffer, kHostAllocatedAddress, kHostAllocatedSize); - return ::tt::target::ttnn::CreateEmbeddingOp(*cache.fbb, in0, in1, output); + return ::tt::target::ttnn::CreateEmbeddingOp( + *cache.fbb, in0, in1, output, cache.fbb->CreateString(debugString)); } template ::flatbuffers::Offset<::tt::target::ttnn::ReshapeOp> -createReshapeOp(FlatbufferObjectCache &cache, ReshapeOp op) { +createReshapeOp(FlatbufferObjectCache &cache, ReshapeOp op, + std::string const &debugString) { auto in = cache.at<::tt::target::TensorRef>(getOperandThroughDPSOps(op.getInput())); auto shape = @@ -472,12 +502,14 @@ createReshapeOp(FlatbufferObjectCache &cache, ReshapeOp op) { auto out = cache.getOrCreate(op.getResult(), tensorValueToFlatbuffer, kHostAllocatedAddress, kHostAllocatedSize); - return ::tt::target::ttnn::CreateReshapeOp(*cache.fbb, in, out, shape); + return ::tt::target::ttnn::CreateReshapeOp( + *cache.fbb, in, out, shape, cache.fbb->CreateString(debugString)); } template ::flatbuffers::Offset<::tt::target::ttnn::SliceOp> -createSliceOp(FlatbufferObjectCache &cache, SliceOp op) { +createSliceOp(FlatbufferObjectCache &cache, SliceOp op, + std::string const &debugString) { auto in = cache.at<::tt::target::TensorRef>(getOperandThroughDPSOps(op.getInput())); auto out = cache.at<::tt::target::TensorRef>( @@ -489,13 +521,15 @@ createSliceOp(FlatbufferObjectCache &cache, SliceOp op) { auto step = arrayAttrToFlatbuffer(cache, op.getStep()); - return ::tt::target::ttnn::CreateSliceOp(*cache.fbb, in, out, begins, ends, - step); + return ::tt::target::ttnn::CreateSliceOp( + *cache.fbb, in, out, begins, ends, step, + cache.fbb->CreateString(debugString)); } template ::flatbuffers::Offset<::tt::target::ttnn::MaxPool2dOp> -createMaxPool2dOp(FlatbufferObjectCache &cache, MaxPool2dOp op) { +createMaxPool2dOp(FlatbufferObjectCache &cache, MaxPool2dOp op, + std::string const &debugString) { auto in = cache.at<::tt::target::TensorRef>(getOperandThroughDPSOps(op.getInput())); auto out = cache.at<::tt::target::TensorRef>( @@ -508,216 +542,261 @@ createMaxPool2dOp(FlatbufferObjectCache &cache, MaxPool2dOp op) { op.getChannels(), op.getKernelHeight(), op.getKernelWidth(), op.getStrideHeight(), op.getStrideWidth(), op.getDilationHeight(), op.getDilationWidth(), op.getCeilMode(), op.getPaddingHeight(), - op.getPaddingWidth()); + op.getPaddingWidth(), cache.fbb->CreateString(debugString)); } template ::flatbuffers::Offset<::tt::target::ttnn::SoftmaxOp> -createSoftmaxOp(FlatbufferObjectCache &cache, SoftmaxOp op) { +createSoftmaxOp(FlatbufferObjectCache &cache, SoftmaxOp op, + std::string const &debugString) { auto in = cache.at<::tt::target::TensorRef>(getOperandThroughDPSOps(op.getInput())); auto out = cache.getOrCreate(op.getResult(), tensorValueToFlatbuffer, kHostAllocatedAddress, kHostAllocatedSize); int32_t dimension = op.getDimension(); - return ::tt::target::ttnn::CreateSoftmaxOp(*cache.fbb, in, out, dimension); + return ::tt::target::ttnn::CreateSoftmaxOp( + *cache.fbb, in, out, dimension, cache.fbb->CreateString(debugString)); } template ::flatbuffers::Offset<::tt::target::ttnn::DeallocOp> -createDeallocOp(FlatbufferObjectCache &cache, DeallocOp op) { +createDeallocOp(FlatbufferObjectCache &cache, DeallocOp op, + std::string const &debugString) { auto in = cache.at<::tt::target::TensorRef>(getOperandThroughDPSOps(op.getInput())); - return ::tt::target::ttnn::CreateDeallocOp(*cache.fbb, in); + return ::tt::target::ttnn::CreateDeallocOp( + *cache.fbb, in, cache.fbb->CreateString(debugString)); } ::flatbuffers::Offset<::tt::target::ttnn::Operation> emitTTNNOperation(FlatbufferObjectCache &cache, Operation *op, std::string const &debugString) { if (auto getDeviceOp = dyn_cast(op); getDeviceOp) { - return createOperation(cache, createOp(cache, getDeviceOp), debugString); + return createOperation(cache, createOp(cache, getDeviceOp, debugString), + debugString); } if (auto toMemoryConfigOp = dyn_cast(op); toMemoryConfigOp) { - return createOperation(cache, createOp(cache, toMemoryConfigOp), - debugString); + return createOperation( + cache, createOp(cache, toMemoryConfigOp, debugString), debugString); } if (auto toLayoutOp = dyn_cast(op); toLayoutOp) { - return createOperation(cache, createOp(cache, toLayoutOp), debugString); + return createOperation(cache, createOp(cache, toLayoutOp, debugString), + debugString); } if (auto typecastOp = dyn_cast(op); typecastOp) { - return createOperation(cache, createOp(cache, typecastOp), debugString); + return createOperation(cache, createOp(cache, typecastOp, debugString), + debugString); } if (auto toDeviceOp = dyn_cast(op); toDeviceOp) { - return createOperation(cache, createOp(cache, toDeviceOp), debugString); + return createOperation(cache, createOp(cache, toDeviceOp, debugString), + debugString); } if (auto fromDeviceOp = dyn_cast(op); fromDeviceOp) { - return createOperation(cache, createOp(cache, fromDeviceOp), debugString); + return createOperation(cache, createOp(cache, fromDeviceOp, debugString), + debugString); } if (auto emptyOp = dyn_cast(op); emptyOp) { - return createOperation(cache, createOp(cache, emptyOp), debugString); + return createOperation(cache, createOp(cache, emptyOp, debugString), + debugString); } if (auto fullOp = dyn_cast(op); fullOp) { - return createOperation(cache, createOp(cache, fullOp), debugString); + return createOperation(cache, createOp(cache, fullOp, debugString), + debugString); } if (auto absOp = dyn_cast(op); absOp) { - return createOperation(cache, createEltwiseOp(cache, absOp), debugString); + return createOperation(cache, createEltwiseOp(cache, absOp, debugString), + debugString); } if (auto addOp = dyn_cast(op); addOp) { - return createOperation(cache, createEltwiseOp(cache, addOp), debugString); + return createOperation(cache, createEltwiseOp(cache, addOp, debugString), + debugString); } if (auto floorOp = dyn_cast(op); floorOp) { - return createOperation(cache, createEltwiseOp(cache, floorOp), debugString); + return createOperation(cache, createEltwiseOp(cache, floorOp, debugString), + debugString); } if (auto isFiniteOp = dyn_cast(op); isFiniteOp) { - return createOperation(cache, createEltwiseOp(cache, isFiniteOp), - debugString); + return createOperation( + cache, createEltwiseOp(cache, isFiniteOp, debugString), debugString); } if (auto andOp = dyn_cast(op); andOp) { - return createOperation(cache, createEltwiseOp(cache, andOp), debugString); + return createOperation(cache, createEltwiseOp(cache, andOp, debugString), + debugString); } if (auto cbrtOp = dyn_cast(op); cbrtOp) { - return createOperation(cache, createEltwiseOp(cache, cbrtOp), debugString); + return createOperation(cache, createEltwiseOp(cache, cbrtOp, debugString), + debugString); } if (auto notOp = dyn_cast(op); notOp) { - return createOperation(cache, createEltwiseOp(cache, notOp), debugString); + return createOperation(cache, createEltwiseOp(cache, notOp, debugString), + debugString); } if (auto orOp = dyn_cast(op); orOp) { - return createOperation(cache, createEltwiseOp(cache, orOp), debugString); + return createOperation(cache, createEltwiseOp(cache, orOp, debugString), + debugString); } if (auto multiplyOp = dyn_cast(op); multiplyOp) { - return createOperation(cache, createEltwiseOp(cache, multiplyOp), - debugString); + return createOperation( + cache, createEltwiseOp(cache, multiplyOp, debugString), debugString); } if (auto negOp = dyn_cast(op); negOp) { - return createOperation(cache, createEltwiseOp(cache, negOp), debugString); + return createOperation(cache, createEltwiseOp(cache, negOp, debugString), + debugString); } if (auto subtractOp = dyn_cast(op); subtractOp) { - return createOperation(cache, createEltwiseOp(cache, subtractOp), - debugString); + return createOperation( + cache, createEltwiseOp(cache, subtractOp, debugString), debugString); } if (auto eqOp = dyn_cast(op); eqOp) { - return createOperation(cache, createEltwiseOp(cache, eqOp), debugString); + return createOperation(cache, createEltwiseOp(cache, eqOp, debugString), + debugString); } if (auto neOp = dyn_cast(op); neOp) { - return createOperation(cache, createEltwiseOp(cache, neOp), debugString); + return createOperation(cache, createEltwiseOp(cache, neOp, debugString), + debugString); } if (auto geOp = dyn_cast(op); geOp) { - return createOperation(cache, createEltwiseOp(cache, geOp), debugString); + return createOperation(cache, createEltwiseOp(cache, geOp, debugString), + debugString); } if (auto gtOp = dyn_cast(op); gtOp) { - return createOperation(cache, createEltwiseOp(cache, gtOp), debugString); + return createOperation(cache, createEltwiseOp(cache, gtOp, debugString), + debugString); } if (auto leOp = dyn_cast(op); leOp) { - return createOperation(cache, createEltwiseOp(cache, leOp), debugString); + return createOperation(cache, createEltwiseOp(cache, leOp, debugString), + debugString); } if (auto ltOp = dyn_cast(op); ltOp) { - return createOperation(cache, createEltwiseOp(cache, ltOp), debugString); + return createOperation(cache, createEltwiseOp(cache, ltOp, debugString), + debugString); } if (auto maximumOp = dyn_cast(op); maximumOp) { - return createOperation(cache, createEltwiseOp(cache, maximumOp), - debugString); + return createOperation( + cache, createEltwiseOp(cache, maximumOp, debugString), debugString); } if (auto minimumOp = dyn_cast(op); minimumOp) { - return createOperation(cache, createEltwiseOp(cache, minimumOp), - debugString); + return createOperation( + cache, createEltwiseOp(cache, minimumOp, debugString), debugString); } if (auto reluOp = dyn_cast(op); reluOp) { - return createOperation(cache, createEltwiseOp(cache, reluOp), debugString); + return createOperation(cache, createEltwiseOp(cache, reluOp, debugString), + debugString); } if (auto sqrtOp = dyn_cast(op); sqrtOp) { - return createOperation(cache, createEltwiseOp(cache, sqrtOp), debugString); + return createOperation(cache, createEltwiseOp(cache, sqrtOp, debugString), + debugString); } if (auto rsqrtOp = dyn_cast(op); rsqrtOp) { - return createOperation(cache, createEltwiseOp(cache, rsqrtOp), debugString); + return createOperation(cache, createEltwiseOp(cache, rsqrtOp, debugString), + debugString); } if (auto signOp = dyn_cast(op); signOp) { - return createOperation(cache, createEltwiseOp(cache, signOp), debugString); + return createOperation(cache, createEltwiseOp(cache, signOp, debugString), + debugString); } if (auto expOp = dyn_cast(op); expOp) { - return createOperation(cache, createEltwiseOp(cache, expOp), debugString); + return createOperation(cache, createEltwiseOp(cache, expOp, debugString), + debugString); } if (auto logOp = dyn_cast(op); logOp) { - return createOperation(cache, createEltwiseOp(cache, logOp), debugString); + return createOperation(cache, createEltwiseOp(cache, logOp, debugString), + debugString); } if (auto expm1Op = dyn_cast(op); expm1Op) { - return createOperation(cache, createEltwiseOp(cache, expm1Op), debugString); + return createOperation(cache, createEltwiseOp(cache, expm1Op, debugString), + debugString); } if (auto sigmoidOp = dyn_cast(op); sigmoidOp) { - return createOperation(cache, createEltwiseOp(cache, sigmoidOp), - debugString); + return createOperation( + cache, createEltwiseOp(cache, sigmoidOp, debugString), debugString); } if (auto log1pOp = dyn_cast(op); log1pOp) { - return createOperation(cache, createEltwiseOp(cache, log1pOp), debugString); + return createOperation(cache, createEltwiseOp(cache, log1pOp, debugString), + debugString); } if (auto reciprocalOp = dyn_cast(op); reciprocalOp) { - return createOperation(cache, createEltwiseOp(cache, reciprocalOp), - debugString); + return createOperation( + cache, createEltwiseOp(cache, reciprocalOp, debugString), debugString); } if (auto divOp = dyn_cast(op); divOp) { - return createOperation(cache, createEltwiseOp(cache, divOp), debugString); + return createOperation(cache, createEltwiseOp(cache, divOp, debugString), + debugString); } if (auto remainderOp = dyn_cast(op); remainderOp) { - return createOperation(cache, createEltwiseOp(cache, remainderOp), - debugString); + return createOperation( + cache, createEltwiseOp(cache, remainderOp, debugString), debugString); } if (auto matmulOp = dyn_cast(op); matmulOp) { - return createOperation(cache, createOp(cache, matmulOp), debugString); + return createOperation(cache, createOp(cache, matmulOp, debugString), + debugString); } if (auto sumOp = dyn_cast(op); sumOp) { - return createOperation(cache, createReductionOp(cache, sumOp), debugString); + return createOperation(cache, createReductionOp(cache, sumOp, debugString), + debugString); } if (auto meanOp = dyn_cast(op); meanOp) { - return createOperation(cache, createReductionOp(cache, meanOp), + return createOperation(cache, createReductionOp(cache, meanOp, debugString), debugString); } if (auto maxOp = dyn_cast(op); maxOp) { - return createOperation(cache, createReductionOp(cache, maxOp), debugString); + return createOperation(cache, createReductionOp(cache, maxOp, debugString), + debugString); } if (auto embeddingOp = dyn_cast(op); embeddingOp) { - return createOperation(cache, createEmbeddingOp(cache, embeddingOp), - debugString); + return createOperation( + cache, createEmbeddingOp(cache, embeddingOp, debugString), debugString); } if (auto softmaxOp = dyn_cast(op); softmaxOp) { - return createOperation(cache, createSoftmaxOp(cache, softmaxOp), - debugString); + return createOperation( + cache, createSoftmaxOp(cache, softmaxOp, debugString), debugString); } if (auto transposeOp = dyn_cast(op); transposeOp) { - return createOperation(cache, createTransposeOp(cache, transposeOp), - debugString); + return createOperation( + cache, createTransposeOp(cache, transposeOp, debugString), debugString); } if (auto conv2dOp = dyn_cast(op); conv2dOp) { - return createOperation(cache, createOp(cache, conv2dOp), debugString); + return createOperation(cache, createOp(cache, conv2dOp, debugString), + debugString); } if (auto allGatherOp = dyn_cast(op); allGatherOp) { - return createOperation(cache, createOp(cache, allGatherOp), debugString); + return createOperation(cache, createOp(cache, allGatherOp, debugString), + debugString); } if (auto concatOp = dyn_cast(op); concatOp) { - return createOperation(cache, createConcatOp(cache, concatOp), debugString); + return createOperation(cache, createConcatOp(cache, concatOp, debugString), + debugString); } if (auto reshapeOp = dyn_cast(op); reshapeOp) { - return createOperation(cache, createReshapeOp(cache, reshapeOp), - debugString); + return createOperation( + cache, createReshapeOp(cache, reshapeOp, debugString), debugString); } if (auto sliceOp = dyn_cast(op); sliceOp) { - return createOperation(cache, createSliceOp(cache, sliceOp), debugString); + return createOperation(cache, createSliceOp(cache, sliceOp, debugString), + debugString); } if (auto max_pool2dOp = dyn_cast(op); max_pool2dOp) { - return createOperation(cache, createMaxPool2dOp(cache, max_pool2dOp), + return createOperation(cache, + createMaxPool2dOp(cache, max_pool2dOp, debugString), debugString); } if (auto deallocOp = dyn_cast(op); deallocOp) { - return createOperation(cache, createDeallocOp(cache, deallocOp), - debugString); + return createOperation( + cache, createDeallocOp(cache, deallocOp, debugString), debugString); } if (auto ceilOp = dyn_cast(op); ceilOp) { - return createOperation(cache, createEltwiseOp(cache, ceilOp), debugString); + return createOperation(cache, createEltwiseOp(cache, ceilOp, debugString), + debugString); } if (auto cosOp = dyn_cast(op); cosOp) { - return createOperation(cache, createEltwiseOp(cache, cosOp), debugString); + return createOperation(cache, createEltwiseOp(cache, cosOp, debugString), + debugString); } if (auto sinOp = dyn_cast(op); sinOp) { - return createOperation(cache, createEltwiseOp(cache, sinOp), debugString); + return createOperation(cache, createEltwiseOp(cache, sinOp, debugString), + debugString); } llvm_unreachable("unhandled op in emitTTNNOperation"); diff --git a/runtime/include/tt/runtime/detail/debug.h b/runtime/include/tt/runtime/detail/debug.h index c5d84c4d98..eef43b0128 100644 --- a/runtime/include/tt/runtime/detail/debug.h +++ b/runtime/include/tt/runtime/detail/debug.h @@ -5,6 +5,8 @@ #ifndef TT_RUNTIME_DETAIL_DEBUG_H #define TT_RUNTIME_DETAIL_DEBUG_H +#include +#include #include namespace tt::runtime::debug { @@ -41,6 +43,49 @@ inline std::ostream &operator<<(std::ostream &os, Env const &env) { return os; } +struct Hooks { +#if defined(TT_RUNTIME_DEBUG) && TT_RUNTIME_DEBUG == 1 + static Hooks const & + get(std::optional, + std::optional)>> + operatorCallback = std::nullopt); +#else + constexpr static Hooks get() { return Hooks(); } +#endif + + std::optional, + std::optional)>> + getOperatorCallback() const { +#if defined(TT_RUNTIME_DEBUG) && TT_RUNTIME_DEBUG == 1 + return operatorCallback; +#else + return std::nullopt; +#endif + } + +private: +#if defined(TT_RUNTIME_DEBUG) && TT_RUNTIME_DEBUG == 1 + Hooks(std::optional, + std::optional)>> + operatorCallback) + : operatorCallback(operatorCallback) {} + + std::optional, + std::optional)>> + operatorCallback; +#else + constexpr Hooks() = default; +#endif +}; + +inline std::ostream &operator<<(std::ostream &os, Hooks const &hooks) { + os << "debug::Hooks{\n" + << "\t" + << "operatorCallback: " << bool(hooks.getOperatorCallback()) << ",\n" + << "}"; + return os; +} + } // namespace tt::runtime::debug #endif // TT_RUNTIME_DETAIL_DEBUG_H diff --git a/runtime/include/tt/runtime/runtime.h b/runtime/include/tt/runtime/runtime.h index a070f2f0f5..43bf4e0d6f 100644 --- a/runtime/include/tt/runtime/runtime.h +++ b/runtime/include/tt/runtime/runtime.h @@ -69,6 +69,10 @@ Event submit(Device device, Binary executable, std::uint32_t programIndex, void wait(Event event); +std::vector getOpOutputTensor(const void *context, + const void *opContext); +std::string getOpDebugString(const void *context, const void *opContext); + } // namespace tt::runtime #endif diff --git a/runtime/lib/common/debug.cpp b/runtime/lib/common/debug.cpp index f075177642..e3491e6793 100644 --- a/runtime/lib/common/debug.cpp +++ b/runtime/lib/common/debug.cpp @@ -13,6 +13,18 @@ Env const &Env::get(bool loadKernelsFromDisk, bool enableAsyncTTNN) { return config; } +#if defined(TT_RUNTIME_DEBUG) && TT_RUNTIME_DEBUG == 1 +Hooks const & +Hooks::get(std::optional, + std::optional)>> + operatorCallback) { + static Hooks config(operatorCallback); + return config; +} +#else +Hooks get() { return Hooks(); } +#endif + } // namespace tt::runtime::debug #endif diff --git a/runtime/lib/runtime.cpp b/runtime/lib/runtime.cpp index 8b0e79daab..227a2cae1f 100644 --- a/runtime/lib/runtime.cpp +++ b/runtime/lib/runtime.cpp @@ -10,6 +10,7 @@ #if defined(TT_RUNTIME_ENABLE_TTNN) #include "tt/runtime/detail/ttnn.h" +#include "tt/runtime/ttnn/types.h" #endif #if defined(TT_RUNTIME_ENABLE_TTMETAL) @@ -245,4 +246,152 @@ void wait(Event event) { throw std::runtime_error("runtime is not enabled"); } +#if defined(TT_RUNTIME_ENABLE_TTNN) +std::vector getOpOutputTensor(const void *context, + const void *opContext) { + auto *contextPtr = static_cast(context); + auto *opContextPtr = + static_cast(opContext); + const ::ttnn::Tensor *outPtr = nullptr; + const ttnn::ProgramTensorPool &tensorPool = contextPtr->getTensorPool(); + std::uint32_t globalId; + + switch (opContextPtr->type_type()) { + case ::tt::target::ttnn::OpType::GetDeviceOp: { + globalId = opContextPtr->type_as_GetDeviceOp()->out()->global_id(); + break; + } + case ::tt::target::ttnn::OpType::ToMemoryConfigOp: { + globalId = opContextPtr->type_as_ToMemoryConfigOp()->out()->global_id(); + ; + break; + } + case ::tt::target::ttnn::OpType::ToLayoutOp: { + globalId = opContextPtr->type_as_ToLayoutOp()->out()->global_id(); + ; + break; + } + case ::tt::target::ttnn::OpType::TypecastOp: { + globalId = opContextPtr->type_as_TypecastOp()->out()->global_id(); + ; + break; + } + case ::tt::target::ttnn::OpType::ToDeviceOp: { + globalId = opContextPtr->type_as_ToDeviceOp()->out()->global_id(); + ; + break; + } + case ::tt::target::ttnn::OpType::FromDeviceOp: { + globalId = opContextPtr->type_as_FromDeviceOp()->out()->global_id(); + ; + break; + } + case ::tt::target::ttnn::OpType::EmptyOp: { + globalId = opContextPtr->type_as_EmptyOp()->out()->global_id(); + ; + break; + } + case ::tt::target::ttnn::OpType::FullOp: { + globalId = opContextPtr->type_as_FullOp()->out()->global_id(); + ; + break; + } + case ::tt::target::ttnn::OpType::EltwiseOp: { + globalId = opContextPtr->type_as_EltwiseOp()->out()->global_id(); + ; + break; + } + case ::tt::target::ttnn::OpType::MatmulOp: { + globalId = opContextPtr->type_as_MatmulOp()->out()->global_id(); + ; + break; + } + case ::tt::target::ttnn::OpType::ReductionOp: { + globalId = opContextPtr->type_as_ReductionOp()->out()->global_id(); + ; + break; + } + case ::tt::target::ttnn::OpType::EmbeddingOp: { + globalId = opContextPtr->type_as_EmbeddingOp()->out()->global_id(); + ; + break; + } + case ::tt::target::ttnn::OpType::SoftmaxOp: { + globalId = opContextPtr->type_as_SoftmaxOp()->out()->global_id(); + ; + break; + } + case ::tt::target::ttnn::OpType::TransposeOp: { + globalId = opContextPtr->type_as_TransposeOp()->out()->global_id(); + ; + break; + } + case ::tt::target::ttnn::OpType::ConcatOp: { + globalId = opContextPtr->type_as_ConcatOp()->out()->global_id(); + ; + break; + } + case ::tt::target::ttnn::OpType::ReshapeOp: { + globalId = opContextPtr->type_as_ReshapeOp()->out()->global_id(); + ; + break; + } + case ::tt::target::ttnn::OpType::SliceOp: { + globalId = opContextPtr->type_as_SliceOp()->out()->global_id(); + ; + break; + } + case ::tt::target::ttnn::OpType::Conv2dOp: { + globalId = opContextPtr->type_as_Conv2dOp()->out()->global_id(); + ; + break; + } + case ::tt::target::ttnn::OpType::DeallocOp: { + LOG_WARNING("getting output tensor for DeallocOp is not supported"); + return {}; + } + case ::tt::target::ttnn::OpType::MaxPool2dOp: { + globalId = opContextPtr->type_as_MaxPool2dOp()->out()->global_id(); + ; + break; + } + case ::tt::target::ttnn::OpType::AllGatherOp: { + globalId = opContextPtr->type_as_AllGatherOp()->out()->global_id(); + ; + break; + } + default: { + throw std::runtime_error("Unsupported operation type"); + } + } + + if (tensorPool.contains(globalId)) { + outPtr = &tensorPool.at(globalId); + } else { + LOG_WARNING("Output tensor not found in tensor pool"); + return {}; + } + ::ttnn::Tensor hostTensor = ::ttnn::from_device(*outPtr); + ::ttnn::Tensor outCopy = + ::ttnn::to_layout(hostTensor, ::ttnn::ROW_MAJOR_LAYOUT, std::nullopt, + std::nullopt, static_cast<::ttnn::Device *>(nullptr)); + std::uint32_t outCopySize = outCopy.volume() * outCopy.element_size(); + void *src = ::tt::tt_metal::get_raw_host_data_ptr(outCopy); + void *dst = malloc(outCopySize); + std::memcpy(dst, src, outCopySize); + std::vector outVec(static_cast(dst), + static_cast(dst) + outCopy.volume()); + + return outVec; +} +#endif + +#if defined(TT_RUNTIME_ENABLE_TTNN) +std::string getOpDebugString(const void *context, const void *opContext) { + auto *opContextPtr = + static_cast(opContext); + return std::string(opContextPtr->debug_info()->c_str()); +} +#endif + } // namespace tt::runtime diff --git a/runtime/lib/ttnn/include/tt/runtime/ttnn/types.h b/runtime/lib/ttnn/include/tt/runtime/ttnn/types.h index e59bb66d60..4698cc0bac 100644 --- a/runtime/lib/ttnn/include/tt/runtime/ttnn/types.h +++ b/runtime/lib/ttnn/include/tt/runtime/ttnn/types.h @@ -46,6 +46,11 @@ class ProgramTensorPool { return *liveTensors.at(globalId); } + const ::ttnn::Tensor &at(std::uint32_t globalId) const { + assert(liveTensors.contains(globalId)); + return *liveTensors.at(globalId); + } + size_t erase(std::uint32_t globalId) { assert(liveTensors.contains(globalId) && intermedTensors.contains(globalId)); @@ -161,6 +166,7 @@ class ProgramContext { // Tensor Pool Operations // ProgramTensorPool &getTensorPool() { return tensorPool; } + const ProgramTensorPool &getTensorPool() const { return tensorPool; } private: ProgramTensorPool tensorPool; diff --git a/runtime/lib/ttnn/program.cpp b/runtime/lib/ttnn/program.cpp index af1b28d990..214b7bf048 100644 --- a/runtime/lib/ttnn/program.cpp +++ b/runtime/lib/ttnn/program.cpp @@ -25,6 +25,7 @@ #include "operations/normalization/softmax.h" #include "operations/pool/maxpool2d.h" #include "operations/reduction/reduction.h" +#include "tt/runtime/detail/debug.h" #include "tt/runtime/detail/logger.h" #include "tt/runtime/ttnn/types.h" #include "ttmlir/Target/TTNN/program_generated.h" @@ -87,75 +88,121 @@ void ProgramExecutor::runEltwiseOperation( void ProgramExecutor::runOperation(const ::tt::target::ttnn::Operation *op) { switch (op->type_type()) { case ::tt::target::ttnn::OpType::GetDeviceOp: { - return operations::context::run(op->type_as_GetDeviceOp(), context); + auto childOp = op->type_as_GetDeviceOp(); + operations::context::run(childOp, context); + break; } case ::tt::target::ttnn::OpType::ToMemoryConfigOp: { - return operations::layout::run(op->type_as_ToMemoryConfigOp(), context); + auto childOp = op->type_as_ToMemoryConfigOp(); + operations::layout::run(childOp, context); + break; } case ::tt::target::ttnn::OpType::ToLayoutOp: { - return operations::layout::run(op->type_as_ToLayoutOp(), context); + auto childOp = op->type_as_ToLayoutOp(); + operations::layout::run(childOp, context); + break; } case ::tt::target::ttnn::OpType::TypecastOp: { - return operations::layout::run(op->type_as_TypecastOp(), context); + auto childOp = op->type_as_TypecastOp(); + operations::layout::run(childOp, context); + break; } case ::tt::target::ttnn::OpType::ToDeviceOp: { - return operations::layout::run(op->type_as_ToDeviceOp(), context); + auto childOp = op->type_as_ToDeviceOp(); + operations::layout::run(childOp, context); + break; } case ::tt::target::ttnn::OpType::FromDeviceOp: { - return operations::layout::run(op->type_as_FromDeviceOp(), context); + auto childOp = op->type_as_FromDeviceOp(); + operations::layout::run(childOp, context); + break; } case ::tt::target::ttnn::OpType::EmptyOp: { - return operations::creation::run(op->type_as_EmptyOp(), context); + auto childOp = op->type_as_EmptyOp(); + operations::creation::run(childOp, context); + break; } case ::tt::target::ttnn::OpType::FullOp: { - return operations::creation::run(op->type_as_FullOp(), context); + auto childOp = op->type_as_FullOp(); + operations::creation::run(childOp, context); + break; } case ::tt::target::ttnn::OpType::EltwiseOp: { - const ::tt::target::ttnn::EltwiseOp *eltwiseOp = op->type_as_EltwiseOp(); - return runEltwiseOperation(eltwiseOp); + auto childOp = op->type_as_EltwiseOp(); + runEltwiseOperation(childOp); + break; } // ANCHOR: adding_an_op_matmul_runtime_program case ::tt::target::ttnn::OpType::MatmulOp: { - return operations::matmul::run(op->type_as_MatmulOp(), context); + auto childOp = op->type_as_MatmulOp(); + operations::matmul::run(childOp, context); + break; } // ANCHOR_END: adding_an_op_matmul_runtime_program case ::tt::target::ttnn::OpType::ReductionOp: { - return operations::reduction::run(op->type_as_ReductionOp(), context); + auto childOp = op->type_as_ReductionOp(); + operations::reduction::run(childOp, context); + break; } case ::tt::target::ttnn::OpType::EmbeddingOp: { - return operations::embedding::run(op->type_as_EmbeddingOp(), context); + auto childOp = op->type_as_EmbeddingOp(); + operations::embedding::run(childOp, context); + break; } case ::tt::target::ttnn::OpType::SoftmaxOp: { - return operations::normalization::run(op->type_as_SoftmaxOp(), context); + auto childOp = op->type_as_SoftmaxOp(); + operations::normalization::run(childOp, context); + break; } case ::tt::target::ttnn::OpType::TransposeOp: { - return operations::data_movement::run(op->type_as_TransposeOp(), context); + auto childOp = op->type_as_TransposeOp(); + operations::data_movement::run(childOp, context); + break; } case ::tt::target::ttnn::OpType::ConcatOp: { - return operations::data_movement::run(op->type_as_ConcatOp(), context); + auto childOp = op->type_as_ConcatOp(); + operations::data_movement::run(childOp, context); + break; } case ::tt::target::ttnn::OpType::ReshapeOp: { - return operations::data_movement::run(op->type_as_ReshapeOp(), context); + auto childOp = op->type_as_ReshapeOp(); + operations::data_movement::run(childOp, context); + break; } case ::tt::target::ttnn::OpType::SliceOp: { - return operations::data_movement::run(op->type_as_SliceOp(), context); + auto childOp = op->type_as_SliceOp(); + operations::data_movement::run(childOp, context); + break; } case ::tt::target::ttnn::OpType::Conv2dOp: { - return operations::conv::run(op->type_as_Conv2dOp(), context); + auto childOp = op->type_as_Conv2dOp(); + operations::conv::run(childOp, context); + break; } case ::tt::target::ttnn::OpType::DeallocOp: { - return operations::deletion::run(op->type_as_DeallocOp(), context); + auto childOp = op->type_as_DeallocOp(); + return operations::deletion::run(childOp, context); + break; } case ::tt::target::ttnn::OpType::MaxPool2dOp: { - return operations::pool::run(op->type_as_MaxPool2dOp(), context); + auto childOp = op->type_as_MaxPool2dOp(); + operations::pool::run(childOp, context); + break; } case ::tt::target::ttnn::OpType::AllGatherOp: { - return operations::ccl::run(op->type_as_AllGatherOp(), context); + auto childOp = op->type_as_AllGatherOp(); + operations::ccl::run(childOp, context); + break; } default: { throw std::runtime_error("Unsupported operation type"); } } + + if (auto callback = debug::Hooks::get().getOperatorCallback(); callback) { + (*callback)(static_cast(&context), + static_cast(op)); + } } // Nop is single input, output tensor where input is returned as output. diff --git a/runtime/lib/ttnn/runtime.cpp b/runtime/lib/ttnn/runtime.cpp index b06ae893aa..f89d891046 100644 --- a/runtime/lib/ttnn/runtime.cpp +++ b/runtime/lib/ttnn/runtime.cpp @@ -5,6 +5,7 @@ #include "tt/runtime/detail/debug.h" #include "tt/runtime/detail/logger.h" #include "tt/runtime/detail/ttnn.h" +#include "tt/runtime/ttnn/types.h" #include "tt/runtime/ttnn/utils.h" #include "tt/runtime/utils.h" #include "ttmlir/Target/TTNN/Target.h" diff --git a/runtime/tools/python/ttrt/common/callback.py b/runtime/tools/python/ttrt/common/callback.py new file mode 100644 index 0000000000..58319c8377 --- /dev/null +++ b/runtime/tools/python/ttrt/common/callback.py @@ -0,0 +1,157 @@ +# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 + +import os +import json +import importlib.machinery +import sys +import signal +import os +import io +import subprocess +import time +import socket +from pkg_resources import get_distribution +import shutil +import atexit +import re + +from ttrt.common.util import * + +GOLDENS = {} + + +def get_atol_rtol_pcc(golden, calculated): + import numpy as np + import torch + + # Calculate atol and rtol + cal_atol = torch.max(torch.abs(golden - calculated)).item() + cal_rtol = torch.max(torch.abs(golden - calculated) / torch.abs(calculated)).item() + + # Calculate PCC + def get_pcc(golden, calculated): + # Both tensors are nan + if torch.all(torch.isnan(golden)) and torch.all(torch.isnan(calculated)): + print("Both tensors are 'nan'") + return 1.0 + # One tensor is all nan, the other is not + elif torch.all(torch.isnan(golden)) or torch.all(torch.isnan(calculated)): + print("One tensor is all nan, the other is not.") + return 0.0 + else: + # For now, mask all infs and nans so that we check the rest... TODO + golden = golden.clone() + golden[ + torch.logical_or( + torch.isnan(golden), + torch.logical_or(torch.isinf(golden), torch.isneginf(golden)), + ) + ] = 0 + calculated = calculated.clone() + calculated[ + torch.logical_or( + torch.isnan(calculated), + torch.logical_or( + torch.isinf(calculated), torch.isneginf(calculated) + ), + ) + ] = 0 + + if torch.equal(golden, calculated): + return 1.0 + + if golden.dtype == torch.bfloat16: + golden = golden.type(torch.float32) + calculated = calculated.type(torch.float32) + + # Single element case + if golden.numel() == 1: + return float(torch.equal(golden, calculated)) + + # If both tensors are contant + if torch.max(golden) == torch.min(golden) and torch.max( + calculated + ) == torch.min(calculated): + return torch.isclose(torch.max(golden), torch.max(calculated)).item() + + cal_pcc = np.ma.corrcoef( + np.ma.masked_invalid(torch.squeeze(golden).detach().numpy()).flatten(), + np.ma.masked_invalid( + torch.squeeze(calculated).detach().numpy() + ).flatten(), + ) + # Remove correlation coefficient with self (typically always 1.0) + mask = np.ones(cal_pcc.shape, dtype=bool) + np.fill_diagonal(mask, 0) + cal_pcc = np.min(cal_pcc[mask]) + + if isinstance(cal_pcc, np.ma.core.MaskedConstant): + return 1.0 + + return cal_pcc + + cal_pcc = get_pcc(golden, calculated) + + return ( + cal_atol, + cal_rtol, + cal_pcc, + f"Max ATOL Delta: {cal_atol}, Max RTOL Delta: {cal_rtol}, PCC: {cal_pcc}", + ) + + +def add_global_golden(golden_tensor): + global GOLDENS + GOLDENS[golden_tensor.tensor_id] = golden_tensor.get_torch_tensor() + + +def golden(context=None, opContext=None): + import torch + import ttrt.runtime + + print("-----------executing golden comparision-----------") + + try: + device_tensor = ttrt.runtime.get_op_output_tensor(context, opContext) + op_debug_str = ttrt.runtime.get_op_debug_str(context, opContext) + + if device_tensor == None or len(device_tensor) == 0: + print("No device tensor provided for golden comparison") + return + elif op_debug_str == None or op_debug_str == "": + print("No debug string provided for golden comparison") + return + else: + # find matching golden tensor based on loc in op debug string + match = re.search(r"loc\(([^)]+)\)", op_debug_str) + + if not match: + print(f"debug_str={op_debug_str}") + print("No location found in debug string - skipping golden comparison") + return + + loc = match.group(1).replace('"', "") + print(f"found location={loc}") + + if loc not in GOLDENS.keys(): + print( + f"No golden tensor found for loc={loc} in golden cache - skipping golden comparison" + ) + return + + golden_torch_tensor = GOLDENS[loc].flatten() + device_tensor_torch = torch.tensor(device_tensor, dtype=torch.float32) + _, _, cal_pcc, output_str = get_atol_rtol_pcc( + golden_torch_tensor, device_tensor_torch + ) + + print(f"PCC={cal_pcc}") + print(output_str) + finally: + print("-----------finished executing golden comparision-----------") + + +def pdb(): + print("right now pdb doesn't do anything") diff --git a/runtime/tools/python/ttrt/common/run.py b/runtime/tools/python/ttrt/common/run.py index 976779e5fe..b9f32a139e 100644 --- a/runtime/tools/python/ttrt/common/run.py +++ b/runtime/tools/python/ttrt/common/run.py @@ -18,6 +18,7 @@ from ttrt.common.util import * from ttrt.common.query import Query +from ttrt.common.callback import golden, pdb class Run: @@ -172,6 +173,13 @@ def initialize_api(): choices=None, help="test file to save results to", ) + Run.register_arg( + name="--golden", + type=bool, + default=False, + choices=[True, False], + help="run golden comparison for intermediate and output tensors", + ) Run.register_arg( name="binary", type=str, @@ -361,6 +369,9 @@ def _execute(binaries): self.logging.warning(f"no binaries found to run - returning early") return + if self["--golden"]: + callback_env = ttrt.runtime.DebugHooks.get(golden) + debug_env = ttrt.runtime.DebugEnv.get( self["--load-kernels-from-disk"], self["--enable-async-ttnn"] ) diff --git a/runtime/tools/python/ttrt/common/util.py b/runtime/tools/python/ttrt/common/util.py index ebbf1d6d72..5e9f7abe8e 100644 --- a/runtime/tools/python/ttrt/common/util.py +++ b/runtime/tools/python/ttrt/common/util.py @@ -17,6 +17,7 @@ import shutil import ttrt.binary +from ttrt.common.callback import add_global_golden # environment tweaks if "LOGGER_LEVEL" not in os.environ: @@ -522,20 +523,45 @@ def get_ttsys_file_extension(): return Flatbuffer.ttsys_file_extension -class Golden: - def __init__(self, tensor_id, tensor_shape, tensor_stride, tensor_data): - self.tensor_id = tensor_id - self.tensor_shape = tensor_shape - self.tensor_stride = tensor_stride - self.tensor_data = tensor_data +class GoldenMap: + def __init__(self): + self.golden_map = {} - def get_golden_tensor(self): - tensor_byte_data = bytes(self.tensor_data) - float_data = np.frombuffer(tensor_byte_data, dtype=np.float32) - golden_tensor = torch.tensor(float_data, dtype=torch.float32).reshape( - self.tensor_shape - ) - return golden_tensor + def add_golden(self, element): + self.golden_map[element.tensor_id] = element + + if not element.tensor_id.startswith("input"): + add_global_golden(element) + + def get_golden(self, tensor_id): + return self.golden_map[tensor_id] + + def get_inputs(self): + inputs = [] + + for i, tensor in self.golden_map.items(): + if i.startswith("input"): + inputs.append(tensor) + + return inputs + + class Golden: + def __init__(self, tensor_id, tensor_shape, tensor_stride, tensor_data): + self.tensor_id = tensor_id + self.tensor_shape = tensor_shape + self.tensor_stride = tensor_stride + self.tensor_data = tensor_data + + def get_torch_tensor(self): + import numpy as np + import torch + + tensor_byte_data = bytes(self.tensor_data) + float_data = np.frombuffer(tensor_byte_data, dtype=np.float32) + golden_tensor = torch.tensor(float_data, dtype=torch.float32).reshape( + self.tensor_shape + ) + return golden_tensor class Binary(Flatbuffer): @@ -557,20 +583,6 @@ def __init__(self, logger, file_manager, file_path, capsule=None): program = Binary.Program(i, self.fbb_dict["programs"][i]) self.programs.append(program) - # populate golden tensors if they exist - if "debug_info" in self.fbb_dict["programs"][i]: - golden_info_list = self.fbb_dict["programs"][i]["debug_info"][ - "golden_info" - ]["golden_map"] - - for golden_tensor_dict in golden_info_list: - Golden( - golden_tensor_dict["key"], - golden_tensor_dict["value"]["shape"], - golden_tensor_dict["value"]["stride"], - golden_tensor_dict["value"]["data"], - ) - def check_system_desc(self, query): import ttrt.binary @@ -615,16 +627,35 @@ def __init__(self, index, program): self.program = program self.input_tensors = [] self.output_tensors = [] + self.golden_map = GoldenMap() - def populate_inputs(self, init_fn): - for i in self.program["inputs"]: - torch_tensor = init_fn( - i["desc"]["shape"], - dtype=Binary.Program.from_data_type( - i["desc"]["layout"]["memory_desc"]["data_type"] - ), + # populate golden tensors if they exist + golden_info_list = self.program["debug_info"]["golden_info"]["golden_map"] + for golden_tensor_dict in golden_info_list: + golden_tensor = GoldenMap.Golden( + golden_tensor_dict["key"], + golden_tensor_dict["value"]["shape"], + golden_tensor_dict["value"]["stride"], + golden_tensor_dict["value"]["data"], ) - self.input_tensors.append(torch_tensor) + self.golden_map.add_golden(golden_tensor) + + def populate_inputs(self, init_fn): + inputs = self.golden_map.get_inputs() + + if len(inputs) != 0: + for tensor in inputs: + torch_tensor = tensor.get_torch_tensor() + self.input_tensors.append(torch_tensor) + else: + for i in self.program["inputs"]: + torch_tensor = init_fn( + i["desc"]["shape"], + dtype=Binary.Program.from_data_type( + i["desc"]["layout"]["memory_desc"]["data_type"] + ), + ) + self.input_tensors.append(torch_tensor) def populate_outputs(self, init_fn): for i in self.program["outputs"]: diff --git a/runtime/tools/python/ttrt/runtime/__init__.py b/runtime/tools/python/ttrt/runtime/__init__.py index 1a616db248..2bdb96609d 100644 --- a/runtime/tools/python/ttrt/runtime/__init__.py +++ b/runtime/tools/python/ttrt/runtime/__init__.py @@ -10,6 +10,7 @@ DataType, DeviceRuntime, DebugEnv, + DebugHooks, get_current_runtime, set_compatible_runtime, get_current_system_desc, @@ -18,6 +19,8 @@ submit, create_tensor, create_multi_device_tensor, + get_op_output_tensor, + get_op_debug_str, wait, WorkaroundEnv, ) diff --git a/runtime/tools/python/ttrt/runtime/module.cpp b/runtime/tools/python/ttrt/runtime/module.cpp index 4f528c02f9..99cfee4aaf 100644 --- a/runtime/tools/python/ttrt/runtime/module.cpp +++ b/runtime/tools/python/ttrt/runtime/module.cpp @@ -9,6 +9,7 @@ #include "tt/runtime/runtime.h" #include "tt/runtime/utils.h" +#include #include #include @@ -77,6 +78,10 @@ PYBIND11_MODULE(_C, m) { dataType, strategy); }, "Create a multi-device host tensor with owned memory"); + m.def("get_op_output_tensor", &tt::runtime::getOpOutputTensor, + "Get the output tensor of an operation"); + m.def("get_op_debug_str", &tt::runtime::getOpDebugString, + "Get the debug string of an operation"); m.def("get_num_available_devices", &tt::runtime::getNumAvailableDevices, "Get the number of available devices"); m.def("open_device", &tt::runtime::openDevice, py::arg("device_ids"), @@ -96,6 +101,21 @@ PYBIND11_MODULE(_C, m) { return os.str(); }); + py::class_(m, "DebugHooks") + .def_static("get", + [](py::function func) { + tt::runtime::debug::Hooks::get( + [func](std::optional context, + std::optional opContext) { + func(context, opContext); + }); + }) + .def("__str__", [](const tt::runtime::debug::Hooks &hooks) { + std::stringstream os; + os << hooks; + return os.str(); + }); + py::class_(m, "WorkaroundEnv") .def_static("get", &tt::runtime::workaround::Env::get) .def("__str__", [](const tt::runtime::workaround::Env &env) {