include("../src/SI.jl")
using PlotlyJS

# Kernels
k1 = (x, y) -> 1./(1+SI.distance_2(x,y).^2) 
k2 = (x, y) -> 1./SI.distance_2(x,y)
k3 = (x, y) -> log.(1+SI.distance_2(x,y)) 
k4 = (x, y) -> exp.(-SI.distance_2(x,y))
knames = ["1/(1+25r^2)", "1/r", "log(1+r)", "exp(-r)"]
kernels = [k1]#, k2, k3, k4]

# Space
b1_1d = SI.CubeBoundingBox{Float64}([-1.0],[2.0])
b2_1d = SI.CubeBoundingBox{Float64}([-1.0],[2.0])
X_1d = SI.generatePoints(b1_1d,[1000])
Y_1d = SI.generatePoints(b2_1d,[1000])

b1_2d = SI.CubeBoundingBox{Float64}([-1.0,-1.0],[1.0,1.0])
b2_2d = SI.CubeBoundingBox{Float64}([1.0,1.0],[1.0,1.0])
X_2d = SI.generatePoints(b1_2d,[50,50])
Y_2d = SI.generatePoints(b2_2d,[50,50])

b1s = [b1_1d, b1_2d]
b2s = [b2_1d, b2_2d]
Xs = [X_1d, X_2d]
Ys = [Y_1d]#, Y_2d]
geo_names = ["1d intervals, |X|=|Y|=1000", "2d squares, |X|=|Y|=50x50"]

# Tolerance
tol = 1e-12

for (kernel, kernel_name) in zip(kernels, knames)
    for (b1, b2, X, Y, geo_name) in zip(b1s, b2s, Xs, Ys, geo_names)

        Ktrue = SI.meshKernelFull(kernel, X, Y)
        XbarMDV = n -> (X[:,SI.get_mdv(X, n)], ones(n))
        YbarMDV = n -> (Y[:,SI.get_mdv(Y, n)], ones(n))
        
        r0max = 80
        
        err0 = zeros(r0max)
        err1 = zeros(r0max)
        rs = 1:r0max

        r0s = zeros(r0max)
        r1s = zeros(r0max)

        Xhat = nothing
        Yhat = nothing
        Xbar = nothing
        Ybar = nothing

        condr0 = zeros(r0max)
        condr1 = zeros(r0max)
        
        for (i,r0) in enumerate(rs)
            # Bar
            (Xbar,) = XbarMDV(r0)
            (Ybar,) = YbarMDV(r0)
            Kbar = SI.meshKernelFull(kernel, Xbar, Ybar)
            (Ub,sb,Vb) = svd(Kbar)
            Kbar_app = (SI.meshKernelFull(kernel, X, Ybar) * Vb) * diagm(1./sb) * (Ub' * SI.meshKernelFull(kernel, Xbar, Y))
            condr0[i] = maximum(sb) / minimum(sb)
            err0[i] = vecnorm(Kbar_app - Ktrue) / vecnorm(Ktrue)
            # Hat
            (px,py) = SI.cur(Kbar, tol)
            Xhat = Xbar[:,px]
            Yhat = Ybar[:,py]
            Khat = SI.meshKernelFull(kernel, Xhat, Yhat)
            (Uh,sh,Vh) = svd(Khat)
            Khat_app = (SI.meshKernelFull(kernel, X, Yhat) * Vh) * diagm(1./sh) * (Uh' * SI.meshKernelFull(kernel, Xhat, Y))
            condr1[i] = maximum(sh) / minimum(sh)
            err1[i] = vecnorm(Khat_app - Ktrue) / vecnorm(Ktrue)
            # Checks and stuff
            r0s[i] = r0
            r1s[i] = size(Xhat)[2]
            @assert r0 == size(Xbar)[2]
            @show err0[i]
            @show err1[i]
            @show r0s[i]
            @show r1s[i]
        end
        
        lyt = PlotlyJS.Layout(title=string("Error for r0 and r1 using RRQR with tol\n ", tol, " on ", geo_name, " for kernel ", kernel_name), yaxis_type="log", yaxis_exponentformat="e", xaxis_title="r0", yaxis_title="Relative Frobenius Error")
        plt1 = PlotlyJS.scatter(x=rs,y=err0,name="r0")
        plt2 = PlotlyJS.scatter(x=rs,y=err1,name="r1")
        plt = PlotlyJS.plot([plt1, plt2], lyt)
        PlotlyJS.display(plt)
        
        lyt = PlotlyJS.Layout(title=string("Condition number for ", kernel_name), yaxis_type="log", yaxis_exponentformat="e", xaxis_title="r0", yaxis_title="Condition number")
        plt1 = PlotlyJS.scatter(x=rs,y=condr0,name="r0")
        plt2 = PlotlyJS.scatter(x=rs,y=condr1,name="r1")
        plta = PlotlyJS.plot([plt1, plt2], lyt)
        
        lyt = PlotlyJS.Layout(title=string("Rank r1 as a function of r0"), xaxis_title="r0", yaxis_title="r1")
        plt = PlotlyJS.scatter(x=r0s,y=r1s,name="r1")
        pltb = PlotlyJS.plot(plt, lyt)
        
        PlotlyJS.display([plta pltb])

        if size(Xbar)[1] == 1
            lyt = PlotlyJS.Layout(title=string("Xbar and Xhat for ", kernel_name), xaxis_title="x")
            plt1 = PlotlyJS.scatter(x=Xbar[:],y=ones(size(Xbar[:])),name="Xbar (r0)",mode="markers",marker_size="10")
            plt2 = PlotlyJS.scatter(x=Xhat[:],y=ones(size(Xhat[:])),name="Xhat (r1)",mode="markers",marker_size="20")
            plta = PlotlyJS.plot([plt1, plt2], lyt)

            lyt = PlotlyJS.Layout(title=string("Ybar and Yhat"), xaxis_title="y")
            plt1 = PlotlyJS.scatter(x=Ybar[:],y=ones(size(Ybar[:])),name="Ybar (r0)",mode="markers",marker_size="10")
            plt2 = PlotlyJS.scatter(x=Yhat[:],y=ones(size(Yhat[:])),name="Yhat (r1)",mode="markers",marker_size="20")
            pltb = PlotlyJS.plot([plt1, plt2], lyt)
            
            PlotlyJS.display([plta pltb])
        else
            lyt = PlotlyJS.Layout(title=string("Xbar and Xhat for ", kernel_name), xaxis_title="x1",yaxis_title="x2")
            plt1 = PlotlyJS.scatter(x=Xbar[1,:],y=Xbar[2,:],name="Xbar (r0)",mode="markers",marker_size="10")
            plt2 = PlotlyJS.scatter(x=Xhat[1,:],y=Xhat[2,:],name="Xhat (r1)",mode="markers",marker_size="20")
            plta = PlotlyJS.plot([plt1, plt2], lyt)

            lyt = PlotlyJS.Layout(title=string("Ybar and Yhat"), xaxis_title="y1",yaxis_title="y2")
            plt1 = PlotlyJS.scatter(x=Ybar[1,:],y=Ybar[2,:],name="Ybar (r0)",mode="markers",marker_size="10")
            plt2 = PlotlyJS.scatter(x=Yhat[1,:],y=Yhat[2,:],name="Yhat (r1)",mode="markers",marker_size="20")
            pltb = PlotlyJS.plot([plt1, plt2], lyt)
            
            PlotlyJS.display([plta pltb])
        end
    end
end
