Merge branch 'integrating_trocr_and_torch_ensembling_and_updating_characters_list'

# Conflicts:
#	src/eynollah/eynollah.py
#	src/eynollah/eynollah_ocr.py
#	src/eynollah/patch_encoder.py
#	src/eynollah/training/cli.py
#	src/eynollah/training/gt_gen_utils.py
#	src/eynollah/training/inference.py
#	src/eynollah/training/models.py
#	src/eynollah/training/train.py
#	src/eynollah/training/utils.py
#	src/eynollah/training/weights_ensembling.py
#	train/requirements.txt
This commit is contained in:
kba 2026-06-11 18:59:33 +02:00
commit 6df11d92d8
16 changed files with 1048 additions and 245 deletions

Binary file not shown.

View file

@ -69,10 +69,11 @@ class Eynollah_ocr:
self.model_zoo.load_models(['ocr', 'tr'])
self.model_zoo.get('ocr').to(self.device)
else:
self.model_zoo.load_models('ocr')
self.model_zoo.load_models('num_to_char')
self.model_zoo.load_models('characters')
self.end_character = len(self.model_zoo.get('characters')) + 2
self.model_zoo.load_model('ocr', '')
self.input_shape = self.model_zoo.get('ocr').input_shape[1:3]
self.model_zoo.load_model('num_to_char')
self.model_zoo.load_model('characters')
self.end_character = len(self.model_zoo.get('characters', list)) + 2
@property
def device(self):
@ -657,7 +658,7 @@ class Eynollah_ocr:
if out_image_with_text:
image_text = Image.new("RGB", (img.shape[1], img.shape[0]), "white")
draw = ImageDraw.Draw(image_text)
font = get_font()
font = get_font(font_size=40)
for indexer_text, bb_ind in enumerate(total_bb_coordinates):
x_bb = bb_ind[0]
@ -823,8 +824,8 @@ class Eynollah_ocr:
page_ns=page_ns,
img_bin=img_bin,
image_width=512,
image_height=32,
image_width=self.input_shape[1],
image_height=self.input_shape[0],
)
self.write_ocr(

View file

@ -20,6 +20,7 @@ class PatchEncoder(layers.Layer):
def get_config(self):
return dict(num_patches=self.num_patches,
projection_dim=self.projection_dim,
position_embedding=self.position_embedding,
**super().get_config())
class Patches(layers.Layer):

View file

@ -9,7 +9,12 @@ from .generate_gt_for_training import main as generate_gt_cli
from .inference import main as inference_cli
from .train import ex
from .extract_line_gt import linegt_cli
<<<<<<< HEAD
from .weights_ensembling import ensemble_cli
=======
from .weights_ensembling import main as ensemble_cli
from .generate_or_update_cnn_rnn_ocr_character_list import main as update_ocr_characters_cli
>>>>>>> integrating_trocr_and_torch_ensembling_and_updating_characters_list
@click.command(context_settings=dict(
ignore_unknown_options=True,
@ -28,3 +33,4 @@ main.add_command(inference_cli, 'inference')
main.add_command(train_cli, 'train')
main.add_command(linegt_cli, 'export_textline_images_and_text')
main.add_command(ensemble_cli, 'ensembling')
main.add_command(update_ocr_characters_cli, 'generate_or_update_cnn_rnn_ocr_character_list')

View file

@ -50,6 +50,12 @@ from ..utils import is_image_filename
is_flag=True,
help="if this parameter set to true, cropped textline images will not be masked with textline contour.",
)
@click.option(
"--exclude_vertical_lines",
"-exv",
is_flag=True,
help="if this parameter set to true, vertical textline images will be excluded.",
)
def linegt_cli(
image,
dir_in,
@ -57,6 +63,7 @@ def linegt_cli(
dir_out,
pref_of_dataset,
do_not_mask_with_textline_contour,
exclude_vertical_lines,
):
assert bool(dir_in) ^ bool(image), "Set --dir-in or --image-filename, not both"
if dir_in:
@ -70,14 +77,13 @@ def linegt_cli(
for dir_img in ls_imgs:
file_name = Path(dir_img).stem
dir_xml = os.path.join(dir_xmls, file_name + '.xml')
img = cv2.imread(dir_img)
total_bb_coordinates = []
tree1 = ET.parse(dir_xml, parser=ET.XMLParser(encoding="utf-8"))
root1 = tree1.getroot()
alltags = [elem.tag for elem in root1.iter()]
tree = ET.parse(dir_xml, parser=ET.XMLParser(encoding="utf-8"))
root = tree.getroot()
alltags = [elem.tag for elem in root.iter()]
name_space = alltags[0].split('}')[0]
name_space = name_space.split('{')[1]
@ -89,7 +95,7 @@ def linegt_cli(
indexer_text_region = 0
indexer_textlines = 0
# FIXME: non recursive, use OCR-D PAGE generateDS API. Or use an existing tool for this purpose altogether
for nn in root1.iter(region_tags):
for nn in root.iter(region_tags):
for child_textregion in nn:
if child_textregion.tag.endswith("TextLine"):
for child_textlines in child_textregion:
@ -100,6 +106,10 @@ def linegt_cli(
x, y, w, h = cv2.boundingRect(textline_coords)
if exclude_vertical_lines and h > 1.4 * w:
img_crop = None
continue
total_bb_coordinates.append([x, y, w, h])
img_poly_on_img = np.copy(img)
@ -114,12 +124,15 @@ def linegt_cli(
img_crop[mask_poly == 0] = 255
if img_crop.shape[0] == 0 or img_crop.shape[1] == 0:
img_crop = None
continue
if child_textlines.tag.endswith("TextEquiv"):
for cheild_text in child_textlines:
if cheild_text.tag.endswith("Unicode"):
textline_text = cheild_text.text
if textline_text:
if textline_text and img_crop is not None:
base_name = os.path.join(
dir_out, file_name + '_line_' + str(indexer_textlines)
)
@ -131,4 +144,4 @@ def linegt_cli(
with open(base_name + '.txt', 'w') as text_file:
text_file.write(textline_text)
cv2.imwrite(base_name + '.png', img_crop)
indexer_textlines += 1
indexer_textlines += 1

View file

@ -6,6 +6,7 @@ from pathlib import Path
from PIL import Image, ImageDraw, ImageFont
import cv2
import numpy as np
from eynollah.utils.font import get_font
from .gt_gen_utils import (
filter_contours_area_of_image,
@ -393,11 +394,15 @@ def visualize_reading_order(xml_file, dir_xml, dir_out, dir_imgs):
layout = np.zeros( (y_len,x_len,3) )
layout = cv2.fillPoly(layout, pts =co_text_all, color=(1,1,1))
img_file_name_with_format = find_format_of_given_filename_in_dir(dir_imgs, f_name)
img = cv2.imread(os.path.join(dir_imgs, img_file_name_with_format))
try:
img_file_name_with_format = find_format_of_given_filename_in_dir(dir_imgs, f_name)
img = cv2.imread(os.path.join(dir_imgs, img_file_name_with_format))
overlayed = overlay_layout_on_image(layout, img, cx_ordered, cy_ordered, color, thickness)
cv2.imwrite(os.path.join(dir_out, f_name+'.png'), overlayed)
except:
pass
overlayed = overlay_layout_on_image(layout, img, cx_ordered, cy_ordered, color, thickness)
cv2.imwrite(os.path.join(dir_out, f_name+'.png'), overlayed)
else:
img = np.zeros( (y_len,x_len,3) )
@ -452,14 +457,17 @@ def visualize_textline_segmentation(xml_file, dir_xml, dir_out, dir_imgs):
xml_file = os.path.join(dir_xml,ind_xml )
f_name = Path(ind_xml).stem
img_file_name_with_format = find_format_of_given_filename_in_dir(dir_imgs, f_name)
img = cv2.imread(os.path.join(dir_imgs, img_file_name_with_format))
try:
img_file_name_with_format = find_format_of_given_filename_in_dir(dir_imgs, f_name)
img = cv2.imread(os.path.join(dir_imgs, img_file_name_with_format))
co_tetxlines, y_len, x_len = get_textline_contours_for_visualization(xml_file)
co_tetxlines, y_len, x_len = get_textline_contours_for_visualization(xml_file)
added_image = visualize_image_from_contours(co_tetxlines, img)
added_image = visualize_image_from_contours(co_tetxlines, img)
cv2.imwrite(os.path.join(dir_out, f_name+'.png'), added_image)
cv2.imwrite(os.path.join(dir_out, f_name+'.png'), added_image)
except:
pass
@ -509,15 +517,17 @@ def visualize_layout_segmentation(xml_file, dir_xml, dir_out, dir_imgs):
f_name = Path(ind_xml).stem
print(f_name, 'f_name')
img_file_name_with_format = find_format_of_given_filename_in_dir(dir_imgs, f_name)
img = cv2.imread(os.path.join(dir_imgs, img_file_name_with_format))
try:
img_file_name_with_format = find_format_of_given_filename_in_dir(dir_imgs, f_name)
img = cv2.imread(os.path.join(dir_imgs, img_file_name_with_format))
co_text, co_graphic, co_sep, co_img, co_table, co_map, co_noise, y_len, x_len = get_layout_contours_for_visualization(xml_file)
co_text, co_graphic, co_sep, co_img, co_table, co_map, co_music, co_noise, y_len, x_len = get_layout_contours_for_visualization(xml_file)
added_image = visualize_image_from_contours_layout(co_text['paragraph'], co_text['header']+co_text['heading'], co_text['drop-capital'], co_sep, co_img, co_text['marginalia'], co_table, co_map, co_music, img)
added_image = visualize_image_from_contours_layout(co_text['paragraph'], co_text['header']+co_text['heading'], co_text['drop-capital'], co_sep, co_img, co_text['marginalia'], co_table, img)
cv2.imwrite(os.path.join(dir_out, f_name+'.png'), added_image)
cv2.imwrite(os.path.join(dir_out, f_name+'.png'), added_image)
except:
pass
@ -552,8 +562,8 @@ def visualize_ocr_text(xml_file, dir_xml, dir_out):
else:
xml_files_ind = [xml_file]
font_path = "Charis-7.000/Charis-Regular.ttf" # Make sure this file exists!
font = ImageFont.truetype(font_path, 40)
###font_path = "Charis-7.000/Charis-Regular.ttf" # Make sure this file exists!
font = get_font(font_size=40)#ImageFont.truetype(font_path, 40)
for ind_xml in tqdm(xml_files_ind):
indexer = 0
@ -590,11 +600,11 @@ def visualize_ocr_text(xml_file, dir_xml, dir_out):
is_vertical = h > 2*w # Check orientation
font = fit_text_single_line(draw, ocr_texts[index], font_path, w, int(h*0.4) )
font = fit_text_single_line(draw, ocr_texts[index], w, int(h*0.4) )
if is_vertical:
vertical_font = fit_text_single_line(draw, ocr_texts[index], font_path, h, int(w * 0.8))
vertical_font = fit_text_single_line(draw, ocr_texts[index], h, int(w * 0.8))
text_img = Image.new("RGBA", (h, w), (255, 255, 255, 0)) # Note: dimensions are swapped
text_draw = ImageDraw.Draw(text_img)

View file

@ -0,0 +1,59 @@
import os
import numpy as np
import json
import click
import logging
def run_character_list_update(dir_labels, out, current_character_list):
ls_labels = os.listdir(dir_labels)
ls_labels = [ind for ind in ls_labels if ind.endswith('.txt')]
if current_character_list:
with open(current_character_list, 'r') as f_name:
characters = json.load(f_name)
characters = set(characters)
else:
characters = set()
for ind in ls_labels:
label = open(os.path.join(dir_labels,ind),'r').read().split('\n')[0]
for char in label:
characters.add(char)
characters = sorted(list(set(characters)))
with open(out, 'w') as f_name:
json.dump(characters, f_name)
@click.command()
@click.option(
"--dir_labels",
"-dl",
help="directory of labels which are .txt files",
type=click.Path(exists=True, file_okay=False),
required=True,
)
@click.option(
"--current_character_list",
"-ccl",
help="existing character list in a .txt file that needs to be updated with a set of labels",
type=click.Path(exists=True, file_okay=True),
required=False,
)
@click.option(
"--out",
"-o",
help="An output .txt file where the generated or updated character list will be written",
type=click.Path(exists=False, file_okay=True),
)
def main(dir_labels, out, current_character_list):
run_character_list_update(dir_labels, out, current_character_list)

View file

@ -8,7 +8,7 @@ from shapely import geometry
from pathlib import Path
from PIL import ImageFont
from ocrd_utils import bbox_from_points
from eynollah.utils.font import get_font
KERNEL = np.ones((5, 5), np.uint8)
NS = { 'pc': 'http://schema.primaresearch.org/PAGE/gts/pagecontent/2019-07-15'
@ -18,7 +18,7 @@ with warnings.catch_warnings():
warnings.simplefilter("ignore")
def visualize_image_from_contours_layout(co_par, co_header, co_drop, co_sep, co_image, co_marginal, co_table, co_map, img):
def visualize_image_from_contours_layout(co_par, co_header, co_drop, co_sep, co_image, co_marginal, co_table, co_map, co_music, img):
alpha = 0.5
blank_image = np.ones( (img.shape[:]), dtype=np.uint8) * 255
@ -32,6 +32,7 @@ def visualize_image_from_contours_layout(co_par, co_header, co_drop, co_sep, co_
col_marginal = (106, 90, 205)
col_table = (0, 90, 205)
col_map = (90, 90, 205)
col_music = (90, 90, 0)
if len(co_image)>0:
cv2.drawContours(blank_image, co_image, -1, col_image, thickness=cv2.FILLED) # Fill the contour
@ -60,6 +61,9 @@ def visualize_image_from_contours_layout(co_par, co_header, co_drop, co_sep, co_
if len(co_map)>0:
cv2.drawContours(blank_image, co_map, -1, col_map, thickness=cv2.FILLED) # Fill the contour
if len(co_music)>0:
cv2.drawContours(blank_image, co_music, -1, col_music, thickness=cv2.FILLED) # Fill the contour
img_final =cv2.cvtColor(blank_image, cv2.COLOR_BGR2RGB)
added_image = cv2.addWeighted(img,alpha,img_final,1- alpha,0)
@ -352,11 +356,11 @@ def get_textline_contours_and_ocr_text(xml_file):
ocr_textlines.append(ocr_text_in[0])
return co_use_case, y_len, x_len, ocr_textlines
def fit_text_single_line(draw, text, font_path, max_width, max_height):
def fit_text_single_line(draw, text, max_width, max_height):
initial_font_size = 50
font_size = initial_font_size
while font_size > 10: # Minimum font size
font = ImageFont.truetype(font_path, font_size)
font = get_font(font_size=font_size)# ImageFont.truetype(font_path, font_size)
text_bbox = draw.textbbox((0, 0), text, font=font) # Get text bounding box
text_width = text_bbox[2] - text_bbox[0]
text_height = text_bbox[3] - text_bbox[1]
@ -366,7 +370,7 @@ def fit_text_single_line(draw, text, font_path, max_width, max_height):
font_size -= 2 # Reduce font size and retry
return ImageFont.truetype(font_path, 10) # Smallest font fallback
return get_font(font_size=10)#ImageFont.truetype(font_path, 10) # Smallest font fallback
def get_layout_contours_for_visualization(xml_file):
tree1 = ET.parse(xml_file, parser = ET.XMLParser(encoding='utf-8'))
@ -389,6 +393,7 @@ def get_layout_contours_for_visualization(xml_file):
co_img=[]
co_table=[]
co_map=[]
co_music=[]
co_noise=[]
types_text = []
@ -631,6 +636,31 @@ def get_layout_contours_for_visualization(xml_file):
break
co_map.append(np.array(c_t_in))
if tag.endswith('}MusicRegion') or tag.endswith('}musicregion'):
#print('sth')
for nn in root1.iter(tag):
c_t_in=[]
sumi=0
for vv in nn.iter():
# check the format of coords
if vv.tag==link+'Coords':
coords=bool(vv.attrib)
if coords:
p_h=vv.attrib['points'].split(' ')
c_t_in.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) )
break
else:
pass
if vv.tag==link+'Point':
c_t_in.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ])
sumi+=1
#print(vv.tag,'in')
elif vv.tag!=link+'Point' and sumi>=1:
break
co_music.append(np.array(c_t_in))
if tag.endswith('}NoiseRegion') or tag.endswith('}noiseregion'):
#print('sth')
@ -656,7 +686,7 @@ def get_layout_contours_for_visualization(xml_file):
elif vv.tag!=link+'Point' and sumi>=1:
break
co_noise.append(np.array(c_t_in))
return co_text, co_graphic, co_sep, co_img, co_table, co_map, co_noise, y_len, x_len
return co_text, co_graphic, co_sep, co_img, co_table, co_map, co_music, co_noise, y_len, x_len
def get_images_of_ground_truth(
gt_list,
@ -682,170 +712,192 @@ def get_images_of_ground_truth(
if not item.endswith('.xml')}
for index in tqdm(range(len(gt_list))):
#try:
print(gt_list[index])
tree1 = ET.parse(dir_in+'/'+gt_list[index], parser = ET.XMLParser(encoding='utf-8'))
root1=tree1.getroot()
alltags=[elem.tag for elem in root1.iter()]
link=alltags[0].split('}')[0]+'}'
try:
tree1 = ET.parse(dir_in+'/'+gt_list[index], parser = ET.XMLParser(encoding='utf-8'))
root1=tree1.getroot()
alltags=[elem.tag for elem in root1.iter()]
link=alltags[0].split('}')[0]+'}'
x_len, y_len = 0, 0
for jj in root1.iter(link+'Page'):
y_len=int(jj.attrib['imageHeight'])
x_len=int(jj.attrib['imageWidth'])
x_len, y_len = 0, 0
for jj in root1.iter(link+'Page'):
y_len=int(jj.attrib['imageHeight'])
x_len=int(jj.attrib['imageWidth'])
if 'columns_width' in list(config_params.keys()):
columns_width_dict = config_params['columns_width']
# FIXME: look in /Page/@custom as well
metadata_element = root1.find(link+'Metadata')
num_col = None
for child in metadata_element:
tag2 = child.tag
if tag2.endswith('}Comments') or tag2.endswith('}comments'):
text_comments = child.text
num_col = int(text_comments.split('num_col')[1])
if 'columns_width' in list(config_params.keys()):
columns_width_dict = config_params['columns_width']
metadata_element = root1.find(link+'Metadata')
num_col = None
for child in metadata_element:
tag2 = child.tag
if tag2.endswith('}Comments') or tag2.endswith('}comments'):
text_comments = child.text
num_col = int(text_comments.split('num_col')[1])
if num_col:
x_new = columns_width_dict[str(num_col)]
y_new = int ( x_new * (y_len / float(x_len)) )
if num_col:
x_new = columns_width_dict[str(num_col)]
y_new = int ( x_new * (y_len / float(x_len)) )
if printspace or "printspace_as_class_in_layout" in list(config_params.keys()):
ps = (root1.xpath('/pc:PcGts/pc:Page/pc:Border', namespaces=NS) +
root1.xpath('/pc:PcGts/pc:Page/pc:PrintSpace', namespaces=NS))
coords = root1.xpath('//pc:Coords/@points', namespaces=NS)
if len(ps):
points = ps[0].find('pc:Coords', NS).get('points')
ps_bbox = bbox_from_points(points)
elif missing_printspace == 'skip':
print(gt_list[index], "has no Border or PrintSpace - skipping file")
continue
elif missing_printspace == 'project' and len(coords):
print(gt_list[index], "has no Border or PrintSpace - projecting hull of segments")
bboxes = list(map(bbox_from_points, coords))
left, top, right, bottom = zip(*bboxes)
left = max(0, min(left) - 5)
top = max(0, min(top) - 5)
right = min(x_len, max(right) + 5)
bottom = min(y_len, max(bottom) + 5)
ps_bbox = [left, top, right, bottom]
else:
print(gt_list[index], "has no Border or PrintSpace - using full page")
ps_bbox = [0, 0, None, None]
if printspace or "printspace_as_class_in_layout" in list(config_params.keys()):
region_tags = np.unique([x for x in alltags if x.endswith('PrintSpace') or x.endswith('Border')])
co_use_case = []
for tag in region_tags:
tag_endings = ['}PrintSpace','}Border']
if config_file and (config_params['use_case']=='textline' or config_params['use_case']=='word' or config_params['use_case']=='glyph' or config_params['use_case']=='printspace'):
keys = list(config_params.keys())
if "artificial_class_label" in keys:
artificial_class_rgb_color = (255,255,0)
artificial_class_label = config_params['artificial_class_label']
if tag.endswith(tag_endings[0]) or tag.endswith(tag_endings[1]):
for nn in root1.iter(tag):
c_t_in = []
sumi = 0
for vv in nn.iter():
# check the format of coords
if vv.tag == link + 'Coords':
coords = bool(vv.attrib)
if coords:
p_h = vv.attrib['points'].split(' ')
c_t_in.append(
np.array([[int(x.split(',')[0]), int(x.split(',')[1])] for x in p_h]))
break
else:
pass
textline_rgb_color = (255, 0, 0)
if config_params['use_case']=='textline':
region_tags = np.unique([x for x in alltags if x.endswith('TextLine')])
elif config_params['use_case']=='word':
region_tags = np.unique([x for x in alltags if x.endswith('Word')])
elif config_params['use_case']=='glyph':
region_tags = np.unique([x for x in alltags if x.endswith('Glyph')])
elif config_params['use_case']=='printspace':
region_tags = np.unique([x for x in alltags if x.endswith('PrintSpace')])
co_use_case = []
for tag in region_tags:
if config_params['use_case']=='textline':
tag_endings = ['}TextLine','}textline']
elif config_params['use_case']=='word':
tag_endings = ['}Word','}word']
elif config_params['use_case']=='glyph':
tag_endings = ['}Glyph','}glyph']
elif config_params['use_case']=='printspace':
tag_endings = ['}PrintSpace','}printspace']
if tag.endswith(tag_endings[0]) or tag.endswith(tag_endings[1]):
for nn in root1.iter(tag):
c_t_in = []
sumi = 0
for vv in nn.iter():
# check the format of coords
if vv.tag == link + 'Coords':
coords = bool(vv.attrib)
if coords:
p_h = vv.attrib['points'].split(' ')
c_t_in.append(
np.array([[int(x.split(',')[0]), int(x.split(',')[1])] for x in p_h]))
if vv.tag == link + 'Point':
c_t_in.append([int(float(vv.attrib['x'])), int(float(vv.attrib['y']))])
sumi += 1
elif vv.tag != link + 'Point' and sumi >= 1:
break
else:
pass
co_use_case.append(np.array(c_t_in))
if vv.tag == link + 'Point':
c_t_in.append([int(float(vv.attrib['x'])), int(float(vv.attrib['y']))])
sumi += 1
elif vv.tag != link + 'Point' and sumi >= 1:
break
co_use_case.append(np.array(c_t_in))
img = np.zeros((y_len, x_len, 3))
if "artificial_class_label" in keys:
img_boundary = np.zeros((y_len, x_len))
erosion_rate = 0#1
dilation_rate = 2
dilation_early = 0
erosion_early = 2
co_use_case, img_boundary = update_region_contours(co_use_case, img_boundary, erosion_rate, dilation_rate, y_len, x_len, dilation_early=dilation_early, erosion_early=erosion_early)
img = np.zeros((y_len, x_len, 3))
if output_type == '2d':
img_poly = cv2.fillPoly(img, pts=co_use_case, color=(1, 1, 1))
img_poly = img_poly.astype(np.uint8)
imgray = cv2.cvtColor(img_poly, cv2.COLOR_BGR2GRAY)
_, thresh = cv2.threshold(imgray, 0, 255, 0)
contours, _ = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
cnt_size = np.array([cv2.contourArea(contours[j]) for j in range(len(contours))])
try:
cnt = contours[np.argmax(cnt_size)]
x, y, w, h = cv2.boundingRect(cnt)
except:
x, y , w, h = 0, 0, x_len, y_len
bb_xywh = [x, y, w, h]
if config_file and (config_params['use_case']=='textline' or config_params['use_case']=='word' or config_params['use_case']=='glyph' or config_params['use_case']=='printspace'):
keys = list(config_params.keys())
if "artificial_class_label" in keys:
img_mask = np.copy(img_poly)
##img_poly[:,:][(img_boundary[:,:]==1) & (img_mask[:,:,0]!=1)] = artificial_class_label
img_poly[:,:][img_boundary[:,:]==1] = artificial_class_label
elif output_type == '3d':
img_poly = cv2.fillPoly(img, pts=co_use_case, color=textline_rgb_color)
artificial_class_rgb_color = (255,255,0)
artificial_class_label = config_params['artificial_class_label']
textline_rgb_color = (255, 0, 0)
if config_params['use_case']=='textline':
region_tags = np.unique([x for x in alltags if x.endswith('TextLine')])
elif config_params['use_case']=='word':
region_tags = np.unique([x for x in alltags if x.endswith('Word')])
elif config_params['use_case']=='glyph':
region_tags = np.unique([x for x in alltags if x.endswith('Glyph')])
elif config_params['use_case']=='printspace':
region_tags = np.unique([x for x in alltags if x.endswith('PrintSpace')])
co_use_case = []
for tag in region_tags:
if config_params['use_case']=='textline':
tag_endings = ['}TextLine','}textline']
elif config_params['use_case']=='word':
tag_endings = ['}Word','}word']
elif config_params['use_case']=='glyph':
tag_endings = ['}Glyph','}glyph']
elif config_params['use_case']=='printspace':
tag_endings = ['}PrintSpace','}printspace']
if tag.endswith(tag_endings[0]) or tag.endswith(tag_endings[1]):
for nn in root1.iter(tag):
c_t_in = []
sumi = 0
for vv in nn.iter():
# check the format of coords
if vv.tag == link + 'Coords':
coords = bool(vv.attrib)
if coords:
p_h = vv.attrib['points'].split(' ')
c_t_in.append(
np.array([[int(x.split(',')[0]), int(x.split(',')[1])] for x in p_h]))
break
else:
pass
if vv.tag == link + 'Point':
c_t_in.append([int(float(vv.attrib['x'])), int(float(vv.attrib['y']))])
sumi += 1
elif vv.tag != link + 'Point' and sumi >= 1:
break
co_use_case.append(np.array(c_t_in))
if "artificial_class_label" in keys:
img_mask = np.copy(img_poly)
img_poly[:,:,0][(img_boundary[:,:]==1) & (img_mask[:,:,0]!=255)] = artificial_class_rgb_color[0]
img_poly[:,:,1][(img_boundary[:,:]==1) & (img_mask[:,:,0]!=255)] = artificial_class_rgb_color[1]
img_poly[:,:,2][(img_boundary[:,:]==1) & (img_mask[:,:,0]!=255)] = artificial_class_rgb_color[2]
img_boundary = np.zeros((y_len, x_len))
erosion_rate = 0#1
dilation_rate = 2
dilation_early = 0
erosion_early = 2
co_use_case, img_boundary = update_region_contours(co_use_case, img_boundary, erosion_rate, dilation_rate, y_len, x_len, dilation_early=dilation_early, erosion_early=erosion_early)
if printspace and config_params['use_case']!='printspace':
img_poly = img_poly[ps_bbox[1]:ps_bbox[3],
ps_bbox[0]:ps_bbox[2], :]
img = np.zeros((y_len, x_len, 3))
if output_type == '2d':
img_poly = cv2.fillPoly(img, pts=co_use_case, color=(1, 1, 1))
if "artificial_class_label" in keys:
img_mask = np.copy(img_poly)
##img_poly[:,:][(img_boundary[:,:]==1) & (img_mask[:,:,0]!=1)] = artificial_class_label
img_poly[:,:][img_boundary[:,:]==1] = artificial_class_label
elif output_type == '3d':
img_poly = cv2.fillPoly(img, pts=co_use_case, color=textline_rgb_color)
if "artificial_class_label" in keys:
img_mask = np.copy(img_poly)
img_poly[:,:,0][(img_boundary[:,:]==1) & (img_mask[:,:,0]!=255)] = artificial_class_rgb_color[0]
img_poly[:,:,1][(img_boundary[:,:]==1) & (img_mask[:,:,0]!=255)] = artificial_class_rgb_color[1]
img_poly[:,:,2][(img_boundary[:,:]==1) & (img_mask[:,:,0]!=255)] = artificial_class_rgb_color[2]
if 'columns_width' in list(config_params.keys()) and num_col and config_params['use_case']!='printspace':
img_poly = resize_image(img_poly, y_new, x_new)
try:
xml_file_stem = os.path.splitext(gt_list[index])[0]
cv2.imwrite(os.path.join(output_dir, xml_file_stem + '.png'), img_poly)
except:
xml_file_stem = os.path.splitext(gt_list[index])[0]
cv2.imwrite(os.path.join(output_dir, xml_file_stem + '.png'), img_poly)
if dir_images:
org_image_name = ls_org_imgs[xml_file_stem]
if not org_image_name:
print("image file for XML stem", xml_file_stem, "is missing")
continue
if not os.path.isfile(os.path.join(dir_images, org_image_name)):
print("image file for XML stem", xml_file_stem, "is not readable")
continue
img_org = cv2.imread(os.path.join(dir_images, org_image_name))
if printspace and config_params['use_case']!='printspace':
img_org = img_org[ps_bbox[1]:ps_bbox[3],
ps_bbox[0]:ps_bbox[2], :]
img_poly = img_poly[bb_xywh[1]:bb_xywh[1]+bb_xywh[3], bb_xywh[0]:bb_xywh[0]+bb_xywh[2], :]
if 'columns_width' in list(config_params.keys()) and num_col and config_params['use_case']!='printspace':
img_org = resize_image(img_org, y_new, x_new)
img_poly = resize_image(img_poly, y_new, x_new)
cv2.imwrite(os.path.join(dir_out_images, org_image_name), img_org)
try:
xml_file_stem = os.path.splitext(gt_list[index])[0]
cv2.imwrite(os.path.join(output_dir, xml_file_stem + '.png'), img_poly)
except:
xml_file_stem = os.path.splitext(gt_list[index])[0]
cv2.imwrite(os.path.join(output_dir, xml_file_stem + '.png'), img_poly)
if dir_images:
org_image_name = ls_org_imgs[ls_org_imgs_stem.index(xml_file_stem)]
img_org = cv2.imread(os.path.join(dir_images, org_image_name))
if printspace and config_params['use_case']!='printspace':
img_org = img_org[bb_xywh[1]:bb_xywh[1]+bb_xywh[3], bb_xywh[0]:bb_xywh[0]+bb_xywh[2], :]
if 'columns_width' in list(config_params.keys()) and num_col and config_params['use_case']!='printspace':
img_org = resize_image(img_org, y_new, x_new)
cv2.imwrite(os.path.join(dir_out_images, org_image_name), img_org)
except:
pass
if config_file and config_params['use_case']=='layout':
keys = list(config_params.keys())
@ -870,7 +922,7 @@ def get_images_of_ground_truth(
types_graphic_label = list(types_graphic_dict.values())
labels_rgb_color = [ (0,0,0), (255,0,0), (255,125,0), (255,0,125), (125,255,125), (125,125,0), (0,125,255), (0,125,0), (125,125,125), (255,0,255), (125,0,125), (0,255,0),(0,0,255), (0,255,255), (255,125,125), (0,125,125), (0,255,125), (255,125,255), (125,255,0), (125,255,255)]
labels_rgb_color = [ (0,0,0), (255,0,0), (255,125,0), (255,0,125), (125,255,125), (125,125,0), (0,125,255), (0,125,0), (125,125,125), (255,0,255), (125,0,125), (0,255,0),(0,0,255), (0,255,255), (255,125,125), (0,125,125), (0,255,125), (255,125,255), (125,255,0), (125,255,255), (125,125,255)]
region_tags=np.unique([x for x in alltags if x.endswith('Region')])
@ -882,6 +934,7 @@ def get_images_of_ground_truth(
co_img=[]
co_table=[]
co_map=[]
co_music=[]
co_noise=[]
for tag in region_tags:
@ -966,20 +1019,22 @@ def get_images_of_ground_truth(
if "rest_as_decoration" in types_graphic:
types_graphic_without_decoration = [element for element in types_graphic if element!='rest_as_decoration' and element!='decoration']
if len(types_graphic_without_decoration) == 0:
if "type" in nn.attrib:
c_t_in_graphic['decoration'].append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) )
#if "type" in nn.attrib:
c_t_in_graphic['decoration'].append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) )
elif len(types_graphic_without_decoration) >= 1:
if "type" in nn.attrib:
if nn.attrib['type'] in types_graphic_without_decoration:
c_t_in_graphic[nn.attrib['type']].append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) )
else:
c_t_in_graphic['decoration'].append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) )
else:
c_t_in_graphic['decoration'].append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) )
else:
if "type" in nn.attrib:
if nn.attrib['type'] in all_defined_graphic_types:
c_t_in_graphic[nn.attrib['type']].append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) )
break
else:
pass
@ -989,9 +1044,9 @@ def get_images_of_ground_truth(
if "rest_as_decoration" in types_graphic:
types_graphic_without_decoration = [element for element in types_graphic if element!='rest_as_decoration' and element!='decoration']
if len(types_graphic_without_decoration) == 0:
if "type" in nn.attrib:
c_t_in_graphic['decoration'].append( [ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ] )
sumi+=1
#if "type" in nn.attrib:
c_t_in_graphic['decoration'].append( [ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ] )
sumi+=1
elif len(types_graphic_without_decoration) >= 1:
if "type" in nn.attrib:
if nn.attrib['type'] in types_graphic_without_decoration:
@ -1000,6 +1055,9 @@ def get_images_of_ground_truth(
else:
c_t_in_graphic['decoration'].append( [ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ] )
sumi+=1
else:
c_t_in_graphic['decoration'].append( [ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ] )
sumi+=1
else:
if "type" in nn.attrib:
@ -1119,6 +1177,32 @@ def get_images_of_ground_truth(
break
co_map.append(np.array(c_t_in))
if 'musicregion' in keys:
if tag.endswith('}MusicRegion') or tag.endswith('}musicregion'):
#print('sth')
for nn in root1.iter(tag):
c_t_in=[]
sumi=0
for vv in nn.iter():
# check the format of coords
if vv.tag==link+'Coords':
coords=bool(vv.attrib)
if coords:
p_h=vv.attrib['points'].split(' ')
c_t_in.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) )
break
else:
pass
if vv.tag==link+'Point':
c_t_in.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ])
sumi+=1
#print(vv.tag,'in')
elif vv.tag!=link+'Point' and sumi>=1:
break
co_music.append(np.array(c_t_in))
if 'noiseregion' in keys:
if tag.endswith('}NoiseRegion') or tag.endswith('}noiseregion'):
#print('sth')
@ -1195,6 +1279,10 @@ def get_images_of_ground_truth(
erosion_rate = 0#2
dilation_rate = 3#4
co_map, img_boundary = update_region_contours(co_map, img_boundary, erosion_rate, dilation_rate, y_len, x_len )
if "musicregion" in elements_with_artificial_class:
erosion_rate = 0#2
dilation_rate = 3#4
co_music, img_boundary = update_region_contours(co_music, img_boundary, erosion_rate, dilation_rate, y_len, x_len )
@ -1222,6 +1310,8 @@ def get_images_of_ground_truth(
img_poly=cv2.fillPoly(img, pts =co_table, color=labels_rgb_color[ config_params['tableregion']])
if 'mapregion' in keys:
img_poly=cv2.fillPoly(img, pts =co_map, color=labels_rgb_color[ config_params['mapregion']])
if 'musicregion' in keys:
img_poly=cv2.fillPoly(img, pts =co_music, color=labels_rgb_color[ config_params['musicregion']])
if 'noiseregion' in keys:
img_poly=cv2.fillPoly(img, pts =co_noise, color=labels_rgb_color[ config_params['noiseregion']])
@ -1286,6 +1376,9 @@ def get_images_of_ground_truth(
if 'mapregion' in keys:
color_label = config_params['mapregion']
img_poly=cv2.fillPoly(img, pts =co_map, color=(color_label,color_label,color_label))
if 'musicregion' in keys:
color_label = config_params['musicregion']
img_poly=cv2.fillPoly(img, pts =co_music, color=(color_label,color_label,color_label))
if 'noiseregion' in keys:
color_label = config_params['noiseregion']
img_poly=cv2.fillPoly(img, pts =co_noise, color=(color_label,color_label,color_label))

View file

@ -9,9 +9,11 @@ import warnings
import json
import click
import numpy as np
from numpy._typing import NDArray
import cv2
import xml.etree.ElementTree as ET
os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15
@ -119,8 +121,33 @@ class SBBPredict:
return mIoU
def start_new_session_and_model(self):
if self.cpu:
tf.config.set_visible_devices([], 'GPU')
if self.task == "cnn-rnn-ocr":
if self.cpu:
os.environ['CUDA_VISIBLE_DEVICES']='-1'
self.model = load_model(self.model_dir)
self.model = tf.keras.models.Model(
self.model.get_layer(name = "image").input,
self.model.get_layer(name = "dense2").output)
assert isinstance(self.model, Model)
elif self.task == "transformer-ocr":
import torch
from transformers import VisionEncoderDecoderModel
from transformers import TrOCRProcessor
self.model = VisionEncoderDecoderModel.from_pretrained(self.model_dir)
self.processor = TrOCRProcessor.from_pretrained(self.model_dir)
if self.cpu:
self.device = torch.device('cpu')
else:
self.device = torch.device('cuda:0')
self.model.to(self.device)
assert isinstance(self.model, torch.nn.Module)
else:
try:
for device in tf.config.list_physical_devices('GPU'):
@ -137,15 +164,13 @@ class SBBPredict:
custom_objects={"PatchEncoder": PatchEncoder,
"Patches": Patches})
##if self.weights_dir!=None:
##self.model.load_weights(self.weights_dir)
if self.task != 'classification' and self.task != 'reading_order':
self.img_height=self.model.layers[len(self.model.layers)-1].output_shape[1]
self.img_width=self.model.layers[len(self.model.layers)-1].output_shape[2]
self.n_classes=self.model.layers[len(self.model.layers)-1].output_shape[3]
assert isinstance(self.model, Model)
if self.task != 'classification' and self.task != 'reading_order':
last = self.model.layers[-1]
self.img_height = last.output_shape[1]
self.img_width = last.output_shape[2]
self.n_classes = last.output_shape[3]
assert isinstance(self.model, Model)
def visualize_model_output(self, prediction, img, task) -> Tuple[NDArray, NDArray]:
if task == "binarization":
@ -191,9 +216,9 @@ class SBBPredict:
return added_image, layout_only
def predict(self, image_dir):
assert isinstance(self.model, Model)
if self.task == 'classification':
classes_names = self.config_params_model['classification_classes_name']
img_1ch = cv2.imread(image_dir, 0) / 255.0
img_1ch = cv2.resize(img_1ch, (self.config_params_model['input_height'],
self.config_params_model['input_width']),
@ -231,6 +256,15 @@ class SBBPredict:
pred_texts = pred_texts[0].replace("[UNK]", "")
return pred_texts
elif self.task == "transformer-ocr":
from PIL import Image
image = Image.open(image_dir).convert("RGB")
pixel_values = self.processor(image, return_tensors="pt").pixel_values
generated_ids = self.model.generate(pixel_values.to(self.device))
return self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
elif self.task == 'reading_order':
img_height = self.config_params_model['input_height']
@ -566,6 +600,8 @@ class SBBPredict:
cv2.imwrite(self.save,res)
elif self.task == "cnn-rnn-ocr":
print(f"Detected text: {res}")
elif self.task == "transformer-ocr":
print(f"Detected text: {res}")
else:
img_seg_overlayed, only_layout = self.visualize_model_output(res, self.img_org, self.task)
if self.save:
@ -668,11 +704,13 @@ class SBBPredict:
help="min area size of regions considered for reading order detection. The default value is zero and means that all text regions are considered for reading order.",
)
def main(image, dir_in, model, patches, save, save_layout, ground_truth, xml_file, cpu, out, min_area):
assert image or dir_in, "Either a single image -i or a dir_in -di input is required"
with open(os.path.join(model,'config.json')) as f:
config_params_model = json.load(f)
task = config_params_model['task']
if task not in ['classification', 'reading_order', "cnn-rnn-ocr"]:
if task not in ['classification', 'reading_order', "cnn-rnn-ocr", "transformer-ocr"]:
assert not image or save, "For segmentation or binarization, an input single image -i also requires an output filename -s"
assert not dir_in or out, "For segmentation or binarization, an input directory -di also requires an output directory -o"
x = SBBPredict(image, dir_in, model, task, config_params_model,

View file

@ -309,6 +309,8 @@ def transformer_block(img,
# Skip connection 2.
encoded_patches = Add()([x3, x2])
#assert isinstance(x, Layer)
encoded_patches = tf.reshape(encoded_patches,
[-1,
img.shape[1],

View file

@ -3,9 +3,14 @@ import sys
import io
import json
from tqdm import tqdm
import requests
import numpy as np
import cv2
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15
import tensorflow as tf
@ -47,12 +52,21 @@ from .utils import (
generate_arrays_from_folder_reading_order,
get_one_hot,
preprocess_imgs,
return_number_of_total_training_data,
OCRDatasetYieldAugmentations
)
from .weights_ensembling import run_ensembling
import torch
from transformers import TrOCRProcessor
import evaluate
from transformers import default_data_collator
from transformers import VisionEncoderDecoderModel
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
class SaveWeightsAfterSteps(ModelCheckpoint):
def __init__(self, save_interval, save_path, _config, **kwargs):
def __init__(self, save_interval, save_path, _config, characters_cnnrnn_ocr=None, **kwargs):
if save_interval:
# batches
super().__init__(
@ -67,12 +81,15 @@ class SaveWeightsAfterSteps(ModelCheckpoint):
verbose=1,
**kwargs)
self._config = _config
self.characters_cnnrnn_ocr = characters_cnnrnn_ocr
# overwrite tf-keras (Keras 2) implementation to get our _config JSON in
def _save_handler(self, filepath):
super()._save_handler(filepath)
with open(os.path.join(filepath, "config.json"), "w") as fp:
json.dump(self._config, fp) # encode dict into JSON
if self.characters_cnnrnn_ocr:
os.system("cp "+self.characters_cnnrnn_ocr+" "+os.path.join(os.path.join(self.save_path, f"model_step_{self.step_count}"),"characters_org.txt"))
def configuration():
try:
@ -820,9 +837,126 @@ def run(_config,
usable_checkpoints = [os.path.join(dir_output, 'model_{epoch:02d}'.format(epoch=epoch + 1))
for epoch in usable_checkpoints]
ens_path = os.path.join(dir_output, 'model_ens_avg')
run_ensembling(usable_checkpoints, ens_path)
run_ensembling(usable_checkpoints, ens_path, framework="tensorflow")
_log.info("ensemble model saved under '%s'", ens_path)
# =======
elif task=="transformer-ocr":
dir_img, dir_lab = get_dirs_or_files(dir_train)
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-printed")
ls_files_images = os.listdir(dir_img)
aug_multip = return_multiplier_based_on_augmnentations(augmentation, color_padding_rotation, rotation_not_90, blur_aug, degrading, bin_deg,
brightening, padding_white, adding_rgb_foreground, adding_rgb_background, binarization,
image_inversion, channels_shuffling, add_red_textlines, white_noise_strap, textline_skewing, textline_skewing_bin, textline_left_in_depth, textline_left_in_depth_bin, textline_right_in_depth, textline_right_in_depth_bin, textline_up_in_depth, textline_up_in_depth_bin, textline_down_in_depth, textline_down_in_depth_bin, pepper_bin_aug, pepper_aug, degrade_scales, number_of_backgrounds_per_image, thetha, thetha_padd, brightness, padd_colors, shuffle_indexes, pepper_indexes, skewing_amplitudes, blur_k, white_padds)
len_dataset = aug_multip*len(ls_files_images)
dataset = OCRDatasetYieldAugmentations(
dir_img=dir_img,
dir_img_bin=dir_img_bin,
dir_lab=dir_lab,
processor=processor,
max_target_length=max_len,
augmentation = augmentation,
binarization = binarization,
add_red_textlines = add_red_textlines,
white_noise_strap = white_noise_strap,
adding_rgb_foreground = adding_rgb_foreground,
adding_rgb_background = adding_rgb_background,
bin_deg = bin_deg,
blur_aug = blur_aug,
brightening = brightening,
padding_white = padding_white,
color_padding_rotation = color_padding_rotation,
rotation_not_90 = rotation_not_90,
degrading = degrading,
channels_shuffling = channels_shuffling,
textline_skewing = textline_skewing,
textline_skewing_bin = textline_skewing_bin,
textline_right_in_depth = textline_right_in_depth,
textline_left_in_depth = textline_left_in_depth,
textline_up_in_depth = textline_up_in_depth,
textline_down_in_depth = textline_down_in_depth,
textline_right_in_depth_bin = textline_right_in_depth_bin,
textline_left_in_depth_bin = textline_left_in_depth_bin,
textline_up_in_depth_bin = textline_up_in_depth_bin,
textline_down_in_depth_bin = textline_down_in_depth_bin,
pepper_aug = pepper_aug,
pepper_bin_aug = pepper_bin_aug,
list_all_possible_background_images=list_all_possible_background_images,
list_all_possible_foreground_rgbs=list_all_possible_foreground_rgbs,
blur_k = blur_k,
degrade_scales = degrade_scales,
white_padds = white_padds,
thetha_padd = thetha_padd,
thetha = thetha,
brightness = brightness,
padd_colors = padd_colors,
number_of_backgrounds_per_image = number_of_backgrounds_per_image,
shuffle_indexes = shuffle_indexes,
pepper_indexes = pepper_indexes,
skewing_amplitudes = skewing_amplitudes,
dir_rgb_backgrounds = dir_rgb_backgrounds,
dir_rgb_foregrounds = dir_rgb_foregrounds,
len_data=len_dataset,
)
# Create a DataLoader
data_loader = torch.utils.data.DataLoader(dataset, batch_size=1)
train_dataset = data_loader.dataset
if continue_training:
model = VisionEncoderDecoderModel.from_pretrained(dir_of_start_model)
else:
model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-printed")
# set special tokens used for creating the decoder_input_ids from the labels
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id
# make sure vocab size is set correctly
model.config.vocab_size = model.config.decoder.vocab_size
# set beam search parameters
model.config.eos_token_id = processor.tokenizer.sep_token_id
model.config.max_length = max_len
model.config.early_stopping = True
model.config.no_repeat_ngram_size = 3
model.config.length_penalty = 2.0
model.config.num_beams = 4
training_args = Seq2SeqTrainingArguments(
predict_with_generate=True,
num_train_epochs=n_epochs,
learning_rate=learning_rate,
per_device_train_batch_size=n_batch,
fp16=True,
output_dir=dir_output,
logging_steps=2,
save_steps=save_interval,
)
cer_metric = evaluate.load("cer")
# instantiate trainer
trainer = Seq2SeqTrainer(
model=model,
tokenizer=processor.feature_extractor,
args=training_args,
train_dataset=train_dataset,
data_collator=default_data_collator,
)
trainer.train()
elif task=='reading_order':
if continue_training:
model = load_model(dir_of_start_model, compile=False)

View file

@ -14,6 +14,9 @@ import tensorflow as tf
from PIL import Image, ImageFile, ImageEnhance
import torch
from torch.utils.data import IterableDataset
ImageFile.LOAD_TRUNCATED_IMAGES = True
@ -78,6 +81,7 @@ def add_salt_and_pepper_noise(img, salt_prob, pepper_prob):
return noisy_image
def invert_image(img):
img_inv = 255 - img
return img_inv
@ -1242,3 +1246,411 @@ def preprocess_img_ocr(
for pepper_ind in pepper_indexes:
img_noisy = add_salt_and_pepper_noise(img_bin_corr, pepper_ind, pepper_ind)
yield scale_image(img_noisy), lab
class OCRDatasetYieldAugmentations(IterableDataset):
def __init__(
self,
dir_img,
dir_img_bin,
dir_lab,
processor,
max_target_length=128,
augmentation = None,
binarization = None,
add_red_textlines = None,
white_noise_strap = None,
adding_rgb_foreground = None,
adding_rgb_background = None,
bin_deg = None,
blur_aug = None,
brightening = None,
padding_white = None,
color_padding_rotation = None,
rotation_not_90 = None,
degrading = None,
channels_shuffling = None,
textline_skewing = None,
textline_skewing_bin = None,
textline_right_in_depth = None,
textline_left_in_depth = None,
textline_up_in_depth = None,
textline_down_in_depth = None,
textline_right_in_depth_bin = None,
textline_left_in_depth_bin = None,
textline_up_in_depth_bin = None,
textline_down_in_depth_bin = None,
pepper_aug = None,
pepper_bin_aug = None,
list_all_possible_background_images=None,
list_all_possible_foreground_rgbs=None,
blur_k = None,
degrade_scales = None,
white_padds = None,
thetha_padd = None,
thetha = None,
brightness = None,
padd_colors = None,
number_of_backgrounds_per_image = None,
shuffle_indexes = None,
pepper_indexes = None,
skewing_amplitudes = None,
dir_rgb_backgrounds = None,
dir_rgb_foregrounds = None,
len_data=None,
):
"""
Args:
images_dir (str): Path to the directory containing images.
labels_dir (str): Path to the directory containing label text files.
tokenizer: Tokenizer for processing labels.
transform: Transformations applied after augmentation (e.g., ToTensor, normalization).
image_size (tuple): Size to resize images to.
max_seq_len (int): Maximum sequence length for tokenized labels.
scales (list or None): List of scale factors to apply.
"""
self.dir_img = dir_img
self.dir_img_bin = dir_img_bin
self.dir_lab = dir_lab
self.processor = processor
self.max_target_length = max_target_length
#self.scales = scales if scales else []
self.augmentation = augmentation
self.binarization = binarization
self.add_red_textlines = add_red_textlines
self.white_noise_strap = white_noise_strap
self.adding_rgb_foreground = adding_rgb_foreground
self.adding_rgb_background = adding_rgb_background
self.bin_deg = bin_deg
self.blur_aug = blur_aug
self.brightening = brightening
self.padding_white = padding_white
self.color_padding_rotation = color_padding_rotation
self.rotation_not_90 = rotation_not_90
self.degrading = degrading
self.channels_shuffling = channels_shuffling
self.textline_skewing = textline_skewing
self.textline_skewing_bin = textline_skewing_bin
self.textline_right_in_depth = textline_right_in_depth
self.textline_left_in_depth = textline_left_in_depth
self.textline_up_in_depth = textline_up_in_depth
self.textline_down_in_depth = textline_down_in_depth
self.textline_right_in_depth_bin = textline_right_in_depth_bin
self.textline_left_in_depth_bin = textline_left_in_depth_bin
self.textline_up_in_depth_bin = textline_up_in_depth_bin
self.textline_down_in_depth_bin = textline_down_in_depth_bin
self.pepper_aug = pepper_aug
self.pepper_bin_aug = pepper_bin_aug
self.list_all_possible_background_images=list_all_possible_background_images
self.list_all_possible_foreground_rgbs=list_all_possible_foreground_rgbs
self.blur_k = blur_k
self.degrade_scales = degrade_scales
self.white_padds = white_padds
self.thetha_padd = thetha_padd
self.thetha = thetha
self.brightness = brightness
self.padd_colors = padd_colors
self.number_of_backgrounds_per_image = number_of_backgrounds_per_image
self.shuffle_indexes = shuffle_indexes
self.pepper_indexes = pepper_indexes
self.skewing_amplitudes = skewing_amplitudes
self.dir_rgb_backgrounds = dir_rgb_backgrounds
self.dir_rgb_foregrounds = dir_rgb_foregrounds
self.image_files = os.listdir(dir_img)#sorted([f for f in os.listdir(images_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))])
self.len_data = len_data
#assert len(self.image_files) == len(self.label_files), "Number of images and labels must match!"
def __len__(self):
return self.len_data
def __iter__(self):
for img_file in self.image_files:
# Load image
f_name = img_file.split('.')[0]
txt_inp = open(os.path.join(self.dir_lab, f_name+'.txt'),'r').read().split('\n')[0]
img = cv2.imread(os.path.join(self.dir_img, img_file))
img = img.astype(np.uint8)
if self.dir_img_bin:
img_bin_corr = cv2.imread(os.path.join(self.dir_img_bin, f_name+'.png') )
img_bin_corr = img_bin_corr.astype(np.uint8)
else:
img_bin_corr = None
labels = self.processor.tokenizer(txt_inp,
padding="max_length",
max_length=self.max_target_length).input_ids
labels = [label if label != self.processor.tokenizer.pad_token_id else -100 for label in labels]
if self.augmentation:
pixel_values = self.processor(Image.fromarray(img), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.color_padding_rotation:
for index, thetha_ind in enumerate(self.thetha_padd):
for padd_col in self.padd_colors:
img_out = rotation_not_90_func_single_image(do_padding_for_ocr(img, 1.2, padd_col), thetha_ind)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.rotation_not_90:
for index, thetha_ind in enumerate(self.thetha):
img_out = rotation_not_90_func_single_image(img, thetha_ind)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.blur_aug:
for index, blur_type in enumerate(self.blur_k):
img_out = bluring(img, blur_type)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.degrading:
for index, deg_scale_ind in enumerate(self.degrade_scales):
try:
img_out = do_degrading(img, deg_scale_ind)
except:
img_out = np.copy(img)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.bin_deg:
for index, deg_scale_ind in enumerate(self.degrade_scales):
try:
img_out = self.do_degrading(img_bin_corr, deg_scale_ind)
except:
img_out = np.copy(img_bin_corr)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.brightening:
for index, bright_scale_ind in enumerate(self.brightness):
try:
img_out = do_brightening(dir_img, bright_scale_ind)
except:
img_out = np.copy(img)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.padding_white:
for index, padding_size in enumerate(self.white_padds):
for padd_col in self.padd_colors:
img_out = do_padding_for_ocr(img, padding_size, padd_col)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.adding_rgb_foreground:
for i_n in range(self.number_of_backgrounds_per_image):
background_image_chosen_name = random.choice(self.list_all_possible_background_images)
foreground_rgb_chosen_name = random.choice(self.list_all_possible_foreground_rgbs)
img_rgb_background_chosen = cv2.imread(self.dir_rgb_backgrounds + '/' + background_image_chosen_name)
foreground_rgb_chosen = np.load(self.dir_rgb_foregrounds + '/' + foreground_rgb_chosen_name)
img_out = return_binary_image_with_given_rgb_background_and_given_foreground_rgb(img_bin_corr, img_rgb_background_chosen, foreground_rgb_chosen)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.adding_rgb_background:
for i_n in range(self.number_of_backgrounds_per_image):
background_image_chosen_name = random.choice(self.list_all_possible_background_images)
img_rgb_background_chosen = cv2.imread(self.dir_rgb_backgrounds + '/' + background_image_chosen_name)
img_out = return_binary_image_with_given_rgb_background(img_bin_corr, img_rgb_background_chosen)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.binarization:
pixel_values = self.processor(Image.fromarray(img_bin_corr), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.channels_shuffling:
for shuffle_index in self.shuffle_indexes:
img_out = return_shuffled_channels(img, shuffle_index)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.add_red_textlines:
img_out = return_image_with_red_elements(img, img_bin_corr)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.white_noise_strap:
img_out = return_image_with_strapped_white_noises(img)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.textline_skewing:
for index, des_scale_ind in enumerate(self.skewing_amplitudes):
try:
img_out = do_deskewing(img, des_scale_ind)
except:
img_out = np.copy(img)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.textline_skewing_bin:
for index, des_scale_ind in enumerate(self.skewing_amplitudes):
try:
img_out = do_deskewing(img_bin_corr, des_scale_ind)
except:
img_out = np.copy(img_bin_corr)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.textline_left_in_depth:
try:
img_out = do_direction_in_depth(img, 'left')
except:
img_out = np.copy(img)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.textline_left_in_depth_bin:
try:
img_out = do_direction_in_depth(img_bin_corr, 'left')
except:
img_out = np.copy(img_bin_corr)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.textline_right_in_depth:
try:
img_out = do_direction_in_depth(img, 'right')
except:
img_out = np.copy(img)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.textline_right_in_depth_bin:
try:
img_out = do_direction_in_depth(img_bin_corr, 'right')
except:
img_out = np.copy(img_bin_corr)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.textline_up_in_depth:
try:
img_out = do_direction_in_depth(img, 'up')
except:
img_out = np.copy(img)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.textline_up_in_depth_bin:
try:
img_out = do_direction_in_depth(img_bin_corr, 'up')
except:
img_out = np.copy(img_bin_corr)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.textline_down_in_depth:
try:
img_out = do_direction_in_depth(img, 'down')
except:
img_out = np.copy(img)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.textline_down_in_depth_bin:
try:
img_out = do_direction_in_depth(img_bin_corr, 'down')
except:
img_out = np.copy(img_bin_corr)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.pepper_bin_aug:
for index, pepper_ind in enumerate(self.pepper_indexes):
img_out = add_salt_and_pepper_noise(img_bin_corr, pepper_ind, pepper_ind)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.pepper_aug:
for index, pepper_ind in enumerate(self.pepper_indexes):
img_out = add_salt_and_pepper_noise(img, pepper_ind, pepper_ind)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
else:
pixel_values = self.processor(Image.fromarray(img), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding

View file

@ -16,28 +16,53 @@ from ..patch_encoder import (
PatchEncoder,
Patches,
)
from PIL import Image
import torch
from transformers import VisionEncoderDecoderModel
def run_ensembling(model_dirs, out_dir):
all_weights = []
def run_ensembling(dir_models, out, framework):
ls_models = os.listdir(dir_models)
if framework=="torch":
models = []
sd_models = []
for model_dir in model_dirs:
assert os.path.isdir(model_dir), model_dir
model = load_model(model_dir, compile=False,
custom_objects=dict(PatchEncoder=PatchEncoder,
Patches=Patches))
all_weights.append(model.get_weights())
for model_name in ls_models:
model = VisionEncoderDecoderModel.from_pretrained(os.path.join(dir_models,model_name))
models.append(model)
sd_models.append(model.state_dict())
for key in sd_models[0]:
sd_models[0][key] = sum(sd[key] for sd in sd_models) / len(sd_models)
new_weights = []
for layer_weights in zip(*all_weights):
layer_weights = np.array([np.array(weights).mean(axis=0)
for weights in zip(*layer_weights)])
new_weights.append(layer_weights)
model.load_state_dict(sd_models[0])
os.system("mkdir "+out)
torch.save(model.state_dict(), os.path.join(out, "pytorch_model.bin"))
os.system('cp '+os.path.join(os.path.join(dir_models,model_name) , "config_eynollah.json ")+out)
#model = tf.keras.models.clone_model(model)
model.set_weights(new_weights)
else:
weights=[]
model.save(out_dir)
os.system('cp ' + os.path.join(model_dirs[0], "config.json ") + out_dir + "/")
for model_name in ls_models:
model = load_model(os.path.join(dir_models,model_name) , compile=False, custom_objects={'PatchEncoder':PatchEncoder, 'Patches': Patches})
weights.append(model.get_weights())
new_weights = list()
for weights_list_tuple in zip(*weights):
new_weights.append(
[np.array(weights_).mean(axis=0)\
for weights_ in zip(*weights_list_tuple)])
new_weights = [np.array(x) for x in new_weights]
model.set_weights(new_weights)
model.save(out)
os.system('cp '+os.path.join(os.path.join(dir_models,model_name) , "config_eynollah.json ")+out)
try:
os.system('cp '+os.path.join(os.path.join(dir_models,model_name) , "characters_org.txt ")+out)
except:
pass
@click.command()
@click.option(
@ -55,12 +80,17 @@ def run_ensembling(model_dirs, out_dir):
required=True,
type=click.Path(exists=False, file_okay=False),
)
def ensemble_cli(in_, out):
@click.option(
"--framework",
"-fw",
help="this parameter gets tensorflow or torch as model framework",
)
def ensemble_cli(in_, out, framework):
"""
mix multiple model weights
Load a sequence of models and mix them into a single ensemble model
by averaging their weights. Write the resulting model.
"""
run_ensembling(in_, out)
run_ensembling(in_, out, framework)

View file

@ -9,8 +9,8 @@ else:
import importlib.resources as importlib_resources
def get_font():
def get_font(font_size):
#font_path = "Charis-7.000/Charis-Regular.ttf" # Make sure this file exists!
font = importlib_resources.files(__package__) / "../Charis-Regular.ttf"
font = importlib_resources.files(__package__) / "../Amiri-Regular.ttf"
with importlib_resources.as_file(font) as font:
return ImageFont.truetype(font=font, size=40)
return ImageFont.truetype(font=font, size=font_size)

View file

@ -1,17 +1,17 @@
{
"backbone_type" : "transformer",
"task": "cnn-rnn-ocr",
"task": "transformer-ocr",
"n_classes" : 2,
"max_len": 280,
"n_epochs" : 3,
"max_len": 192,
"n_epochs" : 1,
"input_height" : 32,
"input_width" : 512,
"weight_decay" : 1e-6,
"n_batch" : 4,
"n_batch" : 1,
"learning_rate": 1e-5,
"save_interval": 1500,
"patches" : false,
"pretraining" : true,
"pretraining" : false,
"augmentation" : true,
"flip_aug" : false,
"blur_aug" : true,
@ -77,7 +77,6 @@
"dir_output": "/home/vahid/extracted_lines/1919_bin/output",
"dir_rgb_backgrounds": "/home/vahid/Documents/1_2_test_eynollah/set_rgb_background",
"dir_rgb_foregrounds": "/home/vahid/Documents/1_2_test_eynollah/out_set_rgb_foreground",
"dir_img_bin": "/home/vahid/extracted_lines/1919_bin/images_bin",
"characters_txt_file":"/home/vahid/Downloads/models_eynollah/model_eynollah_ocr_cnnrnn_20250930/characters_org.txt"
"dir_img_bin": "/home/vahid/extracted_lines/1919_bin/images_bin"
}

View file

@ -8,3 +8,8 @@ tensorflow-addons # for connected_components, depublished and only compatible wi
tensorflow < 2.16 # for tensorflow-addons, so only needed in training
tf_data < 2.16 # for tensorflow-addons, so only needed in training
protobuf < 5 # for tensorflow-addons, so only needed in training
torch
evaluate
accelerate
jiwer
transformers <= 4.30.2