//
// LiDIA - a library for computational number theory
//   Copyright (c) 1994, 1995 by the LiDIA Group
//
// File        : modular_fft_rep.c
// Author      : Victor Shoup, Thomas Pfahler (TPf)
// Last change : TPf, Feb 29, 1996, initial version
//

#if defined(HAVE_MAC_DIRS) || defined(__MWERKS__)
#include <LiDIA:Fp_polynomial_fft.h>
#else
#include <LiDIA/Fp_polynomial_fft.h>
#endif



/***************************************************************

				class modular_fft_rep

****************************************************************/

/*
REMARK: for implementation of 
void modular_fft_rep::to_modular_fft_rep(const poly_mod_rep &a, lidia_size_t lo, lidia_size_t hi, lidia_size_t index)
-> see file fft_arith.c
*/															

sdigit* modular_fft_rep::stat_vec = 0;
lidia_size_t modular_fft_rep::stat_alloc = 0;


modular_fft_rep::modular_fft_rep(lidia_size_t l, const mcp& m) :
	k(0),
	max_k(-1)
{
    debug_handler( "modular_fft_rep", "modular_fft_rep( lidia_size_t, const mcp& ) ");
    init(l, m);
}	

modular_fft_rep::modular_fft_rep(const modular_fft_rep& R) :
	k(0),
	max_k(-1),
	F(R.F)
{
    debug_handler( "modular_fft_rep", "modular_fft_rep( modular_fft_rep& )" );
    set_length(R.k);
    C.init(F.crttable());
}

modular_fft_rep::modular_fft_rep(const fft_rep& R) :
	k(0),
	max_k(-1),
	F(R.F)
{
    debug_handler( "modular_fft_rep", "modular_fft_rep( fft_rep& )" );
    set_length(R.k);
    C.init(F.crttable());
}

	
modular_fft_rep::~modular_fft_rep()
{
    debug_handler( "modular_fft_rep", "destructor()" );
    if (max_k != -1)
	delete[] s;
}


void modular_fft_rep::set_length(lidia_size_t l)
{
    debug_handler( "modular_fft_rep", "set_length( lidia_size_t )" );

    if (l < 0)
	lidia_error_handler( "modular_fft_rep", "set_length( lidia_size_t )::bad arg" );

    lidia_size_t new_K = 1 << l;

    if (l > max_k)
    {
	if (max_k != -1)
	    delete[] s;
	s = new sdigit[new_K];
	if (!s) lidia_error_handler( "modular_fft_rep", "set_length( lidia_size_t )::out of memory" );
	max_k = l;
	k = l;
    }
    else
    {
	k = l;
    }


    if (stat_alloc < new_K)
    {
	delete[] stat_vec;
	stat_vec = new sdigit[new_K];
	if (!stat_vec)
	    lidia_error_handler( "modular_fft_rep", "set_length( lidia_size_t )::out of memory" );
	stat_alloc = new_K;
    }
}

void modular_fft_rep::set_size(lidia_size_t l)
{
    debug_handler( "modular_fft_rep", "set_size( lidia_size_t )" );

    lidia_size_t old_max_k = max_k;

    set_length(l); 		//init. length of vectors

    if (l <= old_max_k) return;	//we do not need a new F

    if (F.set_new_length(l) == true)	//i.e. need a new F for length l
	C.init(F.crttable());
    else
	C.reset();
}
	

void modular_fft_rep::init(lidia_size_t l, const mcp& m)
{
    debug_handler( "modular_fft_rep", "init( lidia_size_t, const mcp& )" );
    set_length(l);
	//init. length of vectors
    F.init(l, m);
    C.init(F.crttable());
}



///////////////////////////////////////////////////

