This commit is contained in:
Konstantin Baierer 2026-02-19 12:59:25 +00:00 committed by GitHub
commit 38745111df
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 875 additions and 113 deletions

View file

@ -70,6 +70,7 @@ class Eynollah_ocr:
self.model_zoo.get('ocr').to(self.device) self.model_zoo.get('ocr').to(self.device)
else: else:
self.model_zoo.load_model('ocr', '') self.model_zoo.load_model('ocr', '')
self.input_shape = self.model_zoo.get('ocr').input_shape[1:3]
self.model_zoo.load_model('num_to_char') self.model_zoo.load_model('num_to_char')
self.model_zoo.load_model('characters') self.model_zoo.load_model('characters')
self.end_character = len(self.model_zoo.get('characters', list)) + 2 self.end_character = len(self.model_zoo.get('characters', list)) + 2
@ -657,7 +658,7 @@ class Eynollah_ocr:
if out_image_with_text: if out_image_with_text:
image_text = Image.new("RGB", (img.shape[1], img.shape[0]), "white") image_text = Image.new("RGB", (img.shape[1], img.shape[0]), "white")
draw = ImageDraw.Draw(image_text) draw = ImageDraw.Draw(image_text)
font = get_font() font = get_font(font_size=40)
for indexer_text, bb_ind in enumerate(total_bb_coordinates): for indexer_text, bb_ind in enumerate(total_bb_coordinates):
x_bb = bb_ind[0] x_bb = bb_ind[0]
@ -823,8 +824,8 @@ class Eynollah_ocr:
page_ns=page_ns, page_ns=page_ns,
img_bin=img_bin, img_bin=img_bin,
image_width=512, image_width=self.input_shape[1],
image_height=32, image_height=self.input_shape[0],
) )
self.write_ocr( self.write_ocr(

View file

@ -10,6 +10,7 @@ from .inference import main as inference_cli
from .train import ex from .train import ex
from .extract_line_gt import linegt_cli from .extract_line_gt import linegt_cli
from .weights_ensembling import main as ensemble_cli from .weights_ensembling import main as ensemble_cli
from .generate_or_update_cnn_rnn_ocr_character_list import main as update_ocr_characters_cli
@click.command(context_settings=dict( @click.command(context_settings=dict(
ignore_unknown_options=True, ignore_unknown_options=True,
@ -28,3 +29,4 @@ main.add_command(inference_cli, 'inference')
main.add_command(train_cli, 'train') main.add_command(train_cli, 'train')
main.add_command(linegt_cli, 'export_textline_images_and_text') main.add_command(linegt_cli, 'export_textline_images_and_text')
main.add_command(ensemble_cli, 'ensembling') main.add_command(ensemble_cli, 'ensembling')
main.add_command(update_ocr_characters_cli, 'generate_or_update_cnn_rnn_ocr_character_list')

View file

@ -50,6 +50,18 @@ from ..utils import is_image_filename
is_flag=True, is_flag=True,
help="if this parameter set to true, cropped textline images will not be masked with textline contour.", help="if this parameter set to true, cropped textline images will not be masked with textline contour.",
) )
@click.option(
"--exclude_vertical_lines",
"-exv",
is_flag=True,
help="if this parameter set to true, vertical textline images will be excluded.",
)
@click.option(
"--page_alto",
"-alto",
is_flag=True,
help="If this parameter is set to True, text line image cropping and text extraction are performed using PAGE/ALTO files. Otherwise, the default method for PAGE XML files is used.",
)
def linegt_cli( def linegt_cli(
image, image,
dir_in, dir_in,
@ -57,6 +69,8 @@ def linegt_cli(
dir_out, dir_out,
pref_of_dataset, pref_of_dataset,
do_not_mask_with_textline_contour, do_not_mask_with_textline_contour,
exclude_vertical_lines,
page_alto,
): ):
assert bool(dir_in) ^ bool(image), "Set --dir-in or --image-filename, not both" assert bool(dir_in) ^ bool(image), "Set --dir-in or --image-filename, not both"
if dir_in: if dir_in:
@ -70,65 +84,149 @@ def linegt_cli(
for dir_img in ls_imgs: for dir_img in ls_imgs:
file_name = Path(dir_img).stem file_name = Path(dir_img).stem
dir_xml = os.path.join(dir_xmls, file_name + '.xml') dir_xml = os.path.join(dir_xmls, file_name + '.xml')
img = cv2.imread(dir_img) img = cv2.imread(dir_img)
if page_alto:
h, w = img.shape[:2]
tree = ET.parse(dir_xml)
root = tree.getroot()
total_bb_coordinates = [] NS = {"alto": "http://www.loc.gov/standards/alto/ns-v4#"}
tree1 = ET.parse(dir_xml, parser=ET.XMLParser(encoding="utf-8")) results = []
root1 = tree1.getroot()
alltags = [elem.tag for elem in root1.iter()] indexer_textlines = 0
for line in root.findall(".//alto:TextLine", NS):
string_el = line.find("alto:String", NS)
textline_text = string_el.attrib["CONTENT"] if string_el is not None else None
name_space = alltags[0].split('}')[0] polygon_el = line.find("alto:Shape/alto:Polygon", NS)
name_space = name_space.split('{')[1] if polygon_el is None:
continue
region_tags = np.unique([x for x in alltags if x.endswith('TextRegion')]) points = polygon_el.attrib["POINTS"].split()
coords = [
(int(points[i]), int(points[i + 1]))
for i in range(0, len(points), 2)
]
coords = np.array(coords, dtype=np.int32)
x, y, w, h = cv2.boundingRect(coords)
if exclude_vertical_lines and h > 1.4 * w:
img_crop = None
continue
img_poly_on_img = np.copy(img)
cropped_lines_region_indexer = [] mask_poly = np.zeros(img.shape)
mask_poly = cv2.fillPoly(mask_poly, pts=[coords], color=(1, 1, 1))
indexer_text_region = 0 mask_poly = mask_poly[y : y + h, x : x + w, :]
indexer_textlines = 0 img_crop = img_poly_on_img[y : y + h, x : x + w, :]
# FIXME: non recursive, use OCR-D PAGE generateDS API. Or use an existing tool for this purpose altogether
for nn in root1.iter(region_tags):
for child_textregion in nn:
if child_textregion.tag.endswith("TextLine"):
for child_textlines in child_textregion:
if child_textlines.tag.endswith("Coords"):
cropped_lines_region_indexer.append(indexer_text_region)
p_h = child_textlines.attrib['points'].split(' ')
textline_coords = np.array([[int(x.split(',')[0]), int(x.split(',')[1])] for x in p_h])
x, y, w, h = cv2.boundingRect(textline_coords) if not do_not_mask_with_textline_contour:
img_crop[mask_poly == 0] = 255
total_bb_coordinates.append([x, y, w, h]) if img_crop.shape[0] == 0 or img_crop.shape[1] == 0:
img_crop = None
continue
if textline_text and img_crop is not None:
base_name = os.path.join(
dir_out, file_name + '_line_' + str(indexer_textlines)
)
if pref_of_dataset:
base_name += '_' + pref_of_dataset
if not do_not_mask_with_textline_contour:
base_name += '_masked'
img_poly_on_img = np.copy(img) with open(base_name + '.txt', 'w') as text_file:
text_file.write(textline_text)
cv2.imwrite(base_name + '.png', img_crop)
indexer_textlines += 1
mask_poly = np.zeros(img.shape)
mask_poly = cv2.fillPoly(mask_poly, pts=[textline_coords], color=(1, 1, 1))
mask_poly = mask_poly[y : y + h, x : x + w, :]
img_crop = img_poly_on_img[y : y + h, x : x + w, :]
if not do_not_mask_with_textline_contour:
img_crop[mask_poly == 0] = 255
if img_crop.shape[0] == 0 or img_crop.shape[1] == 0:
continue
if child_textlines.tag.endswith("TextEquiv"):
for cheild_text in child_textlines:
if cheild_text.tag.endswith("Unicode"):
textline_text = cheild_text.text
if textline_text:
base_name = os.path.join(
dir_out, file_name + '_line_' + str(indexer_textlines)
)
if pref_of_dataset:
base_name += '_' + pref_of_dataset
if not do_not_mask_with_textline_contour:
base_name += '_masked'
with open(base_name + '.txt', 'w') as text_file:
text_file.write(textline_text)
cv2.imwrite(base_name + '.png', img_crop)
indexer_textlines += 1
else:
total_bb_coordinates = []
tree = ET.parse(dir_xml, parser=ET.XMLParser(encoding="utf-8"))
root = tree.getroot()
alltags = [elem.tag for elem in root.iter()]
name_space = alltags[0].split('}')[0]
name_space = name_space.split('{')[1]
region_tags = np.unique([x for x in alltags if x.endswith('TextRegion')])
cropped_lines_region_indexer = []
indexer_text_region = 0
indexer_textlines = 0
# FIXME: non recursive, use OCR-D PAGE generateDS API. Or use an existing tool for this purpose altogether
for nn in root.iter(region_tags):
for child_textregion in nn:
if child_textregion.tag.endswith("TextLine"):
for child_textlines in child_textregion:
if child_textlines.tag.endswith("Coords"):
cropped_lines_region_indexer.append(indexer_text_region)
p_h = child_textlines.attrib['points'].split(' ')
textline_coords = np.array([[int(x.split(',')[0]), int(x.split(',')[1])] for x in p_h])
x, y, w, h = cv2.boundingRect(textline_coords)
if exclude_vertical_lines and h > 1.4 * w:
img_crop = None
continue
total_bb_coordinates.append([x, y, w, h])
img_poly_on_img = np.copy(img)
mask_poly = np.zeros(img.shape)
mask_poly = cv2.fillPoly(mask_poly, pts=[textline_coords], color=(1, 1, 1))
mask_poly = mask_poly[y : y + h, x : x + w, :]
img_crop = img_poly_on_img[y : y + h, x : x + w, :]
if not do_not_mask_with_textline_contour:
img_crop[mask_poly == 0] = 255
if img_crop.shape[0] == 0 or img_crop.shape[1] == 0:
img_crop = None
continue
if child_textlines.tag.endswith("TextEquiv"):
for cheild_text in child_textlines:
if cheild_text.tag.endswith("Unicode"):
textline_text = cheild_text.text
if textline_text and img_crop is not None:
base_name = os.path.join(
dir_out, file_name + '_line_' + str(indexer_textlines)
)
if pref_of_dataset:
base_name += '_' + pref_of_dataset
if not do_not_mask_with_textline_contour:
base_name += '_masked'
with open(base_name + '.txt', 'w') as text_file:
text_file.write(textline_text)
cv2.imwrite(base_name + '.png', img_crop)
indexer_textlines += 1

View file

@ -6,6 +6,7 @@ from pathlib import Path
from PIL import Image, ImageDraw, ImageFont from PIL import Image, ImageDraw, ImageFont
import cv2 import cv2
import numpy as np import numpy as np
from eynollah.utils.font import get_font
from eynollah.training.gt_gen_utils import ( from eynollah.training.gt_gen_utils import (
filter_contours_area_of_image, filter_contours_area_of_image,
@ -477,7 +478,7 @@ def visualize_layout_segmentation(xml_file, dir_xml, dir_out, dir_imgs):
co_text, co_graphic, co_sep, co_img, co_table, co_map, co_noise, y_len, x_len = get_layout_contours_for_visualization(xml_file) co_text, co_graphic, co_sep, co_img, co_table, co_map, co_noise, y_len, x_len = get_layout_contours_for_visualization(xml_file)
added_image = visualize_image_from_contours_layout(co_text['paragraph'], co_text['header']+co_text['heading'], co_text['drop-capital'], co_sep, co_img, co_text['marginalia'], co_table, img) added_image = visualize_image_from_contours_layout(co_text['paragraph'], co_text['header']+co_text['heading'], co_text['drop-capital'], co_sep, co_img, co_text['marginalia'], co_table, co_map, img)
cv2.imwrite(os.path.join(dir_out, f_name+'.png'), added_image) cv2.imwrite(os.path.join(dir_out, f_name+'.png'), added_image)
@ -514,8 +515,8 @@ def visualize_ocr_text(xml_file, dir_xml, dir_out):
else: else:
xml_files_ind = [xml_file] xml_files_ind = [xml_file]
font_path = "Charis-7.000/Charis-Regular.ttf" # Make sure this file exists! ###font_path = "Charis-7.000/Charis-Regular.ttf" # Make sure this file exists!
font = ImageFont.truetype(font_path, 40) font = get_font(font_size=40)#ImageFont.truetype(font_path, 40)
for ind_xml in tqdm(xml_files_ind): for ind_xml in tqdm(xml_files_ind):
indexer = 0 indexer = 0
@ -552,11 +553,11 @@ def visualize_ocr_text(xml_file, dir_xml, dir_out):
is_vertical = h > 2*w # Check orientation is_vertical = h > 2*w # Check orientation
font = fit_text_single_line(draw, ocr_texts[index], font_path, w, int(h*0.4) ) font = fit_text_single_line(draw, ocr_texts[index], w, int(h*0.4) )
if is_vertical: if is_vertical:
vertical_font = fit_text_single_line(draw, ocr_texts[index], font_path, h, int(w * 0.8)) vertical_font = fit_text_single_line(draw, ocr_texts[index], h, int(w * 0.8))
text_img = Image.new("RGBA", (h, w), (255, 255, 255, 0)) # Note: dimensions are swapped text_img = Image.new("RGBA", (h, w), (255, 255, 255, 0)) # Note: dimensions are swapped
text_draw = ImageDraw.Draw(text_img) text_draw = ImageDraw.Draw(text_img)

View file

@ -0,0 +1,59 @@
import os
import numpy as np
import json
import click
import logging
def run_character_list_update(dir_labels, out, current_character_list):
ls_labels = os.listdir(dir_labels)
ls_labels = [ind for ind in ls_labels if ind.endswith('.txt')]
if current_character_list:
with open(current_character_list, 'r') as f_name:
characters = json.load(f_name)
characters = set(characters)
else:
characters = set()
for ind in ls_labels:
label = open(os.path.join(dir_labels,ind),'r').read().split('\n')[0]
for char in label:
characters.add(char)
characters = sorted(list(set(characters)))
with open(out, 'w') as f_name:
json.dump(characters, f_name)
@click.command()
@click.option(
"--dir_labels",
"-dl",
help="directory of labels which are .txt files",
type=click.Path(exists=True, file_okay=False),
required=True,
)
@click.option(
"--current_character_list",
"-ccl",
help="existing character list in a .txt file that needs to be updated with a set of labels",
type=click.Path(exists=True, file_okay=True),
required=False,
)
@click.option(
"--out",
"-o",
help="An output .txt file where the generated or updated character list will be written",
type=click.Path(exists=False, file_okay=True),
)
def main(dir_labels, out, current_character_list):
run_character_list_update(dir_labels, out, current_character_list)

View file

@ -7,7 +7,7 @@ import cv2
from shapely import geometry from shapely import geometry
from pathlib import Path from pathlib import Path
from PIL import ImageFont from PIL import ImageFont
from eynollah.utils.font import get_font
KERNEL = np.ones((5, 5), np.uint8) KERNEL = np.ones((5, 5), np.uint8)
@ -350,11 +350,11 @@ def get_textline_contours_and_ocr_text(xml_file):
ocr_textlines.append(ocr_text_in[0]) ocr_textlines.append(ocr_text_in[0])
return co_use_case, y_len, x_len, ocr_textlines return co_use_case, y_len, x_len, ocr_textlines
def fit_text_single_line(draw, text, font_path, max_width, max_height): def fit_text_single_line(draw, text, max_width, max_height):
initial_font_size = 50 initial_font_size = 50
font_size = initial_font_size font_size = initial_font_size
while font_size > 10: # Minimum font size while font_size > 10: # Minimum font size
font = ImageFont.truetype(font_path, font_size) font = get_font(font_size=font_size)# ImageFont.truetype(font_path, font_size)
text_bbox = draw.textbbox((0, 0), text, font=font) # Get text bounding box text_bbox = draw.textbbox((0, 0), text, font=font) # Get text bounding box
text_width = text_bbox[2] - text_bbox[0] text_width = text_bbox[2] - text_bbox[0]
text_height = text_bbox[3] - text_bbox[1] text_height = text_bbox[3] - text_bbox[1]
@ -364,7 +364,7 @@ def fit_text_single_line(draw, text, font_path, max_width, max_height):
font_size -= 2 # Reduce font size and retry font_size -= 2 # Reduce font size and retry
return ImageFont.truetype(font_path, 10) # Smallest font fallback return get_font(font_size=10)#ImageFont.truetype(font_path, 10) # Smallest font fallback
def get_layout_contours_for_visualization(xml_file): def get_layout_contours_for_visualization(xml_file):
tree1 = ET.parse(xml_file, parser = ET.XMLParser(encoding='utf-8')) tree1 = ET.parse(xml_file, parser = ET.XMLParser(encoding='utf-8'))

View file

@ -12,6 +12,7 @@ from keras.models import Model, load_model
from keras import backend as K from keras import backend as K
import click import click
from tensorflow.python.keras import backend as tensorflow_backend from tensorflow.python.keras import backend as tensorflow_backend
from tensorflow.keras.layers import StringLookup
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
from .gt_gen_utils import ( from .gt_gen_utils import (
@ -169,6 +170,25 @@ class sbb_predict:
self.model = tf.keras.models.Model( self.model = tf.keras.models.Model(
self.model.get_layer(name = "image").input, self.model.get_layer(name = "image").input,
self.model.get_layer(name = "dense2").output) self.model.get_layer(name = "dense2").output)
assert isinstance(self.model, Model)
elif self.task == "transformer-ocr":
import torch
from transformers import VisionEncoderDecoderModel
from transformers import TrOCRProcessor
self.model = VisionEncoderDecoderModel.from_pretrained(self.model_dir)
self.processor = TrOCRProcessor.from_pretrained(self.model_dir)
if self.cpu:
self.device = torch.device('cpu')
else:
self.device = torch.device('cuda:0')
self.model.to(self.device)
assert isinstance(self.model, torch.nn.Module)
else: else:
config = tf.compat.v1.ConfigProto() config = tf.compat.v1.ConfigProto()
config.gpu_options.allow_growth = True config.gpu_options.allow_growth = True
@ -176,15 +196,15 @@ class sbb_predict:
session = tf.compat.v1.Session(config=config) # tf.InteractiveSession() session = tf.compat.v1.Session(config=config) # tf.InteractiveSession()
tensorflow_backend.set_session(session) tensorflow_backend.set_session(session)
self.model = load_model(self.model_dir , compile=False,custom_objects = {"PatchEncoder": PatchEncoder, "Patches": Patches})
##if self.weights_dir!=None:
##self.model.load_weights(self.weights_dir) if self.task != 'classification' and self.task != 'reading_order':
self.img_height=self.model.layers[len(self.model.layers)-1].output_shape[1]
self.img_width=self.model.layers[len(self.model.layers)-1].output_shape[2]
self.n_classes=self.model.layers[len(self.model.layers)-1].output_shape[3]
assert isinstance(self.model, Model)
if self.task != 'classification' and self.task != 'reading_order': assert isinstance(self.model, Model)
self.img_height=self.model.layers[len(self.model.layers)-1].output_shape[1]
self.img_width=self.model.layers[len(self.model.layers)-1].output_shape[2]
self.n_classes=self.model.layers[len(self.model.layers)-1].output_shape[3]
def visualize_model_output(self, prediction, img, task) -> Tuple[NDArray, NDArray]: def visualize_model_output(self, prediction, img, task) -> Tuple[NDArray, NDArray]:
if task == "binarization": if task == "binarization":
@ -235,10 +255,9 @@ class sbb_predict:
return added_image, layout_only return added_image, layout_only
def predict(self, image_dir): def predict(self, image_dir):
assert isinstance(self.model, Model)
if self.task == 'classification': if self.task == 'classification':
classes_names = self.config_params_model['classification_classes_name'] classes_names = self.config_params_model['classification_classes_name']
img_1ch = img=cv2.imread(image_dir, 0) img_1ch =cv2.imread(image_dir, 0)
img_1ch = img_1ch / 255.0 img_1ch = img_1ch / 255.0
img_1ch = cv2.resize(img_1ch, (self.config_params_model['input_height'], self.config_params_model['input_width']), interpolation=cv2.INTER_NEAREST) img_1ch = cv2.resize(img_1ch, (self.config_params_model['input_height'], self.config_params_model['input_width']), interpolation=cv2.INTER_NEAREST)
@ -273,6 +292,15 @@ class sbb_predict:
pred_texts = decode_batch_predictions(preds, num_to_char) pred_texts = decode_batch_predictions(preds, num_to_char)
pred_texts = pred_texts[0].replace("[UNK]", "") pred_texts = pred_texts[0].replace("[UNK]", "")
return pred_texts return pred_texts
elif self.task == "transformer-ocr":
from PIL import Image
image = Image.open(image_dir).convert("RGB")
pixel_values = self.processor(image, return_tensors="pt").pixel_values
generated_ids = self.model.generate(pixel_values.to(self.device))
return self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
elif self.task == 'reading_order': elif self.task == 'reading_order':
@ -607,6 +635,8 @@ class sbb_predict:
cv2.imwrite(self.save,res) cv2.imwrite(self.save,res)
elif self.task == "cnn-rnn-ocr": elif self.task == "cnn-rnn-ocr":
print(f"Detected text: {res}") print(f"Detected text: {res}")
elif self.task == "transformer-ocr":
print(f"Detected text: {res}")
else: else:
img_seg_overlayed, only_layout = self.visualize_model_output(res, self.img_org, self.task) img_seg_overlayed, only_layout = self.visualize_model_output(res, self.img_org, self.task)
if self.save: if self.save:
@ -710,10 +740,12 @@ class sbb_predict:
) )
def main(image, dir_in, model, patches, save, save_layout, ground_truth, xml_file, cpu, out, min_area): def main(image, dir_in, model, patches, save, save_layout, ground_truth, xml_file, cpu, out, min_area):
assert image or dir_in, "Either a single image -i or a dir_in -di is required" assert image or dir_in, "Either a single image -i or a dir_in -di is required"
with open(os.path.join(model,'config.json')) as f:
with open(os.path.join(model,'config_eynollah.json')) as f:
config_params_model = json.load(f) config_params_model = json.load(f)
task = config_params_model['task'] task = config_params_model['task']
if task != 'classification' and task != 'reading_order' and task != "cnn-rnn-ocr": if task != 'classification' and task != 'reading_order' and task != "cnn-rnn-ocr" and task != "transformer-ocr":
if image and not save: if image and not save:
print("Error: You used one of segmentation or binarization task with image input but not set -s, you need a filename to save visualized output with -s") print("Error: You used one of segmentation or binarization task with image input but not set -s, you need a filename to save visualized output with -s")
sys.exit(1) sys.exit(1)

View file

@ -1,7 +1,8 @@
import os import os
import sys import sys
import json import json
import numpy as np
import cv2
import click import click
from eynollah.training.metrics import ( from eynollah.training.metrics import (
@ -27,7 +28,8 @@ from eynollah.training.utils import (
generate_data_from_folder_training, generate_data_from_folder_training,
get_one_hot, get_one_hot,
provide_patches, provide_patches,
return_number_of_total_training_data return_number_of_total_training_data,
OCRDatasetYieldAugmentations
) )
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
@ -41,11 +43,16 @@ from sklearn.metrics import f1_score
from tensorflow.keras.callbacks import Callback from tensorflow.keras.callbacks import Callback
from tensorflow.keras.layers import StringLookup from tensorflow.keras.layers import StringLookup
import numpy as np
import cv2 import torch
from transformers import TrOCRProcessor
import evaluate
from transformers import default_data_collator
from transformers import VisionEncoderDecoderModel
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
class SaveWeightsAfterSteps(Callback): class SaveWeightsAfterSteps(Callback):
def __init__(self, save_interval, save_path, _config): def __init__(self, save_interval, save_path, _config, characters_cnnrnn_ocr=None):
super(SaveWeightsAfterSteps, self).__init__() super(SaveWeightsAfterSteps, self).__init__()
self.save_interval = save_interval self.save_interval = save_interval
self.save_path = save_path self.save_path = save_path
@ -61,7 +68,10 @@ class SaveWeightsAfterSteps(Callback):
self.model.save(save_file) self.model.save(save_file)
with open(os.path.join(os.path.join(self.save_path, f"model_step_{self.step_count}"),"config.json"), "w") as fp: if characters_cnnrnn_ocr:
os.system("cp "+characters_cnnrnn_ocr+" "+os.path.join(os.path.join(self.save_path, f"model_step_{self.step_count}"),"characters_org.txt"))
with open(os.path.join(os.path.join(self.save_path, f"model_step_{self.step_count}"),"config_eynollah.json"), "w") as fp:
json.dump(self._config, fp) # encode dict into JSON json.dump(self._config, fp) # encode dict into JSON
print(f"saved model as steps {self.step_count} to {save_file}") print(f"saved model as steps {self.step_count} to {save_file}")
@ -477,7 +487,7 @@ def run(
model.save(os.path.join(dir_output,'model_'+str(i))) model.save(os.path.join(dir_output,'model_'+str(i)))
with open(os.path.join(os.path.join(dir_output,'model_'+str(i)),"config.json"), "w") as fp: with open(os.path.join(os.path.join(dir_output,'model_'+str(i)),"config_eynollah.json"), "w") as fp:
json.dump(_config, fp) # encode dict into JSON json.dump(_config, fp) # encode dict into JSON
#os.system('rm -rf '+dir_train_flowing) #os.system('rm -rf '+dir_train_flowing)
@ -537,7 +547,7 @@ def run(
opt = tf.keras.optimizers.Adam(learning_rate=learning_rate)#1e-4)#(lr_schedule) opt = tf.keras.optimizers.Adam(learning_rate=learning_rate)#1e-4)#(lr_schedule)
model.compile(optimizer=opt) model.compile(optimizer=opt)
save_weights_callback = SaveWeightsAfterSteps(save_interval, dir_output, _config) if save_interval else None save_weights_callback = SaveWeightsAfterSteps(save_interval, dir_output, _config, characters_cnnrnn_ocr=characters_txt_file) if save_interval else None
for i in tqdm(range(index_start, n_epochs + index_start)): for i in tqdm(range(index_start, n_epochs + index_start)):
if save_interval: if save_interval:
@ -556,9 +566,125 @@ def run(
if i >=0: if i >=0:
model.save( os.path.join(dir_output,'model_'+str(i) )) model.save( os.path.join(dir_output,'model_'+str(i) ))
with open(os.path.join(os.path.join(dir_output,'model_'+str(i)),"config.json"), "w") as fp: os.system("cp "+characters_txt_file+" "+os.path.join(os.path.join(dir_output,'model_'+str(i)),"characters_org.txt"))
with open(os.path.join(os.path.join(dir_output,'model_'+str(i)),"config_eynollah.json"), "w") as fp:
json.dump(_config, fp) # encode dict into JSON json.dump(_config, fp) # encode dict into JSON
elif task=="transformer-ocr":
dir_img, dir_lab = get_dirs_or_files(dir_train)
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-printed")
ls_files_images = os.listdir(dir_img)
aug_multip = return_multiplier_based_on_augmnentations(augmentation, color_padding_rotation, rotation_not_90, blur_aug, degrading, bin_deg,
brightening, padding_white, adding_rgb_foreground, adding_rgb_background, binarization,
image_inversion, channels_shuffling, add_red_textlines, white_noise_strap, textline_skewing, textline_skewing_bin, textline_left_in_depth, textline_left_in_depth_bin, textline_right_in_depth, textline_right_in_depth_bin, textline_up_in_depth, textline_up_in_depth_bin, textline_down_in_depth, textline_down_in_depth_bin, pepper_bin_aug, pepper_aug, degrade_scales, number_of_backgrounds_per_image, thetha, thetha_padd, brightness, padd_colors, shuffle_indexes, pepper_indexes, skewing_amplitudes, blur_k, white_padds)
len_dataset = aug_multip*len(ls_files_images)
dataset = OCRDatasetYieldAugmentations(
dir_img=dir_img,
dir_img_bin=dir_img_bin,
dir_lab=dir_lab,
processor=processor,
max_target_length=max_len,
augmentation = augmentation,
binarization = binarization,
add_red_textlines = add_red_textlines,
white_noise_strap = white_noise_strap,
adding_rgb_foreground = adding_rgb_foreground,
adding_rgb_background = adding_rgb_background,
bin_deg = bin_deg,
blur_aug = blur_aug,
brightening = brightening,
padding_white = padding_white,
color_padding_rotation = color_padding_rotation,
rotation_not_90 = rotation_not_90,
degrading = degrading,
channels_shuffling = channels_shuffling,
textline_skewing = textline_skewing,
textline_skewing_bin = textline_skewing_bin,
textline_right_in_depth = textline_right_in_depth,
textline_left_in_depth = textline_left_in_depth,
textline_up_in_depth = textline_up_in_depth,
textline_down_in_depth = textline_down_in_depth,
textline_right_in_depth_bin = textline_right_in_depth_bin,
textline_left_in_depth_bin = textline_left_in_depth_bin,
textline_up_in_depth_bin = textline_up_in_depth_bin,
textline_down_in_depth_bin = textline_down_in_depth_bin,
pepper_aug = pepper_aug,
pepper_bin_aug = pepper_bin_aug,
list_all_possible_background_images=list_all_possible_background_images,
list_all_possible_foreground_rgbs=list_all_possible_foreground_rgbs,
blur_k = blur_k,
degrade_scales = degrade_scales,
white_padds = white_padds,
thetha_padd = thetha_padd,
thetha = thetha,
brightness = brightness,
padd_colors = padd_colors,
number_of_backgrounds_per_image = number_of_backgrounds_per_image,
shuffle_indexes = shuffle_indexes,
pepper_indexes = pepper_indexes,
skewing_amplitudes = skewing_amplitudes,
dir_rgb_backgrounds = dir_rgb_backgrounds,
dir_rgb_foregrounds = dir_rgb_foregrounds,
len_data=len_dataset,
)
# Create a DataLoader
data_loader = torch.utils.data.DataLoader(dataset, batch_size=1)
train_dataset = data_loader.dataset
if continue_training:
model = VisionEncoderDecoderModel.from_pretrained(dir_of_start_model)
else:
model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-printed")
# set special tokens used for creating the decoder_input_ids from the labels
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id
# make sure vocab size is set correctly
model.config.vocab_size = model.config.decoder.vocab_size
# set beam search parameters
model.config.eos_token_id = processor.tokenizer.sep_token_id
model.config.max_length = max_len
model.config.early_stopping = True
model.config.no_repeat_ngram_size = 3
model.config.length_penalty = 2.0
model.config.num_beams = 4
training_args = Seq2SeqTrainingArguments(
predict_with_generate=True,
num_train_epochs=n_epochs,
learning_rate=learning_rate,
per_device_train_batch_size=n_batch,
fp16=True,
output_dir=dir_output,
logging_steps=2,
save_steps=save_interval,
)
cer_metric = evaluate.load("cer")
# instantiate trainer
trainer = Seq2SeqTrainer(
model=model,
tokenizer=processor.feature_extractor,
args=training_args,
train_dataset=train_dataset,
data_collator=default_data_collator,
)
trainer.train()
elif task=='classification': elif task=='classification':
configuration() configuration()
model = resnet50_classifier(n_classes, input_height, input_width, weight_decay, pretraining) model = resnet50_classifier(n_classes, input_height, input_width, weight_decay, pretraining)
@ -609,10 +735,10 @@ def run(
model_weight_averaged.set_weights(new_weights) model_weight_averaged.set_weights(new_weights)
model_weight_averaged.save(os.path.join(dir_output,'model_ens_avg')) model_weight_averaged.save(os.path.join(dir_output,'model_ens_avg'))
with open(os.path.join( os.path.join(dir_output,'model_ens_avg'), "config.json"), "w") as fp: with open(os.path.join( os.path.join(dir_output,'model_ens_avg'), "config_eynollah.json"), "w") as fp:
json.dump(_config, fp) # encode dict into JSON json.dump(_config, fp) # encode dict into JSON
with open(os.path.join( os.path.join(dir_output,'model_best'), "config.json"), "w") as fp: with open(os.path.join( os.path.join(dir_output,'model_best'), "config_eynollah.json"), "w") as fp:
json.dump(_config, fp) # encode dict into JSON json.dump(_config, fp) # encode dict into JSON
elif task=='reading_order': elif task=='reading_order':
@ -645,7 +771,7 @@ def run(
history = model.fit(generate_arrays_from_folder_reading_order(dir_flow_train_labels, dir_flow_train_imgs, n_batch, input_height, input_width, n_classes, thetha, augmentation), steps_per_epoch=num_rows / n_batch, verbose=1) history = model.fit(generate_arrays_from_folder_reading_order(dir_flow_train_labels, dir_flow_train_imgs, n_batch, input_height, input_width, n_classes, thetha, augmentation), steps_per_epoch=num_rows / n_batch, verbose=1)
model.save( os.path.join(dir_output,'model_'+str(i+indexer_start) )) model.save( os.path.join(dir_output,'model_'+str(i+indexer_start) ))
with open(os.path.join(os.path.join(dir_output,'model_'+str(i)),"config.json"), "w") as fp: with open(os.path.join(os.path.join(dir_output,'model_'+str(i)),"config_eynollah.json"), "w") as fp:
json.dump(_config, fp) # encode dict into JSON json.dump(_config, fp) # encode dict into JSON
''' '''
if f1score>f1score_tot[0]: if f1score>f1score_tot[0]:

View file

@ -13,8 +13,12 @@ import tensorflow as tf
from tensorflow.keras.utils import to_categorical from tensorflow.keras.utils import to_categorical
from PIL import Image, ImageFile, ImageEnhance from PIL import Image, ImageFile, ImageEnhance
import torch
from torch.utils.data import IterableDataset
ImageFile.LOAD_TRUNCATED_IMAGES = True ImageFile.LOAD_TRUNCATED_IMAGES = True
def vectorize_label(label, char_to_num, padding_token, max_len): def vectorize_label(label, char_to_num, padding_token, max_len):
label = char_to_num(tf.strings.unicode_split(label, input_encoding="UTF-8")) label = char_to_num(tf.strings.unicode_split(label, input_encoding="UTF-8"))
length = tf.shape(label)[0] length = tf.shape(label)[0]
@ -76,6 +80,7 @@ def add_salt_and_pepper_noise(img, salt_prob, pepper_prob):
return noisy_image return noisy_image
def invert_image(img): def invert_image(img):
img_inv = 255 - img img_inv = 255 - img
return img_inv return img_inv
@ -1668,3 +1673,411 @@ def return_multiplier_based_on_augmnentations(
aug_multip += len(pepper_indexes) aug_multip += len(pepper_indexes)
return aug_multip return aug_multip
class OCRDatasetYieldAugmentations(IterableDataset):
def __init__(
self,
dir_img,
dir_img_bin,
dir_lab,
processor,
max_target_length=128,
augmentation = None,
binarization = None,
add_red_textlines = None,
white_noise_strap = None,
adding_rgb_foreground = None,
adding_rgb_background = None,
bin_deg = None,
blur_aug = None,
brightening = None,
padding_white = None,
color_padding_rotation = None,
rotation_not_90 = None,
degrading = None,
channels_shuffling = None,
textline_skewing = None,
textline_skewing_bin = None,
textline_right_in_depth = None,
textline_left_in_depth = None,
textline_up_in_depth = None,
textline_down_in_depth = None,
textline_right_in_depth_bin = None,
textline_left_in_depth_bin = None,
textline_up_in_depth_bin = None,
textline_down_in_depth_bin = None,
pepper_aug = None,
pepper_bin_aug = None,
list_all_possible_background_images=None,
list_all_possible_foreground_rgbs=None,
blur_k = None,
degrade_scales = None,
white_padds = None,
thetha_padd = None,
thetha = None,
brightness = None,
padd_colors = None,
number_of_backgrounds_per_image = None,
shuffle_indexes = None,
pepper_indexes = None,
skewing_amplitudes = None,
dir_rgb_backgrounds = None,
dir_rgb_foregrounds = None,
len_data=None,
):
"""
Args:
images_dir (str): Path to the directory containing images.
labels_dir (str): Path to the directory containing label text files.
tokenizer: Tokenizer for processing labels.
transform: Transformations applied after augmentation (e.g., ToTensor, normalization).
image_size (tuple): Size to resize images to.
max_seq_len (int): Maximum sequence length for tokenized labels.
scales (list or None): List of scale factors to apply.
"""
self.dir_img = dir_img
self.dir_img_bin = dir_img_bin
self.dir_lab = dir_lab
self.processor = processor
self.max_target_length = max_target_length
#self.scales = scales if scales else []
self.augmentation = augmentation
self.binarization = binarization
self.add_red_textlines = add_red_textlines
self.white_noise_strap = white_noise_strap
self.adding_rgb_foreground = adding_rgb_foreground
self.adding_rgb_background = adding_rgb_background
self.bin_deg = bin_deg
self.blur_aug = blur_aug
self.brightening = brightening
self.padding_white = padding_white
self.color_padding_rotation = color_padding_rotation
self.rotation_not_90 = rotation_not_90
self.degrading = degrading
self.channels_shuffling = channels_shuffling
self.textline_skewing = textline_skewing
self.textline_skewing_bin = textline_skewing_bin
self.textline_right_in_depth = textline_right_in_depth
self.textline_left_in_depth = textline_left_in_depth
self.textline_up_in_depth = textline_up_in_depth
self.textline_down_in_depth = textline_down_in_depth
self.textline_right_in_depth_bin = textline_right_in_depth_bin
self.textline_left_in_depth_bin = textline_left_in_depth_bin
self.textline_up_in_depth_bin = textline_up_in_depth_bin
self.textline_down_in_depth_bin = textline_down_in_depth_bin
self.pepper_aug = pepper_aug
self.pepper_bin_aug = pepper_bin_aug
self.list_all_possible_background_images=list_all_possible_background_images
self.list_all_possible_foreground_rgbs=list_all_possible_foreground_rgbs
self.blur_k = blur_k
self.degrade_scales = degrade_scales
self.white_padds = white_padds
self.thetha_padd = thetha_padd
self.thetha = thetha
self.brightness = brightness
self.padd_colors = padd_colors
self.number_of_backgrounds_per_image = number_of_backgrounds_per_image
self.shuffle_indexes = shuffle_indexes
self.pepper_indexes = pepper_indexes
self.skewing_amplitudes = skewing_amplitudes
self.dir_rgb_backgrounds = dir_rgb_backgrounds
self.dir_rgb_foregrounds = dir_rgb_foregrounds
self.image_files = os.listdir(dir_img)#sorted([f for f in os.listdir(images_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))])
self.len_data = len_data
#assert len(self.image_files) == len(self.label_files), "Number of images and labels must match!"
def __len__(self):
return self.len_data
def __iter__(self):
for img_file in self.image_files:
# Load image
f_name = img_file.split('.')[0]
txt_inp = open(os.path.join(self.dir_lab, f_name+'.txt'),'r').read().split('\n')[0]
img = cv2.imread(os.path.join(self.dir_img, img_file))
img = img.astype(np.uint8)
if self.dir_img_bin:
img_bin_corr = cv2.imread(os.path.join(self.dir_img_bin, f_name+'.png') )
img_bin_corr = img_bin_corr.astype(np.uint8)
else:
img_bin_corr = None
labels = self.processor.tokenizer(txt_inp,
padding="max_length",
max_length=self.max_target_length).input_ids
labels = [label if label != self.processor.tokenizer.pad_token_id else -100 for label in labels]
if self.augmentation:
pixel_values = self.processor(Image.fromarray(img), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.color_padding_rotation:
for index, thetha_ind in enumerate(self.thetha_padd):
for padd_col in self.padd_colors:
img_out = rotation_not_90_func_single_image(do_padding_for_ocr(img, 1.2, padd_col), thetha_ind)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.rotation_not_90:
for index, thetha_ind in enumerate(self.thetha):
img_out = rotation_not_90_func_single_image(img, thetha_ind)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.blur_aug:
for index, blur_type in enumerate(self.blur_k):
img_out = bluring(img, blur_type)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.degrading:
for index, deg_scale_ind in enumerate(self.degrade_scales):
try:
img_out = do_degrading(img, deg_scale_ind)
except:
img_out = np.copy(img)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.bin_deg:
for index, deg_scale_ind in enumerate(self.degrade_scales):
try:
img_out = self.do_degrading(img_bin_corr, deg_scale_ind)
except:
img_out = np.copy(img_bin_corr)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.brightening:
for index, bright_scale_ind in enumerate(self.brightness):
try:
img_out = do_brightening(dir_img, bright_scale_ind)
except:
img_out = np.copy(img)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.padding_white:
for index, padding_size in enumerate(self.white_padds):
for padd_col in self.padd_colors:
img_out = do_padding_for_ocr(img, padding_size, padd_col)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.adding_rgb_foreground:
for i_n in range(self.number_of_backgrounds_per_image):
background_image_chosen_name = random.choice(self.list_all_possible_background_images)
foreground_rgb_chosen_name = random.choice(self.list_all_possible_foreground_rgbs)
img_rgb_background_chosen = cv2.imread(self.dir_rgb_backgrounds + '/' + background_image_chosen_name)
foreground_rgb_chosen = np.load(self.dir_rgb_foregrounds + '/' + foreground_rgb_chosen_name)
img_out = return_binary_image_with_given_rgb_background_and_given_foreground_rgb(img_bin_corr, img_rgb_background_chosen, foreground_rgb_chosen)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.adding_rgb_background:
for i_n in range(self.number_of_backgrounds_per_image):
background_image_chosen_name = random.choice(self.list_all_possible_background_images)
img_rgb_background_chosen = cv2.imread(self.dir_rgb_backgrounds + '/' + background_image_chosen_name)
img_out = return_binary_image_with_given_rgb_background(img_bin_corr, img_rgb_background_chosen)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.binarization:
pixel_values = self.processor(Image.fromarray(img_bin_corr), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.channels_shuffling:
for shuffle_index in self.shuffle_indexes:
img_out = return_shuffled_channels(img, shuffle_index)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.add_red_textlines:
img_out = return_image_with_red_elements(img, img_bin_corr)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.white_noise_strap:
img_out = return_image_with_strapped_white_noises(img)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.textline_skewing:
for index, des_scale_ind in enumerate(self.skewing_amplitudes):
try:
img_out = do_deskewing(img, des_scale_ind)
except:
img_out = np.copy(img)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.textline_skewing_bin:
for index, des_scale_ind in enumerate(self.skewing_amplitudes):
try:
img_out = do_deskewing(img_bin_corr, des_scale_ind)
except:
img_out = np.copy(img_bin_corr)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.textline_left_in_depth:
try:
img_out = do_direction_in_depth(img, 'left')
except:
img_out = np.copy(img)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.textline_left_in_depth_bin:
try:
img_out = do_direction_in_depth(img_bin_corr, 'left')
except:
img_out = np.copy(img_bin_corr)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.textline_right_in_depth:
try:
img_out = do_direction_in_depth(img, 'right')
except:
img_out = np.copy(img)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.textline_right_in_depth_bin:
try:
img_out = do_direction_in_depth(img_bin_corr, 'right')
except:
img_out = np.copy(img_bin_corr)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.textline_up_in_depth:
try:
img_out = do_direction_in_depth(img, 'up')
except:
img_out = np.copy(img)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.textline_up_in_depth_bin:
try:
img_out = do_direction_in_depth(img_bin_corr, 'up')
except:
img_out = np.copy(img_bin_corr)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.textline_down_in_depth:
try:
img_out = do_direction_in_depth(img, 'down')
except:
img_out = np.copy(img)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.textline_down_in_depth_bin:
try:
img_out = do_direction_in_depth(img_bin_corr, 'down')
except:
img_out = np.copy(img_bin_corr)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.pepper_bin_aug:
for index, pepper_ind in enumerate(self.pepper_indexes):
img_out = add_salt_and_pepper_noise(img_bin_corr, pepper_ind, pepper_ind)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.pepper_aug:
for index, pepper_ind in enumerate(self.pepper_indexes):
img_out = add_salt_and_pepper_noise(img, pepper_ind, pepper_ind)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
else:
pixel_values = self.processor(Image.fromarray(img), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding

View file

@ -21,6 +21,11 @@ from tensorflow.keras.layers import *
import click import click
import logging import logging
from transformers import TrOCRProcessor
from PIL import Image
import torch
from transformers import VisionEncoderDecoderModel
class Patches(layers.Layer): class Patches(layers.Layer):
def __init__(self, patch_size_x, patch_size_y): def __init__(self, patch_size_x, patch_size_y):
@ -92,30 +97,46 @@ def start_new_session():
tensorflow_backend.set_session(session) tensorflow_backend.set_session(session)
return session return session
def run_ensembling(dir_models, out): def run_ensembling(dir_models, out, framework):
ls_models = os.listdir(dir_models) ls_models = os.listdir(dir_models)
if framework=="torch":
models = []
weights=[] sd_models = []
for model_name in ls_models:
model = load_model(os.path.join(dir_models,model_name) , compile=False, custom_objects={'PatchEncoder':PatchEncoder, 'Patches': Patches})
weights.append(model.get_weights())
new_weights = list() for model_name in ls_models:
model = VisionEncoderDecoderModel.from_pretrained(os.path.join(dir_models,model_name))
for weights_list_tuple in zip(*weights): models.append(model)
new_weights.append( sd_models.append(model.state_dict())
[np.array(weights_).mean(axis=0)\ for key in sd_models[0]:
for weights_ in zip(*weights_list_tuple)]) sd_models[0][key] = sum(sd[key] for sd in sd_models) / len(sd_models)
model.load_state_dict(sd_models[0])
os.system("mkdir "+out)
torch.save(model.state_dict(), os.path.join(out, "pytorch_model.bin"))
os.system('cp '+os.path.join(os.path.join(dir_models,model_name) , "config_eynollah.json ")+out)
else:
weights=[]
for model_name in ls_models:
model = load_model(os.path.join(dir_models,model_name) , compile=False, custom_objects={'PatchEncoder':PatchEncoder, 'Patches': Patches})
weights.append(model.get_weights())
new_weights = list()
for weights_list_tuple in zip(*weights):
new_weights.append(
[np.array(weights_).mean(axis=0)\
for weights_ in zip(*weights_list_tuple)])
new_weights = [np.array(x) for x in new_weights] new_weights = [np.array(x) for x in new_weights]
model.set_weights(new_weights) model.set_weights(new_weights)
model.save(out) model.save(out)
os.system('cp '+os.path.join(os.path.join(dir_models,model_name) , "config.json ")+out) os.system('cp '+os.path.join(os.path.join(dir_models,model_name) , "config_eynollah.json ")+out)
os.system('cp '+os.path.join(os.path.join(dir_models,model_name) , "characters_org.txt ")+out)
@click.command() @click.command()
@click.option( @click.option(
@ -130,7 +151,12 @@ def run_ensembling(dir_models, out):
help="output directory where ensembled model will be written.", help="output directory where ensembled model will be written.",
type=click.Path(exists=False, file_okay=False), type=click.Path(exists=False, file_okay=False),
) )
@click.option(
"--framework",
"-fw",
help="this parameter gets tensorflow or torch as model framework",
)
def main(dir_models, out): def main(dir_models, out, framework):
run_ensembling(dir_models, out) run_ensembling(dir_models, out, framework)

View file

@ -9,8 +9,8 @@ else:
import importlib.resources as importlib_resources import importlib.resources as importlib_resources
def get_font(): def get_font(font_size):
#font_path = "Charis-7.000/Charis-Regular.ttf" # Make sure this file exists! #font_path = "Charis-7.000/Charis-Regular.ttf" # Make sure this file exists!
font = importlib_resources.files(__package__) / "../Charis-Regular.ttf" font = importlib_resources.files(__package__) / "../Charis-Regular.ttf"
with importlib_resources.as_file(font) as font: with importlib_resources.as_file(font) as font:
return ImageFont.truetype(font=font, size=40) return ImageFont.truetype(font=font, size=font_size)

View file

@ -1,17 +1,17 @@
{ {
"backbone_type" : "transformer", "backbone_type" : "transformer",
"task": "cnn-rnn-ocr", "task": "transformer-ocr",
"n_classes" : 2, "n_classes" : 2,
"max_len": 280, "max_len": 192,
"n_epochs" : 3, "n_epochs" : 1,
"input_height" : 32, "input_height" : 32,
"input_width" : 512, "input_width" : 512,
"weight_decay" : 1e-6, "weight_decay" : 1e-6,
"n_batch" : 4, "n_batch" : 1,
"learning_rate": 1e-5, "learning_rate": 1e-5,
"save_interval": 1500, "save_interval": 1500,
"patches" : false, "patches" : false,
"pretraining" : true, "pretraining" : false,
"augmentation" : true, "augmentation" : true,
"flip_aug" : false, "flip_aug" : false,
"blur_aug" : true, "blur_aug" : true,
@ -77,7 +77,6 @@
"dir_output": "/home/vahid/extracted_lines/1919_bin/output", "dir_output": "/home/vahid/extracted_lines/1919_bin/output",
"dir_rgb_backgrounds": "/home/vahid/Documents/1_2_test_eynollah/set_rgb_background", "dir_rgb_backgrounds": "/home/vahid/Documents/1_2_test_eynollah/set_rgb_background",
"dir_rgb_foregrounds": "/home/vahid/Documents/1_2_test_eynollah/out_set_rgb_foreground", "dir_rgb_foregrounds": "/home/vahid/Documents/1_2_test_eynollah/out_set_rgb_foreground",
"dir_img_bin": "/home/vahid/extracted_lines/1919_bin/images_bin", "dir_img_bin": "/home/vahid/extracted_lines/1919_bin/images_bin"
"characters_txt_file":"/home/vahid/Downloads/models_eynollah/model_eynollah_ocr_cnnrnn_20250930/characters_org.txt"
} }

View file

@ -4,3 +4,8 @@ numpy <1.24.0
tqdm tqdm
imutils imutils
scipy scipy
torch
evaluate
accelerate
jiwer
transformers <= 4.30.2