Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow closure to unsafe fn coercion #59580

Merged
merged 1 commit into from
Mar 31, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/librustc/middle/expr_use_visitor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -707,7 +707,7 @@ impl<'a, 'gcx, 'tcx> ExprUseVisitor<'a, 'gcx, 'tcx> {
adjustment::Adjust::NeverToAny |
adjustment::Adjust::ReifyFnPointer |
adjustment::Adjust::UnsafeFnPointer |
adjustment::Adjust::ClosureFnPointer |
adjustment::Adjust::ClosureFnPointer(_) |
adjustment::Adjust::MutToConstPointer |
adjustment::Adjust::Unsize => {
// Creating a closure/fn-pointer or unsizing consumes
Expand Down
2 changes: 1 addition & 1 deletion src/librustc/middle/mem_categorization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -621,7 +621,7 @@ impl<'a, 'gcx, 'tcx> MemCategorizationContext<'a, 'gcx, 'tcx> {
adjustment::Adjust::NeverToAny |
adjustment::Adjust::ReifyFnPointer |
adjustment::Adjust::UnsafeFnPointer |
adjustment::Adjust::ClosureFnPointer |
adjustment::Adjust::ClosureFnPointer(_) |
adjustment::Adjust::MutToConstPointer |
adjustment::Adjust::Borrow(_) |
adjustment::Adjust::Unsize => {
Expand Down
5 changes: 3 additions & 2 deletions src/librustc/mir/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2247,8 +2247,9 @@ pub enum CastKind {
/// Converts unique, zero-sized type for a fn to fn()
ReifyFnPointer,

/// Converts non capturing closure to fn()
ClosureFnPointer,
taiki-e marked this conversation as resolved.
Show resolved Hide resolved
/// Converts non capturing closure to fn() or unsafe fn().
/// It cannot convert a closure that requires unsafe.
ClosureFnPointer(hir::Unsafety),

/// Converts safe fn() to unsafe fn()
UnsafeFnPointer,
Expand Down
5 changes: 3 additions & 2 deletions src/librustc/ty/adjustment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,9 @@ pub enum Adjust<'tcx> {
/// Go from a safe fn pointer to an unsafe fn pointer.
UnsafeFnPointer,

/// Go from a non-capturing closure to an fn pointer.
ClosureFnPointer,
/// Go from a non-capturing closure to an fn pointer or an unsafe fn pointer.
/// It cannot convert a closure that requires unsafe.
ClosureFnPointer(hir::Unsafety),

/// Go from a mut raw pointer to a const raw pointer.
MutToConstPointer,
Expand Down
8 changes: 6 additions & 2 deletions src/librustc/ty/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2441,7 +2441,11 @@ impl<'a, 'gcx, 'tcx> TyCtxt<'a, 'gcx, 'tcx> {
/// type with the same signature. Detuples and so forth -- so
/// e.g., if we have a sig with `Fn<(u32, i32)>` then you would get
/// a `fn(u32, i32)`.
taiki-e marked this conversation as resolved.
Show resolved Hide resolved
pub fn coerce_closure_fn_ty(self, sig: PolyFnSig<'tcx>) -> Ty<'tcx> {
/// `unsafety` determines the unsafety of the `fn` type. If you pass
/// `hir::Unsafety::Unsafe` in the previous example, then you would get
/// an `unsafe fn (u32, i32)`.
/// It cannot convert a closure that requires unsafe.
pub fn coerce_closure_fn_ty(self, sig: PolyFnSig<'tcx>, unsafety: hir::Unsafety) -> Ty<'tcx> {
let converted_sig = sig.map_bound(|s| {
let params_iter = match s.inputs()[0].sty {
ty::Tuple(params) => {
Expand All @@ -2453,7 +2457,7 @@ impl<'a, 'gcx, 'tcx> TyCtxt<'a, 'gcx, 'tcx> {
params_iter,
s.output(),
s.c_variadic,
hir::Unsafety::Normal,
unsafety,
abi::Abi::Rust,
)
});
Expand Down
6 changes: 3 additions & 3 deletions src/librustc/ty/structural_impls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -630,8 +630,8 @@ impl<'a, 'tcx> Lift<'tcx> for ty::adjustment::Adjust<'a> {
Some(ty::adjustment::Adjust::ReifyFnPointer),
ty::adjustment::Adjust::UnsafeFnPointer =>
Some(ty::adjustment::Adjust::UnsafeFnPointer),
ty::adjustment::Adjust::ClosureFnPointer =>
Some(ty::adjustment::Adjust::ClosureFnPointer),
ty::adjustment::Adjust::ClosureFnPointer(unsafety) =>
Some(ty::adjustment::Adjust::ClosureFnPointer(unsafety)),
ty::adjustment::Adjust::MutToConstPointer =>
Some(ty::adjustment::Adjust::MutToConstPointer),
ty::adjustment::Adjust::Unsize =>
Expand Down Expand Up @@ -1187,7 +1187,7 @@ EnumTypeFoldableImpl! {
(ty::adjustment::Adjust::NeverToAny),
(ty::adjustment::Adjust::ReifyFnPointer),
(ty::adjustment::Adjust::UnsafeFnPointer),
(ty::adjustment::Adjust::ClosureFnPointer),
(ty::adjustment::Adjust::ClosureFnPointer)(a),
(ty::adjustment::Adjust::MutToConstPointer),
(ty::adjustment::Adjust::Unsize),
(ty::adjustment::Adjust::Deref)(a),
Expand Down
2 changes: 1 addition & 1 deletion src/librustc_codegen_ssa/mir/rvalue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ impl<'a, 'tcx: 'a, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
}
}
}
mir::CastKind::ClosureFnPointer => {
mir::CastKind::ClosureFnPointer(_) => {
match operand.layout.ty.sty {
ty::Closure(def_id, substs) => {
let instance = monomorphize::resolve_closure(
Expand Down
4 changes: 2 additions & 2 deletions src/librustc_mir/borrow_check/nll/type_check/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1999,14 +1999,14 @@ impl<'a, 'gcx, 'tcx> TypeChecker<'a, 'gcx, 'tcx> {
}
}

CastKind::ClosureFnPointer => {
CastKind::ClosureFnPointer(unsafety) => {
let sig = match op.ty(mir, tcx).sty {
ty::Closure(def_id, substs) => {
substs.closure_sig_ty(def_id, tcx).fn_sig(tcx)
}
_ => bug!(),
};
let ty_fn_ptr_from = tcx.coerce_closure_fn_ty(sig);
let ty_fn_ptr_from = tcx.coerce_closure_fn_ty(sig, *unsafety);

if let Err(terr) = self.eq_types(
ty_fn_ptr_from,
Expand Down
4 changes: 2 additions & 2 deletions src/librustc_mir/build/expr/as_rvalue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -162,9 +162,9 @@ impl<'a, 'gcx, 'tcx> Builder<'a, 'gcx, 'tcx> {
let source = unpack!(block = this.as_operand(block, scope, source));
block.and(Rvalue::Cast(CastKind::UnsafeFnPointer, source, expr.ty))
}
ExprKind::ClosureFnPointer { source } => {
ExprKind::ClosureFnPointer { source, unsafety } => {
let source = unpack!(block = this.as_operand(block, scope, source));
block.and(Rvalue::Cast(CastKind::ClosureFnPointer, source, expr.ty))
block.and(Rvalue::Cast(CastKind::ClosureFnPointer(unsafety), source, expr.ty))
}
ExprKind::MutToConstPointer { source } => {
let source = unpack!(block = this.as_operand(block, scope, source));
Expand Down
4 changes: 2 additions & 2 deletions src/librustc_mir/hair/cx/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ fn apply_adjustment<'a, 'gcx, 'tcx>(cx: &mut Cx<'a, 'gcx, 'tcx>,
Adjust::UnsafeFnPointer => {
ExprKind::UnsafeFnPointer { source: expr.to_ref() }
}
Adjust::ClosureFnPointer => {
ExprKind::ClosureFnPointer { source: expr.to_ref() }
Adjust::ClosureFnPointer(unsafety) => {
ExprKind::ClosureFnPointer { source: expr.to_ref(), unsafety }
}
Adjust::NeverToAny => {
ExprKind::NeverToAny { source: expr.to_ref() }
Expand Down
1 change: 1 addition & 0 deletions src/librustc_mir/hair/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ pub enum ExprKind<'tcx> {
},
ClosureFnPointer {
source: ExprRef<'tcx>,
unsafety: hir::Unsafety,
},
UnsafeFnPointer {
source: ExprRef<'tcx>,
Expand Down
2 changes: 1 addition & 1 deletion src/librustc_mir/interpret/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ impl<'a, 'mir, 'tcx, M: Machine<'a, 'mir, 'tcx>> InterpretCx<'a, 'mir, 'tcx, M>
}
}

ClosureFnPointer => {
ClosureFnPointer(_) => {
// The src operand does not matter, just its type
match src.layout.ty.sty {
ty::Closure(def_id, substs) => {
Expand Down
2 changes: 1 addition & 1 deletion src/librustc_mir/monomorphize/collector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -563,7 +563,7 @@ impl<'a, 'tcx> MirVisitor<'tcx> for MirNeighborCollector<'a, 'tcx> {
);
visit_fn_use(self.tcx, fn_ty, false, &mut self.output);
}
mir::Rvalue::Cast(mir::CastKind::ClosureFnPointer, ref operand, _) => {
mir::Rvalue::Cast(mir::CastKind::ClosureFnPointer(_), ref operand, _) => {
let source_ty = operand.ty(self.mir, self.tcx);
let source_ty = self.tcx.subst_and_normalize_erasing_regions(
self.param_substs,
Expand Down
2 changes: 1 addition & 1 deletion src/librustc_mir/transform/qualify_consts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1105,7 +1105,7 @@ impl<'a, 'tcx> Visitor<'tcx> for Checker<'a, 'tcx> {
Rvalue::CheckedBinaryOp(..) |
Rvalue::Cast(CastKind::ReifyFnPointer, ..) |
Rvalue::Cast(CastKind::UnsafeFnPointer, ..) |
Rvalue::Cast(CastKind::ClosureFnPointer, ..) |
Rvalue::Cast(CastKind::ClosureFnPointer(_), ..) |
Rvalue::Cast(CastKind::Unsize, ..) |
Rvalue::Cast(CastKind::MutToConstPointer, ..) |
Rvalue::Discriminant(..) |
Expand Down
2 changes: 1 addition & 1 deletion src/librustc_mir/transform/qualify_min_const_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ fn check_rvalue(
check_operand(tcx, mir, operand, span)
}
Rvalue::Cast(CastKind::UnsafeFnPointer, _, _) |
Rvalue::Cast(CastKind::ClosureFnPointer, _, _) |
Rvalue::Cast(CastKind::ClosureFnPointer(_), _, _) |
Rvalue::Cast(CastKind::ReifyFnPointer, _, _) => Err((
span,
"function pointer casts are not allowed in const fn".into(),
Expand Down
2 changes: 1 addition & 1 deletion src/librustc_passes/rvalue_promotion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -586,7 +586,7 @@ fn check_adjustments<'a, 'tcx>(
Adjust::NeverToAny |
Adjust::ReifyFnPointer |
Adjust::UnsafeFnPointer |
Adjust::ClosureFnPointer |
Adjust::ClosureFnPointer(_) |
Adjust::MutToConstPointer |
Adjust::Borrow(_) |
Adjust::Unsize => {}
Expand Down
12 changes: 8 additions & 4 deletions src/librustc_typeck/check/coercion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,8 @@ impl<'f, 'gcx, 'tcx> Coerce<'f, 'gcx, 'tcx> {
}
ty::Closure(def_id_a, substs_a) => {
// Non-capturing closures are coercible to
// function pointers
// function pointers or unsafe function pointers.
// It cannot convert closures that require unsafe.
self.coerce_closure_to_fn(a, def_id_a, substs_a, b)
}
_ => {
Expand Down Expand Up @@ -714,16 +715,19 @@ impl<'f, 'gcx, 'tcx> Coerce<'f, 'gcx, 'tcx> {

let hir_id_a = self.tcx.hir().as_local_hir_id(def_id_a).unwrap();
match b.sty {
ty::FnPtr(_) if self.tcx.with_freevars(hir_id_a, |v| v.is_empty()) => {
ty::FnPtr(fn_ty) if self.tcx.with_freevars(hir_id_a, |v| v.is_empty()) => {
// We coerce the closure, which has fn type
// `extern "rust-call" fn((arg0,arg1,...)) -> _`
// to
// `fn(arg0,arg1,...) -> _`
taiki-e marked this conversation as resolved.
Show resolved Hide resolved
// or
// `unsafe fn(arg0,arg1,...) -> _`
let sig = self.closure_sig(def_id_a, substs_a);
let pointer_ty = self.tcx.coerce_closure_fn_ty(sig);
let unsafety = fn_ty.unsafety();
let pointer_ty = self.tcx.coerce_closure_fn_ty(sig, unsafety);
debug!("coerce_closure_to_fn(a={:?}, b={:?}, pty={:?})",
a, b, pointer_ty);
self.unify_and(pointer_ty, b, simple(Adjust::ClosureFnPointer))
self.unify_and(pointer_ty, b, simple(Adjust::ClosureFnPointer(unsafety)))
}
_ => self.unify_and(a, b, identity),
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
fn main() {
let _: unsafe fn() = || { ::std::pin::Pin::new_unchecked(&0_u8); };
//~^ ERROR E0133
let _: unsafe fn() = || unsafe { ::std::pin::Pin::new_unchecked(&0_u8); }; // OK
}
7 changes: 7 additions & 0 deletions src/test/run-pass/typeck-closure-to-unsafe-fn-ptr.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
unsafe fn call_unsafe(func: unsafe fn() -> ()) -> () {
func()
}

pub fn main() {
unsafe { call_unsafe(|| {}); }
taiki-e marked this conversation as resolved.
Show resolved Hide resolved
}