/*
   A block-diagonal decomposition preconditioner for the SV codes

   By default, the matrix is divided into BDD_FRAGS pieces, using
   consequtive blocks.  SpSubset is used to generate local copies
   (we could use SpSubsetInPlace for this, but see below).

   An optional routine will set the number of blocks, set the 
   method on a block-by-block or all-blocks basis, and define the
   part of the vector to be solved in each block (and perhaps even
   set the solver context, giving the maximum flexibility in setting
   up the solver).
*/

#define BDD_FRAGS -1

#include "tools.h"
#include "solvers/svctx.h"
#include "solvers/svpriv.h"
#include <math.h>
#include "inline/spops.h" 
#include "inline/copy.h"  
#include "inline/setval.h"

/*ARGSUSED*/
void SViCreateBDD(ctx,mat)
SVctx *ctx;
SpMat *mat;
{
SVBDDctx *lctx;

lctx         = NEW(SVBDDctx);   CHKPTR(lctx);
lctx->nd     = BDD_FRAGS;
lctx->bb     = 0;
lctx->defMeth= SVLU;
lctx->defIt  = ITGMRES;
lctx->itmax  = -1;
ctx->method  = ITGMRES;
ctx->private = (void *) lctx;
ctx->is_iter = 1;

ctx->setup   = SViSetupBDD;
ctx->solve   = SViSolveBDD;
ctx->destroy = SViDestroyBDD;
}

void SViSetupBDD(ctx)
SVctx *ctx;
{
SVBDDctx *lctx = (SVBDDctx *) ctx->private;
int      i, j, ioff, *idx, nv, dnv, nlv, nvl;
double   t1;

t1  = SYGetCPUTime();
ctx->itctx = ITCreate( ctx->method );     CHKERR(1);
DVSetDefaultFunctions( ctx->itctx->vc );  CHKERR(1);
ctx->itctx->amult = SViMult;
ctx->itctx->tamult= SViMultTrans;
ctx->itctx->binv  = SViApplyBDD;
/* put dummy values in vec_sol and vec_rhs */
ctx->itctx->vec_sol = (void *) 1;
ctx->itctx->vec_rhs = (void *) 1;
ctx->itctx->usr_monitor = 0;

if (!lctx->bb) 
    SViAllocBDDDomains( ctx );

/* Compute a default partition of the matrix */
ioff = 0;
dnv  = (ctx->mat->rows + lctx->nd - 1) / lctx->nd;
nlv  = 0;

/* For each block, generate any data that is not provided. */
TRPUSH(SVTRID+1);
for (i=0; i<lctx->nd; i++) {
    nv = dnv;
    if (!lctx->bb[i].idx) {
	/* Form default decomposition */
	if (ioff + nv > ctx->mat->rows) 
	    nv = ctx->mat->rows - ioff;
	idx = lctx->bb[i].idx   = (int *)MALLOC( nv * sizeof(int) );  
	lctx->bb[i].nv          = nv;
	CHKPTR(idx);
	for (j=0; j<nv; j++) idx[j] = j + ioff;
	lctx->bb[i].iscontig    = 1;
	}
    if (!lctx->bb[i].block) {
	idx = lctx->bb[i].idx;
	nvl = lctx->bb[i].nv;
	lctx->bb[i].block = SpSubsetSorted( ctx->mat, nvl, nvl, idx, idx );
	CHKPTR(lctx->bb[i].block)
	}
    if (!lctx->bb[i].svc) {
	lctx->bb[i].svc   = SVCreate( lctx->bb[i].block, lctx->defMeth );
	CHKPTR( lctx->bb[i].svc );
	SVSetAccelerator( lctx->bb[i].svc, lctx->defIt );
	SVSetUp( lctx->bb[i].svc );
	if (lctx->itmax > 0) {
	    SVSetIts( lctx->bb[i].svc, lctx->itmax );
	    /* Set the relative tolerance small enough to force itmax steps */
	    SVSetRelativeTol( lctx->bb[i].svc, 1.0e-15 );
	    }
	ctx->flops += lctx->bb[i].svc->flops;
	lctx->bb[i].svc->flops = 0;
	}
    if (lctx->bb[i].block->rows > nlv) 
	nlv = lctx->bb[i].block->rows;
    ioff += nv;
    }
TRPOP;
TRPUSH(SVTRID+2);
lctx->w1 = (double *)MALLOC( nlv * 2 * sizeof(double) );    CHKPTR(lctx->w1) ;
lctx->w2 = lctx->w1 + nlv;

ctx->nzorig = SpNz(ctx->mat);
ctx->setupcalled = 1;
ctx->t_setup += SYGetCPUTime() - t1;
TRPOP;
}

