//
// bmath.cxx
//
// Leonard Janke
// August 1996

#include "bmath.hxx"

void BMath::Add(const unsigned int* A, const unsigned int* B, unsigned int* C,
		const int sizeOfA, const int sizeOfB, char& carry)
{
  char intermediateCarry;

  for (int i=0; i<sizeOfA-sizeOfB; i++)
    *(C+i)=*(A+i);

  BasicAdd(A+sizeOfA-sizeOfB,B,C+sizeOfA-sizeOfB,sizeOfB,intermediateCarry);

  if ( intermediateCarry && (sizeOfA-sizeOfB) )
    Increment(C,sizeOfA-sizeOfB,carry);
  else 
    carry=intermediateCarry;
}

void BMath::Subtract(const unsigned int* A, const unsigned int* B,
		     unsigned int* C, const int sizeOfA,
		     const int sizeOfB)
{
  char intermediateBorrow;

  for (int i=0; i<sizeOfA-sizeOfB; i++)
    *(C+i)=*(A+i);

  BasicSubtract(A+sizeOfA-sizeOfB,B,C+sizeOfA-sizeOfB,sizeOfB,
		intermediateBorrow);

  if ( intermediateBorrow )
    RippleDecrement(C,sizeOfA-sizeOfB);
}

void BMath::Multiply(const unsigned int* a, int sizeA, const unsigned int* b, 
		     int sizeB, unsigned int* result)
{
  // Recursive Multiplication based on Knuth
  // result needs to be zero'ed for this to work

  if ( sizeA > sizeB )
    {
      LSwap(a,b);
      LSwap(sizeA,sizeB);
    }

  // can assume sizeB >= sizeA  now

  if ( sizeA == 1 ) 
    {
      BasicMultiply(b,*a,result,sizeB);
      return;
    }

  // sizeA != 1

  if ( !(sizeA % 2) )
    {
      if ( sizeA==sizeB )
	{
	  if ( sizeA==2 )
	    {
	      MultDouble(result,a,b);
	      return;
	    }

	  // sizeA != 2

	  // last term
	  unsigned int* firstTerm=new unsigned int[4*sizeA]; 
	  unsigned int* middleTerm=firstTerm+sizeA;
	  unsigned int* absu1minusu0=middleTerm+sizeA;
	  unsigned int* absv0minusv1=absu1minusu0+sizeA/2;
	  unsigned int* lastTerm=absv0minusv1+sizeA/2; 
	  LMisc::MemZero(firstTerm,4*sizeA);

	  Multiply(a+sizeA/2,sizeA/2,b+sizeA/2,sizeA/2,lastTerm);
	  RippleAdd(lastTerm,result+sizeA,result+sizeA,sizeA);
	  RippleAdd(lastTerm,result+sizeA/2,result+sizeA/2,sizeA);

	  // first term

	  Multiply(a,sizeA/2,b,sizeA/2,firstTerm);
	  RippleAdd(firstTerm,result+sizeA/2,result+sizeA/2,sizeA);
	  RippleAdd(firstTerm,result,result,sizeA);

	  // middle term

	  char cmp1=GreaterThanOrEqualTo(a,a+sizeA/2,sizeA/2);
	  char cmp2=GreaterThanOrEqualTo(b+sizeA/2,b,sizeA/2);
	  char cmp=cmp1^cmp2;

	  if ( cmp1 )
	    RippleSubtract(a,a+sizeA/2,absu1minusu0,sizeA/2);
	  else
	    RippleSubtract(a+sizeA/2,a,absu1minusu0,sizeA/2);

	  if ( cmp2 )
	    RippleSubtract(b+sizeA/2,b,absv0minusv1,sizeA/2);
	  else
	    RippleSubtract(b,b+sizeA/2,absv0minusv1,sizeA/2);

	  Multiply(absu1minusu0,sizeA/2,absv0minusv1,sizeA/2,middleTerm);
	  if ( cmp )
	    RippleSubtract(result+sizeA/2,middleTerm,result+sizeA/2,sizeA);
	  else
	    RippleAdd(middleTerm,result+sizeA/2,result+sizeA/2,sizeA);

	  delete[] firstTerm;

	  return;
	}

      // sizeA != sizeB

      unsigned int* temp1=new unsigned int[sizeB];
      LMisc::MemZero(temp1,sizeB);
      Multiply(a,sizeA,b+sizeB-sizeA,sizeA,result+sizeB-sizeA);
      Multiply(a,sizeA,b,sizeB-sizeA,temp1);
      RippleAdd(temp1,result,result,sizeB);
      delete[] temp1;

      return;
    }

  // sizeA % 2 == 1

  Multiply(a+1,sizeA-1,b+sizeB-sizeA+1,sizeA-1,result+sizeB-sizeA+2);

  unsigned int* temp2=new unsigned int[sizeB];
  LMisc::MemZero(temp2,sizeB);

  Multiply(a+1,sizeA-1,b,sizeB-sizeA+1,temp2);
  RippleAdd(result+1,temp2,result+1,sizeB);
  delete[] temp2;

  BasicMultiply(b,*a,result,sizeB);
}

