Skip to content

Commit

Permalink
#72: Added strict flatbuffer and runtime versioning check (#143)
Browse files Browse the repository at this point in the history
  • Loading branch information
tapspatel authored Jul 15, 2024
1 parent 04f952e commit c4f6752
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 9 deletions.
14 changes: 7 additions & 7 deletions cmake/modules/TTMLIRVersion.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,14 @@ execute_process(

# Extract the major and minor version from the tag (assumes tags are in "major.minor" format)
string(REGEX MATCH "^v([0-9]+)\\.([0-9]+)$" GIT_TAG_MATCH ${GIT_TAG})
set(PROJECT_VERSION_MAJOR ${CMAKE_MATCH_1})
set(PROJECT_VERSION_MINOR ${CMAKE_MATCH_2})
set(PROJECT_VERSION_PATCH ${GIT_COMMITS})
set(TTMLIR_VERSION_MAJOR ${CMAKE_MATCH_1})
set(TTMLIR_VERSION_MINOR ${CMAKE_MATCH_2})
set(TTMLIR_VERSION_PATCH ${GIT_COMMITS})

message(STATUS "Project commit hash: ${TTMLIR_GIT_HASH}")
message(STATUS "Project version: ${PROJECT_VERSION_MAJOR}.${PROJECT_VERSION_MINOR}.${PROJECT_VERSION_PATCH}")
message(STATUS "Project version: ${TTMLIR_VERSION_MAJOR}.${TTMLIR_VERSION_MINOR}.${TTMLIR_VERSION_PATCH}")

add_definitions("-DTTMLIR_GIT_HASH=${TTMLIR_GIT_HASH}")
add_definitions("-DTTMLIR_VERSION_MAJOR=${PROJECT_VERSION_MAJOR}")
add_definitions("-DTTMLIR_VERSION_MINOR=${PROJECT_VERSION_MINOR}")
add_definitions("-DTTMLIR_VERSION_PATCH=${PROJECT_VERSION_PATCH}")
add_definitions("-DTTMLIR_VERSION_MAJOR=${TTMLIR_VERSION_MAJOR}")
add_definitions("-DTTMLIR_VERSION_MINOR=${TTMLIR_VERSION_MINOR}")
add_definitions("-DTTMLIR_VERSION_PATCH=${TTMLIR_VERSION_PATCH}")
2 changes: 1 addition & 1 deletion runtime/tools/python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ add_custom_target(ttrt-copy-files
)

add_custom_target(ttrt
COMMAND TTMLIR_ENABLE_RUNTIME=${TTMLIR_ENABLE_RUNTIME} SOURCE_ROOT=${TTMLIR_SOURCE_DIR} python -m pip install .
COMMAND TTMLIR_ENABLE_RUNTIME=${TTMLIR_ENABLE_RUNTIME} TTMLIR_VERSION_MAJOR=${TTMLIR_VERSION_MAJOR} TTMLIR_VERSION_MINOR=${TTMLIR_VERSION_MINOR} TTMLIR_VERSION_PATCH=${TTMLIR_VERSION_PATCH} SOURCE_ROOT=${TTMLIR_SOURCE_DIR} python -m pip install .
WORKING_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}"
COMMENT "python ttrt package"
DEPENDS ttrt-copy-files
Expand Down
6 changes: 5 additions & 1 deletion runtime/tools/python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@
from setuptools import setup
import shutil

__version__ = "0.0.1"
TTMLIR_VERSION_MAJOR=os.getenv('TTMLIR_VERSION_MAJOR', '0')
TTMLIR_VERSION_MINOR=os.getenv('TTMLIR_VERSION_MINOR', '0')
TTMLIR_VERSION_PATCH=os.getenv('TTMLIR_VERSION_PATCH', '0')

__version__ = f"{TTMLIR_VERSION_MAJOR}.{TTMLIR_VERSION_MINOR}.{TTMLIR_VERSION_PATCH}"

src_dir = os.environ.get(
"SOURCE_ROOT",
Expand Down
11 changes: 11 additions & 0 deletions runtime/tools/python/ttrt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import os
import json

from pkg_resources import get_distribution

def system_desc_as_dict(desc):
return json.loads(desc.as_json())
Expand All @@ -16,6 +17,14 @@ def system_desc_as_dict(desc):
if "TT_METAL_LOGGER_LEVEL" not in os.environ:
os.environ["TT_METAL_LOGGER_LEVEL"] = "FATAL"

def check_version(fb_version):
package_name = 'ttrt'
try:
package_version = get_distribution(package_name).version
except Exception as e:
print(f"Error retrieving version: {e} for {package_name}")

assert package_version == fb_version, f"{package_name}=v{package_version} does not match flatbuffer=v{fb_version}"

def mlir_sections(fbb):
d = ttrt.binary.as_dict(fbb)
Expand Down Expand Up @@ -73,6 +82,7 @@ def program_outputs(fbb):

def read(args):
fbb = ttrt.binary.load_from_path(args.binary)
check_version(fbb.version)
read_actions[args.section](fbb)


Expand Down Expand Up @@ -117,6 +127,7 @@ def fromDataType(dtype):
raise ValueError(f"unsupported dtype: {dtype}")

fbb = ttrt.binary.load_binary_from_path(args.binary)
check_version(fbb.version)
assert fbb.file_identifier == "TTNN", "Only TTNN binaries are supported"
d = ttrt.binary.as_dict(fbb)
assert args.program_index < len(d["programs"]), "args.program_index out of range"
Expand Down

0 comments on commit c4f6752

Please sign in to comment.