/*
 * This file is part of LibEuFin.
 * Copyright (C) 2023-2025 Taler Systems S.A.

 * LibEuFin is free software; you can redistribute it and/or modify
 * it under the terms of the GNU Affero General Public License as
 * published by the Free Software Foundation; either version 3, or
 * (at your option) any later version.

 * LibEuFin is distributed in the hope that it will be useful, but
 * WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
 * or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Affero General
 * Public License for more details.

 * You should have received a copy of the GNU Affero General Public
 * License along with LibEuFin; see the file COPYING.  If not, see
 * <http://www.gnu.org/licenses/>
 */

package tech.libeufin.bank.db

import tech.libeufin.bank.*
import tech.libeufin.common.*
import tech.libeufin.common.db.*
import java.sql.*
import org.postgresql.util.PSQLState

/** Data access logic for conversion */
class ConversionDAO(private val db: Database) {
    companion object {
        fun userRate(db: Database, it: ResultSet, username: String, isTalerExchange: Boolean): ConversionRate? {
            return if (db.fiatCurrency == null) {
                null
            } else if (username == "admin") {
                ConversionRate(
                    cashin_ratio = DecimalNumber.ZERO,
                    cashin_fee = TalerAmount.zero(db.bankCurrency), 
                    cashin_tiny_amount = TalerAmount.zero(db.bankCurrency),
                    cashin_rounding_mode = RoundingMode.zero,
                    cashin_min_amount = TalerAmount.zero(db.fiatCurrency),
                    cashout_ratio = DecimalNumber.ZERO,
                    cashout_fee = TalerAmount.zero(db.fiatCurrency), 
                    cashout_tiny_amount = TalerAmount.zero(db.fiatCurrency),
                    cashout_rounding_mode = RoundingMode.zero,
                    cashout_min_amount = TalerAmount.zero(db.bankCurrency),
                )
            } else if (isTalerExchange) {
                ConversionRate(
                    cashin_ratio = it.getDecimal("cashin_ratio"),
                    cashin_fee = it.getAmount("cashin_fee", db.bankCurrency), 
                    cashin_tiny_amount = it.getAmount("cashin_tiny_amount", db.bankCurrency),
                    cashin_rounding_mode = it.getEnum("cashin_rounding_mode"),
                    cashin_min_amount = it.getAmount("cashin_min_amount", db.fiatCurrency),
                    cashout_ratio = DecimalNumber.ZERO,
                    cashout_fee = TalerAmount.zero(db.fiatCurrency), 
                    cashout_tiny_amount = TalerAmount.zero(db.fiatCurrency),
                    cashout_rounding_mode = RoundingMode.zero,
                    cashout_min_amount = TalerAmount.zero(db.bankCurrency),
                )
            } else {
                ConversionRate(
                    cashin_ratio = DecimalNumber.ZERO,
                    cashin_fee = TalerAmount.zero(db.bankCurrency), 
                    cashin_tiny_amount = TalerAmount.zero(db.bankCurrency),
                    cashin_rounding_mode = RoundingMode.zero,
                    cashin_min_amount = TalerAmount.zero(db.fiatCurrency),
                    cashout_ratio = it.getDecimal("cashout_ratio"),
                    cashout_fee = it.getAmount("cashout_fee", db.fiatCurrency), 
                    cashout_tiny_amount = it.getAmount("cashout_tiny_amount", db.fiatCurrency),
                    cashout_rounding_mode = it.getEnum("cashout_rounding_mode"),
                    cashout_min_amount = it.getAmount("cashout_min_amount", db.bankCurrency),
                )
            }
        }
    }

    /** Update in-db conversion config */
    suspend fun updateConfig(cfg: ConversionRate) = db.serializable("""
        CALL config_set_conversion_rate(
            (?, ?)::taler_amount, (?, ?)::taler_amount, (?, ?)::taler_amount, (?, ?)::taler_amount, ?::rounding_mode, 
            (?, ?)::taler_amount, (?, ?)::taler_amount, (?, ?)::taler_amount, (?, ?)::taler_amount, ?::rounding_mode
        )
    """) {
        bind(cfg.cashin_ratio)
        bind(cfg.cashin_fee)
        bind(cfg.cashin_tiny_amount)
        bind(cfg.cashin_min_amount)
        bind(cfg.cashin_rounding_mode)
        bind(cfg.cashout_ratio)
        bind(cfg.cashout_fee)
        bind(cfg.cashout_tiny_amount)
        bind(cfg.cashout_min_amount)
        bind(cfg.cashout_rounding_mode)
        executeUpdate()
    }

