diff --git a/src/eynollah/eynollah_ocr.py b/src/eynollah/eynollah_ocr.py index 3c918e5..173ba46 100644 --- a/src/eynollah/eynollah_ocr.py +++ b/src/eynollah/eynollah_ocr.py @@ -70,6 +70,7 @@ class Eynollah_ocr: self.model_zoo.get('ocr').to(self.device) else: self.model_zoo.load_model('ocr', '') + self.input_shape = self.model_zoo.get('ocr').input_shape[1:3] self.model_zoo.load_model('num_to_char') self.model_zoo.load_model('characters') self.end_character = len(self.model_zoo.get('characters', list)) + 2 @@ -657,7 +658,7 @@ class Eynollah_ocr: if out_image_with_text: image_text = Image.new("RGB", (img.shape[1], img.shape[0]), "white") draw = ImageDraw.Draw(image_text) - font = get_font() + font = get_font(font_size=40) for indexer_text, bb_ind in enumerate(total_bb_coordinates): x_bb = bb_ind[0] @@ -823,8 +824,8 @@ class Eynollah_ocr: page_ns=page_ns, img_bin=img_bin, - image_width=512, - image_height=32, + image_width=self.input_shape[1], + image_height=self.input_shape[0], ) self.write_ocr( diff --git a/src/eynollah/training/cli.py b/src/eynollah/training/cli.py index 3718275..862d212 100644 --- a/src/eynollah/training/cli.py +++ b/src/eynollah/training/cli.py @@ -10,6 +10,7 @@ from .inference import main as inference_cli from .train import ex from .extract_line_gt import linegt_cli from .weights_ensembling import main as ensemble_cli +from .generate_or_update_cnn_rnn_ocr_character_list import main as update_ocr_characters_cli @click.command(context_settings=dict( ignore_unknown_options=True, @@ -28,3 +29,4 @@ main.add_command(inference_cli, 'inference') main.add_command(train_cli, 'train') main.add_command(linegt_cli, 'export_textline_images_and_text') main.add_command(ensemble_cli, 'ensembling') +main.add_command(update_ocr_characters_cli, 'generate_or_update_cnn_rnn_ocr_character_list') diff --git a/src/eynollah/training/extract_line_gt.py b/src/eynollah/training/extract_line_gt.py index 58fc253..3d508bc 100644 --- a/src/eynollah/training/extract_line_gt.py +++ b/src/eynollah/training/extract_line_gt.py @@ -50,6 +50,18 @@ from ..utils import is_image_filename is_flag=True, help="if this parameter set to true, cropped textline images will not be masked with textline contour.", ) +@click.option( + "--exclude_vertical_lines", + "-exv", + is_flag=True, + help="if this parameter set to true, vertical textline images will be excluded.", +) +@click.option( + "--page_alto", + "-alto", + is_flag=True, + help="If this parameter is set to True, text line image cropping and text extraction are performed using PAGE/ALTO files. Otherwise, the default method for PAGE XML files is used.", +) def linegt_cli( image, dir_in, @@ -57,6 +69,8 @@ def linegt_cli( dir_out, pref_of_dataset, do_not_mask_with_textline_contour, + exclude_vertical_lines, + page_alto, ): assert bool(dir_in) ^ bool(image), "Set --dir-in or --image-filename, not both" if dir_in: @@ -70,65 +84,149 @@ def linegt_cli( for dir_img in ls_imgs: file_name = Path(dir_img).stem dir_xml = os.path.join(dir_xmls, file_name + '.xml') - img = cv2.imread(dir_img) + + if page_alto: + h, w = img.shape[:2] + + tree = ET.parse(dir_xml) + root = tree.getroot() - total_bb_coordinates = [] + NS = {"alto": "http://www.loc.gov/standards/alto/ns-v4#"} - tree1 = ET.parse(dir_xml, parser=ET.XMLParser(encoding="utf-8")) - root1 = tree1.getroot() - alltags = [elem.tag for elem in root1.iter()] + results = [] + + indexer_textlines = 0 + for line in root.findall(".//alto:TextLine", NS): + string_el = line.find("alto:String", NS) + textline_text = string_el.attrib["CONTENT"] if string_el is not None else None - name_space = alltags[0].split('}')[0] - name_space = name_space.split('{')[1] + polygon_el = line.find("alto:Shape/alto:Polygon", NS) + if polygon_el is None: + continue - region_tags = np.unique([x for x in alltags if x.endswith('TextRegion')]) + points = polygon_el.attrib["POINTS"].split() + coords = [ + (int(points[i]), int(points[i + 1])) + for i in range(0, len(points), 2) + ] + + coords = np.array(coords, dtype=np.int32) + x, y, w, h = cv2.boundingRect(coords) + + + if exclude_vertical_lines and h > 1.4 * w: + img_crop = None + continue + + img_poly_on_img = np.copy(img) - cropped_lines_region_indexer = [] + mask_poly = np.zeros(img.shape) + mask_poly = cv2.fillPoly(mask_poly, pts=[coords], color=(1, 1, 1)) - indexer_text_region = 0 - indexer_textlines = 0 - # FIXME: non recursive, use OCR-D PAGE generateDS API. Or use an existing tool for this purpose altogether - for nn in root1.iter(region_tags): - for child_textregion in nn: - if child_textregion.tag.endswith("TextLine"): - for child_textlines in child_textregion: - if child_textlines.tag.endswith("Coords"): - cropped_lines_region_indexer.append(indexer_text_region) - p_h = child_textlines.attrib['points'].split(' ') - textline_coords = np.array([[int(x.split(',')[0]), int(x.split(',')[1])] for x in p_h]) + mask_poly = mask_poly[y : y + h, x : x + w, :] + img_crop = img_poly_on_img[y : y + h, x : x + w, :] - x, y, w, h = cv2.boundingRect(textline_coords) + if not do_not_mask_with_textline_contour: + img_crop[mask_poly == 0] = 255 - total_bb_coordinates.append([x, y, w, h]) + if img_crop.shape[0] == 0 or img_crop.shape[1] == 0: + img_crop = None + continue + + if textline_text and img_crop is not None: + base_name = os.path.join( + dir_out, file_name + '_line_' + str(indexer_textlines) + ) + if pref_of_dataset: + base_name += '_' + pref_of_dataset + if not do_not_mask_with_textline_contour: + base_name += '_masked' - img_poly_on_img = np.copy(img) + with open(base_name + '.txt', 'w') as text_file: + text_file.write(textline_text) + cv2.imwrite(base_name + '.png', img_crop) + indexer_textlines += 1 + + - mask_poly = np.zeros(img.shape) - mask_poly = cv2.fillPoly(mask_poly, pts=[textline_coords], color=(1, 1, 1)) - mask_poly = mask_poly[y : y + h, x : x + w, :] - img_crop = img_poly_on_img[y : y + h, x : x + w, :] - if not do_not_mask_with_textline_contour: - img_crop[mask_poly == 0] = 255 - if img_crop.shape[0] == 0 or img_crop.shape[1] == 0: - continue - if child_textlines.tag.endswith("TextEquiv"): - for cheild_text in child_textlines: - if cheild_text.tag.endswith("Unicode"): - textline_text = cheild_text.text - if textline_text: - base_name = os.path.join( - dir_out, file_name + '_line_' + str(indexer_textlines) - ) - if pref_of_dataset: - base_name += '_' + pref_of_dataset - if not do_not_mask_with_textline_contour: - base_name += '_masked' + - with open(base_name + '.txt', 'w') as text_file: - text_file.write(textline_text) - cv2.imwrite(base_name + '.png', img_crop) - indexer_textlines += 1 + + + + + + + + + else: + total_bb_coordinates = [] + + tree = ET.parse(dir_xml, parser=ET.XMLParser(encoding="utf-8")) + root = tree.getroot() + alltags = [elem.tag for elem in root.iter()] + + name_space = alltags[0].split('}')[0] + name_space = name_space.split('{')[1] + + region_tags = np.unique([x for x in alltags if x.endswith('TextRegion')]) + + cropped_lines_region_indexer = [] + + indexer_text_region = 0 + indexer_textlines = 0 + # FIXME: non recursive, use OCR-D PAGE generateDS API. Or use an existing tool for this purpose altogether + for nn in root.iter(region_tags): + for child_textregion in nn: + if child_textregion.tag.endswith("TextLine"): + for child_textlines in child_textregion: + if child_textlines.tag.endswith("Coords"): + cropped_lines_region_indexer.append(indexer_text_region) + p_h = child_textlines.attrib['points'].split(' ') + textline_coords = np.array([[int(x.split(',')[0]), int(x.split(',')[1])] for x in p_h]) + + x, y, w, h = cv2.boundingRect(textline_coords) + + if exclude_vertical_lines and h > 1.4 * w: + img_crop = None + continue + + total_bb_coordinates.append([x, y, w, h]) + + img_poly_on_img = np.copy(img) + + mask_poly = np.zeros(img.shape) + mask_poly = cv2.fillPoly(mask_poly, pts=[textline_coords], color=(1, 1, 1)) + + mask_poly = mask_poly[y : y + h, x : x + w, :] + img_crop = img_poly_on_img[y : y + h, x : x + w, :] + + if not do_not_mask_with_textline_contour: + img_crop[mask_poly == 0] = 255 + + if img_crop.shape[0] == 0 or img_crop.shape[1] == 0: + img_crop = None + continue + + + if child_textlines.tag.endswith("TextEquiv"): + for cheild_text in child_textlines: + if cheild_text.tag.endswith("Unicode"): + textline_text = cheild_text.text + if textline_text and img_crop is not None: + base_name = os.path.join( + dir_out, file_name + '_line_' + str(indexer_textlines) + ) + if pref_of_dataset: + base_name += '_' + pref_of_dataset + if not do_not_mask_with_textline_contour: + base_name += '_masked' + + with open(base_name + '.txt', 'w') as text_file: + text_file.write(textline_text) + cv2.imwrite(base_name + '.png', img_crop) + indexer_textlines += 1 diff --git a/src/eynollah/training/generate_gt_for_training.py b/src/eynollah/training/generate_gt_for_training.py index 30abd04..28f3f1c 100644 --- a/src/eynollah/training/generate_gt_for_training.py +++ b/src/eynollah/training/generate_gt_for_training.py @@ -6,6 +6,7 @@ from pathlib import Path from PIL import Image, ImageDraw, ImageFont import cv2 import numpy as np +from eynollah.utils.font import get_font from eynollah.training.gt_gen_utils import ( filter_contours_area_of_image, @@ -477,7 +478,7 @@ def visualize_layout_segmentation(xml_file, dir_xml, dir_out, dir_imgs): co_text, co_graphic, co_sep, co_img, co_table, co_map, co_noise, y_len, x_len = get_layout_contours_for_visualization(xml_file) - added_image = visualize_image_from_contours_layout(co_text['paragraph'], co_text['header']+co_text['heading'], co_text['drop-capital'], co_sep, co_img, co_text['marginalia'], co_table, img) + added_image = visualize_image_from_contours_layout(co_text['paragraph'], co_text['header']+co_text['heading'], co_text['drop-capital'], co_sep, co_img, co_text['marginalia'], co_table, co_map, img) cv2.imwrite(os.path.join(dir_out, f_name+'.png'), added_image) @@ -514,8 +515,8 @@ def visualize_ocr_text(xml_file, dir_xml, dir_out): else: xml_files_ind = [xml_file] - font_path = "Charis-7.000/Charis-Regular.ttf" # Make sure this file exists! - font = ImageFont.truetype(font_path, 40) + ###font_path = "Charis-7.000/Charis-Regular.ttf" # Make sure this file exists! + font = get_font(font_size=40)#ImageFont.truetype(font_path, 40) for ind_xml in tqdm(xml_files_ind): indexer = 0 @@ -552,11 +553,11 @@ def visualize_ocr_text(xml_file, dir_xml, dir_out): is_vertical = h > 2*w # Check orientation - font = fit_text_single_line(draw, ocr_texts[index], font_path, w, int(h*0.4) ) + font = fit_text_single_line(draw, ocr_texts[index], w, int(h*0.4) ) if is_vertical: - vertical_font = fit_text_single_line(draw, ocr_texts[index], font_path, h, int(w * 0.8)) + vertical_font = fit_text_single_line(draw, ocr_texts[index], h, int(w * 0.8)) text_img = Image.new("RGBA", (h, w), (255, 255, 255, 0)) # Note: dimensions are swapped text_draw = ImageDraw.Draw(text_img) diff --git a/src/eynollah/training/generate_or_update_cnn_rnn_ocr_character_list.py b/src/eynollah/training/generate_or_update_cnn_rnn_ocr_character_list.py new file mode 100644 index 0000000..8620515 --- /dev/null +++ b/src/eynollah/training/generate_or_update_cnn_rnn_ocr_character_list.py @@ -0,0 +1,59 @@ +import os +import numpy as np +import json +import click +import logging + + + +def run_character_list_update(dir_labels, out, current_character_list): + ls_labels = os.listdir(dir_labels) + ls_labels = [ind for ind in ls_labels if ind.endswith('.txt')] + + if current_character_list: + with open(current_character_list, 'r') as f_name: + characters = json.load(f_name) + + characters = set(characters) + else: + characters = set() + + + for ind in ls_labels: + label = open(os.path.join(dir_labels,ind),'r').read().split('\n')[0] + + for char in label: + characters.add(char) + + + characters = sorted(list(set(characters))) + + with open(out, 'w') as f_name: + json.dump(characters, f_name) + + +@click.command() +@click.option( + "--dir_labels", + "-dl", + help="directory of labels which are .txt files", + type=click.Path(exists=True, file_okay=False), + required=True, +) +@click.option( + "--current_character_list", + "-ccl", + help="existing character list in a .txt file that needs to be updated with a set of labels", + type=click.Path(exists=True, file_okay=True), + required=False, +) +@click.option( + "--out", + "-o", + help="An output .txt file where the generated or updated character list will be written", + type=click.Path(exists=False, file_okay=True), +) + +def main(dir_labels, out, current_character_list): + run_character_list_update(dir_labels, out, current_character_list) + diff --git a/src/eynollah/training/gt_gen_utils.py b/src/eynollah/training/gt_gen_utils.py index 50bb6fa..1e15b3b 100644 --- a/src/eynollah/training/gt_gen_utils.py +++ b/src/eynollah/training/gt_gen_utils.py @@ -7,7 +7,7 @@ import cv2 from shapely import geometry from pathlib import Path from PIL import ImageFont - +from eynollah.utils.font import get_font KERNEL = np.ones((5, 5), np.uint8) @@ -350,11 +350,11 @@ def get_textline_contours_and_ocr_text(xml_file): ocr_textlines.append(ocr_text_in[0]) return co_use_case, y_len, x_len, ocr_textlines -def fit_text_single_line(draw, text, font_path, max_width, max_height): +def fit_text_single_line(draw, text, max_width, max_height): initial_font_size = 50 font_size = initial_font_size while font_size > 10: # Minimum font size - font = ImageFont.truetype(font_path, font_size) + font = get_font(font_size=font_size)# ImageFont.truetype(font_path, font_size) text_bbox = draw.textbbox((0, 0), text, font=font) # Get text bounding box text_width = text_bbox[2] - text_bbox[0] text_height = text_bbox[3] - text_bbox[1] @@ -364,7 +364,7 @@ def fit_text_single_line(draw, text, font_path, max_width, max_height): font_size -= 2 # Reduce font size and retry - return ImageFont.truetype(font_path, 10) # Smallest font fallback + return get_font(font_size=10)#ImageFont.truetype(font_path, 10) # Smallest font fallback def get_layout_contours_for_visualization(xml_file): tree1 = ET.parse(xml_file, parser = ET.XMLParser(encoding='utf-8')) diff --git a/src/eynollah/training/inference.py b/src/eynollah/training/inference.py index f74e9e1..dc7979a 100644 --- a/src/eynollah/training/inference.py +++ b/src/eynollah/training/inference.py @@ -12,6 +12,7 @@ from keras.models import Model, load_model from keras import backend as K import click from tensorflow.python.keras import backend as tensorflow_backend +from tensorflow.keras.layers import StringLookup import xml.etree.ElementTree as ET from .gt_gen_utils import ( @@ -169,6 +170,25 @@ class sbb_predict: self.model = tf.keras.models.Model( self.model.get_layer(name = "image").input, self.model.get_layer(name = "dense2").output) + + assert isinstance(self.model, Model) + + elif self.task == "transformer-ocr": + import torch + from transformers import VisionEncoderDecoderModel + from transformers import TrOCRProcessor + + self.model = VisionEncoderDecoderModel.from_pretrained(self.model_dir) + self.processor = TrOCRProcessor.from_pretrained(self.model_dir) + + if self.cpu: + self.device = torch.device('cpu') + else: + self.device = torch.device('cuda:0') + + self.model.to(self.device) + + assert isinstance(self.model, torch.nn.Module) else: config = tf.compat.v1.ConfigProto() config.gpu_options.allow_growth = True @@ -176,15 +196,15 @@ class sbb_predict: session = tf.compat.v1.Session(config=config) # tf.InteractiveSession() tensorflow_backend.set_session(session) - - ##if self.weights_dir!=None: - ##self.model.load_weights(self.weights_dir) + self.model = load_model(self.model_dir , compile=False,custom_objects = {"PatchEncoder": PatchEncoder, "Patches": Patches}) + + if self.task != 'classification' and self.task != 'reading_order': + self.img_height=self.model.layers[len(self.model.layers)-1].output_shape[1] + self.img_width=self.model.layers[len(self.model.layers)-1].output_shape[2] + self.n_classes=self.model.layers[len(self.model.layers)-1].output_shape[3] - assert isinstance(self.model, Model) - if self.task != 'classification' and self.task != 'reading_order': - self.img_height=self.model.layers[len(self.model.layers)-1].output_shape[1] - self.img_width=self.model.layers[len(self.model.layers)-1].output_shape[2] - self.n_classes=self.model.layers[len(self.model.layers)-1].output_shape[3] + + assert isinstance(self.model, Model) def visualize_model_output(self, prediction, img, task) -> Tuple[NDArray, NDArray]: if task == "binarization": @@ -235,10 +255,9 @@ class sbb_predict: return added_image, layout_only def predict(self, image_dir): - assert isinstance(self.model, Model) if self.task == 'classification': classes_names = self.config_params_model['classification_classes_name'] - img_1ch = img=cv2.imread(image_dir, 0) + img_1ch =cv2.imread(image_dir, 0) img_1ch = img_1ch / 255.0 img_1ch = cv2.resize(img_1ch, (self.config_params_model['input_height'], self.config_params_model['input_width']), interpolation=cv2.INTER_NEAREST) @@ -273,6 +292,15 @@ class sbb_predict: pred_texts = decode_batch_predictions(preds, num_to_char) pred_texts = pred_texts[0].replace("[UNK]", "") return pred_texts + + elif self.task == "transformer-ocr": + from PIL import Image + image = Image.open(image_dir).convert("RGB") + pixel_values = self.processor(image, return_tensors="pt").pixel_values + generated_ids = self.model.generate(pixel_values.to(self.device)) + return self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0] + + elif self.task == 'reading_order': @@ -607,6 +635,8 @@ class sbb_predict: cv2.imwrite(self.save,res) elif self.task == "cnn-rnn-ocr": print(f"Detected text: {res}") + elif self.task == "transformer-ocr": + print(f"Detected text: {res}") else: img_seg_overlayed, only_layout = self.visualize_model_output(res, self.img_org, self.task) if self.save: @@ -710,10 +740,12 @@ class sbb_predict: ) def main(image, dir_in, model, patches, save, save_layout, ground_truth, xml_file, cpu, out, min_area): assert image or dir_in, "Either a single image -i or a dir_in -di is required" - with open(os.path.join(model,'config.json')) as f: + + with open(os.path.join(model,'config_eynollah.json')) as f: config_params_model = json.load(f) + task = config_params_model['task'] - if task != 'classification' and task != 'reading_order' and task != "cnn-rnn-ocr": + if task != 'classification' and task != 'reading_order' and task != "cnn-rnn-ocr" and task != "transformer-ocr": if image and not save: print("Error: You used one of segmentation or binarization task with image input but not set -s, you need a filename to save visualized output with -s") sys.exit(1) diff --git a/src/eynollah/training/train.py b/src/eynollah/training/train.py index 7a0cb3d..11ecc8c 100644 --- a/src/eynollah/training/train.py +++ b/src/eynollah/training/train.py @@ -1,7 +1,8 @@ import os import sys import json - +import numpy as np +import cv2 import click from eynollah.training.metrics import ( @@ -27,7 +28,8 @@ from eynollah.training.utils import ( generate_data_from_folder_training, get_one_hot, provide_patches, - return_number_of_total_training_data + return_number_of_total_training_data, + OCRDatasetYieldAugmentations ) os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' @@ -41,11 +43,16 @@ from sklearn.metrics import f1_score from tensorflow.keras.callbacks import Callback from tensorflow.keras.layers import StringLookup -import numpy as np -import cv2 + +import torch +from transformers import TrOCRProcessor +import evaluate +from transformers import default_data_collator +from transformers import VisionEncoderDecoderModel +from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments class SaveWeightsAfterSteps(Callback): - def __init__(self, save_interval, save_path, _config): + def __init__(self, save_interval, save_path, _config, characters_cnnrnn_ocr=None): super(SaveWeightsAfterSteps, self).__init__() self.save_interval = save_interval self.save_path = save_path @@ -61,7 +68,10 @@ class SaveWeightsAfterSteps(Callback): self.model.save(save_file) - with open(os.path.join(os.path.join(self.save_path, f"model_step_{self.step_count}"),"config.json"), "w") as fp: + if characters_cnnrnn_ocr: + os.system("cp "+characters_cnnrnn_ocr+" "+os.path.join(os.path.join(self.save_path, f"model_step_{self.step_count}"),"characters_org.txt")) + + with open(os.path.join(os.path.join(self.save_path, f"model_step_{self.step_count}"),"config_eynollah.json"), "w") as fp: json.dump(self._config, fp) # encode dict into JSON print(f"saved model as steps {self.step_count} to {save_file}") @@ -477,7 +487,7 @@ def run( model.save(os.path.join(dir_output,'model_'+str(i))) - with open(os.path.join(os.path.join(dir_output,'model_'+str(i)),"config.json"), "w") as fp: + with open(os.path.join(os.path.join(dir_output,'model_'+str(i)),"config_eynollah.json"), "w") as fp: json.dump(_config, fp) # encode dict into JSON #os.system('rm -rf '+dir_train_flowing) @@ -537,7 +547,7 @@ def run( opt = tf.keras.optimizers.Adam(learning_rate=learning_rate)#1e-4)#(lr_schedule) model.compile(optimizer=opt) - save_weights_callback = SaveWeightsAfterSteps(save_interval, dir_output, _config) if save_interval else None + save_weights_callback = SaveWeightsAfterSteps(save_interval, dir_output, _config, characters_cnnrnn_ocr=characters_txt_file) if save_interval else None for i in tqdm(range(index_start, n_epochs + index_start)): if save_interval: @@ -556,9 +566,125 @@ def run( if i >=0: model.save( os.path.join(dir_output,'model_'+str(i) )) - with open(os.path.join(os.path.join(dir_output,'model_'+str(i)),"config.json"), "w") as fp: + os.system("cp "+characters_txt_file+" "+os.path.join(os.path.join(dir_output,'model_'+str(i)),"characters_org.txt")) + with open(os.path.join(os.path.join(dir_output,'model_'+str(i)),"config_eynollah.json"), "w") as fp: json.dump(_config, fp) # encode dict into JSON + + elif task=="transformer-ocr": + dir_img, dir_lab = get_dirs_or_files(dir_train) + + processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-printed") + + ls_files_images = os.listdir(dir_img) + + aug_multip = return_multiplier_based_on_augmnentations(augmentation, color_padding_rotation, rotation_not_90, blur_aug, degrading, bin_deg, + brightening, padding_white, adding_rgb_foreground, adding_rgb_background, binarization, + image_inversion, channels_shuffling, add_red_textlines, white_noise_strap, textline_skewing, textline_skewing_bin, textline_left_in_depth, textline_left_in_depth_bin, textline_right_in_depth, textline_right_in_depth_bin, textline_up_in_depth, textline_up_in_depth_bin, textline_down_in_depth, textline_down_in_depth_bin, pepper_bin_aug, pepper_aug, degrade_scales, number_of_backgrounds_per_image, thetha, thetha_padd, brightness, padd_colors, shuffle_indexes, pepper_indexes, skewing_amplitudes, blur_k, white_padds) + + len_dataset = aug_multip*len(ls_files_images) + + dataset = OCRDatasetYieldAugmentations( + dir_img=dir_img, + dir_img_bin=dir_img_bin, + dir_lab=dir_lab, + processor=processor, + max_target_length=max_len, + augmentation = augmentation, + binarization = binarization, + add_red_textlines = add_red_textlines, + white_noise_strap = white_noise_strap, + adding_rgb_foreground = adding_rgb_foreground, + adding_rgb_background = adding_rgb_background, + bin_deg = bin_deg, + blur_aug = blur_aug, + brightening = brightening, + padding_white = padding_white, + color_padding_rotation = color_padding_rotation, + rotation_not_90 = rotation_not_90, + degrading = degrading, + channels_shuffling = channels_shuffling, + textline_skewing = textline_skewing, + textline_skewing_bin = textline_skewing_bin, + textline_right_in_depth = textline_right_in_depth, + textline_left_in_depth = textline_left_in_depth, + textline_up_in_depth = textline_up_in_depth, + textline_down_in_depth = textline_down_in_depth, + textline_right_in_depth_bin = textline_right_in_depth_bin, + textline_left_in_depth_bin = textline_left_in_depth_bin, + textline_up_in_depth_bin = textline_up_in_depth_bin, + textline_down_in_depth_bin = textline_down_in_depth_bin, + pepper_aug = pepper_aug, + pepper_bin_aug = pepper_bin_aug, + list_all_possible_background_images=list_all_possible_background_images, + list_all_possible_foreground_rgbs=list_all_possible_foreground_rgbs, + blur_k = blur_k, + degrade_scales = degrade_scales, + white_padds = white_padds, + thetha_padd = thetha_padd, + thetha = thetha, + brightness = brightness, + padd_colors = padd_colors, + number_of_backgrounds_per_image = number_of_backgrounds_per_image, + shuffle_indexes = shuffle_indexes, + pepper_indexes = pepper_indexes, + skewing_amplitudes = skewing_amplitudes, + dir_rgb_backgrounds = dir_rgb_backgrounds, + dir_rgb_foregrounds = dir_rgb_foregrounds, + len_data=len_dataset, + ) + + # Create a DataLoader + data_loader = torch.utils.data.DataLoader(dataset, batch_size=1) + train_dataset = data_loader.dataset + + + if continue_training: + model = VisionEncoderDecoderModel.from_pretrained(dir_of_start_model) + else: + model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-printed") + + + # set special tokens used for creating the decoder_input_ids from the labels + model.config.decoder_start_token_id = processor.tokenizer.cls_token_id + model.config.pad_token_id = processor.tokenizer.pad_token_id + # make sure vocab size is set correctly + model.config.vocab_size = model.config.decoder.vocab_size + + # set beam search parameters + model.config.eos_token_id = processor.tokenizer.sep_token_id + model.config.max_length = max_len + model.config.early_stopping = True + model.config.no_repeat_ngram_size = 3 + model.config.length_penalty = 2.0 + model.config.num_beams = 4 + + + training_args = Seq2SeqTrainingArguments( + predict_with_generate=True, + num_train_epochs=n_epochs, + learning_rate=learning_rate, + per_device_train_batch_size=n_batch, + fp16=True, + output_dir=dir_output, + logging_steps=2, + save_steps=save_interval, + ) + + + cer_metric = evaluate.load("cer") + + # instantiate trainer + trainer = Seq2SeqTrainer( + model=model, + tokenizer=processor.feature_extractor, + args=training_args, + train_dataset=train_dataset, + data_collator=default_data_collator, + ) + trainer.train() + + elif task=='classification': configuration() model = resnet50_classifier(n_classes, input_height, input_width, weight_decay, pretraining) @@ -609,10 +735,10 @@ def run( model_weight_averaged.set_weights(new_weights) model_weight_averaged.save(os.path.join(dir_output,'model_ens_avg')) - with open(os.path.join( os.path.join(dir_output,'model_ens_avg'), "config.json"), "w") as fp: + with open(os.path.join( os.path.join(dir_output,'model_ens_avg'), "config_eynollah.json"), "w") as fp: json.dump(_config, fp) # encode dict into JSON - with open(os.path.join( os.path.join(dir_output,'model_best'), "config.json"), "w") as fp: + with open(os.path.join( os.path.join(dir_output,'model_best'), "config_eynollah.json"), "w") as fp: json.dump(_config, fp) # encode dict into JSON elif task=='reading_order': @@ -645,7 +771,7 @@ def run( history = model.fit(generate_arrays_from_folder_reading_order(dir_flow_train_labels, dir_flow_train_imgs, n_batch, input_height, input_width, n_classes, thetha, augmentation), steps_per_epoch=num_rows / n_batch, verbose=1) model.save( os.path.join(dir_output,'model_'+str(i+indexer_start) )) - with open(os.path.join(os.path.join(dir_output,'model_'+str(i)),"config.json"), "w") as fp: + with open(os.path.join(os.path.join(dir_output,'model_'+str(i)),"config_eynollah.json"), "w") as fp: json.dump(_config, fp) # encode dict into JSON ''' if f1score>f1score_tot[0]: diff --git a/src/eynollah/training/utils.py b/src/eynollah/training/utils.py index 005810f..d44f10c 100644 --- a/src/eynollah/training/utils.py +++ b/src/eynollah/training/utils.py @@ -13,8 +13,12 @@ import tensorflow as tf from tensorflow.keras.utils import to_categorical from PIL import Image, ImageFile, ImageEnhance +import torch +from torch.utils.data import IterableDataset + ImageFile.LOAD_TRUNCATED_IMAGES = True + def vectorize_label(label, char_to_num, padding_token, max_len): label = char_to_num(tf.strings.unicode_split(label, input_encoding="UTF-8")) length = tf.shape(label)[0] @@ -76,6 +80,7 @@ def add_salt_and_pepper_noise(img, salt_prob, pepper_prob): return noisy_image + def invert_image(img): img_inv = 255 - img return img_inv @@ -1668,3 +1673,411 @@ def return_multiplier_based_on_augmnentations( aug_multip += len(pepper_indexes) return aug_multip + + +class OCRDatasetYieldAugmentations(IterableDataset): + def __init__( + self, + dir_img, + dir_img_bin, + dir_lab, + processor, + max_target_length=128, + augmentation = None, + binarization = None, + add_red_textlines = None, + white_noise_strap = None, + adding_rgb_foreground = None, + adding_rgb_background = None, + bin_deg = None, + blur_aug = None, + brightening = None, + padding_white = None, + color_padding_rotation = None, + rotation_not_90 = None, + degrading = None, + channels_shuffling = None, + textline_skewing = None, + textline_skewing_bin = None, + textline_right_in_depth = None, + textline_left_in_depth = None, + textline_up_in_depth = None, + textline_down_in_depth = None, + textline_right_in_depth_bin = None, + textline_left_in_depth_bin = None, + textline_up_in_depth_bin = None, + textline_down_in_depth_bin = None, + pepper_aug = None, + pepper_bin_aug = None, + list_all_possible_background_images=None, + list_all_possible_foreground_rgbs=None, + blur_k = None, + degrade_scales = None, + white_padds = None, + thetha_padd = None, + thetha = None, + brightness = None, + padd_colors = None, + number_of_backgrounds_per_image = None, + shuffle_indexes = None, + pepper_indexes = None, + skewing_amplitudes = None, + dir_rgb_backgrounds = None, + dir_rgb_foregrounds = None, + len_data=None, + ): + """ + Args: + images_dir (str): Path to the directory containing images. + labels_dir (str): Path to the directory containing label text files. + tokenizer: Tokenizer for processing labels. + transform: Transformations applied after augmentation (e.g., ToTensor, normalization). + image_size (tuple): Size to resize images to. + max_seq_len (int): Maximum sequence length for tokenized labels. + scales (list or None): List of scale factors to apply. + """ + self.dir_img = dir_img + self.dir_img_bin = dir_img_bin + self.dir_lab = dir_lab + self.processor = processor + self.max_target_length = max_target_length + #self.scales = scales if scales else [] + + self.augmentation = augmentation + self.binarization = binarization + self.add_red_textlines = add_red_textlines + self.white_noise_strap = white_noise_strap + self.adding_rgb_foreground = adding_rgb_foreground + self.adding_rgb_background = adding_rgb_background + self.bin_deg = bin_deg + self.blur_aug = blur_aug + self.brightening = brightening + self.padding_white = padding_white + self.color_padding_rotation = color_padding_rotation + self.rotation_not_90 = rotation_not_90 + self.degrading = degrading + self.channels_shuffling = channels_shuffling + self.textline_skewing = textline_skewing + self.textline_skewing_bin = textline_skewing_bin + self.textline_right_in_depth = textline_right_in_depth + self.textline_left_in_depth = textline_left_in_depth + self.textline_up_in_depth = textline_up_in_depth + self.textline_down_in_depth = textline_down_in_depth + self.textline_right_in_depth_bin = textline_right_in_depth_bin + self.textline_left_in_depth_bin = textline_left_in_depth_bin + self.textline_up_in_depth_bin = textline_up_in_depth_bin + self.textline_down_in_depth_bin = textline_down_in_depth_bin + self.pepper_aug = pepper_aug + self.pepper_bin_aug = pepper_bin_aug + self.list_all_possible_background_images=list_all_possible_background_images + self.list_all_possible_foreground_rgbs=list_all_possible_foreground_rgbs + self.blur_k = blur_k + self.degrade_scales = degrade_scales + self.white_padds = white_padds + self.thetha_padd = thetha_padd + self.thetha = thetha + self.brightness = brightness + self.padd_colors = padd_colors + self.number_of_backgrounds_per_image = number_of_backgrounds_per_image + self.shuffle_indexes = shuffle_indexes + self.pepper_indexes = pepper_indexes + self.skewing_amplitudes = skewing_amplitudes + self.dir_rgb_backgrounds = dir_rgb_backgrounds + self.dir_rgb_foregrounds = dir_rgb_foregrounds + self.image_files = os.listdir(dir_img)#sorted([f for f in os.listdir(images_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]) + self.len_data = len_data + #assert len(self.image_files) == len(self.label_files), "Number of images and labels must match!" + + def __len__(self): + return self.len_data + + def __iter__(self): + for img_file in self.image_files: + # Load image + f_name = img_file.split('.')[0] + + txt_inp = open(os.path.join(self.dir_lab, f_name+'.txt'),'r').read().split('\n')[0] + + img = cv2.imread(os.path.join(self.dir_img, img_file)) + img = img.astype(np.uint8) + + + if self.dir_img_bin: + img_bin_corr = cv2.imread(os.path.join(self.dir_img_bin, f_name+'.png') ) + img_bin_corr = img_bin_corr.astype(np.uint8) + else: + img_bin_corr = None + + + labels = self.processor.tokenizer(txt_inp, + padding="max_length", + max_length=self.max_target_length).input_ids + + labels = [label if label != self.processor.tokenizer.pad_token_id else -100 for label in labels] + + + if self.augmentation: + pixel_values = self.processor(Image.fromarray(img), return_tensors="pt").pixel_values + encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)} + yield encoding + + + if self.color_padding_rotation: + for index, thetha_ind in enumerate(self.thetha_padd): + for padd_col in self.padd_colors: + img_out = rotation_not_90_func_single_image(do_padding_for_ocr(img, 1.2, padd_col), thetha_ind) + img_out = img_out.astype(np.uint8) + pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values + encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)} + yield encoding + + if self.rotation_not_90: + for index, thetha_ind in enumerate(self.thetha): + img_out = rotation_not_90_func_single_image(img, thetha_ind) + img_out = img_out.astype(np.uint8) + pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values + encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)} + yield encoding + + + if self.blur_aug: + for index, blur_type in enumerate(self.blur_k): + img_out = bluring(img, blur_type) + img_out = img_out.astype(np.uint8) + pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values + encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)} + yield encoding + + if self.degrading: + for index, deg_scale_ind in enumerate(self.degrade_scales): + try: + img_out = do_degrading(img, deg_scale_ind) + except: + img_out = np.copy(img) + img_out = img_out.astype(np.uint8) + pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values + encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)} + yield encoding + + if self.bin_deg: + for index, deg_scale_ind in enumerate(self.degrade_scales): + try: + img_out = self.do_degrading(img_bin_corr, deg_scale_ind) + except: + img_out = np.copy(img_bin_corr) + img_out = img_out.astype(np.uint8) + pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values + encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)} + yield encoding + + + if self.brightening: + for index, bright_scale_ind in enumerate(self.brightness): + try: + img_out = do_brightening(dir_img, bright_scale_ind) + except: + img_out = np.copy(img) + img_out = img_out.astype(np.uint8) + pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values + encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)} + yield encoding + + + if self.padding_white: + for index, padding_size in enumerate(self.white_padds): + for padd_col in self.padd_colors: + img_out = do_padding_for_ocr(img, padding_size, padd_col) + img_out = img_out.astype(np.uint8) + pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values + encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)} + yield encoding + + + if self.adding_rgb_foreground: + for i_n in range(self.number_of_backgrounds_per_image): + background_image_chosen_name = random.choice(self.list_all_possible_background_images) + foreground_rgb_chosen_name = random.choice(self.list_all_possible_foreground_rgbs) + + img_rgb_background_chosen = cv2.imread(self.dir_rgb_backgrounds + '/' + background_image_chosen_name) + foreground_rgb_chosen = np.load(self.dir_rgb_foregrounds + '/' + foreground_rgb_chosen_name) + + img_out = return_binary_image_with_given_rgb_background_and_given_foreground_rgb(img_bin_corr, img_rgb_background_chosen, foreground_rgb_chosen) + img_out = img_out.astype(np.uint8) + pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values + encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)} + yield encoding + + + + + if self.adding_rgb_background: + for i_n in range(self.number_of_backgrounds_per_image): + background_image_chosen_name = random.choice(self.list_all_possible_background_images) + img_rgb_background_chosen = cv2.imread(self.dir_rgb_backgrounds + '/' + background_image_chosen_name) + img_out = return_binary_image_with_given_rgb_background(img_bin_corr, img_rgb_background_chosen) + img_out = img_out.astype(np.uint8) + pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values + encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)} + yield encoding + + if self.binarization: + pixel_values = self.processor(Image.fromarray(img_bin_corr), return_tensors="pt").pixel_values + encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)} + yield encoding + + + if self.channels_shuffling: + for shuffle_index in self.shuffle_indexes: + img_out = return_shuffled_channels(img, shuffle_index) + img_out = img_out.astype(np.uint8) + pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values + encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)} + yield encoding + + if self.add_red_textlines: + img_out = return_image_with_red_elements(img, img_bin_corr) + img_out = img_out.astype(np.uint8) + pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values + encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)} + yield encoding + + if self.white_noise_strap: + img_out = return_image_with_strapped_white_noises(img) + img_out = img_out.astype(np.uint8) + pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values + encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)} + yield encoding + + if self.textline_skewing: + for index, des_scale_ind in enumerate(self.skewing_amplitudes): + try: + img_out = do_deskewing(img, des_scale_ind) + except: + img_out = np.copy(img) + img_out = img_out.astype(np.uint8) + pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values + encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)} + yield encoding + + if self.textline_skewing_bin: + for index, des_scale_ind in enumerate(self.skewing_amplitudes): + try: + img_out = do_deskewing(img_bin_corr, des_scale_ind) + except: + img_out = np.copy(img_bin_corr) + img_out = img_out.astype(np.uint8) + pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values + encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)} + yield encoding + + + if self.textline_left_in_depth: + try: + img_out = do_direction_in_depth(img, 'left') + except: + img_out = np.copy(img) + img_out = img_out.astype(np.uint8) + pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values + encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)} + yield encoding + + + if self.textline_left_in_depth_bin: + try: + img_out = do_direction_in_depth(img_bin_corr, 'left') + except: + img_out = np.copy(img_bin_corr) + img_out = img_out.astype(np.uint8) + pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values + encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)} + yield encoding + + + if self.textline_right_in_depth: + try: + img_out = do_direction_in_depth(img, 'right') + except: + img_out = np.copy(img) + img_out = img_out.astype(np.uint8) + pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values + encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)} + yield encoding + + + if self.textline_right_in_depth_bin: + try: + img_out = do_direction_in_depth(img_bin_corr, 'right') + except: + img_out = np.copy(img_bin_corr) + img_out = img_out.astype(np.uint8) + pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values + encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)} + yield encoding + + + if self.textline_up_in_depth: + try: + img_out = do_direction_in_depth(img, 'up') + except: + img_out = np.copy(img) + img_out = img_out.astype(np.uint8) + pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values + encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)} + yield encoding + + + if self.textline_up_in_depth_bin: + try: + img_out = do_direction_in_depth(img_bin_corr, 'up') + except: + img_out = np.copy(img_bin_corr) + img_out = img_out.astype(np.uint8) + pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values + encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)} + yield encoding + + + if self.textline_down_in_depth: + try: + img_out = do_direction_in_depth(img, 'down') + except: + img_out = np.copy(img) + img_out = img_out.astype(np.uint8) + pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values + encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)} + yield encoding + + + if self.textline_down_in_depth_bin: + try: + img_out = do_direction_in_depth(img_bin_corr, 'down') + except: + img_out = np.copy(img_bin_corr) + img_out = img_out.astype(np.uint8) + pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values + encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)} + yield encoding + + if self.pepper_bin_aug: + for index, pepper_ind in enumerate(self.pepper_indexes): + img_out = add_salt_and_pepper_noise(img_bin_corr, pepper_ind, pepper_ind) + img_out = img_out.astype(np.uint8) + pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values + encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)} + yield encoding + + + if self.pepper_aug: + for index, pepper_ind in enumerate(self.pepper_indexes): + img_out = add_salt_and_pepper_noise(img, pepper_ind, pepper_ind) + img_out = img_out.astype(np.uint8) + pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values + encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)} + yield encoding + + + + else: + pixel_values = self.processor(Image.fromarray(img), return_tensors="pt").pixel_values + encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)} + yield encoding diff --git a/src/eynollah/training/weights_ensembling.py b/src/eynollah/training/weights_ensembling.py index 6dce7fd..2f25dbf 100644 --- a/src/eynollah/training/weights_ensembling.py +++ b/src/eynollah/training/weights_ensembling.py @@ -21,6 +21,11 @@ from tensorflow.keras.layers import * import click import logging +from transformers import TrOCRProcessor +from PIL import Image +import torch +from transformers import VisionEncoderDecoderModel + class Patches(layers.Layer): def __init__(self, patch_size_x, patch_size_y): @@ -92,30 +97,46 @@ def start_new_session(): tensorflow_backend.set_session(session) return session -def run_ensembling(dir_models, out): +def run_ensembling(dir_models, out, framework): ls_models = os.listdir(dir_models) - - - weights=[] - - for model_name in ls_models: - model = load_model(os.path.join(dir_models,model_name) , compile=False, custom_objects={'PatchEncoder':PatchEncoder, 'Patches': Patches}) - weights.append(model.get_weights()) + if framework=="torch": + models = [] + sd_models = [] - new_weights = list() - - for weights_list_tuple in zip(*weights): - new_weights.append( - [np.array(weights_).mean(axis=0)\ - for weights_ in zip(*weights_list_tuple)]) + for model_name in ls_models: + model = VisionEncoderDecoderModel.from_pretrained(os.path.join(dir_models,model_name)) + models.append(model) + sd_models.append(model.state_dict()) + for key in sd_models[0]: + sd_models[0][key] = sum(sd[key] for sd in sd_models) / len(sd_models) + + model.load_state_dict(sd_models[0]) + os.system("mkdir "+out) + torch.save(model.state_dict(), os.path.join(out, "pytorch_model.bin")) + os.system('cp '+os.path.join(os.path.join(dir_models,model_name) , "config_eynollah.json ")+out) + else: + weights=[] + + for model_name in ls_models: + model = load_model(os.path.join(dir_models,model_name) , compile=False, custom_objects={'PatchEncoder':PatchEncoder, 'Patches': Patches}) + weights.append(model.get_weights()) + + new_weights = list() + + for weights_list_tuple in zip(*weights): + new_weights.append( + [np.array(weights_).mean(axis=0)\ + for weights_ in zip(*weights_list_tuple)]) + - new_weights = [np.array(x) for x in new_weights] - - model.set_weights(new_weights) - model.save(out) - os.system('cp '+os.path.join(os.path.join(dir_models,model_name) , "config.json ")+out) + new_weights = [np.array(x) for x in new_weights] + + model.set_weights(new_weights) + model.save(out) + os.system('cp '+os.path.join(os.path.join(dir_models,model_name) , "config_eynollah.json ")+out) + os.system('cp '+os.path.join(os.path.join(dir_models,model_name) , "characters_org.txt ")+out) @click.command() @click.option( @@ -130,7 +151,12 @@ def run_ensembling(dir_models, out): help="output directory where ensembled model will be written.", type=click.Path(exists=False, file_okay=False), ) +@click.option( + "--framework", + "-fw", + help="this parameter gets tensorflow or torch as model framework", +) -def main(dir_models, out): - run_ensembling(dir_models, out) +def main(dir_models, out, framework): + run_ensembling(dir_models, out, framework) diff --git a/src/eynollah/utils/font.py b/src/eynollah/utils/font.py index 939933e..0354317 100644 --- a/src/eynollah/utils/font.py +++ b/src/eynollah/utils/font.py @@ -9,8 +9,8 @@ else: import importlib.resources as importlib_resources -def get_font(): +def get_font(font_size): #font_path = "Charis-7.000/Charis-Regular.ttf" # Make sure this file exists! font = importlib_resources.files(__package__) / "../Charis-Regular.ttf" with importlib_resources.as_file(font) as font: - return ImageFont.truetype(font=font, size=40) + return ImageFont.truetype(font=font, size=font_size) diff --git a/train/config_params.json b/train/config_params.json index b01ac08..34c6376 100644 --- a/train/config_params.json +++ b/train/config_params.json @@ -1,17 +1,17 @@ { "backbone_type" : "transformer", - "task": "cnn-rnn-ocr", + "task": "transformer-ocr", "n_classes" : 2, - "max_len": 280, - "n_epochs" : 3, + "max_len": 192, + "n_epochs" : 1, "input_height" : 32, "input_width" : 512, "weight_decay" : 1e-6, - "n_batch" : 4, + "n_batch" : 1, "learning_rate": 1e-5, "save_interval": 1500, "patches" : false, - "pretraining" : true, + "pretraining" : false, "augmentation" : true, "flip_aug" : false, "blur_aug" : true, @@ -77,7 +77,6 @@ "dir_output": "/home/vahid/extracted_lines/1919_bin/output", "dir_rgb_backgrounds": "/home/vahid/Documents/1_2_test_eynollah/set_rgb_background", "dir_rgb_foregrounds": "/home/vahid/Documents/1_2_test_eynollah/out_set_rgb_foreground", - "dir_img_bin": "/home/vahid/extracted_lines/1919_bin/images_bin", - "characters_txt_file":"/home/vahid/Downloads/models_eynollah/model_eynollah_ocr_cnnrnn_20250930/characters_org.txt" + "dir_img_bin": "/home/vahid/extracted_lines/1919_bin/images_bin" } diff --git a/train/requirements.txt b/train/requirements.txt index 63f3813..e3599a8 100644 --- a/train/requirements.txt +++ b/train/requirements.txt @@ -4,3 +4,8 @@ numpy <1.24.0 tqdm imutils scipy +torch +evaluate +accelerate +jiwer +transformers <= 4.30.2