Skip to content

Commit 83492dc

Browse files
author
maciej
committed
Add explicit sample weights.
1 parent fbd8417 commit 83492dc

File tree

6 files changed

+726
-533
lines changed

6 files changed

+726
-533
lines changed

changelog.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
### Added
88
- max_sampled hyperparameter for WARP losses. This allows trading off accuracy for WARP training time: a smaller value
99
will mean less negative sampling and faster training when the model is near the optimum.
10+
- Added a sample_weight argument to fit and fit_partial functions. A high value will now increase the size of the SGD step taken for that interaction.
1011

1112
## [1.8][2016-01-14]
1213
### Changed

lightfm/lightfm.py

Lines changed: 58 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from __future__ import print_function
22

3+
import textwrap
4+
35
import numpy as np
46

57
import scipy.sparse as sp
@@ -197,37 +199,29 @@ def _to_cython_dtype(self, mat):
197199
else:
198200
return mat
199201

200-
def fit(self, interactions, user_features=None, item_features=None,
202+
def fit(self, interactions,
203+
user_features=None, item_features=None,
204+
sample_weight=None,
201205
epochs=1, num_threads=1, verbose=False):
202-
203-
# Discard old results, if any
204-
self._reset_state()
205-
206-
return self.fit_partial(interactions,
207-
user_features=user_features,
208-
item_features=item_features,
209-
epochs=epochs,
210-
num_threads=num_threads,
211-
verbose=verbose)
212-
213-
def fit_partial(self, interactions, user_features=None, item_features=None,
214-
epochs=1, num_threads=1, verbose=False):
215206
"""
216-
Fit the model. Repeated calls to this function will resume training from
217-
the point where the last call finished.
207+
Fit the model.
218208
219209
Arguments:
220210
- coo_matrix interactions: matrix of shape [n_users, n_items] containing
221-
user-item interactions. Will be converted to
211+
user-item interactions. Will be converted to
222212
numpy.float32 dtype if it is not of that type
223-
(this conversion may be heavy depending upon
213+
(this conversion may be heavy depending upon
224214
matrix size)
225215
- csr_matrix user_features: array of shape [n_users, n_user_features].
226216
Each row contains that user's weights
227217
over features.
228218
- csr_matrix item_features: array of shape [n_items, n_item_features].
229219
Each row contains that item's weights
230220
over features.
221+
- np.float32 array user_weights: array of shape [n_interactions,] with
222+
weights applied to individual interactions.
223+
Defaults to weight 1.0 for all interactions.
224+
Not implemented for the k-OS loss.
231225
232226
- int epochs: number of epochs to run. Default: 1
233227
- int num_threads: number of parallel computation threads to use. Should
@@ -236,6 +230,22 @@ def fit_partial(self, interactions, user_features=None, item_features=None,
236230
- bool verbose: whether to print progress messages.
237231
"""
238232

233+
# Discard old results, if any
234+
self._reset_state()
235+
236+
return self.fit_partial(interactions,
237+
user_features=user_features,
238+
item_features=item_features,
239+
sample_weight=sample_weight,
240+
epochs=epochs,
241+
num_threads=num_threads,
242+
verbose=verbose)
243+
244+
def fit_partial(self, interactions,
245+
user_features=None, item_features=None,
246+
sample_weight=None,
247+
epochs=1, num_threads=1, verbose=False):
248+
239249
# We need this in the COO format.
240250
# If that's already true, this is a no-op.
241251
interactions = interactions.tocoo()
@@ -247,9 +257,17 @@ def fit_partial(self, interactions, user_features=None, item_features=None,
247257
user_features,
248258
item_features)
249259

260+
if self.loss == 'warp-kos' and sample_weight is not None:
261+
raise NotImplementedError('k-OS loss with sample weights '
262+
'not implemented.')
263+
250264
interactions = self._to_cython_dtype(interactions)
251265
user_features = self._to_cython_dtype(user_features)
252266
item_features = self._to_cython_dtype(item_features)
267+
sample_weight = (self._to_cython_dtype(sample_weight)
268+
if sample_weight is not None else
269+
np.ones(interactions.getnnz(),
270+
dtype=CYTHON_DTYPE))
253271

254272
if self.item_embeddings is None:
255273
# Initialise latent factors only if this is the first call
@@ -261,10 +279,17 @@ def fit_partial(self, interactions, user_features=None, item_features=None,
261279
# Check that the dimensionality of the feature matrices has
262280
# not changed between runs.
263281
if not item_features.shape[1] == self.item_embeddings.shape[0]:
264-
raise Exception('Incorrect number of features in item_features')
282+
raise ValueError('Incorrect number of features in item_features')
265283

266284
if not user_features.shape[1] == self.user_embeddings.shape[0]:
267-
raise Exception('Incorrect number of features in user_features')
285+
raise ValueError('Incorrect number of features in user_features')
286+
287+
if sample_weight.ndim != 1:
288+
raise ValueError('Sample weights must be 1-dimensional')
289+
290+
if sample_weight.shape[0] != interactions.getnnz():
291+
raise ValueError('Number of sample weights incompatible '
292+
'with number of interactions')
268293

269294
for epoch in range(epochs):
270295

@@ -274,12 +299,21 @@ def fit_partial(self, interactions, user_features=None, item_features=None,
274299
self._run_epoch(item_features,
275300
user_features,
276301
interactions,
302+
sample_weight,
277303
num_threads,
278304
self.loss)
279305

280306
return self
281307

282-
def _run_epoch(self, item_features, user_features, interactions, num_threads, loss):
308+
fit_partial.__doc__ = (fit.__doc__ +
309+
textwrap.dedent("""
310+
311+
Unlike fit, repeated calls to this method will cause trainig to resume
312+
from the current model state.
313+
"""))
314+
315+
def _run_epoch(self, item_features, user_features, interactions,
316+
sample_weight, num_threads, loss):
283317
"""
284318
Run an individual epoch.
285319
"""
@@ -318,6 +352,7 @@ def _run_epoch(self, item_features, user_features, interactions, num_threads, lo
318352
interactions.row,
319353
interactions.col,
320354
interactions.data,
355+
sample_weight,
321356
shuffle_indices,
322357
lightfm_data,
323358
self.learning_rate,
@@ -331,6 +366,7 @@ def _run_epoch(self, item_features, user_features, interactions, num_threads, lo
331366
interactions.row,
332367
interactions.col,
333368
interactions.data,
369+
sample_weight,
334370
shuffle_indices,
335371
lightfm_data,
336372
self.learning_rate,
@@ -356,6 +392,7 @@ def _run_epoch(self, item_features, user_features, interactions, num_threads, lo
356392
interactions.row,
357393
interactions.col,
358394
interactions.data,
395+
sample_weight,
359396
shuffle_indices,
360397
lightfm_data,
361398
self.learning_rate,

0 commit comments

Comments
 (0)