using Iterators

# Any 1d-rule-based tensor interpolator
# over the cude [a1 ; b1] x [a2 ; b2] x ... x [adim ; bdim]
type TensorInterpolator{T} <: Interpolator

    # Inputs
    dim::Int64
    intervals::Array{Interval{T},1}  # Interval in every dimension, [dim]
    orders::Array{Int64,1}           # Order in every dimension, [dim]
                                     # Order starts at 1 (1 = 1 point = degree 0, 2 = 2 points = degree 1, etc)
    ruletype::RULETYPE

    # Tensor 1d-informations
    x::Array{Array{T,1},1}           # [dim] x [K(dim)]
    w::Array{Array{T,1},1}           # [dim] x [K(dim)]
    w_int::Array{Array{T,1},1}       # [dim] x [K(dim)]

    # Accumulated informations
    xk::Array{T,2}                   # the interpolation points [dim x K]
    w_intk::Array{T,1}               # the integration *relative* weights [K]
    orderIdx::Array{Int64,2}         # the id in order space of every interpolation point [dim x K]

    # Build a tensor interpolator over intervals for functions funs (we can consider multiple functions at oves, so funs
    # is a list of functions). Error is evaluated over the set X. Tolerance is tol (L2-relative over all the functions)
    # ruletype is chebyshev or legendre.
    function TensorInterpolator{T}(intervals::Array{Interval{T},1}, funs, X::Array{T, 2}, tol::T, ruletype::RULETYPE) where T

        @assert size(X)[1] == length(intervals)
        @assert tol > 0
        dim = length(intervals)
        orders = ones(Int64, dim)

        nf = length(funs)
        ftrue = Array{Array{T, 1}, 1}(nf)
        fapp  = Array{Array{T, 1}, 1}(nf)

        # Idea : at each stage, until convergence, we add the order that decreases the error the most
        # We simply repeatedly use the static constructor
        # There is quite a bit of overhead here
        itid = 1

        while true
            
            err = Array{T, 1}(dim)
            for d = 1:dim
                # Try increasing orders[d] by 1
                newOrders = copy(orders)
                newOrders[d] += 1
                # Compute new tensor interpolator
                newInterp = TensorInterpolator(intervals, newOrders, ruletype)
                # Evaluate for all funs
                for fid = 1:nf
                    ftrue[fid] = funs[fid](X)
                    fapp[fid]  = interpolate(newInterp, X, funs[fid])
                end
                err[d] = errorL2AppTrue(fapp, ftrue)

                @printf "%d. Order %s has error %e\n" itid string(newOrders) err[d]

            end

            bestD = indmin(err)
            orders[bestD] += 1
            @printf "%d. Best order is now %s with error %e\n" itid string(orders) err[bestD] 
            itid += 1
           
            if err[bestD] <= tol
                @printf "Converged\n"
                break
            end
            if itid >= 50
                info("!!!! Too many iterations. Aborting...\n")
                break
            end

        end

        return TensorInterpolator(intervals, orders, ruletype)

    end


    function TensorInterpolator{T}(intervals::Array{Interval{T},1}, orders::Array{Int64,1}, ruletype::RULETYPE) where T
        this = new() 
        this.intervals = intervals
        this.dim = length(intervals)
        @assert this.dim == length(orders)
        this.orders = orders
        this.ruletype == ruletype

        # Compute the 1d rules
        this.w = Array{Array{T,1},1}(this.dim)
        this.w_int = Array{Array{T,1},1}(this.dim)
        this.x = Array{Array{T,1},1}(this.dim)
        nnodes = Array{Int64, 1}(this.dim)
        for d = 1:this.dim
            if ruletype == chebyshev 
                (nd,this.w[d],this.x[d],this.w_int[d]) = getChebyshevNodes(this.orders[d],intervals[d].a,intervals[d].b) 
                @assert nd == this.orders[d]
                @assert length(this.w[d]) == length(this.x[d])
                @assert length(this.x[d]) == nd
                @assert length(this.w_int[d]) == nd
            elseif ruletype == legendre
                (nd,this.w[d],this.x[d],this.w_int[d]) = getLegendreNodes(this.orders[d],intervals[d].a,intervals[d].b) 
                @assert nd == this.orders[d]
                @assert length(this.w[d]) == length(this.x[d])
                @assert length(this.x[d]) == nd
                @assert length(this.w_int[d]) == nd
            elseif ruletype == uniform
                (nd,this.w[d],this.x[d],this.w_int[d]) = getUniformNodes(this.orders[d],intervals[d].a,intervals[d].b)
                @assert nd == this.orders[d]
                @assert length(this.w[d]) == length(this.x[d])
                @assert length(this.x[d]) == nd
                @assert length(this.w_int[d]) == nd
            else
                @assert false
            end
            nnodes[d] = nd
        end

        # Prepare the tensor approx
        K = prod(nnodes)
        this.orderIdx = Array{Int64, 2}(this.dim, K)
        this.xk = Array{T, 2}(this.dim, K)
        this.w_intk = Array{T, 1}(K)
        dimIter = Array{Any,1}(this.dim)
        for d = 1:this.dim
            dimIter[d] = 1:nnodes[d]
        end
        iter = product(dimIter...)
        for (k, nodeid) in enumerate(iter)
            this.orderIdx[:,k] = [nodeid...] # tuple -> array
            this.w_intk[k] = 1.0
            for d = 1:this.dim
                this.xk[d,k] = this.x[d][nodeid[d]]
                this.w_intk[k] *= this.w_int[d][nodeid[d]]
            end
        end
        
        # Return
        return this
    end

end

# X is [dim x M]
# output is [K x M]
function evalBasis{T}(ti::TensorInterpolator{T}, X) 

    @assert size(X)[1] == ti.dim
    @assert length(size(X)) == 2

    nptseval = size(X)[2]   # M
    nnodes = size(ti.xk)[2] # K

    out = ones(T, (nnodes, nptseval))

    # We have some work to do in every dimension
    wi_x_xi = Array{Array{T,2},1}(ti.dim)     # [dim] x [K(dim) x M]
    sum_wi_x_xi = Array{Array{T,1},1}(ti.dim) # [dim] x [M]
    for d = 1:ti.dim
        wi_x_xi[d] = Array{T,2}((length(ti.x[d]),nptseval))
        for i = 1:nptseval
            ids = find(x -> x == X[d,i], ti.x[d])
            @assert length(ids) <= 1
            if length(ids) == 1
                wi_x_xi[d][:,i] = 0.0
                wi_x_xi[d][ids,i] = 1.0
            else
                wi_x_xi[d][:,i] = ti.w[d] ./ ( X[d,i] - ti.x[d] )
            end
        end
        sum_wi_x_xi[d] = sum(wi_x_xi[d],1)[:]
    end

    # We combine all the stuff
    for k = 1:nnodes
        for d = 1:ti.dim
            o = ti.orderIdx[d,k]
            out[k,:] = out[k,:] .* wi_x_xi[d][o,:] ./ sum_wi_x_xi[d]
        end
    end

    @assert ! any(isnan, out)
    
    return out

end

function interpolate{T}(ti::TensorInterpolator{T}, X, f) # X is [dim x M]

    @assert size(X)[1] == ti.dim
    @assert length(size(X)) == 2

    fxk = f(ti.xk) # [dim x K] -> [K]
    out = evalBasis(ti, X)' * fxk # [K x M]' * [K] = [M x K] * [K] = [M]
    return out

end