int SViSolveBDD( ctx, b, x )
SVctx  *ctx;
double *b, *x;
{
SVBDDctx *lctx = (SVBDDctx *) ctx->private;
int      its;
int      i, nd;
double   t1;

if (!ctx->setupcalled) {(*ctx->setup)( ctx );  CHKERRV(1,-1);}
if (!ctx->solvecalled) {
    t1 = SYGetCPUTime();
    ITSetUp(ctx->itctx, (void *)ctx ); CHKERRV(1,-1);
    ctx->t_setup += SYGetCPUTime() - t1;
    }
ctx->solvecalled = 1;

t1 = SYGetCPUTime();
SViManageInitialGuess( ctx, x );
ctx->itctx->vec_rhs = (void *)b;
ctx->itctx->vec_sol = (void *)x;
its = ITSolve( ctx->itctx, (void *)ctx );
SVGetITFlops(ctx,2*ctx->nzorig,0);
/* Get the flops from the subsidiary solvers */
nd = lctx->nd;
for (i=0; i<nd; i++) {
    ctx->flops += lctx->bb[i].svc->flops;
    lctx->bb[i].svc->flops = 0;
    }
ctx->its     = its;
ctx->t_solve += SYGetCPUTime() - t1;
return its;
}

void SViDestroyBDD( ctx )
SVctx *ctx;
{
SVBDDctx *lctx = (SVBDDctx *) ctx->private;
int i, nd;

nd = lctx->nd;

for (i=0; i<nd; i++) {
    SVDestroy( lctx->bb[i].svc );
    SpDestroy( lctx->bb[i].block );
    if (lctx->bb[i].idx)
	FREE( lctx->bb[i].idx );
    }
FREE( lctx->bb );
FREE( lctx->w1 );
FREE( lctx );
VEDestroy( ctx->itctx->vc );
ITDestroy( ctx->itctx, ctx );
FREE( ctx );
}

void SViApplyBDD( ctx, x, y )
SVctx  *ctx;
double *x, *y;
{
SVBDDctx *lctx = (SVBDDctx *) ctx->private;
int      i, nd, *idx, nv;
double   *bb, *xx;

ctx->nbinv++;

bb = lctx->w1;
xx = lctx->w2;
nd = lctx->nd;
for (i=0; i<nd; i++) {
    /* If contiguous, we could just pass contiguous segments.  For 
       block-diagonal, this is an important special case */
    idx = lctx->bb[i].idx;
    if (lctx->bb[i].iscontig) {
	SVSolve( lctx->bb[i].svc, x + idx[0], y + idx[0] );
	}
    else {
	nv  = lctx->bb[i].nv;
	GATHER(bb,idx,x,nv);
	SVSolve( lctx->bb[i].svc, bb, xx );
	SCATTER(xx,idx,y,nv);
	}
    }
}

/*
    Allocate the domains 
 */
void SViAllocBDDDomains( ctx )
SVctx *ctx;
{
SVBDDctx *lctx = (SVBDDctx *) ctx->private;
int      i, nd;
TRPUSH(SVTRID+3);
if (!lctx->bb) {
    nd           = lctx->nd;
    if (nd < 0) {
	nd = (int) sqrt( (double)(ctx->mat->rows + 0.5) );
	if (nd < 1) nd = 1;
	lctx->nd = nd;
	}
    lctx->bb     = (SVBDDBlock *)MALLOC( nd * sizeof(SVBDDBlock) ); 
    CHKPTR(lctx->bb);
    for (i=0; i<nd; i++) {
	lctx->bb[i].block    = 0;
	lctx->bb[i].nv       = 0;
	lctx->bb[i].idx      = 0;
	lctx->bb[i].iscontig = 0;
	lctx->bb[i].svc      = 0;
	lctx->bb[i].lmeth    = SVLU;
	}
    }
TRPOP;
}

