include("../../src/SI.jl")
include("../../benchmarks/geometries.jl")
include("../../benchmarks/aca.jl")
using PyPlot
close("all")

b1 = SI.CubeBoundingBox{Float64}([-1.,-1.],[3.,3.])
b2 = SI.CubeBoundingBox{Float64}([0.5,0.5],[3.,3.])
Xtmp = SI.generatePoints(b1, [50, 50])
Ytmp = SI.generatePoints(b2, [50, 50])

X = zeros(2,0)
Y = zeros(2,0)
cx = [0;0];
cy = [2.5;2.5];
for i = 1:size(Xtmp)[2]
    if vecnorm(Xtmp[:,i]-cx) <= 1.0
        X=hcat(X,Xtmp[:,i])
    end
    if vecnorm(Ytmp[:,i]-cy) <= 1.0
        Y=hcat(Y,Ytmp[:,i])
    end
end

# PyPlot.figure()
# PyPlot.plot(X[1,:],X[2,:],"or");
# PyPlot.plot(Y[1,:],Y[2,:],"ob");

kernel = (x, y) -> log.(SI.distance_2(x,y))
Ktrue = SI.meshKernelFull(kernel, X, Y)
    
cbb1 = SI.CubeBoundingBox{Float64}(X, axisaligned=true)
cbb2 = SI.CubeBoundingBox{Float64}(Y, axisaligned=true)
I1 = SI.getIntervals(cbb1)
I2 = SI.getIntervals(cbb2)
kernelR = (rx, ry) -> kernel(SI.RtoX(cbb1,rx),SI.RtoX(cbb2,ry))

b1sphere = SI.SphereBoundingBox{Float64}(X)
b2sphere = SI.SphereBoundingBox{Float64}(Y)
XbarCircle = n -> (SI.generatePoints(b1sphere, n), ones(n))
YbarCircle = n -> (SI.generatePoints(b2sphere, n), ones(n))
        
XbarMdv = n -> (X[:,SI.get_mdv(X,n)], ones(n))
YbarMdv = n -> (Y[:,SI.get_mdv(Y,n)], ones(n))
        
function XbarCheb(n)
    (orderx, ordery) = SI.getInterpolationOrder(kernelR, I1, I2, n, logLevel=SI.none)
    interpx = SI.TensorInterpolator{Float64}(I1, orderx, SI.chebyshev)
    Xbar = interpx.xk
    Wx   = interpx.w_intk
    return (SI.RtoX(cbb1, Xbar), Wx)
end
function YbarCheb(n)
    (orderx, ordery) = SI.getInterpolationOrder(kernelR, I1, I2, n, logLevel=SI.none)
    interpy = SI.TensorInterpolator{Float64}(I2, ordery, SI.chebyshev)
    Ybar = interpy.xk
    Wy   = interpy.w_intk
    return (SI.RtoX(cbb2, Ybar), Wy)
end

N = size(X)[2]
XbarRnd = n -> (X[:,randperm(N)[1:n]], ones(n))
YbarRnd = n -> (Y[:,randperm(N)[1:n]], ones(n))

r0Circle = []
r0Mdv    = []
r1Circle = []
r1Mdv    = []
r0Cheb   = []
r1Cheb   = []
r1ACA    = []
r0Rnd    = []
r1Rnd    = []

XhatMDV = nothing
YhatMDV = nothing
XhatC   = nothing
YhatC   = nothing
XhatCheb = nothing
YhatCheb = nothing

tols = logspace(-14, -14, 1)

