# Array size policy:
# if one vector only, use array of size (dim,)
# if multiple vectors, use array of size (dim,N)
# if input is a vector, output is expected to be a vector : (dim,) -> (dim,)
# if output is a matrix, output is expected to be a matrix : (dim,N) -> (dim,N)
#                                                 or a vector (dim,N) -> (N,)
# Basically, when possible, dim is the first dimension and there are no "singleton" dimension.

#
# BOUNDING BOX
#

function XtoR{T}(bb::BoundingBox, X::Array{T, 1})
    R = XtoR(bb, reshape(X,(length(X),1)))
    return R[:]
end

function RtoX{T}(bb::BoundingBox, R::Array{T, 1})
    X = RtoX(bb, reshape(R,(length(R),1)))
    return X[:]
end

#
# BOUNDING CUBE
#

type CubeBoundingBox{T} <: BoundingBox

    dim_x::Integer       # Dimensions (in X)
    dim_r::Integer       # Dimensions (in R)
    origin::Array{T,1}   # Origin of the box (in X)
    basis::Array{T,2}    # Basis, ie the axis of the box: M*(R) -> (X)
                         #   columns are vectors in X
    starts::Array{T,1}   # Starts (in R) - 0
    ends::Array{T,1}     # Ends (in R)   - 1
    lengths::Array{T,1}  # Lengths (in X)

    # compress = true means that if the last dimensions are zero (ie data actually like in a smaller subspace) then they are truly removed
    # axisaligned = true means we do not do PCA (i.e., basis = identity)
    function CubeBoundingBox{T}(points::Array{T,2} ; compress=false, axisaligned=false) where T

        # Build object
        this = new()
        this.dim_x = size(points)[1]

        # Obtain the principal axis
        center = mean(points, 2)[:]
        points_centered = zeros(size(points))
        for i = 1:size(points)[2]
            points_centered[:,i] = points[:,i] - center
        end
        SVD = svdfact(points_centered, thin=true)
        if axisaligned
            @printf "No alignment with the data. The basis is aligned with the axis.\n"
            this.basis = eye(this.dim_x, this.dim_x);
        else
            this.basis = SVD[:U]
        end
        
        # Obtain origin
        points_centered_rot = this.basis' * points_centered
        origin_centered_rot = Inf * ones(this.dim_x)
        for i = 1:size(points_centered_rot)[2]
            for j = 1:size(points_centered_rot)[1]
                origin_centered_rot[j] = min(origin_centered_rot[j], points_centered_rot[j,i])
            end
        end
        this.origin = this.basis * origin_centered_rot + center
        for i = 1:size(points_centered_rot)[2]
            points_centered_rot[:,i] -= origin_centered_rot
        end
        this.lengths = - Inf * ones(this.dim_x) 
        for i = 1:size(points_centered_rot)[2]
            for j = 1:size(points_centered_rot)[1]
                this.lengths[j] = max(this.lengths[j], points_centered_rot[j,i])
            end
        end

        # Compress if that's required and if the last dimension is *very* small (1e-15 for float64)
        if compress
            dims_to_keep = this.lengths .>= 10 * eps(T) * maximum(abs.(this.lengths))
            this.lengths = this.lengths[dims_to_keep]
            this.dim_r = length(this.lengths)
            @printf "Compression box from %d to %d dimensions.\n" this.dim_x this.dim_r
            this.basis = this.basis[:,dims_to_keep]
        else
            this.dim_r = this.dim_x
        end
        this.starts = zeros(T, this.dim_r)
        this.ends = ones(T, this.dim_r)

        # Return
        return this

    end

    # Builds a simple rectangle in nd, aligned with the axis
    function CubeBoundingBox{T}(origin::Array{T,1}, lengths::Array{T,1}) where T
        @assert length(origin) == length(lengths)
        this = new()
        dim = length(origin)
        this.dim_x = dim
        this.dim_r = dim
        this.origin = origin
        this.basis = eye(T, (dim, dim)...)
        this.starts = zeros(T, dim)
        this.ends = ones(T, dim)
        this.lengths = lengths
        return this
    end

end