    /** Get default conversion rate */
    suspend fun getDefaultRate(): ConversionRate = db.serializable("""
        SELECT 
            (cashin_ratio).val as cashin_ratio_val, (cashin_ratio).frac as cashin_ratio_frac,
            (cashin_fee).val as cashin_fee_val, (cashin_fee).frac as cashin_fee_frac,
            (cashin_tiny_amount).val as cashin_tiny_amount_val, (cashin_tiny_amount).frac as cashin_tiny_amount_frac,
            (cashin_min_amount).val as cashin_min_amount_val, (cashin_min_amount).frac as cashin_min_amount_frac,
            cashin_rounding_mode,
            (cashout_ratio).val as cashout_ratio_val, (cashout_ratio).frac as cashout_ratio_frac,
            (cashout_fee).val as cashout_fee_val, (cashout_fee).frac as cashout_fee_frac,
            (cashout_tiny_amount).val as cashout_tiny_amount_val, (cashout_tiny_amount).frac as cashout_tiny_amount_frac,
            (cashout_min_amount).val as cashout_min_amount_val, (cashout_min_amount).frac as cashout_min_amount_frac,
            cashout_rounding_mode
        FROM config_get_conversion_rate()
    """) {
        one {
            ConversionRate(
                cashin_ratio = it.getDecimal("cashin_ratio"),
                cashin_fee = it.getAmount("cashin_fee", db.bankCurrency), 
                cashin_tiny_amount = it.getAmount("cashin_tiny_amount", db.bankCurrency),
                cashin_rounding_mode = it.getEnum("cashin_rounding_mode"),
                cashin_min_amount = it.getAmount("cashin_min_amount", db.fiatCurrency!!),
                cashout_ratio = it.getDecimal("cashout_ratio"),
                cashout_fee = it.getAmount("cashout_fee", db.fiatCurrency), 
                cashout_tiny_amount = it.getAmount("cashout_tiny_amount", db.fiatCurrency),
                cashout_rounding_mode = it.getEnum("cashout_rounding_mode"),
                cashout_min_amount = it.getAmount("cashout_min_amount", db.bankCurrency),
            )
        }
    }

    /** Get conversion class rate */
    suspend fun getClassRate(conversionRateClassId: Long): ConversionRate = db.serializable("""
        SELECT 
            (cashin_ratio).val as cashin_ratio_val, (cashin_ratio).frac as cashin_ratio_frac,
            (cashin_fee).val as cashin_fee_val, (cashin_fee).frac as cashin_fee_frac,
            (cashin_tiny_amount).val as cashin_tiny_amount_val, (cashin_tiny_amount).frac as cashin_tiny_amount_frac,
            (cashin_min_amount).val as cashin_min_amount_val, (cashin_min_amount).frac as cashin_min_amount_frac,
            cashin_rounding_mode,
            (cashout_ratio).val as cashout_ratio_val, (cashout_ratio).frac as cashout_ratio_frac,
            (cashout_fee).val as cashout_fee_val, (cashout_fee).frac as cashout_fee_frac,
            (cashout_tiny_amount).val as cashout_tiny_amount_val, (cashout_tiny_amount).frac as cashout_tiny_amount_frac,
            (cashout_min_amount).val as cashout_min_amount_val, (cashout_min_amount).frac as cashout_min_amount_frac,
            cashout_rounding_mode
        FROM get_conversion_class_rate(?)
    """) {
        bind(conversionRateClassId)
        one {
            ConversionRate(
                cashin_ratio = it.getDecimal("cashin_ratio"),
                cashin_fee = it.getAmount("cashin_fee", db.bankCurrency), 
                cashin_tiny_amount = it.getAmount("cashin_tiny_amount", db.bankCurrency),
                cashin_rounding_mode = it.getEnum("cashin_rounding_mode"),
                cashin_min_amount = it.getAmount("cashin_min_amount", db.fiatCurrency!!),
                cashout_ratio = it.getDecimal("cashout_ratio"),
                cashout_fee = it.getAmount("cashout_fee", db.fiatCurrency), 
                cashout_tiny_amount = it.getAmount("cashout_tiny_amount", db.fiatCurrency),
                cashout_rounding_mode = it.getEnum("cashout_rounding_mode"),
                cashout_min_amount = it.getAmount("cashout_min_amount", db.bankCurrency),
            )
        }
    }

