Skip to content

Commit e8dda66

Browse files
committed
✨ Improve EE Docs
1 parent 22c2e33 commit e8dda66

File tree

1 file changed

+44
-9
lines changed

1 file changed

+44
-9
lines changed

src/mlj_embedder_interface.jl

Lines changed: 44 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -85,10 +85,13 @@ MMI.training_losses(embedder::EntityEmbedder, report) =
8585
8686
In MLJ (or MLJBase) bind an instance unsupervised `model` to data with
8787
88-
mach = machine(model, X, y)
88+
mach = machine(embed_model, X, y)
8989
9090
Here:
9191
92+
- `embed_model` is an instance of `EntityEmbedder`, which wraps a supervised MLJFlux model.
93+
The supervised model must be one of these: `MLJFlux.NeuralNetworkClassifier`, `NeuralNetworkBinaryClassifier`,
94+
`MLJFlux.NeuralNetworkRegressor`,`MLJFlux.MultitargetNeuralNetworkRegressor`.
9295
9396
- `X` is any table of input features supported by the model being wrapped. Features to be transformed must
9497
have element scitype `Multiclass` or `OrderedFactor`. Use `schema(X)` to
@@ -116,6 +119,7 @@ Train the machine using `fit!(mach)`.
116119
```julia
117120
using MLJ
118121
using CategoricalArrays
122+
import Pkg; Pkg.add("MLJLIBSVMInterface") # For SVC
119123
120124
# Setup some data
121125
N = 200
@@ -129,25 +133,56 @@ X = (;
129133
repeat(["group1", "group1", "group2", "group2", "group3"], Int(N / 5)),
130134
),
131135
)
132-
y = categorical([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) # Classification
136+
y = categorical(repeat(["class1", "class2", "class3", "class4", "class5"], Int(N / 5)))
133137
134-
# Initiate model
135-
EntityEmbedder = @load EntityEmbedder pkg=MLJFlux
138+
# Load the entity embedder, it's neural network backbone and the SVC which inherently supports
139+
# only continuous features
140+
EntityEmbedder = @load EntityEmbedder pkg=MLJFlux
136141
NeuralNetworkClassifier = @load NeuralNetworkClassifier pkg=MLJFlux
142+
SVC = @load SVC pkg=LIBSVM
137143
138-
clf = NeuralNetworkClassifier(embedding_dims=Dict(:Column2 => 2, :Column3 => 2))
139144
140-
emb = EntityEmbedder(clf)
145+
146+
emb = EntityEmbedder(NeuralNetworkClassifier(embedding_dims=Dict(:Column2 => 2, :Column3 => 2)))
147+
clf = SVC(cost = 1.0)
148+
149+
julia> pipeline = emb |> clf
150+
DeterministicPipeline(
151+
entity_embedder = EntityEmbedder(
152+
model = NeuralNetworkClassifier(builder = Short(n_hidden = 0, …), …)),
153+
svc = SVC(
154+
kernel = LIBSVM.Kernel.RadialBasis,
155+
gamma = 0.0,
156+
cost = 1.0,
157+
cachesize = 200.0,
158+
degree = 3,
159+
coef0 = 0.0,
160+
tolerance = 0.001,
161+
shrinking = true),
162+
cache = true)
141163
142164
# Construct machine
143-
mach = machine(emb, X, y)
165+
mach = machine(pipeline, X, y)
144166
145167
# Train model
146168
fit!(mach)
147169
170+
# Predict
171+
yhat = predict(mach, X)
172+
148173
# Transform data using model to encode categorical columns
149-
Xnew = transform(mach, X)
150-
Xnew
174+
machy = machine(emb, X, y)
175+
fit!(machy)
176+
julia> Xnew = transform(machy, X)
177+
(Column1 = Float32[1.0, 2.0, 3.0, 4.0, 5.0, 1.0, … ],
178+
Column2_1 = Float32[1.285769, 0.08033762, -0.09961729, -0.2812789, 0.94185555, 1.285769, … ],
179+
Column2_2 = Float32[-0.8712612, -0.34193662, -0.8327084, 1.6905315, 0.75170106, -0.8712612, …],
180+
Column3_1 = Float32[-0.00044717162, 1.5679433, -0.48835647, -0.9364795, -0.9364795, -0.00044717162, …],
181+
Column3_2 = Float32[-1.086054, 1.1133554, -1.5444189, 0.2760421, 0.2760421, -1.086054, … ],
182+
Column4 = Float32[1.0, 2.0, 3.0, 4.0, 5.0, 1.0, … ],
183+
Column5 = Float32[0.27364022, 0.12229505, -0.60269946, 1.5815768, -0.6342952, -0.12323896, … ],
184+
Column6_1 = Float32[-0.99640805, -0.99640805, 0.8055623, 0.8055623, 0.34632754, -0.99640805, … ],
185+
Column6_2 = Float32[-1.0043539, -1.0043539, 0.19345926, 0.19345926, 1.7287723, -1.0043539, … ])
151186
```
152187
153188
See also

0 commit comments

Comments
 (0)