# Build a tensor grid with npts points in every dimension
function generatePoints{T}(cbb::CubeBoundingBox{T}, npts::Array{Int64,1})
    @assert length(npts) == cbb.dim_r
    # Builds tensor in R
    pts = Array{Array{T, 1}, 1}(cbb.dim_r)
    for i = 1:cbb.dim_r
        pts[i] = Array(linspace(cbb.starts[i], cbb.ends[i], npts[i]))
    end
    # Rotate tensor to X
    R = tensor_grid(pts)
    return RtoX(cbb, R)
end

function generateNPoints{T}(cbb::CubeBoundingBox{T}, n)
    # We start with 2^dim points, in each corner
    spacing = []
    npts = Array{Int64, 1}(0)
    for i = 1:cbb.dim_r
        push!(spacing, cbb.lengths[i])
        push!(npts, 2)
    end
    @assert prod(npts) <= n
    nCurrent = prod(npts)
    while true
        # Get the dimension with largest spacing, the one we want to decrease
        newspacing, dim = findmax(spacing)
        # Get the future # of points
        npts_next = copy(npts)
        npts_next[dim] += 1
        if prod(npts_next) > n
            break
        end
        npts = npts_next
        spacing[dim] = cbb.lengths[dim] / (npts[dim] - 1)
    end
    @assert prod(npts) <= n
    return generatePoints(cbb, npts)
end

