Skip to content

Commit b479356

Browse files
Merge pull request #11 from RefaceAI/upd-tests
Update tests
2 parents 7ad65fc + 55ca878 commit b479356

File tree

2 files changed

+24
-16
lines changed

2 files changed

+24
-16
lines changed

quantile_estimator/__init__.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,10 @@ def __init__(self, *invariants):
4848
self._invariants = _DEFAULT_INVARIANTS
4949
else:
5050
self._invariants = [_Quantile(q, e) for (q, e) in invariants]
51+
5152
self._buffer = []
5253
self._head = None
5354
self._observations = 0
54-
self._items = 0
5555

5656
def observe(self, value):
5757
"""Samples an observation's value.
@@ -90,7 +90,7 @@ def query(self, rank):
9090
return current._value
9191

9292
mid_rank = math.floor(rank * self._observations)
93-
max_rank = mid_rank + math.floor(self._invariant(mid_rank, self._observations) / 2)
93+
max_rank = mid_rank + math.ceil(self._invariant(mid_rank, self._observations) / 2)
9494

9595
rank = 0.0
9696
while current._successor:
@@ -115,7 +115,8 @@ def _replace_batch(self):
115115
return
116116

117117
if not self._head:
118-
self._head, self._buffer = self._record(self._buffer[0], 1, 0, None), self._buffer[1:]
118+
self._head = self._record(self._buffer[0], 1, 0, None)
119+
self._buffer = self._buffer[1:]
119120

120121
rank = 0.0
121122
current = self._head
@@ -136,7 +137,6 @@ def _replace_batch(self):
136137
def _record(self, value, rank, delta, successor):
137138
"""Catalogs a sample."""
138139
self._observations += 1
139-
self._items += 1
140140

141141
return _Sample(value, rank, delta, successor)
142142

@@ -187,10 +187,10 @@ def __init__(self, quantile, inaccuracy):
187187

188188
"""Computes the delta for the observation."""
189189
def _delta(self, rank, n):
190-
if rank <= math.floor((self._quantile * n)):
190+
if rank <= math.floor(self._quantile * n):
191191
return self._coefficient_i * (n - rank)
192-
193-
return self._coefficient_ii * rank
192+
else:
193+
return self._coefficient_ii * rank
194194

195195

196196
_DEFAULT_INVARIANTS = [_Quantile(0.50, 0.01), _Quantile(0.99, 0.001)]

tests/test_estimator.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import math
12
import random
23

34
import pytest
@@ -7,19 +8,26 @@
78

89
@pytest.mark.parametrize("num_observations", [1, 10, 100, 1000, 10000, 100000])
910
def test_random_observations(num_observations):
10-
estimator = Estimator()
11-
for _ in range(num_observations):
12-
estimator.observe(random.randint(1, 1000) / 100)
11+
invariants = (0.5, 0.01), (0.9, 0.01), (0.99, 0.01)
12+
estimator = Estimator(*invariants)
1313

14-
assert 0 <= estimator.query(0.5) <= estimator.query(0.9) <= estimator.query(0.99) <= 10
14+
values = [random.uniform(0, 100) for _ in range(num_observations)]
15+
for value in values:
16+
estimator.observe(value)
17+
18+
values.sort()
19+
for quantile, inaccuracy in invariants:
20+
min_rank = math.floor(quantile * num_observations - inaccuracy * num_observations)
21+
max_rank = min(math.ceil(quantile * num_observations + inaccuracy * num_observations), num_observations - 1)
22+
assert 0 <= values[min_rank] <= estimator.query(quantile) <= values[max_rank] <= 100
1523

1624

1725
def test_border_invariants():
1826
estimator = Estimator((0.0, 0.0), (1.0, 0.0))
1927

20-
values = [random.randint(1, 1000) for _ in range(1000)]
21-
for x in values:
22-
estimator.observe(x)
28+
values = [random.uniform(0, 100) for _ in range(500)]
29+
for value in values:
30+
estimator.observe(value)
2331

24-
assert estimator.query(0) == min(values)
25-
assert estimator.query(1) == max(values)
32+
assert estimator.query(0.0) == min(values)
33+
assert estimator.query(1.0) == max(values)

0 commit comments

Comments
 (0)