Skip to content
67 changes: 50 additions & 17 deletions llvm/lib/CodeGen/GlobalISel/PatternGen.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
//===- llvm/CodeGen/GlobalISel/PatternGen.cpp - PatternGen ---==//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
Expand Down Expand Up @@ -210,8 +210,10 @@
return "invalid";
}

std::string makeImmTypeStr(int Size, bool Signed) {
return (Signed ? "simm" : "uimm") + std::to_string(Size);
std::string makeImmTypeStr(int Size, bool Signed, std::string llvm_type) {

Check warning on line 213 in llvm/lib/CodeGen/GlobalISel/PatternGen.cpp

View workflow job for this annotation

GitHub Actions / Run linters

llvm/lib/CodeGen/GlobalISel/PatternGen.cpp:213:63 [readability-identifier-naming]

invalid case style for parameter 'llvm_type'
if (llvm_type.empty())
return (Signed ? "simm" : "uimm") + std::to_string(Size);
return llvm_type;
}

struct PatternNode {
Expand Down Expand Up @@ -579,15 +581,15 @@
std::string TypeStr = lltToString(Type);

// ignore bitcast ops for now
if (Op == TargetOpcode::G_BITCAST)
if ((Op == TargetOpcode::G_BITCAST) || (Op == TargetOpcode::G_CONSTANT_FOLD_BARRIER))
return Operand->patternString();

return "(" + TypeStr + " (" + std::string(UnopStr.at(Op)) + " " +
Operand->patternString() + "))";
}

LLT getRegisterTy(int OperandId) const override {
if (OperandId == -1 && Op != TargetOpcode::G_BITCAST)
if (OperandId == -1 && Op != TargetOpcode::G_BITCAST && Op != TargetOpcode::G_CONSTANT_FOLD_BARRIER)
return Type;
return Operand->getRegisterTy(OperandId);
}
Expand Down Expand Up @@ -623,22 +625,27 @@
StringRef Name;
int Size;
bool Sext;
std::string llvm_type;

Check warning on line 628 in llvm/lib/CodeGen/GlobalISel/PatternGen.cpp

View workflow job for this annotation

GitHub Actions / Run linters

llvm/lib/CodeGen/GlobalISel/PatternGen.cpp:628:15 [readability-identifier-naming]

invalid case style for member 'llvm_type'

size_t RegIdx;

RegisterNode(LLT Type, StringRef Name, size_t RegIdx, bool IsImm, int Size,
bool Sext)
bool Sext, std::string llvm_type)

Check warning on line 633 in llvm/lib/CodeGen/GlobalISel/PatternGen.cpp

View workflow job for this annotation

GitHub Actions / Run linters

llvm/lib/CodeGen/GlobalISel/PatternGen.cpp:633:39 [readability-identifier-naming]

invalid case style for parameter 'llvm_type'
: PatternNode(PN_Register, Type, IsImm), Name(Name), Size(Size),
Sext(Sext), RegIdx(RegIdx) {}
Sext(Sext), llvm_type(llvm_type), RegIdx(RegIdx) {}

std::string patternString() override {
std::string TypeStr = lltToString(Type);
bool PrintType = Type.isPointer();

if (IsImm) {
// Immediate Operands
return ("(" + RegT + " ") + (Sext ? "simm" : "uimm") +
std::to_string(Size) + ":$" + std::string(Name) + ")";
std::string pre;

Check warning on line 643 in llvm/lib/CodeGen/GlobalISel/PatternGen.cpp

View workflow job for this annotation

GitHub Actions / Run linters

llvm/lib/CodeGen/GlobalISel/PatternGen.cpp:643:19 [readability-identifier-naming]

invalid case style for variable 'pre'
if (llvm_type.empty())
pre = (Sext ? "simm" : "uimm") + std::to_string(Size);
else
pre = llvm_type;
return ("(" + RegT + " ") + pre + ":$" + std::string(Name) + ")";
}

// Vector Types (currently rv32 only)
Expand Down Expand Up @@ -701,7 +708,7 @@
abort();
}

static bool classof(const PatternNode *p) { return p->getKind() == PN_Load; }

Check warning on line 711 in llvm/lib/CodeGen/GlobalISel/PatternGen.cpp

View workflow job for this annotation

GitHub Actions / Run linters

llvm/lib/CodeGen/GlobalISel/PatternGen.cpp:711:42 [readability-identifier-naming]

invalid case style for parameter 'p'
};

struct CastNode : public PatternNode {
Expand All @@ -715,7 +722,7 @@
return "(" + LLTString + " " + Value->patternString() + ")";
}

static bool classof(const PatternNode *p) { return p->getKind() == PN_Cast; }

Check warning on line 725 in llvm/lib/CodeGen/GlobalISel/PatternGen.cpp

View workflow job for this annotation

