
#include "zz_p.h"
#include "FFT.h"

long zz_p::_modulus = 0;
double zz_p::ModulusInv = 0;


zz_pInfoT::zz_pInfoT(long NewP, long maxroot)
{
   ZZ P, B, M, M1, MinusM;
   long n, i;
   long q, t;

   p = NewP;

   pinv = 1/double(p);

   index = -1;

   P << p;

   sqr(B, P);
   LeftShift(B, B, maxroot+FFTFudge);

   set(M);
   n = 0;
   while (M <= B) {
      UseFFTPrime(n);
      q = FFTPrime[n];
      n++;
      mul(M, M, q);
   }

   if (n > 4) Error("zz_pInit: too many primes");

   NumPrimes = n;
   PrimeCnt = n;
   MaxRoot = CalcMaxRoot(q);

   if (maxroot > MaxRoot)
      Error("maxroot too big");
   else
      MaxRoot = maxroot;


   negate(MinusM, M);
   MinusMModP = rem(MinusM, p);

   if (!(CoeffModP = (long *) malloc(n * (sizeof (long)))))
      Error("out of space");

   if (!(x = (double *) malloc(n * (sizeof (double)))))
      Error("out of space");

   if (!(u = (long *) malloc(n * (sizeof (long)))))
      Error("out of space");

   for (i = 0; i < n; i++) {
      q = FFTPrime[i];

      div(M1, M, q);
      t = rem(M1, q);
      t = InvMod(t, q);
      mul(M1, M1, t);
      CoeffModP[i] = rem(M1, p);
      x[i] = ((double) t)/((double) q);
      u[i] = t;
   }
}

zz_pInfoT::zz_pInfoT(long Index)
{
   index = Index;

   if (index < 0)
      Error("bad FFT prime index");

   // allows non-consecutive indices...I'm not sure why
   while (NumFFTPrimes < index)
      UseFFTPrime(NumFFTPrimes);

   UseFFTPrime(index);

   p = FFTPrime[index];
   pinv = FFTPrimeInv[index];

   NumPrimes = 1;
   PrimeCnt = 0;

   MaxRoot = CalcMaxRoot(p);
}




zz_pInfoT::~zz_pInfoT()
{
   if (index < 0) {
      free(CoeffModP);
      free(x);
      free(u);
   }
}


zz_pInfoT *zz_pInfo = 0;

void zz_pInit(long NewP, long maxroot)
{
   if (NewP <= 1) Error("zz_pInit: modulus must be > 1");
   if (NumBits(NewP) > ZZ_NBITS)
      Error("zz_pInit: modulus too big");

   delete zz_pInfo;
   zz_pInfo = new zz_pInfoT(NewP, maxroot);

   zz_p::_modulus = NewP;
   zz_p::ModulusInv = 1/((double) NewP);
}

void zz_pFFTInit(long index)
{
   delete zz_pInfo;
   zz_pInfo = new zz_pInfoT(index);
 
   zz_p::_modulus = zz_pInfo->p;
   zz_p::ModulusInv = 1/((double) zz_pInfo->p);
}

void zz_pBak::save()
{
   delete ptr;
   MustRestore = 1;
   ptr = zz_pInfo;
   zz_pInfo = 0;

   zz_p::_modulus = 0;
   zz_p::ModulusInv = 0;
}

void zz_pBak::move()
{
   delete ptr;
   MustRestore = 0;
   ptr = zz_pInfo;
   zz_pInfo = 0;

   zz_p::_modulus = 0;
   zz_p::ModulusInv = 0;
}

void zz_pBak::restore()
{
   delete zz_pInfo;
   zz_pInfo = ptr;
   MustRestore = 0;
   ptr = 0;

   if (zz_pInfo) {
      zz_p::_modulus = zz_pInfo->p;
      zz_p::ModulusInv = 1/((double) zz_pInfo->p);
   }
   else {
      zz_p::_modulus = 0;
      zz_p::ModulusInv = 0;
   }
}


static long reduce(long a, long p)
{
   if (a >= 0 && a < p)
      return a;
   else {
      a = a % p;
      if (a < 0) a += p;
      return a;
   }
}

zz_p::zz_p(INIT_VAL_TYPE, long a)
{
   rep = reduce(a, zz_p::_modulus);
}

void operator<<(zz_p& x, long a)
{
   x.rep = reduce(a, zz_p::_modulus);
}

void operator<<(zz_p& x, const ZZ& a)
{
   x.rep = rem(a, zz_p::_modulus);
}

istream& operator>>(istream& s, zz_p& x)
{
   ZZ y;
   s >> y;
   x << y;

   return s;
}

ostream& operator<<(ostream& s, zz_p a)
{
   ZZ y;
   y << rep(a);
   s << y;

   return s;
}
