Skip to content

Commit

Permalink
Initial rough support for deducing generic arguments in a call to a g…
Browse files Browse the repository at this point in the history
…eneric function. (carbon-language#4184)
  • Loading branch information
zygoloid authored and brymer-meneses committed Aug 15, 2024
1 parent 11e3c77 commit 32f3191
Show file tree
Hide file tree
Showing 13 changed files with 1,324 additions and 145 deletions.
2 changes: 2 additions & 0 deletions toolchain/check/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ cc_library(
"context.cpp",
"convert.cpp",
"decl_name_stack.cpp",
"deduce.cpp",
"eval.cpp",
"function.cpp",
"generic.cpp",
Expand All @@ -36,6 +37,7 @@ cc_library(
"convert.h",
"decl_introducer_state.h",
"decl_name_stack.h",
"deduce.h",
"diagnostic_helpers.h",
"eval.h",
"function.h",
Expand Down
21 changes: 10 additions & 11 deletions toolchain/check/call.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "toolchain/base/kind_switch.h"
#include "toolchain/check/context.h"
#include "toolchain/check/convert.h"
#include "toolchain/check/deduce.h"
#include "toolchain/check/function.h"
#include "toolchain/check/generic.h"
#include "toolchain/sem_ir/ids.h"
Expand Down Expand Up @@ -98,18 +99,16 @@ auto PerformCall(Context& context, Parse::NodeId node_id,
}
auto& callable = context.functions().Get(callee_function.function_id);

// TODO: Properly determine the generic argument values for the call. For now,
// we do so only if the function introduces no generic parameters beyond those
// of the enclosing context.
// If the callee is a generic function, determine the generic argument values
// for the call.
auto specific_id = SemIR::SpecificId::Invalid;
if (callee_function.specific_id.is_valid()) {
auto enclosing_args_id =
context.specifics().Get(callee_function.specific_id).args_id;
auto fn_params_id = context.generics().Get(callable.generic_id).bindings_id;
if (context.inst_blocks().Get(fn_params_id).size() ==
context.inst_blocks().Get(enclosing_args_id).size()) {
specific_id =
MakeSpecific(context, callable.generic_id, enclosing_args_id);
if (callable.generic_id.is_valid()) {
specific_id = DeduceGenericCallArguments(
context, node_id, callable.generic_id, callee_function.specific_id,
callable.implicit_param_refs_id, callable.param_refs_id,
callee_function.self_id, arg_ids);
if (!specific_id.is_valid()) {
return SemIR::InstId::BuiltinError;
}
}

Expand Down
220 changes: 220 additions & 0 deletions toolchain/check/deduce.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
// Part of the Carbon Language project, under the Apache License v2.0 with LLVM
// Exceptions. See /LICENSE for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "toolchain/check/deduce.h"

#include "toolchain/base/kind_switch.h"
#include "toolchain/check/context.h"
#include "toolchain/check/generic.h"
#include "toolchain/check/subst.h"
#include "toolchain/sem_ir/typed_insts.h"

namespace Carbon::Check {

namespace {
// A list of pairs of (instruction from generic, corresponding instruction from
// call to of generic) for which we still need to perform deduction, along with
// methods to add and pop pending deductions from the list. Deductions are
// popped in order from most- to least-recently pushed, with the intent that
// they are visited in depth-first order, although the order is not expected to
// matter except when it influences which error is diagnosed.
class DeductionWorklist {
public:
explicit DeductionWorklist(Context& context) : context_(context) {}

struct PendingDeduction {
SemIR::InstId param;
SemIR::InstId arg;
bool needs_substitution;
};

// Adds a single (param, arg) deduction.
auto Add(SemIR::InstId param, SemIR::InstId arg, bool needs_substitution)
-> void {
deductions_.push_back(
{.param = param, .arg = arg, .needs_substitution = needs_substitution});
}

// Adds a list of (param, arg) deductions. These are added in reverse order so
// they are popped in forward order.
auto AddAll(llvm::ArrayRef<SemIR::InstId> params,
llvm::ArrayRef<SemIR::InstId> args, bool needs_substitution)
-> void {
if (params.size() != args.size()) {
// TODO: Decide whether to error on this or just treat the parameter list
// as non-deduced. For now we treat it as non-deduced.
return;
}
for (auto [param, arg] : llvm::reverse(llvm::zip_equal(params, args))) {
Add(param, arg, needs_substitution);
}
}

auto AddAll(SemIR::InstBlockId params, llvm::ArrayRef<SemIR::InstId> args,
bool needs_substitution) -> void {
AddAll(context_.inst_blocks().Get(params), args, needs_substitution);
}

auto AddAll(SemIR::InstBlockId params, SemIR::InstBlockId args,
bool needs_substitution) -> void {
AddAll(context_.inst_blocks().Get(params), context_.inst_blocks().Get(args),
needs_substitution);
}

// Returns whether we have completed all deductions.
auto Done() -> bool { return deductions_.empty(); }

// Pops the next deduction. Requires `!Done()`.
auto PopNext() -> PendingDeduction { return deductions_.pop_back_val(); }

private:
Context& context_;
llvm::SmallVector<PendingDeduction> deductions_;
};
} // namespace

static auto NoteGenericHere(Context& context, SemIR::GenericId generic_id,
Context::DiagnosticBuilder& diag) -> void {
CARBON_DIAGNOSTIC(DeductionGenericHere, Note,
"While deducing parameters of generic declared here.");
diag.Note(context.generics().Get(generic_id).decl_id, DeductionGenericHere);
}

auto DeduceGenericCallArguments(
Context& context, Parse::NodeId node_id, SemIR::GenericId generic_id,
SemIR::SpecificId enclosing_specific_id,
[[maybe_unused]] SemIR::InstBlockId implicit_params_id,
SemIR::InstBlockId params_id, [[maybe_unused]] SemIR::InstId self_id,
llvm::ArrayRef<SemIR::InstId> arg_ids) -> SemIR::SpecificId {
DeductionWorklist worklist(context);

llvm::SmallVector<SemIR::InstId> result_arg_ids;
llvm::SmallVector<Substitution> substitutions;

// Copy any outer generic arguments from the specified instance and prepare to
// substitute them into the function declaration.
if (enclosing_specific_id.is_valid()) {
auto args = context.inst_blocks().Get(
context.specifics().Get(enclosing_specific_id).args_id);
result_arg_ids.assign(args.begin(), args.end());

// TODO: Subst is linear in the length of the substitutions list. Change it
// so we can pass in an array mapping indexes to substitutions instead.
substitutions.reserve(args.size());
for (auto [i, subst_inst_id] : llvm::enumerate(args)) {
substitutions.push_back(
{.bind_id = SemIR::CompileTimeBindIndex(i),
.replacement_id = context.constant_values().Get(subst_inst_id)});
}
}
auto first_deduced_index = SemIR::CompileTimeBindIndex(result_arg_ids.size());

// Initialize the deduced arguments to Invalid.
result_arg_ids.resize(context.inst_blocks()
.Get(context.generics().Get(generic_id).bindings_id)
.size(),
SemIR::InstId::Invalid);

// Prepare to perform deduction of the explicit parameters against their
// arguments.
// TODO: Also perform deduction for type of self.
worklist.AddAll(params_id, arg_ids, /*needs_substitution=*/true);

while (!worklist.Done()) {
auto [param_id, arg_id, needs_substitution] = worklist.PopNext();

// If the parameter has a symbolic type, deduce against that.
auto param_type_id = context.insts().Get(param_id).type_id();
if (param_type_id.AsConstantId().is_symbolic()) {
worklist.Add(
context.types().GetInstId(param_type_id),
context.types().GetInstId(context.insts().Get(arg_id).type_id()),
needs_substitution);
}

// If the parameter is a symbolic constant, deduce against it.
auto param_const_id = context.constant_values().Get(param_id);
if (!param_const_id.is_valid() || !param_const_id.is_symbolic()) {
continue;
}

// If we've not yet substituted into the parameter, do so now.
if (needs_substitution) {
param_const_id = SubstConstant(context, param_const_id, substitutions);
if (!param_const_id.is_valid() || !param_const_id.is_symbolic()) {
continue;
}
needs_substitution = false;
}

CARBON_KIND_SWITCH(context.insts().Get(context.constant_values().GetInstId(
param_const_id))) {
// Deducing a symbolic binding from an argument with a constant value
// deduces the binding as having that constant value.
case CARBON_KIND(SemIR::BindSymbolicName bind): {
auto& entity_name = context.entity_names().Get(bind.entity_name_id);
auto index = entity_name.bind_index;
if (index.is_valid() && index >= first_deduced_index) {
CARBON_CHECK(static_cast<size_t>(index.index) < result_arg_ids.size())
<< "Deduced value for unexpected index " << index
<< "; expected to deduce " << result_arg_ids.size()
<< " arguments.";
auto arg_const_inst_id =
context.constant_values().GetConstantInstId(arg_id);
if (arg_const_inst_id.is_valid()) {
if (result_arg_ids[index.index].is_valid() &&
result_arg_ids[index.index] != arg_const_inst_id) {
// TODO: Include the two different deduced values.
CARBON_DIAGNOSTIC(DeductionInconsistent, Error,
"Inconsistent deductions for value of generic "
"parameter `{0}`.",
SemIR::NameId);
auto diag = context.emitter().Build(
node_id, DeductionInconsistent, entity_name.name_id);
NoteGenericHere(context, generic_id, diag);
diag.Emit();
return SemIR::SpecificId::Invalid;
}
result_arg_ids[index.index] = arg_const_inst_id;
}
}
break;
}

// TODO: Handle more cases.

default:
break;
}
}

// Check we deduced an argument value for every parameter.
for (auto [i, deduced_arg_id] :
llvm::enumerate(llvm::ArrayRef(result_arg_ids)
.drop_front(first_deduced_index.index))) {
if (!deduced_arg_id.is_valid()) {
auto binding_index = first_deduced_index.index + i;
auto binding_id = context.inst_blocks().Get(
context.generics().Get(generic_id).bindings_id)[binding_index];
auto entity_name_id =
context.insts().GetAs<SemIR::AnyBindName>(binding_id).entity_name_id;
CARBON_DIAGNOSTIC(DeductionIncomplete, Error,
"Cannot deduce value for generic parameter `{0}`.",
SemIR::NameId);
auto diag = context.emitter().Build(
node_id, DeductionIncomplete,
context.entity_names().Get(entity_name_id).name_id);
NoteGenericHere(context, generic_id, diag);
diag.Emit();
return SemIR::SpecificId::Invalid;
}
}

// TODO: Convert the deduced values to the types of the bindings.

return MakeSpecific(context, generic_id,
context.inst_blocks().AddCanonical(result_arg_ids));
}

} // namespace Carbon::Check
25 changes: 25 additions & 0 deletions toolchain/check/deduce.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// Part of the Carbon Language project, under the Apache License v2.0 with LLVM
// Exceptions. See /LICENSE for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#ifndef CARBON_TOOLCHAIN_CHECK_DEDUCE_H_
#define CARBON_TOOLCHAIN_CHECK_DEDUCE_H_

#include "toolchain/check/context.h"
#include "toolchain/sem_ir/ids.h"

namespace Carbon::Check {

// Deduces the generic arguments to use in a call to a generic.
auto DeduceGenericCallArguments(Context& context, Parse::NodeId node_id,
SemIR::GenericId generic_id,
SemIR::SpecificId enclosing_specific_id,
SemIR::InstBlockId implicit_params_id,
SemIR::InstBlockId params_id,
SemIR::InstId self_id,
llvm::ArrayRef<SemIR::InstId> arg_ids)
-> SemIR::SpecificId;

} // namespace Carbon::Check

#endif // CARBON_TOOLCHAIN_CHECK_DEDUCE_H_
23 changes: 19 additions & 4 deletions toolchain/check/eval.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -305,19 +305,34 @@ static auto GetConstantValue(EvalContext& eval_context,
return eval_context.inst_blocks().AddCanonical(const_insts);
}

// The constant value of a type block is that type block, but we still need to
// extract its phase.
// Compute the constant value of a type block. This may be different from the
// input type block if we have known generic arguments.
static auto GetConstantValue(EvalContext& eval_context,
SemIR::TypeBlockId type_block_id, Phase* phase)
-> SemIR::TypeBlockId {
if (!type_block_id.is_valid()) {
return SemIR::TypeBlockId::Invalid;
}
auto types = eval_context.type_blocks().Get(type_block_id);
llvm::SmallVector<SemIR::TypeId> new_types;
for (auto type_id : types) {
GetConstantValue(eval_context, type_id, phase);
auto new_type_id = GetConstantValue(eval_context, type_id, phase);
if (!new_type_id.is_valid()) {
return SemIR::TypeBlockId::Invalid;
}

// Once we leave the small buffer, we know the first few elements are all
// constant, so it's likely that the entire block is constant. Resize to the
// target size given that we're going to allocate memory now anyway.
if (new_types.size() == new_types.capacity()) {
new_types.reserve(types.size());
}

new_types.push_back(new_type_id);
}
return type_block_id;
// TODO: If the new block is identical to the original block, and we know the
// old ID was canonical, return the original ID.
return eval_context.type_blocks().AddCanonical(new_types);
}

// The constant value of a specific is the specific with the corresponding
Expand Down
Loading

0 comments on commit 32f3191

Please sign in to comment.