GitHub Actions / Run linters

llvm/lib/CodeGen/GlobalISel/PatternGen.cpp:725:42 [readability-identifier-naming]

invalid case style for parameter 'p'
};

struct StoreNode : public PatternNode {
Expand Down Expand Up @@ -743,20 +750,20 @@
abort();
}

static bool classof(const PatternNode *p) { return p->getKind() == PN_Cast; }

Check warning on line 753 in llvm/lib/CodeGen/GlobalISel/PatternGen.cpp

View workflow job for this annotation

GitHub Actions / Run linters

llvm/lib/CodeGen/GlobalISel/PatternGen.cpp:753:42 [readability-identifier-naming]

invalid case style for parameter 'p'
};

using PatternOrError = std::pair<PatternError, std::unique_ptr<PatternNode>>;
static PatternOrError pError(PatternErrorT Type, MachineInstr *Inst) {
return std::make_pair(PatternError(Type, Inst), nullptr);
}
static PatternOrError PError(PatternError Error) {

Check warning on line 760 in llvm/lib/CodeGen/GlobalISel/PatternGen.cpp

View workflow job for this annotation

GitHub Actions / Run linters

llvm/lib/CodeGen/GlobalISel/PatternGen.cpp:760:23 [readability-identifier-naming]

invalid case style for function 'PError'
return std::make_pair(Error, nullptr);
}
static PatternOrError PError(PatternErrorT Type) {

Check warning on line 763 in llvm/lib/CodeGen/GlobalISel/PatternGen.cpp

View workflow job for this annotation

GitHub Actions / Run linters

llvm/lib/CodeGen/GlobalISel/PatternGen.cpp:763:23 [readability-identifier-naming]

invalid case style for function 'PError'
return std::make_pair(PatternError(Type), nullptr);
}
static PatternOrError PPattern(std::unique_ptr<PatternNode> Pattern) {

Check warning on line 766 in llvm/lib/CodeGen/GlobalISel/PatternGen.cpp

View workflow job for this annotation

GitHub Actions / Run linters

llvm/lib/CodeGen/GlobalISel/PatternGen.cpp:766:23 [readability-identifier-naming]

invalid case style for function 'PPattern'
return std::make_pair(PatternError(SUCCESS), std::move(Pattern));
}

Expand Down Expand Up @@ -937,8 +944,28 @@
ReadOffset = Offset->getOperand(1).getCImm()->getLimitedValue();
}
if (AddrI->getOpcode() == TargetOpcode::G_SELECT) {
// TODO: implement this!
return pError(FORMAT_LOAD, AddrI);
assert(AddrI->getOperand(1).isReg() && "expected register");
auto CondInstr = AddrI->getOperand(1);
auto CondReg = CondInstr.getReg();
auto [ErrCond, CondNode] = traverse(MRI, *MRI.getVRegDef(CondReg));
if (ErrCond)
return PError(ErrCond);
assert(AddrI->getOperand(2).isReg() && "expected register");
auto TrueInstr = AddrI->getOperand(2);
auto TrueReg = TrueInstr.getReg();
auto [ErrTrue, TrueNode] = traverseRegLoad(MRI, Cur, ReadSize, MRI.getVRegDef(TrueReg));
if (ErrTrue)
return PError(ErrTrue);
assert(AddrI->getOperand(3).isReg() && "expected register");
auto FalseInstr = AddrI->getOperand(3);
auto FalseReg = FalseInstr.getReg();
auto [ErrFalse, FalseNode] = traverseRegLoad(MRI, Cur, ReadSize, MRI.getVRegDef(FalseReg));
if (ErrFalse)
return PError(ErrFalse);
auto Node = std::make_unique<TernopNode>(
MRI.getType(Cur.getOperand(0).getReg()), AddrI->getOpcode(),
std::move(CondNode), std::move(TrueNode), std::move(FalseNode));
return PPattern(std::move(Node));
}
if (AddrI->getOpcode() != TargetOpcode::COPY)
return pError(FORMAT_LOAD, AddrI);
Expand All @@ -959,7 +986,7 @@

assert(Cur.getOperand(0).isReg() && "expected register");
std::unique_ptr<PatternNode> Node = std::make_unique<RegisterNode>(
Type, Field->ident, Idx, false, Type.getSizeInBits(), false);
Type, Field->ident, Idx, false, Type.getSizeInBits(), false, Field->llvm_type);

bool SizeMismatch = (int)Type.getSizeInBits() != ReadSize;

Expand Down Expand Up @@ -1040,6 +1067,7 @@

