Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial rough support for deducing generic arguments in a call to a generic function. #4184

Merged
merged 2 commits into from
Aug 2, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions toolchain/check/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ cc_library(
"context.cpp",
"convert.cpp",
"decl_name_stack.cpp",
"deduce.cpp",
"eval.cpp",
"function.cpp",
"generic.cpp",
Expand All @@ -36,6 +37,7 @@ cc_library(
"convert.h",
"decl_introducer_state.h",
"decl_name_stack.h",
"deduce.h",
"diagnostic_helpers.h",
"eval.h",
"function.h",
Expand Down
21 changes: 10 additions & 11 deletions toolchain/check/call.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "toolchain/base/kind_switch.h"
#include "toolchain/check/context.h"
#include "toolchain/check/convert.h"
#include "toolchain/check/deduce.h"
#include "toolchain/check/function.h"
#include "toolchain/check/generic.h"
#include "toolchain/sem_ir/ids.h"
Expand Down Expand Up @@ -98,18 +99,16 @@ auto PerformCall(Context& context, Parse::NodeId node_id,
}
auto& callable = context.functions().Get(callee_function.function_id);

// TODO: Properly determine the generic argument values for the call. For now,
// we do so only if the function introduces no generic parameters beyond those
// of the enclosing context.
// If the callee is a generic function, determine the generic argument values
// for the call.
auto specific_id = SemIR::SpecificId::Invalid;
if (callee_function.specific_id.is_valid()) {
auto enclosing_args_id =
context.specifics().Get(callee_function.specific_id).args_id;
auto fn_params_id = context.generics().Get(callable.generic_id).bindings_id;
if (context.inst_blocks().Get(fn_params_id).size() ==
context.inst_blocks().Get(enclosing_args_id).size()) {
specific_id =
MakeSpecific(context, callable.generic_id, enclosing_args_id);
if (callable.generic_id.is_valid()) {
specific_id = DeduceGenericCallArguments(
context, node_id, callable.generic_id, callee_function.specific_id,
callable.implicit_param_refs_id, callable.param_refs_id,
callee_function.self_id, arg_ids);
if (!specific_id.is_valid()) {
return SemIR::InstId::BuiltinError;
}
}

Expand Down
196 changes: 196 additions & 0 deletions toolchain/check/deduce.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
// Part of the Carbon Language project, under the Apache License v2.0 with LLVM
// Exceptions. See /LICENSE for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "toolchain/check/deduce.h"

#include "toolchain/base/kind_switch.h"
#include "toolchain/check/context.h"
#include "toolchain/check/generic.h"
#include "toolchain/check/subst.h"
#include "toolchain/sem_ir/typed_insts.h"

namespace Carbon::Check {

namespace {
struct DeductionWorklist {
geoffromer marked this conversation as resolved.
Show resolved Hide resolved
auto Add(SemIR::InstId param, SemIR::InstId arg, bool needs_substitution)
-> void {
deductions.push_back(
{.param = param, .arg = arg, .needs_substitution = needs_substitution});
}

auto AddBlock(llvm::ArrayRef<SemIR::InstId> params,
geoffromer marked this conversation as resolved.
Show resolved Hide resolved
llvm::ArrayRef<SemIR::InstId> args, bool needs_substitution)
-> void {
if (params.size() != args.size()) {
return;
geoffromer marked this conversation as resolved.
Show resolved Hide resolved
}
for (auto [param, arg] : llvm::zip_equal(params, args)) {
Add(param, arg, needs_substitution);
}
}

auto AddBlock(SemIR::InstBlockId params, llvm::ArrayRef<SemIR::InstId> args,
bool needs_substitution) -> void {
AddBlock(context.inst_blocks().Get(params), args, needs_substitution);
}

auto AddBlock(SemIR::InstBlockId params, SemIR::InstBlockId args,
bool needs_substitution) -> void {
AddBlock(context.inst_blocks().Get(params), context.inst_blocks().Get(args),
needs_substitution);
}

struct PendingDeduction {
SemIR::InstId param;
SemIR::InstId arg;
bool needs_substitution;
};
Context& context;
llvm::SmallVector<PendingDeduction> deductions;
};
} // namespace

static auto NoteGenericHere(Context& context, SemIR::GenericId generic_id,
Context::DiagnosticBuilder& diag) -> void {
CARBON_DIAGNOSTIC(DeductionGenericHere, Note,
"While deducing parameters of generic declared here.");
diag.Note(context.generics().Get(generic_id).decl_id, DeductionGenericHere);
}

auto DeduceGenericCallArguments(Context& context, Parse::NodeId node_id,
SemIR::GenericId generic_id,
SemIR::SpecificId enclosing_specific_id,
SemIR::InstBlockId implicit_params_id,
SemIR::InstBlockId params_id,
SemIR::InstId self_id,
llvm::ArrayRef<SemIR::InstId> arg_ids)
-> SemIR::SpecificId {
DeductionWorklist worklist = {.context = context};

// TODO: Perform deduction for type of self
static_cast<void>(implicit_params_id);
static_cast<void>(self_id);
geoffromer marked this conversation as resolved.
Show resolved Hide resolved

// Copy any outer generic arguments from the specified instance and prepare to
// substitute them into the function declaration.
llvm::SmallVector<SemIR::InstId> results;
geoffromer marked this conversation as resolved.
Show resolved Hide resolved
llvm::SmallVector<Substitution> substitutions;
if (enclosing_specific_id.is_valid()) {
geoffromer marked this conversation as resolved.
Show resolved Hide resolved
geoffromer marked this conversation as resolved.
Show resolved Hide resolved
auto args = context.inst_blocks().Get(
context.specifics().Get(enclosing_specific_id).args_id);
results.assign(args.begin(), args.end());

// TODO: Subst is linear in the length of the substitutions list. Change it
// so we can pass in an array mapping indexes to substitutions instead.
substitutions.reserve(args.size());
for (auto [i, subst_inst_id] : llvm::enumerate(args)) {
substitutions.push_back(
{.bind_id = SemIR::CompileTimeBindIndex(i),
.replacement_id = context.constant_values().Get(subst_inst_id)});
}
}
auto first_deduced_index = SemIR::CompileTimeBindIndex(results.size());

worklist.AddBlock(params_id, arg_ids, /*needs_substitution=*/true);

results.resize(context.inst_blocks()
.Get(context.generics().Get(generic_id).bindings_id)
.size(),
SemIR::InstId::Invalid);

geoffromer marked this conversation as resolved.
Show resolved Hide resolved
while (!worklist.deductions.empty()) {
auto [param_id, arg_id, needs_substitution] =
worklist.deductions.pop_back_val();

// If the parameter has a symbolic type, deduce against that.
auto param_type_id = context.insts().Get(param_id).type_id();
if (param_type_id.AsConstantId().is_symbolic()) {
worklist.Add(
context.types().GetInstId(param_type_id),
context.types().GetInstId(context.insts().Get(arg_id).type_id()),
needs_substitution);
}

// If the parameter is a symbolic constant, deduce against it.
auto param_const_id = context.constant_values().Get(param_id);
if (!param_const_id.is_valid() || !param_const_id.is_symbolic()) {
continue;
}

// If we've not yet substituted into the parameter, do so now.
if (needs_substitution) {
param_const_id = SubstConstant(context, param_const_id, substitutions);
if (!param_const_id.is_valid() || !param_const_id.is_symbolic()) {
continue;
}
needs_substitution = false;
}

CARBON_KIND_SWITCH(context.insts().Get(context.constant_values().GetInstId(
param_const_id))) {
// Deducing a symbolic binding from an argument with a constant value
// deduces the binding as having that constant value.
case CARBON_KIND(SemIR::BindSymbolicName bind): {
auto& entity_name = context.entity_names().Get(bind.entity_name_id);
auto index = entity_name.bind_index;
if (index.is_valid() && index >= first_deduced_index) {
CARBON_CHECK(static_cast<size_t>(index.index) < results.size())
<< "Deduced value for unexpected index " << index
<< "; expected to deduce " << results.size() << " arguments.";
auto arg_const_inst_id =
context.constant_values().GetConstantInstId(arg_id);
if (arg_const_inst_id.is_valid()) {
if (results[index.index].is_valid() &&
results[index.index] != arg_const_inst_id) {
// TODO: Include the two different deduced values.
CARBON_DIAGNOSTIC(DeductionInconsistent, Error,
"Inconsistent deductions for value of generic "
"parameter `{0}`.",
SemIR::NameId);
auto diag = context.emitter().Build(
node_id, DeductionInconsistent, entity_name.name_id);
NoteGenericHere(context, generic_id, diag);
diag.Emit();
return SemIR::SpecificId::Invalid;
}
results[index.index] = arg_const_inst_id;
}
}
break;
}

// TODO: Handle more cases.

default:
break;
}
}

// Check we deduced an argument value for every parameter.
for (auto [i, deduced_arg_id] : llvm::enumerate(results)) {
if (!deduced_arg_id.is_valid()) {
auto binding_id = context.inst_blocks().Get(
geoffromer marked this conversation as resolved.
Show resolved Hide resolved
context.generics().Get(generic_id).bindings_id)[i];
auto entity_name_id =
context.insts().GetAs<SemIR::AnyBindName>(binding_id).entity_name_id;
CARBON_DIAGNOSTIC(DeductionIncomplete, Error,
"Cannot deduce value for generic parameter `{0}`.",
SemIR::NameId);
auto diag = context.emitter().Build(
node_id, DeductionIncomplete,
context.entity_names().Get(entity_name_id).name_id);
NoteGenericHere(context, generic_id, diag);
diag.Emit();
return SemIR::SpecificId::Invalid;
}
}

// TODO: Convert the deduced values to the types of the bindings.

return MakeSpecific(context, generic_id,
context.inst_blocks().AddCanonical(results));
}

} // namespace Carbon::Check
25 changes: 25 additions & 0 deletions toolchain/check/deduce.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// Part of the Carbon Language project, under the Apache License v2.0 with LLVM
// Exceptions. See /LICENSE for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#ifndef CARBON_TOOLCHAIN_CHECK_DEDUCE_H_
#define CARBON_TOOLCHAIN_CHECK_DEDUCE_H_

#include "toolchain/check/context.h"
#include "toolchain/sem_ir/ids.h"

namespace Carbon::Check {

// Deduces the generic arguments to use in a call to a generic.
auto DeduceGenericCallArguments(Context& context, Parse::NodeId node_id,
SemIR::GenericId generic_id,
SemIR::SpecificId enclosing_specific_id,
SemIR::InstBlockId implicit_params_id,
SemIR::InstBlockId params_id,
SemIR::InstId self_id,
llvm::ArrayRef<SemIR::InstId> arg_ids)
-> SemIR::SpecificId;

} // namespace Carbon::Check

#endif // CARBON_TOOLCHAIN_CHECK_DEDUCE_H_
23 changes: 19 additions & 4 deletions toolchain/check/eval.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -305,19 +305,34 @@ static auto GetConstantValue(EvalContext& eval_context,
return eval_context.inst_blocks().AddCanonical(const_insts);
}

// The constant value of a type block is that type block, but we still need to
// extract its phase.
// Compute the constant value of a type block. This may be different from the
// input type block if we have known generic arguments.
static auto GetConstantValue(EvalContext& eval_context,
SemIR::TypeBlockId type_block_id, Phase* phase)
-> SemIR::TypeBlockId {
if (!type_block_id.is_valid()) {
return SemIR::TypeBlockId::Invalid;
}
auto types = eval_context.type_blocks().Get(type_block_id);
llvm::SmallVector<SemIR::TypeId> new_types;
for (auto type_id : types) {
GetConstantValue(eval_context, type_id, phase);
auto new_type_id = GetConstantValue(eval_context, type_id, phase);
if (!new_type_id.is_valid()) {
return SemIR::TypeBlockId::Invalid;
}

// Once we leave the small buffer, we know the first few elements are all
// constant, so it's likely that the entire block is constant. Resize to the
// target size given that we're going to allocate memory now anyway.
if (new_types.size() == new_types.capacity()) {
new_types.reserve(types.size());
}

new_types.push_back(new_type_id);
}
return type_block_id;
// TODO: If the new block is identical to the original block, and we know the
// old ID was canonical, return the original ID.
return eval_context.type_blocks().AddCanonical(new_types);
}

// The constant value of a specific is the specific with the corresponding
Expand Down
Loading
Loading