/* odei.c */

/* This file is a part of RLaB ("Our"-LaB)
   Copyright (C) 1994  Ian R. Searle

   This program is free software; you can redistribute it and/or modify
   it under the terms of the GNU General Public License as published by
   the Free Software Foundation; either version 2 of the License, or
   (at your option) any later version.

   This program is distributed in the hope that it will be useful,
   but WITHOUT ANY WARRANTY; without even the implied warranty of
   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
   GNU General Public License for more details.

   You should have received a copy of the GNU General Public License
   along with this program; if not, write to the Free Software
   Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.

   See the file ./COPYING
   ********************************************************************** */

#include "rlab.h"
#include "code.h"
#include "symbol.h"
#include "mem.h"
#include "list.h"
#include "btree.h"
#include "bltin.h"
#include "scop1.h"
#include "matop1.h"
#include "matop2.h"
#include "r_string.h"
#include "util.h"
#include "mathl.h"
#include "function.h"
#include "lp.h"
#include "odei.h"

#include <math.h>
#include <stdio.h>
#include <string.h>
#include <errno.h>

#define rabs(x) ((x) >= 0 ? (x) : -(x))

#define TARG_DESTROY(arg, targ)   if (targ.u.ent != arg.u.ent) \
                                    remove_tmp_destroy (targ.u.ent);

int rks_func _PROTO ((double *t, double *y, double *yp));
static double epsilon _PROTO ((void));

static int neq;
static Scalar *stime;
static Matrix *my;
static ListNode *tent, *my_ent;
static char *fname;
static Datum rks_args[3];

/* **************************************************************
 * Builtin interface to RKSUITE
 * ************************************************************** */