return std::make_pair(SUCCESS, std::move(Node));
}
case TargetOpcode::G_CONSTANT_FOLD_BARRIER:
case TargetOpcode::G_ANYEXT:
case TargetOpcode::G_SEXT:
case TargetOpcode::G_ZEXT:
Expand Down Expand Up @@ -1141,7 +1169,7 @@
PatternArgs[Idx].In = true;
PatternArgs[Idx].Llt = LLT();
PatternArgs[Idx].ArgTypeStr =
makeImmTypeStr(Field->len, Field->type & CDSLInstr::SIGNED);
makeImmTypeStr(Field->len, Field->type & CDSLInstr::SIGNED, Field->llvm_type);

if (Field == nullptr)
return std::make_pair(FORMAT_IMM, nullptr);
Expand All @@ -1150,7 +1178,7 @@
return std::make_pair(
SUCCESS, std::make_unique<RegisterNode>(
MRI.getType(Cur.getOperand(0).getReg()), Field->ident,
Idx, true, Field->len, Field->type & CDSLInstr::SIGNED));
Idx, true, Field->len, Field->type & CDSLInstr::SIGNED, Field->llvm_type));
}

// Else COPY is just a pass-through.
Expand Down Expand Up @@ -1331,6 +1359,10 @@
MayLoad = 0;
MayStore = 0;

if (PatternGenArgs::Args.DumpMIR) {
LLVM_DEBUG(MF.dump());
}

std::string InstName = MF.getName().str().substr(4);
std::string InstNameO = InstName;
++PatternGenNumInstructionsProcessed;
Expand Down Expand Up @@ -1362,10 +1394,6 @@
return true;
}

llvm::outs() << "Pattern for " << InstName << ": " << Node->patternString()
<< '\n';
++PatternGenNumPatternsGenerated;

LLT OutType = LLT();
std::string OutsString;
std::string InsString;
Expand Down Expand Up @@ -1409,6 +1437,11 @@
}
}

llvm::outs() << "Pattern for " << InstName << ": " << Node->patternString()
<< '\n';
++PatternGenNumPatternsGenerated;


InsString = InsString.substr(0, InsString.size() - 2);
OutsString = OutsString.substr(0, OutsString.size() - 2);

Expand Down
18 changes: 8 additions & 10 deletions llvm/tools/pattern-gen/Main.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
#include <cstdio>
#include <ctype.h>
#include <exception>
Expand Down Expand Up @@ -55,6 +55,8 @@
cl::cat(ToolOptions));
static cl::opt<bool> PrintIR("print-ir", cl::desc("Print LLVM-IR module."),
cl::cat(ToolOptions));
static cl::opt<bool> PrintMIR("print-mir", cl::desc("Print LLVM-MIR functions."),
cl::cat(ToolOptions));
static cl::opt<bool> NoExtend(
"no-extend",
cl::desc("Do not apply CDSL typing rules (Use C-like type inference)."),
Expand Down Expand Up @@ -128,15 +130,6 @@
auto Mod = std::make_unique<Module>("mod", Ctx);
auto Instrs = ParseCoreDSL2(Ts, (XLen == 64), Mod.get(), NoExtend);

if (irOut) {
std::string Str;
raw_string_ostream OS(Str);
OS << *Mod;
OS.flush();
irOut << Str << "\n";
irOut.close();
}

if (!SkipVerify)
if (verifyModule(*Mod, &errs()))
return -1;
Expand Down Expand Up @@ -165,9 +158,14 @@
PGArgsStruct Args{.Mattr = "",
.OptLevel = Opt,
.Predicates = Predicates,
.Is64Bit = (XLen == 64)};
.Is64Bit = (XLen == 64),
.DumpMIR = PrintMIR.getValue()};

optimizeBehavior(Mod.get(), Instrs, irOut, Args);

if (irOut)
irOut.close();

if (PrintIR)
llvm::outs() << *Mod << "\n";
if (!SkipFmt)
Expand Down
1 change: 1 addition & 0 deletions llvm/tools/pattern-gen/PatternGen.hpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
#pragma once
#include "lib/InstrInfo.hpp"
#include "llvm/CodeGen/SelectionDAG.h"
Expand All @@ -10,6 +10,7 @@
llvm::CodeGenOptLevel OptLevel;
std::string Predicates;
bool Is64Bit;
bool DumpMIR;
};

int optimizeBehavior(llvm::Module* M, std::vector<CDSLInstr> const& Instrs, std::ostream& OstreamIR, PGArgsStruct Args);
Expand Down
2 changes: 2 additions & 0 deletions llvm/tools/pattern-gen/lib/InstrInfo.hpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
#pragma once
#include "llvm/ADT/SmallVector.h"
#include <cstdint>
Expand Down Expand Up @@ -38,10 +38,12 @@
std::string_view ident;
uint32_t identIdx;
FieldType type;
std::string llvm_type;
};

uint8_t size;
std::string name;
std::string llvm_instr;
std::string mnemonic;
std::string argString;

Expand Down
Loading