/* Integer Version 2.1, RD, 19.7.92 	imod.c			*/

/* Replaced several
	if (carry || mvecgt(sum->val, mv, k))
   by
	if (carry || !mvecgt(mv, sum->val, k))
   (or similar)
   in MasMplM, MplasM, MasMmiM, MmiasM.
   Thanks to robert@vlsi.cs.caltech.edu
   RD, 18.11.93 						*/
/* Completed dMod, RD, 18.11.93					*/

/* Modular arithmetic with Montgomery's method. Based on the large
   integer arithmetic libI. RD, 16.7.93.			*/

#include "imod.h"
#include "idigit.h"
#include "imem.h"
/* #include "random.h" */
#include <stdlib.h>
#include "timing.h"


typedef DigitType *mvec;
typedef DigitType *mdoubvec;
typedef void *pointer;

/*********************   special names   **************/

#define mvecadd DigitVecAdd
#define mvecsub DigitVecSub
#define	mvecsr1 DigitVecSr1

/***************   special memory management   **********************/

static mvec 
newmvec(mod)
    ModulusType *mod;
{
    DigitType *u;
    int i;
    if (mod->memsingfree)
    {
	u = (DigitType *) (mod->memsingfree);
	mod->memsingfree = *((pointer *) (mod->memsingfree));
	return u;
    }
    else
    {
	i = mod->length * sizeof(DigitType);
	u = (DigitType *) Imalloc((i > sizeof(void *) ? i : sizeof(void *)));
	if (!u)
	    Merror("newmvec: memory full\n");
	return u;
    }
}				/* newmvec */

static void 
delmvec(u, mod)
    mvec u;
    ModulusType *mod;
{
    pointer *v;
    v = (pointer *) u;
    *v = mod->memsingfree;
    mod->memsingfree = (void *) u;
}				/* delmvec */

static mdoubvec 
newmdoubvec(mod)
    ModulusType *mod;
{
    DigitType *u;
    int i;
    if (mod->memdoubfree)
    {
	u = (DigitType *) (mod->memdoubfree);
	mod->memdoubfree = *((pointer *) (mod->memdoubfree));
	return u;
    }
    else
    {
	i = 2 * mod->length * sizeof(DigitType);
	u = (DigitType *) Imalloc((i > sizeof(void *) ? i : sizeof(void *)));
	if (!u)
	    Merror("newmdoubvec: memory full\n");
	return u;
    }
}				/* newmdoubvec */

static void 
delmdoubvec(u, mod)
    mdoubvec u;
    ModulusType *mod;
{
    pointer *v;
    v = (pointer *) u;
    *v = mod->memdoubfree;
    mod->memdoubfree = (void *) u;
}				/* delmdoubvec */

/***************   special vector functions   **********************/

static void 
mvecas(a, b, k)
    mvec a, b;
    int k;
/* a[k]=b[k]; */
{
    for (; k > 0; k--)
	*a++ = *b++;
}				/* mvecas */

static void 
mvecas0(a, k)
    mvec a;
    int k;
/* a[k]=0; */
{
    for (; k > 0; k--)
	*a++ = 0;
}				/* mvecas0 */

static void 
mvecas1(a, k)
    mvec a;
    int k;
/* a[k]=1; */
{
    *a++ = 1;
    k--;
    for (; k > 0; k--)
	*a++ = 0;
}				/* mvecas1 */

BOOLEAN 
mveceq0(a, k)
    mvec a;
    int k;
/* return a[k]==0; */
{
    for (; k > 0; k--)
	if (*a++)
	    return FALSE;
    return TRUE;
}				/* mveceq0 */

BOOLEAN 
mveceq1(a, k)
    mvec a;
    int k;
/* return a[k]==1; */
{
    if (*a++ != 1)
	return FALSE;
    k--;
    for (; k > 0; k--)
	if (*a++)
	    return FALSE;
    return TRUE;
}				/* mveceq1 */

static BOOLEAN 
mvecgt(a, b, k)
    mvec a, b;
    int k;
/* return a[k]>b[k]; lexikographisch */
{
    for (a += k, b += k; k > 0; k--)
	if (*--a > *--b)
	    return TRUE;
	else if (*a < *b)
	    return FALSE;
    return FALSE;
}				/* mvecgt */


static void 
mveccorsub(diff, b, c, mod, k)
    mvec diff, b, c, mod;
    int k;
