# 2d plates test
include("../../src/SI.jl")
include("../../benchmarks/geometries.jl")
using PyPlot
using LowRankApprox
PyPlot.close("all")
srand(0)

# Decide on geometry using provided geometries
(X, Y, b1, b2) = getGeometry(10,20,coef=-0.9)

# Plot geometry
@printf "Plotting\n"
PyPlot.figure(1)
if b1.dim_x == 2 && b2.dim_x == 2
    PyPlot.plot(X[1,:],X[2,:],"or",markersize=0.5)
    PyPlot.plot(Y[1,:],Y[2,:],"ob",markersize=0.5)
else
    PyPlot.plot3D(X[1,:],X[2,:],X[3,:],"or",markersize=0.5)
    PyPlot.plot3D(Y[1,:],Y[2,:],Y[3,:],"ob",markersize=0.5)
end
SI.plotBoundingBox(b1, 1)
SI.plotBoundingBox(b2, 1)

# Build stretching boxes
# Good value : ref 1.3, tolInterp tol^-0.6, tolQR = 0.1 for coef = 0.0
#              ref 1.5, tolInterp 1.0,      tolQR = 0.1 for coef = -0.9
ref = 1.5
fFwd = x -> (1.0 - ref) .* x ./ (x - ref)
fBwd = r -> ref .* r ./ (r + ref - 1.0)
# No stretching
# b1st = SI.StretchingBox{Float64}([x->x, x->x], [x->x, x->x])
# b2st = SI.StretchingBox{Float64}([x->x, x->x], [x->x, x->x])
# Some stretching
b1st = SI.StretchingBox{Float64}([fFwd, fFwd], [fBwd, fBwd])
b2st = SI.StretchingBox{Float64}([fBwd, fBwd], [fFwd, fFwd])

XtoR1 = x -> SI.XtoR(b1st, SI.XtoR(b1, x))
XtoR2 = x -> SI.XtoR(b2st, SI.XtoR(b2, x))
RtoX1 = r -> SI.RtoX(b1, SI.RtoX(b1st, r))
RtoX2 = r -> SI.RtoX(b2, SI.RtoX(b2st, r))

kernel = (x,y) -> 1./SI.distance_2(x,y)
kernelB = (x,y) -> kernel(RtoX1(x),RtoX2(y))
I1 = SI.getIntervals(b1)
I2 = SI.getIntervals(b2)

tols = logspace(-4, -10, 7)
errsr1 = zeros(size(tols))
errsr2 = zeros(size(tols))
ranksSIr0 = zeros(size(tols))
ranksSIr1 = zeros(size(tols))
ranksSIr2 = zeros(size(tols))
ranksRRQR = zeros(size(tols))

Xhat = 0
Yhat = 0
Xbar = 0
Ybar = 0

Atrue = SI.meshKernelFull(kernel, X, Y)
ranksSVD = SI.rank_eps_fro(Atrue, tols)


for (it, tol) in enumerate(tols)

    (U, F, V, Xhatr, Yhatr, Xbarr, Ybarr, U2, S2, V2) = SI.fastMeshKernelLowRankApprox(kernelB, I1, I2, XtoR1(X), XtoR2(Y), tol, ruletype=SI.chebyshev, tolInterp=tol^(-0.0), tolQR=0.1, logLevel=SI.debug)
    Xhat = RtoX1(Xhatr)
    Yhat = RtoX2(Yhatr)
    Xbar = RtoX1(Xbarr)
    Ybar = RtoX2(Ybarr)

    L = U
    R = (F\(V'))
    Aapp  = L*R

    PQRtol = pqrfact(Atrue, rtol=tol)
    rankRRQR = size(PQRtol[:R])[1]
    rankSI = minimum(size(U))
    errtol = vecnorm(Aapp - Atrue)/vecnorm(Atrue) 

    Aapp2 = U2 * S2 * V2'
    err2tol = vecnorm(Aapp2 - Atrue)/vecnorm(Atrue)
    nKept = size(U2, 2)

    @printf "Error %e for rank %d (actual %e SVD-rank is %d, RRQR-rank is %d)\n" errtol rankSI tol ranksSVD[it] rankRRQR
    @printf "Recompression has error %e and rank %d\n" err2tol nKept

    errsr1[it] = errtol
    errsr2[it] = err2tol
    ranksSIr0[it] = min(size(Xbarr,2),size(Ybarr,2))
    ranksSIr1[it] = rankSI
    ranksSIr2[it] = nKept
    ranksRRQR[it] = rankRRQR

end

PyPlot.figure(1)
if b1.dim_x == 2 && b2.dim_x == 2
    PyPlot.plot(Xhat[1,:],Xhat[2,:],"*r")
    PyPlot.plot(Yhat[1,:],Yhat[2,:],"*b")
else
    PyPlot.plot3D(Xhat[1,:],Xhat[2,:],Xhat[3,:],"*r")
    PyPlot.plot3D(Yhat[1,:],Yhat[2,:],Yhat[3,:],"*b")
end
    
PyPlot.figure(2)
SI.plotBoundingBox(b1, 2)
SI.plotBoundingBox(b2, 2)
PyPlot.plot(Xbar[1,:],Xbar[2,:],"dr")
PyPlot.plot(Ybar[1,:],Ybar[2,:],"db")

PyPlot.figure(3)
PyPlot.loglog(tols, tols, label="Tol")
PyPlot.loglog(tols, errsr1, label="Frob. error r1")
PyPlot.loglog(tols, errsr2, label="Frob. error r2", ".-")
PyPlot.legend()

PyPlot.figure(4)
PyPlot.semilogx(tols, ranksSVD, label="SVD")
PyPlot.semilogx(tols, ranksSIr0, label="r0", "-.")
PyPlot.semilogx(tols, ranksSIr1, label="r1", ".-")
PyPlot.semilogx(tols, ranksSIr2, label="r2", "--")
PyPlot.legend()
