Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RISCV][GISEL][NFC] Make MRI a member in RISCVInstructionSelector #110926

Merged
merged 2 commits into from
Oct 4, 2024

Conversation

michaelmaitland
Copy link
Contributor

@michaelmaitland michaelmaitland commented Oct 2, 2024

It was requested in #110782 (comment) that MRI be made a member of RISCVInstructionSelector.

RISCVInstructionSelector is created in the RISCVSubtarget, independent of MachineFunction. So it cannot be passed by reference during construction of RISCVInstructionSelector.

The MachineRegisterInfo object belongs to each MachineFunction, so we will set it as we enter select, which is the only public function to RISCVInstructionSelector. We don't need to worry about clearing it before returning from select, since there is no other entry point.

It was requested in llvm#110782 (comment)
that MRI be made a member of RISCVInstructionSelector.

RISCVInstructionSelector is created in the RISCVSubtarget, independent of
MachineFunction. So it cannot be passed by reference during construction of
RISCVInstructionSelector.

The MachineRegisterInfo object belongs to each MachineFunction, so we will
set it as we enter `select`, which is the only public function to
RISCVInstructionSelector. We don't need to worry about clearing it before
returning from `select`, since there is no other entry point.

I'm not sure this is any better than what we have today. Not sure whether
we should take this change or not. If in the future we have other public
functions of RISCVInstructionSelector, we will need to be more careful about
checking that MRI is set and clearing it where appropriate.
@llvmbot
Copy link
Member

llvmbot commented Oct 2, 2024

@llvm/pr-subscribers-backend-risc-v

Author: Michael Maitland (michaelmaitland)

Changes

It was requested in #110782 (comment) that MRI be made a member of RISCVInstructionSelector.

RISCVInstructionSelector is created in the RISCVSubtarget, independent of MachineFunction. So it cannot be passed by reference during construction of RISCVInstructionSelector.

The MachineRegisterInfo object belongs to each MachineFunction, so we will set it as we enter select, which is the only public function to RISCVInstructionSelector. We don't need to worry about clearing it before returning from select, since there is no other entry point.

I'm not sure this is any better than what we have today. Not sure whether we should take this change or not. If in the future we have other public functions of RISCVInstructionSelector, we will need to be more careful about checking that MRI is set and clearing it where appropriate.


Patch is 29.49 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/110926.diff

1 Files Affected:

  • (modified) llvm/lib/Target/RISCV/GISel/RISCVInstructionSelector.cpp (+91-116)
