include("../../src/SI.jl")
include("../../benchmarks/geometries.jl")
include("../../LSTC_benchmarks/loadLSTC.jl")
include("../../benchmarks/aca.jl")
using PyPlot
close("all")

function main()


    # Create all the geometries
    allGeos = []
    
    # # Easy square case
    # for c in [0.0, 1.0, 2.0]
    #     (X, Y, b1, b2) = getGeometry(10,10,coef=c)
    #     push!(allGeos, (X, Y, b1, b2, string("2d-square, corner-opposed, distance ", (@sprintf "%3.1f" SI.cluster_distance(X,Y)))))
    # end

    # Torus
    (namegeo, Xs) = getLSTCbenchmark("../../LSTC_benchmarks/", 4)
    for i in [4, 7, 15]
        X = Xs[1]
        Y = Xs[i]
        b1 = SI.CubeBoundingBox{Float64}(X)
        b2 = SI.CubeBoundingBox{Float64}(Y)
        push!(allGeos, (X, Y, b1, b2, string(namegeo, " (pair ", 1, ",", i, " distance ", @sprintf "%3.1f)" SI.cluster_distance(X,Y))))
    end

    # COIL
    (namegeo, Xs) = getLSTCbenchmark("../../LSTC_benchmarks/", 6)
    for i in [10, 20, 30]
        X = Xs[1]
        Y = Xs[i]
        b1 = SI.CubeBoundingBox{Float64}(X)
        b2 = SI.CubeBoundingBox{Float64}(Y)
        push!(allGeos, (X, Y, b1, b2, string(namegeo, " (pair ", 1, ",", i, " distance ", @sprintf "%3.1f)" SI.cluster_distance(X,Y))))
    end
    
    # ENGINE
    (namegeo, Xs) = getLSTCbenchmark("../../LSTC_benchmarks/", 8)
    for i in [10, 30, 50]
        X = Xs[1]
        Y = Xs[i]
        b1 = SI.CubeBoundingBox{Float64}(X)
        b2 = SI.CubeBoundingBox{Float64}(Y)
        push!(allGeos, (X, Y, b1, b2, string(namegeo, " (pair ", 1, ",", i, " distance ", @sprintf "%3.1f)" SI.cluster_distance(X,Y))))
    end

    # Random repetition
    nRepeat = 25

    for (geoid, (X, Y, b1, b2, namegeo)) in enumerate(allGeos)

        @printf "+++++++++++ %s +++++++++++\n" namegeo

        # figure(figsize=(10, 7))
        # SI.plotBoundingBox(b1);
        # SI.plotBoundingBox(b2);
        # if size(X,1) == 3 
        #     PyPlot.plot(X[1,:],X[2,:],X[3,:],"*r",markersize=2.0)
        # else
        #     PyPlot.plot(X[1,:],X[2,:],"*r",markersize=2.0)
        # end
        # if size(Y,1) == 3
        #     PyPlot.plot(Y[1,:],Y[2,:],Y[3,:],"*b",markersize=2.0)
        # else
        #     PyPlot.plot(Y[1,:],Y[2,:],"*b",markersize=2.0)
        # end
        # PyPlot.title(namegeo) 
        # PyPlot.savefig(string("../../report_20171107/geo-", geoid, ".png"),format="png",dpi="figure",bbox_inches="tight",pad_inches=0.1)

        Nx = size(X,2)
        Ny = size(Y,2)
        if size(X,1) == 3
            kernel = (x, y) -> 1./SI.distance_2(x, y)
        elseif size(X,1) == 2
            kernel = (x, y) -> log.(SI.distance_2(x,y))
        end
        kernelR = (rx, ry) -> kernel(SI.RtoX(b1,rx),SI.RtoX(b2,ry))
        Ktrue = SI.meshKernelFull(kernel, X, Y)
        I1 = SI.getIntervals(b1)
        I2 = SI.getIntervals(b2)
        
        errs = []
        r0s  = []
        r1s  = []
        
        tols = logspace(-10, -3, 8)
        ranksSVD = SI.rank_eps_fro(Ktrue, tols)
        ranksACA = zeros(size(ranksSVD))
        errsACA = zeros(size(ranksACA))
        for (i,tol) in enumerate(tols)
            Uaca, Vaca, Iaca, Jaca = pp_aca(kernel, X, Y, tol, Ktrue=Ktrue);
            @assert length(Iaca) == length(Jaca)
            errsACA[i] = vecnorm(Ktrue - Uaca*Vaca')/vecnorm(Ktrue)
            @printf "==> Error for ACA for tol %e is %e with rank %d\n" tol errsACA[i] length(Iaca)
            ranksACA[i] = length(Iaca)
        end
        
        # Points around
        b1sphere = SI.SphereBoundingBox{Float64}(X)
        b2sphere = SI.SphereBoundingBox{Float64}(Y)
        XbarCircle = n -> (SI.generatePoints(b1sphere, n), ones(n))
        YbarCircle = n -> (SI.generatePoints(b2sphere, n), ones(n))

        # Random
        XbarRand = n -> (X[:,randperm(Nx)[1:n]], ones(n))
        YbarRand = n -> (Y[:,randperm(Ny)[1:n]], ones(n))

        # Chebyshev
        function XbarCheb(n)
            (orderx, ordery) = SI.getInterpolationOrder(kernelR, I1, I2, n, logLevel=SI.debug)
            interpx = SI.TensorInterpolator{Float64}(I1, orderx, SI.chebyshev)
            Xbar = interpx.xk
            Wx   = interpx.w_intk
            return (SI.RtoX(b1, Xbar), Wx)
        end
        function YbarCheb(n)
            (orderx, ordery) = SI.getInterpolationOrder(kernelR, I1, I2, n, logLevel=SI.debug)
            interpy = SI.TensorInterpolator{Float64}(I2, ordery, SI.chebyshev)
            Ybar = interpy.xk
            Wy   = interpy.w_intk
            return (SI.RtoX(b2, Ybar), Wy)
        end

        # MDV
        XbarMdv = n -> (X[:,SI.get_mdv(X,n)], ones(n))
        YbarMdv = n -> (Y[:,SI.get_mdv(Y,n)], ones(n))

        # All the different methods
        # XbarFuns = [XbarCheb, XbarRand, XbarMdv, XbarCircle]
        # YbarFuns = [YbarCheb, YbarRand, YbarMdv, YbarCircle]
        # names    = ["Chebyshev", (@sprintf "Random (av. %d)" nRepeat), "MDV", "Exo-Sphere"]

        # r0all    = [ [] , [], [], [] ]
        # r1all    = [ [] , [], [], [] ]
        # errall   = [ [] , [], [], [] ]

        XbarFuns = [XbarMdv]
        YbarFuns = [YbarMdv]
        names    = ["MDV"]

        r0all    = [ [] ]
        r1all    = [ [] ]
        errall   = [ [] ]
        
        for (it, tol) in enumerate(tols)
            
            stop = (r0, r1, Xhat, Yhat) -> SI.accuracy_stopping_criterion(kernel, X, Y, Xhat, Yhat, Ktrue, tol)

            for (id, (XbarFun, YbarFun, name, r0s, r1s, errs)) in enumerate(zip(XbarFuns, YbarFuns, names, r0all, r1all, errall))

                rmax = min(1000,min(size(X,2),size(Y,2)))
                if name == "Chebyshev" || name == "Exo-Sphere" # Cheb and Sphere ok with big ranks
                    rmax = 1000
                end

                Xhat = nothing
                Yhat = nothing
                Xbar = nothing
                Ybar = nothing

                r0Rep = []
                r1Rep = []
                errRep = []
                if contains(name, "Random") # Random only
                    nRepeatActual = nRepeat
                else
                    nRepeatActual = 1
                end
                for rep in 1:nRepeatActual
                    (Xhat, Yhat, Xbar, Ybar) = SI.adaptive_si(kernel, XbarFun, YbarFun, 0.1*tol, rinit=4, rmax=rmax, rinc=x->Int(ceil(x*1.1)), loglevel=SI.debug, stopping_criterion=stop)
                    push!(r0Rep, max(size(Xbar, 2), size(Ybar, 2)))
                    push!(r1Rep, max(size(Xhat, 2), size(Yhat, 2)))
                    push!(errRep, SI.err_si_frob(kernel, X, Y, Xhat, Yhat))
                end

                push!(errs, mean(errRep))
                push!(r0s, mean(r0Rep))
                push!(r1s, mean(r1Rep))
            
                @printf "==> Error for %s for tol %e is %e with rank r0 %d r1 %d\n" name tol errs[end] r0s[end] r1s[end]
           
                # if id == length(names) && it == 1
                #     PyPlot.figure();
                #     if size(X,1) == 2
                #         PyPlot.plot(X[1,:],X[2,:],"*r");
                #         PyPlot.plot(Y[1,:],Y[2,:],"*b");
                #         PyPlot.plot(Xhat[1,:],Xhat[2,:],"db",markersize=2.5);
                #         PyPlot.plot(Yhat[1,:],Yhat[2,:],"dr",markersize=2.5);
                #         PyPlot.plot(Xbar[1,:],Xbar[2,:],"sb",markersize=1.5);
                #         PyPlot.plot(Ybar[1,:],Ybar[2,:],"sr",markersize=1.5);
                #         axis("equal")
                #     else size(X,1) == 3
                #         PyPlot.plot3D(X[1,:],X[2,:],X[3,:],"*r");
                #         PyPlot.plot3D(Y[1,:],Y[2,:],Y[3,:],"*b");
                #         PyPlot.plot3D(Xhat[1,:],Xhat[2,:],Xhat[3,:],"db",markersize=2.5);
                #         PyPlot.plot3D(Yhat[1,:],Yhat[2,:],Yhat[3,:],"dr",markersize=2.5);
                #         PyPlot.plot3D(Xbar[1,:],Xbar[2,:],Xbar[3,:],"sb",markersize=1.5);
                #         PyPlot.plot3D(Ybar[1,:],Ybar[2,:],Ybar[3,:],"sr",markersize=1.5);
                #         axis("equal")
                #     end
                # end

            end
        
        end
   
        fig = figure(figsize=(10, 7))
        subplot(1,2,1)
        semilogx(tols, ranksSVD, label="SVD")
        semilogx(tols, ranksACA, label="ACA")
        for (r0s, r1s, errs, name) in zip(r0all, r1all, errall, names)
            semilogx(tols, r0s, label=string("r0 ", name))
            semilogx(tols, r1s, label=string("r1 ", name), linestyle="--")
        end
        title(string("Ranks for ", namegeo))
        xlabel("Tolerance")
        ylabel("Ranks")
        legend()
        # PyPlot.savefig(string("../../report_20171107/rank-", geoid, ".png"),format="png",dpi="figure",bbox_inches="tight",pad_inches=0.1)
        
        # figure()
        subplot(1,2,2)
        loglog(tols, tols, label="Tol")
        loglog(tols, errsACA, label="ACA")
        for (r0s, r1s, errs, name) in zip(r0all, r1all, errall, names)
            loglog(tols, errs, label=name)
        end
        title(string("Error for ", namegeo))
        legend() 


    end

end

main()
