From 417aad1fe34fa37b2dc79d9d49f7032c981e362b Mon Sep 17 00:00:00 2001 From: Antonio Sanchez Date: Thu, 7 Sep 2023 21:36:18 -0700 Subject: [PATCH] Remove custom u4/s4 definition, use common ml_dtypes one. We want a consistent definition for all int4 usages. Currently, JAX/TF use the underlying type defined in `ml_dtypes`. PiperOrigin-RevId: 563633360 --- xla/BUILD | 7 +- xla/types.h | 181 ++++++++++------------------------------------------ 2 files changed, 37 insertions(+), 151 deletions(-) diff --git a/xla/BUILD b/xla/BUILD index 2213465186e92..bcc59181bdd19 100644 --- a/xla/BUILD +++ b/xla/BUILD @@ -1,12 +1,12 @@ # Placeholder: load py_proto_library +load("//xla:xla.bzl", "xla_cc_test", "xla_py_proto_library") +load("//third_party/compute_library:build_defs.bzl", "if_enable_acl") load("@tsl//tsl:tsl.default.bzl", "filegroup", "get_compatible_with_portable") -load("@tsl//tsl/platform:rules_cc.bzl", "cc_library") load( "@tsl//tsl/platform:build_config.bzl", "tf_proto_library", ) -load("//third_party/compute_library:build_defs.bzl", "if_enable_acl") -load("//xla:xla.bzl", "xla_cc_test", "xla_py_proto_library") +load("@tsl//tsl/platform:rules_cc.bzl", "cc_library") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -206,6 +206,7 @@ cc_library( deps = [ "@com_google_absl//absl/strings:str_format", "@eigen_archive//:eigen3", + "@ml_dtypes//:int4", ], ) diff --git a/xla/types.h b/xla/types.h index eef409ac2a52b..8a3be3d9819ae 100644 --- a/xla/types.h +++ b/xla/types.h @@ -17,14 +17,13 @@ limitations under the License. #define XLA_TYPES_H_ #include -#include +#include #include -#include -#include #include #include "absl/strings/str_format.h" #include "Eigen/Core" // from @eigen_archive +#include "include/int4.h" // from @ml_dtypes namespace xla { @@ -52,152 +51,28 @@ inline constexpr bool is_specialized_integral_v = std::numeric_limits::is_specialized && std::numeric_limits::is_integer; -// LINT.IfChange -template -struct i4 { - private: - UnderlyingTy v : 4; - - public: - constexpr i4() : v(0) {} - constexpr explicit i4(UnderlyingTy val) : v(val & 0x0F) {} - template - constexpr explicit i4(T t) : i4(static_cast(t)) {} - constexpr i4(const i4& other) = default; - - // NOLINTNEXTLINE(google-explicit-constructor) - constexpr operator UnderlyingTy() const { - return static_cast(v); - } - - template - i4 operator>>(const T amount) const { - return i4(v >> amount); - } - template - i4 operator<<(const T amount) const { - return i4(v << amount); - } - - constexpr bool operator==(const i4 other) const { return v == other.v; } - constexpr bool operator!=(const i4 other) const { return v != other.v; } - constexpr bool operator<(const i4 other) const { return v < other.v; } - constexpr bool operator>(const i4 other) const { return v > other.v; } - constexpr bool operator<=(const i4 other) const { return v <= other.v; } - constexpr bool operator>=(const i4 other) const { return v >= other.v; } - - constexpr i4 operator-() const { return i4(-v); } - constexpr i4 operator~() const { return i4(~v); } - constexpr i4 operator++(int) { - i4 tmp(*this); - v = (v + 1) & 0x0F; - return tmp; - } - constexpr i4& operator++() { - v = (v + 1) & 0x0F; - return *this; - } - constexpr i4& operator+=(const i4 other) { - v = (v + other.v) & 0x0F; - return *this; - } - - friend ::std::ostream& operator<<(::std::ostream& os, const i4& num) { - os << static_cast(num.v); - return os; - } - - friend ::std::istream& operator>>(::std::istream& is, i4& num) { - UnderlyingTy value; - is >> value; - num = i4(static_cast(value)); - return is; - } - - template - friend void AbslStringify(Sink& sink, const i4& i) { - absl::Format(&sink, "%d", i.v); - } - - std::string to_string() const { return std::to_string(v); } -}; - -using u4 = i4; -using s4 = i4; -// LINT.ThenChange(//tensorflow/compiler/xla/literal.cc) +using u4 = ml_dtypes::uint4; +using s4 = ml_dtypes::int4; } // namespace xla +// Extend ml_dtypes to allow absl::String functions. +namespace ml_dtypes { +template +void AbslStringify(Sink& sink, const xla::s4& i) { + absl::Format(&sink, "%d", static_cast(i)); +} + +template +void AbslStringify(Sink& sink, const xla::u4& i) { + absl::Format(&sink, "%d", static_cast(i)); +} +} // namespace ml_dtypes + // Alias namespace ::stream_executor as ::xla::se. namespace stream_executor {} namespace xla { namespace se = ::stream_executor; // NOLINT(misc-unused-alias-decls) -} // namespace xla - -namespace std { -// NOLINTBEGIN: these names must match std::numeric_limits. -template -class numeric_limits_int4t { - public: - static constexpr bool is_specialized = true; - static constexpr const bool is_integer = true; - static constexpr const bool is_exact = true; - - static constexpr bool has_infinity = false; - static constexpr bool has_quiet_NaN = false; - static constexpr bool has_signaling_NaN = false; - static constexpr float_denorm_style has_denorm = denorm_absent; - static constexpr bool has_denorm_loss = false; - static constexpr float_round_style round_style = round_toward_zero; - static constexpr bool is_iec559 = false; - static constexpr bool is_bounded = true; - static constexpr int max_digits10 = 0; - static constexpr int radix = 2; - static constexpr int min_exponent = 0; - static constexpr int min_exponent10 = 0; - static constexpr int max_exponent = 0; - static constexpr int max_exponent10 = 0; - static constexpr bool tinyness_before = false; - - static constexpr Int4T epsilon() { return Int4T(0); } - static constexpr Int4T round_error() { return Int4T(0); } - static constexpr Int4T infinity() { return Int4T(0); } - static constexpr Int4T quiet_NaN() { return Int4T(0); } - static constexpr Int4T signaling_NaN() { return Int4T(0); } - static constexpr Int4T denorm_min() { return Int4T(0); } -}; - -template <> -class numeric_limits : public numeric_limits_int4t { - public: - static constexpr const bool is_signed = false; - static constexpr int digits = 4; - static constexpr int digits10 = 2; - static constexpr bool is_modulo = true; - static constexpr bool traps = numeric_limits::traps; - - static constexpr xla::u4(min)() { return xla::u4(0); } - static constexpr xla::u4 lowest() { return xla::u4(0); } - static constexpr xla::u4(max)() { return xla::u4(15); } -}; - -template <> -class numeric_limits : public numeric_limits_int4t { - public: - static constexpr const bool is_signed = true; - static constexpr int digits = 3; - static constexpr int digits10 = 1; - static constexpr bool is_modulo = false; - static constexpr bool traps = numeric_limits::traps; - - static constexpr xla::s4(min)() { return xla::s4(-8); } - static constexpr xla::s4 lowest() { return xla::s4(-8); } - static constexpr xla::s4(max)() { return xla::s4(7); } -}; -// NOLINTEND -} // namespace std - -namespace xla { // std::make_signed_t is “behavior undefined” for custom types, so provide a // general util to make signed/unsigned for both primitive and custom types. @@ -206,9 +81,14 @@ struct make_specialized_unsigned { using type = std::make_unsigned_t; }; -template -struct make_specialized_unsigned> { - using type = xla::i4>; +template <> +struct make_specialized_unsigned { + using type = xla::u4; +}; + +template <> +struct make_specialized_unsigned { + using type = xla::u4; }; template @@ -219,9 +99,14 @@ struct make_specialized_signed { using type = std::make_signed_t; }; -template -struct make_specialized_signed> { - using type = xla::i4>; +template <> +struct make_specialized_signed { + using type = xla::s4; +}; + +template <> +struct make_specialized_signed { + using type = xla::s4; }; template