#define BC7ENC_VERSION "1.08"
#define COMPUTE_SSIM (0)
#if _OPENMP
#include <omp.h>
#endif
#include <cstdint>
#include "bc7enc_rdo/rdo_bc_encoder.h"
#include "bc7enc_rdo/utils.h"

// This is based on main() in test.cpp to make a C function

// Valid formats: 1 3 4 5 7
// BC1 RGB    4bpp   fast, small
// BC3 RGBA   8bpp   fast
// BC4 gray   4bpp   best for grayscale
// BC5 X+Y    8bpp
// BC7 RGB(A) 8bpp   slow, best quality

typedef struct {
    void *data;
    int32_t len;
    int32_t width;
    int32_t height;
    int8_t format; // BCn format to use.
    int8_t bc1_quality; // between 0 and 18
    int8_t bc7_quality; // between 0 and 6
} EncodeBcInput;

typedef struct {
    void *data;
    int32_t len;
    int32_t row_len;
} EncodeBcOutput;

typedef enum {
    NO_ERROR = 0,
    ENCODER_INIT_ERROR = 1,
    ENCODE_ERROR = 2,
    INVALID_SIZE_ERROR = 3,
    UNSUPPORTED_FORMAT_ERROR = 4,
    INVALID_INPUT_LENGTH_ERROR = 5,
    INVALID_OUTPUT_LENGTH_ERROR = 6,
    INVALID_SETTINGS = 7,
} EncodeBcError;

extern "C" EncodeBcError encode_bc(EncodeBcInput &input, EncodeBcOutput &output, bool verbose)
{
	bool quiet_mode = !verbose;

	int max_threads = 1;
#if _OPENMP
	max_threads = std::min(std::max(1, omp_get_max_threads()), 128);
#endif
	
	uint32_t pixel_format_bpp = 8;

	rdo_bc::rdo_bc_params rp;
	rp.m_rdo_max_threads = max_threads;
	rp.m_status_output = !quiet_mode;

    switch(input.format){
        case 1:
            rp.m_dxgi_format = DXGI_FORMAT_BC1_UNORM;
            pixel_format_bpp = 4;
            break;
        case 3:
            rp.m_dxgi_format = DXGI_FORMAT_BC3_UNORM;
            break;
        case 4:
            rp.m_dxgi_format = DXGI_FORMAT_BC4_UNORM;
            pixel_format_bpp = 4;
            break;
        case 5:
            rp.m_dxgi_format = DXGI_FORMAT_BC5_UNORM;
            break;
        case 7:
            // it's already default
            break;
        default:
            return UNSUPPORTED_FORMAT_ERROR;
    }
    rp.m_bc1_quality_level = input.bc1_quality;
    if (((int)rp.m_bc1_quality_level < (int)rgbcx::MIN_LEVEL) || ((int)rp.m_bc1_quality_level > (int)(rgbcx::MAX_LEVEL + 1)))
    {
        fprintf(stderr, "Invalid BC1 quality\n");
        return INVALID_SETTINGS;
    }
    rp.m_bc7_uber_level = input.bc7_quality;
    if ((rp.m_bc7_uber_level < 0) || (rp.m_bc7_uber_level > 6)) //BC7ENC_MAX_UBER_LEVEL))
    {
        fprintf(stderr, "Invalid BC7 quality\n");
        return INVALID_SETTINGS;
    }

    int32_t width = input.width;
    int32_t height = input.height;
    if(width == 0 || height == 0){
        return INVALID_SIZE_ERROR;
    }

    utils::image_u8 source_image;

    // TODO: avoid a copy somehow
    source_image.init(width, height);
    int32_t input_data_size = width * height * sizeof(uint32_t);
    if(input_data_size != input.len){
        return INVALID_INPUT_LENGTH_ERROR;
    }
    memcpy(source_image.get_pixels().data(), input.data, input.len);

	if (rp.m_status_output)
	{
		printf("Max threads: %u\n", max_threads);
		printf("Supports bc7e.ispc: %u\n", SUPPORT_BC7E);
	}

	clock_t overall_start_t = clock();

	rdo_bc::rdo_bc_encoder encoder;
	if (!encoder.init(source_image, rp))
	{
		fprintf(stderr, "rdo_bc_encoder::init() failed!\n");
		return ENCODER_INIT_ERROR;
	}

	if (rp.m_status_output)
	{
		if (encoder.get_has_alpha())
			printf("Source image has an alpha channel.\n");
		else
			printf("Source image is opaque.\n");
	}

	if (!encoder.encode())
	{
		fprintf(stderr, "rdo_bc_encoder::encode() failed!\n");
		return ENCODE_ERROR;
	}

	clock_t overall_end_t = clock();

	if (rp.m_status_output)
		printf("Total processing time: %f secs\n", (double)(overall_end_t - overall_start_t) / CLOCKS_PER_SEC);

	// Compress the output data losslessly using Deflate
	const uint32_t output_data_size = encoder.get_total_blocks_size_in_bytes();
	// const uint32_t pre_rdo_comp_size = get_deflate_size(encoder.get_prerdo_blocks(), output_data_size);

	// float pre_rdo_lz_bits_per_texel = (pre_rdo_comp_size * 8.0f) / encoder.get_total_texels();

	// if (rp.m_status_output)
	// {
	// 	printf("Output data size: %u, LZ (Deflate) compressed file size: %u, %3.2f bits/texel\n",
	// 		output_data_size,
	// 		(uint32_t)pre_rdo_comp_size,
	// 		pre_rdo_lz_bits_per_texel);
	// }
			
	// const uint32_t comp_size = get_deflate_size(encoder.get_blocks(), output_data_size);
		
	// float lz_bits_per_texel = comp_size * 8.0f / encoder.get_total_texels();

	// if (rp.m_status_output)
	// 	printf("RDO output data size: %u, LZ (Deflate) compressed file size: %u, %3.2f bits/texel, savings: %3.2f%%\n", output_data_size, (uint32_t)comp_size, lz_bits_per_texel, 
	// 		(lz_bits_per_texel != pre_rdo_lz_bits_per_texel) ? 100.0f - (lz_bits_per_texel * 100.0f) / pre_rdo_lz_bits_per_texel : 0.0f);

    if(output_data_size != output.len){
        fprintf(stderr, "Output length is %d, expected %d\n", output.len, output_data_size);
        return INVALID_OUTPUT_LENGTH_ERROR;
    }
    // TODO: avoid a copy
	memcpy(output.data, encoder.get_blocks(), output_data_size);
    
	return NO_ERROR;
}