/********************************************************
    © 2020 Continuum Graphics LLC. All Rights Reserved
 ********************************************************/

#if !defined _DIFFUSELIGHTING_
#define _DIFFUSELIGHTING_

#include "/../ContinuumLib/Uniform/ShadowDistortion.glsl"
#include "/../ContinuumLib/Common/Shadows.glsl"
#include "/../ContinuumLib/Utilities/SphericalHarmonics.glsl"
#include "/../InternalLib/Fragment/SpecularLighting.fsh"

const vec2 bilateralOffsets[81] = vec2[81](
    vec2(0.0, 0.0),
    vec2(-1.0, 0.0),
    vec2(-1.0, 1.0),
    vec2(0.0, 1.0),
    vec2(1.0, 1.0),
    vec2(1.0, 0.0),
    vec2(1.0, -1.0),
    vec2(0.0, -1.0),
    vec2(-1.0, -1.0),
    vec2(-2.0, -1.0),
    vec2(-2.0, 0.0),
    vec2(-2.0, 1.0),
    vec2(-2.0, 2.0),
    vec2(-1.0, 2.0),
    vec2(0.0, 2.0),
    vec2(1.0, 2.0),
    vec2(2.0, 2.0),
    vec2(2.0, 1.0),
    vec2(2.0, 0.0),
    vec2(2.0, -1.0),
    vec2(2.0, -2.0),
    vec2(1.0, -2.0),
    vec2(0.0, -2.0),
    vec2(-1.0, -2.0),
    vec2(-2.0, -2.0),
    vec2(2.0, -3.0),
    vec2(1.0, -3.0),
    vec2(0.0, -3.0),
    vec2(-1.0, -3.0),
    vec2(-2.0, -3.0),
    vec2(-3.0, -3.0),
    vec2(-3.0, -2.0),
    vec2(-3.0, -1.0),
    vec2(-3.0, 0.0),
    vec2(-3.0, 1.0),
    vec2(-3.0, 2.0),
    vec2(-3.0, 3.0),
    vec2(-2.0, 3.0),
    vec2(-1.0, 3.0),
    vec2(0.0, 3.0),
    vec2(1.0, 3.0),
    vec2(2.0, 3.0),
    vec2(3.0, 3.0),
    vec2(3.0, 2.0),
    vec2(3.0, 1.0),
    vec2(3.0, 0.0),
    vec2(3.0, -1.0),
    vec2(3.0, -2.0),
    vec2(3.0, -3.0),
    vec2(3.0, -4.0),
    vec2(2.0, -4.0),
    vec2(1.0, -4.0),
    vec2(0.0, -4.0),
    vec2(-1.0, -4.0),
    vec2(-2.0, -4.0),
    vec2(-3.0, -4.0),
    vec2(-4.0, -4.0),
    vec2(-4.0, -3.0),
    vec2(-4.0, -2.0),
    vec2(-4.0, -1.0),
    vec2(-4.0, 0.0),
    vec2(-4.0, 1.0),
    vec2(-4.0, 2.0),
    vec2(-4.0, 3.0),
    vec2(-4.0, 4.0),
    vec2(-3.0, 4.0),
    vec2(-2.0, 4.0),
    vec2(-1.0, 4.0),
    vec2(0.0, 4.0),
    vec2(1.0, 4.0),
    vec2(2.0, 4.0),
    vec2(3.0, 4.0),
    vec2(4.0, 4.0),
    vec2(4.0, 3.0),
    vec2(4.0, 2.0),
    vec2(4.0, 1.0),
    vec2(4.0, 0.0),
    vec2(4.0, -1.0),
    vec2(4.0, -2.0),
    vec2(4.0, -3.0),
    vec2(4.0, -4.0)
);

vec3 GGXDiffuse(vec3 albedo, float roughness, float NoL, float NoV, float NoH, float LoV) {
    float facing = LoV * 0.5 + 0.5;
    float roughFacet = facing * (0.9 - 0.4 * facing) * ((0.5 + NoH) / NoH);
    float smoothFacet = 1.05 * (1.0 - pow5(1.0 - NoL)) * (1.0 - pow5(1.0 - NoV));

    float single = (1.0 / PI) * mix(smoothFacet, roughFacet, roughness);
    float multi = 0.1159 * roughness;

    return albedo * (single + albedo * multi);
}


