Optimizer.cpp
1 #include "Optimizer.h" 2 3 #include "preprocessor/llvm_includes_start.h" 4 #include <llvm/IR/BasicBlock.h> 5 #include <llvm/IR/Function.h> 6 #include <llvm/IR/Module.h> 7 #include <llvm/IR/LegacyPassManager.h> 8 #include <llvm/Transforms/Scalar.h> 9 #include <llvm/Transforms/IPO.h> 10 #include <llvm/Transforms/Utils/BasicBlockUtils.h> 11 #include "preprocessor/llvm_includes_end.h" 12 13 #include "Arith256.h" 14 #include "Type.h" 15 16 namespace dev 17 { 18 namespace eth 19 { 20 namespace jit 21 { 22 23 namespace 24 { 25 26 class LongJmpEliminationPass: public llvm::FunctionPass 27 { 28 static char ID; 29 30 public: 31 LongJmpEliminationPass(): 32 llvm::FunctionPass(ID) 33 {} 34 35 virtual bool runOnFunction(llvm::Function& _func) override; 36 }; 37 38 char LongJmpEliminationPass::ID = 0; 39 40 bool LongJmpEliminationPass::runOnFunction(llvm::Function& _func) 41 { 42 auto iter = _func.getParent()->begin(); 43 if (&_func != &(*iter)) 44 return false; 45 46 auto& mainFunc = _func; 47 auto& ctx = _func.getContext(); 48 auto abortCode = llvm::ConstantInt::get(llvm::Type::getInt32Ty(ctx), -1); 49 50 auto& exitBB = mainFunc.back(); 51 assert(exitBB.getName() == "Exit"); 52 auto retPhi = llvm::cast<llvm::PHINode>(&exitBB.front()); 53 54 auto modified = false; 55 for (auto bbIt = mainFunc.begin(); bbIt != mainFunc.end(); ++bbIt) 56 { 57 if (auto term = llvm::dyn_cast<llvm::UnreachableInst>(bbIt->getTerminator())) 58 { 59 auto longjmp = term->getPrevNode(); 60 assert(llvm::isa<llvm::CallInst>(longjmp)); 61 auto bbPtr = &(*bbIt); 62 retPhi->addIncoming(abortCode, bbPtr); 63 llvm::ReplaceInstWithInst(term, llvm::BranchInst::Create(&exitBB)); 64 longjmp->eraseFromParent(); 65 modified = true; 66 } 67 } 68 69 return modified; 70 } 71 72 } 73 74 bool optimize(llvm::Module& _module) 75 { 76 auto pm = llvm::legacy::PassManager{}; 77 pm.add(llvm::createFunctionInliningPass(2, 2, false)); 78 pm.add(new LongJmpEliminationPass{}); // TODO: Takes a lot of time with little effect 79 pm.add(llvm::createCFGSimplificationPass()); 80 pm.add(llvm::createInstructionCombiningPass()); 81 pm.add(llvm::createAggressiveDCEPass()); 82 pm.add(llvm::createLowerSwitchPass()); 83 return pm.run(_module); 84 } 85 86 namespace 87 { 88 89 class LowerEVMPass: public llvm::BasicBlockPass 90 { 91 static char ID; 92 93 public: 94 LowerEVMPass(): 95 llvm::BasicBlockPass(ID) 96 {} 97 98 virtual bool runOnBasicBlock(llvm::BasicBlock& _bb) override; 99 100 using llvm::BasicBlockPass::doFinalization; 101 virtual bool doFinalization(llvm::Module& _module) override; 102 }; 103 104 char LowerEVMPass::ID = 0; 105 106 bool LowerEVMPass::runOnBasicBlock(llvm::BasicBlock& _bb) 107 { 108 auto modified = false; 109 auto module = _bb.getParent()->getParent(); 110 auto i512Ty = llvm::IntegerType::get(_bb.getContext(), 512); 111 for (auto it = _bb.begin(); it != _bb.end(); ++it) 112 { 113 auto& inst = *it; 114 llvm::Function* func = nullptr; 115 if (inst.getType() == Type::Word) 116 { 117 switch (inst.getOpcode()) 118 { 119 case llvm::Instruction::UDiv: 120 func = Arith256::getUDiv256Func(*module); 121 break; 122 123 case llvm::Instruction::URem: 124 func = Arith256::getURem256Func(*module); 125 break; 126 127 case llvm::Instruction::SDiv: 128 func = Arith256::getSDiv256Func(*module); 129 break; 130 131 case llvm::Instruction::SRem: 132 func = Arith256::getSRem256Func(*module); 133 break; 134 } 135 } 136 else if (inst.getType() == i512Ty) 137 { 138 switch (inst.getOpcode()) 139 { 140 case llvm::Instruction::URem: 141 func = Arith256::getURem512Func(*module); 142 break; 143 } 144 } 145 146 if (func) 147 { 148 auto call = llvm::CallInst::Create(func, {inst.getOperand(0), inst.getOperand(1)}); 149 llvm::ReplaceInstWithInst(_bb.getInstList(), it, call); 150 modified = true; 151 } 152 } 153 return modified; 154 } 155 156 bool LowerEVMPass::doFinalization(llvm::Module&) 157 { 158 return false; 159 } 160 161 } 162 163 bool prepare(llvm::Module& _module) 164 { 165 auto pm = llvm::legacy::PassManager{}; 166 pm.add(llvm::createCFGSimplificationPass()); 167 pm.add(llvm::createDeadCodeEliminationPass()); 168 pm.add(new LowerEVMPass{}); 169 return pm.run(_module); 170 } 171 172 } 173 } 174 }