
#include "LLL.h"

#include "mat_RR.h"




static void RowTransform(vector(ZZ)& A, vector(ZZ)& B, const ZZ& MU1)
// x = x - y*MU
{
   static ZZ T, MU;
   long k;

   MU = MU1;

   if (MU == 0) return;

   if (digit(MU, 0) == 0) 
      k = MakeOdd(MU);
   else
      k = 0;

   long n = A.length();
   long i;

   if (MU.size() <= 1) {
      long mu1;
      mu1 << MU;

      for (i = 1; i <= n; i++) {
         mul(T, B(i), mu1);
         if (k > 0) LeftShift(T, T, k);
         sub(A(i), A(i), T);
      }
   }
   else {
      for (i = 1; i <= n; i++) {
         mul(T, B(i), MU);
         if (k > 0) LeftShift(T, T, k);
         sub(A(i), A(i), T);
      }
   }
}

static void RowTransform2(vector(ZZ)& A, vector(ZZ)& B, const ZZ& MU1)
// x = x + y*MU
{
   static ZZ T, MU;
   long k;

   MU = MU1;

   if (MU == 0) return;

   if (digit(MU, 0) == 0) 
      k = MakeOdd(MU);
   else
      k = 0;

   long n = A.length();
   long i;

   if (MU.size() <= 1) {
      long mu1;
      mu1 << MU;

      for (i = 1; i <= n; i++) {
         mul(T, B(i), mu1);
         if (k > 0) LeftShift(T, T, k);
         add(A(i), A(i), T);
      }
   }
   else {
      for (i = 1; i <= n; i++) {
         mul(T, B(i), MU);
         if (k > 0) LeftShift(T, T, k);
         add(A(i), A(i), T);
      }
   }
}


void ComputeGS(matrix(ZZ)& B, matrix(RR)& B1, matrix(RR)& mu, vector(RR)& b, 
               vector(RR)& c, long k, const RR& bound)
{
   long i, j;
   RR s, t, t1;
   ZZ T1;

   for (j = 1; j <= k-1; j++) {
      InnerProduct(s, B1(k), B1(j));

      sqr(t1, s);
      mul(t, b(k), b(j));
      div(t, t, bound);

      if (t1 <= t) {
         InnerProduct(T1, B(k), B(j));
         s << T1;
      }

      clear(t1);
      for (i = 1; i <= j-1; i++) {
         mul(t, mu(j, i), mu(k, i));
         mul(t, t, c(i));
         add(t1, t1, t);
      }

      sub(t, s, t1);
      div(mu(k, j), t, c(j));
   }

   clear(s);
   for (j = 1; j <= k-1; j++) {
      sqr(t, mu(k, j));
      mul(t, t, c(j));
      add(s, s, t);
   }

   sub(c(k), b(k), s);
}

