Skip to content

Commit 99e867d

Browse files
committed
ModalAdaBoost model
1 parent b3b015d commit 99e867d

File tree

7 files changed

+196
-3
lines changed

7 files changed

+196
-3
lines changed

src/ModalDecisionTrees.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ include("posthoc.jl")
121121
# Apply decision tree/forest to a dataset
122122
include("apply.jl")
123123

124-
export ModalDecisionTree, ModalRandomForest
124+
export ModalDecisionTree, ModalRandomForest, ModalAdaBoost
125125
export depth
126126

127127
export wrapdataset

src/build.jl

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,54 @@ a random forest model on logiset `X` with labels `Y` and weights `W`.
3131
3232
"""
3333

34+
# """$(doc_build)"""
35+
function build_stumps(
36+
X :: MultiLogiset,
37+
y :: AbstractVector{L},
38+
weigths :: Union{Nothing,AbstractVector{U},Symbol} = nothing;
39+
n_iter :: Int = 10;
40+
# rng :: Random.AbstractRNG = Random.GLOBAL_RNG,
41+
kwargs...,
42+
) where {L<:Label,U}
43+
n_y = length(y)
44+
n_labels = length(unique(y))
45+
base_coeff = log(n_labels - 1)
46+
thresh = 1 - 1 / n_labels
47+
weights = ones(n_y) / n_y
48+
stumps = DTree[]
49+
coeffs = Float64[]
50+
# n_features = size(X, 2)
51+
52+
for i in 1:n_iter
53+
new_stump = build_stump(X, y, weigths; impurity_importance=false, kwargs...)
54+
# new_stump = MDT.build_stump( # TODO c'è anche in MDT!!!
55+
# X, y, weights; rng=DT.mk_rng(rng), impurity_importance=false
56+
# )
57+
# predictions = MDT.apply_tree(new_stump, X) # TODO c'è anche in MDT!!!
58+
# err = DT._weighted_error(y, predictions, weights)
59+
# if err >= thresh # should be better than random guess
60+
# continue
61+
# end
62+
# # SAMME algorithm
63+
# new_coeff = log((1.0 - err) / err) + base_coeff
64+
# unmatches = labels .!= predictions
65+
# weights[unmatches] *= exp(new_coeff)
66+
# weights /= sum(weights)
67+
# push!(coeffs, new_coeff)
68+
# push!(stumps, new_stump.node)
69+
# if err < 1e-6
70+
# break
71+
# end
72+
end
73+
# return (DT.Ensemble{S,T}(stumps, n_features, Float64[]), coeffs)
74+
75+
stumps = DTree[]
76+
for i in 1:n_iter
77+
push!(stump_trees, build_stump(X, y, weigths; kwargs...))
78+
end
79+
return stump_trees
80+
end
81+
3482
"""$(doc_build)"""
3583
function build_stump(
3684
X :: MultiLogiset,

src/interfaces/MLJ.jl

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
module MLJInterface
44

5-
export ModalDecisionTree, ModalRandomForest
5+
export ModalDecisionTree, ModalRandomForest, ModalAdaBoost
66
export depth
77
export wrapdataset
88

@@ -37,12 +37,14 @@ include("MLJ/feature-importance.jl")
3737

3838
include("MLJ/ModalDecisionTree.jl")
3939
include("MLJ/ModalRandomForest.jl")
40+
include("MLJ/ModalAdaBoost.jl")
4041

4142
include("MLJ/docstrings.jl")
4243

4344
const SymbolicModel = Union{
4445
ModalDecisionTree,
4546
ModalRandomForest,
47+
ModalAdaBoost,
4648
}
4749

4850
const TreeModel = Union{
@@ -53,6 +55,10 @@ const ForestModel = Union{
5355
ModalRandomForest,
5456
}
5557

58+
const StumpsModel = Union{
59+
ModalAdaBoost,
60+
}
61+
5662
include("MLJ/downsize.jl")
5763
include("MLJ/clean.jl")
5864

@@ -77,6 +83,8 @@ function MMI.fit(m::SymbolicModel, verbosity::Integer, X, y, var_grouping, class
7783
MDT.build_tree(X, y, w; get_kwargs(m, X)...)
7884
elseif m isa ModalRandomForest
7985
MDT.build_forest(X, y, w; get_kwargs(m, X)...)
86+
elseif m isa ModalAdaBoost
87+
MDT.build_stumps(X, y, w; get_kwargs(m, X)...)
8088
else
8189
error("Unexpected model type: $(typeof(m))")
8290
end
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
mutable struct ModalAdaBoost <: MMI.Probabilistic
2+
## Pruning conditions
3+
max_depth ::Union{Nothing,Int}
4+
min_samples_leaf ::Union{Nothing,Int}
5+
min_purity_increase ::Union{Nothing,Float64}
6+
max_purity_at_leaf ::Union{Nothing,Float64}
7+
max_modal_depth ::Union{Nothing,Int}
8+
9+
## Logic parameters
10+
11+
# Relation set
12+
relations ::Union{
13+
Nothing, # defaults to a well-known relation set, depending on the data;
14+
Symbol, # one of the relation sets specified in AVAILABLE_RELATIONS;
15+
Vector{<:AbstractRelation}, # explicitly specify the relation set;
16+
# Vector{<:Union{Symbol,Vector{<:AbstractRelation}}}, # MULTIMODAL CASE: specify a relation set for each modality;
17+
Function # A function worldtype -> relation set.
18+
}
19+
20+
# Condition set
21+
features ::Union{
22+
Nothing, # defaults to scalar conditions (with ≥ and <) on well-known feature functions (e.g., minimum, maximum), applied to all variables;
23+
Vector{<:Union{SoleData.VarFeature,Base.Callable}}, # scalar conditions with ≥ and <, on an explicitly specified feature set (callables to be applied to each variable, or VarFeature objects);
24+
Vector{<:Tuple{Base.Callable,Integer}}, # scalar conditions with ≥ and <, on a set of features specified as a set of callables to be applied to a set of variables each;
25+
Vector{<:Tuple{TestOperator,<:Union{SoleData.VarFeature,Base.Callable}}}, # explicitly specify the pairs (test operator, feature);
26+
Vector{<:SoleData.ScalarMetaCondition}, # explicitly specify the scalar condition set.
27+
}
28+
conditions ::Union{
29+
Nothing, # defaults to scalar conditions (with ≥ and <) on well-known feature functions (e.g., minimum, maximum), applied to all variables;
30+
Vector{<:Union{SoleData.VarFeature,Base.Callable}}, # scalar conditions with ≥ and <, on an explicitly specified feature set (callables to be applied to each variable, or VarFeature objects);
31+
Vector{<:Tuple{Base.Callable,Integer}}, # scalar conditions with ≥ and <, on a set of features specified as a set of callables to be applied to a set of variables each;
32+
Vector{<:Tuple{TestOperator,<:Union{SoleData.VarFeature,Base.Callable}}}, # explicitly specify the pairs (test operator, feature);
33+
Vector{<:SoleData.ScalarMetaCondition}, # explicitly specify the scalar condition set.
34+
}
35+
# Type for the extracted feature values
36+
featvaltype ::Type
37+
38+
# Initial conditions
39+
initconditions ::Union{
40+
Nothing, # defaults to standard conditions (e.g., start_without_world)
41+
Symbol, # one of the initial conditions specified in AVAILABLE_INITIALCONDITIONS;
42+
InitialCondition, # explicitly specify an initial condition for the learning algorithm.
43+
}
44+
45+
## Miscellaneous
46+
downsize ::Union{Bool,NTuple{N,Integer} where N,Function}
47+
force_i_variables ::Bool
48+
fixcallablenans ::Bool
49+
print_progress ::Bool
50+
rng ::Union{Random.AbstractRNG,Integer}
51+
52+
## DecisionTree.jl parameters
53+
display_depth ::Union{Nothing,Int}
54+
min_samples_split ::Union{Nothing,Int}
55+
n_subfeatures ::Union{Nothing,Int,Float64,Function}
56+
post_prune ::Bool
57+
merge_purity_threshold ::Union{Nothing,Float64}
58+
feature_importance ::Symbol
59+
60+
## AdaBoost parameters
61+
n_iter ::Int
62+
end
63+
64+
# keyword constructor
65+
function ModalAdaBoost(;
66+
max_depth = 1,
67+
min_samples_leaf = nothing,
68+
min_purity_increase = nothing,
69+
max_purity_at_leaf = nothing,
70+
max_modal_depth = nothing,
71+
#
72+
relations = nothing,
73+
features = nothing,
74+
conditions = nothing,
75+
featvaltype = Float64,
76+
initconditions = nothing,
77+
#
78+
downsize = true,
79+
force_i_variables = true,
80+
fixcallablenans = false,
81+
print_progress = false,
82+
rng = Random.GLOBAL_RNG,
83+
#
84+
display_depth = nothing,
85+
min_samples_split = nothing,
86+
n_subfeatures = nothing,
87+
post_prune = false,
88+
merge_purity_threshold = nothing,
89+
feature_importance = :split,
90+
#
91+
n_iter = 10,
92+
)
93+
model = ModalAdaBoost(
94+
max_depth,
95+
min_samples_leaf,
96+
min_purity_increase,
97+
max_purity_at_leaf,
98+
max_modal_depth,
99+
#
100+
relations,
101+
features,
102+
conditions,
103+
featvaltype,
104+
initconditions,
105+
#
106+
downsize,
107+
force_i_variables,
108+
fixcallablenans,
109+
print_progress,
110+
rng,
111+
#
112+
display_depth,
113+
min_samples_split,
114+
n_subfeatures,
115+
post_prune,
116+
merge_purity_threshold,
117+
feature_importance,
118+
n_iter,
119+
)
120+
message = MMI.clean!(model)
121+
isempty(message) || @warn message
122+
return model
123+
end

src/interfaces/MLJ/ModalDecisionTree.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ function ModalDecisionTree(;
116116
merge_purity_threshold,
117117
feature_importance,
118118
)
119+
@show model isa SymbolicModel
119120
message = MMI.clean!(model)
120121
isempty(message) || @warn message
121122
return model

src/interfaces/MLJ/clean.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@ function get_kwargs(m::SymbolicModel, X)
2929
ntrees = m.ntrees,
3030
suppress_parity_warning = true,
3131
)
32+
elseif m isa StumpsModel
33+
(;
34+
n_iter = m.n_iter,
35+
)
3236
else
3337
error("Unexpected model type: $(typeof(m))")
3438
end
@@ -38,7 +42,6 @@ end
3842

3943
function MMI.clean!(m::SymbolicModel)
4044
warning = ""
41-
4245
if m isa TreeModel
4346
mlj_default_min_samples_leaf = mlj_mdt_default_min_samples_leaf
4447
mlj_default_min_purity_increase = mlj_mdt_default_min_purity_increase
@@ -51,6 +54,12 @@ function MMI.clean!(m::SymbolicModel)
5154
mlj_default_n_subfeatures = mlj_mrf_default_n_subfeatures
5255
mlj_default_ntrees = mlj_mrf_default_ntrees
5356
mlj_default_sampling_fraction = mlj_mrf_default_sampling_fraction
57+
elseif m isa StumpsModel
58+
mlj_default_min_samples_leaf = mlj_mdt_default_min_samples_leaf
59+
mlj_default_min_purity_increase = mlj_mdt_default_min_purity_increase
60+
mlj_default_max_purity_at_leaf = mlj_mdt_default_max_purity_at_leaf
61+
mlj_default_n_subfeatures = mlj_mdt_default_n_subfeatures
62+
# TODO mlj_default_n_iter
5463
else
5564
error("Unexpected model type: $(typeof(m))")
5665
end

src/interfaces/MLJ/downsize.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,7 @@ end
88
function make_downsizing_function(::ForestModel)
99
make_downsizing_function(Val(2))
1010
end
11+
12+
function make_downsizing_function(::StumpsModel)
13+
make_downsizing_function(Val(1))
14+
end

0 commit comments

Comments
 (0)