diff --git a/.gitignore b/.gitignore index 8b01850e..880cd98b 100755 --- a/.gitignore +++ b/.gitignore @@ -137,3 +137,6 @@ dmypy.json # Pyre type checker .pyre/ + +.idea +.DS_Store \ No newline at end of file diff --git a/README.md b/README.md index 58f3e9fb..76dfad36 100755 --- a/README.md +++ b/README.md @@ -34,12 +34,12 @@ Gradio web demos are available! [![Demo](https://img.shields.io/badge/Demo-Gradi - Web demos are available from the links in the following table. - Note: We have updated the Google Colab demo (as of June 15, 2023) to ensure its proper working. -|Task|Sec/Img|Score|Trained Model|
Demo
| -|---|---|---|---|---| -| [CORD](https://github.com/clovaai/cord) (Document Parsing) | 0.7 /
0.7 /
1.2 | 91.3 /
91.1 /
90.9 | [donut-base-finetuned-cord-v2](https://huggingface.co/naver-clova-ix/donut-base-finetuned-cord-v2/tree/official) (1280) /
[donut-base-finetuned-cord-v1](https://huggingface.co/naver-clova-ix/donut-base-finetuned-cord-v1/tree/official) (1280) /
[donut-base-finetuned-cord-v1-2560](https://huggingface.co/naver-clova-ix/donut-base-finetuned-cord-v1-2560/tree/official) | [gradio space web demo](https://huggingface.co/spaces/naver-clova-ix/donut-base-finetuned-cord-v2),
[google colab demo (updated at 23.06.15)](https://colab.research.google.com/drive/1NMSqoIZ_l39wyRD7yVjw2FIuU2aglzJi?usp=sharing) | -| [Train Ticket](https://github.com/beacandler/EATEN) (Document Parsing) | 0.6 | 98.7 | [donut-base-finetuned-zhtrainticket](https://huggingface.co/naver-clova-ix/donut-base-finetuned-zhtrainticket/tree/official) | [google colab demo (updated at 23.06.15)](https://colab.research.google.com/drive/1YJBjllahdqNktXaBlq5ugPh1BCm8OsxI?usp=sharing) | -| [RVL-CDIP](https://www.cs.cmu.edu/~aharley/rvl-cdip) (Document Classification) | 0.75 | 95.3 | [donut-base-finetuned-rvlcdip](https://huggingface.co/naver-clova-ix/donut-base-finetuned-rvlcdip/tree/official) | [gradio space web demo](https://huggingface.co/spaces/nielsr/donut-rvlcdip),
[google colab demo (updated at 23.06.15)](https://colab.research.google.com/drive/1iWOZHvao1W5xva53upcri5V6oaWT-P0O?usp=sharing) | -| [DocVQA Task1](https://rrc.cvc.uab.es/?ch=17) (Document VQA) | 0.78 | 67.5 | [donut-base-finetuned-docvqa](https://huggingface.co/naver-clova-ix/donut-base-finetuned-docvqa/tree/official) | [gradio space web demo](https://huggingface.co/spaces/nielsr/donut-docvqa),
[google colab demo (updated at 23.06.15)](https://colab.research.google.com/drive/1oKieslZCulFiquequ62eMGc-ZWgay4X3?usp=sharing) | +| Task | Sec/Img | Score | Trained Model |
Demo
| +|--------------------------------------------------------------------------------|-------------------------|----------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| [CORD](https://github.com/clovaai/cord) (Document Parsing) | 0.7 /
0.7 /
1.2 | 91.3 /
91.1 /
90.9 | [donut-base-finetuned-cord-v2](https://huggingface.co/naver-clova-ix/donut-base-finetuned-cord-v2/tree/official) (1280) /
[donut-base-finetuned-cord-v1](https://huggingface.co/naver-clova-ix/donut-base-finetuned-cord-v1/tree/official) (1280) /
[donut-base-finetuned-cord-v1-2560](https://huggingface.co/naver-clova-ix/donut-base-finetuned-cord-v1-2560/tree/official) | [gradio space web demo](https://huggingface.co/spaces/naver-clova-ix/donut-base-finetuned-cord-v2),
[google colab demo (updated at 23.06.15)](https://colab.research.google.com/drive/1NMSqoIZ_l39wyRD7yVjw2FIuU2aglzJi?usp=sharing) | +| [Train Ticket](https://github.com/beacandler/EATEN) (Document Parsing) | 0.6 | 98.7 | [donut-base-finetuned-zhtrainticket](https://huggingface.co/naver-clova-ix/donut-base-finetuned-zhtrainticket/tree/official) | [google colab demo (updated at 23.06.15)](https://colab.research.google.com/drive/1YJBjllahdqNktXaBlq5ugPh1BCm8OsxI?usp=sharing) | +| [RVL-CDIP](https://www.cs.cmu.edu/~aharley/rvl-cdip) (Document Classification) | 0.75 | 95.3 | [donut-base-finetuned-rvlcdip](https://huggingface.co/naver-clova-ix/donut-base-finetuned-rvlcdip/tree/official) | [gradio space web demo](https://huggingface.co/spaces/nielsr/donut-rvlcdip),
[google colab demo (updated at 23.06.15)](https://colab.research.google.com/drive/1iWOZHvao1W5xva53upcri5V6oaWT-P0O?usp=sharing) | +| [DocVQA Task1](https://rrc.cvc.uab.es/?ch=17) (Document VQA) | 0.78 | 67.5 | [donut-base-finetuned-docvqa](https://huggingface.co/naver-clova-ix/donut-base-finetuned-docvqa/tree/official) | [gradio space web demo](https://huggingface.co/spaces/nielsr/donut-docvqa),
[google colab demo (updated at 23.06.15)](https://colab.research.google.com/drive/1oKieslZCulFiquequ62eMGc-ZWgay4X3?usp=sharing) | The links to the pre-trained backbones are here: - [`donut-base`](https://huggingface.co/naver-clova-ix/donut-base/tree/official): trained with 64 A100 GPUs (~2.5 days), number of layers (encoder: {2,2,14,2}, decoder: 4), input size 2560x1920, swin window size 10, IIT-CDIP (11M) and SynthDoG (English, Chinese, Japanese, Korean, 0.5M x 4). @@ -62,6 +62,9 @@ To generate synthetic datasets with our SynthDoG, please see `./synthdog/README. ## Updates +**_2024-05-15_** We have introduced the opportunity to retrieve the scores with the extracted items with the usual +`DonutModel`. The `inference` method has the `return_confs` to return the scores of the predicted items. The parameter +is set to `True` by default. If you don't want the scores, please set `return_confs` to `False`.
**_2023-06-15_** We have updated all Google Colab demos to ensure its proper working.
**_2022-11-14_** New version 1.0.9 is released (`pip install donut-python --upgrade`). See [1.0.9 Release Notes](https://github.com/clovaai/donut/releases/tag/1.0.9).
**_2022-08-12_** Donut 🍩 is also available at [huggingface/transformers 🤗](https://huggingface.co/docs/transformers/main/en/model_doc/donut) (contributed by [@NielsRogge](https://github.com/NielsRogge)). `donut-python` loads the pre-trained weights from the `official` branch of the model repositories. See [1.0.5 Release Notes](https://github.com/clovaai/donut/releases/tag/1.0.5).
diff --git a/donut/model.py b/donut/model.py index 6321c5d2..eeff4480 100755 --- a/donut/model.py +++ b/donut/model.py @@ -6,21 +6,20 @@ import math import os import re -from typing import Any, List, Optional, Union +from typing import Any, List, Optional, Tuple, Union import numpy as np -import PIL import timm import torch import torch.nn as nn -import torch.nn.functional as F -from PIL import ImageOps +import torch.nn.functional as fnc +from PIL import Image, ImageOps from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.models.swin_transformer import SwinTransformer from torchvision import transforms -from torchvision.transforms.functional import resize, rotate from transformers import MBartConfig, MBartForCausalLM, XLMRobertaTokenizer from transformers.file_utils import ModelOutput +from transformers.utils.generic import to_py_obj from transformers.modeling_utils import PretrainedConfig, PreTrainedModel @@ -41,17 +40,19 @@ class SwinEncoder(nn.Module): def __init__( self, - input_size: List[int], + input_size: Tuple[int, int], align_long_axis: bool, window_size: int, - encoder_layer: List[int], + encoder_layer: Tuple[int], name_or_path: Union[str, bytes, os.PathLike] = None, + drop_rate: float | None = 0.0, ): super().__init__() self.input_size = input_size self.align_long_axis = align_long_axis self.window_size = window_size self.encoder_layer = encoder_layer + self.drop_rate = drop_rate self.to_tensor = transforms.Compose( [ @@ -66,10 +67,9 @@ def __init__( window_size=self.window_size, patch_size=4, embed_dim=128, - num_heads=[4, 8, 16, 32], + num_heads=(4, 8, 16, 32), num_classes=0, ) - self.model.norm = None # weight init with swin if not name_or_path: @@ -86,7 +86,7 @@ def __init__( old_len = int(math.sqrt(len(pos_bias))) new_len = int(2 * window_size - 1) pos_bias = pos_bias.reshape(1, old_len, old_len, -1).permute(0, 3, 1, 2) - pos_bias = F.interpolate(pos_bias, size=(new_len, new_len), mode="bicubic", align_corners=False) + pos_bias = fnc.interpolate(pos_bias, size=(new_len, new_len), mode="bicubic", align_corners=False) new_swin_state_dict[x] = pos_bias.permute(0, 2, 3, 1).reshape(1, new_len ** 2, -1).squeeze(0) else: new_swin_state_dict[x] = swin_state_dict[x] @@ -98,11 +98,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x: (batch_size, num_channels, height, width) """ x = self.model.patch_embed(x) - x = self.model.pos_drop(x) + x = nn.Dropout(p=self.drop_rate)(x) x = self.model.layers(x) return x - def prepare_input(self, img: PIL.Image.Image, random_padding: bool = False) -> torch.Tensor: + def prepare_input(self, img: Image.Image, random_padding: bool = False) -> torch.Tensor: """ Convert PIL Image to tensor according to specified input_size after following steps below: - resize @@ -114,8 +114,8 @@ def prepare_input(self, img: PIL.Image.Image, random_padding: bool = False) -> t (self.input_size[0] > self.input_size[1] and img.width > img.height) or (self.input_size[0] < self.input_size[1] and img.width < img.height) ): - img = rotate(img, angle=-90, expand=True) - img = resize(img, min(self.input_size)) + img = img.rotate(angle=-90, expand=True) + img = img.resize(self.input_size) img.thumbnail((self.input_size[1], self.input_size[0])) delta_width = self.input_size[1] - img.width delta_height = self.input_size[0] - img.height @@ -134,6 +134,147 @@ def prepare_input(self, img: PIL.Image.Image, random_padding: bool = False) -> t return self.to_tensor(ImageOps.expand(img, padding)) +class BARTCustomTokenizer(XLMRobertaTokenizer): + """ + Customized XLMRobertaTokenizer to return confidence scores and token id groups aligned with grouped tokens + The default batch_decoder, decode and _decode are overwritten for the Tokenizer + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.DELIM = None + + def batch_decode( + self, + sequences: Union[List[int], List[List[int]], "np.ndarray", "torch.Tensor"], + skip_special_tokens: bool = False, + clean_up_tokenization_spaces: bool = True, + **kwargs + ) -> List[Tuple]: + """ + Convert a list of lists of token ids into a list of strings by calling decode. + Args: + sequences (`Union[List[int], List[List[int]], np.ndarray, torch.Tensor, tf.Tensor]`): + List of tokenized input ids. Can be obtained using the `__call__` method. + skip_special_tokens (`bool`, *optional*, defaults to `False`): + Whether to remove special tokens in the decoding. + clean_up_tokenization_spaces (`bool`, *optional*, defaults to `True`): + Whether to clean up the tokenization spaces. + kwargs (additional keyword arguments, *optional*): + Will be passed to the underlying model specific decode method. + Returns: + `List[str]`: The list of decoded sentences. + """ + confidences = kwargs.pop("confidences", []) + self.DELIM = kwargs.pop("decoder_delim", None) + + result = [] + for seq, conf in zip(sequences, confidences): + kwargs["token_confs"] = conf + result.append(self.decode( + seq, + skip_special_tokens=skip_special_tokens, + clean_up_tokenization_spaces=clean_up_tokenization_spaces, + **kwargs, + )) + return result + + def decode( + self, + token_ids: Union[int, List[int], "np.ndarray", "torch.Tensor"], + skip_special_tokens: bool = False, + clean_up_tokenization_spaces: bool = True, + **kwargs + ) -> Tuple: + """ + Converts a sequence of ids in a string, using the tokenizer and vocabulary with options to remove special + tokens and clean up tokenization spaces. + Similar to doing `self.convert_tokens_to_string(self.convert_ids_to_tokens(token_ids))`. + Args: + token_ids (`Union[int, List[int], np.ndarray, torch.Tensor, tf.Tensor]`): + List of tokenized input ids. Can be obtained using the `__call__` method. + skip_special_tokens (`bool`, *optional*, defaults to `False`): + Whether to remove special tokens in the decoding. + clean_up_tokenization_spaces (`bool`, *optional*, defaults to `True`): + Whether to clean up the tokenization spaces. + kwargs (additional keyword arguments, *optional*): + Will be passed to the underlying model specific decode method. + Returns: + `str`: The decoded sentence. + """ + # Convert inputs to python lists + token_ids = to_py_obj(token_ids) + kwargs["token_confs"] = to_py_obj(kwargs.pop("token_confs", [])) + + return self._decode( + token_ids=token_ids, + skip_special_tokens=skip_special_tokens, + clean_up_tokenization_spaces=clean_up_tokenization_spaces, + **kwargs, + ) + + def _decode( + self, + token_ids: List[int], + skip_special_tokens: bool = False, + clean_up_tokenization_spaces: bool = True, + spaces_between_special_tokens: bool = True, + **kwargs + ) -> Tuple: + token_confs = kwargs.pop("token_confs", []) + + self._decode_use_source_tokenizer = kwargs.pop("use_source_tokenizer", False) + + filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens) + + # To avoid mixing byte-level and unicode for byte-level BPT + # we need to build string separately for added tokens and byte-level tokens + # cf. https://github.com/huggingface/transformers/issues/1133 + sub_texts = [] + sub_confs = [] + sub_idxs = [] + current_sub_text = [] + current_sub_confs = [] + current_sub_idxs = [] + + for idx, (token, conf) in enumerate(zip(filtered_tokens, token_confs)): + if skip_special_tokens and token in self.all_special_ids: + continue + if token in self.added_tokens_encoder: + if current_sub_text: + sub_texts.append(self.convert_tokens_to_string(current_sub_text)) + current_sub_text = [] + sub_confs.append(sum(current_sub_confs) / len(current_sub_confs)) + current_sub_confs = [] + sub_idxs.append(current_sub_idxs) + current_sub_idxs = [] + sub_texts.append(token) + sub_confs.append(conf) + sub_idxs.append([idx]) + else: + current_sub_text.append(token) + current_sub_confs.append(conf) + current_sub_idxs.append(idx) + + if current_sub_text: + sub_texts.append(self.convert_tokens_to_string(current_sub_text)) + sub_confs.append(sum(current_sub_confs) / len(current_sub_confs)) + sub_idxs.append(current_sub_idxs) + + decoder_output_confs = sub_confs + decoder_output_indxs = sub_idxs + if spaces_between_special_tokens: + text = self.DELIM.join(sub_texts) + else: + text = "".join(sub_texts) + + if clean_up_tokenization_spaces: + clean_text = self.clean_up_tokenization(text) + return clean_text, decoder_output_confs, decoder_output_indxs + + return text, decoder_output_confs, decoder_output_indxs + + class BARTDecoder(nn.Module): """ Donut Decoder based on Multilingual BART @@ -157,7 +298,7 @@ def __init__( self.decoder_layer = decoder_layer self.max_position_embeddings = max_position_embeddings - self.tokenizer = XLMRobertaTokenizer.from_pretrained( + self.tokenizer = BARTCustomTokenizer.from_pretrained( "hyunwoongko/asian-bart-ecjk" if not name_or_path else name_or_path ) @@ -173,7 +314,7 @@ def __init__( add_final_layer_norm=True, ) ) - self.model.forward = self.forward # to get cross attentions and utilize `generate` function + self.model.forward = self.forward # to get cross attentions and utilize `generate` function self.model.config.is_encoder_decoder = True # to get cross-attention self.add_special_tokens([""]) # is used for representing a list in a JSON @@ -186,12 +327,9 @@ def __init__( new_bart_state_dict = self.model.state_dict() for x in new_bart_state_dict: if x.endswith("embed_positions.weight") and self.max_position_embeddings != 1024: + # https://github.com/huggingface/transformers/blob/v4.11.3/src/transformers/models/mbart/modeling_mbart.py#L118-L119 new_bart_state_dict[x] = torch.nn.Parameter( - self.resize_bart_abs_pos_emb( - bart_state_dict[x], - self.max_position_embeddings - + 2, # https://github.com/huggingface/transformers/blob/v4.11.3/src/transformers/models/mbart/modeling_mbart.py#L118-L119 - ) + self.resize_bart_abs_pos_emb(bart_state_dict[x], self.max_position_embeddings + 2) ) elif x.endswith("embed_tokens.weight") or x.endswith("lm_head.weight"): new_bart_state_dict[x] = bart_state_dict[x][: len(self.tokenizer), :] @@ -199,35 +337,49 @@ def __init__( new_bart_state_dict[x] = bart_state_dict[x] self.model.load_state_dict(new_bart_state_dict) - def add_special_tokens(self, list_of_tokens: List[str]): + def add_special_tokens(self, list_of_tokens: List[str], replace_additional_special_tokens: bool | None = False): """ Add special tokens to tokenizer and resize the token embeddings """ - newly_added_num = self.tokenizer.add_special_tokens({"additional_special_tokens": sorted(set(list_of_tokens))}) + newly_added_num = 0 + set_of_tokens = set(list_of_tokens) + set_special_tokens = set(self.tokenizer.all_special_tokens) + if len(set_of_tokens - set_special_tokens) > 0: + newly_added_num = self.tokenizer.add_special_tokens( + {"additional_special_tokens": sorted(set_of_tokens)}, + replace_additional_special_tokens=replace_additional_special_tokens + ) if newly_added_num > 0: self.model.resize_token_embeddings(len(self.tokenizer)) - def prepare_inputs_for_inference(self, input_ids: torch.Tensor, encoder_outputs: torch.Tensor, past_key_values=None, past=None, use_cache: bool = None, attention_mask: torch.Tensor = None): + def prepare_inputs_for_inference( + self, + input_ids: torch.Tensor, + encoder_outputs: torch.Tensor = None, + past=None, + use_cache: bool = None, + **kwargs, + ): """ Args: - input_ids: (batch_size, sequence_lenth) + input_ids: (batch_size, sequence_length) + encoder_outputs: (batch_size, sequence_length, hidden_size) + past: Past key values + use_cache: Whether to use cache or not Returns: input_ids: (batch_size, sequence_length) attention_mask: (batch_size, sequence_length) encoder_hidden_states: (batch_size, sequence_length, embedding_dim) """ - # for compatibility with transformers==4.11.x - if past is not None: - past_key_values = past attention_mask = input_ids.ne(self.tokenizer.pad_token_id).long() - if past_key_values is not None: + if past is not None: input_ids = input_ids[:, -1:] output = { "input_ids": input_ids, "attention_mask": attention_mask, - "past_key_values": past_key_values, + "past_key_values": past, "use_cache": use_cache, - "encoder_hidden_states": encoder_outputs.last_hidden_state, + "encoder_hidden_states": encoder_outputs.last_hidden_state if encoder_outputs is not None else None, } return output @@ -244,7 +396,7 @@ def forward( return_dict: bool = None, ): """ - A forward fucntion to get cross attentions and utilize `generate` function + A forward function to get cross attentions and utilize `generate` function Source: https://github.com/huggingface/transformers/blob/v4.11.3/src/transformers/models/mbart/modeling_mbart.py#L1669-L1810 @@ -253,7 +405,12 @@ def forward( input_ids: (batch_size, sequence_length) attention_mask: (batch_size, sequence_length) encoder_hidden_states: (batch_size, sequence_length, hidden_size) - + past_key_values: + labels: (batch_size, sequence_length) + use_cache: Whether to use cache or not + output_attentions: Whether to return attentions or not + output_hidden_states: Whether to return hidden states or not + return_dict: Whether to return dict or not Returns: loss: (1, ) logits: (batch_size, sequence_length, hidden_dim) @@ -308,7 +465,7 @@ def resize_bart_abs_pos_emb(weight: torch.Tensor, max_length: int) -> torch.Tens weight = weight[:max_length, ...] else: weight = ( - F.interpolate( + fnc.interpolate( weight.permute(1, 0).unsqueeze(0), size=max_length, mode="linear", @@ -349,10 +506,10 @@ class DonutConfig(PretrainedConfig): def __init__( self, - input_size: List[int] = [2560, 1920], + input_size: Tuple[int, int] | None = None, align_long_axis: bool = False, window_size: int = 10, - encoder_layer: List[int] = [2, 2, 14, 2], + encoder_layer: Tuple[int] | None = None, decoder_layer: int = 4, max_position_embeddings: int = None, max_length: int = 1536, @@ -360,10 +517,10 @@ def __init__( **kwargs, ): super().__init__() - self.input_size = input_size + self.input_size = input_size if input_size else (2560, 1920) self.align_long_axis = align_long_axis self.window_size = window_size - self.encoder_layer = encoder_layer + self.encoder_layer = encoder_layer if encoder_layer else (2, 2, 14, 2) self.decoder_layer = decoder_layer self.max_position_embeddings = max_length if max_position_embeddings is None else max_position_embeddings self.max_length = max_length @@ -382,6 +539,8 @@ class DonutModel(PreTrainedModel): def __init__(self, config: DonutConfig): super().__init__(config) + self.return_confs = None + self.return_tokens = None self.config = config self.encoder = SwinEncoder( input_size=self.config.input_size, @@ -404,7 +563,7 @@ def forward(self, image_tensors: torch.Tensor, decoder_input_ids: torch.Tensor, Args: image_tensors: (batch_size, num_channels, height, width) decoder_input_ids: (batch_size, sequence_length, embedding_dim) - decode_labels: (batch_size, sequence_length) + decoder_labels: (batch_size, sequence_length) """ encoder_outputs = self.encoder(image_tensors) decoder_outputs = self.decoder( @@ -416,16 +575,18 @@ def forward(self, image_tensors: torch.Tensor, decoder_input_ids: torch.Tensor, def inference( self, - image: PIL.Image = None, + image: Image = None, prompt: str = None, image_tensors: Optional[torch.Tensor] = None, prompt_tensors: Optional[torch.Tensor] = None, return_json: bool = True, - return_attentions: bool = False, + return_confs: bool = True, + return_tokens: bool = False, + return_attentions: bool = False ): """ - Generate a token sequence in an auto-regressive manner, - the generated token sequence is convereted into an ordered JSON format + Generate a token sequence in an autoregressive manner, + the generated token sequence is converted into an ordered JSON format Args: image: input document image (PIL.Image) @@ -434,6 +595,10 @@ def inference( convert prompt to tensor if image_tensor is not fed prompt_tensors: (1, sequence_length) convert image to tensor if prompt_tensor is not fed + return_json: whether to return a JSON format or not + return_confs: whether to return confidence scores or not + return_tokens: whether to return tokens or not + return_attentions: whether to return attentions or not """ # prepare backbone inputs (image and prompt) if image is None and image_tensors is None: @@ -447,6 +612,8 @@ def inference( if self.device.type == "cuda": # half is not compatible in cpu implementation. image_tensors = image_tensors.half() image_tensors = image_tensors.to(self.device) + else: + image_tensors = image_tensors.to(torch.bfloat16) if prompt_tensors is None: prompt_tensors = self.decoder.tokenizer(prompt, add_special_tokens=False, return_tensors="pt")["input_ids"] @@ -477,16 +644,43 @@ def inference( bad_words_ids=[[self.decoder.tokenizer.unk_token_id]], return_dict_in_generate=True, output_attentions=return_attentions, + output_scores=True, ) + decoder_output_confs = torch.amax(torch.stack(decoder_output.scores, dim=1).softmax(-1), 2).cpu().numpy()[0] + # add score for end token and wrap scores in a list + decoder_output_confs = [np.concatenate([decoder_output_confs, [1.]], axis=0)] + output = {"predictions": list()} - for seq in self.decoder.tokenizer.batch_decode(decoder_output.sequences): - seq = seq.replace(self.decoder.tokenizer.eos_token, "").replace(self.decoder.tokenizer.pad_token, "") - seq = re.sub(r"<.*?>", "", seq, count=1).strip() # remove first task start token - if return_json: - output["predictions"].append(self.token2json(seq)) - else: - output["predictions"].append(seq) + self.return_tokens = return_tokens + self.return_confs = return_confs + delimiter = "}~}~}~{" # important, use a delimiter that has a very low prob of appearing in text + + for idx, (seq, confs, idxs) in enumerate(self.decoder.tokenizer.batch_decode( + decoder_output.sequences, confidences=decoder_output_confs, decoder_delim=delimiter) + ): + eos_tkn, pad_tkn = self.decoder.tokenizer.eos_token, self.decoder.tokenizer.pad_token + split_seq = [tkn for tkn in seq.split(delimiter) if tkn] + confs = [confs[i] for i, tkn in enumerate(split_seq) if not ( + tkn.strip().lower() == eos_tkn.lower() or tkn.strip().lower() == pad_tkn.lower() + )] + idxs = [idxs[i] for i, tkn in enumerate(seq.split(delimiter)) if not ( + tkn.strip().lower() == eos_tkn.lower() or tkn.strip().lower() == pad_tkn.lower() + )] + seq = seq.replace(eos_tkn, "").replace(pad_tkn, "") + for i, tkn in enumerate(seq.split(delimiter)): + if re.search(r"<.*?>", tkn, re.IGNORECASE): # remove first task start token conf + confs.pop(i) + idxs.pop(i) + break + seq = re.sub(r"<.*?>", "", seq, count=1).strip(delimiter) # remove first task start token + item = seq + if confs and idxs and return_json: + item = self.token2json_with_confs(seq, confs, idxs, delim=delimiter) if ( + return_confs or return_tokens + ) else self.token2json(seq.replace(delimiter, ' ')) + + output["predictions"].append(item) if return_attentions: output["attentions"] = { @@ -500,37 +694,37 @@ def json2token(self, obj: Any, update_special_tokens_for_json_key: bool = True, """ Convert an ordered JSON object into a token sequence """ - if type(obj) == dict: + if isinstance(obj, dict): if len(obj) == 1 and "text_sequence" in obj: return obj["text_sequence"] - else: - output = "" - if sort_json_key: - keys = sorted(obj.keys(), reverse=True) - else: - keys = obj.keys() - for k in keys: - if update_special_tokens_for_json_key: - self.decoder.add_special_tokens([fr"", fr""]) - output += ( - fr"" - + self.json2token(obj[k], update_special_tokens_for_json_key, sort_json_key) - + fr"" - ) - return output - elif type(obj) == list: + + output = "" + keys = obj.keys() + if sort_json_key: + keys = sorted(keys, reverse=True) + for k in keys: + if update_special_tokens_for_json_key: + self.decoder.add_special_tokens([fr"", fr""]) + output += ( + fr"" + + self.json2token(obj[k], update_special_tokens_for_json_key, sort_json_key) + + fr"" + ) + return output + + if isinstance(obj, list): return r"".join( [self.json2token(item, update_special_tokens_for_json_key, sort_json_key) for item in obj] ) - else: - obj = str(obj) - if f"<{obj}/>" in self.decoder.tokenizer.all_special_tokens: - obj = f"<{obj}/>" # for categorical special tokens - return obj - def token2json(self, tokens, is_inner_value=False): + obj = str(obj) + if f"<{obj}/>" in self.decoder.tokenizer.all_special_tokens: + obj = f"<{obj}/>" # for categorical special tokens + return obj + + def token2json(self, tokens: str, is_inner_value: bool = False) -> List[dict]: """ - Convert a (generated) token seuqnce into an ordered JSON format + Convert a (generated) token sequence into an ordered JSON format """ output = dict() @@ -570,14 +764,118 @@ def token2json(self, tokens, is_inner_value=False): if len(output[key]) == 1: output[key] = output[key][0] - tokens = tokens[tokens.find(end_token) + len(end_token) :].strip() + tokens = tokens[tokens.find(end_token) + len(end_token):].strip() if tokens[:6] == r"": # non-leaf nodes return [output] + self.token2json(tokens[6:], is_inner_value=True) if len(output): return [output] if is_inner_value else output - else: - return [] if is_inner_value else {"text_sequence": tokens} + + return [] if is_inner_value else {"text_sequence": tokens} + + def token2json_with_confs( + self, tokens: str, confs: List[float], idxs: List[list], delim: str, is_inner_val: bool = False + ) -> List: + """ + Convert a (generated) token sequence into an ordered JSON format + """ + output = dict() + + while tokens: + start_token = re.search(r"", tokens, re.IGNORECASE) + if start_token is None: + break + key = start_token.group(1) + end_token = re.search(fr"", tokens, re.IGNORECASE) + start_token = start_token.group() + tokens_split = [tkn for tkn in tokens.split(delim) if tkn] + assert len(tokens_split) == len(confs) == len(idxs) + + if end_token is None: + # remove all occurrences of start_token idxes from confs list and idxs list + confs = [ + confs[i] for i, tkn in enumerate(tokens_split) if not re.search(start_token, tkn, re.IGNORECASE) + ] + idxs = [idxs[i] for i, tkn in enumerate(tokens_split) if not re.search(start_token, tkn, re.IGNORECASE)] + tokens = tokens.replace(start_token, "") + tksplit = [tk for tk in tokens.split(delim) if tk] + assert len(tksplit) == len(confs) == len(idxs) + else: + end_token = end_token.group() + start_token_escaped = re.escape(start_token) + end_token_escaped = re.escape(end_token) + content = re.search(f"{start_token_escaped}(.*?){end_token_escaped}", tokens, re.IGNORECASE) + if content is not None: + start_tkn_esc_idx = None + end_tkn_esc_idx = None + for i, tkn in enumerate(tokens_split): + # only take the first start token + if start_tkn_esc_idx is None and re.search(start_token_escaped, tkn, re.IGNORECASE): + start_tkn_esc_idx = i + # end_token_escaped must exist after start_token_escaped_idx exists + if start_tkn_esc_idx is not None and re.search(end_token_escaped, tkn, re.IGNORECASE): + end_tkn_esc_idx = i + break + content = content.group(1).strip(delim) + content_confs = confs[start_tkn_esc_idx + 1:end_tkn_esc_idx] + content_idxs = idxs[start_tkn_esc_idx + 1:end_tkn_esc_idx] + cntsplit = [tk for tk in content.split(delim) if tk] + + assert len(tokens_split) == len(confs) == len(idxs) + assert len(cntsplit) == len(content_confs) == len(content_idxs) + + if r"", tkn, re.IGNORECASE) + )] + leaf_content_idxs = [content_idxs[i] for i, tkn in enumerate(cntsplit) if not ( + re.search(r"", tkn, re.IGNORECASE) + )] + for leaf_i, leaf in enumerate(content.split(r"")): + leaf_stripped = leaf.strip(delim) + if ( + leaf_stripped in self.decoder.tokenizer.get_added_vocab() + and leaf_stripped[0] == "<" + and leaf_stripped[-2:] == "/>" + ): + leaf_stripped = leaf_stripped[1:-2] # for categorical special tokens + if not leaf_stripped: + continue + if self.return_confs and self.return_tokens: + output[key].append( + [leaf_stripped, leaf_content_confs[leaf_i], leaf_content_idxs[leaf_i]] + ) + elif self.return_confs: + output[key].append([leaf_stripped, leaf_content_confs[leaf_i]]) + elif self.return_tokens: + output[key].append([leaf_stripped, leaf_content_idxs[leaf_i]]) + else: + output[key].append(leaf_stripped) + if len(output[key]) == 1: + output[key] = output[key][0] + for i, tkn in enumerate(tokens_split): + if re.search(end_token, tkn, re.IGNORECASE): + confs = confs[i + 1:] + idxs = idxs[i + 1:] + break + tokens = tokens[tokens.find(end_token) + len(end_token):].strip(delim) + if tokens[:6] == r"": # non-leaf nodes + return [output] + self.token2json_with_confs( + tokens[6:], confs[1:], idxs[1:], delim, is_inner_val=True + ) + + if len(output): + return [output] if is_inner_val else output + + return [] if is_inner_val else {} @classmethod def from_pretrained( @@ -594,18 +892,18 @@ def from_pretrained( Name of a pretrained model name either registered in huggingface.co. or saved in local, e.g., `naver-clova-ix/donut-base`, or `naver-clova-ix/donut-base-finetuned-rvlcdip` """ - model = super(DonutModel, cls).from_pretrained(pretrained_model_name_or_path, revision="official", *model_args, **kwargs) + model = super(DonutModel, cls).from_pretrained( + pretrained_model_name_or_path, revision="official", *model_args, **kwargs + ) - # truncate or interplolate position embeddings of donut decoder + # truncate or interpolate position embeddings of donut decoder max_length = kwargs.get("max_length", model.config.max_position_embeddings) - if ( - max_length != model.config.max_position_embeddings - ): # if max_length of trained model differs max_length you want to train + # if max_length of trained model differs max_length you want to train + if max_length != model.config.max_position_embeddings: + # https://github.com/huggingface/transformers/blob/v4.11.3/src/transformers/models/mbart/modeling_mbart.py#L118-L119 model.decoder.model.model.decoder.embed_positions.weight = torch.nn.Parameter( model.decoder.resize_bart_abs_pos_emb( - model.decoder.model.model.decoder.embed_positions.weight, - max_length - + 2, # https://github.com/huggingface/transformers/blob/v4.11.3/src/transformers/models/mbart/modeling_mbart.py#L118-L119 + model.decoder.model.model.decoder.embed_positions.weight, max_length + 2 ) ) model.config.max_position_embeddings = max_length diff --git a/donut/util.py b/donut/util.py old mode 100755 new mode 100644 index 16b542fe..2f5f41a2 --- a/donut/util.py +++ b/donut/util.py @@ -71,8 +71,13 @@ def __init__( assert isinstance(ground_truth["gt_parses"], list) gt_jsons = ground_truth["gt_parses"] else: - assert "gt_parse" in ground_truth and isinstance(ground_truth["gt_parse"], dict) - gt_jsons = [ground_truth["gt_parse"]] + assert "gt_parse" in ground_truth and ( + isinstance(ground_truth["gt_parse"], dict) or + isinstance(ground_truth["gt_parse"], list) + ) + gt_jsons = [ground_truth["gt_parse"]] if isinstance( + ground_truth["gt_parse"], dict + ) else ground_truth["gt_parse"] self.gt_token_sequences.append( [ diff --git a/lightning_module.py b/lightning_module.py index f05e5ce7..cf47daad 100755 --- a/lightning_module.py +++ b/lightning_module.py @@ -192,7 +192,7 @@ def val_dataloader(self): return loaders @staticmethod - def seed_worker(wordker_id): + def seed_worker(worker_id): worker_seed = torch.initial_seed() % 2 ** 32 np.random.seed(worker_seed) random.seed(worker_seed) diff --git a/setup.py b/setup.py index 572977f9..2c3c8a52 100755 --- a/setup.py +++ b/setup.py @@ -30,9 +30,9 @@ def read_long_description(): description="OCR-free Document Understanding Transformer", long_description=read_long_description(), long_description_content_type="text/markdown", - author="Geewook Kim, Teakgyu Hong, Moonbin Yim, JeongYeon Nam, Jinyoung Park, Jinyeong Yim, Wonseok Hwang, Sangdoo Yun, Dongyoon Han, Seunghyun Park", - author_email="gwkim.rsrch@gmail.com", - url="https://github.com/clovaai/donut", + author="Geewook Kim, Teakgyu Hong, Moonbin Yim, JeongYeon Nam, Jinyoung Park, Jinyeong Yim, Wonseok Hwang, Sangdoo Yun, Dongyoon Han, Seunghyun Park, Matteo Cacciola", + author_email="gwkim.rsrch@gmail.com, matteo.cacciola@gmail.com", + url="https://github.com/clovaai/donut, https://github.com/matteocacciola/donut", license="MIT", packages=find_packages( exclude=[ @@ -50,7 +50,7 @@ def read_long_description(): ), python_requires=">=3.7", install_requires=[ - "transformers>=4.11.3", + "transformers>=4.21.1", "timm", "datasets[vision]", "pytorch-lightning>=1.6.4", diff --git a/test.py b/test.py index 4d1b52aa..bce122cc 100755 --- a/test.py +++ b/test.py @@ -6,16 +6,13 @@ import argparse import json import os -import re -from pathlib import Path import numpy as np import torch from datasets import load_dataset -from PIL import Image from tqdm import tqdm -from donut import DonutModel, JSONParseEvaluator, load_json, save_json +from donut import DonutModel, JSONParseEvaluator, save_json def test(args): diff --git a/train.py b/train.py index 0f7b801c..46ffe284 100755 --- a/train.py +++ b/train.py @@ -5,14 +5,10 @@ """ import argparse import datetime -import json import os -import random -from io import BytesIO from os.path import basename from pathlib import Path -import numpy as np import pytorch_lightning as pl import torch from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint @@ -87,35 +83,36 @@ def train(config): datasets = {"train": [], "validation": []} for i, dataset_name_or_path in enumerate(config.dataset_name_or_paths): task_name = os.path.basename(dataset_name_or_path) # e.g., cord-v2, docvqa, rvlcdip, ... - + task_start_token = config.task_start_tokens[i] if config.get("task_start_tokens", None) else f"" + prompt_end_token = f"" + # add categorical special tokens (optional) if task_name == "rvlcdip": model_module.model.decoder.add_special_tokens([ - "", "", "", "", - "
", "", "", "", - "", "", "", "", + "", "", "", "", + "", "", "", "", + "", "", "", "", "", "", "", "" ]) if task_name == "docvqa": model_module.model.decoder.add_special_tokens(["", ""]) - + for split in ["train", "validation"]: + # prompt_end_token is used for ignoring a given prompt in a loss function + # for docvqa task, i.e., {"question": {used as a prompt}, "answer": {prediction target}}, + # set prompt_end_token to "" datasets[split].append( DonutDataset( dataset_name_or_path=dataset_name_or_path, donut_model=model_module.model, max_length=config.max_length, split=split, - task_start_token=config.task_start_tokens[i] - if config.get("task_start_tokens", None) - else f"", - prompt_end_token="" if "docvqa" in dataset_name_or_path else f"", + task_start_token=task_start_token, + prompt_end_token=prompt_end_token, sort_json_key=config.sort_json_key, ) ) - # prompt_end_token is used for ignoring a given prompt in a loss function - # for docvqa task, i.e., {"question": {used as a prompt}, "answer": {prediction target}}, - # set prompt_end_token to "" + data_module.train_datasets = datasets["train"] data_module.val_datasets = datasets["validation"]