    /** Get user rate */
    suspend fun getUserRate(username: String): Pair<Boolean, ConversionRate> = db.serializable("""
        SELECT 
            (cashin_ratio).val as cashin_ratio_val, (cashin_ratio).frac as cashin_ratio_frac,
            (cashin_fee).val as cashin_fee_val, (cashin_fee).frac as cashin_fee_frac,
            (cashin_tiny_amount).val as cashin_tiny_amount_val, (cashin_tiny_amount).frac as cashin_tiny_amount_frac,
            (cashin_min_amount).val as cashin_min_amount_val, (cashin_min_amount).frac as cashin_min_amount_frac,
            cashin_rounding_mode,
            (cashout_ratio).val as cashout_ratio_val, (cashout_ratio).frac as cashout_ratio_frac,
            (cashout_fee).val as cashout_fee_val, (cashout_fee).frac as cashout_fee_frac,
            (cashout_tiny_amount).val as cashout_tiny_amount_val, (cashout_tiny_amount).frac as cashout_tiny_amount_frac,
            (cashout_min_amount).val as cashout_min_amount_val, (cashout_min_amount).frac as cashout_min_amount_frac,
            cashout_rounding_mode,
            is_taler_exchange
        FROM bank_accounts
            JOIN customers ON customer_id=owning_customer_id
            CROSS JOIN LATERAL get_conversion_class_rate(conversion_rate_class_id)
        WHERE username=?
    """) {
        bind(username)
        one {
            val isTalerExchange = it.getBoolean("is_taler_exchange")
            val rate = ConversionDAO.userRate(db, it, username,isTalerExchange)!!
            Pair(isTalerExchange, rate)
        }
    }

    /** Clear in-db conversion config */
    suspend fun clearConfig() = db.serializable(
        "DELETE FROM config WHERE key LIKE 'cashin%' OR key like 'cashout%'"
    ) { 
        executeUpdate()
    }

    /** Result of conversions operations */
    sealed interface ConversionResult {
        data class Success(val converted: TalerAmount): ConversionResult
        data object ToSmall: ConversionResult
        data object IsExchange: ConversionResult
        data object NotExchange: ConversionResult
    }

    /** Perform [direction] conversion of [amount] using in-db [function] */
    private suspend fun conversion(amount: TalerAmount, function: String, direction: String, conversionRateClassId: Long?): ConversionResult = db.serializable(
        "SELECT too_small, (converted).val AS amount_val, (converted).frac AS amount_frac FROM conversion_$function((?, ?)::taler_amount, ?, ?)"
    ) { 
        bind(amount)
        bind(direction)
        bind(conversionRateClassId)
        one {
            when {
                it.getBoolean("too_small") -> ConversionResult.ToSmall
                else -> ConversionResult.Success(
                    it.getAmount("amount", if (amount.currency == db.bankCurrency) db.fiatCurrency!! else db.bankCurrency)
                )
            }
        }
    }
    private suspend fun userConversion(amount: TalerAmount, function: String, direction: String, username: String): ConversionResult = db.serializable(
        """
        SELECT
            is_taler_exchange,
            too_small,
            (converted).val AS amount_val,
            (converted).frac AS amount_frac
        FROM bank_accounts
            JOIN customers ON customer_id=owning_customer_id,
        LATERAL conversion_$function((?, ?)::taler_amount, ?, conversion_rate_class_id)
        WHERE username=?
        """
    ) { 
        bind(amount)
        bind(direction)
        bind(username)
        one {
            val isExchange = it.getBoolean("is_taler_exchange")
            when {
                direction == "cashout" && isExchange -> ConversionResult.IsExchange
                direction == "cashin" && !isExchange -> ConversionResult.NotExchange
                it.getBoolean("too_small") -> ConversionResult.ToSmall
                else -> ConversionResult.Success(
                    it.getAmount("amount", if (amount.currency == db.bankCurrency) db.fiatCurrency!! else db.bankCurrency)
                )
            }
        }
    }
 
    /** Convert [regional] amount to fiat using cashout rate */
    suspend fun defaultToCashout(regional: TalerAmount): ConversionResult = conversion(regional, "to", "cashout", null)
    suspend fun classToCashout(id: Long, regional: TalerAmount): ConversionResult = conversion(regional, "to", "cashout", id)
    suspend fun userToCashout(username: String, regional: TalerAmount): ConversionResult = userConversion(regional, "to", "cashout", username)
    /** Convert [fiat] amount to regional using cashin rate */
    suspend fun defaultToCashin(fiat: TalerAmount): ConversionResult = conversion(fiat, "to", "cashin", null)
    suspend fun classToCashin(id: Long, fiat: TalerAmount): ConversionResult = conversion(fiat, "to", "cashin", id)
    suspend fun userToCashin(username: String, fiat: TalerAmount): ConversionResult = userConversion(fiat, "to", "cashin", username)
    /** Convert [fiat] amount to regional using inverse cashout rate */
    suspend fun defaultFromCashout(fiat: TalerAmount): ConversionResult = conversion(fiat, "from", "cashout", null)
    suspend fun classFromCashout(id: Long, fiat: TalerAmount): ConversionResult = conversion(fiat, "from", "cashout", id)
    suspend fun userFromCashout(username: String, fiat: TalerAmount): ConversionResult = userConversion(fiat, "from", "cashout", username)
    /** Convert [regional] amount to fiat using inverse cashin rate */
    suspend fun defaultFromCashin(regional: TalerAmount): ConversionResult = conversion(regional, "from", "cashin", null)
    suspend fun classFromCashin(id: Long, regional: TalerAmount): ConversionResult = conversion(regional, "from", "cashin", id)
    suspend fun userFromCashin(username: String, regional: TalerAmount): ConversionResult = userConversion(regional, "from", "cashin", username)
    
