mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-06-16 09:59:13 +02:00
Merge 28a559c710 into 2e3f45c300
This commit is contained in:
commit
33734d3eeb
42 changed files with 2173 additions and 1656 deletions
BIN
src/eynollah/Amiri-Regular.ttf
Normal file
BIN
src/eynollah/Amiri-Regular.ttf
Normal file
Binary file not shown.
|
|
@ -1,7 +1,3 @@
|
|||
# NOTE: For predictable order of imports of torch/shapely/tensorflow
|
||||
# this must be the first import of the CLI!
|
||||
from ..eynollah_imports import imported_libs
|
||||
|
||||
from .cli import main
|
||||
from .cli_binarize import binarize_cli
|
||||
from .cli_enhance import enhance_cli
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ class EynollahCliCtx:
|
|||
Holds options relevant for all eynollah subcommands
|
||||
"""
|
||||
model_zoo: EynollahModelZoo
|
||||
device: str = ''
|
||||
log_level : Union[str, None] = 'INFO'
|
||||
|
||||
|
||||
|
|
@ -35,6 +36,11 @@ class EynollahCliCtx:
|
|||
type=(str, str, str),
|
||||
multiple=True,
|
||||
)
|
||||
@click.option(
|
||||
"--device",
|
||||
"-D",
|
||||
help="placement of computations in predictors for each model type; if none (by default), will try to use first available GPU or fall back to CPU; set string to force using a device (e.g. 'GPU0', 'GPU1' or 'CPU'). Can also be a comma-separated list of model category to device mappings (e.g. 'col_classifier:CPU,page:GPU0,*:GPU1')",
|
||||
)
|
||||
@click.option(
|
||||
"--log_level",
|
||||
"-l",
|
||||
|
|
@ -42,7 +48,7 @@ class EynollahCliCtx:
|
|||
help="Override log level globally to this",
|
||||
)
|
||||
@click.pass_context
|
||||
def main(ctx, model_basedir, model_overrides, log_level):
|
||||
def main(ctx, model_basedir, model_overrides, device, log_level):
|
||||
"""
|
||||
eynollah - Document Layout Analysis, Image Enhancement, OCR
|
||||
"""
|
||||
|
|
@ -58,6 +64,7 @@ def main(ctx, model_basedir, model_overrides, log_level):
|
|||
# Initialize CLI context
|
||||
ctx.obj = EynollahCliCtx(
|
||||
model_zoo=model_zoo,
|
||||
device=device,
|
||||
log_level=log_level,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
import click
|
||||
|
||||
@click.command()
|
||||
@click.command(context_settings=dict(
|
||||
help_option_names=['-h', '--help'],
|
||||
show_default=True))
|
||||
@click.option(
|
||||
'--patches/--no-patches',
|
||||
default=True,
|
||||
|
|
@ -31,11 +33,6 @@ import click
|
|||
help="overwrite (instead of skipping) if output xml exists",
|
||||
is_flag=True,
|
||||
)
|
||||
@click.option(
|
||||
"--device",
|
||||
"-D",
|
||||
help="placement of computations in predictors for each model type; if none (by default), will try to use first available GPU or fall back to CPU; set string to force using a device (e.g. 'GPU0', 'GPU1' or 'CPU'). Can also be a comma-separated list of model category to device mappings (e.g. 'col_classifier:CPU,page:GPU0,*:GPU1')",
|
||||
)
|
||||
@click.pass_context
|
||||
def binarize_cli(
|
||||
ctx,
|
||||
|
|
@ -44,14 +41,14 @@ def binarize_cli(
|
|||
dir_in,
|
||||
output,
|
||||
overwrite,
|
||||
device,
|
||||
):
|
||||
"""
|
||||
Binarize images with a ML model
|
||||
"""
|
||||
from ..sbb_binarize import SbbBinarizer
|
||||
assert bool(input_image) != bool(dir_in), "Either -i (single input) or -di (directory) must be provided, but not both."
|
||||
binarizer = SbbBinarizer(model_zoo=ctx.obj.model_zoo, device=device)
|
||||
binarizer = SbbBinarizer(model_zoo=ctx.obj.model_zoo,
|
||||
device=ctx.obj.device)
|
||||
binarizer.run(
|
||||
image_filename=input_image,
|
||||
use_patches=patches,
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
import click
|
||||
|
||||
@click.command()
|
||||
@click.command(context_settings=dict(
|
||||
help_option_names=['-h', '--help'],
|
||||
show_default=True))
|
||||
@click.option(
|
||||
"--image",
|
||||
"-i",
|
||||
|
|
@ -46,13 +48,8 @@ import click
|
|||
is_flag=True,
|
||||
help="save the enhanced image in original image size",
|
||||
)
|
||||
@click.option(
|
||||
"--device",
|
||||
"-D",
|
||||
help="placement of computations in predictors for each model type; if none (by default), will try to use first available GPU or fall back to CPU; set string to force using a device (e.g. 'GPU0', 'GPU1' or 'CPU'). Can also be a comma-separated list of model category to device mappings (e.g. 'col_classifier:CPU,page:GPU0,*:GPU1')",
|
||||
)
|
||||
@click.pass_context
|
||||
def enhance_cli(ctx, image, out, overwrite, dir_in, num_col_upper, num_col_lower, save_org_scale, device):
|
||||
def enhance_cli(ctx, image, out, overwrite, dir_in, num_col_upper, num_col_lower, save_org_scale):
|
||||
"""
|
||||
Enhance image
|
||||
"""
|
||||
|
|
@ -60,10 +57,10 @@ def enhance_cli(ctx, image, out, overwrite, dir_in, num_col_upper, num_col_lower
|
|||
from ..image_enhancer import Enhancer
|
||||
enhancer = Enhancer(
|
||||
model_zoo=ctx.obj.model_zoo,
|
||||
device=ctx.obj.device,
|
||||
num_col_upper=num_col_upper,
|
||||
num_col_lower=num_col_lower,
|
||||
save_org_scale=save_org_scale,
|
||||
device=device,
|
||||
)
|
||||
enhancer.run(overwrite=overwrite,
|
||||
dir_in=dir_in,
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
import click
|
||||
|
||||
@click.command()
|
||||
@click.command(context_settings=dict(
|
||||
help_option_names=['-h', '--help'],
|
||||
show_default=True))
|
||||
@click.option(
|
||||
"--image",
|
||||
"-i",
|
||||
|
|
@ -30,36 +32,40 @@ import click
|
|||
@click.option(
|
||||
"--save_images",
|
||||
"-si",
|
||||
help="if a directory is given, images in documents will be cropped and saved there",
|
||||
help="if a directory is given, cropped images of pages will be saved there",
|
||||
type=click.Path(exists=True, file_okay=False),
|
||||
)
|
||||
@click.option(
|
||||
"--enable-plotting/--disable-plotting",
|
||||
"-ep/-noep",
|
||||
"--enable-plotting",
|
||||
"-ep",
|
||||
is_flag=True,
|
||||
help="If set, will plot intermediary files and images",
|
||||
help="plot intermediary diagnostic images to files",
|
||||
)
|
||||
@click.option(
|
||||
"--input_binary/--input-RGB",
|
||||
"-ib/-irgb",
|
||||
"--input_binary",
|
||||
"-ib",
|
||||
is_flag=True,
|
||||
help="In general, eynollah uses RGB as input but if the input document is very dark, very bright or for any other reason you can turn on input binarization. When this flag is set, eynollah will binarize the RGB input document, you should always provide RGB images to eynollah.",
|
||||
help="In general, eynollah uses RGB as input, but if the input document is very dark, very bright or for any other reason you can turn on internal binarization here. When set, eynollah will binarize the RGB input document first.",
|
||||
)
|
||||
@click.option(
|
||||
"--ignore_page_extraction/--extract_page_included",
|
||||
"-ipe/-epi",
|
||||
"--ignore_page_extraction",
|
||||
"-ipe",
|
||||
is_flag=True,
|
||||
help="if this parameter set to true, this tool would ignore page extraction",
|
||||
help="ignore page extraction (cropping via page frame detection model)",
|
||||
)
|
||||
@click.option(
|
||||
"--num_col_upper",
|
||||
"-ncu",
|
||||
help="lower limit of columns in document image",
|
||||
default=0,
|
||||
type=click.IntRange(min=0),
|
||||
help="lower limit of columns in document image; 0 means autodetected from model",
|
||||
)
|
||||
@click.option(
|
||||
"--num_col_lower",
|
||||
"-ncl",
|
||||
help="upper limit of columns in document image",
|
||||
default=0,
|
||||
type=click.IntRange(min=0),
|
||||
help="upper limit of columns in document image; 0 means autodetected from model",
|
||||
)
|
||||
@click.pass_context
|
||||
def extract_images_cli(
|
||||
|
|
|
|||
|
|
@ -172,11 +172,6 @@ import click
|
|||
type=click.FloatRange(min=0),
|
||||
help="abort when number of failed images exceeds this value (if >=1) or ratio of failed over total images exceeds this value (if <1); 0 means ignore failures",
|
||||
)
|
||||
@click.option(
|
||||
"--device",
|
||||
"-D",
|
||||
help="placement of computations in predictors for each model type; if none (by default), will try to use first available GPU or fall back to CPU; set string to force using a device (e.g. 'GPU0', 'GPU1' or 'CPU'). Can also be a comma-separated list of model category to device mappings (e.g. 'col_classifier:CPU,page:GPU0,*:GPU1')",
|
||||
)
|
||||
@click.pass_context
|
||||
def layout_cli(
|
||||
ctx,
|
||||
|
|
@ -207,7 +202,6 @@ def layout_cli(
|
|||
ignore_page_extraction,
|
||||
num_jobs,
|
||||
halt_fail,
|
||||
device,
|
||||
):
|
||||
"""
|
||||
Detect Layout (with optional image enhancement and reading order detection)
|
||||
|
|
@ -223,7 +217,7 @@ def layout_cli(
|
|||
assert bool(image) != bool(dir_in), "Either -i (single input) or -di (directory) must be provided, but not both."
|
||||
eynollah = Eynollah(
|
||||
model_zoo=ctx.obj.model_zoo,
|
||||
device=device,
|
||||
device=ctx.obj.device,
|
||||
enable_plotting=enable_plotting,
|
||||
allow_enhancement=allow_enhancement,
|
||||
curved_line=curved_line,
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
import click
|
||||
|
||||
@click.command()
|
||||
@click.command(context_settings=dict(
|
||||
help_option_names=['-h', '--help'],
|
||||
show_default=True))
|
||||
@click.option(
|
||||
"--image",
|
||||
"-i",
|
||||
|
|
@ -16,7 +18,7 @@ import click
|
|||
@click.option(
|
||||
"--dir_in_bin",
|
||||
"-dib",
|
||||
help=("directory of binarized images (in addition to --dir_in for RGB images; filename stems must match the RGB image files, with '.png' \n Perform prediction using both RGB and binary images. (This does not necessarily improve results, however it may be beneficial for certain document images."),
|
||||
help=("directory of binarized images (in addition to --dir_in for RGB images; filename stems must match the RGB image files, with '.png'. \n Perform prediction using both RGB and binary images. (This may improve results for certain document images.)"),
|
||||
type=click.Path(exists=True, file_okay=False),
|
||||
)
|
||||
@click.option(
|
||||
|
|
@ -47,25 +49,29 @@ import click
|
|||
)
|
||||
@click.option(
|
||||
"--tr_ocr",
|
||||
"-trocr/-notrocr",
|
||||
"-trocr",
|
||||
is_flag=True,
|
||||
help="if this parameter set to true, transformer ocr will be applied, otherwise cnn_rnn model.",
|
||||
help="use transformer OCR (instead of classic CNN-RNN) model",
|
||||
)
|
||||
@click.option(
|
||||
"--do_not_mask_with_textline_contour",
|
||||
"-nmtc/-mtc",
|
||||
"-nmtc",
|
||||
is_flag=True,
|
||||
help="if this parameter set to true, cropped textline images will not be masked with textline contour.",
|
||||
help="skip masking each cropped textline image with its corresponding textline contour",
|
||||
)
|
||||
@click.option(
|
||||
"--batch_size",
|
||||
"-bs",
|
||||
default=0,
|
||||
type=click.IntRange(min=0),
|
||||
help="number of inference batch size. Default b_s for trocr and cnn_rnn models are 2 and 8 respectively",
|
||||
)
|
||||
@click.option(
|
||||
"--min_conf_value_of_textline_text",
|
||||
"-min_conf",
|
||||
help="minimum OCR confidence value. Text lines with a confidence value lower than this threshold will not be included in the output XML file.",
|
||||
default=0.3,
|
||||
type=click.FloatRange(min=0.0, max=1.0),
|
||||
help="minimum OCR confidence threshold. Text lines with a lower confidence value will not be included in the output XML file.",
|
||||
)
|
||||
@click.pass_context
|
||||
def ocr_cli(
|
||||
|
|
@ -85,14 +91,16 @@ def ocr_cli(
|
|||
"""
|
||||
Recognize text with a CNN/RNN or transformer ML model.
|
||||
"""
|
||||
assert bool(image) ^ bool(dir_in), "Either -i (single image) or -di (directory) must be provided, but not both."
|
||||
assert bool(image) != bool(dir_in), "Either -i (single image) or -di (directory) must be provided, but not both."
|
||||
from ..eynollah_ocr import Eynollah_ocr
|
||||
eynollah_ocr = Eynollah_ocr(
|
||||
model_zoo=ctx.obj.model_zoo,
|
||||
device=ctx.obj.device,
|
||||
tr_ocr=tr_ocr,
|
||||
do_not_mask_with_textline_contour=do_not_mask_with_textline_contour,
|
||||
batch_size=batch_size,
|
||||
min_conf_value_of_textline_text=min_conf_value_of_textline_text)
|
||||
min_conf_value_of_textline_text=min_conf_value_of_textline_text,
|
||||
)
|
||||
eynollah_ocr.run(overwrite=overwrite,
|
||||
dir_in=dir_in,
|
||||
dir_in_bin=dir_in_bin,
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
import click
|
||||
|
||||
@click.command()
|
||||
@click.command(context_settings=dict(
|
||||
help_option_names=['-h', '--help'],
|
||||
show_default=True))
|
||||
@click.option(
|
||||
"--input",
|
||||
"-i",
|
||||
|
|
@ -25,9 +27,10 @@ def readingorder_cli(ctx, input, dir_in, out):
|
|||
"""
|
||||
Generate ReadingOrder with a ML model
|
||||
"""
|
||||
from ..mb_ro_on_layout import machine_based_reading_order_on_layout
|
||||
from ..mb_ro_on_layout import Reorder
|
||||
assert bool(input) != bool(dir_in), "Either -i (single input) or -di (directory) must be provided, but not both."
|
||||
orderer = machine_based_reading_order_on_layout(model_zoo=ctx.obj.model_zoo)
|
||||
orderer = Reorder(model_zoo=ctx.obj.model_zoo,
|
||||
device=ctx.obj.device)
|
||||
orderer.run(xml_filename=input,
|
||||
dir_in=dir_in,
|
||||
dir_out=out,
|
||||
|
|
|
|||
|
|
@ -9,7 +9,6 @@ import os
|
|||
import time
|
||||
from typing import Optional
|
||||
from pathlib import Path
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
|
|
@ -64,12 +63,6 @@ class EynollahImageExtractor(Eynollah):
|
|||
|
||||
t_start = time.time()
|
||||
|
||||
try:
|
||||
for device in tf.config.list_physical_devices('GPU'):
|
||||
tf.config.experimental.set_memory_growth(device, True)
|
||||
except:
|
||||
self.logger.warning("no GPU device available")
|
||||
|
||||
self.logger.info("Loading models...")
|
||||
self.setup_models()
|
||||
self.logger.info(f"Model initialization complete ({time.time() - t_start:.1f}s)")
|
||||
|
|
|
|||
|
|
@ -1148,7 +1148,6 @@ class Eynollah:
|
|||
boxes,
|
||||
textline_mask_tot
|
||||
):
|
||||
assert np.any(textline_mask_tot)
|
||||
self.logger.debug("enter do_order_of_regions")
|
||||
contours_only_text_parent = ensure_array(contours_only_text_parent)
|
||||
contours_only_text_parent_h = ensure_array(contours_only_text_parent_h)
|
||||
|
|
|
|||
|
|
@ -1,13 +0,0 @@
|
|||
"""
|
||||
Load libraries with possible race conditions once. This must be imported as the first module of eynollah.
|
||||
"""
|
||||
import os
|
||||
os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15
|
||||
|
||||
from ocrd_utils import tf_disable_interactive_logs
|
||||
from torch import *
|
||||
tf_disable_interactive_logs()
|
||||
import tensorflow.keras
|
||||
from shapely import *
|
||||
imported_libs = True
|
||||
__all__ = ['imported_libs']
|
||||
|
|
@ -14,20 +14,21 @@ from cv2.typing import MatLike
|
|||
from xml.etree import ElementTree as ET
|
||||
from PIL import Image, ImageDraw
|
||||
import numpy as np
|
||||
from eynollah.model_zoo import EynollahModelZoo
|
||||
from eynollah.utils.font import get_font
|
||||
from eynollah.utils.xml import etree_namespace_for_element_tag
|
||||
try:
|
||||
import torch
|
||||
except ImportError:
|
||||
torch = None
|
||||
from ocrd_utils import polygon_from_points, xywh_from_polygon
|
||||
|
||||
|
||||
from .utils import is_image_filename
|
||||
from .eynollah import Eynollah
|
||||
from .model_zoo import EynollahModelZoo
|
||||
from .utils import (
|
||||
is_image_filename,
|
||||
batched,
|
||||
pairwise,
|
||||
)
|
||||
from .utils.font import get_font
|
||||
from .utils.xml import etree_namespace_for_element_tag
|
||||
from .utils.resize import resize_image
|
||||
from .utils.utils_ocr import (
|
||||
break_curved_line_into_small_pieces_and_then_merge,
|
||||
decode_batch_predictions,
|
||||
fit_text_single_line,
|
||||
get_contours_and_bounding_boxes,
|
||||
get_orientation_moments,
|
||||
|
|
@ -40,49 +41,45 @@ from .utils.utils_ocr import (
|
|||
@dataclass
|
||||
class EynollahOcrResult:
|
||||
extracted_texts_merged: List
|
||||
extracted_conf_value_merged: Optional[List]
|
||||
extracted_confs_merged: List
|
||||
cropped_lines_region_indexer: List
|
||||
total_bb_coordinates:List
|
||||
|
||||
class Eynollah_ocr:
|
||||
class Eynollah_ocr(Eynollah):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
model_zoo: EynollahModelZoo,
|
||||
tr_ocr=False,
|
||||
batch_size: Optional[int]=None,
|
||||
batch_size: int=0,
|
||||
do_not_mask_with_textline_contour: bool=False,
|
||||
min_conf_value_of_textline_text : Optional[float]=None,
|
||||
min_conf_value_of_textline_text : float=0.3,
|
||||
logger: Optional[Logger]=None,
|
||||
device: str = '',
|
||||
):
|
||||
self.tr_ocr = tr_ocr
|
||||
# masking for OCR and GT generation, relevant for skewed lines and bounding boxes
|
||||
self.do_not_mask_with_textline_contour = do_not_mask_with_textline_contour
|
||||
self.logger = logger if logger else getLogger('eynollah.ocr')
|
||||
|
||||
self.min_conf_value_of_textline_text = min_conf_value_of_textline_text
|
||||
self.b_s = batch_size or 2 if tr_ocr else 8
|
||||
|
||||
self.model_zoo = model_zoo
|
||||
self.setup_models(device=device)
|
||||
|
||||
self.min_conf_value_of_textline_text = min_conf_value_of_textline_text if min_conf_value_of_textline_text else 0.3
|
||||
self.b_s = 2 if batch_size is None and tr_ocr else 8 if batch_size is None else batch_size
|
||||
|
||||
if tr_ocr:
|
||||
self.model_zoo.load_models('trocr_processor')
|
||||
self.model_zoo.load_models(['ocr', 'tr'])
|
||||
self.model_zoo.get('ocr').to(self.device)
|
||||
def setup_models(self, device=''):
|
||||
if self.tr_ocr:
|
||||
self.model_zoo.load_models(('ocr', 'tr'),
|
||||
device=device)
|
||||
else:
|
||||
self.model_zoo.load_models('ocr')
|
||||
self.model_zoo.load_models('num_to_char')
|
||||
self.model_zoo.load_models('characters')
|
||||
self.end_character = len(self.model_zoo.get('characters')) + 2
|
||||
self.model_zoo.load_models('ocr',
|
||||
'binarization',
|
||||
device=device)
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
assert torch
|
||||
if torch.cuda.is_available():
|
||||
self.logger.info("Using GPU acceleration")
|
||||
return torch.device("cuda:0")
|
||||
else:
|
||||
self.logger.info("Using CPU processing")
|
||||
return torch.device("cpu")
|
||||
return self.model_zoo.get('ocr').device
|
||||
|
||||
def run_trocr(
|
||||
self,
|
||||
|
|
@ -90,178 +87,73 @@ class Eynollah_ocr:
|
|||
img: MatLike,
|
||||
page_tree: ET.ElementTree,
|
||||
page_ns,
|
||||
tr_ocr_input_height_and_width,
|
||||
) -> EynollahOcrResult:
|
||||
|
||||
total_bb_coordinates = []
|
||||
|
||||
|
||||
cropped_lines = []
|
||||
cropped_lines_region_indexer = []
|
||||
cropped_lines_meging_indexing = []
|
||||
|
||||
extracted_texts = []
|
||||
for n_region, region in enumerate(page_tree.getroot().iter('{%s}TextRegion' % page_ns)):
|
||||
for n_line, line in enumerate(region.iter('{%s}TextLine' % page_ns)):
|
||||
cropped_lines_region_indexer.append(n_region)
|
||||
|
||||
indexer_text_region = 0
|
||||
indexer_b_s = 0
|
||||
coords = line.find('{%s}Coords' % page_ns)
|
||||
if coords is None:
|
||||
self.logger.warning("region '%s' line '%s' has no Coords", region.attrib['id'], line.attrib['id'])
|
||||
continue
|
||||
poly = np.array(polygon_from_points(coords.attrib['points'])).astype(int)
|
||||
cont = poly[:, np.newaxis]
|
||||
xywh = xywh_from_polygon(poly)
|
||||
x, y, w, h = xywh['x'], xywh['y'], xywh['w'], xywh['h']
|
||||
|
||||
for nn in page_tree.getroot().iter(f'{{{page_ns}}}TextRegion'):
|
||||
for child_textregion in nn:
|
||||
if child_textregion.tag.endswith("TextLine"):
|
||||
total_bb_coordinates.append([x, y, w, h])
|
||||
|
||||
for child_textlines in child_textregion:
|
||||
if child_textlines.tag.endswith("Coords"):
|
||||
cropped_lines_region_indexer.append(indexer_text_region)
|
||||
p_h=child_textlines.attrib['points'].split(' ')
|
||||
textline_coords = np.array( [ [int(x.split(',')[0]),
|
||||
int(x.split(',')[1]) ]
|
||||
for x in p_h] )
|
||||
x,y,w,h = cv2.boundingRect(textline_coords)
|
||||
img_crop = img[y: y + h, x: x + w]
|
||||
if not self.do_not_mask_with_textline_contour:
|
||||
mask_poly = np.zeros(img_crop.shape[:2], dtype=np.uint8)
|
||||
mask_poly = cv2.fillPoly(mask_poly, pts=[cont - [x, y]], color=1)
|
||||
img_crop[mask_poly == 0] = 255 # FIXME: or median color?
|
||||
|
||||
total_bb_coordinates.append([x,y,w,h])
|
||||
|
||||
h2w_ratio = h/float(w)
|
||||
|
||||
img_poly_on_img = np.copy(img)
|
||||
mask_poly = np.zeros(img.shape)
|
||||
mask_poly = cv2.fillPoly(mask_poly, pts=[textline_coords], color=(1, 1, 1))
|
||||
|
||||
mask_poly = mask_poly[y:y+h, x:x+w, :]
|
||||
img_crop = img_poly_on_img[y:y+h, x:x+w, :]
|
||||
img_crop[mask_poly==0] = 255
|
||||
|
||||
self.logger.debug("processing %d lines for '%s'",
|
||||
len(cropped_lines), nn.attrib['id'])
|
||||
if h2w_ratio > 0.1:
|
||||
cropped_lines.append(resize_image(img_crop,
|
||||
tr_ocr_input_height_and_width,
|
||||
tr_ocr_input_height_and_width) )
|
||||
if h > 0.1 * w:
|
||||
cropped_lines.append(img_crop)
|
||||
cropped_lines_meging_indexing.append(0)
|
||||
indexer_b_s+=1
|
||||
if indexer_b_s==self.b_s:
|
||||
imgs = cropped_lines[:]
|
||||
cropped_lines = []
|
||||
indexer_b_s = 0
|
||||
|
||||
pixel_values_merged = self.model_zoo.get('trocr_processor')(imgs, return_tensors="pt").pixel_values
|
||||
generated_ids_merged = self.model_zoo.get('ocr').generate(
|
||||
pixel_values_merged.to(self.device))
|
||||
generated_text_merged = self.model_zoo.get('trocr_processor').batch_decode(
|
||||
generated_ids_merged, skip_special_tokens=True)
|
||||
|
||||
extracted_texts = extracted_texts + generated_text_merged
|
||||
|
||||
else:
|
||||
splited_images, _ = return_textlines_split_if_needed(img_crop, None)
|
||||
#print(splited_images)
|
||||
if splited_images:
|
||||
cropped_lines.append(resize_image(splited_images[0],
|
||||
tr_ocr_input_height_and_width,
|
||||
tr_ocr_input_height_and_width))
|
||||
cropped_lines.append(splited_images[0])
|
||||
cropped_lines.append(splited_images[1])
|
||||
cropped_lines_meging_indexing.append(1)
|
||||
indexer_b_s+=1
|
||||
|
||||
if indexer_b_s==self.b_s:
|
||||
imgs = cropped_lines[:]
|
||||
cropped_lines = []
|
||||
indexer_b_s = 0
|
||||
|
||||
pixel_values_merged = self.model_zoo.get('trocr_processor')(imgs, return_tensors="pt").pixel_values
|
||||
generated_ids_merged = self.model_zoo.get('ocr').generate(
|
||||
pixel_values_merged.to(self.device))
|
||||
generated_text_merged = self.model_zoo.get('trocr_processor').batch_decode(
|
||||
generated_ids_merged, skip_special_tokens=True)
|
||||
|
||||
extracted_texts = extracted_texts + generated_text_merged
|
||||
|
||||
|
||||
cropped_lines.append(resize_image(splited_images[1],
|
||||
tr_ocr_input_height_and_width,
|
||||
tr_ocr_input_height_and_width))
|
||||
cropped_lines_meging_indexing.append(-1)
|
||||
indexer_b_s+=1
|
||||
|
||||
if indexer_b_s==self.b_s:
|
||||
imgs = cropped_lines[:]
|
||||
cropped_lines = []
|
||||
indexer_b_s = 0
|
||||
|
||||
pixel_values_merged = self.model_zoo.get('trocr_processor')(imgs, return_tensors="pt").pixel_values
|
||||
generated_ids_merged = self.model_zoo.get('ocr').generate(
|
||||
pixel_values_merged.to(self.device))
|
||||
generated_text_merged = self.model_zoo.get('trocr_processor').batch_decode(
|
||||
generated_ids_merged, skip_special_tokens=True)
|
||||
|
||||
extracted_texts = extracted_texts + generated_text_merged
|
||||
|
||||
else:
|
||||
cropped_lines.append(img_crop)
|
||||
cropped_lines_meging_indexing.append(0)
|
||||
indexer_b_s+=1
|
||||
|
||||
if indexer_b_s==self.b_s:
|
||||
imgs = cropped_lines[:]
|
||||
cropped_lines = []
|
||||
indexer_b_s = 0
|
||||
|
||||
pixel_values_merged = self.model_zoo.get('trocr_processor')(imgs, return_tensors="pt").pixel_values
|
||||
generated_ids_merged = self.model_zoo.get('ocr').generate(
|
||||
pixel_values_merged.to(self.device))
|
||||
generated_text_merged = self.model_zoo.get('trocr_processor').batch_decode(
|
||||
generated_ids_merged, skip_special_tokens=True)
|
||||
|
||||
extracted_texts = extracted_texts + generated_text_merged
|
||||
|
||||
|
||||
|
||||
indexer_text_region = indexer_text_region +1
|
||||
|
||||
if indexer_b_s!=0:
|
||||
imgs = cropped_lines[:]
|
||||
cropped_lines = []
|
||||
indexer_b_s = 0
|
||||
|
||||
pixel_values_merged = self.model_zoo.get('trocr_processor')(imgs, return_tensors="pt").pixel_values
|
||||
generated_ids_merged = self.model_zoo.get('ocr').generate(pixel_values_merged.to(self.device))
|
||||
generated_text_merged = self.model_zoo.get('trocr_processor').batch_decode(generated_ids_merged, skip_special_tokens=True)
|
||||
|
||||
extracted_texts = extracted_texts + generated_text_merged
|
||||
|
||||
####extracted_texts = []
|
||||
####n_iterations = math.ceil(len(cropped_lines) / self.b_s)
|
||||
|
||||
####for i in range(n_iterations):
|
||||
####if i==(n_iterations-1):
|
||||
####n_start = i*self.b_s
|
||||
####imgs = cropped_lines[n_start:]
|
||||
####else:
|
||||
####n_start = i*self.b_s
|
||||
####n_end = (i+1)*self.b_s
|
||||
####imgs = cropped_lines[n_start:n_end]
|
||||
####pixel_values_merged = self.model_zoo.get('trocr_processor')(imgs, return_tensors="pt").pixel_values
|
||||
####generated_ids_merged = self.model_ocr.generate(
|
||||
#### pixel_values_merged.to(self.device))
|
||||
####generated_text_merged = self.model_zoo.get('trocr_processor').batch_decode(
|
||||
#### generated_ids_merged, skip_special_tokens=True)
|
||||
|
||||
####extracted_texts = extracted_texts + generated_text_merged
|
||||
|
||||
extracted_texts = []
|
||||
extracted_confs = []
|
||||
self.logger.debug("processing %d lines for %d regions",
|
||||
len(cropped_lines), len(set(cropped_lines_region_indexer)))
|
||||
for imgs in batched(cropped_lines, self.b_s):
|
||||
text, conf = self.model_zoo.get('ocr').predict(imgs)
|
||||
extracted_confs.extend(conf)
|
||||
extracted_texts.extend(text)
|
||||
del cropped_lines
|
||||
gc.collect()
|
||||
|
||||
extracted_texts_merged = [extracted_texts[ind]
|
||||
if cropped_lines_meging_indexing[ind]==0
|
||||
else extracted_texts[ind]+" "+extracted_texts[ind+1]
|
||||
if cropped_lines_meging_indexing[ind]==1
|
||||
else None
|
||||
for ind in range(len(cropped_lines_meging_indexing))]
|
||||
|
||||
extracted_texts_merged = [ind for ind in extracted_texts_merged if ind is not None]
|
||||
#print(extracted_texts_merged, len(extracted_texts_merged))
|
||||
if cropped_lines_meging_indexing[ind] == 0
|
||||
else extracted_texts[ind] + " " + extracted_texts[ind + 1]
|
||||
for ind in range(len(cropped_lines_meging_indexing))
|
||||
if cropped_lines_meging_indexing[ind] >= 0]
|
||||
extracted_confs_merged = [extracted_confs[ind]
|
||||
if cropped_lines_meging_indexing[ind] == 0
|
||||
else 0.5 * (extracted_confs[ind] + extracted_confs[ind + 1])
|
||||
for ind in range(len(cropped_lines_meging_indexing))
|
||||
if cropped_lines_meging_indexing[ind] >= 0]
|
||||
|
||||
return EynollahOcrResult(
|
||||
extracted_texts_merged=extracted_texts_merged,
|
||||
extracted_conf_value_merged=None,
|
||||
extracted_confs_merged=extracted_confs_merged,
|
||||
cropped_lines_region_indexer=cropped_lines_region_indexer,
|
||||
total_bb_coordinates=total_bb_coordinates,
|
||||
)
|
||||
|
|
@ -273,367 +165,163 @@ class Eynollah_ocr:
|
|||
img_bin: Optional[MatLike],
|
||||
page_tree: ET.ElementTree,
|
||||
page_ns,
|
||||
image_width,
|
||||
image_height,
|
||||
) -> EynollahOcrResult:
|
||||
_, image_height, image_width, _ = self.model_zoo.get('ocr').input_shape
|
||||
|
||||
total_bb_coordinates = []
|
||||
|
||||
cropped_lines = []
|
||||
img_crop_bin = None
|
||||
imgs_bin = None
|
||||
imgs_bin_ver_flipped = None
|
||||
cropped_lines_rgb = []
|
||||
cropped_lines_bin = []
|
||||
cropped_lines_ver_index = []
|
||||
cropped_lines_region_indexer = []
|
||||
cropped_lines_meging_indexing = []
|
||||
|
||||
indexer_text_region = 0
|
||||
for nn in page_tree.getroot().iter(f'{{{page_ns}}}TextRegion'):
|
||||
try:
|
||||
type_textregion = nn.attrib['type']
|
||||
except:
|
||||
type_textregion = 'paragraph'
|
||||
for child_textregion in nn:
|
||||
if child_textregion.tag.endswith("TextLine"):
|
||||
for child_textlines in child_textregion:
|
||||
if child_textlines.tag.endswith("Coords"):
|
||||
cropped_lines_region_indexer.append(indexer_text_region)
|
||||
p_h=child_textlines.attrib['points'].split(' ')
|
||||
textline_coords = np.array( [ [int(x.split(',')[0]),
|
||||
int(x.split(',')[1]) ]
|
||||
for x in p_h] )
|
||||
img_rgb = img # cosmetic
|
||||
if img_bin is None:
|
||||
# run ad-hoc binarization
|
||||
self.logger.info("running binarization for ensemble input")
|
||||
img_bin = self.do_prediction(True, img, self.model_zoo.get("binarization"),
|
||||
n_batch_inference=5)
|
||||
img_bin = np.repeat(img_bin[:, :, np.newaxis], 3, axis=2)
|
||||
img_bin = 255 * (img_bin == 0).astype(np.uint8)
|
||||
|
||||
x,y,w,h = cv2.boundingRect(textline_coords)
|
||||
for n_region, region in enumerate(page_tree.getroot().iter('{%s}TextRegion' % page_ns)):
|
||||
type_textregion = region.attrib.get('type', 'paragraph')
|
||||
for n_line, line in enumerate(region.iter('{%s}TextLine' % page_ns)):
|
||||
cropped_lines_region_indexer.append(n_region)
|
||||
|
||||
coords = line.find('{%s}Coords' % page_ns)
|
||||
if coords is None:
|
||||
self.logger.warning("region '%s' line '%s' has no Coords", region.attrib['id'], line.attrib['id'])
|
||||
continue
|
||||
poly = np.array(polygon_from_points(coords.attrib['points'])).astype(int)
|
||||
cont = poly[:, np.newaxis]
|
||||
xywh = xywh_from_polygon(poly)
|
||||
x, y, w, h = xywh['x'], xywh['y'], xywh['w'], xywh['h']
|
||||
|
||||
angle_radians = math.atan2(h, w)
|
||||
# Convert to degrees
|
||||
angle_degrees = math.degrees(angle_radians)
|
||||
if type_textregion=='drop-capital':
|
||||
angle_degrees = 0
|
||||
|
||||
total_bb_coordinates.append([x,y,w,h])
|
||||
total_bb_coordinates.append([x, y, w, h])
|
||||
|
||||
w_scaled = w * image_height/float(h)
|
||||
w_scaled = w * image_height / float(h)
|
||||
|
||||
img_poly_on_img = np.copy(img)
|
||||
if img_bin:
|
||||
img_poly_on_img_bin = np.copy(img_bin)
|
||||
img_crop_bin = img_poly_on_img_bin[y:y+h, x:x+w, :]
|
||||
img_crop_rgb = img_rgb[y: y + h, x: x + w]
|
||||
img_crop_bin = img_bin[y: y + h, x: x + w]
|
||||
|
||||
mask_poly = np.zeros(img.shape)
|
||||
mask_poly = cv2.fillPoly(mask_poly, pts=[textline_coords], color=(1, 1, 1))
|
||||
|
||||
|
||||
mask_poly = mask_poly[y:y+h, x:x+w, :]
|
||||
img_crop = img_poly_on_img[y:y+h, x:x+w, :]
|
||||
|
||||
# print(file_name, angle_degrees, w*h,
|
||||
# mask_poly[:,:,0].sum(),
|
||||
# mask_poly[:,:,0].sum() /float(w*h) ,
|
||||
# 'didi')
|
||||
mask_poly = np.zeros(img_crop_rgb.shape[:2], dtype=np.uint8)
|
||||
mask_poly = cv2.fillPoly(mask_poly, pts=[cont - [x, y]], color=1)
|
||||
|
||||
if angle_degrees > 3:
|
||||
better_des_slope = get_orientation_moments(textline_coords)
|
||||
|
||||
img_crop = rotate_image_with_padding(img_crop, better_des_slope)
|
||||
if img_bin:
|
||||
better_des_slope = get_orientation_moments(cont)
|
||||
img_crop_rgb = rotate_image_with_padding(img_crop_rgb, better_des_slope)
|
||||
img_crop_bin = rotate_image_with_padding(img_crop_bin, better_des_slope)
|
||||
|
||||
mask_poly = rotate_image_with_padding(mask_poly, better_des_slope)
|
||||
mask_poly = mask_poly.astype('uint8')
|
||||
|
||||
#new bounding box
|
||||
x_n, y_n, w_n, h_n = get_contours_and_bounding_boxes(mask_poly[:,:,0])
|
||||
|
||||
mask_poly = mask_poly[y_n:y_n+h_n, x_n:x_n+w_n, :]
|
||||
img_crop = img_crop[y_n:y_n+h_n, x_n:x_n+w_n, :]
|
||||
|
||||
if not self.do_not_mask_with_textline_contour:
|
||||
img_crop[mask_poly==0] = 255
|
||||
if img_bin:
|
||||
img_crop_bin = img_crop_bin[y_n:y_n+h_n, x_n:x_n+w_n, :]
|
||||
if not self.do_not_mask_with_textline_contour:
|
||||
img_crop_bin[mask_poly==0] = 255
|
||||
|
||||
if mask_poly[:,:,0].sum() /float(w_n*h_n) < 0.50 and w_scaled > 90:
|
||||
if img_bin:
|
||||
img_crop, img_crop_bin = \
|
||||
break_curved_line_into_small_pieces_and_then_merge(
|
||||
img_crop, mask_poly, img_crop_bin)
|
||||
else:
|
||||
img_crop, _ = \
|
||||
break_curved_line_into_small_pieces_and_then_merge(
|
||||
img_crop, mask_poly)
|
||||
|
||||
# get new bounding box
|
||||
x_n, y_n, w_n, h_n = get_contours_and_bounding_boxes(mask_poly)
|
||||
img_crop_rgb = img_crop_rgb[y_n: y_n + h_n, x_n: x_n + w_n]
|
||||
img_crop_bin = img_crop_bin[y_n: y_n + h_n, x_n: x_n + w_n]
|
||||
mask_poly = mask_poly[y_n: y_n + h_n, x_n: x_n + w_n]
|
||||
else:
|
||||
better_des_slope = 0
|
||||
|
||||
if not self.do_not_mask_with_textline_contour:
|
||||
img_crop[mask_poly==0] = 255
|
||||
if img_bin:
|
||||
if not self.do_not_mask_with_textline_contour:
|
||||
img_crop_bin[mask_poly==0] = 255
|
||||
if type_textregion=='drop-capital':
|
||||
pass
|
||||
else:
|
||||
if mask_poly[:,:,0].sum() /float(w*h) < 0.50 and w_scaled > 90:
|
||||
if img_bin:
|
||||
img_crop, img_crop_bin = \
|
||||
img_crop_rgb[mask_poly == 0] = 255 # FIXME: or median color?
|
||||
img_crop_bin[mask_poly == 0] = 255
|
||||
|
||||
if (type_textregion !='drop-capital' and
|
||||
mask_poly.sum() < 0.50 * mask_poly.size and
|
||||
w_scaled > 90):
|
||||
|
||||
img_crop_rgb, img_crop_bin = \
|
||||
break_curved_line_into_small_pieces_and_then_merge(
|
||||
img_crop, mask_poly, img_crop_bin)
|
||||
else:
|
||||
img_crop, _ = \
|
||||
break_curved_line_into_small_pieces_and_then_merge(
|
||||
img_crop, mask_poly)
|
||||
img_crop_rgb, img_crop_bin, mask_poly)
|
||||
|
||||
if w_scaled < 750:#1.5*image_width:
|
||||
img_fin = preprocess_and_resize_image_for_ocrcnn_model(
|
||||
img_crop, image_height, image_width)
|
||||
cropped_lines.append(img_fin)
|
||||
img_crop_split_rgb = img_crop_split_bin = None
|
||||
else:
|
||||
img_crop_split_rgb, img_crop_split_bin = return_textlines_split_if_needed(
|
||||
img_crop_rgb, img_crop_bin)
|
||||
if img_crop_split_rgb:
|
||||
cropped_lines_rgb.extend(img_crop_split_rgb)
|
||||
cropped_lines_bin.extend(img_crop_split_bin)
|
||||
if abs(better_des_slope) > 45:
|
||||
cropped_lines_ver_index.append(1)
|
||||
cropped_lines_ver_index.append(1)
|
||||
else:
|
||||
cropped_lines_ver_index.append(0)
|
||||
|
||||
cropped_lines_meging_indexing.append(0)
|
||||
if img_bin:
|
||||
img_fin = preprocess_and_resize_image_for_ocrcnn_model(
|
||||
img_crop_bin, image_height, image_width)
|
||||
cropped_lines_bin.append(img_fin)
|
||||
else:
|
||||
splited_images, splited_images_bin = return_textlines_split_if_needed(
|
||||
img_crop, img_crop_bin if img_bin else None)
|
||||
if splited_images:
|
||||
img_fin = preprocess_and_resize_image_for_ocrcnn_model(
|
||||
splited_images[0], image_height, image_width)
|
||||
cropped_lines.append(img_fin)
|
||||
cropped_lines_ver_index.append(0)
|
||||
cropped_lines_meging_indexing.append(1)
|
||||
|
||||
if abs(better_des_slope) > 45:
|
||||
cropped_lines_ver_index.append(1)
|
||||
else:
|
||||
cropped_lines_ver_index.append(0)
|
||||
|
||||
img_fin = preprocess_and_resize_image_for_ocrcnn_model(
|
||||
splited_images[1], image_height, image_width)
|
||||
|
||||
cropped_lines.append(img_fin)
|
||||
cropped_lines_meging_indexing.append(-1)
|
||||
|
||||
else:
|
||||
cropped_lines_rgb.append(img_crop_rgb)
|
||||
cropped_lines_bin.append(img_crop_bin)
|
||||
if abs(better_des_slope) > 45:
|
||||
cropped_lines_ver_index.append(1)
|
||||
else:
|
||||
cropped_lines_ver_index.append(0)
|
||||
|
||||
if img_bin:
|
||||
img_fin = preprocess_and_resize_image_for_ocrcnn_model(
|
||||
splited_images_bin[0], image_height, image_width)
|
||||
cropped_lines_bin.append(img_fin)
|
||||
img_fin = preprocess_and_resize_image_for_ocrcnn_model(
|
||||
splited_images_bin[1], image_height, image_width)
|
||||
cropped_lines_bin.append(img_fin)
|
||||
|
||||
else:
|
||||
img_fin = preprocess_and_resize_image_for_ocrcnn_model(
|
||||
img_crop, image_height, image_width)
|
||||
cropped_lines.append(img_fin)
|
||||
cropped_lines_meging_indexing.append(0)
|
||||
|
||||
if abs(better_des_slope) > 45:
|
||||
cropped_lines_ver_index.append(1)
|
||||
else:
|
||||
cropped_lines_ver_index.append(0)
|
||||
|
||||
if img_bin:
|
||||
img_fin = preprocess_and_resize_image_for_ocrcnn_model(
|
||||
img_crop_bin, image_height, image_width)
|
||||
cropped_lines_bin.append(img_fin)
|
||||
|
||||
|
||||
indexer_text_region = indexer_text_region +1
|
||||
cropped_lines_rgb = [preprocess_and_resize_image_for_ocrcnn_model(img, image_height, image_width)
|
||||
for img in cropped_lines_rgb]
|
||||
cropped_lines_bin = [preprocess_and_resize_image_for_ocrcnn_model(img, image_height, image_width)
|
||||
for img in cropped_lines_bin]
|
||||
|
||||
extracted_texts = []
|
||||
extracted_conf_value = []
|
||||
extracted_confs = []
|
||||
self.logger.debug("processing %d lines for %d regions",
|
||||
len(cropped_lines_rgb), len(set(cropped_lines_region_indexer)))
|
||||
cropped_lines = zip(cropped_lines_rgb, cropped_lines_bin, cropped_lines_ver_index)
|
||||
for batch in batched(cropped_lines, self.b_s):
|
||||
imgs_rgb, imgs_bin, ver_index = zip(*batch)
|
||||
ver_index = np.array(ver_index)
|
||||
imgs_rgb = np.stack(imgs_rgb)
|
||||
imgs_bin = np.stack(imgs_bin)
|
||||
imgs_rgb_ver = imgs_rgb[ver_index > 0, ::-1, ::-1]
|
||||
imgs_bin_ver = imgs_bin[ver_index > 0, ::-1, ::-1]
|
||||
|
||||
n_iterations = math.ceil(len(cropped_lines) / self.b_s)
|
||||
# inference model now yields (char-bytes, line-prob) instead of vocidx-softmax
|
||||
# (so ctc_decode and inverse StringLookup are included)
|
||||
# also, the model now expects a secondary binary input image
|
||||
preds, probs = self.model_zoo.get('ocr').predict((imgs_rgb, imgs_bin), verbose=0)
|
||||
|
||||
# FIXME: copy pasta
|
||||
for i in range(n_iterations):
|
||||
if i==(n_iterations-1):
|
||||
n_start = i*self.b_s
|
||||
imgs = cropped_lines[n_start:]
|
||||
imgs = np.array(imgs)
|
||||
imgs = imgs.reshape(imgs.shape[0], image_height, image_width, 3)
|
||||
if ver_index.any():
|
||||
preds_ver, probs_ver = self.model_zoo.get('ocr').predict((imgs_rgb_ver, imgs_bin_ver), verbose=0)
|
||||
flipped_ver_is_better = np.flatnonzero(probs_ver > probs[ver_index > 0])
|
||||
if len(flipped_ver_is_better):
|
||||
self.logger.info("%d skewed lines perform better when flipped", len(flipped_ver_is_better))
|
||||
preds[ver_index > 0][flipped_ver_is_better] = preds_ver[flipped_ver_is_better]
|
||||
probs[ver_index > 0][flipped_ver_is_better] = probs_ver[flipped_ver_is_better]
|
||||
|
||||
ver_imgs = np.array( cropped_lines_ver_index[n_start:] )
|
||||
indices_ver = np.where(ver_imgs == 1)[0]
|
||||
|
||||
#print(indices_ver, 'indices_ver')
|
||||
if len(indices_ver)>0:
|
||||
imgs_ver_flipped = imgs[indices_ver, : ,: ,:]
|
||||
imgs_ver_flipped = imgs_ver_flipped[:,::-1,::-1,:]
|
||||
#print(imgs_ver_flipped, 'imgs_ver_flipped')
|
||||
|
||||
else:
|
||||
imgs_ver_flipped = None
|
||||
|
||||
if img_bin:
|
||||
imgs_bin = cropped_lines_bin[n_start:]
|
||||
imgs_bin = np.array(imgs_bin)
|
||||
imgs_bin = imgs_bin.reshape(imgs_bin.shape[0], image_height, image_width, 3)
|
||||
|
||||
if len(indices_ver)>0:
|
||||
imgs_bin_ver_flipped = imgs_bin[indices_ver, : ,: ,:]
|
||||
imgs_bin_ver_flipped = imgs_bin_ver_flipped[:,::-1,::-1,:]
|
||||
#print(imgs_ver_flipped, 'imgs_ver_flipped')
|
||||
|
||||
else:
|
||||
imgs_bin_ver_flipped = None
|
||||
else:
|
||||
n_start = i*self.b_s
|
||||
n_end = (i+1)*self.b_s
|
||||
imgs = cropped_lines[n_start:n_end]
|
||||
imgs = np.array(imgs).reshape(self.b_s, image_height, image_width, 3)
|
||||
|
||||
ver_imgs = np.array( cropped_lines_ver_index[n_start:n_end] )
|
||||
indices_ver = np.where(ver_imgs == 1)[0]
|
||||
#print(indices_ver, 'indices_ver')
|
||||
|
||||
if len(indices_ver)>0:
|
||||
imgs_ver_flipped = imgs[indices_ver, : ,: ,:]
|
||||
imgs_ver_flipped = imgs_ver_flipped[:,::-1,::-1,:]
|
||||
#print(imgs_ver_flipped, 'imgs_ver_flipped')
|
||||
else:
|
||||
imgs_ver_flipped = None
|
||||
|
||||
|
||||
if img_bin:
|
||||
imgs_bin = cropped_lines_bin[n_start:n_end]
|
||||
imgs_bin = np.array(imgs_bin).reshape(self.b_s, image_height, image_width, 3)
|
||||
|
||||
|
||||
if len(indices_ver)>0:
|
||||
imgs_bin_ver_flipped = imgs_bin[indices_ver, : ,: ,:]
|
||||
imgs_bin_ver_flipped = imgs_bin_ver_flipped[:,::-1,::-1,:]
|
||||
#print(imgs_ver_flipped, 'imgs_ver_flipped')
|
||||
else:
|
||||
imgs_bin_ver_flipped = None
|
||||
|
||||
|
||||
self.logger.debug("processing next %d lines", len(imgs))
|
||||
preds = self.model_zoo.get('ocr').predict(imgs, verbose=0)
|
||||
|
||||
if len(indices_ver)>0:
|
||||
preds_flipped = self.model_zoo.get('ocr').predict(imgs_ver_flipped, verbose=0)
|
||||
preds_max_fliped = np.max(preds_flipped, axis=2 )
|
||||
preds_max_args_flipped = np.argmax(preds_flipped, axis=2 )
|
||||
pred_max_not_unk_mask_bool_flipped = preds_max_args_flipped[:,:]!=self.end_character
|
||||
masked_means_flipped = \
|
||||
np.sum(preds_max_fliped * pred_max_not_unk_mask_bool_flipped, axis=1) / \
|
||||
np.sum(pred_max_not_unk_mask_bool_flipped, axis=1)
|
||||
masked_means_flipped[np.isnan(masked_means_flipped)] = 0
|
||||
|
||||
preds_max = np.max(preds, axis=2 )
|
||||
preds_max_args = np.argmax(preds, axis=2 )
|
||||
pred_max_not_unk_mask_bool = preds_max_args[:,:]!=self.end_character
|
||||
|
||||
masked_means = \
|
||||
np.sum(preds_max * pred_max_not_unk_mask_bool, axis=1) / \
|
||||
np.sum(pred_max_not_unk_mask_bool, axis=1)
|
||||
masked_means[np.isnan(masked_means)] = 0
|
||||
|
||||
masked_means_ver = masked_means[indices_ver]
|
||||
#print(masked_means_ver, 'pred_max_not_unk')
|
||||
|
||||
indices_where_flipped_conf_value_is_higher = \
|
||||
np.where(masked_means_flipped > masked_means_ver)[0]
|
||||
|
||||
#print(indices_where_flipped_conf_value_is_higher, 'indices_where_flipped_conf_value_is_higher')
|
||||
if len(indices_where_flipped_conf_value_is_higher)>0:
|
||||
indices_to_be_replaced = indices_ver[indices_where_flipped_conf_value_is_higher]
|
||||
preds[indices_to_be_replaced,:,:] = \
|
||||
preds_flipped[indices_where_flipped_conf_value_is_higher, :, :]
|
||||
|
||||
if img_bin:
|
||||
preds_bin = self.model_zoo.get('ocr').predict(imgs_bin, verbose=0)
|
||||
|
||||
if len(indices_ver)>0:
|
||||
preds_flipped = self.model_zoo.get('ocr').predict(imgs_bin_ver_flipped, verbose=0)
|
||||
preds_max_fliped = np.max(preds_flipped, axis=2 )
|
||||
preds_max_args_flipped = np.argmax(preds_flipped, axis=2 )
|
||||
pred_max_not_unk_mask_bool_flipped = preds_max_args_flipped[:,:]!=self.end_character
|
||||
masked_means_flipped = \
|
||||
np.sum(preds_max_fliped * pred_max_not_unk_mask_bool_flipped, axis=1) / \
|
||||
np.sum(pred_max_not_unk_mask_bool_flipped, axis=1)
|
||||
masked_means_flipped[np.isnan(masked_means_flipped)] = 0
|
||||
|
||||
preds_max = np.max(preds, axis=2 )
|
||||
preds_max_args = np.argmax(preds, axis=2 )
|
||||
pred_max_not_unk_mask_bool = preds_max_args[:,:]!=self.end_character
|
||||
|
||||
masked_means = \
|
||||
np.sum(preds_max * pred_max_not_unk_mask_bool, axis=1) / \
|
||||
np.sum(pred_max_not_unk_mask_bool, axis=1)
|
||||
masked_means[np.isnan(masked_means)] = 0
|
||||
|
||||
masked_means_ver = masked_means[indices_ver]
|
||||
#print(masked_means_ver, 'pred_max_not_unk')
|
||||
|
||||
indices_where_flipped_conf_value_is_higher = \
|
||||
np.where(masked_means_flipped > masked_means_ver)[0]
|
||||
|
||||
#print(indices_where_flipped_conf_value_is_higher, 'indices_where_flipped_conf_value_is_higher')
|
||||
if len(indices_where_flipped_conf_value_is_higher)>0:
|
||||
indices_to_be_replaced = indices_ver[indices_where_flipped_conf_value_is_higher]
|
||||
preds_bin[indices_to_be_replaced,:,:] = \
|
||||
preds_flipped[indices_where_flipped_conf_value_is_higher, :, :]
|
||||
|
||||
preds = (preds + preds_bin) / 2.
|
||||
|
||||
pred_texts = decode_batch_predictions(preds, self.model_zoo.get('num_to_char'))
|
||||
|
||||
preds_max = np.max(preds, axis=2 )
|
||||
preds_max_args = np.argmax(preds, axis=2 )
|
||||
pred_max_not_unk_mask_bool = preds_max_args[:,:]!=self.end_character
|
||||
masked_means = \
|
||||
np.sum(preds_max * pred_max_not_unk_mask_bool, axis=1) / \
|
||||
np.sum(pred_max_not_unk_mask_bool, axis=1)
|
||||
|
||||
for ib in range(imgs.shape[0]):
|
||||
pred_texts_ib = pred_texts[ib].replace("[UNK]", "")
|
||||
if masked_means[ib] >= self.min_conf_value_of_textline_text:
|
||||
extracted_texts.append(pred_texts_ib)
|
||||
extracted_conf_value.append(masked_means[ib])
|
||||
else:
|
||||
extracted_texts.append("")
|
||||
extracted_conf_value.append(0)
|
||||
del cropped_lines
|
||||
def nooov(x):
|
||||
return x != b'[UNK]'
|
||||
for pred, prob in zip(preds, probs):
|
||||
text = b''.join(
|
||||
filter(nooov,
|
||||
map(bytes,
|
||||
(filter(None, char)
|
||||
for char in pred.tolist())))).decode('utf-8')
|
||||
extracted_texts.append(text)
|
||||
extracted_confs.append(prob)
|
||||
del cropped_lines_rgb
|
||||
del cropped_lines_bin
|
||||
gc.collect()
|
||||
|
||||
extracted_texts_merged = [extracted_texts[ind]
|
||||
if cropped_lines_meging_indexing[ind]==0
|
||||
else extracted_texts[ind]+" "+extracted_texts[ind+1]
|
||||
if cropped_lines_meging_indexing[ind]==1
|
||||
else None
|
||||
for ind in range(len(cropped_lines_meging_indexing))]
|
||||
|
||||
extracted_conf_value_merged = [extracted_conf_value[ind] # type: ignore
|
||||
if cropped_lines_meging_indexing[ind]==0
|
||||
else (extracted_conf_value[ind]+extracted_conf_value[ind+1])/2.
|
||||
if cropped_lines_meging_indexing[ind]==1
|
||||
else None
|
||||
for ind in range(len(cropped_lines_meging_indexing))]
|
||||
|
||||
extracted_conf_value_merged: List[float] = [extracted_conf_value_merged[ind_cfm]
|
||||
for ind_cfm in range(len(extracted_texts_merged))
|
||||
if extracted_texts_merged[ind_cfm] is not None]
|
||||
|
||||
extracted_texts_merged = [ind for ind in extracted_texts_merged if ind is not None]
|
||||
if cropped_lines_meging_indexing[ind] == 0
|
||||
else extracted_texts[ind] + " " + extracted_texts[ind + 1]
|
||||
for ind in range(len(cropped_lines_meging_indexing))
|
||||
if cropped_lines_meging_indexing[ind] >= 0]
|
||||
extracted_confs_merged = [extracted_confs[ind]
|
||||
if cropped_lines_meging_indexing[ind] == 0
|
||||
else 0.5 * (extracted_confs[ind] + extracted_confs[ind + 1])
|
||||
for ind in range(len(cropped_lines_meging_indexing))
|
||||
if cropped_lines_meging_indexing[ind] >= 0]
|
||||
|
||||
return EynollahOcrResult(
|
||||
extracted_texts_merged=extracted_texts_merged,
|
||||
extracted_conf_value_merged=extracted_conf_value_merged,
|
||||
extracted_confs_merged=extracted_confs_merged,
|
||||
cropped_lines_region_indexer=cropped_lines_region_indexer,
|
||||
total_bb_coordinates=total_bb_coordinates,
|
||||
)
|
||||
|
|
@ -651,13 +339,12 @@ class Eynollah_ocr:
|
|||
cropped_lines_region_indexer = result.cropped_lines_region_indexer
|
||||
total_bb_coordinates = result.total_bb_coordinates
|
||||
extracted_texts_merged = result.extracted_texts_merged
|
||||
extracted_conf_value_merged = result.extracted_conf_value_merged
|
||||
extracted_confs_merged = result.extracted_confs_merged
|
||||
|
||||
unique_cropped_lines_region_indexer = np.unique(cropped_lines_region_indexer)
|
||||
if out_image_with_text:
|
||||
image_text = Image.new("RGB", (img.shape[1], img.shape[0]), "white")
|
||||
draw = ImageDraw.Draw(image_text)
|
||||
font = get_font()
|
||||
font = get_font(font_size=40)
|
||||
|
||||
for indexer_text, bb_ind in enumerate(total_bb_coordinates):
|
||||
x_bb = bb_ind[0]
|
||||
|
|
@ -681,79 +368,53 @@ class Eynollah_ocr:
|
|||
draw.text((text_x, text_y), extracted_texts_merged[indexer_text], fill="black", font=font)
|
||||
image_text.save(out_image_with_text)
|
||||
|
||||
text_by_textregion = []
|
||||
for ind in unique_cropped_lines_region_indexer:
|
||||
ind = np.array(cropped_lines_region_indexer)==ind
|
||||
extracted_texts_merged_un = np.array(extracted_texts_merged)[ind]
|
||||
if len(extracted_texts_merged_un)>1:
|
||||
text_by_textregion_ind = ""
|
||||
cropped_lines_region_indexer = np.array(cropped_lines_region_indexer)
|
||||
for n_region, region in enumerate(page_tree.getroot().iter('{%s}TextRegion' % page_ns)):
|
||||
lines_indexer = np.flatnonzero(cropped_lines_region_indexer == n_region)
|
||||
if not len(lines_indexer):
|
||||
continue
|
||||
|
||||
text_region = ""
|
||||
next_glue = ""
|
||||
for indt in range(len(extracted_texts_merged_un)):
|
||||
if (extracted_texts_merged_un[indt].endswith('⸗') or
|
||||
extracted_texts_merged_un[indt].endswith('-') or
|
||||
extracted_texts_merged_un[indt].endswith('¬')):
|
||||
text_by_textregion_ind += next_glue + extracted_texts_merged_un[indt][:-1]
|
||||
for line_idx in lines_indexer:
|
||||
if extracted_confs_merged[line_idx] < self.min_conf_value_of_textline_text:
|
||||
continue
|
||||
text_line = extracted_texts_merged[line_idx]
|
||||
if (text_line.endswith(('⸗', '-', '¬')) and
|
||||
# last line of a region can still be wrapped
|
||||
# around columns or pages
|
||||
line_idx < len(lines_indexer) - 1):
|
||||
text_region += next_glue + text_line[:-1]
|
||||
next_glue = ""
|
||||
else:
|
||||
text_by_textregion_ind += next_glue + extracted_texts_merged_un[indt]
|
||||
text_region += next_glue + text_line
|
||||
next_glue = " "
|
||||
text_by_textregion.append(text_by_textregion_ind)
|
||||
|
||||
region_textequiv = region.find('{%s}TextEquiv' % page_ns)
|
||||
if region_textequiv is None:
|
||||
region_textequiv = ET.SubElement(region, 'TextEquiv')
|
||||
region_teunicode = region_textequiv.find('{%s}Unicode' % page_ns)
|
||||
if region_teunicode is None:
|
||||
region_teunicode = ET.SubElement(region_textequiv, 'Unicode')
|
||||
region_teunicode.text = text_region
|
||||
|
||||
for n_line, line in enumerate(region.iter('{%s}TextLine' % page_ns)):
|
||||
line_textequiv = line.find('{%s}TextEquiv' % page_ns)
|
||||
if line_textequiv is None:
|
||||
line_textequiv = ET.SubElement(line, 'TextEquiv')
|
||||
line_teunicode = line_textequiv.find('{%s}Unicode' % page_ns)
|
||||
if line_teunicode is None:
|
||||
line_teunicode = ET.SubElement(line_textequiv, 'Unicode')
|
||||
|
||||
line_idx = lines_indexer[n_line]
|
||||
if extracted_confs_merged[line_idx] < self.min_conf_value_of_textline_text:
|
||||
line.remove(line_textequiv)
|
||||
else:
|
||||
text_by_textregion.append(" ".join(extracted_texts_merged_un))
|
||||
|
||||
indexer = 0
|
||||
indexer_textregion = 0
|
||||
for nn in page_tree.getroot().iter(f'{{{page_ns}}}TextRegion'):
|
||||
|
||||
is_textregion_text = False
|
||||
for childtest in nn:
|
||||
if childtest.tag.endswith("TextEquiv"):
|
||||
is_textregion_text = True
|
||||
|
||||
if not is_textregion_text:
|
||||
text_subelement_textregion = ET.SubElement(nn, 'TextEquiv')
|
||||
unicode_textregion = ET.SubElement(text_subelement_textregion, 'Unicode')
|
||||
|
||||
|
||||
has_textline = False
|
||||
for child_textregion in nn:
|
||||
if child_textregion.tag.endswith("TextLine"):
|
||||
|
||||
is_textline_text = False
|
||||
for childtest2 in child_textregion:
|
||||
if childtest2.tag.endswith("TextEquiv"):
|
||||
is_textline_text = True
|
||||
|
||||
|
||||
if not is_textline_text:
|
||||
text_subelement = ET.SubElement(child_textregion, 'TextEquiv')
|
||||
if extracted_conf_value_merged:
|
||||
text_subelement.set('conf', f"{extracted_conf_value_merged[indexer]:.2f}")
|
||||
unicode_textline = ET.SubElement(text_subelement, 'Unicode')
|
||||
unicode_textline.text = extracted_texts_merged[indexer]
|
||||
else:
|
||||
for childtest3 in child_textregion:
|
||||
if childtest3.tag.endswith("TextEquiv"):
|
||||
for child_uc in childtest3:
|
||||
if child_uc.tag.endswith("Unicode"):
|
||||
if extracted_conf_value_merged:
|
||||
childtest3.set('conf', f"{extracted_conf_value_merged[indexer]:.2f}")
|
||||
child_uc.text = extracted_texts_merged[indexer]
|
||||
|
||||
indexer = indexer + 1
|
||||
has_textline = True
|
||||
if has_textline:
|
||||
if is_textregion_text:
|
||||
for child4 in nn:
|
||||
if child4.tag.endswith("TextEquiv"):
|
||||
for childtr_uc in child4:
|
||||
if childtr_uc.tag.endswith("Unicode"):
|
||||
childtr_uc.text = text_by_textregion[indexer_textregion]
|
||||
else:
|
||||
unicode_textregion.text = text_by_textregion[indexer_textregion]
|
||||
indexer_textregion = indexer_textregion + 1
|
||||
line_textequiv.set('conf', str(round(extracted_confs_merged[line_idx], 2)))
|
||||
line_teunicode.text = extracted_texts_merged[line_idx]
|
||||
|
||||
ET.register_namespace("",page_ns)
|
||||
self.logger.info("output filename: '%s'", out_file_ocr)
|
||||
page_tree.write(out_file_ocr, xml_declaration=True, method='xml', encoding="utf-8", default_namespace=None)
|
||||
|
||||
def run(
|
||||
|
|
@ -813,18 +474,13 @@ class Eynollah_ocr:
|
|||
img=img,
|
||||
page_tree=page_tree,
|
||||
page_ns=page_ns,
|
||||
|
||||
tr_ocr_input_height_and_width = 384
|
||||
)
|
||||
else:
|
||||
result = self.run_cnn(
|
||||
img=img,
|
||||
page_tree=page_tree,
|
||||
page_ns=page_ns,
|
||||
|
||||
img_bin=img_bin,
|
||||
image_width=512,
|
||||
image_height=32,
|
||||
)
|
||||
|
||||
self.write_ocr(
|
||||
|
|
|
|||
|
|
@ -17,9 +17,7 @@ import cv2
|
|||
import numpy as np
|
||||
import statistics
|
||||
|
||||
os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15
|
||||
import tensorflow as tf
|
||||
|
||||
from .eynollah import Eynollah
|
||||
from .model_zoo import EynollahModelZoo
|
||||
from .utils.resize import resize_image
|
||||
from .utils.contour import (
|
||||
|
|
@ -33,23 +31,27 @@ DPI_THRESHOLD = 298
|
|||
KERNEL = np.ones((5, 5), np.uint8)
|
||||
|
||||
|
||||
class machine_based_reading_order_on_layout:
|
||||
class Reorder(Eynollah):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
model_zoo: EynollahModelZoo,
|
||||
logger : Optional[logging.Logger] = None,
|
||||
device: str = '',
|
||||
):
|
||||
self.logger = logger or logging.getLogger('eynollah.mbreorder')
|
||||
self.model_zoo = model_zoo
|
||||
|
||||
try:
|
||||
for device in tf.config.list_physical_devices('GPU'):
|
||||
tf.config.experimental.set_memory_growth(device, True)
|
||||
except:
|
||||
self.logger.warning("no GPU device available")
|
||||
self.model_zoo.load_model('reading_order')
|
||||
self.setup_models(device=device)
|
||||
|
||||
def setup_models(self, device=''):
|
||||
loadable = ['reading_order']
|
||||
self.model_zoo.load_models(*loadable, device=device)
|
||||
for model in loadable:
|
||||
self.logger.debug("model %s has input shape %s", model,
|
||||
self.model_zoo.get(model).input_shape)
|
||||
|
||||
self.model_zoo.load_models('reading_order')
|
||||
|
||||
def read_xml(self, xml_file):
|
||||
tree1 = ET.parse(xml_file, parser = ET.XMLParser(encoding='utf-8'))
|
||||
|
|
@ -675,7 +677,7 @@ class machine_based_reading_order_on_layout:
|
|||
tot_counter += 1
|
||||
batch.append(j)
|
||||
if tot_counter % inference_bs == 0 or tot_counter == len(ij_list):
|
||||
y_pr = self.model_zoo.get('reading_order').predict(input_1 , verbose='0')
|
||||
y_pr = self.model_zoo.get('reading_order').predict(input_1, verbose=0)
|
||||
for jb, j in enumerate(batch):
|
||||
if y_pr[jb][0]>=0.5:
|
||||
post_list.append(j)
|
||||
|
|
|
|||
|
|
@ -208,22 +208,6 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([
|
|||
type='Keras',
|
||||
),
|
||||
|
||||
EynollahModelSpec(
|
||||
category="num_to_char",
|
||||
variant='',
|
||||
filename="characters_org.txt",
|
||||
dist_url=dist_url("ocr"),
|
||||
type='decoder',
|
||||
),
|
||||
|
||||
EynollahModelSpec(
|
||||
category="characters",
|
||||
variant='',
|
||||
filename="characters_org.txt",
|
||||
dist_url=dist_url("ocr"),
|
||||
type='List[str]',
|
||||
),
|
||||
|
||||
EynollahModelSpec(
|
||||
category="ocr",
|
||||
variant='tr',
|
||||
|
|
@ -233,20 +217,4 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([
|
|||
type='Keras',
|
||||
),
|
||||
|
||||
EynollahModelSpec(
|
||||
category="trocr_processor",
|
||||
variant='',
|
||||
filename="models_eynollah/model_eynollah_ocr_trocr_20250919",
|
||||
dist_url=dist_url("ocr"),
|
||||
type='TrOCRProcessor',
|
||||
),
|
||||
|
||||
EynollahModelSpec(
|
||||
category="trocr_processor",
|
||||
variant='htr',
|
||||
filename="models_eynollah/microsoft/trocr-base-handwritten",
|
||||
dist_url=dist_url("extra"),
|
||||
type='TrOCRProcessor',
|
||||
),
|
||||
|
||||
])
|
||||
|
|
|
|||
|
|
@ -14,6 +14,19 @@ from .default_specs import DEFAULT_MODEL_SPECS
|
|||
from .types import AnyModel, T
|
||||
|
||||
|
||||
MODEL_VRAM_LIMITS = {
|
||||
"binarization": 868, # due to bs 5
|
||||
"enhancement": 980, # due to bs 3
|
||||
"col_classifier": 210,
|
||||
"page": 618,
|
||||
"textline": 1680, # 954 for bs 1
|
||||
"region_1_2": 1580,
|
||||
"region_fl_np": 1756,
|
||||
"table": 1818,
|
||||
"reading_order": 632,
|
||||
"ocr": 850,
|
||||
}
|
||||
|
||||
class EynollahModelZoo:
|
||||
"""
|
||||
Wrapper class that handles storage and loading of models for all eynollah runners.
|
||||
|
|
@ -35,7 +48,7 @@ class EynollahModelZoo:
|
|||
self._overrides = []
|
||||
if model_overrides:
|
||||
self.override_models(*model_overrides)
|
||||
self._loaded: Dict[str, Predictor] = {}
|
||||
self._loaded: Dict[str, Union[Predictor, AnyModel]] = {}
|
||||
|
||||
@property
|
||||
def model_overrides(self):
|
||||
|
|
@ -70,6 +83,13 @@ class EynollahModelZoo:
|
|||
model_path = Path(self.model_basedir).joinpath(spec.filename)
|
||||
else:
|
||||
model_path = Path(spec.filename)
|
||||
if model_path.suffix == '.h5' and Path(model_path.stem).exists():
|
||||
# prefer SavedModel over HDF5 format if it exists
|
||||
model_path = Path(model_path.stem)
|
||||
if model_path.with_suffix('.onnx').exists():
|
||||
# prefer ONNX over SavedModel format if it exists
|
||||
model_path = model_path.with_suffix('.onnx')
|
||||
|
||||
return model_path
|
||||
|
||||
def load_models(
|
||||
|
|
@ -82,24 +102,31 @@ class EynollahModelZoo:
|
|||
"""
|
||||
ret = {} # cannot use self._loaded here, yet – first spawn all predictors
|
||||
for load_args in all_load_args:
|
||||
load_kwargs = dict(device=device)
|
||||
if isinstance(load_args, str):
|
||||
model_category = load_args
|
||||
load_args = [model_category]
|
||||
model_category, model_variant = load_args, ""
|
||||
elif len(load_args) > 2:
|
||||
# for calls to self.model_path
|
||||
self.override_models(load_args)
|
||||
# for calls to Predictor.load_model
|
||||
model_category, model_variant, model_path = load_args
|
||||
load_kwargs["model_variant"] = model_variant
|
||||
load_kwargs["model_path_override"] = model_path
|
||||
else:
|
||||
model_category = load_args[0]
|
||||
load_kwargs = {}
|
||||
if model_category.endswith('_resized'):
|
||||
load_args[0] = model_category[:-8]
|
||||
load_kwargs["resized"] = True
|
||||
elif model_category.endswith('_patched'):
|
||||
load_args[0] = model_category[:-8]
|
||||
load_kwargs["patched"] = True
|
||||
spec = self.specs.get(model_category, load_args[1] if len(load_args) > 1 else '')
|
||||
if spec.type in ['Keras'] and spec.category != 'ocr':
|
||||
ret[model_category] = Predictor(self.logger, self)
|
||||
ret[model_category].load_model(*load_args, **load_kwargs, device=device)
|
||||
else:
|
||||
ret[model_category] = self.load_model(*load_args, **load_kwargs, device=device)
|
||||
model_category, model_variant = load_args
|
||||
load_kwargs["model_variant"] = model_variant
|
||||
|
||||
# if model_category.endswith('_resized'):
|
||||
# model_category = model_category[:-8]
|
||||
# load_kwargs["resized"] = True
|
||||
# elif model_category.endswith('_patched'):
|
||||
# model_category = model_category[:-8]
|
||||
# load_kwargs["patched"] = True
|
||||
|
||||
model = Predictor(self.logger, self)
|
||||
model.load_model(model_category, **load_kwargs)
|
||||
|
||||
ret[model_category] = model
|
||||
self._loaded.update(ret)
|
||||
return self._loaded
|
||||
|
||||
|
|
@ -108,31 +135,48 @@ class EynollahModelZoo:
|
|||
model_category: str,
|
||||
model_variant: str = '',
|
||||
model_path_override: Optional[str] = None,
|
||||
patched: bool = False,
|
||||
resized: bool = False,
|
||||
# patched: bool = False,
|
||||
# resized: bool = False,
|
||||
device: str = '',
|
||||
) -> AnyModel:
|
||||
"""
|
||||
Load any model
|
||||
"""
|
||||
os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15
|
||||
if model_path_override:
|
||||
self.override_models((model_category, model_variant, model_path_override))
|
||||
model_path = self.model_path(model_category, model_variant)
|
||||
|
||||
if model_category == 'ocr' and model_variant == 'tr':
|
||||
model = self._load_trocr_model(model_path, device=device)
|
||||
elif model_path.is_dir() and (model_path / "keras_metadata.pb").exists():
|
||||
# Keras model
|
||||
model = self._load_keras_model(model_category, model_path, device=device)
|
||||
elif model_path.is_dir():
|
||||
# TF-Serving model
|
||||
model = self._load_serving_model(model_category, model_path, device=device)
|
||||
elif model_path.suffix == '.onnx':
|
||||
# ONNX model
|
||||
model = self._load_onnx_model(model_category, model_path, device=device)
|
||||
else:
|
||||
raise ValueError("unknown model type for '%s'" % str(model_path))
|
||||
model._name = model_category
|
||||
return model
|
||||
|
||||
def get(self, model_category: str) -> Union[Predictor, AnyModel]:
|
||||
if model_category not in self._loaded:
|
||||
raise ValueError(f'Model "{model_category}" not previously loaded with "load_model(..)"')
|
||||
return self._loaded[model_category]
|
||||
|
||||
def _configure_tf_device(self, model_category, device=''):
|
||||
from ocrd_utils import tf_disable_interactive_logs
|
||||
tf_disable_interactive_logs()
|
||||
|
||||
import tensorflow as tf
|
||||
from tensorflow.keras.models import load_model
|
||||
|
||||
from ..patch_encoder import (
|
||||
PatchEncoder,
|
||||
Patches,
|
||||
wrap_layout_model_patched,
|
||||
wrap_layout_model_resized,
|
||||
)
|
||||
cuda = False
|
||||
try:
|
||||
gpus = tf.config.list_physical_devices('GPU')
|
||||
if device:
|
||||
if ',' in device:
|
||||
if ':' in device:
|
||||
for spec in device.split(','):
|
||||
cat, dev = spec.split(':')
|
||||
if fnmatchcase(model_category, cat):
|
||||
|
|
@ -147,7 +191,14 @@ class EynollahModelZoo:
|
|||
gpus = gpus[:1] # TF will always use first allowable
|
||||
tf.config.set_visible_devices(gpus, 'GPU')
|
||||
for device in gpus:
|
||||
tf.config.experimental.set_memory_growth(device, True)
|
||||
# tf.config.experimental.set_memory_growth(device, True)
|
||||
# dynamic growth never frees memory (to avoid fragmentation),
|
||||
# so the VRAM requirements end up much larger than feasible
|
||||
# (for small GPUs); so try hard (calibrated) limits instead:
|
||||
tf.config.set_logical_device_configuration(
|
||||
device,
|
||||
[tf.config.LogicalDeviceConfiguration(
|
||||
memory_limit=MODEL_VRAM_LIMITS[model_category])])
|
||||
vendor_name = (
|
||||
tf.config.experimental.get_device_details(device)
|
||||
.get('device_name', 'unknown'))
|
||||
|
|
@ -155,95 +206,195 @@ class EynollahModelZoo:
|
|||
self.logger.info("using GPU %s (%s) for model %s",
|
||||
device.name,
|
||||
vendor_name,
|
||||
model_category + (
|
||||
"_patched" if patched else
|
||||
"_resized" if resized else ""))
|
||||
model_category # + (
|
||||
# "_patched" if patched else
|
||||
# "_resized" if resized else "")
|
||||
)
|
||||
except RuntimeError:
|
||||
self.logger.exception("cannot configure GPU devices")
|
||||
if not cuda:
|
||||
self.logger.warning("no GPU device available")
|
||||
|
||||
if model_path_override:
|
||||
self.override_models((model_category, model_variant, model_path_override))
|
||||
model_path = self.model_path(model_category, model_variant)
|
||||
if model_path.suffix == '.h5' and Path(model_path.stem).exists():
|
||||
# prefer SavedModel over HDF5 format if it exists
|
||||
model_path = Path(model_path.stem)
|
||||
if model_category == 'ocr':
|
||||
model = self._load_ocr_model(variant=model_variant)
|
||||
elif model_category == 'num_to_char':
|
||||
model = self._load_num_to_char()
|
||||
elif model_category == 'characters':
|
||||
model = self._load_characters()
|
||||
elif model_category == 'trocr_processor':
|
||||
from transformers import TrOCRProcessor
|
||||
model = TrOCRProcessor.from_pretrained(model_path)
|
||||
else:
|
||||
def _configure_torch_device(self, model_category, device=''):
|
||||
import torch
|
||||
|
||||
device0 = torch.device('cpu')
|
||||
if not device and torch.cuda.is_available():
|
||||
device = 'GPU' # try
|
||||
if device and ':' in device:
|
||||
for spec in device.split(','):
|
||||
cat, dev = spec.split(':')
|
||||
if fnmatchcase('ocr', cat):
|
||||
device = dev
|
||||
break
|
||||
if device and device.startswith('GPU'):
|
||||
try:
|
||||
# avoid wasting VRAM on non-transformer models
|
||||
device0 = torch.device('cuda', int(device[3:] or 0))
|
||||
name = torch.cuda.get_device_name(device0)
|
||||
self.logger.info("using GPU %s (%s) for model ocr:tr", device0, name)
|
||||
except:
|
||||
self.logger.exception("cannot configure GPU device")
|
||||
device0 = torch.device('cpu')
|
||||
if device0.type != 'cuda':
|
||||
self.logger.warning("no GPU device available")
|
||||
return device0
|
||||
|
||||
def _load_keras_model(self, model_category, model_path, device=''):
|
||||
os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15
|
||||
from ocrd_utils import tf_disable_interactive_logs
|
||||
tf_disable_interactive_logs()
|
||||
|
||||
from tensorflow.keras.models import load_model
|
||||
from tensorflow.keras.models import Model as KerasModel
|
||||
|
||||
from ..training.models import cnn_rnn_ocr_model4inference
|
||||
|
||||
self._configure_tf_device(model_category, device=device)
|
||||
|
||||
model = load_model(model_path, compile=False)
|
||||
except Exception as e:
|
||||
self.logger.error(e)
|
||||
model = load_model(
|
||||
model_path, compile=False,
|
||||
custom_objects=dict(PatchEncoder=PatchEncoder,
|
||||
Patches=Patches))
|
||||
model._name = model_category
|
||||
if resized:
|
||||
model = wrap_layout_model_resized(model)
|
||||
model._name = model_category + '_resized'
|
||||
elif patched:
|
||||
model = wrap_layout_model_patched(model)
|
||||
model._name = model_category + '_patched'
|
||||
else:
|
||||
model.jit_compile = True
|
||||
assert isinstance(model, KerasModel)
|
||||
|
||||
# from ..patch_encoder import (
|
||||
# wrap_layout_model_patched,
|
||||
# wrap_layout_model_resized,
|
||||
# )
|
||||
# if resized:
|
||||
# model = wrap_layout_model_resized(model)
|
||||
# model._name = model_category + '_resized'
|
||||
# elif patched:
|
||||
# model = wrap_layout_model_patched(model)
|
||||
# model._name = model_category + '_patched'
|
||||
|
||||
if model_category == 'ocr':
|
||||
# cnn-rnn-ocr task model may not be in inference mode, yet
|
||||
model = cnn_rnn_ocr_model4inference(model, model_path)
|
||||
|
||||
model.make_predict_function()
|
||||
|
||||
return model
|
||||
|
||||
def get(self, model_category: str) -> Predictor:
|
||||
if model_category not in self._loaded:
|
||||
raise ValueError(f'Model "{model_category}" not previously loaded with "load_model(..)"')
|
||||
return self._loaded[model_category]
|
||||
def _load_serving_model(self, model_category, model_path, device=''):
|
||||
from ocrd_utils import tf_disable_interactive_logs
|
||||
tf_disable_interactive_logs()
|
||||
import tensorflow as tf
|
||||
|
||||
def _load_ocr_model(self, variant: str) -> AnyModel:
|
||||
self._configure_tf_device(model_category, device=device)
|
||||
model = tf.saved_model.load(model_path)
|
||||
model.predict_on_batch = model.serve
|
||||
model.input_shape = tuple(model.signatures.get('serving_default').inputs[0].shape)
|
||||
|
||||
return model
|
||||
|
||||
def _load_onnx_model(self, model_category, model_path, device=''):
|
||||
import onnxruntime as ort
|
||||
import numpy as np
|
||||
|
||||
providers = ort.get_available_providers()
|
||||
if device:
|
||||
if ':' in device:
|
||||
for spec in device.split(','):
|
||||
cat, dev = spec.split(':')
|
||||
if fnmatchcase(model_category, cat):
|
||||
device = dev
|
||||
break
|
||||
if device == 'CPU':
|
||||
gpu = -1
|
||||
else:
|
||||
assert device.startswith('GPU')
|
||||
gpu = int(device[3:] or "0")
|
||||
else:
|
||||
gpu = 0 # try first allowable
|
||||
# configure and prioritise
|
||||
if 'CUDAExecutionProvider' in providers:
|
||||
providers.remove('CUDAExecutionProvider')
|
||||
if gpu >= 0:
|
||||
providers = [('CUDAExecutionProvider', {
|
||||
'device_id': gpu,
|
||||
# 'arena_extend_strategy': 'kNextPowerOfTwo',
|
||||
'gpu_mem_limit': MODEL_VRAM_LIMITS[model_category] * 1024 * 1024,
|
||||
# 'cudnn_conv_algo_search': 'EXHAUSTIVE',
|
||||
# 'do_copy_in_default_stream': True,
|
||||
# ...
|
||||
})] + providers
|
||||
if 'TensorrtExecutionProvider' in providers:
|
||||
providers.remove('TensorrtExecutionProvider')
|
||||
if gpu >= 0:
|
||||
providers = [('TensorrtExecutionProvider', {
|
||||
'device_id': gpu,
|
||||
'trt_max_workspace_size': MODEL_VRAM_LIMITS[model_category] * 1024 * 1024,
|
||||
# 'trt_fp16_enable': True,
|
||||
# 'trt_engine_cache_enable': True,
|
||||
# 'trt_timing_cache_enable': True,
|
||||
# ...
|
||||
})] + providers
|
||||
model = ort.InferenceSession(
|
||||
model_path,
|
||||
providers=providers)
|
||||
# FIXME: notify about selected provider/device
|
||||
model_inputs = [model_input.name
|
||||
for model_input in model.get_inputs()]
|
||||
model_outputs = [model_output.name
|
||||
for model_output in model.get_outputs()]
|
||||
def predict_onnx(inputs):
|
||||
if len(model_inputs) == 1:
|
||||
inputs = [inputs]
|
||||
outputs = model.run(model_outputs, {
|
||||
model_input:
|
||||
input_data.astype(
|
||||
# models expect data_type() == 'tensor(float)', but np.float16 is 'tensor(float16)'
|
||||
# FIXME: do this dynamically (but how to convert .type to np.dtype?)
|
||||
np.float32 if input_data.dtype in [np.float16, np.float64] else
|
||||
input_data.dtype)
|
||||
for model_input, input_data in zip(model_inputs, inputs)
|
||||
})
|
||||
if len(model_outputs) == 1:
|
||||
outputs = outputs[0]
|
||||
return outputs
|
||||
model.predict_on_batch = predict_onnx
|
||||
model.input_shape = model.get_inputs()[0].shape
|
||||
|
||||
return model
|
||||
|
||||
def _load_trocr_model(self, model_path, device: str = "") -> AnyModel:
|
||||
"""
|
||||
Load OCR model
|
||||
"""
|
||||
from tensorflow.keras.models import Model as KerasModel
|
||||
from tensorflow.keras.models import load_model
|
||||
from transformers import VisionEncoderDecoderModel, TrOCRProcessor
|
||||
import numpy as np
|
||||
|
||||
ocr_model_dir = self.model_path('ocr', variant)
|
||||
if variant == 'tr':
|
||||
from transformers import VisionEncoderDecoderModel
|
||||
ret = VisionEncoderDecoderModel.from_pretrained(ocr_model_dir)
|
||||
assert isinstance(ret, VisionEncoderDecoderModel)
|
||||
return ret
|
||||
device = self._configure_torch_device('ocr', device=device)
|
||||
proc = TrOCRProcessor.from_pretrained(model_path)
|
||||
model = VisionEncoderDecoderModel.from_pretrained(model_path)
|
||||
assert isinstance(model, VisionEncoderDecoderModel)
|
||||
|
||||
model.to(device)
|
||||
def predict_torch(inputs):
|
||||
output = model.generate(
|
||||
proc(inputs, return_tensors="pt").pixel_values.to(device),
|
||||
# beam search instead of greedy decoding:
|
||||
num_beams=4,
|
||||
# also return probability
|
||||
output_scores=True,
|
||||
return_dict_in_generate=True)
|
||||
if output.sequences_scores is not None:
|
||||
# log-prob averaged over length
|
||||
conf = output.sequences_scores.exp().clamp(0.0, 1.0).cpu().numpy()
|
||||
else:
|
||||
ocr_model = load_model(ocr_model_dir, compile=False)
|
||||
assert isinstance(ocr_model, KerasModel)
|
||||
return KerasModel(
|
||||
ocr_model.get_layer(name="image").input, # type: ignore
|
||||
ocr_model.get_layer(name="dense2").output, # type: ignore
|
||||
)
|
||||
|
||||
def _load_characters(self) -> List[str]:
|
||||
"""
|
||||
Load encoding for OCR
|
||||
"""
|
||||
with open(self.model_path('num_to_char'), "r") as config_file:
|
||||
return json.load(config_file)
|
||||
|
||||
def _load_num_to_char(self) -> 'StringLookup':
|
||||
"""
|
||||
Load decoder for OCR
|
||||
"""
|
||||
from tensorflow.keras.layers import StringLookup
|
||||
|
||||
characters = self._load_characters()
|
||||
# Mapping characters to integers.
|
||||
char_to_num = StringLookup(vocabulary=characters, mask_token=None)
|
||||
# Mapping integers back to original characters.
|
||||
return StringLookup(vocabulary=char_to_num.get_vocabulary(), mask_token=None, invert=True)
|
||||
conf = np.ones(len(output.sequences), dtype=float)
|
||||
text = proc.batch_decode(
|
||||
output.sequences,
|
||||
skip_special_tokens=True,
|
||||
clean_up_tokenization_spaces=False)
|
||||
# we must convert to ndarray for Predictor resultq to work
|
||||
text = np.array(text)
|
||||
return text, conf
|
||||
model.predict_on_batch = predict_torch
|
||||
# not actually needed (image processor does resize itself)
|
||||
# no batch dimension (images passed as list w/ varying shapes)
|
||||
model.input_shape = (None,
|
||||
None,
|
||||
len(proc.image_processor.image_mean))
|
||||
return model
|
||||
|
||||
def __str__(self):
|
||||
return tabulate(
|
||||
|
|
@ -277,5 +428,5 @@ class EynollahModelZoo:
|
|||
"""
|
||||
if hasattr(self, '_loaded') and getattr(self, '_loaded'):
|
||||
for needle in list(self._loaded.keys()):
|
||||
if isinstance(self._loaded[needle], Predictor):
|
||||
self._loaded[needle].shutdown()
|
||||
del self._loaded[needle]
|
||||
|
|
|
|||
|
|
@ -1,10 +1,8 @@
|
|||
# NOTE: For predictable order of imports of torch/shapely/tensorflow
|
||||
# this must be the first import of the CLI!
|
||||
from .eynollah_imports import imported_libs
|
||||
from .processor import EynollahProcessor
|
||||
from click import command
|
||||
from ocrd.decorators import ocrd_cli_options, ocrd_cli_wrap_processor
|
||||
|
||||
from .processor import EynollahProcessor
|
||||
|
||||
@command()
|
||||
@ocrd_cli_options
|
||||
def main(*args, **kwargs):
|
||||
|
|
|
|||
|
|
@ -6,8 +6,8 @@ from tensorflow.keras import layers, models
|
|||
class PatchEncoder(layers.Layer):
|
||||
|
||||
# 441=21*21 # 14*14 # 28*28
|
||||
def __init__(self, num_patches=441, projection_dim=64):
|
||||
super().__init__()
|
||||
def __init__(self, num_patches=441, projection_dim=64, name='encode_patches'):
|
||||
super().__init__(name=name)
|
||||
self.num_patches = num_patches
|
||||
self.projection_dim = projection_dim
|
||||
self.projection = layers.Dense(self.projection_dim)
|
||||
|
|
@ -20,11 +20,12 @@ class PatchEncoder(layers.Layer):
|
|||
def get_config(self):
|
||||
return dict(num_patches=self.num_patches,
|
||||
projection_dim=self.projection_dim,
|
||||
position_embedding=self.position_embedding,
|
||||
**super().get_config())
|
||||
|
||||
class Patches(layers.Layer):
|
||||
def __init__(self, patch_size_x=1, patch_size_y=1):
|
||||
super().__init__()
|
||||
def __init__(self, patch_size_x=1, patch_size_y=1, name='extract_patches'):
|
||||
super().__init__(name=name)
|
||||
self.patch_size_x = patch_size_x
|
||||
self.patch_size_y = patch_size_y
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from contextlib import ExitStack
|
||||
from typing import List, Dict
|
||||
from typing import List, Dict, Tuple, Union
|
||||
import logging
|
||||
import logging.handlers
|
||||
import multiprocessing as mp
|
||||
|
|
@ -8,6 +8,7 @@ import numpy as np
|
|||
from .utils.shm import share_ndarray, ndarray_shared
|
||||
|
||||
QSIZE = 200
|
||||
ArrayT = Union[np.ndarray, Tuple[np.ndarray]]
|
||||
|
||||
|
||||
class Predictor(mp.context.SpawnProcess):
|
||||
|
|
@ -40,10 +41,10 @@ class Predictor(mp.context.SpawnProcess):
|
|||
def input_shape(self):
|
||||
return self({})
|
||||
|
||||
def predict(self, data: dict, verbose=0):
|
||||
def predict(self, data: ArrayT, verbose=0) -> ArrayT:
|
||||
return self(data)
|
||||
|
||||
def __call__(self, data: dict):
|
||||
def __call__(self, data: Union[ArrayT, Dict]) -> Union[ArrayT, Tuple]:
|
||||
# unusable as per python/cpython#79967
|
||||
#with self.jobid.get_lock():
|
||||
# would work, but not public:
|
||||
|
|
@ -55,7 +56,15 @@ class Predictor(mp.context.SpawnProcess):
|
|||
self.taskq.put((jobid, data))
|
||||
#self.logger.debug("sent shape query task '%d' for model '%s'", jobid, self.name)
|
||||
return self.result(jobid)
|
||||
with share_ndarray(data) as shared_data:
|
||||
with ExitStack() as stack:
|
||||
if isinstance(data, tuple):
|
||||
# multi-input
|
||||
shared_data = []
|
||||
for data0 in data:
|
||||
shared_data.append(stack.enter_context(share_ndarray(data0)))
|
||||
shared_data = tuple(shared_data)
|
||||
else:
|
||||
shared_data = stack.enter_context(share_ndarray(data))
|
||||
self.taskq.put((jobid, shared_data))
|
||||
#self.logger.debug("sent prediction task '%d' for model '%s': %s", jobid, self.name, shared_data)
|
||||
return self.result(jobid)
|
||||
|
|
@ -67,6 +76,14 @@ class Predictor(mp.context.SpawnProcess):
|
|||
result = self.results.pop(jobid)
|
||||
if isinstance(result, Exception):
|
||||
raise Exception(f"predictor {self.name} failed for {jobid}") from result
|
||||
elif isinstance(result, tuple) and isinstance(result[0], dict):
|
||||
# multi-output
|
||||
result1 = []
|
||||
for result0 in result:
|
||||
with ndarray_shared(result0) as shared_result0:
|
||||
result1.append(np.copy(shared_result0))
|
||||
result = result1
|
||||
self.closable.append(jobid)
|
||||
elif isinstance(result, dict):
|
||||
with ndarray_shared(result) as shared_result:
|
||||
result = np.copy(shared_result)
|
||||
|
|
@ -111,6 +128,8 @@ class Predictor(mp.context.SpawnProcess):
|
|||
"binarization": 4,
|
||||
"enhancement": 4,
|
||||
"reading_order": 4,
|
||||
"ocr": 8,
|
||||
"ocr_tr": 2,
|
||||
# medium size (672x672x3)...
|
||||
"textline": 2,
|
||||
# large models...
|
||||
|
|
@ -126,8 +145,20 @@ class Predictor(mp.context.SpawnProcess):
|
|||
self.resultq.put((jobid, result))
|
||||
#self.logger.debug("sent result for '%d': %s", jobid, result)
|
||||
else:
|
||||
tasks = [(jobid, shared_data)]
|
||||
if self.name == 'ocr_tr':
|
||||
# this model takes a list of (image) tensors
|
||||
# of heterogeneous shape as input,
|
||||
# resizing them internally;
|
||||
# so this looks like multi-input
|
||||
multi_input = True
|
||||
batch_size = len(shared_data)
|
||||
elif isinstance(shared_data, tuple):
|
||||
multi_input = True
|
||||
batch_size = shared_data[0]['shape'][0]
|
||||
else:
|
||||
multi_input = False
|
||||
batch_size = shared_data['shape'][0]
|
||||
tasks = [(jobid, shared_data)]
|
||||
while (not self.taskq.empty() and
|
||||
# climb to target batch size
|
||||
batch_size * len(tasks) < REBATCH_SIZE):
|
||||
|
|
@ -136,7 +167,7 @@ class Predictor(mp.context.SpawnProcess):
|
|||
# add to our batch
|
||||
tasks.append((jobid0, shared_data0))
|
||||
else:
|
||||
# immediately anser
|
||||
# immediately answer
|
||||
self.resultq.put((jobid0, self.model.input_shape))
|
||||
if len(tasks) > 1:
|
||||
self.logger.debug("rebatching %d '%s' tasks of batch size %d",
|
||||
|
|
@ -147,11 +178,25 @@ class Predictor(mp.context.SpawnProcess):
|
|||
for jobid, shared_data in tasks:
|
||||
#self.logger.debug("predicting '%d' with model '%s': %s", jobid, self.name, shared_data)
|
||||
jobs.append(jobid)
|
||||
if multi_input:
|
||||
data.append(tuple(stack.enter_context(ndarray_shared(shared_data0))
|
||||
for shared_data0 in shared_data))
|
||||
else:
|
||||
data.append(stack.enter_context(ndarray_shared(shared_data)))
|
||||
if multi_input:
|
||||
data = list(np.concatenate(data0)
|
||||
for data0 in zip(*data))
|
||||
else:
|
||||
data = np.concatenate(data)
|
||||
#result = self.model.predict(data, verbose=0)
|
||||
# faster, less VRAM
|
||||
result = self.model.predict_on_batch(data)
|
||||
if isinstance(result, tuple):
|
||||
multi_output = True
|
||||
results = zip(*(np.split(result0, len(jobs))
|
||||
for result0 in result))
|
||||
else:
|
||||
multi_output = False
|
||||
results = np.split(result, len(jobs))
|
||||
#self.logger.debug("sharing result array for '%d'", jobid)
|
||||
with ExitStack() as stack:
|
||||
|
|
@ -160,6 +205,10 @@ class Predictor(mp.context.SpawnProcess):
|
|||
# but don't want to wait either, so track closing
|
||||
# context per job, and wait for closable signal
|
||||
# from client
|
||||
if multi_output:
|
||||
result = tuple(stack.enter_context(share_ndarray(result0))
|
||||
for result0 in result)
|
||||
else:
|
||||
result = stack.enter_context(share_ndarray(result))
|
||||
closing[jobid] = stack.pop_all()
|
||||
self.resultq.put((jobid, result))
|
||||
|
|
@ -174,8 +223,11 @@ class Predictor(mp.context.SpawnProcess):
|
|||
def load_model(self, *load_args, **load_kwargs):
|
||||
assert len(load_args)
|
||||
self.name = '_'.join(list(load_args[:1]) +
|
||||
list(load_kwargs[key] for key in load_kwargs
|
||||
if key == 'model_variant') +
|
||||
list(key for key in load_kwargs
|
||||
if key != 'device'))
|
||||
if key in ['patched', 'resized']
|
||||
and load_kwargs[key]))
|
||||
self.load_args = load_args
|
||||
self.load_kwargs = load_kwargs
|
||||
self.start() # call run() in subprocess
|
||||
|
|
@ -194,17 +246,18 @@ class Predictor(mp.context.SpawnProcess):
|
|||
|
||||
def shutdown(self):
|
||||
# do not terminate from forked processor instances
|
||||
if mp.parent_process() is None:
|
||||
if not hasattr(self, 'model'):
|
||||
self.stopped.set()
|
||||
self.join()
|
||||
self.taskq.close()
|
||||
self.taskq.cancel_join_thread()
|
||||
self.resultq.close()
|
||||
self.resultq.cancel_join_thread()
|
||||
self.logq.close()
|
||||
self.terminate()
|
||||
#self.terminate()
|
||||
else:
|
||||
del self.model
|
||||
|
||||
def __del__(self):
|
||||
#self.logger.debug(f"deinit of {self} in {mp.current_process().name}")
|
||||
#self.logger.debug(f"deinit of {self.name} in {mp.current_process().name}")
|
||||
self.shutdown()
|
||||
|
|
|
|||
|
|
@ -7,16 +7,15 @@ import sys
|
|||
from .build_model_load_pretrained_weights_and_save import build_model_load_pretrained_weights_and_save
|
||||
from .generate_gt_for_training import main as generate_gt_cli
|
||||
from .inference import main as inference_cli
|
||||
from .train import ex
|
||||
from .train import train_cli
|
||||
from .convert import convert_cli
|
||||
from .extract_line_gt import linegt_cli
|
||||
<<<<<<< HEAD
|
||||
from .weights_ensembling import ensemble_cli
|
||||
|
||||
@click.command(context_settings=dict(
|
||||
ignore_unknown_options=True,
|
||||
))
|
||||
@click.argument('SACRED_ARGS', nargs=-1, type=click.UNPROCESSED)
|
||||
def train_cli(sacred_args):
|
||||
ex.run_commandline([sys.argv[0]] + list(sacred_args))
|
||||
=======
|
||||
from .weights_ensembling import main as ensemble_cli
|
||||
from .generate_or_update_cnn_rnn_ocr_character_list import main as update_ocr_characters_cli
|
||||
>>>>>>> integrating_trocr_and_torch_ensembling_and_updating_characters_list
|
||||
|
||||
@click.group('training')
|
||||
def main():
|
||||
|
|
@ -26,5 +25,7 @@ main.add_command(build_model_load_pretrained_weights_and_save)
|
|||
main.add_command(generate_gt_cli, 'generate-gt')
|
||||
main.add_command(inference_cli, 'inference')
|
||||
main.add_command(train_cli, 'train')
|
||||
main.add_command(convert_cli, 'convert')
|
||||
main.add_command(linegt_cli, 'export_textline_images_and_text')
|
||||
main.add_command(ensemble_cli, 'ensembling')
|
||||
main.add_command(update_ocr_characters_cli, 'generate_or_update_cnn_rnn_ocr_character_list')
|
||||
|
|
|
|||
103
src/eynollah/training/convert.py
Normal file
103
src/eynollah/training/convert.py
Normal file
|
|
@ -0,0 +1,103 @@
|
|||
import os
|
||||
from pathlib import Path
|
||||
from shutil import copy2
|
||||
import logging
|
||||
import json
|
||||
|
||||
import click
|
||||
|
||||
@click.command(context_settings=dict(
|
||||
help_option_names=['-h', '--help'],
|
||||
show_default=True))
|
||||
@click.option(
|
||||
"--rebuild",
|
||||
"-r",
|
||||
help="build new model from code and then load existing weights (requires input in SavedModel directory format with config.json present)",
|
||||
is_flag=True
|
||||
)
|
||||
@click.option(
|
||||
"--format",
|
||||
"-f",
|
||||
"format_",
|
||||
help="data format to convert to",
|
||||
type=click.Choice(["hdf5", "keras", "tf", "tf-serving", "onnx"]),
|
||||
default="tf"
|
||||
)
|
||||
@click.option(
|
||||
"--in",
|
||||
"-i",
|
||||
"in_",
|
||||
help="path to input model (file in hdf5 / keras format, or directory in tf format)",
|
||||
required=True,
|
||||
type=click.Path(exists=True, dir_okay=True)
|
||||
)
|
||||
@click.option(
|
||||
"--out",
|
||||
"-o",
|
||||
help="path to output model (file in hdf5 / keras / onnx format, or directory in tf / tf-serving format)",
|
||||
required=True,
|
||||
type=click.Path(exists=False, dir_okay=True)
|
||||
)
|
||||
def convert_cli(rebuild, format_, in_, out):
|
||||
"""
|
||||
convert models for inference
|
||||
|
||||
Load model from path, optionally by rebuilding, convert to output format and write model to path.
|
||||
"""
|
||||
os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15
|
||||
from ocrd_utils import tf_disable_interactive_logs
|
||||
tf_disable_interactive_logs()
|
||||
|
||||
import tensorflow as tf
|
||||
from tensorflow.keras.models import load_model
|
||||
from tensorflow.keras.models import Model as KerasModel
|
||||
|
||||
model_path = Path(in_)
|
||||
config_path = model_path / "config.json"
|
||||
if model_path.is_dir():
|
||||
assert (model_path / "keras_metadata.pb").exists(), (
|
||||
"input directory must be Keras model in SavedModel format")
|
||||
if rebuild:
|
||||
from .train import ex
|
||||
from .models import get_model
|
||||
|
||||
assert config_path.exists(), (
|
||||
"rebuilding requires input model in SavedModel format with config.json")
|
||||
|
||||
# merge defaults with existing config file
|
||||
ex.add_config(str(config_path))
|
||||
# some models deviate between training and inference
|
||||
ex.add_config(inference=True)
|
||||
# just retrieve final config (via pseudo-run)
|
||||
ex.main(lambda: 0)
|
||||
config = ex.run(options={'--loglevel': 'ERROR'}).config
|
||||
# use the config to capture the model builder
|
||||
model = get_model(config, logging.root)
|
||||
model.load_weights(model_path).assert_existing_objects_matched().expect_partial()
|
||||
else:
|
||||
from .models import cnn_rnn_ocr_model4inference
|
||||
|
||||
model = load_model(model_path, compile=False)
|
||||
|
||||
if isinstance(model, KerasModel):
|
||||
# cnn-rnn-ocr task deviates between training and inference
|
||||
model = cnn_rnn_ocr_model4inference(model, model_path)
|
||||
|
||||
if format_ in ["hdf5", "keras", "tf"]:
|
||||
kwargs = {"save_format": {"hdf5": "h5"}.get(format_, format_)}
|
||||
if format_ != "keras":
|
||||
kwargs["include_optimizer"] = False
|
||||
model.save(out, **kwargs)
|
||||
elif format_ == "tf-serving":
|
||||
model.export(out)
|
||||
elif format_ == "onnx":
|
||||
import tf2onnx
|
||||
tf2onnx.convert.from_keras(model, opset=18, output_path=out)
|
||||
else:
|
||||
raise ValueError("unknown output format '%s'" % format_)
|
||||
|
||||
# copy config.json if possible
|
||||
if config_path.exists() and format_ in ['tf', 'tf-serving']:
|
||||
copy2(config_path, Path(out) / config_path.name)
|
||||
|
||||
|
||||
|
|
@ -50,6 +50,12 @@ from ..utils import is_image_filename
|
|||
is_flag=True,
|
||||
help="if this parameter set to true, cropped textline images will not be masked with textline contour.",
|
||||
)
|
||||
@click.option(
|
||||
"--exclude_vertical_lines",
|
||||
"-exv",
|
||||
is_flag=True,
|
||||
help="if this parameter set to true, vertical textline images will be excluded.",
|
||||
)
|
||||
def linegt_cli(
|
||||
image,
|
||||
dir_in,
|
||||
|
|
@ -57,6 +63,7 @@ def linegt_cli(
|
|||
dir_out,
|
||||
pref_of_dataset,
|
||||
do_not_mask_with_textline_contour,
|
||||
exclude_vertical_lines,
|
||||
):
|
||||
assert bool(dir_in) ^ bool(image), "Set --dir-in or --image-filename, not both"
|
||||
if dir_in:
|
||||
|
|
@ -70,14 +77,13 @@ def linegt_cli(
|
|||
for dir_img in ls_imgs:
|
||||
file_name = Path(dir_img).stem
|
||||
dir_xml = os.path.join(dir_xmls, file_name + '.xml')
|
||||
|
||||
img = cv2.imread(dir_img)
|
||||
|
||||
total_bb_coordinates = []
|
||||
|
||||
tree1 = ET.parse(dir_xml, parser=ET.XMLParser(encoding="utf-8"))
|
||||
root1 = tree1.getroot()
|
||||
alltags = [elem.tag for elem in root1.iter()]
|
||||
tree = ET.parse(dir_xml, parser=ET.XMLParser(encoding="utf-8"))
|
||||
root = tree.getroot()
|
||||
alltags = [elem.tag for elem in root.iter()]
|
||||
|
||||
name_space = alltags[0].split('}')[0]
|
||||
name_space = name_space.split('{')[1]
|
||||
|
|
@ -89,7 +95,7 @@ def linegt_cli(
|
|||
indexer_text_region = 0
|
||||
indexer_textlines = 0
|
||||
# FIXME: non recursive, use OCR-D PAGE generateDS API. Or use an existing tool for this purpose altogether
|
||||
for nn in root1.iter(region_tags):
|
||||
for nn in root.iter(region_tags):
|
||||
for child_textregion in nn:
|
||||
if child_textregion.tag.endswith("TextLine"):
|
||||
for child_textlines in child_textregion:
|
||||
|
|
@ -100,6 +106,10 @@ def linegt_cli(
|
|||
|
||||
x, y, w, h = cv2.boundingRect(textline_coords)
|
||||
|
||||
if exclude_vertical_lines and h > 1.4 * w:
|
||||
img_crop = None
|
||||
continue
|
||||
|
||||
total_bb_coordinates.append([x, y, w, h])
|
||||
|
||||
img_poly_on_img = np.copy(img)
|
||||
|
|
@ -114,12 +124,15 @@ def linegt_cli(
|
|||
img_crop[mask_poly == 0] = 255
|
||||
|
||||
if img_crop.shape[0] == 0 or img_crop.shape[1] == 0:
|
||||
img_crop = None
|
||||
continue
|
||||
|
||||
|
||||
if child_textlines.tag.endswith("TextEquiv"):
|
||||
for cheild_text in child_textlines:
|
||||
if cheild_text.tag.endswith("Unicode"):
|
||||
textline_text = cheild_text.text
|
||||
if textline_text:
|
||||
if textline_text and img_crop is not None:
|
||||
base_name = os.path.join(
|
||||
dir_out, file_name + '_line_' + str(indexer_textlines)
|
||||
)
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ from pathlib import Path
|
|||
from PIL import Image, ImageDraw, ImageFont
|
||||
import cv2
|
||||
import numpy as np
|
||||
from eynollah.utils.font import get_font
|
||||
|
||||
from .gt_gen_utils import (
|
||||
filter_contours_area_of_image,
|
||||
|
|
@ -393,11 +394,15 @@ def visualize_reading_order(xml_file, dir_xml, dir_out, dir_imgs):
|
|||
layout = np.zeros( (y_len,x_len,3) )
|
||||
layout = cv2.fillPoly(layout, pts =co_text_all, color=(1,1,1))
|
||||
|
||||
try:
|
||||
img_file_name_with_format = find_format_of_given_filename_in_dir(dir_imgs, f_name)
|
||||
img = cv2.imread(os.path.join(dir_imgs, img_file_name_with_format))
|
||||
|
||||
overlayed = overlay_layout_on_image(layout, img, cx_ordered, cy_ordered, color, thickness)
|
||||
cv2.imwrite(os.path.join(dir_out, f_name+'.png'), overlayed)
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
else:
|
||||
img = np.zeros( (y_len,x_len,3) )
|
||||
|
|
@ -452,6 +457,7 @@ def visualize_textline_segmentation(xml_file, dir_xml, dir_out, dir_imgs):
|
|||
xml_file = os.path.join(dir_xml,ind_xml )
|
||||
f_name = Path(ind_xml).stem
|
||||
|
||||
try:
|
||||
img_file_name_with_format = find_format_of_given_filename_in_dir(dir_imgs, f_name)
|
||||
img = cv2.imread(os.path.join(dir_imgs, img_file_name_with_format))
|
||||
|
||||
|
|
@ -460,6 +466,8 @@ def visualize_textline_segmentation(xml_file, dir_xml, dir_out, dir_imgs):
|
|||
added_image = visualize_image_from_contours(co_tetxlines, img)
|
||||
|
||||
cv2.imwrite(os.path.join(dir_out, f_name+'.png'), added_image)
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
|
||||
|
|
@ -509,15 +517,17 @@ def visualize_layout_segmentation(xml_file, dir_xml, dir_out, dir_imgs):
|
|||
f_name = Path(ind_xml).stem
|
||||
print(f_name, 'f_name')
|
||||
|
||||
try:
|
||||
img_file_name_with_format = find_format_of_given_filename_in_dir(dir_imgs, f_name)
|
||||
img = cv2.imread(os.path.join(dir_imgs, img_file_name_with_format))
|
||||
|
||||
co_text, co_graphic, co_sep, co_img, co_table, co_map, co_noise, y_len, x_len = get_layout_contours_for_visualization(xml_file)
|
||||
co_text, co_graphic, co_sep, co_img, co_table, co_map, co_music, co_noise, y_len, x_len = get_layout_contours_for_visualization(xml_file)
|
||||
|
||||
|
||||
added_image = visualize_image_from_contours_layout(co_text['paragraph'], co_text['header']+co_text['heading'], co_text['drop-capital'], co_sep, co_img, co_text['marginalia'], co_table, img)
|
||||
added_image = visualize_image_from_contours_layout(co_text['paragraph'], co_text['header']+co_text['heading'], co_text['drop-capital'], co_sep, co_img, co_text['marginalia'], co_table, co_map, co_music, img)
|
||||
|
||||
cv2.imwrite(os.path.join(dir_out, f_name+'.png'), added_image)
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
|
||||
|
|
@ -552,8 +562,8 @@ def visualize_ocr_text(xml_file, dir_xml, dir_out):
|
|||
else:
|
||||
xml_files_ind = [xml_file]
|
||||
|
||||
font_path = "Charis-7.000/Charis-Regular.ttf" # Make sure this file exists!
|
||||
font = ImageFont.truetype(font_path, 40)
|
||||
###font_path = "Charis-7.000/Charis-Regular.ttf" # Make sure this file exists!
|
||||
font = get_font(font_size=40)#ImageFont.truetype(font_path, 40)
|
||||
|
||||
for ind_xml in tqdm(xml_files_ind):
|
||||
indexer = 0
|
||||
|
|
@ -590,11 +600,11 @@ def visualize_ocr_text(xml_file, dir_xml, dir_out):
|
|||
|
||||
|
||||
is_vertical = h > 2*w # Check orientation
|
||||
font = fit_text_single_line(draw, ocr_texts[index], font_path, w, int(h*0.4) )
|
||||
font = fit_text_single_line(draw, ocr_texts[index], w, int(h*0.4) )
|
||||
|
||||
if is_vertical:
|
||||
|
||||
vertical_font = fit_text_single_line(draw, ocr_texts[index], font_path, h, int(w * 0.8))
|
||||
vertical_font = fit_text_single_line(draw, ocr_texts[index], h, int(w * 0.8))
|
||||
|
||||
text_img = Image.new("RGBA", (h, w), (255, 255, 255, 0)) # Note: dimensions are swapped
|
||||
text_draw = ImageDraw.Draw(text_img)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,59 @@
|
|||
import os
|
||||
import numpy as np
|
||||
import json
|
||||
import click
|
||||
import logging
|
||||
|
||||
|
||||
|
||||
def run_character_list_update(dir_labels, out, current_character_list):
|
||||
ls_labels = os.listdir(dir_labels)
|
||||
ls_labels = [ind for ind in ls_labels if ind.endswith('.txt')]
|
||||
|
||||
if current_character_list:
|
||||
with open(current_character_list, 'r') as f_name:
|
||||
characters = json.load(f_name)
|
||||
|
||||
characters = set(characters)
|
||||
else:
|
||||
characters = set()
|
||||
|
||||
|
||||
for ind in ls_labels:
|
||||
label = open(os.path.join(dir_labels,ind),'r').read().split('\n')[0]
|
||||
|
||||
for char in label:
|
||||
characters.add(char)
|
||||
|
||||
|
||||
characters = sorted(list(set(characters)))
|
||||
|
||||
with open(out, 'w') as f_name:
|
||||
json.dump(characters, f_name)
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.option(
|
||||
"--dir_labels",
|
||||
"-dl",
|
||||
help="directory of labels which are .txt files",
|
||||
type=click.Path(exists=True, file_okay=False),
|
||||
required=True,
|
||||
)
|
||||
@click.option(
|
||||
"--current_character_list",
|
||||
"-ccl",
|
||||
help="existing character list in a .txt file that needs to be updated with a set of labels",
|
||||
type=click.Path(exists=True, file_okay=True),
|
||||
required=False,
|
||||
)
|
||||
@click.option(
|
||||
"--out",
|
||||
"-o",
|
||||
help="An output .txt file where the generated or updated character list will be written",
|
||||
type=click.Path(exists=False, file_okay=True),
|
||||
)
|
||||
|
||||
def main(dir_labels, out, current_character_list):
|
||||
run_character_list_update(dir_labels, out, current_character_list)
|
||||
|
||||
|
|
@ -8,7 +8,7 @@ from shapely import geometry
|
|||
from pathlib import Path
|
||||
from PIL import ImageFont
|
||||
from ocrd_utils import bbox_from_points
|
||||
|
||||
from eynollah.utils.font import get_font
|
||||
|
||||
KERNEL = np.ones((5, 5), np.uint8)
|
||||
NS = { 'pc': 'http://schema.primaresearch.org/PAGE/gts/pagecontent/2019-07-15'
|
||||
|
|
@ -18,7 +18,7 @@ with warnings.catch_warnings():
|
|||
warnings.simplefilter("ignore")
|
||||
|
||||
|
||||
def visualize_image_from_contours_layout(co_par, co_header, co_drop, co_sep, co_image, co_marginal, co_table, co_map, img):
|
||||
def visualize_image_from_contours_layout(co_par, co_header, co_drop, co_sep, co_image, co_marginal, co_table, co_map, co_music, img):
|
||||
alpha = 0.5
|
||||
|
||||
blank_image = np.ones( (img.shape[:]), dtype=np.uint8) * 255
|
||||
|
|
@ -32,6 +32,7 @@ def visualize_image_from_contours_layout(co_par, co_header, co_drop, co_sep, co_
|
|||
col_marginal = (106, 90, 205)
|
||||
col_table = (0, 90, 205)
|
||||
col_map = (90, 90, 205)
|
||||
col_music = (90, 90, 0)
|
||||
|
||||
if len(co_image)>0:
|
||||
cv2.drawContours(blank_image, co_image, -1, col_image, thickness=cv2.FILLED) # Fill the contour
|
||||
|
|
@ -60,6 +61,9 @@ def visualize_image_from_contours_layout(co_par, co_header, co_drop, co_sep, co_
|
|||
if len(co_map)>0:
|
||||
cv2.drawContours(blank_image, co_map, -1, col_map, thickness=cv2.FILLED) # Fill the contour
|
||||
|
||||
if len(co_music)>0:
|
||||
cv2.drawContours(blank_image, co_music, -1, col_music, thickness=cv2.FILLED) # Fill the contour
|
||||
|
||||
img_final =cv2.cvtColor(blank_image, cv2.COLOR_BGR2RGB)
|
||||
|
||||
added_image = cv2.addWeighted(img,alpha,img_final,1- alpha,0)
|
||||
|
|
@ -352,11 +356,11 @@ def get_textline_contours_and_ocr_text(xml_file):
|
|||
ocr_textlines.append(ocr_text_in[0])
|
||||
return co_use_case, y_len, x_len, ocr_textlines
|
||||
|
||||
def fit_text_single_line(draw, text, font_path, max_width, max_height):
|
||||
def fit_text_single_line(draw, text, max_width, max_height):
|
||||
initial_font_size = 50
|
||||
font_size = initial_font_size
|
||||
while font_size > 10: # Minimum font size
|
||||
font = ImageFont.truetype(font_path, font_size)
|
||||
font = get_font(font_size=font_size)# ImageFont.truetype(font_path, font_size)
|
||||
text_bbox = draw.textbbox((0, 0), text, font=font) # Get text bounding box
|
||||
text_width = text_bbox[2] - text_bbox[0]
|
||||
text_height = text_bbox[3] - text_bbox[1]
|
||||
|
|
@ -366,7 +370,7 @@ def fit_text_single_line(draw, text, font_path, max_width, max_height):
|
|||
|
||||
font_size -= 2 # Reduce font size and retry
|
||||
|
||||
return ImageFont.truetype(font_path, 10) # Smallest font fallback
|
||||
return get_font(font_size=10)#ImageFont.truetype(font_path, 10) # Smallest font fallback
|
||||
|
||||
def get_layout_contours_for_visualization(xml_file):
|
||||
tree1 = ET.parse(xml_file, parser = ET.XMLParser(encoding='utf-8'))
|
||||
|
|
@ -389,6 +393,7 @@ def get_layout_contours_for_visualization(xml_file):
|
|||
co_img=[]
|
||||
co_table=[]
|
||||
co_map=[]
|
||||
co_music=[]
|
||||
co_noise=[]
|
||||
|
||||
types_text = []
|
||||
|
|
@ -631,6 +636,31 @@ def get_layout_contours_for_visualization(xml_file):
|
|||
break
|
||||
co_map.append(np.array(c_t_in))
|
||||
|
||||
if tag.endswith('}MusicRegion') or tag.endswith('}musicregion'):
|
||||
#print('sth')
|
||||
for nn in root1.iter(tag):
|
||||
c_t_in=[]
|
||||
sumi=0
|
||||
for vv in nn.iter():
|
||||
# check the format of coords
|
||||
if vv.tag==link+'Coords':
|
||||
coords=bool(vv.attrib)
|
||||
if coords:
|
||||
p_h=vv.attrib['points'].split(' ')
|
||||
c_t_in.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) )
|
||||
break
|
||||
else:
|
||||
pass
|
||||
|
||||
|
||||
if vv.tag==link+'Point':
|
||||
c_t_in.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ])
|
||||
sumi+=1
|
||||
#print(vv.tag,'in')
|
||||
elif vv.tag!=link+'Point' and sumi>=1:
|
||||
break
|
||||
co_music.append(np.array(c_t_in))
|
||||
|
||||
|
||||
if tag.endswith('}NoiseRegion') or tag.endswith('}noiseregion'):
|
||||
#print('sth')
|
||||
|
|
@ -656,7 +686,7 @@ def get_layout_contours_for_visualization(xml_file):
|
|||
elif vv.tag!=link+'Point' and sumi>=1:
|
||||
break
|
||||
co_noise.append(np.array(c_t_in))
|
||||
return co_text, co_graphic, co_sep, co_img, co_table, co_map, co_noise, y_len, x_len
|
||||
return co_text, co_graphic, co_sep, co_img, co_table, co_map, co_music, co_noise, y_len, x_len
|
||||
|
||||
def get_images_of_ground_truth(
|
||||
gt_list,
|
||||
|
|
@ -682,8 +712,9 @@ def get_images_of_ground_truth(
|
|||
if not item.endswith('.xml')}
|
||||
|
||||
for index in tqdm(range(len(gt_list))):
|
||||
#try:
|
||||
print(gt_list[index])
|
||||
|
||||
try:
|
||||
tree1 = ET.parse(dir_in+'/'+gt_list[index], parser = ET.XMLParser(encoding='utf-8'))
|
||||
root1=tree1.getroot()
|
||||
alltags=[elem.tag for elem in root1.iter()]
|
||||
|
|
@ -697,7 +728,6 @@ def get_images_of_ground_truth(
|
|||
|
||||
if 'columns_width' in list(config_params.keys()):
|
||||
columns_width_dict = config_params['columns_width']
|
||||
# FIXME: look in /Page/@custom as well
|
||||
metadata_element = root1.find(link+'Metadata')
|
||||
num_col = None
|
||||
for child in metadata_element:
|
||||
|
|
@ -711,27 +741,55 @@ def get_images_of_ground_truth(
|
|||
y_new = int ( x_new * (y_len / float(x_len)) )
|
||||
|
||||
if printspace or "printspace_as_class_in_layout" in list(config_params.keys()):
|
||||
ps = (root1.xpath('/pc:PcGts/pc:Page/pc:Border', namespaces=NS) +
|
||||
root1.xpath('/pc:PcGts/pc:Page/pc:PrintSpace', namespaces=NS))
|
||||
coords = root1.xpath('//pc:Coords/@points', namespaces=NS)
|
||||
if len(ps):
|
||||
points = ps[0].find('pc:Coords', NS).get('points')
|
||||
ps_bbox = bbox_from_points(points)
|
||||
elif missing_printspace == 'skip':
|
||||
print(gt_list[index], "has no Border or PrintSpace - skipping file")
|
||||
continue
|
||||
elif missing_printspace == 'project' and len(coords):
|
||||
print(gt_list[index], "has no Border or PrintSpace - projecting hull of segments")
|
||||
bboxes = list(map(bbox_from_points, coords))
|
||||
left, top, right, bottom = zip(*bboxes)
|
||||
left = max(0, min(left) - 5)
|
||||
top = max(0, min(top) - 5)
|
||||
right = min(x_len, max(right) + 5)
|
||||
bottom = min(y_len, max(bottom) + 5)
|
||||
ps_bbox = [left, top, right, bottom]
|
||||
region_tags = np.unique([x for x in alltags if x.endswith('PrintSpace') or x.endswith('Border')])
|
||||
co_use_case = []
|
||||
|
||||
for tag in region_tags:
|
||||
tag_endings = ['}PrintSpace','}Border']
|
||||
|
||||
if tag.endswith(tag_endings[0]) or tag.endswith(tag_endings[1]):
|
||||
for nn in root1.iter(tag):
|
||||
c_t_in = []
|
||||
sumi = 0
|
||||
for vv in nn.iter():
|
||||
# check the format of coords
|
||||
if vv.tag == link + 'Coords':
|
||||
coords = bool(vv.attrib)
|
||||
if coords:
|
||||
p_h = vv.attrib['points'].split(' ')
|
||||
c_t_in.append(
|
||||
np.array([[int(x.split(',')[0]), int(x.split(',')[1])] for x in p_h]))
|
||||
break
|
||||
else:
|
||||
print(gt_list[index], "has no Border or PrintSpace - using full page")
|
||||
ps_bbox = [0, 0, None, None]
|
||||
pass
|
||||
|
||||
if vv.tag == link + 'Point':
|
||||
c_t_in.append([int(float(vv.attrib['x'])), int(float(vv.attrib['y']))])
|
||||
sumi += 1
|
||||
elif vv.tag != link + 'Point' and sumi >= 1:
|
||||
break
|
||||
co_use_case.append(np.array(c_t_in))
|
||||
|
||||
img = np.zeros((y_len, x_len, 3))
|
||||
|
||||
img_poly = cv2.fillPoly(img, pts=co_use_case, color=(1, 1, 1))
|
||||
|
||||
img_poly = img_poly.astype(np.uint8)
|
||||
|
||||
imgray = cv2.cvtColor(img_poly, cv2.COLOR_BGR2GRAY)
|
||||
_, thresh = cv2.threshold(imgray, 0, 255, 0)
|
||||
|
||||
contours, _ = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
|
||||
|
||||
cnt_size = np.array([cv2.contourArea(contours[j]) for j in range(len(contours))])
|
||||
|
||||
try:
|
||||
cnt = contours[np.argmax(cnt_size)]
|
||||
x, y, w, h = cv2.boundingRect(cnt)
|
||||
except:
|
||||
x, y , w, h = 0, 0, x_len, y_len
|
||||
|
||||
bb_xywh = [x, y, w, h]
|
||||
|
||||
|
||||
if config_file and (config_params['use_case']=='textline' or config_params['use_case']=='word' or config_params['use_case']=='glyph' or config_params['use_case']=='printspace'):
|
||||
|
|
@ -813,8 +871,7 @@ def get_images_of_ground_truth(
|
|||
|
||||
|
||||
if printspace and config_params['use_case']!='printspace':
|
||||
img_poly = img_poly[ps_bbox[1]:ps_bbox[3],
|
||||
ps_bbox[0]:ps_bbox[2], :]
|
||||
img_poly = img_poly[bb_xywh[1]:bb_xywh[1]+bb_xywh[3], bb_xywh[0]:bb_xywh[0]+bb_xywh[2], :]
|
||||
|
||||
|
||||
if 'columns_width' in list(config_params.keys()) and num_col and config_params['use_case']!='printspace':
|
||||
|
|
@ -828,24 +885,19 @@ def get_images_of_ground_truth(
|
|||
cv2.imwrite(os.path.join(output_dir, xml_file_stem + '.png'), img_poly)
|
||||
|
||||
if dir_images:
|
||||
org_image_name = ls_org_imgs[xml_file_stem]
|
||||
if not org_image_name:
|
||||
print("image file for XML stem", xml_file_stem, "is missing")
|
||||
continue
|
||||
if not os.path.isfile(os.path.join(dir_images, org_image_name)):
|
||||
print("image file for XML stem", xml_file_stem, "is not readable")
|
||||
continue
|
||||
org_image_name = ls_org_imgs[ls_org_imgs_stem.index(xml_file_stem)]
|
||||
img_org = cv2.imread(os.path.join(dir_images, org_image_name))
|
||||
|
||||
if printspace and config_params['use_case']!='printspace':
|
||||
img_org = img_org[ps_bbox[1]:ps_bbox[3],
|
||||
ps_bbox[0]:ps_bbox[2], :]
|
||||
img_org = img_org[bb_xywh[1]:bb_xywh[1]+bb_xywh[3], bb_xywh[0]:bb_xywh[0]+bb_xywh[2], :]
|
||||
|
||||
if 'columns_width' in list(config_params.keys()) and num_col and config_params['use_case']!='printspace':
|
||||
img_org = resize_image(img_org, y_new, x_new)
|
||||
|
||||
cv2.imwrite(os.path.join(dir_out_images, org_image_name), img_org)
|
||||
|
||||
except:
|
||||
pass
|
||||
|
||||
if config_file and config_params['use_case']=='layout':
|
||||
keys = list(config_params.keys())
|
||||
|
|
@ -870,7 +922,7 @@ def get_images_of_ground_truth(
|
|||
types_graphic_label = list(types_graphic_dict.values())
|
||||
|
||||
|
||||
labels_rgb_color = [ (0,0,0), (255,0,0), (255,125,0), (255,0,125), (125,255,125), (125,125,0), (0,125,255), (0,125,0), (125,125,125), (255,0,255), (125,0,125), (0,255,0),(0,0,255), (0,255,255), (255,125,125), (0,125,125), (0,255,125), (255,125,255), (125,255,0), (125,255,255)]
|
||||
labels_rgb_color = [ (0,0,0), (255,0,0), (255,125,0), (255,0,125), (125,255,125), (125,125,0), (0,125,255), (0,125,0), (125,125,125), (255,0,255), (125,0,125), (0,255,0),(0,0,255), (0,255,255), (255,125,125), (0,125,125), (0,255,125), (255,125,255), (125,255,0), (125,255,255), (125,125,255)]
|
||||
|
||||
|
||||
region_tags=np.unique([x for x in alltags if x.endswith('Region')])
|
||||
|
|
@ -882,6 +934,7 @@ def get_images_of_ground_truth(
|
|||
co_img=[]
|
||||
co_table=[]
|
||||
co_map=[]
|
||||
co_music=[]
|
||||
co_noise=[]
|
||||
|
||||
for tag in region_tags:
|
||||
|
|
@ -966,7 +1019,7 @@ def get_images_of_ground_truth(
|
|||
if "rest_as_decoration" in types_graphic:
|
||||
types_graphic_without_decoration = [element for element in types_graphic if element!='rest_as_decoration' and element!='decoration']
|
||||
if len(types_graphic_without_decoration) == 0:
|
||||
if "type" in nn.attrib:
|
||||
#if "type" in nn.attrib:
|
||||
c_t_in_graphic['decoration'].append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) )
|
||||
elif len(types_graphic_without_decoration) >= 1:
|
||||
if "type" in nn.attrib:
|
||||
|
|
@ -974,12 +1027,14 @@ def get_images_of_ground_truth(
|
|||
c_t_in_graphic[nn.attrib['type']].append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) )
|
||||
else:
|
||||
c_t_in_graphic['decoration'].append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) )
|
||||
|
||||
else:
|
||||
c_t_in_graphic['decoration'].append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) )
|
||||
else:
|
||||
if "type" in nn.attrib:
|
||||
if nn.attrib['type'] in all_defined_graphic_types:
|
||||
c_t_in_graphic[nn.attrib['type']].append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) )
|
||||
|
||||
|
||||
break
|
||||
else:
|
||||
pass
|
||||
|
|
@ -989,7 +1044,7 @@ def get_images_of_ground_truth(
|
|||
if "rest_as_decoration" in types_graphic:
|
||||
types_graphic_without_decoration = [element for element in types_graphic if element!='rest_as_decoration' and element!='decoration']
|
||||
if len(types_graphic_without_decoration) == 0:
|
||||
if "type" in nn.attrib:
|
||||
#if "type" in nn.attrib:
|
||||
c_t_in_graphic['decoration'].append( [ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ] )
|
||||
sumi+=1
|
||||
elif len(types_graphic_without_decoration) >= 1:
|
||||
|
|
@ -1000,6 +1055,9 @@ def get_images_of_ground_truth(
|
|||
else:
|
||||
c_t_in_graphic['decoration'].append( [ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ] )
|
||||
sumi+=1
|
||||
else:
|
||||
c_t_in_graphic['decoration'].append( [ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ] )
|
||||
sumi+=1
|
||||
|
||||
else:
|
||||
if "type" in nn.attrib:
|
||||
|
|
@ -1119,6 +1177,32 @@ def get_images_of_ground_truth(
|
|||
break
|
||||
co_map.append(np.array(c_t_in))
|
||||
|
||||
if 'musicregion' in keys:
|
||||
if tag.endswith('}MusicRegion') or tag.endswith('}musicregion'):
|
||||
#print('sth')
|
||||
for nn in root1.iter(tag):
|
||||
c_t_in=[]
|
||||
sumi=0
|
||||
for vv in nn.iter():
|
||||
# check the format of coords
|
||||
if vv.tag==link+'Coords':
|
||||
coords=bool(vv.attrib)
|
||||
if coords:
|
||||
p_h=vv.attrib['points'].split(' ')
|
||||
c_t_in.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) )
|
||||
break
|
||||
else:
|
||||
pass
|
||||
|
||||
|
||||
if vv.tag==link+'Point':
|
||||
c_t_in.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ])
|
||||
sumi+=1
|
||||
#print(vv.tag,'in')
|
||||
elif vv.tag!=link+'Point' and sumi>=1:
|
||||
break
|
||||
co_music.append(np.array(c_t_in))
|
||||
|
||||
if 'noiseregion' in keys:
|
||||
if tag.endswith('}NoiseRegion') or tag.endswith('}noiseregion'):
|
||||
#print('sth')
|
||||
|
|
@ -1195,6 +1279,10 @@ def get_images_of_ground_truth(
|
|||
erosion_rate = 0#2
|
||||
dilation_rate = 3#4
|
||||
co_map, img_boundary = update_region_contours(co_map, img_boundary, erosion_rate, dilation_rate, y_len, x_len )
|
||||
if "musicregion" in elements_with_artificial_class:
|
||||
erosion_rate = 0#2
|
||||
dilation_rate = 3#4
|
||||
co_music, img_boundary = update_region_contours(co_music, img_boundary, erosion_rate, dilation_rate, y_len, x_len )
|
||||
|
||||
|
||||
|
||||
|
|
@ -1222,6 +1310,8 @@ def get_images_of_ground_truth(
|
|||
img_poly=cv2.fillPoly(img, pts =co_table, color=labels_rgb_color[ config_params['tableregion']])
|
||||
if 'mapregion' in keys:
|
||||
img_poly=cv2.fillPoly(img, pts =co_map, color=labels_rgb_color[ config_params['mapregion']])
|
||||
if 'musicregion' in keys:
|
||||
img_poly=cv2.fillPoly(img, pts =co_music, color=labels_rgb_color[ config_params['musicregion']])
|
||||
if 'noiseregion' in keys:
|
||||
img_poly=cv2.fillPoly(img, pts =co_noise, color=labels_rgb_color[ config_params['noiseregion']])
|
||||
|
||||
|
|
@ -1286,6 +1376,9 @@ def get_images_of_ground_truth(
|
|||
if 'mapregion' in keys:
|
||||
color_label = config_params['mapregion']
|
||||
img_poly=cv2.fillPoly(img, pts =co_map, color=(color_label,color_label,color_label))
|
||||
if 'musicregion' in keys:
|
||||
color_label = config_params['musicregion']
|
||||
img_poly=cv2.fillPoly(img, pts =co_music, color=(color_label,color_label,color_label))
|
||||
if 'noiseregion' in keys:
|
||||
color_label = config_params['noiseregion']
|
||||
img_poly=cv2.fillPoly(img, pts =co_noise, color=(color_label,color_label,color_label))
|
||||
|
|
|
|||
|
|
@ -9,9 +9,11 @@ import warnings
|
|||
import json
|
||||
|
||||
import click
|
||||
|
||||
import numpy as np
|
||||
from numpy._typing import NDArray
|
||||
import cv2
|
||||
|
||||
import xml.etree.ElementTree as ET
|
||||
|
||||
os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15
|
||||
|
|
@ -119,8 +121,33 @@ class SBBPredict:
|
|||
return mIoU
|
||||
|
||||
def start_new_session_and_model(self):
|
||||
if self.task == "cnn-rnn-ocr":
|
||||
if self.cpu:
|
||||
tf.config.set_visible_devices([], 'GPU')
|
||||
os.environ['CUDA_VISIBLE_DEVICES']='-1'
|
||||
self.model = load_model(self.model_dir)
|
||||
self.model = tf.keras.models.Model(
|
||||
self.model.get_layer(name = "image").input,
|
||||
self.model.get_layer(name = "dense2").output)
|
||||
|
||||
assert isinstance(self.model, Model)
|
||||
|
||||
elif self.task == "transformer-ocr":
|
||||
import torch
|
||||
from transformers import VisionEncoderDecoderModel
|
||||
from transformers import TrOCRProcessor
|
||||
|
||||
self.model = VisionEncoderDecoderModel.from_pretrained(self.model_dir)
|
||||
self.processor = TrOCRProcessor.from_pretrained(self.model_dir)
|
||||
|
||||
if self.cpu:
|
||||
self.device = torch.device('cpu')
|
||||
else:
|
||||
self.device = torch.device('cuda:0')
|
||||
|
||||
self.model.to(self.device)
|
||||
|
||||
assert isinstance(self.model, torch.nn.Module)
|
||||
|
||||
else:
|
||||
try:
|
||||
for device in tf.config.list_physical_devices('GPU'):
|
||||
|
|
@ -137,15 +164,13 @@ class SBBPredict:
|
|||
custom_objects={"PatchEncoder": PatchEncoder,
|
||||
"Patches": Patches})
|
||||
|
||||
##if self.weights_dir!=None:
|
||||
##self.model.load_weights(self.weights_dir)
|
||||
if self.task != 'classification' and self.task != 'reading_order':
|
||||
self.img_height=self.model.layers[len(self.model.layers)-1].output_shape[1]
|
||||
self.img_width=self.model.layers[len(self.model.layers)-1].output_shape[2]
|
||||
self.n_classes=self.model.layers[len(self.model.layers)-1].output_shape[3]
|
||||
|
||||
|
||||
assert isinstance(self.model, Model)
|
||||
if self.task != 'classification' and self.task != 'reading_order':
|
||||
last = self.model.layers[-1]
|
||||
self.img_height = last.output_shape[1]
|
||||
self.img_width = last.output_shape[2]
|
||||
self.n_classes = last.output_shape[3]
|
||||
|
||||
def visualize_model_output(self, prediction, img, task) -> Tuple[NDArray, NDArray]:
|
||||
if task == "binarization":
|
||||
|
|
@ -191,9 +216,9 @@ class SBBPredict:
|
|||
return added_image, layout_only
|
||||
|
||||
def predict(self, image_dir):
|
||||
assert isinstance(self.model, Model)
|
||||
if self.task == 'classification':
|
||||
classes_names = self.config_params_model['classification_classes_name']
|
||||
|
||||
img_1ch = cv2.imread(image_dir, 0) / 255.0
|
||||
img_1ch = cv2.resize(img_1ch, (self.config_params_model['input_height'],
|
||||
self.config_params_model['input_width']),
|
||||
|
|
@ -231,6 +256,15 @@ class SBBPredict:
|
|||
pred_texts = pred_texts[0].replace("[UNK]", "")
|
||||
return pred_texts
|
||||
|
||||
elif self.task == "transformer-ocr":
|
||||
from PIL import Image
|
||||
image = Image.open(image_dir).convert("RGB")
|
||||
pixel_values = self.processor(image, return_tensors="pt").pixel_values
|
||||
generated_ids = self.model.generate(pixel_values.to(self.device))
|
||||
return self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||
|
||||
|
||||
|
||||
|
||||
elif self.task == 'reading_order':
|
||||
img_height = self.config_params_model['input_height']
|
||||
|
|
@ -566,6 +600,8 @@ class SBBPredict:
|
|||
cv2.imwrite(self.save,res)
|
||||
elif self.task == "cnn-rnn-ocr":
|
||||
print(f"Detected text: {res}")
|
||||
elif self.task == "transformer-ocr":
|
||||
print(f"Detected text: {res}")
|
||||
else:
|
||||
img_seg_overlayed, only_layout = self.visualize_model_output(res, self.img_org, self.task)
|
||||
if self.save:
|
||||
|
|
@ -668,11 +704,13 @@ class SBBPredict:
|
|||
help="min area size of regions considered for reading order detection. The default value is zero and means that all text regions are considered for reading order.",
|
||||
)
|
||||
def main(image, dir_in, model, patches, save, save_layout, ground_truth, xml_file, cpu, out, min_area):
|
||||
|
||||
assert image or dir_in, "Either a single image -i or a dir_in -di input is required"
|
||||
with open(os.path.join(model,'config.json')) as f:
|
||||
config_params_model = json.load(f)
|
||||
|
||||
task = config_params_model['task']
|
||||
if task not in ['classification', 'reading_order', "cnn-rnn-ocr"]:
|
||||
if task not in ['classification', 'reading_order', "cnn-rnn-ocr", "transformer-ocr"]:
|
||||
assert not image or save, "For segmentation or binarization, an input single image -i also requires an output filename -s"
|
||||
assert not dir_in or out, "For segmentation or binarization, an input directory -di also requires an output directory -o"
|
||||
x = SBBPredict(image, dir_in, model, task, config_params_model,
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import os
|
||||
import json
|
||||
|
||||
os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15
|
||||
import tensorflow as tf
|
||||
|
|
@ -23,6 +24,7 @@ from tensorflow.keras.layers import (
|
|||
Reshape,
|
||||
UpSampling2D,
|
||||
ZeroPadding2D,
|
||||
StringLookup,
|
||||
add,
|
||||
concatenate
|
||||
)
|
||||
|
|
@ -58,6 +60,50 @@ class CTCLayer(Layer):
|
|||
# At test time, just return the computed predictions.
|
||||
return y_pred
|
||||
|
||||
class CTCDecoder(Layer):
|
||||
def call(self, inputs):
|
||||
n_samples = tf.shape(inputs)[0]
|
||||
n_steps = inputs.shape[1]
|
||||
n_classes = inputs.shape[2]
|
||||
lengths = tf.ones(n_samples, dtype=tf.int32) * n_steps
|
||||
## Keras beam search seems to mess with double letters
|
||||
## but Keras greedy sometimes removes arbitrary letters
|
||||
# outputs, logits = tf.keras.backend.ctc_decode(inputs,
|
||||
# lengths,
|
||||
# beam_width=20
|
||||
# greedy=False, # True,
|
||||
# # backend does not allow these kwargs
|
||||
# #merge_repeated=False,
|
||||
# #mask_index=inputs.shape[2]-1,
|
||||
# )
|
||||
# tf.nn.ctc_*_decoder (in contrast to tf.keras.backend.ctc_decode)
|
||||
# needs logits instead of probs and time-major (batch 2nd dim)
|
||||
inputs = tf.math.log(
|
||||
tf.transpose(inputs, perm=[1, 0, 2]) + tf.keras.backend.epsilon()
|
||||
)
|
||||
# tf.nn.ctc_greedy_decoder() is not as precise
|
||||
# tf.compat.v1.nn.ctc_beam_search_decoder() also needs merge_repeated=False
|
||||
decoded, logits = tf.nn.ctc_beam_search_decoder(
|
||||
inputs,
|
||||
lengths,
|
||||
beam_width=10,
|
||||
top_paths=2,
|
||||
)
|
||||
# get top path for all sequences in batch
|
||||
decoded = decoded[0]
|
||||
logits = logits[:, 0] - logits[:, 1]
|
||||
# convert to dense
|
||||
outputs = tf.SparseTensor(decoded.indices, decoded.values,
|
||||
(n_samples, n_steps))
|
||||
outputs = tf.sparse.to_dense(sp_input=outputs, default_value=-1)
|
||||
# # drop non-tokens (-1) and OOV (0)
|
||||
# result = []
|
||||
# for output in outputs:
|
||||
# result.append(tf.gather(output, tf.where(output > 0)))
|
||||
# outputs = tf.stack(result)
|
||||
probs = tf.exp(-logits)
|
||||
return outputs, probs
|
||||
|
||||
def mlp(x, hidden_units, dropout_rate):
|
||||
for units in hidden_units:
|
||||
x = Dense(units, activation=tf.nn.gelu)(x)
|
||||
|
|
@ -309,11 +355,12 @@ def transformer_block(img,
|
|||
# Skip connection 2.
|
||||
encoded_patches = Add()([x3, x2])
|
||||
|
||||
encoded_patches = tf.reshape(encoded_patches,
|
||||
[-1,
|
||||
img.shape[1],
|
||||
#assert isinstance(x, Layer)
|
||||
|
||||
encoded_patches = Reshape((img.shape[1],
|
||||
img.shape[2],
|
||||
projection_dim // (patchsize_x * patchsize_y)])
|
||||
projection_dim // (patchsize_x * patchsize_y)),
|
||||
name="reshape_patches")(encoded_patches)
|
||||
return encoded_patches
|
||||
|
||||
def vit_resnet50_unet(num_patches,
|
||||
|
|
@ -423,11 +470,11 @@ def machine_based_reading_order_model(n_classes,input_height=224,input_width=224
|
|||
|
||||
return model
|
||||
|
||||
def cnn_rnn_ocr_model(image_height=None, image_width=None, n_classes=None, max_seq=None):
|
||||
input_img = Input(shape=(image_height, image_width, 3), name="image")
|
||||
def cnn_rnn_ocr_model(image_height=None, image_width=None, n_classes=None, max_len=None, inference=False, characters_txt_file=None):
|
||||
inputs = Input(shape=(image_height, image_width, 3), name="image")
|
||||
labels = Input(name="label", shape=(None,))
|
||||
|
||||
x = Conv2D(64,kernel_size=(3,3),padding="same")(input_img)
|
||||
x = Conv2D(64,kernel_size=(3,3),padding="same")(inputs)
|
||||
x = BatchNormalization(name="bn1")(x)
|
||||
x = Activation("relu", name="relu1")(x)
|
||||
x = Conv2D(64,kernel_size=(3,3),padding="same")(x)
|
||||
|
|
@ -460,43 +507,135 @@ def cnn_rnn_ocr_model(image_height=None, image_width=None, n_classes=None, max_s
|
|||
x2d = MaxPooling2D(pool_size=(1,2),strides=(1,2))(x)
|
||||
x4d = MaxPooling2D(pool_size=(1,2),strides=(1,2))(x2d)
|
||||
|
||||
|
||||
new_shape = (x.shape[1]*x.shape[2], x.shape[3])
|
||||
new_shape2 = (x2d.shape[1]*x2d.shape[2], x2d.shape[3])
|
||||
new_shape4 = (x4d.shape[1]*x4d.shape[2], x4d.shape[3])
|
||||
|
||||
x = Reshape(target_shape=new_shape, name="reshape")(x)
|
||||
x2d = Reshape(target_shape=new_shape2, name="reshape2")(x2d)
|
||||
x4d = Reshape(target_shape=new_shape4, name="reshape4")(x4d)
|
||||
x = Reshape(new_shape, name="reshape")(x)
|
||||
x2d = Reshape(new_shape2, name="reshape2")(x2d)
|
||||
x4d = Reshape(new_shape4, name="reshape4")(x4d)
|
||||
|
||||
xrnnorg = Bidirectional(LSTM(image_width, return_sequences=True, dropout=0.25))(x)
|
||||
xrnn2d = Bidirectional(LSTM(image_width, return_sequences=True, dropout=0.25))(x2d)
|
||||
xrnn4d = Bidirectional(LSTM(image_width, return_sequences=True, dropout=0.25))(x4d)
|
||||
|
||||
xrnn2d = Reshape(target_shape=(1, xrnn2d.shape[1], xrnn2d.shape[2]), name="reshape6")(xrnn2d)
|
||||
xrnn4d = Reshape(target_shape=(1, xrnn4d.shape[1], xrnn4d.shape[2]), name="reshape8")(xrnn4d)
|
||||
|
||||
xrnn2d = Reshape((1, xrnn2d.shape[1], xrnn2d.shape[2]), name="reshape6")(xrnn2d)
|
||||
xrnn4d = Reshape((1, xrnn4d.shape[1], xrnn4d.shape[2]), name="reshape8")(xrnn4d)
|
||||
|
||||
xrnn2dup = UpSampling2D(size=(1, 2), interpolation="nearest")(xrnn2d)
|
||||
xrnn4dup = UpSampling2D(size=(1, 4), interpolation="nearest")(xrnn4d)
|
||||
|
||||
xrnn2dup = Reshape(target_shape=(xrnn2dup.shape[2], xrnn2dup.shape[3]), name="reshape10")(xrnn2dup)
|
||||
xrnn4dup = Reshape(target_shape=(xrnn4dup.shape[2], xrnn4dup.shape[3]), name="reshape12")(xrnn4dup)
|
||||
xrnn2dup = Reshape((xrnn2dup.shape[2], xrnn2dup.shape[3]), name="reshape10")(xrnn2dup)
|
||||
xrnn4dup = Reshape((xrnn4dup.shape[2], xrnn4dup.shape[3]), name="reshape12")(xrnn4dup)
|
||||
|
||||
addition = Add()([xrnnorg, xrnn2dup, xrnn4dup])
|
||||
|
||||
addition_rnn = Bidirectional(LSTM(image_width, return_sequences=True, dropout=0.25))(addition)
|
||||
|
||||
out = Conv1D(max_seq, 1, data_format="channels_first")(addition_rnn)
|
||||
out = Conv1D(max_len, 1, data_format="channels_first")(addition_rnn)
|
||||
out = BatchNormalization(name="bn9")(out)
|
||||
out = Activation("relu", name="relu9")(out)
|
||||
#out = Conv1D(n_classes, 1, activation='relu', data_format="channels_last")(out)
|
||||
|
||||
out = Dense(n_classes, activation="softmax", name="dense2")(out)
|
||||
|
||||
if inference:
|
||||
# add second path for binarization
|
||||
inputs_bin = Input(shape=(image_height, image_width, 3), name="image_bin")
|
||||
out_bin = Model(inputs, out)(inputs_bin)
|
||||
# ensemble raw results
|
||||
out = 0.5 * (out + out_bin)
|
||||
# get tf.string batch
|
||||
out, prob = CTCDecoder()(out)
|
||||
# decode int to str
|
||||
with open(characters_txt_file, "r") as voc_file:
|
||||
voc = json.load(voc_file)
|
||||
char2num = StringLookup(vocabulary=voc)
|
||||
voc = char2num.get_vocabulary()
|
||||
num2char = StringLookup(vocabulary=voc, invert=True)
|
||||
output = num2char(out)
|
||||
# avoid output tf.dtype=string → np.dtype=object (which cannot be shm-ed)
|
||||
output = tf.io.decode_raw(output, tf.uint8, fixed_length=max(map(len, voc)))
|
||||
|
||||
return Model((inputs, inputs_bin), (output, prob))
|
||||
|
||||
# Add CTC layer for calculating CTC loss at each step.
|
||||
output = CTCLayer(name="ctc_loss")(labels, out)
|
||||
out = CTCLayer(name="ctc_loss")(labels, out)
|
||||
|
||||
model = Model(inputs=(input_img, labels), outputs=output, name="handwriting_recognizer")
|
||||
return Model((inputs, labels), out)
|
||||
|
||||
def cnn_rnn_ocr_model4inference(model, model_path):
|
||||
"""convert trained cnn-rnn-ocr model to inference model post-hoc"""
|
||||
try:
|
||||
model.get_layer(name='ctc_loss')
|
||||
except ValueError:
|
||||
# likely already converted
|
||||
return model
|
||||
else:
|
||||
inputs = model.get_layer(name='image').input
|
||||
output = model.get_layer(name='dense2').output
|
||||
inputs_bin = Input(inputs.shape[1:], name='image_bin')
|
||||
output_bin = Model(inputs, output)(inputs_bin)
|
||||
output = 0.5 * (output + output_bin)
|
||||
output, prob = CTCDecoder()(output)
|
||||
with open(model_path / "characters_org.txt", "r") as voc_file:
|
||||
voc = json.load(voc_file)
|
||||
char2num = StringLookup(vocabulary=voc)
|
||||
voc = char2num.get_vocabulary()
|
||||
num2char = StringLookup(vocabulary=voc, invert=True)
|
||||
output = num2char(output)
|
||||
# avoid output tf.dtype=string → np.dtype=object (which cannot be shm-ed)
|
||||
output = tf.io.decode_raw(output, tf.uint8, fixed_length=max(map(len, voc)))
|
||||
inputs = (inputs, inputs_bin)
|
||||
outputs = (output, prob)
|
||||
return Model(inputs, outputs)
|
||||
|
||||
def get_model(config, logger):
|
||||
from sacred.config import create_captured_function
|
||||
|
||||
task = config['task']
|
||||
if task in ["segmentation", "enhancement", "binarization"]:
|
||||
if config['backbone_type'] == 'nontransformer':
|
||||
builder = resnet50_unet
|
||||
else:
|
||||
num_patches_x, num_patches_y = config['transformer_num_patches_xy']
|
||||
num_patches = num_patches_x * num_patches_y
|
||||
|
||||
if config['transformer_cnn_first']:
|
||||
builder = vit_resnet50_unet
|
||||
multiple = 32
|
||||
else:
|
||||
builder = vit_resnet50_unet_transformer_before_cnn
|
||||
multiple = 1
|
||||
|
||||
assert config['input_height'] == (
|
||||
num_patches_y * config['transformer_patchsize_y'] * multiple), (
|
||||
"transformer_patchsize_y or transformer_num_patches_xy height value error: "
|
||||
"input_height should be equal to "
|
||||
"(transformer_num_patches_xy height value * transformer_patchsize_y * %d)" % multiple)
|
||||
assert config['input_width'] == (
|
||||
num_patches_x * config['transformer_patchsize_x'] * multiple), (
|
||||
"transformer_patchsize_x or transformer_num_patches_xy width value error: "
|
||||
"input_width should be equal to "
|
||||
"(transformer_num_patches_xy width value * transformer_patchsize_x * %d)" % multiple)
|
||||
assert 0 == (config['transformer_projection_dim'] %
|
||||
(config['transformer_patchsize_y'] *
|
||||
config['transformer_patchsize_x'])), (
|
||||
"transformer_projection_dim error: "
|
||||
"The remainder when parameter transformer_projection_dim is divided by "
|
||||
"(transformer_patchsize_y*transformer_patchsize_x) should be zero")
|
||||
|
||||
config['num_patches'] = num_patches
|
||||
elif task == "cnn-rnn-ocr":
|
||||
builder = cnn_rnn_ocr_model
|
||||
elif task=='classification':
|
||||
builder = resnet50_classifier
|
||||
elif task=='reading_order':
|
||||
builder = machine_based_reading_order_model
|
||||
else:
|
||||
raise ValueError("unknown model task '%s'" % task)
|
||||
|
||||
builder = create_captured_function(builder)
|
||||
builder.config = config
|
||||
builder.logger = logger
|
||||
return builder()
|
||||
|
|
|
|||
|
|
@ -4,38 +4,65 @@ MODELS_SRC = models_eynollah
|
|||
MODELS_DST = reloaded/models_eynollah
|
||||
|
||||
|
||||
# $(MODELS_DST)/eynollah-binarization_20210425 \
|
||||
# $(MODELS_DST)/eynollah-column-classifier_20210425 \
|
||||
# $(MODELS_DST)/eynollah-enhancement_20210425 \
|
||||
# $(MODELS_DST)/eynollah-main-regions-aug-rotation_20210425 \
|
||||
# $(MODELS_DST)/eynollah-main-regions-aug-scaling_20210425 \
|
||||
# $(MODELS_DST)/eynollah-main-regions-ensembled_20210425 \
|
||||
# $(MODELS_DST)/eynollah-main-regions_20220314 \
|
||||
# $(MODELS_DST)/eynollah-main-regions_20231127_672_org_ens_11_13_16_17_18 \
|
||||
# $(MODELS_DST)/eynollah-tables_20210319 \
|
||||
# $(MODELS_DST)/model_eynollah_ocr_cnnrnn_20250930 \
|
||||
# eynollah-main-regions-aug-rotation_20210425
|
||||
# eynollah-main-regions-aug-scaling_20210425
|
||||
# eynollah-main-regions-ensembled_20210425
|
||||
# eynollah-main-regions_20220314
|
||||
# eynollah-main-regions_20231127_672_org_ens_11_13_16_17_18
|
||||
# eynollah-tables_20210319
|
||||
|
||||
RELOADABLE_MODELS = \
|
||||
$(MODELS_DST)/model_eynollah_page_extraction_20250915 \
|
||||
$(MODELS_DST)/model_eynollah_reading_order_20250824 \
|
||||
$(MODELS_DST)/modelens_e_l_all_sp_0_1_2_3_4_171024 \
|
||||
$(MODELS_DST)/modelens_full_lay_1__4_3_091124 \
|
||||
$(MODELS_DST)/modelens_table_0t4_201124 \
|
||||
$(MODELS_DST)/modelens_textline_0_1__2_4_16092024
|
||||
CURRENT_MODELS :=
|
||||
CURRENT_MODELS += model_eynollah_page_extraction_20250915
|
||||
CURRENT_MODELS += model_eynollah_reading_order_20250824
|
||||
CURRENT_MODELS += modelens_e_l_all_sp_0_1_2_3_4_171024
|
||||
CURRENT_MODELS += modelens_full_lay_1__4_3_091124
|
||||
CURRENT_MODELS += modelens_table_0t4_201124
|
||||
CURRENT_MODELS += modelens_textline_0_1__2_4_16092024
|
||||
CURRENT_MODELS += model_eynollah_ocr_cnnrnn_20250930
|
||||
CURRENT_MODELS += eynollah-binarization_20210425
|
||||
CURRENT_MODELS += eynollah-column-classifier_20210425
|
||||
CURRENT_MODELS += eynollah-enhancement_20210425
|
||||
|
||||
all: $(RELOADABLE_MODELS)
|
||||
all: tf-serving
|
||||
|
||||
tf-serving: $(CURRENT_MODELS:%=$(MODELS_DST)/%)
|
||||
keras: $(CURRENT_MODELS:%=$(MODELS_DST)/%.keras)
|
||||
hdf5: $(CURRENT_MODELS:%=$(MODELS_DST)/%.h5)
|
||||
onnx: $(CURRENT_MODELS:%=$(MODELS_DST)/%.onnx)
|
||||
|
||||
$(MODELS_DST)/%: $(MODELS_SRC)/%
|
||||
mkdir -p $@
|
||||
test -e $</config.json || exit 1
|
||||
eynollah-training train --force \
|
||||
with $</config.json \
|
||||
reload_weights=True \
|
||||
continue_training=False \
|
||||
dir_output=$(dir $@) \
|
||||
dir_of_start_model=$< \
|
||||
2>&1 | tee $(notdir $<).log
|
||||
cp $</config.json $@/config.json
|
||||
eynollah-training convert \
|
||||
$(and $(wildcard $</config.json),--rebuild) \
|
||||
--in $< \
|
||||
--format tf-serving \
|
||||
--out $@ \
|
||||
2>&1 | tee $(notdir $<).tf-serving.log
|
||||
|
||||
$(MODELS_DST)/%.keras: $(MODELS_SRC)/%
|
||||
eynollah-training convert \
|
||||
$(and $(wildcard $</config.json),--rebuild) \
|
||||
--in $< \
|
||||
--format keras \
|
||||
--out $@ \
|
||||
2>&1 | tee $(notdir $<).keras.log
|
||||
|
||||
$(MODELS_DST)/%.h5: $(MODELS_SRC)/%
|
||||
eynollah-training convert \
|
||||
$(and $(wildcard $</config.json),--rebuild) \
|
||||
--in $< \
|
||||
--format hdf5 \
|
||||
--out $@ \
|
||||
2>&1 | tee $(notdir $<).hdf5.log
|
||||
|
||||
$(MODELS_DST)/%.onnx: $(MODELS_SRC)/%
|
||||
if jq -e '.task == "segmentation" and .backbone_type == "transformer"' $</config.json &>/dev/null; then \
|
||||
echo skipping $@: vision transformer architecture currently does not work with ONNX; else \
|
||||
eynollah-training convert \
|
||||
$(and $(wildcard $</config.json),--rebuild) \
|
||||
--in $< \
|
||||
--format onnx \
|
||||
--out $@ \
|
||||
2>&1 | tee $(notdir $<).onnx.log; fi
|
||||
|
||||
compare:
|
||||
for i in `find $(MODELS_DST) -mindepth 2`;do \
|
||||
|
|
@ -43,6 +70,5 @@ compare:
|
|||
du -bs $$n $$i ; \
|
||||
done
|
||||
|
||||
|
||||
clear:
|
||||
rm -rf $(MODELS_DST)
|
||||
|
|
|
|||
|
|
@ -2,10 +2,16 @@ import os
|
|||
import sys
|
||||
import io
|
||||
import json
|
||||
import click
|
||||
|
||||
|
||||
from tqdm import tqdm
|
||||
import requests
|
||||
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
|
||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
||||
os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15
|
||||
import tensorflow as tf
|
||||
|
|
@ -17,7 +23,6 @@ from tensorflow.keras.layers import StringLookup
|
|||
from tensorflow.keras.utils import image_dataset_from_directory
|
||||
from tensorflow.keras.backend import one_hot
|
||||
from sacred import Experiment
|
||||
from sacred.config import create_captured_function
|
||||
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
|
@ -32,27 +37,29 @@ from .metrics import (
|
|||
connected_components_loss,
|
||||
)
|
||||
from .models import (
|
||||
PatchEncoder,
|
||||
Patches,
|
||||
machine_based_reading_order_model,
|
||||
resnet50_classifier,
|
||||
resnet50_unet,
|
||||
vit_resnet50_unet,
|
||||
vit_resnet50_unet_transformer_before_cnn,
|
||||
cnn_rnn_ocr_model,
|
||||
RESNET50_WEIGHTS_PATH,
|
||||
RESNET50_WEIGHTS_URL
|
||||
RESNET50_WEIGHTS_URL,
|
||||
get_model
|
||||
)
|
||||
from .utils import (
|
||||
generate_arrays_from_folder_reading_order,
|
||||
get_one_hot,
|
||||
preprocess_imgs,
|
||||
return_number_of_total_training_data,
|
||||
OCRDatasetYieldAugmentations
|
||||
)
|
||||
from .weights_ensembling import run_ensembling
|
||||
|
||||
|
||||
import torch
|
||||
from transformers import TrOCRProcessor
|
||||
import evaluate
|
||||
from transformers import default_data_collator
|
||||
from transformers import VisionEncoderDecoderModel
|
||||
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
|
||||
|
||||
class SaveWeightsAfterSteps(ModelCheckpoint):
|
||||
def __init__(self, save_interval, save_path, _config, **kwargs):
|
||||
def __init__(self, save_interval, save_path, _config, characters_cnnrnn_ocr=None, **kwargs):
|
||||
if save_interval:
|
||||
# batches
|
||||
super().__init__(
|
||||
|
|
@ -67,12 +74,15 @@ class SaveWeightsAfterSteps(ModelCheckpoint):
|
|||
verbose=1,
|
||||
**kwargs)
|
||||
self._config = _config
|
||||
self.characters_cnnrnn_ocr = characters_cnnrnn_ocr
|
||||
|
||||
# overwrite tf-keras (Keras 2) implementation to get our _config JSON in
|
||||
def _save_handler(self, filepath):
|
||||
super()._save_handler(filepath)
|
||||
with open(os.path.join(filepath, "config.json"), "w") as fp:
|
||||
json.dump(self._config, fp) # encode dict into JSON
|
||||
if self.characters_cnnrnn_ocr:
|
||||
os.system("cp "+self.characters_cnnrnn_ocr+" "+os.path.join(os.path.join(self.save_path, f"model_step_{self.step_count}"),"characters_org.txt"))
|
||||
|
||||
def configuration():
|
||||
try:
|
||||
|
|
@ -355,10 +365,9 @@ def config_params():
|
|||
dir_output = None # Directory where the augmented training data and the model checkpoints will be saved.
|
||||
pretraining = False # Set to true to (down)load pretrained weights of ResNet50 encoder.
|
||||
save_interval = None # frequency for writing model checkpoints (positive integer for number of batches saved under "model_step_{batch:04d}", otherwise epoch saved under "model_{epoch:02d}")
|
||||
reload_weights = False # Set true to build new model from config, load weights from dir_of_start_model, save under dir_output and exit.
|
||||
continue_training = False # Whether to continue training an existing model.
|
||||
dir_of_start_model = '' # Directory of model checkpoint to load to continue training or load weights from. (E.g. if you already trained for 3 epochs, set "dir_of_start_model=dir_output/model_03".)
|
||||
if continue_training:
|
||||
dir_of_start_model = '' # Directory of model checkpoint to load to continue training. (E.g. if you already trained for 3 epochs, set "dir_of_start_model=dir_output/model_03".)
|
||||
index_start = 0 # Epoch counter initial value to continue training. (E.g. if you already trained for 3 epochs, set "index_start=3" to continue naming checkpoints model_04, model_05 etc.)
|
||||
data_is_provided = False # Whether the preprocessed input data (subdirectories "images" and "labels" in both subdirectories "train" and "eval" of "dir_output") has already been generated (in the first epoch of a previous run).
|
||||
|
||||
|
|
@ -379,7 +388,6 @@ def run(_config,
|
|||
weight_decay,
|
||||
learning_rate,
|
||||
continue_training,
|
||||
reload_weights,
|
||||
save_interval,
|
||||
augmentation,
|
||||
# dependent config keys need a default,
|
||||
|
|
@ -477,58 +485,15 @@ def run(_config,
|
|||
if task == "enhancement":
|
||||
assert not is_loss_soft_dice, "for enhancement, soft_dice loss does not apply"
|
||||
assert not weighted_loss, "for enhancement, weighted loss does not apply"
|
||||
|
||||
if continue_training:
|
||||
custom_objects = dict()
|
||||
if is_loss_soft_dice:
|
||||
custom_objects.update(soft_dice_loss=soft_dice_loss)
|
||||
elif weighted_loss:
|
||||
custom_objects.update(loss=weighted_categorical_crossentropy(weights))
|
||||
if backbone_type == 'transformer':
|
||||
custom_objects.update(PatchEncoder=PatchEncoder,
|
||||
Patches=Patches)
|
||||
model = load_model(dir_of_start_model, compile=False,
|
||||
custom_objects=custom_objects)
|
||||
model = load_model(dir_of_start_model, compile=False)
|
||||
else:
|
||||
index_start = 0
|
||||
if backbone_type == 'nontransformer':
|
||||
model = resnet50_unet(n_classes,
|
||||
input_height,
|
||||
input_width,
|
||||
task,
|
||||
weight_decay,
|
||||
pretraining)
|
||||
else:
|
||||
num_patches_x = transformer_num_patches_xy[0]
|
||||
num_patches_y = transformer_num_patches_xy[1]
|
||||
num_patches = num_patches_x * num_patches_y
|
||||
|
||||
if transformer_cnn_first:
|
||||
model_builder = vit_resnet50_unet
|
||||
multiple = 32
|
||||
else:
|
||||
model_builder = vit_resnet50_unet_transformer_before_cnn
|
||||
multiple = 1
|
||||
|
||||
assert input_height == (
|
||||
num_patches_y * transformer_patchsize_y * multiple), (
|
||||
"transformer_patchsize_y or transformer_num_patches_xy height value error: "
|
||||
"input_height should be equal to "
|
||||
"(transformer_num_patches_xy height value * transformer_patchsize_y * %d)" % multiple)
|
||||
assert input_width == (
|
||||
num_patches_x * transformer_patchsize_x * multiple), (
|
||||
"transformer_patchsize_x or transformer_num_patches_xy width value error: "
|
||||
"input_width should be equal to "
|
||||
"(transformer_num_patches_xy width value * transformer_patchsize_x * %d)" % multiple)
|
||||
assert 0 == (transformer_projection_dim %
|
||||
(transformer_patchsize_y * transformer_patchsize_x)), (
|
||||
"transformer_projection_dim error: "
|
||||
"The remainder when parameter transformer_projection_dim is divided by "
|
||||
"(transformer_patchsize_y*transformer_patchsize_x) should be zero")
|
||||
|
||||
model_builder = create_captured_function(model_builder)
|
||||
model_builder.config = _config
|
||||
model_builder.logger = _log
|
||||
model = model_builder(num_patches)
|
||||
model = get_model(_config, _log)
|
||||
if dir_of_start_model:
|
||||
model.load_weights(dir_of_start_model).assert_existing_objects_matched().expect_partial()
|
||||
_log.info("reloaded weights from %s", dir_of_start_model)
|
||||
|
||||
assert model is not None
|
||||
#if you want to see the model structure just uncomment model summary.
|
||||
|
|
@ -559,15 +524,6 @@ def run(_config,
|
|||
optimizer=Adam(learning_rate=learning_rate),
|
||||
metrics=metrics)
|
||||
|
||||
if reload_weights:
|
||||
model.load_weights(dir_of_start_model).assert_existing_objects_matched().expect_partial()
|
||||
dir_save = os.path.join(dir_output, os.path.basename(os.path.normpath(dir_of_start_model)))
|
||||
model.save(dir_save, include_optimizer=False)
|
||||
with open(os.path.join(dir_save, "config.json"), "w") as fp:
|
||||
json.dump(_config, fp) # encode dict into JSON
|
||||
_log.info("reloaded model from %s to %s", dir_of_start_model, dir_save)
|
||||
return
|
||||
|
||||
if not data_is_provided:
|
||||
# first create a directory in output for both training and evaluations
|
||||
# in order to flow data from these directories.
|
||||
|
|
@ -708,10 +664,11 @@ def run(_config,
|
|||
model = load_model(dir_of_start_model)
|
||||
else:
|
||||
index_start = 0
|
||||
model = cnn_rnn_ocr_model(image_height=input_height,
|
||||
image_width=input_width,
|
||||
n_classes=n_classes,
|
||||
max_seq=max_len)
|
||||
model = get_model(_config, _log)
|
||||
if dir_of_start_model:
|
||||
model.load_weights(dir_of_start_model).assert_existing_objects_matched().expect_partial()
|
||||
_log.info("reloaded weights from %s", dir_of_start_model)
|
||||
|
||||
#initial_learning_rate = 1e-4
|
||||
#decay_steps = int (n_epochs * ( len_dataset / n_batch ))
|
||||
#alpha = 0.01
|
||||
|
|
@ -722,15 +679,6 @@ def run(_config,
|
|||
|
||||
#print(model.summary())
|
||||
|
||||
if reload_weights:
|
||||
model.load_weights(dir_of_start_model).assert_existing_objects_matched().expect_partial()
|
||||
dir_save = os.path.join(dir_output, os.path.basename(os.path.normpath(dir_of_start_model)))
|
||||
model.save(dir_save, include_optimizer=False)
|
||||
with open(os.path.join(dir_save, "config.json"), "w") as fp:
|
||||
json.dump(_config, fp) # encode dict into JSON
|
||||
_log.info("reloaded model from %s to %s", dir_of_start_model, dir_save)
|
||||
return
|
||||
|
||||
# todo: use Dataset.map() on Dataset.list_files()
|
||||
def get_dataset(dir_img, dir_lab):
|
||||
def gen():
|
||||
|
|
@ -772,25 +720,15 @@ def run(_config,
|
|||
model = load_model(dir_of_start_model, compile=False)
|
||||
else:
|
||||
index_start = 0
|
||||
model = resnet50_classifier(n_classes,
|
||||
input_height,
|
||||
input_width,
|
||||
weight_decay,
|
||||
pretraining)
|
||||
model = get_model(_config, _log)
|
||||
if dir_of_start_model:
|
||||
model.load_weights(dir_of_start_model).assert_existing_objects_matched().expect_partial()
|
||||
_log.info("reloaded weights from %s", dir_of_start_model)
|
||||
|
||||
model.compile(loss='categorical_crossentropy',
|
||||
optimizer=Adam(learning_rate=0.001), # rs: why not learning_rate?
|
||||
metrics=['accuracy', F1Score(average='macro', name='f1')])
|
||||
|
||||
if reload_weights:
|
||||
model.load_weights(dir_of_start_model).assert_existing_objects_matched().expect_partial()
|
||||
dir_save = os.path.join(dir_output, os.path.basename(os.path.normpath(dir_of_start_model)))
|
||||
model.save(dir_save, include_optimizer=False)
|
||||
with open(os.path.join(dir_save, "config.json"), "w") as fp:
|
||||
json.dump(_config, fp) # encode dict into JSON
|
||||
_log.info("reloaded model from %s to %s", dir_of_start_model, dir_save)
|
||||
return
|
||||
|
||||
list_classes = list(classification_classes_name.values())
|
||||
data_args = dict(label_mode="categorical",
|
||||
class_names=list_classes,
|
||||
|
|
@ -820,19 +758,135 @@ def run(_config,
|
|||
usable_checkpoints = [os.path.join(dir_output, 'model_{epoch:02d}'.format(epoch=epoch + 1))
|
||||
for epoch in usable_checkpoints]
|
||||
ens_path = os.path.join(dir_output, 'model_ens_avg')
|
||||
run_ensembling(usable_checkpoints, ens_path)
|
||||
run_ensembling(usable_checkpoints, ens_path, framework="tensorflow")
|
||||
_log.info("ensemble model saved under '%s'", ens_path)
|
||||
|
||||
# =======
|
||||
|
||||
|
||||
elif task=="transformer-ocr":
|
||||
dir_img, dir_lab = get_dirs_or_files(dir_train)
|
||||
|
||||
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-printed")
|
||||
|
||||
ls_files_images = os.listdir(dir_img)
|
||||
|
||||
aug_multip = return_multiplier_based_on_augmnentations(augmentation, color_padding_rotation, rotation_not_90, blur_aug, degrading, bin_deg,
|
||||
brightening, padding_white, adding_rgb_foreground, adding_rgb_background, binarization,
|
||||
image_inversion, channels_shuffling, add_red_textlines, white_noise_strap, textline_skewing, textline_skewing_bin, textline_left_in_depth, textline_left_in_depth_bin, textline_right_in_depth, textline_right_in_depth_bin, textline_up_in_depth, textline_up_in_depth_bin, textline_down_in_depth, textline_down_in_depth_bin, pepper_bin_aug, pepper_aug, degrade_scales, number_of_backgrounds_per_image, thetha, thetha_padd, brightness, padd_colors, shuffle_indexes, pepper_indexes, skewing_amplitudes, blur_k, white_padds)
|
||||
|
||||
len_dataset = aug_multip*len(ls_files_images)
|
||||
|
||||
dataset = OCRDatasetYieldAugmentations(
|
||||
dir_img=dir_img,
|
||||
dir_img_bin=dir_img_bin,
|
||||
dir_lab=dir_lab,
|
||||
processor=processor,
|
||||
max_target_length=max_len,
|
||||
augmentation = augmentation,
|
||||
binarization = binarization,
|
||||
add_red_textlines = add_red_textlines,
|
||||
white_noise_strap = white_noise_strap,
|
||||
adding_rgb_foreground = adding_rgb_foreground,
|
||||
adding_rgb_background = adding_rgb_background,
|
||||
bin_deg = bin_deg,
|
||||
blur_aug = blur_aug,
|
||||
brightening = brightening,
|
||||
padding_white = padding_white,
|
||||
color_padding_rotation = color_padding_rotation,
|
||||
rotation_not_90 = rotation_not_90,
|
||||
degrading = degrading,
|
||||
channels_shuffling = channels_shuffling,
|
||||
textline_skewing = textline_skewing,
|
||||
textline_skewing_bin = textline_skewing_bin,
|
||||
textline_right_in_depth = textline_right_in_depth,
|
||||
textline_left_in_depth = textline_left_in_depth,
|
||||
textline_up_in_depth = textline_up_in_depth,
|
||||
textline_down_in_depth = textline_down_in_depth,
|
||||
textline_right_in_depth_bin = textline_right_in_depth_bin,
|
||||
textline_left_in_depth_bin = textline_left_in_depth_bin,
|
||||
textline_up_in_depth_bin = textline_up_in_depth_bin,
|
||||
textline_down_in_depth_bin = textline_down_in_depth_bin,
|
||||
pepper_aug = pepper_aug,
|
||||
pepper_bin_aug = pepper_bin_aug,
|
||||
list_all_possible_background_images=list_all_possible_background_images,
|
||||
list_all_possible_foreground_rgbs=list_all_possible_foreground_rgbs,
|
||||
blur_k = blur_k,
|
||||
degrade_scales = degrade_scales,
|
||||
white_padds = white_padds,
|
||||
thetha_padd = thetha_padd,
|
||||
thetha = thetha,
|
||||
brightness = brightness,
|
||||
padd_colors = padd_colors,
|
||||
number_of_backgrounds_per_image = number_of_backgrounds_per_image,
|
||||
shuffle_indexes = shuffle_indexes,
|
||||
pepper_indexes = pepper_indexes,
|
||||
skewing_amplitudes = skewing_amplitudes,
|
||||
dir_rgb_backgrounds = dir_rgb_backgrounds,
|
||||
dir_rgb_foregrounds = dir_rgb_foregrounds,
|
||||
len_data=len_dataset,
|
||||
)
|
||||
|
||||
# Create a DataLoader
|
||||
data_loader = torch.utils.data.DataLoader(dataset, batch_size=1)
|
||||
train_dataset = data_loader.dataset
|
||||
|
||||
|
||||
if continue_training:
|
||||
model = VisionEncoderDecoderModel.from_pretrained(dir_of_start_model)
|
||||
else:
|
||||
model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-printed")
|
||||
|
||||
|
||||
# set special tokens used for creating the decoder_input_ids from the labels
|
||||
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
|
||||
model.config.pad_token_id = processor.tokenizer.pad_token_id
|
||||
# make sure vocab size is set correctly
|
||||
model.config.vocab_size = model.config.decoder.vocab_size
|
||||
|
||||
# set beam search parameters
|
||||
model.config.eos_token_id = processor.tokenizer.sep_token_id
|
||||
model.config.max_length = max_len
|
||||
model.config.early_stopping = True
|
||||
model.config.no_repeat_ngram_size = 3
|
||||
model.config.length_penalty = 2.0
|
||||
model.config.num_beams = 4
|
||||
|
||||
|
||||
training_args = Seq2SeqTrainingArguments(
|
||||
predict_with_generate=True,
|
||||
num_train_epochs=n_epochs,
|
||||
learning_rate=learning_rate,
|
||||
per_device_train_batch_size=n_batch,
|
||||
fp16=True,
|
||||
output_dir=dir_output,
|
||||
logging_steps=2,
|
||||
save_steps=save_interval,
|
||||
)
|
||||
|
||||
|
||||
cer_metric = evaluate.load("cer")
|
||||
|
||||
# instantiate trainer
|
||||
trainer = Seq2SeqTrainer(
|
||||
model=model,
|
||||
tokenizer=processor.feature_extractor,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset,
|
||||
data_collator=default_data_collator,
|
||||
)
|
||||
trainer.train()
|
||||
|
||||
|
||||
elif task=='reading_order':
|
||||
if continue_training:
|
||||
model = load_model(dir_of_start_model, compile=False)
|
||||
else:
|
||||
index_start = 0
|
||||
model = machine_based_reading_order_model(n_classes,
|
||||
input_height,
|
||||
input_width,
|
||||
weight_decay,
|
||||
pretraining)
|
||||
model = get_model(_config, _log)
|
||||
if dir_of_start_model:
|
||||
model.load_weights(dir_of_start_model).assert_existing_objects_matched().expect_partial()
|
||||
_log.info("reloaded weights from %s", dir_of_start_model)
|
||||
|
||||
#f1score_tot = [0]
|
||||
model.compile(loss="binary_crossentropy",
|
||||
|
|
@ -840,15 +894,6 @@ def run(_config,
|
|||
optimizer=Adam(learning_rate=0.0001), # rs: why not learning_rate?
|
||||
metrics=['accuracy'])
|
||||
|
||||
if reload_weights:
|
||||
model.load_weights(dir_of_start_model).assert_existing_objects_matched().expect_partial()
|
||||
dir_save = os.path.join(dir_output, os.path.basename(os.path.normpath(dir_of_start_model)))
|
||||
model.save(dir_save, include_optimizer=False)
|
||||
with open(os.path.join(dir_save, "config.json"), "w") as fp:
|
||||
json.dump(_config, fp) # encode dict into JSON
|
||||
_log.info("reloaded model from %s to %s", dir_of_start_model, dir_save)
|
||||
return
|
||||
|
||||
dir_flow_train_imgs = os.path.join(dir_train, 'images')
|
||||
dir_flow_train_labels = os.path.join(dir_train, 'labels')
|
||||
|
||||
|
|
@ -881,3 +926,23 @@ def run(_config,
|
|||
model_dir = os.path.join(dir_out,'model_best')
|
||||
model.save(model_dir)
|
||||
'''
|
||||
|
||||
@click.command(context_settings=dict(
|
||||
ignore_unknown_options=True,
|
||||
))
|
||||
@click.argument('SACRED_ARGS', nargs=-1, type=click.UNPROCESSED)
|
||||
def train_cli(sacred_args):
|
||||
"""
|
||||
train model on extracted GT
|
||||
|
||||
SACRED_ARGS as per CLI interface of Sacred, cf.
|
||||
https://sacred.readthedocs.io/en/stable/command_line.html:
|
||||
|
||||
\b
|
||||
To configure the learning task, pass the string `with`,
|
||||
followed by any number of
|
||||
- config JSON file paths
|
||||
- parameter overrides in the form of key=value
|
||||
(where the later settings will override the former).
|
||||
"""
|
||||
ex.run_commandline([sys.argv[0]] + list(sacred_args))
|
||||
|
|
|
|||
|
|
@ -14,6 +14,9 @@ import tensorflow as tf
|
|||
|
||||
from PIL import Image, ImageFile, ImageEnhance
|
||||
|
||||
import torch
|
||||
from torch.utils.data import IterableDataset
|
||||
|
||||
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||
|
||||
|
||||
|
|
@ -78,6 +81,7 @@ def add_salt_and_pepper_noise(img, salt_prob, pepper_prob):
|
|||
|
||||
return noisy_image
|
||||
|
||||
|
||||
def invert_image(img):
|
||||
img_inv = 255 - img
|
||||
return img_inv
|
||||
|
|
@ -1242,3 +1246,411 @@ def preprocess_img_ocr(
|
|||
for pepper_ind in pepper_indexes:
|
||||
img_noisy = add_salt_and_pepper_noise(img_bin_corr, pepper_ind, pepper_ind)
|
||||
yield scale_image(img_noisy), lab
|
||||
|
||||
|
||||
class OCRDatasetYieldAugmentations(IterableDataset):
|
||||
def __init__(
|
||||
self,
|
||||
dir_img,
|
||||
dir_img_bin,
|
||||
dir_lab,
|
||||
processor,
|
||||
max_target_length=128,
|
||||
augmentation = None,
|
||||
binarization = None,
|
||||
add_red_textlines = None,
|
||||
white_noise_strap = None,
|
||||
adding_rgb_foreground = None,
|
||||
adding_rgb_background = None,
|
||||
bin_deg = None,
|
||||
blur_aug = None,
|
||||
brightening = None,
|
||||
padding_white = None,
|
||||
color_padding_rotation = None,
|
||||
rotation_not_90 = None,
|
||||
degrading = None,
|
||||
channels_shuffling = None,
|
||||
textline_skewing = None,
|
||||
textline_skewing_bin = None,
|
||||
textline_right_in_depth = None,
|
||||
textline_left_in_depth = None,
|
||||
textline_up_in_depth = None,
|
||||
textline_down_in_depth = None,
|
||||
textline_right_in_depth_bin = None,
|
||||
textline_left_in_depth_bin = None,
|
||||
textline_up_in_depth_bin = None,
|
||||
textline_down_in_depth_bin = None,
|
||||
pepper_aug = None,
|
||||
pepper_bin_aug = None,
|
||||
list_all_possible_background_images=None,
|
||||
list_all_possible_foreground_rgbs=None,
|
||||
blur_k = None,
|
||||
degrade_scales = None,
|
||||
white_padds = None,
|
||||
thetha_padd = None,
|
||||
thetha = None,
|
||||
brightness = None,
|
||||
padd_colors = None,
|
||||
number_of_backgrounds_per_image = None,
|
||||
shuffle_indexes = None,
|
||||
pepper_indexes = None,
|
||||
skewing_amplitudes = None,
|
||||
dir_rgb_backgrounds = None,
|
||||
dir_rgb_foregrounds = None,
|
||||
len_data=None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
images_dir (str): Path to the directory containing images.
|
||||
labels_dir (str): Path to the directory containing label text files.
|
||||
tokenizer: Tokenizer for processing labels.
|
||||
transform: Transformations applied after augmentation (e.g., ToTensor, normalization).
|
||||
image_size (tuple): Size to resize images to.
|
||||
max_seq_len (int): Maximum sequence length for tokenized labels.
|
||||
scales (list or None): List of scale factors to apply.
|
||||
"""
|
||||
self.dir_img = dir_img
|
||||
self.dir_img_bin = dir_img_bin
|
||||
self.dir_lab = dir_lab
|
||||
self.processor = processor
|
||||
self.max_target_length = max_target_length
|
||||
#self.scales = scales if scales else []
|
||||
|
||||
self.augmentation = augmentation
|
||||
self.binarization = binarization
|
||||
self.add_red_textlines = add_red_textlines
|
||||
self.white_noise_strap = white_noise_strap
|
||||
self.adding_rgb_foreground = adding_rgb_foreground
|
||||
self.adding_rgb_background = adding_rgb_background
|
||||
self.bin_deg = bin_deg
|
||||
self.blur_aug = blur_aug
|
||||
self.brightening = brightening
|
||||
self.padding_white = padding_white
|
||||
self.color_padding_rotation = color_padding_rotation
|
||||
self.rotation_not_90 = rotation_not_90
|
||||
self.degrading = degrading
|
||||
self.channels_shuffling = channels_shuffling
|
||||
self.textline_skewing = textline_skewing
|
||||
self.textline_skewing_bin = textline_skewing_bin
|
||||
self.textline_right_in_depth = textline_right_in_depth
|
||||
self.textline_left_in_depth = textline_left_in_depth
|
||||
self.textline_up_in_depth = textline_up_in_depth
|
||||
self.textline_down_in_depth = textline_down_in_depth
|
||||
self.textline_right_in_depth_bin = textline_right_in_depth_bin
|
||||
self.textline_left_in_depth_bin = textline_left_in_depth_bin
|
||||
self.textline_up_in_depth_bin = textline_up_in_depth_bin
|
||||
self.textline_down_in_depth_bin = textline_down_in_depth_bin
|
||||
self.pepper_aug = pepper_aug
|
||||
self.pepper_bin_aug = pepper_bin_aug
|
||||
self.list_all_possible_background_images=list_all_possible_background_images
|
||||
self.list_all_possible_foreground_rgbs=list_all_possible_foreground_rgbs
|
||||
self.blur_k = blur_k
|
||||
self.degrade_scales = degrade_scales
|
||||
self.white_padds = white_padds
|
||||
self.thetha_padd = thetha_padd
|
||||
self.thetha = thetha
|
||||
self.brightness = brightness
|
||||
self.padd_colors = padd_colors
|
||||
self.number_of_backgrounds_per_image = number_of_backgrounds_per_image
|
||||
self.shuffle_indexes = shuffle_indexes
|
||||
self.pepper_indexes = pepper_indexes
|
||||
self.skewing_amplitudes = skewing_amplitudes
|
||||
self.dir_rgb_backgrounds = dir_rgb_backgrounds
|
||||
self.dir_rgb_foregrounds = dir_rgb_foregrounds
|
||||
self.image_files = os.listdir(dir_img)#sorted([f for f in os.listdir(images_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))])
|
||||
self.len_data = len_data
|
||||
#assert len(self.image_files) == len(self.label_files), "Number of images and labels must match!"
|
||||
|
||||
def __len__(self):
|
||||
return self.len_data
|
||||
|
||||
def __iter__(self):
|
||||
for img_file in self.image_files:
|
||||
# Load image
|
||||
f_name = img_file.split('.')[0]
|
||||
|
||||
txt_inp = open(os.path.join(self.dir_lab, f_name+'.txt'),'r').read().split('\n')[0]
|
||||
|
||||
img = cv2.imread(os.path.join(self.dir_img, img_file))
|
||||
img = img.astype(np.uint8)
|
||||
|
||||
|
||||
if self.dir_img_bin:
|
||||
img_bin_corr = cv2.imread(os.path.join(self.dir_img_bin, f_name+'.png') )
|
||||
img_bin_corr = img_bin_corr.astype(np.uint8)
|
||||
else:
|
||||
img_bin_corr = None
|
||||
|
||||
|
||||
labels = self.processor.tokenizer(txt_inp,
|
||||
padding="max_length",
|
||||
max_length=self.max_target_length).input_ids
|
||||
|
||||
labels = [label if label != self.processor.tokenizer.pad_token_id else -100 for label in labels]
|
||||
|
||||
|
||||
if self.augmentation:
|
||||
pixel_values = self.processor(Image.fromarray(img), return_tensors="pt").pixel_values
|
||||
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
|
||||
yield encoding
|
||||
|
||||
|
||||
if self.color_padding_rotation:
|
||||
for index, thetha_ind in enumerate(self.thetha_padd):
|
||||
for padd_col in self.padd_colors:
|
||||
img_out = rotation_not_90_func_single_image(do_padding_for_ocr(img, 1.2, padd_col), thetha_ind)
|
||||
img_out = img_out.astype(np.uint8)
|
||||
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
|
||||
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
|
||||
yield encoding
|
||||
|
||||
if self.rotation_not_90:
|
||||
for index, thetha_ind in enumerate(self.thetha):
|
||||
img_out = rotation_not_90_func_single_image(img, thetha_ind)
|
||||
img_out = img_out.astype(np.uint8)
|
||||
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
|
||||
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
|
||||
yield encoding
|
||||
|
||||
|
||||
if self.blur_aug:
|
||||
for index, blur_type in enumerate(self.blur_k):
|
||||
img_out = bluring(img, blur_type)
|
||||
img_out = img_out.astype(np.uint8)
|
||||
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
|
||||
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
|
||||
yield encoding
|
||||
|
||||
if self.degrading:
|
||||
for index, deg_scale_ind in enumerate(self.degrade_scales):
|
||||
try:
|
||||
img_out = do_degrading(img, deg_scale_ind)
|
||||
except:
|
||||
img_out = np.copy(img)
|
||||
img_out = img_out.astype(np.uint8)
|
||||
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
|
||||
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
|
||||
yield encoding
|
||||
|
||||
if self.bin_deg:
|
||||
for index, deg_scale_ind in enumerate(self.degrade_scales):
|
||||
try:
|
||||
img_out = self.do_degrading(img_bin_corr, deg_scale_ind)
|
||||
except:
|
||||
img_out = np.copy(img_bin_corr)
|
||||
img_out = img_out.astype(np.uint8)
|
||||
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
|
||||
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
|
||||
yield encoding
|
||||
|
||||
|
||||
if self.brightening:
|
||||
for index, bright_scale_ind in enumerate(self.brightness):
|
||||
try:
|
||||
img_out = do_brightening(dir_img, bright_scale_ind)
|
||||
except:
|
||||
img_out = np.copy(img)
|
||||
img_out = img_out.astype(np.uint8)
|
||||
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
|
||||
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
|
||||
yield encoding
|
||||
|
||||
|
||||
if self.padding_white:
|
||||
for index, padding_size in enumerate(self.white_padds):
|
||||
for padd_col in self.padd_colors:
|
||||
img_out = do_padding_for_ocr(img, padding_size, padd_col)
|
||||
img_out = img_out.astype(np.uint8)
|
||||
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
|
||||
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
|
||||
yield encoding
|
||||
|
||||
|
||||
if self.adding_rgb_foreground:
|
||||
for i_n in range(self.number_of_backgrounds_per_image):
|
||||
background_image_chosen_name = random.choice(self.list_all_possible_background_images)
|
||||
foreground_rgb_chosen_name = random.choice(self.list_all_possible_foreground_rgbs)
|
||||
|
||||
img_rgb_background_chosen = cv2.imread(self.dir_rgb_backgrounds + '/' + background_image_chosen_name)
|
||||
foreground_rgb_chosen = np.load(self.dir_rgb_foregrounds + '/' + foreground_rgb_chosen_name)
|
||||
|
||||
img_out = return_binary_image_with_given_rgb_background_and_given_foreground_rgb(img_bin_corr, img_rgb_background_chosen, foreground_rgb_chosen)
|
||||
img_out = img_out.astype(np.uint8)
|
||||
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
|
||||
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
|
||||
yield encoding
|
||||
|
||||
|
||||
|
||||
|
||||
if self.adding_rgb_background:
|
||||
for i_n in range(self.number_of_backgrounds_per_image):
|
||||
background_image_chosen_name = random.choice(self.list_all_possible_background_images)
|
||||
img_rgb_background_chosen = cv2.imread(self.dir_rgb_backgrounds + '/' + background_image_chosen_name)
|
||||
img_out = return_binary_image_with_given_rgb_background(img_bin_corr, img_rgb_background_chosen)
|
||||
img_out = img_out.astype(np.uint8)
|
||||
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
|
||||
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
|
||||
yield encoding
|
||||
|
||||
if self.binarization:
|
||||
pixel_values = self.processor(Image.fromarray(img_bin_corr), return_tensors="pt").pixel_values
|
||||
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
|
||||
yield encoding
|
||||
|
||||
|
||||
if self.channels_shuffling:
|
||||
for shuffle_index in self.shuffle_indexes:
|
||||
img_out = return_shuffled_channels(img, shuffle_index)
|
||||
img_out = img_out.astype(np.uint8)
|
||||
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
|
||||
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
|
||||
yield encoding
|
||||
|
||||
if self.add_red_textlines:
|
||||
img_out = return_image_with_red_elements(img, img_bin_corr)
|
||||
img_out = img_out.astype(np.uint8)
|
||||
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
|
||||
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
|
||||
yield encoding
|
||||
|
||||
if self.white_noise_strap:
|
||||
img_out = return_image_with_strapped_white_noises(img)
|
||||
img_out = img_out.astype(np.uint8)
|
||||
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
|
||||
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
|
||||
yield encoding
|
||||
|
||||
if self.textline_skewing:
|
||||
for index, des_scale_ind in enumerate(self.skewing_amplitudes):
|
||||
try:
|
||||
img_out = do_deskewing(img, des_scale_ind)
|
||||
except:
|
||||
img_out = np.copy(img)
|
||||
img_out = img_out.astype(np.uint8)
|
||||
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
|
||||
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
|
||||
yield encoding
|
||||
|
||||
if self.textline_skewing_bin:
|
||||
for index, des_scale_ind in enumerate(self.skewing_amplitudes):
|
||||
try:
|
||||
img_out = do_deskewing(img_bin_corr, des_scale_ind)
|
||||
except:
|
||||
img_out = np.copy(img_bin_corr)
|
||||
img_out = img_out.astype(np.uint8)
|
||||
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
|
||||
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
|
||||
yield encoding
|
||||
|
||||
|
||||
if self.textline_left_in_depth:
|
||||
try:
|
||||
img_out = do_direction_in_depth(img, 'left')
|
||||
except:
|
||||
img_out = np.copy(img)
|
||||
img_out = img_out.astype(np.uint8)
|
||||
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
|
||||
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
|
||||
yield encoding
|
||||
|
||||
|
||||
if self.textline_left_in_depth_bin:
|
||||
try:
|
||||
img_out = do_direction_in_depth(img_bin_corr, 'left')
|
||||
except:
|
||||
img_out = np.copy(img_bin_corr)
|
||||
img_out = img_out.astype(np.uint8)
|
||||
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
|
||||
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
|
||||
yield encoding
|
||||
|
||||
|
||||
if self.textline_right_in_depth:
|
||||
try:
|
||||
img_out = do_direction_in_depth(img, 'right')
|
||||
except:
|
||||
img_out = np.copy(img)
|
||||
img_out = img_out.astype(np.uint8)
|
||||
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
|
||||
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
|
||||
yield encoding
|
||||
|
||||
|
||||
if self.textline_right_in_depth_bin:
|
||||
try:
|
||||
img_out = do_direction_in_depth(img_bin_corr, 'right')
|
||||
except:
|
||||
img_out = np.copy(img_bin_corr)
|
||||
img_out = img_out.astype(np.uint8)
|
||||
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
|
||||
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
|
||||
yield encoding
|
||||
|
||||
|
||||
if self.textline_up_in_depth:
|
||||
try:
|
||||
img_out = do_direction_in_depth(img, 'up')
|
||||
except:
|
||||
img_out = np.copy(img)
|
||||
img_out = img_out.astype(np.uint8)
|
||||
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
|
||||
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
|
||||
yield encoding
|
||||
|
||||
|
||||
if self.textline_up_in_depth_bin:
|
||||
try:
|
||||
img_out = do_direction_in_depth(img_bin_corr, 'up')
|
||||
except:
|
||||
img_out = np.copy(img_bin_corr)
|
||||
img_out = img_out.astype(np.uint8)
|
||||
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
|
||||
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
|
||||
yield encoding
|
||||
|
||||
|
||||
if self.textline_down_in_depth:
|
||||
try:
|
||||
img_out = do_direction_in_depth(img, 'down')
|
||||
except:
|
||||
img_out = np.copy(img)
|
||||
img_out = img_out.astype(np.uint8)
|
||||
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
|
||||
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
|
||||
yield encoding
|
||||
|
||||
|
||||
if self.textline_down_in_depth_bin:
|
||||
try:
|
||||
img_out = do_direction_in_depth(img_bin_corr, 'down')
|
||||
except:
|
||||
img_out = np.copy(img_bin_corr)
|
||||
img_out = img_out.astype(np.uint8)
|
||||
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
|
||||
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
|
||||
yield encoding
|
||||
|
||||
if self.pepper_bin_aug:
|
||||
for index, pepper_ind in enumerate(self.pepper_indexes):
|
||||
img_out = add_salt_and_pepper_noise(img_bin_corr, pepper_ind, pepper_ind)
|
||||
img_out = img_out.astype(np.uint8)
|
||||
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
|
||||
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
|
||||
yield encoding
|
||||
|
||||
|
||||
if self.pepper_aug:
|
||||
for index, pepper_ind in enumerate(self.pepper_indexes):
|
||||
img_out = add_salt_and_pepper_noise(img, pepper_ind, pepper_ind)
|
||||
img_out = img_out.astype(np.uint8)
|
||||
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
|
||||
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
|
||||
yield encoding
|
||||
|
||||
|
||||
|
||||
else:
|
||||
pixel_values = self.processor(Image.fromarray(img), return_tensors="pt").pixel_values
|
||||
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
|
||||
yield encoding
|
||||
|
|
|
|||
|
|
@ -16,33 +16,59 @@ from ..patch_encoder import (
|
|||
PatchEncoder,
|
||||
Patches,
|
||||
)
|
||||
from PIL import Image
|
||||
import torch
|
||||
from transformers import VisionEncoderDecoderModel
|
||||
|
||||
def run_ensembling(model_dirs, out_dir):
|
||||
all_weights = []
|
||||
def run_ensembling(dir_models, out, framework):
|
||||
ls_models = os.listdir(dir_models)
|
||||
if framework=="torch":
|
||||
models = []
|
||||
sd_models = []
|
||||
|
||||
for model_dir in model_dirs:
|
||||
assert os.path.isdir(model_dir), model_dir
|
||||
model = load_model(model_dir, compile=False,
|
||||
custom_objects=dict(PatchEncoder=PatchEncoder,
|
||||
Patches=Patches))
|
||||
all_weights.append(model.get_weights())
|
||||
for model_name in ls_models:
|
||||
model = VisionEncoderDecoderModel.from_pretrained(os.path.join(dir_models,model_name))
|
||||
models.append(model)
|
||||
sd_models.append(model.state_dict())
|
||||
for key in sd_models[0]:
|
||||
sd_models[0][key] = sum(sd[key] for sd in sd_models) / len(sd_models)
|
||||
|
||||
new_weights = []
|
||||
for layer_weights in zip(*all_weights):
|
||||
layer_weights = np.array([np.array(weights).mean(axis=0)
|
||||
for weights in zip(*layer_weights)])
|
||||
new_weights.append(layer_weights)
|
||||
model.load_state_dict(sd_models[0])
|
||||
os.system("mkdir "+out)
|
||||
torch.save(model.state_dict(), os.path.join(out, "pytorch_model.bin"))
|
||||
os.system('cp '+os.path.join(os.path.join(dir_models,model_name) , "config_eynollah.json ")+out)
|
||||
|
||||
else:
|
||||
weights=[]
|
||||
|
||||
for model_name in ls_models:
|
||||
model = load_model(os.path.join(dir_models,model_name) , compile=False, custom_objects={'PatchEncoder':PatchEncoder, 'Patches': Patches})
|
||||
weights.append(model.get_weights())
|
||||
|
||||
new_weights = list()
|
||||
|
||||
for weights_list_tuple in zip(*weights):
|
||||
new_weights.append(
|
||||
[np.array(weights_).mean(axis=0)\
|
||||
for weights_ in zip(*weights_list_tuple)])
|
||||
|
||||
|
||||
|
||||
new_weights = [np.array(x) for x in new_weights]
|
||||
|
||||
#model = tf.keras.models.clone_model(model)
|
||||
model.set_weights(new_weights)
|
||||
|
||||
model.save(out_dir)
|
||||
os.system('cp ' + os.path.join(model_dirs[0], "config.json ") + out_dir + "/")
|
||||
model.save(out)
|
||||
os.system('cp '+os.path.join(os.path.join(dir_models,model_name) , "config_eynollah.json ")+out)
|
||||
try:
|
||||
os.system('cp '+os.path.join(os.path.join(dir_models,model_name) , "characters_org.txt ")+out)
|
||||
except:
|
||||
pass
|
||||
|
||||
@click.command()
|
||||
@click.option(
|
||||
"--in",
|
||||
"-i",
|
||||
"in_",
|
||||
help="input directory of checkpoint models to be read",
|
||||
multiple=True,
|
||||
required=True,
|
||||
|
|
@ -55,12 +81,17 @@ def run_ensembling(model_dirs, out_dir):
|
|||
required=True,
|
||||
type=click.Path(exists=False, file_okay=False),
|
||||
)
|
||||
def ensemble_cli(in_, out):
|
||||
@click.option(
|
||||
"--framework",
|
||||
"-fw",
|
||||
help="this parameter gets tensorflow or torch as model framework",
|
||||
)
|
||||
|
||||
def ensemble_cli(in_, out, framework):
|
||||
"""
|
||||
mix multiple model weights
|
||||
|
||||
Load a sequence of models and mix them into a single ensemble model
|
||||
by averaging their weights. Write the resulting model.
|
||||
"""
|
||||
run_ensembling(in_, out)
|
||||
|
||||
run_ensembling(in_, out, framework)
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ from typing import Iterable, List, Tuple
|
|||
from logging import getLogger
|
||||
import time
|
||||
import math
|
||||
from itertools import islice
|
||||
|
||||
try:
|
||||
import matplotlib.pyplot as plt
|
||||
|
|
@ -33,6 +34,11 @@ def pairwise(iterable):
|
|||
yield a, b
|
||||
a = b
|
||||
|
||||
def batched(iterable, n):
|
||||
iterator = iter(iterable)
|
||||
while batch := tuple(islice(iterator, n)):
|
||||
yield batch
|
||||
|
||||
def return_multicol_separators_x_start_end(
|
||||
regions_without_separators, peak_points, top, bot,
|
||||
x_min_hor_some, x_max_hor_some, cy_hor_some, y_min_hor_some, y_max_hor_some):
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ from shapely.geometry.polygon import orient
|
|||
from shapely import set_precision, affinity
|
||||
from shapely.ops import unary_union, nearest_points
|
||||
|
||||
from .rotate import rotate_image, rotation_image_new
|
||||
from .rotate import rotate_image
|
||||
|
||||
def contours_in_same_horizon(cy_main_hor):
|
||||
"""
|
||||
|
|
@ -120,94 +120,6 @@ def return_contours_of_interested_region(region_pre_p, label, min_area=0.0002, d
|
|||
dilate=dilate)
|
||||
return contours_imgs
|
||||
|
||||
def do_work_of_contours_in_image(contour, index_r_con, img, slope_first):
|
||||
img_copy = np.zeros(img.shape[:2], dtype=np.uint8)
|
||||
img_copy = cv2.fillPoly(img_copy, pts=[contour], color=1)
|
||||
|
||||
img_copy = rotation_image_new(img_copy, -slope_first)
|
||||
_, thresh = cv2.threshold(img_copy, 0, 255, 0)
|
||||
|
||||
cont_int, _ = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
|
||||
|
||||
cont_int[0][:, 0, 0] = cont_int[0][:, 0, 0] + np.abs(img_copy.shape[1] - img.shape[1])
|
||||
cont_int[0][:, 0, 1] = cont_int[0][:, 0, 1] + np.abs(img_copy.shape[0] - img.shape[0])
|
||||
|
||||
return cont_int[0], index_r_con
|
||||
|
||||
def get_textregion_contours_in_org_image_multi(cnts, img, slope_first, map=map):
|
||||
if not len(cnts):
|
||||
return [], []
|
||||
results = map(partial(do_work_of_contours_in_image,
|
||||
img=img,
|
||||
slope_first=slope_first,
|
||||
),
|
||||
cnts, range(len(cnts)))
|
||||
return tuple(zip(*results))
|
||||
|
||||
def get_textregion_contours_in_org_image(cnts, img, slope_first):
|
||||
cnts_org = []
|
||||
# print(cnts,'cnts')
|
||||
for i in range(len(cnts)):
|
||||
img_copy = np.zeros(img.shape[:2], dtype=np.uint8)
|
||||
img_copy = cv2.fillPoly(img_copy, pts=[cnts[i]], color=1)
|
||||
|
||||
# plt.imshow(img_copy)
|
||||
# plt.show()
|
||||
|
||||
# print(img.shape,'img')
|
||||
img_copy = rotation_image_new(img_copy, -slope_first)
|
||||
##print(img_copy.shape,'img_copy')
|
||||
# plt.imshow(img_copy)
|
||||
# plt.show()
|
||||
|
||||
_, thresh = cv2.threshold(img_copy, 0, 255, 0)
|
||||
|
||||
cont_int, _ = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
|
||||
cont_int[0][:, 0, 0] = cont_int[0][:, 0, 0] + np.abs(img_copy.shape[1] - img.shape[1])
|
||||
cont_int[0][:, 0, 1] = cont_int[0][:, 0, 1] + np.abs(img_copy.shape[0] - img.shape[0])
|
||||
# print(np.shape(cont_int[0]))
|
||||
cnts_org.append(cont_int[0])
|
||||
|
||||
return cnts_org
|
||||
|
||||
def get_textregion_confidences_old(cnts, img, slope_first):
|
||||
zoom = 3
|
||||
img = cv2.resize(img, (img.shape[1] // zoom,
|
||||
img.shape[0] // zoom),
|
||||
interpolation=cv2.INTER_NEAREST)
|
||||
cnts_org = []
|
||||
for cnt in cnts:
|
||||
img_copy = np.zeros(img.shape[:2], dtype=np.uint8)
|
||||
img_copy = cv2.fillPoly(img_copy, pts=[cnt // zoom], color=1)
|
||||
|
||||
img_copy = rotation_image_new(img_copy, -slope_first).astype(np.uint8)
|
||||
_, thresh = cv2.threshold(img_copy, 0, 255, 0)
|
||||
|
||||
cont_int, _ = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
|
||||
cont_int[0][:, 0, 0] = cont_int[0][:, 0, 0] + np.abs(img_copy.shape[1] - img.shape[1])
|
||||
cont_int[0][:, 0, 1] = cont_int[0][:, 0, 1] + np.abs(img_copy.shape[0] - img.shape[0])
|
||||
cnts_org.append(cont_int[0] * zoom)
|
||||
|
||||
return cnts_org
|
||||
|
||||
def do_back_rotation_and_get_cnt_back(contour_par, index_r_con, img, slope_first, confidence_matrix):
|
||||
img_copy = np.zeros(img.shape[:2], dtype=np.uint8)
|
||||
img_copy = cv2.fillPoly(img_copy, pts=[contour_par], color=1)
|
||||
confidence_matrix_mapped_with_contour = confidence_matrix * img_copy
|
||||
confidence_contour = np.sum(confidence_matrix_mapped_with_contour) / float(np.sum(img_copy))
|
||||
|
||||
img_copy = rotation_image_new(img_copy, -slope_first).astype(np.uint8)
|
||||
_, thresh = cv2.threshold(img_copy, 0, 255, 0)
|
||||
|
||||
cont_int, _ = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
|
||||
if len(cont_int)==0:
|
||||
cont_int = [contour_par]
|
||||
confidence_contour = 0
|
||||
else:
|
||||
cont_int[0][:, 0, 0] = cont_int[0][:, 0, 0] + np.abs(img_copy.shape[1] - img.shape[1])
|
||||
cont_int[0][:, 0, 1] = cont_int[0][:, 0, 1] + np.abs(img_copy.shape[0] - img.shape[0])
|
||||
return cont_int[0], index_r_con, confidence_contour
|
||||
|
||||
def get_region_confidences(cnts, confidence_matrix):
|
||||
if not len(cnts):
|
||||
return []
|
||||
|
|
@ -418,7 +330,7 @@ def estimate_skew_contours(contours):
|
|||
if not np.any(usable):
|
||||
raise ValueError("not enough contours with consistent length")
|
||||
if np.count_nonzero(usable) == 1:
|
||||
return angle_in[usable]
|
||||
return angle_in[usable][0]
|
||||
# 4. there is no way to distinguish between +90 and -89.9 here,
|
||||
# so map to [0,180] when calculating averages, then map back to [-90,90]
|
||||
# (we don't want -90 and +89 to average zero, or +1 and +179 to average 90)
|
||||
|
|
|
|||
|
|
@ -9,8 +9,8 @@ else:
|
|||
import importlib.resources as importlib_resources
|
||||
|
||||
|
||||
def get_font():
|
||||
def get_font(font_size):
|
||||
#font_path = "Charis-7.000/Charis-Regular.ttf" # Make sure this file exists!
|
||||
font = importlib_resources.files(__package__) / "../Charis-Regular.ttf"
|
||||
font = importlib_resources.files(__package__) / "../Amiri-Regular.ttf"
|
||||
with importlib_resources.as_file(font) as font:
|
||||
return ImageFont.truetype(font=font, size=40)
|
||||
return ImageFont.truetype(font=font, size=font_size)
|
||||
|
|
|
|||
|
|
@ -2,10 +2,6 @@ import math
|
|||
import cv2
|
||||
|
||||
|
||||
def rotation_image_new(img, thetha):
|
||||
rotated = rotate_image(img, thetha)
|
||||
return rotate_max_area_new(img, rotated, thetha)
|
||||
|
||||
def rotate_image(img_patch, slope):
|
||||
(h, w) = img_patch.shape[:2]
|
||||
center = (w // 2, h // 2)
|
||||
|
|
|
|||
|
|
@ -3,15 +3,20 @@ import copy
|
|||
|
||||
import numpy as np
|
||||
import cv2
|
||||
import tensorflow as tf
|
||||
# avoid module-level import:
|
||||
# import tensorflow as tf
|
||||
# (wait for tf-keras and logging setup in ModelZoo.load_model)
|
||||
from scipy.signal import find_peaks
|
||||
from scipy.ndimage import gaussian_filter1d
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
|
||||
from . import pairwise
|
||||
from .resize import resize_image
|
||||
|
||||
|
||||
def decode_batch_predictions(pred, num_to_char, max_len = 128):
|
||||
import tensorflow as tf
|
||||
|
||||
# input_len is the product of the batch size and the
|
||||
# number of time steps.
|
||||
input_len = np.ones(pred.shape[0]) * pred.shape[1]
|
||||
|
|
@ -37,43 +42,6 @@ def decode_batch_predictions(pred, num_to_char, max_len = 128):
|
|||
output.append(d)
|
||||
return output
|
||||
|
||||
|
||||
def distortion_free_resize(image, img_size):
|
||||
w, h = img_size
|
||||
image = tf.image.resize(image, size=(h, w), preserve_aspect_ratio=True)
|
||||
|
||||
# Check tha amount of padding needed to be done.
|
||||
pad_height = h - tf.shape(image)[0]
|
||||
pad_width = w - tf.shape(image)[1]
|
||||
|
||||
# Only necessary if you want to do same amount of padding on both sides.
|
||||
if pad_height % 2 != 0:
|
||||
height = pad_height // 2
|
||||
pad_height_top = height + 1
|
||||
pad_height_bottom = height
|
||||
else:
|
||||
pad_height_top = pad_height_bottom = pad_height // 2
|
||||
|
||||
if pad_width % 2 != 0:
|
||||
width = pad_width // 2
|
||||
pad_width_left = width + 1
|
||||
pad_width_right = width
|
||||
else:
|
||||
pad_width_left = pad_width_right = pad_width // 2
|
||||
|
||||
image = tf.pad(
|
||||
image,
|
||||
paddings=[
|
||||
[pad_height_top, pad_height_bottom],
|
||||
[pad_width_left, pad_width_right],
|
||||
[0, 0],
|
||||
],
|
||||
)
|
||||
|
||||
image = tf.transpose(image, (1, 0, 2))
|
||||
image = tf.image.flip_left_right(image)
|
||||
return image
|
||||
|
||||
def return_start_and_end_of_common_text_of_textline_ocr_without_common_section(textline_image):
|
||||
width = np.shape(textline_image)[1]
|
||||
height = np.shape(textline_image)[0]
|
||||
|
|
@ -256,249 +224,58 @@ def return_splitting_point_of_image(image_to_spliited):
|
|||
|
||||
return np.sort(peaks_sort_4)
|
||||
|
||||
def break_curved_line_into_small_pieces_and_then_merge(img_curved, mask_curved, img_bin_curved=None):
|
||||
peaks_4 = return_splitting_point_of_image(img_curved)
|
||||
if len(peaks_4)>0:
|
||||
def break_curved_line_into_small_pieces_and_then_merge(img_rgb_curved, img_bin_curved, mask_curved):
|
||||
peaks_4 = return_splitting_point_of_image(img_rgb_curved)
|
||||
if len(peaks_4):
|
||||
imgs_tot = []
|
||||
|
||||
for ind in range(len(peaks_4)+1):
|
||||
if ind==0:
|
||||
img = img_curved[:, :peaks_4[ind], :]
|
||||
if img_bin_curved is not None:
|
||||
img_bin = img_bin_curved[:, :peaks_4[ind], :]
|
||||
mask = mask_curved[:, :peaks_4[ind], :]
|
||||
elif ind==len(peaks_4):
|
||||
img = img_curved[:, peaks_4[ind-1]:, :]
|
||||
if img_bin_curved is not None:
|
||||
img_bin = img_bin_curved[:, peaks_4[ind-1]:, :]
|
||||
mask = mask_curved[:, peaks_4[ind-1]:, :]
|
||||
else:
|
||||
img = img_curved[:, peaks_4[ind-1]:peaks_4[ind], :]
|
||||
if img_bin_curved is not None:
|
||||
img_bin = img_bin_curved[:, peaks_4[ind-1]:peaks_4[ind], :]
|
||||
mask = mask_curved[:, peaks_4[ind-1]:peaks_4[ind], :]
|
||||
|
||||
for left, right in pairwise([None] + peaks_4 + [None]):
|
||||
img_rgb = img_rgb_curved[:, left: right]
|
||||
img_bin = img_bin_curved[:, left: right]
|
||||
mask = mask_curved[:, left: right]
|
||||
or_ma = get_orientation_moments_of_mask(mask)
|
||||
|
||||
if img_bin_curved is not None:
|
||||
imgs_tot.append([img, mask, or_ma, img_bin] )
|
||||
else:
|
||||
imgs_tot.append([img, mask, or_ma] )
|
||||
|
||||
imgs_tot.append([img_rgb, img_bin, mask, or_ma])
|
||||
|
||||
w_tot_des_list = []
|
||||
w_tot_des = 0
|
||||
imgs_deskewed_list = []
|
||||
imgs_rgb_deskewed_list = []
|
||||
imgs_bin_deskewed_list = []
|
||||
|
||||
for ind in range(len(imgs_tot)):
|
||||
img_in = imgs_tot[ind][0]
|
||||
mask_in = imgs_tot[ind][1]
|
||||
ori_in = imgs_tot[ind][2]
|
||||
if img_bin_curved is not None:
|
||||
img_bin_in = imgs_tot[ind][3]
|
||||
|
||||
if abs(ori_in)<45:
|
||||
img_in_des = rotate_image_with_padding(img_in, ori_in, border_value=(255,255,255) )
|
||||
if img_bin_curved is not None:
|
||||
for img_rgb_in, img_bin_in, mask_in, ori_in in imgs_tot:
|
||||
if abs(ori_in) < 45:
|
||||
img_rgb_in_des = rotate_image_with_padding(img_rgb_in, ori_in, border_value=(255,255,255) )
|
||||
img_bin_in_des = rotate_image_with_padding(img_bin_in, ori_in, border_value=(255,255,255) )
|
||||
mask_in_des = rotate_image_with_padding(mask_in, ori_in)
|
||||
mask_in_des = mask_in_des.astype('uint8')
|
||||
|
||||
#new bounding box
|
||||
x_n, y_n, w_n, h_n = get_contours_and_bounding_boxes(mask_in_des[:,:,0])
|
||||
|
||||
if w_n==0 or h_n==0:
|
||||
img_in_des = np.copy(img_in)
|
||||
if img_bin_curved is not None:
|
||||
img_bin_in_des = np.copy(img_bin_in)
|
||||
w_relative = int(32 * img_in_des.shape[1]/float(img_in_des.shape[0]) )
|
||||
if w_relative==0:
|
||||
w_relative = img_in_des.shape[1]
|
||||
img_in_des = resize_image(img_in_des, 32, w_relative)
|
||||
if img_bin_curved is not None:
|
||||
img_bin_in_des = resize_image(img_bin_in_des, 32, w_relative)
|
||||
# get new bounding box
|
||||
x_n, y_n, w_n, h_n = get_contours_and_bounding_boxes(mask_in_des)
|
||||
if w_n and h_n:
|
||||
img_rgb_in_des = img_rgb_in_des[y_n: y_n + h_n, x_n: x_n + w_n]
|
||||
img_bin_in_des = img_bin_in_des[y_n: y_n + h_n, x_n: x_n + w_n]
|
||||
else:
|
||||
mask_in_des = mask_in_des[y_n:y_n+h_n, x_n:x_n+w_n, :]
|
||||
img_in_des = img_in_des[y_n:y_n+h_n, x_n:x_n+w_n, :]
|
||||
if img_bin_curved is not None:
|
||||
img_bin_in_des = img_bin_in_des[y_n:y_n+h_n, x_n:x_n+w_n, :]
|
||||
|
||||
w_relative = int(32 * img_in_des.shape[1]/float(img_in_des.shape[0]) )
|
||||
if w_relative==0:
|
||||
w_relative = img_in_des.shape[1]
|
||||
img_in_des = resize_image(img_in_des, 32, w_relative)
|
||||
if img_bin_curved is not None:
|
||||
img_bin_in_des = resize_image(img_bin_in_des, 32, w_relative)
|
||||
|
||||
|
||||
else:
|
||||
img_in_des = np.copy(img_in)
|
||||
if img_bin_curved is not None:
|
||||
img_rgb_in_des = np.copy(img_rgb_in)
|
||||
img_bin_in_des = np.copy(img_bin_in)
|
||||
else:
|
||||
img_rgb_in_des = np.copy(img_rgb_in)
|
||||
img_bin_in_des = np.copy(img_bin_in)
|
||||
w_relative = int(32 * img_in_des.shape[1]/float(img_in_des.shape[0]) )
|
||||
if w_relative==0:
|
||||
w_relative = img_in_des.shape[1]
|
||||
img_in_des = resize_image(img_in_des, 32, w_relative)
|
||||
if img_bin_curved is not None:
|
||||
img_bin_in_des = resize_image(img_bin_in_des, 32, w_relative)
|
||||
|
||||
w_tot_des+=img_in_des.shape[1]
|
||||
w_tot_des_list.append(img_in_des.shape[1])
|
||||
imgs_deskewed_list.append(img_in_des)
|
||||
if img_bin_curved is not None:
|
||||
h, w = img_rgb_in_des.shape[:2]
|
||||
new_h = 32
|
||||
new_w = 32 * w // h
|
||||
new_w = new_w or w
|
||||
img_rgb_in_des = resize_image(img_rgb_in_des, new_h, new_w)
|
||||
img_bin_in_des = resize_image(img_bin_in_des, new_h, new_w)
|
||||
|
||||
w_tot_des_list.append(new_w)
|
||||
imgs_rgb_deskewed_list.append(img_rgb_in_des)
|
||||
imgs_bin_deskewed_list.append(img_bin_in_des)
|
||||
|
||||
|
||||
|
||||
|
||||
img_final_deskewed = np.zeros((32, w_tot_des, 3))+255
|
||||
if img_bin_curved is not None:
|
||||
img_bin_final_deskewed = np.zeros((32, w_tot_des, 3))+255
|
||||
else:
|
||||
img_bin_final_deskewed = None
|
||||
img_rgb_final_deskewed = np.ones((new_h, sum(w_tot_des_list), 3)) * 255
|
||||
img_bin_final_deskewed = np.ones((new_h, sum(w_tot_des_list), 3)) * 255
|
||||
|
||||
w_indexer = 0
|
||||
for ind in range(len(w_tot_des_list)):
|
||||
img_final_deskewed[:,w_indexer:w_indexer+w_tot_des_list[ind],:] = imgs_deskewed_list[ind][:,:,:]
|
||||
if img_bin_curved is not None:
|
||||
img_bin_final_deskewed[:,w_indexer:w_indexer+w_tot_des_list[ind],:] = imgs_bin_deskewed_list[ind][:,:,:]
|
||||
w_indexer = w_indexer+w_tot_des_list[ind]
|
||||
return img_final_deskewed, img_bin_final_deskewed
|
||||
w_indexer2 = w_indexer + w_tot_des_list[ind]
|
||||
img_rgb_final_deskewed[:, w_indexer: w_indexer2] = imgs_rgb_deskewed_list[ind]
|
||||
img_bin_final_deskewed[:, w_indexer: w_indexer2] = imgs_bin_deskewed_list[ind]
|
||||
w_indexer = w_indexer2
|
||||
return img_rgb_final_deskewed, img_bin_final_deskewed
|
||||
else:
|
||||
return img_curved, img_bin_curved
|
||||
|
||||
def return_textline_contour_with_added_box_coordinate(textline_contour, box_ind):
|
||||
textline_contour[:,:,0] += box_ind[2]
|
||||
textline_contour[:,:,1] += box_ind[0]
|
||||
return textline_contour
|
||||
|
||||
|
||||
def return_rnn_cnn_ocr_of_given_textlines(image,
|
||||
all_found_textline_polygons,
|
||||
all_box_coord,
|
||||
prediction_model,
|
||||
b_s_ocr, num_to_char,
|
||||
curved_line=False):
|
||||
max_len = 512
|
||||
padding_token = 299
|
||||
image_width = 512#max_len * 4
|
||||
image_height = 32
|
||||
ind_tot = 0
|
||||
#cv2.imwrite('./img_out.png', image_page)
|
||||
ocr_all_textlines = []
|
||||
cropped_lines_region_indexer = []
|
||||
cropped_lines_meging_indexing = []
|
||||
cropped_lines = []
|
||||
indexer_text_region = 0
|
||||
|
||||
for indexing, ind_poly_first in enumerate(all_found_textline_polygons):
|
||||
#ocr_textline_in_textregion = []
|
||||
if len(ind_poly_first)==0:
|
||||
cropped_lines_region_indexer.append(indexer_text_region)
|
||||
cropped_lines_meging_indexing.append(0)
|
||||
img_fin = np.ones((image_height, image_width, 3))*1
|
||||
cropped_lines.append(img_fin)
|
||||
|
||||
else:
|
||||
for indexing2, ind_poly in enumerate(ind_poly_first):
|
||||
cropped_lines_region_indexer.append(indexer_text_region)
|
||||
if not curved_line:
|
||||
ind_poly = copy.deepcopy(ind_poly)
|
||||
box_ind = all_box_coord[indexing]
|
||||
|
||||
ind_poly = return_textline_contour_with_added_box_coordinate(ind_poly, box_ind)
|
||||
#print(ind_poly_copy)
|
||||
ind_poly[ind_poly<0] = 0
|
||||
x, y, w, h = cv2.boundingRect(ind_poly)
|
||||
|
||||
w_scaled = w * image_height/float(h)
|
||||
|
||||
mask_poly = np.zeros(image.shape)
|
||||
|
||||
img_poly_on_img = np.copy(image)
|
||||
|
||||
mask_poly = cv2.fillPoly(mask_poly, pts=[ind_poly], color=(1, 1, 1))
|
||||
|
||||
|
||||
|
||||
mask_poly = mask_poly[y:y+h, x:x+w, :]
|
||||
img_crop = img_poly_on_img[y:y+h, x:x+w, :]
|
||||
|
||||
img_crop[mask_poly==0] = 255
|
||||
|
||||
if w_scaled < 640:#1.5*image_width:
|
||||
img_fin = preprocess_and_resize_image_for_ocrcnn_model(img_crop, image_height, image_width)
|
||||
cropped_lines.append(img_fin)
|
||||
cropped_lines_meging_indexing.append(0)
|
||||
else:
|
||||
splited_images, splited_images_bin = return_textlines_split_if_needed(img_crop, None)
|
||||
|
||||
if splited_images:
|
||||
img_fin = preprocess_and_resize_image_for_ocrcnn_model(splited_images[0],
|
||||
image_height,
|
||||
image_width)
|
||||
cropped_lines.append(img_fin)
|
||||
cropped_lines_meging_indexing.append(1)
|
||||
|
||||
img_fin = preprocess_and_resize_image_for_ocrcnn_model(splited_images[1],
|
||||
image_height,
|
||||
image_width)
|
||||
|
||||
cropped_lines.append(img_fin)
|
||||
cropped_lines_meging_indexing.append(-1)
|
||||
|
||||
else:
|
||||
img_fin = preprocess_and_resize_image_for_ocrcnn_model(img_crop,
|
||||
image_height,
|
||||
image_width)
|
||||
cropped_lines.append(img_fin)
|
||||
cropped_lines_meging_indexing.append(0)
|
||||
|
||||
indexer_text_region+=1
|
||||
|
||||
extracted_texts = []
|
||||
|
||||
n_iterations = math.ceil(len(cropped_lines) / b_s_ocr)
|
||||
|
||||
for i in range(n_iterations):
|
||||
if i==(n_iterations-1):
|
||||
n_start = i*b_s_ocr
|
||||
imgs = cropped_lines[n_start:]
|
||||
imgs = np.array(imgs)
|
||||
imgs = imgs.reshape(imgs.shape[0], image_height, image_width, 3)
|
||||
|
||||
|
||||
else:
|
||||
n_start = i*b_s_ocr
|
||||
n_end = (i+1)*b_s_ocr
|
||||
imgs = cropped_lines[n_start:n_end]
|
||||
imgs = np.array(imgs).reshape(b_s_ocr, image_height, image_width, 3)
|
||||
|
||||
|
||||
preds = prediction_model.predict(imgs, verbose=0)
|
||||
|
||||
pred_texts = decode_batch_predictions(preds, num_to_char)
|
||||
|
||||
for ib in range(imgs.shape[0]):
|
||||
pred_texts_ib = pred_texts[ib].replace("[UNK]", "")
|
||||
extracted_texts.append(pred_texts_ib)
|
||||
|
||||
extracted_texts_merged = [extracted_texts[ind]
|
||||
if cropped_lines_meging_indexing[ind]==0
|
||||
else extracted_texts[ind]+" "+extracted_texts[ind+1]
|
||||
if cropped_lines_meging_indexing[ind]==1
|
||||
else None
|
||||
for ind in range(len(cropped_lines_meging_indexing))]
|
||||
|
||||
extracted_texts_merged = [ind for ind in extracted_texts_merged if ind is not None]
|
||||
unique_cropped_lines_region_indexer = np.unique(cropped_lines_region_indexer)
|
||||
|
||||
ocr_all_textlines = []
|
||||
for ind in unique_cropped_lines_region_indexer:
|
||||
ocr_textline_in_textregion = []
|
||||
extracted_texts_merged_un = np.array(extracted_texts_merged)[np.array(cropped_lines_region_indexer)==ind]
|
||||
for it_ind, text_textline in enumerate(extracted_texts_merged_un):
|
||||
ocr_textline_in_textregion.append(text_textline)
|
||||
ocr_all_textlines.append(ocr_textline_in_textregion)
|
||||
return ocr_all_textlines
|
||||
return img_rgb_curved, img_bin_curved
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
from typing import List
|
||||
import os
|
||||
import pytest
|
||||
import logging
|
||||
|
||||
|
|
@ -31,6 +32,8 @@ def run_eynollah_ok_and_check_logs(
|
|||
subcommand,
|
||||
*args
|
||||
]
|
||||
if 'EYNOLLAH_OPTIONS' in os.environ:
|
||||
args = os.environ['EYNOLLAH_OPTIONS'].split() + args
|
||||
if pytestconfig.getoption('verbose') > 0:
|
||||
args = ['-l', 'DEBUG'] + args
|
||||
caplog.set_level(logging.INFO)
|
||||
|
|
|
|||
|
|
@ -6,11 +6,12 @@ from ocrd_models.constants import NAMESPACES as NS
|
|||
"options",
|
||||
[
|
||||
[], # defaults
|
||||
#["--allow_scaling", "--curved-line"],
|
||||
["--allow_scaling", "--curved-line", "--full-layout"],
|
||||
["--allow_scaling", "--curved-line", "--full-layout", "--reading_order_machine_based"],
|
||||
#["--curved-line"],
|
||||
["--curved-line", "--full-layout"],
|
||||
["--curved-line", "--full-layout", "--reading_order_machine_based"],
|
||||
# -ep ...
|
||||
# -eoi ...
|
||||
# --input_binary
|
||||
# --ignore_page_extraction
|
||||
# --skip_layout_and_reading_order
|
||||
], ids=str)
|
||||
def test_run_eynollah_layout_filename(
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ def test_run_eynollah_ocr_filename(
|
|||
'-o', str(outfile.parent),
|
||||
] + options,
|
||||
[
|
||||
# FIXME: ocr has no logging!
|
||||
'output filename:'
|
||||
]
|
||||
)
|
||||
assert outfile.exists()
|
||||
|
|
@ -57,7 +57,7 @@ def test_run_eynollah_ocr_directory(
|
|||
'-o', str(outdir),
|
||||
],
|
||||
[
|
||||
# FIXME: ocr has no logging!
|
||||
'output filename:'
|
||||
]
|
||||
)
|
||||
assert len(list(outdir.iterdir())) == 2
|
||||
|
|
|
|||
|
|
@ -1,16 +1,28 @@
|
|||
from eynollah.model_zoo import EynollahModelZoo
|
||||
from eynollah.predictor import Predictor
|
||||
|
||||
def test_trocr1(
|
||||
model_dir,
|
||||
):
|
||||
model_zoo = EynollahModelZoo(model_dir)
|
||||
try:
|
||||
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
|
||||
model_zoo.load_models('trocr_processor')
|
||||
proc = model_zoo.get('trocr_processor')
|
||||
assert isinstance(proc, TrOCRProcessor)
|
||||
model_zoo.load_models(['ocr', 'tr'])
|
||||
model_zoo.load_models(('ocr', 'tr'))
|
||||
model = model_zoo.get('ocr')
|
||||
assert isinstance(model, VisionEncoderDecoderModel)
|
||||
assert isinstance(model, Predictor)
|
||||
shape = model.input_shape
|
||||
assert len(shape) == 3
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
def test_cnnrnnocr1(
|
||||
model_dir,
|
||||
):
|
||||
model_zoo = EynollahModelZoo(model_dir)
|
||||
try:
|
||||
model_zoo.load_models('ocr')
|
||||
model = model_zoo.get('ocr')
|
||||
assert isinstance(model, Predictor)
|
||||
shape = model.input_shape
|
||||
assert len(shape) == 4
|
||||
except ImportError:
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -1,17 +1,17 @@
|
|||
{
|
||||
"backbone_type" : "transformer",
|
||||
"task": "cnn-rnn-ocr",
|
||||
"task": "transformer-ocr",
|
||||
"n_classes" : 2,
|
||||
"max_len": 280,
|
||||
"n_epochs" : 3,
|
||||
"max_len": 192,
|
||||
"n_epochs" : 1,
|
||||
"input_height" : 32,
|
||||
"input_width" : 512,
|
||||
"weight_decay" : 1e-6,
|
||||
"n_batch" : 4,
|
||||
"n_batch" : 1,
|
||||
"learning_rate": 1e-5,
|
||||
"save_interval": 1500,
|
||||
"patches" : false,
|
||||
"pretraining" : true,
|
||||
"pretraining" : false,
|
||||
"augmentation" : true,
|
||||
"flip_aug" : false,
|
||||
"blur_aug" : true,
|
||||
|
|
@ -77,7 +77,6 @@
|
|||
"dir_output": "/home/vahid/extracted_lines/1919_bin/output",
|
||||
"dir_rgb_backgrounds": "/home/vahid/Documents/1_2_test_eynollah/set_rgb_background",
|
||||
"dir_rgb_foregrounds": "/home/vahid/Documents/1_2_test_eynollah/out_set_rgb_foreground",
|
||||
"dir_img_bin": "/home/vahid/extracted_lines/1919_bin/images_bin",
|
||||
"characters_txt_file":"/home/vahid/Downloads/models_eynollah/model_eynollah_ocr_cnnrnn_20250930/characters_org.txt"
|
||||
"dir_img_bin": "/home/vahid/extracted_lines/1919_bin/images_bin"
|
||||
|
||||
}
|
||||
|
|
|
|||
|
|
@ -8,3 +8,8 @@ tensorflow-addons # for connected_components, depublished and only compatible wi
|
|||
tensorflow < 2.16 # for tensorflow-addons, so only needed in training
|
||||
tf_data < 2.16 # for tensorflow-addons, so only needed in training
|
||||
protobuf < 5 # for tensorflow-addons, so only needed in training
|
||||
torch
|
||||
evaluate
|
||||
accelerate
|
||||
jiwer
|
||||
transformers <= 4.30.2
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue