//
// LiDIA - a library for computational number theory
//   Copyright (c) 1994, 1995, 1996 by the LiDIA Group
//
// File        : udigit_appl.c
// Author      : Markus Maurer (MM)
// Last change : MM, Oct 07 1996, initial version
//


#include <LiDIA/udigit.h>
#include <assert.h>

void add_special_values ()
 {
   udigit max_u;
   udigit a, b, c;
   udigit carry;

   max_u = max_udigit();

   a = 0; b = 0;
   carry = udigit_add (c, a, b);
   assert (c == 0);
   assert (carry == 0);   

   a = max_u; b = 1;
   carry = udigit_add (c, a, b);
   assert (c == 0);
   assert (carry == 1);

   a = 1; b = max_u;
   carry = udigit_add (c, a, b);
   assert (c == 0);
   assert (carry == 1);

   if (max_u >= 2)
    {
      a = max_u; b = 2;
      carry = udigit_add (c, a, b);
      assert (c == 1);
      assert (carry == 1);

      a = 2; b = max_u;
      carry = udigit_add (c, a, b);
      assert (c == 1);
      assert (carry == 1);
    }
 }


void subtract_special_values ()
 {
   udigit max_u;
   udigit a, b, c;
   udigit carry;

   max_u = max_udigit();

   // carry = 0

   a = 0; b = 0;
   carry = udigit_subtract (c, a, b);
   assert (c == 0);
   assert (carry == 0);   

   a = max_u; b = max_u;
   carry = udigit_subtract (c, a, b);
   assert (c == 0);
   assert (carry == 0);

   a = 0; b = max_u;
   carry = udigit_subtract (c, a, b);
   assert (c == 1);
   assert (carry == 1);

   if (max_u >= 2)
    {
      a = 1; b = max_u;
      carry = udigit_subtract (c, a, b);
      assert (c == 2);
      assert (carry == 1);
    }

   a = 0; b = 1;
   carry = udigit_subtract (c, a, b);
   assert (c == max_u);
   assert (carry == 1);

   a = 0; b = 0;
   carry = udigit_subtract (c, a, b);
   assert (c == 0);
   assert (carry == 0);   


   // carry = 1

   a = max_u; b = max_u;
   carry = udigit_subtract (c, a, b, 1);
   assert (c == max_u);
   assert (carry == 1);

   a = 0; b = max_u;
   carry = udigit_subtract (c, a, b, 1);
   assert (c == 0);
   assert (carry == 1);

   a = 1; b = max_u;
   carry = udigit_subtract (c, a, b, 1);
   assert (c == 1);
   assert (carry == 1);

   a = 0; b = 1;
   carry = udigit_subtract (c, a, b, 1);
   assert (c == (max_u-1));
   assert (carry == 1);
 }


void multiply_special_values ()
 {
   udigit max_u;
   udigit a, b, c;
   udigit carry;

   max_u = max_udigit();

   a = 0; b = 0;
   carry = udigit_multiply (c, a, b);
   assert (c == 0);
   assert (carry == 0);

   a = 1; b = 0;
   carry = udigit_multiply (c, a, b);
   assert (c == 0);
   assert (carry == 0);

   a = 0; b = 1;
   carry = udigit_multiply (c, a, b);
   assert (c == 0);
   assert (carry == 0);

   a = 1; b = 1;
   carry = udigit_multiply (c, a, b);
   assert (c == 1);
   assert (carry == 0);

   a = max_u; b = 1;
   carry = udigit_multiply (c, a, b);
   assert (c == max_u);
   assert (carry == 0);
 }


void add_mod_special_values ()
 {
   udigit max_m = max_udigit_modulus();
   udigit a, b, c;

   a = max_m -1;
   b = 0;
   c = udigit_add_mod(a,b,max_m);
   assert (c == (max_m -1));

   a = max_m -1;
   b = 1;
   c = udigit_add_mod(a,b,max_m);
   assert (c == 0);

   if (max_m > 2)
    {
      a = max_m -1;
      b = 2;
      c = udigit_add_mod(a,b,max_m);
      assert (c == 1);
      
      a = max_m -1;
      b = max_m -1;
      c = udigit_add_mod(a,b,max_m);
      assert (c == (max_m-2));
    }
 }


void subtract_mod_special_values ()
 {
   udigit max_m = max_udigit_modulus();
   udigit a, b, c;

   a = max_m -1;
   b = 0;
   c = udigit_subtract_mod(a,b,max_m);
   assert (c == a);

   a = 0;
   b = max_m -1;
   c = udigit_subtract_mod(a,b,max_m);
   assert (c == 1);

   a = 0;
   b = 1;
   c = udigit_subtract_mod(a,b,max_m);
   assert (c == (max_m -1));

   if (max_m > 2)
    {
      a = 2;
      b = 1;
      c = udigit_subtract_mod(a,b,max_m);
      assert (c == 1);
      
      a = 1;
      b = 2;
      c = udigit_subtract_mod(a,b,max_m);
      assert (c == (max_m-1));
    }
 }



