Skip to content

Commit 5b4611a

Browse files
committed
updates to support dir processing
1 parent 0fa9589 commit 5b4611a

File tree

2 files changed

+92
-23
lines changed

2 files changed

+92
-23
lines changed

ocrpy/io/reader.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,32 @@ class DocumentReader:
1717
credentials = field(default=None)
1818

1919
def read(self):
20-
if self.file.endswith(".png") or self.file.endswith(".jpg"):
20+
file_type = self.get_file_type()
21+
if file_type == 'image':
2122
return self._read_image(self.file)
23+
elif file_type == 'pdf':
24+
return self._read_pdf(self.file)
25+
else:
26+
raise ValueError("File type not supported")
2227

28+
def get_file_type(self):
29+
if self.file.endswith(".png") or self.file.endswith(".jpg"):
30+
file_type = "image"
2331
elif self.file.endswith(".pdf"):
24-
return self._read_pdf(self.file)
32+
file_type = "pdf"
33+
else:
34+
file_type = "unknown"
35+
return file_type
2536

37+
def get_storage_type(self):
38+
storage_type = None
39+
if self.file.startswith("gs://"):
40+
storage_type = 'gs'
41+
elif self.file.startswith("s3://"):
42+
storage_type = 's3'
2643
else:
27-
raise ValueError("File type not supported")
44+
storage_type = 'local'
45+
return storage_type
2846

2947
def _read_image(self, file):
3048
return self._read(file)
@@ -40,10 +58,11 @@ def _read(self, file):
4058
return file_data.read_bytes()
4159

4260
def _get_client(self, file):
43-
if file.startswith("gs://") and self.credentials:
61+
storage_type = self.get_storage_type()
62+
if storage_type == "gs" and self.credentials:
4463
client = GSClient(application_credentials=self.credentials)
4564

46-
elif file.startswith("s3://") and self.credentials:
65+
elif storage_type == 's3' and self.credentials:
4766
load_dotenv(self.credentials)
4867
client = S3Client(aws_access_key_id=os.getenv(
4968
'aws_access_key_id'), aws_secret_access_key=os.getenv('aws_secret_access_key'))

ocrpy/parsers/text/aws_text.py

Lines changed: 68 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import os
22
import boto3
3+
import time
4+
from cloudpathlib import AnyPath
35
from dotenv import load_dotenv
46
from attr import define, field
57
from typing import List, Dict, Any
@@ -26,6 +28,36 @@ def aws_token_formator(token):
2628
return token
2729

2830

31+
def is_job_complete(client, job_id):
32+
time.sleep(1)
33+
response = client.get_document_text_detection(JobId=job_id)
34+
status = response["JobStatus"]
35+
response = client.get_document_text_detection(JobId=job_id)
36+
while(status == "IN_PROGRESS"):
37+
time.sleep(1)
38+
response = client.get_document_text_detection(JobId=job_id)
39+
status = response["JobStatus"]
40+
return status
41+
42+
43+
def get_job_results(client, job_id):
44+
pages = []
45+
response = client.get_document_text_detection(JobId=job_id)
46+
pages.append(response)
47+
next_token = None
48+
if 'NextToken' in response:
49+
next_token = response['NextToken']
50+
51+
while next_token:
52+
response = client.\
53+
get_document_text_detection(JobId=job_id, NextToken=next_token)
54+
pages.append(response)
55+
next_token = None
56+
if 'NextToken' in response:
57+
next_token = response['NextToken']
58+
return pages
59+
60+
2961
@define
3062
class AwsLineSegmenter(AbstractLineSegmenter):
3163
"""
@@ -41,6 +73,7 @@ def lines(self) -> List[Dict[str, Any]]:
4173
lines = []
4274
for line in self.ocr["Blocks"]:
4375
if line["BlockType"] == "LINE":
76+
4477
idx = line.get("Id")
4578
text = line.get("Text")
4679
region = aws_region_extractor(line)
@@ -57,8 +90,9 @@ def _aws_token_extractor(self, relationship):
5790
for i in relationship:
5891
for idx in i.get('Ids'):
5992
token = self.mapper.get(idx)
60-
token = aws_token_formator(token)
61-
tokens.append(token)
93+
if token:
94+
token = aws_token_formator(token)
95+
tokens.append(token)
6296
return tokens
6397

6498

@@ -80,37 +114,53 @@ class AwsTextOCR(AbstractTextOCR):
80114
def __attrs_post_init__(self):
81115
if self.env_file:
82116
load_dotenv(self.env_file)
83-
self.document = self.reader.read()
84117
region = os.getenv('region_name')
85118
access_key = os.getenv('aws_access_key_id')
86119
secret_key = os.getenv('aws_secret_access_key')
87120
self.textract = boto3.client('textract', region_name=region,
88121
aws_access_key_id=access_key, aws_secret_access_key=secret_key)
89-
# self.ocr = textract.detect_document_text(
90-
# Document={'Bytes': self.document})
91122

92123
@property
93124
def parse(self):
94125
return self._process_data()
95126

96127
def _process_data(self):
97-
is_image = False
98-
if isinstance(self.document, bytes):
99-
self.document = [self.document]
100-
is_image = True
101-
102128
result = {}
103-
for index, document in enumerate(self.document):
104-
ocr = self.textract.detect_document_text(
105-
Document={'Bytes': document})
106-
data = dict(text=self._get_text(ocr), lines=self._get_lines(
107-
ocr), blocks=self._get_blocks(ocr), tokens=self._get_tokens(ocr))
129+
ocr = self._get_ocr()
130+
if not isinstance(ocr, list):
131+
ocr = [ocr]
132+
for index, page in enumerate(ocr):
133+
print("Processing page {}".format(index))
134+
data = dict(text=self._get_text(page), lines=self._get_lines(
135+
page), blocks=self._get_blocks(page), tokens=self._get_tokens(page))
108136
result[index] = data
137+
return result
138+
139+
def _get_ocr(self):
140+
storage_type = self.reader.get_storage_type()
141+
142+
if storage_type == 's3':
143+
path = AnyPath(self.reader.file)
144+
145+
response = self.textract.start_document_text_detection(DocumentLocation={
146+
'S3Object': {
147+
'Bucket': path.bucket,
148+
'Name': path.key
149+
}})
150+
job_id = response['JobId']
151+
status = is_job_complete(self.textract, job_id)
152+
ocr = get_job_results(self.textract, job_id)
109153

110-
if is_image:
111-
return result[0]
112154
else:
113-
return result
155+
self.document = self.reader.read()
156+
if isinstance(self.document, bytes):
157+
self.document = [self.document]
158+
ocr = []
159+
for document in self.document:
160+
result = self.textract.detect_document_text(
161+
Document={'Bytes': document})
162+
ocr.append(result)
163+
return ocr
114164

115165
def _get_blocks(self, ocr):
116166
try:

0 commit comments

Comments
 (0)