diff --git a/llvm/lib/Target/RISCV/GISel/RISCVInstructionSelector.cpp b/llvm/lib/Target/RISCV/GISel/RISCVInstructionSelector.cpp
index 92d00c26bd219c..dfaf87509dbcb8 100644
--- a/llvm/lib/Target/RISCV/GISel/RISCVInstructionSelector.cpp
+++ b/llvm/lib/Target/RISCV/GISel/RISCVInstructionSelector.cpp
@@ -49,8 +49,8 @@ class RISCVInstructionSelector : public InstructionSelector {
   const TargetRegisterClass *
   getRegClassForTypeOnBank(LLT Ty, const RegisterBank &RB) const;
 
-  bool isRegInGprb(Register Reg, MachineRegisterInfo &MRI) const;
-  bool isRegInFprb(Register Reg, MachineRegisterInfo &MRI) const;
+  bool isRegInGprb(Register Reg) const;
+  bool isRegInFprb(Register Reg) const;
 
   // tblgen-erated 'select' implementation, used as the initial selector for
   // the patterns that don't require complex C++.
@@ -58,31 +58,23 @@ class RISCVInstructionSelector : public InstructionSelector {
 
   // A lowering phase that runs before any selection attempts.
   // Returns true if the instruction was modified.
-  void preISelLower(MachineInstr &MI, MachineIRBuilder &MIB,
-                    MachineRegisterInfo &MRI);
+  void preISelLower(MachineInstr &MI, MachineIRBuilder &MIB);
 
-  bool replacePtrWithInt(MachineOperand &Op, MachineIRBuilder &MIB,
-                         MachineRegisterInfo &MRI);
+  bool replacePtrWithInt(MachineOperand &Op, MachineIRBuilder &MIB);
 
   // Custom selection methods
-  bool selectCopy(MachineInstr &MI, MachineRegisterInfo &MRI) const;
-  bool selectImplicitDef(MachineInstr &MI, MachineIRBuilder &MIB,
-                         MachineRegisterInfo &MRI) const;
+  bool selectCopy(MachineInstr &MI) const;
+  bool selectImplicitDef(MachineInstr &MI, MachineIRBuilder &MIB) const;
   bool materializeImm(Register Reg, int64_t Imm, MachineIRBuilder &MIB) const;
-  bool selectAddr(MachineInstr &MI, MachineIRBuilder &MIB,
-                  MachineRegisterInfo &MRI, bool IsLocal = true,
+  bool selectAddr(MachineInstr &MI, MachineIRBuilder &MIB, bool IsLocal = true,
                   bool IsExternWeak = false) const;
   bool selectSExtInreg(MachineInstr &MI, MachineIRBuilder &MIB) const;
-  bool selectSelect(MachineInstr &MI, MachineIRBuilder &MIB,
-                    MachineRegisterInfo &MRI) const;
-  bool selectFPCompare(MachineInstr &MI, MachineIRBuilder &MIB,
-                       MachineRegisterInfo &MRI) const;
+  bool selectSelect(MachineInstr &MI, MachineIRBuilder &MIB) const;
+  bool selectFPCompare(MachineInstr &MI, MachineIRBuilder &MIB) const;
   void emitFence(AtomicOrdering FenceOrdering, SyncScope::ID FenceSSID,
                  MachineIRBuilder &MIB) const;
-  bool selectMergeValues(MachineInstr &MI, MachineIRBuilder &MIB,
-                         MachineRegisterInfo &MRI) const;
-  bool selectUnmergeValues(MachineInstr &MI, MachineIRBuilder &MIB,
-                           MachineRegisterInfo &MRI) const;
+  bool selectMergeValues(MachineInstr &MI, MachineIRBuilder &MIB) const;
+  bool selectUnmergeValues(MachineInstr &MI, MachineIRBuilder &MIB) const;
 
   ComplexRendererFns selectShiftMask(MachineOperand &Root) const;
   ComplexRendererFns selectAddrRegImm(MachineOperand &Root) const;
@@ -121,6 +113,8 @@ class RISCVInstructionSelector : public InstructionSelector {
   const RISCVRegisterBankInfo &RBI;
   const RISCVTargetMachine &TM;
 
+  MachineRegisterInfo *MRI = nullptr;
+
   // FIXME: This is necessary because DAGISel uses "Subtarget->" and GlobalISel
   // uses "STI." in the code generated by TableGen. We need to unify the name of
   // Subtarget variable.
@@ -162,16 +156,15 @@ RISCVInstructionSelector::selectShiftMask(MachineOperand &Root) const {
     return std::nullopt;
 
   using namespace llvm::MIPatternMatch;
-  MachineRegisterInfo &MRI = MF->getRegInfo();
 
   Register RootReg = Root.getReg();
   Register ShAmtReg = RootReg;
-  const LLT ShiftLLT = MRI.getType(RootReg);
+  const LLT ShiftLLT = MRI->getType(RootReg);
   unsigned ShiftWidth = ShiftLLT.getSizeInBits();
   assert(isPowerOf2_32(ShiftWidth) && "Unexpected max shift amount!");
   // Peek through zext.
   Register ZExtSrcReg;
-  if (mi_match(ShAmtReg, MRI, m_GZExt(m_Reg(ZExtSrcReg)))) {
+  if (mi_match(ShAmtReg, *MRI, m_GZExt(m_Reg(ZExtSrcReg)))) {
     ShAmtReg = ZExtSrcReg;
   }
 
@@ -191,7 +184,7 @@ RISCVInstructionSelector::selectShiftMask(MachineOperand &Root) const {
   //
   // 1. the lowest log2(XLEN) bits of the and mask are all set
   // 2. the bits of the register being masked are already unset (zero set)
-  if (mi_match(ShAmtReg, MRI, m_GAnd(m_Reg(AndSrcReg), m_ICst(AndMask)))) {
+  if (mi_match(ShAmtReg, *MRI, m_GAnd(m_Reg(AndSrcReg), m_ICst(AndMask)))) {
     APInt ShMask(AndMask.getBitWidth(), ShiftWidth - 1);
     if (ShMask.isSubsetOf(AndMask)) {
       ShAmtReg = AndSrcReg;
@@ -206,16 +199,16 @@ RISCVInstructionSelector::selectShiftMask(MachineOperand &Root) const {
 
   APInt Imm;
   Register Reg;
-  if (mi_match(ShAmtReg, MRI, m_GAdd(m_Reg(Reg), m_ICst(Imm)))) {
+  if (mi_match(ShAmtReg, *MRI, m_GAdd(m_Reg(Reg), m_ICst(Imm)))) {
     if (Imm != 0 && Imm.urem(ShiftWidth) == 0)
       // If we are shifting by X+N where N == 0 mod Size, then just shift by X
       // to avoid the ADD.
       ShAmtReg = Reg;
-  } else if (mi_match(ShAmtReg, MRI, m_GSub(m_ICst(Imm), m_Reg(Reg)))) {
+  } else if (mi_match(ShAmtReg, *MRI, m_GSub(m_ICst(Imm), m_Reg(Reg)))) {
     if (Imm != 0 && Imm.urem(ShiftWidth) == 0) {
       // If we are shifting by N-X where N == 0 mod Size, then just shift by -X
       // to generate a NEG instead of a SUB of a constant.
-      ShAmtReg = MRI.createVirtualRegister(&RISCV::GPRRegClass);
+      ShAmtReg = MRI->createVirtualRegister(&RISCV::GPRRegClass);
       unsigned NegOpc = Subtarget->is64Bit() ? RISCV::SUBW : RISCV::SUB;
       return {{[=](MachineInstrBuilder &MIB) {
         MachineIRBuilder(*MIB.getInstr())
@@ -226,7 +219,7 @@ RISCVInstructionSelector::selectShiftMask(MachineOperand &Root) const {
     if (Imm.urem(ShiftWidth) == ShiftWidth - 1) {
       // If we are shifting by N-X where N == -1 mod Size, then just shift by ~X
       // to generate a NOT instead of a SUB of a constant.
-      ShAmtReg = MRI.createVirtualRegister(&RISCV::GPRRegClass);
+      ShAmtReg = MRI->createVirtualRegister(&RISCV::GPRRegClass);
       return {{[=](MachineInstrBuilder &MIB) {
         MachineIRBuilder(*MIB.getInstr())
             .buildInstr(RISCV::XORI, {ShAmtReg}, {Reg})
@@ -243,8 +236,6 @@ InstructionSelector::ComplexRendererFns
 RISCVInstructionSelector::selectSHXADDOp(MachineOperand &Root,
                                          unsigned ShAmt) const {
   using namespace llvm::MIPatternMatch;
-  MachineFunction &MF = *Root.getParent()->getParent()->getParent();
-  MachineRegisterInfo &MRI = MF.getRegInfo();
 
   if (!Root.isReg())
     return std::nullopt;
@@ -255,11 +246,11 @@ RISCVInstructionSelector::selectSHXADDOp(MachineOperand &Root,
   Register RegY;
   std::optional<bool> LeftShift;
   // (and (shl y, c2), mask)
-  if (mi_match(RootReg, MRI,
+  if (mi_match(RootReg, *MRI,
                m_GAnd(m_GShl(m_Reg(RegY), m_ICst(C2)), m_ICst(Mask))))
     LeftShift = true;
   // (and (lshr y, c2), mask)
-  else if (mi_match(RootReg, MRI,
+  else if (mi_match(RootReg, *MRI,
                     m_GAnd(m_GLShr(m_Reg(RegY), m_ICst(C2)), m_ICst(Mask))))
     LeftShift = false;
 
@@ -275,7 +266,7 @@ RISCVInstructionSelector::selectSHXADDOp(MachineOperand &Root,
       // Given (and (shl y, c2), mask) in which mask has no leading zeros and
       // c3 trailing zeros. We can use an SRLI by c3 - c2 followed by a SHXADD.
       if (*LeftShift && Leading == 0 && C2.ult(Trailing) && Trailing == ShAmt) {
-        Register DstReg = MRI.createVirtualRegister(&RISCV::GPRRegClass);
+        Register DstReg = MRI->createVirtualRegister(&RISCV::GPRRegClass);
         return {{[=](MachineInstrBuilder &MIB) {
           MachineIRBuilder(*MIB.getInstr())
               .buildInstr(RISCV::SRLI, {DstReg}, {RegY})
@@ -287,7 +278,7 @@ RISCVInstructionSelector::selectSHXADDOp(MachineOperand &Root,
       // Given (and (lshr y, c2), mask) in which mask has c2 leading zeros and
       // c3 trailing zeros. We can use an SRLI by c2 + c3 followed by a SHXADD.
       if (!*LeftShift && Leading == C2 && Trailing == ShAmt) {
-        Register DstReg = MRI.createVirtualRegister(&RISCV::GPRRegClass);
+        Register DstReg = MRI->createVirtualRegister(&RISCV::GPRRegClass);
         return {{[=](MachineInstrBuilder &MIB) {
           MachineIRBuilder(*MIB.getInstr())
               .buildInstr(RISCV::SRLI, {DstReg}, {RegY})
@@ -301,12 +292,12 @@ RISCVInstructionSelector::selectSHXADDOp(MachineOperand &Root,
   LeftShift.reset();
 
   // (shl (and y, mask), c2)
-  if (mi_match(RootReg, MRI,
+  if (mi_match(RootReg, *MRI,
                m_GShl(m_OneNonDBGUse(m_GAnd(m_Reg(RegY), m_ICst(Mask))),
                       m_ICst(C2))))
     LeftShift = true;
   // (lshr (and y, mask), c2)
-  else if (mi_match(RootReg, MRI,
+  else if (mi_match(RootReg, *MRI,
                     m_GLShr(m_OneNonDBGUse(m_GAnd(m_Reg(RegY), m_ICst(Mask))),
                             m_ICst(C2))))
     LeftShift = false;
@@ -326,7 +317,7 @@ RISCVInstructionSelector::selectSHXADDOp(MachineOperand &Root,
              (Trailing - C2.getLimitedValue()) == ShAmt;
 
     if (Cond) {
-      Register DstReg = MRI.createVirtualRegister(&RISCV::GPRRegClass);
+      Register DstReg = MRI->createVirtualRegister(&RISCV::GPRRegClass);
       return {{[=](MachineInstrBuilder &MIB) {
         MachineIRBuilder(*MIB.getInstr())
             .buildInstr(RISCV::SRLIW, {DstReg}, {RegY})
@@ -343,8 +334,6 @@ InstructionSelector::ComplexRendererFns
 RISCVInstructionSelector::selectSHXADD_UWOp(MachineOperand &Root,
                                             unsigned ShAmt) const {
   using namespace llvm::MIPatternMatch;
-  MachineFunction &MF = *Root.getParent()->getParent()->getParent();
-  MachineRegisterInfo &MRI = MF.getRegInfo();
 
   if (!Root.isReg())
     return std::nullopt;
@@ -356,7 +345,7 @@ RISCVInstructionSelector::selectSHXADD_UWOp(MachineOperand &Root,
   APInt Mask, C2;
   Register RegX;
   if (mi_match(
-          RootReg, MRI,
+          RootReg, *MRI,
           m_OneNonDBGUse(m_GAnd(m_OneNonDBGUse(m_GShl(m_Reg(RegX), m_ICst(C2))),
                                 m_ICst(Mask))))) {
     Mask &= maskTrailingZeros<uint64_t>(C2.getLimitedValue());
@@ -365,7 +354,7 @@ RISCVInstructionSelector::selectSHXADD_UWOp(MachineOperand &Root,
       unsigned Leading = Mask.countl_zero();
       unsigned Trailing = Mask.countr_zero();
       if (Leading == 32 - ShAmt && C2 == Trailing && Trailing > ShAmt) {
-        Register DstReg = MRI.createVirtualRegister(&RISCV::GPRRegClass);
+        Register DstReg = MRI->createVirtualRegister(&RISCV::GPRRegClass);
         return {{[=](MachineInstrBuilder &MIB) {
           MachineIRBuilder(*MIB.getInstr())
               .buildInstr(RISCV::SLLI, {DstReg}, {RegX})
@@ -381,13 +370,10 @@ RISCVInstructionSelector::selectSHXADD_UWOp(MachineOperand &Root,
 
 InstructionSelector::ComplexRendererFns
 RISCVInstructionSelector::selectAddrRegImm(MachineOperand &Root) const {
-  MachineFunction &MF = *Root.getParent()->getParent()->getParent();
-  MachineRegisterInfo &MRI = MF.getRegInfo();
-
   if (!Root.isReg())
     return std::nullopt;
 
-  MachineInstr *RootDef = MRI.getVRegDef(Root.getReg());
+  MachineInstr *RootDef = MRI->getVRegDef(Root.getReg());
   if (RootDef->getOpcode() == TargetOpcode::G_FRAME_INDEX) {
     return {{
         [=](MachineInstrBuilder &MIB) { MIB.add(RootDef->getOperand(1)); },
@@ -395,11 +381,11 @@ RISCVInstructionSelector::selectAddrRegImm(MachineOperand &Root) const {
     }};
   }
 
-  if (isBaseWithConstantOffset(Root, MRI)) {
+  if (isBaseWithConstantOffset(Root, *MRI)) {
     MachineOperand &LHS = RootDef->getOperand(1);
     MachineOperand &RHS = RootDef->getOperand(2);
-    MachineInstr *LHSDef = MRI.getVRegDef(LHS.getReg());
-    MachineInstr *RHSDef = MRI.getVRegDef(RHS.getReg());
+    MachineInstr *LHSDef = MRI->getVRegDef(LHS.getReg());
+    MachineInstr *RHSDef = MRI->getVRegDef(RHS.getReg());
 
     int64_t RHSC = RHSDef->getOperand(1).getCImm()->getSExtValue();
     if (isInt<12>(RHSC)) {
@@ -441,9 +427,9 @@ static RISCVCC::CondCode getRISCVCCFromICmp(CmpInst::Predicate CC) {
   }
 }
 
-static void getOperandsForBranch(Register CondReg, MachineRegisterInfo &MRI,
-                                 RISCVCC::CondCode &CC, Register &LHS,
-                                 Register &RHS) {
+static void getOperandsForBranch(Register CondReg, RISCVCC::CondCode &CC,
+                                 Register &LHS, Register &RHS,
+                                 MachineRegisterInfo &MRI) {
   // Try to fold an ICmp. If that fails, use a NE compare with X0.
   CmpInst::Predicate Pred = CmpInst::BAD_ICMP_PREDICATE;
   if (!mi_match(CondReg, MRI, m_GICmp(m_Pred(Pred), m_Reg(LHS), m_Reg(RHS)))) {
@@ -509,19 +495,19 @@ static void getOperandsForBranch(Register CondReg, MachineRegisterInfo &MRI,
 bool RISCVInstructionSelector::select(MachineInstr &MI) {
   MachineBasicBlock &MBB = *MI.getParent();
   MachineFunction &MF = *MBB.getParent();
-  MachineRegisterInfo &MRI = MF.getRegInfo();
+  MRI = &MF.getRegInfo();
   MachineIRBuilder MIB(MI);
 
-  preISelLower(MI, MIB, MRI);
+  preISelLower(MI, MIB);
   const unsigned Opc = MI.getOpcode();
 
   if (!MI.isPreISelOpcode() || Opc == TargetOpcode::G_PHI) {
     if (Opc == TargetOpcode::PHI || Opc == TargetOpcode::G_PHI) {
       const Register DefReg = MI.getOperand(0).getReg();
-      const LLT DefTy = MRI.getType(DefReg);
+      const LLT DefTy = MRI->getType(DefReg);
 
       const RegClassOrRegBank &RegClassOrBank =
-          MRI.getRegClassOrRegBank(DefReg);
+          MRI->getRegClassOrRegBank(DefReg);
 
       const TargetRegisterClass *DefRC =
           RegClassOrBank.dyn_cast<const TargetRegisterClass *>();
@@ -540,12 +526,12 @@ bool RISCVInstructionSelector::select(MachineInstr &MI) {
       }
 
       MI.setDesc(TII.get(TargetOpcode::PHI));
-      return RBI.constrainGenericRegister(DefReg, *DefRC, MRI);
+      return RBI.constrainGenericRegister(DefReg, *DefRC, *MRI);
     }
 
     // Certain non-generic instructions also need some special handling.
     if (MI.isCopy())
-      return selectCopy(MI, MRI);
+      return selectCopy(MI);
 
     return true;
   }
@@ -559,7 +545,7 @@ bool RISCVInstructionSelector::select(MachineInstr &MI) {
   case TargetOpcode::G_INTTOPTR:
   case TargetOpcode::G_TRUNC:
   case TargetOpcode::G_FREEZE:
-    return selectCopy(MI, MRI);
+    return selectCopy(MI);
   case TargetOpcode::G_CONSTANT: {
     Register DstReg = MI.getOperand(0).getReg();
     int64_t Imm = MI.getOperand(1).getCImm()->getSExtValue();
@@ -576,9 +562,9 @@ bool RISCVInstructionSelector::select(MachineInstr &MI) {
     Register DstReg = MI.getOperand(0).getReg();
     const APFloat &FPimm = MI.getOperand(1).getFPImm()->getValueAPF();
     APInt Imm = FPimm.bitcastToAPInt();
-    unsigned Size = MRI.getType(DstReg).getSizeInBits();
+    unsigned Size = MRI->getType(DstReg).getSizeInBits();
     if (Size == 16 || Size == 32 || (Size == 64 && Subtarget->is64Bit())) {
-      Register GPRReg = MRI.createVirtualRegister(&RISCV::GPRRegClass);
+      Register GPRReg = MRI->createVirtualRegister(&RISCV::GPRRegClass);
       if (!materializeImm(GPRReg, Imm.getSExtValue(), MIB))
         return false;
 
@@ -592,8 +578,8 @@ bool RISCVInstructionSelector::select(MachineInstr &MI) {
       assert(Size == 64 && !Subtarget->is64Bit() &&
              "Unexpected size or subtarget");
       // Split into two pieces and build through the stack.
-      Register GPRRegHigh = MRI.createVirtualRegister(&RISCV::GPRRegClass);
-      Register GPRRegLow = MRI.createVirtualRegister(&RISCV::GPRRegClass);
+      Register GPRRegHigh = MRI->createVirtualRegister(&RISCV::GPRRegClass);
+      Register GPRRegLow = MRI->createVirtualRegister(&RISCV::GPRRegClass);
       if (!materializeImm(GPRRegHigh, Imm.extractBits(32, 32).getSExtValue(),
                           MIB))
         return false;
@@ -615,8 +601,7 @@ bool RISCVInstructionSelector::select(MachineInstr &MI) {
       return false;
     }
 
-    return selectAddr(MI, MIB, MRI, GV->isDSOLocal(),
-                      GV->hasExternalWeakLinkage());
+    return selectAddr(MI, MIB, GV->isDSOLocal(), GV->hasExternalWeakLinkage());
   }
   case TargetOpcode::G_JUMP_TABLE:
   case TargetOpcode::G_CONSTANT_POOL:
@@ -624,7 +609,7 @@ bool RISCVInstructionSelector::select(MachineInstr &MI) {
   case TargetOpcode::G_BRCOND: {
     Register LHS, RHS;
     RISCVCC::CondCode CC;
-    getOperandsForBranch(MI.getOperand(0).getReg(), MRI, CC, LHS, RHS);
+    getOperandsForBranch(MI.getOperand(0).getReg(), CC, LHS, RHS, *MRI);
 
     auto Bcc = MIB.buildInstr(RISCVCC::getBrCond(CC), {}, {LHS, RHS})
                    .addMBB(MI.getOperand(1).getMBB());
@@ -698,9 +683,9 @@ bool RISCVInstructionSelector::select(MachineInstr &MI) {
     return constrainSelectedInstRegOperands(MI, TII, TRI, RBI);
   }
   case TargetOpcode::G_SELECT:
-    return selectSelect(MI, MIB, MRI);
+    return selectSelect(MI, MIB);
   case TargetOpcode::G_FCMP:
-    return selectFPCompare(MI, MIB, MRI);
+    return selectFPCompare(MI, MIB);
   case TargetOpcode::G_FENCE: {
     AtomicOrdering FenceOrdering =
         static_cast<AtomicOrdering>(MI.getOperand(0).getImm());
@@ -711,18 +696,18 @@ bool RISCVInstructionSelector::select(MachineInstr &MI) {
     return true;
   }
   case TargetOpcode::G_IMPLICIT_DEF:
-    return selectImplicitDef(MI, MIB, MRI);
+    return selectImplicitDef(MI, MIB);
   case TargetOpcode::G_MERGE_VALUES:
-    return selectMergeValues(MI, MIB, MRI);
+    return selectMergeValues(MI, MIB);
   case TargetOpcode::G_UNMERGE_VALUES:
-    return selectUnmergeValues(MI, MIB, MRI);
+    return selectUnmergeValues(MI, MIB);
   default:
     return false;
   }
 }
 
-bool RISCVInstructionSelector::selectMergeValues(
-    MachineInstr &MI, MachineIRBuilder &MIB, MachineRegisterInfo &MRI) const {
+bool RISCVInstructionSelector::selectMergeValues(MachineInstr &MI,
+                                                 MachineIRBuilder &MIB) const {
   assert(MI.getOpcode() == TargetOpcode::G_MERGE_VALUES);
 
   // Build a F64 Pair from operands
@@ -731,14 +716,14 @@ bool RISCVInstructionSelector::selectMergeValues(
   Register Dst = MI.getOperand(0).getReg();
   Register Lo = MI.getOperand(1).getReg();
   Register Hi = MI.getOperand(2).getReg();
-  if (!isRegInFprb(Dst, MRI) || !isRegInGprb(Lo, MRI) || !isRegInGprb(Hi, MRI))
+  if (!isRegInFprb(Dst) || !isRegInGprb(Lo) || !isRegInGprb(Hi))
     return false;
   MI.setDesc(TII.get(RISCV::BuildPairF64Pseudo));
   return constrainSelectedInstRegOperands(MI, TII, TRI, RBI);
 }
 
 bool RISCVInstructionSelector::selectUnmergeValues(
-    MachineInstr &MI, MachineIRBuilder &MIB, MachineRegisterInfo &MRI) const {
+    MachineInstr &MI, MachineIRBuilder &MIB) const {
   assert(MI.getOpcode() == TargetOpcode::G_UNMERGE_VALUES);
 
   // Split F64 Src into two s32 parts
@@ -747,44 +732,42 @@ bool RISCVInstructionSelector::selectUnmergeValues(
   Register Src = MI.getOperand(2).getReg();
   Register Lo = MI.getOperand(0).getReg();
   Register Hi = MI.getOperand(1).getReg();
-  if (!isRegInFprb(Src, MRI) || !isRegInGprb(Lo, MRI) || !isRegInGprb(Hi, MRI))
+  if (!isRegInFprb(Src) || !isRegInGprb(Lo) || !isRegInGprb(Hi))
     return false;
   MI.setDesc(TII.get(RISCV::SplitF64Pseudo));
   return constrainSelectedInstRegOperands(MI, TII, TRI, RBI);
 }
 
 bool RISCVInstructionSelector::replacePtrWithInt(MachineOperand &Op,
-                                                 MachineIRBuilder &MIB,
-                                                 MachineRegisterInfo &MRI) {
+                                                 MachineIRBuilder &MIB) {
   Register PtrReg = Op.getReg();
-  assert(MRI.getType(PtrReg).isPointer() && "Operand is not a pointer!");
+  assert(MRI->getType(PtrReg).isPointer() && "Operand is not a pointer!");
 
   const LLT sXLen = LLT::scalar(STI.getXLen());
   auto PtrToInt = MIB.buildPtrToInt(sXLen, PtrReg);
-  MRI.setRegBank(PtrToInt.getReg(0), RBI.getRegBank(RISCV::GPRBRegBankID));
+  MRI->setRegBank(PtrToInt.getReg(0), RBI.getRegBank(RISCV::GPRBRegBankID));
   Op.setReg(PtrToInt.getReg(0));
   return select(*PtrToInt);
 }
 
 void RISCV...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Oct 2, 2024

@llvm/pr-subscribers-llvm-globalisel

Author: Michael Maitland (michaelmaitland)

Changes

It was requested in #110782 (comment) that MRI be made a member of RISCVInstructionSelector.

RISCVInstructionSelector is created in the RISCVSubtarget, independent of MachineFunction. So it cannot be passed by reference during construction of RISCVInstructionSelector.

The MachineRegisterInfo object belongs to each MachineFunction, so we will set it as we enter select, which is the only public function to RISCVInstructionSelector. We don't need to worry about clearing it before returning from select, since there is no other entry point.

I'm not sure this is any better than what we have today. Not sure whether we should take this change or not. If in the future we have other public functions of RISCVInstructionSelector, we will need to be more careful about checking that MRI is set and clearing it where appropriate.


Patch is 29.49 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/110926.diff

1 Files Affected:

  • (modified) llvm/lib/Target/RISCV/GISel/RISCVInstructionSelector.cpp (+91-116)
diff --git a/llvm/lib/Target/RISCV/GISel/RISCVInstructionSelector.cpp b/llvm/lib/Target/RISCV/GISel/RISCVInstructionSelector.cpp
index 92d00c26bd219c..dfaf87509dbcb8 100644
--- a/llvm/lib/Target/RISCV/GISel/RISCVInstructionSelector.cpp
+++ b/llvm/lib/Target/RISCV/GISel/RISCVInstructionSelector.cpp
@@ -49,8 +49,8 @@ class RISCVInstructionSelector : public InstructionSelector {
   const TargetRegisterClass *
   getRegClassForTypeOnBank(LLT Ty, const RegisterBank &RB) const;
 
-  bool isRegInGprb(Register Reg, MachineRegisterInfo &MRI) const;
-  bool isRegInFprb(Register Reg, MachineRegisterInfo &MRI) const;
+  bool isRegInGprb(Register Reg) const;
+  bool isRegInFprb(Register Reg) const;
 
   // tblgen-erated 'select' implementation, used as the initial selector for
   // the patterns that don't require complex C++.
@@ -58,31 +58,23 @@ class RISCVInstructionSelector : public InstructionSelector {
 
   // A lowering phase that runs before any selection attempts.
   // Returns true if the instruction was modified.
-  void preISelLower(MachineInstr &MI, MachineIRBuilder &MIB,
-                    MachineRegisterInfo &MRI);
+  void preISelLower(MachineInstr &MI, MachineIRBuilder &MIB);
 
-  bool replacePtrWithInt(MachineOperand &Op, MachineIRBuilder &MIB,
-                         MachineRegisterInfo &MRI);
+  bool replacePtrWithInt(MachineOperand &Op, MachineIRBuilder &MIB);
 
   // Custom selection methods
-  bool selectCopy(MachineInstr &MI, MachineRegisterInfo &MRI) const;
-  bool selectImplicitDef(MachineInstr &MI, MachineIRBuilder &MIB,
-                         MachineRegisterInfo &MRI) const;
+  bool selectCopy(MachineInstr &MI) const;
+  bool selectImplicitDef(MachineInstr &MI, MachineIRBuilder &MIB) const;
   bool materializeImm(Register Reg, int64_t Imm, MachineIRBuilder &MIB) const;
-  bool selectAddr(MachineInstr &MI, MachineIRBuilder &MIB,
-                  MachineRegisterInfo &MRI, bool IsLocal = true,
+  bool selectAddr(MachineInstr &MI, MachineIRBuilder &MIB, bool IsLocal = true,
                   bool IsExternWeak = false) const;
   bool selectSExtInreg(MachineInstr &MI, MachineIRBuilder &MIB) const;
-  bool selectSelect(MachineInstr &MI, MachineIRBuilder &MIB,
-                    MachineRegisterInfo &MRI) const;
-  bool selectFPCompare(MachineInstr &MI, MachineIRBuilder &MIB,
-                       MachineRegisterInfo &MRI) const;
+  bool selectSelect(MachineInstr &MI, MachineIRBuilder &MIB) const;
+  bool selectFPCompare(MachineInstr &MI, MachineIRBuilder &MIB) const;
   void emitFence(AtomicOrdering FenceOrdering, SyncScope::ID FenceSSID,
                  MachineIRBuilder &MIB) const;
-  bool selectMergeValues(MachineInstr &MI, MachineIRBuilder &MIB,
-                         MachineRegisterInfo &MRI) const;
-  bool selectUnmergeValues(MachineInstr &MI, MachineIRBuilder &MIB,
-                           MachineRegisterInfo &MRI) const;
+  bool selectMergeValues(MachineInstr &MI, MachineIRBuilder &MIB) const;
+  bool selectUnmergeValues(MachineInstr &MI, MachineIRBuilder &MIB) const;
 
   ComplexRendererFns selectShiftMask(MachineOperand &Root) const;
   ComplexRendererFns selectAddrRegImm(MachineOperand &Root) const;
@@ -121,6 +113,8 @@ class RISCVInstructionSelector : public InstructionSelector {
   const RISCVRegisterBankInfo &RBI;
   const RISCVTargetMachine &TM;
 
+  MachineRegisterInfo *MRI = nullptr;
+
   // FIXME: This is necessary because DAGISel uses "Subtarget->" and GlobalISel
   // uses "STI." in the code generated by TableGen. We need to unify the name of
   // Subtarget variable.
@@ -162,16 +156,15 @@ RISCVInstructionSelector::selectShiftMask(MachineOperand &Root) const {
     return std::nullopt;
 
   using namespace llvm::MIPatternMatch;
-  MachineRegisterInfo &MRI = MF->getRegInfo();
 
   Register RootReg = Root.getReg();
   Register ShAmtReg = RootReg;
-  const LLT ShiftLLT = MRI.getType(RootReg);
+  const LLT ShiftLLT = MRI->getType(RootReg);
   unsigned ShiftWidth = ShiftLLT.getSizeInBits();
   assert(isPowerOf2_32(ShiftWidth) && "Unexpected max shift amount!");
   // Peek through zext.
   Register ZExtSrcReg;
-  if (mi_match(ShAmtReg, MRI, m_GZExt(m_Reg(ZExtSrcReg)))) {
+  if (mi_match(ShAmtReg, *MRI, m_GZExt(m_Reg(ZExtSrcReg)))) {
     ShAmtReg = ZExtSrcReg;
   }
 
@@ -191,7 +184,7 @@ RISCVInstructionSelector::selectShiftMask(MachineOperand &Root) const {
   //
   // 1. the lowest log2(XLEN) bits of the and mask are all set
   // 2. the bits of the register being masked are already unset (zero set)
-  if (mi_match(ShAmtReg, MRI, m_GAnd(m_Reg(AndSrcReg), m_ICst(AndMask)))) {
+  if (mi_match(ShAmtReg, *MRI, m_GAnd(m_Reg(AndSrcReg), m_ICst(AndMask)))) {
     APInt ShMask(AndMask.getBitWidth(), ShiftWidth - 1);
     if (ShMask.isSubsetOf(AndMask)) {
       ShAmtReg = AndSrcReg;
@@ -206,16 +199,16 @@ RISCVInstructionSelector::selectShiftMask(MachineOperand &Root) const {
 
   APInt Imm;
   Register Reg;
-  if (mi_match(ShAmtReg, MRI, m_GAdd(m_Reg(Reg), m_ICst(Imm)))) {
+  if (mi_match(ShAmtReg, *MRI, m_GAdd(m_Reg(Reg), m_ICst(Imm)))) {
     if (Imm != 0 && Imm.urem(ShiftWidth) == 0)
       // If we are shifting by X+N where N == 0 mod Size, then just shift by X
       // to avoid the ADD.
       ShAmtReg = Reg;
-  } else if (mi_match(ShAmtReg, MRI, m_GSub(m_ICst(Imm), m_Reg(Reg)))) {
+  } else if (mi_match(ShAmtReg, *MRI, m_GSub(m_ICst(Imm), m_Reg(Reg)))) {
     if (Imm != 0 && Imm.urem(ShiftWidth) == 0) {
       // If we are shifting by N-X where N == 0 mod Size, then just shift by -X
       // to generate a NEG instead of a SUB of a constant.
-      ShAmtReg = MRI.createVirtualRegister(&RISCV::GPRRegClass);
+      ShAmtReg = MRI->createVirtualRegister(&RISCV::GPRRegClass);
       unsigned NegOpc = Subtarget->is64Bit() ? RISCV::SUBW : RISCV::SUB;
       return {{[=](MachineInstrBuilder &MIB) {
         MachineIRBuilder(*MIB.getInstr())
@@ -226,7 +219,7 @@ RISCVInstructionSelector::selectShiftMask(MachineOperand &Root) const {
     if (Imm.urem(ShiftWidth) == ShiftWidth - 1) {
       // If we are shifting by N-X where N == -1 mod Size, then just shift by ~X
       // to generate a NOT instead of a SUB of a constant.
-      ShAmtReg = MRI.createVirtualRegister(&RISCV::GPRRegClass);
+      ShAmtReg = MRI->createVirtualRegister(&RISCV::GPRRegClass);
       return {{[=](MachineInstrBuilder &MIB) {
         MachineIRBuilder(*MIB.getInstr())
             .buildInstr(RISCV::XORI, {ShAmtReg}, {Reg})
@@ -243,8 +236,6 @@ InstructionSelector::ComplexRendererFns
 RISCVInstructionSelector::selectSHXADDOp(MachineOperand &Root,
                                          unsigned ShAmt) const {
   using namespace llvm::MIPatternMatch;
-  MachineFunction &MF = *Root.getParent()->getParent()->getParent();
-  MachineRegisterInfo &MRI = MF.getRegInfo();
 
   if (!Root.isReg())
     return std::nullopt;
@@ -255,11 +246,11 @@ RISCVInstructionSelector::selectSHXADDOp(MachineOperand &Root,
   Register RegY;
   std::optional<bool> LeftShift;
   // (and (shl y, c2), mask)
-  if (mi_match(RootReg, MRI,
+  if (mi_match(RootReg, *MRI,
                m_GAnd(m_GShl(m_Reg(RegY), m_ICst(C2)), m_ICst(Mask))))
     LeftShift = true;
   // (and (lshr y, c2), mask)
-  else if (mi_match(RootReg, MRI,
+  else if (mi_match(RootReg, *MRI,
                     m_GAnd(m_GLShr(m_Reg(RegY), m_ICst(C2)), m_ICst(Mask))))
     LeftShift = false;
 
@@ -275,7 +266,7 @@ RISCVInstructionSelector::selectSHXADDOp(MachineOperand &Root,
       // Given (and (shl y, c2), mask) in which mask has no leading zeros and
       // c3 trailing zeros. We can use an SRLI by c3 - c2 followed by a SHXADD.
       if (*LeftShift && Leading == 0 && C2.ult(Trailing) && Trailing == ShAmt) {
-        Register DstReg = MRI.createVirtualRegister(&RISCV::GPRRegClass);
+        Register DstReg = MRI->createVirtualRegister(&RISCV::GPRRegClass);
         return {{[=](MachineInstrBuilder &MIB) {
           MachineIRBuilder(*MIB.getInstr())
               .buildInstr(RISCV::SRLI, {DstReg}, {RegY})
@@ -287,7 +278,7 @@ RISCVInstructionSelector::selectSHXADDOp(MachineOperand &Root,
       // Given (and (lshr y, c2), mask) in which mask has c2 leading zeros and
       // c3 trailing zeros. We can use an SRLI by c2 + c3 followed by a SHXADD.
       if (!*LeftShift && Leading == C2 && Trailing == ShAmt) {
-        Register DstReg = MRI.createVirtualRegister(&RISCV::GPRRegClass);
+        Register DstReg = MRI->createVirtualRegister(&RISCV::GPRRegClass);
         return {{[=](MachineInstrBuilder &MIB) {
           MachineIRBuilder(*MIB.getInstr())
               .buildInstr(RISCV::SRLI, {DstReg}, {RegY})
@@ -301,12 +292,12 @@ RISCVInstructionSelector::selectSHXADDOp(MachineOperand &Root,
   LeftShift.reset();
 
   // (shl (and y, mask), c2)
-  if (mi_match(RootReg, MRI,
+  if (mi_match(RootReg, *MRI,
                m_GShl(m_OneNonDBGUse(m_GAnd(m_Reg(RegY), m_ICst(Mask))),
                       m_ICst(C2))))
     LeftShift = true;
   // (lshr (and y, mask), c2)
-  else if (mi_match(RootReg, MRI,
+  else if (mi_match(RootReg, *MRI,
                     m_GLShr(m_OneNonDBGUse(m_GAnd(m_Reg(RegY), m_ICst(Mask))),
                             m_ICst(C2))))
     LeftShift = false;
@@ -326,7 +317,7 @@ RISCVInstructionSelector::selectSHXADDOp(MachineOperand &Root,
              (Trailing - C2.getLimitedValue()) == ShAmt;
 
     if (Cond) {
-      Register DstReg = MRI.createVirtualRegister(&RISCV::GPRRegClass);
+      Register DstReg = MRI->createVirtualRegister(&RISCV::GPRRegClass);
       return {{[=](MachineInstrBuilder &MIB) {
         MachineIRBuilder(*MIB.getInstr())
             .buildInstr(RISCV::SRLIW, {DstReg}, {RegY})
@@ -343,8 +334,6 @@ InstructionSelector::ComplexRendererFns
 RISCVInstructionSelector::selectSHXADD_UWOp(MachineOperand &Root,
                                             unsigned ShAmt) const {
   using namespace llvm::MIPatternMatch;
-  MachineFunction &MF = *Root.getParent()->getParent()->getParent();
-  MachineRegisterInfo &MRI = MF.getRegInfo();
 
   if (!Root.isReg())
     return std::nullopt;
@@ -356,7 +345,7 @@ RISCVInstructionSelector::selectSHXADD_UWOp(MachineOperand &Root,
   APInt Mask, C2;
   Register RegX;
   if (mi_match(
-          RootReg, MRI,
+          RootReg, *MRI,
           m_OneNonDBGUse(m_GAnd(m_OneNonDBGUse(m_GShl(m_Reg(RegX), m_ICst(C2))),
                                 m_ICst(Mask))))) {
     Mask &= maskTrailingZeros<uint64_t>(C2.getLimitedValue());
@@ -365,7 +354,7 @@ RISCVInstructionSelector::selectSHXADD_UWOp(MachineOperand &Root,
       unsigned Leading = Mask.countl_zero();
       unsigned Trailing = Mask.countr_zero();
       if (Leading == 32 - ShAmt && C2 == Trailing && Trailing > ShAmt) {
-        Register DstReg = MRI.createVirtualRegister(&RISCV::GPRRegClass);
+        Register DstReg = MRI->createVirtualRegister(&RISCV::GPRRegClass);
         return {{[=](MachineInstrBuilder &MIB) {
           MachineIRBuilder(*MIB.getInstr())
               .buildInstr(RISCV::SLLI, {DstReg}, {RegX})
@@ -381,13 +370,10 @@ RISCVInstructionSelector::selectSHXADD_UWOp(MachineOperand &Root,
 
 InstructionSelector::ComplexRendererFns
 RISCVInstructionSelector::selectAddrRegImm(MachineOperand &Root) const {
-  MachineFunction &MF = *Root.getParent()->getParent()->getParent();
-  MachineRegisterInfo &MRI = MF.getRegInfo();
-
   if (!Root.isReg())
     return std::nullopt;
 
-  MachineInstr *RootDef = MRI.getVRegDef(Root.getReg());
+  MachineInstr *RootDef = MRI->getVRegDef(Root.getReg());
   if (RootDef->getOpcode() == TargetOpcode::G_FRAME_INDEX) {
     return {{
         [=](MachineInstrBuilder &MIB) { MIB.add(RootDef->getOperand(1)); },
@@ -395,11 +381,11 @@ RISCVInstructionSelector::selectAddrRegImm(MachineOperand &Root) const {
     }};
   }
 
-  if (isBaseWithConstantOffset(Root, MRI)) {
+  if (isBaseWithConstantOffset(Root, *MRI)) {
     MachineOperand &LHS = RootDef->getOperand(1);
     MachineOperand &RHS = RootDef->getOperand(2);
-    MachineInstr *LHSDef = MRI.getVRegDef(LHS.getReg());
-    MachineInstr *RHSDef = MRI.getVRegDef(RHS.getReg());
+    MachineInstr *LHSDef = MRI->getVRegDef(LHS.getReg());
+    MachineInstr *RHSDef = MRI->getVRegDef(RHS.getReg());
 
     int64_t RHSC = RHSDef->getOperand(1).getCImm()->getSExtValue();
     if (isInt<12>(RHSC)) {
@@ -441,9 +427,9 @@ static RISCVCC::CondCode getRISCVCCFromICmp(CmpInst::Predicate CC) {
   }
 }
 
-static void getOperandsForBranch(Register CondReg, MachineRegisterInfo &MRI,
-                                 RISCVCC::CondCode &CC, Register &LHS,
-                                 Register &RHS) {
+static void getOperandsForBranch(Register CondReg, RISCVCC::CondCode &CC,
+                                 Register &LHS, Register &RHS,
+                                 MachineRegisterInfo &MRI) {
   // Try to fold an ICmp. If that fails, use a NE compare with X0.
   CmpInst::Predicate Pred = CmpInst::BAD_ICMP_PREDICATE;
   if (!mi_match(CondReg, MRI, m_GICmp(m_Pred(Pred), m_Reg(LHS), m_Reg(RHS)))) {
@@ -509,19 +495,19 @@ static void getOperandsForBranch(Register CondReg, MachineRegisterInfo &MRI,
 bool RISCVInstructionSelector::select(MachineInstr &MI) {
   MachineBasicBlock &MBB = *MI.getParent();
   MachineFunction &MF = *MBB.getParent();
-  MachineRegisterInfo &MRI = MF.getRegInfo();
+  MRI = &MF.getRegInfo();
   MachineIRBuilder MIB(MI);
 
-  preISelLower(MI, MIB, MRI);
+  preISelLower(MI, MIB);
   const unsigned Opc = MI.getOpcode();
 
   if (!MI.isPreISelOpcode() || Opc == TargetOpcode::G_PHI) {
     if (Opc == TargetOpcode::PHI || Opc == TargetOpcode::G_PHI) {
       const Register DefReg = MI.getOperand(0).getReg();
-      const LLT DefTy = MRI.getType(DefReg);
+      const LLT DefTy = MRI->getType(DefReg);
 
       const RegClassOrRegBank &RegClassOrBank =
-          MRI.getRegClassOrRegBank(DefReg);
+          MRI->getRegClassOrRegBank(DefReg);
 
       const TargetRegisterClass *DefRC =
           RegClassOrBank.dyn_cast<const TargetRegisterClass *>();
@@ -540,12 +526,12 @@ bool RISCVInstructionSelector::select(MachineInstr &MI) {
       }
 
       MI.setDesc(TII.get(TargetOpcode::PHI));
-      return RBI.constrainGenericRegister(DefReg, *DefRC, MRI);
+      return RBI.constrainGenericRegister(DefReg, *DefRC, *MRI);
     }
 
     // Certain non-generic instructions also need some special handling.
     if (MI.isCopy())
-      return selectCopy(MI, MRI);
+      return selectCopy(MI);
 
     return true;
   }
@@ -559,7 +545,7 @@ bool RISCVInstructionSelector::select(MachineInstr &MI) {
   case TargetOpcode::G_INTTOPTR:
   case TargetOpcode::G_TRUNC:
   case TargetOpcode::G_FREEZE:
-    return selectCopy(MI, MRI);
+    return selectCopy(MI);
   case TargetOpcode::G_CONSTANT: {
     Register DstReg = MI.getOperand(0).getReg();
     int64_t Imm = MI.getOperand(1).getCImm()->getSExtValue();
@@ -576,9 +562,9 @@ bool RISCVInstructionSelector::select(MachineInstr &MI) {
     Register DstReg = MI.getOperand(0).getReg();
     const APFloat &FPimm = MI.getOperand(1).getFPImm()->getValueAPF();
     APInt Imm = FPimm.bitcastToAPInt();
-    unsigned Size = MRI.getType(DstReg).getSizeInBits();
+    unsigned Size = MRI->getType(DstReg).getSizeInBits();
     if (Size == 16 || Size == 32 || (Size == 64 && Subtarget->is64Bit())) {
-      Register GPRReg = MRI.createVirtualRegister(&RISCV::GPRRegClass);
+      Register GPRReg = MRI->createVirtualRegister(&RISCV::GPRRegClass);
       if (!materializeImm(GPRReg, Imm.getSExtValue(), MIB))
         return false;
 
@@ -592,8 +578,8 @@ bool RISCVInstructionSelector::select(MachineInstr &MI) {
       assert(Size == 64 && !Subtarget->is64Bit() &&
              "Unexpected size or subtarget");
       // Split into two pieces and build through the stack.
-      Register GPRRegHigh = MRI.createVirtualRegister(&RISCV::GPRRegClass);
-      Register GPRRegLow = MRI.createVirtualRegister(&RISCV::GPRRegClass);
+      Register GPRRegHigh = MRI->createVirtualRegister(&RISCV::GPRRegClass);
+      Register GPRRegLow = MRI->createVirtualRegister(&RISCV::GPRRegClass);
       if (!materializeImm(GPRRegHigh, Imm.extractBits(32, 32).getSExtValue(),
                           MIB))
         return false;
@@ -615,8 +601,7 @@ bool RISCVInstructionSelector::select(MachineInstr &MI) {
       return false;
     }
 
-    return selectAddr(MI, MIB, MRI, GV->isDSOLocal(),
-                      GV->hasExternalWeakLinkage());
+    return selectAddr(MI, MIB, GV->isDSOLocal(), GV->hasExternalWeakLinkage());
   }
   case TargetOpcode::G_JUMP_TABLE:
   case TargetOpcode::G_CONSTANT_POOL:
@@ -624,7 +609,7 @@ bool RISCVInstructionSelector::select(MachineInstr &MI) {
   case TargetOpcode::G_BRCOND: {
     Register LHS, RHS;
     RISCVCC::CondCode CC;
-    getOperandsForBranch(MI.getOperand(0).getReg(), MRI, CC, LHS, RHS);
+    getOperandsForBranch(MI.getOperand(0).getReg(), CC, LHS, RHS, *MRI);
 
     auto Bcc = MIB.buildInstr(RISCVCC::getBrCond(CC), {}, {LHS, RHS})
                    .addMBB(MI.getOperand(1).getMBB());
@@ -698,9 +683,9 @@ bool RISCVInstructionSelector::select(MachineInstr &MI) {
     return constrainSelectedInstRegOperands(MI, TII, TRI, RBI);
   }
   case TargetOpcode::G_SELECT:
-    return selectSelect(MI, MIB, MRI);
+    return selectSelect(MI, MIB);
   case TargetOpcode::G_FCMP:
-    return selectFPCompare(MI, MIB, MRI);
+    return selectFPCompare(MI, MIB);
   case TargetOpcode::G_FENCE: {
     AtomicOrdering FenceOrdering =
         static_cast<AtomicOrdering>(MI.getOperand(0).getImm());
@@ -711,18 +696,18 @@ bool RISCVInstructionSelector::select(MachineInstr &MI) {
     return true;
   }
   case TargetOpcode::G_IMPLICIT_DEF:
-    return selectImplicitDef(MI, MIB, MRI);
+    return selectImplicitDef(MI, MIB);
   case TargetOpcode::G_MERGE_VALUES:
-    return selectMergeValues(MI, MIB, MRI);
+    return selectMergeValues(MI, MIB);
   case TargetOpcode::G_UNMERGE_VALUES:
-    return selectUnmergeValues(MI, MIB, MRI);
+    return selectUnmergeValues(MI, MIB);
   default:
     return false;
   }
 }
 
-bool RISCVInstructionSelector::selectMergeValues(
-    MachineInstr &MI, MachineIRBuilder &MIB, MachineRegisterInfo &MRI) const {
+bool RISCVInstructionSelector::selectMergeValues(MachineInstr &MI,
+                                                 MachineIRBuilder &MIB) const {
   assert(MI.getOpcode() == TargetOpcode::G_MERGE_VALUES);
 
   // Build a F64 Pair from operands
@@ -731,14 +716,14 @@ bool RISCVInstructionSelector::selectMergeValues(
   Register Dst = MI.getOperand(0).getReg();
   Register Lo = MI.getOperand(1).getReg();
   Register Hi = MI.getOperand(2).getReg();
-  if (!isRegInFprb(Dst, MRI) || !isRegInGprb(Lo, MRI) || !isRegInGprb(Hi, MRI))
+  if (!isRegInFprb(Dst) || !isRegInGprb(Lo) || !isRegInGprb(Hi))
     return false;
   MI.setDesc(TII.get(RISCV::BuildPairF64Pseudo));
   return constrainSelectedInstRegOperands(MI, TII, TRI, RBI);
 }
 
 bool RISCVInstructionSelector::selectUnmergeValues(
-    MachineInstr &MI, MachineIRBuilder &MIB, MachineRegisterInfo &MRI) const {
+    MachineInstr &MI, MachineIRBuilder &MIB) const {
   assert(MI.getOpcode() == TargetOpcode::G_UNMERGE_VALUES);
 
   // Split F64 Src into two s32 parts
@@ -747,44 +732,42 @@ bool RISCVInstructionSelector::selectUnmergeValues(
   Register Src = MI.getOperand(2).getReg();
   Register Lo = MI.getOperand(0).getReg();
   Register Hi = MI.getOperand(1).getReg();
-  if (!isRegInFprb(Src, MRI) || !isRegInGprb(Lo, MRI) || !isRegInGprb(Hi, MRI))
+  if (!isRegInFprb(Src) || !isRegInGprb(Lo) || !isRegInGprb(Hi))
     return false;
   MI.setDesc(TII.get(RISCV::SplitF64Pseudo));
   return constrainSelectedInstRegOperands(MI, TII, TRI, RBI);
 }
 
 bool RISCVInstructionSelector::replacePtrWithInt(MachineOperand &Op,
-                                                 MachineIRBuilder &MIB,
-                                                 MachineRegisterInfo &MRI) {
+                                                 MachineIRBuilder &MIB) {
   Register PtrReg = Op.getReg();
-  assert(MRI.getType(PtrReg).isPointer() && "Operand is not a pointer!");
+  assert(MRI->getType(PtrReg).isPointer() && "Operand is not a pointer!");
 
   const LLT sXLen = LLT::scalar(STI.getXLen());
   auto PtrToInt = MIB.buildPtrToInt(sXLen, PtrReg);
-  MRI.setRegBank(PtrToInt.getReg(0), RBI.getRegBank(RISCV::GPRBRegBankID));
+  MRI->setRegBank(PtrToInt.getReg(0), RBI.getRegBank(RISCV::GPRBRegBankID));
   Op.setReg(PtrToInt.getReg(0));
   return select(*PtrToInt);
 }
 
 void RISCV...
[truncated]

Copy link
Collaborator

@topperc topperc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link

github-actions bot commented Oct 2, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@@ -509,19 +495,19 @@ static void getOperandsForBranch(Register CondReg, MachineRegisterInfo &MRI,
bool RISCVInstructionSelector::select(MachineInstr &MI) {
MachineBasicBlock &MBB = *MI.getParent();
MachineFunction &MF = *MBB.getParent();
MachineRegisterInfo &MRI = MF.getRegInfo();
MRI = &MF.getRegInfo();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should go into a setupMF like other targets for function state setup

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where other targets appears to just be AMDGPU and SPIRV?

@@ -509,19 +495,19 @@ static void getOperandsForBranch(Register CondReg, MachineRegisterInfo &MRI,
bool RISCVInstructionSelector::select(MachineInstr &MI) {
MachineBasicBlock &MBB = *MI.getParent();
MachineFunction &MF = *MBB.getParent();
MachineRegisterInfo &MRI = MF.getRegInfo();
MRI = &MF.getRegInfo();
MachineIRBuilder MIB(MI);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unrelated, but should not create single use MachineIRBuilders. Generally shouldn't need it during selection anyway

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since it is unrelated, I will address this in a different PR. It looks like AArch64 creates a significant number of these MachineIRBuilders. And it looks like both RISC-V and AArch64 need to use it during instruction selection. Could you please tell me more about what you have in mind?

Copy link
Contributor

@arsenm arsenm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, but should also avoid creating a new MachineIRBuilder for every instruction (if not just avoid using it altogether)

@michaelmaitland michaelmaitland merged commit a3cc4b6 into llvm:main Oct 4, 2024
8 checks passed
@michaelmaitland michaelmaitland deleted the mri-as-member branch October 4, 2024 13:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants