Skip to content
This repository has been archived by the owner on Apr 18, 2024. It is now read-only.

Commit

Permalink
Merged PR 2: Merge latest commits
Browse files Browse the repository at this point in the history
  • Loading branch information
wenxcs committed Feb 23, 2022
1 parent 55cfc4a commit f904e94
Show file tree
Hide file tree
Showing 32 changed files with 2,608 additions and 2 deletions.
21 changes: 21 additions & 0 deletions .azurepipeline/ci-track_github_tvm.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
schedules:
- cron: "0 0 * * *"
displayName: Daily midnight check
branches:
include:
- main

pool:
vmImage: ubuntu-latest

steps:
- script: |
git config --global user.email "nnfusion_team@microsoft.com"
git config --global user.name "NNFusion team"
git config pull.rebase false
git checkout origin main
git checkout main
git remote add tvm https://github.com/apache/tvm.git
git pull tvm main
git push https://${PAT}:PAT@dev.azure.com/TensorStar/TensorStar/_git/tvm main
displayName: 'Add apache/tvm as new origin'
20 changes: 20 additions & 0 deletions .azurepipeline/ci-windows_build.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
trigger:
- develop

pool: Default

steps:
- checkout: self
submodules: true
persistCredentials: true
- task: PowerShell@2
inputs:
targetType: 'inline'
script: |
F:\tool\vcpkg\vcpkg.exe install gtest:x64-windows
- task: CMake@1
inputs:
cmakeArgs: '.. -DUSE_DIRECTX=ON -DCMAKE_TOOLCHAIN_FILE="F:/tool/vcpkg/scripts/buildsystems/vcpkg.cmake"'
- task: CMake@1
inputs:
cmakeArgs: '--build . -j'
9 changes: 9 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,12 @@
[submodule "3rdparty/cutlass"]
path = 3rdparty/cutlass
url = https://github.com/NVIDIA/cutlass
[submodule "3rdparty/DirectXShaderCompiler"]
path = 3rdparty/DirectXShaderCompiler
url = https://github.com/nnfusion/DirectXShaderCompiler
[submodule "3rdparty/DirectX-Headers"]
path = 3rdparty/DirectX-Headers
url = https://github.com/microsoft/DirectX-Headers.git
[submodule "3rdparty/DirectXTK12"]
path = 3rdparty/DirectXTK12
url = https://github.com/microsoft/DirectXTK12.git
1 change: 1 addition & 0 deletions 3rdparty/DirectX-Headers
Submodule DirectX-Headers added at 0644e7
5 changes: 5 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ tvm_option(USE_OPENCL "Build with OpenCL" OFF)
tvm_option(USE_VULKAN "Build with Vulkan" OFF)
tvm_option(USE_METAL "Build with Metal" OFF)
tvm_option(USE_ROCM "Build with ROCM" OFF)
tvm_option(USE_DIRECTX "Build with DIRECTX" ON)
tvm_option(ROCM_PATH "The path to rocm" /opt/rocm)
tvm_option(USE_HEXAGON_DEVICE "Build with Hexagon device support in TVM runtime" OFF)
tvm_option(USE_HEXAGON_SDK "Path to the Hexagon SDK root (required for Hexagon support in TVM runtime or for building TVM runtime for Hexagon)" /path/to/sdk)
Expand Down Expand Up @@ -63,6 +64,7 @@ tvm_option(DMLC_PATH "Path to DMLC" "3rdparty/dmlc-core/include")
tvm_option(RANG_PATH "Path to RANG" "3rdparty/rang/include")
tvm_option(COMPILER_RT_PATH "Path to COMPILER-RT" "3rdparty/compiler-rt")
tvm_option(PICOJSON_PATH "Path to PicoJSON" "3rdparty/picojson")
tvm_option(DIRECTX_HEADER_PATH "Path to DirectX headers" "3rdparty/DirectX-Headers/include")

# Contrib library options
tvm_option(USE_BYODT_POSIT "Build with BYODT software emulated posit custom datatype" OFF)
Expand Down Expand Up @@ -103,6 +105,7 @@ include_directories(SYSTEM ${DMLC_PATH})
include_directories(SYSTEM ${RANG_PATH})
include_directories(SYSTEM ${COMPILER_RT_PATH})
include_directories(SYSTEM ${PICOJSON_PATH})
include_directories(SYSTEM ${DIRECTX_HEADER_PATH})