vec3 CalculateTorchLight(float torchLightmap){    
    float lightDistance = (1.0 - torchLightmap) * 15.0 + 1.0;
    float atten = torchLightmap / (lightDistance * lightDistance);

    vec3 lightColor = blackbody(LIGHT_TEMP);
    #if defined WORLD_1
        atten += NETHER_AMBIENT_LUM;
    #endif

    return atten * LIGHT_LUM * lightColor; // 800 lumens
}

float SampleAO(vec2 coord, vec3 normal){
    //if (texcoord.x < 0.5)return textureLod(colortex3, coord, 0).a;

    float totalWeight = 0.0;
    float ao = 0.0;

    vec2 minCoord = 0.5 + pixelSize;
    vec2 maxCoord = 0.75 - pixelSize;

    for (int i = 0; i < 9; ++i){
        vec2 offset = pixelSize * bilateralOffsets[i];

        vec2 aoCoord = coord + offset;
             aoCoord = clamp(aoCoord, minCoord, maxCoord);
        vec4 aoTex = texture(colortex3, aoCoord);

        vec3 aoNormal = aoTex.rgb * 2.0 - 1.0;
        float normalWeight = abs(dot(aoNormal, normal));
              normalWeight = exp2(normalWeight * 64.0 - 64.0);

        float weight = normalWeight;

        ao += aoTex.a * weight;
        totalWeight += weight;
    }

    return ao / totalWeight;
}

vec3 SampleRSM(vec2 coordG, vec2 coordN, vec3 normal, float depth, float NoV){
    //if (texcoord.x < 0.5)return textureLod(colortex3, coord, 0).a;

    float totalWeight = 0.0;
    vec3 rsm = vec3(0.0);

    vec2 minCoordRsm = vec2(0.75, 0.5) + pixelSize;
    vec2 maxCoordRsm = vec2(1.0, 0.75) - pixelSize;

    vec2 minCoordN = 0.5 + pixelSize;
    vec2 maxCoordN = 0.75 - pixelSize;

    float linearDepth = -ScreenToViewSpaceDepth(depth);

    for (int i = 0; i < 81; ++i){
        vec2 offset = pixelSize * bilateralOffsets[i];

        float weight = length(bilateralOffsets[i]);
              weight = exp2(-weight * weight * 0.1);

        vec2 rsmCoord = coordG + offset;
             rsmCoord = clamp(rsmCoord, minCoordRsm, maxCoordRsm);

        vec2 nCoord = coordN + offset;
             nCoord = clamp(nCoord, minCoordN, maxCoordN);

        vec4 rsmTex = texture(colortex3, rsmCoord);
        vec4 nTex = texture(colortex3, nCoord);

        vec3 loddedNormal = nTex.rgb * 2.0 - 1.0;
        float normalWeight = abs(dot(loddedNormal, normal));
              normalWeight = exp2(normalWeight * 64.0 - 64.0);

        float offsetLinearDepth = rsmTex.a;
        float depthWeight = exp2(-abs(linearDepth - offsetLinearDepth) * 4.0 * NoV);

        weight = normalWeight * depthWeight * weight;

        rsm += rsmTex.rgb * weight;
        totalWeight += weight;
    }

    return rsm / totalWeight;
}

#if defined Composite1_glsl
    float rayTraceShadow(vec3 hitPixel, vec3 rayOrigin, vec3 rayDir, vec3 viewDir, float NoL, float dither, bool foliageMask, bool handMask, inout float SSS) {
        if (abs(NoL) < 0.05 || handMask) {
            SSS = float(foliageMask);
            return 1.0;
        }

        const int steps = 8;
        const float maxStepSize = 0.025 / float(steps);
        const float minDepthDifference = 0.1;

        const float densityPerStep = 128.0;

        float maxDist = far * sqrt(3.);

        float rayLength = ((rayOrigin.z + rayDir.z * maxDist) > -near) ?
                        (-near - rayOrigin.z) / rayDir.z : maxDist;

        vec3 direction = normalize(ViewSpaceToScreenSpace(rayDir * rayLength + rayOrigin) - hitPixel);

        vec3 increment = direction * maxStepSize;
        vec3 rayPosition = hitPixel + increment * (dither + 1.);

        float depth = texture(depthtex1, rayPosition.xy).x;

        float sampleDepth = 0.0;
        float shadow = 1.0;

        for (int i = 0; i < steps; i++) {

            if (clamp01(rayPosition.xy) != rayPosition.xy) {
                break;
            }

            if (depth < rayPosition.z) {
                float linearZ = ScreenToViewSpaceDepth(rayPosition.z);
                float linearD = ScreenToViewSpaceDepth(depth);

                float dist = linearD - linearZ;

                if (dist / -linearD < minDepthDifference && dist > 0.0) {
                    sampleDepth = float(i);
                    
                    if (!foliageMask && dist < minDepthDifference) {
                        shadow = 0.0;
                        break;
                    }
                }
            }

            rayPosition += increment;
            depth = texture(depthtex1, rayPosition.xy).x;
        }

        if (foliageMask) {
            SSS = exp(-sampleDepth * maxStepSize * densityPerStep);
        }

        return shadow;
    }
