# Benchmark all
include("../src/SI.jl")
include("../benchmarks/geometries.jl")

# Dont forget
# module load julia/precompiled/0.5.0
# module load h5utils/1.12.1
# module load hdf5/1.8.16
# module load zlib/1.2.8

using Iterators
#using JLD
using Dates
using LowRankApprox
srand(0)

# Parse command line arguments
@printf "Arguments %s\n" string(ARGS)
@assert length(ARGS) == 0 || length(ARGS) == 2

kernels = [(x,y) -> 1./SI.distance_2(x,y),
           (x,y) -> sqrt.(1+SI.distance_2(x,y).^2),
           (x,y) -> 1./SI.distance_2(x,y).^2,
           (x,y) -> SI.distance_2(x,y).^3,
           (x,y) -> SI.distance_2(x,y).^2.*log.(1+SI.distance_2(x,y)),
           (x,y) -> exp.(-SI.distance_2(x,y).^2)]
kernels = [(x,y) -> 1./SI.distance_2(x,y)]
kernels_names = ["1/r"]#,"sqrt(1+r^2)","1/r^2","r^3","r^2 ln(r)","exp(-r^2)"]

rule     = SI.chebyshev
methods_names = ["Tensor"] # To compare multiple type of interpolators - not used anymore
method = ["tensor"]

norm_type = "fro" # "fro" or "l2"
@assert norm_type == "fro"

coef = 0.0

geo_names = getGeometriesNames(coef)

if length(ARGS) == 0
    geoBegin = 3
    geoEnd = 3 # length(geo_names)
elseif length(ARGS) == 2
    geoBegin = parse(ARGS[1])
    geoEnd = parse(ARGS[2])
end
@assert geoBegin <= geoEnd
@assert geoBegin >= 1
@assert geoEnd <= length(geo_names)

geos = geoBegin:geoEnd
@printf "geoBegin %d, geoEnd %d\n" geoBegin geoEnd

npts = zeros(Int64, size(geos))
npts[geos .!= 11] = 50
npts[geos .== 11] = 10 # 3D cube

@assert length(kernels) == length(kernels_names)
@printf "===== Settings =====\n"
@printf "Geo Names %s\n" string(geo_names)
@printf "Kernels Names %s\n" string(kernels_names)
@printf "Method Names %s\n" string(methods_names)
@printf "Npts %s\n" string(npts)
@printf "Norm %s\n" norm_type
@printf "--------------------\n"

line_styles = ["-","--","-."]

tols = logspace(-12, -1, 12)

# Get one date for everyone
date_run = replace(replace(replace(string(Dates.now()),":","."),"/","."),":",".")

