Skip to content

Commit

Permalink
Fix map visibility for non-public functions.
Browse files Browse the repository at this point in the history
Also update std::is_unsigned_msb_set to public, since it flagged a violation at https://github.com/google/xls/blob/main/xls/examples/protobuf/varint_streaming_decode.x#L98.

Fixes #1490.

PiperOrigin-RevId: 649100919
  • Loading branch information
mikex-oss authored and copybara-github committed Jul 3, 2024
1 parent 52703c5 commit 03b21bb
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 3 deletions.
2 changes: 1 addition & 1 deletion xls/dslx/stdlib/std.x
Original file line number Diff line number Diff line change
Expand Up @@ -978,7 +978,7 @@ fn test_umul_with_overflow() {
}(());
}

fn is_unsigned_msb_set<N: u32>(x: uN[N]) -> bool { x[-1:] }
pub fn is_unsigned_msb_set<N: u32>(x: uN[N]) -> bool { x[-1:] }

#[test]
fn is_unsigned_msb_set_test() {
Expand Down
18 changes: 16 additions & 2 deletions xls/dslx/type_system/deduce_invocation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ extern absl::StatusOr<std::unique_ptr<Type>> DeduceAndResolve(
// created).
//
// Builtins don't have `Function` nodes (since they're not userspace functions),
// so that inference can't occur, so we essentually perform that synthesis and
// so that inference can't occur, so we essentially perform that synthesis and
// deduction here.
//
// Args:
Expand Down Expand Up @@ -109,9 +109,23 @@ static absl::StatusOr<std::unique_ptr<Type>> DeduceMapInvocation(
XLS_ASSIGN_OR_RETURN(std::unique_ptr<Type> arg0_type,
DeduceAndResolve(args[0], ctx));

Expr* callee = args[1];
// If the callee is an imported function, we need to check that it is public.
if (auto* colon_ref = dynamic_cast<ColonRef*>(callee); colon_ref != nullptr) {
XLS_ASSIGN_OR_RETURN(Function * callee_fn,
ResolveFunction(colon_ref, ctx->type_info()));
if (!callee_fn->is_public()) {
return TypeInferenceErrorStatus(
node->span(), nullptr,
absl::StrFormat("Attempted to refer to module member %s that "
"is not public.",
callee->ToString()));
}
}

// Then get the type and bindings for the mapping fn.
Invocation* element_invocation =
CreateElementInvocation(ctx->module(), node->span(), /*callee=*/args[1],
CreateElementInvocation(ctx->module(), node->span(), /*callee=*/callee,
/*arg_array=*/args[0], /*parent=*/node->parent());
XLS_ASSIGN_OR_RETURN(TypeAndParametricEnv tab,
ctx->typecheck_invocation()(ctx, element_invocation,
Expand Down
38 changes: 38 additions & 0 deletions xls/dslx/type_system/typecheck_module_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,44 @@ fn f() -> u32[3] {
XLS_EXPECT_OK(Typecheck(program));
}

TEST(TypecheckTest, MapImportedNonPublicFunction) {
constexpr std::string_view kImported = R"(
fn some_function(x: u32) -> u32 { x }
)";
constexpr std::string_view kProgram = R"(
import imported;
fn main() -> u32[3] {
map(u32[3]:[1, 2, 3], imported::some_function)
})";
auto import_data = CreateImportDataForTest();
XLS_ASSERT_OK_AND_ASSIGN(
TypecheckedModule module,
ParseAndTypecheck(kImported, "imported.x", "imported", &import_data));
EXPECT_THAT(
ParseAndTypecheck(kProgram, "fake_main_path.x", "main", &import_data),
StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("not public")));
}

TEST(TypecheckTest, MapImportedNonPublicInferredParametricFunction) {
constexpr std::string_view kImported = R"(
fn some_function<N: u32>(x: bits[N]) -> bits[N] { x }
)";
constexpr std::string_view kProgram = R"(
import imported;
fn main() -> u32[3] {
map(u32[3]:[1, 2, 3], imported::some_function)
})";
auto import_data = CreateImportDataForTest();
XLS_ASSERT_OK_AND_ASSIGN(
TypecheckedModule module,
ParseAndTypecheck(kImported, "imported.x", "imported", &import_data));
EXPECT_THAT(
ParseAndTypecheck(kProgram, "fake_main_path.x", "main", &import_data),
StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("not public")));
}

TEST(TypecheckErrorTest, ParametricInvocationConflictingArgs) {
std::string program = R"(
fn id<N: u32>(x: bits[N], y: bits[N]) -> bits[N] { x }
Expand Down

0 comments on commit 03b21bb

Please sign in to comment.