Merge branch 'integrating_trocr_and_torch_ensembling_and_updating_characters_list'

# Conflicts:
#	src/eynollah/eynollah.py
#	src/eynollah/eynollah_ocr.py
#	src/eynollah/patch_encoder.py
#	src/eynollah/training/cli.py
#	src/eynollah/training/gt_gen_utils.py
#	src/eynollah/training/inference.py
#	src/eynollah/training/models.py
#	src/eynollah/training/train.py
#	src/eynollah/training/utils.py
#	src/eynollah/training/weights_ensembling.py
#	train/requirements.txt
This commit is contained in:
kba 2026-06-11 18:59:33 +02:00
commit 6df11d92d8
16 changed files with 1048 additions and 245 deletions

Binary file not shown.

View file

@ -69,10 +69,11 @@ class Eynollah_ocr:
self.model_zoo.load_models(['ocr', 'tr']) self.model_zoo.load_models(['ocr', 'tr'])
self.model_zoo.get('ocr').to(self.device) self.model_zoo.get('ocr').to(self.device)
else: else:
self.model_zoo.load_models('ocr') self.model_zoo.load_model('ocr', '')
self.model_zoo.load_models('num_to_char') self.input_shape = self.model_zoo.get('ocr').input_shape[1:3]
self.model_zoo.load_models('characters') self.model_zoo.load_model('num_to_char')
self.end_character = len(self.model_zoo.get('characters')) + 2 self.model_zoo.load_model('characters')
self.end_character = len(self.model_zoo.get('characters', list)) + 2
@property @property
def device(self): def device(self):
@ -657,7 +658,7 @@ class Eynollah_ocr:
if out_image_with_text: if out_image_with_text:
image_text = Image.new("RGB", (img.shape[1], img.shape[0]), "white") image_text = Image.new("RGB", (img.shape[1], img.shape[0]), "white")
draw = ImageDraw.Draw(image_text) draw = ImageDraw.Draw(image_text)
font = get_font() font = get_font(font_size=40)
for indexer_text, bb_ind in enumerate(total_bb_coordinates): for indexer_text, bb_ind in enumerate(total_bb_coordinates):
x_bb = bb_ind[0] x_bb = bb_ind[0]
@ -823,8 +824,8 @@ class Eynollah_ocr:
page_ns=page_ns, page_ns=page_ns,
img_bin=img_bin, img_bin=img_bin,
image_width=512, image_width=self.input_shape[1],
image_height=32, image_height=self.input_shape[0],
) )
self.write_ocr( self.write_ocr(

View file

@ -20,6 +20,7 @@ class PatchEncoder(layers.Layer):
def get_config(self): def get_config(self):
return dict(num_patches=self.num_patches, return dict(num_patches=self.num_patches,
projection_dim=self.projection_dim, projection_dim=self.projection_dim,
position_embedding=self.position_embedding,
**super().get_config()) **super().get_config())
class Patches(layers.Layer): class Patches(layers.Layer):

View file

@ -9,7 +9,12 @@ from .generate_gt_for_training import main as generate_gt_cli
from .inference import main as inference_cli from .inference import main as inference_cli
from .train import ex from .train import ex
from .extract_line_gt import linegt_cli from .extract_line_gt import linegt_cli
<<<<<<< HEAD
from .weights_ensembling import ensemble_cli from .weights_ensembling import ensemble_cli
=======
from .weights_ensembling import main as ensemble_cli
from .generate_or_update_cnn_rnn_ocr_character_list import main as update_ocr_characters_cli
>>>>>>> integrating_trocr_and_torch_ensembling_and_updating_characters_list
@click.command(context_settings=dict( @click.command(context_settings=dict(
ignore_unknown_options=True, ignore_unknown_options=True,
@ -28,3 +33,4 @@ main.add_command(inference_cli, 'inference')
main.add_command(train_cli, 'train') main.add_command(train_cli, 'train')
main.add_command(linegt_cli, 'export_textline_images_and_text') main.add_command(linegt_cli, 'export_textline_images_and_text')
main.add_command(ensemble_cli, 'ensembling') main.add_command(ensemble_cli, 'ensembling')
main.add_command(update_ocr_characters_cli, 'generate_or_update_cnn_rnn_ocr_character_list')

View file

@ -50,6 +50,12 @@ from ..utils import is_image_filename
is_flag=True, is_flag=True,
help="if this parameter set to true, cropped textline images will not be masked with textline contour.", help="if this parameter set to true, cropped textline images will not be masked with textline contour.",
) )
@click.option(
"--exclude_vertical_lines",
"-exv",
is_flag=True,
help="if this parameter set to true, vertical textline images will be excluded.",
)
def linegt_cli( def linegt_cli(
image, image,
dir_in, dir_in,
@ -57,6 +63,7 @@ def linegt_cli(
dir_out, dir_out,
pref_of_dataset, pref_of_dataset,
do_not_mask_with_textline_contour, do_not_mask_with_textline_contour,
exclude_vertical_lines,
): ):
assert bool(dir_in) ^ bool(image), "Set --dir-in or --image-filename, not both" assert bool(dir_in) ^ bool(image), "Set --dir-in or --image-filename, not both"
if dir_in: if dir_in:
@ -70,14 +77,13 @@ def linegt_cli(
for dir_img in ls_imgs: for dir_img in ls_imgs:
file_name = Path(dir_img).stem file_name = Path(dir_img).stem
dir_xml = os.path.join(dir_xmls, file_name + '.xml') dir_xml = os.path.join(dir_xmls, file_name + '.xml')
img = cv2.imread(dir_img) img = cv2.imread(dir_img)
total_bb_coordinates = [] total_bb_coordinates = []
tree1 = ET.parse(dir_xml, parser=ET.XMLParser(encoding="utf-8")) tree = ET.parse(dir_xml, parser=ET.XMLParser(encoding="utf-8"))
root1 = tree1.getroot() root = tree.getroot()
alltags = [elem.tag for elem in root1.iter()] alltags = [elem.tag for elem in root.iter()]
name_space = alltags[0].split('}')[0] name_space = alltags[0].split('}')[0]
name_space = name_space.split('{')[1] name_space = name_space.split('{')[1]
@ -89,7 +95,7 @@ def linegt_cli(
indexer_text_region = 0 indexer_text_region = 0
indexer_textlines = 0 indexer_textlines = 0
# FIXME: non recursive, use OCR-D PAGE generateDS API. Or use an existing tool for this purpose altogether # FIXME: non recursive, use OCR-D PAGE generateDS API. Or use an existing tool for this purpose altogether
for nn in root1.iter(region_tags): for nn in root.iter(region_tags):
for child_textregion in nn: for child_textregion in nn:
if child_textregion.tag.endswith("TextLine"): if child_textregion.tag.endswith("TextLine"):
for child_textlines in child_textregion: for child_textlines in child_textregion:
@ -99,6 +105,10 @@ def linegt_cli(
textline_coords = np.array([[int(x.split(',')[0]), int(x.split(',')[1])] for x in p_h]) textline_coords = np.array([[int(x.split(',')[0]), int(x.split(',')[1])] for x in p_h])
x, y, w, h = cv2.boundingRect(textline_coords) x, y, w, h = cv2.boundingRect(textline_coords)
if exclude_vertical_lines and h > 1.4 * w:
img_crop = None
continue
total_bb_coordinates.append([x, y, w, h]) total_bb_coordinates.append([x, y, w, h])
@ -114,12 +124,15 @@ def linegt_cli(
img_crop[mask_poly == 0] = 255 img_crop[mask_poly == 0] = 255
if img_crop.shape[0] == 0 or img_crop.shape[1] == 0: if img_crop.shape[0] == 0 or img_crop.shape[1] == 0:
img_crop = None
continue continue
if child_textlines.tag.endswith("TextEquiv"): if child_textlines.tag.endswith("TextEquiv"):
for cheild_text in child_textlines: for cheild_text in child_textlines:
if cheild_text.tag.endswith("Unicode"): if cheild_text.tag.endswith("Unicode"):
textline_text = cheild_text.text textline_text = cheild_text.text
if textline_text: if textline_text and img_crop is not None:
base_name = os.path.join( base_name = os.path.join(
dir_out, file_name + '_line_' + str(indexer_textlines) dir_out, file_name + '_line_' + str(indexer_textlines)
) )
@ -131,4 +144,4 @@ def linegt_cli(
with open(base_name + '.txt', 'w') as text_file: with open(base_name + '.txt', 'w') as text_file:
text_file.write(textline_text) text_file.write(textline_text)
cv2.imwrite(base_name + '.png', img_crop) cv2.imwrite(base_name + '.png', img_crop)
indexer_textlines += 1 indexer_textlines += 1

View file

@ -6,6 +6,7 @@ from pathlib import Path
from PIL import Image, ImageDraw, ImageFont from PIL import Image, ImageDraw, ImageFont
import cv2 import cv2
import numpy as np import numpy as np
from eynollah.utils.font import get_font
from .gt_gen_utils import ( from .gt_gen_utils import (
filter_contours_area_of_image, filter_contours_area_of_image,
@ -393,11 +394,15 @@ def visualize_reading_order(xml_file, dir_xml, dir_out, dir_imgs):
layout = np.zeros( (y_len,x_len,3) ) layout = np.zeros( (y_len,x_len,3) )
layout = cv2.fillPoly(layout, pts =co_text_all, color=(1,1,1)) layout = cv2.fillPoly(layout, pts =co_text_all, color=(1,1,1))
img_file_name_with_format = find_format_of_given_filename_in_dir(dir_imgs, f_name) try:
img = cv2.imread(os.path.join(dir_imgs, img_file_name_with_format)) img_file_name_with_format = find_format_of_given_filename_in_dir(dir_imgs, f_name)
img = cv2.imread(os.path.join(dir_imgs, img_file_name_with_format))
overlayed = overlay_layout_on_image(layout, img, cx_ordered, cy_ordered, color, thickness)
cv2.imwrite(os.path.join(dir_out, f_name+'.png'), overlayed) overlayed = overlay_layout_on_image(layout, img, cx_ordered, cy_ordered, color, thickness)
cv2.imwrite(os.path.join(dir_out, f_name+'.png'), overlayed)
except:
pass
else: else:
img = np.zeros( (y_len,x_len,3) ) img = np.zeros( (y_len,x_len,3) )
@ -452,14 +457,17 @@ def visualize_textline_segmentation(xml_file, dir_xml, dir_out, dir_imgs):
xml_file = os.path.join(dir_xml,ind_xml ) xml_file = os.path.join(dir_xml,ind_xml )
f_name = Path(ind_xml).stem f_name = Path(ind_xml).stem
img_file_name_with_format = find_format_of_given_filename_in_dir(dir_imgs, f_name) try:
img = cv2.imread(os.path.join(dir_imgs, img_file_name_with_format)) img_file_name_with_format = find_format_of_given_filename_in_dir(dir_imgs, f_name)
img = cv2.imread(os.path.join(dir_imgs, img_file_name_with_format))
co_tetxlines, y_len, x_len = get_textline_contours_for_visualization(xml_file)
co_tetxlines, y_len, x_len = get_textline_contours_for_visualization(xml_file) added_image = visualize_image_from_contours(co_tetxlines, img)
added_image = visualize_image_from_contours(co_tetxlines, img) cv2.imwrite(os.path.join(dir_out, f_name+'.png'), added_image)
except:
cv2.imwrite(os.path.join(dir_out, f_name+'.png'), added_image) pass
@ -509,15 +517,17 @@ def visualize_layout_segmentation(xml_file, dir_xml, dir_out, dir_imgs):
f_name = Path(ind_xml).stem f_name = Path(ind_xml).stem
print(f_name, 'f_name') print(f_name, 'f_name')
img_file_name_with_format = find_format_of_given_filename_in_dir(dir_imgs, f_name) try:
img = cv2.imread(os.path.join(dir_imgs, img_file_name_with_format)) img_file_name_with_format = find_format_of_given_filename_in_dir(dir_imgs, f_name)
img = cv2.imread(os.path.join(dir_imgs, img_file_name_with_format))
co_text, co_graphic, co_sep, co_img, co_table, co_map, co_music, co_noise, y_len, x_len = get_layout_contours_for_visualization(xml_file)
co_text, co_graphic, co_sep, co_img, co_table, co_map, co_noise, y_len, x_len = get_layout_contours_for_visualization(xml_file) added_image = visualize_image_from_contours_layout(co_text['paragraph'], co_text['header']+co_text['heading'], co_text['drop-capital'], co_sep, co_img, co_text['marginalia'], co_table, co_map, co_music, img)
added_image = visualize_image_from_contours_layout(co_text['paragraph'], co_text['header']+co_text['heading'], co_text['drop-capital'], co_sep, co_img, co_text['marginalia'], co_table, img)
cv2.imwrite(os.path.join(dir_out, f_name+'.png'), added_image) cv2.imwrite(os.path.join(dir_out, f_name+'.png'), added_image)
except:
pass
@ -552,8 +562,8 @@ def visualize_ocr_text(xml_file, dir_xml, dir_out):
else: else:
xml_files_ind = [xml_file] xml_files_ind = [xml_file]
font_path = "Charis-7.000/Charis-Regular.ttf" # Make sure this file exists! ###font_path = "Charis-7.000/Charis-Regular.ttf" # Make sure this file exists!
font = ImageFont.truetype(font_path, 40) font = get_font(font_size=40)#ImageFont.truetype(font_path, 40)
for ind_xml in tqdm(xml_files_ind): for ind_xml in tqdm(xml_files_ind):
indexer = 0 indexer = 0
@ -590,11 +600,11 @@ def visualize_ocr_text(xml_file, dir_xml, dir_out):
is_vertical = h > 2*w # Check orientation is_vertical = h > 2*w # Check orientation
font = fit_text_single_line(draw, ocr_texts[index], font_path, w, int(h*0.4) ) font = fit_text_single_line(draw, ocr_texts[index], w, int(h*0.4) )
if is_vertical: if is_vertical:
vertical_font = fit_text_single_line(draw, ocr_texts[index], font_path, h, int(w * 0.8)) vertical_font = fit_text_single_line(draw, ocr_texts[index], h, int(w * 0.8))
text_img = Image.new("RGBA", (h, w), (255, 255, 255, 0)) # Note: dimensions are swapped text_img = Image.new("RGBA", (h, w), (255, 255, 255, 0)) # Note: dimensions are swapped
text_draw = ImageDraw.Draw(text_img) text_draw = ImageDraw.Draw(text_img)

View file

@ -0,0 +1,59 @@
import os
import numpy as np
import json
import click
import logging
def run_character_list_update(dir_labels, out, current_character_list):
ls_labels = os.listdir(dir_labels)
ls_labels = [ind for ind in ls_labels if ind.endswith('.txt')]
if current_character_list:
with open(current_character_list, 'r') as f_name:
characters = json.load(f_name)
characters = set(characters)
else:
characters = set()
for ind in ls_labels:
label = open(os.path.join(dir_labels,ind),'r').read().split('\n')[0]
for char in label:
characters.add(char)
characters = sorted(list(set(characters)))
with open(out, 'w') as f_name:
json.dump(characters, f_name)
@click.command()
@click.option(
"--dir_labels",
"-dl",
help="directory of labels which are .txt files",
type=click.Path(exists=True, file_okay=False),
required=True,
)
@click.option(
"--current_character_list",
"-ccl",
help="existing character list in a .txt file that needs to be updated with a set of labels",
type=click.Path(exists=True, file_okay=True),
required=False,
)
@click.option(
"--out",
"-o",
help="An output .txt file where the generated or updated character list will be written",
type=click.Path(exists=False, file_okay=True),
)
def main(dir_labels, out, current_character_list):
run_character_list_update(dir_labels, out, current_character_list)

View file

@ -8,7 +8,7 @@ from shapely import geometry
from pathlib import Path from pathlib import Path
from PIL import ImageFont from PIL import ImageFont
from ocrd_utils import bbox_from_points from ocrd_utils import bbox_from_points
from eynollah.utils.font import get_font
KERNEL = np.ones((5, 5), np.uint8) KERNEL = np.ones((5, 5), np.uint8)
NS = { 'pc': 'http://schema.primaresearch.org/PAGE/gts/pagecontent/2019-07-15' NS = { 'pc': 'http://schema.primaresearch.org/PAGE/gts/pagecontent/2019-07-15'
@ -18,7 +18,7 @@ with warnings.catch_warnings():
warnings.simplefilter("ignore") warnings.simplefilter("ignore")
def visualize_image_from_contours_layout(co_par, co_header, co_drop, co_sep, co_image, co_marginal, co_table, co_map, img): def visualize_image_from_contours_layout(co_par, co_header, co_drop, co_sep, co_image, co_marginal, co_table, co_map, co_music, img):
alpha = 0.5 alpha = 0.5
blank_image = np.ones( (img.shape[:]), dtype=np.uint8) * 255 blank_image = np.ones( (img.shape[:]), dtype=np.uint8) * 255
@ -32,6 +32,7 @@ def visualize_image_from_contours_layout(co_par, co_header, co_drop, co_sep, co_
col_marginal = (106, 90, 205) col_marginal = (106, 90, 205)
col_table = (0, 90, 205) col_table = (0, 90, 205)
col_map = (90, 90, 205) col_map = (90, 90, 205)
col_music = (90, 90, 0)
if len(co_image)>0: if len(co_image)>0:
cv2.drawContours(blank_image, co_image, -1, col_image, thickness=cv2.FILLED) # Fill the contour cv2.drawContours(blank_image, co_image, -1, col_image, thickness=cv2.FILLED) # Fill the contour
@ -59,6 +60,9 @@ def visualize_image_from_contours_layout(co_par, co_header, co_drop, co_sep, co_
if len(co_map)>0: if len(co_map)>0:
cv2.drawContours(blank_image, co_map, -1, col_map, thickness=cv2.FILLED) # Fill the contour cv2.drawContours(blank_image, co_map, -1, col_map, thickness=cv2.FILLED) # Fill the contour
if len(co_music)>0:
cv2.drawContours(blank_image, co_music, -1, col_music, thickness=cv2.FILLED) # Fill the contour
img_final =cv2.cvtColor(blank_image, cv2.COLOR_BGR2RGB) img_final =cv2.cvtColor(blank_image, cv2.COLOR_BGR2RGB)
@ -352,11 +356,11 @@ def get_textline_contours_and_ocr_text(xml_file):
ocr_textlines.append(ocr_text_in[0]) ocr_textlines.append(ocr_text_in[0])
return co_use_case, y_len, x_len, ocr_textlines return co_use_case, y_len, x_len, ocr_textlines
def fit_text_single_line(draw, text, font_path, max_width, max_height): def fit_text_single_line(draw, text, max_width, max_height):
initial_font_size = 50 initial_font_size = 50
font_size = initial_font_size font_size = initial_font_size
while font_size > 10: # Minimum font size while font_size > 10: # Minimum font size
font = ImageFont.truetype(font_path, font_size) font = get_font(font_size=font_size)# ImageFont.truetype(font_path, font_size)
text_bbox = draw.textbbox((0, 0), text, font=font) # Get text bounding box text_bbox = draw.textbbox((0, 0), text, font=font) # Get text bounding box
text_width = text_bbox[2] - text_bbox[0] text_width = text_bbox[2] - text_bbox[0]
text_height = text_bbox[3] - text_bbox[1] text_height = text_bbox[3] - text_bbox[1]
@ -366,7 +370,7 @@ def fit_text_single_line(draw, text, font_path, max_width, max_height):
font_size -= 2 # Reduce font size and retry font_size -= 2 # Reduce font size and retry
return ImageFont.truetype(font_path, 10) # Smallest font fallback return get_font(font_size=10)#ImageFont.truetype(font_path, 10) # Smallest font fallback
def get_layout_contours_for_visualization(xml_file): def get_layout_contours_for_visualization(xml_file):
tree1 = ET.parse(xml_file, parser = ET.XMLParser(encoding='utf-8')) tree1 = ET.parse(xml_file, parser = ET.XMLParser(encoding='utf-8'))
@ -389,6 +393,7 @@ def get_layout_contours_for_visualization(xml_file):
co_img=[] co_img=[]
co_table=[] co_table=[]
co_map=[] co_map=[]
co_music=[]
co_noise=[] co_noise=[]
types_text = [] types_text = []
@ -630,6 +635,31 @@ def get_layout_contours_for_visualization(xml_file):
elif vv.tag!=link+'Point' and sumi>=1: elif vv.tag!=link+'Point' and sumi>=1:
break break
co_map.append(np.array(c_t_in)) co_map.append(np.array(c_t_in))
if tag.endswith('}MusicRegion') or tag.endswith('}musicregion'):
#print('sth')
for nn in root1.iter(tag):
c_t_in=[]
sumi=0
for vv in nn.iter():
# check the format of coords
if vv.tag==link+'Coords':
coords=bool(vv.attrib)
if coords:
p_h=vv.attrib['points'].split(' ')
c_t_in.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) )
break
else:
pass
if vv.tag==link+'Point':
c_t_in.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ])
sumi+=1
#print(vv.tag,'in')
elif vv.tag!=link+'Point' and sumi>=1:
break
co_music.append(np.array(c_t_in))
if tag.endswith('}NoiseRegion') or tag.endswith('}noiseregion'): if tag.endswith('}NoiseRegion') or tag.endswith('}noiseregion'):
@ -656,7 +686,7 @@ def get_layout_contours_for_visualization(xml_file):
elif vv.tag!=link+'Point' and sumi>=1: elif vv.tag!=link+'Point' and sumi>=1:
break break
co_noise.append(np.array(c_t_in)) co_noise.append(np.array(c_t_in))
return co_text, co_graphic, co_sep, co_img, co_table, co_map, co_noise, y_len, x_len return co_text, co_graphic, co_sep, co_img, co_table, co_map, co_music, co_noise, y_len, x_len
def get_images_of_ground_truth( def get_images_of_ground_truth(
gt_list, gt_list,
@ -682,171 +712,193 @@ def get_images_of_ground_truth(
if not item.endswith('.xml')} if not item.endswith('.xml')}
for index in tqdm(range(len(gt_list))): for index in tqdm(range(len(gt_list))):
#try:
print(gt_list[index]) print(gt_list[index])
tree1 = ET.parse(dir_in+'/'+gt_list[index], parser = ET.XMLParser(encoding='utf-8'))
root1=tree1.getroot()
alltags=[elem.tag for elem in root1.iter()]
link=alltags[0].split('}')[0]+'}'
x_len, y_len = 0, 0 try:
for jj in root1.iter(link+'Page'): tree1 = ET.parse(dir_in+'/'+gt_list[index], parser = ET.XMLParser(encoding='utf-8'))
y_len=int(jj.attrib['imageHeight']) root1=tree1.getroot()
x_len=int(jj.attrib['imageWidth']) alltags=[elem.tag for elem in root1.iter()]
link=alltags[0].split('}')[0]+'}'
if 'columns_width' in list(config_params.keys()): x_len, y_len = 0, 0
columns_width_dict = config_params['columns_width'] for jj in root1.iter(link+'Page'):
# FIXME: look in /Page/@custom as well y_len=int(jj.attrib['imageHeight'])
metadata_element = root1.find(link+'Metadata') x_len=int(jj.attrib['imageWidth'])
num_col = None
for child in metadata_element:
tag2 = child.tag
if tag2.endswith('}Comments') or tag2.endswith('}comments'):
text_comments = child.text
num_col = int(text_comments.split('num_col')[1])
if num_col: if 'columns_width' in list(config_params.keys()):
x_new = columns_width_dict[str(num_col)] columns_width_dict = config_params['columns_width']
y_new = int ( x_new * (y_len / float(x_len)) ) metadata_element = root1.find(link+'Metadata')
num_col = None
if printspace or "printspace_as_class_in_layout" in list(config_params.keys()): for child in metadata_element:
ps = (root1.xpath('/pc:PcGts/pc:Page/pc:Border', namespaces=NS) + tag2 = child.tag
root1.xpath('/pc:PcGts/pc:Page/pc:PrintSpace', namespaces=NS)) if tag2.endswith('}Comments') or tag2.endswith('}comments'):
coords = root1.xpath('//pc:Coords/@points', namespaces=NS) text_comments = child.text
if len(ps): num_col = int(text_comments.split('num_col')[1])
points = ps[0].find('pc:Coords', NS).get('points')
ps_bbox = bbox_from_points(points)
elif missing_printspace == 'skip':
print(gt_list[index], "has no Border or PrintSpace - skipping file")
continue
elif missing_printspace == 'project' and len(coords):
print(gt_list[index], "has no Border or PrintSpace - projecting hull of segments")
bboxes = list(map(bbox_from_points, coords))
left, top, right, bottom = zip(*bboxes)
left = max(0, min(left) - 5)
top = max(0, min(top) - 5)
right = min(x_len, max(right) + 5)
bottom = min(y_len, max(bottom) + 5)
ps_bbox = [left, top, right, bottom]
else:
print(gt_list[index], "has no Border or PrintSpace - using full page")
ps_bbox = [0, 0, None, None]
if config_file and (config_params['use_case']=='textline' or config_params['use_case']=='word' or config_params['use_case']=='glyph' or config_params['use_case']=='printspace'):
keys = list(config_params.keys())
if "artificial_class_label" in keys:
artificial_class_rgb_color = (255,255,0)
artificial_class_label = config_params['artificial_class_label']
textline_rgb_color = (255, 0, 0)
if config_params['use_case']=='textline':
region_tags = np.unique([x for x in alltags if x.endswith('TextLine')])
elif config_params['use_case']=='word':
region_tags = np.unique([x for x in alltags if x.endswith('Word')])
elif config_params['use_case']=='glyph':
region_tags = np.unique([x for x in alltags if x.endswith('Glyph')])
elif config_params['use_case']=='printspace':
region_tags = np.unique([x for x in alltags if x.endswith('PrintSpace')])
co_use_case = []
for tag in region_tags:
if config_params['use_case']=='textline':
tag_endings = ['}TextLine','}textline']
elif config_params['use_case']=='word':
tag_endings = ['}Word','}word']
elif config_params['use_case']=='glyph':
tag_endings = ['}Glyph','}glyph']
elif config_params['use_case']=='printspace':
tag_endings = ['}PrintSpace','}printspace']
if tag.endswith(tag_endings[0]) or tag.endswith(tag_endings[1]): if num_col:
for nn in root1.iter(tag): x_new = columns_width_dict[str(num_col)]
c_t_in = [] y_new = int ( x_new * (y_len / float(x_len)) )
sumi = 0
for vv in nn.iter(): if printspace or "printspace_as_class_in_layout" in list(config_params.keys()):
# check the format of coords region_tags = np.unique([x for x in alltags if x.endswith('PrintSpace') or x.endswith('Border')])
if vv.tag == link + 'Coords': co_use_case = []
coords = bool(vv.attrib)
if coords: for tag in region_tags:
p_h = vv.attrib['points'].split(' ') tag_endings = ['}PrintSpace','}Border']
c_t_in.append(
np.array([[int(x.split(',')[0]), int(x.split(',')[1])] for x in p_h])) if tag.endswith(tag_endings[0]) or tag.endswith(tag_endings[1]):
for nn in root1.iter(tag):
c_t_in = []
sumi = 0
for vv in nn.iter():
# check the format of coords
if vv.tag == link + 'Coords':
coords = bool(vv.attrib)
if coords:
p_h = vv.attrib['points'].split(' ')
c_t_in.append(
np.array([[int(x.split(',')[0]), int(x.split(',')[1])] for x in p_h]))
break
else:
pass
if vv.tag == link + 'Point':
c_t_in.append([int(float(vv.attrib['x'])), int(float(vv.attrib['y']))])
sumi += 1
elif vv.tag != link + 'Point' and sumi >= 1:
break break
else: co_use_case.append(np.array(c_t_in))
pass
img = np.zeros((y_len, x_len, 3))
if vv.tag == link + 'Point':
c_t_in.append([int(float(vv.attrib['x'])), int(float(vv.attrib['y']))])
sumi += 1
elif vv.tag != link + 'Point' and sumi >= 1:
break
co_use_case.append(np.array(c_t_in))
if "artificial_class_label" in keys:
img_boundary = np.zeros((y_len, x_len))
erosion_rate = 0#1
dilation_rate = 2
dilation_early = 0
erosion_early = 2
co_use_case, img_boundary = update_region_contours(co_use_case, img_boundary, erosion_rate, dilation_rate, y_len, x_len, dilation_early=dilation_early, erosion_early=erosion_early)
img = np.zeros((y_len, x_len, 3))
if output_type == '2d':
img_poly = cv2.fillPoly(img, pts=co_use_case, color=(1, 1, 1)) img_poly = cv2.fillPoly(img, pts=co_use_case, color=(1, 1, 1))
if "artificial_class_label" in keys:
img_mask = np.copy(img_poly)
##img_poly[:,:][(img_boundary[:,:]==1) & (img_mask[:,:,0]!=1)] = artificial_class_label
img_poly[:,:][img_boundary[:,:]==1] = artificial_class_label
elif output_type == '3d':
img_poly = cv2.fillPoly(img, pts=co_use_case, color=textline_rgb_color)
if "artificial_class_label" in keys:
img_mask = np.copy(img_poly)
img_poly[:,:,0][(img_boundary[:,:]==1) & (img_mask[:,:,0]!=255)] = artificial_class_rgb_color[0]
img_poly[:,:,1][(img_boundary[:,:]==1) & (img_mask[:,:,0]!=255)] = artificial_class_rgb_color[1]
img_poly[:,:,2][(img_boundary[:,:]==1) & (img_mask[:,:,0]!=255)] = artificial_class_rgb_color[2]
if printspace and config_params['use_case']!='printspace':
img_poly = img_poly[ps_bbox[1]:ps_bbox[3],
ps_bbox[0]:ps_bbox[2], :]
img_poly = img_poly.astype(np.uint8)
if 'columns_width' in list(config_params.keys()) and num_col and config_params['use_case']!='printspace': imgray = cv2.cvtColor(img_poly, cv2.COLOR_BGR2GRAY)
img_poly = resize_image(img_poly, y_new, x_new) _, thresh = cv2.threshold(imgray, 0, 255, 0)
try: contours, _ = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
xml_file_stem = os.path.splitext(gt_list[index])[0]
cv2.imwrite(os.path.join(output_dir, xml_file_stem + '.png'), img_poly)
except:
xml_file_stem = os.path.splitext(gt_list[index])[0]
cv2.imwrite(os.path.join(output_dir, xml_file_stem + '.png'), img_poly)
if dir_images: cnt_size = np.array([cv2.contourArea(contours[j]) for j in range(len(contours))])
org_image_name = ls_org_imgs[xml_file_stem]
if not org_image_name:
print("image file for XML stem", xml_file_stem, "is missing")
continue
if not os.path.isfile(os.path.join(dir_images, org_image_name)):
print("image file for XML stem", xml_file_stem, "is not readable")
continue
img_org = cv2.imread(os.path.join(dir_images, org_image_name))
try:
cnt = contours[np.argmax(cnt_size)]
x, y, w, h = cv2.boundingRect(cnt)
except:
x, y , w, h = 0, 0, x_len, y_len
bb_xywh = [x, y, w, h]
if config_file and (config_params['use_case']=='textline' or config_params['use_case']=='word' or config_params['use_case']=='glyph' or config_params['use_case']=='printspace'):
keys = list(config_params.keys())
if "artificial_class_label" in keys:
artificial_class_rgb_color = (255,255,0)
artificial_class_label = config_params['artificial_class_label']
textline_rgb_color = (255, 0, 0)
if config_params['use_case']=='textline':
region_tags = np.unique([x for x in alltags if x.endswith('TextLine')])
elif config_params['use_case']=='word':
region_tags = np.unique([x for x in alltags if x.endswith('Word')])
elif config_params['use_case']=='glyph':
region_tags = np.unique([x for x in alltags if x.endswith('Glyph')])
elif config_params['use_case']=='printspace':
region_tags = np.unique([x for x in alltags if x.endswith('PrintSpace')])
co_use_case = []
for tag in region_tags:
if config_params['use_case']=='textline':
tag_endings = ['}TextLine','}textline']
elif config_params['use_case']=='word':
tag_endings = ['}Word','}word']
elif config_params['use_case']=='glyph':
tag_endings = ['}Glyph','}glyph']
elif config_params['use_case']=='printspace':
tag_endings = ['}PrintSpace','}printspace']
if tag.endswith(tag_endings[0]) or tag.endswith(tag_endings[1]):
for nn in root1.iter(tag):
c_t_in = []
sumi = 0
for vv in nn.iter():
# check the format of coords
if vv.tag == link + 'Coords':
coords = bool(vv.attrib)
if coords:
p_h = vv.attrib['points'].split(' ')
c_t_in.append(
np.array([[int(x.split(',')[0]), int(x.split(',')[1])] for x in p_h]))
break
else:
pass
if vv.tag == link + 'Point':
c_t_in.append([int(float(vv.attrib['x'])), int(float(vv.attrib['y']))])
sumi += 1
elif vv.tag != link + 'Point' and sumi >= 1:
break
co_use_case.append(np.array(c_t_in))
if "artificial_class_label" in keys:
img_boundary = np.zeros((y_len, x_len))
erosion_rate = 0#1
dilation_rate = 2
dilation_early = 0
erosion_early = 2
co_use_case, img_boundary = update_region_contours(co_use_case, img_boundary, erosion_rate, dilation_rate, y_len, x_len, dilation_early=dilation_early, erosion_early=erosion_early)
img = np.zeros((y_len, x_len, 3))
if output_type == '2d':
img_poly = cv2.fillPoly(img, pts=co_use_case, color=(1, 1, 1))
if "artificial_class_label" in keys:
img_mask = np.copy(img_poly)
##img_poly[:,:][(img_boundary[:,:]==1) & (img_mask[:,:,0]!=1)] = artificial_class_label
img_poly[:,:][img_boundary[:,:]==1] = artificial_class_label
elif output_type == '3d':
img_poly = cv2.fillPoly(img, pts=co_use_case, color=textline_rgb_color)
if "artificial_class_label" in keys:
img_mask = np.copy(img_poly)
img_poly[:,:,0][(img_boundary[:,:]==1) & (img_mask[:,:,0]!=255)] = artificial_class_rgb_color[0]
img_poly[:,:,1][(img_boundary[:,:]==1) & (img_mask[:,:,0]!=255)] = artificial_class_rgb_color[1]
img_poly[:,:,2][(img_boundary[:,:]==1) & (img_mask[:,:,0]!=255)] = artificial_class_rgb_color[2]
if printspace and config_params['use_case']!='printspace': if printspace and config_params['use_case']!='printspace':
img_org = img_org[ps_bbox[1]:ps_bbox[3], img_poly = img_poly[bb_xywh[1]:bb_xywh[1]+bb_xywh[3], bb_xywh[0]:bb_xywh[0]+bb_xywh[2], :]
ps_bbox[0]:ps_bbox[2], :]
if 'columns_width' in list(config_params.keys()) and num_col and config_params['use_case']!='printspace': if 'columns_width' in list(config_params.keys()) and num_col and config_params['use_case']!='printspace':
img_org = resize_image(img_org, y_new, x_new) img_poly = resize_image(img_poly, y_new, x_new)
cv2.imwrite(os.path.join(dir_out_images, org_image_name), img_org)
try:
xml_file_stem = os.path.splitext(gt_list[index])[0]
cv2.imwrite(os.path.join(output_dir, xml_file_stem + '.png'), img_poly)
except:
xml_file_stem = os.path.splitext(gt_list[index])[0]
cv2.imwrite(os.path.join(output_dir, xml_file_stem + '.png'), img_poly)
if dir_images:
org_image_name = ls_org_imgs[ls_org_imgs_stem.index(xml_file_stem)]
img_org = cv2.imread(os.path.join(dir_images, org_image_name))
if printspace and config_params['use_case']!='printspace':
img_org = img_org[bb_xywh[1]:bb_xywh[1]+bb_xywh[3], bb_xywh[0]:bb_xywh[0]+bb_xywh[2], :]
if 'columns_width' in list(config_params.keys()) and num_col and config_params['use_case']!='printspace':
img_org = resize_image(img_org, y_new, x_new)
cv2.imwrite(os.path.join(dir_out_images, org_image_name), img_org)
except:
pass
if config_file and config_params['use_case']=='layout': if config_file and config_params['use_case']=='layout':
keys = list(config_params.keys()) keys = list(config_params.keys())
@ -870,7 +922,7 @@ def get_images_of_ground_truth(
types_graphic_label = list(types_graphic_dict.values()) types_graphic_label = list(types_graphic_dict.values())
labels_rgb_color = [ (0,0,0), (255,0,0), (255,125,0), (255,0,125), (125,255,125), (125,125,0), (0,125,255), (0,125,0), (125,125,125), (255,0,255), (125,0,125), (0,255,0),(0,0,255), (0,255,255), (255,125,125), (0,125,125), (0,255,125), (255,125,255), (125,255,0), (125,255,255)] labels_rgb_color = [ (0,0,0), (255,0,0), (255,125,0), (255,0,125), (125,255,125), (125,125,0), (0,125,255), (0,125,0), (125,125,125), (255,0,255), (125,0,125), (0,255,0),(0,0,255), (0,255,255), (255,125,125), (0,125,125), (0,255,125), (255,125,255), (125,255,0), (125,255,255), (125,125,255)]
region_tags=np.unique([x for x in alltags if x.endswith('Region')]) region_tags=np.unique([x for x in alltags if x.endswith('Region')])
@ -882,6 +934,7 @@ def get_images_of_ground_truth(
co_img=[] co_img=[]
co_table=[] co_table=[]
co_map=[] co_map=[]
co_music=[]
co_noise=[] co_noise=[]
for tag in region_tags: for tag in region_tags:
@ -966,19 +1019,21 @@ def get_images_of_ground_truth(
if "rest_as_decoration" in types_graphic: if "rest_as_decoration" in types_graphic:
types_graphic_without_decoration = [element for element in types_graphic if element!='rest_as_decoration' and element!='decoration'] types_graphic_without_decoration = [element for element in types_graphic if element!='rest_as_decoration' and element!='decoration']
if len(types_graphic_without_decoration) == 0: if len(types_graphic_without_decoration) == 0:
if "type" in nn.attrib: #if "type" in nn.attrib:
c_t_in_graphic['decoration'].append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) c_t_in_graphic['decoration'].append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) )
elif len(types_graphic_without_decoration) >= 1: elif len(types_graphic_without_decoration) >= 1:
if "type" in nn.attrib: if "type" in nn.attrib:
if nn.attrib['type'] in types_graphic_without_decoration: if nn.attrib['type'] in types_graphic_without_decoration:
c_t_in_graphic[nn.attrib['type']].append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) c_t_in_graphic[nn.attrib['type']].append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) )
else: else:
c_t_in_graphic['decoration'].append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) c_t_in_graphic['decoration'].append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) )
else:
c_t_in_graphic['decoration'].append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) )
else: else:
if "type" in nn.attrib: if "type" in nn.attrib:
if nn.attrib['type'] in all_defined_graphic_types: if nn.attrib['type'] in all_defined_graphic_types:
c_t_in_graphic[nn.attrib['type']].append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) c_t_in_graphic[nn.attrib['type']].append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) )
break break
else: else:
@ -989,9 +1044,9 @@ def get_images_of_ground_truth(
if "rest_as_decoration" in types_graphic: if "rest_as_decoration" in types_graphic:
types_graphic_without_decoration = [element for element in types_graphic if element!='rest_as_decoration' and element!='decoration'] types_graphic_without_decoration = [element for element in types_graphic if element!='rest_as_decoration' and element!='decoration']
if len(types_graphic_without_decoration) == 0: if len(types_graphic_without_decoration) == 0:
if "type" in nn.attrib: #if "type" in nn.attrib:
c_t_in_graphic['decoration'].append( [ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ] ) c_t_in_graphic['decoration'].append( [ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ] )
sumi+=1 sumi+=1
elif len(types_graphic_without_decoration) >= 1: elif len(types_graphic_without_decoration) >= 1:
if "type" in nn.attrib: if "type" in nn.attrib:
if nn.attrib['type'] in types_graphic_without_decoration: if nn.attrib['type'] in types_graphic_without_decoration:
@ -1000,6 +1055,9 @@ def get_images_of_ground_truth(
else: else:
c_t_in_graphic['decoration'].append( [ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ] ) c_t_in_graphic['decoration'].append( [ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ] )
sumi+=1 sumi+=1
else:
c_t_in_graphic['decoration'].append( [ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ] )
sumi+=1
else: else:
if "type" in nn.attrib: if "type" in nn.attrib:
@ -1118,6 +1176,32 @@ def get_images_of_ground_truth(
elif vv.tag!=link+'Point' and sumi>=1: elif vv.tag!=link+'Point' and sumi>=1:
break break
co_map.append(np.array(c_t_in)) co_map.append(np.array(c_t_in))
if 'musicregion' in keys:
if tag.endswith('}MusicRegion') or tag.endswith('}musicregion'):
#print('sth')
for nn in root1.iter(tag):
c_t_in=[]
sumi=0
for vv in nn.iter():
# check the format of coords
if vv.tag==link+'Coords':
coords=bool(vv.attrib)
if coords:
p_h=vv.attrib['points'].split(' ')
c_t_in.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) )
break
else:
pass
if vv.tag==link+'Point':
c_t_in.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ])
sumi+=1
#print(vv.tag,'in')
elif vv.tag!=link+'Point' and sumi>=1:
break
co_music.append(np.array(c_t_in))
if 'noiseregion' in keys: if 'noiseregion' in keys:
if tag.endswith('}NoiseRegion') or tag.endswith('}noiseregion'): if tag.endswith('}NoiseRegion') or tag.endswith('}noiseregion'):
@ -1195,6 +1279,10 @@ def get_images_of_ground_truth(
erosion_rate = 0#2 erosion_rate = 0#2
dilation_rate = 3#4 dilation_rate = 3#4
co_map, img_boundary = update_region_contours(co_map, img_boundary, erosion_rate, dilation_rate, y_len, x_len ) co_map, img_boundary = update_region_contours(co_map, img_boundary, erosion_rate, dilation_rate, y_len, x_len )
if "musicregion" in elements_with_artificial_class:
erosion_rate = 0#2
dilation_rate = 3#4
co_music, img_boundary = update_region_contours(co_music, img_boundary, erosion_rate, dilation_rate, y_len, x_len )
@ -1222,6 +1310,8 @@ def get_images_of_ground_truth(
img_poly=cv2.fillPoly(img, pts =co_table, color=labels_rgb_color[ config_params['tableregion']]) img_poly=cv2.fillPoly(img, pts =co_table, color=labels_rgb_color[ config_params['tableregion']])
if 'mapregion' in keys: if 'mapregion' in keys:
img_poly=cv2.fillPoly(img, pts =co_map, color=labels_rgb_color[ config_params['mapregion']]) img_poly=cv2.fillPoly(img, pts =co_map, color=labels_rgb_color[ config_params['mapregion']])
if 'musicregion' in keys:
img_poly=cv2.fillPoly(img, pts =co_music, color=labels_rgb_color[ config_params['musicregion']])
if 'noiseregion' in keys: if 'noiseregion' in keys:
img_poly=cv2.fillPoly(img, pts =co_noise, color=labels_rgb_color[ config_params['noiseregion']]) img_poly=cv2.fillPoly(img, pts =co_noise, color=labels_rgb_color[ config_params['noiseregion']])
@ -1286,6 +1376,9 @@ def get_images_of_ground_truth(
if 'mapregion' in keys: if 'mapregion' in keys:
color_label = config_params['mapregion'] color_label = config_params['mapregion']
img_poly=cv2.fillPoly(img, pts =co_map, color=(color_label,color_label,color_label)) img_poly=cv2.fillPoly(img, pts =co_map, color=(color_label,color_label,color_label))
if 'musicregion' in keys:
color_label = config_params['musicregion']
img_poly=cv2.fillPoly(img, pts =co_music, color=(color_label,color_label,color_label))
if 'noiseregion' in keys: if 'noiseregion' in keys:
color_label = config_params['noiseregion'] color_label = config_params['noiseregion']
img_poly=cv2.fillPoly(img, pts =co_noise, color=(color_label,color_label,color_label)) img_poly=cv2.fillPoly(img, pts =co_noise, color=(color_label,color_label,color_label))

View file

@ -9,9 +9,11 @@ import warnings
import json import json
import click import click
import numpy as np import numpy as np
from numpy._typing import NDArray from numpy._typing import NDArray
import cv2 import cv2
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15 os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15
@ -119,8 +121,33 @@ class SBBPredict:
return mIoU return mIoU
def start_new_session_and_model(self): def start_new_session_and_model(self):
if self.cpu: if self.task == "cnn-rnn-ocr":
tf.config.set_visible_devices([], 'GPU') if self.cpu:
os.environ['CUDA_VISIBLE_DEVICES']='-1'
self.model = load_model(self.model_dir)
self.model = tf.keras.models.Model(
self.model.get_layer(name = "image").input,
self.model.get_layer(name = "dense2").output)
assert isinstance(self.model, Model)
elif self.task == "transformer-ocr":
import torch
from transformers import VisionEncoderDecoderModel
from transformers import TrOCRProcessor
self.model = VisionEncoderDecoderModel.from_pretrained(self.model_dir)
self.processor = TrOCRProcessor.from_pretrained(self.model_dir)
if self.cpu:
self.device = torch.device('cpu')
else:
self.device = torch.device('cuda:0')
self.model.to(self.device)
assert isinstance(self.model, torch.nn.Module)
else: else:
try: try:
for device in tf.config.list_physical_devices('GPU'): for device in tf.config.list_physical_devices('GPU'):
@ -137,15 +164,13 @@ class SBBPredict:
custom_objects={"PatchEncoder": PatchEncoder, custom_objects={"PatchEncoder": PatchEncoder,
"Patches": Patches}) "Patches": Patches})
##if self.weights_dir!=None: if self.task != 'classification' and self.task != 'reading_order':
##self.model.load_weights(self.weights_dir) self.img_height=self.model.layers[len(self.model.layers)-1].output_shape[1]
self.img_width=self.model.layers[len(self.model.layers)-1].output_shape[2]
self.n_classes=self.model.layers[len(self.model.layers)-1].output_shape[3]
assert isinstance(self.model, Model)
if self.task != 'classification' and self.task != 'reading_order': assert isinstance(self.model, Model)
last = self.model.layers[-1]
self.img_height = last.output_shape[1]
self.img_width = last.output_shape[2]
self.n_classes = last.output_shape[3]
def visualize_model_output(self, prediction, img, task) -> Tuple[NDArray, NDArray]: def visualize_model_output(self, prediction, img, task) -> Tuple[NDArray, NDArray]:
if task == "binarization": if task == "binarization":
@ -191,9 +216,9 @@ class SBBPredict:
return added_image, layout_only return added_image, layout_only
def predict(self, image_dir): def predict(self, image_dir):
assert isinstance(self.model, Model)
if self.task == 'classification': if self.task == 'classification':
classes_names = self.config_params_model['classification_classes_name'] classes_names = self.config_params_model['classification_classes_name']
img_1ch = cv2.imread(image_dir, 0) / 255.0 img_1ch = cv2.imread(image_dir, 0) / 255.0
img_1ch = cv2.resize(img_1ch, (self.config_params_model['input_height'], img_1ch = cv2.resize(img_1ch, (self.config_params_model['input_height'],
self.config_params_model['input_width']), self.config_params_model['input_width']),
@ -230,6 +255,15 @@ class SBBPredict:
pred_texts = decode_batch_predictions(preds, num_to_char) pred_texts = decode_batch_predictions(preds, num_to_char)
pred_texts = pred_texts[0].replace("[UNK]", "") pred_texts = pred_texts[0].replace("[UNK]", "")
return pred_texts return pred_texts
elif self.task == "transformer-ocr":
from PIL import Image
image = Image.open(image_dir).convert("RGB")
pixel_values = self.processor(image, return_tensors="pt").pixel_values
generated_ids = self.model.generate(pixel_values.to(self.device))
return self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
elif self.task == 'reading_order': elif self.task == 'reading_order':
@ -566,6 +600,8 @@ class SBBPredict:
cv2.imwrite(self.save,res) cv2.imwrite(self.save,res)
elif self.task == "cnn-rnn-ocr": elif self.task == "cnn-rnn-ocr":
print(f"Detected text: {res}") print(f"Detected text: {res}")
elif self.task == "transformer-ocr":
print(f"Detected text: {res}")
else: else:
img_seg_overlayed, only_layout = self.visualize_model_output(res, self.img_org, self.task) img_seg_overlayed, only_layout = self.visualize_model_output(res, self.img_org, self.task)
if self.save: if self.save:
@ -668,11 +704,13 @@ class SBBPredict:
help="min area size of regions considered for reading order detection. The default value is zero and means that all text regions are considered for reading order.", help="min area size of regions considered for reading order detection. The default value is zero and means that all text regions are considered for reading order.",
) )
def main(image, dir_in, model, patches, save, save_layout, ground_truth, xml_file, cpu, out, min_area): def main(image, dir_in, model, patches, save, save_layout, ground_truth, xml_file, cpu, out, min_area):
assert image or dir_in, "Either a single image -i or a dir_in -di input is required" assert image or dir_in, "Either a single image -i or a dir_in -di input is required"
with open(os.path.join(model,'config.json')) as f: with open(os.path.join(model,'config.json')) as f:
config_params_model = json.load(f) config_params_model = json.load(f)
task = config_params_model['task'] task = config_params_model['task']
if task not in ['classification', 'reading_order', "cnn-rnn-ocr"]: if task not in ['classification', 'reading_order', "cnn-rnn-ocr", "transformer-ocr"]:
assert not image or save, "For segmentation or binarization, an input single image -i also requires an output filename -s" assert not image or save, "For segmentation or binarization, an input single image -i also requires an output filename -s"
assert not dir_in or out, "For segmentation or binarization, an input directory -di also requires an output directory -o" assert not dir_in or out, "For segmentation or binarization, an input directory -di also requires an output directory -o"
x = SBBPredict(image, dir_in, model, task, config_params_model, x = SBBPredict(image, dir_in, model, task, config_params_model,

View file

@ -308,6 +308,8 @@ def transformer_block(img,
x3 = mlp(x3, hidden_units=mlp_head_units, dropout_rate=0.1) x3 = mlp(x3, hidden_units=mlp_head_units, dropout_rate=0.1)
# Skip connection 2. # Skip connection 2.
encoded_patches = Add()([x3, x2]) encoded_patches = Add()([x3, x2])
#assert isinstance(x, Layer)
encoded_patches = tf.reshape(encoded_patches, encoded_patches = tf.reshape(encoded_patches,
[-1, [-1,

View file

@ -3,9 +3,14 @@ import sys
import io import io
import json import json
from tqdm import tqdm from tqdm import tqdm
import requests import requests
import numpy as np
import cv2
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15 os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15
import tensorflow as tf import tensorflow as tf
@ -47,12 +52,21 @@ from .utils import (
generate_arrays_from_folder_reading_order, generate_arrays_from_folder_reading_order,
get_one_hot, get_one_hot,
preprocess_imgs, preprocess_imgs,
return_number_of_total_training_data,
OCRDatasetYieldAugmentations
) )
from .weights_ensembling import run_ensembling from .weights_ensembling import run_ensembling
import torch
from transformers import TrOCRProcessor
import evaluate
from transformers import default_data_collator
from transformers import VisionEncoderDecoderModel
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
class SaveWeightsAfterSteps(ModelCheckpoint): class SaveWeightsAfterSteps(ModelCheckpoint):
def __init__(self, save_interval, save_path, _config, **kwargs): def __init__(self, save_interval, save_path, _config, characters_cnnrnn_ocr=None, **kwargs):
if save_interval: if save_interval:
# batches # batches
super().__init__( super().__init__(
@ -67,12 +81,15 @@ class SaveWeightsAfterSteps(ModelCheckpoint):
verbose=1, verbose=1,
**kwargs) **kwargs)
self._config = _config self._config = _config
self.characters_cnnrnn_ocr = characters_cnnrnn_ocr
# overwrite tf-keras (Keras 2) implementation to get our _config JSON in # overwrite tf-keras (Keras 2) implementation to get our _config JSON in
def _save_handler(self, filepath): def _save_handler(self, filepath):
super()._save_handler(filepath) super()._save_handler(filepath)
with open(os.path.join(filepath, "config.json"), "w") as fp: with open(os.path.join(filepath, "config.json"), "w") as fp:
json.dump(self._config, fp) # encode dict into JSON json.dump(self._config, fp) # encode dict into JSON
if self.characters_cnnrnn_ocr:
os.system("cp "+self.characters_cnnrnn_ocr+" "+os.path.join(os.path.join(self.save_path, f"model_step_{self.step_count}"),"characters_org.txt"))
def configuration(): def configuration():
try: try:
@ -820,9 +837,126 @@ def run(_config,
usable_checkpoints = [os.path.join(dir_output, 'model_{epoch:02d}'.format(epoch=epoch + 1)) usable_checkpoints = [os.path.join(dir_output, 'model_{epoch:02d}'.format(epoch=epoch + 1))
for epoch in usable_checkpoints] for epoch in usable_checkpoints]
ens_path = os.path.join(dir_output, 'model_ens_avg') ens_path = os.path.join(dir_output, 'model_ens_avg')
run_ensembling(usable_checkpoints, ens_path) run_ensembling(usable_checkpoints, ens_path, framework="tensorflow")
_log.info("ensemble model saved under '%s'", ens_path) _log.info("ensemble model saved under '%s'", ens_path)
# =======
elif task=="transformer-ocr":
dir_img, dir_lab = get_dirs_or_files(dir_train)
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-printed")
ls_files_images = os.listdir(dir_img)
aug_multip = return_multiplier_based_on_augmnentations(augmentation, color_padding_rotation, rotation_not_90, blur_aug, degrading, bin_deg,
brightening, padding_white, adding_rgb_foreground, adding_rgb_background, binarization,
image_inversion, channels_shuffling, add_red_textlines, white_noise_strap, textline_skewing, textline_skewing_bin, textline_left_in_depth, textline_left_in_depth_bin, textline_right_in_depth, textline_right_in_depth_bin, textline_up_in_depth, textline_up_in_depth_bin, textline_down_in_depth, textline_down_in_depth_bin, pepper_bin_aug, pepper_aug, degrade_scales, number_of_backgrounds_per_image, thetha, thetha_padd, brightness, padd_colors, shuffle_indexes, pepper_indexes, skewing_amplitudes, blur_k, white_padds)
len_dataset = aug_multip*len(ls_files_images)
dataset = OCRDatasetYieldAugmentations(
dir_img=dir_img,
dir_img_bin=dir_img_bin,
dir_lab=dir_lab,
processor=processor,
max_target_length=max_len,
augmentation = augmentation,
binarization = binarization,
add_red_textlines = add_red_textlines,
white_noise_strap = white_noise_strap,
adding_rgb_foreground = adding_rgb_foreground,
adding_rgb_background = adding_rgb_background,
bin_deg = bin_deg,
blur_aug = blur_aug,
brightening = brightening,
padding_white = padding_white,
color_padding_rotation = color_padding_rotation,
rotation_not_90 = rotation_not_90,
degrading = degrading,
channels_shuffling = channels_shuffling,
textline_skewing = textline_skewing,
textline_skewing_bin = textline_skewing_bin,
textline_right_in_depth = textline_right_in_depth,
textline_left_in_depth = textline_left_in_depth,
textline_up_in_depth = textline_up_in_depth,
textline_down_in_depth = textline_down_in_depth,
textline_right_in_depth_bin = textline_right_in_depth_bin,
textline_left_in_depth_bin = textline_left_in_depth_bin,
textline_up_in_depth_bin = textline_up_in_depth_bin,
textline_down_in_depth_bin = textline_down_in_depth_bin,
pepper_aug = pepper_aug,
pepper_bin_aug = pepper_bin_aug,
list_all_possible_background_images=list_all_possible_background_images,
list_all_possible_foreground_rgbs=list_all_possible_foreground_rgbs,
blur_k = blur_k,
degrade_scales = degrade_scales,
white_padds = white_padds,
thetha_padd = thetha_padd,
thetha = thetha,
brightness = brightness,
padd_colors = padd_colors,
number_of_backgrounds_per_image = number_of_backgrounds_per_image,
shuffle_indexes = shuffle_indexes,
pepper_indexes = pepper_indexes,
skewing_amplitudes = skewing_amplitudes,
dir_rgb_backgrounds = dir_rgb_backgrounds,
dir_rgb_foregrounds = dir_rgb_foregrounds,
len_data=len_dataset,
)
# Create a DataLoader
data_loader = torch.utils.data.DataLoader(dataset, batch_size=1)
train_dataset = data_loader.dataset
if continue_training:
model = VisionEncoderDecoderModel.from_pretrained(dir_of_start_model)
else:
model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-printed")
# set special tokens used for creating the decoder_input_ids from the labels
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id
# make sure vocab size is set correctly
model.config.vocab_size = model.config.decoder.vocab_size
# set beam search parameters
model.config.eos_token_id = processor.tokenizer.sep_token_id
model.config.max_length = max_len
model.config.early_stopping = True
model.config.no_repeat_ngram_size = 3
model.config.length_penalty = 2.0
model.config.num_beams = 4
training_args = Seq2SeqTrainingArguments(
predict_with_generate=True,
num_train_epochs=n_epochs,
learning_rate=learning_rate,
per_device_train_batch_size=n_batch,
fp16=True,
output_dir=dir_output,
logging_steps=2,
save_steps=save_interval,
)
cer_metric = evaluate.load("cer")
# instantiate trainer
trainer = Seq2SeqTrainer(
model=model,
tokenizer=processor.feature_extractor,
args=training_args,
train_dataset=train_dataset,
data_collator=default_data_collator,
)
trainer.train()
elif task=='reading_order': elif task=='reading_order':
if continue_training: if continue_training:
model = load_model(dir_of_start_model, compile=False) model = load_model(dir_of_start_model, compile=False)

View file

@ -14,6 +14,9 @@ import tensorflow as tf
from PIL import Image, ImageFile, ImageEnhance from PIL import Image, ImageFile, ImageEnhance
import torch
from torch.utils.data import IterableDataset
ImageFile.LOAD_TRUNCATED_IMAGES = True ImageFile.LOAD_TRUNCATED_IMAGES = True
@ -78,6 +81,7 @@ def add_salt_and_pepper_noise(img, salt_prob, pepper_prob):
return noisy_image return noisy_image
def invert_image(img): def invert_image(img):
img_inv = 255 - img img_inv = 255 - img
return img_inv return img_inv
@ -1242,3 +1246,411 @@ def preprocess_img_ocr(
for pepper_ind in pepper_indexes: for pepper_ind in pepper_indexes:
img_noisy = add_salt_and_pepper_noise(img_bin_corr, pepper_ind, pepper_ind) img_noisy = add_salt_and_pepper_noise(img_bin_corr, pepper_ind, pepper_ind)
yield scale_image(img_noisy), lab yield scale_image(img_noisy), lab
class OCRDatasetYieldAugmentations(IterableDataset):
def __init__(
self,
dir_img,
dir_img_bin,
dir_lab,
processor,
max_target_length=128,
augmentation = None,
binarization = None,
add_red_textlines = None,
white_noise_strap = None,
adding_rgb_foreground = None,
adding_rgb_background = None,
bin_deg = None,
blur_aug = None,
brightening = None,
padding_white = None,
color_padding_rotation = None,
rotation_not_90 = None,
degrading = None,
channels_shuffling = None,
textline_skewing = None,
textline_skewing_bin = None,
textline_right_in_depth = None,
textline_left_in_depth = None,
textline_up_in_depth = None,
textline_down_in_depth = None,
textline_right_in_depth_bin = None,
textline_left_in_depth_bin = None,
textline_up_in_depth_bin = None,
textline_down_in_depth_bin = None,
pepper_aug = None,
pepper_bin_aug = None,
list_all_possible_background_images=None,
list_all_possible_foreground_rgbs=None,
blur_k = None,
degrade_scales = None,
white_padds = None,
thetha_padd = None,
thetha = None,
brightness = None,
padd_colors = None,
number_of_backgrounds_per_image = None,
shuffle_indexes = None,
pepper_indexes = None,
skewing_amplitudes = None,
dir_rgb_backgrounds = None,
dir_rgb_foregrounds = None,
len_data=None,
):
"""
Args:
images_dir (str): Path to the directory containing images.
labels_dir (str): Path to the directory containing label text files.
tokenizer: Tokenizer for processing labels.
transform: Transformations applied after augmentation (e.g., ToTensor, normalization).
image_size (tuple): Size to resize images to.
max_seq_len (int): Maximum sequence length for tokenized labels.
scales (list or None): List of scale factors to apply.
"""
self.dir_img = dir_img
self.dir_img_bin = dir_img_bin
self.dir_lab = dir_lab
self.processor = processor
self.max_target_length = max_target_length
#self.scales = scales if scales else []
self.augmentation = augmentation
self.binarization = binarization
self.add_red_textlines = add_red_textlines
self.white_noise_strap = white_noise_strap
self.adding_rgb_foreground = adding_rgb_foreground
self.adding_rgb_background = adding_rgb_background
self.bin_deg = bin_deg
self.blur_aug = blur_aug
self.brightening = brightening
self.padding_white = padding_white
self.color_padding_rotation = color_padding_rotation
self.rotation_not_90 = rotation_not_90
self.degrading = degrading
self.channels_shuffling = channels_shuffling
self.textline_skewing = textline_skewing
self.textline_skewing_bin = textline_skewing_bin
self.textline_right_in_depth = textline_right_in_depth
self.textline_left_in_depth = textline_left_in_depth
self.textline_up_in_depth = textline_up_in_depth
self.textline_down_in_depth = textline_down_in_depth
self.textline_right_in_depth_bin = textline_right_in_depth_bin
self.textline_left_in_depth_bin = textline_left_in_depth_bin
self.textline_up_in_depth_bin = textline_up_in_depth_bin
self.textline_down_in_depth_bin = textline_down_in_depth_bin
self.pepper_aug = pepper_aug
self.pepper_bin_aug = pepper_bin_aug
self.list_all_possible_background_images=list_all_possible_background_images
self.list_all_possible_foreground_rgbs=list_all_possible_foreground_rgbs
self.blur_k = blur_k
self.degrade_scales = degrade_scales
self.white_padds = white_padds
self.thetha_padd = thetha_padd
self.thetha = thetha
self.brightness = brightness
self.padd_colors = padd_colors
self.number_of_backgrounds_per_image = number_of_backgrounds_per_image
self.shuffle_indexes = shuffle_indexes
self.pepper_indexes = pepper_indexes
self.skewing_amplitudes = skewing_amplitudes
self.dir_rgb_backgrounds = dir_rgb_backgrounds
self.dir_rgb_foregrounds = dir_rgb_foregrounds
self.image_files = os.listdir(dir_img)#sorted([f for f in os.listdir(images_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))])
self.len_data = len_data
#assert len(self.image_files) == len(self.label_files), "Number of images and labels must match!"
def __len__(self):
return self.len_data
def __iter__(self):
for img_file in self.image_files:
# Load image
f_name = img_file.split('.')[0]
txt_inp = open(os.path.join(self.dir_lab, f_name+'.txt'),'r').read().split('\n')[0]
img = cv2.imread(os.path.join(self.dir_img, img_file))
img = img.astype(np.uint8)
if self.dir_img_bin:
img_bin_corr = cv2.imread(os.path.join(self.dir_img_bin, f_name+'.png') )
img_bin_corr = img_bin_corr.astype(np.uint8)
else:
img_bin_corr = None
labels = self.processor.tokenizer(txt_inp,
padding="max_length",
max_length=self.max_target_length).input_ids
labels = [label if label != self.processor.tokenizer.pad_token_id else -100 for label in labels]
if self.augmentation:
pixel_values = self.processor(Image.fromarray(img), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.color_padding_rotation:
for index, thetha_ind in enumerate(self.thetha_padd):
for padd_col in self.padd_colors:
img_out = rotation_not_90_func_single_image(do_padding_for_ocr(img, 1.2, padd_col), thetha_ind)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.rotation_not_90:
for index, thetha_ind in enumerate(self.thetha):
img_out = rotation_not_90_func_single_image(img, thetha_ind)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.blur_aug:
for index, blur_type in enumerate(self.blur_k):
img_out = bluring(img, blur_type)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.degrading:
for index, deg_scale_ind in enumerate(self.degrade_scales):
try:
img_out = do_degrading(img, deg_scale_ind)
except:
img_out = np.copy(img)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.bin_deg:
for index, deg_scale_ind in enumerate(self.degrade_scales):
try:
img_out = self.do_degrading(img_bin_corr, deg_scale_ind)
except:
img_out = np.copy(img_bin_corr)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.brightening:
for index, bright_scale_ind in enumerate(self.brightness):
try:
img_out = do_brightening(dir_img, bright_scale_ind)
except:
img_out = np.copy(img)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.padding_white:
for index, padding_size in enumerate(self.white_padds):
for padd_col in self.padd_colors:
img_out = do_padding_for_ocr(img, padding_size, padd_col)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.adding_rgb_foreground:
for i_n in range(self.number_of_backgrounds_per_image):
background_image_chosen_name = random.choice(self.list_all_possible_background_images)
foreground_rgb_chosen_name = random.choice(self.list_all_possible_foreground_rgbs)
img_rgb_background_chosen = cv2.imread(self.dir_rgb_backgrounds + '/' + background_image_chosen_name)
foreground_rgb_chosen = np.load(self.dir_rgb_foregrounds + '/' + foreground_rgb_chosen_name)
img_out = return_binary_image_with_given_rgb_background_and_given_foreground_rgb(img_bin_corr, img_rgb_background_chosen, foreground_rgb_chosen)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.adding_rgb_background:
for i_n in range(self.number_of_backgrounds_per_image):
background_image_chosen_name = random.choice(self.list_all_possible_background_images)
img_rgb_background_chosen = cv2.imread(self.dir_rgb_backgrounds + '/' + background_image_chosen_name)
img_out = return_binary_image_with_given_rgb_background(img_bin_corr, img_rgb_background_chosen)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.binarization:
pixel_values = self.processor(Image.fromarray(img_bin_corr), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.channels_shuffling:
for shuffle_index in self.shuffle_indexes:
img_out = return_shuffled_channels(img, shuffle_index)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.add_red_textlines:
img_out = return_image_with_red_elements(img, img_bin_corr)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.white_noise_strap:
img_out = return_image_with_strapped_white_noises(img)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.textline_skewing:
for index, des_scale_ind in enumerate(self.skewing_amplitudes):
try:
img_out = do_deskewing(img, des_scale_ind)
except:
img_out = np.copy(img)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.textline_skewing_bin:
for index, des_scale_ind in enumerate(self.skewing_amplitudes):
try:
img_out = do_deskewing(img_bin_corr, des_scale_ind)
except:
img_out = np.copy(img_bin_corr)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.textline_left_in_depth:
try:
img_out = do_direction_in_depth(img, 'left')
except:
img_out = np.copy(img)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.textline_left_in_depth_bin:
try:
img_out = do_direction_in_depth(img_bin_corr, 'left')
except:
img_out = np.copy(img_bin_corr)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.textline_right_in_depth:
try:
img_out = do_direction_in_depth(img, 'right')
except:
img_out = np.copy(img)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.textline_right_in_depth_bin:
try:
img_out = do_direction_in_depth(img_bin_corr, 'right')
except:
img_out = np.copy(img_bin_corr)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.textline_up_in_depth:
try:
img_out = do_direction_in_depth(img, 'up')
except:
img_out = np.copy(img)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.textline_up_in_depth_bin:
try:
img_out = do_direction_in_depth(img_bin_corr, 'up')
except:
img_out = np.copy(img_bin_corr)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.textline_down_in_depth:
try:
img_out = do_direction_in_depth(img, 'down')
except:
img_out = np.copy(img)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.textline_down_in_depth_bin:
try:
img_out = do_direction_in_depth(img_bin_corr, 'down')
except:
img_out = np.copy(img_bin_corr)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.pepper_bin_aug:
for index, pepper_ind in enumerate(self.pepper_indexes):
img_out = add_salt_and_pepper_noise(img_bin_corr, pepper_ind, pepper_ind)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.pepper_aug:
for index, pepper_ind in enumerate(self.pepper_indexes):
img_out = add_salt_and_pepper_noise(img, pepper_ind, pepper_ind)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
else:
pixel_values = self.processor(Image.fromarray(img), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding

View file

@ -16,28 +16,53 @@ from ..patch_encoder import (
PatchEncoder, PatchEncoder,
Patches, Patches,
) )
from PIL import Image
import torch
from transformers import VisionEncoderDecoderModel
def run_ensembling(model_dirs, out_dir): def run_ensembling(dir_models, out, framework):
all_weights = [] ls_models = os.listdir(dir_models)
if framework=="torch":
for model_dir in model_dirs: models = []
assert os.path.isdir(model_dir), model_dir sd_models = []
model = load_model(model_dir, compile=False,
custom_objects=dict(PatchEncoder=PatchEncoder,
Patches=Patches))
all_weights.append(model.get_weights())
new_weights = [] for model_name in ls_models:
for layer_weights in zip(*all_weights): model = VisionEncoderDecoderModel.from_pretrained(os.path.join(dir_models,model_name))
layer_weights = np.array([np.array(weights).mean(axis=0) models.append(model)
for weights in zip(*layer_weights)]) sd_models.append(model.state_dict())
new_weights.append(layer_weights) for key in sd_models[0]:
sd_models[0][key] = sum(sd[key] for sd in sd_models) / len(sd_models)
model.load_state_dict(sd_models[0])
os.system("mkdir "+out)
torch.save(model.state_dict(), os.path.join(out, "pytorch_model.bin"))
os.system('cp '+os.path.join(os.path.join(dir_models,model_name) , "config_eynollah.json ")+out)
else:
weights=[]
#model = tf.keras.models.clone_model(model) for model_name in ls_models:
model.set_weights(new_weights) model = load_model(os.path.join(dir_models,model_name) , compile=False, custom_objects={'PatchEncoder':PatchEncoder, 'Patches': Patches})
weights.append(model.get_weights())
new_weights = list()
model.save(out_dir) for weights_list_tuple in zip(*weights):
os.system('cp ' + os.path.join(model_dirs[0], "config.json ") + out_dir + "/") new_weights.append(
[np.array(weights_).mean(axis=0)\
for weights_ in zip(*weights_list_tuple)])
new_weights = [np.array(x) for x in new_weights]
model.set_weights(new_weights)
model.save(out)
os.system('cp '+os.path.join(os.path.join(dir_models,model_name) , "config_eynollah.json ")+out)
try:
os.system('cp '+os.path.join(os.path.join(dir_models,model_name) , "characters_org.txt ")+out)
except:
pass
@click.command() @click.command()
@click.option( @click.option(
@ -55,12 +80,17 @@ def run_ensembling(model_dirs, out_dir):
required=True, required=True,
type=click.Path(exists=False, file_okay=False), type=click.Path(exists=False, file_okay=False),
) )
def ensemble_cli(in_, out): @click.option(
"--framework",
"-fw",
help="this parameter gets tensorflow or torch as model framework",
)
def ensemble_cli(in_, out, framework):
""" """
mix multiple model weights mix multiple model weights
Load a sequence of models and mix them into a single ensemble model Load a sequence of models and mix them into a single ensemble model
by averaging their weights. Write the resulting model. by averaging their weights. Write the resulting model.
""" """
run_ensembling(in_, out) run_ensembling(in_, out, framework)

View file

@ -9,8 +9,8 @@ else:
import importlib.resources as importlib_resources import importlib.resources as importlib_resources
def get_font(): def get_font(font_size):
#font_path = "Charis-7.000/Charis-Regular.ttf" # Make sure this file exists! #font_path = "Charis-7.000/Charis-Regular.ttf" # Make sure this file exists!
font = importlib_resources.files(__package__) / "../Charis-Regular.ttf" font = importlib_resources.files(__package__) / "../Amiri-Regular.ttf"
with importlib_resources.as_file(font) as font: with importlib_resources.as_file(font) as font:
return ImageFont.truetype(font=font, size=40) return ImageFont.truetype(font=font, size=font_size)

View file

@ -1,17 +1,17 @@
{ {
"backbone_type" : "transformer", "backbone_type" : "transformer",
"task": "cnn-rnn-ocr", "task": "transformer-ocr",
"n_classes" : 2, "n_classes" : 2,
"max_len": 280, "max_len": 192,
"n_epochs" : 3, "n_epochs" : 1,
"input_height" : 32, "input_height" : 32,
"input_width" : 512, "input_width" : 512,
"weight_decay" : 1e-6, "weight_decay" : 1e-6,
"n_batch" : 4, "n_batch" : 1,
"learning_rate": 1e-5, "learning_rate": 1e-5,
"save_interval": 1500, "save_interval": 1500,
"patches" : false, "patches" : false,
"pretraining" : true, "pretraining" : false,
"augmentation" : true, "augmentation" : true,
"flip_aug" : false, "flip_aug" : false,
"blur_aug" : true, "blur_aug" : true,
@ -77,7 +77,6 @@
"dir_output": "/home/vahid/extracted_lines/1919_bin/output", "dir_output": "/home/vahid/extracted_lines/1919_bin/output",
"dir_rgb_backgrounds": "/home/vahid/Documents/1_2_test_eynollah/set_rgb_background", "dir_rgb_backgrounds": "/home/vahid/Documents/1_2_test_eynollah/set_rgb_background",
"dir_rgb_foregrounds": "/home/vahid/Documents/1_2_test_eynollah/out_set_rgb_foreground", "dir_rgb_foregrounds": "/home/vahid/Documents/1_2_test_eynollah/out_set_rgb_foreground",
"dir_img_bin": "/home/vahid/extracted_lines/1919_bin/images_bin", "dir_img_bin": "/home/vahid/extracted_lines/1919_bin/images_bin"
"characters_txt_file":"/home/vahid/Downloads/models_eynollah/model_eynollah_ocr_cnnrnn_20250930/characters_org.txt"
} }

View file

@ -8,3 +8,8 @@ tensorflow-addons # for connected_components, depublished and only compatible wi
tensorflow < 2.16 # for tensorflow-addons, so only needed in training tensorflow < 2.16 # for tensorflow-addons, so only needed in training
tf_data < 2.16 # for tensorflow-addons, so only needed in training tf_data < 2.16 # for tensorflow-addons, so only needed in training
protobuf < 5 # for tensorflow-addons, so only needed in training protobuf < 5 # for tensorflow-addons, so only needed in training
torch
evaluate
accelerate
jiwer
transformers <= 4.30.2