include("../src/SI.jl")
include("../benchmarks/aca.jl")
include("../benchmarks/geometries.jl")

using SI
using PyPlot
using JSON

coef = 0.1
tols = logspace(-10, -1, 10)

r0 = []
r1 = []
errCheb = []
errAca  = []

for tol in tols
   
    npts = 20
    rrt = Array{Array{Float64, 1}, 1}(3)
    for d = 1:3
        rrt[d] = linspace(0, 1, npts)
    end
    RRt2d = SI.tensor_grid(rrt[1:2])

    # 2 & 1 \\ 1 & 2
    X2 = broadcast(+, [0, 0, 0.1], [0 1 ; 0 0 ; 1 0] * [1 0 ; 0 1] * RRt2d) # y = 0 side
    Y1 = broadcast(+, [2, 0, 0.1], [0 1 ; 0 0 ; 1 0] * [1 0 ; 0 1] * RRt2d)
    n1 = [0.0, -1.0, 0.0]
    X1 = broadcast(+, [0, 0.1, 0], [0 1 ; 1 0 ; 0 0] * [1 0 ; 0 1] * RRt2d) # z = 0 side
    Y2 = broadcast(+, [2, 0.1, 0], [0 1 ; 1 0 ; 0 0] * [1 0 ; 0 1] * RRt2d)
    n2 = [0.0, 0.0, -1.0]

    X = hcat(X1, X2)
    Y = hcat(Y1, Y2)

    function kernel_fun(X, Y)
        nx = size(X,2)
        ny = size(Y,2)
        @assert nx == ny || nx == 1 || ny == 1
        n = max(nx, ny)
        Z = zeros(n)
        for i = 1:n
            if nx == 1
                x = X[:,1]
            else
                x = X[:,i]
            end
            if ny == 1
                y = Y[:,1]
            else
                y = Y[:,i]
            end
            if y[2] == 0.0 # Y1
                n = n1
            elseif y[3] == 0.0 
                n = n2
            else
                @assert false || "Y[2/3] should be 0.0"
            end
            d = SI.distance_2(x, y)[1]
            Z[i] = dot(x-y, n) / d^(3.0)
        end
        return Z
    end

    bx1 = SI.CubeBoundingBox{Float64}(X1, compress=true, axisaligned=true)
    bx2 = SI.CubeBoundingBox{Float64}(X2, compress=true, axisaligned=true)
    by1 = SI.CubeBoundingBox{Float64}(Y1, compress=true, axisaligned=true)
    by2 = SI.CubeBoundingBox{Float64}(Y2, compress=true, axisaligned=true)
    
    Ix1 = SI.getIntervals(bx1)
    Ix2 = SI.getIntervals(bx2)
    Iy1 = SI.getIntervals(by1)
    Iy2 = SI.getIntervals(by2)

    PyPlot.figure(1)
    PyPlot.plot3D(X[1,:],X[2,:],X[3,:])
    PyPlot.plot3D(Y[1,:],Y[2,:],Y[3,:])
    SI.plotBoundingBox(bx1, 1)
    SI.plotBoundingBox(bx2, 1)
    SI.plotBoundingBox(by1, 1)
    SI.plotBoundingBox(by2, 1)

    Ktrue = SI.meshKernelFull(kernel_fun, X, Y)

    @show SI.rank_eps_fro(Ktrue, [1e-6])

    # Pure ACA
    (U,V,Iaca,Jaca) = pp_aca(kernel_fun, X, Y, tol)
    errSIaca = SI.err_si_frob(kernel_fun, X, Y, X[:,Iaca], Y[:,Jaca], Ktrue)
    @printf "-> ACA  (%d) SI-error %e\n" length(Iaca) errSIaca
    
    function BarCheb(n, which)
        o = Int64(floor(sqrt(div(n,2))))
        if which == 'X'
            Is = [Ix1, Ix2]
            Bs = [bx1, bx2]
        else
            Is = [Iy1, Iy2]
            Bs = [by1, by2]
        end
        interp1 = SI.TensorInterpolator{Float64}(Is[1], [o, o], SI.chebyshev)
        interp2 = SI.TensorInterpolator{Float64}(Is[2], [o, o], SI.chebyshev)
        Xbar = hcat(SI.RtoX(Bs[1], interp1.xk),    SI.RtoX(Bs[2], interp2.xk)   ) # [Xbar1, Xbar2]
        Wx   = vcat(               interp1.w_intk,                interp2.w_intk) # [Wx1  ; Wx2]
        return (Xbar, Wx)
    end
    XbarCheb = n -> BarCheb(n, 'X')
    YbarCheb = n -> BarCheb(n, 'Y')
   
    # SI - RRQR
    stop = (r0, r1, Xhat, Yhat) -> SI.accuracy_stopping_criterion(kernel_fun, X, Y, Xhat, Yhat, Ktrue, tol)
    (Xhat, Yhat, Xbar, Ybar) = SI.adaptive_si(kernel_fun, XbarCheb, YbarCheb, coef * tol, rinit=8, rmax=1000, rinc=x->Int(ceil(x*1.1)), loglevel=SI.debug, stopping_criterion=stop)
    fixedRankR0 = size(Xbar,2)
    fixedRankR1 = size(Xhat,2)
    Khat = SI.meshKernelFull(kernel_fun, Xhat, Yhat)
    errSIcheb = SI.err_si_frob(kernel_fun, X, Y, Xhat, Yhat, Ktrue)
    @printf "Using Rank %d\n" fixedRankR1
    @printf "SVD Rank Kbar is %d\n" SI.rank_eps_fro(SI.meshKernelFull(kernel_fun, Xhat, Yhat), [tol])[1]
    @printf "-> Cheb (%d) SI-error %e\n" size(Xhat,2) errSIcheb
     
    # SI-ACA
    (U,V,Iaca,Jaca) = pp_aca(kernel_fun, Xbar, Ybar, tol, fixedRank = fixedRankR1)
    errSIaca = SI.err_si_frob(kernel_fun, X, Y, Xbar[:,Iaca], Ybar[:,Jaca], Ktrue)
    @printf "-> ACA  (%d) SI-error %e\n" length(Iaca) errSIaca

    push!(r0, size(Xbar, 2))
    push!(r1, size(Xhat, 2))
    push!(errCheb, errSIcheb)
    push!(errAca,  errSIaca)
    
end
