include("../src/SI.jl")
include("../benchmarks/geometries.jl")
using PyPlot
using Iterators
using Interpolations
using Polynomials
srand(0)

function main()

    kernel = (x,y) -> 1./(4+broadcast(-, x[1,:], y[1,:]))
    #kernel = (x, y) -> sin.(x[1,:]*pi).*cos.(y[1,:]*pi) + exp.(x[1,:]).*(y[1,:]+2)
    
    n = 1000
    b1 = SI.CubeBoundingBox{Float64}([-1.0],[2.0])
    b2 = SI.CubeBoundingBox{Float64}([-1.0],[2.0])
    X = reshape(Array(linspace(-1,1,n+1)[1:end-1]), (1, n))
    Y = reshape(Array(linspace(-1,1,n+1)[1:end-1]), (1, n))
    dS = 1.0/n
    nn = length(X[1:20:end])
    tol = 1e-10
    
    # Look at, say, left eigenfunctions
    (U, F, V, Xhat, Yhat, Uuc, Fuc, Vuc, I1, I2, diags) = SI.meshKernelLowRankAdaptive(kernel, b1, b2, X, Y, tol, ruletype=SI.chebyshev, useSI=true, extraDiags=true, logLevel=SI.debug, faster=true)
    nLag = size(F)[1]
    
    # SI
    Khat      = SI.meshKernelFull(kernel, Xhat, Yhat)
    Kx        = SI.meshKernelFull(kernel, X, Yhat)
    Ky        = SI.meshKernelFull(kernel, Xhat, Y)
    
    # True eigenfunctions
    K        = SI.meshKernelFull(kernel, X, Y)
    err = vecnorm( U * ( F \ V' ) - K )/vecnorm( K )
    @printf "Error %e\n" err
    (U,s,V)  = svd(K)
    U        = U / sqrt(dS)
    V        = V / sqrt(dS)
    s        = s * dS
    U1       = U ; # Those are orthogonal, obviously
    V1       = V ;
    uTrue = []
    vTrue = []
    for i = 1:nLag
        uTruei = interpolate(U1[:,i], BSpline(Cubic(Line())), OnGrid())
        uTrueiFunc = scale(uTruei, linspace(minimum(X[1,:]), maximum(X[1,:]), n))
        uTrueiFuncx = x -> uTrueiFunc[x]
        push!(uTrue, uTrueiFuncx)
        vTruei = interpolate(V1[:,i], BSpline(Cubic(Line())), OnGrid())
        vTrueiFunc = scale(vTruei, linspace(minimum(Y[1,:]), maximum(Y[1,:]), n))
        vTrueiFuncy = y -> vTrueiFunc[y]
        push!(vTrue, vTrueiFuncy)
    end

    # Check orthogonality of eigenfunctions over Xhat
    uTrueXhat = zeros(nLag, nLag)
    for i = 1:nLag
        for j = 1:nLag
            uTrueXhat[i, j] = uTrue[j](Xhat[1,i])
        end
    end
    vTrueYhat = zeros(nLag, nLag)
    for i = 1:nLag
        for j = 1:nLag
            vTrueYhat[i, j] = vTrue[j](Yhat[1,i])
        end
    end

    # for i = 1:nLag
    #     figure();
    #     plot(Xhat[1,:], uTrueXhat[:,i], "*")
    #     plot(X[1,:], U1[:,i])
    #     title((@sprintf "%dth eigenfunction u_i(x) and u_i(Xhat)" i))
    # end
    # for i = 1:nLag
    #     figure();
    #     plot(Yhat[1,:], vTrueYhat[:,i], "*")
    #     plot(Y[1,:], V1[:,i])
    #     title((@sprintf "%dth eigenfunction v_i(y) and v_i(Yhat)" i))
    # end
    
    @show vTrueYhat' * (Khat \ uTrueXhat)
    show(IOContext(STDOUT), "text/plain", vTrueYhat' * (Khat \ uTrueXhat))
    println()
    @show s[1:nLag]
    sInv = sort(1./s[1:nLag])
    @show sInv 
    Ms = diagm(sInv)
    for i = 1:nLag
        for j = 1:nLag
            if i != j
            Ms[i,j] = sqrt(s[nLag-i+1] * s[nLag-j+1]) # 1.0 / (sInv[i] * sInv[j])
            end
        end
    end
    show(IOContext(STDOUT), "text/plain", Ms)
    println()
    @show Ms
    
    # Plot Lagrange basis functions
    Shat = Kx / Khat ;
    for i = 1:min(nLag,4)
        # Shat
        Shati = Shat[:,i] ;
        v = zeros(size(Xhat[1,:]));
        v[i] = 1.0;
        # Poly Shat
        oneShot = zeros(size(Xhat[1,:]))
        oneShot[i] = 1.0
        Spolyi = polyfit(Xhat[1,:], oneShot) ;
        SpolyiX = polyval(Spolyi, X[1,:])
        # Plot
        if i == 4
            # PyPlot.figure();
            # PyPlot.plot(X[1,:], Shati);
            # PyPlot.plot(X[1,:], SpolyiX);
            # PyPlot.plot(Xhat[1,:], oneShot, "*");
            # data1 = hcat(X[1,:], Shati, SpolyiX)
            # data2 = hcat(Xhat[1,:], oneShot)
            # writedlm("Lag4_X.dat", data1);
            # writedlm("Lag4_Xhat.dat", data2);
        end
    end

    # Look at K(Xbar, Ybar)
    Xbar = SI.RtoX(b1, I1.xk)
    Wx   = I1.w_int[1]
    Ybar = SI.RtoX(b2, I2.xk)
    Wy   = I2.w_int[1]
    Kbar = SI.meshKernelFull(kernel, Xbar, Ybar)
    Kwbar = diagm(sqrt.(Wx)) * Kbar * diagm(sqrt.(Wy))
    (Ubar,sbar,Vbar) = svd(Kwbar)
    
    # Plot eigenfunctions
    # Compare with usual |Xhat| nodes Chebshev
    interp = SI.TensorInterpolator{Float64}([SI.Interval{Float64}(-1.0,1.0)], [length(Xhat[1,:])], SI.chebyshev)
    Xchebred = interp.xk
    @assert size(Xchebred)[1] == 1
    @assert size(Xchebred)[2] == size(Xhat)[2]
    Schebred = SI.evalBasis(interp, X)'
   
    errShatui = []

    for i = 1:nLag
    
        # X
        u1 = U1[:,i]
        # Can Shat[x,xhat] approximate u1[x,i] well ?
        # Basically, check wether
        # Shat[x, xhat] * u1[xhat, i] ~= u1[x, i] for all x ?
        u1hat = zeros(size(Xhat)[2])
        for j = 1:size(Xhat)[2]
            u1hat[j] = uTrue[i](Xhat[1,j])
        end
        Shatu1 = Shat * u1hat
        # Also interpolate using usual chebyshev nodes
        u1cheb = zeros(size(Xchebred)[2]) # Same size than Xchebred, see above
        for j = 1:size(Xchebred)[2]
            u1cheb[j] = uTrue[i](Xchebred[1,j])
        end
        Schebredu1 = Schebred * u1cheb
        # And using reduced degree
        poly = polyfit(Xhat[1,:], u1hat)
        ShatLowu1 = polyval(poly, X[1,:])
    
        signFlipCheb = sign(Schebredu1[1])
        signFlipHat = sign(Shatu1[1])
        u1 = u1 / sign(u1[10])
        Shatu1 = Shatu1 / signFlipHat
        Schebredu1 = Schebredu1 / signFlipCheb
        ShatLowu1 = ShatLowu1 / sign(ShatLowu1[10])
        
        errShatu1u1 = maximum(abs.(Shatu1 - u1)) / maximum(abs.(u1))
        errSchebredu1u1 = maximum(abs.(Schebredu1 - u1)) / maximum(abs.(u1))
        errShatLowu1u1 = maximum(abs.(ShatLowu1 - u1)) / maximum(abs.(u1))
    
        # figure() ;
        # plot(X[1,:],u1,label=(@sprintf "%dth eigenfunction in X" i),"-")
        # plot(X[1,:],Shatu1,label=(@sprintf "SI-based interpolant through Xhat"),"--")
        # plot(X[1,:],ShatLowu1,label=(@sprintf "Minimal degree poly interpolant through Xhat"),"--")
        # plot(X[1,:],Schebredu1,label=(@sprintf "Minimal degree classical poly interpolant"),"-.")
        # plot(Xhat[1,:],u1hat / signFlipHat,"b*", label="Xhat") 
        # plot(Xchebred[1,:],u1cheb / signFlipCheb,"r*",label="Classical Chebyshev nodes") 
        # title(string(i, "th x-eigenfunction interpolation, err ", errShatu1u1))
        # xlabel("x")
        # legend()

        if i == nLag
            PyPlot.figure()
            PyPlot.plot(X[1,:],u1);
            PyPlot.plot(X[1,:],Shatu1);
            PyPlot.plot(X[1,:],ShatLowu1);
            PyPlot.plot(Xhat[1,:], u1hat/signFlipHat, "*")
            data1 = hcat(X[1,:], u1, Shatu1, ShatLowu1)
            data2 = hcat(Xhat[1,:], u1hat/signFlipHat)
            writedlm("Uend_X.dat", data1);
            writedlm("Uend_Xhat.dat", data2);
        end

        errmax = maximum(abs.(u1 - Shatu1))
        push!(errShatui, errmax)
    
    end

    data = hcat(1:nLag, errShatui, s[1]./s[1:nLag]*errShatui[1])
    writedlm("errShatui.dat", data);
end

main()
