Skip to content

Commit

Permalink
Add a CPU lowering pass (#26)
Browse files Browse the repository at this point in the history
Per discussion in #24, this PR creates a reference-CPU backend for triton-shared using the standard MLIR->LLVM lowering passes.

Without any changes in the Triton runtime, we can run triton-shared on CPU from Python-Triton code.
  - Potentially it is useful for testing.
  - A very basic kernel (see reduce.py) runs and succeeds.
  - Many kernels do not run because of a lack of LLVM lowering and other problems.
    - For example, memref::tensorStore->LLVM seems not implemented, so many kernels are not supported at this point.

Help from the community and triton-shared core developers is needed to improve and maintain it. I'd appreciate your feedback and suggestions.

Note: compared with the PR I used in the discussion, this PR version can run a compute kernel (reduce.py, as an example) by using MLIR's CRunnerUtils.h.
  • Loading branch information
shintaro-iwasaki authored Nov 6, 2023
1 parent 0f6d8e1 commit 6fa7ce3
Show file tree
Hide file tree
Showing 10 changed files with 1,256 additions and 0 deletions.
6 changes: 6 additions & 0 deletions .github/workflows/integration-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,9 @@ jobs:
with:
triton-ref: '05dc28be0e72dd496300a31b99a21a5a5118f8e9' # known good commit "[CI] refactor workflows (#2504)"
triton-shared-ref: ${{ github.ref }}

test-cpuref:
uses: ./.github/workflows/test-cpuref.yml
with:
triton-ref: '05dc28be0e72dd496300a31b99a21a5a5118f8e9' # known good commit "[CI] refactor workflows (#2504)"
triton-shared-ref: ${{ github.ref }}
75 changes: 75 additions & 0 deletions .github/workflows/test-cpuref.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
name: Triton-Shared Plugin Testing

on:
workflow_call:
inputs:
triton-ref:
required: true
type: string
triton-shared-ref:
required: true
type: string
workflow_dispatch:
inputs:
triton-ref:
required: true
type: string
triton-shared-ref:
required: true
type: string

jobs:
build_and_test_triton_shared:
runs-on: ubuntu-latest

steps:

- name: Checkout Triton
uses: actions/checkout@v4
with:
repository: 'openai/triton'
ref: ${{ inputs.triton-ref }}
path: triton
submodules: 'recursive'

- name: Checkout Triton-Shared
uses: actions/checkout@v4
with:
ref: ${{ inputs.triton-shared-ref }}
path: triton/third_party/triton_shared

- name: Clear Triton Cache
run: |
rm -rf ~/.triton
- name: Update PATH
run: |
echo "PATH=${HOME}/.local/bin:${PATH}" >> "${GITHUB_ENV}"
- name: Check pre-commit
run: |
cd triton
python3 -m pip install --upgrade pre-commit
python3 -m pre_commit run --all-files --verbose
- name: Build/Install Triton
run: |
export TRITON_CODEGEN_TRITON_SHARED=1
cd triton/python
python3 -m pip install --upgrade pip
python3 -m pip install cmake==3.24
python3 -m pip install ninja
python3 -m pip uninstall -y triton
python3 setup.py build
python3 -m pip install --no-build-isolation -vvv '.[tests]'
- name: Install PyTorch
run: |
python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu
- name: Run an example
run: |
cd triton/python
export TRITON_SHARED_OPT_PATH="$(pwd)/build/$(ls $(pwd)/build | grep -i cmake)/third_party/triton_shared/tools/triton-shared-opt/triton-shared-opt"
export LLVM_BINARY_DIR="${HOME}/.triton/llvm/$(ls ${HOME}/.triton/llvm/ | grep -i llvm)/bin"
python3 ../third_party/triton_shared/python/examples/reduce.py
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ set(TRITON_SHARED_BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}")
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include)
include_directories(${CMAKE_CURRENT_BINARY_DIR}/include) # Tablegen'd files