/* diff[k]=b[k]-c[k]+mod[k]; */
/*
{	DigitVecAdd(diff, b, mod, k);
	DigitVecSub(diff, diff, c, k);
}
*/
{
    DigitType accu, ac = 0, sc = 0, bb, cc, mm, tmp;
    int i;
    for (i = 0; i < k; i++)
    {
	bb = b[i];
	cc = c[i];
	mm = mod[i];
	/* now : bb + ac + mm - sc - cc */
	accu = bb + ac;
	ac = (accu < bb);
	accu += mm;
	ac += (accu < mm);
	tmp = accu - sc;
	sc = tmp > accu;
	accu = tmp - cc;
	sc += accu > tmp;
	diff[i] = accu;
    }
}				/* mveccorsub */

static void 
mvecmul(prod, b, c, k)
    mdoubvec prod;
    mvec b;
    mvec c;
    int k;
/* prod[2k]=b[k]*c[k]; */
{
    int i;

    prod[k] = DigitVecMult(prod, b, c[0], k);
    for (i = 1; i < k; i++)
    {
	prod[k + i] = DigitVecMultAdd(&prod[i], b, c[i], k);
    }
}				/* mvecmul */

static void 
omvecred(t, n, nprime, k)
    mdoubvec t;
    mvec n;
    DigitType nprime;
    int k;
/* Aendere t so modulo n ab, dass seine niedere Haelfte Null ist. */
{
    DigitType *a, m, c = 0, tmp, aa, c1;
    int j;


    for (j = 0; j < k; j++)
    {
	a = &t[j];
	m = *a * nprime;
	tmp = DigitVecMultAdd(a, n, m, k);
	aa = a[k];
	tmp += aa;
	c1 = tmp < aa;
	tmp += c;
	c = c1 + (tmp < c);
	a[k] = tmp;
    }
    if (c || !mvecgt(n, &t[k], k))
	mvecsub(&t[k], &t[k], n, k);
}				/* omvecred */


static void 
mvaddsr1(r, a, mod, k)
    mvec r, a, mod;
    int k;
/* r[k] = (a[k]+mod[k]) >> 1; */
/* r == a moeglich */
{
    DigitType carry;
    carry = DigitVecAdd(r, a, mod, k);
    DigitVecSr1(r, k);
    if (carry)
	r[k - 1] = r[k - 1] | (1L << (BitsPerDigit - 1));
}				/* mvaddsr1 */

/*********** static function for random Digits ********/

#define NO_RANDOM_BITS 31

static DigitType 
Prandom()
 /* return  random DigitType */
{
    unsigned int x;
    static BOOLEAN init = FALSE;
    int i;
    DigitType ran;

    if (!init)
    {
	init = TRUE;
	x = timeseed();
	srandom(x);
    }

    ran = random();
    i = NO_RANDOM_BITS;
    while (i < BitsPerDigit)
    {
	ran = (ran << NO_RANDOM_BITS) | random();
	i += NO_RANDOM_BITS;
    }
    return ran;
}				/* Prandom */

/*********** static function for inversion ********/

#define MEVEN(A) (!((A)&1))

static BOOLEAN 
mvecinv(inv, val, mod)
    mvec inv, val;
    ModulusType *mod;
{
    int k;
    BOOLEAN ok;
    mvec swap, a, b, x, u, b0;

    k = mod->length;
    if (mveceq0(val, k))
	return FALSE;
    a = newmvec(mod);
    b = newmvec(mod);
    x = newmvec(mod);
    u = newmvec(mod);
    mvecas(a, val, k);
    b0 = mod->vec;
    mvecas(b, b0, k);
    mvecas1(u, k);		/* a == u*a0 + v*b0; */
    mvecas0(x, k);		/* b == x*a0 + y*b0; */
    while (MEVEN(*a))
    {
	mvecsr1(a, k);
	if (MEVEN(*u))
	    mvecsr1(u, k);
	else
	    mvaddsr1(u, u, b0, k);
    }
    while (TRUE)
    {
	if (mvecgt(b, a, k))
	{
	    swap = b;
	    b = a;
	    a = swap;
	    swap = x;
	    x = u;
	    u = swap;
	}
	mvecsub(a, a, b, k);
	if (!mveceq0(a, k))
	{
	    if (mvecgt(x, u, k))
		mveccorsub(u, u, x, b0, k);
	    else
		mvecsub(u, u, x, k);
	    while (MEVEN(*a))
	    {
		mvecsr1(a, k);
		if (MEVEN(*u))
		    mvecsr1(u, k);
		else
		    mvaddsr1(u, u, b0, k);
	    }
	}
	else
	    break;
    }				/* while(TRUE) */
    ok = mveceq1(b, k);
    mvecas(inv, x, k);
    delmvec(u, mod);
    delmvec(x, mod);
    delmvec(b, mod);
    delmvec(a, mod);
    return ok;
}				/* mvecinv */



