Skip to content

Commit 9cc33e2

Browse files
committed
clarify lora adaptation
1 parent 5902e7e commit 9cc33e2

File tree

6 files changed

+41
-11
lines changed

6 files changed

+41
-11
lines changed

README.md

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ pip install adapters==0.2.1 --no-dependencies
194194
cd ..
195195
```
196196

197-
Create data in this format:
197+
1. Create data in this format:
198198
```python
199199
import torch
200200

@@ -217,19 +217,22 @@ torch.save(
217217
"dummy-dataset.pth"
218218
)
219219
```
220-
Note that there should not be any newlines within individual sentences! Your corpus should already be well-split.
220+
Note that there should not be any newlines within individual sentences! This now raises an error. Instead, each entry of a list should be a sentence, and there should be no "\n" characters. So your corpus should already be well-split.
221221

222-
Create/adapt config; provide base model via `model_name_or_path` and training data .pth via `text_path`:
222+
2. Create/adapt config; provide base model via `model_name_or_path` and training data .pth via `text_path`:
223223

224224

225225
`configs/lora/lora_dummy_config.json`
226226

227-
Train LoRA:
227+
We recommend starting using this config, and adapting `model_name_or_path`, `output_dir`, and `text_path` if needed.
228+
You may also wish to adapt other aspects such as `adapter_config` and batch sizes, but this is more experimental.
229+
230+
3. Train LoRA:
228231
```
229232
python3 wtpsplit/train/train_lora.py configs/lora/lora_dummy_config.json
230233
```
231234

232-
Once training is done, provide your saved module's path to SaT:
235+
4. Once training is done, provide your saved module's path to SaT:
233236
```python
234237

235238
sat_lora_adapted = SaT("model-used", lora_path="dummy_lora_path")

requirements.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ cohere
1818
replicate
1919
onnx
2020
onnxruntime
21-
torchinfo
2221
mosestokenizer
2322
cached_property
2423
tqdm

setup.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,23 +2,23 @@
22

33
setup(
44
name="wtpsplit",
5-
version="2.1.2",
5+
version="2.1.3",
66
packages=find_packages(),
77
description="Universal Robust, Efficient and Adaptable Sentence Segmentation",
88
author="Markus Frohmann, Igor Sterner, Benjamin Minixhofer",
99
author_email="markus.frohmann@gmail.com",
1010
install_requires=[
1111
"onnxruntime>=1.13.1",
1212
"transformers>=4.22.2",
13-
"huggingface-hub==0.25.2", # see https://github.com/segment-any-text/wtpsplit/issues/135
13+
"huggingface-hub",
1414
"numpy>=1.0",
1515
"scikit-learn>=1",
1616
"tqdm",
1717
"skops",
1818
"pandas>=1",
1919
"cached_property", # for Py37
2020
"mosestokenizer",
21-
"adapters",
21+
"adapters>=1.0.1",
2222
],
2323
url="https://github.com/segment-any-text/wtpsplit",
2424
package_data={"wtpsplit": ["data/*"]},

wtpsplit/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from wtpsplit.extract import BertCharORTWrapper, SaTORTWrapper, PyTorchWrapper, extract
2020
from wtpsplit.utils import Constants, indices_to_sentences, sigmoid, token_to_char_probs
2121

22-
__version__ = "2.1.2"
22+
__version__ = "2.1.3"
2323

2424
warnings.simplefilter("default", DeprecationWarning) # show by default
2525
warnings.simplefilter("ignore", category=FutureWarning) # for tranformers

wtpsplit/train/train_lora.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,10 @@ def prepare_dataset(
126126
if one_sample_per_line or isinstance(dataset[0], list):
127127
processed_dataset = []
128128
for chunk in dataset:
129+
if "\n" in chunk:
130+
raise ValueError(
131+
"Newlines in text are not supported! Data needs to be processed as a list of sentences."
132+
)
129133
processed_chunk = {}
130134
processed_chunk["lang"] = lang
131135
processed_chunk["ends_with_punctuation"] = chunk[-1].endswith(
@@ -137,10 +141,15 @@ def prepare_dataset(
137141
dataset = datasets.Dataset.from_list(processed_dataset)
138142

139143
else:
144+
for i, chunk in enumerate(dataset):
145+
if "\n" in chunk:
146+
raise ValueError(
147+
"Newlines in text are not supported! Data needs to be processed as a list of sentences."
148+
)
140149
dataset = datasets.Dataset.from_list(
141150
[
142151
{
143-
args.text_column: sample + "\n" if sample and sample[-1] != "\n" else sample, # TODO
152+
args.text_column: sample + "\n" if sample and sample[-1] != "\n" else sample,
144153
"lang": lang,
145154
"ends_with_punctuation": sample.endswith(tuple(Constants.PUNCTUATION_CHARS)),
146155
}

wtpsplit/utils/create_dummy_data.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import torch
2+
3+
torch.save(
4+
{
5+
"language_code": {
6+
"sentence": {
7+
"dummy-dataset": {
8+
"meta": {
9+
"train_data": ["train sentence 1", "train sentence 2"],
10+
},
11+
"data": [
12+
"train sentence 1", "train sentence 2"
13+
]
14+
}
15+
}
16+
}
17+
},
18+
"dummy-dataset.pth"
19+
)

0 commit comments

Comments
 (0)