Skip to content

Commit 9bdc082

Browse files
committed
Merge branch 'dev' into entity-embedder-docstring-improvements
2 parents b2226a9 + dc42998 commit 9bdc082

File tree

3 files changed

+33
-8
lines changed

3 files changed

+33
-8
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MLJFlux"
22
uuid = "094fc8d1-fd35-5302-93ea-dabda2abf845"
33
authors = ["Anthony D. Blaom <anthony.blaom@gmail.com>", "Ayush Shridhar <ayush.shridhar1999@gmail.com>"]
4-
version = "0.6.5"
4+
version = "0.6.6"
55

66
[deps]
77
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"

src/mlj_embedder_interface.jl

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -82,14 +82,27 @@ are still presented a target variable in training, but they behave as transforme
8282
pipelines. They are entity embedding transformers, in the sense of the article, "Entity
8383
Embeddings of Categorical Variables" by Cheng Guo, Felix Berkhahn.
8484
85-
The atomic `model` must be an instance of `MLJFlux.NeuralNetworkClassifier`,
86-
`MLJFlux.NeuralNetworkBinaryClassifier`, `MLJFlux.NeuralNetworkRegressor`, or
87-
`MLJFlux.MultitargetNeuralNetworkRegressor`. Hyperparameters of the atomic model, in
88-
particular `builder` and `embedding_dims`, will effect embedding performance.
85+
# Training data
8986
90-
The wrapped model is bound to a machine and trained exactly as the wrapped supervised
91-
`model`, and supports the same form of training data. In particular, a training target
92-
must be supplied.
87+
In MLJ (or MLJBase) bind an instance unsupervised `model` to data with
88+
89+
mach = machine(embed_model, X, y)
90+
91+
Here:
92+
93+
- `embed_model` is an instance of `EntityEmbedder`, which wraps a supervised MLJFlux
94+
model, `model`, which must be an instance of one of these:
95+
`MLJFlux.NeuralNetworkClassifier`, `NeuralNetworkBinaryClassifier`,
96+
`MLJFlux.NeuralNetworkRegressor`,`MLJFlux.MultitargetNeuralNetworkRegressor`.
97+
98+
- `X` is any table of input features supported by the model being wrapped. Features to be
99+
transformed must have element scitype `Multiclass` or `OrderedFactor`. Use `schema(X)`
100+
to check scitypes.
101+
102+
- `y` is the target, which can be any `AbstractVector` supported by the model being
103+
wrapped.
104+
105+
Train the machine using `fit!(mach)`.
93106
94107
# Examples
95108
@@ -107,6 +120,7 @@ X = (
107120
b = categorical(rand("abcde", N)),
108121
c = categorical(rand("ABCDEFGHIJ", N), ordered = true),
109122
)
123+
110124
y = categorical(rand("YN", N));
111125
112126
# Initiate model

test/encoders.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
Column3 = categorical(["b", "c", "d"]),
88
Column4 = [1.0, 2.0, 3.0, 4.0, 5.0],
99
)
10+
# Test Encoding Functionality
1011
map = MLJFlux.ordinal_encoder_fit(X; featinds = [2, 3])
1112
Xenc = MLJFlux.ordinal_encoder_transform(X, map)
1213
@test map[2] == Dict('a' => 1, 'b' => 2, 'c' => 3, 'd' => 4, 'e' => 5)
@@ -21,6 +22,16 @@
2122
@test !haskey(map, 1) # already encoded
2223

2324
@test Xenc == MLJFlux.ordinal_encoder_fit_transform(X; featinds = [2, 3])[1]
25+
26+
# Test Consistency with Types
27+
scs = schema(Xenc).scitypes
28+
ts = schema(Xenc).types
29+
30+
# 1) all scitypes must be exactly Continuous
31+
@test all(scs .== Continuous)
32+
33+
# 2) all types must be a concrete subtype of AbstractFloat (i.e. <: AbstractFloat, but ≠ AbstractFloat itself)
34+
@test all(t -> t <: AbstractFloat && isconcretetype(t), ts)
2435
end
2536

2637
@testset "Generate New feature names Function Tests" begin

0 commit comments

Comments
 (0)