/***********************************************************
** EXTERN_FUNCTION(void Merror, (const char *));
** 	Error message
*/
void 
Merror(s)
    const char *s;
{
    fprintf(stderr, "M: %s\n", s);
#ifdef unix
    abort();
#else
    exit(-1);
#endif
}				/* Merror */

/***********************************************************
** EXTERN_FUNCTION(void cMod, (ModulusType *, const Integer *));
**	Creator Modulus
*/
void 
cMod(mod, a)
    ModulusType *mod;
    const Integer *a;
{
    Integer r, rsquare, rcube, u, v, d;
    int i;
    if (Ile0(a))
	Merror("cMod: <=0");
    if (Ieq1(a))
	Merror("cMod: 1");
    if (Ieven(a))
	Merror("cMod: even");
    mod->memsingfree = NULL;
    mod->memdoubfree = NULL;
    mod->length = a->length;
    mod->vec = newmvec(mod);
    mvecas(mod->vec, a->vec, mod->length);
    cIasI(&mod->ModIval, a);
    cIasint(&r, 1);
    cI(&rsquare);
    cI(&rcube);
    cI(&u);
    cI(&v);
    cI(&d);
    IslasD(&r, BitsPerDigit * a->length);
    Ixgcd(&d, &u, &v, &r, a);
    if (!Ieq1(&d))
	Merror("cMod: d!=1");
    Ineg(&v);
    while (Ilt0(&v))
	IplasI(&v, &r);
    while (IgeI(&v, &r))
	ImiasI(&v, &r);
    mod->nprime = v.vec[0];
    IasImuI(&rsquare, &r, &r);
    IreasI(&rsquare, a);
    IasImuI(&rcube, &rsquare, &r);
    IreasI(&rcube, a);
    mod->rsquare = newmvec(mod);
    mod->rcube = newmvec(mod);
    for (i = 0; i < rsquare.length; i++)
	mod->rsquare[i] = rsquare.vec[i];
    while (i < mod->length)
	mod->rsquare[i++] = 0;
    for (i = 0; i < rcube.length; i++)
	mod->rcube[i] = rcube.vec[i];
    while (i < mod->length)
	mod->rcube[i++] = 0;
    dI(&r);
    dI(&rsquare);
    dI(&rcube);
    dI(&u);
    dI(&v);
    dI(&d);
}				/* cMod */

/***********************************************************
** EXTERN_FUNCTION(void dMod, (ModulusType *));
**	Destructor Modulus
*/
void 
dMod(mod)
    ModulusType *mod;
{
    int k;
    void *u, *v;

/*	Hier soll jetzt der komplette Speicher der Listen
 *	memsingfree und memdoubfree zurueckgegeben werden, ebenso
 *	die Vektoren in mod.
 */

    delmvec(mod->rcube, mod);
    delmvec(mod->rsquare, mod);
    delmvec(mod->vec, mod);
    dI(&mod->ModIval);
    k = mod->length;
    u = mod->memsingfree;
    while (u)
    {
	v = u;
	Ifree(u);
	u = *((pointer *) v);
    }
    u = mod->memdoubfree;
    while (u)
    {
	v = u;
	Ifree(u);
	u = *((pointer *) v);
    }
}				/* dMod */

/***********************************************************
** EXTERN_FUNCTION(void IasMod, (Integer *, const ModulusType *));
**	Modulus value
*/
void 
IasMod(a, m)
    Integer *a;
    const ModulusType *m;
{
    IasI(a, &m->ModIval);
}				/* IasMod */

/***********************************************************
** EXTERN_FUNCTION(const ModulusType * Mmod, (const Minteger *));
**	Reference Modulus
*/
const ModulusType *
Mmod(a)
    const Minteger *a;
{
    return a->mod;
}				/* Mmod */

