mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-02-20 16:32:03 +01:00
Merge c4434c7f7d into 586077fbcd
This commit is contained in:
commit
38745111df
13 changed files with 875 additions and 113 deletions
|
|
@ -70,6 +70,7 @@ class Eynollah_ocr:
|
|||
self.model_zoo.get('ocr').to(self.device)
|
||||
else:
|
||||
self.model_zoo.load_model('ocr', '')
|
||||
self.input_shape = self.model_zoo.get('ocr').input_shape[1:3]
|
||||
self.model_zoo.load_model('num_to_char')
|
||||
self.model_zoo.load_model('characters')
|
||||
self.end_character = len(self.model_zoo.get('characters', list)) + 2
|
||||
|
|
@ -657,7 +658,7 @@ class Eynollah_ocr:
|
|||
if out_image_with_text:
|
||||
image_text = Image.new("RGB", (img.shape[1], img.shape[0]), "white")
|
||||
draw = ImageDraw.Draw(image_text)
|
||||
font = get_font()
|
||||
font = get_font(font_size=40)
|
||||
|
||||
for indexer_text, bb_ind in enumerate(total_bb_coordinates):
|
||||
x_bb = bb_ind[0]
|
||||
|
|
@ -823,8 +824,8 @@ class Eynollah_ocr:
|
|||
page_ns=page_ns,
|
||||
|
||||
img_bin=img_bin,
|
||||
image_width=512,
|
||||
image_height=32,
|
||||
image_width=self.input_shape[1],
|
||||
image_height=self.input_shape[0],
|
||||
)
|
||||
|
||||
self.write_ocr(
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from .inference import main as inference_cli
|
|||
from .train import ex
|
||||
from .extract_line_gt import linegt_cli
|
||||
from .weights_ensembling import main as ensemble_cli
|
||||
from .generate_or_update_cnn_rnn_ocr_character_list import main as update_ocr_characters_cli
|
||||
|
||||
@click.command(context_settings=dict(
|
||||
ignore_unknown_options=True,
|
||||
|
|
@ -28,3 +29,4 @@ main.add_command(inference_cli, 'inference')
|
|||
main.add_command(train_cli, 'train')
|
||||
main.add_command(linegt_cli, 'export_textline_images_and_text')
|
||||
main.add_command(ensemble_cli, 'ensembling')
|
||||
main.add_command(update_ocr_characters_cli, 'generate_or_update_cnn_rnn_ocr_character_list')
|
||||
|
|
|
|||
|
|
@ -50,6 +50,18 @@ from ..utils import is_image_filename
|
|||
is_flag=True,
|
||||
help="if this parameter set to true, cropped textline images will not be masked with textline contour.",
|
||||
)
|
||||
@click.option(
|
||||
"--exclude_vertical_lines",
|
||||
"-exv",
|
||||
is_flag=True,
|
||||
help="if this parameter set to true, vertical textline images will be excluded.",
|
||||
)
|
||||
@click.option(
|
||||
"--page_alto",
|
||||
"-alto",
|
||||
is_flag=True,
|
||||
help="If this parameter is set to True, text line image cropping and text extraction are performed using PAGE/ALTO files. Otherwise, the default method for PAGE XML files is used.",
|
||||
)
|
||||
def linegt_cli(
|
||||
image,
|
||||
dir_in,
|
||||
|
|
@ -57,6 +69,8 @@ def linegt_cli(
|
|||
dir_out,
|
||||
pref_of_dataset,
|
||||
do_not_mask_with_textline_contour,
|
||||
exclude_vertical_lines,
|
||||
page_alto,
|
||||
):
|
||||
assert bool(dir_in) ^ bool(image), "Set --dir-in or --image-filename, not both"
|
||||
if dir_in:
|
||||
|
|
@ -70,14 +84,91 @@ def linegt_cli(
|
|||
for dir_img in ls_imgs:
|
||||
file_name = Path(dir_img).stem
|
||||
dir_xml = os.path.join(dir_xmls, file_name + '.xml')
|
||||
|
||||
img = cv2.imread(dir_img)
|
||||
|
||||
if page_alto:
|
||||
h, w = img.shape[:2]
|
||||
|
||||
tree = ET.parse(dir_xml)
|
||||
root = tree.getroot()
|
||||
|
||||
NS = {"alto": "http://www.loc.gov/standards/alto/ns-v4#"}
|
||||
|
||||
results = []
|
||||
|
||||
indexer_textlines = 0
|
||||
for line in root.findall(".//alto:TextLine", NS):
|
||||
string_el = line.find("alto:String", NS)
|
||||
textline_text = string_el.attrib["CONTENT"] if string_el is not None else None
|
||||
|
||||
polygon_el = line.find("alto:Shape/alto:Polygon", NS)
|
||||
if polygon_el is None:
|
||||
continue
|
||||
|
||||
points = polygon_el.attrib["POINTS"].split()
|
||||
coords = [
|
||||
(int(points[i]), int(points[i + 1]))
|
||||
for i in range(0, len(points), 2)
|
||||
]
|
||||
|
||||
coords = np.array(coords, dtype=np.int32)
|
||||
x, y, w, h = cv2.boundingRect(coords)
|
||||
|
||||
|
||||
if exclude_vertical_lines and h > 1.4 * w:
|
||||
img_crop = None
|
||||
continue
|
||||
|
||||
img_poly_on_img = np.copy(img)
|
||||
|
||||
mask_poly = np.zeros(img.shape)
|
||||
mask_poly = cv2.fillPoly(mask_poly, pts=[coords], color=(1, 1, 1))
|
||||
|
||||
mask_poly = mask_poly[y : y + h, x : x + w, :]
|
||||
img_crop = img_poly_on_img[y : y + h, x : x + w, :]
|
||||
|
||||
if not do_not_mask_with_textline_contour:
|
||||
img_crop[mask_poly == 0] = 255
|
||||
|
||||
if img_crop.shape[0] == 0 or img_crop.shape[1] == 0:
|
||||
img_crop = None
|
||||
continue
|
||||
|
||||
if textline_text and img_crop is not None:
|
||||
base_name = os.path.join(
|
||||
dir_out, file_name + '_line_' + str(indexer_textlines)
|
||||
)
|
||||
if pref_of_dataset:
|
||||
base_name += '_' + pref_of_dataset
|
||||
if not do_not_mask_with_textline_contour:
|
||||
base_name += '_masked'
|
||||
|
||||
with open(base_name + '.txt', 'w') as text_file:
|
||||
text_file.write(textline_text)
|
||||
cv2.imwrite(base_name + '.png', img_crop)
|
||||
indexer_textlines += 1
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
else:
|
||||
total_bb_coordinates = []
|
||||
|
||||
tree1 = ET.parse(dir_xml, parser=ET.XMLParser(encoding="utf-8"))
|
||||
root1 = tree1.getroot()
|
||||
alltags = [elem.tag for elem in root1.iter()]
|
||||
tree = ET.parse(dir_xml, parser=ET.XMLParser(encoding="utf-8"))
|
||||
root = tree.getroot()
|
||||
alltags = [elem.tag for elem in root.iter()]
|
||||
|
||||
name_space = alltags[0].split('}')[0]
|
||||
name_space = name_space.split('{')[1]
|
||||
|
|
@ -89,7 +180,7 @@ def linegt_cli(
|
|||
indexer_text_region = 0
|
||||
indexer_textlines = 0
|
||||
# FIXME: non recursive, use OCR-D PAGE generateDS API. Or use an existing tool for this purpose altogether
|
||||
for nn in root1.iter(region_tags):
|
||||
for nn in root.iter(region_tags):
|
||||
for child_textregion in nn:
|
||||
if child_textregion.tag.endswith("TextLine"):
|
||||
for child_textlines in child_textregion:
|
||||
|
|
@ -100,6 +191,10 @@ def linegt_cli(
|
|||
|
||||
x, y, w, h = cv2.boundingRect(textline_coords)
|
||||
|
||||
if exclude_vertical_lines and h > 1.4 * w:
|
||||
img_crop = None
|
||||
continue
|
||||
|
||||
total_bb_coordinates.append([x, y, w, h])
|
||||
|
||||
img_poly_on_img = np.copy(img)
|
||||
|
|
@ -114,12 +209,15 @@ def linegt_cli(
|
|||
img_crop[mask_poly == 0] = 255
|
||||
|
||||
if img_crop.shape[0] == 0 or img_crop.shape[1] == 0:
|
||||
img_crop = None
|
||||
continue
|
||||
|
||||
|
||||
if child_textlines.tag.endswith("TextEquiv"):
|
||||
for cheild_text in child_textlines:
|
||||
if cheild_text.tag.endswith("Unicode"):
|
||||
textline_text = cheild_text.text
|
||||
if textline_text:
|
||||
if textline_text and img_crop is not None:
|
||||
base_name = os.path.join(
|
||||
dir_out, file_name + '_line_' + str(indexer_textlines)
|
||||
)
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ from pathlib import Path
|
|||
from PIL import Image, ImageDraw, ImageFont
|
||||
import cv2
|
||||
import numpy as np
|
||||
from eynollah.utils.font import get_font
|
||||
|
||||
from eynollah.training.gt_gen_utils import (
|
||||
filter_contours_area_of_image,
|
||||
|
|
@ -477,7 +478,7 @@ def visualize_layout_segmentation(xml_file, dir_xml, dir_out, dir_imgs):
|
|||
co_text, co_graphic, co_sep, co_img, co_table, co_map, co_noise, y_len, x_len = get_layout_contours_for_visualization(xml_file)
|
||||
|
||||
|
||||
added_image = visualize_image_from_contours_layout(co_text['paragraph'], co_text['header']+co_text['heading'], co_text['drop-capital'], co_sep, co_img, co_text['marginalia'], co_table, img)
|
||||
added_image = visualize_image_from_contours_layout(co_text['paragraph'], co_text['header']+co_text['heading'], co_text['drop-capital'], co_sep, co_img, co_text['marginalia'], co_table, co_map, img)
|
||||
|
||||
cv2.imwrite(os.path.join(dir_out, f_name+'.png'), added_image)
|
||||
|
||||
|
|
@ -514,8 +515,8 @@ def visualize_ocr_text(xml_file, dir_xml, dir_out):
|
|||
else:
|
||||
xml_files_ind = [xml_file]
|
||||
|
||||
font_path = "Charis-7.000/Charis-Regular.ttf" # Make sure this file exists!
|
||||
font = ImageFont.truetype(font_path, 40)
|
||||
###font_path = "Charis-7.000/Charis-Regular.ttf" # Make sure this file exists!
|
||||
font = get_font(font_size=40)#ImageFont.truetype(font_path, 40)
|
||||
|
||||
for ind_xml in tqdm(xml_files_ind):
|
||||
indexer = 0
|
||||
|
|
@ -552,11 +553,11 @@ def visualize_ocr_text(xml_file, dir_xml, dir_out):
|
|||
|
||||
|
||||
is_vertical = h > 2*w # Check orientation
|
||||
font = fit_text_single_line(draw, ocr_texts[index], font_path, w, int(h*0.4) )
|
||||
font = fit_text_single_line(draw, ocr_texts[index], w, int(h*0.4) )
|
||||
|
||||
if is_vertical:
|
||||
|
||||
vertical_font = fit_text_single_line(draw, ocr_texts[index], font_path, h, int(w * 0.8))
|
||||
vertical_font = fit_text_single_line(draw, ocr_texts[index], h, int(w * 0.8))
|
||||
|
||||
text_img = Image.new("RGBA", (h, w), (255, 255, 255, 0)) # Note: dimensions are swapped
|
||||
text_draw = ImageDraw.Draw(text_img)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,59 @@
|
|||
import os
|
||||
import numpy as np
|
||||
import json
|
||||
import click
|
||||
import logging
|
||||
|
||||
|
||||
|
||||
def run_character_list_update(dir_labels, out, current_character_list):
|
||||
ls_labels = os.listdir(dir_labels)
|
||||
ls_labels = [ind for ind in ls_labels if ind.endswith('.txt')]
|
||||
|
||||
if current_character_list:
|
||||
with open(current_character_list, 'r') as f_name:
|
||||
characters = json.load(f_name)
|
||||
|
||||
characters = set(characters)
|
||||
else:
|
||||
characters = set()
|
||||
|
||||
|
||||
for ind in ls_labels:
|
||||
label = open(os.path.join(dir_labels,ind),'r').read().split('\n')[0]
|
||||
|
||||
for char in label:
|
||||
characters.add(char)
|
||||
|
||||
|
||||
characters = sorted(list(set(characters)))
|
||||
|
||||
with open(out, 'w') as f_name:
|
||||
json.dump(characters, f_name)
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.option(
|
||||
"--dir_labels",
|
||||
"-dl",
|
||||
help="directory of labels which are .txt files",
|
||||
type=click.Path(exists=True, file_okay=False),
|
||||
required=True,
|
||||
)
|
||||
@click.option(
|
||||
"--current_character_list",
|
||||
"-ccl",
|
||||
help="existing character list in a .txt file that needs to be updated with a set of labels",
|
||||
type=click.Path(exists=True, file_okay=True),
|
||||
required=False,
|
||||
)
|
||||
@click.option(
|
||||
"--out",
|
||||
"-o",
|
||||
help="An output .txt file where the generated or updated character list will be written",
|
||||
type=click.Path(exists=False, file_okay=True),
|
||||
)
|
||||
|
||||
def main(dir_labels, out, current_character_list):
|
||||
run_character_list_update(dir_labels, out, current_character_list)
|
||||
|
||||
|
|
@ -7,7 +7,7 @@ import cv2
|
|||
from shapely import geometry
|
||||
from pathlib import Path
|
||||
from PIL import ImageFont
|
||||
|
||||
from eynollah.utils.font import get_font
|
||||
|
||||
KERNEL = np.ones((5, 5), np.uint8)
|
||||
|
||||
|
|
@ -350,11 +350,11 @@ def get_textline_contours_and_ocr_text(xml_file):
|
|||
ocr_textlines.append(ocr_text_in[0])
|
||||
return co_use_case, y_len, x_len, ocr_textlines
|
||||
|
||||
def fit_text_single_line(draw, text, font_path, max_width, max_height):
|
||||
def fit_text_single_line(draw, text, max_width, max_height):
|
||||
initial_font_size = 50
|
||||
font_size = initial_font_size
|
||||
while font_size > 10: # Minimum font size
|
||||
font = ImageFont.truetype(font_path, font_size)
|
||||
font = get_font(font_size=font_size)# ImageFont.truetype(font_path, font_size)
|
||||
text_bbox = draw.textbbox((0, 0), text, font=font) # Get text bounding box
|
||||
text_width = text_bbox[2] - text_bbox[0]
|
||||
text_height = text_bbox[3] - text_bbox[1]
|
||||
|
|
@ -364,7 +364,7 @@ def fit_text_single_line(draw, text, font_path, max_width, max_height):
|
|||
|
||||
font_size -= 2 # Reduce font size and retry
|
||||
|
||||
return ImageFont.truetype(font_path, 10) # Smallest font fallback
|
||||
return get_font(font_size=10)#ImageFont.truetype(font_path, 10) # Smallest font fallback
|
||||
|
||||
def get_layout_contours_for_visualization(xml_file):
|
||||
tree1 = ET.parse(xml_file, parser = ET.XMLParser(encoding='utf-8'))
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ from keras.models import Model, load_model
|
|||
from keras import backend as K
|
||||
import click
|
||||
from tensorflow.python.keras import backend as tensorflow_backend
|
||||
from tensorflow.keras.layers import StringLookup
|
||||
import xml.etree.ElementTree as ET
|
||||
|
||||
from .gt_gen_utils import (
|
||||
|
|
@ -169,6 +170,25 @@ class sbb_predict:
|
|||
self.model = tf.keras.models.Model(
|
||||
self.model.get_layer(name = "image").input,
|
||||
self.model.get_layer(name = "dense2").output)
|
||||
|
||||
assert isinstance(self.model, Model)
|
||||
|
||||
elif self.task == "transformer-ocr":
|
||||
import torch
|
||||
from transformers import VisionEncoderDecoderModel
|
||||
from transformers import TrOCRProcessor
|
||||
|
||||
self.model = VisionEncoderDecoderModel.from_pretrained(self.model_dir)
|
||||
self.processor = TrOCRProcessor.from_pretrained(self.model_dir)
|
||||
|
||||
if self.cpu:
|
||||
self.device = torch.device('cpu')
|
||||
else:
|
||||
self.device = torch.device('cuda:0')
|
||||
|
||||
self.model.to(self.device)
|
||||
|
||||
assert isinstance(self.model, torch.nn.Module)
|
||||
else:
|
||||
config = tf.compat.v1.ConfigProto()
|
||||
config.gpu_options.allow_growth = True
|
||||
|
|
@ -176,16 +196,16 @@ class sbb_predict:
|
|||
session = tf.compat.v1.Session(config=config) # tf.InteractiveSession()
|
||||
tensorflow_backend.set_session(session)
|
||||
|
||||
self.model = load_model(self.model_dir , compile=False,custom_objects = {"PatchEncoder": PatchEncoder, "Patches": Patches})
|
||||
|
||||
##if self.weights_dir!=None:
|
||||
##self.model.load_weights(self.weights_dir)
|
||||
|
||||
assert isinstance(self.model, Model)
|
||||
if self.task != 'classification' and self.task != 'reading_order':
|
||||
self.img_height=self.model.layers[len(self.model.layers)-1].output_shape[1]
|
||||
self.img_width=self.model.layers[len(self.model.layers)-1].output_shape[2]
|
||||
self.n_classes=self.model.layers[len(self.model.layers)-1].output_shape[3]
|
||||
|
||||
|
||||
assert isinstance(self.model, Model)
|
||||
|
||||
def visualize_model_output(self, prediction, img, task) -> Tuple[NDArray, NDArray]:
|
||||
if task == "binarization":
|
||||
prediction = prediction * -1
|
||||
|
|
@ -235,10 +255,9 @@ class sbb_predict:
|
|||
return added_image, layout_only
|
||||
|
||||
def predict(self, image_dir):
|
||||
assert isinstance(self.model, Model)
|
||||
if self.task == 'classification':
|
||||
classes_names = self.config_params_model['classification_classes_name']
|
||||
img_1ch = img=cv2.imread(image_dir, 0)
|
||||
img_1ch =cv2.imread(image_dir, 0)
|
||||
|
||||
img_1ch = img_1ch / 255.0
|
||||
img_1ch = cv2.resize(img_1ch, (self.config_params_model['input_height'], self.config_params_model['input_width']), interpolation=cv2.INTER_NEAREST)
|
||||
|
|
@ -274,6 +293,15 @@ class sbb_predict:
|
|||
pred_texts = pred_texts[0].replace("[UNK]", "")
|
||||
return pred_texts
|
||||
|
||||
elif self.task == "transformer-ocr":
|
||||
from PIL import Image
|
||||
image = Image.open(image_dir).convert("RGB")
|
||||
pixel_values = self.processor(image, return_tensors="pt").pixel_values
|
||||
generated_ids = self.model.generate(pixel_values.to(self.device))
|
||||
return self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||
|
||||
|
||||
|
||||
|
||||
elif self.task == 'reading_order':
|
||||
img_height = self.config_params_model['input_height']
|
||||
|
|
@ -607,6 +635,8 @@ class sbb_predict:
|
|||
cv2.imwrite(self.save,res)
|
||||
elif self.task == "cnn-rnn-ocr":
|
||||
print(f"Detected text: {res}")
|
||||
elif self.task == "transformer-ocr":
|
||||
print(f"Detected text: {res}")
|
||||
else:
|
||||
img_seg_overlayed, only_layout = self.visualize_model_output(res, self.img_org, self.task)
|
||||
if self.save:
|
||||
|
|
@ -710,10 +740,12 @@ class sbb_predict:
|
|||
)
|
||||
def main(image, dir_in, model, patches, save, save_layout, ground_truth, xml_file, cpu, out, min_area):
|
||||
assert image or dir_in, "Either a single image -i or a dir_in -di is required"
|
||||
with open(os.path.join(model,'config.json')) as f:
|
||||
|
||||
with open(os.path.join(model,'config_eynollah.json')) as f:
|
||||
config_params_model = json.load(f)
|
||||
|
||||
task = config_params_model['task']
|
||||
if task != 'classification' and task != 'reading_order' and task != "cnn-rnn-ocr":
|
||||
if task != 'classification' and task != 'reading_order' and task != "cnn-rnn-ocr" and task != "transformer-ocr":
|
||||
if image and not save:
|
||||
print("Error: You used one of segmentation or binarization task with image input but not set -s, you need a filename to save visualized output with -s")
|
||||
sys.exit(1)
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
import os
|
||||
import sys
|
||||
import json
|
||||
|
||||
import numpy as np
|
||||
import cv2
|
||||
import click
|
||||
|
||||
from eynollah.training.metrics import (
|
||||
|
|
@ -27,7 +28,8 @@ from eynollah.training.utils import (
|
|||
generate_data_from_folder_training,
|
||||
get_one_hot,
|
||||
provide_patches,
|
||||
return_number_of_total_training_data
|
||||
return_number_of_total_training_data,
|
||||
OCRDatasetYieldAugmentations
|
||||
)
|
||||
|
||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
||||
|
|
@ -41,11 +43,16 @@ from sklearn.metrics import f1_score
|
|||
from tensorflow.keras.callbacks import Callback
|
||||
from tensorflow.keras.layers import StringLookup
|
||||
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
import torch
|
||||
from transformers import TrOCRProcessor
|
||||
import evaluate
|
||||
from transformers import default_data_collator
|
||||
from transformers import VisionEncoderDecoderModel
|
||||
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
|
||||
|
||||
class SaveWeightsAfterSteps(Callback):
|
||||
def __init__(self, save_interval, save_path, _config):
|
||||
def __init__(self, save_interval, save_path, _config, characters_cnnrnn_ocr=None):
|
||||
super(SaveWeightsAfterSteps, self).__init__()
|
||||
self.save_interval = save_interval
|
||||
self.save_path = save_path
|
||||
|
|
@ -61,7 +68,10 @@ class SaveWeightsAfterSteps(Callback):
|
|||
|
||||
self.model.save(save_file)
|
||||
|
||||
with open(os.path.join(os.path.join(self.save_path, f"model_step_{self.step_count}"),"config.json"), "w") as fp:
|
||||
if characters_cnnrnn_ocr:
|
||||
os.system("cp "+characters_cnnrnn_ocr+" "+os.path.join(os.path.join(self.save_path, f"model_step_{self.step_count}"),"characters_org.txt"))
|
||||
|
||||
with open(os.path.join(os.path.join(self.save_path, f"model_step_{self.step_count}"),"config_eynollah.json"), "w") as fp:
|
||||
json.dump(self._config, fp) # encode dict into JSON
|
||||
print(f"saved model as steps {self.step_count} to {save_file}")
|
||||
|
||||
|
|
@ -477,7 +487,7 @@ def run(
|
|||
|
||||
model.save(os.path.join(dir_output,'model_'+str(i)))
|
||||
|
||||
with open(os.path.join(os.path.join(dir_output,'model_'+str(i)),"config.json"), "w") as fp:
|
||||
with open(os.path.join(os.path.join(dir_output,'model_'+str(i)),"config_eynollah.json"), "w") as fp:
|
||||
json.dump(_config, fp) # encode dict into JSON
|
||||
|
||||
#os.system('rm -rf '+dir_train_flowing)
|
||||
|
|
@ -537,7 +547,7 @@ def run(
|
|||
opt = tf.keras.optimizers.Adam(learning_rate=learning_rate)#1e-4)#(lr_schedule)
|
||||
model.compile(optimizer=opt)
|
||||
|
||||
save_weights_callback = SaveWeightsAfterSteps(save_interval, dir_output, _config) if save_interval else None
|
||||
save_weights_callback = SaveWeightsAfterSteps(save_interval, dir_output, _config, characters_cnnrnn_ocr=characters_txt_file) if save_interval else None
|
||||
|
||||
for i in tqdm(range(index_start, n_epochs + index_start)):
|
||||
if save_interval:
|
||||
|
|
@ -556,9 +566,125 @@ def run(
|
|||
|
||||
if i >=0:
|
||||
model.save( os.path.join(dir_output,'model_'+str(i) ))
|
||||
with open(os.path.join(os.path.join(dir_output,'model_'+str(i)),"config.json"), "w") as fp:
|
||||
os.system("cp "+characters_txt_file+" "+os.path.join(os.path.join(dir_output,'model_'+str(i)),"characters_org.txt"))
|
||||
with open(os.path.join(os.path.join(dir_output,'model_'+str(i)),"config_eynollah.json"), "w") as fp:
|
||||
json.dump(_config, fp) # encode dict into JSON
|
||||
|
||||
|
||||
elif task=="transformer-ocr":
|
||||
dir_img, dir_lab = get_dirs_or_files(dir_train)
|
||||
|
||||
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-printed")
|
||||
|
||||
ls_files_images = os.listdir(dir_img)
|
||||
|
||||
aug_multip = return_multiplier_based_on_augmnentations(augmentation, color_padding_rotation, rotation_not_90, blur_aug, degrading, bin_deg,
|
||||
brightening, padding_white, adding_rgb_foreground, adding_rgb_background, binarization,
|
||||
image_inversion, channels_shuffling, add_red_textlines, white_noise_strap, textline_skewing, textline_skewing_bin, textline_left_in_depth, textline_left_in_depth_bin, textline_right_in_depth, textline_right_in_depth_bin, textline_up_in_depth, textline_up_in_depth_bin, textline_down_in_depth, textline_down_in_depth_bin, pepper_bin_aug, pepper_aug, degrade_scales, number_of_backgrounds_per_image, thetha, thetha_padd, brightness, padd_colors, shuffle_indexes, pepper_indexes, skewing_amplitudes, blur_k, white_padds)
|
||||
|
||||
len_dataset = aug_multip*len(ls_files_images)
|
||||
|
||||
dataset = OCRDatasetYieldAugmentations(
|
||||
dir_img=dir_img,
|
||||
dir_img_bin=dir_img_bin,
|
||||
dir_lab=dir_lab,
|
||||
processor=processor,
|
||||
max_target_length=max_len,
|
||||
augmentation = augmentation,
|
||||
binarization = binarization,
|
||||
add_red_textlines = add_red_textlines,
|
||||
white_noise_strap = white_noise_strap,
|
||||
adding_rgb_foreground = adding_rgb_foreground,
|
||||
adding_rgb_background = adding_rgb_background,
|
||||
bin_deg = bin_deg,
|
||||
blur_aug = blur_aug,
|
||||
brightening = brightening,
|
||||
padding_white = padding_white,
|
||||
color_padding_rotation = color_padding_rotation,
|
||||
rotation_not_90 = rotation_not_90,
|
||||
degrading = degrading,
|
||||
channels_shuffling = channels_shuffling,
|
||||
textline_skewing = textline_skewing,
|
||||
textline_skewing_bin = textline_skewing_bin,
|
||||
textline_right_in_depth = textline_right_in_depth,
|
||||
textline_left_in_depth = textline_left_in_depth,
|
||||
textline_up_in_depth = textline_up_in_depth,
|
||||
textline_down_in_depth = textline_down_in_depth,
|
||||
textline_right_in_depth_bin = textline_right_in_depth_bin,
|
||||
textline_left_in_depth_bin = textline_left_in_depth_bin,
|
||||
textline_up_in_depth_bin = textline_up_in_depth_bin,
|
||||
textline_down_in_depth_bin = textline_down_in_depth_bin,
|
||||
pepper_aug = pepper_aug,
|
||||
pepper_bin_aug = pepper_bin_aug,
|
||||
list_all_possible_background_images=list_all_possible_background_images,
|
||||
list_all_possible_foreground_rgbs=list_all_possible_foreground_rgbs,
|
||||
blur_k = blur_k,
|
||||
degrade_scales = degrade_scales,
|
||||
white_padds = white_padds,
|
||||
thetha_padd = thetha_padd,
|
||||
thetha = thetha,
|
||||
brightness = brightness,
|
||||
padd_colors = padd_colors,
|
||||
number_of_backgrounds_per_image = number_of_backgrounds_per_image,
|
||||
shuffle_indexes = shuffle_indexes,
|
||||
pepper_indexes = pepper_indexes,
|
||||
skewing_amplitudes = skewing_amplitudes,
|
||||
dir_rgb_backgrounds = dir_rgb_backgrounds,
|
||||
dir_rgb_foregrounds = dir_rgb_foregrounds,
|
||||
len_data=len_dataset,
|
||||
)
|
||||
|
||||
# Create a DataLoader
|
||||
data_loader = torch.utils.data.DataLoader(dataset, batch_size=1)
|
||||
train_dataset = data_loader.dataset
|
||||
|
||||
|
||||
if continue_training:
|
||||
model = VisionEncoderDecoderModel.from_pretrained(dir_of_start_model)
|
||||
else:
|
||||
model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-printed")
|
||||
|
||||
|
||||
# set special tokens used for creating the decoder_input_ids from the labels
|
||||
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
|
||||
model.config.pad_token_id = processor.tokenizer.pad_token_id
|
||||
# make sure vocab size is set correctly
|
||||
model.config.vocab_size = model.config.decoder.vocab_size
|
||||
|
||||
# set beam search parameters
|
||||
model.config.eos_token_id = processor.tokenizer.sep_token_id
|
||||
model.config.max_length = max_len
|
||||
model.config.early_stopping = True
|
||||
model.config.no_repeat_ngram_size = 3
|
||||
model.config.length_penalty = 2.0
|
||||
model.config.num_beams = 4
|
||||
|
||||
|
||||
training_args = Seq2SeqTrainingArguments(
|
||||
predict_with_generate=True,
|
||||
num_train_epochs=n_epochs,
|
||||
learning_rate=learning_rate,
|
||||
per_device_train_batch_size=n_batch,
|
||||
fp16=True,
|
||||
output_dir=dir_output,
|
||||
logging_steps=2,
|
||||
save_steps=save_interval,
|
||||
)
|
||||
|
||||
|
||||
cer_metric = evaluate.load("cer")
|
||||
|
||||
# instantiate trainer
|
||||
trainer = Seq2SeqTrainer(
|
||||
model=model,
|
||||
tokenizer=processor.feature_extractor,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset,
|
||||
data_collator=default_data_collator,
|
||||
)
|
||||
trainer.train()
|
||||
|
||||
|
||||
elif task=='classification':
|
||||
configuration()
|
||||
model = resnet50_classifier(n_classes, input_height, input_width, weight_decay, pretraining)
|
||||
|
|
@ -609,10 +735,10 @@ def run(
|
|||
model_weight_averaged.set_weights(new_weights)
|
||||
|
||||
model_weight_averaged.save(os.path.join(dir_output,'model_ens_avg'))
|
||||
with open(os.path.join( os.path.join(dir_output,'model_ens_avg'), "config.json"), "w") as fp:
|
||||
with open(os.path.join( os.path.join(dir_output,'model_ens_avg'), "config_eynollah.json"), "w") as fp:
|
||||
json.dump(_config, fp) # encode dict into JSON
|
||||
|
||||
with open(os.path.join( os.path.join(dir_output,'model_best'), "config.json"), "w") as fp:
|
||||
with open(os.path.join( os.path.join(dir_output,'model_best'), "config_eynollah.json"), "w") as fp:
|
||||
json.dump(_config, fp) # encode dict into JSON
|
||||
|
||||
elif task=='reading_order':
|
||||
|
|
@ -645,7 +771,7 @@ def run(
|
|||
history = model.fit(generate_arrays_from_folder_reading_order(dir_flow_train_labels, dir_flow_train_imgs, n_batch, input_height, input_width, n_classes, thetha, augmentation), steps_per_epoch=num_rows / n_batch, verbose=1)
|
||||
model.save( os.path.join(dir_output,'model_'+str(i+indexer_start) ))
|
||||
|
||||
with open(os.path.join(os.path.join(dir_output,'model_'+str(i)),"config.json"), "w") as fp:
|
||||
with open(os.path.join(os.path.join(dir_output,'model_'+str(i)),"config_eynollah.json"), "w") as fp:
|
||||
json.dump(_config, fp) # encode dict into JSON
|
||||
'''
|
||||
if f1score>f1score_tot[0]:
|
||||
|
|
|
|||
|
|
@ -13,8 +13,12 @@ import tensorflow as tf
|
|||
from tensorflow.keras.utils import to_categorical
|
||||
from PIL import Image, ImageFile, ImageEnhance
|
||||
|
||||
import torch
|
||||
from torch.utils.data import IterableDataset
|
||||
|
||||
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||
|
||||
|
||||
def vectorize_label(label, char_to_num, padding_token, max_len):
|
||||
label = char_to_num(tf.strings.unicode_split(label, input_encoding="UTF-8"))
|
||||
length = tf.shape(label)[0]
|
||||
|
|
@ -76,6 +80,7 @@ def add_salt_and_pepper_noise(img, salt_prob, pepper_prob):
|
|||
|
||||
return noisy_image
|
||||
|
||||
|
||||
def invert_image(img):
|
||||
img_inv = 255 - img
|
||||
return img_inv
|
||||
|
|
@ -1668,3 +1673,411 @@ def return_multiplier_based_on_augmnentations(
|
|||
aug_multip += len(pepper_indexes)
|
||||
|
||||
return aug_multip
|
||||
|
||||
|
||||
class OCRDatasetYieldAugmentations(IterableDataset):
|
||||
def __init__(
|
||||
self,
|
||||
dir_img,
|
||||
dir_img_bin,
|
||||
dir_lab,
|
||||
processor,
|
||||
max_target_length=128,
|
||||
augmentation = None,
|
||||
binarization = None,
|
||||
add_red_textlines = None,
|
||||
white_noise_strap = None,
|
||||
adding_rgb_foreground = None,
|
||||
adding_rgb_background = None,
|
||||
bin_deg = None,
|
||||
blur_aug = None,
|
||||
brightening = None,
|
||||
padding_white = None,
|
||||
color_padding_rotation = None,
|
||||
rotation_not_90 = None,
|
||||
degrading = None,
|
||||
channels_shuffling = None,
|
||||
textline_skewing = None,
|
||||
textline_skewing_bin = None,
|
||||
textline_right_in_depth = None,
|
||||
textline_left_in_depth = None,
|
||||
textline_up_in_depth = None,
|
||||
textline_down_in_depth = None,
|
||||
textline_right_in_depth_bin = None,
|
||||
textline_left_in_depth_bin = None,
|
||||
textline_up_in_depth_bin = None,
|
||||
textline_down_in_depth_bin = None,
|
||||
pepper_aug = None,
|
||||
pepper_bin_aug = None,
|
||||
list_all_possible_background_images=None,
|
||||
list_all_possible_foreground_rgbs=None,
|
||||
blur_k = None,
|
||||
degrade_scales = None,
|
||||
white_padds = None,
|
||||
thetha_padd = None,
|
||||
thetha = None,
|
||||
brightness = None,
|
||||
padd_colors = None,
|
||||
number_of_backgrounds_per_image = None,
|
||||
shuffle_indexes = None,
|
||||
pepper_indexes = None,
|
||||
skewing_amplitudes = None,
|
||||
dir_rgb_backgrounds = None,
|
||||
dir_rgb_foregrounds = None,
|
||||
len_data=None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
images_dir (str): Path to the directory containing images.
|
||||
labels_dir (str): Path to the directory containing label text files.
|
||||
tokenizer: Tokenizer for processing labels.
|
||||
transform: Transformations applied after augmentation (e.g., ToTensor, normalization).
|
||||
image_size (tuple): Size to resize images to.
|
||||
max_seq_len (int): Maximum sequence length for tokenized labels.
|
||||
scales (list or None): List of scale factors to apply.
|
||||
"""
|
||||
self.dir_img = dir_img
|
||||
self.dir_img_bin = dir_img_bin
|
||||
self.dir_lab = dir_lab
|
||||
self.processor = processor
|
||||
self.max_target_length = max_target_length
|
||||
#self.scales = scales if scales else []
|
||||
|
||||
self.augmentation = augmentation
|
||||
self.binarization = binarization
|
||||
self.add_red_textlines = add_red_textlines
|
||||
self.white_noise_strap = white_noise_strap
|
||||
self.adding_rgb_foreground = adding_rgb_foreground
|
||||
self.adding_rgb_background = adding_rgb_background
|
||||
self.bin_deg = bin_deg
|
||||
self.blur_aug = blur_aug
|
||||
self.brightening = brightening
|
||||
self.padding_white = padding_white
|
||||
self.color_padding_rotation = color_padding_rotation
|
||||
self.rotation_not_90 = rotation_not_90
|
||||
self.degrading = degrading
|
||||
self.channels_shuffling = channels_shuffling
|
||||
self.textline_skewing = textline_skewing
|
||||
self.textline_skewing_bin = textline_skewing_bin
|
||||
self.textline_right_in_depth = textline_right_in_depth
|
||||
self.textline_left_in_depth = textline_left_in_depth
|
||||
self.textline_up_in_depth = textline_up_in_depth
|
||||
self.textline_down_in_depth = textline_down_in_depth
|
||||
self.textline_right_in_depth_bin = textline_right_in_depth_bin
|
||||
self.textline_left_in_depth_bin = textline_left_in_depth_bin
|
||||
self.textline_up_in_depth_bin = textline_up_in_depth_bin
|
||||
self.textline_down_in_depth_bin = textline_down_in_depth_bin
|
||||
self.pepper_aug = pepper_aug
|
||||
self.pepper_bin_aug = pepper_bin_aug
|
||||
self.list_all_possible_background_images=list_all_possible_background_images
|
||||
self.list_all_possible_foreground_rgbs=list_all_possible_foreground_rgbs
|
||||
self.blur_k = blur_k
|
||||
self.degrade_scales = degrade_scales
|
||||
self.white_padds = white_padds
|
||||
self.thetha_padd = thetha_padd
|
||||
self.thetha = thetha
|
||||
self.brightness = brightness
|
||||
self.padd_colors = padd_colors
|
||||
self.number_of_backgrounds_per_image = number_of_backgrounds_per_image
|
||||
self.shuffle_indexes = shuffle_indexes
|
||||
self.pepper_indexes = pepper_indexes
|
||||
self.skewing_amplitudes = skewing_amplitudes
|
||||
self.dir_rgb_backgrounds = dir_rgb_backgrounds
|
||||
self.dir_rgb_foregrounds = dir_rgb_foregrounds
|
||||
self.image_files = os.listdir(dir_img)#sorted([f for f in os.listdir(images_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))])
|
||||
self.len_data = len_data
|
||||
#assert len(self.image_files) == len(self.label_files), "Number of images and labels must match!"
|
||||
|
||||
def __len__(self):
|
||||
return self.len_data
|
||||
|
||||
def __iter__(self):
|
||||
for img_file in self.image_files:
|
||||
# Load image
|
||||
f_name = img_file.split('.')[0]
|
||||
|
||||
txt_inp = open(os.path.join(self.dir_lab, f_name+'.txt'),'r').read().split('\n')[0]
|
||||
|
||||
img = cv2.imread(os.path.join(self.dir_img, img_file))
|
||||
img = img.astype(np.uint8)
|
||||
|
||||
|
||||
if self.dir_img_bin:
|
||||
img_bin_corr = cv2.imread(os.path.join(self.dir_img_bin, f_name+'.png') )
|
||||
img_bin_corr = img_bin_corr.astype(np.uint8)
|
||||
else:
|
||||
img_bin_corr = None
|
||||
|
||||
|
||||
labels = self.processor.tokenizer(txt_inp,
|
||||
padding="max_length",
|
||||
max_length=self.max_target_length).input_ids
|
||||
|
||||
labels = [label if label != self.processor.tokenizer.pad_token_id else -100 for label in labels]
|
||||
|
||||
|
||||
if self.augmentation:
|
||||
pixel_values = self.processor(Image.fromarray(img), return_tensors="pt").pixel_values
|
||||
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
|
||||
yield encoding
|
||||
|
||||
|
||||
if self.color_padding_rotation:
|
||||
for index, thetha_ind in enumerate(self.thetha_padd):
|
||||
for padd_col in self.padd_colors:
|
||||
img_out = rotation_not_90_func_single_image(do_padding_for_ocr(img, 1.2, padd_col), thetha_ind)
|
||||
img_out = img_out.astype(np.uint8)
|
||||
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
|
||||
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
|
||||
yield encoding
|
||||
|
||||
if self.rotation_not_90:
|
||||
for index, thetha_ind in enumerate(self.thetha):
|
||||
img_out = rotation_not_90_func_single_image(img, thetha_ind)
|
||||
img_out = img_out.astype(np.uint8)
|
||||
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
|
||||
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
|
||||
yield encoding
|
||||
|
||||
|
||||
if self.blur_aug:
|
||||
for index, blur_type in enumerate(self.blur_k):
|
||||
img_out = bluring(img, blur_type)
|
||||
img_out = img_out.astype(np.uint8)
|
||||
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
|
||||
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
|
||||
yield encoding
|
||||
|
||||
if self.degrading:
|
||||
for index, deg_scale_ind in enumerate(self.degrade_scales):
|
||||
try:
|
||||
img_out = do_degrading(img, deg_scale_ind)
|
||||
except:
|
||||
img_out = np.copy(img)
|
||||
img_out = img_out.astype(np.uint8)
|
||||
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
|
||||
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
|
||||
yield encoding
|
||||
|
||||
if self.bin_deg:
|
||||
for index, deg_scale_ind in enumerate(self.degrade_scales):
|
||||
try:
|
||||
img_out = self.do_degrading(img_bin_corr, deg_scale_ind)
|
||||
except:
|
||||
img_out = np.copy(img_bin_corr)
|
||||
img_out = img_out.astype(np.uint8)
|
||||
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
|
||||
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
|
||||
yield encoding
|
||||
|
||||
|
||||
if self.brightening:
|
||||
for index, bright_scale_ind in enumerate(self.brightness):
|
||||
try:
|
||||
img_out = do_brightening(dir_img, bright_scale_ind)
|
||||
except:
|
||||
img_out = np.copy(img)
|
||||
img_out = img_out.astype(np.uint8)
|
||||
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
|
||||
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
|
||||
yield encoding
|
||||
|
||||
|
||||
if self.padding_white:
|
||||
for index, padding_size in enumerate(self.white_padds):
|
||||
for padd_col in self.padd_colors:
|
||||
img_out = do_padding_for_ocr(img, padding_size, padd_col)
|
||||
img_out = img_out.astype(np.uint8)
|
||||
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
|
||||
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
|
||||
yield encoding
|
||||
|
||||
|
||||
if self.adding_rgb_foreground:
|
||||
for i_n in range(self.number_of_backgrounds_per_image):
|
||||
background_image_chosen_name = random.choice(self.list_all_possible_background_images)
|
||||
foreground_rgb_chosen_name = random.choice(self.list_all_possible_foreground_rgbs)
|
||||
|
||||
img_rgb_background_chosen = cv2.imread(self.dir_rgb_backgrounds + '/' + background_image_chosen_name)
|
||||
foreground_rgb_chosen = np.load(self.dir_rgb_foregrounds + '/' + foreground_rgb_chosen_name)
|
||||
|
||||
img_out = return_binary_image_with_given_rgb_background_and_given_foreground_rgb(img_bin_corr, img_rgb_background_chosen, foreground_rgb_chosen)
|
||||
img_out = img_out.astype(np.uint8)
|
||||
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
|
||||
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
|
||||
yield encoding
|
||||
|
||||
|
||||
|
||||
|
||||
if self.adding_rgb_background:
|
||||
for i_n in range(self.number_of_backgrounds_per_image):
|
||||
background_image_chosen_name = random.choice(self.list_all_possible_background_images)
|
||||
img_rgb_background_chosen = cv2.imread(self.dir_rgb_backgrounds + '/' + background_image_chosen_name)
|
||||
img_out = return_binary_image_with_given_rgb_background(img_bin_corr, img_rgb_background_chosen)
|
||||
img_out = img_out.astype(np.uint8)
|
||||
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
|
||||
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
|
||||
yield encoding
|
||||
|
||||
if self.binarization:
|
||||
pixel_values = self.processor(Image.fromarray(img_bin_corr), return_tensors="pt").pixel_values
|
||||
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
|
||||
yield encoding
|
||||
|
||||
|
||||
if self.channels_shuffling:
|
||||
for shuffle_index in self.shuffle_indexes:
|
||||
img_out = return_shuffled_channels(img, shuffle_index)
|
||||
img_out = img_out.astype(np.uint8)
|
||||
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
|
||||
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
|
||||
yield encoding
|
||||
|
||||
if self.add_red_textlines:
|
||||
img_out = return_image_with_red_elements(img, img_bin_corr)
|
||||
img_out = img_out.astype(np.uint8)
|
||||
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
|
||||
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
|
||||
yield encoding
|
||||
|
||||
if self.white_noise_strap:
|
||||
img_out = return_image_with_strapped_white_noises(img)
|
||||
img_out = img_out.astype(np.uint8)
|
||||
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
|
||||
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
|
||||
yield encoding
|
||||
|
||||
if self.textline_skewing:
|
||||
for index, des_scale_ind in enumerate(self.skewing_amplitudes):
|
||||
try:
|
||||
img_out = do_deskewing(img, des_scale_ind)
|
||||
except:
|
||||
img_out = np.copy(img)
|
||||
img_out = img_out.astype(np.uint8)
|
||||
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
|
||||
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
|
||||
yield encoding
|
||||
|
||||
if self.textline_skewing_bin:
|
||||
for index, des_scale_ind in enumerate(self.skewing_amplitudes):
|
||||
try:
|
||||
img_out = do_deskewing(img_bin_corr, des_scale_ind)
|
||||
except:
|
||||
img_out = np.copy(img_bin_corr)
|
||||
img_out = img_out.astype(np.uint8)
|
||||
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
|
||||
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
|
||||
yield encoding
|
||||
|
||||
|
||||
if self.textline_left_in_depth:
|
||||
try:
|
||||
img_out = do_direction_in_depth(img, 'left')
|
||||
except:
|
||||
img_out = np.copy(img)
|
||||
img_out = img_out.astype(np.uint8)
|
||||
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
|
||||
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
|
||||
yield encoding
|
||||
|
||||
|
||||
if self.textline_left_in_depth_bin:
|
||||
try:
|
||||
img_out = do_direction_in_depth(img_bin_corr, 'left')
|
||||
except:
|
||||
img_out = np.copy(img_bin_corr)
|
||||
img_out = img_out.astype(np.uint8)
|
||||
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
|
||||
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
|
||||
yield encoding
|
||||
|
||||
|
||||
if self.textline_right_in_depth:
|
||||
try:
|
||||
img_out = do_direction_in_depth(img, 'right')
|
||||
except:
|
||||
img_out = np.copy(img)
|
||||
img_out = img_out.astype(np.uint8)
|
||||
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
|
||||
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
|
||||
yield encoding
|
||||
|
||||
|
||||
if self.textline_right_in_depth_bin:
|
||||
try:
|
||||
img_out = do_direction_in_depth(img_bin_corr, 'right')
|
||||
except:
|
||||
img_out = np.copy(img_bin_corr)
|
||||
img_out = img_out.astype(np.uint8)
|
||||
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
|
||||
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
|
||||
yield encoding
|
||||
|
||||
|
||||
if self.textline_up_in_depth:
|
||||
try:
|
||||
img_out = do_direction_in_depth(img, 'up')
|
||||
except:
|
||||
img_out = np.copy(img)
|
||||
img_out = img_out.astype(np.uint8)
|
||||
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
|
||||
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
|
||||
yield encoding
|
||||
|
||||
|
||||
if self.textline_up_in_depth_bin:
|
||||
try:
|
||||
img_out = do_direction_in_depth(img_bin_corr, 'up')
|
||||
except:
|
||||
img_out = np.copy(img_bin_corr)
|
||||
img_out = img_out.astype(np.uint8)
|
||||
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
|
||||
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
|
||||
yield encoding
|
||||
|
||||
|
||||
if self.textline_down_in_depth:
|
||||
try:
|
||||
img_out = do_direction_in_depth(img, 'down')
|
||||
except:
|
||||
img_out = np.copy(img)
|
||||
img_out = img_out.astype(np.uint8)
|
||||
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
|
||||
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
|
||||
yield encoding
|
||||
|
||||
|
||||
if self.textline_down_in_depth_bin:
|
||||
try:
|
||||
img_out = do_direction_in_depth(img_bin_corr, 'down')
|
||||
except:
|
||||
img_out = np.copy(img_bin_corr)
|
||||
img_out = img_out.astype(np.uint8)
|
||||
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
|
||||
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
|
||||
yield encoding
|
||||
|
||||
if self.pepper_bin_aug:
|
||||
for index, pepper_ind in enumerate(self.pepper_indexes):
|
||||
img_out = add_salt_and_pepper_noise(img_bin_corr, pepper_ind, pepper_ind)
|
||||
img_out = img_out.astype(np.uint8)
|
||||
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
|
||||
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
|
||||
yield encoding
|
||||
|
||||
|
||||
if self.pepper_aug:
|
||||
for index, pepper_ind in enumerate(self.pepper_indexes):
|
||||
img_out = add_salt_and_pepper_noise(img, pepper_ind, pepper_ind)
|
||||
img_out = img_out.astype(np.uint8)
|
||||
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
|
||||
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
|
||||
yield encoding
|
||||
|
||||
|
||||
|
||||
else:
|
||||
pixel_values = self.processor(Image.fromarray(img), return_tensors="pt").pixel_values
|
||||
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
|
||||
yield encoding
|
||||
|
|
|
|||
|
|
@ -21,6 +21,11 @@ from tensorflow.keras.layers import *
|
|||
import click
|
||||
import logging
|
||||
|
||||
from transformers import TrOCRProcessor
|
||||
from PIL import Image
|
||||
import torch
|
||||
from transformers import VisionEncoderDecoderModel
|
||||
|
||||
|
||||
class Patches(layers.Layer):
|
||||
def __init__(self, patch_size_x, patch_size_y):
|
||||
|
|
@ -92,10 +97,25 @@ def start_new_session():
|
|||
tensorflow_backend.set_session(session)
|
||||
return session
|
||||
|
||||
def run_ensembling(dir_models, out):
|
||||
def run_ensembling(dir_models, out, framework):
|
||||
ls_models = os.listdir(dir_models)
|
||||
if framework=="torch":
|
||||
models = []
|
||||
sd_models = []
|
||||
|
||||
for model_name in ls_models:
|
||||
model = VisionEncoderDecoderModel.from_pretrained(os.path.join(dir_models,model_name))
|
||||
models.append(model)
|
||||
sd_models.append(model.state_dict())
|
||||
for key in sd_models[0]:
|
||||
sd_models[0][key] = sum(sd[key] for sd in sd_models) / len(sd_models)
|
||||
|
||||
model.load_state_dict(sd_models[0])
|
||||
os.system("mkdir "+out)
|
||||
torch.save(model.state_dict(), os.path.join(out, "pytorch_model.bin"))
|
||||
os.system('cp '+os.path.join(os.path.join(dir_models,model_name) , "config_eynollah.json ")+out)
|
||||
|
||||
else:
|
||||
weights=[]
|
||||
|
||||
for model_name in ls_models:
|
||||
|
|
@ -115,7 +135,8 @@ def run_ensembling(dir_models, out):
|
|||
|
||||
model.set_weights(new_weights)
|
||||
model.save(out)
|
||||
os.system('cp '+os.path.join(os.path.join(dir_models,model_name) , "config.json ")+out)
|
||||
os.system('cp '+os.path.join(os.path.join(dir_models,model_name) , "config_eynollah.json ")+out)
|
||||
os.system('cp '+os.path.join(os.path.join(dir_models,model_name) , "characters_org.txt ")+out)
|
||||
|
||||
@click.command()
|
||||
@click.option(
|
||||
|
|
@ -130,7 +151,12 @@ def run_ensembling(dir_models, out):
|
|||
help="output directory where ensembled model will be written.",
|
||||
type=click.Path(exists=False, file_okay=False),
|
||||
)
|
||||
@click.option(
|
||||
"--framework",
|
||||
"-fw",
|
||||
help="this parameter gets tensorflow or torch as model framework",
|
||||
)
|
||||
|
||||
def main(dir_models, out):
|
||||
run_ensembling(dir_models, out)
|
||||
def main(dir_models, out, framework):
|
||||
run_ensembling(dir_models, out, framework)
|
||||
|
||||
|
|
|
|||
|
|
@ -9,8 +9,8 @@ else:
|
|||
import importlib.resources as importlib_resources
|
||||
|
||||
|
||||
def get_font():
|
||||
def get_font(font_size):
|
||||
#font_path = "Charis-7.000/Charis-Regular.ttf" # Make sure this file exists!
|
||||
font = importlib_resources.files(__package__) / "../Charis-Regular.ttf"
|
||||
with importlib_resources.as_file(font) as font:
|
||||
return ImageFont.truetype(font=font, size=40)
|
||||
return ImageFont.truetype(font=font, size=font_size)
|
||||
|
|
|
|||
|
|
@ -1,17 +1,17 @@
|
|||
{
|
||||
"backbone_type" : "transformer",
|
||||
"task": "cnn-rnn-ocr",
|
||||
"task": "transformer-ocr",
|
||||
"n_classes" : 2,
|
||||
"max_len": 280,
|
||||
"n_epochs" : 3,
|
||||
"max_len": 192,
|
||||
"n_epochs" : 1,
|
||||
"input_height" : 32,
|
||||
"input_width" : 512,
|
||||
"weight_decay" : 1e-6,
|
||||
"n_batch" : 4,
|
||||
"n_batch" : 1,
|
||||
"learning_rate": 1e-5,
|
||||
"save_interval": 1500,
|
||||
"patches" : false,
|
||||
"pretraining" : true,
|
||||
"pretraining" : false,
|
||||
"augmentation" : true,
|
||||
"flip_aug" : false,
|
||||
"blur_aug" : true,
|
||||
|
|
@ -77,7 +77,6 @@
|
|||
"dir_output": "/home/vahid/extracted_lines/1919_bin/output",
|
||||
"dir_rgb_backgrounds": "/home/vahid/Documents/1_2_test_eynollah/set_rgb_background",
|
||||
"dir_rgb_foregrounds": "/home/vahid/Documents/1_2_test_eynollah/out_set_rgb_foreground",
|
||||
"dir_img_bin": "/home/vahid/extracted_lines/1919_bin/images_bin",
|
||||
"characters_txt_file":"/home/vahid/Downloads/models_eynollah/model_eynollah_ocr_cnnrnn_20250930/characters_org.txt"
|
||||
"dir_img_bin": "/home/vahid/extracted_lines/1919_bin/images_bin"
|
||||
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,3 +4,8 @@ numpy <1.24.0
|
|||
tqdm
|
||||
imutils
|
||||
scipy
|
||||
torch
|
||||
evaluate
|
||||
accelerate
|
||||
jiwer
|
||||
transformers <= 4.30.2
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue