Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 0 additions & 22 deletions compiler/rustc_middle/src/mir/terminator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -695,28 +695,6 @@ impl<'tcx> TerminatorKind<'tcx> {
_ => None,
}
}

/// Returns true if the terminator can write to memory.
pub fn can_write_to_memory(&self) -> bool {
match self {
TerminatorKind::Goto { .. }
| TerminatorKind::SwitchInt { .. }
| TerminatorKind::UnwindResume
| TerminatorKind::UnwindTerminate(_)
| TerminatorKind::Return
| TerminatorKind::Assert { .. }
| TerminatorKind::CoroutineDrop
| TerminatorKind::FalseEdge { .. }
| TerminatorKind::FalseUnwind { .. }
| TerminatorKind::Unreachable => false,
TerminatorKind::Call { .. }
| TerminatorKind::Drop { .. }
| TerminatorKind::TailCall { .. }
// Yield writes to the resume_arg place.
| TerminatorKind::Yield { .. }
| TerminatorKind::InlineAsm { .. } => true,
}
}
}

#[derive(Copy, Clone, Debug)]
Expand Down
139 changes: 86 additions & 53 deletions compiler/rustc_mir_transform/src/gvn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,24 +129,18 @@ impl<'tcx> crate::MirPass<'tcx> for GVN {
let ssa = SsaLocals::new(tcx, body, typing_env);
// Clone dominators because we need them while mutating the body.
let dominators = body.basic_blocks.dominators().clone();
let maybe_loop_headers = loops::maybe_loop_headers(body);

let arena = DroplessArena::default();
let mut state =
VnState::new(tcx, body, typing_env, &ssa, dominators, &body.local_decls, &arena);

for local in body.args_iter().filter(|&local| ssa.is_ssa(local)) {
let opaque = state.new_opaque(body.local_decls[local].ty);
let opaque = state.new_argument(body.local_decls[local].ty);
state.assign(local, opaque);
}

let reverse_postorder = body.basic_blocks.reverse_postorder().to_vec();
for bb in reverse_postorder {
// N.B. With loops, reverse postorder cannot produce a valid topological order.
// A statement or terminator from inside the loop, that is not processed yet, may have performed an indirect write.
if maybe_loop_headers.contains(bb) {
state.invalidate_derefs();
}
let data = &mut body.basic_blocks.as_mut_preserves_cfg()[bb];
state.visit_basic_block_data(bb, data);
}
Expand Down Expand Up @@ -204,8 +198,9 @@ enum AddressBase {
enum Value<'a, 'tcx> {
// Root values.
/// Used to represent values we know nothing about.
/// The `usize` is a counter incremented by `new_opaque`.
Opaque(VnOpaque),
/// The value is a argument.
Argument(VnOpaque),
/// Evaluated or unevaluated constant value.
Constant {
value: Const<'tcx>,
Expand Down Expand Up @@ -290,7 +285,7 @@ impl<'a, 'tcx> ValueSet<'a, 'tcx> {
let value = value(VnOpaque);

debug_assert!(match value {
Value::Opaque(_) | Value::Address { .. } => true,
Value::Opaque(_) | Value::Argument(_) | Value::Address { .. } => true,
Value::Constant { disambiguator, .. } => disambiguator.is_some(),
_ => false,
});
Expand Down Expand Up @@ -350,12 +345,6 @@ impl<'a, 'tcx> ValueSet<'a, 'tcx> {
fn ty(&self, index: VnIndex) -> Ty<'tcx> {
self.types[index]
}

/// Replace the value associated with `index` with an opaque value.
#[inline]
fn forget(&mut self, index: VnIndex) {
self.values[index] = Value::Opaque(VnOpaque);
}
}

struct VnState<'body, 'a, 'tcx> {
Expand All @@ -374,8 +363,6 @@ struct VnState<'body, 'a, 'tcx> {
/// - `Some(None)` are values for which computation has failed;
/// - `Some(Some(op))` are successful computations.
evaluated: IndexVec<VnIndex, Option<Option<&'a OpTy<'tcx>>>>,
/// Cache the deref values.
derefs: Vec<VnIndex>,
ssa: &'body SsaLocals,
dominators: Dominators<BasicBlock>,
reused_locals: DenseBitSet<Local>,
Expand Down Expand Up @@ -408,7 +395,6 @@ impl<'body, 'a, 'tcx> VnState<'body, 'a, 'tcx> {
rev_locals: IndexVec::with_capacity(num_values),
values: ValueSet::new(num_values),
evaluated: IndexVec::with_capacity(num_values),
derefs: Vec::new(),
ssa,
dominators,
reused_locals: DenseBitSet::new_empty(local_decls.len()),
Expand Down Expand Up @@ -455,6 +441,13 @@ impl<'body, 'a, 'tcx> VnState<'body, 'a, 'tcx> {
index
}

#[instrument(level = "trace", skip(self), ret)]
fn new_argument(&mut self, ty: Ty<'tcx>) -> VnIndex {
let index = self.insert_unique(ty, Value::Argument);
self.evaluated[index] = Some(None);
index
}

/// Create a new `Value::Address` distinct from all the others.
#[instrument(level = "trace", skip(self), ret)]
fn new_pointer(&mut self, place: Place<'tcx>, kind: AddressKind) -> Option<VnIndex> {
Expand All @@ -473,6 +466,10 @@ impl<'body, 'a, 'tcx> VnState<'body, 'a, 'tcx> {
projection.next();
AddressBase::Deref(base)
} else {
// Only propagate the pointer of the SSA local.
if !self.ssa.is_ssa(place.local) {
return None;
}
AddressBase::Local(place.local)
};
// Do not try evaluating inside `Index`, this has been done by `simplify_place_projection`.
Expand Down Expand Up @@ -541,18 +538,6 @@ impl<'body, 'a, 'tcx> VnState<'body, 'a, 'tcx> {
self.insert(ty, Value::Aggregate(VariantIdx::ZERO, self.arena.alloc_slice(values)))
}

fn insert_deref(&mut self, ty: Ty<'tcx>, value: VnIndex) -> VnIndex {
let value = self.insert(ty, Value::Projection(value, ProjectionElem::Deref));
self.derefs.push(value);
value
}

fn invalidate_derefs(&mut self) {
for deref in std::mem::take(&mut self.derefs) {
self.values.forget(deref);
}
}

#[instrument(level = "trace", skip(self), ret)]
fn eval_to_const_inner(&mut self, value: VnIndex) -> Option<OpTy<'tcx>> {
use Value::*;
Expand All @@ -566,7 +551,7 @@ impl<'body, 'a, 'tcx> VnState<'body, 'a, 'tcx> {
let op = match self.get(value) {
_ if ty.is_zst() => ImmTy::uninit(ty).into(),

Opaque(_) => return None,
Opaque(_) | Argument(_) => return None,
// Keep runtime check constants as symbolic.
RuntimeChecks(..) => return None,

Expand Down Expand Up @@ -818,7 +803,13 @@ impl<'body, 'a, 'tcx> VnState<'body, 'a, 'tcx> {

// An immutable borrow `_x` always points to the same value for the
// lifetime of the borrow, so we can merge all instances of `*_x`.
return Some((projection_ty, self.insert_deref(projection_ty.ty, value)));
return Some((
projection_ty,
self.insert(
projection_ty.ty,
Value::Projection(value, ProjectionElem::Deref),
),
));
} else {
return None;
}
Expand Down Expand Up @@ -1037,7 +1028,7 @@ impl<'body, 'a, 'tcx> VnState<'body, 'a, 'tcx> {
let op = self.simplify_operand(op, location)?;
Value::Repeat(op, amount)
}
Rvalue::Aggregate(..) => return self.simplify_aggregate(lhs, rvalue, location),
Rvalue::Aggregate(..) => return self.simplify_aggregate(rvalue, location),
Rvalue::Ref(_, borrow_kind, ref mut place) => {
self.simplify_place_projection(place, location);
return self.new_pointer(*place, AddressKind::Ref(borrow_kind));
Expand Down Expand Up @@ -1142,13 +1133,48 @@ impl<'body, 'a, 'tcx> VnState<'body, 'a, 'tcx> {
}
}

// We can introduce a new dereference if the source value cannot be changed in the body.
let mut copy_root = copy_from_local_value;
loop {
match self.get(copy_root) {
Value::Projection(base, _) => {
copy_root = base;
}
Value::Address {
base,
projection,
kind: AddressKind::Ref(BorrowKind::Shared),
..
} if projection.iter().all(ProjectionElem::is_stable_offset) => match base {
AddressBase::Local(_) => {
break;
}
AddressBase::Deref(index) => {
copy_root = index;
}
},
Value::Argument(_) if !self.ty(copy_root).is_mutable_ptr() => {
break;
}
Value::Opaque(_) => {
let ty = self.ty(copy_root);
if ty.is_fn() || !ty.is_any_ptr() {
break;
}
return None;
}
_ => {
return None;
}
}
}

// Both must be variants of the same type.
if self.ty(copy_from_local_value) == ty { Some(copy_from_local_value) } else { None }
}

fn simplify_aggregate(
&mut self,
lhs: &Place<'tcx>,
rvalue: &mut Rvalue<'tcx>,
location: Location,
) -> Option<VnIndex> {
Expand Down Expand Up @@ -1231,12 +1257,7 @@ impl<'body, 'a, 'tcx> VnState<'body, 'a, 'tcx> {
}

if let Some(value) = self.simplify_aggregate_to_copy(ty, variant_index, &fields) {
// Allow introducing places with non-constant offsets, as those are still better than
// reconstructing an aggregate. But avoid creating `*a = copy (*b)`, as they might be
// aliases resulting in overlapping assignments.
let allow_complex_projection =
lhs.projection[..].iter().all(PlaceElem::is_stable_offset);
if let Some(place) = self.try_as_place(value, location, allow_complex_projection) {
if let Some(place) = self.try_as_place(value, location, true) {
self.reused_locals.insert(place.local);
*rvalue = Rvalue::Use(Operand::Copy(place));
}
Expand Down Expand Up @@ -1873,10 +1894,6 @@ impl<'tcx> MutVisitor<'tcx> for VnState<'_, '_, 'tcx> {

fn visit_place(&mut self, place: &mut Place<'tcx>, context: PlaceContext, location: Location) {
self.simplify_place_projection(place, location);
if context.is_mutating_use() && place.is_indirect() {
// Non-local mutation maybe invalidate deref.
self.invalidate_derefs();
}
self.super_place(place, context, location);
}

Expand All @@ -1893,7 +1910,7 @@ impl<'tcx> MutVisitor<'tcx> for VnState<'_, '_, 'tcx> {
) {
self.simplify_place_projection(lhs, location);

let value = self.simplify_rvalue(lhs, rvalue, location);
let mut value = self.simplify_rvalue(lhs, rvalue, location);
if let Some(value) = value {
if let Some(const_) = self.try_as_constant(value) {
*rvalue = Rvalue::Use(Operand::Constant(Box::new(const_)));
Expand All @@ -1906,14 +1923,34 @@ impl<'tcx> MutVisitor<'tcx> for VnState<'_, '_, 'tcx> {
}
}

if lhs.is_indirect() {
// Non-local mutation maybe invalidate deref.
self.invalidate_derefs();
let rvalue_ty = rvalue.ty(self.local_decls, self.tcx);
// DO NOT reason the pointer value if it may point to a non-SSA local.
// For instance, we cannot unify two pointers that dereference same local, because they may
// have different lifetimes.
if rvalue_ty.is_ref()
&& let Some(index) = value
{
match self.get(index) {
Value::Opaque(_) | Value::Projection(..) => {
value = None;
}
Value::Constant { .. }
| Value::Address { .. }
| Value::Argument(_)
| Value::RawPtr { .. }
| Value::BinaryOp(..)
| Value::Cast { .. } => {}
Value::Aggregate(..)
| Value::Union(..)
| Value::Repeat(..)
| Value::Discriminant(..)
| Value::RuntimeChecks(..)
| Value::UnaryOp(..) => unreachable!(),
}
}

if let Some(local) = lhs.as_local()
&& self.ssa.is_ssa(local)
&& let rvalue_ty = rvalue.ty(self.local_decls, self.tcx)
// FIXME(#112651) `rvalue` may have a subtype to `local`. We can only mark
// `local` as reusable if we have an exact type match.
&& self.local_decls[local].ty == rvalue_ty
Expand All @@ -1933,10 +1970,6 @@ impl<'tcx> MutVisitor<'tcx> for VnState<'_, '_, 'tcx> {
self.assign(local, opaque);
}
}
// Terminators that can write to memory may invalidate (nested) derefs.
if terminator.kind.can_write_to_memory() {
self.invalidate_derefs();
}
self.super_terminator(terminator, location);
}
}
Expand Down
37 changes: 37 additions & 0 deletions tests/mir-opt/gvn.dereference_reborrow.GVN.panic-abort.diff
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
- // MIR for `dereference_reborrow` before GVN
+ // MIR for `dereference_reborrow` after GVN

fn dereference_reborrow(_1: &mut u8) -> () {
debug mut_a => _1;
let mut _0: ();
let _2: &u8;
scope 1 {
debug a => _2;
let _3: u8;
scope 2 {
debug b => _3;
let _4: u8;
scope 3 {
debug c => _4;
}
}
}

bb0: {
StorageLive(_2);
_2 = &(*_1);
- StorageLive(_3);
+ nop;
_3 = copy (*_2);
StorageLive(_4);
- _4 = copy (*_2);
+ _4 = copy _3;
_0 = const ();
StorageDead(_4);
- StorageDead(_3);
+ nop;
StorageDead(_2);
return;
}
}

37 changes: 37 additions & 0 deletions tests/mir-opt/gvn.dereference_reborrow.GVN.panic-unwind.diff
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
- // MIR for `dereference_reborrow` before GVN
+ // MIR for `dereference_reborrow` after GVN

fn dereference_reborrow(_1: &mut u8) -> () {
debug mut_a => _1;
let mut _0: ();
let _2: &u8;
scope 1 {
debug a => _2;
let _3: u8;
scope 2 {
debug b => _3;
let _4: u8;
scope 3 {
debug c => _4;
}
}
}

bb0: {
StorageLive(_2);
_2 = &(*_1);
- StorageLive(_3);
+ nop;
_3 = copy (*_2);
StorageLive(_4);
- _4 = copy (*_2);
+ _4 = copy _3;
_0 = const ();
StorageDead(_4);
- StorageDead(_3);
+ nop;
StorageDead(_2);
return;
}
}

Loading
Loading