/ contracts / utils / FixedPointMathLib.sol
FixedPointMathLib.sol
  1  // SPDX-License-Identifier: MIT
  2  pragma solidity >=0.6.0;
  3  
  4  /// @notice Arithmetic library with operations for fixed-point numbers.
  5  /// @author Solmate (https://github.com/Rari-Capital/solmate/blob/main/src/utils/FixedPointMathLib.sol)
  6  library FixedPointMathLib {
  7      /*//////////////////////////////////////////////////////////////
  8                      SIMPLIFIED FIXED POINT OPERATIONS
  9      //////////////////////////////////////////////////////////////*/
 10  
 11      uint256 internal constant WAD = 1e18; // The scalar of ETH and most ERC20s.
 12  
 13      function mulWadDown(uint256 x, uint256 y) internal pure returns (uint256) {
 14          return mulDivDown(x, y, WAD); // Equivalent to (x * y) / WAD rounded down.
 15      }
 16  
 17      function mulWadUp(uint256 x, uint256 y) internal pure returns (uint256) {
 18          return mulDivUp(x, y, WAD); // Equivalent to (x * y) / WAD rounded up.
 19      }
 20  
 21      function divWadDown(uint256 x, uint256 y) internal pure returns (uint256) {
 22          return mulDivDown(x, WAD, y); // Equivalent to (x * WAD) / y rounded down.
 23      }
 24  
 25      function divWadUp(uint256 x, uint256 y) internal pure returns (uint256) {
 26          return mulDivUp(x, WAD, y); // Equivalent to (x * WAD) / y rounded up.
 27      }
 28  
 29      function powWad(int256 x, int256 y) internal pure returns (int256) {
 30          // Equivalent to x to the power of y because x ** y = (e ** ln(x)) ** y = e ** (ln(x) * y)
 31          return expWad((lnWad(x) * y) / int256(WAD)); // Using ln(x) means x must be greater than 0.
 32      }
 33  
 34      function expWad(int256 x) internal pure returns (int256 r) {
 35          {
 36              // When the result is < 0.5 we return zero. This happens when
 37              // x <= floor(log(0.5e18) * 1e18) ~ -42e18
 38              if (x <= -42139678854452767551) return 0;
 39  
 40              // When the result is > (2**255 - 1) / 1e18 we can not represent it as an
 41              // int. This happens when x >= floor(log((2**255 - 1) / 1e18) * 1e18) ~ 135.
 42              if (x >= 135305999368893231589) revert("EXP_OVERFLOW");
 43  
 44              // x is now in the range (-42, 136) * 1e18. Convert to (-42, 136) * 2**96
 45              // for more intermediate precision and a binary basis. This base conversion
 46              // is a multiplication by 1e18 / 2**96 = 5**18 / 2**78.
 47              x = (x << 78) / 5**18;
 48  
 49              // Reduce range of x to (-½ ln 2, ½ ln 2) * 2**96 by factoring out powers
 50              // of two such that exp(x) = exp(x') * 2**k, where k is an integer.
 51              // Solving this gives k = round(x / log(2)) and x' = x - k * log(2).
 52              int256 k = ((x << 96) / 54916777467707473351141471128 + 2**95) >> 96;
 53              x = x - k * 54916777467707473351141471128;
 54  
 55              // k is in the range [-61, 195].
 56  
 57              // Evaluate using a (6, 7)-term rational approximation.
 58              // p is made monic, we'll multiply by a scale factor later.
 59              int256 y = x + 1346386616545796478920950773328;
 60              y = ((y * x) >> 96) + 57155421227552351082224309758442;
 61              int256 p = y + x - 94201549194550492254356042504812;
 62              p = ((p * y) >> 96) + 28719021644029726153956944680412240;
 63              p = p * x + (4385272521454847904659076985693276 << 96);
 64  
 65              // We leave p in 2**192 basis so we don't need to scale it back up for the division.
 66              int256 q = x - 2855989394907223263936484059900;
 67              q = ((q * x) >> 96) + 50020603652535783019961831881945;
 68              q = ((q * x) >> 96) - 533845033583426703283633433725380;
 69              q = ((q * x) >> 96) + 3604857256930695427073651918091429;
 70              q = ((q * x) >> 96) - 14423608567350463180887372962807573;
 71              q = ((q * x) >> 96) + 26449188498355588339934803723976023;
 72  
 73              assembly {
 74                  // Div in assembly because solidity adds a zero check despite the unchecked.
 75                  // The q polynomial won't have zeros in the domain as all its roots are complex.
 76                  // No scaling is necessary because p is already 2**96 too large.
 77                  r := sdiv(p, q)
 78              }
 79  
 80              // r should be in the range (0.09, 0.25) * 2**96.
 81  
 82              // We now need to multiply r by:
 83              // * the scale factor s = ~6.031367120.
 84              // * the 2**k factor from the range reduction.
 85              // * the 1e18 / 2**96 factor for base conversion.
 86              // We do this all at once, with an intermediate result in 2**213
 87              // basis, so the final right shift is always by a positive amount.
 88              r = int256((uint256(r) * 3822833074963236453042738258902158003155416615667) >> uint256(195 - k));
 89          }
 90      }
 91  
 92      function lnWad(int256 x) internal pure returns (int256 r) {
 93          {
 94              require(x > 0, "UNDEFINED");
 95  
 96              // We want to convert x from 10**18 fixed point to 2**96 fixed point.
 97              // We do this by multiplying by 2**96 / 10**18. But since
 98              // ln(x * C) = ln(x) + ln(C), we can simply do nothing here
 99              // and add ln(2**96 / 10**18) at the end.
100  
101              // Reduce range of x to (1, 2) * 2**96
102              // ln(2^k * x) = k * ln(2) + ln(x)
103              int256 k = int256(log2fpl(uint256(x))) - 96;
104              x <<= uint256(159 - k);
105              x = int256(uint256(x) >> 159);
106  
107              // Evaluate using a (8, 8)-term rational approximation.
108              // p is made monic, we will multiply by a scale factor later.
109              int256 p = x + 3273285459638523848632254066296;
110              p = ((p * x) >> 96) + 24828157081833163892658089445524;
111              p = ((p * x) >> 96) + 43456485725739037958740375743393;
112              p = ((p * x) >> 96) - 11111509109440967052023855526967;
113              p = ((p * x) >> 96) - 45023709667254063763336534515857;
114              p = ((p * x) >> 96) - 14706773417378608786704636184526;
115              p = p * x - (795164235651350426258249787498 << 96);
116  
117              // We leave p in 2**192 basis so we don't need to scale it back up for the division.
118              // q is monic by convention.
119              int256 q = x + 5573035233440673466300451813936;
120              q = ((q * x) >> 96) + 71694874799317883764090561454958;
121              q = ((q * x) >> 96) + 283447036172924575727196451306956;
122              q = ((q * x) >> 96) + 401686690394027663651624208769553;
123              q = ((q * x) >> 96) + 204048457590392012362485061816622;
124              q = ((q * x) >> 96) + 31853899698501571402653359427138;
125              q = ((q * x) >> 96) + 909429971244387300277376558375;
126              assembly {
127                  // Div in assembly because solidity adds a zero check despite the unchecked.
128                  // The q polynomial is known not to have zeros in the domain.
129                  // No scaling required because p is already 2**96 too large.
130                  r := sdiv(p, q)
131              }
132  
133              // r is in the range (0, 0.125) * 2**96
134  
135              // Finalization, we need to:
136              // * multiply by the scale factor s = 5.549…
137              // * add ln(2**96 / 10**18)
138              // * add k * ln(2)
139              // * multiply by 10**18 / 2**96 = 5**18 >> 78
140  
141              // mul s * 5e18 * 2**96, base is now 5**18 * 2**192
142              r *= 1677202110996718588342820967067443963516166;
143              // add ln(2) * k * 5e18 * 2**192
144              r += 16597577552685614221487285958193947469193820559219878177908093499208371 * k;
145              // add ln(2**96 / 10**18) * 5e18 * 2**192
146              r += 600920179829731861736702779321621459595472258049074101567377883020018308;
147              // base conversion: mul 2**18 / 2**192
148              r >>= 174;
149          }
150      }
151  
152      /*//////////////////////////////////////////////////////////////
153                      LOW LEVEL FIXED POINT OPERATIONS
154      //////////////////////////////////////////////////////////////*/
155  
156      function mulDivDown(
157          uint256 x,
158          uint256 y,
159          uint256 denominator
160      ) internal pure returns (uint256 z) {
161          assembly {
162              // Store x * y in z for now.
163              z := mul(x, y)
164  
165              // Equivalent to require(denominator != 0 && (x == 0 || (x * y) / x == y))
166              if iszero(and(iszero(iszero(denominator)), or(iszero(x), eq(div(z, x), y)))) {
167                  revert(0, 0)
168              }
169  
170              // Divide z by the denominator.
171              z := div(z, denominator)
172          }
173      }
174  
175      function mulDivUp(
176          uint256 x,
177          uint256 y,
178          uint256 denominator
179      ) internal pure returns (uint256 z) {
180          assembly {
181              // Store x * y in z for now.
182              z := mul(x, y)
183  
184              // Equivalent to require(denominator != 0 && (x == 0 || (x * y) / x == y))
185              if iszero(and(iszero(iszero(denominator)), or(iszero(x), eq(div(z, x), y)))) {
186                  revert(0, 0)
187              }
188  
189              // First, divide z - 1 by the denominator and add 1.
190              // We allow z - 1 to underflow if z is 0, because we multiply the
191              // end result by 0 if z is zero, ensuring we return 0 if z is zero.
192              z := mul(iszero(iszero(z)), add(div(sub(z, 1), denominator), 1))
193          }
194      }
195  
196      function rpow(
197          uint256 x,
198          uint256 n,
199          uint256 scalar
200      ) internal pure returns (uint256 z) {
201          assembly {
202              switch x
203              case 0 {
204                  switch n
205                  case 0 {
206                      // 0 ** 0 = 1
207                      z := scalar
208                  }
209                  default {
210                      // 0 ** n = 0
211                      z := 0
212                  }
213              }
214              default {
215                  switch mod(n, 2)
216                  case 0 {
217                      // If n is even, store scalar in z for now.
218                      z := scalar
219                  }
220                  default {
221                      // If n is odd, store x in z for now.
222                      z := x
223                  }
224  
225                  // Shifting right by 1 is like dividing by 2.
226                  let half := shr(1, scalar)
227  
228                  for {
229                      // Shift n right by 1 before looping to halve it.
230                      n := shr(1, n)
231                  } n {
232                      // Shift n right by 1 each iteration to halve it.
233                      n := shr(1, n)
234                  } {
235                      // Revert immediately if x ** 2 would overflow.
236                      // Equivalent to iszero(eq(div(xx, x), x)) here.
237                      if shr(128, x) {
238                          revert(0, 0)
239                      }
240  
241                      // Store x squared.
242                      let xx := mul(x, x)
243  
244                      // Round to the nearest number.
245                      let xxRound := add(xx, half)
246  
247                      // Revert if xx + half overflowed.
248                      if lt(xxRound, xx) {
249                          revert(0, 0)
250                      }
251  
252                      // Set x to scaled xxRound.
253                      x := div(xxRound, scalar)
254  
255                      // If n is even:
256                      if mod(n, 2) {
257                          // Compute z * x.
258                          let zx := mul(z, x)
259  
260                          // If z * x overflowed:
261                          if iszero(eq(div(zx, x), z)) {
262                              // Revert if x is non-zero.
263                              if iszero(iszero(x)) {
264                                  revert(0, 0)
265                              }
266                          }
267  
268                          // Round to the nearest number.
269                          let zxRound := add(zx, half)
270  
271                          // Revert if zx + half overflowed.
272                          if lt(zxRound, zx) {
273                              revert(0, 0)
274                          }
275  
276                          // Return properly scaled zxRound.
277                          z := div(zxRound, scalar)
278                      }
279                  }
280              }
281          }
282      }
283  
284      /*//////////////////////////////////////////////////////////////
285                          GENERAL NUMBER UTILITIES
286      //////////////////////////////////////////////////////////////*/
287  
288      function sqrt(uint256 x) internal pure returns (uint256 z) {
289          assembly {
290              let y := x // We start y at x, which will help us make our initial estimate.
291  
292              z := 181 // The "correct" value is 1, but this saves a multiplication later.
293  
294              // This segment is to get a reasonable initial estimate for the Babylonian method. With a bad
295              // start, the correct # of bits increases ~linearly each iteration instead of ~quadratically.
296  
297              // We check y >= 2^(k + 8) but shift right by k bits
298              // each branch to ensure that if x >= 256, then y >= 256.
299              if iszero(lt(y, 0x10000000000000000000000000000000000)) {
300                  y := shr(128, y)
301                  z := shl(64, z)
302              }
303              if iszero(lt(y, 0x1000000000000000000)) {
304                  y := shr(64, y)
305                  z := shl(32, z)
306              }
307              if iszero(lt(y, 0x10000000000)) {
308                  y := shr(32, y)
309                  z := shl(16, z)
310              }
311              if iszero(lt(y, 0x1000000)) {
312                  y := shr(16, y)
313                  z := shl(8, z)
314              }
315  
316              // Goal was to get z*z*y within a small factor of x. More iterations could
317              // get y in a tighter range. Currently, we will have y in [256, 256*2^16).
318              // We ensured y >= 256 so that the relative difference between y and y+1 is small.
319              // That's not possible if x < 256 but we can just verify those cases exhaustively.
320  
321              // Now, z*z*y <= x < z*z*(y+1), and y <= 2^(16+8), and either y >= 256, or x < 256.
322              // Correctness can be checked exhaustively for x < 256, so we assume y >= 256.
323              // Then z*sqrt(y) is within sqrt(257)/sqrt(256) of sqrt(x), or about 20bps.
324  
325              // For s in the range [1/256, 256], the estimate f(s) = (181/1024) * (s+1) is in the range
326              // (1/2.84 * sqrt(s), 2.84 * sqrt(s)), with largest error when s = 1 and when s = 256 or 1/256.
327  
328              // Since y is in [256, 256*2^16), let a = y/65536, so that a is in [1/256, 256). Then we can estimate
329              // sqrt(y) using sqrt(65536) * 181/1024 * (a + 1) = 181/4 * (y + 65536)/65536 = 181 * (y + 65536)/2^18.
330  
331              // There is no overflow risk here since y < 2^136 after the first branch above.
332              z := shr(18, mul(z, add(y, 65536))) // A mul() is saved from starting z at 181.
333  
334              // Given the worst case multiplicative error of 2.84 above, 7 iterations should be enough.
335              z := shr(1, add(z, div(x, z)))
336              z := shr(1, add(z, div(x, z)))
337              z := shr(1, add(z, div(x, z)))
338              z := shr(1, add(z, div(x, z)))
339              z := shr(1, add(z, div(x, z)))
340              z := shr(1, add(z, div(x, z)))
341              z := shr(1, add(z, div(x, z)))
342  
343              // If x+1 is a perfect square, the Babylonian method cycles between
344              // floor(sqrt(x)) and ceil(sqrt(x)). This statement ensures we return floor.
345              // See: https://en.wikipedia.org/wiki/Integer_square_root#Using_only_integer_division
346              // Since the ceil is rare, we save gas on the assignment and repeat division in the rare case.
347              // If you don't care whether the floor or ceil square root is returned, you can remove this statement.
348              z := sub(z, lt(div(x, z), z))
349          }
350      }
351  
352      function log2fpl(uint256 x) internal pure returns (uint256 r) {
353          require(x > 0, "UNDEFINED");
354  
355          assembly {
356              r := shl(7, lt(0xffffffffffffffffffffffffffffffff, x))
357              r := or(r, shl(6, lt(0xffffffffffffffff, shr(r, x))))
358              r := or(r, shl(5, lt(0xffffffff, shr(r, x))))
359              r := or(r, shl(4, lt(0xffff, shr(r, x))))
360              r := or(r, shl(3, lt(0xff, shr(r, x))))
361              r := or(r, shl(2, lt(0xf, shr(r, x))))
362              r := or(r, shl(1, lt(0x3, shr(r, x))))
363              r := or(r, lt(0x1, shr(r, x)))
364          }
365      }
366  
367      function unsafeMod(uint256 x, uint256 y) internal pure returns (uint256 z) {
368          assembly {
369              // z will equal 0 if y is 0, unlike in Solidity where it will revert.
370              z := mod(x, y)
371          }
372      }
373  
374      function unsafeDiv(uint256 x, uint256 y) internal pure returns (uint256 z) {
375          assembly {
376              // z will equal 0 if y is 0, unlike in Solidity where it will revert.
377              z := div(x, y)
378          }
379      }
380  
381      /// @dev Will return 0 instead of reverting if y is zero.
382      function unsafeDivUp(uint256 x, uint256 y) internal pure returns (uint256 z) {
383          assembly {
384              // Add 1 to x * y if x % y > 0.
385              z := add(gt(mod(x, y), 0), div(x, y))
386          }
387      }
388  }