static
long ll_LLL_RR(matrix(ZZ)& B, matrix(ZZ)* U, const RR& delta, long deep, 
           LLLCheckFct check, matrix(RR)& B1, matrix(RR)& mu, 
           vector(RR)& b, vector(RR)& c, long m, long init_k, long &quit)
{
   long n = B.NumCols();

   long i, j, k, Fc, Fc1;
   ZZ MU;
   RR mu1, t1, t2, cc;
   ZZ T1;

   RR bound;

      // we tolerate a 15% loss of precision in computing
      // inner products in ComputeGS.

   bound << 1;
   for (i = 2*long(0.15*RR::precision()); i > 0; i--) 
      mul(bound, bound, 2);


   RR bound1;

   bound1 << 1;
   for (i = long(0.15*RR::precision()); i > 0; i--)
      mul(bound1, bound1, 2);


   quit = 0;
   k = init_k;

   while (k <= m) {


      ComputeGS(B, B1, mu, b, c, k, bound);


      do {
         // size reduction

         Fc = Fc1 = 0;
   
         for (j = k-1; j >= 1; j--) {
            abs(t1, mu(k,j));
            if (t1 > 0.5) {
               Fc1 = 1;
   
               mu1 = mu(k,j);
               if (sign(mu1) >= 0) {
                  sub(mu1, mu1, 0.5);
                  ceil(mu1, mu1);
               }
               else {
                  add(mu1, mu1, 0.5);
                  floor(mu1, mu1);
               }

               if (t1 > bound1)
                  Fc = 1;
   
               for (i = 1; i <= j-1; i++) {
                  mul(t2, mu1, mu(j,i));
                  sub(mu(k,i), mu(k,i), t2);
               }

               sub(mu(k,j), mu(k,j), mu1);
   
               MU << mu1;
   
               RowTransform(B(k), B(j), MU);
               if (U) RowTransform((*U)(k), (*U)(j), MU);
            }
         }

         if (Fc1) {
            for (i = 1; i <= n; i++)
               B1(k, i) << B(k, i);
   
            InnerProduct(b(k), B1(k), B1(k));
            ComputeGS(B, B1, mu, b, c, k, bound);
         }
      } while (Fc);

      if (check && (*check)(B(k))) 
         quit = 1;

      if (IsZero(b(k))) {
         for (i = k; i < m; i++) {
            // swap i, i+1
            swap(B(i), B(i+1));
            swap(B1(i), B1(i+1));
            swap(b(i), b(i+1));
            if (U) swap((*U)(i), (*U)(i+1));
         }

         m--;
         if (quit) break;
         continue;
      }

      if (quit) break;

      if (deep > 0) {
         // deep insertions
   
         cc = b(k);
         long l = 1;
         while (l <= k-1) { 
            mul(t1, delta, c(l));
            if (t1 > cc) break;
            sqr(t1, mu(k,l));
            mul(t1, t1, c(l));
            sub(cc, cc, t1);
            l++;
         }
   
         if (l <= k-1 && (l <= deep || k-l <= deep)) {
            // deep insertion at position l
   
            for (i = k; i > l; i--) {
               // swap rows i, i-1
               swap(B(i), B(i-1));
               swap(B1(i), B1(i-1));
               swap(b(i), b(i-1));
               if (U) swap((*U)(i), (*U)(i-1));
            }
   
            k = l;
            continue;
         }
      } // end deep insertions

      // test LLL reduction condition

      if (k <= 1) 
         k++;
      else {
         sqr(t1, mu(k,k-1));
         mul(t1, t1, c(k-1));
         add(t1, t1, c(k));
         mul(t2, delta, c(k-1));
         if (t2 > t1) {
            // swap rows k, k-1
            swap(B(k), B(k-1));
            swap(B1(k), B1(k-1));
            swap(b(k), b(k-1));
            if (U) swap((*U)(k), (*U)(k-1));
   
            k--;
         }
         else
            k++;
      }
   }

   return m;
}

static
long LLL_RR(matrix(ZZ)& B, matrix(ZZ)* U, const RR& delta, long deep, 
           LLLCheckFct check)
{
   long m = B.NumRows();
   long n = B.NumCols();

   long i, j, k, Fc;
   long new_m, dep, quit;
   RR s;
   ZZ MU;
   RR mu1;

   RR t1;
   ZZ T1;

   if (U) ident(*U, m);

   matrix(RR) B1;  // approximates B
   B1.SetDims(m, n);


   matrix(RR) mu;
   mu.SetDims(m, m);

   vector(RR) c;  // squared lengths of Gramm-Schmidt basis vectors
   c.SetLength(m);

   vector(RR) b; // squared lengths of basis vectors
   b.SetLength(m);


   for (i = 1; i <=m; i++)
      for (j = 1; j <= n; j++) 
         B1(i, j) << B(i, j);


         
   for (i = 1; i <= m; i++) {
      InnerProduct(b(i), B1(i), B1(i));
   }


   new_m = ll_LLL_RR(B, U, delta, deep, check, B1, mu, b, c, m, 1, quit);
   dep = m - new_m;
   m = new_m;

   if (dep > 0) {
      // for consistency, we move all of the zero rows to the front

      for (i = 0; i < m; i++) {
         swap(B(m+dep-i), B(m-i));
         if (U) swap((*U)(m+dep-i), (*U)(m-i));
      }
   }


   return m;
}

         

long LLL_RR(matrix(ZZ)& B, double delta, long deep, LLLCheckFct check)
{
   if (delta <= 0.25 || delta > 1) Error("LLL_RR: bad delta");
   if (deep < 0) Error("LLL_RR: bad deep");
   RR Delta;
   Delta << delta;
   return LLL_RR(B, 0, Delta, deep, check);
}

long LLL_RR(matrix(ZZ)& B, matrix(ZZ)& U, double delta, long deep, 
           LLLCheckFct check)
{
   if (delta <= 0.25 || delta > 1) Error("LLL_RR: bad delta");
   if (deep < 0) Error("LLL_RR: bad deep");
   RR Delta;
   Delta << delta;
   return LLL_RR(B, &U, Delta, deep, check);
}




