include("../src/SI.jl")
include("../benchmarks/aca.jl")
include("../benchmarks/geometries.jl")

using PlotlyJS
using JSON
# using PyPlot

dists = [0.5,0.75,1.0,1.25,1.5,1.75,2.0,2.25,2.5,2.75,3.0,3.25,3.5,3.75,4.0,4.25,4.5,4.75,5.0]
tol = 1e-8
nRepeat = 25
        
kernel_1r    = (x, y) -> 1./SI.distance_2(x, y)
kernel_1r2   = (x, y) -> 1./SI.distance_2(x,y).^2
kernel_1r3   = (x, y) -> 1./SI.distance_2(x,y).^3
kernel_logr  = (x, y) -> log.(SI.distance_2(x,y))
kernel_expr  = (x, y) -> exp.(-SI.distance_2(x,y))
kernel_gaus  = (x, y) -> exp.(-SI.distance_2(x,y).^2)

kernels_funs   = [kernel_1r, kernel_1r2, kernel_1r3, kernel_logr, kernel_expr, kernel_gaus]
kernels_names  = ["1/r",     "1/r^2",    "1/r^3",    "log(r)",    "exp(-r)",   "exp(-r^2)"]
coefs          = [1.0,       0.8,        0.8,        0.8,         0.8,         0.8        ]

