include("../src/SI.jl")
include("../benchmarks/aca.jl")
include("../benchmarks/geometries.jl")

using SI
using PyPlot
using LowRankApprox

r0 = []
r1 = []
errCheb = []
errAca  = []

n = 1000 # e(-r^2) with 200, 1:h 2h:3h ... and pi/4 shift gives weird stuff
#h = div(n,8)
#starts = [1, 2*h, 4*h, 6*h]
h = div(n,8)
starts = [1, 4*h]
ends   = starts+50
theta = linspace(0, 2*pi, n+1)[vcat( [s:e for (s,e) in zip(starts, ends)]... )]
theta = theta'
X = vcat(cos.(theta), sin.(theta))
rho = 2.0
beta = 1.8
l1 = beta * (1.-1/rho)
l2 = beta * (   1/rho)
Y = vcat(l1*cos.(theta), l2*sin.(theta))

PyPlot.close("all")
    
PyPlot.figure(1)
PyPlot.subplot(221)
PyPlot.plot(X[1,:],X[2,:], "*")
PyPlot.plot(Y[1,:],Y[2,:], "d")
PyPlot.axis("equal")
PyPlot.title("K = 1/r^3")
K = (x, y) -> 1./SI.distance_2(x,y).^3
# K = (x, y) -> exp(-SI.distance_2(x, y).^2/0.2.^2) extreme case
    
Ktrue = SI.meshKernelFull(K, X, Y)

r_aca = []
r_rrqr = []
e_aca = []
e_rrqr = []

tols = logspace(-10, -1, 10)

for tol in tols
    @printf "--- %e\n" tol
    (U,V,Iaca,Jaca) = pp_aca(K, X, Y, tol)
    err_aca = SI.err_si_frob(K, X, Y, X[:,Iaca], Y[:,Jaca], Ktrue)
    @printf "-> ACA  (%d) error %e\n" length(Iaca) err_aca
    (px, py) = SI.cur2(Ktrue, Ktrue, tol)
    err_rrqr = SI.err_si_frob(K, X, Y, X[:,px], Y[:,py], Ktrue)
    @printf "-> CUR (%d) error %e\n" length(px) err_rrqr

    push!(r_aca, length(Iaca))
    push!(r_rrqr, length(px))
    push!(e_aca, err_aca)
    push!(e_rrqr, err_rrqr)
end

PyPlot.subplot(223)
PyPlot.loglog(tols, e_aca, "*-")
PyPlot.loglog(tols, e_rrqr, "*-")
PyPlot.legend(("ACA", "RRQR"))
PyPlot.ylabel("Error(tol), rank free")
PyPlot.xlabel("Tol")

PyPlot.subplot(224)
PyPlot.semilogx(tols, r_aca, "*-")
PyPlot.semilogx(tols, r_rrqr, "*-")
PyPlot.legend(("ACA", "RRQR"))
PyPlot.ylabel("Rank(tol), rank free")
PyPlot.xlabel("Tol")

r_aca_rrqr = []
e_aca = []
e_rrqr = []

for tol = logspace(-10, -1, 10)
    @printf "--- %e\n" tol
    #(px, py) = SI.cur2(Ktrue, Ktrue, tol)
    (px, py) = cur(Ktrue, rtol=tol, maxdet_tol=1e-6, sketch=:none)
    err_rrqr = SI.err_si_frob(K, X, Y, X[:,px], Y[:,py], Ktrue)
    @printf "-> CUR (%d) error %e\n" length(px) err_rrqr
    (U,V,Iaca,Jaca) = pp_aca(K, X, Y, tol, fixedRank=length(px))
    err_aca = SI.err_si_frob(K, X, Y, X[:,Iaca], Y[:,Jaca], Ktrue)
    @printf "-> ACA  (%d) error %e\n" length(Iaca) err_aca
    @assert length(Iaca) == length(px)
    push!(r_aca_rrqr, length(Iaca))
    push!(e_aca, err_aca)
    push!(e_rrqr, err_rrqr)
end

PyPlot.subplot(222)
PyPlot.semilogy(r_aca_rrqr, e_aca, "*-")
PyPlot.semilogy(r_aca_rrqr, e_rrqr, "*-")
PyPlot.legend(("ACA", "RRQR"))
PyPlot.ylabel("Error(rank)")
PyPlot.xlabel("Rank")


# for tol in tols
#    
#     npts = 20
#     rrt = Array{Array{Float64, 1}, 1}(3)
#     for d = 1:3
#         rrt[d] = linspace(0, 1, npts)
#     end
#     RRt2d = SI.tensor_grid(rrt[1:2])
# 
#     # 2 & 1 \\ 1 & 2
#     X2 = broadcast(+, [0, 0, 0.1], [0 1 ; 0 0 ; 1 0] * [1 0 ; 0 1] * RRt2d) # y = 0 side
#     Y1 = broadcast(+, [2, 0, 0.1], [0 1 ; 0 0 ; 1 0] * [1 0 ; 0 1] * RRt2d)
#     n1 = [0.0, -1.0, 0.0]
#     X1 = broadcast(+, [0, 0.1, 0], [0 1 ; 1 0 ; 0 0] * [1 0 ; 0 1] * RRt2d) # z = 0 side
#     Y2 = broadcast(+, [2, 0.1, 0], [0 1 ; 1 0 ; 0 0] * [1 0 ; 0 1] * RRt2d)
#     n2 = [0.0, 0.0, -1.0]
# 
#     X = hcat(X1, X2)
#     Y = hcat(Y1, Y2)
# 
#     function kernel_fun(X, Y)
#         nx = size(X,2)
#         ny = size(Y,2)
#         @assert nx == ny || nx == 1 || ny == 1
#         n = max(nx, ny)
#         Z = zeros(n)
#         for i = 1:n
#             if nx == 1
#                 x = X[:,1]
#             else
#                 x = X[:,i]
#             end
#             if ny == 1
#                 y = Y[:,1]
#             else
#                 y = Y[:,i]
#             end
#             if y[2] == 0.0 # Y1
#                 n = n1
#             elseif y[3] == 0.0 
#                 n = n2
#             else
#                 @assert false || "Y[2/3] should be 0.0"
#             end
#             d = SI.distance_2(x, y)[1]
#             Z[i] = dot(x-y, n) / d^(3.0)
#         end
#         return Z
#     end
# 
#     bx1 = SI.CubeBoundingBox{Float64}(X1, compress=true, axisaligned=true)
#     bx2 = SI.CubeBoundingBox{Float64}(X2, compress=true, axisaligned=true)
#     by1 = SI.CubeBoundingBox{Float64}(Y1, compress=true, axisaligned=true)
#     by2 = SI.CubeBoundingBox{Float64}(Y2, compress=true, axisaligned=true)
#     
#     Ix1 = SI.getIntervals(bx1)
#     Ix2 = SI.getIntervals(bx2)
#     Iy1 = SI.getIntervals(by1)
#     Iy2 = SI.getIntervals(by2)
# 
#     PyPlot.figure(1)
#     PyPlot.plot3D(X[1,:],X[2,:],X[3,:])
#     PyPlot.plot3D(Y[1,:],Y[2,:],Y[3,:])
#     SI.plotBoundingBox(bx1, 1)
#     SI.plotBoundingBox(bx2, 1)
#     SI.plotBoundingBox(by1, 1)
#     SI.plotBoundingBox(by2, 1)
# 
#     Ktrue = SI.meshKernelFull(kernel_fun, X, Y)
# 
#     @show SI.rank_eps_fro(Ktrue, [1e-6])
# 
#     # Pure ACA
#     (U,V,Iaca,Jaca) = pp_aca(kernel_fun, X, Y, tol)
#     errSIaca = SI.err_si_frob(kernel_fun, X, Y, X[:,Iaca], Y[:,Jaca], Ktrue)
#     @printf "-> ACA  (%d) SI-error %e\n" length(Iaca) errSIaca
#     
#     function BarCheb(n, which)
#         o = Int64(floor(sqrt(div(n,2))))
#         if which == 'X'
#             Is = [Ix1, Ix2]
#             Bs = [bx1, bx2]
#         else
#             Is = [Iy1, Iy2]
#             Bs = [by1, by2]
#         end
#         interp1 = SI.TensorInterpolator{Float64}(Is[1], [o, o], SI.chebyshev)
#         interp2 = SI.TensorInterpolator{Float64}(Is[2], [o, o], SI.chebyshev)
#         Xbar = hcat(SI.RtoX(Bs[1], interp1.xk),    SI.RtoX(Bs[2], interp2.xk)   ) # [Xbar1, Xbar2]
#         Wx   = vcat(               interp1.w_intk,                interp2.w_intk) # [Wx1  ; Wx2]
#         return (Xbar, Wx)
#     end
#     XbarCheb = n -> BarCheb(n, 'X')
#     YbarCheb = n -> BarCheb(n, 'Y')
#    
#     # SI - RRQR
#     stop = (r0, r1, Xhat, Yhat) -> SI.accuracy_stopping_criterion(kernel_fun, X, Y, Xhat, Yhat, Ktrue, tol)
#     (Xhat, Yhat, Xbar, Ybar) = SI.adaptive_si(kernel_fun, XbarCheb, YbarCheb, coef * tol, rinit=8, rmax=1000, rinc=x->Int(ceil(x*1.1)), loglevel=SI.debug, stopping_criterion=stop)
#     fixedRankR0 = size(Xbar,2)
#     fixedRankR1 = size(Xhat,2)
#     Khat = SI.meshKernelFull(kernel_fun, Xhat, Yhat)
#     errSIcheb = SI.err_si_frob(kernel_fun, X, Y, Xhat, Yhat, Ktrue)
#     @printf "Using Rank %d\n" fixedRankR1
#     @printf "SVD Rank Kbar is %d\n" SI.rank_eps_fro(SI.meshKernelFull(kernel_fun, Xhat, Yhat), [tol])[1]
#     @printf "-> Cheb (%d) SI-error %e\n" size(Xhat,2) errSIcheb
#      
#     # SI-ACA
#     (U,V,Iaca,Jaca) = pp_aca(kernel_fun, Xbar, Ybar, tol, fixedRank = fixedRankR1)
#     errSIaca = SI.err_si_frob(kernel_fun, X, Y, Xbar[:,Iaca], Ybar[:,Jaca], Ktrue)
#     @printf "-> ACA  (%d) SI-error %e\n" length(Iaca) errSIaca
# 
#     push!(r0, size(Xbar, 2))
#     push!(r1, size(Xhat, 2))
#     push!(errCheb, errSIcheb)
#     push!(errAca,  errSIaca)
#     
# end