void modular_fft_rep::to_modular_fft_rep(const Fp_polynomial &a, lidia_size_t lo, lidia_size_t hi, lidia_size_t index)
// computes an n = 2^k point convolution.
// if deg(x) >= 2^k, then x is first reduced modulo X^n-1.
{
    debug_handler( "modular_fft_rep", "to_modular_fft_rep( Fp_polynomial&, lidia_size_t, lidia_size_t, lidia_size_t )" );

    lidia_size_t K = 1 << k;
    lidia_size_t j, m, j1;

    if (lo < 0)
	lidia_error_handler( "modular_fft_rep", "to_modular_fft_rep( Fp_polynomial&, lidia_size_t, lidia_size_t, lidia_size_t )::bad arg (lo < 0)" );

    hi = comparator<lidia_size_t>::min(hi, a.degree());
    m = comparator<lidia_size_t>::max(hi-lo + 1, 0);

    bigint accum;
    const bigint &p = a.modulus();

    const bigint *aptr = &a.coeff[lo];
    sdigit *uptr = stat_vec;

    if (m < K)
    {
	C.reduce(uptr, aptr, m, index);
	for (uptr = &uptr[m], j = K-m; j > 0; j--, uptr++)
	    *uptr = 0;
    }
    else
    {
	if (m < (K<<1))
	{
	    lidia_size_t m2 = m - K;
	    const bigint *ap2 = &a.coeff[lo+K];
	    for (j = 0; j < m2; j++, aptr++, ap2++, uptr++)
	    {
		add(accum, *aptr, *ap2);
		C.reduce(*uptr, accum, index);
	    }
	    C.reduce(&stat_vec[m2], &a.coeff[m2], K-m2, index);
	}
	else
	{
	    for (j=0; j< K; j++, uptr++, aptr++)
	    {
		accum.assign( *aptr );
		for (j1 = j + K; j1 < m; j1 += K)
		    AddMod(accum, accum, a.coeff[j1+lo], p);
		C.reduce(*uptr, accum, index);
	    }
	}
    }
    F.evaluate(s, stat_vec, k, index);
}




void modular_fft_rep::from_modular_fft_rep(lidia_size_t lo, lidia_size_t hi, lidia_size_t index)
{
    debug_handler( "modular_fft_rep", "from_modular_fft_rep( lidia_size_t, lidia_size_t, lidia_size_t )" );

    lidia_size_t K = 1 << k;
    hi = comparator<lidia_size_t>::min(hi, K-1);
    lidia_size_t l = comparator<lidia_size_t>::max(hi-lo+1, 0);

    F.interpolate(stat_vec, s, k, index);
    
    F.divide_by_power_of_two(&stat_vec[lo], l, k, index);

    C.combine(&stat_vec[lo], l, index);
}

void modular_fft_rep::get_result(Fp_polynomial &a, lidia_size_t lo, lidia_size_t hi)
{
    debug_handler( "modular_fft_rep", "get_result( Fp_polynomial&, lidia_size_t, lidia_size_t )" );
    lidia_size_t K = 1 << k;
    hi = comparator<lidia_size_t>::min(hi, K-1);

    lidia_size_t l = comparator<lidia_size_t>::max(hi-lo+1, 0);
    a.set_degree(l-1);
    a.MOD = F.CT->mod;
    C.get_result(a.coeff, l);

    const bigint &p = a.modulus();
    bigint *ap = a.coeff;
    for (lidia_size_t i = l; i != 0; i--, ap++)
	Remainder(*ap, *ap, p);

    a.remove_leading_zeros();
}

void modular_fft_rep::get_result_ptr(bigint *a, lidia_size_t lo, lidia_size_t hi)
{
//used in build_from_roots
    debug_handler( "modular_fft_rep", "get_result_ptr( bigint*, lidia_size_t, lidia_size_t )" );
//enough space must be allocated for result !!!!!
    lidia_size_t l = comparator<lidia_size_t>::max(hi-lo+1, 0);
    C.get_result(a, l);

    const bigint &p = F.CT->mod.mod();
    for (lidia_size_t i = l; i != 0; i--, a++)
	Remainder(*a, *a, p);
}

///////////////////////////////////////////////////////


