// David Eberly, Geometric Tools, Redmond WA 98052
// Copyright (c) 1998-2025
// Distributed under the Boost Software License, Version 1.0.
// https://www.boost.org/LICENSE_1_0.txt
// https://www.geometrictools.com/License/Boost/LICENSE_1_0.txt
// File Version: 8.0.2025.07.17

#include <Graphics/GL46/GTGraphicsGL46PCH.h>
#include <Graphics/GL46/GLSLComputeProgram.h>
#include <Graphics/GL46/GLSLProgramFactory.h>
#include <Graphics/GL46/GLSLVisualProgram.h>
#include <Graphics/GL46/GLSLShader.h>
using namespace gte;

std::string GLSLProgramFactory::defaultVersion = "#version 460";
std::string GLSLProgramFactory::defaultVSEntry = "main";
std::string GLSLProgramFactory::defaultPSEntry = "main";
std::string GLSLProgramFactory::defaultGSEntry = "main";
std::string GLSLProgramFactory::defaultCSEntry = "main";
uint32_t GLSLProgramFactory::defaultFlags = 0;  // unused in GLSL for now

GLSLProgramFactory::GLSLProgramFactory()
{
    version = defaultVersion;
    vsEntry = defaultVSEntry;
    psEntry = defaultPSEntry;
    gsEntry = defaultGSEntry;
    csEntry = defaultCSEntry;
    flags = defaultFlags;
}

std::shared_ptr<VisualProgram> GLSLProgramFactory::CreateFromNamedSources(
    std::string const&, std::string const& vsSource,
    std::string const&, std::string const& psSource,
    std::string const&, std::string const& gsSource)
{
    if (vsSource == "" || psSource == "")
    {
        LogError("A program must have a vertex shader and a pixel shader.");
    }

    GLuint vsHandle = Compile(GL_VERTEX_SHADER, vsSource);
    if (vsHandle == 0)
    {
        return nullptr;
    }

    GLuint psHandle = Compile(GL_FRAGMENT_SHADER, psSource);
    if (psHandle == 0)
    {
        return nullptr;
    }

    GLuint gsHandle = 0;
    if (gsSource != "")
    {
        gsHandle = Compile(GL_GEOMETRY_SHADER, gsSource);
        if (gsHandle == 0)
        {
            return nullptr;
        }
    }

    GLuint programHandle = glCreateProgram();
    if (programHandle == 0)
    {
        LogError("Program creation failed.");
    }

    glAttachShader(programHandle, vsHandle);
    glAttachShader(programHandle, psHandle);
    if (gsHandle > 0)
    {
        glAttachShader(programHandle, gsHandle);
    }

    if (!Link(programHandle))
    {
        glDetachShader(programHandle, vsHandle);
        glDeleteShader(vsHandle);
        glDetachShader(programHandle, psHandle);
        glDeleteShader(psHandle);
        if (gsHandle)
        {
            glDetachShader(programHandle, gsHandle);
            glDeleteShader(gsHandle);
        }
        glDeleteProgram(programHandle);
        return nullptr;
    }

    std::shared_ptr<GLSLVisualProgram> program =
        std::make_shared<GLSLVisualProgram>(programHandle, vsHandle,
        psHandle, gsHandle);

    GLSLReflection const& reflector = program->GetReflector();
    auto vshader = std::make_shared<GLSLShader>(reflector, GT_VERTEX_SHADER, GLSLReflection::ReferenceType::VERTEX);
    auto pshader = std::make_shared<GLSLShader>(reflector, GT_PIXEL_SHADER, GLSLReflection::ReferenceType::PIXEL);
    program->SetVertexShader(vshader);
    program->SetPixelShader(pshader);
    if (gsHandle > 0)
    {
        auto gshader = std::make_shared<GLSLShader>(reflector, GT_GEOMETRY_SHADER, GLSLReflection::ReferenceType::GEOMETRY);
        program->SetGeometryShader(gshader);
    }
    return program;
}