#endif

vec3 CalculateDiffuseLighting(mat2x3 positions, vec3 pixelPosition, vec3 diffuseColor, vec3 normal, vec3 worldNormal, vec3 viewVector, float depth, float roughness, float f0, mat2x3 metalIOR, vec2 lightmaps, float dither, float pomShadow, bool isFoliage, bool isLava, bool isHand, out vec3 sunIrradiance) {
    // Calculate Shadows and multiply by sunlight color
    vec3 shadowPosition = WorldSpaceToShadowSpace(positions[1]);
    vec4 texSampleCoords = vec4(texcoord * 0.25 + 0.5, texcoord * 0.25 + vec2(0.75, 0.5));

    #if defined WORLD0
    float cloudShadow = CalculateCloudShadows(positions[1] + cameraPosition, wLightVector, 4);

    vec3 shadows = CalculateShadows(shadowPosition, positions, normal, worldNormal, lightVector, dither, isFoliage, isLava) * pomShadow;
         shadows *= cloudShadow;

    // Get Light Coeffs
    vec3 skyIrradiance = shUnprojectCosineLobe(shR, shG, shB, worldNormal) * rPI;
         sunIrradiance = sunIlluminanceVert;
    #endif

    float NoL = dot(normal, lightVector);
    float NoV = max(1e-6, dot(normal, viewVector));
    float LoV = length(lightVector + viewVector);
    float LoH = 0.5 * LoV;
    float NoH = max(-fsign(NoV), (NoL + NoV) / LoV);

    float alpha = roughness * roughness;
    float alpha2 = alpha * alpha;

    #if defined WORLD0
    #if defined Composite1_glsl
        float tracedSSS = 0.0;
        vec3 flatWorldNormal = clamp(normalize(cross(dFdx(positions[1]), dFdy(positions[1]))), -1.0, 1.0);
        float tracedShadow = rayTraceShadow(pixelPosition, positions[0], lightVector, viewVector, dot(flatWorldNormal, wLightVector), dither, isFoliage, isHand, tracedSSS);
        vec3 SSS = tracedSSS * diffuseColor * rTAU * sunIrradiance * shadows * (NoL * -0.5 + 0.5);

        shadows *= tracedShadow;
    #else
        vec3 SSS = vec3(0.0);
    #endif
    #endif

    #if defined WORLD0

    vec3 F = FMaster(f0, LoH, metalIOR, diffuseColor);
    vec3 D = GGXDiffuse(diffuseColor, roughness, NoL, NoV, NoH, LoV) * clamp01(NoL);

    vec3 sunLight = D * (1.0 - F);
         sunLight *= sunIrradiance * shadows;
         sunLight += SSS;
    #endif

    // Calculate the gray scale ao
    float ao = SampleAO(texSampleCoords.xy, normal);

    #if defined WORLD0
    float FSky = FSchlickGaussian(f0, NoV);
    vec3 DSky = GGXDiffuse(diffuseColor, roughness, 1.0, NoV, NoV, 0.0) * (1.0 - FSky);
    
    vec3 skyLight = skyIrradiance * DSky;
         skyLight *= ao * pow2(lightmaps.y);

    #if defined Composite1_glsl
        vec3 rsm = SampleRSM(texSampleCoords.zw, texSampleCoords.xy, normal, depth, NoV) * cloudShadow * sunIrradiance * diffuseColor; //CalculateGI(positions, normal, dither, lightmaps.y, isFoliage) * sunIrradiance * diffuseColor;
    #else
        vec3 rsm = vec3(0.0);
    #endif

    sunIrradiance *= shadows;
    #endif

    // Calculate Emissive Surfaces

    // Calculate Torch Lights
    vec3 torchLight = CalculateTorchLight(lightmaps.x) * diffuseColor * mix(ao, 1.0, pow8(lightmaps.x));

    #if defined WORLD0
        return (sunLight + skyLight + torchLight + rsm);
    #elif defined WORLD_1
        return torchLight;
    #else
        return ao * diffuseColor * END_AMBIENT_LUM * rPI + torchLight;
    #endif
}

#endif