void reduce(modular_fft_rep &x, const modular_fft_rep &a, lidia_size_t l)
// reduces a 2^k point modular_fft_rep to a 2^l point modular_fft_rep
// input may alias output
{
    debug_handler( "modular_fft_rep", "reduce( modular_fft_rep&, modular_fft_rep&, lidia_size_t )");
    if (l > a.k)
	lidia_error_handler( "modular_fft_rep", "reduce( modular_fft_rep&, modular_fft_rep&, lidia_size_t ): bad operand" );
    if (x.F.FT != a.F.FT)
	lidia_error_handler( "modular_fft_rep", "reduce( modular_fft_rep&, modular_fft_rep&, lidia_size_t ): Reps do not match" );
		
    lidia_size_t L = 1 << l;
    if (x.k != l)
	x.set_size(l);

    sdigit *ap = &a.s[0];
    sdigit *xp = &x.s[0];
    lidia_size_t diff = a.k-l;
    for (lidia_size_t j = 0; j < L; j++, xp++)
	*xp = ap[j << diff];
}


void multiply(modular_fft_rep &x, const modular_fft_rep &a,
		    const modular_fft_rep &b, lidia_size_t index)
{
    debug_handler( "modular_fft_rep", "multiply( modular_fft_rep&, modular_fft_rep&, modular_fft_rep&, lidia_size_t )");
    lidia_size_t k = a.k;
    if (k != b.k)
	lidia_error_handler( "modular_fft_rep", "multiply( modular_fft_rep&, modular_fft_rep&, modular_fft_rep&, lidia_size_t )::size mismatch");
    if (a.F.FT != b.F.FT || a.F.FT != x.F.FT)
	lidia_error_handler( "modular_fft_rep", "multiply( modular_fft_rep&, modular_fft_rep&, modular_fft_rep&, lidia_size_t )::Reps do not match" );

    if (x.k != k)
	x.set_size(k);

    x.F.pointwise_multiply(x.s, a.s, b.s, k, index);
}

void add(modular_fft_rep &x, const modular_fft_rep &a,
		const modular_fft_rep &b, lidia_size_t index)
{
    debug_handler( "modular_fft_rep", "add( modular_fft_rep&, modular_fft_rep&, modular_fft_rep&, lidia_size_t )");
    lidia_size_t k = a.k;
    if (k != b.k)
	lidia_error_handler( "modular_fft_rep", "add( modular_fft_rep&, modular_fft_rep&, modular_fft_rep&, lidia_size_t )::size mismatch");
    if (a.F.FT != b.F.FT || a.F.FT != x.F.FT)
	lidia_error_handler( "modular_fft_rep", "add( modular_fft_rep&, modular_fft_rep&, modular_fft_rep&, lidia_size_t )::Reps do not match" );

    if (x.k != k)
	x.set_size(k);

    x.F.pointwise_add(x.s, a.s, b.s, k, index);
}


void subtract(modular_fft_rep &x, const modular_fft_rep &a,
		const modular_fft_rep &b, lidia_size_t index)
{
    debug_handler( "modular_fft_rep", "subtract( modular_fft_rep&, modular_fft_rep&, modular_fft_rep&, lidia_size_t )");
    lidia_size_t k = a.k;
    if (k != b.k)
	lidia_error_handler( "modular_fft_rep", "subtract( modular_fft_rep&, modular_fft_rep&, modular_fft_rep&, lidia_size_t )::size mismatch");
    if (a.F.FT != b.F.FT || a.F.FT != x.F.FT)
	lidia_error_handler( "modular_fft_rep", "subtract( modular_fft_rep&, modular_fft_rep&, modular_fft_rep&, lidia_size_t )::Reps do not match" );

    if (x.k != k)
	x.set_size(k);

    x.F.pointwise_subtract(x.s, a.s, b.s, k, index);
}