/***********************************************************
** EXTERN_FUNCTION(void cM, (Minteger *, ModulusType *));
**	Creator Minteger
*/
void 
cM(a, mod)
    Minteger *a;
    ModulusType *mod;
{
    a->mod = mod;
    a->val = newmvec(mod);
    mvecas0(a->val, mod->length);
}				/* cM */

/***********************************************************
** EXTERN_FUNCTION(void cMasI, (Minteger *, const Integer *,
**	ModulusType *));
**	Creator Minteger, Init Integer
*/
void 
cMasI(r, a, mod)
    Minteger *r;
    const Integer *a;
    ModulusType *mod;
{
    Integer rem;
    int i, k;
    DigitType *t;
    r->mod = mod;
    k = mod->length;
    r->val = newmvec(mod);
    cIasI(&rem, a);
    IreasI(&rem, &(mod->ModIval));
    for (i = 0; i < rem.length; i++)
	r->val[i] = rem.vec[i];
    while (i < k)
	r->val[i++] = 0;
    dI(&rem);
    t = newmdoubvec(mod);
    mvecmul(t, r->val, mod->rsquare, k);
    omvecred(t, mod->vec, mod->nprime, k);
    mvecas(r->val, &t[k], k);
    delmdoubvec(t, mod);
}				/* cMasI */

/***********************************************************
** EXTERN_FUNCTION(void cMasM, (Minteger *, const Minteger *));
**	Creator Minteger, Init Minteger
*/
void 
cMasM(a, b)
    Minteger *a;
    const Minteger *b;
{
    a->mod = b->mod;
    a->val = newmvec(a->mod);
    mvecas(a->val, b->val, a->mod->length);
}				/* cMasM */

/***********************************************************
** EXTERN_FUNCTION(void dM, (Minteger *));
**	Destructor
*/
void 
dM(a)
    Minteger *a;
{
    delmvec(a->val, a->mod);
}				/* dM */

/***********************************************************
** EXTERN_FUNCTION(void MasM, (Minteger *, const Minteger *));
*/
void 
MasM(a, b)
    Minteger *a;
    const Minteger *b;
{
    mvecas(a->val, b->val, a->mod->length);
}				/* MasM */

/***********************************************************
** EXTERN_FUNCTION(void MasI, (Minteger *, const Integer *));
*/
void 
MasI(r, a)
    Minteger *r;
    const Integer *a;
{
    Integer rem;
    ModulusType *mod;
    int i, k;
    DigitType *t;
    mod = r->mod;
    k = mod->length;
    cIasI(&rem, a);
    IreasI(&rem, &(mod->ModIval));
    for (i = 0; i < rem.length; i++)
	r->val[i] = rem.vec[i];
    while (i < k)
	r->val[i++] = 0;
    dI(&rem);
    t = newmdoubvec(mod);
    mvecmul(t, r->val, mod->rsquare, k);
    omvecred(t, mod->vec, mod->nprime, k);
    mvecas(r->val, &t[k], k);
    delmdoubvec(t, mod);
}				/* MasI */

/***********************************************************
** EXTERN_FUNCTION(void IasM, (Integer *, const Minteger *));
*/
void 
IasM(a, r)
    Integer *a;
    const Minteger *r;
{
    DigitType *t;
    int i, k;
    ModulusType *mod;
    mod = r->mod;
    t = newmdoubvec(mod);
    k = mod->length;
    mvecas(t, r->val, k);
    mvecas0(&t[k], k);
    omvecred(t, mod->vec, mod->nprime, k);
    IasI(a, &mod->ModIval);
    mvecas(a->vec, &t[k], k);
    delmdoubvec(t, mod);
    i = k;
    t = &a->vec[k - 1];
    while ((i > 0) && (!*t))
    {
	t--;
	i--;
    }
    a->length = i;
}				/* IasM */

/***********************************************************
** EXTERN_FUNCTION(BOOLEAN Meq0, (const Minteger *));
*/
BOOLEAN 
Meq0(a)
    const Minteger *a;
{
    int k;
    DigitType *u;
    k = a->mod->length;
    u = a->val;
    for (; k > 0; k--)
	if (*u++)
	    return FALSE;
    return TRUE;
}				/* Meq0 */

/***********************************************************
** EXTERN_FUNCTION(void MasMplM, (Minteger *, const Minteger *,
**	const Minteger *));
*/
void 
MasMplM(sum, a, b)
    Minteger *sum;
    const Minteger *a, *b;
{
    DigitType carry;
    int k = sum->mod->length;
    DigitType *mv = sum->mod->vec;
    carry = mvecadd(sum->val, a->val, b->val, k);
    if (carry || !mvecgt(mv, sum->val, k))
	mvecsub(sum->val, sum->val, mv, k);
}				/* MasMplM */

