Skip to content

Commit 91799d8

Browse files
authored
Merge pull request #306 from FluxML/improve-docs
✨ Improve EE Docs
2 parents d6b00a2 + 02428f0 commit 91799d8

File tree

3 files changed

+33
-9
lines changed

3 files changed

+33
-9
lines changed

docs/make.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ makedocs(
3333
"Regression"=>"interface/Regression.md",
3434
"Multi-Target Regression"=>"interface/Multitarget Regression.md",
3535
"Image Classification"=>"interface/Image Classification.md",
36+
"Entity Embeddings"=>"interface/Entity Embeddings.md",
3637
],
3738
"Common Workflows" => Any[
3839
"Incremental Training"=>"common_workflows/incremental_training/notebook.md",
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
```@docs
2+
MLJFlux.EntityEmbedder
3+
```

src/mlj_embedder_interface.jl

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -85,10 +85,13 @@ MMI.training_losses(embedder::EntityEmbedder, report) =
8585
8686
In MLJ (or MLJBase) bind an instance unsupervised `model` to data with
8787
88-
mach = machine(model, X, y)
88+
mach = machine(embed_model, X, y)
8989
9090
Here:
9191
92+
- `embed_model` is an instance of `EntityEmbedder`, which wraps a supervised MLJFlux model.
93+
The supervised model must be one of these: `MLJFlux.NeuralNetworkClassifier`, `NeuralNetworkBinaryClassifier`,
94+
`MLJFlux.NeuralNetworkRegressor`,`MLJFlux.MultitargetNeuralNetworkRegressor`.
9295
9396
- `X` is any table of input features supported by the model being wrapped. Features to be transformed must
9497
have element scitype `Multiclass` or `OrderedFactor`. Use `schema(X)` to
@@ -129,25 +132,42 @@ X = (;
129132
repeat(["group1", "group1", "group2", "group2", "group3"], Int(N / 5)),
130133
),
131134
)
132-
y = categorical([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) # Classification
135+
y = categorical(repeat(["class1", "class2", "class3", "class4", "class5"], Int(N / 5)))
133136
134-
# Initiate model
135-
EntityEmbedder = @load EntityEmbedder pkg=MLJFlux
137+
# Load the entity embedder, it's neural network backbone and the SVC which inherently supports
138+
# only continuous features
139+
EntityEmbedder = @load EntityEmbedder pkg=MLJFlux
136140
NeuralNetworkClassifier = @load NeuralNetworkClassifier pkg=MLJFlux
141+
SVC = @load SVC pkg=LIBSVM
137142
138-
clf = NeuralNetworkClassifier(embedding_dims=Dict(:Column2 => 2, :Column3 => 2))
139143
140-
emb = EntityEmbedder(clf)
144+
emb = EntityEmbedder(NeuralNetworkClassifier(embedding_dims=Dict(:Column2 => 2, :Column3 => 2)))
145+
clf = SVC(cost = 1.0)
146+
147+
pipeline = emb |> clf
141148
142149
# Construct machine
143-
mach = machine(emb, X, y)
150+
mach = machine(pipeline, X, y)
144151
145152
# Train model
146153
fit!(mach)
147154
155+
# Predict
156+
yhat = predict(mach, X)
157+
148158
# Transform data using model to encode categorical columns
149-
Xnew = transform(mach, X)
150-
Xnew
159+
machy = machine(emb, X, y)
160+
fit!(machy)
161+
julia> Xnew = transform(machy, X)
162+
(Column1 = Float32[1.0, 2.0, 3.0, … ],
163+
Column2_1 = Float32[1.2, 0.08, -0.09, -0.2, 0.94, 1.2, … ],
164+
Column2_2 = Float32[-0.87, -0.34, -0.8, 1.6, 0.75, -0.87, …],
165+
Column3_1 = Float32[-0.0, 1.56, -0.48, -0.9, -0.9, -0.0, …],
166+
Column3_2 = Float32[-1.0, 1.1, -1.54, 0.2, 0.2, -1.0, … ],
167+
Column4 = Float32[1.0, 2.0, 3.0, 4.0, 5.0, 1.0, … ],
168+
Column5 = Float32[0.27, 0.12, -0.60, 1.5, -0.6, -0.123, … ],
169+
Column6_1 = Float32[-0.99, -0.99, 0.8, 0.8, 0.34, -0.99, … ],
170+
Column6_2 = Float32[-1.00, -1.0, 0.19, 0.19, 1.7, -1.00, … ])
151171
```
152172
153173
See also

0 commit comments

Comments
 (0)