//
// montyrep.cxx
//
// Copyright (C) 1996-7 by Leonard Janke (janke@unixg.ubc.ca)

#include <linteger/montyrep.hxx>
#include <linteger/lmath.hxx>
#include <cassert>

//
//
//
// MontyRing stuff 
//
//
//

MontyRing::MontyRing()
{
  _n=NULL;
  _rSquaredModN=NULL;
}

MontyRing::MontyRing(const LInteger& N)
{
  assert( N.IsOdd() && N>LInteger::One );

  _digitsOfN=N._digits;

  LInteger R=LInteger::TwoToThe(_digitsOfN*LMisc::bitsPerUInt);
  LInteger RInv, NPrime;

  LInteger gcd=LMath::ExtendedEuclid(R,N,RInv,NPrime);

  if ( RInv.IsNegative() )
    {
      RInv+=N;
      NPrime-=R;
    }

  NPrime=-NPrime;

  _nPrimeModB=NPrime._magnitude[NPrime._digits-1];

  LInteger RSquaredModN=(R*R)%N;
  _rSquaredModN=new unsigned int[_digitsOfN];
  int offset=_digitsOfN-RSquaredModN._digits;
  LMisc::MemZero(_rSquaredModN,offset);

  LMisc::MemCopy(_rSquaredModN+offset,RSquaredModN._magnitude,
		 RSquaredModN._digits-offset);

  _n=new unsigned int[_digitsOfN];
  LMisc::MemCopy(_n,N._magnitude,_digitsOfN);
}

MontyRing::~MontyRing()
{
  delete[] _rSquaredModN;
  delete[] _n;
}

MontyRing::MontyRing(const MontyRing& x)
{
  if ( x._n )
    {
      _digitsOfN=x._digitsOfN;
      _nPrimeModB=x._nPrimeModB;

      _n=new unsigned int[_digitsOfN]; 
      LMisc::MemCopy(_n,x._n,_digitsOfN);
      
      _rSquaredModN=new unsigned int[_digitsOfN]; 
      LMisc::MemCopy(_rSquaredModN,x._rSquaredModN,_digitsOfN);
    }
  else
    {
      _n=NULL;
      _rSquaredModN=NULL;
    }
}

MontyRing& MontyRing::operator=(const MontyRing& x)
{
  if ( this != &x )
    {
      delete[] _rSquaredModN;
      delete[] _n;

      if ( x._n )
	{
	  _digitsOfN=x._digitsOfN;
	  _nPrimeModB=x._nPrimeModB;
	  
	  _n=new unsigned int[_digitsOfN]; 
	  LMisc::MemCopy(_n,x._n,_digitsOfN);

	  _rSquaredModN=new unsigned int[_digitsOfN]; 
	  LMisc::MemCopy(_rSquaredModN,x._rSquaredModN,_digitsOfN);
	}
      else
	{
	  _n=NULL;
	  _rSquaredModN=NULL;
	}
    }

  return *this;
}


//
//
//
// MontyRepStuff
//
//
//

unsigned int* MontyRep::TScratch(NULL);
MontyRing MontyRep::ring;

MontyRep::MontyRep() : _montyRep(NULL)
{
}

MontyRep::MontyRep(const LInteger& residue) 
{
  // add assertions here

  _montyRep=new unsigned int[ring._digitsOfN];

  int offset=ring._digitsOfN-residue._digits;
  LMisc::MemZero(_montyRep,offset);

  LMisc::MemCopy(_montyRep+offset,residue._magnitude,
                 residue._digits);

  MontyMultiply(_montyRep,ring._rSquaredModN,_montyRep);
}

MontyRep::~MontyRep()
{
  delete[] _montyRep;
}

MontyRep::MontyRep(const MontyRep& x) 
{
  if ( x._montyRep )
    {
      _montyRep=new unsigned int[ring._digitsOfN];
      LMisc::MemCopy(_montyRep,x._montyRep,ring._digitsOfN);
    }
  else
    _montyRep=NULL;
}

MontyRep& MontyRep::operator=(const MontyRep& x) 
{
  if ( this != &x )
    {
      delete[] _montyRep;

      if ( x._montyRep )
	{
	  _montyRep=new unsigned int[ring._digitsOfN];
	  LMisc::MemCopy(_montyRep,x._montyRep,ring._digitsOfN);
	}
      else
	_montyRep=NULL;
    }
}

MontyRing MontyRep::SetRing(const MontyRing& newRing) 
{ 
  MontyRing oldRing(ring);

  ring=newRing; 
  delete[] TScratch;
  
  TScratch=new unsigned int[2*ring._digitsOfN+1];

  return oldRing;
}

void MontyRep::ReduceTScratch()
{
  for (int i=2*ring._digitsOfN; i>=ring._digitsOfN+1; i--)
    BMath::BasicMultiply(ring._n, TScratch[i]*ring._nPrimeModB,
			 TScratch+i-ring._digitsOfN, ring._digitsOfN);

  if ( TScratch[0] == 1u ||
       BMath::GreaterThanOrEqualTo(TScratch+1,ring._n,ring._digitsOfN) )
    BMath::RippleSubtract(TScratch+1,ring._n,TScratch+1,ring._digitsOfN);
}

void MontyRep::MontyMultiply(const unsigned int* x,
			     const unsigned int* y,
			     unsigned int* z)
{
  LMisc::MemZero(TScratch,2*ring._digitsOfN+1);
  BMath::Multiply(x,ring._digitsOfN,y,ring._digitsOfN,TScratch+1);
  ReduceTScratch();
  LMisc::MemCopy(z,TScratch+1,ring._digitsOfN);
}

MontyRep& MontyRep::Square()
{
  LMisc::MemZero(TScratch,2*ring._digitsOfN+1);
  BMath::Square(_montyRep,ring._digitsOfN,TScratch+1);
  ReduceTScratch();
  LMisc::MemCopy(_montyRep,TScratch+1,ring._digitsOfN);
}

LInteger MontyRep::ToLInteger() const
{
  LMisc::MemZero(TScratch,ring._digitsOfN+1);
  LMisc::MemCopy(TScratch+ring._digitsOfN+1,_montyRep,ring._digitsOfN);
  ReduceTScratch();

  unsigned int* x=new unsigned int[ring._digitsOfN];
  LMisc::MemCopy(x,TScratch+1,ring._digitsOfN);

  LInteger rv(x,ring._digitsOfN,0,0);
  rv.compress();

  return rv;
}
