/ libevmjit / Arith256.cpp
Arith256.cpp
  1  #include "Arith256.h"
  2  
  3  #include <iostream>
  4  #include <iomanip>
  5  
  6  #include "preprocessor/llvm_includes_start.h"
  7  #include <llvm/IR/Module.h>
  8  #include <llvm/IR/IntrinsicInst.h>
  9  #include "preprocessor/llvm_includes_end.h"
 10  
 11  #include "Type.h"
 12  #include "Endianness.h"
 13  #include "Utils.h"
 14  
 15  namespace dev
 16  {
 17  namespace eth
 18  {
 19  namespace jit
 20  {
 21  
 22  Arith256::Arith256(IRBuilder& _builder) :
 23  	CompilerHelper(_builder)
 24  {}
 25  
 26  void Arith256::debug(llvm::Value* _value, char _c, llvm::Module& _module, IRBuilder& _builder)
 27  {
 28  	static const auto funcName = "debug";
 29  	auto func = _module.getFunction(funcName);
 30  	if (!func)
 31  		func = llvm::Function::Create(llvm::FunctionType::get(Type::Void, {Type::Word, _builder.getInt8Ty()}, false), llvm::Function::ExternalLinkage, funcName, &_module);
 32  
 33  	_builder.CreateCall(func, {_builder.CreateZExtOrTrunc(_value, Type::Word), _builder.getInt8(_c)});
 34  }
 35  
 36  namespace
 37  {
 38  llvm::Function* createUDivRemFunc(llvm::Type* _type, llvm::Module& _module, char const* _funcName)
 39  {
 40  	// Based of "Improved shift divisor algorithm" from "Software Integer Division" by Microsoft Research
 41  	// The following algorithm also handles divisor of value 0 returning 0 for both quotient and remainder
 42  
 43  	auto retType = llvm::VectorType::get(_type, 2);
 44  	auto func = llvm::Function::Create(llvm::FunctionType::get(retType, {_type, _type}, false), llvm::Function::PrivateLinkage, _funcName, &_module);
 45  	func->setDoesNotThrow();
 46  	func->setDoesNotAccessMemory();
 47  
 48  	auto zero = llvm::ConstantInt::get(_type, 0);
 49  	auto one = llvm::ConstantInt::get(_type, 1);
 50  
 51  	auto iter = func->arg_begin();
 52  	llvm::Argument* x = &(*iter++);
 53  	x->setName("x");
 54  	llvm::Argument* y = &(*iter);
 55  	y->setName("y");
 56  
 57  	auto entryBB = llvm::BasicBlock::Create(_module.getContext(), "Entry", func);
 58  	auto mainBB = llvm::BasicBlock::Create(_module.getContext(), "Main", func);
 59  	auto loopBB = llvm::BasicBlock::Create(_module.getContext(), "Loop", func);
 60  	auto continueBB = llvm::BasicBlock::Create(_module.getContext(), "Continue", func);
 61  	auto returnBB = llvm::BasicBlock::Create(_module.getContext(), "Return", func);
 62  
 63  	auto builder = IRBuilder{entryBB};
 64  	auto yLEx = builder.CreateICmpULE(y, x);
 65  	auto r0 = x;
 66  	builder.CreateCondBr(yLEx, mainBB, returnBB);
 67  
 68  	builder.SetInsertPoint(mainBB);
 69  	auto ctlzIntr = llvm::Intrinsic::getDeclaration(&_module, llvm::Intrinsic::ctlz, _type);
 70  	// both y and r are non-zero
 71  	auto yLz = builder.CreateCall(ctlzIntr, {y, builder.getInt1(true)}, "y.lz");
 72  	auto rLz = builder.CreateCall(ctlzIntr, {r0, builder.getInt1(true)}, "r.lz");
 73  	auto i0 = builder.CreateNUWSub(yLz, rLz, "i0");
 74  	auto y0 = builder.CreateShl(y, i0);
 75  	builder.CreateBr(loopBB);
 76  
 77  	builder.SetInsertPoint(loopBB);
 78  	auto yPhi = builder.CreatePHI(_type, 2, "y.phi");
 79  	auto rPhi = builder.CreatePHI(_type, 2, "r.phi");
 80  	auto iPhi = builder.CreatePHI(_type, 2, "i.phi");
 81  	auto qPhi = builder.CreatePHI(_type, 2, "q.phi");
 82  	auto rUpdate = builder.CreateNUWSub(rPhi, yPhi);
 83  	auto qUpdate = builder.CreateOr(qPhi, one);	// q += 1, q lowest bit is 0
 84  	auto rGEy = builder.CreateICmpUGE(rPhi, yPhi);
 85  	auto r1 = builder.CreateSelect(rGEy, rUpdate, rPhi, "r1");
 86  	auto q1 = builder.CreateSelect(rGEy, qUpdate, qPhi, "q");
 87  	auto iZero = builder.CreateICmpEQ(iPhi, zero);
 88  	builder.CreateCondBr(iZero, returnBB, continueBB);
 89  
 90  	builder.SetInsertPoint(continueBB);
 91  	auto i2 = builder.CreateNUWSub(iPhi, one);
 92  	auto q2 = builder.CreateShl(q1, one);
 93  	auto y2 = builder.CreateLShr(yPhi, one);
 94  	builder.CreateBr(loopBB);
 95  
 96  	yPhi->addIncoming(y0, mainBB);
 97  	yPhi->addIncoming(y2, continueBB);
 98  	rPhi->addIncoming(r0, mainBB);
 99  	rPhi->addIncoming(r1, continueBB);
100  	iPhi->addIncoming(i0, mainBB);
101  	iPhi->addIncoming(i2, continueBB);
102  	qPhi->addIncoming(zero, mainBB);
103  	qPhi->addIncoming(q2, continueBB);
104  
105  	builder.SetInsertPoint(returnBB);
106  	auto qRet = builder.CreatePHI(_type, 2, "q.ret");
107  	qRet->addIncoming(zero, entryBB);
108  	qRet->addIncoming(q1, loopBB);
109  	auto rRet = builder.CreatePHI(_type, 2, "r.ret");
110  	rRet->addIncoming(r0, entryBB);
111  	rRet->addIncoming(r1, loopBB);
112  	auto ret = builder.CreateInsertElement(llvm::UndefValue::get(retType), qRet, uint64_t(0), "ret0");
113  	ret = builder.CreateInsertElement(ret, rRet, 1, "ret");
114  	builder.CreateRet(ret);
115  
116  	return func;
117  }
118  }
119  
120  llvm::Function* Arith256::getUDivRem256Func(llvm::Module& _module)
121  {
122  	static const auto funcName = "evm.udivrem.i256";
123  	if (auto func = _module.getFunction(funcName))
124  		return func;
125  
126  	return createUDivRemFunc(Type::Word, _module, funcName);
127  }
128  
129  llvm::Function* Arith256::getUDivRem512Func(llvm::Module& _module)
130  {
131  	static const auto funcName = "evm.udivrem.i512";
132  	if (auto func = _module.getFunction(funcName))
133  		return func;
134  
135  	return createUDivRemFunc(llvm::IntegerType::get(_module.getContext(), 512), _module, funcName);
136  }
137  
138  llvm::Function* Arith256::getUDiv256Func(llvm::Module& _module)
139  {
140  	static const auto funcName = "evm.udiv.i256";
141  	if (auto func = _module.getFunction(funcName))
142  		return func;
143  
144  	auto udivremFunc = getUDivRem256Func(_module);
145  
146  	auto func = llvm::Function::Create(llvm::FunctionType::get(Type::Word, {Type::Word, Type::Word}, false), llvm::Function::PrivateLinkage, funcName, &_module);
147  	func->setDoesNotThrow();
148  	func->setDoesNotAccessMemory();
149  
150  	auto iter = func->arg_begin();
151  	llvm::Argument* x = &(*iter++);
152  	x->setName("x");
153  	llvm::Argument* y = &(*iter);
154  	y->setName("y");
155  
156  	auto bb = llvm::BasicBlock::Create(_module.getContext(), {}, func);
157  	auto builder = IRBuilder{bb};
158  	auto udivrem = builder.CreateCall(udivremFunc, {x, y});
159  	auto udiv = builder.CreateExtractElement(udivrem, uint64_t(0));
160  	builder.CreateRet(udiv);
161  
162  	return func;
163  }
164  
165  namespace
166  {
167  llvm::Function* createURemFunc(llvm::Type* _type, llvm::Module& _module, char const* _funcName)
168  {
169  	auto udivremFunc = _type == Type::Word ? Arith256::getUDivRem256Func(_module) : Arith256::getUDivRem512Func(_module);
170  
171  	auto func = llvm::Function::Create(llvm::FunctionType::get(_type, {_type, _type}, false), llvm::Function::PrivateLinkage, _funcName, &_module);
172  	func->setDoesNotThrow();
173  	func->setDoesNotAccessMemory();
174  
175  	auto iter = func->arg_begin();
176  	llvm::Argument* x = &(*iter++);
177  	x->setName("x");
178  	llvm::Argument* y = &(*iter);
179  	y->setName("y");
180  
181  	auto bb = llvm::BasicBlock::Create(_module.getContext(), {}, func);
182  	auto builder = IRBuilder{bb};
183  	auto udivrem = builder.CreateCall(udivremFunc, {x, y});
184  	auto r = builder.CreateExtractElement(udivrem, uint64_t(1));
185  	builder.CreateRet(r);
186  
187  	return func;
188  }
189  }
190  
191  llvm::Function* Arith256::getURem256Func(llvm::Module& _module)
192  {
193  	static const auto funcName = "evm.urem.i256";
194  	if (auto func = _module.getFunction(funcName))
195  		return func;
196  	return createURemFunc(Type::Word, _module, funcName);
197  }
198  
199  llvm::Function* Arith256::getURem512Func(llvm::Module& _module)
200  {
201  	static const auto funcName = "evm.urem.i512";
202  	if (auto func = _module.getFunction(funcName))
203  		return func;
204  	return createURemFunc(llvm::IntegerType::get(_module.getContext(), 512), _module, funcName);
205  }
206  
207  llvm::Function* Arith256::getSDivRem256Func(llvm::Module& _module)
208  {
209  	static const auto funcName = "evm.sdivrem.i256";
210  	if (auto func = _module.getFunction(funcName))
211  		return func;
212  
213  	auto udivremFunc = getUDivRem256Func(_module);
214  
215  	auto retType = llvm::VectorType::get(Type::Word, 2);
216  	auto func = llvm::Function::Create(llvm::FunctionType::get(retType, {Type::Word, Type::Word}, false), llvm::Function::PrivateLinkage, funcName, &_module);
217  	func->setDoesNotThrow();
218  	func->setDoesNotAccessMemory();
219  
220  	auto iter = func->arg_begin();
221  	llvm::Argument* x = &(*iter++);
222  	x->setName("x");
223  	llvm::Argument* y = &(*iter);
224  	y->setName("y");
225  
226  	auto bb = llvm::BasicBlock::Create(_module.getContext(), "", func);
227  	auto builder = IRBuilder{bb};
228  	auto xIsNeg = builder.CreateICmpSLT(x, Constant::get(0));
229  	auto xNeg = builder.CreateSub(Constant::get(0), x);
230  	auto xAbs = builder.CreateSelect(xIsNeg, xNeg, x);
231  
232  	auto yIsNeg = builder.CreateICmpSLT(y, Constant::get(0));
233  	auto yNeg = builder.CreateSub(Constant::get(0), y);
234  	auto yAbs = builder.CreateSelect(yIsNeg, yNeg, y);
235  
236  	auto res = builder.CreateCall(udivremFunc, {xAbs, yAbs});
237  	auto qAbs = builder.CreateExtractElement(res, uint64_t(0));
238  	auto rAbs = builder.CreateExtractElement(res, 1);
239  
240  	// the remainder has the same sign as dividend
241  	auto rNeg = builder.CreateSub(Constant::get(0), rAbs);
242  	auto r = builder.CreateSelect(xIsNeg, rNeg, rAbs);
243  
244  	auto qNeg = builder.CreateSub(Constant::get(0), qAbs);
245  	auto xyOpposite = builder.CreateXor(xIsNeg, yIsNeg);
246  	auto q = builder.CreateSelect(xyOpposite, qNeg, qAbs);
247  
248  	auto ret = builder.CreateInsertElement(llvm::UndefValue::get(retType), q, uint64_t(0));
249  	ret = builder.CreateInsertElement(ret, r, 1);
250  	builder.CreateRet(ret);
251  
252  	return func;
253  }
254  
255  llvm::Function* Arith256::getSDiv256Func(llvm::Module& _module)
256  {
257  	static const auto funcName = "evm.sdiv.i256";
258  	if (auto func = _module.getFunction(funcName))
259  		return func;
260  
261  	auto sdivremFunc = getSDivRem256Func(_module);
262  
263  	auto func = llvm::Function::Create(llvm::FunctionType::get(Type::Word, {Type::Word, Type::Word}, false), llvm::Function::PrivateLinkage, funcName, &_module);
264  	func->setDoesNotThrow();
265  	func->setDoesNotAccessMemory();
266  
267  	auto iter = func->arg_begin();
268  	llvm::Argument* x = &(*iter++);
269  	x->setName("x");
270  	llvm::Argument* y = &(*iter);
271  	y->setName("y");
272  
273  	auto bb = llvm::BasicBlock::Create(_module.getContext(), {}, func);
274  	auto builder = IRBuilder{bb};
275  	auto sdivrem = builder.CreateCall(sdivremFunc, {x, y});
276  	auto q = builder.CreateExtractElement(sdivrem, uint64_t(0));
277  	builder.CreateRet(q);
278  
279  	return func;
280  }
281  
282  llvm::Function* Arith256::getSRem256Func(llvm::Module& _module)
283  {
284  	static const auto funcName = "evm.srem.i256";
285  	if (auto func = _module.getFunction(funcName))
286  		return func;
287  
288  	auto sdivremFunc = getSDivRem256Func(_module);
289  
290  	auto func = llvm::Function::Create(llvm::FunctionType::get(Type::Word, {Type::Word, Type::Word}, false), llvm::Function::PrivateLinkage, funcName, &_module);
291  	func->setDoesNotThrow();
292  	func->setDoesNotAccessMemory();
293  
294  	auto iter = func->arg_begin();
295  	llvm::Argument* x = &(*iter++);
296  	x->setName("x");
297  	llvm::Argument* y = &(*iter);
298  	y->setName("y");
299  
300  	auto bb = llvm::BasicBlock::Create(_module.getContext(), {}, func);
301  	auto builder = IRBuilder{bb};
302  	auto sdivrem = builder.CreateCall(sdivremFunc, {x, y});
303  	auto r = builder.CreateExtractElement(sdivrem, uint64_t(1));
304  	builder.CreateRet(r);
305  
306  	return func;
307  }
308  
309  llvm::Function* Arith256::getExpFunc()
310  {
311  	if (!m_exp)
312  	{
313  		llvm::Type* argTypes[] = {Type::Word, Type::Word};
314  		m_exp = llvm::Function::Create(llvm::FunctionType::get(Type::Word, argTypes, false), llvm::Function::PrivateLinkage, "exp", getModule());
315  		m_exp->setDoesNotThrow();
316  		m_exp->setDoesNotAccessMemory();
317  
318  		auto iter = m_exp->arg_begin();
319  		llvm::Argument* base = &(*iter++);
320  		base->setName("base");
321  		llvm::Argument* exponent = &(*iter);
322  		exponent->setName("exponent");
323  
324  		InsertPointGuard guard{m_builder};
325  
326  		//	while (e != 0) {
327  		//		if (e % 2 == 1)
328  		//			r *= b;
329  		//		b *= b;
330  		//		e /= 2;
331  		//	}
332  
333  		auto entryBB = llvm::BasicBlock::Create(m_builder.getContext(), "Entry", m_exp);
334  		auto headerBB = llvm::BasicBlock::Create(m_builder.getContext(), "LoopHeader", m_exp);
335  		auto bodyBB = llvm::BasicBlock::Create(m_builder.getContext(), "LoopBody", m_exp);
336  		auto updateBB = llvm::BasicBlock::Create(m_builder.getContext(), "ResultUpdate", m_exp);
337  		auto continueBB = llvm::BasicBlock::Create(m_builder.getContext(), "Continue", m_exp);
338  		auto returnBB = llvm::BasicBlock::Create(m_builder.getContext(), "Return", m_exp);
339  
340  		m_builder.SetInsertPoint(entryBB);
341  		m_builder.CreateBr(headerBB);
342  
343  		m_builder.SetInsertPoint(headerBB);
344  		auto r = m_builder.CreatePHI(Type::Word, 2, "r");
345  		auto b = m_builder.CreatePHI(Type::Word, 2, "b");
346  		auto e = m_builder.CreatePHI(Type::Word, 2, "e");
347  		auto eNonZero = m_builder.CreateICmpNE(e, Constant::get(0), "e.nonzero");
348  		m_builder.CreateCondBr(eNonZero, bodyBB, returnBB);
349  
350  		m_builder.SetInsertPoint(bodyBB);
351  		auto eOdd = m_builder.CreateICmpNE(m_builder.CreateAnd(e, Constant::get(1)), Constant::get(0), "e.isodd");
352  		m_builder.CreateCondBr(eOdd, updateBB, continueBB);
353  
354  		m_builder.SetInsertPoint(updateBB);
355  		auto r0 = m_builder.CreateMul(r, b);
356  		m_builder.CreateBr(continueBB);
357  
358  		m_builder.SetInsertPoint(continueBB);
359  		auto r1 = m_builder.CreatePHI(Type::Word, 2, "r1");
360  		r1->addIncoming(r, bodyBB);
361  		r1->addIncoming(r0, updateBB);
362  		auto b1 = m_builder.CreateMul(b, b);
363  		auto e1 = m_builder.CreateLShr(e, Constant::get(1), "e1");
364  		m_builder.CreateBr(headerBB);
365  
366  		r->addIncoming(Constant::get(1), entryBB);
367  		r->addIncoming(r1, continueBB);
368  		b->addIncoming(base, entryBB);
369  		b->addIncoming(b1, continueBB);
370  		e->addIncoming(exponent, entryBB);
371  		e->addIncoming(e1, continueBB);
372  
373  		m_builder.SetInsertPoint(returnBB);
374  		m_builder.CreateRet(r);
375  	}
376  	return m_exp;
377  }
378  
379  llvm::Value* Arith256::exp(llvm::Value* _arg1, llvm::Value* _arg2)
380  {
381  	//	while (e != 0) {
382  	//		if (e % 2 == 1)
383  	//			r *= b;
384  	//		b *= b;
385  	//		e /= 2;
386  	//	}
387  
388  	if (auto c1 = llvm::dyn_cast<llvm::ConstantInt>(_arg1))
389  	{
390  		if (auto c2 = llvm::dyn_cast<llvm::ConstantInt>(_arg2))
391  		{
392  			auto b = c1->getValue();
393  			auto e = c2->getValue();
394  			auto r = llvm::APInt{256, 1};
395  			while (e != 0)
396  			{
397  				if (e[0])
398  					r *= b;
399  				b *= b;
400  				e = e.lshr(1);
401  			}
402  			return Constant::get(r);
403  		}
404  	}
405  
406  	return m_builder.CreateCall(getExpFunc(), {_arg1, _arg2});
407  }
408  
409  }
410  }
411  }
412  
413  extern "C"
414  {
415  	EXPORT void debug(uint64_t a, uint64_t b, uint64_t c, uint64_t d, char z)
416  	{
417  		DLOG(JIT) << "DEBUG " << std::dec << z << ": " //<< d << c << b << a
418  				<< " ["	<< std::hex << std::setfill('0') << std::setw(16) << d << std::setw(16) << c << std::setw(16) << b << std::setw(16) << a << "]\n";
419  	}
420  }