void multiply_mod_special_values ()
 {
   udigit max_m = max_udigit_modulus();
   udigit a, b, c;

   a = 0;
   b = 0;
   c = udigit_multiply_mod(a,b,max_m);
   assert (c == 0);

   a = 0;
   b = max_m -1;
   c = udigit_multiply_mod(a,b,max_m);
   assert (c == 0);

   a = 1;
   b = max_m -1;
   c = udigit_multiply_mod(a,b,max_m);
   assert (c == b);

   a = 1;
   b = 1;
   c = udigit_multiply_mod(a,b,max_m);
   assert (c == b);

   a = max_m -1;
   b = max_m -1;
   c = udigit_multiply_mod(a,b,max_m);
   assert (c == 1);
 }


//
// taken from bigint_appl.c, identity_test()
//

void mod_identity_test_3param (udigit a, udigit b, udigit m)
 {
   udigit c1, c2, c3, c4;

   //
   // -(-a) == a
   //
   c1 = udigit_negate_mod(a,m);
   c1 = udigit_negate_mod(c1,m);
   assert (c1 == a);

   //
   // (a + b) ==  (b + a)
   //
   c1 = udigit_add_mod(a,b,m);
   c2 = udigit_add_mod(b,a,m);
   assert (c1 == c2);

   //
   // (a + (-b)) == (a - b)
   //
   c1 = udigit_negate_mod(b,m);
   c1 = udigit_add_mod(a,c1,m);
   c2 = udigit_subtract_mod(a,b,m);
   assert (c1 == c2);
   
   //
   // (a * b) ==  (b * a)
   //
   c1 = udigit_multiply_mod(a,b,m);
   c2 = udigit_multiply_mod(b,a,m);
   assert (c1 == c2);

   //
   // (a * (-b)) == -(a * b)
   //
   c1 = udigit_negate_mod(b,m);
   c1 = udigit_multiply_mod(a,c1,m);
   c2 = udigit_multiply_mod(a,b,m);
   c2 = udigit_negate_mod(c2,m);
   assert (c1 == c2);

   //
   // (a - b) ==  -(b - a)
   //
   c1 = udigit_subtract_mod(a,b,m);
   c2 = udigit_subtract_mod(b,a,m);
   c2 = udigit_negate_mod(c2,m);
   assert (c1 == c2);

   //
   // ((a - b) + b)) ==  a
   //
   c1 = udigit_subtract_mod(a,b,m);
   c1 = udigit_add_mod(c1,b,m);
   assert (c1 == a);

   //
   // ((a + b) - b)) ==  a
   //
   c1 = udigit_add_mod(a,b,m);
   c1 = udigit_subtract_mod(c1,b,m);
   assert (c1 == a);

   //
   // (a+b)^2 == a^2 + 2ab + b^2 mod m
   //
   c1 = udigit_add_mod(a,b,m);
   c1 = udigit_multiply_mod(c1,c1,m);

   c2 = udigit_multiply_mod(a,b,m);
   c2 = udigit_add_mod(c2,c2,m);
   c3 = udigit_multiply_mod(a,a,m);
   c4 = udigit_multiply_mod(b,b,m);

   c2 = udigit_add_mod(c2,c3,m);
   c2 = udigit_add_mod(c2,c4,m);

   assert(c1 == c2);
 }


void mod_identity_test_4param (udigit a, udigit b, udigit c, udigit m)
{
   udigit c1, c2, c3; 

   //
   // (a + (b + c)) ==  ((a + b) + c)
   //
   c1 = udigit_add_mod(b,c,m);
   c1 = udigit_add_mod(a,c1,m);
   c2 = udigit_add_mod(a,b,m);
   c2 = udigit_add_mod(c2,c,m);
   assert (c1 == c2);

   //
   // (a * (b * c)) ==  ((a * b) * c)
   //
   c1 = udigit_multiply_mod(b,c,m);
   c1 = udigit_multiply_mod(a,c1,m);
   c2 = udigit_multiply_mod(a,b,m);
   c2 = udigit_multiply_mod(c2,c,m);
   assert (c1 == c2);

   //
   // (a * (b + c)) ==  ((a * b) + (a * c))
   //
   c1 = udigit_add_mod(b,c,m);
   c1 = udigit_multiply_mod(a,c1,m);
   c2 = udigit_multiply_mod(a,b,m);
   c3 = udigit_multiply_mod(a,c,m);
   c2 = udigit_add_mod(c2,c3,m);
   assert (c1 == c2);
}


void main ()
 {
   udigit max_u;
   udigit max_m;
   udigit a,b,c;
   udigit delta;

   max_u = max_udigit();

   if (max_u == 0)
    {
      cout << "max_udigit() == 0 -> No tests started !\n";
      cout.flush();
    }
   else
    {
      add_special_values();
      subtract_special_values();
      multiply_special_values();
      add_mod_special_values();
      subtract_mod_special_values();
      multiply_mod_special_values();

      max_m = max_udigit_modulus();
      delta = max_u / 5;

      if (delta == 0)
         delta = 1;

      for (a = 0; a < max_m; a+=delta)
        for (b = 0; b < max_m; b+=delta)
         {
           mod_identity_test_3param(a,b,max_m);

           for (c = 0; c < max_m; c+=delta)
              mod_identity_test_4param(a,b,c,max_m);
	 }

       cout << "Tests passed.\n";
       cout.flush();
     }
 }


