#define _GNU_SOURCE 1

#include <sys/ioctl.h>
#include <sys/stat.h>
#include <fcntl.h>
#include <stdio.h>
#include <unistd.h>
#include <string.h>
#include <stdint.h>

#if defined __has_include && __has_include(<libdrm/drm.h>)
#  include <libdrm/drm_fourcc.h>
#else
#  include <drm/drm_fourcc.h>
#endif

#include "nv-driver.h"
#include <nvidia.h>
#include "../vabackend.h"

static const NvHandle NULL_OBJECT = {0};

bool nv_alloc_object(int fd, NvHandle hRoot, NvHandle hObjectParent, NvHandle* hObjectNew, NvV32 hClass, void* params) {
    NVOS64_PARAMETERS alloc = {
        .hRoot = hRoot,
        .hObjectParent = hObjectParent,
        .hObjectNew = *hObjectNew,
        .hClass = hClass,
        .pRightsRequested = NULL,
        .pAllocParms = params
    };

    int ret = ioctl(fd, _IOC(_IOC_READ|_IOC_WRITE, NV_IOCTL_MAGIC, NV_ESC_RM_ALLOC, sizeof(NVOS64_PARAMETERS)), &alloc);

    if (ret != 0 || alloc.status != NV_OK) {
        LOG("nv_alloc_object failed: %d %X", ret, alloc.status);
        return false;
    }

    *hObjectNew = alloc.hObjectNew;

    return true;
}

bool nv_free_object(int fd, NvHandle hRoot, NvHandle hObject) {
    if (hObject == 0) {
        return true;
    }

    NVOS00_PARAMETERS freeParams = {
        .hRoot = hRoot,
        .hObjectParent = NULL_OBJECT,
        .hObjectOld = hObject
    };

    int ret = ioctl(fd, _IOC(_IOC_READ|_IOC_WRITE, NV_IOCTL_MAGIC, NV_ESC_RM_FREE, sizeof(NVOS00_PARAMETERS)), &freeParams);

    if (ret != 0 || freeParams.status != NV_OK) {
        LOG("nv_free_object failed: %d %X", ret, freeParams.status);
        return false;
    }

    return true;
}

bool nv_rm_control(int fd, NvHandle hClient, NvHandle hObject, NvV32 cmd, NvU32 flags, int paramSize, void* params) {
    NVOS54_PARAMETERS control = {
        .hClient = hClient,
        .hObject = hObject,
        .cmd = cmd,
        .flags = flags,
        .params = params,
        .paramsSize = paramSize
    };

    int ret = ioctl(fd, _IOC(_IOC_READ|_IOC_WRITE, NV_IOCTL_MAGIC, NV_ESC_RM_CONTROL, sizeof(NVOS54_PARAMETERS)), &control);

    if (ret != 0 || control.status != NV_OK) {
        LOG("nv_rm_control failed: %d %X", ret, control.status);
        return false;
    }

    return true;
}

bool nv_check_version(int fd, char *versionString) {
    nv_ioctl_rm_api_version_t obj = {
        .cmd = 0
    };

    strcpy(obj.versionString, versionString);

    int ret = ioctl(fd, _IOC(_IOC_READ|_IOC_WRITE, NV_IOCTL_MAGIC, NV_ESC_CHECK_VERSION_STR, sizeof(obj)), &obj);

    return ret == 0 && obj.reply == NV_RM_API_VERSION_REPLY_RECOGNIZED;
}

NvU64 nv_sys_params(int fd) {
    //read from /sys/devices/system/memory/block_size_bytes
    nv_ioctl_sys_params_t obj = { .memblock_size = 0x8000000 };

    int ret = ioctl(fd, _IOC(_IOC_READ|_IOC_WRITE, NV_IOCTL_MAGIC, NV_ESC_SYS_PARAMS, sizeof(obj)), &obj);

    return ret == 0 ? obj.memblock_size : 0;
}

bool nv_card_info(int fd, nv_ioctl_card_info_t (*card_info)[32]) {
    int ret = ioctl(fd, _IOC(_IOC_READ|_IOC_WRITE, NV_IOCTL_MAGIC, NV_ESC_CARD_INFO, sizeof(nv_ioctl_card_info_t) * 32), card_info);

    return ret == 0;
}

bool nv_attach_gpus(int fd, int gpu) {
    int ret = ioctl(fd, _IOC(_IOC_READ|_IOC_WRITE, NV_IOCTL_MAGIC, NV_ESC_ATTACH_GPUS_TO_FD, sizeof(gpu)), &gpu);

    return ret == 0;
}

