1
1
from __future__ import print_function
2
2
3
+ import textwrap
4
+
3
5
import numpy as np
4
6
5
7
import scipy .sparse as sp
@@ -197,37 +199,29 @@ def _to_cython_dtype(self, mat):
197
199
else :
198
200
return mat
199
201
200
- def fit (self , interactions , user_features = None , item_features = None ,
202
+ def fit (self , interactions ,
203
+ user_features = None , item_features = None ,
204
+ sample_weight = None ,
201
205
epochs = 1 , num_threads = 1 , verbose = False ):
202
-
203
- # Discard old results, if any
204
- self ._reset_state ()
205
-
206
- return self .fit_partial (interactions ,
207
- user_features = user_features ,
208
- item_features = item_features ,
209
- epochs = epochs ,
210
- num_threads = num_threads ,
211
- verbose = verbose )
212
-
213
- def fit_partial (self , interactions , user_features = None , item_features = None ,
214
- epochs = 1 , num_threads = 1 , verbose = False ):
215
206
"""
216
- Fit the model. Repeated calls to this function will resume training from
217
- the point where the last call finished.
207
+ Fit the model.
218
208
219
209
Arguments:
220
210
- coo_matrix interactions: matrix of shape [n_users, n_items] containing
221
- user-item interactions. Will be converted to
211
+ user-item interactions. Will be converted to
222
212
numpy.float32 dtype if it is not of that type
223
- (this conversion may be heavy depending upon
213
+ (this conversion may be heavy depending upon
224
214
matrix size)
225
215
- csr_matrix user_features: array of shape [n_users, n_user_features].
226
216
Each row contains that user's weights
227
217
over features.
228
218
- csr_matrix item_features: array of shape [n_items, n_item_features].
229
219
Each row contains that item's weights
230
220
over features.
221
+ - np.float32 array user_weights: array of shape [n_interactions,] with
222
+ weights applied to individual interactions.
223
+ Defaults to weight 1.0 for all interactions.
224
+ Not implemented for the k-OS loss.
231
225
232
226
- int epochs: number of epochs to run. Default: 1
233
227
- int num_threads: number of parallel computation threads to use. Should
@@ -236,6 +230,22 @@ def fit_partial(self, interactions, user_features=None, item_features=None,
236
230
- bool verbose: whether to print progress messages.
237
231
"""
238
232
233
+ # Discard old results, if any
234
+ self ._reset_state ()
235
+
236
+ return self .fit_partial (interactions ,
237
+ user_features = user_features ,
238
+ item_features = item_features ,
239
+ sample_weight = sample_weight ,
240
+ epochs = epochs ,
241
+ num_threads = num_threads ,
242
+ verbose = verbose )
243
+
244
+ def fit_partial (self , interactions ,
245
+ user_features = None , item_features = None ,
246
+ sample_weight = None ,
247
+ epochs = 1 , num_threads = 1 , verbose = False ):
248
+
239
249
# We need this in the COO format.
240
250
# If that's already true, this is a no-op.
241
251
interactions = interactions .tocoo ()
@@ -247,9 +257,17 @@ def fit_partial(self, interactions, user_features=None, item_features=None,
247
257
user_features ,
248
258
item_features )
249
259
260
+ if self .loss == 'warp-kos' and sample_weight is not None :
261
+ raise NotImplementedError ('k-OS loss with sample weights '
262
+ 'not implemented.' )
263
+
250
264
interactions = self ._to_cython_dtype (interactions )
251
265
user_features = self ._to_cython_dtype (user_features )
252
266
item_features = self ._to_cython_dtype (item_features )
267
+ sample_weight = (self ._to_cython_dtype (sample_weight )
268
+ if sample_weight is not None else
269
+ np .ones (interactions .getnnz (),
270
+ dtype = CYTHON_DTYPE ))
253
271
254
272
if self .item_embeddings is None :
255
273
# Initialise latent factors only if this is the first call
@@ -261,10 +279,17 @@ def fit_partial(self, interactions, user_features=None, item_features=None,
261
279
# Check that the dimensionality of the feature matrices has
262
280
# not changed between runs.
263
281
if not item_features .shape [1 ] == self .item_embeddings .shape [0 ]:
264
- raise Exception ('Incorrect number of features in item_features' )
282
+ raise ValueError ('Incorrect number of features in item_features' )
265
283
266
284
if not user_features .shape [1 ] == self .user_embeddings .shape [0 ]:
267
- raise Exception ('Incorrect number of features in user_features' )
285
+ raise ValueError ('Incorrect number of features in user_features' )
286
+
287
+ if sample_weight .ndim != 1 :
288
+ raise ValueError ('Sample weights must be 1-dimensional' )
289
+
290
+ if sample_weight .shape [0 ] != interactions .getnnz ():
291
+ raise ValueError ('Number of sample weights incompatible '
292
+ 'with number of interactions' )
268
293
269
294
for epoch in range (epochs ):
270
295
@@ -274,12 +299,21 @@ def fit_partial(self, interactions, user_features=None, item_features=None,
274
299
self ._run_epoch (item_features ,
275
300
user_features ,
276
301
interactions ,
302
+ sample_weight ,
277
303
num_threads ,
278
304
self .loss )
279
305
280
306
return self
281
307
282
- def _run_epoch (self , item_features , user_features , interactions , num_threads , loss ):
308
+ fit_partial .__doc__ = (fit .__doc__ +
309
+ textwrap .dedent ("""
310
+
311
+ Unlike fit, repeated calls to this method will cause trainig to resume
312
+ from the current model state.
313
+ """ ))
314
+
315
+ def _run_epoch (self , item_features , user_features , interactions ,
316
+ sample_weight , num_threads , loss ):
283
317
"""
284
318
Run an individual epoch.
285
319
"""
@@ -318,6 +352,7 @@ def _run_epoch(self, item_features, user_features, interactions, num_threads, lo
318
352
interactions .row ,
319
353
interactions .col ,
320
354
interactions .data ,
355
+ sample_weight ,
321
356
shuffle_indices ,
322
357
lightfm_data ,
323
358
self .learning_rate ,
@@ -331,6 +366,7 @@ def _run_epoch(self, item_features, user_features, interactions, num_threads, lo
331
366
interactions .row ,
332
367
interactions .col ,
333
368
interactions .data ,
369
+ sample_weight ,
334
370
shuffle_indices ,
335
371
lightfm_data ,
336
372
self .learning_rate ,
@@ -356,6 +392,7 @@ def _run_epoch(self, item_features, user_features, interactions, num_threads, lo
356
392
interactions .row ,
357
393
interactions .col ,
358
394
interactions .data ,
395
+ sample_weight ,
359
396
shuffle_indices ,
360
397
lightfm_data ,
361
398
self .learning_rate ,
0 commit comments