Skip to content

Commit 813d0b4

Browse files
committed
Make Spacy model configurable
1 parent 26a064e commit 813d0b4

File tree

2 files changed

+58
-29
lines changed

2 files changed

+58
-29
lines changed

bugbug/nlp.py

Lines changed: 57 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@
1414
basicConfig(level=INFO)
1515
logger = getLogger(__name__)
1616

17+
DEFAULT_SPACY_MODEL = "en_core_web_md"
18+
_spacy_model_name = DEFAULT_SPACY_MODEL
19+
_nlp = None
20+
1721
HAS_OPTIONAL_DEPENDENCIES = False
1822

1923
try:
@@ -23,43 +27,69 @@
2327
except ImportError:
2428
pass
2529

26-
try:
27-
if HAS_OPTIONAL_DEPENDENCIES:
28-
nlp = spacy.load("en_core_web_md")
29-
except OSError:
30-
HAS_OPTIONAL_DEPENDENCIES = False
31-
logger.error(
32-
"Spacy model is missing, install it with: %s -m spacy download en_core_web_md",
33-
sys.executable,
34-
)
3530

3631
OPT_MSG_MISSING = (
3732
"Optional dependencies are missing, install them with: pip install bugbug[nlp]\n"
3833
"You might need also to download the models with: "
39-
f"{sys.executable} -m spacy download en_core_web_md"
34+
f"{sys.executable} -m spacy download {{model_name}}"
4035
)
4136

4237

38+
def get_spacy_model_name():
39+
return _spacy_model_name
40+
41+
42+
def set_spacy_model_name(model_name):
43+
if not model_name:
44+
raise ValueError("model_name must be a non-empty string")
45+
46+
global _spacy_model_name
47+
global _nlp
48+
_spacy_model_name = model_name
49+
_nlp = None
50+
51+
52+
def get_nlp():
53+
model_name = get_spacy_model_name()
54+
opt_msg_missing = OPT_MSG_MISSING.format(model_name=model_name)
55+
56+
if not HAS_OPTIONAL_DEPENDENCIES:
57+
raise NotImplementedError(opt_msg_missing)
58+
59+
global _nlp
60+
if _nlp is None:
61+
try:
62+
_nlp = spacy.load(model_name)
63+
except OSError as e:
64+
logger.error(
65+
"Spacy model '%s' is missing, install it with: %s -m spacy download %s",
66+
model_name,
67+
sys.executable,
68+
model_name,
69+
)
70+
raise NotImplementedError(opt_msg_missing) from e
71+
72+
return _nlp
73+
74+
4375
def spacy_token_lemmatizer(text):
44-
if len(text) > nlp.max_length:
45-
text = text[: nlp.max_length - 1]
46-
doc = nlp(text)
76+
model = get_nlp()
77+
if len(text) > model.max_length:
78+
text = text[: model.max_length - 1]
79+
doc = model(text)
4780
return [token.lemma_ for token in doc]
4881

4982

5083
def lemmatizing_tfidf_vectorizer(**kwargs):
51-
# Detect when the Spacy optional dependency is missing.
52-
if not HAS_OPTIONAL_DEPENDENCIES:
53-
raise NotImplementedError(OPT_MSG_MISSING)
54-
5584
return TfidfVectorizer(tokenizer=spacy_token_lemmatizer, **kwargs)
5685

5786

5887
def _get_vector_dim():
59-
if nlp.vocab.vectors_length:
60-
return nlp.vocab.vectors_length
88+
model = get_nlp()
89+
if model.vocab.vectors_length:
90+
return model.vocab.vectors_length
6191

62-
doc = nlp("vector")
92+
doc = model("vector")
6393
if doc:
6494
return doc[0].vector.shape[0]
6595

@@ -68,10 +98,11 @@ def _get_vector_dim():
6898

6999
def _token_vector(token):
70100
key = token.lower_
101+
vocab = token.vocab
71102

72103
# Check if there is a lowercase word vector first.
73-
if nlp.vocab.has_vector(key):
74-
return nlp.vocab.get_vector(key)
104+
if vocab.has_vector(key):
105+
return vocab.get_vector(key)
75106

76107
if token.has_vector:
77108
return token.vector
@@ -81,24 +112,21 @@ def _token_vector(token):
81112

82113
class MeanEmbeddingTransformer(BaseEstimator, TransformerMixin):
83114
def __init__(self):
84-
# Detect when the Spacy optional dependency is missing.
85-
if not HAS_OPTIONAL_DEPENDENCIES:
86-
raise NotImplementedError(OPT_MSG_MISSING)
87-
88115
self.dim = _get_vector_dim()
89116

90117
def fit(self, x, y=None):
91118
return self
92119

93120
def transform(self, data):
121+
model = get_nlp()
94122
return np.array(
95123
[
96124
np.mean(
97125
[vec for token in doc if (vec := _token_vector(token)) is not None]
98126
or [np.zeros(self.dim)],
99127
axis=0,
100128
)
101-
for doc in nlp.pipe(data)
129+
for doc in model.pipe(data)
102130
]
103131
)
104132

@@ -125,6 +153,7 @@ def fit(self, X, y=None):
125153
return self
126154

127155
def transform(self, data):
156+
model = get_nlp()
128157
return np.array(
129158
[
130159
np.mean(
@@ -136,6 +165,6 @@ def transform(self, data):
136165
or [np.zeros(self.dim)],
137166
axis=0,
138167
)
139-
for doc in nlp.pipe(data)
168+
for doc in model.pipe(data)
140169
]
141170
)

tests/test_nlp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def runtime_nlp(monkeypatch):
1717
model.vocab.set_vector("world", np.array([3.0, 4.0, 5.0], dtype=np.float32))
1818
model.vocab.set_vector("alpha", np.array([1.0, 2.0, 0.0], dtype=np.float32))
1919
model.vocab.set_vector("beta", np.array([3.0, 5.0, 0.0], dtype=np.float32))
20-
monkeypatch.setattr(nlp, "nlp", model)
20+
monkeypatch.setattr(nlp, "get_nlp", lambda: model)
2121
return model
2222

2323

0 commit comments

Comments
 (0)