#include "tommath_private.h"
#ifdef S_MP_DIV_RECURSIVE_C
/* LibTomMath, multiple-precision integer library -- Tom St Denis */
/* SPDX-License-Identifier: Unlicense */

/*
   Direct implementation of algorithms 1.8 "RecursiveDivRem" and 1.9 "UnbalancedDivision"
   from:

      Brent, Richard P., and Paul Zimmermann. "Modern computer arithmetic"
      Vol. 18. Cambridge University Press, 2010
      Available online at https://arxiv.org/pdf/1004.4710

   pages 19ff. in the above online document.
*/

static mp_err s_recursion(const mp_int *a, const mp_int *b, mp_int *q, mp_int *r)
{
   mp_err err;
   mp_int A1, A2, B1, B0, Q1, Q0, R1, R0, t;
   int m = a->used - b->used, k = m/2;

   if (m < (MP_MUL_KARATSUBA_CUTOFF)) {
      return s_mp_div_school(a, b, q, r);
   }

   if ((err = mp_init_multi(&A1, &A2, &B1, &B0, &Q1, &Q0, &R1, &R0, &t, NULL)) != MP_OKAY) {
      goto LBL_ERR;
   }

   /* B1 = b / beta^k, B0 = b % beta^k*/
   if ((err = mp_div_2d(b, k * MP_DIGIT_BIT, &B1, &B0)) != MP_OKAY)        goto LBL_ERR;

   /* (Q1, R1) =  RecursiveDivRem(A / beta^(2k), B1) */
   if ((err = mp_div_2d(a, 2*k * MP_DIGIT_BIT, &A1, &t)) != MP_OKAY)       goto LBL_ERR;
   if ((err = s_recursion(&A1, &B1, &Q1, &R1)) != MP_OKAY)                 goto LBL_ERR;

   /* A1 = (R1 * beta^(2k)) + (A % beta^(2k)) - (Q1 * B0 * beta^k) */
   if ((err = mp_lshd(&R1, 2*k)) != MP_OKAY)                               goto LBL_ERR;
   if ((err = mp_add(&R1, &t, &A1)) != MP_OKAY)                            goto LBL_ERR;
   if ((err = mp_mul(&Q1, &B0, &t)) != MP_OKAY)                            goto LBL_ERR;
   if ((err = mp_lshd(&t, k)) != MP_OKAY)                                  goto LBL_ERR;
   if ((err = mp_sub(&A1, &t, &A1)) != MP_OKAY)                            goto LBL_ERR;

   /* while A1 < 0 do Q1 = Q1 - 1, A1 = A1 + (beta^k * B) */
   if (mp_cmp_d(&A1, 0uL) == MP_LT) {
      if ((err = mp_mul_2d(b, k * MP_DIGIT_BIT, &t)) != MP_OKAY)           goto LBL_ERR;
      do {
         if ((err = mp_decr(&Q1)) != MP_OKAY)                              goto LBL_ERR;
         if ((err = mp_add(&A1, &t, &A1)) != MP_OKAY)                      goto LBL_ERR;
      } while (mp_cmp_d(&A1, 0uL) == MP_LT);
   }
   /* (Q0, R0) =  RecursiveDivRem(A1 / beta^(k), B1) */
   if ((err = mp_div_2d(&A1, k * MP_DIGIT_BIT, &A1, &t)) != MP_OKAY)       goto LBL_ERR;
   if ((err = s_recursion(&A1, &B1, &Q0, &R0)) != MP_OKAY)                 goto LBL_ERR;

   /* A2 = (R0*beta^k) +  (A1 % beta^k) - (Q0*B0) */
   if ((err = mp_lshd(&R0, k)) != MP_OKAY)                                 goto LBL_ERR;
   if ((err = mp_add(&R0, &t, &A2)) != MP_OKAY)                            goto LBL_ERR;
   if ((err = mp_mul(&Q0, &B0, &t)) != MP_OKAY)                            goto LBL_ERR;
   if ((err = mp_sub(&A2, &t, &A2)) != MP_OKAY)                            goto LBL_ERR;

   /* while A2 < 0 do Q0 = Q0 - 1, A2 = A2 + B */
   while (mp_cmp_d(&A2, 0uL) == MP_LT) {
      if ((err = mp_decr(&Q0)) != MP_OKAY)                                 goto LBL_ERR;
      if ((err = mp_add(&A2, b, &A2)) != MP_OKAY)                          goto LBL_ERR;
   }
   /* return q = (Q1*beta^k) + Q0, r = A2 */
   if ((err = mp_lshd(&Q1, k)) != MP_OKAY)                                 goto LBL_ERR;
   if ((err = mp_add(&Q1, &Q0, q)) != MP_OKAY)                             goto LBL_ERR;

   if ((err = mp_copy(&A2, r)) != MP_OKAY)                                 goto LBL_ERR;

LBL_ERR:
   mp_clear_multi(&A1, &A2, &B1, &B0, &Q1, &Q0, &R1, &R0, &t, NULL);
   return err;
}


