
#include "/lib/universal/universal.glsl"

cInt maxSubdivisions = MAX_WATER_SUBDIVISIONS;

layout (triangles) in;
layout (triangle_strip, max_vertices = (maxSubdivisions*(maxSubdivisions - 1)/2)*2 + maxSubdivisions*3) out;

uniform sampler2D gaux4;

uniform sampler2D noisetex;

uniform mat4 gbufferProjection, gbufferProjectionInverse;
uniform mat4 gbufferModelViewInverse;

uniform int isEyeInWater;

//Geometry Inputs
in mat3[] out_tbn;
in vec3[] out_tint;
in vec3[] out_worldPosition;
in vec3[] out_viewPosition;
in vec3[] out_vertexNormal;
in vec2[] out_textureCoordinate;
in vec2[] out_lightmapCoordinate;
in float[] out_ao;
in float[] out_id;

//Geometry Outputs
out mat3 tbn;
out vec3 tint;
out vec3 worldPosition;
out vec3 viewPosition;
out vec3 vertexNormal;
out vec2 textureCoordinate;
out vec2 lightmapCoordinate;
out float ao;
out float id;

#include "/lib/shared/surface/water/constants.glsl"
#include "/lib/shared/surface/water/waves.glsl"

vec4 projectVertex(vec3 position) {
	return vec4(gl_ProjectionMatrix[0].x, gl_ProjectionMatrix[1].y, gl_ProjectionMatrix[2].zw) * position.xyzz + gl_ProjectionMatrix[3] + vec4(gl_ProjectionMatrix[2].xy * position.z, 0.0, 0.0);
}

float SignExtract(float x) {
    return uintBitsToFloat((floatBitsToUint(x) & 0x80000000u) | floatBitsToUint(1.0));
}

mat3 GetRotationMatrix(vec3 from, vec3 to) {
    float cosine = dot(from, to);

    float tmp = SignExtract(cosine);
          tmp = 1.0 / (tmp + cosine);

    vec3 axis = cross(to, from);
    vec3 tmpv = axis * tmp;

    return mat3(
        axis.x * tmpv.x + cosine, axis.x * tmpv.y - axis.z, axis.x * tmpv.z + axis.y,
        axis.y * tmpv.x + axis.z, axis.y * tmpv.y + cosine, axis.y * tmpv.z - axis.x,
        axis.z * tmpv.x - axis.y, axis.z * tmpv.y + axis.x, axis.z * tmpv.z + cosine
    );
}

void main() {
    vec4 position_v0 = gl_in[1].gl_Position - gl_in[0].gl_Position;
    vec4 position_v1 = gl_in[1].gl_Position - gl_in[2].gl_Position;

    mat3 tbn_v0 = out_tbn[1] - out_tbn[0]; 
    mat3 tbn_v1 = out_tbn[1] - out_tbn[2];

    vec3 tint_v0 = out_tint[1] - out_tint[0]; 
    vec3 tint_v1 = out_tint[1] - out_tint[2];

    vec3 worldPosition_v0 = out_worldPosition[1] - out_worldPosition[0]; 
    vec3 worldPosition_v1 = out_worldPosition[1] - out_worldPosition[2];

    vec3 viewPosition_v0 = out_viewPosition[1] - out_viewPosition[0]; 
    vec3 viewPosition_v1 = out_viewPosition[1] - out_viewPosition[2];

    vec3 flatNormal_v0 = out_vertexNormal[1] - out_vertexNormal[0]; 
    vec3 flatNormal_v1 = out_vertexNormal[1] - out_vertexNormal[2];

    vec2 textureCoordinate_v0 = out_textureCoordinate[1] - out_textureCoordinate[0]; 
    vec2 textureCoordinate_v1 = out_textureCoordinate[1] - out_textureCoordinate[2];

    vec2 lightmapCoordinate_v0 = out_lightmapCoordinate[1] - out_lightmapCoordinate[0]; 
    vec2 lightmapCoordinate_v1 = out_lightmapCoordinate[1] - out_lightmapCoordinate[2];

    float ao_v0 = out_ao[1] - out_ao[0]; 
    float ao_v1 = out_ao[1] - out_ao[2];

    float id_v0 = out_id[1] - out_id[0]; 
    float id_v1 = out_id[1] - out_id[2];

    for(int i = 0; i < maxSubdivisions; ++i) {
        for(int j = 0; j <= i; ++j) {
            int verticesCount = 3;
            if(j != i) verticesCount = 2;
            for(int v = 0; v < verticesCount; v++) {
                float w[2];
                w[0] = float(j + v/2) / maxSubdivisions;
                w[1] = 1.0 - float(i + 1 - v%2) / maxSubdivisions;
                w[0] = -w[0];
                w[1] = -w[1];
                vec4 pos = gl_in[1].gl_Position + position_v0 * w[0] + position_v1 * w[1];

                tbn = out_tbn[1] + tbn_v0 * w[0] + tbn_v1 * w[1];
                tint = out_tint[1] + tint_v0 * w[0] + tint_v1 * w[1];
                worldPosition = out_worldPosition[1] + worldPosition_v0 * w[0] + worldPosition_v1 * w[1];
                viewPosition = out_viewPosition[1] + viewPosition_v0 * w[0] + viewPosition_v1 * w[1];
                vertexNormal = out_vertexNormal[1] + flatNormal_v0 * w[0] + flatNormal_v1 * w[1];
                textureCoordinate = out_textureCoordinate[1] + textureCoordinate_v0 * w[0] + textureCoordinate_v1 * w[1];
                lightmapCoordinate = out_lightmapCoordinate[1] + lightmapCoordinate_v0 * w[0] + lightmapCoordinate_v1 * w[1];
                ao = out_ao[1] + ao_v0 * w[0] + ao_v1 * w[1];
                id = out_id[1] + id_v0 * w[0] + id_v1 * w[1];

                if(id == 8.0) {
                    vec3 waveNorm = tbn * waterNormal(worldPosition.xz).xzy;
                    mat3 rotMat = GetRotationMatrix(tbn[2], waveNorm);
                    #ifdef USE_VERTEX_NORMAL
                        tbn = rotMat * tbn;
                        vertexNormal = rotMat * vertexNormal;
                    #endif
                    vertexNormal = mat3(gbufferModelViewInverse) * vertexNormal;

                    if(MAX_WATER_SUBDIVISIONS > 1) pos.y += calculateWaves(worldPosition.xz);
                } else {
                    vertexNormal = mat3(gbufferModelViewInverse) * vertexNormal;
                }

                gl_Position.xyz = mat3(gl_ModelViewMatrix) * pos.xyz + gl_ModelViewMatrix[3].xyz;
                gl_Position = projectVertex(gl_Position.xyz);
                gl_Position.xy += taaOffset * gl_Position.w;

                EmitVertex();
            }
        }
	    EndPrimitive();
    }
}