
#include "LLL.h"
#include "xdouble.h"


static xdouble InnerProduct(xdouble *a, xdouble *b, long n)
{
   xdouble s, c, y, t;
   long i;

   // Kahan summation

   s = c = 0;
   for (i = 1; i <= n; i++) {
      y = a[i]*b[i] - c;
      t = s+y;
      c = t-s;
      c = c-y;
      s = t;
   }

   return s;
}

static void RowTransform(vector(ZZ)& A, vector(ZZ)& B, const ZZ& MU1)
// x = x - y*MU
{
   static ZZ T, MU;
   long k;

   MU = MU1;

   if (MU == 0) return;

   if (digit(MU, 0) == 0) 
      k = MakeOdd(MU);
   else
      k = 0;

   long n = A.length();
   long i;

   if (MU.size() <= 1) {
      long mu1;
      mu1 << MU;

      for (i = 1; i <= n; i++) {
         mul(T, B(i), mu1);
         if (k > 0) LeftShift(T, T, k);
         sub(A(i), A(i), T);
      }
   }
   else {
      for (i = 1; i <= n; i++) {
         mul(T, B(i), MU);
         if (k > 0) LeftShift(T, T, k);
         sub(A(i), A(i), T);
      }
   }
}

static void RowTransform2(vector(ZZ)& A, vector(ZZ)& B, const ZZ& MU1)
// x = x + y*MU
{
   static ZZ T, MU;
   long k;

   MU = MU1;

   if (MU == 0) return;

   if (digit(MU, 0) == 0) 
      k = MakeOdd(MU);
   else
      k = 0;

   long n = A.length();
   long i;

   if (MU.size() <= 1) {
      long mu1;
      mu1 << MU;

      for (i = 1; i <= n; i++) {
         mul(T, B(i), mu1);
         if (k > 0) LeftShift(T, T, k);
         add(A(i), A(i), T);
      }
   }
   else {
      for (i = 1; i <= n; i++) {
         mul(T, B(i), MU);
         if (k > 0) LeftShift(T, T, k);
         add(A(i), A(i), T);
      }
   }
}


static
void ComputeGS(matrix(ZZ)& B, xdouble **B1, xdouble **mu, xdouble *b, 
               xdouble *c, long k, xdouble bound)
{
   long n = B.NumCols();
   long i, j;
   xdouble s, t1, y, c1, t;
   ZZ T1;

   for (j = 1; j <= k-1; j++) {
      s = InnerProduct(B1[k], B1[j], n);

      if (s*s <= b[k]*b[j]/bound) {
         InnerProduct(T1, B(k), B(j));
         s << T1;
      }

      // Kahan summation
      t1 = c1 = 0;
      for (i = 1; i <= j-1; i++) {
         y = mu[j][i]*mu[k][i]*c[i] - c1;
         t = t1+y;
         c1 = t-t1;
         c1 = c1-y;
         t1 = t;
      }
 
      mu[k][j] = (s - t1)/c[j];
   }

   // Kahan summation
   s = c1 = 0;
   for (j = 1; j <= k-1; j++) {
      y = mu[k][j]*mu[k][j]*c[j] - c1;
      t = s+y;
      c1 = t-s;
      c1 = c1-y;
      s = t;
   }

   c[k] = b[k] - s;
}

