Skip to content

Commit 99cdf91

Browse files
authored
Increase versions, speed & update model downloading (#10)
Co-authored-by: Nicolas Dalsass <nicolasdalsass@users.noreply.github.com> - Move FLAIR to resources_init & make non-en spaCy models optional - Upgrade to spaCy 3 & increment plugin version - Disable unused pipeline algos - divides recipe time roughly by two - Allow multi-cpu processing - on a 8 core machine, divives recipe time roughly by 3
1 parent 6fb143e commit 99cdf91

File tree

11 files changed

+84
-268
lines changed

11 files changed

+84
-268
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
# Changelog
22

3+
## Version 2.0.0 - Feature release - 2022-06-13
4+
- Upgrade Flair, spaCy and model downloading functionality
5+
36
## Version 1.3.4 - Feature release - 2022-01-04
47
- Add Japanese support
58

code-env/python/spec/requirements.txt

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
1-
torch==1.6.0
2-
flair==0.6.1
1+
flair==0.11.3
2+
flask>=2.0,<2.1
33
gensim==3.8.0
4-
flask>=1.0,<1.1
4+
numpy==1.19.5
5+
spacy[ja]==3.3.0
6+
tokenizers==0.10.3; python_version == '3.6'
7+
sudachipy==0.6.0; python_version == '3.6'
58
tqdm==4.50.0
6-
spacy[ja]==2.3.2
7-
https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-2.3.1/en_core_web_sm-2.3.1.tar.gz
8-
https://github.com/explosion/spacy-models/releases/download/es_core_news_sm-2.3.1/es_core_news_sm-2.3.1.tar.gz
9-
https://github.com/explosion/spacy-models/releases/download/zh_core_web_sm-2.3.1/zh_core_web_sm-2.3.1.tar.gz
10-
https://github.com/explosion/spacy-models/releases/download/pl_core_news_sm-2.3.0/pl_core_news_sm-2.3.0.tar.gz
11-
https://github.com/explosion/spacy-models/releases/download/nb_core_news_sm-2.3.0/nb_core_news_sm-2.3.0.tar.gz
12-
https://github.com/explosion/spacy-models/releases/download/fr_core_news_sm-2.3.0/fr_core_news_sm-2.3.0.tar.gz
13-
https://github.com/explosion/spacy-models/releases/download/de_core_news_sm-2.3.0/de_core_news_sm-2.3.0.tar.gz
14-
https://github.com/explosion/spacy-models/releases/download/ja_core_news_sm-2.3.0/ja_core_news_sm-2.3.0.tar.gz
9+
https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.3.0/en_core_web_sm-3.3.0.tar.gz
10+
# https://github.com/explosion/spacy-models/releases/download/es_core_news_sm-3.3.0/es_core_news_sm-3.3.0.tar.gz
11+
# https://github.com/explosion/spacy-models/releases/download/zh_core_web_sm-3.3.0/zh_core_web_sm-3.3.0.tar.gz
12+
# https://github.com/explosion/spacy-models/releases/download/pl_core_news_sm-3.3.0/pl_core_news_sm-3.3.0.tar.gz
13+
# https://github.com/explosion/spacy-models/releases/download/nb_core_news_sm-3.3.0/nb_core_news_sm-3.3.0.tar.gz
14+
# https://github.com/explosion/spacy-models/releases/download/fr_core_news_sm-3.3.0/fr_core_news_sm-3.3.0.tar.gz
15+
# https://github.com/explosion/spacy-models/releases/download/de_core_news_sm-3.3.0/de_core_news_sm-3.3.0.tar.gz
16+
# https://github.com/explosion/spacy-models/releases/download/ja_core_news_sm-3.3.0/ja_core_news_sm-3.3.0.tar.gz
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
######################## Base imports #################################
2+
from dataiku.code_env_resources import clear_all_env_vars
3+
from dataiku.code_env_resources import set_env_path
4+
5+
######################## Download FLAIR Models ###########################
6+
# Clear all environment variables defined by a previously run script
7+
clear_all_env_vars()
8+
9+
# Set Flair cache directory
10+
set_env_path("FLAIR_CACHE_ROOT", "flair")
11+
12+
from flair.models import SequenceTagger
13+
14+
# Download pretrained model: automatically managed by Flair,
15+
# does not download anything if model is already in FLAIR_CACHE_ROOT
16+
SequenceTagger.load('flair/ner-english-fast')
17+
# Add any other models you want to download, check https://huggingface.co/flair for examples
18+
# E.g. SequenceTagger.load('flair/ner-french')
19+
# Make sure to modify the model used in recipe.py if you want to use a different model

custom-recipes/named-entity-recognition-extract/recipe.json

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,6 @@
1414
"arity": "UNARY",
1515
"required": true,
1616
"acceptsDataset": true
17-
},
18-
{
19-
"name": "model_folder",
20-
"label": "Flair model (optional)",
21-
"description": "Folder containing Flair model weights",
22-
"arity": "UNARY",
23-
"required": false,
24-
"acceptsManagedFolder": true,
25-
"acceptsDataset": false,
26-
"mustBeStrictlyType": "Filesystem"
2717
}
2818
],
2919
"outputRoles": [
@@ -124,7 +114,7 @@
124114
"name": "ner_model",
125115
"label": "Model",
126116
"type": "SELECT",
127-
"description": "spaCy (multi-lingual, faster) or Flair (English only, slower)",
117+
"description": "spaCy (faster) or Flair (slower)",
128118
"selectChoices": [
129119
{
130120
"value": "spacy",

custom-recipes/named-entity-recognition-extract/recipe.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
# -*- coding: utf-8 -*-
2+
import multiprocessing
3+
24
import dataiku
35
from dataiku.customrecipe import get_input_names_for_role, get_output_names_for_role, get_recipe_config
46

@@ -37,16 +39,10 @@
3739

3840
language = recipe_config.get("text_language_spacy", "en")
3941
else:
40-
from ner_utils_flair import extract_entities, CustomSequenceTagger
41-
42-
try:
43-
model_folder = get_input_names_for_role("model_folder")[0]
44-
except IndexError:
45-
raise Exception(
46-
"To use Flair, download the model using the macro and add the resulting folder as input to the recipe."
47-
)
48-
folder_path = dataiku.Folder(model_folder).get_path()
49-
tagger = CustomSequenceTagger.load("ner-ontonotes-fast", folder_path)
42+
from flair.models import SequenceTagger
43+
from ner_utils_flair import extract_entities
44+
45+
tagger = SequenceTagger.load("flair/ner-english-fast")
5046

5147
#############################
5248
# Main Loop
@@ -63,7 +59,11 @@ def compute_entities_df(df):
6359
out_df = df.merge(out_df, left_index=True, right_index=True)
6460
return out_df
6561

62+
if ner_model == "spacy":
63+
chunksize = 200 * multiprocessing.cpu_count()
64+
else:
65+
chunksize = 100
6666

6767
process_dataset_chunks(
68-
input_dataset=input_dataset, output_dataset=output_dataset, func=compute_entities_df, chunksize=100
68+
input_dataset=input_dataset, output_dataset=output_dataset, func=compute_entities_df, chunksize=chunksize
6969
)

plugin.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
{
22
"id": "named-entity-recognition",
3-
"version": "1.3.4",
3+
"version": "2.0.0",
44
"meta": {
55
"label": "Named Entity Recognition",
66
"category": "Natural Language Processing",

python-lib/ner_utils_flair.py

Lines changed: 12 additions & 143 deletions
Original file line numberDiff line numberDiff line change
@@ -1,134 +1,9 @@
11
# -*- coding: utf-8 -*-
2-
import os
3-
import logging
4-
import re
5-
import json
6-
import requests
7-
from urllib.parse import urlparse
82
from collections import defaultdict
3+
import json
94

105
from flair.data import Sentence
11-
from flair.models.sequence_tagger_model import SequenceTagger
126
import pandas as pd
13-
from tqdm import tqdm
14-
15-
FLAIR_ENTITIES = [
16-
"PERSON",
17-
"NORP",
18-
"FAC",
19-
"ORG",
20-
"GPE",
21-
"LOC",
22-
"PRODUCT",
23-
"EVENT",
24-
"WORK_OF_ART",
25-
"LAW",
26-
"LANGUAGE",
27-
"DATE",
28-
"TIME",
29-
"PERCENT",
30-
"MONEY",
31-
"QUANTITY",
32-
"ORDINAL",
33-
"CARDINAL",
34-
]
35-
36-
37-
def get_from_cache(url: str, cache_dir: str = None) -> str:
38-
"""
39-
Given a URL, look for the corresponding dataset in the local cache.
40-
If it's not there, download it. Then return the path to the cached file.
41-
"""
42-
os.makedirs(cache_dir, exist_ok=True)
43-
44-
filename = re.sub(r".+/", "", url)
45-
# get cache path to put the file
46-
cache_path = os.path.join(cache_dir, filename)
47-
if os.path.exists(cache_path):
48-
logging.info("File {} found in cache".format(filename))
49-
return cache_path
50-
51-
# make HEAD request to check ETag
52-
response = requests.head(url)
53-
if response.status_code != 200:
54-
raise IOError("HEAD request failed for url {}".format(url))
55-
56-
if not os.path.exists(cache_path):
57-
logging.info("File {} not found in cache, downloading from URL {}...".format(filename, url))
58-
req = requests.get(url, stream=True)
59-
content_length = req.headers.get("Content-Length")
60-
total = int(content_length) if content_length is not None else None
61-
progress = tqdm(unit="B", total=total)
62-
with open(cache_path, "wb") as temp_file:
63-
for chunk in req.iter_content(chunk_size=1024):
64-
if chunk: # filter out keep-alive new chunks
65-
progress.update(len(chunk))
66-
temp_file.write(chunk)
67-
progress.close()
68-
69-
return cache_path
70-
71-
72-
def cached_path(url_or_filename: str, cache_path: str, cache_dir: str) -> str:
73-
"""
74-
Given something that might be a URL (or might be a local path),
75-
determine which. If it's a URL, download the file and cache it, and
76-
return the path to the cached file. If it's already a local path,
77-
make sure the file exists and then return the path.
78-
"""
79-
dataset_cache = os.path.join(cache_path, cache_dir)
80-
81-
parsed = urlparse(url_or_filename)
82-
83-
if parsed.scheme in ("http", "https"):
84-
# URL, so get it from the cache (downloading if necessary)
85-
return get_from_cache(url_or_filename, dataset_cache)
86-
elif parsed.scheme == "" and os.path.exists(url_or_filename):
87-
# File, and it exists.
88-
return url_or_filename
89-
elif parsed.scheme == "":
90-
# File, but it doesn't exist.
91-
raise FileNotFoundError("file {} not found".format(url_or_filename))
92-
else:
93-
# Something unknown
94-
raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename))
95-
96-
97-
class CustomSequenceTagger(SequenceTagger):
98-
@staticmethod
99-
def load(model: str, cache_path: str):
100-
model_file = None
101-
aws_resource_path = "https://nlp.informatik.hu-berlin.de/resources/models"
102-
103-
if model.lower() == "ner":
104-
base_path = "/".join([aws_resource_path, "ner", "en-ner-conll03-v0.4.pt"])
105-
model_file = cached_path(base_path, cache_path, cache_dir="models")
106-
107-
if model.lower() == "ner-ontonotes-fast":
108-
base_path = "/".join([aws_resource_path, "ner-ontonotes-fast", "en-ner-ontonotes-fast-v0.4.pt"])
109-
model_file = cached_path(base_path, cache_path, cache_dir="models")
110-
111-
if model_file is not None:
112-
tagger = SequenceTagger.load(model_file)
113-
return tagger
114-
115-
116-
#############################
117-
# NER function
118-
#############################
119-
120-
# Regex for matching either
121-
PATTERN = r"({}|{})".format(
122-
# Single-word entities
123-
r"(?:\s*\S+ <S-[A-Z_]*>)", # (<S-TAG> format)
124-
# Match multi-word entities
125-
r"{}{}{}".format(
126-
r"(?:\s*\S+ <B-[A-Z_]*>)", # A first tag in <B-TAG> format
127-
r"(?:\s*\S+ <I-[A-Z_]*>)*", # Zero or more tags in <I-TAG> format
128-
r"(?:\s*\S+ <E-[A-Z_]*>)", # A final tag in <E-TAG> format
129-
),
130-
)
131-
matcher = re.compile(PATTERN)
1327

1338

1349
def extract_entities(text_column, format, tagger):
@@ -138,33 +13,27 @@ def extract_entities(text_column, format, tagger):
13813
# Tag Sentences
13914
tagger.predict(sentences)
14015

141-
# Retrieve entities
142-
if format:
143-
entity_df = pd.DataFrame()
144-
else:
145-
entity_df = pd.DataFrame(columns=FLAIR_ENTITIES)
146-
16+
# Extract entities
17+
rows = []
14718
for sentence in sentences:
14819
df_row = defaultdict(list)
149-
entities = matcher.findall(sentence.to_tagged_string())
150-
# Entities are in the following format: word1 <X-TAG> word2 <X-TAG> ...
151-
for entity in entities:
152-
# Extract entity text (word1, word2, ...)
153-
text = " ".join(entity.split()[::2])
154-
# Extract entity type (TAG)
155-
tag = re.search(r"<.-(.+?)>", entity).group(1)
20+
for entity in sentence.get_spans('ner'):
21+
tag = entity.get_label("ner").value
22+
text = entity.text
15623
df_row[tag].append(text)
157-
15824
if format:
15925
df_row = {"sentence": sentence.to_plain_string(), "entities": json.dumps(df_row)}
16026
else:
16127
for k, v in df_row.items():
16228
df_row[k] = json.dumps(v)
16329
df_row["sentence"] = sentence.to_plain_string()
16430

165-
entity_df = entity_df.append(df_row, ignore_index=True)
31+
rows.append(df_row)
16632

167-
cols = [col for col in entity_df.columns.tolist() if col != "sentence"]
168-
entity_df = entity_df[cols]
33+
entity_df = pd.DataFrame(rows)
16934

35+
# Put 'sentence' column first
36+
cols = sorted(list(entity_df.columns))
37+
cols.insert(0, cols.pop(cols.index("sentence")))
38+
entity_df = entity_df[cols]
17039
return entity_df

python-lib/ner_utils_spacy.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# -*- coding: utf-8 -*-
2-
import json
32
from collections import defaultdict
3+
import json
4+
45
import pandas as pd
56
import spacy
67

@@ -15,14 +16,25 @@
1516
"nb": "nb_core_news_sm",
1617
}
1718

19+
def get_spacy_model(language: str):
20+
language_model = SPACY_LANGUAGE_MODELS.get(language, None)
21+
if language_model is None:
22+
raise ValueError(f"The language {language} is not available. \
23+
You can add the language & corresponding model name by editing the code.")
24+
try:
25+
nlp = spacy.load(language_model, exclude=["tok2vec", "tagger", "parser", "attribute_ruler", "lemmatizer"])
26+
except OSError:
27+
# Raising ValueError instead of OSError so it shows up at the top of the log
28+
raise ValueError(f"Could not find spaCy model for the language {language}. \
29+
Maybe you need to edit the requirements.txt file to enable it.")
30+
return nlp
1831

1932
def extract_entities(text_column, format: bool, language: str):
2033
# Tag sentences
21-
nlp = spacy.load(SPACY_LANGUAGE_MODELS[language])
22-
docs = nlp.pipe(text_column.values)
23-
34+
nlp = get_spacy_model(language=language)
35+
docs = nlp.pipe(text_column.values, n_process=-1, batch_size=100)
2436
# Extract entities
25-
entity_df = pd.DataFrame()
37+
rows = []
2638
for doc in docs:
2739
df_row = defaultdict(list)
2840
for entity in doc.ents:
@@ -35,11 +47,12 @@ def extract_entities(text_column, format: bool, language: str):
3547
df_row[k] = json.dumps(v)
3648
df_row["sentence"] = doc.text
3749

38-
entity_df = entity_df.append(df_row, ignore_index=True)
50+
rows.append(df_row)
51+
52+
entity_df = pd.DataFrame(rows)
3953

4054
# Put 'sentence' column first
4155
cols = sorted(list(entity_df.columns))
4256
cols.insert(0, cols.pop(cols.index("sentence")))
4357
entity_df = entity_df[cols]
44-
4558
return entity_df

python-runnables/named-entity-recognition-download/runnable.json

Lines changed: 0 additions & 25 deletions
This file was deleted.

0 commit comments

Comments
 (0)