Skip to content

Commit

Permalink
Merge pull request #996 from Nadrieril/cache-trait-res
Browse files Browse the repository at this point in the history
Cache many things
  • Loading branch information
W95Psp authored Oct 14, 2024
2 parents 1ba7ce3 + 3575efc commit 050449c
Show file tree
Hide file tree
Showing 14 changed files with 784 additions and 729 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ readme = "README.md"
itertools = "0.11.0"
schemars = "0.8"
which = "4.4"
serde = { version = "1.0", features = ["derive"] }
serde = { version = "1.0", features = ["derive", "rc"] }
serde_json = "1.0"
clap = { version = "4.0", features = ["derive"] }
syn = { version = "1.0.107", features = [
Expand Down
5 changes: 0 additions & 5 deletions cli/driver/src/driver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,6 @@ fn setup_logging() {
};
let subscriber = tracing_subscriber::Registry::default()
.with(tracing_subscriber::EnvFilter::from_default_env())
.with(
tracing_subscriber::fmt::layer()
.with_file(true)
.with_line_number(true),
)
.with(
tracing_tree::HierarchicalLayer::new(2)
.with_ansi(enable_colors)
Expand Down
41 changes: 20 additions & 21 deletions cli/driver/src/exporter.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use hax_frontend_exporter::state::{ExportedSpans, LocalContextS};
use hax_frontend_exporter::state::LocalContextS;
use hax_frontend_exporter::SInto;
use hax_types::cli_options::{Backend, PathOrDash, ENV_VAR_OPTIONS_FRONTEND};
use rustc_driver::{Callbacks, Compilation};
Expand Down Expand Up @@ -46,8 +46,7 @@ fn dummy_thir_body(
/// stealing issues (theoretically...)
fn precompute_local_thir_bodies(
tcx: TyCtxt<'_>,
) -> std::collections::HashMap<rustc_span::def_id::LocalDefId, ThirBundle<'_>> {
let hir = tcx.hir();
) -> impl Iterator<Item = (rustc_hir::def_id::DefId, ThirBundle<'_>)> {
use rustc_hir::def::DefKind::*;
use rustc_hir::*;

Expand All @@ -74,7 +73,7 @@ fn precompute_local_thir_bodies(
}

use itertools::Itertools;
hir.body_owners()
tcx.hir().body_owners()
.filter(|ldid| {
match tcx.def_kind(ldid.to_def_id()) {
InlineConst | AnonConst => {
Expand All @@ -91,10 +90,10 @@ fn precompute_local_thir_bodies(
}
})
.sorted_by_key(|ldid| const_level_of(tcx, *ldid))
.filter(|ldid| hir.maybe_body_owned_by(*ldid).is_some())
.map(|ldid| {
.filter(move |ldid| tcx.hir().maybe_body_owned_by(*ldid).is_some())
.map(move |ldid| {
tracing::debug!("⏳ Type-checking THIR body for {:#?}", ldid);
let span = hir.span(tcx.local_def_id_to_hir_id(ldid));
let span = tcx.hir().span(tcx.local_def_id_to_hir_id(ldid));
let (thir, expr) = match tcx.thir_body(ldid) {
Ok(x) => x,
Err(e) => {
Expand All @@ -117,7 +116,7 @@ fn precompute_local_thir_bodies(
tracing::debug!("✅ Type-checked THIR body for {:#?}", ldid);
(ldid, (Rc::new(thir), expr))
})
.collect()
.map(|(ldid, bundle)| (ldid.to_def_id(), bundle))
}

/// Browse a crate and translate every item from HIR+THIR to "THIR'"
Expand All @@ -136,27 +135,27 @@ fn convert_thir<'tcx, Body: hax_frontend_exporter::IsBody>(
)>,
Vec<hax_frontend_exporter::Item<Body>>,
) {
use hax_frontend_exporter::WithGlobalCacheExt;
let mut state = hax_frontend_exporter::state::State::new(tcx, options.clone());
state.base.macro_infos = Rc::new(macro_calls);
state.base.cached_thirs = Rc::new(precompute_local_thir_bodies(tcx));
for (def_id, thir) in precompute_local_thir_bodies(tcx) {
state.with_item_cache(def_id, |caches| caches.thir = Some(thir));
}

let result = hax_frontend_exporter::inline_macro_invocations(tcx.hir().items(), &state);
let impl_infos = hax_frontend_exporter::impl_def_ids_to_impled_types_and_bounds(&state)
.into_iter()
.collect();
let exported_spans = state.base.exported_spans.borrow().clone();
let exported_spans = state.with_global_cache(|cache| cache.spans.keys().copied().collect());
let exported_def_ids = state.with_global_cache(|cache| {
cache
.per_item
.values()
.filter_map(|per_item_cache| per_item_cache.def_id.clone())
.collect()
});

let exported_def_ids = {
let def_ids = state.base.exported_def_ids.borrow();
let state = hax_frontend_exporter::state::State::new(tcx, options.clone());
def_ids.iter().map(|did| did.sinto(&state)).collect()
};
(
exported_spans.into_iter().collect(),
exported_def_ids,
impl_infos,
result,
)
(exported_spans, exported_def_ids, impl_infos, result)
}

/// Collect a map from spans to macro calls
Expand Down
12 changes: 7 additions & 5 deletions engine/lib/concrete_ident/concrete_ident.ml
Original file line number Diff line number Diff line change
Expand Up @@ -284,8 +284,8 @@ module View = struct
last.data |> Imported.of_def_path_item |> string_of_def_path_item
|> Option.map ~f:escape
in
let arity0 =
Option.map ~f:escape << function
let arity0 ty =
match ty.Types.kind with
| Types.Bool -> Some "bool"
| Char -> Some "char"
| Str -> Some "str"
Expand All @@ -305,11 +305,13 @@ module View = struct
| Float F32 -> Some "f32"
| Float F64 -> Some "f64"
| Tuple [] -> Some "unit"
| Adt { def_id; generic_args = []; _ } -> adt def_id
| Adt { def_id; generic_args = []; _ } ->
Option.map ~f:escape (adt def_id)
| _ -> None
in
let apply left right = left ^ "_of_" ^ right in
let rec arity1 = function
let rec arity1 ty =
match ty.Types.kind with
| Types.Slice sub -> arity1 sub |> Option.map ~f:(apply "slice")
| Ref (_, sub, _) -> arity1 sub |> Option.map ~f:(apply "ref")
| Adt { def_id; generic_args = [ Type arg ]; _ } ->
Expand All @@ -319,7 +321,7 @@ module View = struct
| Tuple l ->
let* l = List.map ~f:arity0 l |> Option.all in
Some ("tuple_" ^ String.concat ~sep:"_" l)
| otherwise -> arity0 otherwise
| _ -> arity0 ty
in
arity1

Expand Down
4 changes: 2 additions & 2 deletions engine/lib/import_thir.ml
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ end) : EXPR = struct
in
(* if there is no expression & the last expression is ⊥, just use that *)
let lift_last_statement_as_expr_if_possible expr stmts (ty : Thir.ty) =
match (ty, expr, List.drop_last stmts, List.last stmts) with
match (ty.kind, expr, List.drop_last stmts, List.last stmts) with
| ( Thir.Never,
None,
Some stmts,
Expand Down Expand Up @@ -981,7 +981,7 @@ end) : EXPR = struct
("Pointer, with [cast] being " ^ [%show: Thir.pointer_coercion] cast)

and c_ty (span : Thir.span) (ty : Thir.ty) : ty =
match ty with
match ty.kind with
| Bool -> TBool
| Char -> TChar
| Int k -> TInt (c_int_ty k)
Expand Down
8 changes: 5 additions & 3 deletions frontend/exporter/src/body.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@ mod module {
Rc<rustc_middle::thir::Thir<'tcx>>,
rustc_middle::thir::ExprId,
) {
let base = s.base();
let msg = || fatal!(s[base.tcx.def_span(did)], "THIR not found for {:?}", did);
base.cached_thirs.get(&did).unwrap_or_else(msg).clone()
let tcx = s.base().tcx;
s.with_item_cache(did.to_def_id(), |caches| {
let msg = || fatal!(s[tcx.def_span(did)], "THIR not found for {:?}", did);
caches.thir.as_ref().unwrap_or_else(msg).clone()
})
}

pub trait IsBody: Sized + Clone + 'static {
Expand Down
14 changes: 7 additions & 7 deletions frontend/exporter/src/constant_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -584,9 +584,9 @@ mod rustc {
assert!(dc.variant.is_none());

// The type should be tuple
let hax_ty = ty.sinto(s);
match &hax_ty {
Ty::Tuple(_) => (),
let hax_ty: Ty = ty.sinto(s);
match hax_ty.kind() {
TyKind::Tuple(_) => (),
_ => {
fatal!(s[span], "Expected the type to be tuple: {:?}", val)
}
Expand Down Expand Up @@ -632,12 +632,12 @@ mod rustc {
}
ConstValue::ZeroSized { .. } => {
// Should be unit
let hty = ty.sinto(s);
let cv = match &hty {
Ty::Tuple(tys) if tys.is_empty() => {
let hty: Ty = ty.sinto(s);
let cv = match hty.kind() {
TyKind::Tuple(tys) if tys.is_empty() => {
ConstantExprKind::Tuple { fields: Vec::new() }
}
Ty::Arrow(_) => match ty.kind() {
TyKind::Arrow(_) => match ty.kind() {
rustc_middle::ty::TyKind::FnDef(def_id, args) => {
let (def_id, generics, generics_impls, method_impl) =
get_function_from_def_id_and_generics(s, *def_id, args);
Expand Down
7 changes: 5 additions & 2 deletions frontend/exporter/src/rustc_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,11 @@ impl<'tcx, T: ty::TypeFoldable<ty::TyCtxt<'tcx>>> ty::Binder<'tcx, T> {
}

#[tracing::instrument(skip(s))]
pub(crate) fn arrow_of_sig<'tcx, S: UnderOwnerState<'tcx>>(sig: &ty::PolyFnSig<'tcx>, s: &S) -> Ty {
Ty::Arrow(Box::new(sig.sinto(s)))
pub(crate) fn arrow_of_sig<'tcx, S: UnderOwnerState<'tcx>>(
sig: &ty::PolyFnSig<'tcx>,
s: &S,
) -> TyKind {
TyKind::Arrow(Box::new(sig.sinto(s)))
}

#[tracing::instrument(skip(s))]
Expand Down
85 changes: 59 additions & 26 deletions frontend/exporter/src/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,8 @@ macro_rules! mk {

mod types {
use crate::prelude::*;
use rustc_middle::ty;
use std::cell::RefCell;
use std::collections::HashSet;

pub struct LocalContextS {
pub vars: HashMap<rustc_middle::thir::LocalVarId, String>,
Expand All @@ -119,24 +119,41 @@ mod types {
}
}

/// Global caches
#[derive(Default)]
pub struct GlobalCache<'tcx> {
/// Cache the `Span` translations.
pub spans: HashMap<rustc_span::Span, Span>,
/// Per-item cache.
pub per_item: HashMap<RDefId, ItemCache<'tcx>>,
}

/// Per-item cache
#[derive(Default)]
pub struct ItemCache<'tcx> {
/// The translated `DefId`.
pub def_id: Option<DefId>,
/// Cache the `Ty` translations.
pub tys: HashMap<ty::Ty<'tcx>, Ty>,
/// Cache the trait resolution engine for each item.
pub predicate_searcher: Option<crate::traits::PredicateSearcher<'tcx>>,
/// Cache of trait refs to resolved impl expressions.
pub impl_exprs: HashMap<ty::PolyTraitRef<'tcx>, crate::traits::ImplExpr>,
/// Cache thir bodies.
pub thir: Option<(
Rc<rustc_middle::thir::Thir<'tcx>>,
rustc_middle::thir::ExprId,
)>,
}

#[derive(Clone)]
pub struct Base<'tcx> {
pub options: Rc<hax_frontend_exporter_options::Options>,
pub macro_infos: MacroCalls,
pub local_ctx: Rc<RefCell<LocalContextS>>,
pub opt_def_id: Option<rustc_hir::def_id::DefId>,
pub exported_spans: ExportedSpans,
pub exported_def_ids: ExportedDefIds,
pub cached_thirs: Rc<
HashMap<
rustc_span::def_id::LocalDefId,
(
Rc<rustc_middle::thir::Thir<'tcx>>,
rustc_middle::thir::ExprId,
),
>,
>,
pub tcx: rustc_middle::ty::TyCtxt<'tcx>,
pub cache: Rc<RefCell<GlobalCache<'tcx>>>,
pub tcx: ty::TyCtxt<'tcx>,
/// Rust doesn't enforce bounds on generic parameters in type
/// aliases. Thus, when translating type aliases, we need to
/// disable the resolution of implementation expressions. For
Expand All @@ -153,22 +170,18 @@ mod types {
Self {
tcx,
macro_infos: Rc::new(HashMap::new()),
cached_thirs: Rc::new(HashMap::new()),
cache: Default::default(),
options: Rc::new(options),
// Always prefer `s.owner_id()` to `s.base().opt_def_id`.
// `opt_def_id` is used in `utils` for error reporting
opt_def_id: None,
local_ctx: Rc::new(RefCell::new(LocalContextS::new())),
exported_spans: Rc::new(RefCell::new(HashSet::new())),
exported_def_ids: Rc::new(RefCell::new(HashSet::new())),
ty_alias_mode: false,
}
}
}

pub type MacroCalls = Rc<HashMap<Span, Span>>;
pub type ExportedSpans = Rc<RefCell<HashSet<rustc_span::Span>>>;
pub type ExportedDefIds = Rc<RefCell<HashSet<rustc_hir::def_id::DefId>>>;
pub type RcThir<'tcx> = Rc<rustc_middle::thir::Thir<'tcx>>;
pub type RcMir<'tcx> = Rc<rustc_middle::mir::Body<'tcx>>;
pub type Binder<'tcx> = rustc_middle::ty::Binder<'tcx, ()>;
Expand Down Expand Up @@ -294,6 +307,31 @@ pub trait UnderBinderState<'tcx> = UnderOwnerState<'tcx> + HasBinder<'tcx>;
/// body and an `owner_id` in the state
pub trait ExprState<'tcx> = UnderOwnerState<'tcx> + HasThir<'tcx>;

pub trait WithGlobalCacheExt<'tcx>: BaseState<'tcx> {
/// Access the global cache. You must not call `sinto` within this function as this will likely
/// result in `BorrowMut` panics.
fn with_global_cache<T>(&self, f: impl FnOnce(&mut GlobalCache<'tcx>) -> T) -> T {
let base = self.base();
let mut cache = base.cache.borrow_mut();
f(&mut *cache)
}
/// Access the cache for a given item. You must not call `sinto` within this function as this
/// will likely result in `BorrowMut` panics.
fn with_item_cache<T>(&self, def_id: RDefId, f: impl FnOnce(&mut ItemCache<'tcx>) -> T) -> T {
self.with_global_cache(|cache| f(cache.per_item.entry(def_id).or_default()))
}
}
impl<'tcx, S: BaseState<'tcx>> WithGlobalCacheExt<'tcx> for S {}

pub trait WithItemCacheExt<'tcx>: UnderOwnerState<'tcx> {
/// Access the cache for the current item. You must not call `sinto` within this function as
/// this will likely result in `BorrowMut` panics.
fn with_cache<T>(&self, f: impl FnOnce(&mut ItemCache<'tcx>) -> T) -> T {
self.with_item_cache(self.owner_id(), f)
}
}
impl<'tcx, S: UnderOwnerState<'tcx>> WithItemCacheExt<'tcx> for S {}

impl ImplInfos {
fn from(base: Base<'_>, did: rustc_hir::def_id::DefId) -> Self {
let tcx = base.tcx;
Expand All @@ -313,13 +351,9 @@ impl ImplInfos {
pub fn impl_def_ids_to_impled_types_and_bounds<'tcx, S: BaseState<'tcx>>(
s: &S,
) -> HashMap<DefId, ImplInfos> {
let Base {
tcx,
exported_def_ids,
..
} = s.base();
let tcx = s.base().tcx;

let def_ids = exported_def_ids.as_ref().borrow().clone();
let def_ids: Vec<_> = s.with_global_cache(|cache| cache.per_item.keys().copied().collect());
let with_parents = |mut did: rustc_hir::def_id::DefId| {
let mut acc = vec![did];
while let Some(parent) = tcx.opt_parent(did) {
Expand All @@ -330,8 +364,7 @@ pub fn impl_def_ids_to_impled_types_and_bounds<'tcx, S: BaseState<'tcx>>(
};
use itertools::Itertools;
def_ids
.iter()
.cloned()
.into_iter()
.flat_map(with_parents)
.unique()
.filter(|&did| {
Expand Down
Loading

0 comments on commit 050449c

Please sign in to comment.