Skip to content

Commit

Permalink
Remove all instances of defaultStrideCompuation
Browse files Browse the repository at this point in the history
This change removes all instances of `defaultStrideComputation` from the
workaround env
  • Loading branch information
ctodTT committed Feb 5, 2025
1 parent fe4ebc2 commit 852e16e
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 26 deletions.
19 changes: 4 additions & 15 deletions runtime/include/tt/runtime/detail/workarounds.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@ struct Env {
#endif
get(bool maxpool2dPreshard = true, bool swapBinaryOperands = true,
bool readUpdateIndexFromDeviceForKVCache = true,
bool toDtypeOnHost = true, bool defaultStrideComputation = true)
bool toDtypeOnHost = true)
#if defined(TT_RUNTIME_WORKAROUNDS) && TT_RUNTIME_WORKAROUNDS == 1
;
#else
{
return Env(true, true, true, true, true);
return Env(true, true, true, true);
}
#endif
// TODO(bug #855): Ideally we should have an op that preshards for maxpool2d
Expand All @@ -45,23 +45,14 @@ struct Env {
// to handle this, we should remove this workaround.
bool toDtypeOnHost;

// TODO(bug #2045): Our current stride calculation is incorrect for tilized
// tensors. The current solution is to remove stride entirely from the
// flatbuffer and calculate the stride in runtime assuming using the default
// method ignoring details like grid, layout etc. Once we have a more
// sophisticated way for handling this, we can remove this workaround.
bool defaultStrideComputation;

private:
constexpr Env(bool maxpool2dPreshard, bool swapBinaryOperands,
bool readUpdateIndexFromDeviceForKVCache, bool toDtypeOnHost,
bool defaultStrideComputation)
bool readUpdateIndexFromDeviceForKVCache, bool toDtypeOnHost)
: maxpool2dPreshard(maxpool2dPreshard),
swapBinaryOperands(swapBinaryOperands),
readUpdateIndexFromDeviceForKVCache(
readUpdateIndexFromDeviceForKVCache),
toDtypeOnHost(toDtypeOnHost),
defaultStrideComputation(defaultStrideComputation) {}
toDtypeOnHost(toDtypeOnHost) {}
};

inline std::ostream &operator<<(std::ostream &os, const Env &env) {
Expand All @@ -75,8 +66,6 @@ inline std::ostream &operator<<(std::ostream &os, const Env &env) {
<< env.readUpdateIndexFromDeviceForKVCache << "\n";
os << "\t"
<< "toDtypeOnHost: " << env.toDtypeOnHost << "\n";
os << "\t"
<< "defaultStrideComputation: " << env.defaultStrideComputation << "\n";
os << "}";
return os;
}
Expand Down
5 changes: 5 additions & 0 deletions runtime/lib/binary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@ static std::string asJson(void const *fbb, uint8_t const *binarySchema,

static std::vector<uint32_t>
calculateStride(std::vector<uint32_t> const &shape) {
// TODO(bug #2045): Our current stride calculation is incorrect for tilized
// tensors. The current solution is to remove stride entirely from the
// flatbuffer and calculate the stride in runtime assuming using the default
// method ignoring details like grid, layout etc. Once we have a more
// sophisticated way for handling this, we can remove this workaround.
LOG_ASSERT(!shape.empty());
std::vector<uint32_t> stride(shape.size(), 1);
for (size_t i = shape.size() - 1; i > 0; i--) {
Expand Down
5 changes: 2 additions & 3 deletions runtime/lib/common/workarounds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,9 @@ namespace tt::runtime::workaround {
#if defined(TT_RUNTIME_WORKAROUNDS) && TT_RUNTIME_WORKAROUNDS == 1
const Env &Env::get(bool maxpool2dPreshard, bool swapBinaryOperands,
bool readUpdateIndexFromDeviceForKVCache,
bool toDtypeOnHost, bool defaultStrideComputation) {
bool toDtypeOnHost) {
static const Env config(maxpool2dPreshard, swapBinaryOperands,
readUpdateIndexFromDeviceForKVCache, toDtypeOnHost,
defaultStrideComputation);
readUpdateIndexFromDeviceForKVCache, toDtypeOnHost);
return config;
}
#endif
Expand Down
8 changes: 0 additions & 8 deletions runtime/tools/python/ttrt/common/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,13 +147,6 @@ def initialize_api():
choices=[True, False],
help="disable to_dtype on host workaround",
)
Run.register_arg(
name="--disable-default-stride-computation",
type=bool,
default=False,
choices=[True, False],
help="disable runtime default stride computation workaround",
)
Run.register_arg(
name="--result-file",
type=str,
Expand Down Expand Up @@ -410,7 +403,6 @@ def _execute(binaries):
not self["--disable-swap-binary-operands"],
not self["--disable-read-update-index-for-kv-cache"],
not self["--disable-to-dtype-on-host"],
not self["--disable-default-stride-computation"],
)
self.logging.debug(f"setting tt runtime workaround env={workaround_env}")
self.logging.debug(f"setting torch manual seed={self['--seed']}")
Expand Down

0 comments on commit 852e16e

Please sign in to comment.