set(TRITON_BUILD_PYTHON_MODULE ON)
add_subdirectory(include)
add_subdirectory(lib)
add_subdirectory(test)
add_subdirectory(tools)
add_subdirectory(python)
12 changes: 12 additions & 0 deletions python/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@

# Python module
if(TRITON_BUILD_PYTHON_MODULE)
message(STATUS "Adding Triton-Shared Reference CPU Backend")
file(INSTALL
${CMAKE_CURRENT_SOURCE_DIR}/__init__.py
${CMAKE_CURRENT_SOURCE_DIR}/ExecutionEngine/Msan.h
${CMAKE_CURRENT_SOURCE_DIR}/ExecutionEngine/CRunnerUtils.h
${CMAKE_CURRENT_SOURCE_DIR}/ExecutionEngine/CRunnerUtils.cpp
DESTINATION ${PYTHON_THIRD_PARTY_PATH}/cpu/)
# TODO: perhaps we want to install binary files used by __init__.py
endif()
192 changes: 192 additions & 0 deletions python/ExecutionEngine/CRunnerUtils.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
//===- CRunnerUtils.cpp - Utils for MLIR execution ------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements basic functions to manipulate structured MLIR types at
// runtime. Entities in this file are meant to be retargetable, including on
// targets without a C++ runtime, and must be kept C compatible.
//
//===----------------------------------------------------------------------===//

#include "CRunnerUtils.h"
#include "Msan.h"

#ifndef _WIN32
#if defined(__FreeBSD__) || defined(__NetBSD__) || defined(__OpenBSD__) || \
defined(__DragonFly__)
#include <cstdlib>
#else
#include <alloca.h>
#endif
#include <sys/time.h>
#else
#include "malloc.h"
#endif // _WIN32

#include <algorithm>
#include <cinttypes>
#include <cstdio>
#include <cstdlib>
#include <random>
#include <string.h>

#ifdef MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS

namespace {
template <typename V>
void stdSort(uint64_t n, V *p) {
std::sort(p, p + n);
}

} // namespace

// Small runtime support "lib" for vector.print lowering.
// By providing elementary printing methods only, this
// library can remain fully unaware of low-level implementation
// details of our vectors. Also useful for direct LLVM IR output.
extern "C" void printI64(int64_t i) { fprintf(stdout, "%" PRId64, i); }
extern "C" void printU64(uint64_t u) { fprintf(stdout, "%" PRIu64, u); }
extern "C" void printF32(float f) { fprintf(stdout, "%g", f); }
extern "C" void printF64(double d) { fprintf(stdout, "%lg", d); }
extern "C" void printString(char const *s) { fputs(s, stdout); }
extern "C" void printOpen() { fputs("( ", stdout); }
extern "C" void printClose() { fputs(" )", stdout); }
extern "C" void printComma() { fputs(", ", stdout); }
extern "C" void printNewline() { fputc('\n', stdout); }

