Skip to content

Commit c7e70d9

Browse files
authored
Restrict number of input arguments for all measures (#327)
* Structure OCE code better. It was hard to debug. * Restrict number of possible input variables for all measures * Faster tests for OCE No need to use so many samples * Don't restrict type of input for PMI * Revert shortening of type aliases
1 parent 4ef4cee commit c7e70d9

File tree

33 files changed

+251
-98
lines changed

33 files changed

+251
-98
lines changed

src/causal_graphs/oce/OCE.jl

Lines changed: 75 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -153,60 +153,89 @@ function prepare_embeddings(alg::OCE, x, i)
153153
return τs, js, 𝒫s
154154
end
155155

156-
157-
function select_parent!(alg::OCE, parents, τs, js, 𝒫s, xᵢ, i::Int; verbose = true)
158-
# Have any parents been identified yet?
156+
function pairwise_test(parents::OCESelectedParents)
157+
# Have any parents been identified yet? If not, then we're doing pairwise tests.
159158
pairwise = isempty(parents.parents)
160159

160+
return pairwise
161+
end
162+
163+
function select_parent!(alg::OCE, parents, τs, js, 𝒫s, xᵢ, i::Int; verbose = true)
161164
# If there are no potential parents to pick from, return immediately.
162165
isempty(𝒫s) && return false
163166

164-
# Configure estimation and independence testing function calls, which differ in the
165-
# number of arguments depending on whether we're doing the pairwise or conditional case.
166-
if !pairwise
167-
P = StateSpaceSet(parents.parents...)
168-
f = (measure, est, xᵢ, Pⱼ) -> estimate(measure, est, xᵢ, Pⱼ, P)
169-
findep = (test, xᵢ, Pix) -> independence(test, xᵢ, Pix, P)
170-
else
171-
f = (measure, est, xᵢ, Pⱼ) -> estimate(measure, est, xᵢ, Pⱼ)
172-
findep = (test, xᵢ, Pix) -> independence(test, xᵢ, Pix)
173-
end
167+
# Anonymous two-argument functions for computing raw measure and performing
168+
# independence tests, taking care of conditioning on parents when necessary.
169+
compute_raw_measure, test_independence = rawmeasure_and_independencetest(alg, parents)
174170

175171
# Compute the measure without significance testing first. This avoids unnecessary
176172
# independence testing, which takes a lot of time.
177173
Is = zeros(length(𝒫s))
178174
for (i, Pⱼ) in enumerate(𝒫s)
179-
Is[i] = f(alg.utest.measure, alg.utest.est, xᵢ, Pⱼ)
175+
Is[i] = compute_raw_measure(xᵢ, Pⱼ)
180176
end
181177

182-
# Sort variables according to maximal measure and select the first lagged variable that
183-
# gives significant association with the target variable.
184-
maximize_sortidxs = sortperm(Is, rev = true)
185-
n_checked = 0
186-
n_potential_vars = length(𝒫s)
178+
# First sort variables according to maximal measure. Then, we select the first lagged
179+
# variable that gives significant association with the target variable.
180+
idxs_that_maximize_measure = sortperm(Is, rev = true)
181+
182+
n_checked, n_potential_vars = 0, length(𝒫s)
187183
while n_checked < n_potential_vars
188184
n_checked += 1
189-
ix = maximize_sortidxs[n_checked]
185+
ix = idxs_that_maximize_measure[n_checked]
190186
if Is[ix] > 0
191-
# findep takes into account the conditioning set too if it is non-empty.
192-
result = findep(alg.utest, xᵢ, 𝒫s[ix])
187+
result = test_independence(xᵢ, 𝒫s[ix])
193188
if pvalue(result) < alg.α
194-
if verbose && !pairwise
195-
println("\tx$i(0) !⫫ x$(js[ix])($(τs[ix])) | $(selected(parents))")
196-
elseif verbose && pairwise
197-
println("\tx$i(0) !⫫ x$(js[ix])($(τs[ix])) | ∅")
198-
end
199-
push!(parents.parents, 𝒫s[ix])
200-
push!(parents.parents_js, js[ix])
201-
push!(parents.parents_τs, τs[ix])
202-
deleteat!(𝒫s, ix)
203-
deleteat!(js, ix)
204-
deleteat!(τs, ix)
189+
print_status(IndependenceStatus(), parents, τs, js, ix, i; verbose)
190+
update_parents_and_selected!(parents, 𝒫s, τs, js, ix)
205191
return true
206192
end
207193
end
208194
end
209-
# If we reach this stage, no variables have been selected. Print an informative message.
195+
196+
# If we reach this stage, no variables have been selected.
197+
print_status(NoVariablesSelected(), parents, τs, js, i; verbose)
198+
return false
199+
end
200+
201+
# For pairwise cases, we don't need to condition on any parents. For conditional
202+
# cases, we must condition on the parents that have already been selected (`P`).
203+
# The measures, estimators and independence tests are different for the pairwise
204+
# and conditional case.
205+
# This just defines the functions `compute_raw_measure` and
206+
# `test_independence` so that they only need two input arguments, ensuring
207+
# that `P` is always conditioned on when relevant. The two functions are returned.
208+
function rawmeasure_and_independencetest(alg, parents::OCESelectedParents)
209+
if pairwise_test(parents)
210+
measure, est = alg.utest.measure, alg.utest.est
211+
compute_raw_measure = (xᵢ, Pⱼ) -> estimate(measure, est, xᵢ, Pⱼ)
212+
test_independence = (xᵢ, Pix) -> independence(alg.utest, xᵢ, Pix)
213+
else
214+
measure, est = alg.ctest.measure, alg.ctest.est
215+
P = StateSpaceSet(parents.parents...)
216+
compute_raw_measure = (xᵢ, Pⱼ) -> estimate(measure, est, xᵢ, Pⱼ, P)
217+
test_independence = (xᵢ, Pix) -> independence(alg.ctest, xᵢ, Pix, P)
218+
end
219+
return compute_raw_measure, test_independence
220+
end
221+
222+
function update_parents_and_selected!(parents::OCESelectedParents, 𝒫s, τs, js, ix::Int)
223+
push!(parents.parents, 𝒫s[ix])
224+
push!(parents.parents_js, js[ix])
225+
push!(parents.parents_τs, τs[ix])
226+
deleteat!(𝒫s, ix)
227+
deleteat!(js, ix)
228+
deleteat!(τs, ix)
229+
end
230+
231+
###################################################################
232+
# Pretty printing
233+
###################################################################
234+
struct NoVariablesSelected end
235+
function print_status(::NoVariablesSelected, parents::OCESelectedParents,
236+
τs, js, i::Int; verbose = true)
237+
238+
pairwise = pairwise_test(parents)
210239
if verbose && !pairwise
211240
# No more associations were found
212241
s = ["x$i(1) ⫫ x$j() | $(selected(parents)))" for (τ, j) in zip(τs, js)]
@@ -216,7 +245,18 @@ function select_parent!(alg::OCE, parents, τs, js, 𝒫s, xᵢ, i::Int; verbose
216245
s = ["x$i(0) ⫫ x$j() | ∅)" for (τ, j) in zip(τs, js)]
217246
println("\t$(join(s, "\n\t"))")
218247
end
219-
return false
248+
end
249+
250+
struct IndependenceStatus end
251+
function print_status(::IndependenceStatus, parents::OCESelectedParents,
252+
τs, js, ix::Int, i::Int; verbose)
253+
if verbose
254+
if pairwise_test(parents)
255+
println("\tx$i(0) !⫫ x$(js[ix])($(τs[ix])) | ∅")
256+
else
257+
println("\tx$i(0) !⫫ x$(js[ix])($(τs[ix])) | $(selected(parents))")
258+
end
259+
end
220260
end
221261

222262
"""

src/core.jl

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,10 @@ const ArrayOrStateSpaceSet{D, T, N} = Union{AbstractArray{T, N}, AbstractStateSp
66
export AssociationMeasure
77
export DirectedAssociationMeasure
88

9+
10+
# Any non-bivariate association measures must implement:
11+
# - [`min_inputs_vars`](@ref).
12+
# - [`max_inputs_vars`](@ref).
913
"""
1014
AssociationMeasure
1115
@@ -33,3 +37,37 @@ with logarithms to base `b`. This can be used to convert the "unit" of an entrop
3337
function _convert_logunit(h::Real, base_from, base_to)
3438
h / log(base_from, base_to)
3539
end
40+
41+
# Default to bivariate measures. Other measures override it.
42+
"""
43+
min_inputs_vars(m::AssociationMeasure) → nmin::Int
44+
45+
Return the minimum number of variables is that the measure can be computed for.
46+
47+
For example, [`CMIShannon`](@ref) requires 3 input variables.
48+
"""
49+
min_inputs_vars(m::AssociationMeasure) = 2
50+
51+
# Default to bivariate measures. Other measures override it.
52+
53+
"""
54+
max_inputs_vars(m::AssociationMeasure) → nmax::Int
55+
56+
Return the maximum number of variables is that the measure can be computed for.
57+
58+
For example, [`MIShannon`](@ref) cannot be computed for more than 2 variables.
59+
"""
60+
max_inputs_vars(m::AssociationMeasure) = 2
61+
62+
function verify_number_of_inputs_vars(measure::AssociationMeasure, n::Int)
63+
T = typeof(measure)
64+
nmin = min_inputs_vars(measure)
65+
if n < nmin
66+
throw(ArgumentError("$T requires at least $nmin inputs. Got $n inputs."))
67+
end
68+
69+
nmax = max_inputs_vars(measure)
70+
if n > nmax
71+
throw(ArgumentError("$T accepts a maximum of $nmax inputs. Got $n inputs."))
72+
end
73+
end

src/independence_tests/independence.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ Returns a test `summary`, whose type depends on `test`.
2727
- [`LocalPermutationTest`](@ref).
2828
- [`JointDistanceDistributionTest`](@ref).
2929
"""
30-
function independence(test, args...; kwargs...)
31-
error("No concrete implementation for $(typeof(test)) test yet")
30+
function independence(test::IndependenceTest, x...)
31+
throw(ArgumentError("No concrete implementation for $(typeof(test)) test yet"))
3232
end
3333

3434
function pvalue_text_summary(test::IndependenceTestResult)

src/independence_tests/local_permutation/LocalPermutationTest.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,10 @@ end
164164
# KD-trees and do marginal searches for all marginals all the time.
165165
function independence(test::LocalPermutationTest, x, y, z)
166166
measure, est, nshuffles = test.measure, test.est, test.nshuffles
167+
168+
# Make sure that the measure is compatible with the input data.
169+
verify_number_of_inputs_vars(measure, 3)
170+
167171
X, Y, Z = StateSpaceSet(x), StateSpaceSet(y), StateSpaceSet(z)
168172
@assert length(X) == length(Y) == length(Z)
169173
N = length(X)
@@ -178,6 +182,7 @@ end
178182
# computing the test statistic.
179183
function permuted_Îs(X, Y, Z, measure, est, test)
180184
rng, kperm, nshuffles, replace, w = test.rng, test.kperm, test.nshuffles, test.replace, test.w
185+
181186
N = length(X)
182187
test.kperm < N || throw(ArgumentError("kperm must be smaller than input data length"))
183188

src/independence_tests/surrogate/SurrogateTest.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,10 @@ end
129129
# conditional mutual information.
130130
function independence(test::SurrogateTest, x, y, z)
131131
(; measure, est, rng, surrogate, nshuffles) = test
132+
133+
# Make sure that the measure is compatible with the input data.
134+
verify_number_of_inputs_vars(measure, 3)
135+
132136
X, Y, Z = StateSpaceSet(x), StateSpaceSet(y), StateSpaceSet(z)
133137
@assert length(X) == length(Y) == length(Z)
134138
N = length(x)
@@ -145,6 +149,10 @@ end
145149

146150
function independence(test::SurrogateTest, x, y)
147151
(; measure, est, rng, surrogate, nshuffles) = test
152+
153+
# Make sure that the measure is compatible with the input data.
154+
verify_number_of_inputs_vars(measure, 2)
155+
148156
X, Y = StateSpaceSet(x), StateSpaceSet(y)
149157
@assert length(X) == length(Y)
150158
N = length(x)

src/methods/correlation/distance_correlation.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ is computed.
3434
"""
3535
struct DistanceCorrelation <: AssociationMeasure end
3636

37+
max_inputs_vars(::DistanceCorrelation) = 3
38+
3739
"""
3840
distance_correlation(x, y) → dcor ∈ [0, 1]
3941
distance_correlation(x, y, z) → pdcor

src/methods/correlation/partial_correlation.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@ where ``\\hat{P} = \\hat{\\Sigma}^{-1}`` is the sample precision matrix.
4040
"""
4141
struct PartialCorrelation <: AssociationMeasure end
4242

43+
min_inputs_vars(::PartialCorrelation) = 3
44+
max_inputs_vars(::PartialCorrelation) = Inf
45+
4346
"""
4447
partial_correlation(x::VectorOrStateSpaceSet, y::VectorOrStateSpaceSet,
4548
z::VectorOrStateSpaceSet...)

src/methods/infomeasures/condmutualinfo/CMIRenyiJizba.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@ struct CMIRenyiJizba{E <: Renyi} <: ConditionalMutualInformation{E}
3636
end
3737
end
3838

39+
min_inputs_vars(::CMIRenyiJizba) = 3
40+
max_inputs_vars(::CMIRenyiJizba) = 3
41+
3942
function estimate(measure::CMIRenyiJizba, est::Contingency, x, y, z)
4043
c = _contingency_matrix(measure, est, x, y, z)
4144
pxz = probabilities(c, dims = [1, 3])

src/methods/infomeasures/condmutualinfo/CMIRenyiPoczos.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ defined in (Póczos & Schneider, 2012)[^Póczos2012].
99
## Usage
1010
1111
- Use with [`independence`](@ref) to perform a formal hypothesis test for pairwise dependence.
12-
- Use with [`condmutualinfo`](@ref) to compute the raw conditional mutual information.
12+
- Use with [`condmutualinfo`](@ref) to compute the raw conditional mutual information.
1313
1414
## Definition
1515
@@ -37,3 +37,6 @@ struct CMIRenyiPoczos{E <: Renyi} <: ConditionalMutualInformation{E}
3737
new{E}(e)
3838
end
3939
end
40+
41+
min_inputs_vars(::CMIRenyiPoczos) = 3
42+
max_inputs_vars(::CMIRenyiPoczos) = 3

src/methods/infomeasures/condmutualinfo/CMIRenyiSarbu.jl

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ The Rényi conditional mutual information from Sarbu (2014)[^Sarbu2014]).
99
## Usage
1010
1111
- Use with [`independence`](@ref) to perform a formal hypothesis test for pairwise dependence.
12-
- Use with [`condmutualinfo`](@ref) to compute the raw conditional mutual information.
12+
- Use with [`condmutualinfo`](@ref) to compute the raw conditional mutual information.
1313
1414
## Discrete description
1515
@@ -41,12 +41,15 @@ struct CMIRenyiSarbu{E <: Renyi} <: ConditionalMutualInformation{E}
4141
end
4242
end
4343

44-
function estimate(measure::CMIRenyiSarbu, est::Contingency{<:ProbabilitiesEstimator}, x...)
45-
return estimate(measure, contingency_matrix(est.est, x...))
44+
min_inputs_vars(::CMIRenyiSarbu) = 3
45+
max_inputs_vars(::CMIRenyiSarbu) = 3
46+
47+
function estimate(measure::CMIRenyiSarbu, est::Contingency{<:ProbabilitiesEstimator}, x, y, z)
48+
return estimate(measure, contingency_matrix(est.est, x, y, z))
4649
end
4750

48-
function estimate(measure::CMIRenyiSarbu, est::Contingency{<:Nothing}, x...)
49-
return estimate(measure, contingency_matrix(x...))
51+
function estimate(measure::CMIRenyiSarbu, est::Contingency{<:Nothing}, x, y, z)
52+
return estimate(measure, contingency_matrix(x, y, z))
5053
end
5154

5255
function estimate(

0 commit comments

Comments
 (0)