bool nv_export_object_to_fd(int fd, int export_fd, NvHandle hClient, NvHandle hDevice, NvHandle hParent, NvHandle hObject) {
    NV0000_CTRL_OS_UNIX_EXPORT_OBJECT_TO_FD_PARAMS params = {
        .fd = export_fd,
        .flags = 0,
        .object = {
            .type = NV0000_CTRL_OS_UNIX_EXPORT_OBJECT_TYPE_RM,
            .data.rmObject = {
                .hDevice = hDevice,
                .hParent = hParent,
                .hObject = hObject
            }
        }
    };

    return nv_rm_control(fd, hClient, hClient, NV0000_CTRL_CMD_OS_UNIX_EXPORT_OBJECT_TO_FD, 0, sizeof(params), &params);
}

bool nv_get_versions(int fd, NvHandle hClient, char **driverVersion) {
    char driverVersionBuffer[64];
    char versionBuffer[64];
    char titleBuffer[64];
    NV0000_CTRL_SYSTEM_GET_BUILD_VERSION_PARAMS params = {
        .sizeOfStrings = 64,
        .pDriverVersionBuffer = driverVersionBuffer,
        .pVersionBuffer = versionBuffer,
        .pTitleBuffer = titleBuffer
    };
    bool ret = nv_rm_control(fd, hClient, hClient, NV0000_CTRL_CMD_SYSTEM_GET_BUILD_VERSION, 0, sizeof(params), &params);
    if (!ret) {
        LOG("NV0000_CTRL_CMD_SYSTEM_GET_BUILD_VERSION failed");
        return false;
    }

    *driverVersion = strdup(driverVersionBuffer);

    return true;
}

bool nv0_register_fd(int nv0_fd, int nvctl_fd) {
    int ret = ioctl(nv0_fd, _IOC(_IOC_READ|_IOC_WRITE, NV_IOCTL_MAGIC, NV_ESC_REGISTER_FD, sizeof(int)), &nvctl_fd);
    return ret == 0;
}

bool get_device_info(int fd, struct drm_nvidia_get_dev_info_params *devInfo) {
    int ret = ioctl(fd, DRM_IOCTL_NVIDIA_GET_DEV_INFO, devInfo);
    if (ret) {
        LOG("DRM_IOCTL_NVIDIA_GET_DEV_INFO failed: %d", ret);
        return false;
    }
    return true;
}

bool get_device_uuid(NVDriverContext *context, char uuid[16]) {
    NV0000_CTRL_GPU_GET_UUID_FROM_GPU_ID_PARAMS uuidParams = {
        .gpuId = context->devInfo.gpu_id,
        .flags = NV0000_CTRL_CMD_GPU_GET_UUID_FROM_GPU_ID_FLAGS_FORMAT_BINARY |
                 NV0000_CTRL_CMD_GPU_GET_UUID_FROM_GPU_ID_FLAGS_TYPE_SHA1
    };
    int ret = nv_rm_control(context->nvctlFd, context->clientObject, context->clientObject, NV0000_CTRL_CMD_GPU_GET_UUID_FROM_GPU_ID, 0, sizeof(uuidParams), &uuidParams);
    if (ret) {
        return false;
    }

    for (int i = 0; i < 16; i++) {
        uuid[i] = uuidParams.gpuUuid[i];
    }

    return true;
}

