Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Syntax highlighting for PTX code #275

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 0 additions & 72 deletions res/pygments/ptx.py

This file was deleted.

172 changes: 152 additions & 20 deletions src/reflection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,33 +8,165 @@ const Cthulhu = Base.PkgId(UUID("f68482b8-f384-11e8-15f7-abe071a5a75f"), "Cthulh
# syntax highlighting
#

const _pygmentize = Ref{Union{String,Nothing}}()
function pygmentize()
if !isassigned(_pygmentize)
_pygmentize[] = Sys.which("pygmentize")
end
return _pygmentize[]
end
# https://github.com/JuliaLang/julia/blob/dacd16f068fb27719b31effbe8929952ee2d5b32/stdlib/InteractiveUtils/src/codeview.jl
const hlscheme = Dict{Symbol, Tuple{Bool, Union{Symbol, Int}}}(
:default => (false, :normal), # e.g. comma, equal sign, unknown token
:comment => (false, :light_black),
:label => (false, :light_red),
:instruction => ( true, :light_cyan),
:type => (false, :cyan),
:number => (false, :yellow),
:bracket => (false, :yellow),
:variable => (false, :normal), # e.g. variable, register
:keyword => (false, :light_magenta),
:funcname => (false, :light_yellow),
)

function highlight(io::IO, code, lexer)
highlighter = pygmentize()
have_color = get(io, :color, false)
if highlighter === nothing || !have_color
if !haskey(io, :color)
print(io, code)
elseif lexer == "llvm"
InteractiveUtils.print_llvm(io, code)
elseif lexer == "ptx"
highlight_ptx(io, code)
else
custom_lexer = joinpath(dirname(@__DIR__), "res", "pygments", "$lexer.py")
if isfile(custom_lexer)
lexer = `$custom_lexer -x`
end

pipe = open(`$highlighter -f terminal -P bg=dark -l $lexer`, "r+")
print(pipe, code)
close(pipe.in)
print(io, read(pipe, String))
print(io, code)
end
return
end

ptx_instructions = ["abs", "activemask", "add", "addc", "alloca", "and",
"applypriority", "atom", "bar", "barrier", "bfe", "bfi",
"bfind", "bmsk", "bra", "brev", "brkpt", "brx", "call", "clz",
"cnot", "copysign", "cos", "cp", "createpolicy", "cvt", "cvta",
"discard", "div", "dp2a", "dp4a", "ex2", "exit", "fence",
"fma", "fns", "isspacep", "istypep", "ld", "ldmatrix", "ldu",
"lg2", "lop3", "mad", "mad24", "madc", "match", "max", "mbarrier",
"membar", "min", "mma", "mov", "mul", "mul24", "nanosleep", "neg",
"not", "or", "pmevent", "popc", "prefetch", "prefetchu", "prmt",
"rcp", "red", "redux", "rem", "ret", "rsqrt", "sad", "selp",
"set", "setp", "shf", "shfl", "shl", "shr", "sin", "slct", "sqrt",
"st", "stackrestore", "stacksave", "sub", "subc", "suld", "suq",
"sured", "sust", "szext", "tanh", "testp", "tex", "tld4", "trap",
"txq", "vabsdiff", "vabsdiff2", "vabsdiff4", "vadd", "vadd2", "vadd4",
"vavrg2", "vavrg4", "vmad", "vmax", "vmax2", "vmax4", "vmin", "vmin2",
"vmin4", "vote", "vset", "vset2", "vset4", "vshl", "vshr", "vsub",
"vsub2", "vsub4", "wmma", "xor"]

r_ptx_instruction = join(ptx_instructions, "|")

types = ["s8", "s16", "s32", "s64", "u8,", "u16,", "u32", "u64", "f16", "f16x2", "f32", "f64", "b8,", "b16", "b32", "b64", "pred"]
r_types = join(types, "|")


# https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-and-bit-size-comparisons
operators_comparison_sint = ["eq", "ne", "lt", "le", "gt", "ge"]
operators_comparison_uint = ["eq", "ne", "lo", "ls", "hi", "hs"]
operators_comparison_bit = ["eq", "ne"]

operators_comparison_float = ["eq", "ne", "lt", "le", "gt", "ge"]
operators_comparison_nanfloat = ["equ", "neu", "ltu", "leu", "gtu", "geu"]
operators_comparison_nan = ["num", "nan"]

modifiers_int = ["rni", "rzi", "rmi", "rpi"]
modifiers_float = ["rn", "rna", "rz", "rm", "rp"]
modifiers = sort(unique([modifiers_int...,modifiers_float...]))

state_spaces = ["reg", "sreg", "const", "global", "local", "param", "shared", "tex"]


operators = sort(unique([operators_comparison_sint..., operators_comparison_uint...,
operators_comparison_bit..., operators_comparison_float...,
operators_comparison_nanfloat..., operators_comparison_nan...,
modifiers..., state_spaces..., types...]))


r_operators = join(operators, "|")

# We can divide into types of instructions as all combinations of instructions, types and operators are not valid.
r_instruction = "(?:(?:$r_ptx_instruction)\\.(?:(?:$r_operators)(?:\\.))?(?:$(r_types)))"

directives = ["address_size", "align", "branchtargets", "callprototype",
"calltargets", "const", "entry", "extern", "file", "func", "global",
"loc", "local", "maxnctapersm", "maxnreg", "maxntid",
"minnctapersm", "param", "pragma", "reg", "reqntid", "section",
"shared", "sreg", "target", "tex", "version", "visible", "weak"]

r_directive = "(?:.(?:" * join(directives, "|") * "))"


r_hex = "0[xX][A-F]+U?"
r_octal = "0[0-8]+U?"
r_binary = "0[bB][01]+U?"
r_decimal = "[0-9]+U?"
r_float = "0[fF]{hexdigit}{8}"
r_double = "0[dD]{hexdigit}{16}"

r_number = join(map(x -> "(?:" * x * ")", [r_hex, r_octal, r_binary, r_decimal, r_float, r_double]), "|")

r_register_special = ["%clock", "%clock64", "%clock_hi", "%ctaid", "%dynamic_smem_size", "%envreg\\d{0,2}", # envreg0-31
"%globaltimer", "%globaltimer_hi", "%globaltimer_lo,", "%gridid", "%laneid", "%lanemask_eq",
"%lanemask_ge", "%lanemask_gt", "%lanemask_le", "%lanemask_lt", "%nctaid", "%nsmid",
"%ntid", "%nwarpid", "%pm\\d,", "%pm\\d_64", "%reserved_smem_offset<2>",
"%reserved_smem_offset_begin", "%reserved_smem_offset_cap", "%reserved_smem_offset_end", "%smid",
"%tid", "%total_smem_size", "%warpid", "%\\w{1,2}\\d{0,2}"]

r_register = join(r_register_special, "|")


r_followsym = "[a-zA-Z0-9_\$]"
r_identifier= "[a-zA-Z]{$r_followsym}* | {[_\$%]{$r_followsym}+"

r_guard_predicate = "@!?%p\\d{0,2}"
r_label = "[\\w_]+:"
r_comment = "//"
r_unknown = "[^\\s]*"

r_line = "(?:(?:$r_directive)|(?:$r_instruction)|(?:$r_register)|(?:$r_number)|(?:$r_label)|(?:$r_guard_predicate)|(?:$r_comment)|(?:$r_identifier)|(?:$r_unknown))"

get_token(n::Nothing) = nothing, nothing, nothing

# simple regex-based highlighter
# https://docs.nvidia.com/cuda/parallel-thread-execution/index.html
function highlight_ptx(io::IO, code::AbstractString)
function get_token(s)
m = match(Regex("^(\\s*)($r_line)([^\\w\\d]+.*)?"), s)
m !== nothing && (return m.captures[1:3])
return nothing, nothing, nothing
end
get_token(n::Nothing) = nothing, nothing, nothing
print_tok(token, type) = Base.printstyled(io,
token,
bold = hlscheme[type][1],
color = hlscheme[type][2])
code = IOBuffer(code)
while !eof(code)
line = readline(code)
indent, tok, line = get_token(line)
is_tok(regex) = match(Regex("^(" * regex * ")"), tok) !== nothing
while (tok !== nothing)
print(io, indent)
if is_tok(r_comment)
print_tok(tok, :comment)
print_tok(line, :comment)
break
elseif is_tok(r_label)
print_tok(tok, :label)
elseif is_tok(r_instruction)
print_tok(tok, :instruction)
elseif is_tok(r_directive)
print_tok(tok, :type)
elseif is_tok(r_guard_predicate)
print_tok(tok, :keyword)
elseif is_tok(r_register)
print_tok(tok, :number)
else
print_tok(tok, :default)
end
indent, tok, line = get_token(line)
end
print(io, '\n')
end
end

#
# code_* replacements
Expand Down