/*
   Set the indices (idx) for domain i.  There are nv elements in idx.
   A COPY is made
 */
void SViSetBDDDecomp( ctx, i, idx, nv )
SVctx *ctx;
int   i, *idx, nv;
{
SVBDDctx *lctx = (SVBDDctx *) ctx->private;
int      nd, *lidx;

if (ctx->type != SVBDD) return;
TRPUSH(SVTRID+4);
if (!lctx->bb) {
    SViAllocBDDDomains( ctx );
    CHKERR(1);
    }

nd = lctx->nd;
if (i >= nd) { 
    SETERRC(1,"Attempt to set indices for out-of-range domain"); return; }

lidx = lctx->bb[i].idx = (int *)MALLOC( nv * sizeof(int) );   CHKPTR(lidx);
lctx->bb[i].nv = nv;
MEMCPY(lidx,idx,nv*sizeof(int));
TRPOP;
}

/*
   Set the Method for domain i. 
 */
void SViSetBDDMethodDecomp( ctx, i, v )
SVctx    *ctx;
int      i;
SVMETHOD v;
{
SVBDDctx *lctx = (SVBDDctx *) ctx->private;
int      nd;

if (ctx->type != SVBDD) return;
if (!lctx->bb) {
    SViAllocBDDDomains( ctx );
    CHKERR(1);
    }

nd = lctx->nd;
if (i >= nd) { 
    SETERRC(1,"Attempt to set method for out-of-range domain"); return; }

lctx->bb[i].lmeth = v;
}

/* Solver context for a particular subdomain (with optional matrix) */
void SViSetBDDSolverDecomp( ctx, i, v, mat )
SVctx    *ctx;
int      i;
ITMETHOD v;
SpMat    *mat;
{
SVBDDctx *lctx = (SVBDDctx *) ctx->private;
int      nd;

if (ctx->type != SVBDD) return;
if (!lctx->bb) {
    SViAllocBDDDomains( ctx );
    CHKERR(1);
    }

nd = lctx->nd;
if (i >= nd) { 
    SETERRC(1,"Attempt to set solver for out-of-range domain"); return; }

/* lctx->bb[i].svc = SVCreate( matrix, v ); */
}

/*@
    SVSetBDDRegularDomains2d - Set the domains for a n1 x n2 regular mesh

    Input parameters:
.   ctx   - solver context
.   n1,n2 - mesh is n1 x n2
.   nc    - there are nc components per mesh point.  Components are numbered
            first

    Notes:
    This routine is provided as (1) an example of a routine to set 
    user-defined domains and (2) a service routine for a relatively
    common case.

    This uses a square decomposition.  It may change the number of subdomains
    if that number is not a square.
@*/
void SVSetBDDRegularDomains2d( ctx, n1, n2, nc )
SVctx *ctx;
int   n1, n2, nc;
{
SVOSMctx *lctx = (SVOSMctx *) ctx->private;
int      k, id, nx1, *idx;
int      nxi, nyi;         /* Number of points along each side of a domain */
int      sx, ex, sy, ey, ii, jj, cnt;
 
if (ctx->type != SVBDD) return;
lctx->nd = SViGet2dDomain( ctx->size, lctx->nd, n1, n2, nc, &nxi, &nyi, &nx1 );
CHKERR(1);

/* Allocate storage to hold the local indices */
idx = (int *)MALLOC( (nxi + 2) * (nyi + 2) * nc * sizeof(int) ); CHKPTR(idx);
/* Set the indices.  Note that we make the domains overlap along their
   common borders */
sy = 0;
ey = nyi;
id = 0;
for (jj = 0; jj < nx1; jj++) {
    sx = 0; 
    ex = nxi;
    for (ii = 0; ii < nx1; ii++) {
	cnt = SViNumber2dDomain( sy, ey, sx, ex, nc, n1, k, idx );
	SViSetBDDDecomp( ctx, id, idx, cnt );

	id++;
	sx = ex;
	ex += nxi;
	if (ex >= n1 || ii == nx1 - 2) ex = n1 - 1;
	}
    sy = ey;
    ey += nyi;
    if (ey >= n2 || jj == nx1 - 2) ey = n2 - 1;
    }
FREE( idx );
}
