include("../../src/SI.jl")
include("../../benchmarks/geometries.jl")
using PyPlot
close("all")

(X, Y, b1, b2) = getGeometry(10,30,coef=-0.75)
N = size(X,2)
kernel = (x, y) -> 1./SI.distance_2(x, y)
kernelR = (rx, ry) -> kernel(SI.RtoX(b1,rx),SI.RtoX(b2,ry))
tol = 1e-6
tolQR = 1e-7
Ktrue = SI.meshKernelFull(kernel, X, Y)
stop = (r0, r1, Xhat, Yhat) -> SI.accuracy_stopping_criterion(kernel, X, Y, Xhat, Yhat, Ktrue, tol)

#
# Random sampling case (Xhat = X, Yhat = Y)
#

# We pick amongst X and Y
Xbar = n -> (X, ones(N))
Ybar = n -> (Y, ones(N))
# We monitor using random points in X and Y
Xmon = n -> (X[:,randperm(N)[1:n]], ones(n))
Ymon = n -> (Y[:,randperm(N)[1:n]], ones(n))
# Then call the adaptive SI
@time (Xhat, Yhat) = SI.adaptive_si(kernel, Xbar, Ybar, Xmon=Xmon, Ymon=Ymon, tolQR, rinit=10, rmax=1000, rinc=x->x*2, loglevel=SI.debug, stopping_criterion=stop)

# Plot
figure()
plot(X[1,:],X[2,:],"or",markersize=0.5)
plot(Y[1,:],Y[2,:],"ob",markersize=0.5)
plot(Xhat[1,:],Xhat[2,:],"*r")
plot(Yhat[1,:],Yhat[2,:],"*b")
title("Random Samping")
# Check accuracy
err = SI.err_si_frob(kernel, X, Y, Xhat, Yhat)
@printf "==> Random Sampling Error for tol %e is %e with rank r1 %d\n" tol err size(Xhat,2)

#
# Perfect RRQR case
#

@time (Xhat, Yhat) = SI.adaptive_si(kernel, Xbar, Ybar, Xmon=Xbar, Ymon=Ybar, tolQR, rinit=10, rmax=1000, rinc=x->x*2, loglevel=SI.debug, stopping_criterion=stop)

# Plot
figure()
plot(X[1,:],X[2,:],"or",markersize=0.5)
plot(Y[1,:],Y[2,:],"ob",markersize=0.5)
plot(Xhat[1,:],Xhat[2,:],"*r")
plot(Yhat[1,:],Yhat[2,:],"*b")
title("Ideal RRQR")
# Check accuracy
err = SI.err_si_frob(kernel, X, Y, Xhat, Yhat)
@printf "==> Ideal RRQR Error for tol %e is %e with rank r1 %d\n" tol err size(Xhat,2)

#
# Random sampling case (Xhat = rand, Yhat = rand)
#

# We pick amongst X and Y
Xbar = n -> (X[:,randperm(N)[1:n]], ones(n))
Ybar = n -> (Y[:,randperm(N)[1:n]], ones(n))

# Then call the adaptive SI
@time (Xhat, Yhat) = SI.adaptive_si(kernel, Xbar, Ybar, tolQR, rinit=10, rmax=1000, rinc=x->x*2, loglevel=SI.debug, stopping_criterion=stop)

# Plot
figure()
plot(X[1,:],X[2,:],"or",markersize=0.5)
plot(Y[1,:],Y[2,:],"ob",markersize=0.5)
plot(Xhat[1,:],Xhat[2,:],"*r")
plot(Yhat[1,:],Yhat[2,:],"*b")
title("Full Random Sampling")
# Check accuracy
err = SI.err_si_frob(kernel, X, Y, Xhat, Yhat)
@printf "==> Full Random Sampling Error for tol %e is %e with rank r1 %d\n" tol err size(Xhat,2)

#
# Chebyshev-SI with interpolation and weights
#
I1 = SI.getIntervals(b1)
I2 = SI.getIntervals(b2)
      (U1, F1, V1, Xhatr, Yhatr, Xbarr, Ybarr, U2, S2, V2) = SI.fastMeshKernelLowRankApprox(kernelR, I1, I2, SI.XtoR(b1,X), SI.XtoR(b2,Y), tol, ruletype=SI.chebyshev, tolInterp=1.0, tolQR=0.1, logLevel=SI.none) # Just to compile everything
@time (U1, F1, V1, Xhatr, Yhatr, Xbarr, Ybarr, U2, S2, V2) = SI.fastMeshKernelLowRankApprox(kernelR, I1, I2, SI.XtoR(b1,X), SI.XtoR(b2,Y), tol, ruletype=SI.chebyshev, tolInterp=1.0, tolQR=0.1, logLevel=SI.info)
Xhat = SI.RtoX(b1, Xhatr)
Yhat = SI.RtoX(b2, Yhatr)
err = SI.err_si_frob(kernel, X, Y, Xhat, Yhat)
@printf "==> Usual Fast Cheb-SI Error for tol %e is %e with rank r1 %d\n" tol err size(Xhat,2)

#
# MDV case 
#

# We pick amongst X and Y
Xbar = n -> (X, ones(N))
Ybar = n -> (Y, ones(N))
# We monitor using MDV points in X and Y
Xmon = n -> (X[:,SI.get_mdv(X,n)], ones(n))
Ymon = n -> (Y[:,SI.get_mdv(Y,n)], ones(n))
# Then call the adaptive SI
@time (Xhat, Yhat) = SI.adaptive_si(kernel, Xbar, Ybar, Xmon=Xmon, Ymon=Ymon, tolQR, rinit=10, rmax=1000, rinc=x->x*2, loglevel=SI.debug, stopping_criterion=stop)

# Plot
figure()
plot(X[1,:],X[2,:],"or",markersize=0.5)
plot(Y[1,:],Y[2,:],"ob",markersize=0.5)
plot(Xhat[1,:],Xhat[2,:],"*r")
plot(Yhat[1,:],Yhat[2,:],"*b")
title("MDV")
# Check accuracy
err = SI.err_si_frob(kernel, X, Y, Xhat, Yhat)
@printf "==> MDV Error for tol %e is %e with rank r1 %d\n" tol err size(Xhat,2)
