/ src / gmpy2_mpmath.c
gmpy2_mpmath.c
  1  /* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * *
  2   * gmpy2_mpmath.c                                                          *
  3   * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * *
  4   * Python interface to the GMP or MPIR, MPFR, and MPC multiple precision   *
  5   * libraries.                                                              *
  6   *                                                                         *
  7   * Copyright 2000 - 2009 Alex Martelli                                     *
  8   *                                                                         *
  9   * Copyright 2008 - 2021 Case Van Horsen                                   *
 10   *                                                                         *
 11   * This file is part of GMPY2.                                             *
 12   *                                                                         *
 13   * GMPY2 is free software: you can redistribute it and/or modify it under  *
 14   * the terms of the GNU Lesser General Public License as published by the  *
 15   * Free Software Foundation, either version 3 of the License, or (at your  *
 16   * option) any later version.                                              *
 17   *                                                                         *
 18   * GMPY2 is distributed in the hope that it will be useful, but WITHOUT    *
 19   * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or   *
 20   * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Lesser General Public    *
 21   * License for more details.                                               *
 22   *                                                                         *
 23   * You should have received a copy of the GNU Lesser General Public        *
 24   * License along with GMPY2; if not, see <http://www.gnu.org/licenses/>    *
 25   * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
 26  
 27  
 28  /* Internal helper function for mpmath. */
 29  
 30  static PyObject *
 31  mpmath_build_mpf(long sign, MPZ_Object *man, PyObject *exp, mp_bitcnt_t bc)
 32  {
 33      PyObject *tup, *tsign, *tbc;
 34  
 35      if (!(tup = PyTuple_New(4))) {
 36          Py_DECREF((PyObject*)man);
 37          Py_DECREF(exp);
 38          return NULL;
 39      }
 40  
 41      if (!(tsign = PyIntOrLong_FromLong(sign))) {
 42          Py_DECREF((PyObject*)man);
 43          Py_DECREF(exp);
 44          Py_DECREF(tup);
 45          return NULL;
 46      }
 47  
 48      if (!(tbc = PyIntOrLong_FromMpBitCnt(bc))) {
 49          Py_DECREF((PyObject*)man);
 50          Py_DECREF(exp);
 51          Py_DECREF(tup);
 52          Py_DECREF(tsign);
 53          return NULL;
 54      }
 55  
 56      PyTuple_SET_ITEM(tup, 0, tsign);
 57      PyTuple_SET_ITEM(tup, 1, (PyObject*)man);
 58      PyTuple_SET_ITEM(tup, 2, (exp)?exp:PyIntOrLong_FromLong(0));
 59      PyTuple_SET_ITEM(tup, 3, tbc);
 60      return tup;
 61  }
 62  
 63  PyDoc_STRVAR(doc_mpmath_normalizeg,
 64  "_mpmath_normalize(...): helper function for mpmath.");
 65  
 66  static PyObject *
 67  Pympz_mpmath_normalize(PyObject *self, PyObject *args)
 68  {
 69      long sign = 0;
 70      mp_bitcnt_t zbits, bc = 0, prec = 0, shift;
 71      long carry = 0;
 72      PyObject *exp = 0, *newexp = 0, *newexp2 = 0, *tmp = 0, *rndstr = 0;
 73      MPZ_Object *man = 0, *upper = 0, *lower = 0;
 74      Py2or3String_Type rnd = 0;
 75      int err1, err2, err3;
 76  
 77      if (PyTuple_GET_SIZE(args) == 6) {
 78          /* Need better error-checking here. Under Python 3.0, overflow into
 79             C-long is possible. */
 80          sign = GMPy_Integer_AsLongAndError(PyTuple_GET_ITEM(args, 0), &err1);
 81          man = (MPZ_Object*)PyTuple_GET_ITEM(args, 1);
 82          exp = PyTuple_GET_ITEM(args, 2);
 83          bc = GMPy_Integer_AsMpBitCntAndError(PyTuple_GET_ITEM(args, 3), &err2);
 84          prec = GMPy_Integer_AsMpBitCntAndError(PyTuple_GET_ITEM(args, 4), &err3);
 85          rndstr = PyTuple_GET_ITEM(args, 5);
 86          if (err1 || err2 || err3) {
 87              TYPE_ERROR("arguments long, MPZ_Object*, PyObject*, long, long, char needed");
 88              return NULL;
 89          }
 90      }
 91      else {
 92          TYPE_ERROR("6 arguments required");
 93          return NULL;
 94      }
 95  
 96      if (!MPZ_Check(man)) {
 97  		/* Try to convert to an mpz... */
 98  		if (!(man = GMPy_MPZ_From_Integer((PyObject*)man, NULL))) {
 99  			TYPE_ERROR("argument is not an mpz");
100  			return NULL;
101  		}
102      }
103  
104      /* If rndstr really is a string, extract the first character. */
105      if (Py2or3String_Check(rndstr)) {
106          rnd = Py2or3String_1Char(rndstr);
107      }
108      else {
109          VALUE_ERROR("invalid rounding mode specified");
110          return NULL;
111      }
112  
113      /* If the mantissa is 0, return the normalized representation. */
114      if (!mpz_sgn(man->z)) {
115          Py_INCREF((PyObject*)man);
116          return mpmath_build_mpf(0, man, 0, 0);
117      }
118  
119  
120      /* if bc <= prec and the number is odd return it */
121      if ((bc <= prec) && mpz_odd_p(man->z)) {
122          Py_INCREF((PyObject*)man);
123          Py_INCREF((PyObject*)exp);
124          return mpmath_build_mpf(sign, man, exp, bc);
125      }
126  
127      if (!(upper = GMPy_MPZ_New(NULL)) || !(lower = GMPy_MPZ_New(NULL))) {
128          Py_XDECREF((PyObject*)upper);
129          Py_XDECREF((PyObject*)lower);
130      }
131  
132      if (bc > prec) {
133          shift = bc - prec;
134          switch (rnd) {
135              case (Py2or3String_Type)'f':
136                  if(sign) {
137                      mpz_cdiv_q_2exp(upper->z, man->z, shift);
138                  }
139                  else {
140                      mpz_fdiv_q_2exp(upper->z, man->z, shift);
141                  }
142                  break;
143              case (Py2or3String_Type)'c':
144                  if(sign) {
145                      mpz_fdiv_q_2exp(upper->z, man->z, shift);
146                  }
147                  else {
148                      mpz_cdiv_q_2exp(upper->z, man->z, shift);
149                  }
150                  break;
151              case (Py2or3String_Type)'d':
152                  mpz_fdiv_q_2exp(upper->z, man->z, shift);
153                  break;
154              case (Py2or3String_Type)'u':
155                  mpz_cdiv_q_2exp(upper->z, man->z, shift);
156                  break;
157              case (Py2or3String_Type)'n':
158              default:
159                  mpz_tdiv_r_2exp(lower->z, man->z, shift);
160                  mpz_tdiv_q_2exp(upper->z, man->z, shift);
161                  if (mpz_sgn(lower->z)) {
162                      /* lower is not 0 so it must have at least 1 bit set */
163                      if (mpz_sizeinbase(lower->z, 2) == shift) {
164                          /* lower is >= 1/2 */
165                          if (mpz_scan1(lower->z, 0) == shift-1) {
166                              /* lower is exactly 1/2 */
167                              if (mpz_odd_p(upper->z))
168                                  carry = 1;
169                          }
170                          else {
171                              carry = 1;
172                          }
173                      }
174                  }
175                  if (carry)
176                      mpz_add_ui(upper->z, upper->z, 1);
177          }
178  
179          if (!(tmp = PyIntOrLong_FromMpBitCnt(shift))) {
180              Py_DECREF((PyObject*)upper);
181              Py_DECREF((PyObject*)lower);
182              return NULL;
183          }
184  
185          if (!(newexp = PyNumber_Add(exp, tmp))) {
186              Py_DECREF((PyObject*)upper);
187              Py_DECREF((PyObject*)lower);
188              Py_DECREF(tmp);
189              return NULL;
190          }
191          Py_DECREF(tmp);
192          bc = prec;
193      }
194      else {
195          mpz_set(upper->z, man->z);
196          newexp = exp;
197          Py_INCREF(newexp);
198      }
199  
200      /* Strip trailing 0 bits. */
201      if ((zbits = mpz_scan1(upper->z, 0)))
202          mpz_tdiv_q_2exp(upper->z, upper->z, zbits);
203  
204      if (!(tmp = PyIntOrLong_FromMpBitCnt(zbits))) {
205          Py_DECREF((PyObject*)upper);
206          Py_DECREF((PyObject*)lower);
207          Py_DECREF(newexp);
208          return NULL;
209      }
210      if (!(newexp2 = PyNumber_Add(newexp, tmp))) {
211          Py_DECREF((PyObject*)upper);
212          Py_DECREF((PyObject*)lower);
213          Py_DECREF(tmp);
214          Py_DECREF(newexp);
215          return NULL;
216      }
217      Py_DECREF(newexp);
218      Py_DECREF(tmp);
219  
220      bc -= zbits;
221      /* Check if one less than a power of 2 was rounded up. */
222      if (!mpz_cmp_ui(upper->z, 1))
223          bc = 1;
224  
225      Py_DECREF((PyObject*)lower);
226      return mpmath_build_mpf(sign, upper, newexp2, bc);
227  }
228  
229  PyDoc_STRVAR(doc_mpmath_createg,
230  "_mpmath_create(...): helper function for mpmath.");
231  
232  static PyObject *
233  Pympz_mpmath_create(PyObject *self, PyObject *args)
234  {
235      long sign;
236      mp_bitcnt_t zbits, bc = 0, prec = 0, shift;
237      long carry = 0;
238      PyObject *exp = 0, *newexp = 0, *newexp2 = 0, *tmp = 0;
239      MPZ_Object *man = 0, *upper = 0, *lower = 0;
240      int error;
241  
242      Py2or3String_Type rnd = (Py2or3String_Type)'f';
243  
244      if (PyTuple_GET_SIZE(args) < 2) {
245          TYPE_ERROR("mpmath_create() expects 'mpz','int'[,'int','str'] arguments");
246          return NULL;
247      }
248  
249      switch (PyTuple_GET_SIZE(args)) {
250          case 4:
251              rnd = Py2or3String_1Char(PyTuple_GET_ITEM(args, 3));
252          case 3:
253              prec = GMPy_Integer_AsMpBitCntAndError(PyTuple_GET_ITEM(args, 2), &error);
254              if (error)
255                  return NULL;
256          case 2:
257              exp = PyTuple_GET_ITEM(args, 1);
258          case 1:
259              man = GMPy_MPZ_From_Integer(PyTuple_GET_ITEM(args, 0), NULL);
260              if (!man) {
261                  TYPE_ERROR("mpmath_create() expects 'mpz','int'[,'int','str'] arguments");
262                  return NULL;
263              }
264      }
265  
266      /* If the mantissa is 0, return the normalized representation. */
267      if (!mpz_sgn(man->z)) {
268          return mpmath_build_mpf(0, man, 0, 0);
269      }
270  
271      upper = GMPy_MPZ_New(NULL);
272      lower = GMPy_MPZ_New(NULL);
273      if (!upper || !lower) {
274          Py_DECREF((PyObject*)man);
275          Py_XDECREF((PyObject*)upper);
276          Py_XDECREF((PyObject*)lower);
277          return NULL;
278      }
279  
280      /* Extract sign, make man positive, and set bit count */
281  
282      sign = (mpz_sgn(man->z) == -1);
283      mpz_abs(upper->z, man->z);
284      bc = mpz_sizeinbase(upper->z, 2);
285  
286      if (!prec) {
287          prec = bc;
288      }
289  
290      if (bc > prec) {
291          shift = bc - prec;
292          switch (rnd) {
293              case (Py2or3String_Type)'f':
294                  if (sign) {
295                      mpz_cdiv_q_2exp(upper->z, upper->z, shift);
296                  }
297                  else {
298                      mpz_fdiv_q_2exp(upper->z, upper->z, shift);
299                  }
300                  break;
301              case (Py2or3String_Type)'c':
302                  if (sign) {
303                      mpz_fdiv_q_2exp(upper->z, upper->z, shift);
304                  }
305                  else {
306                      mpz_cdiv_q_2exp(upper->z, upper->z, shift);
307                  }
308                  break;
309              case (Py2or3String_Type)'d':
310                  mpz_fdiv_q_2exp(upper->z, upper->z, shift);
311                  break;
312              case (Py2or3String_Type)'u':
313                  mpz_cdiv_q_2exp(upper->z, upper->z, shift);
314                  break;
315              case (Py2or3String_Type)'n':
316              default:
317                  mpz_tdiv_r_2exp(lower->z, upper->z, shift);
318                  mpz_tdiv_q_2exp(upper->z, upper->z, shift);
319                  if (mpz_sgn(lower->z)) {
320                      /* lower is not 0 so it must have at least 1 bit set */
321                      if (mpz_sizeinbase(lower->z, 2)==shift) {
322                          /* lower is >= 1/2 */
323                          if (mpz_scan1(lower->z, 0)==shift-1) {
324                              /* lower is exactly 1/2 */
325                              if (mpz_odd_p(upper->z))
326                                  carry = 1;
327                          }
328                          else {
329                              carry = 1;
330                          }
331                      }
332                  }
333                  if (carry) {
334                      mpz_add_ui(upper->z, upper->z, 1);
335                  }
336          }
337          if (!(tmp = PyIntOrLong_FromMpBitCnt(shift))) {
338              Py_DECREF((PyObject*)upper);
339              Py_DECREF((PyObject*)lower);
340              return NULL;
341          }
342          if (!(newexp = PyNumber_Add(exp, tmp))) {
343              Py_DECREF((PyObject*)man);
344              Py_DECREF((PyObject*)upper);
345              Py_DECREF((PyObject*)lower);
346              Py_DECREF(tmp);
347              return NULL;
348          }
349          Py_DECREF(tmp);
350          bc = prec;
351      }
352      else {
353          newexp = exp;
354          Py_INCREF(newexp);
355      }
356  
357      /* Strip trailing 0 bits. */
358      if ((zbits = mpz_scan1(upper->z, 0)))
359          mpz_tdiv_q_2exp(upper->z, upper->z, zbits);
360  
361      if (!(tmp = PyIntOrLong_FromMpBitCnt(zbits))) {
362          Py_DECREF((PyObject*)man);
363          Py_DECREF((PyObject*)upper);
364          Py_DECREF((PyObject*)lower);
365          Py_DECREF(newexp);
366          return NULL;
367      }
368      if (!(newexp2 = PyNumber_Add(newexp, tmp))) {
369          Py_DECREF((PyObject*)man);
370          Py_DECREF((PyObject*)upper);
371          Py_DECREF((PyObject*)lower);
372          Py_DECREF(tmp);
373          Py_DECREF(newexp);
374          return NULL;
375      }
376      Py_DECREF(newexp);
377      Py_DECREF(tmp);
378  
379      bc -= zbits;
380      /* Check if one less than a power of 2 was rounded up. */
381      if (!mpz_cmp_ui(upper->z, 1))
382          bc = 1;
383  
384      Py_DECREF((PyObject*)lower);
385      Py_DECREF((PyObject*)man);
386      return mpmath_build_mpf(sign, upper, newexp2, bc);
387  }