enum ImageFormat {
    FORMAT_UNKNOWN = 0,
    FORMAT_BYTE,     // 8-bit unsigned
    FORMAT_SHORT,    // 16-bit unsigned
    FORMAT_HALF,     // 16-bit float
    FORMAT_FLOAT,     // 32-bit float
    MAX_ENUM = 0x7FFFFFFF
};

// image_decoder.cpp
#include <OpenImageIO/imageio.h>
#include <OpenImageIO/filesystem.h>
#include <OpenImageIO/imagebufalgo.h>
#include <cstring>
#include <exception>
#include <array>
#include <string>

using namespace OIIO;

static ImageFormat type_to_format(TypeDesc type) {
    switch (type.basetype) {
        case TypeDesc::UINT8:  return FORMAT_BYTE;
        case TypeDesc::UINT16: return FORMAT_SHORT;
        case TypeDesc::HALF:   return FORMAT_HALF;
        case TypeDesc::FLOAT:  return FORMAT_FLOAT;
        default:               return FORMAT_UNKNOWN;
    }
}

static const char* type_to_string(TypeDesc type) {
    switch (type.basetype) {
        case TypeDesc::UINT8:  return "FORMAT_BYTE";
        case TypeDesc::UINT16: return "FORMAT_SHORT";
        case TypeDesc::HALF:   return "FORMAT_HALF";
        case TypeDesc::FLOAT:  return "FORMAT_FLOAT";
        default:               return "FORMAT_UNKNOWN";
    }
}

static TypeDesc format_to_type(enum ImageFormat format) {
    switch (format) {
        case FORMAT_BYTE:  return TypeDesc::UINT8;  // 8-bit unsigned
        case FORMAT_SHORT: return TypeDesc::UINT16; // 16-bit unsigned
        case FORMAT_HALF:  return TypeDesc::HALF;   // 16-bit float
        case FORMAT_FLOAT: return TypeDesc::FLOAT;  // 32-bit float
        default:           return TypeDesc::UNKNOWN; // Unknown format
    }
}

extern "C" {

int oiio_image_decode(const uint8_t* input, size_t input_len,
                uint8_t* output, size_t output_len,
                int32_t desired_channels, ImageFormat format, char* is_BGR, const char* file_name) {
    if (!input || !output || desired_channels < 1 || desired_channels > 4) {
        fprintf(stderr, "Invalid input parameters\n");
        return 0;
    }

    uint8_t* decoded = output;
    bool free_decoded = false;

    try {
        auto mem_reader = Filesystem::IOMemReader(input, input_len);

        auto in = ImageInput::open(file_name, nullptr, &mem_reader);
        if (!in) {
            fprintf(stderr, "Open failed: %s\n", geterror().c_str());
            return 0;
        }

        const ImageSpec& spec = in->spec();
        *is_BGR = spec.channelnames[0] == "B"? 1:0;
        const size_t required = spec.width * spec.height * desired_channels;
        const TypeDesc ftype = format_to_type(format);
        if (output_len != required * ftype.size()) {
            fprintf(stderr, "Buffer size mismatch: Needed %zu, got %u\n", 
                    required * ftype.size(), output_len);
            in->close();
            return 0;
        }

        if(spec.nchannels != desired_channels){
            // have a separate buffer for the decoded image, free later
            decoded = (uint8_t*)malloc(spec.image_bytes());
            free_decoded = true;
        }

        if (!in->read_image(0, 0, 0, spec.nchannels, ftype, decoded)) {
            fprintf(stderr, "Read error: %s\n", in->geterror().c_str());
            in->close();
            if(free_decoded) free(decoded);
            return 0;
        }
        
        if(spec.nchannels < desired_channels){
            // expand channels
            ImageSpec src_spec(spec.width, spec.height, spec.nchannels, ftype);
            ImageBuf src_buf(src_spec, decoded);
            ImageSpec dst_spec(spec.width, spec.height, desired_channels, ftype);
            ImageBuf dst_buf(dst_spec, output);

            const float channel_values[] = {0.0f, 0.0f, 0.0f, 1.0f};
            // -1 below fills with the values above
            std::array<int, 4> channel_order = {-1, -1, -1, -1};
            for(int i=0; i<spec.nchannels; i++){
                channel_order[i] = i;
            }
            if(spec.nchannels == 1 && desired_channels >= 3){
                // single channel exception: make it greyscale RGB/RGBA
                channel_order = {0,0,0,-1};
            }

            if (!ImageBufAlgo::channels(dst_buf, src_buf, 
                                    desired_channels,
                                    channel_order, channel_values)) {
                fprintf(stderr, "Channel conversion failed: %s\n", dst_buf.geterror().c_str());
                in->close();
                if(free_decoded) free(decoded);
                return 0;
            }
            uint8_t* pixels = (uint8_t*)dst_buf.localpixels();
            if(pixels != output){
                // it turns out channels() always resets the buffer
                // so it's no longer wrapping output
                memcpy(output, pixels, dst_spec.image_bytes());
            }
        }

        in->close();
        if(free_decoded) free(decoded);
        return 1;
    }
    catch (const std::exception& e) {
        fprintf(stderr, "Exception: %s\n", e.what());
        if(free_decoded) free(decoded);
        return 0;
    }
    catch (...) {
        fprintf(stderr, "Unknown exception occurred\n");
        if(free_decoded) free(decoded);
        return 0;
    }
}

int oiio_image_get_attributes(const uint8_t* input, size_t input_len,
                        int32_t* width, int32_t* height, int32_t* channels,
                        enum ImageFormat* format, const char* file_name) {
    if (!input) {
        fprintf(stderr, "Null input buffer\n");
        return 0;
    }

    try {
        auto mem_reader = Filesystem::IOMemReader(input, input_len);

        auto in = ImageInput::open(file_name, nullptr, &mem_reader);
        if (!in) {
            fprintf(stderr, "Open failed: %s\n", geterror().c_str());
            return 0;
        }

        const ImageSpec& spec = in->spec();
        if (width) *width = spec.width;
        if (height) *height = spec.height;
        if (channels) *channels = spec.nchannels;
        if (format) *format = type_to_format(spec.format);

        in->close();
        return 1;
    }
    catch (const std::exception& e) {
        fprintf(stderr, "Exception: %s\n", e.what());
        return 0;
    }
    catch (...) {
        fprintf(stderr, "Unknown exception occurred\n");
        return 0;
    }
}

} // extern "C"



