Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
298 changes: 255 additions & 43 deletions ocr_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,18 @@
import os
import sys
import re
from typing import Dict, List, Optional, Tuple
import time
import glob
import json
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Set

import pytesseract
from PIL import Image, ImageEnhance, ImageOps, UnidentifiedImageError
from rapidfuzz import fuzz, process
from difflib import SequenceMatcher

# ---------- CONFIGURATION ----------
MODEL_TEXT = "Searching for NikeAir shoes on Myntra"
IMAGE_PATH = "result.png" # path to your image result
SIMILARITY_THRESHOLD = 80 # 0-100; below this, flag mismatch
# ----------------------------------

Expand Down Expand Up @@ -134,78 +136,288 @@ def suggest_token_corrections(
return suggestions


def main(argv: Optional[List[str]] = None) -> int:
parser = argparse.ArgumentParser(
description="Detect mismatches between model text and OCR text from an image."
)
parser.add_argument("--text", "-t", default=MODEL_TEXT, help="Model's expected text")
parser.add_argument("--image", "-i", default=IMAGE_PATH, help="Path to image")
parser.add_argument("--threshold", "-th", type=int, default=SIMILARITY_THRESHOLD, help="Threshold")
parser.add_argument("--lang", default="eng", help="Tesseract language")
parser.add_argument("--psm", type=int, default=6, help="PSM mode")
parser.add_argument("--oem", type=int, default=3, help="OEM mode")
parser.add_argument("--no-preprocess", action="store_true", help="Disable preprocessing")
parser.add_argument("--debug", action="store_true", help="Debug output")
def read_sidecar_text_for_image(
image_path: str,
*,
text_exts: Tuple[str, ...] = (".txt", ".caption", ".json"),
json_key: str = "text",
) -> Optional[str]:
"""Read sidecar text for an image if available.

Attempts the following for the same basename as the image:
- <basename>.txt or .caption: read raw text
- <basename>.json: read JSON and extract `json_key`
Returns None if nothing usable is found.
"""
image_path_obj = Path(image_path)
base = image_path_obj.with_suffix("")
for ext in text_exts:
candidate = base.with_suffix(ext)
if not candidate.exists():
continue
try:
if candidate.suffix.lower() == ".json":
with candidate.open("r", encoding="utf-8") as f:
data = json.load(f)
value = data
for part in json_key.split(".") if json_key else []:
if isinstance(value, dict) and part in value:
value = value[part]
else:
value = None
break
if isinstance(value, str) and value.strip():
return value
else:
text = candidate.read_text(encoding="utf-8", errors="ignore").strip()
if text:
return text
except Exception:
# Ignore sidecar read errors; treat as missing
continue
return None

args = parser.parse_args(argv)

if not os.path.exists(args.image):
print(f"[ERROR] Image not found: {args.image}")
return 2
def validate_text_vs_image(
expected_text: str,
image_path: str,
*,
threshold: int,
lang: str,
psm: int,
oem: int,
enable_preprocess: bool,
debug: bool,
quiet: bool = False,
) -> Tuple[bool, Dict[str, int]]:
"""Run OCR on the image and compare with expected text.

Returns (is_match, metrics)
"""
if not os.path.exists(image_path):
if not quiet:
print(f"[ERROR] Image not found: {image_path}")
return False, {}

print("[INFO] Extracting text from image...")
try:
image_text_raw = extract_text_from_image(
args.image,
lang=args.lang,
psm=args.psm,
oem=args.oem,
enable_preprocess=not args.no_preprocess,
image_path,
lang=lang,
psm=psm,
oem=oem,
enable_preprocess=enable_preprocess,
)
except Exception as exc:
print(f"[ERROR] OCR failed: {exc}")
return 2
if not quiet:
print(f"[ERROR] OCR failed for {image_path}: {exc}")
return False, {}

if args.debug:
if debug and not quiet:
print("[DEBUG] Raw OCR Text:")
print(image_text_raw.strip())

model_text_norm = normalize(args.text)
model_text_norm = normalize(expected_text)
image_text_norm = normalize(image_text_raw)

metrics = compute_similarity_metrics(model_text_norm, image_text_norm)

if args.debug:
print("\n[DEBUG] Normalized Texts:")
print("Model:", model_text_norm)
print("Image:", image_text_norm)
if not quiet:
print("\n[RESULT] Similarity Metrics (0-100):")
for name, value in metrics.items():
print(f"- {name}: {value}")

print("\n[RESULT] Similarity Metrics (0-100):")
for name, value in metrics.items():
print(f"- {name}: {value}")
score = metrics.get("char_ratio", 0)
is_match = score >= threshold

score = metrics["char_ratio"]
if score < args.threshold:
if not is_match and not quiet:
print("\n❌ Potential mismatch detected!")

diffs = get_word_differences(model_text_norm, image_text_norm)
if diffs:
print("\n[DIFFERENCES]")
for d in diffs:
print(f"- {d['type'].upper()} | Expected: '{d['expected']}' | Found: '{d['found']}'")
print(
f"- {d['type'].upper()} | Expected: '{d['expected']}' | Found: '{d['found']}'"
)

expected_tokens = tokenize(args.text)
expected_tokens = tokenize(expected_text)
found_tokens = tokenize(image_text_raw)
suggestions = suggest_token_corrections(expected_tokens, found_tokens)
useful_suggestions = [s for s in suggestions if s.get("suggested")]
if useful_suggestions:
print("\n[SUGGESTIONS]")
for s in useful_suggestions:
print(f"- '{s['expected']}' → '{s['suggested']}' (score: {s['score']})")
return 1
else:

if is_match and not quiet:
print("\n✅ Text & Image look consistent.")
return 0

return is_match, metrics


def watch_and_validate(
*,
watch_dir: str,
image_globs: List[str],
text_exts: Tuple[str, ...],
json_key: str,
interval_s: float,
threshold: int,
lang: str,
psm: int,
oem: int,
enable_preprocess: bool,
debug: bool,
fail_on_mismatch: bool,
) -> int:
"""Continuously watch a directory for new/updated images and validate with sidecar text."""
directory = Path(watch_dir)
if not directory.exists() or not directory.is_dir():
print(f"[ERROR] Watch directory not found or not a dir: {watch_dir}")
return 2

print(
f"[INFO] Watching '{directory.resolve()}' for images: {', '.join(image_globs)} with sidecars {', '.join(text_exts)}"
)
if json_key:
print(f"[INFO] JSON key for text: '{json_key}'")
print("[INFO] Press Ctrl+C to stop.")

processed: Dict[str, float] = {}
any_mismatch = False

try:
while True:
matched_images: Set[str] = set()
for pattern in image_globs:
for path in glob.glob(str(directory / pattern)):
matched_images.add(os.path.abspath(path))

for image_path in sorted(matched_images):
try:
mtime = os.path.getmtime(image_path)
except FileNotFoundError:
continue

last = processed.get(image_path)
if last is not None and mtime <= last:
continue # unchanged

# Try to read sidecar text
expected_text = read_sidecar_text_for_image(
image_path, text_exts=text_exts, json_key=json_key
)
if expected_text is None:
if debug:
print(f"[DEBUG] No sidecar text found for {image_path}")
# Do not mark as processed so we can retry next loop
continue

print(f"\n[FILE] Validating {image_path}")
is_match, metrics = validate_text_vs_image(
expected_text,
image_path,
threshold=threshold,
lang=lang,
psm=psm,
oem=oem,
enable_preprocess=enable_preprocess,
debug=debug,
quiet=False,
)

processed[image_path] = mtime
if not is_match:
any_mismatch = True
if fail_on_mismatch:
print("[INFO] Exiting due to mismatch and --fail-on-mismatch set.")
return 1

time.sleep(max(0.1, float(interval_s)))
except KeyboardInterrupt:
print("\n[INFO] Stopped watching.")

return 1 if any_mismatch else 0


def main(argv: Optional[List[str]] = None) -> int:
parser = argparse.ArgumentParser(
description="Detect mismatches between model text and OCR text from an image."
)
# Either provide --text and --image for a one-off check, or use --watch mode
parser.add_argument("--text", "-t", default=None, help="Model's expected text")
parser.add_argument("--image", "-i", default=None, help="Path to image")
parser.add_argument("--threshold", "-th", type=int, default=SIMILARITY_THRESHOLD, help="Threshold")
parser.add_argument("--lang", default="eng", help="Tesseract language")
parser.add_argument("--psm", type=int, default=6, help="PSM mode")
parser.add_argument("--oem", type=int, default=3, help="OEM mode")
parser.add_argument("--no-preprocess", action="store_true", help="Disable preprocessing")
parser.add_argument("--debug", action="store_true", help="Debug output")

# Watch mode options
parser.add_argument("--watch", action="store_true", help="Watch a directory for new image+text pairs")
parser.add_argument("--watch-dir", default=".", help="Directory to watch for outputs")
parser.add_argument(
"--image-glob",
default="*.png,*.jpg,*.jpeg",
help="Comma-separated image glob patterns inside watch dir",
)
parser.add_argument(
"--text-exts",
default=".txt,.caption,.json",
help="Comma-separated sidecar text extensions to try",
)
parser.add_argument(
"--json-key",
default="text",
help="JSON key path (dot-separated) to extract text from sidecar JSON",
)
parser.add_argument("--interval", type=float, default=1.0, help="Polling interval seconds in watch mode")
parser.add_argument(
"--fail-on-mismatch",
action="store_true",
help="Exit immediately with non-zero status when a mismatch is detected in watch mode",
)

args = parser.parse_args(argv)

if args.watch:
image_globs = [p.strip() for p in args.image_glob.split(",") if p.strip()]
text_exts = tuple(e if e.startswith(".") else f".{e}" for e in [t.strip() for t in args.text_exts.split(",") if t.strip()])
return watch_and_validate(
watch_dir=args.watch_dir,
image_globs=image_globs,
text_exts=text_exts,
json_key=args.json_key,
interval_s=args.interval,
threshold=args.threshold,
lang=args.lang,
psm=args.psm,
oem=args.oem,
enable_preprocess=not args.no_preprocess,
debug=args.debug,
fail_on_mismatch=args.fail_on_mismatch,
)

# Non-watch single run requires both text and image
if not args.text or not args.image:
print("[ERROR] Either use --watch or provide both --text and --image.")
parser.print_help()
return 2

print("[INFO] Extracting text from image...")
is_match, _ = validate_text_vs_image(
args.text,
args.image,
threshold=args.threshold,
lang=args.lang,
psm=args.psm,
oem=args.oem,
enable_preprocess=not args.no_preprocess,
debug=args.debug,
quiet=False,
)
return 0 if is_match else 1


if __name__ == "__main__":
Expand Down