void
odei (return_ptr, n_args, d_arg)
     VPTR *return_ptr;
     int n_args;
     Datum *d_arg;
{
  double dtout, eps, t, t0, tend, tstart;
  int i, j, lenwrk, nstep;
  Datum arg1, arg2, targ2, arg3, targ3, arg4, targ4;
  Datum arg5, targ5, arg6, targ6, arg7, targ7;
  F_DOUBLE abserr, relerr;
  F_INT iflag, iwork[5];
  ListNode *tmp1, *tmp2, *tmp5;
  Matrix *out, *work, *y, *ystart;

  eps = epsilon ();

  /*
   * Check arguments.
   */

  if (n_args < 4)
    error_1 ("ode: requires at least 4 arguments", 0);

  /* Get function ptr */
  arg1 = get_bltin_arg ("ode", d_arg, 1, 0);
  if (arg1.type != ENTITY)
    error_1 ("ode: 1st argument must be a user/bltin function", 0);
  if (!(e_type (arg1.u.ent) == U_FUNCTION || e_type (arg1.u.ent) == BLTIN))
    error_1 ("ode: 1st argument must be a user/bltin function", 0);

  fname = e_name (arg1.u.ent);

  /* Get tstart */
  arg2 = get_bltin_arg ("ode", d_arg, 2, NUM);
  targ2 = convert_to_matrix (arg2);
  tstart = (double) MAT (e_data (targ2.u.ent), 1, 1);

  /* Get tend */
  arg3 = get_bltin_arg ("ode", d_arg, 3, NUM);
  targ3 = convert_to_matrix (arg3);
  tend = (double) MAT (e_data (targ3.u.ent), 1, 1);
  
  if (tend == tstart)
    error_1 ("ode: tstart == tend", 0);

  /* Get ystart */
  arg5 = get_bltin_arg ("ode", d_arg, 4, NUM);
  targ5 = convert_to_matrix (arg5);
  ystart = (Matrix *) e_data (targ5.u.ent);
  if (MTYPE (ystart) != REAL)
    error_1 ("ode: YSTART must be REAL", 0);

  /* Extract neq from ystart */
  if (MNR (ystart) != 1 && MNC (ystart) != 1)
    error_1 ("ode: YSTART must be a row or column vector", 0);

  neq = MNR (ystart) * MNC (ystart);
  tmp2 = install_tmp (MATRIX, y = matrix_Create (neq, 1), 
		      matrix_Destroy);

  if (n_args > 4)
  {
    /* Get dtout */
    arg4 = get_bltin_arg ("ode", d_arg, 5, NUM);
    if (arg4.type == ENTITY && e_type (arg4.u.ent) == UNDEF)
    {
      /* Default value */
      dtout = (tend - tstart)/100;
    }
    else
    {
      targ4 = convert_to_matrix (arg4);
      dtout = (double) MAT (e_data (targ4.u.ent), 1, 1);
      TARG_DESTROY (arg4, targ4);
    }
  }
  else
  {
    /* Default value */
    dtout = (tend - tstart)/100;
  }

  if (dtout == 0)
    error_1 ("ode: dout must be non-zero", 0);
    
  if (n_args > 5)
  {
    /* Get relerr */
    arg6 = get_bltin_arg ("ode", d_arg, 6, NUM);
    if (arg6.type == ENTITY && e_type (arg6.u.ent) == UNDEF)
    {
      /* Default value */
      relerr = (F_DOUBLE) 1.e-6;
    }
    else
    {
      targ6 = convert_to_matrix (arg6);
      relerr = (F_DOUBLE) MAT (e_data (targ6.u.ent), 1, 1);
      TARG_DESTROY (arg6, targ6);
    }
  }
  else
  {
    /* Default */
    relerr = (F_DOUBLE) 1.e-6;
  }

  if (n_args > 6)
  {
    /* Get tol */
    arg7 = get_bltin_arg ("ode", d_arg, 7, 0);
    
    if (arg7.type == ENTITY && e_type (arg7.u.ent) == UNDEF)
    {
      /* Default value */
      abserr = (F_DOUBLE) 1.0e-6;
    }
    else
    {
      targ7 = convert_to_matrix (arg7);
      abserr = (F_DOUBLE) MAT (e_data (targ7.u.ent), 1, 1);
      TARG_DESTROY (arg7, targ7);
    }
  }
  else
  {
    /* Default value */
    abserr = (F_DOUBLE) 1.0e-6;
  }

  /*
   * Done with argument processing.
   * Initialize some things...
   */
  
  lenwrk = 100 + 21*neq;
  tmp1 = install_tmp (MATRIX, work = matrix_Create (lenwrk, 1),
		      matrix_Destroy);
  iflag = 1;

  /*
   * Call integrator repeatedley.
   */

  nstep = (rabs (tend - tstart)/dtout + .5);

  /* Set up output array */
  tmp5 = install_tmp (MATRIX, out = matrix_Create (nstep+1, neq+1),
		      matrix_Destroy);

  /*
   * Set up ENTITIES for user-function.
   */

  tent = listNode_Create ();
  listNode_AttachData (tent, SCALAR, stime = scalar_Create (0.0),
		       scalar_Destroy);
  listNode_SetKey (tent, cpstr ("t"));

  my_ent = listNode_Create ();
  listNode_AttachData (my_ent, MATRIX, my = matrix_Create (0, 0),
		       matrix_Destroy);
  listNode_SetKey (my_ent, cpstr ("y"));

  /*
   * Set these manually so that we can just 
   * copy the pointer later, and not duplicate
   * the space.
   */

  scalar_SetName (stime, cpstr ("t"));
  matrix_SetName (my, cpstr ("y"));
  my->nrow = neq;
  my->ncol = 1;

  rks_args[0].u.ent = tent;
  rks_args[0].type = ENTITY;
  rks_args[1].u.ent = my_ent;
  rks_args[1].type = ENTITY;

  /* Save initial conditions and setup y[] */
  MAT (out, 1, 1) = tstart;
  for (j = 2; j <= neq+1; j++)
  {
    MAT (out, 1, j) = MATrv1 (ystart, j-1);
    MATrv1 (y, j-1) = MATrv1 (ystart, j-1);
  }

  /* Now step through output points */
  t0 = tstart;
  for (i = 1; i <= nstep; i++)
  {
    t = tstart + i * dtout;
    if ( i == nstep ) t = tend;

    ODE (rks_func, &neq, MDPTRr (y), &t0, &t,
	 &relerr, &abserr, &iflag, MDPTRr (work), iwork);
    
    /* Check for errors */
    if (iflag > 3)
    {
      /* Check for different types of failures (later) */
      printf ("ode: iflag = %i\n", (int) iflag);
    }

    /* Reset the time */
    t0 = t;

    /* Save the output */
    MAT (out, i+1, 1) = t;
    for (j = 2; j <= neq+1; j++)
      MAT (out, i+1, j) = MATrv1 (y, j-1);
  }

  /* Clean Up */
  remove_tmp_destroy (tmp1);
  remove_tmp_destroy (tmp2);
  remove_tmp (tmp5);

  /* Clean up time, my */
  listNode_Destroy (tent);
  my->nrow = 0;
  my->ncol = 0;
  my->val.mr = 0;
  listNode_Destroy (my_ent);
  
  TARG_DESTROY (arg2, targ2);
  TARG_DESTROY (arg3, targ3);
  TARG_DESTROY (arg5, targ5);

  *return_ptr = (VPTR) out;
}

/*
 * The interface to the user-specified function.
 */

int
rks_func (t, y, yp)
     double *t, *y, *yp;
{
  int i;
  VPTR retval;

  /*
   * Put t, y, and yp into rks_args.
   */

  SVALr (stime) = *t;
  my->val.mr = y;

  /*
   * Call user/builtin function.
   */

  retval = call_rlab_script (fname, rks_args, 2);

  /*
   * Now copy returned entity into yp.
   */

  if ((int) *((int *) retval) == MATRIX)
  {
    if (MNR (retval) * MNC (retval) != neq)
      error_1 ("ode: incorrectly dimensioned derivitive", 0);
    if (MTYPE (retval) != REAL)
      error_1 ("ode: rhs function must return REAL matrix", 0);

    for (i = 0; i < neq; i++)
    {
      yp[i] = MATrv (retval, i);
    }
    if (matrix_GetName (retval) == 0)
      matrix_Destroy ((Matrix *) retval);
  }
  else if ((int) *((int *) retval) == SCALAR)
  {
    if (neq != 1)
      error_1 ("ode: incorrectly dimensioned derivitive", 0);
    
    yp[0] = SVALr (retval);
    if (scalar_GetName (retval) == 0)
      scalar_Destroy ((Scalar *) retval);
  }
  else
    error_1 ("ode: derivitive function must return a NUMERIC entity", 0);

  return (1);
}

static double
epsilon ()
{
  double eps;
  eps = 1.0;
  while ((1.0 + eps) != 1.0)
  {
    eps = eps / 2.0;
  }
  return (eps);
}
