Skip to content

Commit

Permalink
Review comments.
Browse files Browse the repository at this point in the history
  • Loading branch information
zygoloid committed Aug 2, 2024
1 parent c2bb90e commit a2e7803
Show file tree
Hide file tree
Showing 3 changed files with 140 additions and 65 deletions.
126 changes: 75 additions & 51 deletions toolchain/check/deduce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,42 +13,64 @@
namespace Carbon::Check {

namespace {
struct DeductionWorklist {
// A list of pairs of (instruction from generic, corresponding instruction from
// call to of generic) for which we still need to perform deduction, along with
// methods to add and pop pending deductions from the list. Deductions are
// popped in order from most- to least-recently pushed, with the intent that
// they are visited in depth-first order, although the order is not expected to
// matter except when it influences which error is diagnosed.
class DeductionWorklist {
public:
explicit DeductionWorklist(Context& context) : context_(context) {}

struct PendingDeduction {
SemIR::InstId param;
SemIR::InstId arg;
bool needs_substitution;
};

// Adds a single (param, arg) deduction.
auto Add(SemIR::InstId param, SemIR::InstId arg, bool needs_substitution)
-> void {
deductions.push_back(
deductions_.push_back(
{.param = param, .arg = arg, .needs_substitution = needs_substitution});
}

auto AddBlock(llvm::ArrayRef<SemIR::InstId> params,
llvm::ArrayRef<SemIR::InstId> args, bool needs_substitution)
// Adds a list of (param, arg) deductions. These are added in reverse order so
// they are popped in forward order.
auto AddAll(llvm::ArrayRef<SemIR::InstId> params,
llvm::ArrayRef<SemIR::InstId> args, bool needs_substitution)
-> void {
if (params.size() != args.size()) {
// TODO: Decide whether to error on this or just treat the parameter list
// as non-deduced. For now we treat it as non-deduced.
return;
}
for (auto [param, arg] : llvm::zip_equal(params, args)) {
for (auto [param, arg] : llvm::reverse(llvm::zip_equal(params, args))) {
Add(param, arg, needs_substitution);
}
}

auto AddBlock(SemIR::InstBlockId params, llvm::ArrayRef<SemIR::InstId> args,
bool needs_substitution) -> void {
AddBlock(context.inst_blocks().Get(params), args, needs_substitution);
auto AddAll(SemIR::InstBlockId params, llvm::ArrayRef<SemIR::InstId> args,
bool needs_substitution) -> void {
AddAll(context_.inst_blocks().Get(params), args, needs_substitution);
}

auto AddBlock(SemIR::InstBlockId params, SemIR::InstBlockId args,
bool needs_substitution) -> void {
AddBlock(context.inst_blocks().Get(params), context.inst_blocks().Get(args),
needs_substitution);
auto AddAll(SemIR::InstBlockId params, SemIR::InstBlockId args,
bool needs_substitution) -> void {
AddAll(context_.inst_blocks().Get(params), context_.inst_blocks().Get(args),
needs_substitution);
}

struct PendingDeduction {
SemIR::InstId param;
SemIR::InstId arg;
bool needs_substitution;
};
Context& context;
llvm::SmallVector<PendingDeduction> deductions;
// Returns whether we have completed all deductions.
auto Done() -> bool { return deductions_.empty(); }

// Pops the next deduction. Requires `!Done()`.
auto PopNext() -> PendingDeduction { return deductions_.pop_back_val(); }

private:
Context& context_;
llvm::SmallVector<PendingDeduction> deductions_;
};
} // namespace

Expand All @@ -59,28 +81,23 @@ static auto NoteGenericHere(Context& context, SemIR::GenericId generic_id,
diag.Note(context.generics().Get(generic_id).decl_id, DeductionGenericHere);
}

auto DeduceGenericCallArguments(Context& context, Parse::NodeId node_id,
SemIR::GenericId generic_id,
SemIR::SpecificId enclosing_specific_id,
SemIR::InstBlockId implicit_params_id,
SemIR::InstBlockId params_id,
SemIR::InstId self_id,
llvm::ArrayRef<SemIR::InstId> arg_ids)
-> SemIR::SpecificId {
DeductionWorklist worklist = {.context = context};
auto DeduceGenericCallArguments(
Context& context, Parse::NodeId node_id, SemIR::GenericId generic_id,
SemIR::SpecificId enclosing_specific_id,
[[maybe_unused]] SemIR::InstBlockId implicit_params_id,
SemIR::InstBlockId params_id, [[maybe_unused]] SemIR::InstId self_id,
llvm::ArrayRef<SemIR::InstId> arg_ids) -> SemIR::SpecificId {
DeductionWorklist worklist(context);

// TODO: Perform deduction for type of self
static_cast<void>(implicit_params_id);
static_cast<void>(self_id);
llvm::SmallVector<SemIR::InstId> result_arg_ids;
llvm::SmallVector<Substitution> substitutions;

// Copy any outer generic arguments from the specified instance and prepare to
// substitute them into the function declaration.
llvm::SmallVector<SemIR::InstId> results;
llvm::SmallVector<Substitution> substitutions;
if (enclosing_specific_id.is_valid()) {
auto args = context.inst_blocks().Get(
context.specifics().Get(enclosing_specific_id).args_id);
results.assign(args.begin(), args.end());
result_arg_ids.assign(args.begin(), args.end());

// TODO: Subst is linear in the length of the substitutions list. Change it
// so we can pass in an array mapping indexes to substitutions instead.
Expand All @@ -91,18 +108,21 @@ auto DeduceGenericCallArguments(Context& context, Parse::NodeId node_id,
.replacement_id = context.constant_values().Get(subst_inst_id)});
}
}
auto first_deduced_index = SemIR::CompileTimeBindIndex(results.size());
auto first_deduced_index = SemIR::CompileTimeBindIndex(result_arg_ids.size());

worklist.AddBlock(params_id, arg_ids, /*needs_substitution=*/true);
// Initialize the deduced arguments to Invalid.
result_arg_ids.resize(context.inst_blocks()
.Get(context.generics().Get(generic_id).bindings_id)
.size(),
SemIR::InstId::Invalid);

results.resize(context.inst_blocks()
.Get(context.generics().Get(generic_id).bindings_id)
.size(),
SemIR::InstId::Invalid);
// Prepare to perform deduction of the explicit parameters against their
// arguments.
// TODO: Also perform deduction for type of self.
worklist.AddAll(params_id, arg_ids, /*needs_substitution=*/true);

while (!worklist.deductions.empty()) {
auto [param_id, arg_id, needs_substitution] =
worklist.deductions.pop_back_val();
while (!worklist.Done()) {
auto [param_id, arg_id, needs_substitution] = worklist.PopNext();

// If the parameter has a symbolic type, deduce against that.
auto param_type_id = context.insts().Get(param_id).type_id();
Expand Down Expand Up @@ -136,14 +156,15 @@ auto DeduceGenericCallArguments(Context& context, Parse::NodeId node_id,
auto& entity_name = context.entity_names().Get(bind.entity_name_id);
auto index = entity_name.bind_index;
if (index.is_valid() && index >= first_deduced_index) {
CARBON_CHECK(static_cast<size_t>(index.index) < results.size())
CARBON_CHECK(static_cast<size_t>(index.index) < result_arg_ids.size())
<< "Deduced value for unexpected index " << index
<< "; expected to deduce " << results.size() << " arguments.";
<< "; expected to deduce " << result_arg_ids.size()
<< " arguments.";
auto arg_const_inst_id =
context.constant_values().GetConstantInstId(arg_id);
if (arg_const_inst_id.is_valid()) {
if (results[index.index].is_valid() &&
results[index.index] != arg_const_inst_id) {
if (result_arg_ids[index.index].is_valid() &&
result_arg_ids[index.index] != arg_const_inst_id) {
// TODO: Include the two different deduced values.
CARBON_DIAGNOSTIC(DeductionInconsistent, Error,
"Inconsistent deductions for value of generic "
Expand All @@ -155,7 +176,7 @@ auto DeduceGenericCallArguments(Context& context, Parse::NodeId node_id,
diag.Emit();
return SemIR::SpecificId::Invalid;
}
results[index.index] = arg_const_inst_id;
result_arg_ids[index.index] = arg_const_inst_id;
}
}
break;
Expand All @@ -169,10 +190,13 @@ auto DeduceGenericCallArguments(Context& context, Parse::NodeId node_id,
}

// Check we deduced an argument value for every parameter.
for (auto [i, deduced_arg_id] : llvm::enumerate(results)) {
for (auto [i, deduced_arg_id] :
llvm::enumerate(llvm::ArrayRef(result_arg_ids)
.drop_front(first_deduced_index.index))) {
if (!deduced_arg_id.is_valid()) {
auto binding_index = first_deduced_index.index + i;
auto binding_id = context.inst_blocks().Get(
context.generics().Get(generic_id).bindings_id)[i];
context.generics().Get(generic_id).bindings_id)[binding_index];
auto entity_name_id =
context.insts().GetAs<SemIR::AnyBindName>(binding_id).entity_name_id;
CARBON_DIAGNOSTIC(DeductionIncomplete, Error,
Expand All @@ -190,7 +214,7 @@ auto DeduceGenericCallArguments(Context& context, Parse::NodeId node_id,
// TODO: Convert the deduced values to the types of the bindings.

return MakeSpecific(context, generic_id,
context.inst_blocks().AddCanonical(results));
context.inst_blocks().AddCanonical(result_arg_ids));
}

} // namespace Carbon::Check
75 changes: 62 additions & 13 deletions toolchain/check/testdata/function/generic/deduce.carbon
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,18 @@

// --- deduce_explicit.carbon

library "fail_todo_deduce_explicit";
library "deduce_explicit";

fn ExplicitGenericParam(T:! type) -> T*;

fn CallExplicitGenericParam() -> i32* {
return ExplicitGenericParam(i32);
}

fn CallExplicitGenericParamWithGenericArg(T:! type) -> {.a: T}* {
return ExplicitGenericParam({.a: T});
}

// --- fail_todo_explicit_vs_deduced.carbon

library "fail_todo_explicit_vs_deduced";
Expand Down Expand Up @@ -86,6 +90,8 @@ fn CallStructParam() {

library "fail_deduce_incomplete";

// TODO: It would be nice to diagnose this at its point of declaration, because
// U is not deducible.
fn ImplicitNotDeducible[T:! type, U:! type](x: T) -> U;

fn CallImplicitNotDeducible() {
Expand Down Expand Up @@ -128,6 +134,10 @@ fn CallImplicitNotDeducible() {
// CHECK:STDOUT: %.3: type = ptr_type i32 [template]
// CHECK:STDOUT: %CallExplicitGenericParam.type: type = fn_type @CallExplicitGenericParam [template]
// CHECK:STDOUT: %CallExplicitGenericParam: %CallExplicitGenericParam.type = struct_value () [template]
// CHECK:STDOUT: %.4: type = struct_type {.a: %T} [symbolic]
// CHECK:STDOUT: %.5: type = ptr_type %.4 [symbolic]
// CHECK:STDOUT: %CallExplicitGenericParamWithGenericArg.type: type = fn_type @CallExplicitGenericParamWithGenericArg [template]
// CHECK:STDOUT: %CallExplicitGenericParamWithGenericArg: %CallExplicitGenericParamWithGenericArg.type = struct_value () [template]
// CHECK:STDOUT: }
// CHECK:STDOUT:
// CHECK:STDOUT: imports {
Expand All @@ -149,12 +159,13 @@ fn CallImplicitNotDeducible() {
// CHECK:STDOUT: .Core = imports.%Core
// CHECK:STDOUT: .ExplicitGenericParam = %ExplicitGenericParam.decl
// CHECK:STDOUT: .CallExplicitGenericParam = %CallExplicitGenericParam.decl
// CHECK:STDOUT: .CallExplicitGenericParamWithGenericArg = %CallExplicitGenericParamWithGenericArg.decl
// CHECK:STDOUT: }
// CHECK:STDOUT: %Core.import = import Core
// CHECK:STDOUT: %ExplicitGenericParam.decl: %ExplicitGenericParam.type = fn_decl @ExplicitGenericParam [template = constants.%ExplicitGenericParam] {
// CHECK:STDOUT: %T.loc4_25.1: type = param T
// CHECK:STDOUT: @ExplicitGenericParam.%T.loc4: type = bind_symbolic_name T 0, %T.loc4_25.1 [symbolic = @ExplicitGenericParam.%T.1 (constants.%T)]
// CHECK:STDOUT: %T.ref: type = name_ref T, @ExplicitGenericParam.%T.loc4 [symbolic = @ExplicitGenericParam.%T.1 (constants.%T)]
// CHECK:STDOUT: %T.ref.loc4: type = name_ref T, @ExplicitGenericParam.%T.loc4 [symbolic = @ExplicitGenericParam.%T.1 (constants.%T)]
// CHECK:STDOUT: %.loc4: type = ptr_type %T [symbolic = @ExplicitGenericParam.%.1 (constants.%.1)]
// CHECK:STDOUT: @ExplicitGenericParam.%return: ref @ExplicitGenericParam.%.1 (%.1) = var <return slot>
// CHECK:STDOUT: }
Expand All @@ -165,6 +176,14 @@ fn CallImplicitNotDeducible() {
// CHECK:STDOUT: %.loc6_37.3: type = ptr_type i32 [template = constants.%.3]
// CHECK:STDOUT: @CallExplicitGenericParam.%return: ref %.3 = var <return slot>
// CHECK:STDOUT: }
// CHECK:STDOUT: %CallExplicitGenericParamWithGenericArg.decl: %CallExplicitGenericParamWithGenericArg.type = fn_decl @CallExplicitGenericParamWithGenericArg [template = constants.%CallExplicitGenericParamWithGenericArg] {
// CHECK:STDOUT: %T.loc10_43.1: type = param T
// CHECK:STDOUT: @CallExplicitGenericParamWithGenericArg.%T.loc10: type = bind_symbolic_name T 0, %T.loc10_43.1 [symbolic = @CallExplicitGenericParamWithGenericArg.%T.1 (constants.%T)]
// CHECK:STDOUT: %T.ref.loc10: type = name_ref T, @CallExplicitGenericParamWithGenericArg.%T.loc10 [symbolic = @CallExplicitGenericParamWithGenericArg.%T.1 (constants.%T)]
// CHECK:STDOUT: %.loc10_62: type = struct_type {.a: %T} [symbolic = @CallExplicitGenericParamWithGenericArg.%.1 (constants.%.4)]
// CHECK:STDOUT: %.loc10_63: type = ptr_type %.4 [symbolic = @CallExplicitGenericParamWithGenericArg.%.2 (constants.%.5)]
// CHECK:STDOUT: @CallExplicitGenericParamWithGenericArg.%return: ref @CallExplicitGenericParamWithGenericArg.%.2 (%.5) = var <return slot>
// CHECK:STDOUT: }
// CHECK:STDOUT: }
// CHECK:STDOUT:
// CHECK:STDOUT: generic fn @ExplicitGenericParam(%T.loc4: type) {
Expand All @@ -188,6 +207,25 @@ fn CallImplicitNotDeducible() {
// CHECK:STDOUT: return %.loc7_35.2
// CHECK:STDOUT: }
// CHECK:STDOUT:
// CHECK:STDOUT: generic fn @CallExplicitGenericParamWithGenericArg(%T.loc10: type) {
// CHECK:STDOUT: %T.1: type = bind_symbolic_name T 0 [symbolic = %T.1 (constants.%T)]
// CHECK:STDOUT: %.1: type = struct_type {.a: @CallExplicitGenericParamWithGenericArg.%T.1 (%T)} [symbolic = %.1 (constants.%.4)]
// CHECK:STDOUT: %.2: type = ptr_type @CallExplicitGenericParamWithGenericArg.%.1 (%.4) [symbolic = %.2 (constants.%.5)]
// CHECK:STDOUT:
// CHECK:STDOUT: !definition:
// CHECK:STDOUT:
// CHECK:STDOUT: fn(%T.loc10: type) -> @CallExplicitGenericParamWithGenericArg.%.2 (%.5) {
// CHECK:STDOUT: !entry:
// CHECK:STDOUT: %ExplicitGenericParam.ref: %ExplicitGenericParam.type = name_ref ExplicitGenericParam, file.%ExplicitGenericParam.decl [template = constants.%ExplicitGenericParam]
// CHECK:STDOUT: %T.ref: type = name_ref T, %T.loc10 [symbolic = %T.1 (constants.%T)]
// CHECK:STDOUT: %.loc11_37: type = struct_type {.a: %T} [symbolic = %.1 (constants.%.4)]
// CHECK:STDOUT: %ExplicitGenericParam.call: init @CallExplicitGenericParamWithGenericArg.%.2 (%.5) = call %ExplicitGenericParam.ref(%.loc11_37)
// CHECK:STDOUT: %.loc11_39.1: @CallExplicitGenericParamWithGenericArg.%.2 (%.5) = value_of_initializer %ExplicitGenericParam.call
// CHECK:STDOUT: %.loc11_39.2: @CallExplicitGenericParamWithGenericArg.%.2 (%.5) = converted %ExplicitGenericParam.call, %.loc11_39.1
// CHECK:STDOUT: return %.loc11_39.2
// CHECK:STDOUT: }
// CHECK:STDOUT: }
// CHECK:STDOUT:
// CHECK:STDOUT: specific @ExplicitGenericParam(constants.%T) {
// CHECK:STDOUT: %T.1 => constants.%T
// CHECK:STDOUT: %.1 => constants.%.1
Expand All @@ -198,6 +236,17 @@ fn CallImplicitNotDeducible() {
// CHECK:STDOUT: %.1 => constants.%.3
// CHECK:STDOUT: }
// CHECK:STDOUT:
// CHECK:STDOUT: specific @CallExplicitGenericParamWithGenericArg(constants.%T) {
// CHECK:STDOUT: %T.1 => constants.%T
// CHECK:STDOUT: %.1 => constants.%.4
// CHECK:STDOUT: %.2 => constants.%.5
// CHECK:STDOUT: }
// CHECK:STDOUT:
// CHECK:STDOUT: specific @ExplicitGenericParam(constants.%.4) {
// CHECK:STDOUT: %T.1 => constants.%.4
// CHECK:STDOUT: %.1 => constants.%.5
// CHECK:STDOUT: }
// CHECK:STDOUT:
// CHECK:STDOUT: --- fail_todo_explicit_vs_deduced.carbon
// CHECK:STDOUT:
// CHECK:STDOUT: constants {
Expand Down Expand Up @@ -529,30 +578,30 @@ fn CallImplicitNotDeducible() {
// CHECK:STDOUT: }
// CHECK:STDOUT: %Core.import = import Core
// CHECK:STDOUT: %ImplicitNotDeducible.decl: %ImplicitNotDeducible.type = fn_decl @ImplicitNotDeducible [template = constants.%ImplicitNotDeducible] {
// CHECK:STDOUT: %T.loc4_25.1: type = param T
// CHECK:STDOUT: @ImplicitNotDeducible.%T.loc4: type = bind_symbolic_name T 0, %T.loc4_25.1 [symbolic = @ImplicitNotDeducible.%T.1 (constants.%T)]
// CHECK:STDOUT: %U.loc4_35.1: type = param U
// CHECK:STDOUT: @ImplicitNotDeducible.%U.loc4: type = bind_symbolic_name U 1, %U.loc4_35.1 [symbolic = @ImplicitNotDeducible.%U.1 (constants.%U)]
// CHECK:STDOUT: %T.ref: type = name_ref T, @ImplicitNotDeducible.%T.loc4 [symbolic = @ImplicitNotDeducible.%T.1 (constants.%T)]
// CHECK:STDOUT: %x.loc4_45.1: @ImplicitNotDeducible.%T.1 (%T) = param x
// CHECK:STDOUT: @ImplicitNotDeducible.%x: @ImplicitNotDeducible.%T.1 (%T) = bind_name x, %x.loc4_45.1
// CHECK:STDOUT: %U.ref: type = name_ref U, @ImplicitNotDeducible.%U.loc4 [symbolic = @ImplicitNotDeducible.%U.1 (constants.%U)]
// CHECK:STDOUT: %T.loc6_25.1: type = param T
// CHECK:STDOUT: @ImplicitNotDeducible.%T.loc6: type = bind_symbolic_name T 0, %T.loc6_25.1 [symbolic = @ImplicitNotDeducible.%T.1 (constants.%T)]
// CHECK:STDOUT: %U.loc6_35.1: type = param U
// CHECK:STDOUT: @ImplicitNotDeducible.%U.loc6: type = bind_symbolic_name U 1, %U.loc6_35.1 [symbolic = @ImplicitNotDeducible.%U.1 (constants.%U)]
// CHECK:STDOUT: %T.ref: type = name_ref T, @ImplicitNotDeducible.%T.loc6 [symbolic = @ImplicitNotDeducible.%T.1 (constants.%T)]
// CHECK:STDOUT: %x.loc6_45.1: @ImplicitNotDeducible.%T.1 (%T) = param x
// CHECK:STDOUT: @ImplicitNotDeducible.%x: @ImplicitNotDeducible.%T.1 (%T) = bind_name x, %x.loc6_45.1
// CHECK:STDOUT: %U.ref: type = name_ref U, @ImplicitNotDeducible.%U.loc6 [symbolic = @ImplicitNotDeducible.%U.1 (constants.%U)]
// CHECK:STDOUT: @ImplicitNotDeducible.%return: ref @ImplicitNotDeducible.%U.1 (%U) = var <return slot>
// CHECK:STDOUT: }
// CHECK:STDOUT: %CallImplicitNotDeducible.decl: %CallImplicitNotDeducible.type = fn_decl @CallImplicitNotDeducible [template = constants.%CallImplicitNotDeducible] {}
// CHECK:STDOUT: }
// CHECK:STDOUT:
// CHECK:STDOUT: generic fn @ImplicitNotDeducible(%T.loc4: type, %U.loc4: type) {
// CHECK:STDOUT: generic fn @ImplicitNotDeducible(%T.loc6: type, %U.loc6: type) {
// CHECK:STDOUT: %T.1: type = bind_symbolic_name T 0 [symbolic = %T.1 (constants.%T)]
// CHECK:STDOUT: %U.1: type = bind_symbolic_name U 1 [symbolic = %U.1 (constants.%U)]
// CHECK:STDOUT:
// CHECK:STDOUT: fn[%T.loc4: type, %U.loc4: type](%x: @ImplicitNotDeducible.%T.1 (%T)) -> @ImplicitNotDeducible.%U.1 (%U);
// CHECK:STDOUT: fn[%T.loc6: type, %U.loc6: type](%x: @ImplicitNotDeducible.%T.1 (%T)) -> @ImplicitNotDeducible.%U.1 (%U);
// CHECK:STDOUT: }
// CHECK:STDOUT:
// CHECK:STDOUT: fn @CallImplicitNotDeducible() {
// CHECK:STDOUT: !entry:
// CHECK:STDOUT: %ImplicitNotDeducible.ref: %ImplicitNotDeducible.type = name_ref ImplicitNotDeducible, file.%ImplicitNotDeducible.decl [template = constants.%ImplicitNotDeducible]
// CHECK:STDOUT: %.loc14: i32 = int_literal 42 [template = constants.%.2]
// CHECK:STDOUT: %.loc16: i32 = int_literal 42 [template = constants.%.2]
// CHECK:STDOUT: return
// CHECK:STDOUT: }
// CHECK:STDOUT:
Expand Down
Loading

0 comments on commit a2e7803

Please sign in to comment.