for (i_kernel, (kernel_fun, kernel_name, coef)) in enumerate(zip(kernels_funs, kernels_names, coefs))

    @printf "+++++++++++++++++ %s +++++++++++++++++++\n" kernel_name

    if i_kernel == 1
        continue
    end

    medSI = zeros(length(dists))
    medRand = zeros(length(dists))
    medACA = zeros(length(dists))
    
    allerrsSI = []
    allerrsRand = []
    allerrsACA = []

    fixedRankR0 = nothing

    for (i_dist, dist) in enumerate(dists)
    
        b1 = SI.CubeBoundingBox{Float64}([0.0,0.0],[1.0,1.0])
        b2 = SI.CubeBoundingBox{Float64}([1.0+dist,0.0],[1.0,1.0])
        Xall = SI.generatePoints(b1, [100,100])
        Yall = SI.generatePoints(b2, [100,100])
        refxR = [1.0,0.5];
        refyR = [1.0+dist,0.5];
        N = size(Xall,2)
            
        kernelR = (rx, ry) -> kernel_fun(SI.RtoX(b1,rx),SI.RtoX(b2,ry))
        I1 = SI.getIntervals(b1)
        I2 = SI.getIntervals(b2)
    
    
        errsRand = zeros(Float64, nRepeat)
        errsSI   = zeros(Float64, nRepeat)
        errsACA  = zeros(Float64, nRepeat)
    
        for r = 1:nRepeat
    
            # Build random set of points
            n = 500
            X = Xall[:,randperm(N)[1:n]]
            Y = Yall[:,randperm(N)[1:n]]
        
            Ktrue = SI.meshKernelFull(kernel_fun, X, Y)
        
            function XbarCheb(n)
                (orderx, ordery) = SI.getInterpolationOrder(kernelR, I1, I2, n, logLevel=SI.none, refx=SI.XtoR(b1,refxR), refy=SI.XtoR(b2,refyR))
                interpx = SI.TensorInterpolator{Float64}(I1, orderx, SI.chebyshev)
                Xbar = interpx.xk
                Wx   = interpx.w_intk
                return (SI.RtoX(b1, Xbar), Wx)
            end
            function YbarCheb(n)
                (orderx, ordery) = SI.getInterpolationOrder(kernelR, I1, I2, n, logLevel=SI.none, refx=SI.XtoR(b1,refxR), refy=SI.XtoR(b2,refyR))
                interpy = SI.TensorInterpolator{Float64}(I2, ordery, SI.chebyshev)
                Ybar = interpy.xk
                Wy   = interpy.w_intk
                return (SI.RtoX(b2, Ybar), Wy)
            end
    
            stop = (r0, r1, Xhat, Yhat) -> SI.accuracy_stopping_criterion(kernel_fun, X, Y, Xhat, Yhat, Ktrue, tol)
        
            # Cheb SI
            if fixedRankR0 == nothing # If want to reuse rank throughout the code
                # Heuristic-based version
                # (U,S,V,Xhatr,Yhatr,Xbarr,Ybarr,Uc,Sc,Vc) = SI.fastMeshKernelLowRankApprox(kernelR, I1, I2, SI.XtoR(b1,X), SI.XtoR(b2,Y), tol, tolInterp=0.1, tolQR=1, logLevel=SI.none, refx=SI.XtoR(b1,refxR), refy=SI.XtoR(b2,refyR))
                # Xhat = SI.RtoX(b1,Xhatr)
                # Yhat = SI.RtoX(b2,Yhatr)
                # 'Optimal r0-r1' version
                (Xhat, Yhat, Xbar, Ybar) = SI.adaptive_si(kernel_fun, XbarCheb, YbarCheb, coef * tol, rinit=4, rmax=1000, rinc=x->Int(ceil(x*1.1)), loglevel=SI.debug, stopping_criterion=stop)
                fixedRankR0 = size(Xbar,2)
                fixedRankR1 = size(Xhat,2)
            else
                @printf "Cheb using rank %d\n" fixedRankR0
                (Xhat, Yhat, Xbar, Ybar) = SI.adaptive_si(kernel_fun, XbarCheb, YbarCheb, coef * tol, rinit = fixedRankR0, rmax = fixedRankR0, loglevel=SI.none) # Forces rank
            end
            errsSI[r] = SI.err_si_frob(kernel_fun, X, Y, Xhat, Yhat, Ktrue)
    
            @printf "RCUR and ACA using rank %d\n" fixedRankR1
    
            # Random CUR
            Xhat = X[:,randperm(n)[1:fixedRankR1]]
            Yhat = Y[:,randperm(n)[1:fixedRankR1]]
            errsRand[r] = SI.err_si_frob(kernel_fun, X, Y, Xhat, Yhat, Ktrue)
    
            # ACA
            (U,V,Iaca,Jaca) = pp_aca(kernel_fun, X, Y, tol, fixedRank = fixedRankR1, loglevel = SI.none)
            errsACA[r] = vecnorm(Ktrue - U*V') / vecnorm(Ktrue)
            # errsACA[r] = SI.err_si_frob(kernel, X, Y, X[:,Iaca], Y[:,Jaca], Ktrue) # Checking if its a stability issue
            
            # if i == 1 && r == 1
            #     PyPlot.figure(1)
            #     SI.plotBoundingBox(b1, 1)
            #     SI.plotBoundingBox(b2, 1)
            #     # PyPlot.plot(Xhat[1,:],Xhat[2,:],".b",markersize=3.5)
            #     # PyPlot.plot(Yhat[1,:],Yhat[2,:],".r",markersize=3.5)
            #     # PyPlot.plot(X[1,Iaca],X[2,Iaca],"dr",markersize=3.5)
            #     # PyPlot.plot(Y[1,Jaca],Y[2,Jaca],"db",markersize=3.5)
            #     PyPlot.plot(refx[1],refx[2],"or",markersize=3.5) 
            #     PyPlot.plot(refy[1],refy[2],"or",markersize=3.5) 
            # end
    
            # Reset ranks, or not
            fixedRankR0 = nothing
            
        end
    
        push!(allerrsSI, errsSI)
        push!(allerrsACA, errsACA)
        push!(allerrsRand, errsRand)
    
        medSI[i_dist] = sort(errsSI)[Int(round(nRepeat/2))]
        medRand[i_dist] = sort(errsRand)[Int(round(nRepeat/2))]
        medACA[i_dist] = sort(errsACA)[Int(round(nRepeat/2))]
    end

    # Create plot
    allDists = []
    allerrsSIconcat = []
    allerrsACAconcat = []
    allerrsRandconcat = []
    for i = 1:length(dists)
        for j = 1:nRepeat
            push!(allDists, dists[i])
        end
        append!(allerrsSIconcat, allerrsSI[i])
        append!(allerrsACAconcat, allerrsACA[i])
        append!(allerrsRandconcat, allerrsRand[i])
    end

    @assert size(allerrsSIconcat) == size(allDists)
    @assert size(allerrsACAconcat) == size(allDists)
    @assert size(allerrsRandconcat) == size(allDists)
    
    colorSI = "rgb(128,0,0)"
    colorRand = "rgb(0,0,128)"
    colorACA = "rgb(0,128,0)"

    boxSI   = PlotlyJS.box(y=allerrsSIconcat, x=allDists, name="Chebyshev SI", marker_color=colorSI)
    boxRand = PlotlyJS.box(y=allerrsRandconcat, x=allDists+0.05, name="Random CUR", marker_color=colorRand)
    boxACA  = PlotlyJS.box(y=allerrsACAconcat, x=allDists-0.05, name="ACA", marker_color=colorACA)
    
    si_plt = PlotlyJS.scatter(x=dists, y = medSI, name="Chebyshev SI", line_color=colorSI, marker_color=colorSI, showlegend=false)
    rand_plt = PlotlyJS.scatter(x=dists+0.05, y = medRand, name="Random CUR", line_color=colorRand, marker_color=colorRand, showlegend=false)
    aca_plt = PlotlyJS.scatter(x=dists-0.05, y = medACA, name="ACA", line_color=colorACA, marker_color=colorACA, showlegend=false)
    
    disp = PlotlyJS.plot([boxRand, boxACA, boxSI, si_plt, aca_plt, rand_plt], Layout(yaxis_type="log", yaxis_exponentformat="e", boxmode="group", title=string("Error as a function of the distance with same rank r1 for all methods, with kernel ", kernel_name), xaxis_title="Distance", yaxis_title="Frobenius relative error", width=800, height=400))
    PlotlyJS.display(disp)
   
    # Save to file
    file = string("results-benchmarks/benchmark_", i_kernel)
    PlotlyJS.savefig(disp, string(file, ".eps"))
    open(string(file, ".json"), "w") do f
        write(f, JSON.json(disp))
    end

end
