#include <rand.h>
#include "libpgp5.h"

/*--------------------------------------------------*/
/* encrypt using DH public key */

int pkeenc5(unsigned long long keyid, unsigned char *dbuf,
            unsigned char *convkey, unsigned int ckeylen, unsigned int pkalg)
{
  unsigned char *bp, *bp0;
  unsigned int i, j;
  DH *dhkey;
#ifndef NO_RSA
  RSA *rsakey;
#endif

  if (pkalg == 0x10)
    dhkey = DH_new(),
      j = getkey5(&dhkey, NULL, NULL, &keyid);
#ifndef NO_RSA
  else if (pkalg == 1 || pkalg == 2)
    rsakey = RSA_new(),
      j = getkey5((DH **) & rsakey, (DSA **) & rsakey, NULL, &keyid);
#endif
  else
    return -1;
  if (j)
    return 0;

  dbuf[0] = 0xc1;
  bp = &dbuf[3];
  *bp++ = 3;                    /* version */
  for (j = 0; j < 8; j++)       /* key id */
    *bp++ = keyid >> (56 - 8 * j);

  *bp++ = pkalg;

  if (pkalg == 0x10) {
    BN_CTX *ctx = BN_CTX_new();
    BIGNUM *a = BN_new(), *b = BN_new(), *c = BN_new(), *d = BN_new();

/* set c to be 0 2 pad 0 alg key cks */
    bp0 = bp;
    *bp++ = 0;
    *bp++ = 2;
    j = BN_num_bytes(dhkey->p) - 3 - ckeylen;
    RAND_bytes(bp, j);          /* random, nonzero padding */
    while (j--) {
      while (!*bp)
        RAND_bytes(bp, 1);
      bp++;
    }
    *bp++ = 0;
    memcpy(bp, convkey, ckeylen);
    bp += ckeylen;

    j = BN_num_bits(dhkey->p);
    j = 200 + j / 32;           /* size: overkill, but simple */
    BN_rand(d, j, 1, 1);

    /* calc a and b */
    BN_mod_exp(a, dhkey->g, d, dhkey->p, ctx);
    BN_mod_exp(b, dhkey->pub_key, d, dhkey->p, ctx);

    BN_bin2bn(bp0, bp - bp0, c);  /* calculate B from c */
    bp = bp0;
    BN_mod_mul(b, c, b, dhkey->p, ctx);

    DH_free(dhkey);

    j = BN_num_bits(a);         /* write out A and B */
    *bp++ = j >> 8;
    *bp++ = j;
    BN_bn2bin(a, bp);
    bp += (j + 7) / 8;

    j = BN_num_bits(b);
    *bp++ = j >> 8;
    *bp++ = j;
    BN_bn2bin(b, bp);
    bp += (j + 7) / 8;

    BN_clear_free(a), BN_clear_free(b), BN_clear_free(c), BN_clear_free(d);
    BN_CTX_free(ctx);

  }
#ifndef NO_RSA
  else if (pkalg == 1 || pkalg == 2) {
    j = RSA_public_encrypt(ckeylen, convkey, bp, rsakey, RSA_PKCS1_PADDING);
    RSA_free(rsakey);
    while (bp[0] == 0)
      memmove(bp, &bp[1], j--);
    memmove(&bp[2], bp, j);
    *bp++ = j >> 5;
    *bp++ = j << 3;
    bp += j;
  }
#endif
  else
    return -3;

  i = bp - dbuf;
  j = i - 3;
  if (j > 191) {
    j -= 192;
    j |= 0xc000;
    dbuf[1] = j >> 8;
    dbuf[2] = j;
  } else {
    memmove(&dbuf[1], &dbuf[2], j + 1);
    dbuf[1] = j;
    i--;
  }
  return i;
}
