|
| 1 | +# This file is a part of EmpiricalDistributions.jl, licensed under the MIT License (MIT). |
| 2 | + |
| 3 | +struct MvBinnedDist{T, N} <: Distributions.Distribution{Multivariate,Continuous} |
| 4 | + h::StatsBase.Histogram{<:Real, N} |
| 5 | + edges::NTuple{N, <:AbstractVector{T}} |
| 6 | + cart_inds::CartesianIndices{N, NTuple{N, Base.OneTo{Int}}} |
| 7 | + |
| 8 | + probabilty_edges::AbstractVector{T} |
| 9 | + |
| 10 | + μ::AbstractVector{T} |
| 11 | + var::AbstractVector{T} |
| 12 | + cov::AbstractMatrix{T} |
| 13 | +end |
| 14 | + |
| 15 | +export MvBinnedDist |
| 16 | + |
| 17 | + |
| 18 | +function MvBinnedDist(h::StatsBase.Histogram{<:Real, N}, T::DataType = Float64) where {N} |
| 19 | + nh = normalize(h) |
| 20 | + |
| 21 | + probabilty_widths = nh.weights * inv(sum(nh.weights)) |
| 22 | + probabilty_edges::Vector{T} = Vector{Float64}(undef, length(h.weights) + 1) |
| 23 | + probabilty_edges[1] = 0 |
| 24 | + for (i, w) in enumerate(probabilty_widths) |
| 25 | + v = probabilty_edges[i] + probabilty_widths[i] |
| 26 | + probabilty_edges[i+1] = v > 1 ? 1 : v |
| 27 | + end |
| 28 | + |
| 29 | + mean = _mean(h) |
| 30 | + var = _var(h, mean = mean) |
| 31 | + cov = _cov(h, mean = mean) |
| 32 | + |
| 33 | + return MvBinnedDist{T, N}( |
| 34 | + nh, |
| 35 | + collect.(nh.edges), |
| 36 | + CartesianIndices(nh.weights), |
| 37 | + probabilty_edges, |
| 38 | + mean, |
| 39 | + var, |
| 40 | + cov |
| 41 | + ) |
| 42 | +end |
| 43 | + |
| 44 | + |
| 45 | +Base.length(d::MvBinnedDist{T, N}) where {T, N} = N |
| 46 | +Base.size(d::MvBinnedDist{T, N}) where {T, N} = (N,) |
| 47 | +Base.eltype(d::MvBinnedDist{T, N}) where {T, N} = T |
| 48 | + |
| 49 | +Statistics.mean(d::MvBinnedDist{T, N}) where {T, N} = d.μ |
| 50 | +Statistics.var(d::MvBinnedDist{T, N}) where {T, N} = d.var |
| 51 | +Statistics.cov(d::MvBinnedDist{T, N}) where {T, N} = d.cov |
| 52 | + |
| 53 | + |
| 54 | +function _mean(h::StatsBase.Histogram{<:Real, N}; T::DataType = Float64) where {N} |
| 55 | + s_inv::T = inv(sum(h.weights)) |
| 56 | + m::Vector{T} = zeros(T, N) |
| 57 | + mps = StatsBase.midpoints.(h.edges) |
| 58 | + cart_inds = CartesianIndices(h.weights) |
| 59 | + for i in cart_inds |
| 60 | + for idim in 1:N |
| 61 | + m[idim] += s_inv * mps[idim][i[idim]] * h.weights[i] |
| 62 | + end |
| 63 | + end |
| 64 | + return m |
| 65 | +end |
| 66 | + |
| 67 | + |
| 68 | +function _var(h::StatsBase.Histogram{<:Real, N}; T::DataType = Float64, mean = StatsBase.mean(h, T = T), ) where {N} |
| 69 | + s_inv::T = inv(sum(h.weights)) |
| 70 | + v::Vector{T} = zeros(T, N) |
| 71 | + mps = StatsBase.midpoints.(h.edges) |
| 72 | + cart_inds = CartesianIndices(h.weights) |
| 73 | + for i in cart_inds |
| 74 | + for idim in 1:N |
| 75 | + v[idim] += s_inv * (mps[idim][i[idim]] - mean[idim])^2 * h.weights[i] |
| 76 | + end |
| 77 | + end |
| 78 | + return v |
| 79 | +end |
| 80 | + |
| 81 | + |
| 82 | +function _cov(h::StatsBase.Histogram{<:Real, N}; T::DataType = Float64, mean = StatsBase.mean(h, T = T)) where {N} |
| 83 | + s_inv::T = inv(sum(h.weights)) |
| 84 | + c::Matrix{T} = zeros(T, N, N) |
| 85 | + mps = StatsBase.midpoints.(h.edges) |
| 86 | + cart_inds = CartesianIndices(h.weights) |
| 87 | + for i in cart_inds |
| 88 | + for idim in 1:N |
| 89 | + for jdim in 1:N |
| 90 | + c[idim, jdim] += s_inv * (mps[idim][i[idim]] - mean[idim]) * (mps[jdim][i[jdim]] - mean[jdim]) * h.weights[i] |
| 91 | + end |
| 92 | + end |
| 93 | + end |
| 94 | + return c |
| 95 | +end |
| 96 | + |
| 97 | + |
| 98 | +function Distributions._rand!(r::AbstractRNG, d::MvBinnedDist{T,N}, A::AbstractVector{<:Real}) where {T, N} |
| 99 | + rand!(r, A) |
| 100 | + next_inds::UnitRange{Int} = searchsorted(d.probabilty_edges::Vector{T}, A[1]::T) |
| 101 | + cell_lin_index::Int = min(next_inds.start, next_inds.stop) |
| 102 | + cell_car_index = d.cart_inds[cell_lin_index] |
| 103 | + for idim in Base.OneTo(N) |
| 104 | + i = cell_car_index[idim] |
| 105 | + sub_int = d.edges[idim][i:i+1] |
| 106 | + sub_int_width::T = sub_int[2] - sub_int[1] |
| 107 | + A[idim] = sub_int[1] + sub_int_width * A[idim] |
| 108 | + end |
| 109 | + return A |
| 110 | +end |
| 111 | + |
| 112 | +function Distributions._rand!(r::AbstractRNG, d::MvBinnedDist{T,N}, A::AbstractMatrix{<:Real}) where {T, N} |
| 113 | + Distributions._rand!.((r,), (d,), nestedview(A)) |
| 114 | +end |
| 115 | + |
| 116 | + |
| 117 | +function Distributions.pdf(d::MvBinnedDist{T, N}, x::AbstractArray{<:Real, 1}) where {T, N} |
| 118 | + return @inbounds d.h.weights[StatsBase.binindex(d.h, Tuple(x))...] |
| 119 | +end |
| 120 | + |
| 121 | + |
| 122 | +function Distributions.logpdf(d::MvBinnedDist{T, N}, x::AbstractArray{<:Real, 1}) where {T, N} |
| 123 | + return log(pdf(d, x)) |
| 124 | +end |
| 125 | + |
| 126 | +function Distributions._logpdf(d::MvBinnedDist{T,N}, x::AbstractArray{<:Real, 1}) where {T, N} |
| 127 | + return logpdf(d, x) |
| 128 | +end |
0 commit comments