Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support multiple floating point types in client library test base #17546

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions xla/literal_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,16 @@ void SetScalarAtIndexImpl(MutableLiteralBase& literal,
return ConvertType<float, tsl::float8_e5m2fnuz>(f32_literal);
}

/* static */ Literal LiteralUtil::ConvertF32ToF8E5M2(
const LiteralSlice& f32_literal) {
return ConvertType<float, tsl::float8_e5m2>(f32_literal);
}

/* static */ Literal LiteralUtil::ConvertF32ToF8E4M3FN(
const LiteralSlice& f32_literal) {
return ConvertType<float, tsl::float8_e4m3fn>(f32_literal);
}

/* static */ Literal LiteralUtil::ConvertF32ToBF16(
const LiteralSlice& f32_literal) {
return ConvertType<float, bfloat16>(f32_literal);
Expand Down
2 changes: 2 additions & 0 deletions xla/literal_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,8 @@ class LiteralUtil {
static Literal ConvertBF16ToF64(const LiteralSlice& bf16_literal);
static Literal ConvertF32ToF8E4M3FNUZ(const LiteralSlice& f32_literal);
static Literal ConvertF32ToF8E5M2FNUZ(const LiteralSlice& f32_literal);
static Literal ConvertF32ToF8E5M2(const LiteralSlice& f32_literal);
static Literal ConvertF32ToF8E4M3FN(const LiteralSlice& f32_literal);
static Literal ConvertF32ToBF16(const LiteralSlice& f32_literal);
static Literal ConvertF32ToS8(const LiteralSlice& f32_literal);
static Literal ConvertF32ToF64(const LiteralSlice& f32_literal);
Expand Down
1 change: 1 addition & 0 deletions xla/tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ cc_library(
"//xla/client:local_client",
"//xla/client:xla_builder",
"//xla/client:xla_computation",
"//xla/hlo/builder:xla_builder",
"//xla/service:interpreter_plugin", # reference backend
"//xla/service:platform_util",
"//xla/stream_executor",
Expand Down
90 changes: 41 additions & 49 deletions xla/tests/client_library_test_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,13 @@ limitations under the License.
#include "xla/client/local_client.h"
#include "xla/client/xla_builder.h"
#include "xla/execution_options_util.h"
#include "xla/hlo/builder/xla_builder.h"
#include "xla/literal_util.h"
#include "xla/service/platform_util.h"
#include "xla/shape_util.h"
#include "xla/status_macros.h"
#include "xla/test_helpers.h"
#include "xla/xla_data.pb.h"
#include "tsl/platform/logging.h"

namespace xla {
Expand Down Expand Up @@ -291,7 +293,7 @@ absl::StatusOr<Literal> ClientLibraryTestBase::ComputeAndTransfer(
for (const auto& argument : arguments_) {
TF_ASSIGN_OR_RETURN(
std::unique_ptr<GlobalData> owned_argument,
client_->TransferToServer(MaybeConvertLiteralToBfloat16(argument)));
client_->TransferToServer(MaybeConvertLiteralToTestType(argument)));
owning_arguments.push_back(std::move(owned_argument));
arguments.push_back(owning_arguments.back().get());
}
Expand All @@ -315,7 +317,7 @@ absl::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
for (const auto& argument : arguments_) {
TF_ASSIGN_OR_RETURN(
std::unique_ptr<GlobalData> owned_argument,
client_->TransferToServer(MaybeConvertLiteralToBfloat16(argument)));
client_->TransferToServer(MaybeConvertLiteralToTestType(argument)));
owning_arguments.push_back(std::move(owned_argument));
arguments.push_back(owning_arguments.back().get());
}
Expand All @@ -326,20 +328,20 @@ absl::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
ShapeUtil::ElementIsComplex(expected.shape())) {
LOG(WARNING) << "performing exact comparison of floating point numbers";
}
// We allow using a float expected literal for a bfloat16 output. In this
// case, we need to convert the expected literal to bfloat16.
// We allow using a float expected literal for non float outputs. In this
// case, we need to convert the expected literal to test_type_.
const Literal* expected_ptr = &expected;
Literal converted_expected;
Shape layout_shape;
if (use_bfloat16()) {
converted_expected = LiteralUtil::ConvertF32ToBF16(expected);
if (test_type_ != F32) {
converted_expected = MaybeConvertLiteralToTestType(expected);
expected_ptr = &converted_expected;
if (shape_with_layout != nullptr) {
layout_shape = *shape_with_layout;
ShapeUtil::ForEachMutableSubshape(
&layout_shape, [&](Shape* subshape, const ShapeIndex& /*index*/) {
if (subshape->element_type() == F32) {
subshape->set_element_type(BF16);
subshape->set_element_type(test_type_);
}
});
shape_with_layout = &layout_shape;
Expand Down Expand Up @@ -377,27 +379,27 @@ absl::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
for (const auto& argument : arguments_) {
TF_ASSIGN_OR_RETURN(
std::unique_ptr<GlobalData> owned_argument,
client_->TransferToServer(MaybeConvertLiteralToBfloat16(argument)));
client_->TransferToServer(MaybeConvertLiteralToTestType(argument)));
owning_arguments.push_back(std::move(owned_argument));
arguments.push_back(owning_arguments.back().get());
}
}

TF_ASSIGN_OR_RETURN(auto computation, builder->Build());
// We allow using a float expected literal for a bfloat16 output. In this
// case, we need to convert the expected literal to bfloat16.
// We allow using a float expected literal for a non float outputs. In this
// case, we need to convert the expected literal to type_test_.
const Literal* expected_ptr = &expected;
Literal converted_expected;
Shape layout_shape;
if (use_bfloat16()) {
converted_expected = LiteralUtil::ConvertF32ToBF16(expected);
if (test_type_ != F32) {
converted_expected = MaybeConvertLiteralToTestType(expected);
expected_ptr = &converted_expected;
if (shape_with_layout != nullptr) {
layout_shape = *shape_with_layout;
ShapeUtil::ForEachMutableSubshape(
&layout_shape, [&](Shape* subshape, const ShapeIndex& /*index*/) {
if (subshape->element_type() == F32) {
subshape->set_element_type(BF16);
subshape->set_element_type(test_type_);
}
});
shape_with_layout = &layout_shape;
Expand Down Expand Up @@ -535,13 +537,11 @@ ClientLibraryTestBase::ComputeValueAndReference(
return std::make_pair(std::move(reference), std::move(result));
}

XlaComputation ClientLibraryTestBase::CreateScalarRelu() {
XlaComputation ClientLibraryTestBase::CreateScalarReluF32() {
XlaBuilder builder("relu");
auto shape = ShapeUtil::MakeShape(use_bfloat16() ? BF16 : F32, {});
auto shape = ShapeUtil::MakeShape(F32, {});
auto z_value = Parameter(&builder, 0, shape, "z_value");
auto zero = use_bfloat16()
? ConstantR0<bfloat16>(&builder, static_cast<bfloat16>(0.0f))
: ConstantR0<float>(&builder, 0.0f);
auto zero = ConstantR0<float>(&builder, 0.0f);
Max(z_value, zero);
auto computation_status = builder.Build();
TF_CHECK_OK(computation_status.status());
Expand All @@ -550,7 +550,7 @@ XlaComputation ClientLibraryTestBase::CreateScalarRelu() {

XlaComputation ClientLibraryTestBase::CreateScalarMax() {
XlaBuilder builder("max");
auto shape = ShapeUtil::MakeShape(use_bfloat16() ? BF16 : F32, {});
auto shape = ShapeUtil::MakeShape(test_type_, {});
auto x = Parameter(&builder, 0, shape, "x");
auto y = Parameter(&builder, 1, shape, "y");
Max(x, y);
Expand All @@ -559,22 +559,6 @@ XlaComputation ClientLibraryTestBase::CreateScalarMax() {
return std::move(computation_status).value();
}

XlaComputation ClientLibraryTestBase::CreateScalarReluSensitivity() {
XlaBuilder builder("relu_sensitivity");
auto shape = ShapeUtil::MakeShape(use_bfloat16() ? BF16 : F32, {});
auto activation = Parameter(&builder, 0, shape, "activation");
auto backprop = Parameter(&builder, 1, shape, "backprop");
auto zero = use_bfloat16()
? ConstantR0<bfloat16>(&builder, static_cast<bfloat16>(0.0f))
: ConstantR0<float>(&builder, 0.0f);
auto activation_gtz = Gt(activation, zero);
Select(activation_gtz, /*on_true=*/backprop, /*on_false=*/zero);

auto computation_status = builder.Build();
TF_CHECK_OK(computation_status.status());
return std::move(computation_status).value();
}

std::unique_ptr<Array2D<float>> ClientLibraryTestBase::CreatePatternedMatrix(
int rows, int cols, float offset) {
auto array = std::make_unique<Array2D<float>>(rows, cols);
Expand Down Expand Up @@ -605,7 +589,7 @@ XlaOp ClientLibraryTestBase::AddParam(const Literal& argument,
XlaBuilder* builder) {
arguments_.push_back(argument.Clone());
return Parameter(builder, /*parameter_number=*/arguments_.size() - 1,
MaybeConvertShapeToBfloat16(argument.shape()), "");
MaybeConvertShapeToTestType(argument.shape()), "");
}

XlaOp ClientLibraryTestBase::CreateConstantFromLiteral(const Literal& literal,
Expand All @@ -623,34 +607,42 @@ ClientLibraryTestBase::CreateParameterAndTransferLiteral(
nullptr, builder, data_handle);
}

Shape ClientLibraryTestBase::MaybeConvertShapeToBfloat16(const Shape& shape) {
if (!use_bfloat16()) {
Shape ClientLibraryTestBase::MaybeConvertShapeToTestType(const Shape& shape) {
if (test_type_ == F32) {
return shape;
}
Shape new_shape = shape;
ShapeUtil::ForEachMutableSubshape(&new_shape,
[](Shape* subshape, const ShapeIndex&) {
if (subshape->element_type() == F32) {
subshape->set_element_type(BF16);
}
});
ShapeUtil::ForEachMutableSubshape(
&new_shape, [test_type = test_type_](Shape* subshape, const ShapeIndex&) {
if (subshape->element_type() == F32) {
subshape->set_element_type(test_type);
}
});
return new_shape;
}

Literal ClientLibraryTestBase::MaybeConvertLiteralToBfloat16(
Literal ClientLibraryTestBase::MaybeConvertLiteralToTestType(
const Literal& literal) {
if (use_bfloat16()) {
return LiteralUtil::ConvertF32ToBF16(literal);
switch (test_type_) {
case BF16:
return LiteralUtil::ConvertF32ToBF16(literal);
case F32:
return literal.Clone();
case F8E5M2:
return LiteralUtil::ConvertF32ToF8E5M2(literal);
case F8E4M3FN:
return LiteralUtil::ConvertF32ToF8E4M3FN(literal);
default:
LOG(FATAL) << "Unsupported test type: " << test_type_;
}
return literal.Clone();
}

absl::StatusOr<std::unique_ptr<GlobalData>>
ClientLibraryTestBase::CreateParameterAndTransferLiteral(
int64_t parameter_number, const Literal& literal, const std::string& name,
const DeviceHandle* device_handle, XlaBuilder* builder,
XlaOp* data_handle) {
Literal param_literal = MaybeConvertLiteralToBfloat16(literal);
Literal param_literal = MaybeConvertLiteralToTestType(literal);
TF_ASSIGN_OR_RETURN(auto data,
client_->TransferToServer(param_literal, device_handle));
*data_handle =
Expand Down
56 changes: 29 additions & 27 deletions xla/tests/client_library_test_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,20 @@ std::vector<TestCase> ExpandUseBfloat16(
return expanded;
}

template <typename TestCase>
std::vector<TestCase> ExpandTestType(
absl::Span<const PrimitiveType> test_type_params,
absl::Span<const TestCase> specs) {
std::vector<TestCase> expanded;
for (const PrimitiveType test_type : test_type_params) {
for (const auto& spec : specs) {
expanded.push_back(spec);
expanded.back().test_type = test_type;
}
}
return expanded;
}

// A client library test establishes an in-process XLA client connection.
class ClientLibraryTestBase : public ::testing::Test {
protected:
Expand Down Expand Up @@ -236,9 +250,8 @@ class ClientLibraryTestBase : public ::testing::Test {
absl::Span<GlobalData* const> arguments,
ErrorSpec error);
// Create scalar operations for use in reductions.
XlaComputation CreateScalarRelu();
XlaComputation CreateScalarReluF32();
XlaComputation CreateScalarMax();
XlaComputation CreateScalarReluSensitivity();

// Special case convenience functions for creating filled arrays.

Expand Down Expand Up @@ -277,7 +290,7 @@ class ClientLibraryTestBase : public ::testing::Test {
// Creates a parameter instruction, transfers the literal for the parameter to
// server, then stores into "data_handle" the global handle for that
// parameter. When the test_type is bfloat16 but the literal has F32 elements,
// the literal will be converted to BF16 before being transferred.
// the literal will be converted to test_type_ before being transferred.
absl::StatusOr<std::unique_ptr<GlobalData>> CreateParameterAndTransferLiteral(
int64_t parameter_number, const Literal& literal, const std::string& name,
XlaBuilder* builder, XlaOp* data_handle);
Expand All @@ -304,7 +317,7 @@ class ClientLibraryTestBase : public ::testing::Test {

// Creates a constant instruction with the given literal. When the test_type
// is bfloat16 but the literal has F32 elements, the literal will be converted
// to BF16 before being transferred.
// to test_type_ before being transferred.
XlaOp CreateConstantFromLiteral(const Literal& literal, XlaBuilder* builder);

// Creates a constant instruction with the given array. When the test_type is
Expand Down Expand Up @@ -417,8 +430,8 @@ class ClientLibraryTestBase : public ::testing::Test {
absl::StatusOr<std::pair<Literal, Literal>> ComputeValueAndReference(
XlaBuilder* builder, absl::Span<const Literal> arguments);

// Converts an f32 literal to bf16 if test_type is BF16.
Literal MaybeConvertLiteralToBfloat16(const Literal& literal);
// Converts a literal to the test_type if the literal's type is F32.
Literal MaybeConvertLiteralToTestType(const Literal& literal);

LocalClient* client_;
LocalClient* ref_client_; // To compute reference result.
Expand All @@ -439,10 +452,11 @@ class ClientLibraryTestBase : public ::testing::Test {
verify_output,
const Shape* output_with_layout = nullptr);

// Converts an f32 shape to bf16 if use_bfloat16_ is true.
Shape MaybeConvertShapeToBfloat16(const Shape& shape);
// Converts an f32 shape to test_type_.
Shape MaybeConvertShapeToTestType(const Shape& shape);

// Type to use when running tests.
// Type to use when running tests. By default, we use F32 for historical
// reasons and we rely on the underlying tests to change it.
PrimitiveType test_type_ = F32;

// Arguments to be passed to the computation when it runs.
Expand Down Expand Up @@ -584,9 +598,7 @@ std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR0Parameter(
NativeT value, int64_t parameter_number, const std::string& name,
XlaBuilder* builder, XlaOp* data_handle) {
Literal literal = LiteralUtil::CreateR0(value);
if (use_bfloat16() && literal.shape().element_type() == F32) {
literal = LiteralUtil::ConvertF32ToBF16(literal);
}
literal = MaybeConvertLiteralToTestType(literal);
std::unique_ptr<GlobalData> data = client_->TransferToServer(literal).value();
*data_handle = Parameter(builder, parameter_number, literal.shape(), name);
return data;
Expand All @@ -597,9 +609,7 @@ std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR1Parameter(
absl::Span<const NativeT> values, int64_t parameter_number,
const std::string& name, XlaBuilder* builder, XlaOp* data_handle) {
Literal literal = LiteralUtil::CreateR1(values);
if (use_bfloat16() && literal.shape().element_type() == F32) {
literal = LiteralUtil::ConvertF32ToBF16(literal);
}
literal = MaybeConvertLiteralToTestType(literal);
std::unique_ptr<GlobalData> data = client_->TransferToServer(literal).value();
*data_handle = Parameter(builder, parameter_number, literal.shape(), name);
return data;
Expand All @@ -610,9 +620,7 @@ std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR2Parameter(
const Array2D<NativeT>& array_2d, int64_t parameter_number,
const std::string& name, XlaBuilder* builder, XlaOp* data_handle) {
Literal literal = LiteralUtil::CreateR2FromArray2D(array_2d);
if (use_bfloat16() && literal.shape().element_type() == F32) {
literal = LiteralUtil::ConvertF32ToBF16(literal);
}
literal = MaybeConvertLiteralToTestType(literal);
std::unique_ptr<GlobalData> data = client_->TransferToServer(literal).value();
*data_handle = Parameter(builder, parameter_number, literal.shape(), name);
return data;
Expand All @@ -623,9 +631,7 @@ std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR3Parameter(
const Array3D<NativeT>& array_3d, int64_t parameter_number,
const std::string& name, XlaBuilder* builder, XlaOp* data_handle) {
Literal literal = LiteralUtil::CreateR3FromArray3D(array_3d);
if (use_bfloat16() && literal.shape().element_type() == F32) {
literal = LiteralUtil::ConvertF32ToBF16(literal);
}
literal = MaybeConvertLiteralToTestType(literal);
std::unique_ptr<GlobalData> data = client_->TransferToServer(literal).value();
*data_handle = Parameter(builder, parameter_number, literal.shape(), name);
return data;
Expand All @@ -636,9 +642,7 @@ std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR4Parameter(
const Array4D<NativeT>& array_4d, int64_t parameter_number,
const std::string& name, XlaBuilder* builder, XlaOp* data_handle) {
Literal literal = LiteralUtil::CreateR4FromArray4D(array_4d);
if (use_bfloat16() && literal.shape().element_type() == F32) {
literal = LiteralUtil::ConvertF32ToBF16(literal);
}
literal = MaybeConvertLiteralToTestType(literal);
std::unique_ptr<GlobalData> data = client_->TransferToServer(literal).value();
*data_handle = Parameter(builder, parameter_number, literal.shape(), name);
return data;
Expand All @@ -649,9 +653,7 @@ std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateParameter(
const Array<NativeT>& array, int64_t parameter_number,
const std::string& name, XlaBuilder* builder, XlaOp* data_handle) {
Literal literal = LiteralUtil::CreateFromArray(array);
if (use_bfloat16() && literal.shape().element_type() == F32) {
literal = LiteralUtil::ConvertF32ToBF16(literal);
}
literal = MaybeConvertLiteralToTestType(literal);
std::unique_ptr<GlobalData> data = client_->TransferToServer(literal).value();
*data_handle = Parameter(builder, parameter_number, literal.shape(), name);
return data;
Expand Down
2 changes: 1 addition & 1 deletion xla/tests/reduce_window_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1339,7 +1339,7 @@ class R2ReduceWindowTest : public ReduceWindowTestBase,
/*window_dilations=*/param.window_dilation,
/*padding=*/padding);

ComputeAndCompare(&b, {MaybeConvertLiteralToBfloat16(input_literal)},
ComputeAndCompare(&b, {MaybeConvertLiteralToTestType(input_literal)},
DefaultErrorSpec());
}
};
Expand Down
Loading