22
33
44"""
5- UvBinnedDist <: Distribution{Univariate ,Continuous}
5+ MvBinnedDist <: Distribution{Multivariate ,Continuous}
66
77Wraps a multi-dimensional histograms and presents it as a binned multivariate
88distribution.
@@ -11,16 +11,21 @@ Constructor:
1111
1212 MvBinnedDist(h::Histogram{<:Real,N})
1313"""
14- struct MvBinnedDist{T, N} <: Distributions.Distribution{Multivariate,Continuous}
15- h:: StatsBase.Histogram{<:Real, N}
16- edges:: NTuple{N, <:AbstractVector{T}}
17- cart_inds:: CartesianIndices{N, NTuple{N, Base.OneTo{Int}}}
18-
19- probabilty_edges:: AbstractVector{T}
20-
21- μ:: AbstractVector{T}
22- var:: AbstractVector{T}
23- cov:: AbstractMatrix{T}
14+ struct MvBinnedDist{
15+ T <: Real ,
16+ N,
17+ H <: Histogram{<:Real, N} ,
18+ VT <: AbstractVector{T} ,
19+ MT <: AbstractMatrix{T}
20+ } <: Distributions.Distribution{Multivariate,Continuous}
21+ hist:: H
22+ _edges:: NTuple{N, <:AbstractVector{T}}
23+ _cart_inds:: CartesianIndices{N, NTuple{N, Base.OneTo{Int}}}
24+ _probability_edges:: VT
25+ _mean:: VT
26+ _mode:: VT
27+ _var:: VT
28+ _cov:: MT
2429end
2530
2631export MvBinnedDist
@@ -37,83 +42,45 @@ function MvBinnedDist(h::StatsBase.Histogram{<:Real, N}, T::DataType = Float64)
3742 probabilty_edges[i+ 1 ] = v > 1 ? 1 : v
3843 end
3944
40- mean = _mean (h)
41- var = _var (h, mean = mean)
42- cov = _cov (h, mean = mean)
45+ mean_est = _mean (h)
46+ mode_est = _mode (nh)
47+ var_est = _var (h, mean = mean_est)
48+ cov_est = _cov (h, mean = mean_est)
4349
44- return MvBinnedDist {T, N} (
50+ return MvBinnedDist (
4551 nh,
4652 collect .(nh. edges),
4753 CartesianIndices (nh. weights),
4854 probabilty_edges,
49- mean,
50- var,
51- cov
55+ mean_est,
56+ mode_est,
57+ var_est,
58+ cov_est
5259 )
5360end
5461
5562
63+ Base. convert (:: Type{Histogram} , d:: MvBinnedDist ) = d. hist
64+
65+
5666Base. length (d:: MvBinnedDist{T, N} ) where {T, N} = N
5767Base. size (d:: MvBinnedDist{T, N} ) where {T, N} = (N,)
5868Base. eltype (d:: MvBinnedDist{T, N} ) where {T, N} = T
5969
60- Statistics. mean (d:: MvBinnedDist{T, N} ) where {T, N} = d. μ
61- Statistics. var (d:: MvBinnedDist{T, N} ) where {T, N} = d. var
62- Statistics. cov (d:: MvBinnedDist{T, N} ) where {T, N} = d. cov
63-
64-
65- function _mean (h:: StatsBase.Histogram{<:Real, N} ; T:: DataType = Float64) where {N}
66- s_inv:: T = inv (sum (h. weights))
67- m:: Vector{T} = zeros (T, N)
68- mps = StatsBase. midpoints .(h. edges)
69- cart_inds = CartesianIndices (h. weights)
70- for i in cart_inds
71- for idim in 1 : N
72- m[idim] += s_inv * mps[idim][i[idim]] * h. weights[i]
73- end
74- end
75- return m
76- end
77-
78-
79- function _var (h:: StatsBase.Histogram{<:Real, N} ; T:: DataType = Float64, mean = StatsBase. mean (h, T = T), ) where {N}
80- s_inv:: T = inv (sum (h. weights))
81- v:: Vector{T} = zeros (T, N)
82- mps = StatsBase. midpoints .(h. edges)
83- cart_inds = CartesianIndices (h. weights)
84- for i in cart_inds
85- for idim in 1 : N
86- v[idim] += s_inv * (mps[idim][i[idim]] - mean[idim])^ 2 * h. weights[i]
87- end
88- end
89- return v
90- end
91-
92-
93- function _cov (h:: StatsBase.Histogram{<:Real, N} ; T:: DataType = Float64, mean = StatsBase. mean (h, T = T)) where {N}
94- s_inv:: T = inv (sum (h. weights))
95- c:: Matrix{T} = zeros (T, N, N)
96- mps = StatsBase. midpoints .(h. edges)
97- cart_inds = CartesianIndices (h. weights)
98- for i in cart_inds
99- for idim in 1 : N
100- for jdim in 1 : N
101- c[idim, jdim] += s_inv * (mps[idim][i[idim]] - mean[idim]) * (mps[jdim][i[jdim]] - mean[jdim]) * h. weights[i]
102- end
103- end
104- end
105- return c
106- end
70+ Statistics. mean (d:: MvBinnedDist{T, N} ) where {T, N} = d. _mean
71+ StatsBase. mode (d:: MvBinnedDist{T, N} ) where {T, N} = d. _mode
72+ Statistics. var (d:: MvBinnedDist{T, N} ) where {T, N} = d. _var
73+ Statistics. cov (d:: MvBinnedDist{T, N} ) where {T, N} = d. _cov
10774
10875
10976function Distributions. _rand! (r:: AbstractRNG , d:: MvBinnedDist{T,N} , A:: AbstractVector{<:Real} ) where {T, N}
11077 rand! (r, A)
111- next_inds:: UnitRange{Int} = searchsorted (d. probabilty_edges :: Vector{T} , A[1 ]:: T )
78+ next_inds:: UnitRange{Int} = searchsorted (d. _probability_edges :: Vector{T} , A[1 ]:: T )
11279 cell_lin_index:: Int = min (next_inds. start, next_inds. stop)
113- cell_car_index = d. cart_inds [cell_lin_index]
80+ cell_car_index = d. _cart_inds [cell_lin_index]
11481 for idim in Base. OneTo (N)
11582 i = cell_car_index[idim]
116- sub_int = d. edges [idim][i: i+ 1 ]
83+ sub_int = d. _edges [idim][i: i+ 1 ]
11784 sub_int_width:: T = sub_int[2 ] - sub_int[1 ]
11885 A[idim] = sub_int[1 ] + sub_int_width * A[idim]
11986 end
12289
12390function Distributions. _rand! (r:: AbstractRNG , d:: MvBinnedDist{T,N} , A:: AbstractMatrix{<:Real} ) where {T, N}
12491 Distributions. _rand! .((r,), (d,), nestedview (A))
92+ return A
93+ end
94+
95+
96+ # Similar to unroll_tuple in StaticArrays.jl:
97+ @generated function _unsafe_unroll_tuple (A:: AbstractArray , :: Val{L} ) where {L}
98+ exprs = [:(A[idx0 + $ j]) for j = 0 : (L- 1 )]
99+ quote
100+ idx0 = firstindex (A)
101+ Base. @_inline_meta
102+ @inbounds return $ (Expr (:tuple , exprs... ))
103+ end
125104end
126105
127106
128- function Distributions. pdf (d:: MvBinnedDist{T, N} , x:: AbstractArray{<:Real, 1} ) where {T, N}
129- return @inbounds d. h. weights[StatsBase. binindex (d. h, Tuple (x))... ]
107+ function Distributions. pdf (d:: MvBinnedDist{T,N} , x:: AbstractVector{<:Real} ) where {T,N}
108+ length (eachindex (x)) == N || throw (ArgumentError (" Length of variate doesn't match dimensionality of distribution" ))
109+ x_tpl = _unsafe_unroll_tuple (x, Val (N))
110+ _pdf (d. hist, x_tpl)
130111end
131112
132113
0 commit comments