# Benchmark all
include("../src/SI.jl")
include("../benchmarks/geometries.jl")

using Iterators
using JLD
using Dates
using LowRankApprox

kernel = (x,y) -> 1./SI.distance_2(x,y)
rule = SI.chebyshev

date_run = replace(replace(replace(string(Dates.now()),":","."),"/","."),":",".")
@printf "Run %s\n" date_run

# Dry run
(X, Y, b1, b2) = getGeometry(10, 10)
I1 = SI.getIntervals(b1)
I2 = SI.getIntervals(b2)
#SI.meshKernelLowRankAdaptive(kernel, b1, b2, X, Y, 1e-6, method, adaptive, ruletype=rule, useSI=true)
SI.fastMeshKernelLowRankApprox((x,y)->kernel(SI.RtoX(b1,x),SI.RtoX(b2,y)),I1,I2,SI.XtoR(b1,X),SI.XtoR(b2,Y),1e-1,ruletype=rule,logLevel=SI.none)
trash = SI.meshKernelFull(kernel, X, Y)
pqrfact(trash, rtol=1e-2)

# Run the stuff
geoId = 10
dim = 2
nn = Array{Int64,1}((round.(logspace(1, 5, 10).^(1/dim))))
nXnY = nn.^dim
nmax = 50
nRepeat = 5
tols = logspace(-1, -6, 6)
# coefs = [-0.5, -0.25, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0]
# coefs = [-0.5, 0.0, 2.0]
coefs = [0.0, 1.0, 2.0, 3.0, 4.0, 5.0]
# nn = Array{Int64,1}((round.(sqrt.(logspace(1, 2, 4)))))
# nn2 = nn.^2
# nRepeat = 5
# tols = logspace(-1, -4, 4)
# coefs = [0.0, 1.0, 2.0]

times_Atrue_s = zeros(Float64, (length(nn), length(coefs), nRepeat)) # No tols here <3
times_RRQR_s = zeros(Float64, (length(nn), length(tols), length(coefs), nRepeat))
times_LR_s = zeros(Float64, (length(nn), length(tols), length(coefs), nRepeat))
times_Inverse_s = zeros(Float64, (length(nn), length(tols), length(coefs), nRepeat))
errs_fro = zeros(Float64, (length(nn), length(tols), length(coefs)))
diags = Array{Any, 4}((length(nn), length(tols), length(coefs), nRepeat))
ranks_LR = zeros(Int64, (length(nn), length(tols), length(coefs), nRepeat))

for (c, coef) in enumerate(coefs)
    for (i, n) in enumerate(nn)
        @printf "WWWWWWWWWWWWWWWWWWWWWWWWWWWWWWWWWWWWWWWWWWWWWWWWWWWWWWWWWWWWWWWWWWWWWW\n"
        @printf "c %d coef %e i %d n %d\n" c coef i n
        (X, Y, b1, b2) = getGeometry(geoId, n, coef=coef)
        I1 = SI.getIntervals(b1)
        I2 = SI.getIntervals(b2)
        Atrue = Array{Float64, 2}(0, 0)
        U = Array{Float64, 2}(0, 0)
        V = Array{Float64, 2}(0, 0)
        F = Array{Float64, 2}(0, 0)
        Finv = 0
        @printf "|X| = |Y| = %d\n" size(X)[2]
        @assert size(X)[2] == size(Y)[2]
        # Build Atrue
        @printf "========= Building Atrue =========\n"
        for nr in 1:nRepeat
            @printf "===== nr %d / %d =====\n" nr nRepeat
            if n <= nmax
                times_Atrue_s[i,c,nr] = @elapsed Atrue = SI.meshKernelFull(kernel, X, Y)
                @printf "Time Atrue %f s.\n" times_Atrue_s[i,c,nr]
            else
                times_Atrue_s[i,c,nr] = NaN
            end
        end
        for (t, tol) in enumerate(tols)
            @printf "================================================\n"
            @printf "i %d, n %d, n2 %d, n3 %d, tol %e, coef %e\n" i n n^2 n^3 tol coef
            for nr in 1:nRepeat
                @printf "===== nr %d / %d =====\n" nr nRepeat
                # Time  RRQR (Naive)
                if n <= nmax
                    t_RRQR = @elapsed pqrfact(Atrue, rtol=tol)
                    @printf "Time RRQR tol %e : %f s.\n" tol t_RRQR
                    times_RRQR_s[i,t,c,nr] = t_RRQR
                else
                    @printf "Skipping Atrue and RRQR\n"
                    times_RRQR_s[i,t,c,nr] = NaN
                end
                # Time LR Approx
                t_LR = @elapsed (U, F, V, Xhat, Yhat, Xbar, Ybar) = SI.fastMeshKernelLowRankApprox((x,y)->kernel(SI.RtoX(b1,x),SI.RtoX(b2,y)), I1, I2, SI.XtoR(b1, X), SI.XtoR(b2, Y), tol, ruletype=rule, tolInterp=tol^(-0.3), tolQR=0.1, logLevel=SI.none, nRandom = 5)
                ranks_LR[i,t,c,nr] = min(size(Xhat, 2), size(Yhat, 2))
                times_LR_s[i,t,c,nr] = t_LR
                # Time inverse F \ V
                times_Inverse_s[i,t,c,nr] = @elapsed Finv = lufact(F)
                @printf "Time LR tol %e : %f s.\n" tol t_LR
            end
            @printf "====== Checking error ======\n"
            # Check error
            if n <= nmax
                Aapp = U*(F\V')
                err_l2 = norm(Atrue - Aapp, 2) / norm(Atrue, 2)
                err_fro = vecnorm(Atrue - Aapp) / vecnorm(Atrue)
                errs_fro[i,t,c] = err_fro
                @printf "Error l2 %e fro %e (tol %e)\n" err_l2 err_fro tol
            else
                errs_fro[i,t,c] = NaN
                @printf "Skipping error check\n"
            end
        end
    end
end

@printf "============================================================================\n"
filename = string("results-perfs/results-perfs-", date_run, ".jld")
@printf "Saving to %s\n" filename
save(filename, "times_Atrue_s", times_Atrue_s, "times_RRQR_s", times_RRQR_s, "times_Inverse_s", times_Inverse_s, "times_LR_s", times_LR_s, "nn", nn, "nmax", nmax, "nXnY", nXnY, "geoId", geoId, "tols", tols, "date_run", date_run, "nRepeat", nRepeat, "coefs", coefs, "ranks_LR", ranks_LR, "errs_fro", errs_fro)
@printf "Saved done to %s\n" filename

# @printf "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
# (X, Y, b1, b2) = getGeometry(10, 20, coef=0.0)
# I1 = SI.getIntervals(b1)
# I2 = SI.getIntervals(b2)
# tol = 1e-6
# @profile SI.fastMeshKernelLowRankApprox((x,y)->kernel(SI.RtoX(b1,x),SI.RtoX(b2,y)), I1, I2, SI.XtoR(b1, X), SI.XtoR(b2, Y), tol, ruletype=rule, tolInterp=tol^(-0.5), tolQR=0.1, logLevel=SI.none)
# Profile.print()