for geo = 1:length(geos)
    @printf "||||||||||||||||| Geometry %d (%d pts) ||||||||||||||| \n" geos[geo] npts[geo]
    (X, Y, b1, b2) = getGeometry(geos[geo], npts[geo], coef=coef)
    I1 = SI.getIntervals(b1)
    I2 = SI.getIntervals(b2)
    # Dry run to compile
    if geo == 1
        @printf "============== Dry Run ===============\n"
        # trash = SI.meshKernelLowRankAdaptive(kernels[1], b1, b2, X, Y, 1e-1, ruletype=rule, useSI=true) 
        SI.fastMeshKernelLowRankApprox((x,y)->kernels[1](SI.RtoX(b1,x),SI.RtoX(b2,y)), I1, I2, SI.XtoR(b1,X), SI.XtoR(b2,Y), 1e-1, ruletype=rule)
        SI.meshKernelFull(kernels[1], X[:,1:10], Y[:,1:10])
    end
    # Run Stuff
    for ker = 1:length(kernels)
        @printf "=============== Kernel %d %s =============\n" ker kernels_names[ker]
        kernel = kernels[ker]
        # Compute Kernel
        t_Atrue = @elapsed Atrue = SI.meshKernelFull(kernel, X, Y)
        @printf "Time Atrue %f s.\n" t_Atrue
        # Prepare data
        ranks_uncomp = zeros(Float64, (length(method), length(tols)))
        sizes_uncomp = Array{Array{Int64,1}, 2}((length(method), length(tols)))
        sizes_comp   = Array{Array{Int64,1}, 2}((length(method), length(tols)))
        err          = zeros(Float64, (length(method), length(tols)))
        ranks_comp   = zeros(Float64, (length(method), length(tols)))
        ranks_opt    = zeros(Float64, (length(tols),))
        if norm_type == "fro"
            ranks_opt[:] = SI.rank_eps_fro(Atrue, tols)
        elseif norm_type == "l2"
            ranks_opt[:] = SI.rank_eps_l2(Atrue, tols)
        end
        Xhat           = Array{Array{Float64,2},2}((length(method), length(tols)))
        Yhat           = Array{Array{Float64,2},2}((length(method), length(tols)))
        Xbar           = Array{Array{Float64,2},2}((length(method), length(tols)))
        Ybar           = Array{Array{Float64,2},2}((length(method), length(tols)))
        diags        = Array{Any,2}((length(method), length(tols)))
        t_RRQR       = Array{Any,1}(length(tols))
        ranks_rrqr   = zeros(Int64, length(tols))
        # Build Low Rank
        for t = 1:length(tols)
            tol = tols[t]
            @printf "-------- Tol %e -------\n" tol
            t_RRQR[t] = @elapsed PQRtol = pqrfact(Atrue, rtol=tol)
            @printf "Time RRQR tol %e : %f s.\n" tol (t_RRQR[t])
            ranks_rrqr[t] = size(PQRtol[:R])[1]
            @printf "RRQR gives rank %d (SVD gives %d)\n" ranks_rrqr[t] ranks_opt[t]
            for m = 1:length(method)
                @printf "-------- Method %d -------\n" m 
                t_lr = @elapsed (U, F, V, Xhat[m,t], Yhat[m,t], Xbar[m,t], Ybar[m,t], Ucomp, Scomp, Vcomp) = SI.fastMeshKernelLowRankApprox((x,y)->kernels[ker](SI.RtoX(b1,x),SI.RtoX(b2,y)), I1, I2, SI.XtoR(b1,X), SI.XtoR(b2,Y), tol, ruletype=rule, logLevel=SI.debug, tolInterp=tol^(-0.25), tolQR=1)
                Xhat[m,t] = SI.RtoX(b1,Xhat[m,t])
                Yhat[m,t] = SI.RtoX(b2,Yhat[m,t])
                Xbar[m,t] = SI.RtoX(b1,Xbar[m,t])
                Ybar[m,t] = SI.RtoX(b2,Ybar[m,t])
                # Facto is U*Finv*V'
                Aapp = U*(F\(V'))
                if norm_type == "fro"
                    err[m,t] = vecnorm(Aapp - Atrue)/vecnorm(Atrue)
                elseif norm_type == "l2"
                    err[m,t] = norm(Aapp - Atrue, 2)/norm(Atrue, 2)
                end
                sizes_uncomp[m,t] = [size(Xbar[m,t],2), size(Ybar[m,t],2)]
                sizes_comp[m,t] = [size(Xhat[m,t],2), size(Yhat[m,t],2)]
                @printf "Error %e for rank %d kernel %s geo %s (actual %e rank is %d, uncomp rank was %d) in %e sec (Atrue %e sec)\n" err[m,t]  minimum(sizes_comp[m,t]) kernels_names[ker] geo_names[geos[geo]] tol ranks_opt[t] minimum(sizes_uncomp[m,t]) t_lr t_Atrue
                ranks_uncomp[m,t] = minimum(sizes_uncomp[m,t])
                ranks_comp[m,t]   = minimum(sizes_comp[m,t])
            end
        end
        # Saving stuff
        # filename = string("results-noplots/results-", date_run, "-geo", geos[geo], "-kernel", ker, ".jld")
        # @printf "Saving to %s\n" filename
        # save(filename, "err", err, "sizes_uncomp", sizes_uncomp, "sizes_comp", sizes_comp, "ranks_uncomp", ranks_uncomp, "ranks_comp", ranks_comp, "ranks_opt", ranks_opt, "ranks_rrqr", ranks_rrqr, "tols", tols, "geos", geos, "geo_names", geo_names, "kernel", kernels_names, "method", methods_names, "date", string(Dates.now()), "Xhat", Xhat, "Yhat", Yhat, "Xbar", Xbar, "Ybar", Ybar, "X", X, "Y", Y, "b1", b1, "b2", b2, "norm_type", norm_type, "coef", coef)
        # @printf "Saved done to %s\n" filename
        filename = string("results-noplots/results-", date_run, "-geo", geos[geo], "-kernel", ker, ".dat")
        @printf "Saving to %s\n" filename
        f = open(filename, "w")
        @printf(f,"tol\tr0\tr1\trRRQR\trSVD\terr\n")
        for (tol, r0, r1, rRRQR, rSVD, err) in zip(tols, ranks_uncomp, ranks_comp, ranks_rrqr, ranks_opt, err)
            @printf(f,"%e\t%d\t%d\t%d\t%d\t%e\n",tol,r0,r1,rRRQR,rSVD,err)
        end
        close(f)
        filename = string("results-noplots/results-", date_run, "-geo", geos[geo], "-kernel", ker, "XYhat.dat")
        @printf "Saving to %s\n" filename
        f = open(filename, "w")
        @printf(f,"x1\tx2\tx3\ty1\ty2\ty3\n")
        for i = 1:size(Xhat[1,7], 2)
            @printf(f, "%f\t%f\t%f\t%f\t%f\t%f\n", Xhat[1,7][1,i], Xhat[1,7][2,i], Xhat[1,7][3,i], Yhat[1,7][1,i], Yhat[1,7][2,i], Yhat[1,7][3,i])
        end
        close(f)
        @printf "Saved done to %s\n" filename
    end
end

# Save information
@printf "Done.\n"
@printf "======================\n"