void add_mul(modular_fft_rep &x,
	const modular_fft_rep &a, const modular_fft_rep &b,
	const modular_fft_rep &c, const modular_fft_rep &d, lidia_size_t index)
{
    debug_handler( "modular_fft_rep", "add_mul( modular_fft_rep&, modular_fft_rep&, modular_fft_rep&, modular_fft_rep&, modular_fft_rep&, lidia_size_t )");
    lidia_size_t k = a.k;
    if (k != b.k || k != c.k || k != d.k)
	lidia_error_handler( "modular_fft_rep", "add_mul( modular_fft_rep&, modular_fft_rep&, modular_fft_rep&, modular_fft_rep&, modular_fft_rep&, lidia_size_t )::size mismatch");
    if (x.F.FT != a.F.FT || x.F.FT != b.F.FT || x.F.FT != c.F.FT || x.F.FT != d.F.FT)
	lidia_error_handler( "modular_fft_rep", "multiply( modular_fft_rep&, modular_fft_rep&, modular_fft_rep&, lidia_size_t )::Reps do not match" );

    if (x.k != k)
	x.set_size(k);

    x.F.pointwise_add_mul(x.s, a.s, b.s, c.s, d.s, k, index);
}



void multiply(modular_fft_rep &x, const fft_rep &a,
	    const modular_fft_rep &b, lidia_size_t index)
{
    debug_handler( "::", "multiply( modular_fft_rep&, fft_rep&, modular_fft_rep&, lidia_size_t )");
    lidia_size_t k = a.k;
    if (k != b.k)
	lidia_error_handler( "modular_fft_rep", "multiply( modular_fft_rep&, fft_rep&, modular_fft_rep&, lidia_size_t )::size mismatch");
    if (x.F.FT != a.F.FT || x.F.FT != b.F.FT)
	lidia_error_handler( "modular_fft_rep", "multiply( modular_fft_rep&, fft_rep&, modular_fft_rep&, lidia_size_t )::Reps do not match" );

    if (x.k != k)
	x.set_size(k);

    x.F.pointwise_multiply(x.s, a.tbl[index], b.s, k, index);
}


void reduce(modular_fft_rep &x, const fft_rep &a, lidia_size_t l, lidia_size_t index)
{
// reduces a 2^k point FFT-rep to a 2^l point ModularFFT-rep
    debug_handler( "modular_fft_rep", "reduce( modular_fft_rep&, fft_rep&, lidia_size_t, lidia_size_t )");
    if (l > a.k)
	lidia_error_handler( "modular_fft_rep", "reduce( modular_fft_rep&, fft_rep&, lidia_size_t, lidia_size_t ): bad operand");
    if (x.F.FT != a.F.FT)
	lidia_error_handler( "modular_fft_rep", "reduce( modular_fft_rep&, fft_rep&, lidia_size_t, lidia_size_t )::Reps do not match" );

    lidia_size_t L = 1 << l;
    if (x.k != l)	
	x.set_size(l);

    sdigit *ap = &a.tbl[index][0];
    sdigit *xp = x.s;
    lidia_size_t diff = a.k-l;

    for (lidia_size_t j = 0; j < L; j++, xp++)
	*xp = ap[j << diff];
}

void subtract(modular_fft_rep &x, const fft_rep &a, const modular_fft_rep &b,
		    lidia_size_t index)
{
    debug_handler( "modular_fft_rep", "subtract( modular_fft_rep&, fft_rep&, modular_fft_rep&, lidia_size_t )");
    lidia_size_t k = a.k;
    if (k != b.k)
	lidia_error_handler( "modular_fft_rep", "subtract( modular_fft_rep&, fft_rep&, modular_fft_rep&, lidia_size_t )::size mismatch");
    if (x.F.FT != a.F.FT || x.F.FT != b.F.FT)
	lidia_error_handler( "modular_fft_rep", "subtract( modular_fft_rep&, fft_rep&, modular_fft_rep&, lidia_size_t )::Reps do not match" );

    if (x.k != k)
	x.set_size(k);

    x.F.pointwise_subtract(x.s, a.tbl[index], b.s, k, index);
}