mp_err s_mp_div_recursive(const mp_int *a, const mp_int *b, mp_int *q, mp_int *r)
{
   int j, m, n, sigma;
   mp_err err;
   bool neg;
   mp_digit msb_b, msb;
   mp_int A, B, Q, Q1, R, A_div, A_mod;

   if ((err = mp_init_multi(&A, &B, &Q, &Q1, &R, &A_div, &A_mod, NULL)) != MP_OKAY) {
      goto LBL_ERR;
   }

   /* most significant bit of a limb */
   /* assumes  MP_DIGIT_MAX < (sizeof(mp_digit) * CHAR_BIT) */
   msb = (MP_DIGIT_MAX + (mp_digit)(1)) >> 1;
   sigma = 0;
   msb_b = b->dp[b->used - 1];
   while (msb_b < msb) {
      sigma++;
      msb_b <<= 1;
   }
   /* Use that sigma to normalize B */
   if ((err = mp_mul_2d(b, sigma, &B)) != MP_OKAY) {
      goto LBL_ERR;
   }
   if ((err = mp_mul_2d(a, sigma, &A)) != MP_OKAY) {
      goto LBL_ERR;
   }

   /* fix the sign */
   neg = (a->sign != b->sign);
   A.sign = B.sign = MP_ZPOS;

   /*
      If the magnitude of "A" is not more more than twice that of "B" we can work
      on them directly, otherwise we need to work at "A" in chunks
    */
   n = B.used;
   m = A.used - B.used;

   /* Q = 0 */
   mp_zero(&Q);
   while (m > n) {
      /* (q, r) = RecursiveDivRem(A / (beta^(m-n)), B) */
      j = (m - n) * MP_DIGIT_BIT;
      if ((err = mp_div_2d(&A, j, &A_div, &A_mod)) != MP_OKAY)                   goto LBL_ERR;
      if ((err = s_recursion(&A_div, &B, &Q1, &R)) != MP_OKAY)                goto LBL_ERR;
      /* Q = (Q*beta!(n)) + q */
      if ((err = mp_mul_2d(&Q, n * MP_DIGIT_BIT, &Q)) != MP_OKAY)                goto LBL_ERR;
      if ((err = mp_add(&Q, &Q1, &Q)) != MP_OKAY)                                goto LBL_ERR;
      /* A = (r * beta^(m-n)) + (A % beta^(m-n))*/
      if ((err = mp_mul_2d(&R, (m - n) * MP_DIGIT_BIT, &R)) != MP_OKAY)          goto LBL_ERR;
      if ((err = mp_add(&R, &A_mod, &A)) != MP_OKAY)                             goto LBL_ERR;
      /* m = m - n */
      m = m - n;
   }
   /* (q, r) = RecursiveDivRem(A, B) */
   if ((err = s_recursion(&A, &B, &Q1, &R)) != MP_OKAY)                       goto LBL_ERR;
   /* Q = (Q * beta^m) + q, R = r */
   if ((err = mp_mul_2d(&Q, m * MP_DIGIT_BIT, &Q)) != MP_OKAY)                   goto LBL_ERR;
   if ((err = mp_add(&Q, &Q1, &Q)) != MP_OKAY)                                   goto LBL_ERR;

   /* get sign before writing to c */
   R.sign = (mp_iszero(&Q) ? MP_ZPOS : a->sign);

   if (q != NULL) {
      mp_exch(&Q, q);
      q->sign = (neg ? MP_NEG : MP_ZPOS);
   }
   if (r != NULL) {
      /* de-normalize the remainder */
      if ((err = mp_div_2d(&R, sigma, &R, NULL)) != MP_OKAY)                      goto LBL_ERR;
      mp_exch(&R, r);
   }
LBL_ERR:
   mp_clear_multi(&A, &B, &Q, &Q1, &R, &A_div, &A_mod, NULL);
   return err;
}

#endif