include("../src/SI.jl")
include("../LSTC_benchmarks/loadLSTC.jl")
using PyPlot
close("all")

# b1 = SI.CubeBoundingBox{Float64}([-1.,-1.],[3.,3.])
# b2 = SI.CubeBoundingBox{Float64}([0.5,0.5],[3.,3.])
# Xtmp = SI.generatePoints(b1, [50, 50])
# Ytmp = SI.generatePoints(b2, [50, 50])

(trash,XYs) = getLSTCbenchmark("../LSTC_benchmarks/", 6)

X = XYs[1]
Y = XYs[36]

@show SI.cluster_distance(X,Y)

PyPlot.figure()
PyPlot.plot3D(X[1,:],X[2,:],X[3,:],"or");
PyPlot.plot3D(Y[1,:],Y[2,:],Y[3,:],"ob");

kernel = (x, y) -> 1./(SI.distance_2(x,y))
Ktrue = SI.meshKernelFull(kernel, X, Y)
    
cbb1 = SI.CubeBoundingBox{Float64}(X)
cbb2 = SI.CubeBoundingBox{Float64}(Y)
I1 = SI.getIntervals(cbb1)
I2 = SI.getIntervals(cbb2)
kernelR = (rx, ry) -> kernel(SI.RtoX(cbb1,rx),SI.RtoX(cbb2,ry))

b1sphere = SI.SphereBoundingBox{Float64}(X)
b2sphere = SI.SphereBoundingBox{Float64}(Y)
XbarCircle = n -> (SI.generatePoints(b1sphere, n), ones(n))
YbarCircle = n -> (SI.generatePoints(b2sphere, n), ones(n))
        
XbarMdv = n -> (X[:,SI.get_mdv(X,n)], ones(n))
YbarMdv = n -> (Y[:,SI.get_mdv(Y,n)], ones(n))
        
function XbarCheb(n)
    (orderx, ordery) = SI.getInterpolationOrder(kernelR, I1, I2, n, logLevel=SI.none)
    interpx = SI.TensorInterpolator{Float64}(I1, orderx, SI.chebyshev)
    Xbar = interpx.xk
    Wx   = interpx.w_intk
    return (SI.RtoX(cbb1, Xbar), Wx)
end
function YbarCheb(n)
    (orderx, ordery) = SI.getInterpolationOrder(kernelR, I1, I2, n, logLevel=SI.none)
    interpy = SI.TensorInterpolator{Float64}(I2, ordery, SI.chebyshev)
    Ybar = interpy.xk
    Wy   = interpy.w_intk
    return (SI.RtoX(cbb2, Ybar), Wy)
end

Nx = size(X)[2]
Ny = size(Y)[2]
XbarRnd = n -> (X[:,randperm(Nx)[1:n]], ones(n))
YbarRnd = n -> (Y[:,randperm(Ny)[1:n]], ones(n))

r0Circle = []
r0Mdv    = []
r1Circle = []
r1Mdv    = []
r0Cheb   = []
r1Cheb   = []
r1ACA    = []
r0Rnd    = []
r1Rnd    = []
r2Circle = []
r2Mdv    = []
r2Cheb   = []
r2Rnd    = []

XhatMDV = nothing
YhatMDV = nothing
XhatC   = nothing
YhatC   = nothing
XhatCheb = nothing
YhatCheb = nothing
XhatR = nothing
YhatR = nothing