    /** Result status of conversion rate class creation */
    sealed interface ClassCreateResult {
        data class Success(val id: Long): ClassCreateResult
        data object NameReuse: ClassCreateResult
    }

    /** Create a new conversion rate class */
    suspend fun createClass(
        input: ConversionRateClassInput
    ): ClassCreateResult = db.serializable(
        """
        INSERT INTO conversion_rate_classes (
             name
            ,description
            ,cashin_ratio
            ,cashin_fee
            ,cashin_min_amount
            ,cashin_rounding_mode
            ,cashout_ratio
            ,cashout_fee
            ,cashout_min_amount
            ,cashout_rounding_mode
        ) VALUES (
            ?, ?,
            ${optDecimal(input.cashin_ratio)},
            ${optAmount(input.cashin_fee)},
            ${optAmount(input.cashin_min_amount)},
            ?::rounding_mode,
            ${optDecimal(input.cashout_ratio)},
            ${optAmount(input.cashout_fee)},
            ${optAmount(input.cashout_min_amount)},
            ?::rounding_mode
        )
        RETURNING conversion_rate_class_id 
        """
    ) {
        bind(input.name)
        bind(input.description)
        bind(input.cashin_ratio)
        bind(input.cashin_fee)
        bind(input.cashin_min_amount)
        bind(input.cashin_rounding_mode)
        bind(input.cashout_ratio)
        bind(input.cashout_fee)
        bind(input.cashout_min_amount)
        bind(input.cashout_rounding_mode)
        try {
            one {
                ClassCreateResult.Success(it.getLong("conversion_rate_class_id"))
            }
        } catch (e: SQLException) {
            if (e.sqlState == PSQLState.UNIQUE_VIOLATION.state) ClassCreateResult.NameReuse
            else throw e
        }
        
    }

    /** Result status of conversion rate class patching */
    enum class ClassPatchResult {
        Success,
        Unknown,
        NameReuse
    }

    /** Patch a conversion rate class */
    suspend fun patchClass(
        id: Long,
        input: ConversionRateClassInput
    ): ClassPatchResult = db.serializable(
        """
        UPDATE conversion_rate_classes SET
             name=?
            ,description=?
            ,cashin_ratio=${optDecimal(input.cashin_ratio)}
            ,cashin_fee=${optAmount(input.cashin_fee)}
            ,cashin_min_amount=${optAmount(input.cashin_min_amount)}
            ,cashin_rounding_mode=?::rounding_mode
            ,cashout_ratio=${optDecimal(input.cashout_ratio)}
            ,cashout_fee=${optAmount(input.cashout_fee)}
            ,cashout_min_amount=${optAmount(input.cashout_min_amount)}
            ,cashout_rounding_mode=?::rounding_mode
        WHERE conversion_rate_class_id=?
        """
    ) {
        bind(input.name)
        bind(input.description)
        bind(input.cashin_ratio)
        bind(input.cashin_fee)
        bind(input.cashin_min_amount)
        bind(input.cashin_rounding_mode)
        bind(input.cashout_ratio)
        bind(input.cashout_fee)
        bind(input.cashout_min_amount)
        bind(input.cashout_rounding_mode)
        bind(id)
        try {
            if (executeUpdateCheck()) {
                ClassPatchResult.Success
            } else {
                ClassPatchResult.Unknown
            }
        } catch (e: SQLException) {
            if (e.sqlState == PSQLState.UNIQUE_VIOLATION.state) ClassPatchResult.NameReuse
            else throw e
        }
    }

    /** Delete a conversion rate class */
    suspend fun deleteClass(
        id: Long
    ): Boolean = db.serializable(
        "DELETE FROM conversion_rate_classes WHERE conversion_rate_class_id=?"
    ) {
        bind(id)
        executeUpdateCheck()
    }

