Skip to content

Commit

Permalink
Remove custom u4/s4 definition, use common ml_dtypes one.
Browse files Browse the repository at this point in the history
We want a consistent definition for all int4 usages.  Currently,
JAX/TF use the underlying type defined in `ml_dtypes`.

PiperOrigin-RevId: 563633360
  • Loading branch information
cantonios authored and copybara-github committed Sep 8, 2023
1 parent cfac4c3 commit 417aad1
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 151 deletions.
7 changes: 4 additions & 3 deletions xla/BUILD
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
# Placeholder: load py_proto_library
load("//xla:xla.bzl", "xla_cc_test", "xla_py_proto_library")
load("//third_party/compute_library:build_defs.bzl", "if_enable_acl")
load("@tsl//tsl:tsl.default.bzl", "filegroup", "get_compatible_with_portable")
load("@tsl//tsl/platform:rules_cc.bzl", "cc_library")
load(
"@tsl//tsl/platform:build_config.bzl",
"tf_proto_library",
)
load("//third_party/compute_library:build_defs.bzl", "if_enable_acl")
load("//xla:xla.bzl", "xla_cc_test", "xla_py_proto_library")
load("@tsl//tsl/platform:rules_cc.bzl", "cc_library")

package(
# copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
Expand Down Expand Up @@ -206,6 +206,7 @@ cc_library(
deps = [
"@com_google_absl//absl/strings:str_format",
"@eigen_archive//:eigen3",
"@ml_dtypes//:int4",
],
)

Expand Down
181 changes: 33 additions & 148 deletions xla/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,13 @@ limitations under the License.
#define XLA_TYPES_H_

#include <complex>
#include <istream>
#include <cstdint>
#include <limits>
#include <ostream>
#include <string>
#include <type_traits>

#include "absl/strings/str_format.h"
#include "Eigen/Core" // from @eigen_archive
#include "include/int4.h" // from @ml_dtypes

namespace xla {

Expand Down Expand Up @@ -52,152 +51,28 @@ inline constexpr bool is_specialized_integral_v =
std::numeric_limits<T>::is_specialized &&
std::numeric_limits<T>::is_integer;

// LINT.IfChange
template <typename UnderlyingTy>
struct i4 {
private:
UnderlyingTy v : 4;

public:
constexpr i4() : v(0) {}
constexpr explicit i4(UnderlyingTy val) : v(val & 0x0F) {}
template <typename T>
constexpr explicit i4(T t) : i4(static_cast<UnderlyingTy>(t)) {}
constexpr i4(const i4& other) = default;

// NOLINTNEXTLINE(google-explicit-constructor)
constexpr operator UnderlyingTy() const {
return static_cast<UnderlyingTy>(v);
}

template <typename T>
i4 operator>>(const T amount) const {
return i4(v >> amount);
}
template <typename T>
i4 operator<<(const T amount) const {
return i4(v << amount);
}

constexpr bool operator==(const i4 other) const { return v == other.v; }
constexpr bool operator!=(const i4 other) const { return v != other.v; }
constexpr bool operator<(const i4 other) const { return v < other.v; }
constexpr bool operator>(const i4 other) const { return v > other.v; }
constexpr bool operator<=(const i4 other) const { return v <= other.v; }
constexpr bool operator>=(const i4 other) const { return v >= other.v; }

constexpr i4 operator-() const { return i4(-v); }
constexpr i4 operator~() const { return i4(~v); }
constexpr i4 operator++(int) {
i4 tmp(*this);
v = (v + 1) & 0x0F;
return tmp;
}
constexpr i4& operator++() {
v = (v + 1) & 0x0F;
return *this;
}
constexpr i4& operator+=(const i4 other) {
v = (v + other.v) & 0x0F;
return *this;
}

friend ::std::ostream& operator<<(::std::ostream& os, const i4& num) {
os << static_cast<int16_t>(num.v);
return os;
}

friend ::std::istream& operator>>(::std::istream& is, i4& num) {
UnderlyingTy value;
is >> value;
num = i4(static_cast<UnderlyingTy>(value));
return is;
}

template <typename Sink>
friend void AbslStringify(Sink& sink, const i4& i) {
absl::Format(&sink, "%d", i.v);
}

std::string to_string() const { return std::to_string(v); }
};

using u4 = i4<uint8_t>;
using s4 = i4<int8_t>;
// LINT.ThenChange(//tensorflow/compiler/xla/literal.cc)
using u4 = ml_dtypes::uint4;
using s4 = ml_dtypes::int4;

} // namespace xla

// Extend ml_dtypes to allow absl::String functions.
namespace ml_dtypes {
template <typename Sink>
void AbslStringify(Sink& sink, const xla::s4& i) {
absl::Format(&sink, "%d", static_cast<int32_t>(i));
}

template <typename Sink>
void AbslStringify(Sink& sink, const xla::u4& i) {
absl::Format(&sink, "%d", static_cast<uint32_t>(i));
}
} // namespace ml_dtypes

// Alias namespace ::stream_executor as ::xla::se.
namespace stream_executor {}
namespace xla {
namespace se = ::stream_executor; // NOLINT(misc-unused-alias-decls)
} // namespace xla

namespace std {
// NOLINTBEGIN: these names must match std::numeric_limits.
template <typename Int4T>
class numeric_limits_int4t {
public:
static constexpr bool is_specialized = true;
static constexpr const bool is_integer = true;
static constexpr const bool is_exact = true;

static constexpr bool has_infinity = false;
static constexpr bool has_quiet_NaN = false;
static constexpr bool has_signaling_NaN = false;
static constexpr float_denorm_style has_denorm = denorm_absent;
static constexpr bool has_denorm_loss = false;
static constexpr float_round_style round_style = round_toward_zero;
static constexpr bool is_iec559 = false;
static constexpr bool is_bounded = true;
static constexpr int max_digits10 = 0;
static constexpr int radix = 2;
static constexpr int min_exponent = 0;
static constexpr int min_exponent10 = 0;
static constexpr int max_exponent = 0;
static constexpr int max_exponent10 = 0;
static constexpr bool tinyness_before = false;

static constexpr Int4T epsilon() { return Int4T(0); }
static constexpr Int4T round_error() { return Int4T(0); }
static constexpr Int4T infinity() { return Int4T(0); }
static constexpr Int4T quiet_NaN() { return Int4T(0); }
static constexpr Int4T signaling_NaN() { return Int4T(0); }
static constexpr Int4T denorm_min() { return Int4T(0); }
};

template <>
class numeric_limits<xla::u4> : public numeric_limits_int4t<xla::u4> {
public:
static constexpr const bool is_signed = false;
static constexpr int digits = 4;
static constexpr int digits10 = 2;
static constexpr bool is_modulo = true;
static constexpr bool traps = numeric_limits<uint8_t>::traps;

static constexpr xla::u4(min)() { return xla::u4(0); }
static constexpr xla::u4 lowest() { return xla::u4(0); }
static constexpr xla::u4(max)() { return xla::u4(15); }
};

template <>
class numeric_limits<xla::s4> : public numeric_limits_int4t<xla::s4> {
public:
static constexpr const bool is_signed = true;
static constexpr int digits = 3;
static constexpr int digits10 = 1;
static constexpr bool is_modulo = false;
static constexpr bool traps = numeric_limits<int8_t>::traps;

static constexpr xla::s4(min)() { return xla::s4(-8); }
static constexpr xla::s4 lowest() { return xla::s4(-8); }
static constexpr xla::s4(max)() { return xla::s4(7); }
};
// NOLINTEND
} // namespace std

namespace xla {

// std::make_signed_t is “behavior undefined” for custom types, so provide a
// general util to make signed/unsigned for both primitive and custom types.
Expand All @@ -206,9 +81,14 @@ struct make_specialized_unsigned {
using type = std::make_unsigned_t<T>;
};

template <typename UnderlyingTy>
struct make_specialized_unsigned<xla::i4<UnderlyingTy>> {
using type = xla::i4<std::make_unsigned_t<UnderlyingTy>>;
template <>
struct make_specialized_unsigned<xla::s4> {
using type = xla::u4;
};

template <>
struct make_specialized_unsigned<xla::u4> {
using type = xla::u4;
};

template <typename T>
Expand All @@ -219,9 +99,14 @@ struct make_specialized_signed {
using type = std::make_signed_t<T>;
};

template <typename UnderlyingTy>
struct make_specialized_signed<xla::i4<UnderlyingTy>> {
using type = xla::i4<std::make_signed_t<UnderlyingTy>>;
template <>
struct make_specialized_signed<xla::s4> {
using type = xla::s4;
};

template <>
struct make_specialized_signed<xla::u4> {
using type = xla::s4;
};

template <typename T>
Expand Down

0 comments on commit 417aad1

Please sign in to comment.