Skip to content

Commit 143067c

Browse files
author
Boyan Hristov
committed
#20 - fixed scipy.sparse support in expected_error, fixed issues from code review
1 parent 68f8878 commit 143067c

File tree

3 files changed

+123
-45
lines changed

3 files changed

+123
-45
lines changed

modAL/expected_error.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from sklearn.exceptions import NotFittedError
1111

1212
from modAL.models import ActiveLearner
13-
from modAL.utils.data import modALinput, data_vstack, enumerate_data, drop_rows
13+
from modAL.utils.data import modALinput, data_vstack, enumerate_data, drop_rows, data_shape, add_row
1414
from modAL.utils.selection import multi_argmax, shuffled_argmax
1515
from modAL.uncertainty import _proba_uncertainty, _proba_entropy
1616

@@ -38,14 +38,13 @@ def expected_error_reduction(learner: ActiveLearner, X: modALinput, loss: str =
3838
3939
4040
Returns:
41-
The indices of the instances from X chosen to be labelled;
42-
the instances from X chosen to be labelled.
41+
The indices of the instances from X chosen to be labelled.
4342
"""
4443

4544
assert 0.0 <= p_subsample <= 1.0, 'p_subsample subsampling keep ratio must be between 0.0 and 1.0'
4645
assert loss in ['binary', 'log'], 'loss must be \'binary\' or \'log\''
4746

48-
expected_error = np.zeros(shape=(len(X), ))
47+
expected_error = np.zeros(shape=(data_shape(X)[0],))
4948
possible_labels = np.unique(learner.y_training)
5049

5150
try:
@@ -62,7 +61,7 @@ def expected_error_reduction(learner: ActiveLearner, X: modALinput, loss: str =
6261
X_reduced = drop_rows(X, x_idx)
6362
# estimate the expected error
6463
for y_idx, y in enumerate(possible_labels):
65-
X_new = data_vstack((learner.X_training, [x]))
64+
X_new = add_row(learner.X_training, x)
6665
y_new = data_vstack((learner.y_training, np.array(y).reshape(1,)))
6766

6867
cloned_estimator.fit(X_new, y_new)

modAL/utils/data.py

Lines changed: 101 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
from typing import Union, List, Sequence
2-
from itertools import chain
32

43
import numpy as np
54
import pandas as pd
65
import scipy.sparse as sp
76

87

9-
modALinput = Union[list, np.ndarray, sp.csr_matrix, pd.DataFrame]
8+
modALinput = Union[sp.csr_matrix, pd.DataFrame, np.ndarray, list]
109

1110

1211
def data_vstack(blocks: Sequence[modALinput]) -> modALinput:
@@ -19,60 +18,137 @@ def data_vstack(blocks: Sequence[modALinput]) -> modALinput:
1918
Returns:
2019
New sequence of vertically stacked elements.
2120
"""
22-
if isinstance(blocks[0], np.ndarray):
23-
return np.concatenate(blocks)
24-
elif isinstance(blocks[0], list):
25-
return list(chain(blocks))
26-
elif sp.issparse(blocks[0]):
21+
if any([sp.issparse(b) for b in blocks]):
2722
return sp.vstack(blocks)
2823
elif isinstance(blocks[0], pd.DataFrame):
29-
return blocks[0].append(blocks[1])
30-
else:
31-
try:
32-
return np.concatenate(blocks)
33-
except:
34-
raise TypeError('%s datatype is not supported' % type(blocks[0]))
24+
return blocks[0].append(blocks[1:])
25+
elif isinstance(blocks[0], np.ndarray):
26+
return np.concatenate(blocks)
27+
elif isinstance(blocks[0], list):
28+
return np.concatenate(blocks).tolist()
29+
30+
raise TypeError('%s datatype is not supported' % type(blocks[0]))
3531

3632

3733
def data_hstack(blocks: Sequence[modALinput]) -> modALinput:
3834
"""
39-
Stack horizontally both sparse and dense arrays
35+
Stack horizontally sparse/dense arrays and pandas data frames.
4036
4137
Args:
4238
blocks: Sequence of modALinput objects.
4339
4440
Returns:
4541
New sequence of horizontally stacked elements.
4642
"""
47-
# use sparse representation if any of the blocks do
4843
if any([sp.issparse(b) for b in blocks]):
4944
return sp.hstack(blocks)
50-
51-
try:
45+
elif isinstance(blocks[0], pd.DataFrame):
46+
pd.concat(blocks, axis=1)
47+
elif isinstance(blocks[0], np.ndarray):
5248
return np.hstack(blocks)
53-
except:
54-
raise TypeError('%s datatype is not supported' % type(blocks[0]))
49+
elif isinstance(blocks[0], list):
50+
return np.hstack(blocks).tolist()
51+
52+
TypeError('%s datatype is not supported' % type(blocks[0]))
53+
54+
55+
def add_row(X:modALinput, row: modALinput):
56+
"""
57+
Returns X' =
58+
59+
[X
60+
61+
row]
62+
"""
63+
if isinstance(X, np.ndarray):
64+
return np.vstack((X, row))
65+
elif isinstance(X, list):
66+
return np.vstack((X, row)).tolist()
67+
68+
# data_vstack readily supports stacking of matrix as first argument
69+
# and row as second for the other data types
70+
return data_vstack([X, row])
5571

5672

5773
def retrieve_rows(X: modALinput,
5874
I: Union[int, List[int], np.ndarray]) -> Union[sp.csc_matrix, np.ndarray, pd.DataFrame]:
5975
"""
6076
Returns the rows I from the data set X
77+
78+
For a single index, the result is as follows:
79+
* 1xM matrix in case of scipy sparse NxM matrix X
80+
* pandas series in case of a pandas data frame
81+
* row in case of list or numpy format
6182
"""
62-
if isinstance(X, pd.DataFrame):
83+
if sp.issparse(X):
84+
# Out of the sparse matrix formats (sp.csc_matrix, sp.csr_matrix, sp.bsr_matrix,
85+
# sp.lil_matrix, sp.dok_matrix, sp.coo_matrix, sp.dia_matrix), only sp.bsr_matrix, sp.coo_matrix
86+
# and sp.dia_matrix don't support indexing and need to be converted to a sparse format
87+
# that does support indexing. It seems conversion to CSR is currently most efficient.
88+
89+
try:
90+
return X[I]
91+
except:
92+
sp_format = X.getformat()
93+
return X.tocsr()[I].asformat(sp_format)
94+
elif isinstance(X, pd.DataFrame):
6395
return X.iloc[I]
96+
elif isinstance(X, np.ndarray):
97+
return X[I]
98+
elif isinstance(X, list):
99+
return np.array(X)[I].tolist()
100+
101+
raise TypeError('%s datatype is not supported' % type(X))
64102

65-
return X[I]
66103

67104
def drop_rows(X: modALinput,
68105
I: Union[int, List[int], np.ndarray]) -> Union[sp.csc_matrix, np.ndarray, pd.DataFrame]:
69-
if isinstance(X, pd.DataFrame):
106+
"""
107+
Returns X without the row(s) at index/indices I
108+
"""
109+
if sp.issparse(X):
110+
mask = np.ones(X.shape[0], dtype=bool)
111+
mask[I] = False
112+
return retrieve_rows(X, mask)
113+
elif isinstance(X, pd.DataFrame):
70114
return X.drop(I, axis=0)
115+
elif isinstance(X, np.ndarray):
116+
return np.delete(X, I, axis=0)
117+
elif isinstance(X, list):
118+
return np.delete(X, I, axis=0).tolist()
119+
120+
raise TypeError('%s datatype is not supported' % type(X))
71121

72-
return np.delete(X, I, axis=0)
73122

74123
def enumerate_data(X: modALinput):
75-
if isinstance(X, pd.DataFrame):
124+
"""
125+
for i, x in enumerate_data(X):
126+
127+
Depending on the data type of X, returns:
128+
129+
* A 1xM matrix in case of scipy sparse NxM matrix X
130+
* pandas series in case of a pandas data frame X
131+
* row in case of list or numpy format
132+
"""
133+
if sp.issparse(X):
134+
return enumerate(X.tocsr())
135+
elif isinstance(X, pd.DataFrame):
76136
return X.iterrows()
137+
elif isinstance(X, np.ndarray) or isinstance(X, list):
138+
# numpy arrays and lists can readily be enumerated
139+
return enumerate(X)
140+
141+
raise TypeError('%s datatype is not supported' % type(X))
142+
143+
144+
def data_shape(X: modALinput):
145+
"""
146+
Returns the shape of the data set X
147+
"""
148+
if sp.issparse(X) or isinstance(X, pd.DataFrame) or isinstance(X, np.ndarray):
149+
# scipy.sparse, pandas and numpy all support .shape
150+
return X.shape
151+
elif isinstance(X, list):
152+
return np.array(X).shape
77153

78-
return enumerate(X)
154+
raise TypeError('%s datatype is not supported' % type(X))

tests/core_tests.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -457,21 +457,24 @@ def test_max_std_sampling(self):
457457
class TestEER(unittest.TestCase):
458458
def test_eer(self):
459459
for n_pool, n_features, n_classes in product(range(5, 10), range(1, 5), range(2, 5)):
460-
X_training, y_training = np.random.rand(10, n_features), np.random.randint(0, n_classes, size=10)
461-
X_pool, y_pool = np.random.rand(n_pool, n_features), np.random.randint(0, n_classes+1, size=n_pool)
462-
463-
learner = modAL.models.ActiveLearner(RandomForestClassifier(n_estimators=2),
464-
X_training=X_training, y_training=y_training)
465-
466-
modAL.expected_error.expected_error_reduction(learner, X_pool)
467-
modAL.expected_error.expected_error_reduction(learner, X_pool, random_tie_break=True)
468-
modAL.expected_error.expected_error_reduction(learner, X_pool, p_subsample=0.1)
469-
modAL.expected_error.expected_error_reduction(learner, X_pool, loss='binary')
470-
modAL.expected_error.expected_error_reduction(learner, X_pool, p_subsample=0.1, loss='log')
471-
self.assertRaises(AssertionError, modAL.expected_error.expected_error_reduction,
472-
learner, X_pool, p_subsample=1.5)
473-
self.assertRaises(AssertionError, modAL.expected_error.expected_error_reduction,
474-
learner, X_pool, loss=42)
460+
X_training_, y_training = np.random.rand(10, n_features).tolist(), np.random.randint(0, n_classes, size=10)
461+
X_pool_, y_pool = np.random.rand(n_pool, n_features).tolist(), np.random.randint(0, n_classes+1, size=n_pool)
462+
463+
for data_type in (sp.csr_matrix, pd.DataFrame, np.array, list):
464+
X_training, X_pool = data_type(X_training_), data_type(X_pool_)
465+
466+
learner = modAL.models.ActiveLearner(RandomForestClassifier(n_estimators=2),
467+
X_training=X_training, y_training=y_training)
468+
469+
modAL.expected_error.expected_error_reduction(learner, X_pool)
470+
modAL.expected_error.expected_error_reduction(learner, X_pool, random_tie_break=True)
471+
modAL.expected_error.expected_error_reduction(learner, X_pool, p_subsample=0.1)
472+
modAL.expected_error.expected_error_reduction(learner, X_pool, loss='binary')
473+
modAL.expected_error.expected_error_reduction(learner, X_pool, p_subsample=0.1, loss='log')
474+
self.assertRaises(AssertionError, modAL.expected_error.expected_error_reduction,
475+
learner, X_pool, p_subsample=1.5)
476+
self.assertRaises(AssertionError, modAL.expected_error.expected_error_reduction,
477+
learner, X_pool, loss=42)
475478

476479

477480
class TestUncertainties(unittest.TestCase):

0 commit comments

Comments
 (0)