1
1
import os
2
2
import boto3
3
+ import time
4
+ from cloudpathlib import AnyPath
3
5
from dotenv import load_dotenv
4
6
from attr import define , field
5
7
from typing import List , Dict , Any
@@ -26,6 +28,36 @@ def aws_token_formator(token):
26
28
return token
27
29
28
30
31
+ def is_job_complete (client , job_id ):
32
+ time .sleep (1 )
33
+ response = client .get_document_text_detection (JobId = job_id )
34
+ status = response ["JobStatus" ]
35
+ response = client .get_document_text_detection (JobId = job_id )
36
+ while (status == "IN_PROGRESS" ):
37
+ time .sleep (1 )
38
+ response = client .get_document_text_detection (JobId = job_id )
39
+ status = response ["JobStatus" ]
40
+ return status
41
+
42
+
43
+ def get_job_results (client , job_id ):
44
+ pages = []
45
+ response = client .get_document_text_detection (JobId = job_id )
46
+ pages .append (response )
47
+ next_token = None
48
+ if 'NextToken' in response :
49
+ next_token = response ['NextToken' ]
50
+
51
+ while next_token :
52
+ response = client .\
53
+ get_document_text_detection (JobId = job_id , NextToken = next_token )
54
+ pages .append (response )
55
+ next_token = None
56
+ if 'NextToken' in response :
57
+ next_token = response ['NextToken' ]
58
+ return pages
59
+
60
+
29
61
@define
30
62
class AwsLineSegmenter (AbstractLineSegmenter ):
31
63
"""
@@ -41,6 +73,7 @@ def lines(self) -> List[Dict[str, Any]]:
41
73
lines = []
42
74
for line in self .ocr ["Blocks" ]:
43
75
if line ["BlockType" ] == "LINE" :
76
+
44
77
idx = line .get ("Id" )
45
78
text = line .get ("Text" )
46
79
region = aws_region_extractor (line )
@@ -57,8 +90,9 @@ def _aws_token_extractor(self, relationship):
57
90
for i in relationship :
58
91
for idx in i .get ('Ids' ):
59
92
token = self .mapper .get (idx )
60
- token = aws_token_formator (token )
61
- tokens .append (token )
93
+ if token :
94
+ token = aws_token_formator (token )
95
+ tokens .append (token )
62
96
return tokens
63
97
64
98
@@ -80,37 +114,53 @@ class AwsTextOCR(AbstractTextOCR):
80
114
def __attrs_post_init__ (self ):
81
115
if self .env_file :
82
116
load_dotenv (self .env_file )
83
- self .document = self .reader .read ()
84
117
region = os .getenv ('region_name' )
85
118
access_key = os .getenv ('aws_access_key_id' )
86
119
secret_key = os .getenv ('aws_secret_access_key' )
87
120
self .textract = boto3 .client ('textract' , region_name = region ,
88
121
aws_access_key_id = access_key , aws_secret_access_key = secret_key )
89
- # self.ocr = textract.detect_document_text(
90
- # Document={'Bytes': self.document})
91
122
92
123
@property
93
124
def parse (self ):
94
125
return self ._process_data ()
95
126
96
127
def _process_data (self ):
97
- is_image = False
98
- if isinstance (self .document , bytes ):
99
- self .document = [self .document ]
100
- is_image = True
101
-
102
128
result = {}
103
- for index , document in enumerate (self .document ):
104
- ocr = self .textract .detect_document_text (
105
- Document = {'Bytes' : document })
106
- data = dict (text = self ._get_text (ocr ), lines = self ._get_lines (
107
- ocr ), blocks = self ._get_blocks (ocr ), tokens = self ._get_tokens (ocr ))
129
+ ocr = self ._get_ocr ()
130
+ if not isinstance (ocr , list ):
131
+ ocr = [ocr ]
132
+ for index , page in enumerate (ocr ):
133
+ print ("Processing page {}" .format (index ))
134
+ data = dict (text = self ._get_text (page ), lines = self ._get_lines (
135
+ page ), blocks = self ._get_blocks (page ), tokens = self ._get_tokens (page ))
108
136
result [index ] = data
137
+ return result
138
+
139
+ def _get_ocr (self ):
140
+ storage_type = self .reader .get_storage_type ()
141
+
142
+ if storage_type == 's3' :
143
+ path = AnyPath (self .reader .file )
144
+
145
+ response = self .textract .start_document_text_detection (DocumentLocation = {
146
+ 'S3Object' : {
147
+ 'Bucket' : path .bucket ,
148
+ 'Name' : path .key
149
+ }})
150
+ job_id = response ['JobId' ]
151
+ status = is_job_complete (self .textract , job_id )
152
+ ocr = get_job_results (self .textract , job_id )
109
153
110
- if is_image :
111
- return result [0 ]
112
154
else :
113
- return result
155
+ self .document = self .reader .read ()
156
+ if isinstance (self .document , bytes ):
157
+ self .document = [self .document ]
158
+ ocr = []
159
+ for document in self .document :
160
+ result = self .textract .detect_document_text (
161
+ Document = {'Bytes' : document })
162
+ ocr .append (result )
163
+ return ocr
114
164
115
165
def _get_blocks (self , ocr ):
116
166
try :
0 commit comments