// // image_to_ppm.c
// #include <stdio.h>
// #include <stdlib.h>

// int main(int argc, char** argv) {
//     if (argc != 2) {
//         fprintf(stderr, "Usage: %s <image-file>\n", argv[0]);
//         return 1;
//     }

//     // Read input file
//     FILE* file = fopen(argv[1], "rb");
//     if (!file) {
//         perror("Failed to open file");
//         return 1;
//     }

//     fseek(file, 0, SEEK_END);
//     long file_size = ftell(file);
//     fseek(file, 0, SEEK_SET);

//     uint8_t* input = (uint8_t*)malloc(file_size);
//     if (!input) {
//         fclose(file);
//         fprintf(stderr, "Memory allocation failed\n");
//         return 1;
//     }

//     if (fread(input, 1, file_size, file) != file_size) {
//         fclose(file);
//         free(input);
//         fprintf(stderr, "File read error\n");
//         return 1;
//     }
//     fclose(file);

//     // Get image attributes
//     int width, height, channels;
//     enum ImageFormat format;
//     if (!image_get_attributes(input, file_size, &width, &height, &channels, &format)) {
//         free(input);
//         fprintf(stderr, "Unsupported image format\n");
//         return 1;
//     }

//     // Allocate output buffer for 3-channel RGB
//     size_t output_size = width * height * 3;
//     uint8_t* pixels = (uint8_t*)malloc(output_size);
//     if (!pixels) {
//         free(input);
//         fprintf(stderr, "Output buffer allocation failed\n");
//         return 1;
//     }

//     // Decode to 3 channels (RGB)
//     if (!image_decode(input, file_size, pixels, output_size, 3, format)) {
//         free(input);
//         free(pixels);
//         fprintf(stderr, "Image decoding failed\n");
//         return 1;
//     }

//     free(input);

//     // Output PPM header
//     printf("P6\n%d %d\n255\n", width, height);
    
//     // Output raw pixel data
//     fwrite(pixels, 1, output_size, stdout);

//     free(pixels);
//     return 0;
// }

bool convertRGBToRGBA(const std::string& filename, 
                     void* output_buffer, 
                     size_t buffer_length,
                     std::string& error_msg) {
    // Load source image
    ImageBuf src_buf(filename);
    if (!src_buf.read()) {
        error_msg = "Failed to load image: " + src_buf.geterror();
        return false;
    }

    const ImageSpec& src_spec = src_buf.spec();

    // Verify source is RGB
    if (src_spec.nchannels != 3) {
        error_msg = "Image is not RGB (has " + std::to_string(src_spec.nchannels) + " channels)";
        return false;
    }

    // Calculate required buffer size
    TypeDesc data_type = src_spec.format;
    const size_t required_size = src_spec.width * src_spec.height * 4 * data_type.size();
    if (buffer_length < required_size) {
        error_msg = "Buffer too small. Required: " + std::to_string(required_size) +
                    ", provided: " + std::to_string(buffer_length);
        return false;
    }

    // Prepare destination buffer wrapped in ImageBuf
    ImageSpec dst_spec(src_spec.width, src_spec.height, 4, data_type);
    ImageBuf dst_buf(dst_spec, output_buffer);

    // Set up channel remapping with alpha=1.0
    const int channel_order[] = {0, 1, 2, -1};  // Source RGB, new alpha
    const float channel_values[] = {0.0f, 0.0f, 0.0f, 1.0f};

    // Perform channel conversion
    if (!ImageBufAlgo::channels(dst_buf, src_buf, 
                               /* channel count */ 4,
                               channel_order, channel_values)) {
        error_msg = "Channel conversion failed: " + dst_buf.geterror();
        return false;
    }

    return true;
}