bool init_nvdriver(NVDriverContext *context, int drmFd) {
    LOG("Initing nvdriver...");
    if (!get_device_info(drmFd, &context->devInfo)) {
        return false;
    }

    LOG("Got dev info: %x %x %x %x", context->devInfo.gpu_id, context->devInfo.sector_layout, context->devInfo.page_kind_generation, context->devInfo.generic_page_kind);

    int nvctlFd = -1, nv0Fd = -1;

    nvctlFd = open("/dev/nvidiactl", O_RDWR|O_CLOEXEC);
    if (nvctlFd == -1) {
        goto err;
    }

    nv0Fd = open("/dev/nvidia0", O_RDWR|O_CLOEXEC);
    if (nv0Fd == -1) {
        goto err;
    }

    //nv_check_version(nvctl_fd, "515.48.07");
    //not sure why this is called.
    //printf("sys params: %llu\n", nv_sys_params(nvctl_fd));

    //allocate the root object
    bool ret = nv_alloc_object(nvctlFd, NULL_OBJECT, NULL_OBJECT, &context->clientObject, NV01_ROOT_CLIENT, (void*)0);
    if (!ret) {
        LOG("nv_alloc_object NV01_ROOT_CLIENT failed");
        goto err;
    }

    //attach the drm fd to this handle
    ret = nv_attach_gpus(nvctlFd, context->devInfo.gpu_id);
    if (!ret) {
        LOG("nv_attach_gpu failed");
        goto err;
    }

    //allocate the parent memory object
    NV0080_ALLOC_PARAMETERS deviceParams = {
       .hClientShare = context->clientObject
    };

    //allocate the device object
    ret = nv_alloc_object(nvctlFd, context->clientObject, context->clientObject, &context->deviceObject, NV01_DEVICE_0, &deviceParams);
    if (!ret) {
        LOG("nv_alloc_object NV01_DEVICE_0 failed");
        goto err;
    }

    //allocate the subdevice object
    NV2080_ALLOC_PARAMETERS subdevice = { 0 };
    ret = nv_alloc_object(nvctlFd, context->clientObject, context->deviceObject, &context->subdeviceObject, NV20_SUBDEVICE_0, &subdevice);
    if (!ret) {
        LOG("nv_alloc_object NV20_SUBDEVICE_0 failed");
        goto err;
    }

    //TODO honestly not sure if this is needed
    ret = nv0_register_fd(nv0Fd, nvctlFd);
    if (!ret) {
        LOG("nv0_register_fd failed");
        goto err;
    }

    char *ver;
    nv_get_versions(nvctlFd, context->clientObject, &ver);
    LOG("NVIDIA kernel driver version: %s", ver);
    context->driverMajorVersion = atoi(ver);
    free(ver);

    //figure out what page sizes are available
    //we don't actually need this at the moment
//    NV0080_CTRL_DMA_ADV_SCHED_GET_VA_CAPS_PARAMS vaParams = {0};
//    ret = nv_rm_control(nvctlFd, context->clientObject, context->deviceObject, NV0080_CTRL_CMD_DMA_ADV_SCHED_GET_VA_CAPS, 0, sizeof(vaParams), &vaParams);
//    if (!ret) {
//        LOG("NV0080_CTRL_CMD_DMA_ADV_SCHED_GET_VA_CAPS failed");
//        goto err;
//    }
//    LOG("Got big page size: %d, huge page size: %d", vaParams.bigPageSize, vaParams.hugePageSize);

    context->drmFd = drmFd;
    context->nvctlFd = nvctlFd;
    context->nv0Fd = nv0Fd;
    //context->hasHugePage = vaParams.hugePageSize != 0;

    return true;
err:

    LOG("Got error initing");
    if (nvctlFd != -1) {
        close(nvctlFd);
    }
    if (nv0Fd != -1) {
        close(nv0Fd);
    }
    return false;
}

bool free_nvdriver(NVDriverContext *context) {
    nv_free_object(context->nvctlFd, context->clientObject, context->subdeviceObject);
    nv_free_object(context->nvctlFd, context->clientObject, context->deviceObject);
    nv_free_object(context->nvctlFd, context->clientObject, context->clientObject);

    if (context->nvctlFd > 0) {
        close(context->nvctlFd);
    }
    if (context->drmFd > 0) {
        close(context->drmFd);
    }
    if (context->nv0Fd > 0) {
        close(context->nv0Fd);
    }

    memset(context, 0, sizeof(NVDriverContext));
    return true;
}

