Skip to content

Commit

Permalink
[CI] Added basic CI skeletons (#23)
Browse files Browse the repository at this point in the history
Includes minor fixes to make things compile and pass static checks properly
  • Loading branch information
ptillet authored Jul 26, 2022
1 parent faa3bc5 commit 40a62f9
Show file tree
Hide file tree
Showing 10 changed files with 61 additions and 60 deletions.
36 changes: 13 additions & 23 deletions .github/workflows/integration-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,11 @@ on:
branches:
- main


jobs:

Integration-Tests:

runs-on: self-hosted
runs-on: ubuntu-20.04

steps:

Expand All @@ -23,32 +22,23 @@ jobs:
rm -r ~/.triton/
continue-on-error: true

- name: Install Triton
run: |
alias python='python3'
cd python
pip3 install -e '.[tests]'
- name: Check imports
run: "isort -c ./python || ( echo '::error title=Imports not sorted::Please run \"isort ./python\"' ; exit 1 )"
run: |
pip install isort
isort -c ./python || ( echo '::error title=Imports not sorted::Please run \"isort ./python\"' ; exit 1 )
- name: Check style
run: "autopep8 -a -r -d --exit-code ./python || ( echo '::error title=Style issues::Please run \"autopep8 -a -r -i ./python\"' ; exit 1 )"
run: |
pip install autopep8
autopep8 -a -r -d --exit-code ./python || ( echo '::error title=Style issues::Please run \"autopep8 -a -r -i ./python\"' ; exit 1 )
- name: Flake8
run: "flake8 --config ./python/setup.cfg ./python || ( echo '::error::Flake8 failed; see logs for errors.' ; exit 1 )"

- name: Unit tests
run: |
cd python/test/unit
pytest -vs .
pip install flake8
flake8 --config ./python/setup.cfg ./python || ( echo '::error::Flake8 failed; see logs for errors.' ; exit 1 )
- name: Regression tests
- name: Install Triton
run: |
cd python/test/regression
sudo nvidia-smi -i 0 -pm 1
sudo nvidia-smi -i 0 --lock-gpu-clocks=1350,1350
sudo nvidia-smi -i 0 --lock-memory-clocks=877,877
pytest -vs .
sudo nvidia-smi -i 0 -rgc
sudo nvidia-smi -i 0 -rmc
alias python='python3'
cd python
pip3 install -e '.[tests]'
3 changes: 3 additions & 0 deletions lib/Analysis/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
add_mlir_library(TritonAnalysis
AxisInfo.cpp

DEPENDS
TritonGPUAttrDefsIncGen
)
11 changes: 6 additions & 5 deletions python/examples/copy_strided.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@

import triton
import triton.language as tl


# triton kernel
@triton.jit
def kernel(X, stride_xm, stride_xn,
Z, stride_zm, stride_zn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
def kernel(X, stride_xm, stride_xn,
Z, stride_zm, stride_zn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
off_m = tl.arange(0, BLOCK_M)
off_n = tl.arange(0, BLOCK_N)
Xs = X + off_m[:, None] * stride_xm + off_n[None, :] * stride_xn
Expand All @@ -15,4 +16,4 @@ def kernel(X, stride_xm, stride_xn,


ret = triton.compile(kernel, "*fp32,i32,i32,*fp32,i32,i32", constants={"BLOCK_M": 128, "BLOCK_N": 128}, output="ttgir")
print(ret)
print(ret)
4 changes: 3 additions & 1 deletion python/examples/empty.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import triton
import triton.language as tl


@triton.jit
def kernel(X, stride_xm, stride_xn, BLOCK: tl.constexpr):
pass

ret = triton.compile(kernel, "*fp32,i32,i32", constants={"BLOCK": 256}, output="ttgir")

ret = triton.compile(kernel, "*fp32,i32,i32", constants={"BLOCK": 256}, output="ttgir")
9 changes: 0 additions & 9 deletions python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,6 @@


def get_llvm():
# tries to find system LLVM
versions = ['-14.0', '-14', '-14-64']
supported = ['llvm-config{v}'.format(v=v) for v in versions]
paths = [distutils.spawn.find_executable(cfg) for cfg in supported]
paths = [p for p in paths if p is not None]
if paths:
return '', ''
if platform.system() == "Windows":
return '', ''
# download if nothing is installed
name = 'clang+llvm-14.0.0-x86_64-linux-gnu-ubuntu-18.04'
dir = '/tmp'
Expand Down
37 changes: 22 additions & 15 deletions python/triton/compiler.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
from __future__ import annotations

import ast
import sys
import warnings
from typing import Dict, Union

import triton
import triton._C.libtriton.triton as _triton


def str_to_ty(name):
if name[0] == "*":
ty = str_to_ty(name[1:])
ty = str_to_ty(name[1:])
return triton.language.pointer_type(ty)
tys = {
"fp8": triton.language.float8,
Expand All @@ -26,9 +28,10 @@ def str_to_ty(name):
"u32": triton.language.uint32,
"u64": triton.language.uint64,
"B": triton.language.int1,
}
}
return tys[name]


def mangle_ty(ty):
if ty.is_ptr():
return 'P' + mangle_ty(ty.element_ty)
Expand Down Expand Up @@ -62,6 +65,7 @@ def mangle_fn(name, arg_tys, constants):
ret = f'{name}__{mangled_arg_names}__{mangled_constants}'
return ret


class enter_sub_region:
def __init__(self, generator: CodeGenerator):
self.generator = generator
Expand All @@ -79,6 +83,7 @@ def __exit__(self, *args, **kwargs):
self.generator.lscope = self.liveins
self.generator.local_defs = self.prev_defs


class CodeGenerator(ast.NodeVisitor):
def __init__(self, context, prototype, gscope, attributes, constants, module=None, is_kernel=False, function_types=dict()):
self.builder = _triton.ir.builder(context)
Expand Down Expand Up @@ -491,8 +496,8 @@ def visit_While(self, node):
while_op = self.builder.create_while_op([ty.to_ir(self.builder) for ty in ret_types],
[arg.handle for arg in init_args])
# merge the condition region
before_block = self.builder.create_block_with_parent(while_op.get_before(),
[ty.to_ir(self.builder) for ty in ret_types])
before_block = self.builder.create_block_with_parent(while_op.get_before(),
[ty.to_ir(self.builder) for ty in ret_types])
cond_block.merge_block_before(before_block)
self.builder.set_insertion_point_to_end(before_block)
# create CondtionOp: e.g., scf.condition(%cond) %arg0, %arg1, ...
Expand Down Expand Up @@ -538,7 +543,6 @@ def visit_For(self, node):
iter_args = [self.visit(arg) for arg in node.iter.args]
is_static = all([isinstance(x, triton.language.constexpr) for x in iter_args])
if is_static:
st_target = ast.Name(id=node.target.id, ctx=ast.Store())
iter_args = [arg.value for arg in iter_args]
range = iterator(*iter_args)
if len(range) <= 10:
Expand Down Expand Up @@ -597,7 +601,7 @@ def visit_For(self, node):
# replace global uses with block arguments
for i, name in enumerate(names):
# arg0 is the induction variable
for_op.get_body(0).replace_use_in_block_with(init_args[i].handle, for_op.get_body(0).arg(i+1))
for_op.get_body(0).replace_use_in_block_with(init_args[i].handle, for_op.get_body(0).arg(i + 1))

# update lscope & local_defs (ForOp defines new values)
for i, name in enumerate(names):
Expand Down Expand Up @@ -633,7 +637,7 @@ def visit_Call(self, node):
args = getcallargs(fn.fn, *args, **kws)
args = [args[name] for name in fn.arg_names]
args = [arg if isinstance(arg, triton.language.tensor)
else triton.language.constexpr(arg) for arg in args]
else triton.language.constexpr(arg) for arg in args]
# generate function def
attributes = dict()
constexprs = [i for i, arg in enumerate(args) if isinstance(arg, triton.language.constexpr)]
Expand Down Expand Up @@ -712,7 +716,6 @@ def generic_visit(self, node):
raise NotImplementedError("Unsupported node: {}".format(typename))



class CompilationError(Exception):
def __init__(self, src, node):
self.message = f'at {node.lineno}:{node.col_offset}:\n'
Expand Down Expand Up @@ -742,11 +745,11 @@ def __reduce__(self):
return (type(self), (self.required, self.limit, self.name))


def make_triton_ir(fn, signature, constants = dict(), attributes = dict()):
def make_triton_ir(fn, signature, constants=dict(), attributes=dict()):
context = _triton.ir.context()
context.load_triton()
# create kernel prototype
arg_types = signature.replace(' ','').split(',')
arg_types = signature.replace(' ', '').split(',')
constants = {fn.arg_names.index(name): value for name, value in constants.items()}
arg_types = [str_to_ty(x) for x in arg_types]
prototype = triton.language.function_type([], arg_types)
Expand All @@ -765,6 +768,7 @@ def make_triton_ir(fn, signature, constants = dict(), attributes = dict()):
ret.context = context
return ret


def make_tritongpu_ir(mod, num_warps):
pm = _triton.ir.pass_manager(mod.context)
pm.add_inliner_pass()
Expand All @@ -775,6 +779,7 @@ def make_tritongpu_ir(mod, num_warps):
pm.run(mod)
return mod


def optimize_tritongpu_ir(mod, num_stages):
pm = _triton.ir.pass_manager(mod.context)
pm.add_tritongpu_pipeline_pass(num_stages)
Expand All @@ -785,22 +790,24 @@ def optimize_tritongpu_ir(mod, num_stages):
pm.run(mod)
return mod


def make_ptx(mod):
# TODO
return mod

def compile(fn, signature, constants = dict(), attributes = dict(), num_warps=4, num_stages=3, output = "ttgir"):

def compile(fn, signature, constants=dict(), attributes=dict(), num_warps=4, num_stages=3, output="ttgir"):
assert output in ["ttir", "ttgir", "ptx"]
# triton-ir
module = make_triton_ir(fn, signature, constants, attributes)
if output == "ttir":
if output == "ttir":
return module.str()
# tritongpu-ir
module = make_tritongpu_ir(module, num_warps)
module = optimize_tritongpu_ir(module, num_stages)
if output == "ttgir":
if output == "ttgir":
return module.str()
# ptx
if output == "ptx":
if output == "ptx":
return make_ptx(module)
assert False
assert False
4 changes: 2 additions & 2 deletions python/triton/runtime/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .jit import JITFunction, jit
from .autotuner import Config, autotune, heuristics
from .autotuner import Config, autotune, heuristics # noqa: F401
from .jit import JITFunction, jit # noqa: F401
5 changes: 3 additions & 2 deletions python/triton/runtime/autotuner.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from __future__ import annotations

import builtins
import time
from typing import Dict

from ..testing import do_bench


class Autotuner:
Expand Down Expand Up @@ -57,7 +59,7 @@ def kernel_call():
config.pre_hook(self.nargs)
self.hook(args)
self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current)
return triton.testing.do_bench(kernel_call)
return do_bench(kernel_call)

def __call__(self, *args, **kwargs):
self.nargs = dict(zip(self.arg_names, args))
Expand Down Expand Up @@ -199,4 +201,3 @@ def fun(*args, **meta):
return fn

return decorator

10 changes: 7 additions & 3 deletions python/triton/runtime/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import subprocess
import tempfile
import textwrap

import triton
import triton._C.libtriton.triton as _triton
from ..tools.disasm import extract
Expand All @@ -16,6 +17,7 @@
# Binary
# -----------------------------------------------------------------------------


class Binary:
def __init__(self, backend, name, asm, shared_mem, num_warps):
self.backend = backend
Expand Down Expand Up @@ -63,13 +65,13 @@ def get_sass(self, fun=None):
# Kernel
# -----------------------------------------------------------------------------


class Kernel:

def __call__(self, *args, grid, num_warps=4, num_stages=3, **kwargs):
raise RuntimeError("Not implemented. Public repo implementation will be rewritten to reduce latency.")



# -----------------------------------------------------------------------------
# Dependencies Finder
# -----------------------------------------------------------------------------
Expand Down Expand Up @@ -118,6 +120,7 @@ def visit_Call(self, node):
# JITFunction
# -----------------------------------------------------------------------------


@functools.lru_cache()
def version_key():
import pkgutil
Expand Down Expand Up @@ -232,7 +235,7 @@ def __init__(self, kernel, grid):

def __call__(self, *wargs, **kwargs):
return self.kernel(*wargs, **kwargs, grid=self.grid)

return Launcher(self._init_kernel(), grid)

def __repr__(self):
Expand All @@ -242,6 +245,7 @@ def __repr__(self):
# `jit` decorator
# -----------------------------------------------------------------------------


def jit(*args, **kwargs):
"""
Decorator for JIT-compiling a function using the Triton compiler.
Expand All @@ -265,4 +269,4 @@ def jit(*args, **kwargs):
else:
def decorator(fn):
return JITFunction(fn, **kwargs)
return decorator
return decorator
2 changes: 2 additions & 0 deletions python/triton/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from __future__ import annotations

import torch


Expand All @@ -17,6 +18,7 @@ def next_power_of_2(n):
n += 1
return n


class TensorWrapper:
def __init__(self, base, dtype):
self.dtype = dtype
Expand Down

0 comments on commit 40a62f9

Please sign in to comment.