Skip to content

Commit 50e1ca9

Browse files
committed
Merge branch 'dev' into main
2 parents 712907a + bb4afd0 commit 50e1ca9

21 files changed

+198
-128
lines changed

data/SROIE_dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def __getitem__(self, index):
119119
ocr_text_filter = []
120120
seg_index = 0
121121
for text in ocr_text:
122-
if text == "" or text.isspace:
122+
if text == "" or text.isspace():
123123
continue
124124
curr_tokens = self.tokenizer.tokenize(text.lower())
125125
if len(curr_tokens) == 0:

deployment/inference_EPHOIE.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ def inference_pipe(
172172
DEVICE,
173173
NUM_CLASSES,
174174
image_bytes=image_bytes,
175-
parse_mode=PARSE_MODE
175+
parse_mode=PARSE_MODE,
176176
)
177177

178178
with open(image_dir.replace(".jpg", ".json"), "w") as f:

deployment/inference_preporcessing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def generate_batch(
144144
parse_mode: str = None,
145145
):
146146
image = Image.open(io.BytesIO(image_bytes))
147-
image = image.convert('RGB')
147+
image = image.convert("RGB")
148148

149149
status_code, return_text_list, return_coor_list = ocr_extraction(
150150
image_bytes=image_bytes, ocr_url=ocr_url, parse_mode=parse_mode

deployment/module_load.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,7 @@
99
from model.ViBERTgrid_net import ViBERTgridNet
1010

1111

12-
def inference_init(
13-
dir_config: str = "./deployment/config/network_config.yaml"
14-
):
12+
def inference_init(dir_config: str = "./deployment/config/network_config.yaml"):
1513
with open(dir_config, "r") as c:
1614
hyp = yaml.load(c, Loader=yaml.FullLoader)
1715

eval_FUNSD.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def evaluation_FUNSD(
6161
pred_gt_dict.update({pred_label.detach(): gt_label.detach()})
6262

6363
p, r, f, report = BIO_F1_criteria(
64-
pred_gt_dict=pred_gt_dict, tag_to_idx=TAG_TO_IDX, average="macro"
64+
pred_gt_list=pred_gt_dict, tag_to_idx=TAG_TO_IDX, average="macro"
6565
)
6666

6767
return p, r, f, report

model/BERTgrid_generator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ def BERT_embedding(
180180
curr_batch_aggre_embeddings.append(curr_embedding.unsqueeze(0))
181181

182182
prev_seg_index = curr_seg_index.int().item()
183-
183+
184184
if self.grid_mode == "mean":
185185
mean_embeddings /= num_tok
186186
curr_batch_aggre_embeddings.append(mean_embeddings.unsqueeze(0))

model/crf.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ def argmax(vec):
2121
_, idx = torch.max(vec, 1)
2222
return idx.item()
2323

24+
2425
# Compute log sum exp in a numerically stable way for the forward algorithm
2526
def log_sum_exp(vec):
2627
max_score = vec[0, argmax(vec)]
@@ -43,7 +44,6 @@ def __init__(self, tag_to_ix):
4344
self.transitions.data[tag_to_ix[START_TAG], :] = -10000
4445
self.transitions.data[:, tag_to_ix[STOP_TAG]] = -10000
4546

46-
4747
def _forward_alg(self, feats):
4848
device = self.transitions.device
4949

@@ -81,7 +81,12 @@ def _score_sentence(self, feats, tags):
8181
# Gives the score of a provided tag sequence
8282
score = torch.zeros(1, device=device)
8383
tags = torch.cat(
84-
[torch.tensor([self.tag_to_ix[START_TAG]], dtype=torch.long, device=device), tags]
84+
[
85+
torch.tensor(
86+
[self.tag_to_ix[START_TAG]], dtype=torch.long, device=device
87+
),
88+
tags,
89+
]
8590
)
8691
for i, feat in enumerate(feats):
8792
score = score + self.transitions[tags[i + 1], tags[i]] + feat[tags[i + 1]]
@@ -94,7 +99,9 @@ def _viterbi_decode(self, feats):
9499
backpointers = []
95100

96101
# Initialize the viterbi variables in log space
97-
init_vvars = torch.full((1, self.tagset_size), -10000.0, device=self.transitions.device)
102+
init_vvars = torch.full(
103+
(1, self.tagset_size), -10000.0, device=self.transitions.device
104+
)
98105
init_vvars[0][self.tag_to_ix[START_TAG]] = 0
99106

100107
# forward_var at step i holds the viterbi variables for step i-1
@@ -148,4 +155,3 @@ def inference(self, feats): # dont confuse this with _forward_alg above.
148155
# Find the best path, given the features.
149156
score, tag_seq = self._viterbi_decode(feats)
150157
return score, tag_seq
151-

model/field_type_classification_head.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -610,7 +610,7 @@ def __init__(
610610
type of classifier, `single` for a single layer perceptron, `multi` for a MLP
611611
work_mode: str, optional
612612
work mode of the model, controls the return values, `train`, `eval` or `inference`
613-
613+
614614
"""
615615
super().__init__()
616616

model/semantic_segmentation_head.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -81,13 +81,14 @@ def forward(self, x):
8181
class SemanticSegmentationBinaryClassifier(nn.Module):
8282
"""binaly classifier used in auxiliary semantic segmentation head
8383
84-
Parameters
85-
----------
86-
in_channels : int
87-
number of channels of the input feature
84+
Parameters
85+
----------
86+
in_channels : int
87+
number of channels of the input feature
8888
"""
89+
8990
def __init__(self, in_channels: int) -> None:
90-
91+
9192
super().__init__()
9293
self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=1, kernel_size=1)
9394

pipeline/criteria.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
classification_report,
77
)
88

9-
from typing import Dict
9+
from typing import Dict, List, Tuple
1010

1111

1212
@torch.no_grad()
@@ -22,12 +22,16 @@ def token_classification_criteria(gt_label: torch.Tensor, pred_label: torch.Tens
2222

2323

2424
@torch.no_grad()
25-
def BIO_F1_criteria(pred_gt_dict: Dict[torch.Tensor, torch.Tensor], tag_to_idx: Dict, average: str = "micro"):
25+
def BIO_F1_criteria(
26+
pred_gt_list: List[Tuple[torch.Tensor, torch.Tensor]],
27+
tag_to_idx: Dict,
28+
average: str = "micro",
29+
):
2630
idx_to_tag = {v: k for k, v in tag_to_idx.items()}
2731

2832
pred_list = list()
2933
label_list = list()
30-
for pred, label in pred_gt_dict.items():
34+
for (pred, label) in pred_gt_list:
3135
if len(pred.shape) != 1 and pred.shape[1] != 1:
3236
pred = pred.argmax(dim=1)
3337
if len(pred.shape) != 1:
@@ -49,11 +53,14 @@ def BIO_F1_criteria(pred_gt_dict: Dict[torch.Tensor, torch.Tensor], tag_to_idx:
4953

5054

5155
@torch.no_grad()
52-
def token_F1_criteria(pred_gt_dict: Dict[torch.Tensor, torch.Tensor]):
53-
pred_label: torch.Tensor
54-
gt_label: torch.Tensor
55-
pred_label = torch.cat(list(pred_gt_dict.keys()), dim=0)
56-
gt_label = torch.cat(list(pred_gt_dict.values()), dim=0)
56+
def token_F1_criteria(pred_gt_list: List[Tuple[torch.Tensor, torch.Tensor]]):
57+
pred_label = list()
58+
gt_label = list()
59+
for item in pred_gt_list:
60+
pred_label.append(item[0])
61+
gt_label.append(item[1])
62+
pred_label = torch.cat(pred_label, dim=0)
63+
gt_label = torch.cat(gt_label, dim=0)
5764

5865
num_classes = pred_label.shape[1]
5966
pred_label = pred_label.int()

pipeline/distributed_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ def reduce_loss(loss, average=True):
2424

2525
return loss
2626

27+
2728
def is_dist_avail_and_initialized():
2829
if not torch.distributed.is_available():
2930
return False

pipeline/funsd_data_preprocessing.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ def annotation_parsing_word(dir_annotation: str, dir_save: str):
3030
top = coors[1]
3131
right = coors[2]
3232
bot = coors[3]
33-
3433

3534
curr_row_dict = {
3635
"left": [left],
@@ -64,7 +63,7 @@ def annotation_parsing_seg(dir_annotation: str, dir_save: str):
6463
seg_text = Literal["N/A"]
6564
if seg_text == "NA":
6665
seg_text = Literal["NA"]
67-
66+
6867
data_class = seg["label"]
6968
pos_neg = 2 if data_class == 0 else 1
7069

@@ -129,4 +128,4 @@ def run_annotation_parser(dir_funsd_root: str, mode: str):
129128
parser.add_argument("--mode", type=str, help="label data level, word or seg")
130129
args = parser.parse_args()
131130

132-
run_annotation_parser(args.root, args.mode)
131+
run_annotation_parser(args.root, args.mode)

pipeline/sroie_data_preprocessing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -495,7 +495,7 @@ def data_parser_multiprocessing(
495495
|___img: images
496496
|___box: txt files that contain OCR results
497497
|___key: txt files that contain key info labels
498-
498+
499499
___test_raw
500500
|___img: images
501501
|___box: txt files that contain OCR results

pipeline/train_val_utils.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -405,7 +405,7 @@ def validate(
405405
method_precision_sum = torch.zeros(1, device=device)
406406

407407
model.eval()
408-
pred_gt_dict = dict()
408+
pred_gt_list = list()
409409
mean_validate_loss = torch.zeros(1).to(device)
410410
for step, validate_batch in enumerate(validate_loader):
411411
(
@@ -517,7 +517,7 @@ def validate(
517517
num_gt += torch.tensor(curr_num_gt, device=device)
518518
num_det += torch.tensor(curr_num_det, device=device)
519519

520-
pred_gt_dict.update({pred_label.detach(): gt_label.detach()})
520+
pred_gt_list.append((pred_label.detach(), gt_label.detach()))
521521

522522
validate_loss = reduce_loss(validate_loss)
523523
validate_loss_value = validate_loss.item()
@@ -541,14 +541,19 @@ def validate(
541541
torch.distributed.all_reduce(method_precision_sum)
542542
torch.distributed.all_reduce(method_recall_sum)
543543

544-
pred_gt_dict_syn = [None for _ in range(num_proc)]
544+
pred_gt_list_syn = [None for _ in range(num_proc)]
545545
torch.distributed.all_gather_object(
546-
object_list=pred_gt_dict_syn, obj=pred_gt_dict
546+
object_list=pred_gt_list_syn, obj=pred_gt_list
547547
)
548-
pred_gt_dict_ = dict()
549-
for p_g_d in pred_gt_dict_syn:
550-
for k, v in p_g_d.items():
551-
pred_gt_dict_.update({k: v})
548+
pred_gt_list_ = list()
549+
for p_g_d in pred_gt_list_syn:
550+
for p_g_item in p_g_d:
551+
pred_gt_list_.append(p_g_item)
552+
del pred_gt_list_syn
553+
else:
554+
pred_gt_list_ = pred_gt_list
555+
556+
del pred_gt_list
552557

553558
num_gt = int(num_gt.item())
554559
num_det = int(num_det.item())
@@ -558,7 +563,7 @@ def validate(
558563
if eval_mode == "seqeval":
559564
assert tag_to_idx is not None
560565
precision, recall, F1, report = BIO_F1_criteria(
561-
pred_gt_dict=pred_gt_dict_, tag_to_idx=tag_to_idx, average=seqeval_average
566+
pred_gt_list=pred_gt_list_, tag_to_idx=tag_to_idx, average=seqeval_average
562567
)
563568
print(report)
564569
print(
@@ -588,7 +593,7 @@ def validate(
588593
elif eval_mode == "seq_and_str":
589594
assert tag_to_idx is not None
590595
token_precision, token_recall, token_F1, report = BIO_F1_criteria(
591-
pred_gt_dict=pred_gt_dict_, tag_to_idx=tag_to_idx, average=seqeval_average
596+
pred_gt_list=pred_gt_list_, tag_to_idx=tag_to_idx, average=seqeval_average
592597
)
593598
print("==> token level result")
594599
print(report)
@@ -620,7 +625,7 @@ def validate(
620625

621626
else:
622627
result_dict: Dict
623-
result_dict = token_F1_criteria(pred_gt_dict=pred_gt_dict_)
628+
result_dict = token_F1_criteria(pred_gt_list=pred_gt_list_)
624629
num_classes = result_dict["num_classes"]
625630
precision = 0.0
626631
recall = 0.0

pipeline/transform.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -259,10 +259,7 @@ def return_batch(
259259

260260
# 创建shape为batch_shape且值全部为0的tensor
261261
batched_imgs = images[0].new_full(batch_image_shape, 0)
262-
for (
263-
img,
264-
pad_img,
265-
) in zip(
262+
for (img, pad_img,) in zip(
266263
images,
267264
batched_imgs,
268265
):

train_EPHOIE.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
import torch
99
from transformers import BertTokenizer, RobertaTokenizer
1010

11-
from data.EPHOIE_dataset import load_train_dataset_multi_gpu as EPHOIE_load_train
11+
from data.EPHOIE_dataset import load_train_dataset_multi_gpu as EPHOIE_load_train_multi
12+
from data.EPHOIE_dataset import load_train_dataset as EPHOIE_load_train
1213
from model.ViBERTgrid_net import ViBERTgridNet
1314
from pipeline.train_val_utils import (
1415
train_one_epoch,
@@ -155,6 +156,11 @@ def train(args):
155156
eval_mode = hyp["eval_mode"]
156157
tag_mode = hyp["tag_mode"]
157158

159+
if classifier_mode == "crf":
160+
assert (
161+
eval_mode == "seqeval"
162+
), "When using the crf classifier, only the seqeval metric is available"
163+
158164
if tag_mode == "BIO":
159165
map_dict = TAG_TO_IDX_BIO
160166
else:
@@ -170,12 +176,21 @@ def train(args):
170176
print(f"==> tokenizer {bert_version} loaded")
171177

172178
print(f"==> loading datasets")
173-
train_loader, val_loader, train_sampler = EPHOIE_load_train(
174-
root=data_root,
175-
batch_size=batch_size,
176-
num_workers=num_workers,
177-
tokenizer=tokenizer,
178-
)
179+
if args.distributed:
180+
train_loader, val_loader, train_sampler = EPHOIE_load_train_multi(
181+
root=data_root,
182+
batch_size=batch_size,
183+
num_workers=num_workers,
184+
tokenizer=tokenizer,
185+
)
186+
else:
187+
train_loader, val_loader = EPHOIE_load_train(
188+
root=data_root,
189+
batch_size=batch_size,
190+
num_workers=num_workers,
191+
tokenizer=tokenizer,
192+
)
193+
179194
print(f"==> dataset loaded")
180195

181196
print(f"==> creating model {backbone} | {bert_version}")
@@ -296,7 +311,7 @@ def train(args):
296311
f"{curr_time.tm_year:04d}-{curr_time.tm_mon:02d}-{curr_time.tm_mday:02d}"
297312
)
298313
curr_time_h += (
299-
f"_{curr_time.tm_hour:02d}:{curr_time.tm_min:02d}:{curr_time.tm_sec:02d}"
314+
f"_{curr_time.tm_hour:02d}-{curr_time.tm_min:02d}-{curr_time.tm_sec:02d}"
300315
)
301316
comment = (
302317
comment_exp
@@ -321,6 +336,7 @@ def train(args):
321336
device=device,
322337
epoch=0,
323338
logger=logger,
339+
distributed=args.distributed,
324340
eval_mode=eval_mode,
325341
tag_to_idx=map_dict,
326342
category_list=EPHOIE_CLASS_LIST,
@@ -361,6 +377,7 @@ def train(args):
361377
device=device,
362378
epoch=epoch,
363379
logger=logger,
380+
distributed=args.distributed,
364381
eval_mode=eval_mode,
365382
tag_to_idx=map_dict,
366383
category_list=EPHOIE_CLASS_LIST,
@@ -371,7 +388,7 @@ def train(args):
371388
if F1 > top_F1:
372389
top_F1 = F1
373390

374-
if F1 > top_F1_tresh or (epoch % 400 == 0 and epoch != start_epoch):
391+
if F1 > top_F1_tresh or (epoch % 10 == 0 and epoch != start_epoch):
375392
top_F1_tresh = F1
376393
if save_top is not None:
377394
if not os.path.exists(save_top) and is_main_process():

0 commit comments

Comments
 (0)