static
long ll_LLL_XD(matrix(ZZ)& B, matrix(ZZ)* U, xdouble delta, long deep, 
           LLLCheckFct check, xdouble **B1, xdouble **mu, xdouble *b, xdouble *c,
           long m, long init_k, long &quit)
{
   long n = B.NumCols();

   long i, j, k, Fc, Fc1;
   ZZ MU;
   xdouble mu1;

   xdouble t1;
   ZZ T1;
   xdouble *tp;


   static xdouble bound = 0;



   if (bound == 0) {
      // we tolerate a 15% loss of precision in computing
      // inner products in ComputeGS.

      bound = 1;
      for (i = 2*long(0.15*ZZ_DOUBLE_PRECISION); i > 0; i--) {
         bound = bound * 2;
      }
   }


   static xdouble bound1 = 0;

   if (bound1 == 0) {
      bound1 = 1;
      // for (i = ZZ_DOUBLE_PRECISION/2; i > 0; i--)
      for (i = long(0.15*ZZ_DOUBLE_PRECISION); i > 0; i--)
         bound1 = bound1 * 2;
   }


   quit = 0;
   k = init_k;

   while (k <= m) {


      ComputeGS(B, B1, mu, b, c, k, bound);


      do {
         // size reduction

         Fc = Fc1 = 0;
   
         for (j = k-1; j >= 1; j--) {
            t1 = fabs(mu[k][j]);
            if (t1 > 0.5) {
               Fc1 = 1;
   
               mu1 = mu[k][j];
               if (mu1 >= 0)
                  mu1 = ceil(mu1-0.5);
               else
                  mu1 = floor(mu1+0.5);
   
               if (fabs(mu1) > bound1)
                  Fc = 1;
   
               for (i = 1; i <= j-1; i++)
                  mu[k][i] -= mu1*mu[j][i];
   
               mu[k][j] -= mu1;
   
               MU << mu1;
   
               RowTransform(B(k), B(j), MU);
               if (U) RowTransform((*U)(k), (*U)(j), MU);
            }
         }

         if (Fc1) {
            for (i = 1; i <= n; i++)
               B1[k][i] << B(k, i);
   
            b[k] = InnerProduct(B1[k], B1[k], n);
            ComputeGS(B, B1, mu, b, c, k, bound);
         }
      } while (Fc);

      if (check && (*check)(B(k))) 
         quit = 1;

      if (b[k] == 0) {
         for (i = k; i < m; i++) {
            // swap i, i+1
            swap(B(i), B(i+1));
            tp = B1[i]; B1[i] = B1[i+1]; B1[i+1] = tp;
            t1 = b[i]; b[i] = b[i+1]; b[i+1] = t1;
            if (U) swap((*U)(i), (*U)(i+1));
         }

         m--;
         if (quit) break;
         continue;
      }

      if (quit) break;

      if (deep > 0) {
         // deep insertions
   
         xdouble cc = b[k];
         long l = 1;
         while (l <= k-1 && delta*c[l] <= cc) {
            cc = cc - mu[k][l]*mu[k][l]*c[l];
            l++;
         }
   
         if (l <= k-1 && (l <= deep || k-l <= deep)) {
            // deep insertion at position l
   
            for (i = k; i > l; i--) {
               // swap rows i, i-1
               swap(B(i), B(i-1));
               tp = B1[i]; B1[i] = B1[i-1]; B1[i-1] = tp;
               t1 = b[i]; b[i] = b[i-1]; b[i-1] = t1;
               if (U) swap((*U)(i), (*U)(i-1));
            }
   
            k = l;
            continue;
         }
      } // end deep insertions

      // test LLL reduction condition

      if (k > 1 && delta*c[k-1] > c[k] + mu[k][k-1]*mu[k][k-1]*c[k-1]) {
         // swap rows k, k-1
         swap(B(k), B(k-1));
         tp = B1[k]; B1[k] = B1[k-1]; B1[k-1] = tp;
         t1 = b[k]; b[k] = b[k-1]; b[k-1] = t1;
         if (U) swap((*U)(k), (*U)(k-1));

         k--;
      }
      else
         k++;
   }

   return m;
}

static
long LLL_XD(matrix(ZZ)& B, matrix(ZZ)* U, xdouble delta, long deep, 
           LLLCheckFct check)
{
   long m = B.NumRows();
   long n = B.NumCols();

   long i, j, k, Fc;
   long new_m, dep, quit;
   xdouble s;
   ZZ MU;
   xdouble mu1;

   xdouble t1;
   ZZ T1;
   xdouble *tp;

   if (U) ident(*U, m);

   xdouble **B1;  // approximates B

   typedef xdouble *xdoubleptr;

   B1 = new xdoubleptr[m+1];
   if (!B1) Error("LLL_XD: out of memory");

   for (i = 1; i <= m; i++) {
      B1[i] = new xdouble[n+1];
      if (!B1[i]) Error("LLL_XD: out of memory");
   }

   xdouble **mu;
   mu = new xdoubleptr[m+1];
   if (!mu) Error("LLL_XD: out of memory");

   for (i = 1; i <= m; i++) {
      mu[i] = new xdouble[m+1];
      if (!mu[i]) Error("LLL_XD: out of memory");
   }

   xdouble *c; // squared lengths of Gramm-Schmidt basis vectors

   c = new xdouble[m+1];
   if (!c) Error("LLL_XD: out of memory");

   xdouble *b; // squared lengths of basis vectors

   b = new xdouble[m+1];
   if (!b) Error("LLL_XD: out of memory");



   for (i = 1; i <=m; i++)
      for (j = 1; j <= n; j++) 
         B1[i][j] << B(i, j);


         
   for (i = 1; i <= m; i++) {
      b[i] = InnerProduct(B1[i], B1[i], n);
   }


   new_m = ll_LLL_XD(B, U, delta, deep, check, B1, mu, b, c, m, 1, quit);
   dep = m - new_m;
   m = new_m;

   if (dep > 0) {
      // for consistency, we move all of the zero rows to the front

      for (i = 0; i < m; i++) {
         swap(B(m+dep-i), B(m-i));
         if (U) swap((*U)(m+dep-i), (*U)(m-i));
      }
   }


   // clean-up

   for (i = 1; i <= m; i++) {
      delete [] B1[i];
   }

   delete [] B1;

   for (i = 1; i <= m; i++) {
      delete [] mu[i];
   }

   delete [] mu;

   delete [] c;

   delete [] b;

   return m;
}

         

