Skip to content

Commit 6a78fa1

Browse files
committed
implement loading sentences
1 parent 3a4bde0 commit 6a78fa1

File tree

3 files changed

+33
-4
lines changed

3 files changed

+33
-4
lines changed

aonewsela/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11

22

3-
__version__ = '1.0.0'
3+
__version__ = '1.1.0'

aonewsela/dataset.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,20 @@ def load_tokens(self) -> List[str]:
3535
return res
3636

3737
def load_sentences(self) -> List[str]:
38-
raise NotImplementedError
38+
39+
exceptions = {'u.s.', 'u.n.', 'st.', 'dr.', 'd.c.', 'jan.', 'feb.'}
40+
41+
sentences = []
42+
tokens_in_sentence = []
43+
for token in self.load_tokens():
44+
if (token.endswith('.') or token.endswith('!') or token.endswith('?')) and token not in exceptions:
45+
sentence = ' '.join(tokens_in_sentence) + ' ' + token
46+
sentences.append(sentence)
47+
tokens_in_sentence = []
48+
else:
49+
tokens_in_sentence.append(token)
50+
51+
return sentences
3952

4053
def load_text(self) -> str:
4154
return ' '.join(self.load_tokens())

aonewsela/pipeline.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,28 @@ def path_to_version(p: Path):
4646
res = []
4747
for path_article in sorted(article_paths, key=lambda p: path_to_version(p), reverse=True):
4848

49-
text = path_article.read_text(encoding='utf-8').replace('\n', ' ').lower()
49+
# TODO many article start with a location followed by a dash - remove this
50+
51+
# filter article sub-headings
52+
lines_filtered = []
53+
for line in path_article.open(encoding='utf-8').readlines():
54+
55+
if line.startswith('##'):
56+
continue
57+
58+
if 'http' in line: # TODO only exclude the sentence or link, not the entire line
59+
continue
60+
61+
if '<img' in line:
62+
continue
63+
64+
lines_filtered.append(line.rstrip('\n'))
5065

5166
if not self.params.punctuation:
5267
raise NotImplementedError
5368

54-
res.append(Transcript(text, path_to_version(path_article)))
69+
text = ' '.join(lines_filtered)
70+
res.append(Transcript(text.lower(), path_to_version(path_article)))
5571

5672
pbar.update()
5773

0 commit comments

Comments
 (0)