@@ -82,14 +82,27 @@ are still presented a target variable in training, but they behave as transforme
82
82
pipelines. They are entity embedding transformers, in the sense of the article, "Entity
83
83
Embeddings of Categorical Variables" by Cheng Guo, Felix Berkhahn.
84
84
85
- The atomic `model` must be an instance of `MLJFlux.NeuralNetworkClassifier`,
86
- `MLJFlux.NeuralNetworkBinaryClassifier`, `MLJFlux.NeuralNetworkRegressor`, or
87
- `MLJFlux.MultitargetNeuralNetworkRegressor`. Hyperparameters of the atomic model, in
88
- particular `builder` and `embedding_dims`, will effect embedding performance.
85
+ # Training data
89
86
90
- The wrapped model is bound to a machine and trained exactly as the wrapped supervised
91
- `model`, and supports the same form of training data. In particular, a training target
92
- must be supplied.
87
+ In MLJ (or MLJBase) bind an instance unsupervised `model` to data with
88
+
89
+ mach = machine(embed_model, X, y)
90
+
91
+ Here:
92
+
93
+ - `embed_model` is an instance of `EntityEmbedder`, which wraps a supervised MLJFlux
94
+ model, `model`, which must be an instance of one of these:
95
+ `MLJFlux.NeuralNetworkClassifier`, `NeuralNetworkBinaryClassifier`,
96
+ `MLJFlux.NeuralNetworkRegressor`,`MLJFlux.MultitargetNeuralNetworkRegressor`.
97
+
98
+ - `X` is any table of input features supported by the model being wrapped. Features to be
99
+ transformed must have element scitype `Multiclass` or `OrderedFactor`. Use `schema(X)`
100
+ to check scitypes.
101
+
102
+ - `y` is the target, which can be any `AbstractVector` supported by the model being
103
+ wrapped.
104
+
105
+ Train the machine using `fit!(mach)`.
93
106
94
107
# Examples
95
108
@@ -107,6 +120,7 @@ X = (
107
120
b = categorical(rand("abcde", N)),
108
121
c = categorical(rand("ABCDEFGHIJ", N), ordered = true),
109
122
)
123
+
110
124
y = categorical(rand("YN", N));
111
125
112
126
# Initiate model
0 commit comments