bool alloc_memory(NVDriverContext *context, uint32_t size, int *fd) {
    //allocate the buffer
    int nvctlFd2 = -1;
    NvHandle bufferObject = {0};

    //we don't have huge pages available on all hardware
    //turns out we don't need to know that anyway, although this will probably result is less optimal page size
    /*
    NvU32 pageSizeAttr = context->hasHugePage ? DRF_DEF(OS32, _ATTR, _PAGE_SIZE, _HUGE)
                                              : DRF_DEF(OS32, _ATTR, _PAGE_SIZE, _BIG);
    NvU32 pageSizeAttr2 = context->hasHugePage ? DRF_DEF(OS32, _ATTR2, _PAGE_SIZE_HUGE, _2MB)
                                               : 0;*/

    NV_MEMORY_ALLOCATION_PARAMS memParams = {
        .owner = context->clientObject,
        .type = NVOS32_TYPE_IMAGE,
        .flags = NVOS32_ALLOC_FLAGS_IGNORE_BANK_PLACEMENT |
                 //NVOS32_ALLOC_FLAGS_ALIGNMENT_FORCE | //this doesn't seem to be needed
                 NVOS32_ALLOC_FLAGS_MAP_NOT_REQUIRED |
                 NVOS32_ALLOC_FLAGS_PERSISTENT_VIDMEM,

        .attr = //pageSizeAttr |
                DRF_DEF(OS32, _ATTR, _DEPTH, _UNKNOWN) |
                DRF_DEF(OS32, _ATTR, _FORMAT, _BLOCK_LINEAR) |
                DRF_DEF(OS32, _ATTR, _PHYSICALITY, _CONTIGUOUS),
        .format = 0,
        .width = 0,
        .height = 0,
        .size = size,
        .alignment = 0, //see flags above
        .attr2 = //pageSizeAttr2 |
                 DRF_DEF(OS32, _ATTR2, _ZBC, _PREFER_NO_ZBC) |
                 DRF_DEF(OS32, _ATTR2, _GPU_CACHEABLE, _YES)
    };
    bool ret = nv_alloc_object(context->nvctlFd, context->clientObject, context->deviceObject, &bufferObject, NV01_MEMORY_LOCAL_USER, &memParams);
    if (!ret) {
        LOG("nv_alloc_object NV01_MEMORY_LOCAL_USER failed");
        return false;
    }

    //open a new handle to return
    nvctlFd2 = open("/dev/nvidiactl", O_RDWR|O_CLOEXEC);
    if (nvctlFd2 == -1) {
        LOG("open /dev/nvidiactl failed");
        goto err;
    }

    //attach the new fd to the correct gpus
    ret = nv_attach_gpus(nvctlFd2, context->devInfo.gpu_id);
    if (!ret) {
        LOG("nv_attach_gpus failed");
        goto err;
    }

    //actually export the object
    ret = nv_export_object_to_fd(context->nvctlFd, nvctlFd2, context->clientObject, context->deviceObject, context->deviceObject, bufferObject);
    if (!ret) {
        LOG("nv_export_object_to_fd failed");
        goto err;
    }

    ret = nv_free_object(context->nvctlFd, context->clientObject, bufferObject);
    if (!ret) {
        LOG("nv_free_object failed");
        goto err;
    }

    *fd = nvctlFd2;
    return true;

 err:
    LOG("error");
    if (nvctlFd2 > 0) {
        close(nvctlFd2);
    }

    ret = nv_free_object(context->nvctlFd, context->clientObject, bufferObject);
    if (!ret) {
        LOG("nv_free_object failed");
    }

    return false;
}