long LLL_XD(matrix(ZZ)& B, double delta, long deep, LLLCheckFct check)
{
   if (delta <= 0.25 || delta > 1) Error("LLL_XD: bad delta");
   if (deep < 0) Error("LLL_XD: bad deep");
   return LLL_XD(B, 0, xdouble(delta), deep, check);
}

long LLL_XD(matrix(ZZ)& B, matrix(ZZ)& U, double delta, long deep, 
           LLLCheckFct check)
{
   if (delta <= 0.25 || delta > 1) Error("LLL_XD: bad delta");
   if (deep < 0) Error("LLL_XD: bad deep");
   return LLL_XD(B, &U, xdouble(delta), deep, check);
}



static
long BKZ_XD(matrix(ZZ)& BB, matrix(ZZ)* UU, xdouble delta, 
         long beta, long prune, LLLCheckFct check)
{
   long m = BB.NumRows();
   long n = BB.NumCols();
   long m_orig = m;
   
   long i, j;
   ZZ MU;

   xdouble t1;
   ZZ T1;
   xdouble *tp;

   matrix(ZZ) B;
   B = BB;

   B.SetDims(m+1, n);


   xdouble **B1;  // approximates B

   typedef xdouble *xdoubleptr;

   B1 = new xdoubleptr[m+2];
   if (!B1) Error("BKZ_XD: out of memory");

   for (i = 1; i <= m+1; i++) {
      B1[i] = new xdouble[n+1];
      if (!B1[i]) Error("BKZ_XD: out of memory");
   }

   xdouble **mu;
   mu = new xdoubleptr[m+2];
   if (!mu) Error("BKZ_XD: out of memory");

   for (i = 1; i <= m+1; i++) {
      mu[i] = new xdouble[m+1];
      if (!mu[i]) Error("BKZ_XD: out of memory");
   }

   xdouble *c; // squared lengths of Gramm-Schmidt basis vectors

   c = new xdouble[m+2];
   if (!c) Error("BKZ_XD: out of memory");

   xdouble *b; // squared lengths of basis vectors

   b = new xdouble[m+2];
   if (!b) Error("BKZ_XD: out of memory");

   xdouble cbar;

   xdouble *ctilda;
   ctilda = new xdouble[m+2];
   if (!ctilda) Error("BKZ_XD: out of memory");

   xdouble *vvec;
   vvec = new xdouble[m+2];
   if (!vvec) Error("BKZ_XD: out of memory");

   xdouble *yvec;
   yvec = new xdouble[m+2];
   if (!yvec) Error("BKZ_XD: out of memory");

   xdouble *uvec;
   uvec = new xdouble[m+2];
   if (!uvec) Error("BKZ_XD: out of memory");

   xdouble *utildavec;
   utildavec = new xdouble[m+2];
   if (!utildavec) Error("BKZ_XD: out of memory");


   long *Deltavec;
   Deltavec = new long[m+2];
   if (!Deltavec) Error("BKZ_XD: out of memory");

   long *deltavec;
   deltavec = new long[m+2];
   if (!deltavec) Error("BKZ_XD: out of memory");

   matrix(ZZ) Ulocal;
   matrix(ZZ) *U;

   if (UU) {
      Ulocal.SetDims(m+1, m);
      for (i = 1; i <= m; i++)
         Ulocal(i, i) << 1;
      U = &Ulocal;
   }
   else
      U = 0;

   long quit;
   long new_m;
   long z, jj, kk;
   long s, t;
   long h;
   long mu1;
   xdouble alpha;


   for (i = 1; i <=m; i++)
      for (j = 1; j <= n; j++) 
         B1[i][j] << B(i, j);

         
   for (i = 1; i <= m; i++) {
      b[i] = InnerProduct(B1[i], B1[i], n);
   }

   // cerr << "\n";
   // cerr << "first LLL\n";

   m = ll_LLL_XD(B, U, delta, 0, check, B1, mu, b, c, m, 1, quit);

   if (m < m_orig) {
      for (i = m_orig+1; i >= m+2; i--) {
         // swap i, i-1

         swap(B(i), B(i-1));
         if (U) swap((*U)(i), (*U)(i-1));
      }
   }

   if (!quit && m > 1) {
      // cerr << "continuing\n";
      if (beta > m) beta = m;

      z = 0;
      jj = 0;
   
      while (z < m-1) {
         jj++;
         kk = min(jj+beta-1, m);
   
         if (jj == m) {
            jj = 1;
            kk = beta;
         }
   
         // ENUM
   
         cbar = c[jj];
         utildavec[jj] = uvec[jj] = 1;
   
         yvec[jj] = vvec[jj] = 0;
         Deltavec[jj] = 0;
   
   
         s = t = jj;
         deltavec[jj] = 1;
   
         for (i = jj+1; i <= kk+1; i++) {
            ctilda[i] = uvec[i] = utildavec[i] = yvec[i] = 0;
            Deltavec[i] = 0;
            vvec[i] = 0;
            deltavec[i] = 1;
         }
   
         while (t <= kk) {
            ctilda[t] = ctilda[t+1] + 
               (yvec[t]+utildavec[t])*(yvec[t]+utildavec[t])*c[t];
   
            if (prune) {
               alpha = 1.05*xdouble(kk-t+1)/xdouble(kk-jj);
               if (alpha > 1) alpha = 1;
            }
            else
               alpha = 1;


            if (ctilda[t] < alpha*cbar) {
               if (t > jj) {
                  t--;
                  t1 = 0;
                  for (i = t+1; i <= s; i++) {
                     t1 += utildavec[i]*mu[i][t];
                  }


                  yvec[t] = t1;
                  t1 = -t1;
                  if (t1 >= 0)
                     t1 = ceil(t1-0.5);
                  else
                     t1 = floor(t1+0.5);

                  utildavec[t] = vvec[t] = t1;
                  Deltavec[t] = 0;
                  if (utildavec[t] > -yvec[t]) 
                     deltavec[t] = -1;
                  else
                     deltavec[t] = 1;
               }
               else {
                  cbar = ctilda[jj];
                  for (i = jj; i <= kk; i++) {
                     uvec[i] = utildavec[i];
                  }
               }
            }
            else {
               t++;
               s = max(s, t);
               if (t < s) Deltavec[t] = -Deltavec[t];
               if (Deltavec[t]*deltavec[t] >= 0) Deltavec[t] += deltavec[t];
               utildavec[t] = vvec[t] + Deltavec[t];
            }
         }
         
   
         h = min(kk+1, m);
   
         if (delta*c[jj] > cbar) {
            // we treat the case that the new vector is b_s (jj < s <= kk)
            // as a special case that appears to occur most of the time.
   
            s = 0;
            for (i = jj+1; i <= kk; i++) {
               if (uvec[i] != 0) {
                  if (s == 0)
                     s = i;
                  else
                     s = -1;
               }
            }
   
            if (s == 0) Error("BKZ_XD: internal error");
   
            if (s > 0) {
               // special case
   
               for (i = s; i > jj; i--) {
                  // swap i, i-1
                  swap(B(i-1), B(i));
                  if (U) swap((*U)(i-1), (*U)(i));
                  tp = B1[i-1]; B1[i-1] = B1[i]; B1[i] = tp;
                  t1 = b[i-1]; b[i-1] = b[i]; b[i] = t1;
               }
   
               // cerr << "special case\n";
               new_m = ll_LLL_XD(B, U, delta, 0, check, 
                                B1, mu, b, c, h, jj, quit);
               if (new_m != h) Error("BKZ_XD: internal error");
               if (quit) break;
            }
            else {
               // the general case
   
               for (i = 1; i <= n; i++) B(m+1, i) << 0;

               if (U) {
                  for (i = 1; i <= m_orig; i++)
                     (*U)(m+1, i) << 0;
               }

               for (i = jj; i <= kk; i++) {
                  if (uvec[i] == 0) continue;
                  MU << uvec[i];
                  RowTransform2(B(m+1), B(i), MU);
                  if (U) RowTransform2((*U)(m+1), (*U)(i), MU);
               }
      
               for (i = m+1; i >= jj+1; i--) {
                  // swap i, i-1
                  swap(B(i-1), B(i));
                  if (U) swap((*U)(i-1), (*U)(i));
                  tp = B1[i-1]; B1[i-1] = B1[i]; B1[i] = tp;
                  t1 = b[i-1]; b[i-1] = b[i]; b[i] = t1;
               }
      
               for (i = 1; i <= n; i++)
                  B1[jj][i] << B(jj, i);
      
               b[jj] = InnerProduct(B1[jj], B1[jj], n);
      
               if (b[jj] == 0) Error("BKZ_XD: internal error"); 
      
               // remove linear dependencies
   
               // cerr << "general case\n";
               new_m = ll_LLL_XD(B, U, delta, 0, 0, B1, mu, b, c, kk+1, jj, quit);
              
               if (new_m != kk) Error("BKZ_XD: internal error"); 

               // remove zero vector
      
               for (i = kk+2; i <= m+1; i++) {
                  // swap i, i-1
                  swap(B(i-1), B(i));
                  if (U) swap((*U)(i-1), (*U)(i));
                  tp = B1[i-1]; B1[i-1] = B1[i]; B1[i] = tp;
                  t1 = b[i-1]; b[i-1] = b[i]; b[i] = t1;
               }
      
               quit = 0;
               if (check) {
                  for (i = 1; i <= kk; i++)
                     if ((*check)(B(i))) {
                        quit = 1;
                        break;
                     }
               }

               if (quit) break;
   
               if (h > kk) {
                  // extend reduced basis
   
                  new_m = ll_LLL_XD(B, U, delta, 0, check, 
                                   B1, mu, b, c, h, h, quit);
   
                  if (new_m != h) Error("BKZ_XD: internal error");
                  if (quit) break;
               }
            }
   
            z = 0;
         }
         else {
            // LLL_XD
            // cerr << "progress\n";
            new_m = ll_LLL_XD(B, U, delta, 0, check, B1, mu, b, c, h, h, quit);
   
   
            if (new_m != h) Error("BKZ_XD: internal error");
            if (quit) break;
   
            z++;
         }
      }
   }

   // clean up

   if (m_orig > m) {
      // for consistency, we move zero vectors to the front

      for (i = m+1; i <= m_orig; i++) {
         swap(B(i), B(i+1));
         if (U) swap((*U)(i), (*U)(i+1));
      }

      for (i = 0; i < m; i++) {
         swap(B(m_orig-i), B(m-i));
         if (U) swap((*U)(m_orig-i), (*U)(m-i));
      }
   }

   B.SetDims(m_orig, n);
   BB = B;

   if (U) {
      U->SetDims(m_orig, m_orig);
      *UU = *U;
   }

   for (i = 1; i <= m+1; i++) {
      delete [] B1[i];
   }

   delete [] B1;

   for (i = 1; i <= m+1; i++) {
      delete [] mu[i];
   }

   delete [] mu;

   delete [] c;
   delete [] b;
   delete [] ctilda;
   delete [] vvec;
   delete [] yvec;
   delete [] uvec;
   delete [] utildavec;
   delete [] Deltavec;
   delete [] deltavec;

   return m;
}

long BKZ_XD(matrix(ZZ)& BB, matrix(ZZ)& UU, double delta, 
         long beta, long prune, LLLCheckFct check)
{
   if (delta <= 0.25 || delta > 1) Error("BKZ_XD: bad delta");
   if (beta < 2) Error("BKZ_XD: bad block size");

   return BKZ_XD(BB, &UU, xdouble(delta), beta, prune, check);
}

long BKZ_XD(matrix(ZZ)& BB, double delta, 
         long beta, long prune, LLLCheckFct check)
{
   if (delta <= 0.25 || delta > 1) Error("BKZ_XD: bad delta");
   if (beta < 2) Error("BKZ_XD: bad block size");

   return BKZ_XD(BB, 0, xdouble(delta), beta, prune, check);
}

