Skip to content

Commit

Permalink
more language awareness
Browse files Browse the repository at this point in the history
  • Loading branch information
Ellipse0934 committed Jan 21, 2022
1 parent bf64c20 commit 89f3e3d
Showing 1 changed file with 103 additions and 45 deletions.
148 changes: 103 additions & 45 deletions src/reflection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,73 +34,131 @@ function highlight(io::IO, code, lexer)
end
end

const ptx_instructions = [
"abs", "cvt", "min", "shfl", "vadd", "activemask", "cvta", "mma", "shl", "vadd2",
"add", "discard", "mov", "shr", "vadd4", "addc", "div", "mul", "sin", "vavrg2",
"alloca", "dp2a", "mul24", "slct", "vavrg4", "and", "dp4a", "nanosleep", "sqrt",
"vmad", "applypriority", "ex2", "neg", "st", "vmax", "atom", "exit", "not",
"stackrestore", "vmax2", "bar", "fence", "or", "stacksave", "vmax4", "barrier",
"fma", "pmevent", "sub", "vmin", "bfe", "fns", "popc", "subc", "vmin2", "bfi",
"isspacep", "prefetch", "suld", "vmin4", "bfind", "istypep", "prefetchu", "suq",
"vote", "bmsk", "ld", "prmt", "sured", "vset", "bra", "ldmatrix", "rcp", "sust",
"vset2", "brev", "ldu", "red", "szext", "vset4", "brkpt", "lg2", "redux", "tanh",
"vshl", "brx", "lop3", "rem", "testp", "vshr", "call", "mad", "ret", "tex", "vsub",
"clz", "mad24", "rsqrt", "tld4", "vsub2", "cnot", "madc", "sad", "trap", "vsub4",
"copysign", "match", "selp", "txq", "wmma", "cos", "max", "set", "vabsdiff", "xor",
"cp", "mbarrier", "setp", "vabsdiff2", "createpolicy", "membar", "shf", "vabsdiff4"]
ptx_instructions = ["abs", "activemask", "add", "addc", "alloca", "and",
"applypriority", "atom", "bar", "barrier", "bfe", "bfi",
"bfind", "bmsk", "bra", "brev", "brkpt", "brx", "call", "clz",
"cnot", "copysign", "cos", "cp", "createpolicy", "cvt", "cvta",
"discard", "div", "dp2a", "dp4a", "ex2", "exit", "fence",
"fma", "fns", "isspacep", "istypep", "ld", "ldmatrix", "ldu",
"lg2", "lop3", "mad", "mad24", "madc", "match", "max", "mbarrier",
"membar", "min", "mma", "mov", "mul", "mul24", "nanosleep", "neg",
"not", "or", "pmevent", "popc", "prefetch", "prefetchu", "prmt",
"rcp", "red", "redux", "rem", "ret", "rsqrt", "sad", "selp",
"set", "setp", "shf", "shfl", "shl", "shr", "sin", "slct", "sqrt",
"st", "stackrestore", "stacksave", "sub", "subc", "suld", "suq",
"sured", "sust", "szext", "tanh", "testp", "tex", "tld4", "trap",
"txq", "vabsdiff", "vabsdiff2", "vabsdiff4", "vadd", "vadd2", "vadd4",
"vavrg2", "vavrg4", "vmad", "vmax", "vmax2", "vmax4", "vmin", "vmin2",
"vmin4", "vote", "vset", "vset2", "vset4", "vshl", "vshr", "vsub",
"vsub2", "vsub4", "wmma", "xor"]

r_ptx_instruction = join(ptx_instructions, "|")

types = ["s8", "s16", "s32", "s64", "u8,", "u16,", "u32", "u64", "f16", "f16x2", "f32", "f64", "b8,", "b16", "b32", "b64", "pred"]
r_types = join(types, "|")


# https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-and-bit-size-comparisons
operators_comparison_sint = ["eq", "ne", "lt", "le", "gt", "ge"]
operators_comparison_uint = ["eq", "ne", "lo", "ls", "hi", "hs"]
operators_comparison_bit = ["eq", "ne"]

operators_comparison_float = ["eq", "ne", "lt", "le", "gt", "ge"]
operators_comparison_nanfloat = ["equ", "neu", "ltu", "leu", "gtu", "geu"]
operators_comparison_nan = ["num", "nan"]

modifiers_int = ["rni", "rzi", "rmi", "rpi"]
modifiers_float = ["rn", "rna", "rz", "rm", "rp"]
modifiers = sort(unique([modifiers_int...,modifiers_float...]))

state_spaces = ["reg", "sreg", "const", "global", "local", "param", "shared", "tex"]


operators = sort(unique([operators_comparison_sint..., operators_comparison_uint...,
operators_comparison_bit..., operators_comparison_float...,
operators_comparison_nanfloat..., operators_comparison_nan...,
modifiers..., state_spaces..., types...]))


r_operators = join(operators, "|")

