mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-02-20 16:32:03 +01:00
Merge c4434c7f7d into 586077fbcd
This commit is contained in:
commit
38745111df
13 changed files with 875 additions and 113 deletions
|
|
@ -70,6 +70,7 @@ class Eynollah_ocr:
|
||||||
self.model_zoo.get('ocr').to(self.device)
|
self.model_zoo.get('ocr').to(self.device)
|
||||||
else:
|
else:
|
||||||
self.model_zoo.load_model('ocr', '')
|
self.model_zoo.load_model('ocr', '')
|
||||||
|
self.input_shape = self.model_zoo.get('ocr').input_shape[1:3]
|
||||||
self.model_zoo.load_model('num_to_char')
|
self.model_zoo.load_model('num_to_char')
|
||||||
self.model_zoo.load_model('characters')
|
self.model_zoo.load_model('characters')
|
||||||
self.end_character = len(self.model_zoo.get('characters', list)) + 2
|
self.end_character = len(self.model_zoo.get('characters', list)) + 2
|
||||||
|
|
@ -657,7 +658,7 @@ class Eynollah_ocr:
|
||||||
if out_image_with_text:
|
if out_image_with_text:
|
||||||
image_text = Image.new("RGB", (img.shape[1], img.shape[0]), "white")
|
image_text = Image.new("RGB", (img.shape[1], img.shape[0]), "white")
|
||||||
draw = ImageDraw.Draw(image_text)
|
draw = ImageDraw.Draw(image_text)
|
||||||
font = get_font()
|
font = get_font(font_size=40)
|
||||||
|
|
||||||
for indexer_text, bb_ind in enumerate(total_bb_coordinates):
|
for indexer_text, bb_ind in enumerate(total_bb_coordinates):
|
||||||
x_bb = bb_ind[0]
|
x_bb = bb_ind[0]
|
||||||
|
|
@ -823,8 +824,8 @@ class Eynollah_ocr:
|
||||||
page_ns=page_ns,
|
page_ns=page_ns,
|
||||||
|
|
||||||
img_bin=img_bin,
|
img_bin=img_bin,
|
||||||
image_width=512,
|
image_width=self.input_shape[1],
|
||||||
image_height=32,
|
image_height=self.input_shape[0],
|
||||||
)
|
)
|
||||||
|
|
||||||
self.write_ocr(
|
self.write_ocr(
|
||||||
|
|
|
||||||
|
|
@ -10,6 +10,7 @@ from .inference import main as inference_cli
|
||||||
from .train import ex
|
from .train import ex
|
||||||
from .extract_line_gt import linegt_cli
|
from .extract_line_gt import linegt_cli
|
||||||
from .weights_ensembling import main as ensemble_cli
|
from .weights_ensembling import main as ensemble_cli
|
||||||
|
from .generate_or_update_cnn_rnn_ocr_character_list import main as update_ocr_characters_cli
|
||||||
|
|
||||||
@click.command(context_settings=dict(
|
@click.command(context_settings=dict(
|
||||||
ignore_unknown_options=True,
|
ignore_unknown_options=True,
|
||||||
|
|
@ -28,3 +29,4 @@ main.add_command(inference_cli, 'inference')
|
||||||
main.add_command(train_cli, 'train')
|
main.add_command(train_cli, 'train')
|
||||||
main.add_command(linegt_cli, 'export_textline_images_and_text')
|
main.add_command(linegt_cli, 'export_textline_images_and_text')
|
||||||
main.add_command(ensemble_cli, 'ensembling')
|
main.add_command(ensemble_cli, 'ensembling')
|
||||||
|
main.add_command(update_ocr_characters_cli, 'generate_or_update_cnn_rnn_ocr_character_list')
|
||||||
|
|
|
||||||
|
|
@ -50,6 +50,18 @@ from ..utils import is_image_filename
|
||||||
is_flag=True,
|
is_flag=True,
|
||||||
help="if this parameter set to true, cropped textline images will not be masked with textline contour.",
|
help="if this parameter set to true, cropped textline images will not be masked with textline contour.",
|
||||||
)
|
)
|
||||||
|
@click.option(
|
||||||
|
"--exclude_vertical_lines",
|
||||||
|
"-exv",
|
||||||
|
is_flag=True,
|
||||||
|
help="if this parameter set to true, vertical textline images will be excluded.",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--page_alto",
|
||||||
|
"-alto",
|
||||||
|
is_flag=True,
|
||||||
|
help="If this parameter is set to True, text line image cropping and text extraction are performed using PAGE/ALTO files. Otherwise, the default method for PAGE XML files is used.",
|
||||||
|
)
|
||||||
def linegt_cli(
|
def linegt_cli(
|
||||||
image,
|
image,
|
||||||
dir_in,
|
dir_in,
|
||||||
|
|
@ -57,6 +69,8 @@ def linegt_cli(
|
||||||
dir_out,
|
dir_out,
|
||||||
pref_of_dataset,
|
pref_of_dataset,
|
||||||
do_not_mask_with_textline_contour,
|
do_not_mask_with_textline_contour,
|
||||||
|
exclude_vertical_lines,
|
||||||
|
page_alto,
|
||||||
):
|
):
|
||||||
assert bool(dir_in) ^ bool(image), "Set --dir-in or --image-filename, not both"
|
assert bool(dir_in) ^ bool(image), "Set --dir-in or --image-filename, not both"
|
||||||
if dir_in:
|
if dir_in:
|
||||||
|
|
@ -70,65 +84,149 @@ def linegt_cli(
|
||||||
for dir_img in ls_imgs:
|
for dir_img in ls_imgs:
|
||||||
file_name = Path(dir_img).stem
|
file_name = Path(dir_img).stem
|
||||||
dir_xml = os.path.join(dir_xmls, file_name + '.xml')
|
dir_xml = os.path.join(dir_xmls, file_name + '.xml')
|
||||||
|
|
||||||
img = cv2.imread(dir_img)
|
img = cv2.imread(dir_img)
|
||||||
|
|
||||||
|
if page_alto:
|
||||||
|
h, w = img.shape[:2]
|
||||||
|
|
||||||
|
tree = ET.parse(dir_xml)
|
||||||
|
root = tree.getroot()
|
||||||
|
|
||||||
total_bb_coordinates = []
|
NS = {"alto": "http://www.loc.gov/standards/alto/ns-v4#"}
|
||||||
|
|
||||||
tree1 = ET.parse(dir_xml, parser=ET.XMLParser(encoding="utf-8"))
|
results = []
|
||||||
root1 = tree1.getroot()
|
|
||||||
alltags = [elem.tag for elem in root1.iter()]
|
indexer_textlines = 0
|
||||||
|
for line in root.findall(".//alto:TextLine", NS):
|
||||||
|
string_el = line.find("alto:String", NS)
|
||||||
|
textline_text = string_el.attrib["CONTENT"] if string_el is not None else None
|
||||||
|
|
||||||
name_space = alltags[0].split('}')[0]
|
polygon_el = line.find("alto:Shape/alto:Polygon", NS)
|
||||||
name_space = name_space.split('{')[1]
|
if polygon_el is None:
|
||||||
|
continue
|
||||||
|
|
||||||
region_tags = np.unique([x for x in alltags if x.endswith('TextRegion')])
|
points = polygon_el.attrib["POINTS"].split()
|
||||||
|
coords = [
|
||||||
|
(int(points[i]), int(points[i + 1]))
|
||||||
|
for i in range(0, len(points), 2)
|
||||||
|
]
|
||||||
|
|
||||||
|
coords = np.array(coords, dtype=np.int32)
|
||||||
|
x, y, w, h = cv2.boundingRect(coords)
|
||||||
|
|
||||||
|
|
||||||
|
if exclude_vertical_lines and h > 1.4 * w:
|
||||||
|
img_crop = None
|
||||||
|
continue
|
||||||
|
|
||||||
|
img_poly_on_img = np.copy(img)
|
||||||
|
|
||||||
cropped_lines_region_indexer = []
|
mask_poly = np.zeros(img.shape)
|
||||||
|
mask_poly = cv2.fillPoly(mask_poly, pts=[coords], color=(1, 1, 1))
|
||||||
|
|
||||||
indexer_text_region = 0
|
mask_poly = mask_poly[y : y + h, x : x + w, :]
|
||||||
indexer_textlines = 0
|
img_crop = img_poly_on_img[y : y + h, x : x + w, :]
|
||||||
# FIXME: non recursive, use OCR-D PAGE generateDS API. Or use an existing tool for this purpose altogether
|
|
||||||
for nn in root1.iter(region_tags):
|
|
||||||
for child_textregion in nn:
|
|
||||||
if child_textregion.tag.endswith("TextLine"):
|
|
||||||
for child_textlines in child_textregion:
|
|
||||||
if child_textlines.tag.endswith("Coords"):
|
|
||||||
cropped_lines_region_indexer.append(indexer_text_region)
|
|
||||||
p_h = child_textlines.attrib['points'].split(' ')
|
|
||||||
textline_coords = np.array([[int(x.split(',')[0]), int(x.split(',')[1])] for x in p_h])
|
|
||||||
|
|
||||||
x, y, w, h = cv2.boundingRect(textline_coords)
|
if not do_not_mask_with_textline_contour:
|
||||||
|
img_crop[mask_poly == 0] = 255
|
||||||
|
|
||||||
total_bb_coordinates.append([x, y, w, h])
|
if img_crop.shape[0] == 0 or img_crop.shape[1] == 0:
|
||||||
|
img_crop = None
|
||||||
|
continue
|
||||||
|
|
||||||
|
if textline_text and img_crop is not None:
|
||||||
|
base_name = os.path.join(
|
||||||
|
dir_out, file_name + '_line_' + str(indexer_textlines)
|
||||||
|
)
|
||||||
|
if pref_of_dataset:
|
||||||
|
base_name += '_' + pref_of_dataset
|
||||||
|
if not do_not_mask_with_textline_contour:
|
||||||
|
base_name += '_masked'
|
||||||
|
|
||||||
img_poly_on_img = np.copy(img)
|
with open(base_name + '.txt', 'w') as text_file:
|
||||||
|
text_file.write(textline_text)
|
||||||
|
cv2.imwrite(base_name + '.png', img_crop)
|
||||||
|
indexer_textlines += 1
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
mask_poly = np.zeros(img.shape)
|
|
||||||
mask_poly = cv2.fillPoly(mask_poly, pts=[textline_coords], color=(1, 1, 1))
|
|
||||||
|
|
||||||
mask_poly = mask_poly[y : y + h, x : x + w, :]
|
|
||||||
img_crop = img_poly_on_img[y : y + h, x : x + w, :]
|
|
||||||
|
|
||||||
if not do_not_mask_with_textline_contour:
|
|
||||||
img_crop[mask_poly == 0] = 255
|
|
||||||
|
|
||||||
if img_crop.shape[0] == 0 or img_crop.shape[1] == 0:
|
|
||||||
continue
|
|
||||||
if child_textlines.tag.endswith("TextEquiv"):
|
|
||||||
for cheild_text in child_textlines:
|
|
||||||
if cheild_text.tag.endswith("Unicode"):
|
|
||||||
textline_text = cheild_text.text
|
|
||||||
if textline_text:
|
|
||||||
base_name = os.path.join(
|
|
||||||
dir_out, file_name + '_line_' + str(indexer_textlines)
|
|
||||||
)
|
|
||||||
if pref_of_dataset:
|
|
||||||
base_name += '_' + pref_of_dataset
|
|
||||||
if not do_not_mask_with_textline_contour:
|
|
||||||
base_name += '_masked'
|
|
||||||
|
|
||||||
with open(base_name + '.txt', 'w') as text_file:
|
|
||||||
text_file.write(textline_text)
|
|
||||||
cv2.imwrite(base_name + '.png', img_crop)
|
|
||||||
indexer_textlines += 1
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
else:
|
||||||
|
total_bb_coordinates = []
|
||||||
|
|
||||||
|
tree = ET.parse(dir_xml, parser=ET.XMLParser(encoding="utf-8"))
|
||||||
|
root = tree.getroot()
|
||||||
|
alltags = [elem.tag for elem in root.iter()]
|
||||||
|
|
||||||
|
name_space = alltags[0].split('}')[0]
|
||||||
|
name_space = name_space.split('{')[1]
|
||||||
|
|
||||||
|
region_tags = np.unique([x for x in alltags if x.endswith('TextRegion')])
|
||||||
|
|
||||||
|
cropped_lines_region_indexer = []
|
||||||
|
|
||||||
|
indexer_text_region = 0
|
||||||
|
indexer_textlines = 0
|
||||||
|
# FIXME: non recursive, use OCR-D PAGE generateDS API. Or use an existing tool for this purpose altogether
|
||||||
|
for nn in root.iter(region_tags):
|
||||||
|
for child_textregion in nn:
|
||||||
|
if child_textregion.tag.endswith("TextLine"):
|
||||||
|
for child_textlines in child_textregion:
|
||||||
|
if child_textlines.tag.endswith("Coords"):
|
||||||
|
cropped_lines_region_indexer.append(indexer_text_region)
|
||||||
|
p_h = child_textlines.attrib['points'].split(' ')
|
||||||
|
textline_coords = np.array([[int(x.split(',')[0]), int(x.split(',')[1])] for x in p_h])
|
||||||
|
|
||||||
|
x, y, w, h = cv2.boundingRect(textline_coords)
|
||||||
|
|
||||||
|
if exclude_vertical_lines and h > 1.4 * w:
|
||||||
|
img_crop = None
|
||||||
|
continue
|
||||||
|
|
||||||
|
total_bb_coordinates.append([x, y, w, h])
|
||||||
|
|
||||||
|
img_poly_on_img = np.copy(img)
|
||||||
|
|
||||||
|
mask_poly = np.zeros(img.shape)
|
||||||
|
mask_poly = cv2.fillPoly(mask_poly, pts=[textline_coords], color=(1, 1, 1))
|
||||||
|
|
||||||
|
mask_poly = mask_poly[y : y + h, x : x + w, :]
|
||||||
|
img_crop = img_poly_on_img[y : y + h, x : x + w, :]
|
||||||
|
|
||||||
|
if not do_not_mask_with_textline_contour:
|
||||||
|
img_crop[mask_poly == 0] = 255
|
||||||
|
|
||||||
|
if img_crop.shape[0] == 0 or img_crop.shape[1] == 0:
|
||||||
|
img_crop = None
|
||||||
|
continue
|
||||||
|
|
||||||
|
|
||||||
|
if child_textlines.tag.endswith("TextEquiv"):
|
||||||
|
for cheild_text in child_textlines:
|
||||||
|
if cheild_text.tag.endswith("Unicode"):
|
||||||
|
textline_text = cheild_text.text
|
||||||
|
if textline_text and img_crop is not None:
|
||||||
|
base_name = os.path.join(
|
||||||
|
dir_out, file_name + '_line_' + str(indexer_textlines)
|
||||||
|
)
|
||||||
|
if pref_of_dataset:
|
||||||
|
base_name += '_' + pref_of_dataset
|
||||||
|
if not do_not_mask_with_textline_contour:
|
||||||
|
base_name += '_masked'
|
||||||
|
|
||||||
|
with open(base_name + '.txt', 'w') as text_file:
|
||||||
|
text_file.write(textline_text)
|
||||||
|
cv2.imwrite(base_name + '.png', img_crop)
|
||||||
|
indexer_textlines += 1
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ from pathlib import Path
|
||||||
from PIL import Image, ImageDraw, ImageFont
|
from PIL import Image, ImageDraw, ImageFont
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from eynollah.utils.font import get_font
|
||||||
|
|
||||||
from eynollah.training.gt_gen_utils import (
|
from eynollah.training.gt_gen_utils import (
|
||||||
filter_contours_area_of_image,
|
filter_contours_area_of_image,
|
||||||
|
|
@ -477,7 +478,7 @@ def visualize_layout_segmentation(xml_file, dir_xml, dir_out, dir_imgs):
|
||||||
co_text, co_graphic, co_sep, co_img, co_table, co_map, co_noise, y_len, x_len = get_layout_contours_for_visualization(xml_file)
|
co_text, co_graphic, co_sep, co_img, co_table, co_map, co_noise, y_len, x_len = get_layout_contours_for_visualization(xml_file)
|
||||||
|
|
||||||
|
|
||||||
added_image = visualize_image_from_contours_layout(co_text['paragraph'], co_text['header']+co_text['heading'], co_text['drop-capital'], co_sep, co_img, co_text['marginalia'], co_table, img)
|
added_image = visualize_image_from_contours_layout(co_text['paragraph'], co_text['header']+co_text['heading'], co_text['drop-capital'], co_sep, co_img, co_text['marginalia'], co_table, co_map, img)
|
||||||
|
|
||||||
cv2.imwrite(os.path.join(dir_out, f_name+'.png'), added_image)
|
cv2.imwrite(os.path.join(dir_out, f_name+'.png'), added_image)
|
||||||
|
|
||||||
|
|
@ -514,8 +515,8 @@ def visualize_ocr_text(xml_file, dir_xml, dir_out):
|
||||||
else:
|
else:
|
||||||
xml_files_ind = [xml_file]
|
xml_files_ind = [xml_file]
|
||||||
|
|
||||||
font_path = "Charis-7.000/Charis-Regular.ttf" # Make sure this file exists!
|
###font_path = "Charis-7.000/Charis-Regular.ttf" # Make sure this file exists!
|
||||||
font = ImageFont.truetype(font_path, 40)
|
font = get_font(font_size=40)#ImageFont.truetype(font_path, 40)
|
||||||
|
|
||||||
for ind_xml in tqdm(xml_files_ind):
|
for ind_xml in tqdm(xml_files_ind):
|
||||||
indexer = 0
|
indexer = 0
|
||||||
|
|
@ -552,11 +553,11 @@ def visualize_ocr_text(xml_file, dir_xml, dir_out):
|
||||||
|
|
||||||
|
|
||||||
is_vertical = h > 2*w # Check orientation
|
is_vertical = h > 2*w # Check orientation
|
||||||
font = fit_text_single_line(draw, ocr_texts[index], font_path, w, int(h*0.4) )
|
font = fit_text_single_line(draw, ocr_texts[index], w, int(h*0.4) )
|
||||||
|
|
||||||
if is_vertical:
|
if is_vertical:
|
||||||
|
|
||||||
vertical_font = fit_text_single_line(draw, ocr_texts[index], font_path, h, int(w * 0.8))
|
vertical_font = fit_text_single_line(draw, ocr_texts[index], h, int(w * 0.8))
|
||||||
|
|
||||||
text_img = Image.new("RGBA", (h, w), (255, 255, 255, 0)) # Note: dimensions are swapped
|
text_img = Image.new("RGBA", (h, w), (255, 255, 255, 0)) # Note: dimensions are swapped
|
||||||
text_draw = ImageDraw.Draw(text_img)
|
text_draw = ImageDraw.Draw(text_img)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,59 @@
|
||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
import json
|
||||||
|
import click
|
||||||
|
import logging
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def run_character_list_update(dir_labels, out, current_character_list):
|
||||||
|
ls_labels = os.listdir(dir_labels)
|
||||||
|
ls_labels = [ind for ind in ls_labels if ind.endswith('.txt')]
|
||||||
|
|
||||||
|
if current_character_list:
|
||||||
|
with open(current_character_list, 'r') as f_name:
|
||||||
|
characters = json.load(f_name)
|
||||||
|
|
||||||
|
characters = set(characters)
|
||||||
|
else:
|
||||||
|
characters = set()
|
||||||
|
|
||||||
|
|
||||||
|
for ind in ls_labels:
|
||||||
|
label = open(os.path.join(dir_labels,ind),'r').read().split('\n')[0]
|
||||||
|
|
||||||
|
for char in label:
|
||||||
|
characters.add(char)
|
||||||
|
|
||||||
|
|
||||||
|
characters = sorted(list(set(characters)))
|
||||||
|
|
||||||
|
with open(out, 'w') as f_name:
|
||||||
|
json.dump(characters, f_name)
|
||||||
|
|
||||||
|
|
||||||
|
@click.command()
|
||||||
|
@click.option(
|
||||||
|
"--dir_labels",
|
||||||
|
"-dl",
|
||||||
|
help="directory of labels which are .txt files",
|
||||||
|
type=click.Path(exists=True, file_okay=False),
|
||||||
|
required=True,
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--current_character_list",
|
||||||
|
"-ccl",
|
||||||
|
help="existing character list in a .txt file that needs to be updated with a set of labels",
|
||||||
|
type=click.Path(exists=True, file_okay=True),
|
||||||
|
required=False,
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--out",
|
||||||
|
"-o",
|
||||||
|
help="An output .txt file where the generated or updated character list will be written",
|
||||||
|
type=click.Path(exists=False, file_okay=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
def main(dir_labels, out, current_character_list):
|
||||||
|
run_character_list_update(dir_labels, out, current_character_list)
|
||||||
|
|
||||||
|
|
@ -7,7 +7,7 @@ import cv2
|
||||||
from shapely import geometry
|
from shapely import geometry
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from PIL import ImageFont
|
from PIL import ImageFont
|
||||||
|
from eynollah.utils.font import get_font
|
||||||
|
|
||||||
KERNEL = np.ones((5, 5), np.uint8)
|
KERNEL = np.ones((5, 5), np.uint8)
|
||||||
|
|
||||||
|
|
@ -350,11 +350,11 @@ def get_textline_contours_and_ocr_text(xml_file):
|
||||||
ocr_textlines.append(ocr_text_in[0])
|
ocr_textlines.append(ocr_text_in[0])
|
||||||
return co_use_case, y_len, x_len, ocr_textlines
|
return co_use_case, y_len, x_len, ocr_textlines
|
||||||
|
|
||||||
def fit_text_single_line(draw, text, font_path, max_width, max_height):
|
def fit_text_single_line(draw, text, max_width, max_height):
|
||||||
initial_font_size = 50
|
initial_font_size = 50
|
||||||
font_size = initial_font_size
|
font_size = initial_font_size
|
||||||
while font_size > 10: # Minimum font size
|
while font_size > 10: # Minimum font size
|
||||||
font = ImageFont.truetype(font_path, font_size)
|
font = get_font(font_size=font_size)# ImageFont.truetype(font_path, font_size)
|
||||||
text_bbox = draw.textbbox((0, 0), text, font=font) # Get text bounding box
|
text_bbox = draw.textbbox((0, 0), text, font=font) # Get text bounding box
|
||||||
text_width = text_bbox[2] - text_bbox[0]
|
text_width = text_bbox[2] - text_bbox[0]
|
||||||
text_height = text_bbox[3] - text_bbox[1]
|
text_height = text_bbox[3] - text_bbox[1]
|
||||||
|
|
@ -364,7 +364,7 @@ def fit_text_single_line(draw, text, font_path, max_width, max_height):
|
||||||
|
|
||||||
font_size -= 2 # Reduce font size and retry
|
font_size -= 2 # Reduce font size and retry
|
||||||
|
|
||||||
return ImageFont.truetype(font_path, 10) # Smallest font fallback
|
return get_font(font_size=10)#ImageFont.truetype(font_path, 10) # Smallest font fallback
|
||||||
|
|
||||||
def get_layout_contours_for_visualization(xml_file):
|
def get_layout_contours_for_visualization(xml_file):
|
||||||
tree1 = ET.parse(xml_file, parser = ET.XMLParser(encoding='utf-8'))
|
tree1 = ET.parse(xml_file, parser = ET.XMLParser(encoding='utf-8'))
|
||||||
|
|
|
||||||
|
|
@ -12,6 +12,7 @@ from keras.models import Model, load_model
|
||||||
from keras import backend as K
|
from keras import backend as K
|
||||||
import click
|
import click
|
||||||
from tensorflow.python.keras import backend as tensorflow_backend
|
from tensorflow.python.keras import backend as tensorflow_backend
|
||||||
|
from tensorflow.keras.layers import StringLookup
|
||||||
import xml.etree.ElementTree as ET
|
import xml.etree.ElementTree as ET
|
||||||
|
|
||||||
from .gt_gen_utils import (
|
from .gt_gen_utils import (
|
||||||
|
|
@ -169,6 +170,25 @@ class sbb_predict:
|
||||||
self.model = tf.keras.models.Model(
|
self.model = tf.keras.models.Model(
|
||||||
self.model.get_layer(name = "image").input,
|
self.model.get_layer(name = "image").input,
|
||||||
self.model.get_layer(name = "dense2").output)
|
self.model.get_layer(name = "dense2").output)
|
||||||
|
|
||||||
|
assert isinstance(self.model, Model)
|
||||||
|
|
||||||
|
elif self.task == "transformer-ocr":
|
||||||
|
import torch
|
||||||
|
from transformers import VisionEncoderDecoderModel
|
||||||
|
from transformers import TrOCRProcessor
|
||||||
|
|
||||||
|
self.model = VisionEncoderDecoderModel.from_pretrained(self.model_dir)
|
||||||
|
self.processor = TrOCRProcessor.from_pretrained(self.model_dir)
|
||||||
|
|
||||||
|
if self.cpu:
|
||||||
|
self.device = torch.device('cpu')
|
||||||
|
else:
|
||||||
|
self.device = torch.device('cuda:0')
|
||||||
|
|
||||||
|
self.model.to(self.device)
|
||||||
|
|
||||||
|
assert isinstance(self.model, torch.nn.Module)
|
||||||
else:
|
else:
|
||||||
config = tf.compat.v1.ConfigProto()
|
config = tf.compat.v1.ConfigProto()
|
||||||
config.gpu_options.allow_growth = True
|
config.gpu_options.allow_growth = True
|
||||||
|
|
@ -176,15 +196,15 @@ class sbb_predict:
|
||||||
session = tf.compat.v1.Session(config=config) # tf.InteractiveSession()
|
session = tf.compat.v1.Session(config=config) # tf.InteractiveSession()
|
||||||
tensorflow_backend.set_session(session)
|
tensorflow_backend.set_session(session)
|
||||||
|
|
||||||
|
self.model = load_model(self.model_dir , compile=False,custom_objects = {"PatchEncoder": PatchEncoder, "Patches": Patches})
|
||||||
##if self.weights_dir!=None:
|
|
||||||
##self.model.load_weights(self.weights_dir)
|
if self.task != 'classification' and self.task != 'reading_order':
|
||||||
|
self.img_height=self.model.layers[len(self.model.layers)-1].output_shape[1]
|
||||||
|
self.img_width=self.model.layers[len(self.model.layers)-1].output_shape[2]
|
||||||
|
self.n_classes=self.model.layers[len(self.model.layers)-1].output_shape[3]
|
||||||
|
|
||||||
assert isinstance(self.model, Model)
|
|
||||||
if self.task != 'classification' and self.task != 'reading_order':
|
assert isinstance(self.model, Model)
|
||||||
self.img_height=self.model.layers[len(self.model.layers)-1].output_shape[1]
|
|
||||||
self.img_width=self.model.layers[len(self.model.layers)-1].output_shape[2]
|
|
||||||
self.n_classes=self.model.layers[len(self.model.layers)-1].output_shape[3]
|
|
||||||
|
|
||||||
def visualize_model_output(self, prediction, img, task) -> Tuple[NDArray, NDArray]:
|
def visualize_model_output(self, prediction, img, task) -> Tuple[NDArray, NDArray]:
|
||||||
if task == "binarization":
|
if task == "binarization":
|
||||||
|
|
@ -235,10 +255,9 @@ class sbb_predict:
|
||||||
return added_image, layout_only
|
return added_image, layout_only
|
||||||
|
|
||||||
def predict(self, image_dir):
|
def predict(self, image_dir):
|
||||||
assert isinstance(self.model, Model)
|
|
||||||
if self.task == 'classification':
|
if self.task == 'classification':
|
||||||
classes_names = self.config_params_model['classification_classes_name']
|
classes_names = self.config_params_model['classification_classes_name']
|
||||||
img_1ch = img=cv2.imread(image_dir, 0)
|
img_1ch =cv2.imread(image_dir, 0)
|
||||||
|
|
||||||
img_1ch = img_1ch / 255.0
|
img_1ch = img_1ch / 255.0
|
||||||
img_1ch = cv2.resize(img_1ch, (self.config_params_model['input_height'], self.config_params_model['input_width']), interpolation=cv2.INTER_NEAREST)
|
img_1ch = cv2.resize(img_1ch, (self.config_params_model['input_height'], self.config_params_model['input_width']), interpolation=cv2.INTER_NEAREST)
|
||||||
|
|
@ -273,6 +292,15 @@ class sbb_predict:
|
||||||
pred_texts = decode_batch_predictions(preds, num_to_char)
|
pred_texts = decode_batch_predictions(preds, num_to_char)
|
||||||
pred_texts = pred_texts[0].replace("[UNK]", "")
|
pred_texts = pred_texts[0].replace("[UNK]", "")
|
||||||
return pred_texts
|
return pred_texts
|
||||||
|
|
||||||
|
elif self.task == "transformer-ocr":
|
||||||
|
from PIL import Image
|
||||||
|
image = Image.open(image_dir).convert("RGB")
|
||||||
|
pixel_values = self.processor(image, return_tensors="pt").pixel_values
|
||||||
|
generated_ids = self.model.generate(pixel_values.to(self.device))
|
||||||
|
return self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
elif self.task == 'reading_order':
|
elif self.task == 'reading_order':
|
||||||
|
|
@ -607,6 +635,8 @@ class sbb_predict:
|
||||||
cv2.imwrite(self.save,res)
|
cv2.imwrite(self.save,res)
|
||||||
elif self.task == "cnn-rnn-ocr":
|
elif self.task == "cnn-rnn-ocr":
|
||||||
print(f"Detected text: {res}")
|
print(f"Detected text: {res}")
|
||||||
|
elif self.task == "transformer-ocr":
|
||||||
|
print(f"Detected text: {res}")
|
||||||
else:
|
else:
|
||||||
img_seg_overlayed, only_layout = self.visualize_model_output(res, self.img_org, self.task)
|
img_seg_overlayed, only_layout = self.visualize_model_output(res, self.img_org, self.task)
|
||||||
if self.save:
|
if self.save:
|
||||||
|
|
@ -710,10 +740,12 @@ class sbb_predict:
|
||||||
)
|
)
|
||||||
def main(image, dir_in, model, patches, save, save_layout, ground_truth, xml_file, cpu, out, min_area):
|
def main(image, dir_in, model, patches, save, save_layout, ground_truth, xml_file, cpu, out, min_area):
|
||||||
assert image or dir_in, "Either a single image -i or a dir_in -di is required"
|
assert image or dir_in, "Either a single image -i or a dir_in -di is required"
|
||||||
with open(os.path.join(model,'config.json')) as f:
|
|
||||||
|
with open(os.path.join(model,'config_eynollah.json')) as f:
|
||||||
config_params_model = json.load(f)
|
config_params_model = json.load(f)
|
||||||
|
|
||||||
task = config_params_model['task']
|
task = config_params_model['task']
|
||||||
if task != 'classification' and task != 'reading_order' and task != "cnn-rnn-ocr":
|
if task != 'classification' and task != 'reading_order' and task != "cnn-rnn-ocr" and task != "transformer-ocr":
|
||||||
if image and not save:
|
if image and not save:
|
||||||
print("Error: You used one of segmentation or binarization task with image input but not set -s, you need a filename to save visualized output with -s")
|
print("Error: You used one of segmentation or binarization task with image input but not set -s, you need a filename to save visualized output with -s")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,8 @@
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import json
|
import json
|
||||||
|
import numpy as np
|
||||||
|
import cv2
|
||||||
import click
|
import click
|
||||||
|
|
||||||
from eynollah.training.metrics import (
|
from eynollah.training.metrics import (
|
||||||
|
|
@ -27,7 +28,8 @@ from eynollah.training.utils import (
|
||||||
generate_data_from_folder_training,
|
generate_data_from_folder_training,
|
||||||
get_one_hot,
|
get_one_hot,
|
||||||
provide_patches,
|
provide_patches,
|
||||||
return_number_of_total_training_data
|
return_number_of_total_training_data,
|
||||||
|
OCRDatasetYieldAugmentations
|
||||||
)
|
)
|
||||||
|
|
||||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
||||||
|
|
@ -41,11 +43,16 @@ from sklearn.metrics import f1_score
|
||||||
from tensorflow.keras.callbacks import Callback
|
from tensorflow.keras.callbacks import Callback
|
||||||
from tensorflow.keras.layers import StringLookup
|
from tensorflow.keras.layers import StringLookup
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import cv2
|
import torch
|
||||||
|
from transformers import TrOCRProcessor
|
||||||
|
import evaluate
|
||||||
|
from transformers import default_data_collator
|
||||||
|
from transformers import VisionEncoderDecoderModel
|
||||||
|
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
|
||||||
|
|
||||||
class SaveWeightsAfterSteps(Callback):
|
class SaveWeightsAfterSteps(Callback):
|
||||||
def __init__(self, save_interval, save_path, _config):
|
def __init__(self, save_interval, save_path, _config, characters_cnnrnn_ocr=None):
|
||||||
super(SaveWeightsAfterSteps, self).__init__()
|
super(SaveWeightsAfterSteps, self).__init__()
|
||||||
self.save_interval = save_interval
|
self.save_interval = save_interval
|
||||||
self.save_path = save_path
|
self.save_path = save_path
|
||||||
|
|
@ -61,7 +68,10 @@ class SaveWeightsAfterSteps(Callback):
|
||||||
|
|
||||||
self.model.save(save_file)
|
self.model.save(save_file)
|
||||||
|
|
||||||
with open(os.path.join(os.path.join(self.save_path, f"model_step_{self.step_count}"),"config.json"), "w") as fp:
|
if characters_cnnrnn_ocr:
|
||||||
|
os.system("cp "+characters_cnnrnn_ocr+" "+os.path.join(os.path.join(self.save_path, f"model_step_{self.step_count}"),"characters_org.txt"))
|
||||||
|
|
||||||
|
with open(os.path.join(os.path.join(self.save_path, f"model_step_{self.step_count}"),"config_eynollah.json"), "w") as fp:
|
||||||
json.dump(self._config, fp) # encode dict into JSON
|
json.dump(self._config, fp) # encode dict into JSON
|
||||||
print(f"saved model as steps {self.step_count} to {save_file}")
|
print(f"saved model as steps {self.step_count} to {save_file}")
|
||||||
|
|
||||||
|
|
@ -477,7 +487,7 @@ def run(
|
||||||
|
|
||||||
model.save(os.path.join(dir_output,'model_'+str(i)))
|
model.save(os.path.join(dir_output,'model_'+str(i)))
|
||||||
|
|
||||||
with open(os.path.join(os.path.join(dir_output,'model_'+str(i)),"config.json"), "w") as fp:
|
with open(os.path.join(os.path.join(dir_output,'model_'+str(i)),"config_eynollah.json"), "w") as fp:
|
||||||
json.dump(_config, fp) # encode dict into JSON
|
json.dump(_config, fp) # encode dict into JSON
|
||||||
|
|
||||||
#os.system('rm -rf '+dir_train_flowing)
|
#os.system('rm -rf '+dir_train_flowing)
|
||||||
|
|
@ -537,7 +547,7 @@ def run(
|
||||||
opt = tf.keras.optimizers.Adam(learning_rate=learning_rate)#1e-4)#(lr_schedule)
|
opt = tf.keras.optimizers.Adam(learning_rate=learning_rate)#1e-4)#(lr_schedule)
|
||||||
model.compile(optimizer=opt)
|
model.compile(optimizer=opt)
|
||||||
|
|
||||||
save_weights_callback = SaveWeightsAfterSteps(save_interval, dir_output, _config) if save_interval else None
|
save_weights_callback = SaveWeightsAfterSteps(save_interval, dir_output, _config, characters_cnnrnn_ocr=characters_txt_file) if save_interval else None
|
||||||
|
|
||||||
for i in tqdm(range(index_start, n_epochs + index_start)):
|
for i in tqdm(range(index_start, n_epochs + index_start)):
|
||||||
if save_interval:
|
if save_interval:
|
||||||
|
|
@ -556,9 +566,125 @@ def run(
|
||||||
|
|
||||||
if i >=0:
|
if i >=0:
|
||||||
model.save( os.path.join(dir_output,'model_'+str(i) ))
|
model.save( os.path.join(dir_output,'model_'+str(i) ))
|
||||||
with open(os.path.join(os.path.join(dir_output,'model_'+str(i)),"config.json"), "w") as fp:
|
os.system("cp "+characters_txt_file+" "+os.path.join(os.path.join(dir_output,'model_'+str(i)),"characters_org.txt"))
|
||||||
|
with open(os.path.join(os.path.join(dir_output,'model_'+str(i)),"config_eynollah.json"), "w") as fp:
|
||||||
json.dump(_config, fp) # encode dict into JSON
|
json.dump(_config, fp) # encode dict into JSON
|
||||||
|
|
||||||
|
|
||||||
|
elif task=="transformer-ocr":
|
||||||
|
dir_img, dir_lab = get_dirs_or_files(dir_train)
|
||||||
|
|
||||||
|
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-printed")
|
||||||
|
|
||||||
|
ls_files_images = os.listdir(dir_img)
|
||||||
|
|
||||||
|
aug_multip = return_multiplier_based_on_augmnentations(augmentation, color_padding_rotation, rotation_not_90, blur_aug, degrading, bin_deg,
|
||||||
|
brightening, padding_white, adding_rgb_foreground, adding_rgb_background, binarization,
|
||||||
|
image_inversion, channels_shuffling, add_red_textlines, white_noise_strap, textline_skewing, textline_skewing_bin, textline_left_in_depth, textline_left_in_depth_bin, textline_right_in_depth, textline_right_in_depth_bin, textline_up_in_depth, textline_up_in_depth_bin, textline_down_in_depth, textline_down_in_depth_bin, pepper_bin_aug, pepper_aug, degrade_scales, number_of_backgrounds_per_image, thetha, thetha_padd, brightness, padd_colors, shuffle_indexes, pepper_indexes, skewing_amplitudes, blur_k, white_padds)
|
||||||
|
|
||||||
|
len_dataset = aug_multip*len(ls_files_images)
|
||||||
|
|
||||||
|
dataset = OCRDatasetYieldAugmentations(
|
||||||
|
dir_img=dir_img,
|
||||||
|
dir_img_bin=dir_img_bin,
|
||||||
|
dir_lab=dir_lab,
|
||||||
|
processor=processor,
|
||||||
|
max_target_length=max_len,
|
||||||
|
augmentation = augmentation,
|
||||||
|
binarization = binarization,
|
||||||
|
add_red_textlines = add_red_textlines,
|
||||||
|
white_noise_strap = white_noise_strap,
|
||||||
|
adding_rgb_foreground = adding_rgb_foreground,
|
||||||
|
adding_rgb_background = adding_rgb_background,
|
||||||
|
bin_deg = bin_deg,
|
||||||
|
blur_aug = blur_aug,
|
||||||
|
brightening = brightening,
|
||||||
|
padding_white = padding_white,
|
||||||
|
color_padding_rotation = color_padding_rotation,
|
||||||
|
rotation_not_90 = rotation_not_90,
|
||||||
|
degrading = degrading,
|
||||||
|
channels_shuffling = channels_shuffling,
|
||||||
|
textline_skewing = textline_skewing,
|
||||||
|
textline_skewing_bin = textline_skewing_bin,
|
||||||
|
textline_right_in_depth = textline_right_in_depth,
|
||||||
|
textline_left_in_depth = textline_left_in_depth,
|
||||||
|
textline_up_in_depth = textline_up_in_depth,
|
||||||
|
textline_down_in_depth = textline_down_in_depth,
|
||||||
|
textline_right_in_depth_bin = textline_right_in_depth_bin,
|
||||||
|
textline_left_in_depth_bin = textline_left_in_depth_bin,
|
||||||
|
textline_up_in_depth_bin = textline_up_in_depth_bin,
|
||||||
|
textline_down_in_depth_bin = textline_down_in_depth_bin,
|
||||||
|
pepper_aug = pepper_aug,
|
||||||
|
pepper_bin_aug = pepper_bin_aug,
|
||||||
|
list_all_possible_background_images=list_all_possible_background_images,
|
||||||
|
list_all_possible_foreground_rgbs=list_all_possible_foreground_rgbs,
|
||||||
|
blur_k = blur_k,
|
||||||
|
degrade_scales = degrade_scales,
|
||||||
|
white_padds = white_padds,
|
||||||
|
thetha_padd = thetha_padd,
|
||||||
|
thetha = thetha,
|
||||||
|
brightness = brightness,
|
||||||
|
padd_colors = padd_colors,
|
||||||
|
number_of_backgrounds_per_image = number_of_backgrounds_per_image,
|
||||||
|
shuffle_indexes = shuffle_indexes,
|
||||||
|
pepper_indexes = pepper_indexes,
|
||||||
|
skewing_amplitudes = skewing_amplitudes,
|
||||||
|
dir_rgb_backgrounds = dir_rgb_backgrounds,
|
||||||
|
dir_rgb_foregrounds = dir_rgb_foregrounds,
|
||||||
|
len_data=len_dataset,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a DataLoader
|
||||||
|
data_loader = torch.utils.data.DataLoader(dataset, batch_size=1)
|
||||||
|
train_dataset = data_loader.dataset
|
||||||
|
|
||||||
|
|
||||||
|
if continue_training:
|
||||||
|
model = VisionEncoderDecoderModel.from_pretrained(dir_of_start_model)
|
||||||
|
else:
|
||||||
|
model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-printed")
|
||||||
|
|
||||||
|
|
||||||
|
# set special tokens used for creating the decoder_input_ids from the labels
|
||||||
|
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
|
||||||
|
model.config.pad_token_id = processor.tokenizer.pad_token_id
|
||||||
|
# make sure vocab size is set correctly
|
||||||
|
model.config.vocab_size = model.config.decoder.vocab_size
|
||||||
|
|
||||||
|
# set beam search parameters
|
||||||
|
model.config.eos_token_id = processor.tokenizer.sep_token_id
|
||||||
|
model.config.max_length = max_len
|
||||||
|
model.config.early_stopping = True
|
||||||
|
model.config.no_repeat_ngram_size = 3
|
||||||
|
model.config.length_penalty = 2.0
|
||||||
|
model.config.num_beams = 4
|
||||||
|
|
||||||
|
|
||||||
|
training_args = Seq2SeqTrainingArguments(
|
||||||
|
predict_with_generate=True,
|
||||||
|
num_train_epochs=n_epochs,
|
||||||
|
learning_rate=learning_rate,
|
||||||
|
per_device_train_batch_size=n_batch,
|
||||||
|
fp16=True,
|
||||||
|
output_dir=dir_output,
|
||||||
|
logging_steps=2,
|
||||||
|
save_steps=save_interval,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
cer_metric = evaluate.load("cer")
|
||||||
|
|
||||||
|
# instantiate trainer
|
||||||
|
trainer = Seq2SeqTrainer(
|
||||||
|
model=model,
|
||||||
|
tokenizer=processor.feature_extractor,
|
||||||
|
args=training_args,
|
||||||
|
train_dataset=train_dataset,
|
||||||
|
data_collator=default_data_collator,
|
||||||
|
)
|
||||||
|
trainer.train()
|
||||||
|
|
||||||
|
|
||||||
elif task=='classification':
|
elif task=='classification':
|
||||||
configuration()
|
configuration()
|
||||||
model = resnet50_classifier(n_classes, input_height, input_width, weight_decay, pretraining)
|
model = resnet50_classifier(n_classes, input_height, input_width, weight_decay, pretraining)
|
||||||
|
|
@ -609,10 +735,10 @@ def run(
|
||||||
model_weight_averaged.set_weights(new_weights)
|
model_weight_averaged.set_weights(new_weights)
|
||||||
|
|
||||||
model_weight_averaged.save(os.path.join(dir_output,'model_ens_avg'))
|
model_weight_averaged.save(os.path.join(dir_output,'model_ens_avg'))
|
||||||
with open(os.path.join( os.path.join(dir_output,'model_ens_avg'), "config.json"), "w") as fp:
|
with open(os.path.join( os.path.join(dir_output,'model_ens_avg'), "config_eynollah.json"), "w") as fp:
|
||||||
json.dump(_config, fp) # encode dict into JSON
|
json.dump(_config, fp) # encode dict into JSON
|
||||||
|
|
||||||
with open(os.path.join( os.path.join(dir_output,'model_best'), "config.json"), "w") as fp:
|
with open(os.path.join( os.path.join(dir_output,'model_best'), "config_eynollah.json"), "w") as fp:
|
||||||
json.dump(_config, fp) # encode dict into JSON
|
json.dump(_config, fp) # encode dict into JSON
|
||||||
|
|
||||||
elif task=='reading_order':
|
elif task=='reading_order':
|
||||||
|
|
@ -645,7 +771,7 @@ def run(
|
||||||
history = model.fit(generate_arrays_from_folder_reading_order(dir_flow_train_labels, dir_flow_train_imgs, n_batch, input_height, input_width, n_classes, thetha, augmentation), steps_per_epoch=num_rows / n_batch, verbose=1)
|
history = model.fit(generate_arrays_from_folder_reading_order(dir_flow_train_labels, dir_flow_train_imgs, n_batch, input_height, input_width, n_classes, thetha, augmentation), steps_per_epoch=num_rows / n_batch, verbose=1)
|
||||||
model.save( os.path.join(dir_output,'model_'+str(i+indexer_start) ))
|
model.save( os.path.join(dir_output,'model_'+str(i+indexer_start) ))
|
||||||
|
|
||||||
with open(os.path.join(os.path.join(dir_output,'model_'+str(i)),"config.json"), "w") as fp:
|
with open(os.path.join(os.path.join(dir_output,'model_'+str(i)),"config_eynollah.json"), "w") as fp:
|
||||||
json.dump(_config, fp) # encode dict into JSON
|
json.dump(_config, fp) # encode dict into JSON
|
||||||
'''
|
'''
|
||||||
if f1score>f1score_tot[0]:
|
if f1score>f1score_tot[0]:
|
||||||
|
|
|
||||||
|
|
@ -13,8 +13,12 @@ import tensorflow as tf
|
||||||
from tensorflow.keras.utils import to_categorical
|
from tensorflow.keras.utils import to_categorical
|
||||||
from PIL import Image, ImageFile, ImageEnhance
|
from PIL import Image, ImageFile, ImageEnhance
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import IterableDataset
|
||||||
|
|
||||||
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||||
|
|
||||||
|
|
||||||
def vectorize_label(label, char_to_num, padding_token, max_len):
|
def vectorize_label(label, char_to_num, padding_token, max_len):
|
||||||
label = char_to_num(tf.strings.unicode_split(label, input_encoding="UTF-8"))
|
label = char_to_num(tf.strings.unicode_split(label, input_encoding="UTF-8"))
|
||||||
length = tf.shape(label)[0]
|
length = tf.shape(label)[0]
|
||||||
|
|
@ -76,6 +80,7 @@ def add_salt_and_pepper_noise(img, salt_prob, pepper_prob):
|
||||||
|
|
||||||
return noisy_image
|
return noisy_image
|
||||||
|
|
||||||
|
|
||||||
def invert_image(img):
|
def invert_image(img):
|
||||||
img_inv = 255 - img
|
img_inv = 255 - img
|
||||||
return img_inv
|
return img_inv
|
||||||
|
|
@ -1668,3 +1673,411 @@ def return_multiplier_based_on_augmnentations(
|
||||||
aug_multip += len(pepper_indexes)
|
aug_multip += len(pepper_indexes)
|
||||||
|
|
||||||
return aug_multip
|
return aug_multip
|
||||||
|
|
||||||
|
|
||||||
|
class OCRDatasetYieldAugmentations(IterableDataset):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dir_img,
|
||||||
|
dir_img_bin,
|
||||||
|
dir_lab,
|
||||||
|
processor,
|
||||||
|
max_target_length=128,
|
||||||
|
augmentation = None,
|
||||||
|
binarization = None,
|
||||||
|
add_red_textlines = None,
|
||||||
|
white_noise_strap = None,
|
||||||
|
adding_rgb_foreground = None,
|
||||||
|
adding_rgb_background = None,
|
||||||
|
bin_deg = None,
|
||||||
|
blur_aug = None,
|
||||||
|
brightening = None,
|
||||||
|
padding_white = None,
|
||||||
|
color_padding_rotation = None,
|
||||||
|
rotation_not_90 = None,
|
||||||
|
degrading = None,
|
||||||
|
channels_shuffling = None,
|
||||||
|
textline_skewing = None,
|
||||||
|
textline_skewing_bin = None,
|
||||||
|
textline_right_in_depth = None,
|
||||||
|
textline_left_in_depth = None,
|
||||||
|
textline_up_in_depth = None,
|
||||||
|
textline_down_in_depth = None,
|
||||||
|
textline_right_in_depth_bin = None,
|
||||||
|
textline_left_in_depth_bin = None,
|
||||||
|
textline_up_in_depth_bin = None,
|
||||||
|
textline_down_in_depth_bin = None,
|
||||||
|
pepper_aug = None,
|
||||||
|
pepper_bin_aug = None,
|
||||||
|
list_all_possible_background_images=None,
|
||||||
|
list_all_possible_foreground_rgbs=None,
|
||||||
|
blur_k = None,
|
||||||
|
degrade_scales = None,
|
||||||
|
white_padds = None,
|
||||||
|
thetha_padd = None,
|
||||||
|
thetha = None,
|
||||||
|
brightness = None,
|
||||||
|
padd_colors = None,
|
||||||
|
number_of_backgrounds_per_image = None,
|
||||||
|
shuffle_indexes = None,
|
||||||
|
pepper_indexes = None,
|
||||||
|
skewing_amplitudes = None,
|
||||||
|
dir_rgb_backgrounds = None,
|
||||||
|
dir_rgb_foregrounds = None,
|
||||||
|
len_data=None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
images_dir (str): Path to the directory containing images.
|
||||||
|
labels_dir (str): Path to the directory containing label text files.
|
||||||
|
tokenizer: Tokenizer for processing labels.
|
||||||
|
transform: Transformations applied after augmentation (e.g., ToTensor, normalization).
|
||||||
|
image_size (tuple): Size to resize images to.
|
||||||
|
max_seq_len (int): Maximum sequence length for tokenized labels.
|
||||||
|
scales (list or None): List of scale factors to apply.
|
||||||
|
"""
|
||||||
|
self.dir_img = dir_img
|
||||||
|
self.dir_img_bin = dir_img_bin
|
||||||
|
self.dir_lab = dir_lab
|
||||||
|
self.processor = processor
|
||||||
|
self.max_target_length = max_target_length
|
||||||
|
#self.scales = scales if scales else []
|
||||||
|
|
||||||
|
self.augmentation = augmentation
|
||||||
|
self.binarization = binarization
|
||||||
|
self.add_red_textlines = add_red_textlines
|
||||||
|
self.white_noise_strap = white_noise_strap
|
||||||
|
self.adding_rgb_foreground = adding_rgb_foreground
|
||||||
|
self.adding_rgb_background = adding_rgb_background
|
||||||
|
self.bin_deg = bin_deg
|
||||||
|
self.blur_aug = blur_aug
|
||||||
|
self.brightening = brightening
|
||||||
|
self.padding_white = padding_white
|
||||||
|
self.color_padding_rotation = color_padding_rotation
|
||||||
|
self.rotation_not_90 = rotation_not_90
|
||||||
|
self.degrading = degrading
|
||||||
|
self.channels_shuffling = channels_shuffling
|
||||||
|
self.textline_skewing = textline_skewing
|
||||||
|
self.textline_skewing_bin = textline_skewing_bin
|
||||||
|
self.textline_right_in_depth = textline_right_in_depth
|
||||||
|
self.textline_left_in_depth = textline_left_in_depth
|
||||||
|
self.textline_up_in_depth = textline_up_in_depth
|
||||||
|
self.textline_down_in_depth = textline_down_in_depth
|
||||||
|
self.textline_right_in_depth_bin = textline_right_in_depth_bin
|
||||||
|
self.textline_left_in_depth_bin = textline_left_in_depth_bin
|
||||||
|
self.textline_up_in_depth_bin = textline_up_in_depth_bin
|
||||||
|
self.textline_down_in_depth_bin = textline_down_in_depth_bin
|
||||||
|
self.pepper_aug = pepper_aug
|
||||||
|
self.pepper_bin_aug = pepper_bin_aug
|
||||||
|
self.list_all_possible_background_images=list_all_possible_background_images
|
||||||
|
self.list_all_possible_foreground_rgbs=list_all_possible_foreground_rgbs
|
||||||
|
self.blur_k = blur_k
|
||||||
|
self.degrade_scales = degrade_scales
|
||||||
|
self.white_padds = white_padds
|
||||||
|
self.thetha_padd = thetha_padd
|
||||||
|
self.thetha = thetha
|
||||||
|
self.brightness = brightness
|
||||||
|
self.padd_colors = padd_colors
|
||||||
|
self.number_of_backgrounds_per_image = number_of_backgrounds_per_image
|
||||||
|
self.shuffle_indexes = shuffle_indexes
|
||||||
|
self.pepper_indexes = pepper_indexes
|
||||||
|
self.skewing_amplitudes = skewing_amplitudes
|
||||||
|
self.dir_rgb_backgrounds = dir_rgb_backgrounds
|
||||||
|
self.dir_rgb_foregrounds = dir_rgb_foregrounds
|
||||||
|
self.image_files = os.listdir(dir_img)#sorted([f for f in os.listdir(images_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))])
|
||||||
|
self.len_data = len_data
|
||||||
|
#assert len(self.image_files) == len(self.label_files), "Number of images and labels must match!"
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.len_data
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
for img_file in self.image_files:
|
||||||
|
# Load image
|
||||||
|
f_name = img_file.split('.')[0]
|
||||||
|
|
||||||
|
txt_inp = open(os.path.join(self.dir_lab, f_name+'.txt'),'r').read().split('\n')[0]
|
||||||
|
|
||||||
|
img = cv2.imread(os.path.join(self.dir_img, img_file))
|
||||||
|
img = img.astype(np.uint8)
|
||||||
|
|
||||||
|
|
||||||
|
if self.dir_img_bin:
|
||||||
|
img_bin_corr = cv2.imread(os.path.join(self.dir_img_bin, f_name+'.png') )
|
||||||
|
img_bin_corr = img_bin_corr.astype(np.uint8)
|
||||||
|
else:
|
||||||
|
img_bin_corr = None
|
||||||
|
|
||||||
|
|
||||||
|
labels = self.processor.tokenizer(txt_inp,
|
||||||
|
padding="max_length",
|
||||||
|
max_length=self.max_target_length).input_ids
|
||||||
|
|
||||||
|
labels = [label if label != self.processor.tokenizer.pad_token_id else -100 for label in labels]
|
||||||
|
|
||||||
|
|
||||||
|
if self.augmentation:
|
||||||
|
pixel_values = self.processor(Image.fromarray(img), return_tensors="pt").pixel_values
|
||||||
|
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
|
||||||
|
yield encoding
|
||||||
|
|
||||||
|
|
||||||
|
if self.color_padding_rotation:
|
||||||
|
for index, thetha_ind in enumerate(self.thetha_padd):
|
||||||
|
for padd_col in self.padd_colors:
|
||||||
|
img_out = rotation_not_90_func_single_image(do_padding_for_ocr(img, 1.2, padd_col), thetha_ind)
|
||||||
|
img_out = img_out.astype(np.uint8)
|
||||||
|
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
|
||||||
|
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
|
||||||
|
yield encoding
|
||||||
|
|
||||||
|
if self.rotation_not_90:
|
||||||
|
for index, thetha_ind in enumerate(self.thetha):
|
||||||
|
img_out = rotation_not_90_func_single_image(img, thetha_ind)
|
||||||
|
img_out = img_out.astype(np.uint8)
|
||||||
|
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
|
||||||
|
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
|
||||||
|
yield encoding
|
||||||
|
|
||||||
|
|
||||||
|
if self.blur_aug:
|
||||||
|
for index, blur_type in enumerate(self.blur_k):
|
||||||
|
img_out = bluring(img, blur_type)
|
||||||
|
img_out = img_out.astype(np.uint8)
|
||||||
|
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
|
||||||
|
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
|
||||||
|
yield encoding
|
||||||
|
|
||||||
|
if self.degrading:
|
||||||
|
for index, deg_scale_ind in enumerate(self.degrade_scales):
|
||||||
|
try:
|
||||||
|
img_out = do_degrading(img, deg_scale_ind)
|
||||||
|
except:
|
||||||
|
img_out = np.copy(img)
|
||||||
|
img_out = img_out.astype(np.uint8)
|
||||||
|
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
|
||||||
|
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
|
||||||
|
yield encoding
|
||||||
|
|
||||||
|
if self.bin_deg:
|
||||||
|
for index, deg_scale_ind in enumerate(self.degrade_scales):
|
||||||
|
try:
|
||||||
|
img_out = self.do_degrading(img_bin_corr, deg_scale_ind)
|
||||||
|
except:
|
||||||
|
img_out = np.copy(img_bin_corr)
|
||||||
|
img_out = img_out.astype(np.uint8)
|
||||||
|
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
|
||||||
|
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
|
||||||
|
yield encoding
|
||||||
|
|
||||||
|
|
||||||
|
if self.brightening:
|
||||||
|
for index, bright_scale_ind in enumerate(self.brightness):
|
||||||
|
try:
|
||||||
|
img_out = do_brightening(dir_img, bright_scale_ind)
|
||||||
|
except:
|
||||||
|
img_out = np.copy(img)
|
||||||
|
img_out = img_out.astype(np.uint8)
|
||||||
|
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
|
||||||
|
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
|
||||||
|
yield encoding
|
||||||
|
|
||||||
|
|
||||||
|
if self.padding_white:
|
||||||
|
for index, padding_size in enumerate(self.white_padds):
|
||||||
|
for padd_col in self.padd_colors:
|
||||||
|
img_out = do_padding_for_ocr(img, padding_size, padd_col)
|
||||||
|
img_out = img_out.astype(np.uint8)
|
||||||
|
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
|
||||||
|
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
|
||||||
|
yield encoding
|
||||||
|
|
||||||
|
|
||||||
|
if self.adding_rgb_foreground:
|
||||||
|
for i_n in range(self.number_of_backgrounds_per_image):
|
||||||
|
background_image_chosen_name = random.choice(self.list_all_possible_background_images)
|
||||||
|
foreground_rgb_chosen_name = random.choice(self.list_all_possible_foreground_rgbs)
|
||||||
|
|
||||||
|
img_rgb_background_chosen = cv2.imread(self.dir_rgb_backgrounds + '/' + background_image_chosen_name)
|
||||||
|
foreground_rgb_chosen = np.load(self.dir_rgb_foregrounds + '/' + foreground_rgb_chosen_name)
|
||||||
|
|
||||||
|
img_out = return_binary_image_with_given_rgb_background_and_given_foreground_rgb(img_bin_corr, img_rgb_background_chosen, foreground_rgb_chosen)
|
||||||
|
img_out = img_out.astype(np.uint8)
|
||||||
|
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
|
||||||
|
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
|
||||||
|
yield encoding
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if self.adding_rgb_background:
|
||||||
|
for i_n in range(self.number_of_backgrounds_per_image):
|
||||||
|
background_image_chosen_name = random.choice(self.list_all_possible_background_images)
|
||||||
|
img_rgb_background_chosen = cv2.imread(self.dir_rgb_backgrounds + '/' + background_image_chosen_name)
|
||||||
|
img_out = return_binary_image_with_given_rgb_background(img_bin_corr, img_rgb_background_chosen)
|
||||||
|
img_out = img_out.astype(np.uint8)
|
||||||
|
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
|
||||||
|
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
|
||||||
|
yield encoding
|
||||||
|
|
||||||
|
if self.binarization:
|
||||||
|
pixel_values = self.processor(Image.fromarray(img_bin_corr), return_tensors="pt").pixel_values
|
||||||
|
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
|
||||||
|
yield encoding
|
||||||
|
|
||||||
|
|
||||||
|
if self.channels_shuffling:
|
||||||
|
for shuffle_index in self.shuffle_indexes:
|
||||||
|
img_out = return_shuffled_channels(img, shuffle_index)
|
||||||
|
img_out = img_out.astype(np.uint8)
|
||||||
|
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
|
||||||
|
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
|
||||||
|
yield encoding
|
||||||
|
|
||||||
|
if self.add_red_textlines:
|
||||||
|
img_out = return_image_with_red_elements(img, img_bin_corr)
|
||||||
|
img_out = img_out.astype(np.uint8)
|
||||||
|
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
|
||||||
|
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
|
||||||
|
yield encoding
|
||||||
|
|
||||||
|
if self.white_noise_strap:
|
||||||
|
img_out = return_image_with_strapped_white_noises(img)
|
||||||
|
img_out = img_out.astype(np.uint8)
|
||||||
|
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
|
||||||
|
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
|
||||||
|
yield encoding
|
||||||
|
|
||||||
|
if self.textline_skewing:
|
||||||
|
for index, des_scale_ind in enumerate(self.skewing_amplitudes):
|
||||||
|
try:
|
||||||
|
img_out = do_deskewing(img, des_scale_ind)
|
||||||
|
except:
|
||||||
|
img_out = np.copy(img)
|
||||||
|
img_out = img_out.astype(np.uint8)
|
||||||
|
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
|
||||||
|
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
|
||||||
|
yield encoding
|
||||||
|
|
||||||
|
if self.textline_skewing_bin:
|
||||||
|
for index, des_scale_ind in enumerate(self.skewing_amplitudes):
|
||||||
|
try:
|
||||||
|
img_out = do_deskewing(img_bin_corr, des_scale_ind)
|
||||||
|
except:
|
||||||
|
img_out = np.copy(img_bin_corr)
|
||||||
|
img_out = img_out.astype(np.uint8)
|
||||||
|
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
|
||||||
|
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
|
||||||
|
yield encoding
|
||||||
|
|
||||||
|
|
||||||
|
if self.textline_left_in_depth:
|
||||||
|
try:
|
||||||
|
img_out = do_direction_in_depth(img, 'left')
|
||||||
|
except:
|
||||||
|
img_out = np.copy(img)
|
||||||
|
img_out = img_out.astype(np.uint8)
|
||||||
|
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
|
||||||
|
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
|
||||||
|
yield encoding
|
||||||
|
|
||||||
|
|
||||||
|
if self.textline_left_in_depth_bin:
|
||||||
|
try:
|
||||||
|
img_out = do_direction_in_depth(img_bin_corr, 'left')
|
||||||
|
except:
|
||||||
|
img_out = np.copy(img_bin_corr)
|
||||||
|
img_out = img_out.astype(np.uint8)
|
||||||
|
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
|
||||||
|
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
|
||||||
|
yield encoding
|
||||||
|
|
||||||
|
|
||||||
|
if self.textline_right_in_depth:
|
||||||
|
try:
|
||||||
|
img_out = do_direction_in_depth(img, 'right')
|
||||||
|
except:
|
||||||
|
img_out = np.copy(img)
|
||||||
|
img_out = img_out.astype(np.uint8)
|
||||||
|
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
|
||||||
|
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
|
||||||
|
yield encoding
|
||||||
|
|
||||||
|
|
||||||
|
if self.textline_right_in_depth_bin:
|
||||||
|
try:
|
||||||
|
img_out = do_direction_in_depth(img_bin_corr, 'right')
|
||||||
|
except:
|
||||||
|
img_out = np.copy(img_bin_corr)
|
||||||
|
img_out = img_out.astype(np.uint8)
|
||||||
|
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
|
||||||
|
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
|
||||||
|
yield encoding
|
||||||
|
|
||||||
|
|
||||||
|
if self.textline_up_in_depth:
|
||||||
|
try:
|
||||||
|
img_out = do_direction_in_depth(img, 'up')
|
||||||
|
except:
|
||||||
|
img_out = np.copy(img)
|
||||||
|
img_out = img_out.astype(np.uint8)
|
||||||
|
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
|
||||||
|
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
|
||||||
|
yield encoding
|
||||||
|
|
||||||
|
|
||||||
|
if self.textline_up_in_depth_bin:
|
||||||
|
try:
|
||||||
|
img_out = do_direction_in_depth(img_bin_corr, 'up')
|
||||||
|
except:
|
||||||
|
img_out = np.copy(img_bin_corr)
|
||||||
|
img_out = img_out.astype(np.uint8)
|
||||||
|
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
|
||||||
|
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
|
||||||
|
yield encoding
|
||||||
|
|
||||||
|
|
||||||
|
if self.textline_down_in_depth:
|
||||||
|
try:
|
||||||
|
img_out = do_direction_in_depth(img, 'down')
|
||||||
|
except:
|
||||||
|
img_out = np.copy(img)
|
||||||
|
img_out = img_out.astype(np.uint8)
|
||||||
|
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
|
||||||
|
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
|
||||||
|
yield encoding
|
||||||
|
|
||||||
|
|
||||||
|
if self.textline_down_in_depth_bin:
|
||||||
|
try:
|
||||||
|
img_out = do_direction_in_depth(img_bin_corr, 'down')
|
||||||
|
except:
|
||||||
|
img_out = np.copy(img_bin_corr)
|
||||||
|
img_out = img_out.astype(np.uint8)
|
||||||
|
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
|
||||||
|
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
|
||||||
|
yield encoding
|
||||||
|
|
||||||
|
if self.pepper_bin_aug:
|
||||||
|
for index, pepper_ind in enumerate(self.pepper_indexes):
|
||||||
|
img_out = add_salt_and_pepper_noise(img_bin_corr, pepper_ind, pepper_ind)
|
||||||
|
img_out = img_out.astype(np.uint8)
|
||||||
|
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
|
||||||
|
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
|
||||||
|
yield encoding
|
||||||
|
|
||||||
|
|
||||||
|
if self.pepper_aug:
|
||||||
|
for index, pepper_ind in enumerate(self.pepper_indexes):
|
||||||
|
img_out = add_salt_and_pepper_noise(img, pepper_ind, pepper_ind)
|
||||||
|
img_out = img_out.astype(np.uint8)
|
||||||
|
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
|
||||||
|
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
|
||||||
|
yield encoding
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
else:
|
||||||
|
pixel_values = self.processor(Image.fromarray(img), return_tensors="pt").pixel_values
|
||||||
|
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
|
||||||
|
yield encoding
|
||||||
|
|
|
||||||
|
|
@ -21,6 +21,11 @@ from tensorflow.keras.layers import *
|
||||||
import click
|
import click
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
from transformers import TrOCRProcessor
|
||||||
|
from PIL import Image
|
||||||
|
import torch
|
||||||
|
from transformers import VisionEncoderDecoderModel
|
||||||
|
|
||||||
|
|
||||||
class Patches(layers.Layer):
|
class Patches(layers.Layer):
|
||||||
def __init__(self, patch_size_x, patch_size_y):
|
def __init__(self, patch_size_x, patch_size_y):
|
||||||
|
|
@ -92,30 +97,46 @@ def start_new_session():
|
||||||
tensorflow_backend.set_session(session)
|
tensorflow_backend.set_session(session)
|
||||||
return session
|
return session
|
||||||
|
|
||||||
def run_ensembling(dir_models, out):
|
def run_ensembling(dir_models, out, framework):
|
||||||
ls_models = os.listdir(dir_models)
|
ls_models = os.listdir(dir_models)
|
||||||
|
if framework=="torch":
|
||||||
|
models = []
|
||||||
weights=[]
|
sd_models = []
|
||||||
|
|
||||||
for model_name in ls_models:
|
|
||||||
model = load_model(os.path.join(dir_models,model_name) , compile=False, custom_objects={'PatchEncoder':PatchEncoder, 'Patches': Patches})
|
|
||||||
weights.append(model.get_weights())
|
|
||||||
|
|
||||||
new_weights = list()
|
for model_name in ls_models:
|
||||||
|
model = VisionEncoderDecoderModel.from_pretrained(os.path.join(dir_models,model_name))
|
||||||
for weights_list_tuple in zip(*weights):
|
models.append(model)
|
||||||
new_weights.append(
|
sd_models.append(model.state_dict())
|
||||||
[np.array(weights_).mean(axis=0)\
|
for key in sd_models[0]:
|
||||||
for weights_ in zip(*weights_list_tuple)])
|
sd_models[0][key] = sum(sd[key] for sd in sd_models) / len(sd_models)
|
||||||
|
|
||||||
|
model.load_state_dict(sd_models[0])
|
||||||
|
os.system("mkdir "+out)
|
||||||
|
torch.save(model.state_dict(), os.path.join(out, "pytorch_model.bin"))
|
||||||
|
os.system('cp '+os.path.join(os.path.join(dir_models,model_name) , "config_eynollah.json ")+out)
|
||||||
|
|
||||||
|
else:
|
||||||
|
weights=[]
|
||||||
|
|
||||||
|
for model_name in ls_models:
|
||||||
|
model = load_model(os.path.join(dir_models,model_name) , compile=False, custom_objects={'PatchEncoder':PatchEncoder, 'Patches': Patches})
|
||||||
|
weights.append(model.get_weights())
|
||||||
|
|
||||||
|
new_weights = list()
|
||||||
|
|
||||||
|
for weights_list_tuple in zip(*weights):
|
||||||
|
new_weights.append(
|
||||||
|
[np.array(weights_).mean(axis=0)\
|
||||||
|
for weights_ in zip(*weights_list_tuple)])
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
new_weights = [np.array(x) for x in new_weights]
|
new_weights = [np.array(x) for x in new_weights]
|
||||||
|
|
||||||
model.set_weights(new_weights)
|
model.set_weights(new_weights)
|
||||||
model.save(out)
|
model.save(out)
|
||||||
os.system('cp '+os.path.join(os.path.join(dir_models,model_name) , "config.json ")+out)
|
os.system('cp '+os.path.join(os.path.join(dir_models,model_name) , "config_eynollah.json ")+out)
|
||||||
|
os.system('cp '+os.path.join(os.path.join(dir_models,model_name) , "characters_org.txt ")+out)
|
||||||
|
|
||||||
@click.command()
|
@click.command()
|
||||||
@click.option(
|
@click.option(
|
||||||
|
|
@ -130,7 +151,12 @@ def run_ensembling(dir_models, out):
|
||||||
help="output directory where ensembled model will be written.",
|
help="output directory where ensembled model will be written.",
|
||||||
type=click.Path(exists=False, file_okay=False),
|
type=click.Path(exists=False, file_okay=False),
|
||||||
)
|
)
|
||||||
|
@click.option(
|
||||||
|
"--framework",
|
||||||
|
"-fw",
|
||||||
|
help="this parameter gets tensorflow or torch as model framework",
|
||||||
|
)
|
||||||
|
|
||||||
def main(dir_models, out):
|
def main(dir_models, out, framework):
|
||||||
run_ensembling(dir_models, out)
|
run_ensembling(dir_models, out, framework)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -9,8 +9,8 @@ else:
|
||||||
import importlib.resources as importlib_resources
|
import importlib.resources as importlib_resources
|
||||||
|
|
||||||
|
|
||||||
def get_font():
|
def get_font(font_size):
|
||||||
#font_path = "Charis-7.000/Charis-Regular.ttf" # Make sure this file exists!
|
#font_path = "Charis-7.000/Charis-Regular.ttf" # Make sure this file exists!
|
||||||
font = importlib_resources.files(__package__) / "../Charis-Regular.ttf"
|
font = importlib_resources.files(__package__) / "../Charis-Regular.ttf"
|
||||||
with importlib_resources.as_file(font) as font:
|
with importlib_resources.as_file(font) as font:
|
||||||
return ImageFont.truetype(font=font, size=40)
|
return ImageFont.truetype(font=font, size=font_size)
|
||||||
|
|
|
||||||
|
|
@ -1,17 +1,17 @@
|
||||||
{
|
{
|
||||||
"backbone_type" : "transformer",
|
"backbone_type" : "transformer",
|
||||||
"task": "cnn-rnn-ocr",
|
"task": "transformer-ocr",
|
||||||
"n_classes" : 2,
|
"n_classes" : 2,
|
||||||
"max_len": 280,
|
"max_len": 192,
|
||||||
"n_epochs" : 3,
|
"n_epochs" : 1,
|
||||||
"input_height" : 32,
|
"input_height" : 32,
|
||||||
"input_width" : 512,
|
"input_width" : 512,
|
||||||
"weight_decay" : 1e-6,
|
"weight_decay" : 1e-6,
|
||||||
"n_batch" : 4,
|
"n_batch" : 1,
|
||||||
"learning_rate": 1e-5,
|
"learning_rate": 1e-5,
|
||||||
"save_interval": 1500,
|
"save_interval": 1500,
|
||||||
"patches" : false,
|
"patches" : false,
|
||||||
"pretraining" : true,
|
"pretraining" : false,
|
||||||
"augmentation" : true,
|
"augmentation" : true,
|
||||||
"flip_aug" : false,
|
"flip_aug" : false,
|
||||||
"blur_aug" : true,
|
"blur_aug" : true,
|
||||||
|
|
@ -77,7 +77,6 @@
|
||||||
"dir_output": "/home/vahid/extracted_lines/1919_bin/output",
|
"dir_output": "/home/vahid/extracted_lines/1919_bin/output",
|
||||||
"dir_rgb_backgrounds": "/home/vahid/Documents/1_2_test_eynollah/set_rgb_background",
|
"dir_rgb_backgrounds": "/home/vahid/Documents/1_2_test_eynollah/set_rgb_background",
|
||||||
"dir_rgb_foregrounds": "/home/vahid/Documents/1_2_test_eynollah/out_set_rgb_foreground",
|
"dir_rgb_foregrounds": "/home/vahid/Documents/1_2_test_eynollah/out_set_rgb_foreground",
|
||||||
"dir_img_bin": "/home/vahid/extracted_lines/1919_bin/images_bin",
|
"dir_img_bin": "/home/vahid/extracted_lines/1919_bin/images_bin"
|
||||||
"characters_txt_file":"/home/vahid/Downloads/models_eynollah/model_eynollah_ocr_cnnrnn_20250930/characters_org.txt"
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -4,3 +4,8 @@ numpy <1.24.0
|
||||||
tqdm
|
tqdm
|
||||||
imutils
|
imutils
|
||||||
scipy
|
scipy
|
||||||
|
torch
|
||||||
|
evaluate
|
||||||
|
accelerate
|
||||||
|
jiwer
|
||||||
|
transformers <= 4.30.2
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue