Skip to content

Commit 6abb7f5

Browse files
feynmanliangararslan
authored andcommitted
Add insupport method to MixtureModel (#651)
* Adds mixturemodel#insupport * Adds tests * Fix type of x * Fix tests
1 parent 2d98eb6 commit 6abb7f5

File tree

2 files changed

+23
-0
lines changed

2 files changed

+23
-0
lines changed

src/mixtures/mixturemodel.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,13 @@ Compute the overall mean (expectation).
5858
"""
5959
mean(d::AbstractMixtureModel)
6060

61+
"""
62+
insupport(d::MultivariateMixture, x)
63+
64+
Evaluate whether `x` is within the support of mixture distribution `d`.
65+
"""
66+
insupport(d::AbstractMixtureModel, x::AbstractVector)
67+
6168
"""
6269
pdf(d::Union{UnivariateMixture, MultivariateMixture}, x)
6370
@@ -263,6 +270,19 @@ end
263270

264271
#### Evaluation
265272

273+
function insupport(d::AbstractMixtureModel, x::AbstractVector)
274+
K = ncomponents(d)
275+
p = probs(d)
276+
@assert length(p) == K
277+
for i = 1:K
278+
@inbounds pi = p[i]
279+
if pi > 0.0 && insupport(component(d, i), x)
280+
return true
281+
end
282+
end
283+
return false
284+
end
285+
266286
function _cdf(d::UnivariateMixture, x::Real)
267287
K = ncomponents(d)
268288
p = probs(d)

test/mixture.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,8 @@ test_params(g_u)
150150
g_u = MixtureModel([TriangularDist(-1,2,0),TriangularDist(-.5,3,1),TriangularDist(-2,0,-1)])
151151
@test minimum(g_u) == -2.0
152152
@test maximum(g_u) == 3.0
153+
@test insupport(g_u, 2.5) == true
154+
@test insupport(g_u, 3.5) == false
153155

154156
g_u = UnivariateGMM([0.0, 2.0, -4.0], [1.0, 1.2, 1.5], Categorical([0.2, 0.5, 0.3]))
155157
@test isa(g_u, UnivariateGMM)
@@ -168,5 +170,6 @@ g_m = MixtureModel(
168170
@test isa(g_m, MixtureModel{Multivariate, Continuous, IsoNormal})
169171
@test length(components(g_m)) == 3
170172
@test length(g_m) == 2
173+
@test insupport(g_m, [0.0, 0.0]) == true
171174
test_mixture(g_m, 1000, 10^6)
172175
test_params(g_m)

0 commit comments

Comments
 (0)