function recompress(Xhat, Yhat, tol) # Does K(X,Yhat) K(Xhat,Yhat)^-1 K(Xhat,Y) = (QxRx) * ... * (Ry^TQy^T) = Qx (Rx * K(Xhat,Yhat)^-1 * Ry^T) Qy^T = Qx Ux S Uy^T Qy^T
                                     # Return a U*V' factorization
    # K ~ Kl * Kc^-1 * Kr
    Kl = SI.meshKernelFull(kernel, X, Yhat)
    Kc = SI.meshKernelFull(kernel, Xhat, Yhat)
    Kr = SI.meshKernelFull(kernel, Xhat, Y)
    # QR on Kl
    Klqr = qrfact(Kl, Val{false})
    Rx = Klqr[:R]
    Qx = Klqr[:Q][:,1:size(Rx)[1]]
    # QR on Kr^T
    Krqr = qrfact(Kr', Val{false})
    Ry = Krqr[:R]
    Qy = Krqr[:Q][:,1:size(Ry)[1]]
    # Merge with center 
    Kcup = Rx * (Kc \ Ry')
    # SVD on center matrix
    Kcsvd = svdfact(Kcup)
    Ux = Kcsvd[:U]
    Uy = Kcsvd[:V]
    S  = Kcsvd[:S]
    # Approx and return
    r2 = SI.cut_spectrum_fro(S, tol)
    U = Qx * Ux[:,1:r2]
    V = Qy * Uy[:,1:r2] * diagm(S[1:r2])
    # Check error
    @printf "Recompression: error %e -> %e, r1 %d -> r2 %d\n" vecnorm(Ktrue - Kl*(Kc\Kr))/vecnorm(Ktrue) vecnorm(Ktrue - U*V')/vecnorm(Ktrue) size(Xhat)[2] size(U)[2]
    return U, V
end

tols = logspace(-10, -1, 10)

for (i,tol) = enumerate(tols)

    @printf "+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
    @printf "Tolerance %e\n" tol

    stop = (r0, r1, Xhat, Yhat) -> SI.accuracy_stopping_criterion(kernel, X, Y, Xhat, Yhat, Ktrue, tol)
    (Xhat, Yhat, Xbar, Ybar) = SI.adaptive_si(kernel, XbarCircle, YbarCircle, 0.1*tol, rinit=4, rmax=1000, rinc=x->Int(ceil(x*1.1)), loglevel=SI.debug, stopping_criterion=stop)
    push!(r0Circle, size(Xbar)[2])
    push!(r1Circle, size(Xhat)[2])
    (U,V) = recompress(Xhat,Yhat,tol)
    push!(r2Circle, size(U)[2])
    if i == 1
        PyPlot.figure()
        PyPlot.plot3D(Xhat[1,:],Xhat[2,:],Xhat[3,:],"*b",label="Sphere");
        PyPlot.plot3D(Yhat[1,:],Yhat[2,:],Yhat[3,:],"*r",label="Sphere");
    end
    if i == 1
        XhatC = Xhat
        YhatC = Yhat
    end
    (Xhat, Yhat, Xbar, Ybar) = SI.adaptive_si(kernel, XbarMdv, YbarMdv, 0.1*tol, rinit=4, rmax=min(size(X)[2], size(Y)[2]), rinc=x->Int(ceil(x*1.1)), loglevel=SI.debug, stopping_criterion=stop)
    push!(r0Mdv, size(Xbar)[2])
    push!(r1Mdv, size(Xhat)[2])
    (U,V) = recompress(Xhat,Yhat,tol)
    push!(r2Mdv, size(U)[2])
    if i == 1
        PyPlot.plot3D(Xhat[1,:],Xhat[2,:],Xhat[3,:],"dk",label="Mdv");
        PyPlot.plot3D(Yhat[1,:],Yhat[2,:],Yhat[3,:],"dg",label="Mdv");
    end
    if i == 1
        XhatMDV = Xhat
        YhatMDV = Yhat
    end
    (Xhat, Yhat, Xbar, Ybar) = SI.adaptive_si(kernel, XbarCheb, YbarCheb, 0.1*tol, rinit=4, rmax=1000, rinc=x->Int(ceil(x*1.1)), loglevel=SI.debug, stopping_criterion=stop)
    push!(r0Cheb, size(Xbar)[2])
    push!(r1Cheb, size(Xhat)[2])
    (U,V) = recompress(Xhat,Yhat,tol)
    push!(r2Cheb, size(U)[2])
    if i == 1
        PyPlot.plot3D(Xhat[1,:],Xhat[2,:],Xhat[3,:],"om",label="Cheb");
        PyPlot.plot3D(Yhat[1,:],Yhat[2,:],Yhat[3,:],"oc",label="Cheb");
    end
    if i == 1
        XhatCheb = Xhat
        YhatCheb = Yhat
    end
    r0rndTmp = []
    r1rndTmp = []
    r2rndTmp = []
    for repeat = 1:25
        (Xhat, Yhat, Xbar, Ybar) = SI.adaptive_si(kernel, XbarRnd, YbarRnd, 0.1*tol, rinit=4, rmax=min(size(X)[2], size(Y)[2]), rinc=x->Int(ceil(x*1.1)), loglevel=SI.debug, stopping_criterion=stop)
        push!(r0rndTmp, size(Xbar)[2])
        push!(r1rndTmp, size(Xhat)[2])
        if i == 1 && repeat == 1
            PyPlot.plot3D(Xhat[1,:],Xhat[2,:],Xhat[3,:],"om",label="Rand");
            PyPlot.plot3D(Yhat[1,:],Yhat[2,:],Yhat[3,:],"oc",label="Rand");
            PyPlot.legend()
        end
        (U,V) = recompress(Xhat,Yhat,tol)
        push!(r2rndTmp, size(U)[2])
    end
    push!(r0Rnd, mean(r0rndTmp))
    push!(r1Rnd, mean(r1rndTmp))
    push!(r2Rnd, mean(r2rndTmp))

end

r01svd = SI.rank_eps_fro(Ktrue, tols, ll=SI.info)

PyPlot.figure()
PyPlot.semilogx(tols, r0Circle, label="r0 Sphere")
PyPlot.semilogx(tols, r0Mdv,    label="r0 Mdv")
PyPlot.semilogx(tols, r0Cheb,   label="r0 Cheb")
PyPlot.semilogx(tols, r0Rnd,    label="r0 Rand")
PyPlot.semilogx(tols, r01svd,   label="SVD")
PyPlot.legend()
PyPlot.figure()
PyPlot.semilogx(tols, r1Circle, label="r1 Sphere")
PyPlot.semilogx(tols, r1Mdv,    label="r1 Mdv")
PyPlot.semilogx(tols, r1Cheb,   label="r1 Cheb")
PyPlot.semilogx(tols, r1Rnd,    label="r1 Rand")
PyPlot.semilogx(tols, r01svd,   label="SVD")
PyPlot.legend()
PyPlot.figure()
PyPlot.semilogx(tols, r2Circle, label="r2 Sphere")
PyPlot.semilogx(tols, r2Mdv,    label="r2 Mdv")
PyPlot.semilogx(tols, r2Cheb,   label="r2 Cheb")
PyPlot.semilogx(tols, r2Rnd,    label="r2 Rand")
PyPlot.semilogx(tols, r01svd,   label="SVD")
PyPlot.legend()

f = open("r0r1r2.dat", "w")

write(f, "Error\tSVD\tr0-sphere\tr0-mdv\tr0-random\tr0-Cheb\tr1-sphere\tr1-mdv\tr1-random\tr1-Cheb\tr2-sphere\tr2-mdv\tr2-random\tr2-Cheb\n")
for i = 1:length(tols)
    write(f, @sprintf "%e\t%e\t%e\t%e\t%e\t%e\t%e\t%e\t%e\t%e\t%e\t%e\t%e\t%e\t\n" tols[i] r01svd[i] r0Circle[i] r0Mdv[i] r0Rnd[i] r0Cheb[i] r1Circle[i] r1Mdv[i] r1Rnd[i] r1Cheb[i] r2Circle[i] r2Mdv[i] r2Rnd[i] r2Cheb[i])
end
close(f)