    /** Get conversion rate class [id] */
    suspend fun getClass(id: Long): ConversionRateClass? = db.serializable(
        """
        SELECT
            name
            ,description
            ,(cashin_ratio).val as cashin_ratio_val, (cashin_ratio).frac as cashin_ratio_frac
            ,(cashin_fee).val as cashin_fee_val, (cashin_fee).frac as cashin_fee_frac
            ,(cashin_min_amount).val as cashin_min_amount_val, (cashin_min_amount).frac as cashin_min_amount_frac
            ,cashin_rounding_mode
            ,(cashout_ratio).val as cashout_ratio_val, (cashout_ratio).frac as cashout_ratio_frac
            ,(cashout_fee).val as cashout_fee_val, (cashout_fee).frac as cashout_fee_frac
            ,(cashout_min_amount).val as cashout_min_amount_val, (cashout_min_amount).frac as cashout_min_amount_frac
            ,cashout_rounding_mode
            ,(SELECT count(*) FROM bank_accounts WHERE bank_accounts.conversion_rate_class_id=conversion_rate_classes.conversion_rate_class_id) as num_users
        FROM conversion_rate_classes
        WHERE conversion_rate_class_id=?
        """
    ) {
        bind(id)
        oneOrNull {
            ConversionRateClass(
                name = it.getString("name"),
                description = it.getString("description"),
                conversion_rate_class_id = id,
                num_users = it.getInt("num_users"),
                cashin_ratio = it.getOptDecimal("cashin_ratio"),
                cashin_fee = it.getOptAmount("cashin_fee", db.bankCurrency),
                cashin_rounding_mode = it.getOptEnum<RoundingMode>("cashin_rounding_mode"),
                cashin_min_amount = it.getOptAmount("cashin_min_amount", db.fiatCurrency!!),
                cashout_ratio = it.getOptDecimal("cashout_ratio"),
                cashout_fee = it.getOptAmount("cashout_fee", db.fiatCurrency),
                cashout_rounding_mode = it.getOptEnum<RoundingMode>("cashout_rounding_mode"),
                cashout_min_amount = it.getOptAmount("cashout_min_amount", db.bankCurrency),
            )
        }
    }

    /** Get a page of conversion rate classes */
    suspend fun pageClass(params: ClassParams): List<ConversionRateClass>
        = db.page(
            params.page,
            "conversion_rate_class_id",
            """
            SELECT
            name
            ,description
            ,(cashin_ratio).val as cashin_ratio_val, (cashin_ratio).frac as cashin_ratio_frac
            ,(cashin_fee).val as cashin_fee_val, (cashin_fee).frac as cashin_fee_frac
            ,(cashin_min_amount).val as cashin_min_amount_val, (cashin_min_amount).frac as cashin_min_amount_frac
            ,cashin_rounding_mode
            ,(cashout_ratio).val as cashout_ratio_val, (cashout_ratio).frac as cashout_ratio_frac
            ,(cashout_fee).val as cashout_fee_val, (cashout_fee).frac as cashout_fee_frac
            ,(cashout_min_amount).val as cashout_min_amount_val, (cashout_min_amount).frac as cashout_min_amount_frac
            ,cashout_rounding_mode
            ,(SELECT count(*) FROM bank_accounts WHERE bank_accounts.conversion_rate_class_id=conversion_rate_classes.conversion_rate_class_id) as num_users
            ,conversion_rate_class_id
            FROM conversion_rate_classes
            WHERE ${if (params.nameFilter != null) "name ILIKE ? AND" else ""}
            """,
            {
                if (params.nameFilter != null) {
                    bind(params.nameFilter)
                }
            }
        ) {
            ConversionRateClass(
                name = it.getString("name"),
                description = it.getString("description"),
                conversion_rate_class_id = it.getLong("conversion_rate_class_id"),
                num_users = it.getInt("num_users"),
                cashin_ratio = it.getOptDecimal("cashin_ratio"),
                cashin_fee = it.getOptAmount("cashin_fee", db.bankCurrency),
                cashin_rounding_mode = it.getOptEnum<RoundingMode>("cashin_rounding_mode"),
                cashin_min_amount = it.getOptAmount("cashin_min_amount", db.fiatCurrency!!),
                cashout_ratio = it.getOptDecimal("cashout_ratio"),
                cashout_fee = it.getOptAmount("cashout_fee", db.fiatCurrency),
                cashout_rounding_mode = it.getOptEnum<RoundingMode>("cashout_rounding_mode"),
                cashout_min_amount = it.getOptAmount("cashout_min_amount", db.bankCurrency),
            )
        }
}