static
long BKZ_RR(matrix(ZZ)& BB, matrix(ZZ)* UU, const RR& delta, 
         long beta, long prune, LLLCheckFct check)
{
   long m = BB.NumRows();
   long n = BB.NumCols();
   long m_orig = m;
   
   long i, j;
   ZZ MU;

   RR t1, t2;
   ZZ T1;

   matrix(ZZ) B;
   B = BB;

   B.SetDims(m+1, n);


   matrix(RR) B1;
   B1.SetDims(m+1, n);

   matrix(RR) mu;
   mu.SetDims(m+1, m);

   vector(RR) c;
   c.SetLength(m+1);

   vector(RR) b;
   b.SetLength(m+1);

   RR cbar;

   vector(RR) ctilda;
   ctilda.SetLength(m+1);

   vector(RR) vvec;
   vvec.SetLength(m+1);

   vector(RR) yvec;
   yvec.SetLength(m+1);

   vector(RR) uvec;
   uvec.SetLength(m+1);

   vector(RR) utildavec;
   utildavec.SetLength(m+1);

   vector(long) Deltavec;
   Deltavec.SetLength(m+1);

   vector(long) deltavec;
   deltavec.SetLength(m+1);

   matrix(ZZ) Ulocal;
   matrix(ZZ) *U;

   if (UU) {
      Ulocal.SetDims(m+1, m);
      for (i = 1; i <= m; i++)
         Ulocal(i, i) << 1;
      U = &Ulocal;
   }
   else
      U = 0;

   long quit;
   long new_m;
   long z, jj, kk;
   long s, t;
   long h;
   long mu1;
   double alpha;


   for (i = 1; i <=m; i++)
      for (j = 1; j <= n; j++) 
         B1(i, j) << B(i, j);

         
   for (i = 1; i <= m; i++) {
      InnerProduct(b(i), B1(i), B1(i));
   }

   // cerr << "\n";
   // cerr << "first LLL\n";

   m = ll_LLL_RR(B, U, delta, 0, check, B1, mu, b, c, m, 1, quit);

   if (m < m_orig) {
      for (i = m_orig+1; i >= m+2; i--) {
         // swap i, i-1

         swap(B(i), B(i-1));
         if (U) swap((*U)(i), (*U)(i-1));
      }
   }

   if (!quit && m > 1) {
      // cerr << "continuing\n";
      if (beta > m) beta = m;

      z = 0;
      jj = 0;
   
      while (z < m-1) {
         jj++;
         kk = min(jj+beta-1, m);
   
         if (jj == m) {
            jj = 1;
            kk = beta;
         }
   
         // ENUM
   
         cbar = c(jj);
         utildavec(jj) << 1;
         uvec(jj) << 1;
   
         yvec(jj) << 0;
         vvec(jj) << 0;
         Deltavec(jj) = 0;
   
   
         s = t = jj;
         deltavec(jj) = 1;
   
         for (i = jj+1; i <= kk+1; i++) {
            ctilda(i) << 0;
            uvec(i) << 0;
            utildavec(i) << 0;
            yvec(i) << 0;
            Deltavec(i) = 0;
            vvec(i) << 0;
            deltavec(i) = 1;
         }
   
         while (t <= kk) {
            add(t1, yvec(t), utildavec(t));
            sqr(t1, t1);
            mul(t1, t1, c(t));
            add(ctilda(t), ctilda(t+1), t1);

   
            if (prune) {
               alpha = 1.05*double(kk-t+1)/double(kk-jj);
               if (alpha > 1) alpha = 1;
            }
            else
               alpha = 1;

            mul(t1, alpha, cbar);
   
            if (ctilda(t) < t1) {
               if (t > jj) {
                  t--;
                  clear(t1);
                  for (i = t+1; i <= s; i++) {
                     mul(t2, utildavec(i), mu(i,t));
                     add(t1, t1, t2);
                  }

                  yvec(t) = t1;
                  negate(t1, t1);
                  if (sign(t1) >= 0) {
                     sub(t1, t1, 0.5);
                     ceil(t1, t1);
                  }
                  else {
                     add(t1, t1, 0.5);
                     floor(t1, t1);
                  }

                  utildavec(t) = t1;
                  vvec(t) = t1;
                  Deltavec(t) = 0;

                  negate(t1, t1);

                  if (t1 < yvec(t)) 
                     deltavec(t) = -1;
                  else
                     deltavec(t) = 1;
               }
               else {
                  cbar = ctilda(jj);
                  for (i = jj; i <= kk; i++) {
                     uvec(i) = utildavec(i);
                  }
               }
            }
            else {
               t++;
               s = max(s, t);
               if (t < s) Deltavec(t) = -Deltavec(t);
               if (Deltavec(t)*deltavec(t) >= 0) Deltavec(t) += deltavec(t);
               add(utildavec(t), vvec(t), Deltavec(t));
            }
         }
         
   
         h = min(kk+1, m);

         mul(t1, delta, c(jj));
   
         if (t1 > cbar) {
            // we treat the case that the new vector is b_s (jj < s <= kk)
            // as a special case that appears to occur most of the time.
   
            s = 0;
            for (i = jj+1; i <= kk; i++) {
               if (uvec(i) != 0) {
                  if (s == 0)
                     s = i;
                  else
                     s = -1;
               }
            }
   
            if (s == 0) Error("BKZ_RR: internal error");
   
            if (s > 0) {
               // special case
               // cerr << "special case\n";
   
               for (i = s; i > jj; i--) {
                  // swap i, i-1
                  swap(B(i-1), B(i));
                  swap(B1(i-1), B1(i));
                  swap(b(i-1), b(i));
                  if (U) swap((*U)(i-1), (*U)(i));
               }
   
               new_m = ll_LLL_RR(B, U, delta, 0, check, 
                                B1, mu, b, c, h, jj, quit);
               if (new_m != h) Error("BKZ_RR: internal error");
               if (quit) break;
            }
            else {
               // the general case
   
               for (i = 1; i <= n; i++) B(m+1, i) << 0;

               if (U) {
                  for (i = 1; i <= m_orig; i++)
                     (*U)(m+1, i) << 0;
               }

               for (i = jj; i <= kk; i++) {
                  if (uvec(i) == 0) continue;
                  MU << uvec(i);
                  RowTransform2(B(m+1), B(i), MU);
                  if (U) RowTransform2((*U)(m+1), (*U)(i), MU);
               }
      
               for (i = m+1; i >= jj+1; i--) {
                  // swap i, i-1
                  swap(B(i-1), B(i));
                  swap(B1(i-1), B1(i));
                  swap(b(i-1), b(i));
                  if (U) swap((*U)(i-1), (*U)(i));
               }
      
               for (i = 1; i <= n; i++)
                  B1(jj, i) << B(jj, i);
      
               InnerProduct(b(jj), B1(jj), B1(jj));
      
               if (b(jj) == 0) Error("BKZ_RR: internal error"); 
      
               // remove linear dependencies
   
               // cerr << "general case\n";
               new_m = ll_LLL_RR(B, U, delta, 0, 0, B1, mu, b, c, kk+1, jj, quit);
              
               if (new_m != kk) Error("BKZ_RR: internal error"); 

               // remove zero vector
      
               for (i = kk+2; i <= m+1; i++) {
                  // swap i, i-1
                  swap(B(i-1), B(i));
                  swap(B1(i-1), B1(i));
                  swap(b(i-1), b(i));
                  if (U) swap((*U)(i-1), (*U)(i));
               }
      
               quit = 0;
               if (check) {
                  for (i = 1; i <= kk; i++)
                     if ((*check)(B(i))) {
                        quit = 1;
                        break;
                     }
               }

               if (quit) break;
   
               if (h > kk) {
                  // extend reduced basis
   
                  new_m = ll_LLL_RR(B, U, delta, 0, check, 
                                   B1, mu, b, c, h, h, quit);
   
                  if (new_m != h) Error("BKZ_RR: internal error");
                  if (quit) break;
               }
            }
   
            z = 0;
         }
         else {
            // LLL_RR
            // cerr << "progress\n";
            new_m = ll_LLL_RR(B, U, delta, 0, check, B1, mu, b, c, h, h, quit);
   
   
            if (new_m != h) Error("BKZ_RR: internal error");
            if (quit) break;
   
            z++;
         }
      }
   }

   // clean up

   if (m_orig > m) {
      // for consistency, we move zero vectors to the front

      for (i = m+1; i <= m_orig; i++) {
         swap(B(i), B(i+1));
         if (U) swap((*U)(i), (*U)(i+1));
      }

      for (i = 0; i < m; i++) {
         swap(B(m_orig-i), B(m-i));
         if (U) swap((*U)(m_orig-i), (*U)(m-i));
      }
   }

   B.SetDims(m_orig, n);
   BB = B;

   if (U) {
      U->SetDims(m_orig, m_orig);
      *UU = *U;
   }

   return m;
}

long BKZ_RR(matrix(ZZ)& BB, matrix(ZZ)& UU, double delta, 
         long beta, long prune, LLLCheckFct check)
{
   if (delta <= 0.25 || delta > 1) Error("BKZ_RR: bad delta");
   if (beta < 2) Error("BKZ_RR: bad block size");

   RR Delta;
   Delta << delta;

   return BKZ_RR(BB, &UU, Delta, beta, prune, check);
}

long BKZ_RR(matrix(ZZ)& BB, double delta, 
         long beta, long prune, LLLCheckFct check)
{
   if (delta <= 0.25 || delta > 1) Error("BKZ_RR: bad delta");
   if (beta < 2) Error("BKZ_RR: bad block size");

   RR Delta;
   Delta << delta;

   return BKZ_RR(BB, 0, Delta, beta, prune, check);
}