std::shared_ptr<ComputeProgram> GLSLProgramFactory::CreateFromNamedSource(
    std::string const&, std::string const& csSource)
{
    if (csSource == "")
    {
        LogError("A program must have a compute shader.");
    }

    GLuint csHandle = Compile(GL_COMPUTE_SHADER, csSource);
    if (csHandle == 0)
    {
        return nullptr;
    }

    GLuint programHandle = glCreateProgram();
    if (programHandle == 0)
    {
        LogError("Program creation failed.");
    }

    glAttachShader(programHandle, csHandle);

    if (!Link(programHandle))
    {
        glDetachShader(programHandle, csHandle);
        glDeleteShader(csHandle);
        glDeleteProgram(programHandle);
        return nullptr;
    }

    auto program = std::make_shared<GLSLComputeProgram>(programHandle, csHandle);
    GLSLReflection const& reflector = program->GetReflector();
    auto cshader = std::make_shared<GLSLShader>(reflector, GT_COMPUTE_SHADER, GLSLReflection::ReferenceType::COMPUTE);
    program->SetComputeShader(cshader);
    return program;
}

GLuint GLSLProgramFactory::Compile(GLenum shaderType, std::string const& source)
{
    GLuint handle = glCreateShader(shaderType);
    if (handle > 0)
    {
        // Prepend to the definitions
        // 1. The version of the GLSL program; for example, "#version 400".
        // 2. A define for the matrix-vector multiplication convention if
        //    it is selected as GTE_USE_MAT_VEC: "define GTE_USE_MAT_VEC 1"
        //    else "define GTE_USE_MAT_VEC 0".
        // 3. "layout(std140, *_major) uniform;" for either row_major or
        //    column_major to select default for all uniform matrices and
        //    select std140 layout.
        // 4. "layout(std430, *_major) buffer;" for either row_major or
        //    column_major to select default for all buffer matrices and
        //    select std430 layout.
        // Append to the definitions the source-code string.
        auto const& definitions = defines.Get();
        std::vector<std::string> glslDefines;
        glslDefines.reserve(definitions.size() + 5);
        glslDefines.push_back(version + "\n");
#if defined(GTE_USE_VEC_MAT)
        glslDefines.push_back("#define GTE_USE_MAT_VEC 0\n");
#else
        glslDefines.push_back("#define GTE_USE_MAT_VEC 1\n");
#endif
#if defined(GTE_USE_COL_MAJOR)
        glslDefines.push_back("layout(std140, column_major) uniform;\n");
        glslDefines.push_back("layout(std430, column_major) buffer;\n");
#else
        glslDefines.push_back("layout(std140, row_major) uniform;\n");
        glslDefines.push_back("layout(std430, row_major) buffer;\n");
#endif
        for (auto const& d : definitions)
        {
            glslDefines.push_back("#define " + d.first + " " + d.second + "\n");
        }
        glslDefines.push_back(source);

        // Repackage the definitions for glShaderSource.
        std::vector<GLchar const*> code;
        code.reserve(glslDefines.size());
        for (auto const& d : glslDefines)
        {
            code.push_back(d.c_str());
        }

        glShaderSource(handle, static_cast<GLsizei>(code.size()), &code[0], nullptr);

        glCompileShader(handle);
        GLint status;
        glGetShaderiv(handle, GL_COMPILE_STATUS, &status);
        if (status == GL_TRUE)
        {
            return handle;
        }

        GLint logLength;
        glGetShaderiv(handle, GL_INFO_LOG_LENGTH, &logLength);
        if (logLength > 0)
        {
            std::vector<GLchar> log(logLength);
            GLsizei numWritten;
            glGetShaderInfoLog(handle, static_cast<GLsizei>(logLength), &numWritten, log.data());
            std::string message(log.data());
            LogError("Compile failed:\n" + message);
        }
        else
        {
            LogError("Invalid info log length.");
        }
    }
    else
    {
        LogError("Cannot create shader.");
    }
}

bool GLSLProgramFactory::Link(GLuint programHandle)
{
    glLinkProgram(programHandle);
    int32_t status;
    glGetProgramiv(programHandle, GL_LINK_STATUS, &status);
    if (status == GL_TRUE)
    {
        return true;
    }

    int32_t logLength;
    glGetProgramiv(programHandle, GL_INFO_LOG_LENGTH, &logLength);
    if (logLength > 0)
    {
        std::vector<GLchar> log(logLength);
        int32_t numWritten;
        glGetProgramInfoLog(programHandle, logLength, &numWritten, log.data());
        std::string message(log.data());
        LogError("Link failed:\n" + message);
    }
    else
    {
        LogError("Invalid info log length.");
    }
}