extern "C" void memrefCopy(int64_t elemSize, UnrankedMemRefType<char> *srcArg,
UnrankedMemRefType<char> *dstArg) {
DynamicMemRefType<char> src(*srcArg);
DynamicMemRefType<char> dst(*dstArg);

int64_t rank = src.rank;
MLIR_MSAN_MEMORY_IS_INITIALIZED(src.sizes, rank * sizeof(int64_t));

// Handle empty shapes -> nothing to copy.
for (int rankp = 0; rankp < rank; ++rankp)
if (src.sizes[rankp] == 0)
return;

char *srcPtr = src.data + src.offset * elemSize;
char *dstPtr = dst.data + dst.offset * elemSize;

if (rank == 0) {
memcpy(dstPtr, srcPtr, elemSize);
return;
}

int64_t *indices = static_cast<int64_t *>(alloca(sizeof(int64_t) * rank));
int64_t *srcStrides = static_cast<int64_t *>(alloca(sizeof(int64_t) * rank));
int64_t *dstStrides = static_cast<int64_t *>(alloca(sizeof(int64_t) * rank));

// Initialize index and scale strides.
for (int rankp = 0; rankp < rank; ++rankp) {
indices[rankp] = 0;
srcStrides[rankp] = src.strides[rankp] * elemSize;
dstStrides[rankp] = dst.strides[rankp] * elemSize;
}

int64_t readIndex = 0, writeIndex = 0;
for (;;) {
// Copy over the element, byte by byte.
memcpy(dstPtr + writeIndex, srcPtr + readIndex, elemSize);
// Advance index and read position.
for (int64_t axis = rank - 1; axis >= 0; --axis) {
// Advance at current axis.
auto newIndex = ++indices[axis];
readIndex += srcStrides[axis];
writeIndex += dstStrides[axis];
// If this is a valid index, we have our next index, so continue copying.
if (src.sizes[axis] != newIndex)
break;
// We reached the end of this axis. If this is axis 0, we are done.
if (axis == 0)
return;
// Else, reset to 0 and undo the advancement of the linear index that
// this axis had. Then continue with the axis one outer.
indices[axis] = 0;
readIndex -= src.sizes[axis] * srcStrides[axis];
writeIndex -= dst.sizes[axis] * dstStrides[axis];
}
}
}

/// Prints GFLOPS rating.
extern "C" void printFlops(double flops) {
fprintf(stderr, "%lf GFLOPS\n", flops / 1.0E9);
}

/// Returns the number of seconds since Epoch 1970-01-01 00:00:00 +0000 (UTC).
extern "C" double rtclock() {
#ifndef _WIN32
struct timeval tp;
int stat = gettimeofday(&tp, nullptr);
if (stat != 0)
fprintf(stderr, "Error returning time from gettimeofday: %d\n", stat);
return (tp.tv_sec + tp.tv_usec * 1.0e-6);
#else
fprintf(stderr, "Timing utility not implemented on Windows\n");
return 0.0;
#endif // _WIN32
}

extern "C" void *mlirAlloc(uint64_t size) { return malloc(size); }

extern "C" void *mlirAlignedAlloc(uint64_t alignment, uint64_t size) {
#ifdef _WIN32
return _aligned_malloc(size, alignment);
#elif defined(__APPLE__)
// aligned_alloc was added in MacOS 10.15. Fall back to posix_memalign to also
// support older versions.
void *result = nullptr;
(void)::posix_memalign(&result, alignment, size);
return result;
#else
return aligned_alloc(alignment, size);
#endif
}

extern "C" void mlirFree(void *ptr) { free(ptr); }

extern "C" void mlirAlignedFree(void *ptr) {
#ifdef _WIN32
_aligned_free(ptr);
#else
free(ptr);
#endif
}

extern "C" void *rtsrand(uint64_t s) {
// Standard mersenne_twister_engine seeded with s.
return new std::mt19937(s);
}

extern "C" uint64_t rtrand(void *g, uint64_t m) {
std::mt19937 *generator = static_cast<std::mt19937 *>(g);
std::uniform_int_distribution<uint64_t> distrib(0, m);
return distrib(*generator);
}

extern "C" void rtdrand(void *g) {
std::mt19937 *generator = static_cast<std::mt19937 *>(g);
delete generator;
}

#define IMPL_STDSORT(VNAME, V) \
extern "C" void _mlir_ciface_stdSort##VNAME(uint64_t n, \
StridedMemRefType<V, 1> *vref) { \
assert(vref); \
assert(vref->strides[0] == 1); \
V *values = vref->data + vref->offset; \
stdSort(n, values); \
}
IMPL_STDSORT(I64, int64_t)
IMPL_STDSORT(F64, double)
IMPL_STDSORT(F32, float)
#undef IMPL_STDSORT

#endif // MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS
Loading

0 comments on commit 6fa7ce3

Please sign in to comment.