# initial variables
set(TVM_LINKER_LIBS "")
Expand Down Expand Up @@ -267,6 +270,7 @@ tvm_file_glob(GLOB_RECURSE COMPILER_SRCS
tvm_file_glob(GLOB CODEGEN_SRCS
src/target/*.cc
src/target/source/*.cc
src/target/directx/*.cc
)

list(APPEND COMPILER_SRCS ${CODEGEN_SRCS})
Expand Down Expand Up @@ -428,6 +432,7 @@ include(cmake/modules/Arduino.cmake)
include(cmake/modules/CUDA.cmake)
include(cmake/modules/Hexagon.cmake)
include(cmake/modules/OpenCL.cmake)
include(cmake/modules/DirectX.cmake)
include(cmake/modules/OpenMP.cmake)
include(cmake/modules/Vulkan.cmake)
include(cmake/modules/Metal.cmake)
Expand Down
5 changes: 4 additions & 1 deletion cmake/config.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@ set(USE_OPENCL OFF)
# Whether enable Metal runtime
set(USE_METAL OFF)

# Whether enable DirectX runtime
set(USE_DIRECTX OFF)

# Whether enable Vulkan runtime
#
# Possible values:
Expand All @@ -99,7 +102,7 @@ set(USE_IOS_RPC OFF)
# Whether embed stackvm into the runtime
set(USE_STACKVM_RUNTIME OFF)

# Whether enable tiny embedded graph executor.
# Whether enable tiny embedded graph executor.
set(USE_GRAPH_EXECUTOR ON)

# Whether enable tiny graph executor with CUDA Graph
Expand Down
32 changes: 32 additions & 0 deletions cmake/modules/DirectX.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# DirectX Module

if(USE_DIRECTX)
message(STATUS "Build with DirectX support")
file(GLOB RUNTIME_DIRECTX_SRCS src/runtime/directx/*.cc)
list(APPEND RUNTIME_SRCS ${RUNTIME_DIRECTX_SRCS})
if(GTEST_FOUND)
file(GLOB RUNTIME_TEST_DIRECTX_SRCS src/runtime/directx/test/*.cc)
add_executable(dx_test ${RUNTIME_TEST_DIRECTX_SRCS})
target_link_libraries(dx_test tvm_libinfo_objs tvm_objs tvm_runtime_objs GTest::GTest GTest::Main)
gtest_discover_tests(dx_test)
endif()
else()
list(APPEND COMPILER_SRCS src/target/opt/build_directx_off.cc)
endif(USE_DIRECTX)
6 changes: 5 additions & 1 deletion include/tvm/runtime/c_runtime_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,11 @@ typedef enum {
kOpenGL = 11,
kDLMicroDev = 13,
kDLHexagon = 14,
kDLWebGPU = 15
kDLWebGPU = 15,
kDLDirectX = 20,
kDLDirectXHost = 21,
kDLDirectXUpload = 22,
kDLDirectXReadback = 23,
// AddExtraTVMType which is not in DLPack here
} TVMDeviceExtType;

Expand Down
8 changes: 8 additions & 0 deletions include/tvm/runtime/device_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,14 @@ inline const char* DeviceName(int type) {
return "webgpu";
case kDLHexagon:
return "hexagon";
case kDLDirectXUpload:
return "directx_upload";
case kDLDirectX:
return "directx";
case kDLDirectXReadback:
return "directx_readback";
case kDLDirectXHost:
return "directx_host";
default:
LOG(FATAL) << "unknown type =" << type;
return "Unknown";
Expand Down
8 changes: 8 additions & 0 deletions python/tvm/_ffi/runtime_ctypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,10 @@ class Device(ctypes.Structure):
12: "ext_dev",
14: "hexagon",
15: "webgpu",
20: "directx",
21: "directx_host",
22: "directx_upload",
23: "directx_readback",
}
STR2MASK = {
"llvm": 1,
Expand All @@ -232,6 +236,10 @@ class Device(ctypes.Structure):
"ext_dev": 12,
"hexagon": 14,
"webgpu": 15,
"directx": 20,
"directx_host": 21,
"directx_upload": 22,
"directx_readback": 23,
}

def __init__(self, device_type, device_id):
Expand Down
123 changes: 123 additions & 0 deletions src/runtime/directx/directx_buffer.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
#include "directx_header.h"

using namespace tvm::runtime::dx;

DirectBuffer::DirectBuffer(DirectXDevice* _dev, UINT64 size, DLDataType type) : _dxdev(_dev) {
_res = _dev->device_allocate(size);
D3D12_RESOURCE_DESC desc = _res->GetDesc();
this->size = desc.Width;
this->type = type;
}

DirectBuffer::DirectBuffer(DirectXDevice* _dev, ComPtr<ID3D12Resource> res, DLDataType type)
: _dxdev(_dev), _res(res) {
D3D12_RESOURCE_DESC desc = _res->GetDesc();
this->size = desc.Width;
this->type = type;
}

DirectHostBuffer::DirectHostBuffer(DirectXDevice* _dev, UINT64 size, DLDataType type, hostbuffer_state state)
: _cur_state(state), ptr(nullptr) {
if (state == hostbuffer_state::upload)
_host_res = _dev->upload_allocate(size);
else if (state == hostbuffer_state::readback)
_host_res = _dev->readback_allocate(size);
else
throw std::invalid_argument(_msg_("Buffer state is not supported"));
D3D12_RESOURCE_DESC desc = _host_res->GetDesc();
this->size = desc.Width;
range = {0, static_cast<SIZE_T>(size)};
_dxdev = _dev;
this->type = type;
}

void* DirectHostBuffer::open_data_ptr() {
if (ptr != nullptr) return ptr;
ThrowIfFailed(_host_res->Map(0, &range, reinterpret_cast<void**>(&ptr)));
return ptr;
}

void DirectHostBuffer::close_data_ptr() {
ptr = nullptr;
if (_cur_state == hostbuffer_state::readback)
// Use begin = end to tell no data is changed;
{
D3D12_RANGE range = {0, 0};
_host_res->Unmap(0, &range);
} else
_host_res->Unmap(0, &range);
}

// todo(wenxh): use resource barrier to support call this transition in async way;
void DirectHostBuffer::change_state(hostbuffer_state hs) {
if (hs == _cur_state) return;
D3D12_RESOURCE_DESC desc = _host_res->GetDesc();
if (hs == hostbuffer_state::upload) {
// readback to upload, need to memcpy
auto tgt = _dxdev->upload_allocate(size);
// cpu memcpy
{
// open ptr
void* t_ptr = nullptr;
ThrowIfFailed(tgt->Map(0, &range, reinterpret_cast<void**>(&t_ptr)));
open_data_ptr();
memcpy(t_ptr, ptr, size);

// close ptr
close_data_ptr();
tgt->Unmap(0, &range);
}

_host_res = tgt;
} else if (hs == hostbuffer_state::readback) {
close_data_ptr();
auto tgt = _dxdev->readback_allocate(size);
_dxdev->copy(tgt, _host_res);
_host_res = tgt;
} else {
throw std::invalid_argument(_msg_("Target buffer state is not supported."));
}
_cur_state = hs;
}

DirectReadBackBuffer::DirectReadBackBuffer(DirectXDevice* _dev, UINT64 size, DLDataType type)
: DirectBuffer(_dev, size, type) {
_host_res = _dev->readback_allocate(size);
range = {0, static_cast<SIZE_T>(size)};
this->ptr = nullptr;
}

void* DirectReadBackBuffer::open_data_ptr() {
if (ptr != nullptr) return ptr;
ThrowIfFailed(_host_res->Map(0, &range, reinterpret_cast<void**>(&ptr)));
return ptr;
}

void DirectReadBackBuffer::to_host(bool async) { _dxdev->copy(_host_res, _res, async); }

void DirectReadBackBuffer::close_data_ptr() {
ptr = nullptr;
// Use begin = end to tell no data is changed;
D3D12_RANGE range = {0, 0};
_host_res->Unmap(0, &range);
}

DirectUploadBuffer::DirectUploadBuffer(DirectXDevice* _dev, UINT64 size, DLDataType type)
: DirectBuffer(_dev, size, type) {
_host_res = _dev->upload_allocate(size);
range = {0, static_cast<SIZE_T>(size)};
this->ptr = nullptr;
}

void* DirectUploadBuffer::open_data_ptr() {
if (ptr != nullptr) return ptr;
ThrowIfFailed(_host_res->Map(0, &range, reinterpret_cast<void**>(&ptr)));
return ptr;
}

void DirectUploadBuffer::close_data_ptr() {
ptr = nullptr;
_host_res->Unmap(0, &range);
}

void DirectUploadBuffer::to_device(bool async) { _dxdev->copy(_res, _host_res, async); }
Loading

0 comments on commit f904e94

Please sign in to comment.