/ libevmjit / Optimizer.cpp
Optimizer.cpp
  1  #include "Optimizer.h"
  2  
  3  #include "preprocessor/llvm_includes_start.h"
  4  #include <llvm/IR/BasicBlock.h>
  5  #include <llvm/IR/Function.h>
  6  #include <llvm/IR/Module.h>
  7  #include <llvm/IR/LegacyPassManager.h>
  8  #include <llvm/Transforms/Scalar.h>
  9  #include <llvm/Transforms/IPO.h>
 10  #include <llvm/Transforms/Utils/BasicBlockUtils.h>
 11  #include "preprocessor/llvm_includes_end.h"
 12  
 13  #include "Arith256.h"
 14  #include "Type.h"
 15  
 16  namespace dev
 17  {
 18  namespace eth
 19  {
 20  namespace jit
 21  {
 22  
 23  namespace
 24  {
 25  
 26  class LongJmpEliminationPass: public llvm::FunctionPass
 27  {
 28  	static char ID;
 29  
 30  public:
 31  	LongJmpEliminationPass():
 32  		llvm::FunctionPass(ID)
 33  	{}
 34  
 35  	virtual bool runOnFunction(llvm::Function& _func) override;
 36  };
 37  
 38  char LongJmpEliminationPass::ID = 0;
 39  
 40  bool LongJmpEliminationPass::runOnFunction(llvm::Function& _func)
 41  {
 42  	auto iter = _func.getParent()->begin();
 43  	if (&_func != &(*iter))
 44  		return false;
 45  
 46  	auto& mainFunc = _func;
 47  	auto& ctx = _func.getContext();
 48  	auto abortCode = llvm::ConstantInt::get(llvm::Type::getInt32Ty(ctx), -1);
 49  
 50  	auto& exitBB = mainFunc.back();
 51  	assert(exitBB.getName() == "Exit");
 52  	auto retPhi = llvm::cast<llvm::PHINode>(&exitBB.front());
 53  
 54  	auto modified = false;
 55  	for (auto bbIt = mainFunc.begin(); bbIt != mainFunc.end(); ++bbIt)
 56  	{
 57  		if (auto term = llvm::dyn_cast<llvm::UnreachableInst>(bbIt->getTerminator()))
 58  		{
 59  			auto longjmp = term->getPrevNode();
 60  			assert(llvm::isa<llvm::CallInst>(longjmp));
 61  			auto bbPtr = &(*bbIt);
 62  			retPhi->addIncoming(abortCode, bbPtr);
 63  			llvm::ReplaceInstWithInst(term, llvm::BranchInst::Create(&exitBB));
 64  			longjmp->eraseFromParent();
 65  			modified = true;
 66  		}
 67  	}
 68  
 69  	return modified;
 70  }
 71  
 72  }
 73  
 74  bool optimize(llvm::Module& _module)
 75  {
 76  	auto pm = llvm::legacy::PassManager{};
 77  	pm.add(llvm::createFunctionInliningPass(2, 2, false));
 78  	pm.add(new LongJmpEliminationPass{}); 				// TODO: Takes a lot of time with little effect
 79  	pm.add(llvm::createCFGSimplificationPass());
 80  	pm.add(llvm::createInstructionCombiningPass());
 81  	pm.add(llvm::createAggressiveDCEPass());
 82  	pm.add(llvm::createLowerSwitchPass());
 83  	return pm.run(_module);
 84  }
 85  
 86  namespace
 87  {
 88  
 89  class LowerEVMPass: public llvm::BasicBlockPass
 90  {
 91  	static char ID;
 92  
 93  public:
 94  	LowerEVMPass():
 95  		llvm::BasicBlockPass(ID)
 96  	{}
 97  
 98  	virtual bool runOnBasicBlock(llvm::BasicBlock& _bb) override;
 99  
100  	using llvm::BasicBlockPass::doFinalization;
101  	virtual bool doFinalization(llvm::Module& _module) override;
102  };
103  
104  char LowerEVMPass::ID = 0;
105  
106  bool LowerEVMPass::runOnBasicBlock(llvm::BasicBlock& _bb)
107  {
108  	auto modified = false;
109  	auto module = _bb.getParent()->getParent();
110  	auto i512Ty = llvm::IntegerType::get(_bb.getContext(), 512);
111  	for (auto it = _bb.begin(); it != _bb.end(); ++it)
112  	{
113  		auto& inst = *it;
114  		llvm::Function* func = nullptr;
115  		if (inst.getType() == Type::Word)
116  		{
117  			switch (inst.getOpcode())
118  			{
119  			case llvm::Instruction::UDiv:
120  				func = Arith256::getUDiv256Func(*module);
121  				break;
122  
123  			case llvm::Instruction::URem:
124  				func = Arith256::getURem256Func(*module);
125  				break;
126  
127  			case llvm::Instruction::SDiv:
128  				func = Arith256::getSDiv256Func(*module);
129  				break;
130  
131  			case llvm::Instruction::SRem:
132  				func = Arith256::getSRem256Func(*module);
133  				break;
134  			}
135  		}
136  		else if (inst.getType() == i512Ty)
137  		{
138  			switch (inst.getOpcode())
139  			{
140  			case llvm::Instruction::URem:
141  				func = Arith256::getURem512Func(*module);
142  				break;
143  			}
144  		}
145  
146  		if (func)
147  		{
148  			auto call = llvm::CallInst::Create(func, {inst.getOperand(0), inst.getOperand(1)});
149  			llvm::ReplaceInstWithInst(_bb.getInstList(), it, call);
150  			modified = true;
151  		}
152  	}
153  	return modified;
154  }
155  
156  bool LowerEVMPass::doFinalization(llvm::Module&)
157  {
158  	return false;
159  }
160  
161  }
162  
163  bool prepare(llvm::Module& _module)
164  {
165  	auto pm = llvm::legacy::PassManager{};
166  	pm.add(llvm::createCFGSimplificationPass());
167  	pm.add(llvm::createDeadCodeEliminationPass());
168  	pm.add(new LowerEVMPass{});
169  	return pm.run(_module);
170  }
171  
172  }
173  }
174  }