bool alloc_image(NVDriverContext *context, uint32_t width, uint32_t height, uint8_t channels, uint8_t bitsPerChannel, NVDriverImage *image) {
    uint32_t depth = 1;
    uint32_t gobWidthInBytes = 64;
    uint32_t gobHeightInBytes = 8;
    uint32_t gobDepthInBytes = 1;

    uint32_t bytesPerChannel = bitsPerChannel/8;
    uint32_t bytesPerPixel = channels * bytesPerChannel;

    //first figure out the gob layout
    uint32_t log2GobsPerBlockX = 0; //TODO not sure if these are the correct numbers to start with, but they're the largest ones i've seen used
    uint32_t log2GobsPerBlockY = height < 96 ? 3 : 4; //TODO 96 is a guess, 80px high needs 3, 112px needs 4, 96px needs 4
    uint32_t log2GobsPerBlockZ = 0;

//    while (log2GobsPerBlockX > 0 && (gobWidthInBytes << (log2GobsPerBlockX - 1)) >= width * bytesPerPixel)
//        log2GobsPerBlockX--;
//    while (log2GobsPerBlockY > 0 && (gobHeightInBytes << (log2GobsPerBlockY - 1)) >= height)
//        log2GobsPerBlockY--;
//    while (log2GobsPerBlockZ > 0 && (gobDepthInBytes << (log2GobsPerBlockZ - 1)) >= depth)
//        log2GobsPerBlockZ--;

    LOG("Calculated GOB size: %dx%d (%dx%d)", gobWidthInBytes << log2GobsPerBlockX, gobHeightInBytes << log2GobsPerBlockY, log2GobsPerBlockX, log2GobsPerBlockY);

    //These two seem to be correct, but it was discovered by trial and error so I'm not 100% sure
    uint32_t widthInBytes = ROUND_UP(width * bytesPerPixel, gobWidthInBytes << log2GobsPerBlockX);
    uint32_t alignedHeight = ROUND_UP(height, gobHeightInBytes << log2GobsPerBlockY);


    //uint32_t granularity = 1;//65536;
    uint32_t imageSizeInBytes = widthInBytes * alignedHeight;
    uint32_t size = imageSizeInBytes;//ROUND_UP(imageSizeInBytes, granularity);
    //uint32_t alignment = 0x200000;

    LOG("Aligned image size: %dx%d = %d (%d)", widthInBytes, alignedHeight, imageSizeInBytes, size);

    //this gets us some memory, and the fd to import into cuda
    int memFd = -1;
    bool ret = alloc_memory(context, size, &memFd);
    if (!ret) {
        LOG("alloc_memory failed");
        return false;
    }

    //now export the dma-buf
    uint32_t pitchInBlocks = widthInBytes / (gobWidthInBytes << log2GobsPerBlockX);

    //printf("got gobsPerBlock: %ux%u %u %u %u %d\n", width, height, log2GobsPerBlockX, log2GobsPerBlockY, log2GobsPerBlockZ, pitchInBlocks);
    //duplicate the fd so we don't invalidate it by importing it
    int memFd2 = dup(memFd);
    if (memFd2 == -1) {
        LOG("dup failed");
        goto err;
    }

    struct NvKmsKapiPrivImportMemoryParams nvkmsParams = {
        .memFd = memFd2,
        .surfaceParams = {
            .layout = NvKmsSurfaceMemoryLayoutBlockLinear,
            .blockLinear = {
                .genericMemory = 0,
                .pitchInBlocks = pitchInBlocks,
                .log2GobsPerBlock.x = log2GobsPerBlockX,
                .log2GobsPerBlock.y = log2GobsPerBlockY,
                .log2GobsPerBlock.z = log2GobsPerBlockZ,
            }
        }
    };

    struct drm_nvidia_gem_import_nvkms_memory_params params = {
        .mem_size = imageSizeInBytes,
        .nvkms_params_ptr = (uint64_t) &nvkmsParams,
        .nvkms_params_size = context->driverMajorVersion == 470 ? 0x20 : sizeof(nvkmsParams) //needs to be 0x20 in the 470 series driver
    };
    int drmret = ioctl(context->drmFd, DRM_IOCTL_NVIDIA_GEM_IMPORT_NVKMS_MEMORY, &params);
    if (drmret != 0) {
        LOG("DRM_IOCTL_NVIDIA_GEM_IMPORT_NVKMS_MEMORY failed: %d %d", drmret);
        goto err;
    }

    //export dma-buf
    struct drm_prime_handle prime_handle = {
        .handle = params.handle
    };
    drmret = ioctl(context->drmFd, DRM_IOCTL_PRIME_HANDLE_TO_FD, &prime_handle);
    if (drmret != 0) {
        LOG("DRM_IOCTL_PRIME_HANDLE_TO_FD failed: %d", drmret);
        goto err;
    }

    struct drm_gem_close gem_close = {
        .handle = params.handle
    };
    drmret = ioctl(context->drmFd, DRM_IOCTL_GEM_CLOSE, &gem_close);
    if (drmret != 0) {
        LOG("DRM_IOCTL_GEM_CLOSE failed: %d", drmret);
        goto prime_err;
    }

    image->width = width;
    image->height = height;
    image->nvFd = memFd;
    image->nvFd2 = memFd2; //not sure why we can't close this one, we shouldn't need it after importing the image
    image->drmFd = prime_handle.fd;
    image->mods = DRM_FORMAT_MOD_NVIDIA_BLOCK_LINEAR_2D(0, context->devInfo.sector_layout, context->devInfo.page_kind_generation, context->devInfo.generic_page_kind, log2GobsPerBlockY);
    image->offset = 0;
    image->pitch = widthInBytes;
    image->memorySize = imageSizeInBytes;
    if (channels == 1) {
        image->fourcc = bytesPerChannel == 1 ? DRM_FORMAT_R8 : DRM_FORMAT_R16;
    } else if (channels == 2) {
        image->fourcc = bytesPerChannel == 1 ? DRM_FORMAT_RG88 : DRM_FORMAT_RG1616;
    } else {
        LOG("Unknown fourcc");
        return false;
    }

    LOG("created image: %dx%d %lx %d %x", width, height, image->mods, widthInBytes, imageSizeInBytes);

    return true;

prime_err:
    if (prime_handle.fd > 0) {
        close(prime_handle.fd);
    }

err:
    if (memFd > 0) {
        close(memFd);
    }

    return false;
}
