diff --git a/BUILD b/BUILD index b6b23576beb7..d4b988f04058 100644 --- a/BUILD +++ b/BUILD @@ -10,6 +10,7 @@ load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") # copybara:uncomment load("//tools/build_defs/license:license.bzl", "license") +load("//:triton.bzl", "if_not_msvc") package( # copybara:uncomment_begin @@ -37,6 +38,13 @@ package( # exports_files(["LICENSE"]) # copybara:uncomment_end +config_setting( + name = "compiler_is_msvc", + flag_values = { + "@bazel_tools//tools/cpp:compiler": "msvc-cl", + }, +) + td_library( name = "td_files", srcs = glob(["include/triton/**/*.td"]), @@ -276,7 +284,7 @@ cc_library( name = "TritonDialect", srcs = glob(["lib/Dialect/Triton/IR/*.cpp"]), hdrs = glob(["include/triton/Dialect/Triton/IR/*.h"]), - copts = ["-Wno-unused-variable"], # TODO(manany): fix + copts = if_not_msvc(["-Wno-unused-variable"]), includes = ["include"], deps = [ ":triton_dialect_inc_gen", @@ -328,7 +336,7 @@ cc_library( "include/triton/Analysis/*.h", "include/triton/Dialect/TritonGPU/IR/*.h", ]), - copts = ["-Wno-unused-variable"], # TODO(csigg): fix + copts = if_not_msvc(["-Wno-unused-variable"]), includes = ["include"], deps = [ ":TritonDialect", @@ -356,7 +364,7 @@ cc_library( "lib/Dialect/TritonGPU/Transforms/*.h", ]), hdrs = glob(["include/triton/Dialect/TritonGPU/Transforms/*.h"]), - copts = ["-Wno-unused-variable"], # TODO(csigg): fix + copts = if_not_msvc(["-Wno-unused-variable"]), includes = ["include"], deps = [ ":TritonDialect", @@ -391,7 +399,7 @@ cc_library( "include/triton/Tools/Sys/*.hpp", "include/triton/Conversion/TritonGPUToLLVM/*.h", ]), - copts = ["-Wno-unused-variable"], # TODO(csigg): fix + copts = if_not_msvc(["-Wno-unused-variable"]), includes = [ "include", "lib/Conversion/TritonGPUToLLVM", diff --git a/triton.bzl b/triton.bzl new file mode 100644 index 000000000000..25627d030bcd --- /dev/null +++ b/triton.bzl @@ -0,0 +1,10 @@ +"""Bazel macros used by the triton build.""" + +def if_msvc(if_true, if_false = []): + return select({ + ":compiler_is_msvc": if_true, + "//conditions:default": if_false, + }) + +def if_not_msvc(a): + return if_msvc([], a)