void BMath::Square(const unsigned int* a, const int sizeA, unsigned int* result)
{
  // Recursive Squaring 
  // same algorithm as multiplying 
  // but a few steps can be eliminated
  //
  // result needs to be zeroed for this function to work!

  if ( sizeA ==1 ) 
    {
      BasicMultiply(a,*a,result,1);
      return;
    }

  // size != 1

  if ( !(sizeA % 2) )
    {
      if ( sizeA==2 )
	{
	  SquareDouble(result,a);
	  return;
	}

      // size != 2
      
      unsigned int* firstTerm=new unsigned int[4*sizeA]; 
      unsigned int* middleTerm=firstTerm+sizeA;
      unsigned int* absu1minusu0=middleTerm+sizeA;
      unsigned int* lastTerm=absu1minusu0+sizeA;
      LMisc::MemZero(firstTerm,4*sizeA);

      // last term

      Square(a+sizeA/2,sizeA/2,lastTerm);
      RippleAdd(lastTerm,result+sizeA,result+sizeA,sizeA);
      RippleAdd(lastTerm,result+sizeA/2,result+sizeA/2,sizeA);

      // first term

      Square(a,sizeA/2,firstTerm);
      RippleAdd(firstTerm,result+sizeA/2,result+sizeA/2,sizeA);
      RippleAdd(firstTerm,result,result,sizeA);

      // middle term

      char cmp=GreaterThanOrEqualTo(a,a+sizeA/2,sizeA/2);

      if ( cmp )
        RippleSubtract(a,a+sizeA/2,absu1minusu0,sizeA/2);
      else
        RippleSubtract(a+sizeA/2,a,absu1minusu0,sizeA/2);

      Square(absu1minusu0,sizeA/2,middleTerm);
      RippleSubtract(result+sizeA/2,middleTerm,result+sizeA/2,sizeA);

      delete[] firstTerm;
      return;
    }

  // sizeA % 2 == 1

  Square(a+1,sizeA-1,result+2);
  BasicMultiply(a,*a,result,sizeA);
  BasicMultiply(a+1,*a,result+1,sizeA-1);
}

void BMath::Divide(const unsigned int* divisor, int divisorSize,
		   const unsigned int* dividend, int dividendSize,
		   unsigned int*& quotient, unsigned int*& remainder)
{
  char normalizationFactor=31-char(BSR(*divisor));

  // remainder=normalized dividend

  remainder=new unsigned int[dividendSize+1];

  LMisc::MemCopy(dividend,remainder+1,dividendSize);

  *remainder=0;

  ShiftLeft(remainder,dividendSize+1,normalizationFactor);

  // padded divisor = concat(0,divisor) for use in comparisons
  
  unsigned int* paddedDivisor=new unsigned int[divisorSize+1];

  LMisc::MemCopy(divisor,paddedDivisor+1,divisorSize);

  *paddedDivisor=0;

  // normalize padded divisor

  ShiftLeft(paddedDivisor,divisorSize+1,normalizationFactor);

  // create quotient

  quotient=new unsigned int[dividendSize-divisorSize+1];

  // main

  unsigned int d1=*(paddedDivisor+1)+1;

  unsigned int highWord;
  unsigned int lowWord;
  unsigned int q;
  unsigned int r;          // Can I use this somehow?

  unsigned int* multBuf=new unsigned int[divisorSize+1];

  for (int i=0; i<dividendSize-divisorSize+1; i++)
    {
      highWord=*(remainder+i);
      lowWord=*(remainder+i+1);
      

      if ( !d1 )
	q=highWord;
      else if ( d1==highWord)  
	q=0xffffffff;
      else
	BasicDivide(highWord,lowWord,d1,q,r);

      // clear multBuf

      LMisc::MemZero(multBuf,divisorSize+1);

      BasicMultiply(paddedDivisor+1,q,multBuf,divisorSize);
      RippleSubtract(remainder+i,multBuf,remainder+i,divisorSize+1);

      *(quotient+i)=q;

      if ( GreaterThanOrEqualTo(remainder+i,paddedDivisor,divisorSize+1) )
	{
	  (*(quotient+i))++;
	  RippleSubtract(remainder+i,paddedDivisor,remainder+i,divisorSize+1);

	  if ( GreaterThanOrEqualTo(remainder+i,paddedDivisor,divisorSize+1) )
	    {
	      (*(quotient+i))++;
	      RippleSubtract(remainder+i,paddedDivisor,remainder+i,
			     divisorSize+1);

	      if ( GreaterThanOrEqualTo(remainder+i,paddedDivisor,
					divisorSize+1) )
		abort(); // oh, oh! A bug!!!
	    }
	}

    }
  // clean up

  delete[] paddedDivisor;
  delete[] multBuf;

  // Undo Normalization 
  ShiftRight(remainder,dividendSize+1,normalizationFactor);
}
