Skip to content

Commit

Permalink
Update default sys desc
Browse files Browse the repository at this point in the history
  • Loading branch information
jnie-TT committed Aug 8, 2024
1 parent bc16b85 commit 71cde8c
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 21 deletions.
21 changes: 1 addition & 20 deletions lib/Dialect/TT/IR/TTOpsTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ mlir::tt::SystemDescAttr::getDefault(MLIRContext *context) {
{
tt::ChipDescAttr::get(
context, tt::ArchAttr::get(context, tt::Arch::WormholeB0), {8, 8},
(1 << 20), 12, (1 << 20), 16, 32, 32),
1499136, 12, (1 << 30), 16, 32, 32),
},
// Chip Descriptor Indices
{
Expand Down Expand Up @@ -357,15 +357,6 @@ mlir::Type LayoutAttr::getElementType() const {
return getMemref().getElementType();
}

uint64_t LayoutAttr::getElementSizeBytes() const {
mlir::Type elementType = getElementType();
if (mlir::isa<TileType>(elementType)) {
auto tileType = mlir::cast<TileType>(elementType);
return tileType.getSizeBytes();
}
return elementType.getIntOrFloatBitWidth() / 8;
}

LayoutAttr LayoutAttr::withGrid(
::mlir::MLIRContext *context, ArrayRef<int64_t> tensorShape, GridAttr grid,
ArrayRef<std::pair<std::int64_t, std::int64_t>> collapseIntervals) {
Expand Down Expand Up @@ -458,16 +449,6 @@ DeviceAttr::verify(::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError,
return ::mlir::success();
}

::mlir::LogicalResult
TileType::verify(::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError,
ArrayRef<int64_t> shape, DataType dataType) {
if (shape.size() != 2) {
emitError() << "expected 2D shape";
return ::mlir::failure();
}
return ::mlir::success();
}

llvm::SmallVector<int64_t>
TileType::getScalarShape(SmallVector<int64_t> tiledShape) const {
assert(tiledShape.size() >= 2 && "expected at least 2D shape");
Expand Down
3 changes: 2 additions & 1 deletion runtime/lib/ttnn/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,8 @@ getCurrentSystemDescImpl(const ::ttnn::multi_device::DeviceMesh &deviceMesh) {
std::vector<uint32_t> chipDescIndices;
std::vector<::tt::target::ChipCapability> chipCapabilities;
// Ignore for now
std::vector<::tt::target::ChipCoord> chipCoords;
std::vector<::tt::target::ChipCoord> chipCoords = {
::tt::target::ChipCoord(0, 0, 0, 0)};
::flatbuffers::FlatBufferBuilder fbb;

for (const ::ttnn::Device *device : devices) {
Expand Down

0 comments on commit 71cde8c

Please sign in to comment.