# We can divide into types of instructions as all combinations of instructions, types and operators are not valid.
r_instruction = "(?:(?:$r_ptx_instruction)\\.(?:(?:$r_operators)(?:\\.))?(?:$(r_types)))"

directives = ["address_size", "align", "branchtargets", "callprototype",
"calltargets", "const", "entry", "extern", "file", "func", "global",
"loc", "local", "maxnctapersm", "maxnreg", "maxntid",
"minnctapersm", "param", "pragma", "reg", "reqntid", "section",
"shared", "sreg", "target", "tex", "version", "visible", "weak"]

r_directive = join(directives, "|")


r_hex = "0[xX][A-F]+U?"
r_octal = "0[0-8]+U?"
r_binary = "0[bB][01]+U?"
r_decimal = "[0-9]+U?"
r_float = "0[fF]{hexdigit}{8}"
r_double = "0[dD]{hexdigit}{16}"

r_number = join(map(x -> "(?:" * x * ")", [r_hex, r_octal, r_binary, r_decimal, r_float, r_double]), "|")

r_register_special = ["%clock", "%clock64", "%clock_hi", "%ctaid", "%dynamic_smem_size", "%envreg\\d{0,2}", # envreg0-31
"%globaltimer", "%globaltimer_hi", "%globaltimer_lo,", "%gridid", "%laneid", "%lanemask_eq",
"%lanemask_ge", "%lanemask_gt", "%lanemask_le", "%lanemask_lt", "%nctaid", "%nsmid",
"%ntid", "%nwarpid", "%pm\\d,", "%pm\\d_64", "%reserved_smem_offset<2>",
"%reserved_smem_offset_begin", "%reserved_smem_offset_cap", "%reserved_smem_offset_end", "%smid",
"%tid", "%total_smem_size", "%warpid", "%\\w{1,2}\\d{0,2}"]

r_register = join(r_register_special, "|")


r_followsym = "[a-zA-Z0-9_\$]"
r_identifier= "[a-zA-Z]{$r_followsym}* | {[_\$%]{$r_followsym}+"

r_guard_predicate = "@!?%p\\d{0,2}"
r_label = "[\\w_]+:"
r_comment = "//"
r_unknown = "[^\\s]*"

r_line = "(?:(?:.$r_directive)|(?:$r_instruction)|(?:$r_register)|(?:$r_number)|(?:$r_label)|(?:$r_guard_predicate)|(?:$r_comment)|(?:$r_identifier)|(?:$r_unknown))"

get_token(n::Nothing) = nothing, nothing, nothing

# simple regex-based highlighter
# https://docs.nvidia.com/cuda/parallel-thread-execution/index.html
function highlight_ptx(io::IO, code)
function highlight_ptx(io::IO, code::AbstractString)
function get_token(s)
# TODO: doesn't handle `ret;`, `{1`, etc; not properly tokenizing symbols
m = match(r"(\s*)([^\s]+)(.*)", s)
m = match(Regex("^(\\s*)($r_line)([^\\w\\d]+.*)?"), s)
m !== nothing && (return m.captures[1:3])
return nothing, nothing, nothing
end
get_token(n::Nothing) = nothing, nothing, nothing
print_tok(token, type) = Base.printstyled(io,
token,
bold = hlscheme[type][1],
color = hlscheme[type][2])
buf = IOBuffer(code)
while !eof(buf)
line = readline(buf)
code = IOBuffer(code)
while !eof(code)
line = readline(code)
indent, tok, line = get_token(line)
istok(regex) = match(regex, tok) !== nothing
isinstr() = first(split(tok, '.')) in ptx_instructions
is_tok(regex) = match(Regex("^(" * regex * ")"), tok) !== nothing
while (tok !== nothing)
print(io, indent)

# comments
if istok(r"^\/\/")
if is_tok(r_comment)
print_tok(tok, :comment)
print_tok(line, :comment)
break
# labels
elseif istok(r"^[\w]+:")
elseif is_tok(r_label)
print_tok(tok, :label)
# instructions
elseif isinstr()
elseif is_tok(r_instruction)
print_tok(tok, :instruction)
# directives
elseif istok(r"^\.[\w]+")
elseif is_tok(r_directive)
print_tok(tok, :type)
# guard predicates
elseif istok(r"^@!?%p.+")
elseif is_tok(r_guard_predicate)
print_tok(tok, :keyword)
# registers
elseif istok(r"^%[\w]+")
print_tok(tok, :variable)
# constants
elseif istok(r"^0[xX][A-F]+U?") || # hexadecimal
istok(r"^0[0-8]+U?") || # octal
istok(r"^0[bB][01]+U?") || # binary
istok(r"^[0-9]+U?") || # decimal
istok(r"^0[fF]{hexdigit}{8}") || # single-precision floating point
istok(r"^0[dD]{hexdigit}{16}") # double-precision floating point
elseif is_tok(r_register)
print_tok(tok, :number)
# TODO: function names
# TODO: labels as RHS
else
print_tok(tok, :default)
end
Expand Down

0 comments on commit 89f3e3d

Please sign in to comment.