@@ -13,3 +13,94 @@ function params(::AbstractTransformedMeasure) end
1313function paramnames (:: AbstractTransformedMeasure ) end
1414
1515function parent (:: AbstractTransformedMeasure ) end
16+
17+
18+ export PushforwardMeasure
19+
20+ """
21+ struct PushforwardMeasure{FF,IF,MU,VC<:TransformVolCorr} <: AbstractPushforward
22+ f :: FF
23+ inv_f :: IF
24+ origin :: MU
25+ volcorr :: VC
26+ end
27+ """
28+ struct PushforwardMeasure{FF,IF,M,VC<: TransformVolCorr } <: AbstractPushforward
29+ f:: FF
30+ inv_f:: IF
31+ origin:: M
32+ volcorr:: VC
33+ end
34+
35+ gettransform (ν:: PushforwardMeasure ) = ν. f
36+ parent (ν:: PushforwardMeasure ) = ν. origin
37+
38+
39+ function Pretty. tile (ν:: PushforwardMeasure )
40+ Pretty. list_layout (Pretty. tile .([ν. f, ν. inv_f, ν. origin]); prefix = :PushforwardMeasure )
41+ end
42+
43+
44+ @inline function logdensity_def (ν:: PushforwardMeasure{FF,IF,M,<:WithVolCorr} , y) where {FF,IF,M}
45+ x_orig, inv_ladj = with_logabsdet_jacobian (ν. inv_f, y)
46+ logd_orig = logdensity_def (ν. origin, x_orig)
47+ logd = float (logd_orig + inv_ladj)
48+ neginf = oftype (logd, - Inf )
49+ return ifelse (
50+ # Zero density wins against infinite volume:
51+ (isnan (logd) && logd_orig == - Inf && inv_ladj == + Inf ) ||
52+ # Maybe also for (logd_orig == -Inf) && isfinite(inv_ladj) ?
53+ # Return constant -Inf to prevent problems with ForwardDiff:
54+ (isfinite (logd_orig) && (inv_ladj == - Inf )),
55+ neginf,
56+ logd
57+ )
58+ end
59+
60+ @inline function logdensity_def (ν:: PushforwardMeasure{FF,IF,M,<:NoVolCorr} , y) where {FF,IF,M}
61+ x_orig = to_origin (ν, y)
62+ return logdensity_def (ν. origin, x_orig)
63+ end
64+
65+
66+ insupport (ν:: PushforwardMeasure , y) = insupport (transport_origin (ν), to_origin (ν, y))
67+
68+ testvalue (ν:: PushforwardMeasure ) = from_origin (ν, testvalue (transport_origin (ν)))
69+
70+ @inline function basemeasure (ν:: PushforwardMeasure )
71+ PushforwardMeasure (ν. f, ν. inv_f, basemeasure (transport_origin (ν)), NoVolCorr ())
72+ end
73+
74+
75+ _pushfwd_dof (:: Type{MU} , :: Type , dof) where MU = NoDOF {MU} ()
76+ _pushfwd_dof (:: Type{MU} , :: Type{<:Tuple{Any,Real}} , dof) where MU = dof
77+
78+ # Assume that DOF are preserved if with_logabsdet_jacobian is functional:
79+ @inline function getdof (ν:: MU ) where {MU<: PushforwardMeasure }
80+ T = Core. Compiler. return_type (testvalue, Tuple{typeof (ν. origin)})
81+ R = Core. Compiler. return_type (with_logabsdet_jacobian, Tuple{typeof (ν. f), T})
82+ _pushfwd_dof (MU, R, getdof (ν. origin))
83+ end
84+
85+ # Bypass `checked_var`, would require potentially costly transformation:
86+ @inline checked_var (:: PushforwardMeasure , x) = x
87+
88+
89+ @inline transport_origin (ν:: PushforwardMeasure ) = ν. origin
90+ @inline from_origin (ν:: PushforwardMeasure , x) = ν. f (x)
91+ @inline to_origin (ν:: PushforwardMeasure , y) = ν. inv_f (y)
92+
93+ function Base. rand (rng:: AbstractRNG , :: Type{T} , ν:: PushforwardMeasure ) where T
94+ return from_origin (ν, rand (rng, T, transport_origin (ν)))
95+ end
96+
97+
98+ export pushfwd
99+
100+ """
101+ pushfwd(f, μ, volcorr = WithVolCorr())
102+
103+ Return the [pushforward measure](https://en.wikipedia.org/wiki/Pushforward_measure)
104+ from `μ` the [measurable function](https://en.wikipedia.org/wiki/Measurable_function) `f`.
105+ """
106+ pushfwd (f, μ, volcorr = WithVolCorr ()) = PushforwardMeasure (f, inverse (f), μ, volcorr)
0 commit comments