diff --git a/.gitignore b/.gitignore
index 8b01850e..880cd98b 100755
--- a/.gitignore
+++ b/.gitignore
@@ -137,3 +137,6 @@ dmypy.json
# Pyre type checker
.pyre/
+
+.idea
+.DS_Store
\ No newline at end of file
diff --git a/README.md b/README.md
index 58f3e9fb..76dfad36 100755
--- a/README.md
+++ b/README.md
@@ -34,12 +34,12 @@ Gradio web demos are available! [ to ensure its proper working.
-|Task|Sec/Img|Score|Trained Model|
Demo
|
-|---|---|---|---|---|
-| [CORD](https://github.com/clovaai/cord) (Document Parsing) | 0.7 /
0.7 /
1.2 | 91.3 /
91.1 /
90.9 | [donut-base-finetuned-cord-v2](https://huggingface.co/naver-clova-ix/donut-base-finetuned-cord-v2/tree/official) (1280) /
[donut-base-finetuned-cord-v1](https://huggingface.co/naver-clova-ix/donut-base-finetuned-cord-v1/tree/official) (1280) /
[donut-base-finetuned-cord-v1-2560](https://huggingface.co/naver-clova-ix/donut-base-finetuned-cord-v1-2560/tree/official) | [gradio space web demo](https://huggingface.co/spaces/naver-clova-ix/donut-base-finetuned-cord-v2),
[google colab demo (updated at 23.06.15)](https://colab.research.google.com/drive/1NMSqoIZ_l39wyRD7yVjw2FIuU2aglzJi?usp=sharing) |
-| [Train Ticket](https://github.com/beacandler/EATEN) (Document Parsing) | 0.6 | 98.7 | [donut-base-finetuned-zhtrainticket](https://huggingface.co/naver-clova-ix/donut-base-finetuned-zhtrainticket/tree/official) | [google colab demo (updated at 23.06.15)](https://colab.research.google.com/drive/1YJBjllahdqNktXaBlq5ugPh1BCm8OsxI?usp=sharing) |
-| [RVL-CDIP](https://www.cs.cmu.edu/~aharley/rvl-cdip) (Document Classification) | 0.75 | 95.3 | [donut-base-finetuned-rvlcdip](https://huggingface.co/naver-clova-ix/donut-base-finetuned-rvlcdip/tree/official) | [gradio space web demo](https://huggingface.co/spaces/nielsr/donut-rvlcdip),
[google colab demo (updated at 23.06.15)](https://colab.research.google.com/drive/1iWOZHvao1W5xva53upcri5V6oaWT-P0O?usp=sharing) |
-| [DocVQA Task1](https://rrc.cvc.uab.es/?ch=17) (Document VQA) | 0.78 | 67.5 | [donut-base-finetuned-docvqa](https://huggingface.co/naver-clova-ix/donut-base-finetuned-docvqa/tree/official) | [gradio space web demo](https://huggingface.co/spaces/nielsr/donut-docvqa),
[google colab demo (updated at 23.06.15)](https://colab.research.google.com/drive/1oKieslZCulFiquequ62eMGc-ZWgay4X3?usp=sharing) |
+| Task | Sec/Img | Score | Trained Model | Demo
|
+|--------------------------------------------------------------------------------|-------------------------|----------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
+| [CORD](https://github.com/clovaai/cord) (Document Parsing) | 0.7 /
0.7 /
1.2 | 91.3 /
91.1 /
90.9 | [donut-base-finetuned-cord-v2](https://huggingface.co/naver-clova-ix/donut-base-finetuned-cord-v2/tree/official) (1280) /
[donut-base-finetuned-cord-v1](https://huggingface.co/naver-clova-ix/donut-base-finetuned-cord-v1/tree/official) (1280) /
[donut-base-finetuned-cord-v1-2560](https://huggingface.co/naver-clova-ix/donut-base-finetuned-cord-v1-2560/tree/official) | [gradio space web demo](https://huggingface.co/spaces/naver-clova-ix/donut-base-finetuned-cord-v2),
[google colab demo (updated at 23.06.15)](https://colab.research.google.com/drive/1NMSqoIZ_l39wyRD7yVjw2FIuU2aglzJi?usp=sharing) |
+| [Train Ticket](https://github.com/beacandler/EATEN) (Document Parsing) | 0.6 | 98.7 | [donut-base-finetuned-zhtrainticket](https://huggingface.co/naver-clova-ix/donut-base-finetuned-zhtrainticket/tree/official) | [google colab demo (updated at 23.06.15)](https://colab.research.google.com/drive/1YJBjllahdqNktXaBlq5ugPh1BCm8OsxI?usp=sharing) |
+| [RVL-CDIP](https://www.cs.cmu.edu/~aharley/rvl-cdip) (Document Classification) | 0.75 | 95.3 | [donut-base-finetuned-rvlcdip](https://huggingface.co/naver-clova-ix/donut-base-finetuned-rvlcdip/tree/official) | [gradio space web demo](https://huggingface.co/spaces/nielsr/donut-rvlcdip),
[google colab demo (updated at 23.06.15)](https://colab.research.google.com/drive/1iWOZHvao1W5xva53upcri5V6oaWT-P0O?usp=sharing) |
+| [DocVQA Task1](https://rrc.cvc.uab.es/?ch=17) (Document VQA) | 0.78 | 67.5 | [donut-base-finetuned-docvqa](https://huggingface.co/naver-clova-ix/donut-base-finetuned-docvqa/tree/official) | [gradio space web demo](https://huggingface.co/spaces/nielsr/donut-docvqa),
[google colab demo (updated at 23.06.15)](https://colab.research.google.com/drive/1oKieslZCulFiquequ62eMGc-ZWgay4X3?usp=sharing) |
The links to the pre-trained backbones are here:
- [`donut-base`](https://huggingface.co/naver-clova-ix/donut-base/tree/official): trained with 64 A100 GPUs (~2.5 days), number of layers (encoder: {2,2,14,2}, decoder: 4), input size 2560x1920, swin window size 10, IIT-CDIP (11M) and SynthDoG (English, Chinese, Japanese, Korean, 0.5M x 4).
@@ -62,6 +62,9 @@ To generate synthetic datasets with our SynthDoG, please see `./synthdog/README.
## Updates
+**_2024-05-15_** We have introduced the opportunity to retrieve the scores with the extracted items with the usual
+`DonutModel`. The `inference` method has the `return_confs` to return the scores of the predicted items. The parameter
+is set to `True` by default. If you don't want the scores, please set `return_confs` to `False`.
**_2023-06-15_** We have updated all Google Colab demos to ensure its proper working.
**_2022-11-14_** New version 1.0.9 is released (`pip install donut-python --upgrade`). See [1.0.9 Release Notes](https://github.com/clovaai/donut/releases/tag/1.0.9).
**_2022-08-12_** Donut 🍩 is also available at [huggingface/transformers 🤗](https://huggingface.co/docs/transformers/main/en/model_doc/donut) (contributed by [@NielsRogge](https://github.com/NielsRogge)). `donut-python` loads the pre-trained weights from the `official` branch of the model repositories. See [1.0.5 Release Notes](https://github.com/clovaai/donut/releases/tag/1.0.5).
diff --git a/donut/model.py b/donut/model.py
index 6321c5d2..eeff4480 100755
--- a/donut/model.py
+++ b/donut/model.py
@@ -6,21 +6,20 @@
import math
import os
import re
-from typing import Any, List, Optional, Union
+from typing import Any, List, Optional, Tuple, Union
import numpy as np
-import PIL
import timm
import torch
import torch.nn as nn
-import torch.nn.functional as F
-from PIL import ImageOps
+import torch.nn.functional as fnc
+from PIL import Image, ImageOps
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.models.swin_transformer import SwinTransformer
from torchvision import transforms
-from torchvision.transforms.functional import resize, rotate
from transformers import MBartConfig, MBartForCausalLM, XLMRobertaTokenizer
from transformers.file_utils import ModelOutput
+from transformers.utils.generic import to_py_obj
from transformers.modeling_utils import PretrainedConfig, PreTrainedModel
@@ -41,17 +40,19 @@ class SwinEncoder(nn.Module):
def __init__(
self,
- input_size: List[int],
+ input_size: Tuple[int, int],
align_long_axis: bool,
window_size: int,
- encoder_layer: List[int],
+ encoder_layer: Tuple[int],
name_or_path: Union[str, bytes, os.PathLike] = None,
+ drop_rate: float | None = 0.0,
):
super().__init__()
self.input_size = input_size
self.align_long_axis = align_long_axis
self.window_size = window_size
self.encoder_layer = encoder_layer
+ self.drop_rate = drop_rate
self.to_tensor = transforms.Compose(
[
@@ -66,10 +67,9 @@ def __init__(
window_size=self.window_size,
patch_size=4,
embed_dim=128,
- num_heads=[4, 8, 16, 32],
+ num_heads=(4, 8, 16, 32),
num_classes=0,
)
- self.model.norm = None
# weight init with swin
if not name_or_path:
@@ -86,7 +86,7 @@ def __init__(
old_len = int(math.sqrt(len(pos_bias)))
new_len = int(2 * window_size - 1)
pos_bias = pos_bias.reshape(1, old_len, old_len, -1).permute(0, 3, 1, 2)
- pos_bias = F.interpolate(pos_bias, size=(new_len, new_len), mode="bicubic", align_corners=False)
+ pos_bias = fnc.interpolate(pos_bias, size=(new_len, new_len), mode="bicubic", align_corners=False)
new_swin_state_dict[x] = pos_bias.permute(0, 2, 3, 1).reshape(1, new_len ** 2, -1).squeeze(0)
else:
new_swin_state_dict[x] = swin_state_dict[x]
@@ -98,11 +98,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
x: (batch_size, num_channels, height, width)
"""
x = self.model.patch_embed(x)
- x = self.model.pos_drop(x)
+ x = nn.Dropout(p=self.drop_rate)(x)
x = self.model.layers(x)
return x
- def prepare_input(self, img: PIL.Image.Image, random_padding: bool = False) -> torch.Tensor:
+ def prepare_input(self, img: Image.Image, random_padding: bool = False) -> torch.Tensor:
"""
Convert PIL Image to tensor according to specified input_size after following steps below:
- resize
@@ -114,8 +114,8 @@ def prepare_input(self, img: PIL.Image.Image, random_padding: bool = False) -> t
(self.input_size[0] > self.input_size[1] and img.width > img.height)
or (self.input_size[0] < self.input_size[1] and img.width < img.height)
):
- img = rotate(img, angle=-90, expand=True)
- img = resize(img, min(self.input_size))
+ img = img.rotate(angle=-90, expand=True)
+ img = img.resize(self.input_size)
img.thumbnail((self.input_size[1], self.input_size[0]))
delta_width = self.input_size[1] - img.width
delta_height = self.input_size[0] - img.height
@@ -134,6 +134,147 @@ def prepare_input(self, img: PIL.Image.Image, random_padding: bool = False) -> t
return self.to_tensor(ImageOps.expand(img, padding))
+class BARTCustomTokenizer(XLMRobertaTokenizer):
+ """
+ Customized XLMRobertaTokenizer to return confidence scores and token id groups aligned with grouped tokens
+ The default batch_decoder, decode and _decode are overwritten for the Tokenizer
+ """
+
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+ self.DELIM = None
+
+ def batch_decode(
+ self,
+ sequences: Union[List[int], List[List[int]], "np.ndarray", "torch.Tensor"],
+ skip_special_tokens: bool = False,
+ clean_up_tokenization_spaces: bool = True,
+ **kwargs
+ ) -> List[Tuple]:
+ """
+ Convert a list of lists of token ids into a list of strings by calling decode.
+ Args:
+ sequences (`Union[List[int], List[List[int]], np.ndarray, torch.Tensor, tf.Tensor]`):
+ List of tokenized input ids. Can be obtained using the `__call__` method.
+ skip_special_tokens (`bool`, *optional*, defaults to `False`):
+ Whether to remove special tokens in the decoding.
+ clean_up_tokenization_spaces (`bool`, *optional*, defaults to `True`):
+ Whether to clean up the tokenization spaces.
+ kwargs (additional keyword arguments, *optional*):
+ Will be passed to the underlying model specific decode method.
+ Returns:
+ `List[str]`: The list of decoded sentences.
+ """
+ confidences = kwargs.pop("confidences", [])
+ self.DELIM = kwargs.pop("decoder_delim", None)
+
+ result = []
+ for seq, conf in zip(sequences, confidences):
+ kwargs["token_confs"] = conf
+ result.append(self.decode(
+ seq,
+ skip_special_tokens=skip_special_tokens,
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
+ **kwargs,
+ ))
+ return result
+
+ def decode(
+ self,
+ token_ids: Union[int, List[int], "np.ndarray", "torch.Tensor"],
+ skip_special_tokens: bool = False,
+ clean_up_tokenization_spaces: bool = True,
+ **kwargs
+ ) -> Tuple:
+ """
+ Converts a sequence of ids in a string, using the tokenizer and vocabulary with options to remove special
+ tokens and clean up tokenization spaces.
+ Similar to doing `self.convert_tokens_to_string(self.convert_ids_to_tokens(token_ids))`.
+ Args:
+ token_ids (`Union[int, List[int], np.ndarray, torch.Tensor, tf.Tensor]`):
+ List of tokenized input ids. Can be obtained using the `__call__` method.
+ skip_special_tokens (`bool`, *optional*, defaults to `False`):
+ Whether to remove special tokens in the decoding.
+ clean_up_tokenization_spaces (`bool`, *optional*, defaults to `True`):
+ Whether to clean up the tokenization spaces.
+ kwargs (additional keyword arguments, *optional*):
+ Will be passed to the underlying model specific decode method.
+ Returns:
+ `str`: The decoded sentence.
+ """
+ # Convert inputs to python lists
+ token_ids = to_py_obj(token_ids)
+ kwargs["token_confs"] = to_py_obj(kwargs.pop("token_confs", []))
+
+ return self._decode(
+ token_ids=token_ids,
+ skip_special_tokens=skip_special_tokens,
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
+ **kwargs,
+ )
+
+ def _decode(
+ self,
+ token_ids: List[int],
+ skip_special_tokens: bool = False,
+ clean_up_tokenization_spaces: bool = True,
+ spaces_between_special_tokens: bool = True,
+ **kwargs
+ ) -> Tuple:
+ token_confs = kwargs.pop("token_confs", [])
+
+ self._decode_use_source_tokenizer = kwargs.pop("use_source_tokenizer", False)
+
+ filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens)
+
+ # To avoid mixing byte-level and unicode for byte-level BPT
+ # we need to build string separately for added tokens and byte-level tokens
+ # cf. https://github.com/huggingface/transformers/issues/1133
+ sub_texts = []
+ sub_confs = []
+ sub_idxs = []
+ current_sub_text = []
+ current_sub_confs = []
+ current_sub_idxs = []
+
+ for idx, (token, conf) in enumerate(zip(filtered_tokens, token_confs)):
+ if skip_special_tokens and token in self.all_special_ids:
+ continue
+ if token in self.added_tokens_encoder:
+ if current_sub_text:
+ sub_texts.append(self.convert_tokens_to_string(current_sub_text))
+ current_sub_text = []
+ sub_confs.append(sum(current_sub_confs) / len(current_sub_confs))
+ current_sub_confs = []
+ sub_idxs.append(current_sub_idxs)
+ current_sub_idxs = []
+ sub_texts.append(token)
+ sub_confs.append(conf)
+ sub_idxs.append([idx])
+ else:
+ current_sub_text.append(token)
+ current_sub_confs.append(conf)
+ current_sub_idxs.append(idx)
+
+ if current_sub_text:
+ sub_texts.append(self.convert_tokens_to_string(current_sub_text))
+ sub_confs.append(sum(current_sub_confs) / len(current_sub_confs))
+ sub_idxs.append(current_sub_idxs)
+
+ decoder_output_confs = sub_confs
+ decoder_output_indxs = sub_idxs
+ if spaces_between_special_tokens:
+ text = self.DELIM.join(sub_texts)
+ else:
+ text = "".join(sub_texts)
+
+ if clean_up_tokenization_spaces:
+ clean_text = self.clean_up_tokenization(text)
+ return clean_text, decoder_output_confs, decoder_output_indxs
+
+ return text, decoder_output_confs, decoder_output_indxs
+
+
class BARTDecoder(nn.Module):
"""
Donut Decoder based on Multilingual BART
@@ -157,7 +298,7 @@ def __init__(
self.decoder_layer = decoder_layer
self.max_position_embeddings = max_position_embeddings
- self.tokenizer = XLMRobertaTokenizer.from_pretrained(
+ self.tokenizer = BARTCustomTokenizer.from_pretrained(
"hyunwoongko/asian-bart-ecjk" if not name_or_path else name_or_path
)
@@ -173,7 +314,7 @@ def __init__(
add_final_layer_norm=True,
)
)
- self.model.forward = self.forward # to get cross attentions and utilize `generate` function
+ self.model.forward = self.forward # to get cross attentions and utilize `generate` function
self.model.config.is_encoder_decoder = True # to get cross-attention
self.add_special_tokens([""]) # is used for representing a list in a JSON
@@ -186,12 +327,9 @@ def __init__(
new_bart_state_dict = self.model.state_dict()
for x in new_bart_state_dict:
if x.endswith("embed_positions.weight") and self.max_position_embeddings != 1024:
+ # https://github.com/huggingface/transformers/blob/v4.11.3/src/transformers/models/mbart/modeling_mbart.py#L118-L119
new_bart_state_dict[x] = torch.nn.Parameter(
- self.resize_bart_abs_pos_emb(
- bart_state_dict[x],
- self.max_position_embeddings
- + 2, # https://github.com/huggingface/transformers/blob/v4.11.3/src/transformers/models/mbart/modeling_mbart.py#L118-L119
- )
+ self.resize_bart_abs_pos_emb(bart_state_dict[x], self.max_position_embeddings + 2)
)
elif x.endswith("embed_tokens.weight") or x.endswith("lm_head.weight"):
new_bart_state_dict[x] = bart_state_dict[x][: len(self.tokenizer), :]
@@ -199,35 +337,49 @@ def __init__(
new_bart_state_dict[x] = bart_state_dict[x]
self.model.load_state_dict(new_bart_state_dict)
- def add_special_tokens(self, list_of_tokens: List[str]):
+ def add_special_tokens(self, list_of_tokens: List[str], replace_additional_special_tokens: bool | None = False):
"""
Add special tokens to tokenizer and resize the token embeddings
"""
- newly_added_num = self.tokenizer.add_special_tokens({"additional_special_tokens": sorted(set(list_of_tokens))})
+ newly_added_num = 0
+ set_of_tokens = set(list_of_tokens)
+ set_special_tokens = set(self.tokenizer.all_special_tokens)
+ if len(set_of_tokens - set_special_tokens) > 0:
+ newly_added_num = self.tokenizer.add_special_tokens(
+ {"additional_special_tokens": sorted(set_of_tokens)},
+ replace_additional_special_tokens=replace_additional_special_tokens
+ )
if newly_added_num > 0:
self.model.resize_token_embeddings(len(self.tokenizer))
- def prepare_inputs_for_inference(self, input_ids: torch.Tensor, encoder_outputs: torch.Tensor, past_key_values=None, past=None, use_cache: bool = None, attention_mask: torch.Tensor = None):
+ def prepare_inputs_for_inference(
+ self,
+ input_ids: torch.Tensor,
+ encoder_outputs: torch.Tensor = None,
+ past=None,
+ use_cache: bool = None,
+ **kwargs,
+ ):
"""
Args:
- input_ids: (batch_size, sequence_lenth)
+ input_ids: (batch_size, sequence_length)
+ encoder_outputs: (batch_size, sequence_length, hidden_size)
+ past: Past key values
+ use_cache: Whether to use cache or not
Returns:
input_ids: (batch_size, sequence_length)
attention_mask: (batch_size, sequence_length)
encoder_hidden_states: (batch_size, sequence_length, embedding_dim)
"""
- # for compatibility with transformers==4.11.x
- if past is not None:
- past_key_values = past
attention_mask = input_ids.ne(self.tokenizer.pad_token_id).long()
- if past_key_values is not None:
+ if past is not None:
input_ids = input_ids[:, -1:]
output = {
"input_ids": input_ids,
"attention_mask": attention_mask,
- "past_key_values": past_key_values,
+ "past_key_values": past,
"use_cache": use_cache,
- "encoder_hidden_states": encoder_outputs.last_hidden_state,
+ "encoder_hidden_states": encoder_outputs.last_hidden_state if encoder_outputs is not None else None,
}
return output
@@ -244,7 +396,7 @@ def forward(
return_dict: bool = None,
):
"""
- A forward fucntion to get cross attentions and utilize `generate` function
+ A forward function to get cross attentions and utilize `generate` function
Source:
https://github.com/huggingface/transformers/blob/v4.11.3/src/transformers/models/mbart/modeling_mbart.py#L1669-L1810
@@ -253,7 +405,12 @@ def forward(
input_ids: (batch_size, sequence_length)
attention_mask: (batch_size, sequence_length)
encoder_hidden_states: (batch_size, sequence_length, hidden_size)
-
+ past_key_values:
+ labels: (batch_size, sequence_length)
+ use_cache: Whether to use cache or not
+ output_attentions: Whether to return attentions or not
+ output_hidden_states: Whether to return hidden states or not
+ return_dict: Whether to return dict or not
Returns:
loss: (1, )
logits: (batch_size, sequence_length, hidden_dim)
@@ -308,7 +465,7 @@ def resize_bart_abs_pos_emb(weight: torch.Tensor, max_length: int) -> torch.Tens
weight = weight[:max_length, ...]
else:
weight = (
- F.interpolate(
+ fnc.interpolate(
weight.permute(1, 0).unsqueeze(0),
size=max_length,
mode="linear",
@@ -349,10 +506,10 @@ class DonutConfig(PretrainedConfig):
def __init__(
self,
- input_size: List[int] = [2560, 1920],
+ input_size: Tuple[int, int] | None = None,
align_long_axis: bool = False,
window_size: int = 10,
- encoder_layer: List[int] = [2, 2, 14, 2],
+ encoder_layer: Tuple[int] | None = None,
decoder_layer: int = 4,
max_position_embeddings: int = None,
max_length: int = 1536,
@@ -360,10 +517,10 @@ def __init__(
**kwargs,
):
super().__init__()
- self.input_size = input_size
+ self.input_size = input_size if input_size else (2560, 1920)
self.align_long_axis = align_long_axis
self.window_size = window_size
- self.encoder_layer = encoder_layer
+ self.encoder_layer = encoder_layer if encoder_layer else (2, 2, 14, 2)
self.decoder_layer = decoder_layer
self.max_position_embeddings = max_length if max_position_embeddings is None else max_position_embeddings
self.max_length = max_length
@@ -382,6 +539,8 @@ class DonutModel(PreTrainedModel):
def __init__(self, config: DonutConfig):
super().__init__(config)
+ self.return_confs = None
+ self.return_tokens = None
self.config = config
self.encoder = SwinEncoder(
input_size=self.config.input_size,
@@ -404,7 +563,7 @@ def forward(self, image_tensors: torch.Tensor, decoder_input_ids: torch.Tensor,
Args:
image_tensors: (batch_size, num_channels, height, width)
decoder_input_ids: (batch_size, sequence_length, embedding_dim)
- decode_labels: (batch_size, sequence_length)
+ decoder_labels: (batch_size, sequence_length)
"""
encoder_outputs = self.encoder(image_tensors)
decoder_outputs = self.decoder(
@@ -416,16 +575,18 @@ def forward(self, image_tensors: torch.Tensor, decoder_input_ids: torch.Tensor,
def inference(
self,
- image: PIL.Image = None,
+ image: Image = None,
prompt: str = None,
image_tensors: Optional[torch.Tensor] = None,
prompt_tensors: Optional[torch.Tensor] = None,
return_json: bool = True,
- return_attentions: bool = False,
+ return_confs: bool = True,
+ return_tokens: bool = False,
+ return_attentions: bool = False
):
"""
- Generate a token sequence in an auto-regressive manner,
- the generated token sequence is convereted into an ordered JSON format
+ Generate a token sequence in an autoregressive manner,
+ the generated token sequence is converted into an ordered JSON format
Args:
image: input document image (PIL.Image)
@@ -434,6 +595,10 @@ def inference(
convert prompt to tensor if image_tensor is not fed
prompt_tensors: (1, sequence_length)
convert image to tensor if prompt_tensor is not fed
+ return_json: whether to return a JSON format or not
+ return_confs: whether to return confidence scores or not
+ return_tokens: whether to return tokens or not
+ return_attentions: whether to return attentions or not
"""
# prepare backbone inputs (image and prompt)
if image is None and image_tensors is None:
@@ -447,6 +612,8 @@ def inference(
if self.device.type == "cuda": # half is not compatible in cpu implementation.
image_tensors = image_tensors.half()
image_tensors = image_tensors.to(self.device)
+ else:
+ image_tensors = image_tensors.to(torch.bfloat16)
if prompt_tensors is None:
prompt_tensors = self.decoder.tokenizer(prompt, add_special_tokens=False, return_tensors="pt")["input_ids"]
@@ -477,16 +644,43 @@ def inference(
bad_words_ids=[[self.decoder.tokenizer.unk_token_id]],
return_dict_in_generate=True,
output_attentions=return_attentions,
+ output_scores=True,
)
+ decoder_output_confs = torch.amax(torch.stack(decoder_output.scores, dim=1).softmax(-1), 2).cpu().numpy()[0]
+ # add score for end token and wrap scores in a list
+ decoder_output_confs = [np.concatenate([decoder_output_confs, [1.]], axis=0)]
+
output = {"predictions": list()}
- for seq in self.decoder.tokenizer.batch_decode(decoder_output.sequences):
- seq = seq.replace(self.decoder.tokenizer.eos_token, "").replace(self.decoder.tokenizer.pad_token, "")
- seq = re.sub(r"<.*?>", "", seq, count=1).strip() # remove first task start token
- if return_json:
- output["predictions"].append(self.token2json(seq))
- else:
- output["predictions"].append(seq)
+ self.return_tokens = return_tokens
+ self.return_confs = return_confs
+ delimiter = "}~}~}~{" # important, use a delimiter that has a very low prob of appearing in text
+
+ for idx, (seq, confs, idxs) in enumerate(self.decoder.tokenizer.batch_decode(
+ decoder_output.sequences, confidences=decoder_output_confs, decoder_delim=delimiter)
+ ):
+ eos_tkn, pad_tkn = self.decoder.tokenizer.eos_token, self.decoder.tokenizer.pad_token
+ split_seq = [tkn for tkn in seq.split(delimiter) if tkn]
+ confs = [confs[i] for i, tkn in enumerate(split_seq) if not (
+ tkn.strip().lower() == eos_tkn.lower() or tkn.strip().lower() == pad_tkn.lower()
+ )]
+ idxs = [idxs[i] for i, tkn in enumerate(seq.split(delimiter)) if not (
+ tkn.strip().lower() == eos_tkn.lower() or tkn.strip().lower() == pad_tkn.lower()
+ )]
+ seq = seq.replace(eos_tkn, "").replace(pad_tkn, "")
+ for i, tkn in enumerate(seq.split(delimiter)):
+ if re.search(r"<.*?>", tkn, re.IGNORECASE): # remove first task start token conf
+ confs.pop(i)
+ idxs.pop(i)
+ break
+ seq = re.sub(r"<.*?>", "", seq, count=1).strip(delimiter) # remove first task start token
+ item = seq
+ if confs and idxs and return_json:
+ item = self.token2json_with_confs(seq, confs, idxs, delim=delimiter) if (
+ return_confs or return_tokens
+ ) else self.token2json(seq.replace(delimiter, ' '))
+
+ output["predictions"].append(item)
if return_attentions:
output["attentions"] = {
@@ -500,37 +694,37 @@ def json2token(self, obj: Any, update_special_tokens_for_json_key: bool = True,
"""
Convert an ordered JSON object into a token sequence
"""
- if type(obj) == dict:
+ if isinstance(obj, dict):
if len(obj) == 1 and "text_sequence" in obj:
return obj["text_sequence"]
- else:
- output = ""
- if sort_json_key:
- keys = sorted(obj.keys(), reverse=True)
- else:
- keys = obj.keys()
- for k in keys:
- if update_special_tokens_for_json_key:
- self.decoder.add_special_tokens([fr"", fr""])
- output += (
- fr""
- + self.json2token(obj[k], update_special_tokens_for_json_key, sort_json_key)
- + fr""
- )
- return output
- elif type(obj) == list:
+
+ output = ""
+ keys = obj.keys()
+ if sort_json_key:
+ keys = sorted(keys, reverse=True)
+ for k in keys:
+ if update_special_tokens_for_json_key:
+ self.decoder.add_special_tokens([fr"", fr""])
+ output += (
+ fr""
+ + self.json2token(obj[k], update_special_tokens_for_json_key, sort_json_key)
+ + fr""
+ )
+ return output
+
+ if isinstance(obj, list):
return r"".join(
[self.json2token(item, update_special_tokens_for_json_key, sort_json_key) for item in obj]
)
- else:
- obj = str(obj)
- if f"<{obj}/>" in self.decoder.tokenizer.all_special_tokens:
- obj = f"<{obj}/>" # for categorical special tokens
- return obj
- def token2json(self, tokens, is_inner_value=False):
+ obj = str(obj)
+ if f"<{obj}/>" in self.decoder.tokenizer.all_special_tokens:
+ obj = f"<{obj}/>" # for categorical special tokens
+ return obj
+
+ def token2json(self, tokens: str, is_inner_value: bool = False) -> List[dict]:
"""
- Convert a (generated) token seuqnce into an ordered JSON format
+ Convert a (generated) token sequence into an ordered JSON format
"""
output = dict()
@@ -570,14 +764,118 @@ def token2json(self, tokens, is_inner_value=False):
if len(output[key]) == 1:
output[key] = output[key][0]
- tokens = tokens[tokens.find(end_token) + len(end_token) :].strip()
+ tokens = tokens[tokens.find(end_token) + len(end_token):].strip()
if tokens[:6] == r"": # non-leaf nodes
return [output] + self.token2json(tokens[6:], is_inner_value=True)
if len(output):
return [output] if is_inner_value else output
- else:
- return [] if is_inner_value else {"text_sequence": tokens}
+
+ return [] if is_inner_value else {"text_sequence": tokens}
+
+ def token2json_with_confs(
+ self, tokens: str, confs: List[float], idxs: List[list], delim: str, is_inner_val: bool = False
+ ) -> List:
+ """
+ Convert a (generated) token sequence into an ordered JSON format
+ """
+ output = dict()
+
+ while tokens:
+ start_token = re.search(r"", tokens, re.IGNORECASE)
+ if start_token is None:
+ break
+ key = start_token.group(1)
+ end_token = re.search(fr"", tokens, re.IGNORECASE)
+ start_token = start_token.group()
+ tokens_split = [tkn for tkn in tokens.split(delim) if tkn]
+ assert len(tokens_split) == len(confs) == len(idxs)
+
+ if end_token is None:
+ # remove all occurrences of start_token idxes from confs list and idxs list
+ confs = [
+ confs[i] for i, tkn in enumerate(tokens_split) if not re.search(start_token, tkn, re.IGNORECASE)
+ ]
+ idxs = [idxs[i] for i, tkn in enumerate(tokens_split) if not re.search(start_token, tkn, re.IGNORECASE)]
+ tokens = tokens.replace(start_token, "")
+ tksplit = [tk for tk in tokens.split(delim) if tk]
+ assert len(tksplit) == len(confs) == len(idxs)
+ else:
+ end_token = end_token.group()
+ start_token_escaped = re.escape(start_token)
+ end_token_escaped = re.escape(end_token)
+ content = re.search(f"{start_token_escaped}(.*?){end_token_escaped}", tokens, re.IGNORECASE)
+ if content is not None:
+ start_tkn_esc_idx = None
+ end_tkn_esc_idx = None
+ for i, tkn in enumerate(tokens_split):
+ # only take the first start token
+ if start_tkn_esc_idx is None and re.search(start_token_escaped, tkn, re.IGNORECASE):
+ start_tkn_esc_idx = i
+ # end_token_escaped must exist after start_token_escaped_idx exists
+ if start_tkn_esc_idx is not None and re.search(end_token_escaped, tkn, re.IGNORECASE):
+ end_tkn_esc_idx = i
+ break
+ content = content.group(1).strip(delim)
+ content_confs = confs[start_tkn_esc_idx + 1:end_tkn_esc_idx]
+ content_idxs = idxs[start_tkn_esc_idx + 1:end_tkn_esc_idx]
+ cntsplit = [tk for tk in content.split(delim) if tk]
+
+ assert len(tokens_split) == len(confs) == len(idxs)
+ assert len(cntsplit) == len(content_confs) == len(content_idxs)
+
+ if r"", tkn, re.IGNORECASE)
+ )]
+ leaf_content_idxs = [content_idxs[i] for i, tkn in enumerate(cntsplit) if not (
+ re.search(r"", tkn, re.IGNORECASE)
+ )]
+ for leaf_i, leaf in enumerate(content.split(r"")):
+ leaf_stripped = leaf.strip(delim)
+ if (
+ leaf_stripped in self.decoder.tokenizer.get_added_vocab()
+ and leaf_stripped[0] == "<"
+ and leaf_stripped[-2:] == "/>"
+ ):
+ leaf_stripped = leaf_stripped[1:-2] # for categorical special tokens
+ if not leaf_stripped:
+ continue
+ if self.return_confs and self.return_tokens:
+ output[key].append(
+ [leaf_stripped, leaf_content_confs[leaf_i], leaf_content_idxs[leaf_i]]
+ )
+ elif self.return_confs:
+ output[key].append([leaf_stripped, leaf_content_confs[leaf_i]])
+ elif self.return_tokens:
+ output[key].append([leaf_stripped, leaf_content_idxs[leaf_i]])
+ else:
+ output[key].append(leaf_stripped)
+ if len(output[key]) == 1:
+ output[key] = output[key][0]
+ for i, tkn in enumerate(tokens_split):
+ if re.search(end_token, tkn, re.IGNORECASE):
+ confs = confs[i + 1:]
+ idxs = idxs[i + 1:]
+ break
+ tokens = tokens[tokens.find(end_token) + len(end_token):].strip(delim)
+ if tokens[:6] == r"": # non-leaf nodes
+ return [output] + self.token2json_with_confs(
+ tokens[6:], confs[1:], idxs[1:], delim, is_inner_val=True
+ )
+
+ if len(output):
+ return [output] if is_inner_val else output
+
+ return [] if is_inner_val else {}
@classmethod
def from_pretrained(
@@ -594,18 +892,18 @@ def from_pretrained(
Name of a pretrained model name either registered in huggingface.co. or saved in local,
e.g., `naver-clova-ix/donut-base`, or `naver-clova-ix/donut-base-finetuned-rvlcdip`
"""
- model = super(DonutModel, cls).from_pretrained(pretrained_model_name_or_path, revision="official", *model_args, **kwargs)
+ model = super(DonutModel, cls).from_pretrained(
+ pretrained_model_name_or_path, revision="official", *model_args, **kwargs
+ )
- # truncate or interplolate position embeddings of donut decoder
+ # truncate or interpolate position embeddings of donut decoder
max_length = kwargs.get("max_length", model.config.max_position_embeddings)
- if (
- max_length != model.config.max_position_embeddings
- ): # if max_length of trained model differs max_length you want to train
+ # if max_length of trained model differs max_length you want to train
+ if max_length != model.config.max_position_embeddings:
+ # https://github.com/huggingface/transformers/blob/v4.11.3/src/transformers/models/mbart/modeling_mbart.py#L118-L119
model.decoder.model.model.decoder.embed_positions.weight = torch.nn.Parameter(
model.decoder.resize_bart_abs_pos_emb(
- model.decoder.model.model.decoder.embed_positions.weight,
- max_length
- + 2, # https://github.com/huggingface/transformers/blob/v4.11.3/src/transformers/models/mbart/modeling_mbart.py#L118-L119
+ model.decoder.model.model.decoder.embed_positions.weight, max_length + 2
)
)
model.config.max_position_embeddings = max_length
diff --git a/donut/util.py b/donut/util.py
old mode 100755
new mode 100644
index 16b542fe..2f5f41a2
--- a/donut/util.py
+++ b/donut/util.py
@@ -71,8 +71,13 @@ def __init__(
assert isinstance(ground_truth["gt_parses"], list)
gt_jsons = ground_truth["gt_parses"]
else:
- assert "gt_parse" in ground_truth and isinstance(ground_truth["gt_parse"], dict)
- gt_jsons = [ground_truth["gt_parse"]]
+ assert "gt_parse" in ground_truth and (
+ isinstance(ground_truth["gt_parse"], dict) or
+ isinstance(ground_truth["gt_parse"], list)
+ )
+ gt_jsons = [ground_truth["gt_parse"]] if isinstance(
+ ground_truth["gt_parse"], dict
+ ) else ground_truth["gt_parse"]
self.gt_token_sequences.append(
[
diff --git a/lightning_module.py b/lightning_module.py
index f05e5ce7..cf47daad 100755
--- a/lightning_module.py
+++ b/lightning_module.py
@@ -192,7 +192,7 @@ def val_dataloader(self):
return loaders
@staticmethod
- def seed_worker(wordker_id):
+ def seed_worker(worker_id):
worker_seed = torch.initial_seed() % 2 ** 32
np.random.seed(worker_seed)
random.seed(worker_seed)
diff --git a/setup.py b/setup.py
index 572977f9..2c3c8a52 100755
--- a/setup.py
+++ b/setup.py
@@ -30,9 +30,9 @@ def read_long_description():
description="OCR-free Document Understanding Transformer",
long_description=read_long_description(),
long_description_content_type="text/markdown",
- author="Geewook Kim, Teakgyu Hong, Moonbin Yim, JeongYeon Nam, Jinyoung Park, Jinyeong Yim, Wonseok Hwang, Sangdoo Yun, Dongyoon Han, Seunghyun Park",
- author_email="gwkim.rsrch@gmail.com",
- url="https://github.com/clovaai/donut",
+ author="Geewook Kim, Teakgyu Hong, Moonbin Yim, JeongYeon Nam, Jinyoung Park, Jinyeong Yim, Wonseok Hwang, Sangdoo Yun, Dongyoon Han, Seunghyun Park, Matteo Cacciola",
+ author_email="gwkim.rsrch@gmail.com, matteo.cacciola@gmail.com",
+ url="https://github.com/clovaai/donut, https://github.com/matteocacciola/donut",
license="MIT",
packages=find_packages(
exclude=[
@@ -50,7 +50,7 @@ def read_long_description():
),
python_requires=">=3.7",
install_requires=[
- "transformers>=4.11.3",
+ "transformers>=4.21.1",
"timm",
"datasets[vision]",
"pytorch-lightning>=1.6.4",
diff --git a/test.py b/test.py
index 4d1b52aa..bce122cc 100755
--- a/test.py
+++ b/test.py
@@ -6,16 +6,13 @@
import argparse
import json
import os
-import re
-from pathlib import Path
import numpy as np
import torch
from datasets import load_dataset
-from PIL import Image
from tqdm import tqdm
-from donut import DonutModel, JSONParseEvaluator, load_json, save_json
+from donut import DonutModel, JSONParseEvaluator, save_json
def test(args):
diff --git a/train.py b/train.py
index 0f7b801c..46ffe284 100755
--- a/train.py
+++ b/train.py
@@ -5,14 +5,10 @@
"""
import argparse
import datetime
-import json
import os
-import random
-from io import BytesIO
from os.path import basename
from pathlib import Path
-import numpy as np
import pytorch_lightning as pl
import torch
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
@@ -87,35 +83,36 @@ def train(config):
datasets = {"train": [], "validation": []}
for i, dataset_name_or_path in enumerate(config.dataset_name_or_paths):
task_name = os.path.basename(dataset_name_or_path) # e.g., cord-v2, docvqa, rvlcdip, ...
-
+ task_start_token = config.task_start_tokens[i] if config.get("task_start_tokens", None) else f""
+ prompt_end_token = f""
+
# add categorical special tokens (optional)
if task_name == "rvlcdip":
model_module.model.decoder.add_special_tokens([
- "", "", "", "",
- "", "", "", "",
- "", "", "", "",
+ "", "", "", "",
+ "", "", "", "",
+ "", "", "", "",
"", "", "", ""
])
if task_name == "docvqa":
model_module.model.decoder.add_special_tokens(["", ""])
-
+
for split in ["train", "validation"]:
+ # prompt_end_token is used for ignoring a given prompt in a loss function
+ # for docvqa task, i.e., {"question": {used as a prompt}, "answer": {prediction target}},
+ # set prompt_end_token to ""
datasets[split].append(
DonutDataset(
dataset_name_or_path=dataset_name_or_path,
donut_model=model_module.model,
max_length=config.max_length,
split=split,
- task_start_token=config.task_start_tokens[i]
- if config.get("task_start_tokens", None)
- else f"",
- prompt_end_token="" if "docvqa" in dataset_name_or_path else f"",
+ task_start_token=task_start_token,
+ prompt_end_token=prompt_end_token,
sort_json_key=config.sort_json_key,
)
)
- # prompt_end_token is used for ignoring a given prompt in a loss function
- # for docvqa task, i.e., {"question": {used as a prompt}, "answer": {prediction target}},
- # set prompt_end_token to ""
+
data_module.train_datasets = datasets["train"]
data_module.val_datasets = datasets["validation"]