include("../src/SI.jl")
include("geometries.jl")

using PlotlyJS

kernel1 = (x, y) -> 1./(SI.distance_2(x, y)) + 1e-7 * randn(size(x)[2])
kernel2 = (x, y) -> 1./(SI.distance_2(x, y).^(1.5))
kernel3 = (x, y) -> 1./(sqrt.(SI.distance_2(x, y).^2 + 1.0))
kernel4 = (x, y) -> sqrt.(SI.distance_2(x, y).^2 + 1.0)
kernel5 = (x, y) -> exp.(-0.1 * SI.distance_2(x, y).^2)

kernels      = [kernel1, kernel2, kernel3, kernel4, kernel5]
kernels_name = ["1/r", "1/r^(3/2)", "1/sqrt(r^2+1)", "sqrt(r^2+1)", "exp(-0.1 * r^2)"]

# The mesh
b1 = SI.CubeBoundingBox{Float64}([-1.0,-1.0,-1.0],[1.0,1.0,1.0])
b2 = SI.CubeBoundingBox{Float64}([3.0,3.0,3.0],[1.0,1.0,1.0])
X = SI.generatePoints(b1,[10,10,10])
Y = SI.generatePoints(b2,[10,10,10])

# Put sphere        
b1sphere = SI.SphereBoundingBox{Float64}(X)
b2sphere = SI.SphereBoundingBox{Float64}(Y)
XbarCircle = n -> (SI.generatePoints(b1sphere, n), ones(n))
YbarCircle = n -> (SI.generatePoints(b2sphere, n), ones(n))

XbarMDV = n -> (X[:,SI.get_mdv(X, n)], ones(n))
YbarMDV = n -> (Y[:,SI.get_mdv(Y, n)], ones(n))

for (kernel, kernel_name) in zip(kernels, kernels_name)
    @printf "+++++++++ Kernel %s +++++++++\n" kernel_name
    Ktrue = SI.meshKernelFull(kernel, X, Y)
    for (it, tol) in enumerate(logspace(-8, -8, 1))
        stop = (r0, r1, Xhat, Yhat) -> SI.accuracy_stopping_criterion(kernel, X, Y, Xhat, Yhat, Ktrue, tol)
        datas_1 = []
        datas_2 = []
        for method in [1, 2]
            for tolQR in logspace(-10, -9, 2)
                if method == 1
                    (Xhat, Yhat, Xbar, Ybar, diags, r0s) = SI.adaptive_si(kernel, XbarCircle, YbarCircle, tolQR, rinit=4, rmax=1000, rinc=x->Int(ceil(x*1.5)), loglevel=SI.none, stopping_criterion=stop)
                    push!(datas_1, PlotlyJS.scatter(x=r0s, y=[e for (r,e) in diags], name=string("Sphere ", tolQR)))
                    push!(datas_2, PlotlyJS.scatter(x=[r for (r,e) in diags], y=[e for (r,e) in diags], name=string("Sphere ", tolQR)))
                    if tolQR == 10.0^(-10) && kernel_name == "1/r"
                        plt = PlotlyJS.plot(PlotlyJS.scatter3d(x=Xhat[1,:],y=Xhat[2,:],z=Xhat[3,:]))
                        display(plt)
                    end
                else
                    (Xhat, Yhat, Xbar, Ybar, diags, r0s) = SI.adaptive_si(kernel, XbarMDV, YbarMDV, tolQR, rinit=4, rmax=1000, rinc=x->Int(ceil(x*1.5)), loglevel=SI.none, stopping_criterion=stop)
                    push!(datas_1, PlotlyJS.scatter(x=r0s, y=[e for (r,e) in diags], name=string("MDV ", tolQR)))
                    push!(datas_2, PlotlyJS.scatter(x=[r for (r,e) in diags], y=[e for (r,e) in diags], name=string("MDV ", tolQR)))
                    if tolQR == 10.0^(-10) && kernel_name == "1/r"
                        plt = PlotlyJS.plot(PlotlyJS.scatter3d(x=Xhat[1,:],y=Xhat[2,:],z=Xhat[3,:]))
                        display(plt)
                    end
                end
            end
        end
        layout_1 = PlotlyJS.Layout(title=string(kernel_name, " - Looking for error <= ", tol, " using various tolQR, MDV and Sphere"),yaxis_type="log", yaxis_exponentformat="e", xaxis_title="r0", yaxis_title="Frobenius relative error")
        layout_2 = PlotlyJS.Layout(title=string(kernel_name, " - Looking for error <= ", tol, " using various tolQR and r0, MDV and Sphere"),yaxis_type="log", yaxis_exponentformat="e", xaxis_title="r1", yaxis_title="Frobenius relative error")
        disp_1 = PlotlyJS.plot([d for d in datas_1], layout_1)
        disp_2 = PlotlyJS.plot([d for d in datas_2], layout_2)
        display([disp_1 disp_2])
    end
end
