Skip to content

Commit

Permalink
In literal_util.cc, use absl::uniform_int_distribution.
Browse files Browse the repository at this point in the history
absl::uniform_int_distribution is faster than std::uniform_int_distribution. This makes initializing literals in run_hlo_module faster. In particular, I tested the following HLO:

    ENTRY f {
      arg = s8[2000000000] parameter(0)
      ROOT add_result = s8[2000000000] add(arg, arg)
    }

It takes 7.8 seconds to initialize the input literal with the absl function, and 18.2 with the std function.

Unfortunately the absl version of uniform_real_distribution is not faster. It takes 25.5 seconds with absl and 8.3 with std on the HLO when s8 is replaced with f16.

PiperOrigin-RevId: 723316656
  • Loading branch information
reedwm authored and Google-ML-Automation committed Feb 5, 2025
1 parent 6327932 commit b431812
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
1 change: 1 addition & 0 deletions xla/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -684,6 +684,7 @@ cc_library(
":xla_data_proto_cc",
"//xla/tsl/lib/core:bitmap",
"@com_google_absl//absl/functional:function_ref",
"@com_google_absl//absl/random:distributions",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
Expand Down
7 changes: 4 additions & 3 deletions xla/literal_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ limitations under the License.
#include <utility>
#include <vector>

#include "absl/random/uniform_int_distribution.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
Expand Down Expand Up @@ -323,7 +324,7 @@ void PopulateWithFloatingPointData(
"max_bits_of_precision for floating points.";
CHECK(!no_duplicates) << "Cannot set both no_duplicates and "
"max_bits_of_precision for floating points.";
std::uniform_int_distribution<int64_t> generator(
absl::uniform_int_distribution<int64_t> generator(
-(1 << *max_bits_of_precision), 1 << *max_bits_of_precision);
for (FloatT& value : literal->data<FloatT>()) {
int64_t temp = generator(*engine);
Expand Down Expand Up @@ -391,7 +392,7 @@ void PopulateWithRandomIntegralDataWithBounds(Literal* literal,
std::shuffle(literal->data<IntT>().begin(), literal->data<IntT>().end(),
*engine);
} else {
std::uniform_int_distribution<RngT<IntT>> generator(
absl::uniform_int_distribution<RngT<IntT>> generator(
static_cast<RngT<IntT>>(min), static_cast<RngT<IntT>>(max));
for (IntT& value : literal->data<IntT>()) {
value = static_cast<IntT>(generator(*engine));
Expand Down Expand Up @@ -732,7 +733,7 @@ absl::StatusOr<Literal> MakeFakeLiteral(
return absl::OkStatus();
}
if constexpr (primitive_type_constant == PRED) {
std::uniform_int_distribution<int> generator(0, 1);
absl::uniform_int_distribution<int> generator(0, 1);
TF_CHECK_OK(literal.Populate<bool>(
[&](absl::Span<const int64_t> /*indices*/) {
return generator(*engine);
Expand Down

0 comments on commit b431812

Please sign in to comment.