include("../../src/SI.jl")
include("../../benchmarks/geometries.jl")
using PyPlot
PyPlot.close("all")
using StatsBase

function main()

    # 2d plates test 
    
    # Decide on geometry using provided geometries
    (X, Y, b1, b2) = getGeometry(10,30,coef=-0.5)
    I1 = SI.getIntervals(b1)
    I2 = SI.getIntervals(b2)
    
    tol = 1e-10
    kernel = (x,y) -> 1./SI.distance_2(x,y)
    # kernel = (x,y) -> exp.(-SI.distance_2(x,y).^2)
    # kernel = (x,y) -> SI.distance_2(x,y)
    kernelR = (x,y) -> kernel(SI.RtoX(b1,x),SI.RtoX(b2,y))
        
    Atrue = SI.meshKernelFull(kernel, X, Y)
    
    realRank = SI.rank_eps_fro(Atrue, [tol])[1]
    
    tolInterps = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5]
    errs = zeros(size(tolInterps))
    
    r0 = zeros(size(tolInterps))
    r1 = zeros(size(tolInterps))
    
    ruletype = SI.chebyshev
    Xk = 0
    Yk = 0
    err = 0
    
    for (i, tolInterp) = enumerate(tolInterps)

        @printf "========\n"
    
        (U, F, V, Xhat, Yhat, Xbar, Ybar) = SI.fastMeshKernelLowRankApprox(kernelR, I1, I2, SI.XtoR(b1, X), SI.XtoR(b2, Y), tol, ruletype=ruletype, tolInterp=tol.^(-tolInterp))
        Xk = SI.RtoX(b1, Xhat)
        Yk = SI.RtoX(b2, Yhat)
        Aapp  = U*(F\V')
        Fhat = SI.meshKernelFull(kernel, Xk, Yk)
        err = norm(Aapp - Atrue)/norm(Atrue) 
        @printf "Coef %f Error %e Cond %e for rank %d (actual %e rank is %d)\n" tolInterp err cond(Fhat) minimum(size(U)) tol realRank

        ntests = 10000

        errLU_rand = zeros(ntests)
        for t = 1:ntests
            xr = rand(size(Fhat, 1))
            br = Fhat * xr  
            xr_solved = Fhat \ br
            errLU_rand[t] = norm(xr_solved - xr) / norm(xr)
        end
        @printf "Random Err Min %e Max %e Mean %e Std %e\n" minimum(errLU_rand) maximum(errLU_rand) mean(errLU_rand) std(errLU_rand)
        errLU_Xhat = zeros(ntests)
        for t = 1:ntests
            xr = V[rand(1:size(V, 2)),:]
            br = Fhat * xr  
            xr_solved = Fhat \ br
            errLU_Xhat[t] = norm(xr_solved - xr) / norm(xr)
        end
        @printf "Xhat-Y Err Min %e Max %e Mean %e Std %e\n" minimum(errLU_Xhat) maximum(errLU_Xhat) mean(errLU_Xhat) std(errLU_Xhat)
        errLU_Yhat = zeros(ntests)
        for t = 1:ntests
            xr = U[rand(1:size(V, 2)),:]
            br = Fhat * xr  
            xr_solved = Fhat \ br
            errLU_Yhat[t] = norm(xr_solved - xr) / norm(xr)
        end
        @printf "Yhat-X Err Min %e Max %e Mean %e Std %e\n" minimum(errLU_Yhat) maximum(errLU_Yhat) mean(errLU_Yhat) std(errLU_Yhat)
        
        
        # PyPlot.figure(i)
        # PyPlot.plot(Xk[1,:],Xk[2,:],"*r")
        # SI.plotBoundingBox(b1, i)
        # PyPlot.plot(Yk[1,:],Yk[2,:],"*r")
        # SI.plotBoundingBox(b2, i)
        # PyPlot.title(@sprintf "Coef %f Error %e rank %d (real %d)" tolInterp err minimum(size(U)) realRank)
    
        errs[i] = err
        r0[i] = min(size(Xbar, 2), size(Ybar, 2))
        r1[i] = min(size(Xhat, 2), size(Yhat, 2))
    
    end
    
    PyPlot.figure()
    id = PyPlot.gcf()[:number] 
    PyPlot.plot(Xk[1,:],Xk[2,:],"*r")
    SI.plotBoundingBox(b1, id)
    PyPlot.plot(Yk[1,:],Yk[2,:],"*r")
    SI.plotBoundingBox(b2, id)
    PyPlot.title(@sprintf "Coef %f Error %e/%e rank %d (real %d) type %s" tolInterps[end] err[end] tol size(Xk, 2) realRank ruletype)
    
    PyPlot.figure()
    PyPlot.semilogy(tolInterps, errs)
    PyPlot.title(string(ruletype))
    
    PyPlot.figure()
    PyPlot.plot(tolInterps, r0, label="r0")
    PyPlot.plot(tolInterps, r1, label="r1")
    PyPlot.plot(tolInterps, realRank * ones(size(tolInterps)), label="r")
    PyPlot.title(string(ruletype))

    @printf "===============\n"
    
    # Compare with uniform grid - should be bad hopefully ?
    npts = Int64(ceil(sqrt(realRank)))
    orders = [npts, npts]
    Xhat_unif = SI.RtoX(b1, SI.TensorInterpolator{Float64}(I1, orders, SI.uniform).xk)
    Yhat_unif = SI.RtoX(b2, SI.TensorInterpolator{Float64}(I2, orders, SI.uniform).xk)
    Aapp_unif = SI.meshKernelFull(kernel, X, Yhat_unif) * (SI.meshKernelFull(kernel, Xhat_unif, Yhat_unif) \ SI.meshKernelFull(kernel, Xhat_unif, Y))
    err_unif = norm(Aapp_unif - Atrue)/norm(Atrue)
    
    PyPlot.figure()
    id = PyPlot.gcf()[:number] 
    PyPlot.plot(Xhat_unif[1,:],Xhat_unif[2,:],"*r")
    SI.plotBoundingBox(b1, id)
    PyPlot.plot(Yhat_unif[1,:],Yhat_unif[2,:],"*r")
    SI.plotBoundingBox(b2, id)
    PyPlot.title(@sprintf "Unif no SI Error %e/%e rank %d (real %d)" err_unif tol size(Xhat_unif, 2) realRank)

    # And with Random
    ntests = 50
    err_unif = zeros(ntests)
    for i = 1:ntests
        Xhat_rand = X[:,sample(1:size(X, 2), realRank, replace=false)]
        Yhat_rand = Y[:,sample(1:size(Y, 2), realRank, replace=false)]
        Aapp_rand = SI.meshKernelFull(kernel, X, Yhat_rand) * (SI.meshKernelFull(kernel, Xhat_rand, Yhat_rand) \ SI.meshKernelFull(kernel, Xhat_rand, Y))
        err_unif[i] = norm(Aapp_rand - Atrue)/norm(Atrue)
    end
    @printf "Unif Err Min %e Max %e Mean %e Std %e\n" minimum(err_unif) maximum(err_unif) mean(err_unif) std(err_unif)

end

main()