/***********************************************************
** EXTERN_FUNCTION(void MplasM, (Minteger *, const Minteger *));
*/
void 
MplasM(sum, b)
    Minteger *sum;
    const Minteger *b;
{
    DigitType carry;
    int k = sum->mod->length;
    DigitType *mv = sum->mod->vec;
    carry = mvecadd(sum->val, sum->val, b->val, k);
    if (carry || !mvecgt(mv, sum->val, k))
	mvecsub(sum->val, sum->val, mv, k);
}				/* MplasM */

/***********************************************************
** EXTERN_FUNCTION(void MasMmiM, (Minteger *, const Minteger *,
**	const Minteger *));
*/
void 
MasMmiM(diff, a, b)
    Minteger *diff;
    const Minteger *a, *b;
{
    int k = diff->mod->length;
    if (!mvecgt(b->val, a->val, k))
	mvecsub(diff->val, a->val, b->val, k);
    else
	mveccorsub(diff->val, a->val, b->val, diff->mod->vec, k);
}				/* MasMmiM */

/***********************************************************
** EXTERN_FUNCTION(void MmiasM, (Minteger *, const Minteger *));
*/
void 
MmiasM(diff, b)
    Minteger *diff;
    const Minteger *b;
{
    int k = diff->mod->length;
    if (!mvecgt(b->val, diff->val, k))
	mvecsub(diff->val, diff->val, b->val, k);
    else
	mveccorsub(diff->val, diff->val, b->val, diff->mod->vec, k);
}				/* MmiasM */

/***********************************************************
** EXTERN_FUNCTION(void MasMmuM, (Minteger *, const Minteger *,
**	const Minteger *));
*/
void 
MasMmuM(prod, a, b)
    Minteger *prod;
    const Minteger *a, *b;
{
    mdoubvec t;
    ModulusType *m;
    int k;

    m = prod->mod;
    k = m->length;
    t = newmdoubvec(m);
    mvecmul(t, a->val, b->val, k);
    omvecred(t, m->vec, m->nprime, k);
    mvecas(prod->val, &t[k], k);
    delmdoubvec(t, m);
}				/* MasMmuM */

/***********************************************************
** EXTERN_FUNCTION(void MmuasM, (Minteger *, const Minteger *));
*/
void 
MmuasM(prod, b)
    Minteger *prod;
    const Minteger *b;
{
    mdoubvec t;
    ModulusType *m;
    int k;

    m = prod->mod;
    k = m->length;
    t = newmdoubvec(m);
    mvecmul(t, prod->val, b->val, k);
    omvecred(t, m->vec, m->nprime, k);
    mvecas(prod->val, &t[k], k);
    delmdoubvec(t, m);
}				/* MmuasM */

/***********************************************************
** EXTERN_FUNCTION(BOOLEAN MasinvM, (Minteger *, const Minteger *));
*/
BOOLEAN 
MasinvM(inv, a)
    Minteger *inv;
    const Minteger *a;
{
    mdoubvec t;
    ModulusType *mod;
    int k;

    mod = a->mod;
    k = mod->length;
    if (!mvecinv(inv->val, a->val, mod))
	return FALSE;
    t = newmdoubvec(mod);
    mvecmul(t, inv->val, mod->rcube, k);
    omvecred(t, mod->vec, mod->nprime, k);
    mvecas(inv->val, &t[k], k);
    delmdoubvec(t, mod);
    return TRUE;
}				/* MasinvM */

/***********************************************************
** EXTERN_FUNCTION(void MasMdiM, (Minteger *, const Minteger *,
** const Minteger *));
*/

/***********************************************************
** EXTERN_FUNCTION(void MdiasM, (Minteger *, const Minteger *));
*/

/***********************************************************
** EXTERN_FUNCTION(void Mrandom, (Minteger * a));
*/
void 
Mrandom(a)
    Minteger *a;
{
    int i, k;
    DigitType *pa;
    pa = a->val;
    k = a->mod->length;
    pa[k - 1] = Prandom() % (a->mod->vec[k - 1]);
    for (i = k - 2; i >= 0; i--)
	pa[i] = Prandom();
}				/* Mrandom */
