Skip to content

Commit 054982d

Browse files
committed
Add 'zipf' option to specify the exponent of the zipf distribution
1 parent 6babb24 commit 054982d

File tree

8 files changed

+49
-6
lines changed

8 files changed

+49
-6
lines changed

finalfrontier-utils/src/bin/ff-train.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ static MAXN: &str = "maxn";
9292
static MODEL: &str = "model";
9393
static NS: &str = "ns";
9494
static THREADS: &str = "threads";
95+
static ZIPF_EXPONENT: &str = "zipf";
9596

9697
// Argument constants
9798
static CORPUS: &str = "CORPUS";
@@ -145,6 +146,13 @@ fn config_from_matches<'a>(matches: &ArgMatches<'a>) -> Config {
145146
.or_exit("Cannot parse number of negative samples", 1)
146147
})
147148
.unwrap_or(5);
149+
let zipf_exponent = matches
150+
.value_of(ZIPF_EXPONENT)
151+
.map(|v| {
152+
v.parse()
153+
.or_exit("Cannot parse exponent zipf distribution", 1)
154+
})
155+
.unwrap_or(0.5);
148156

149157
Config {
150158
context_size,
@@ -159,6 +167,7 @@ fn config_from_matches<'a>(matches: &ArgMatches<'a>) -> Config {
159167
buckets_exp,
160168
negative_samples,
161169
lr,
170+
zipf_exponent,
162171
}
163172
}
164173

@@ -249,6 +258,13 @@ fn parse_args() -> ArgMatches<'static> {
249258
.help("Number of threads (default: logical_cpus / 2)")
250259
.takes_value(true),
251260
)
261+
.arg(
262+
Arg::with_name(ZIPF_EXPONENT)
263+
.long("zipf")
264+
.value_name("EXP")
265+
.help("Exponent Zipf distribution for negative sampling (default: 0.5)")
266+
.takes_value(true),
267+
)
252268
.arg(
253269
Arg::with_name(CORPUS)
254270
.help("Tokenized corpus")

finalfrontier/src/config.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,4 +96,9 @@ pub struct Config {
9696

9797
/// The initial learning rate.
9898
pub lr: f32,
99+
100+
/// Exponent in zipfian distribution.
101+
///
102+
/// This is s in *f(k) = 1 / (k^s H_{N, s})*.
103+
pub zipf_exponent: f64,
99104
}

finalfrontier/src/model.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
//! Embedding prediction model.
22
3+
use std::f64;
34
use std::fs::File;
45
use std::io::{BufReader, Read, Seek, SeekFrom, Write};
56
use std::iter::Enumerate;
@@ -223,6 +224,7 @@ where
223224
buckets_exp,
224225
negative_samples,
225226
lr,
227+
zipf_exponent: f64::NAN,
226228
})
227229
}
228230

@@ -350,6 +352,7 @@ mod tests {
350352
min_n: 3,
351353
model: ModelType::SkipGram,
352354
negative_samples: 5,
355+
zipf_exponent: 0.5,
353356
};
354357

355358
#[test]

finalfrontier/src/sampling.rs

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ pub trait RangeGenerator: Iterator<Item = usize> {
1010
/// Exponent to use for the Zipf's distribution.
1111
///
1212
/// This is the exponent s in f(k) = 1 / (k^s H_{N, s})
13-
const ZIPF_RANGE_GENERATOR_EXPONENT: f64 = 0.5;
13+
const ZIPF_RANGE_GENERATOR_DEFAULT_EXPONENT: f64 = 0.5;
1414

1515
/// An iterator that draws from *[0, n)* with integer weights.
1616
///
@@ -86,6 +86,7 @@ where
8686
/// is proportional to its frequency.
8787
pub struct ZipfRangeGenerator<R> {
8888
upper_bound: usize,
89+
exponent: f64,
8990
rng: R,
9091
dist: ZipfDistribution,
9192
}
@@ -97,8 +98,9 @@ where
9798
fn clone(&self) -> Self {
9899
ZipfRangeGenerator {
99100
upper_bound: self.upper_bound,
101+
exponent: self.exponent,
100102
rng: self.rng.clone(),
101-
dist: ZipfDistribution::new(self.upper_bound, ZIPF_RANGE_GENERATOR_EXPONENT).unwrap(),
103+
dist: ZipfDistribution::new(self.upper_bound, self.exponent).unwrap(),
102104
}
103105
}
104106
}
@@ -107,11 +109,17 @@ impl<R> ZipfRangeGenerator<R>
107109
where
108110
R: Rng,
109111
{
112+
#[allow(dead_code)]
110113
pub fn new(rng: R, upper: usize) -> Self {
114+
Self::new_with_exponent(rng, upper, ZIPF_RANGE_GENERATOR_DEFAULT_EXPONENT)
115+
}
116+
117+
pub fn new_with_exponent(rng: R, upper_bound: usize, exponent: f64) -> Self {
111118
ZipfRangeGenerator {
112-
upper_bound: upper,
119+
upper_bound,
120+
exponent,
113121
rng,
114-
dist: ZipfDistribution::new(upper, ZIPF_RANGE_GENERATOR_EXPONENT).unwrap(),
122+
dist: ZipfDistribution::new(upper_bound, exponent).unwrap(),
115123
}
116124
}
117125
}

finalfrontier/src/sgd.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,11 @@ where
5959

6060
let range_gen = BandedRangeGenerator::new(
6161
reseed_on_clone.clone(),
62-
ZipfRangeGenerator::new(reseed_on_clone.clone(), model.vocab().len()),
62+
ZipfRangeGenerator::new_with_exponent(
63+
reseed_on_clone.clone(),
64+
model.vocab().len(),
65+
model.config().zipf_exponent,
66+
),
6367
band_size as usize,
6468
);
6569

finalfrontier/src/train_model.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,7 @@ mod tests {
191191
min_n: 3,
192192
model: ModelType::SkipGram,
193193
negative_samples: 5,
194+
zipf_exponent: 0.5,
194195
};
195196

196197
#[test]

finalfrontier/src/vocab.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,7 @@ mod tests {
273273
min_n: 3,
274274
model: ModelType::SkipGram,
275275
negative_samples: 5,
276+
zipf_exponent: 0.5,
276277
};
277278

278279
#[test]

man/ff-train.1.md

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,12 @@ OPTIONS
8787

8888
: The number of thread to use during training for parallelization. The
8989
default is to use half of the logical CPUs of the machine.
90-
90+
91+
`--zipf` *EXP*
92+
93+
: Exponent *s* used in the Zipf distribution `p(k) = 1 / (k^s H_N)` for
94+
negative sampling. Default: 0.5
95+
9196
EXAMPLES
9297
========
9398

0 commit comments

Comments
 (0)