function XtoR{T}(cbb::CubeBoundingBox{T}, X::Array{T,2})
    @assert size(X)[1] == cbb.dim_x
    R = broadcast(/, (cbb.basis' * broadcast(-, X, cbb.origin)), cbb.lengths)
end

function RtoX{T}(cbb::CubeBoundingBox{T}, R::Array{T,2})
    @assert size(R)[1] == cbb.dim_r
    X = broadcast(+, cbb.origin, cbb.basis * broadcast(*, R, cbb.lengths))
end

function getIntervals{T}(cbb::CubeBoundingBox{T})
    Is = Array{Interval{T},1}()
    for d = 1:cbb.dim_r
        I = Interval{T}(cbb.starts[d], cbb.ends[d])
        push!(Is, I)
    end
    return Is
end

using PyPlot

function plotEdge(corner1, corner2, figid)
    if figid != nothing
        figure(figid)
    else
        gcf()
    end
    plot3D(Array([corner1[1], corner2[1]]),Array([corner1[2], corner2[2]]),Array([corner1[3], corner2[3]]),"-k")
end

function plotBoundingBox{T}(cbb::CubeBoundingBox{T},figid=nothing)
    
    if cbb.dim_x == 2
        corner1 = cbb.origin
        corner2 = cbb.origin + cbb.basis[:,1] * cbb.lengths[1]
        corner3 = cbb.origin + cbb.basis[:,2] * cbb.lengths[2]
        corner4 = cbb.origin + cbb.basis * cbb.lengths
        if figid != nothing
            figure(figid)
        else
            gcf()
        end
        PyPlot.plot(Array([corner1[1], corner2[1]]),Array([corner1[2], corner2[2]]),"-r")
        PyPlot.plot(Array([corner1[1], corner3[1]]),Array([corner1[2], corner3[2]]),"-r")
        PyPlot.plot(Array([corner2[1], corner4[1]]),Array([corner2[2], corner4[2]]),"-r")
        PyPlot.plot(Array([corner3[1], corner4[1]]),Array([corner3[2], corner4[2]]),"-r")
    elseif cbb.dim_x == 3
        corner1 = cbb.origin
        corner2 = cbb.origin + cbb.basis[:,1] * cbb.lengths[1]
        corner3 = cbb.origin + cbb.basis[:,2] * cbb.lengths[2]
        if cbb.dim_r == 2
            corner4 = cbb.origin 
        else
            corner4 = cbb.origin + cbb.basis[:,3] * cbb.lengths[3]
        end
        corner5 = cbb.origin + cbb.basis[:,1:2]   * cbb.lengths[1:2]
        if cbb.dim_r == 2
            corner6 = cbb.origin + cbb.basis[:,2] * cbb.lengths[2]
            corner7 = cbb.origin + cbb.basis[:,1] * cbb.lengths[1]
        else
            corner6 = cbb.origin + cbb.basis[:,2:3] * cbb.lengths[2:3]
            corner7 = cbb.origin + cbb.basis[:,1:2:3] * cbb.lengths[1:2:3]
        end
        corner8 = cbb.origin + cbb.basis          * cbb.lengths
        plotEdge(corner1, corner2, figid)
        plotEdge(corner1, corner3, figid)
        plotEdge(corner1, corner4, figid)
        plotEdge(corner5, corner8, figid)
        plotEdge(corner6, corner8, figid)
        plotEdge(corner7, corner8, figid)
        plotEdge(corner2, corner5, figid)
        plotEdge(corner2, corner7, figid)
        plotEdge(corner4, corner6, figid)
        plotEdge(corner4, corner7, figid)
        plotEdge(corner3, corner5, figid)
        plotEdge(corner3, corner6, figid)
    else
        error("Not implemented yet")
    end

end

#
# BOUNDING ELLIPSOID
#

type EllipsoidBoundingBox{T} <: BoundingBox

    dim_x::Integer        # Always 3 here
    dim_r::Integer        # 2 or 3, depending on the width
    starts::Array{T, 1}   # length 3
    ends::Array{T, 1}     # length 3
    
    center::Array{T, 1}   # length 3
    lengths::Array{T, 1}  # length 3
    basis::Array{T, 2}    # size 3x3

    radius::T             # if dim_r = 2, stores the radius
                          # otherwise, undefined
    tol::T                # compression tolerance

    function EllipsoidBoundingBox{T}(points::Array{T,2}; compress=false, tol=eps(T)*100) where T
        
        @assert size(points)[1] == 3
        N = size(points)[2]

        # We get the ellipsoid
        this = new()
        this.dim_x = 3
        this.dim_r = 3
        (this.center, this.lengths, this.basis) = ellipsoid3DFitting(points)
        if compress
            this.tol = tol
        end

        # Try orientation #1
        R = XtoR(this, points)
        a1 = [minimum(R[1,:]), maximum(R[1,:])]
        b1 = [minimum(R[2,:]), maximum(R[2,:])]
        r1 = [minimum(R[3,:]), maximum(R[3,:])]

        # Try orientation #2
        this.basis[:,1] = - this.basis[:,1]
        this.basis[:,2] = - this.basis[:,2]
        R = XtoR(this, points)
        a2 = [minimum(R[1,:]), maximum(R[1,:])]
        b2 = [minimum(R[2,:]), maximum(R[2,:])]
        r2 = [minimum(R[3,:]), maximum(R[3,:])]

        @assert r1 == r2

        # Keep the one that gives smallest range in alpha (azimuth)
        if a2[2] - a2[1] < a1[2] - a1[1]
            a = a2
            b = b2
            r = r2
        else
            a = a1
            b = b1
            r = r1
            # revert basis
            this.basis[:,1] = - this.basis[:,1]
            this.basis[:,2] = - this.basis[:,2]
        end

        @printf "Radius %e .. %e\n" r[1] r[2]
        if compress && abs.(r[1] - r[2]) < tol * max(abs.(a[2]-a[1]),abs.(b[1]-b[2]))

            this.radius = (r[1]+r[2])/2.0
            @printf "Compressing ellipsoid from 3 to 2. Radius is %e\n" this.radius
            this.starts = Array([a[1], b[1]]) 
            this.ends   = Array([a[2], b[2]])
            this.dim_r  = 2

        else

            # Relax a bit r
            r[1] = r[1] / (1 + 1e-8)
            r[2] = r[2] * (1 + 1e-8)
            if abs.(r[2] - r[1]) < 1e-8
                r[1] = r[1] - 1e-8
                r[2] = r[2] + 1e-8
            end

            # Store start/ends
            this.starts = Array([a[1], b[1], r[1]])
            this.ends   = Array([a[2], b[2], r[2]])

        end

        # we now need to get the orientation and bounds and stuff to make the box
        this

    end

end

function XtoR{T}(ebb::EllipsoidBoundingBox{T}, X::Array{T,2})
    @assert size(X)[1] == ebb.dim_x
    xscaled = broadcast(/, ebb.basis' * broadcast(-, X, ebb.center), ebb.lengths)
    R = Array{T, 2}(ebb.dim_r, size(X)[2])
    if ebb.dim_r == 3 
        R[3,:] = sqrt.(sum(xscaled.*xscaled,1))
    else
        @assert maximum(abs.(sqrt.(sum(xscaled.*xscaled,1)) - ebb.radius)) < ebb.tol * maximum(ebb.lengths)
    end
    R[1,:] = atan2.(xscaled[2,:],xscaled[1,:])
    if ebb.dim_r == 3
        R[2,:] = acos.(xscaled[3,:]./R[3,:])
    else
        R[2,:] = acos.(xscaled[3,:]./ebb.radius)
    end
    R
end

function RtoX{T}(ebb::EllipsoidBoundingBox{T},R::Array{T,2})
    @assert size(R)[1] == ebb.dim_r
    a = R[1,:]
    b = R[2,:]
    if ebb.dim_r == 3
        r = R[3,:]
    else
        r = ebb.radius * ones(T, size(R)[2])
    end
    angles = ebb.basis * broadcast(*, ebb.lengths, vcat( (cos.(a).*sin.(b))' , (sin.(a).*sin.(b))' , cos.(b)' ) )
    rangles = broadcast(*, r', angles)
    X = broadcast(+, ebb.center, rangles) 
    X
end

function plotCurvedEdge{T}(ebb::EllipsoidBoundingBox{T}, a_bnd, b_bnd, r_bnd, figid, width=1)
    nPts = 50
    if length(r_bnd) > 1
        Ra = a_bnd[1]*ones(nPts)
        Rb = b_bnd[1]*ones(nPts)
        Rr = linspace(r_bnd[1],r_bnd[2],nPts)
    elseif length(a_bnd) > 1
        Ra = linspace(a_bnd[1],a_bnd[2],nPts)
        Rb = b_bnd[1]*ones(nPts)
        Rr = r_bnd[1]*ones(nPts)
    elseif length(b_bnd) > 1
        Ra = a_bnd[1]*ones(nPts)
        Rb = linspace(b_bnd[1],b_bnd[2],nPts)
        Rr = r_bnd[1]*ones(nPts)
    end
    R = vcat(Ra', Rb', Rr')
    if ebb.dim_r == 3
        X = RtoX(ebb, R)
    else
        X = RtoX(ebb, R[1:2,:])
    end
    figure(figid)
    plot3D(X[1,:],X[2,:],X[3,:],"-k",linewidth=width)
end

function plotBoundingBox{T}(ebb::EllipsoidBoundingBox{T}, figid)
    a_bnd = [ebb.starts[1], ebb.ends[1]]
    b_bnd = [ebb.starts[2], ebb.ends[2]]
    if ebb.dim_r == 3
        r_bnd = [ebb.starts[3], ebb.ends[3]]
    else
        r_bnd = [ebb.radius, ebb.radius]
    end

    plotCurvedEdge(ebb, a_bnd[1:1], b_bnd[1:1], r_bnd, figid, 2)
    plotCurvedEdge(ebb, a_bnd[2:2], b_bnd[1:1], r_bnd, figid, 2)
    plotCurvedEdge(ebb, a_bnd[1:1], b_bnd[2:2], r_bnd, figid, 2)
    plotCurvedEdge(ebb, a_bnd[2:2], b_bnd[2:2], r_bnd, figid, 2)
    plotCurvedEdge(ebb, a_bnd[1:1], b_bnd, r_bnd[1:1], figid, 2)
    plotCurvedEdge(ebb, a_bnd[2:2], b_bnd, r_bnd[1:1], figid, 2)
    plotCurvedEdge(ebb, a_bnd[1:1], b_bnd, r_bnd[2:2], figid, 2)
    plotCurvedEdge(ebb, a_bnd[2:2], b_bnd, r_bnd[2:2], figid, 2)
    plotCurvedEdge(ebb, a_bnd, b_bnd[1:1], r_bnd[1:1], figid, 2)
    plotCurvedEdge(ebb, a_bnd, b_bnd[1:1], r_bnd[2:2], figid, 2)
    plotCurvedEdge(ebb, a_bnd, b_bnd[2:2], r_bnd[1:1], figid, 2)
    plotCurvedEdge(ebb, a_bnd, b_bnd[2:2], r_bnd[2:2], figid, 2)

    nInter = 10
    a_inter = linspace(a_bnd[1], a_bnd[2], 2+nInter)
    a_inter = a_inter[2:end-1]
    b_inter = linspace(b_bnd[1], b_bnd[2], 2+nInter)
    b_inter = b_inter[2:end-1]
    for i in 1:nInter
        plotCurvedEdge(ebb, a_inter[i:i], b_bnd, r_bnd[1:1], figid)
        plotCurvedEdge(ebb, a_inter[i:i], b_bnd, r_bnd[2:2], figid)
        plotCurvedEdge(ebb, a_bnd, b_inter[i:i], r_bnd[1:1], figid)
        plotCurvedEdge(ebb, a_bnd, b_inter[i:i], r_bnd[2:2], figid)
    end

end

#
# Stretching box (each dim separately)
# Stretches on [0,1] <-> [0,1]
# Depending on concavity, stretches towards one end or the other
#

type StretchingBox{T} <: BoundingBox

    dim
    fun
    funInv

    # fun[i] : [N] -> [N]
    # funInv[i] : [N] -> [N]
    function StretchingBox{T}(fun, funInv) where T
        this = new()
        dim = size(fun, 1)
        @assert dim == size(funInv, 1)
        this.fun = fun
        this.funInv = funInv
        this.dim = dim
        for i = 1:dim
            @assert    norm(fun[i]([0.0,1.0]) - [0.0,1.0]) <= 1e-12
            @assert norm(funInv[i]([0.0,1.0]) - [0.0,1.0]) <= 1e-12
        end
        return this
    end
    
end

function RtoX{T}(sb::StretchingBox{T},R::Array{T,2})

    @assert size(R, 1) == sb.dim
    X = zeros(size(R))
    for i = 1:sb.dim
        X[i,:] = sb.funInv[i](R[i,:])
    end
    @assert all(R .<= 1.0 + 1e-14)
    @assert all(R .>= 0.0 - 1e-14)
    @assert all(X .<= 1.0 + 1e-14)
    @assert all(X .>= 0.0 - 1e-14)
    return X

end

function XtoR{T}(sb::StretchingBox{T},X::Array{T,2})
    
    @assert size(X, 1) == sb.dim
    R = zeros(size(X))
    for i = 1:sb.dim
        R[i,:] = sb.fun[i](X[i,:])
    end
    @assert all(R .<= 1.0 + 1e-14)
    @assert all(R .>= 0.0 - 1e-14)
    @assert all(X .<= 1.0 + 1e-14)
    @assert all(X .>= 0.0 - 1e-14)
    return R

end

#
# BOUNDING TORUS (generates points and bounding box - 2d only - no fitting)
#

type TorusBoundingBox{T} <: BoundingBox

    dim_x::Integer        # Always 3
    dim_r::Integer        # Always 2
    starts::Array{T, 1}   # length 2
    ends::Array{T, 1}     # length 2
    
    R::T                  # Big radius (major)
    r::T                  # Small radius (minor)
    # Center is zeros, vertical is z, torus is in (x, y) included in [-R,R] x [-R,R] x [-r,r]

    # Mapping is (x,y,z) -> (theta,phi) where phi is angle in (x,y) plane (major disk) and theta is minor disk angle (starts at 0 outside, goes up in z, etc)
    #
    # x = (R + r cos theta) cos phi
    # y = (R + r cos theta) sin phi
    # z = r sin theta
    #
    # where -pi <= phi   < pi
    # and   -pi <= theta < pi
    #
    # i.e.
    #
    # theta =   asin(z/r)      if x^2+y^2 >= R^2
    # or    =   pi - asin(z/r) if x^2+y^2 < R^2 and asin(z,r) > 0
    #       = - pi - asin(z,r) if x^2+y^2 < R^2 and asin(z,r) < 0
    # phi   =   atan2(x, y)
    
    function TorusBoundingBox{T}(R::T, r::T, thetarange::Array{T,1}, phirange::Array{T,1}) where T
        this = new()
        this.R = R
        this.r = r
        this.dim_x = 3
        this.dim_r = 2
        @assert -pi <= thetarange[1]
        @assert        thetarange[2] < pi
        @assert -pi <= phirange[1]
        @assert        phirange[2]   < pi
        @assert thetarange[1] < thetarange[2]
        @assert phirange[1] < phirange[2]
        this.starts = [thetarange[1], phirange[1]]
        this.ends = [thetarange[2], phirange[2]]
        return this
    end

end

function XtoR{T}(tbb::TorusBoundingBox{T}, X::Array{T,2})
    @assert size(X)[1] == 3
    nPts = size(X)[2]
    x = X[1,:]
    y = X[2,:]
    z = X[3,:]
    @assert maximum(abs.( (tbb.R - sqrt.(x.^2+y.^2)).^2 + z.^2 - tbb.r^2 )) < 1e-14 # Check they are on the torus indeed
    # theta
    theta = zeros(T, (nPts,)) 
    asinzr = asin.(z ./ tbb.r)
    x2y2 = x.^2+y.^2
    id1 = (x2y2 .>= tbb.R^2)
    id2 = (x2y2 .<  tbb.R^2) .& (asinzr .>= 0)
    id3 = (x2y2 .<  tbb.R^2) .& (asinzr .<  0)
    theta[id1] =        asinzr[id1]
    theta[id2] =   pi - asinzr[id2]
    theta[id3] = - pi - asinzr[id3]
    # pi
    phi = atan2.(y, x)
    # return
    return vcat(theta', phi')
end

function RtoX{T}(tbb::TorusBoundingBox{T},R::Array{T,2})
    @assert size(R,1) == 2
    theta = R[1,:]
    phi   = R[2,:]
    @assert all(tbb.starts[1] - 1e-14 .<= theta) && all(theta .<= tbb.ends[1] + 1e-14)
    @assert all(tbb.starts[2] - 1e-14 .<= phi  ) && all(phi   .<= tbb.ends[2] + 1e-14)
    x     = (tbb.R + tbb.r*cos.(theta)).*cos.(phi) 
    y     = (tbb.R + tbb.r*cos.(theta)).*sin.(phi)
    z     = tbb.r * sin.(theta)
    return vcat(x', y', z')
end

function generatePoints{T}(tbb::TorusBoundingBox{T},nTheta::Int64,nPhi::Int64)
    theta = Array(linspace(tbb.starts[1], tbb.ends[1], nTheta))
    phi   = Array(linspace(tbb.starts[2], tbb.ends[2], nPhi))
    R     = tensor_grid([theta, phi])
    return RtoX(tbb, R)
end

function plotBoundingBox{T}(cbb::TorusBoundingBox{T},figid)
    warn("Not implemented yet")
end

type SphereBoundingBox{T} <: BoundingBox
    center::Array{T,1} # (dim,)
    radius::T
    function SphereBoundingBox{T}(X::Array{T,2}) where T
        @assert size(X,1) == 2 || size(X,1) == 3
        N = size(X,2)
        this = new()
        this.center = reshape(mean(X,2), (size(X,1),))
        this.radius = 0.0
        for i = 1:size(X,2)
            this.radius = max(this.radius, distance_2(X[:,i], this.center)[1])
        end
        return this
    end
end

function generatePoints{T}(sbb::SphereBoundingBox{T},n)
    if size(sbb.center,1) == 2 # 2d - circle
        theta = linspace(0,2*pi,n+1)[1:n];
        cn    = reshape(cos.(theta), (1, n))
        sn    = reshape(sin.(theta), (1, n))
        P = broadcast(+, sbb.center, sbb.radius .* vcat(cn, sn))
        return P
    elseif size(sbb.center,1) == 3 # 3d - sphere
        golden_angle = pi * (3 - sqrt(5))
        theta = golden_angle * (1:n)
        z = linspace(-1.0, 1.0, n)
        r = sqrt.(1 - z .^ 2)
        P = broadcast(+, sbb.center, sbb.radius .* hcat( r .* cos.(theta) ,  
                                                         r .* sin.(theta) ,
                                                         z                )' )
        return P
    end
end
