diff --git a/ext/enumerable/statistics/extension/statistics.c b/ext/enumerable/statistics/extension/statistics.c index fb5abd2..4e4c640 100644 --- a/ext/enumerable/statistics/extension/statistics.c +++ b/ext/enumerable/statistics/extension/statistics.c @@ -1413,6 +1413,186 @@ enum_stdev(int argc, VALUE* argv, VALUE obj) return stdev; } +#if SIZEOF_SIZE_T == SIZEOF_LONG +static inline size_t +random_usize_limited(VALUE rnd, size_t max) +{ + return (size_t)rb_random_ulong_limited(rnd, max); +} +#else +static inline size_t +random_usize_limited(VALUE rnd, size_t max) +{ + if (max <= ULONG_MAX) { + return (size_t)rb_random_ulong_limited(rnd, (unsigned long)max); + } + else { + VALUE num = rb_random_int(rnd, SIZET2NUM(max)); + return NUM2SIZET(num); + } +} +#endif + +struct enum_sample_memo { + size_t k; + long n; + VALUE sample; + VALUE random; +}; + +static VALUE +enum_sample_single_i(RB_BLOCK_CALL_FUNC_ARGLIST(e, data)) +{ + struct enum_sample_memo *memo = (struct enum_sample_memo *)data; + ENUM_WANT_SVALUE(); + + if (++memo->k <= 1) { + memo->sample = e; + } + else { + size_t j = random_usize_limited(memo->random, memo->k - 1); + if (j == 1) { + memo->sample = e; + } + } + + return Qnil; +} + +static VALUE +enum_sample_single(VALUE obj, VALUE random) +{ + struct enum_sample_memo memo; + + memo.k = 0; + memo.n = 1; + memo.sample = Qundef; + memo.random = random; + + rb_block_call(obj, id_each, 0, 0, enum_sample_single_i, (VALUE)&memo); + + return memo.sample; +} + +static VALUE +enum_sample_multiple_without_replace_unweighted_i(RB_BLOCK_CALL_FUNC_ARGLIST(e, data)) +{ + struct enum_sample_memo *memo = (struct enum_sample_memo *)data; + ENUM_WANT_SVALUE(); + + if (++memo->k <= memo->n) { + rb_ary_push(memo->sample, e); + } + else { + size_t j = random_usize_limited(memo->random, memo->k - 1); + if (j <= memo->n) { + rb_ary_store(memo->sample, (long)(j - 1), e); + } + } + + return Qnil; +} + +static VALUE +enum_sample_multiple_unweighted(VALUE obj, long size, VALUE random, int replace_p) +{ + struct enum_sample_memo memo; + + assert(size > 1); + + memo.k = 0; + memo.n = size; + memo.sample = rb_ary_new_capa(size); + memo.random = random; + + if (replace_p) { + return Qnil; + } + else { + rb_block_call(obj, id_each, 0, 0, enum_sample_multiple_without_replace_unweighted_i, (VALUE)&memo); + } + + return memo.sample; +} + +/* call-seq: + * enum.sample -> obj + * enum.sample(random: rng) -> obj + * enum.sample(n) -> ary + * enum.sample(n, random: rng) -> ary + * enum.sample(n, random: rng, replace: true) -> ary + * + * Choose a random element or +n+ random elements from the enumerable. + * + * The enumerable is completely scanned just once for choosing random elements + * even if +n+ is ommitted or +n+ is +1+. This means this method cannot be + * applicable to an infinite enumerable. + * + * +replace:+ keyword specifies whether the sample is with or without + * replacement. + * + * On without-replacement sampling, the elements are chosen by using random + * in order to ensure that an element doesn't repeat itself unless the + * enumerable already contained duplicated elements. + * + * On with-replacement sampling, the elements are chosen by using random, and + * indices into the array can be duplicated even if the enumerable didn't contain + * duplicated elements. + * + * If the enumerable is empty the first two forms return +nil+, and the latter + * forms with +n+ return an empty array. + * + * The optional +rng+ argument will be used as the random number generator. + */ +static VALUE +enum_sample(int argc, VALUE *argv, VALUE obj) +{ + VALUE size_v, random_v, replace_v, weights_v, opts; + long size; + int replace_p; + + random_v = rb_cRandom; + replace_v = Qundef; + weights_v = Qundef; + + if (argc == 0) goto single; + + rb_scan_args(argc, argv, "01:", &size_v, &opts); + size = NIL_P(size_v) ? 1 : NUM2LONG(size_v); + + if (size == 1 && NIL_P(opts)) { + goto single; + } + + if (!NIL_P(opts)) { + static ID keywords[3]; + VALUE kwargs[3]; + if (!keywords[0]) { + keywords[0] = rb_intern("random"); + keywords[1] = rb_intern("replace"); + /* keywords[2] = rb_intern("weights"); */ + } + rb_get_kwargs(opts, keywords, 0, 2, kwargs); + random_v = kwargs[0]; + replace_v = kwargs[1]; + /* weights_v = kwargs[2]; */ + } + + if (random_v == Qundef) { + random_v = rb_cRandom; + } + + if (size == 1) { +single: + return enum_sample_single(obj, random_v); + } + + replace_p = (replace_v == Qundef) ? 0 : RTEST(replace_v); + + return enum_sample_multiple_unweighted(obj, size, random_v, replace_p); +} + + /* call-seq: * ary.mean_stdev(population: false) * @@ -1479,6 +1659,7 @@ Init_extension(void) rb_define_method(rb_mEnumerable, "variance", enum_variance, -1); rb_define_method(rb_mEnumerable, "mean_stdev", enum_mean_stdev, -1); rb_define_method(rb_mEnumerable, "stdev", enum_stdev, -1); + rb_define_method(rb_mEnumerable, "sample", enum_sample, -1); #ifndef HAVE_ARRAY_SUM rb_define_method(rb_cArray, "sum", ary_sum, -1); diff --git a/spec/enum/sample_spec.rb b/spec/enum/sample_spec.rb new file mode 100644 index 0000000..3959e1a --- /dev/null +++ b/spec/enum/sample_spec.rb @@ -0,0 +1,201 @@ +require 'spec_helper' +require 'enumerable/statistics' + +RSpec.describe Enumerable, '#sample' do + let(:random) { Random.new } + let(:n) { 20 } + + let(:replace) { nil } + let(:weights) { nil } + let(:opts) { {} } + + before do + opts[:replace] = replace if replace + opts[:weights] = weights if weights + end + + context 'when the receiver has 1 item' do + let(:enum) { 1.upto(1) } + + shared_examples_for '1-item enumerable' do + context 'without replacement' do + specify { expect(opts).not_to include(:replace) } + + specify do + expect(enum.sample(**opts)).to eq(1) + + expect(enum.sample(10, **opts)).to eq([1]) + expect(enum.sample(20, **opts)).to eq([1]) + end + end + + context 'with replacement' do + let(:replace) { true } + + specify { expect(opts).to include(replace: true) } + + specify do + expect(enum.sample(10, **opts)).to eq(Array.new(10, 1)) + expect(enum.sample(20, **opts)).to eq(Array.new(20, 1)) + end + end + end + + context 'without weights' do + specify { expect(opts).not_to include(:weights) } + + include_examples '1-item enumerable' + end + + # TODO: weights + xcontext 'with weights' do + let(:weights) do + { 1 => 1.0 } + end + + specify { expect(opts).to include(weights: weights) } + + include_examples '1-item enumerable' + end + end + + context 'when the receiver has 2 item' do + let(:enum) { 1.upto(2) } + + shared_examples_for 'sample from 2-item enumerable without replacement' do + specify { expect(opts).not_to include(:replace) } + + specify do + expect(Array.new(100) { enum.sample(**opts) }).to all(eq(1).or eq(2)) + + expect(enum.sample(10, **opts)).to contain_exactly(1, 2) + expect(enum.sample(20, **opts)).to contain_exactly(1, 2) + end + end + + context 'without weights' do + context 'without replacement' do + it_behaves_like 'sample from 2-item enumerable without replacement' + end + + context 'with replacement' do + let(:replace) { true } + + specify { expect(opts).to include(replace: true) } + + specify do + expect(enum.sample(10, **opts)).to have_attributes(length: 10).and all(eq(1).or eq(2)) + expect(enum.sample(20, **opts)).to have_attributes(length: 20).and all(eq(1).or eq(2)) + end + end + end + + # TODO: weights + xcontext 'with weights' do + specify { expect(opts).to include(weights: weights) } + + context 'without replacement' do + it_behaves_like 'sample from 2-item enumerable without replacement' + end + + context 'with replacement' do + let(:replace) { true } + + specify { expect(opts).to include(replace: true) } + end + end + end + + context 'without weight' do + let(:enum) { 1.upto(100000) } + + specify { expect(opts).not_to include(:weights) } + + context 'without replacement' do + specify { expect(opts).not_to include(:replace) } + + context 'without size' do + context 'without rng' do + specify do + result = enum.sample + expect(result).to be_an(Integer) + other_results = Array.new(100) { enum.sample } + expect(other_results).not_to be_all {|i| i == result } + end + end + + context 'with rng' do + specify do + save_random = random.dup + result = enum.sample(random: random) + expect(result).to be_an(Integer) + other_results = Array.new(100) { enum.sample(random: save_random.dup) } + expect(other_results).to be_all {|i| i == result } + end + end + end + + context 'with size (== 1)' do + context 'without rng' do + specify do + result = enum.sample(1) + expect(result).to be_an(Integer) + other_results = Array.new(100) { enum.sample(1) } + expect(other_results).not_to be_all {|i| i == result } + end + end + + context 'with rng' do + specify do + save_random = random.dup + result = enum.sample(1, random: random) + expect(result).to be_an(Integer) + other_results = Array.new(100) { enum.sample(1, random: save_random.dup) } + expect(other_results).to be_all {|i| i == result } + end + end + end + + context 'with size (> 1)' do + context 'without rng' do + subject(:result) { enum.sample(n) } + + specify do + result = enum.sample(n) + expect(result).to be_an(Array) + expect(result.length).to eq(n) + expect(result.uniq.length).to eq(n) + other_results = Array.new(100) { enum.sample(n) } + expect(other_results).not_to be_all {|i| i == result } + end + end + + context 'with rng' do + subject(:result) { enum.sample(n, random: random) } + + specify do + save_random = random.dup + result = enum.sample(n, random: random) + expect(result).to be_an(Array) + expect(result.length).to eq(n) + expect(result.uniq.length).to eq(n) + other_results = Array.new(100) { enum.sample(n, random: save_random.dup) } + expect(other_results).to be_all {|i| i == result } + end + end + end + end + + context 'with replacement' do + let(:replace) { true } + + specify { expect(opts).to include(replace: true) } + + pending + end + end + + context 'with weight' do + pending + end +end