for (i,tol) = enumerate(tols)


    stop = (r0, r1, Xhat, Yhat) -> SI.accuracy_stopping_criterion(kernel, X, Y, Xhat, Yhat, Ktrue, tol)
    (Xhat, Yhat, Xbar, Ybar) = SI.adaptive_si(kernel, XbarCircle, YbarCircle, 0.1*tol, rinit=4, rmax=1000, rinc=x->Int(ceil(x*1.1)), loglevel=SI.debug, stopping_criterion=stop)
    push!(r0Circle, size(Xbar)[2])
    push!(r1Circle, size(Xhat)[2])
    if i == 1
        PyPlot.figure(1)
        tt = linspace(0,2*pi,100);
        PyPlot.plot(cx[1]+cos.(tt),cx[2]+sin.(tt), "--r")
        PyPlot.plot(cy[1]+cos.(tt),cy[2]+sin.(tt), "--b")
        PyPlot.plot(Xhat[1,:],Xhat[2,:],"*b",label="Circle");
        PyPlot.plot(Yhat[1,:],Yhat[2,:],"*r",label="Circle");
        
        PyPlot.figure(2)
        tt = linspace(0,2*pi,100);
        PyPlot.plot(cx[1]+cos.(tt),cx[2]+sin.(tt), "--r")
        PyPlot.plot(cy[1]+cos.(tt),cy[2]+sin.(tt), "--b")
        PyPlot.plot(Xbar[1,:],Xbar[2,:],"*b",label="Circle");
        PyPlot.plot(Ybar[1,:],Ybar[2,:],"*r",label="Circle");
    
        writedlm("circle_mdv_vs_circle.dat", hcat(Xhat', Yhat'))
        writedlm("circle_mdv_vs_circle_0.dat", hcat(Xbar', Ybar'))

    end
    if i == 1
        XhatC = Xhat
        YhatC = Yhat
    end
    (Xhat, Yhat, Xbar, Ybar) = SI.adaptive_si(kernel, XbarMdv, YbarMdv, 0.1*tol, rinit=4, rmax=1000, rinc=x->Int(ceil(x*1.1)), loglevel=SI.debug, stopping_criterion=stop)
    push!(r0Mdv, size(Xbar)[2])
    push!(r1Mdv, size(Xhat)[2])
    if i == 1
        PyPlot.figure(1)
        PyPlot.plot(Xhat[1,:],Xhat[2,:],"dk",label="Mdv");
        PyPlot.plot(Yhat[1,:],Yhat[2,:],"dg",label="Mdv");
        PyPlot.figure(2)
        PyPlot.plot(Xbar[1,:],Xbar[2,:],"dk",label="Mdv");
        PyPlot.plot(Ybar[1,:],Ybar[2,:],"dg",label="Mdv");
        
        writedlm("mdv_mdv_vs_circle.dat", hcat(Xhat', Yhat'))
        writedlm("mdv_mdv_vs_circle_0.dat", hcat(Xbar', Ybar'))
        
        @show "====== MDV ====="
        @show hcat(Xhat', Yhat')
        @show hcat(Xbar', Ybar')
    end
    if i == 1
        XhatMDV = Xhat
        YhatMDV = Yhat
    end
    (Xhat, Yhat, Xbar, Ybar) = SI.adaptive_si(kernel, XbarCheb, YbarCheb, 0.1*tol, rinit=4, rmax=1000, rinc=x->Int(ceil(x*1.1)), loglevel=SI.debug, stopping_criterion=stop)
    push!(r0Cheb, size(Xbar)[2])
    push!(r1Cheb, size(Xhat)[2])
    # if i == 1
    #     PyPlot.figure(1)
    #     PyPlot.plot(Xhat[1,:],Xhat[2,:],"om",label="Cheb");
    #     PyPlot.plot(Yhat[1,:],Yhat[2,:],"oc",label="Cheb");
    # end
    if i == 1
        XhatCheb = Xhat
        YhatCheb = Yhat
    end
    r0rndTmp = []
    r1rndTmp = []
    # for repeat = 1:10
    #     (Xhat, Yhat, Xbar, Ybar) = SI.adaptive_si(kernel, XbarRnd, YbarRnd, 0.1*tol, rinit=4, rmax=1000, rinc=x->Int(ceil(x*1.1)), loglevel=SI.debug, stopping_criterion=stop)
    #     push!(r0rndTmp, size(Xbar)[2])
    #     push!(r1rndTmp, size(Xhat)[2])
    #     if i == 1
    #         PyPlot.figure(1)
    #         PyPlot.plot(Xhat[1,:],Xhat[2,:],"om",label="Rand");
    #         PyPlot.plot(Yhat[1,:],Yhat[2,:],"oc",label="Rand");
    #         PyPlot.legend()
    #     end
    # end
    push!(r0Rnd, mean(r0rndTmp))
    push!(r1Rnd, mean(r1rndTmp))

    #(U, V) = pp_aca(kernel, X, Y, tol)
    #push!(r1ACA, size(U)[2])
    
end

r01svd = SI.rank_eps_fro(Ktrue, tols, ll=SI.info)

PyPlot.figure(3)
PyPlot.semilogx(tols, r0Circle, label="r0 Circle")
PyPlot.semilogx(tols, r0Mdv,    label="r0 Mdv")
PyPlot.semilogx(tols, r0Cheb,   label="r0 Cheb")
PyPlot.semilogx(tols, r0Rnd,    label="r0 Rand")
PyPlot.semilogx(tols, r01svd,   label="SVD")
PyPlot.legend()
PyPlot.figure(4)
PyPlot.semilogx(tols, r1Circle, label="r1 Circle")
PyPlot.semilogx(tols, r1Mdv,    label="r1 Mdv")
PyPlot.semilogx(tols, r1Cheb,   label="r1 Cheb")
PyPlot.semilogx(tols, r1Rnd,    label="r1 Rand")
PyPlot.semilogx(tols, r01svd,   label="SVD")
PyPlot.legend()
