diff --git a/src/eynollah/Amiri-Regular.ttf b/src/eynollah/Amiri-Regular.ttf new file mode 100644 index 0000000..df5e1df Binary files /dev/null and b/src/eynollah/Amiri-Regular.ttf differ diff --git a/src/eynollah/eynollah_ocr.py b/src/eynollah/eynollah_ocr.py index 1b49077..45ce598 100644 --- a/src/eynollah/eynollah_ocr.py +++ b/src/eynollah/eynollah_ocr.py @@ -69,10 +69,11 @@ class Eynollah_ocr: self.model_zoo.load_models(['ocr', 'tr']) self.model_zoo.get('ocr').to(self.device) else: - self.model_zoo.load_models('ocr') - self.model_zoo.load_models('num_to_char') - self.model_zoo.load_models('characters') - self.end_character = len(self.model_zoo.get('characters')) + 2 + self.model_zoo.load_model('ocr', '') + self.input_shape = self.model_zoo.get('ocr').input_shape[1:3] + self.model_zoo.load_model('num_to_char') + self.model_zoo.load_model('characters') + self.end_character = len(self.model_zoo.get('characters', list)) + 2 @property def device(self): @@ -657,7 +658,7 @@ class Eynollah_ocr: if out_image_with_text: image_text = Image.new("RGB", (img.shape[1], img.shape[0]), "white") draw = ImageDraw.Draw(image_text) - font = get_font() + font = get_font(font_size=40) for indexer_text, bb_ind in enumerate(total_bb_coordinates): x_bb = bb_ind[0] @@ -823,8 +824,8 @@ class Eynollah_ocr: page_ns=page_ns, img_bin=img_bin, - image_width=512, - image_height=32, + image_width=self.input_shape[1], + image_height=self.input_shape[0], ) self.write_ocr( diff --git a/src/eynollah/patch_encoder.py b/src/eynollah/patch_encoder.py index f163132..fcf41f1 100644 --- a/src/eynollah/patch_encoder.py +++ b/src/eynollah/patch_encoder.py @@ -20,6 +20,7 @@ class PatchEncoder(layers.Layer): def get_config(self): return dict(num_patches=self.num_patches, projection_dim=self.projection_dim, + position_embedding=self.position_embedding, **super().get_config()) class Patches(layers.Layer): diff --git a/src/eynollah/training/cli.py b/src/eynollah/training/cli.py index ae14f04..ff87b90 100644 --- a/src/eynollah/training/cli.py +++ b/src/eynollah/training/cli.py @@ -9,7 +9,12 @@ from .generate_gt_for_training import main as generate_gt_cli from .inference import main as inference_cli from .train import ex from .extract_line_gt import linegt_cli +<<<<<<< HEAD from .weights_ensembling import ensemble_cli +======= +from .weights_ensembling import main as ensemble_cli +from .generate_or_update_cnn_rnn_ocr_character_list import main as update_ocr_characters_cli +>>>>>>> integrating_trocr_and_torch_ensembling_and_updating_characters_list @click.command(context_settings=dict( ignore_unknown_options=True, @@ -28,3 +33,4 @@ main.add_command(inference_cli, 'inference') main.add_command(train_cli, 'train') main.add_command(linegt_cli, 'export_textline_images_and_text') main.add_command(ensemble_cli, 'ensembling') +main.add_command(update_ocr_characters_cli, 'generate_or_update_cnn_rnn_ocr_character_list') diff --git a/src/eynollah/training/extract_line_gt.py b/src/eynollah/training/extract_line_gt.py index 58fc253..4600e79 100644 --- a/src/eynollah/training/extract_line_gt.py +++ b/src/eynollah/training/extract_line_gt.py @@ -50,6 +50,12 @@ from ..utils import is_image_filename is_flag=True, help="if this parameter set to true, cropped textline images will not be masked with textline contour.", ) +@click.option( + "--exclude_vertical_lines", + "-exv", + is_flag=True, + help="if this parameter set to true, vertical textline images will be excluded.", +) def linegt_cli( image, dir_in, @@ -57,6 +63,7 @@ def linegt_cli( dir_out, pref_of_dataset, do_not_mask_with_textline_contour, + exclude_vertical_lines, ): assert bool(dir_in) ^ bool(image), "Set --dir-in or --image-filename, not both" if dir_in: @@ -70,14 +77,13 @@ def linegt_cli( for dir_img in ls_imgs: file_name = Path(dir_img).stem dir_xml = os.path.join(dir_xmls, file_name + '.xml') - img = cv2.imread(dir_img) - + total_bb_coordinates = [] - tree1 = ET.parse(dir_xml, parser=ET.XMLParser(encoding="utf-8")) - root1 = tree1.getroot() - alltags = [elem.tag for elem in root1.iter()] + tree = ET.parse(dir_xml, parser=ET.XMLParser(encoding="utf-8")) + root = tree.getroot() + alltags = [elem.tag for elem in root.iter()] name_space = alltags[0].split('}')[0] name_space = name_space.split('{')[1] @@ -89,7 +95,7 @@ def linegt_cli( indexer_text_region = 0 indexer_textlines = 0 # FIXME: non recursive, use OCR-D PAGE generateDS API. Or use an existing tool for this purpose altogether - for nn in root1.iter(region_tags): + for nn in root.iter(region_tags): for child_textregion in nn: if child_textregion.tag.endswith("TextLine"): for child_textlines in child_textregion: @@ -99,6 +105,10 @@ def linegt_cli( textline_coords = np.array([[int(x.split(',')[0]), int(x.split(',')[1])] for x in p_h]) x, y, w, h = cv2.boundingRect(textline_coords) + + if exclude_vertical_lines and h > 1.4 * w: + img_crop = None + continue total_bb_coordinates.append([x, y, w, h]) @@ -114,12 +124,15 @@ def linegt_cli( img_crop[mask_poly == 0] = 255 if img_crop.shape[0] == 0 or img_crop.shape[1] == 0: + img_crop = None continue + + if child_textlines.tag.endswith("TextEquiv"): for cheild_text in child_textlines: if cheild_text.tag.endswith("Unicode"): textline_text = cheild_text.text - if textline_text: + if textline_text and img_crop is not None: base_name = os.path.join( dir_out, file_name + '_line_' + str(indexer_textlines) ) @@ -131,4 +144,4 @@ def linegt_cli( with open(base_name + '.txt', 'w') as text_file: text_file.write(textline_text) cv2.imwrite(base_name + '.png', img_crop) - indexer_textlines += 1 + indexer_textlines += 1 diff --git a/src/eynollah/training/generate_gt_for_training.py b/src/eynollah/training/generate_gt_for_training.py index cc5a1b2..a848b65 100644 --- a/src/eynollah/training/generate_gt_for_training.py +++ b/src/eynollah/training/generate_gt_for_training.py @@ -6,6 +6,7 @@ from pathlib import Path from PIL import Image, ImageDraw, ImageFont import cv2 import numpy as np +from eynollah.utils.font import get_font from .gt_gen_utils import ( filter_contours_area_of_image, @@ -393,11 +394,15 @@ def visualize_reading_order(xml_file, dir_xml, dir_out, dir_imgs): layout = np.zeros( (y_len,x_len,3) ) layout = cv2.fillPoly(layout, pts =co_text_all, color=(1,1,1)) - img_file_name_with_format = find_format_of_given_filename_in_dir(dir_imgs, f_name) - img = cv2.imread(os.path.join(dir_imgs, img_file_name_with_format)) - - overlayed = overlay_layout_on_image(layout, img, cx_ordered, cy_ordered, color, thickness) - cv2.imwrite(os.path.join(dir_out, f_name+'.png'), overlayed) + try: + img_file_name_with_format = find_format_of_given_filename_in_dir(dir_imgs, f_name) + img = cv2.imread(os.path.join(dir_imgs, img_file_name_with_format)) + + overlayed = overlay_layout_on_image(layout, img, cx_ordered, cy_ordered, color, thickness) + cv2.imwrite(os.path.join(dir_out, f_name+'.png'), overlayed) + except: + pass + else: img = np.zeros( (y_len,x_len,3) ) @@ -452,14 +457,17 @@ def visualize_textline_segmentation(xml_file, dir_xml, dir_out, dir_imgs): xml_file = os.path.join(dir_xml,ind_xml ) f_name = Path(ind_xml).stem - img_file_name_with_format = find_format_of_given_filename_in_dir(dir_imgs, f_name) - img = cv2.imread(os.path.join(dir_imgs, img_file_name_with_format)) + try: + img_file_name_with_format = find_format_of_given_filename_in_dir(dir_imgs, f_name) + img = cv2.imread(os.path.join(dir_imgs, img_file_name_with_format)) + + co_tetxlines, y_len, x_len = get_textline_contours_for_visualization(xml_file) - co_tetxlines, y_len, x_len = get_textline_contours_for_visualization(xml_file) - - added_image = visualize_image_from_contours(co_tetxlines, img) - - cv2.imwrite(os.path.join(dir_out, f_name+'.png'), added_image) + added_image = visualize_image_from_contours(co_tetxlines, img) + + cv2.imwrite(os.path.join(dir_out, f_name+'.png'), added_image) + except: + pass @@ -509,15 +517,17 @@ def visualize_layout_segmentation(xml_file, dir_xml, dir_out, dir_imgs): f_name = Path(ind_xml).stem print(f_name, 'f_name') - img_file_name_with_format = find_format_of_given_filename_in_dir(dir_imgs, f_name) - img = cv2.imread(os.path.join(dir_imgs, img_file_name_with_format)) + try: + img_file_name_with_format = find_format_of_given_filename_in_dir(dir_imgs, f_name) + img = cv2.imread(os.path.join(dir_imgs, img_file_name_with_format)) + + co_text, co_graphic, co_sep, co_img, co_table, co_map, co_music, co_noise, y_len, x_len = get_layout_contours_for_visualization(xml_file) - co_text, co_graphic, co_sep, co_img, co_table, co_map, co_noise, y_len, x_len = get_layout_contours_for_visualization(xml_file) - - - added_image = visualize_image_from_contours_layout(co_text['paragraph'], co_text['header']+co_text['heading'], co_text['drop-capital'], co_sep, co_img, co_text['marginalia'], co_table, img) + added_image = visualize_image_from_contours_layout(co_text['paragraph'], co_text['header']+co_text['heading'], co_text['drop-capital'], co_sep, co_img, co_text['marginalia'], co_table, co_map, co_music, img) - cv2.imwrite(os.path.join(dir_out, f_name+'.png'), added_image) + cv2.imwrite(os.path.join(dir_out, f_name+'.png'), added_image) + except: + pass @@ -552,8 +562,8 @@ def visualize_ocr_text(xml_file, dir_xml, dir_out): else: xml_files_ind = [xml_file] - font_path = "Charis-7.000/Charis-Regular.ttf" # Make sure this file exists! - font = ImageFont.truetype(font_path, 40) + ###font_path = "Charis-7.000/Charis-Regular.ttf" # Make sure this file exists! + font = get_font(font_size=40)#ImageFont.truetype(font_path, 40) for ind_xml in tqdm(xml_files_ind): indexer = 0 @@ -590,11 +600,11 @@ def visualize_ocr_text(xml_file, dir_xml, dir_out): is_vertical = h > 2*w # Check orientation - font = fit_text_single_line(draw, ocr_texts[index], font_path, w, int(h*0.4) ) + font = fit_text_single_line(draw, ocr_texts[index], w, int(h*0.4) ) if is_vertical: - vertical_font = fit_text_single_line(draw, ocr_texts[index], font_path, h, int(w * 0.8)) + vertical_font = fit_text_single_line(draw, ocr_texts[index], h, int(w * 0.8)) text_img = Image.new("RGBA", (h, w), (255, 255, 255, 0)) # Note: dimensions are swapped text_draw = ImageDraw.Draw(text_img) diff --git a/src/eynollah/training/generate_or_update_cnn_rnn_ocr_character_list.py b/src/eynollah/training/generate_or_update_cnn_rnn_ocr_character_list.py new file mode 100644 index 0000000..8620515 --- /dev/null +++ b/src/eynollah/training/generate_or_update_cnn_rnn_ocr_character_list.py @@ -0,0 +1,59 @@ +import os +import numpy as np +import json +import click +import logging + + + +def run_character_list_update(dir_labels, out, current_character_list): + ls_labels = os.listdir(dir_labels) + ls_labels = [ind for ind in ls_labels if ind.endswith('.txt')] + + if current_character_list: + with open(current_character_list, 'r') as f_name: + characters = json.load(f_name) + + characters = set(characters) + else: + characters = set() + + + for ind in ls_labels: + label = open(os.path.join(dir_labels,ind),'r').read().split('\n')[0] + + for char in label: + characters.add(char) + + + characters = sorted(list(set(characters))) + + with open(out, 'w') as f_name: + json.dump(characters, f_name) + + +@click.command() +@click.option( + "--dir_labels", + "-dl", + help="directory of labels which are .txt files", + type=click.Path(exists=True, file_okay=False), + required=True, +) +@click.option( + "--current_character_list", + "-ccl", + help="existing character list in a .txt file that needs to be updated with a set of labels", + type=click.Path(exists=True, file_okay=True), + required=False, +) +@click.option( + "--out", + "-o", + help="An output .txt file where the generated or updated character list will be written", + type=click.Path(exists=False, file_okay=True), +) + +def main(dir_labels, out, current_character_list): + run_character_list_update(dir_labels, out, current_character_list) + diff --git a/src/eynollah/training/gt_gen_utils.py b/src/eynollah/training/gt_gen_utils.py index 796e896..473ee11 100644 --- a/src/eynollah/training/gt_gen_utils.py +++ b/src/eynollah/training/gt_gen_utils.py @@ -8,7 +8,7 @@ from shapely import geometry from pathlib import Path from PIL import ImageFont from ocrd_utils import bbox_from_points - +from eynollah.utils.font import get_font KERNEL = np.ones((5, 5), np.uint8) NS = { 'pc': 'http://schema.primaresearch.org/PAGE/gts/pagecontent/2019-07-15' @@ -18,7 +18,7 @@ with warnings.catch_warnings(): warnings.simplefilter("ignore") -def visualize_image_from_contours_layout(co_par, co_header, co_drop, co_sep, co_image, co_marginal, co_table, co_map, img): +def visualize_image_from_contours_layout(co_par, co_header, co_drop, co_sep, co_image, co_marginal, co_table, co_map, co_music, img): alpha = 0.5 blank_image = np.ones( (img.shape[:]), dtype=np.uint8) * 255 @@ -32,6 +32,7 @@ def visualize_image_from_contours_layout(co_par, co_header, co_drop, co_sep, co_ col_marginal = (106, 90, 205) col_table = (0, 90, 205) col_map = (90, 90, 205) + col_music = (90, 90, 0) if len(co_image)>0: cv2.drawContours(blank_image, co_image, -1, col_image, thickness=cv2.FILLED) # Fill the contour @@ -59,6 +60,9 @@ def visualize_image_from_contours_layout(co_par, co_header, co_drop, co_sep, co_ if len(co_map)>0: cv2.drawContours(blank_image, co_map, -1, col_map, thickness=cv2.FILLED) # Fill the contour + + if len(co_music)>0: + cv2.drawContours(blank_image, co_music, -1, col_music, thickness=cv2.FILLED) # Fill the contour img_final =cv2.cvtColor(blank_image, cv2.COLOR_BGR2RGB) @@ -352,11 +356,11 @@ def get_textline_contours_and_ocr_text(xml_file): ocr_textlines.append(ocr_text_in[0]) return co_use_case, y_len, x_len, ocr_textlines -def fit_text_single_line(draw, text, font_path, max_width, max_height): +def fit_text_single_line(draw, text, max_width, max_height): initial_font_size = 50 font_size = initial_font_size while font_size > 10: # Minimum font size - font = ImageFont.truetype(font_path, font_size) + font = get_font(font_size=font_size)# ImageFont.truetype(font_path, font_size) text_bbox = draw.textbbox((0, 0), text, font=font) # Get text bounding box text_width = text_bbox[2] - text_bbox[0] text_height = text_bbox[3] - text_bbox[1] @@ -366,7 +370,7 @@ def fit_text_single_line(draw, text, font_path, max_width, max_height): font_size -= 2 # Reduce font size and retry - return ImageFont.truetype(font_path, 10) # Smallest font fallback + return get_font(font_size=10)#ImageFont.truetype(font_path, 10) # Smallest font fallback def get_layout_contours_for_visualization(xml_file): tree1 = ET.parse(xml_file, parser = ET.XMLParser(encoding='utf-8')) @@ -389,6 +393,7 @@ def get_layout_contours_for_visualization(xml_file): co_img=[] co_table=[] co_map=[] + co_music=[] co_noise=[] types_text = [] @@ -630,6 +635,31 @@ def get_layout_contours_for_visualization(xml_file): elif vv.tag!=link+'Point' and sumi>=1: break co_map.append(np.array(c_t_in)) + + if tag.endswith('}MusicRegion') or tag.endswith('}musicregion'): + #print('sth') + for nn in root1.iter(tag): + c_t_in=[] + sumi=0 + for vv in nn.iter(): + # check the format of coords + if vv.tag==link+'Coords': + coords=bool(vv.attrib) + if coords: + p_h=vv.attrib['points'].split(' ') + c_t_in.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + break + else: + pass + + + if vv.tag==link+'Point': + c_t_in.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ]) + sumi+=1 + #print(vv.tag,'in') + elif vv.tag!=link+'Point' and sumi>=1: + break + co_music.append(np.array(c_t_in)) if tag.endswith('}NoiseRegion') or tag.endswith('}noiseregion'): @@ -656,7 +686,7 @@ def get_layout_contours_for_visualization(xml_file): elif vv.tag!=link+'Point' and sumi>=1: break co_noise.append(np.array(c_t_in)) - return co_text, co_graphic, co_sep, co_img, co_table, co_map, co_noise, y_len, x_len + return co_text, co_graphic, co_sep, co_img, co_table, co_map, co_music, co_noise, y_len, x_len def get_images_of_ground_truth( gt_list, @@ -682,171 +712,193 @@ def get_images_of_ground_truth( if not item.endswith('.xml')} for index in tqdm(range(len(gt_list))): - #try: print(gt_list[index]) - tree1 = ET.parse(dir_in+'/'+gt_list[index], parser = ET.XMLParser(encoding='utf-8')) - root1=tree1.getroot() - alltags=[elem.tag for elem in root1.iter()] - link=alltags[0].split('}')[0]+'}' - - x_len, y_len = 0, 0 - for jj in root1.iter(link+'Page'): - y_len=int(jj.attrib['imageHeight']) - x_len=int(jj.attrib['imageWidth']) + try: + tree1 = ET.parse(dir_in+'/'+gt_list[index], parser = ET.XMLParser(encoding='utf-8')) + root1=tree1.getroot() + alltags=[elem.tag for elem in root1.iter()] + link=alltags[0].split('}')[0]+'}' + - if 'columns_width' in list(config_params.keys()): - columns_width_dict = config_params['columns_width'] - # FIXME: look in /Page/@custom as well - metadata_element = root1.find(link+'Metadata') - num_col = None - for child in metadata_element: - tag2 = child.tag - if tag2.endswith('}Comments') or tag2.endswith('}comments'): - text_comments = child.text - num_col = int(text_comments.split('num_col')[1]) + x_len, y_len = 0, 0 + for jj in root1.iter(link+'Page'): + y_len=int(jj.attrib['imageHeight']) + x_len=int(jj.attrib['imageWidth']) - if num_col: - x_new = columns_width_dict[str(num_col)] - y_new = int ( x_new * (y_len / float(x_len)) ) - - if printspace or "printspace_as_class_in_layout" in list(config_params.keys()): - ps = (root1.xpath('/pc:PcGts/pc:Page/pc:Border', namespaces=NS) + - root1.xpath('/pc:PcGts/pc:Page/pc:PrintSpace', namespaces=NS)) - coords = root1.xpath('//pc:Coords/@points', namespaces=NS) - if len(ps): - points = ps[0].find('pc:Coords', NS).get('points') - ps_bbox = bbox_from_points(points) - elif missing_printspace == 'skip': - print(gt_list[index], "has no Border or PrintSpace - skipping file") - continue - elif missing_printspace == 'project' and len(coords): - print(gt_list[index], "has no Border or PrintSpace - projecting hull of segments") - bboxes = list(map(bbox_from_points, coords)) - left, top, right, bottom = zip(*bboxes) - left = max(0, min(left) - 5) - top = max(0, min(top) - 5) - right = min(x_len, max(right) + 5) - bottom = min(y_len, max(bottom) + 5) - ps_bbox = [left, top, right, bottom] - else: - print(gt_list[index], "has no Border or PrintSpace - using full page") - ps_bbox = [0, 0, None, None] - - - if config_file and (config_params['use_case']=='textline' or config_params['use_case']=='word' or config_params['use_case']=='glyph' or config_params['use_case']=='printspace'): - keys = list(config_params.keys()) - if "artificial_class_label" in keys: - artificial_class_rgb_color = (255,255,0) - artificial_class_label = config_params['artificial_class_label'] - - textline_rgb_color = (255, 0, 0) - - if config_params['use_case']=='textline': - region_tags = np.unique([x for x in alltags if x.endswith('TextLine')]) - elif config_params['use_case']=='word': - region_tags = np.unique([x for x in alltags if x.endswith('Word')]) - elif config_params['use_case']=='glyph': - region_tags = np.unique([x for x in alltags if x.endswith('Glyph')]) - elif config_params['use_case']=='printspace': - region_tags = np.unique([x for x in alltags if x.endswith('PrintSpace')]) - - co_use_case = [] - - for tag in region_tags: - if config_params['use_case']=='textline': - tag_endings = ['}TextLine','}textline'] - elif config_params['use_case']=='word': - tag_endings = ['}Word','}word'] - elif config_params['use_case']=='glyph': - tag_endings = ['}Glyph','}glyph'] - elif config_params['use_case']=='printspace': - tag_endings = ['}PrintSpace','}printspace'] + if 'columns_width' in list(config_params.keys()): + columns_width_dict = config_params['columns_width'] + metadata_element = root1.find(link+'Metadata') + num_col = None + for child in metadata_element: + tag2 = child.tag + if tag2.endswith('}Comments') or tag2.endswith('}comments'): + text_comments = child.text + num_col = int(text_comments.split('num_col')[1]) - if tag.endswith(tag_endings[0]) or tag.endswith(tag_endings[1]): - for nn in root1.iter(tag): - c_t_in = [] - sumi = 0 - for vv in nn.iter(): - # check the format of coords - if vv.tag == link + 'Coords': - coords = bool(vv.attrib) - if coords: - p_h = vv.attrib['points'].split(' ') - c_t_in.append( - np.array([[int(x.split(',')[0]), int(x.split(',')[1])] for x in p_h])) + if num_col: + x_new = columns_width_dict[str(num_col)] + y_new = int ( x_new * (y_len / float(x_len)) ) + + if printspace or "printspace_as_class_in_layout" in list(config_params.keys()): + region_tags = np.unique([x for x in alltags if x.endswith('PrintSpace') or x.endswith('Border')]) + co_use_case = [] + + for tag in region_tags: + tag_endings = ['}PrintSpace','}Border'] + + if tag.endswith(tag_endings[0]) or tag.endswith(tag_endings[1]): + for nn in root1.iter(tag): + c_t_in = [] + sumi = 0 + for vv in nn.iter(): + # check the format of coords + if vv.tag == link + 'Coords': + coords = bool(vv.attrib) + if coords: + p_h = vv.attrib['points'].split(' ') + c_t_in.append( + np.array([[int(x.split(',')[0]), int(x.split(',')[1])] for x in p_h])) + break + else: + pass + + if vv.tag == link + 'Point': + c_t_in.append([int(float(vv.attrib['x'])), int(float(vv.attrib['y']))]) + sumi += 1 + elif vv.tag != link + 'Point' and sumi >= 1: break - else: - pass - - if vv.tag == link + 'Point': - c_t_in.append([int(float(vv.attrib['x'])), int(float(vv.attrib['y']))]) - sumi += 1 - elif vv.tag != link + 'Point' and sumi >= 1: - break - co_use_case.append(np.array(c_t_in)) - - - if "artificial_class_label" in keys: - img_boundary = np.zeros((y_len, x_len)) - erosion_rate = 0#1 - dilation_rate = 2 - dilation_early = 0 - erosion_early = 2 - co_use_case, img_boundary = update_region_contours(co_use_case, img_boundary, erosion_rate, dilation_rate, y_len, x_len, dilation_early=dilation_early, erosion_early=erosion_early) - + co_use_case.append(np.array(c_t_in)) + + img = np.zeros((y_len, x_len, 3)) - img = np.zeros((y_len, x_len, 3)) - if output_type == '2d': img_poly = cv2.fillPoly(img, pts=co_use_case, color=(1, 1, 1)) - if "artificial_class_label" in keys: - img_mask = np.copy(img_poly) - ##img_poly[:,:][(img_boundary[:,:]==1) & (img_mask[:,:,0]!=1)] = artificial_class_label - img_poly[:,:][img_boundary[:,:]==1] = artificial_class_label - elif output_type == '3d': - img_poly = cv2.fillPoly(img, pts=co_use_case, color=textline_rgb_color) - if "artificial_class_label" in keys: - img_mask = np.copy(img_poly) - img_poly[:,:,0][(img_boundary[:,:]==1) & (img_mask[:,:,0]!=255)] = artificial_class_rgb_color[0] - img_poly[:,:,1][(img_boundary[:,:]==1) & (img_mask[:,:,0]!=255)] = artificial_class_rgb_color[1] - img_poly[:,:,2][(img_boundary[:,:]==1) & (img_mask[:,:,0]!=255)] = artificial_class_rgb_color[2] - - - if printspace and config_params['use_case']!='printspace': - img_poly = img_poly[ps_bbox[1]:ps_bbox[3], - ps_bbox[0]:ps_bbox[2], :] + img_poly = img_poly.astype(np.uint8) - if 'columns_width' in list(config_params.keys()) and num_col and config_params['use_case']!='printspace': - img_poly = resize_image(img_poly, y_new, x_new) + imgray = cv2.cvtColor(img_poly, cv2.COLOR_BGR2GRAY) + _, thresh = cv2.threshold(imgray, 0, 255, 0) - try: - xml_file_stem = os.path.splitext(gt_list[index])[0] - cv2.imwrite(os.path.join(output_dir, xml_file_stem + '.png'), img_poly) - except: - xml_file_stem = os.path.splitext(gt_list[index])[0] - cv2.imwrite(os.path.join(output_dir, xml_file_stem + '.png'), img_poly) + contours, _ = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) - if dir_images: - org_image_name = ls_org_imgs[xml_file_stem] - if not org_image_name: - print("image file for XML stem", xml_file_stem, "is missing") - continue - if not os.path.isfile(os.path.join(dir_images, org_image_name)): - print("image file for XML stem", xml_file_stem, "is not readable") - continue - img_org = cv2.imread(os.path.join(dir_images, org_image_name)) + cnt_size = np.array([cv2.contourArea(contours[j]) for j in range(len(contours))]) + try: + cnt = contours[np.argmax(cnt_size)] + x, y, w, h = cv2.boundingRect(cnt) + except: + x, y , w, h = 0, 0, x_len, y_len + + bb_xywh = [x, y, w, h] + + + if config_file and (config_params['use_case']=='textline' or config_params['use_case']=='word' or config_params['use_case']=='glyph' or config_params['use_case']=='printspace'): + keys = list(config_params.keys()) + if "artificial_class_label" in keys: + artificial_class_rgb_color = (255,255,0) + artificial_class_label = config_params['artificial_class_label'] + + textline_rgb_color = (255, 0, 0) + + if config_params['use_case']=='textline': + region_tags = np.unique([x for x in alltags if x.endswith('TextLine')]) + elif config_params['use_case']=='word': + region_tags = np.unique([x for x in alltags if x.endswith('Word')]) + elif config_params['use_case']=='glyph': + region_tags = np.unique([x for x in alltags if x.endswith('Glyph')]) + elif config_params['use_case']=='printspace': + region_tags = np.unique([x for x in alltags if x.endswith('PrintSpace')]) + + co_use_case = [] + + for tag in region_tags: + if config_params['use_case']=='textline': + tag_endings = ['}TextLine','}textline'] + elif config_params['use_case']=='word': + tag_endings = ['}Word','}word'] + elif config_params['use_case']=='glyph': + tag_endings = ['}Glyph','}glyph'] + elif config_params['use_case']=='printspace': + tag_endings = ['}PrintSpace','}printspace'] + + if tag.endswith(tag_endings[0]) or tag.endswith(tag_endings[1]): + for nn in root1.iter(tag): + c_t_in = [] + sumi = 0 + for vv in nn.iter(): + # check the format of coords + if vv.tag == link + 'Coords': + coords = bool(vv.attrib) + if coords: + p_h = vv.attrib['points'].split(' ') + c_t_in.append( + np.array([[int(x.split(',')[0]), int(x.split(',')[1])] for x in p_h])) + break + else: + pass + + if vv.tag == link + 'Point': + c_t_in.append([int(float(vv.attrib['x'])), int(float(vv.attrib['y']))]) + sumi += 1 + elif vv.tag != link + 'Point' and sumi >= 1: + break + co_use_case.append(np.array(c_t_in)) + + + if "artificial_class_label" in keys: + img_boundary = np.zeros((y_len, x_len)) + erosion_rate = 0#1 + dilation_rate = 2 + dilation_early = 0 + erosion_early = 2 + co_use_case, img_boundary = update_region_contours(co_use_case, img_boundary, erosion_rate, dilation_rate, y_len, x_len, dilation_early=dilation_early, erosion_early=erosion_early) + + + img = np.zeros((y_len, x_len, 3)) + if output_type == '2d': + img_poly = cv2.fillPoly(img, pts=co_use_case, color=(1, 1, 1)) + if "artificial_class_label" in keys: + img_mask = np.copy(img_poly) + ##img_poly[:,:][(img_boundary[:,:]==1) & (img_mask[:,:,0]!=1)] = artificial_class_label + img_poly[:,:][img_boundary[:,:]==1] = artificial_class_label + elif output_type == '3d': + img_poly = cv2.fillPoly(img, pts=co_use_case, color=textline_rgb_color) + if "artificial_class_label" in keys: + img_mask = np.copy(img_poly) + img_poly[:,:,0][(img_boundary[:,:]==1) & (img_mask[:,:,0]!=255)] = artificial_class_rgb_color[0] + img_poly[:,:,1][(img_boundary[:,:]==1) & (img_mask[:,:,0]!=255)] = artificial_class_rgb_color[1] + img_poly[:,:,2][(img_boundary[:,:]==1) & (img_mask[:,:,0]!=255)] = artificial_class_rgb_color[2] + + if printspace and config_params['use_case']!='printspace': - img_org = img_org[ps_bbox[1]:ps_bbox[3], - ps_bbox[0]:ps_bbox[2], :] + img_poly = img_poly[bb_xywh[1]:bb_xywh[1]+bb_xywh[3], bb_xywh[0]:bb_xywh[0]+bb_xywh[2], :] + if 'columns_width' in list(config_params.keys()) and num_col and config_params['use_case']!='printspace': - img_org = resize_image(img_org, y_new, x_new) - - cv2.imwrite(os.path.join(dir_out_images, org_image_name), img_org) + img_poly = resize_image(img_poly, y_new, x_new) - + try: + xml_file_stem = os.path.splitext(gt_list[index])[0] + cv2.imwrite(os.path.join(output_dir, xml_file_stem + '.png'), img_poly) + except: + xml_file_stem = os.path.splitext(gt_list[index])[0] + cv2.imwrite(os.path.join(output_dir, xml_file_stem + '.png'), img_poly) + + if dir_images: + org_image_name = ls_org_imgs[ls_org_imgs_stem.index(xml_file_stem)] + img_org = cv2.imread(os.path.join(dir_images, org_image_name)) + + if printspace and config_params['use_case']!='printspace': + img_org = img_org[bb_xywh[1]:bb_xywh[1]+bb_xywh[3], bb_xywh[0]:bb_xywh[0]+bb_xywh[2], :] + + if 'columns_width' in list(config_params.keys()) and num_col and config_params['use_case']!='printspace': + img_org = resize_image(img_org, y_new, x_new) + + cv2.imwrite(os.path.join(dir_out_images, org_image_name), img_org) + + except: + pass + if config_file and config_params['use_case']=='layout': keys = list(config_params.keys()) @@ -870,7 +922,7 @@ def get_images_of_ground_truth( types_graphic_label = list(types_graphic_dict.values()) - labels_rgb_color = [ (0,0,0), (255,0,0), (255,125,0), (255,0,125), (125,255,125), (125,125,0), (0,125,255), (0,125,0), (125,125,125), (255,0,255), (125,0,125), (0,255,0),(0,0,255), (0,255,255), (255,125,125), (0,125,125), (0,255,125), (255,125,255), (125,255,0), (125,255,255)] + labels_rgb_color = [ (0,0,0), (255,0,0), (255,125,0), (255,0,125), (125,255,125), (125,125,0), (0,125,255), (0,125,0), (125,125,125), (255,0,255), (125,0,125), (0,255,0),(0,0,255), (0,255,255), (255,125,125), (0,125,125), (0,255,125), (255,125,255), (125,255,0), (125,255,255), (125,125,255)] region_tags=np.unique([x for x in alltags if x.endswith('Region')]) @@ -882,6 +934,7 @@ def get_images_of_ground_truth( co_img=[] co_table=[] co_map=[] + co_music=[] co_noise=[] for tag in region_tags: @@ -966,19 +1019,21 @@ def get_images_of_ground_truth( if "rest_as_decoration" in types_graphic: types_graphic_without_decoration = [element for element in types_graphic if element!='rest_as_decoration' and element!='decoration'] if len(types_graphic_without_decoration) == 0: - if "type" in nn.attrib: - c_t_in_graphic['decoration'].append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + #if "type" in nn.attrib: + c_t_in_graphic['decoration'].append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) elif len(types_graphic_without_decoration) >= 1: if "type" in nn.attrib: if nn.attrib['type'] in types_graphic_without_decoration: c_t_in_graphic[nn.attrib['type']].append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) else: c_t_in_graphic['decoration'].append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) - + else: + c_t_in_graphic['decoration'].append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) else: if "type" in nn.attrib: if nn.attrib['type'] in all_defined_graphic_types: - c_t_in_graphic[nn.attrib['type']].append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + c_t_in_graphic[nn.attrib['type']].append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + break else: @@ -989,9 +1044,9 @@ def get_images_of_ground_truth( if "rest_as_decoration" in types_graphic: types_graphic_without_decoration = [element for element in types_graphic if element!='rest_as_decoration' and element!='decoration'] if len(types_graphic_without_decoration) == 0: - if "type" in nn.attrib: - c_t_in_graphic['decoration'].append( [ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ] ) - sumi+=1 + #if "type" in nn.attrib: + c_t_in_graphic['decoration'].append( [ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ] ) + sumi+=1 elif len(types_graphic_without_decoration) >= 1: if "type" in nn.attrib: if nn.attrib['type'] in types_graphic_without_decoration: @@ -1000,6 +1055,9 @@ def get_images_of_ground_truth( else: c_t_in_graphic['decoration'].append( [ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ] ) sumi+=1 + else: + c_t_in_graphic['decoration'].append( [ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ] ) + sumi+=1 else: if "type" in nn.attrib: @@ -1118,6 +1176,32 @@ def get_images_of_ground_truth( elif vv.tag!=link+'Point' and sumi>=1: break co_map.append(np.array(c_t_in)) + + if 'musicregion' in keys: + if tag.endswith('}MusicRegion') or tag.endswith('}musicregion'): + #print('sth') + for nn in root1.iter(tag): + c_t_in=[] + sumi=0 + for vv in nn.iter(): + # check the format of coords + if vv.tag==link+'Coords': + coords=bool(vv.attrib) + if coords: + p_h=vv.attrib['points'].split(' ') + c_t_in.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + break + else: + pass + + + if vv.tag==link+'Point': + c_t_in.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ]) + sumi+=1 + #print(vv.tag,'in') + elif vv.tag!=link+'Point' and sumi>=1: + break + co_music.append(np.array(c_t_in)) if 'noiseregion' in keys: if tag.endswith('}NoiseRegion') or tag.endswith('}noiseregion'): @@ -1195,6 +1279,10 @@ def get_images_of_ground_truth( erosion_rate = 0#2 dilation_rate = 3#4 co_map, img_boundary = update_region_contours(co_map, img_boundary, erosion_rate, dilation_rate, y_len, x_len ) + if "musicregion" in elements_with_artificial_class: + erosion_rate = 0#2 + dilation_rate = 3#4 + co_music, img_boundary = update_region_contours(co_music, img_boundary, erosion_rate, dilation_rate, y_len, x_len ) @@ -1222,6 +1310,8 @@ def get_images_of_ground_truth( img_poly=cv2.fillPoly(img, pts =co_table, color=labels_rgb_color[ config_params['tableregion']]) if 'mapregion' in keys: img_poly=cv2.fillPoly(img, pts =co_map, color=labels_rgb_color[ config_params['mapregion']]) + if 'musicregion' in keys: + img_poly=cv2.fillPoly(img, pts =co_music, color=labels_rgb_color[ config_params['musicregion']]) if 'noiseregion' in keys: img_poly=cv2.fillPoly(img, pts =co_noise, color=labels_rgb_color[ config_params['noiseregion']]) @@ -1286,6 +1376,9 @@ def get_images_of_ground_truth( if 'mapregion' in keys: color_label = config_params['mapregion'] img_poly=cv2.fillPoly(img, pts =co_map, color=(color_label,color_label,color_label)) + if 'musicregion' in keys: + color_label = config_params['musicregion'] + img_poly=cv2.fillPoly(img, pts =co_music, color=(color_label,color_label,color_label)) if 'noiseregion' in keys: color_label = config_params['noiseregion'] img_poly=cv2.fillPoly(img, pts =co_noise, color=(color_label,color_label,color_label)) diff --git a/src/eynollah/training/inference.py b/src/eynollah/training/inference.py index 2be937d..c3f229c 100644 --- a/src/eynollah/training/inference.py +++ b/src/eynollah/training/inference.py @@ -9,9 +9,11 @@ import warnings import json import click + import numpy as np from numpy._typing import NDArray import cv2 + import xml.etree.ElementTree as ET os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15 @@ -119,8 +121,33 @@ class SBBPredict: return mIoU def start_new_session_and_model(self): - if self.cpu: - tf.config.set_visible_devices([], 'GPU') + if self.task == "cnn-rnn-ocr": + if self.cpu: + os.environ['CUDA_VISIBLE_DEVICES']='-1' + self.model = load_model(self.model_dir) + self.model = tf.keras.models.Model( + self.model.get_layer(name = "image").input, + self.model.get_layer(name = "dense2").output) + + assert isinstance(self.model, Model) + + elif self.task == "transformer-ocr": + import torch + from transformers import VisionEncoderDecoderModel + from transformers import TrOCRProcessor + + self.model = VisionEncoderDecoderModel.from_pretrained(self.model_dir) + self.processor = TrOCRProcessor.from_pretrained(self.model_dir) + + if self.cpu: + self.device = torch.device('cpu') + else: + self.device = torch.device('cuda:0') + + self.model.to(self.device) + + assert isinstance(self.model, torch.nn.Module) + else: try: for device in tf.config.list_physical_devices('GPU'): @@ -137,15 +164,13 @@ class SBBPredict: custom_objects={"PatchEncoder": PatchEncoder, "Patches": Patches}) - ##if self.weights_dir!=None: - ##self.model.load_weights(self.weights_dir) + if self.task != 'classification' and self.task != 'reading_order': + self.img_height=self.model.layers[len(self.model.layers)-1].output_shape[1] + self.img_width=self.model.layers[len(self.model.layers)-1].output_shape[2] + self.n_classes=self.model.layers[len(self.model.layers)-1].output_shape[3] - assert isinstance(self.model, Model) - if self.task != 'classification' and self.task != 'reading_order': - last = self.model.layers[-1] - self.img_height = last.output_shape[1] - self.img_width = last.output_shape[2] - self.n_classes = last.output_shape[3] + + assert isinstance(self.model, Model) def visualize_model_output(self, prediction, img, task) -> Tuple[NDArray, NDArray]: if task == "binarization": @@ -191,9 +216,9 @@ class SBBPredict: return added_image, layout_only def predict(self, image_dir): - assert isinstance(self.model, Model) if self.task == 'classification': classes_names = self.config_params_model['classification_classes_name'] + img_1ch = cv2.imread(image_dir, 0) / 255.0 img_1ch = cv2.resize(img_1ch, (self.config_params_model['input_height'], self.config_params_model['input_width']), @@ -230,6 +255,15 @@ class SBBPredict: pred_texts = decode_batch_predictions(preds, num_to_char) pred_texts = pred_texts[0].replace("[UNK]", "") return pred_texts + + elif self.task == "transformer-ocr": + from PIL import Image + image = Image.open(image_dir).convert("RGB") + pixel_values = self.processor(image, return_tensors="pt").pixel_values + generated_ids = self.model.generate(pixel_values.to(self.device)) + return self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0] + + elif self.task == 'reading_order': @@ -566,6 +600,8 @@ class SBBPredict: cv2.imwrite(self.save,res) elif self.task == "cnn-rnn-ocr": print(f"Detected text: {res}") + elif self.task == "transformer-ocr": + print(f"Detected text: {res}") else: img_seg_overlayed, only_layout = self.visualize_model_output(res, self.img_org, self.task) if self.save: @@ -668,11 +704,13 @@ class SBBPredict: help="min area size of regions considered for reading order detection. The default value is zero and means that all text regions are considered for reading order.", ) def main(image, dir_in, model, patches, save, save_layout, ground_truth, xml_file, cpu, out, min_area): + assert image or dir_in, "Either a single image -i or a dir_in -di input is required" with open(os.path.join(model,'config.json')) as f: config_params_model = json.load(f) + task = config_params_model['task'] - if task not in ['classification', 'reading_order', "cnn-rnn-ocr"]: + if task not in ['classification', 'reading_order', "cnn-rnn-ocr", "transformer-ocr"]: assert not image or save, "For segmentation or binarization, an input single image -i also requires an output filename -s" assert not dir_in or out, "For segmentation or binarization, an input directory -di also requires an output directory -o" x = SBBPredict(image, dir_in, model, task, config_params_model, diff --git a/src/eynollah/training/models.py b/src/eynollah/training/models.py index 3494249..692218f 100644 --- a/src/eynollah/training/models.py +++ b/src/eynollah/training/models.py @@ -308,6 +308,8 @@ def transformer_block(img, x3 = mlp(x3, hidden_units=mlp_head_units, dropout_rate=0.1) # Skip connection 2. encoded_patches = Add()([x3, x2]) + + #assert isinstance(x, Layer) encoded_patches = tf.reshape(encoded_patches, [-1, diff --git a/src/eynollah/training/train.py b/src/eynollah/training/train.py index de998fd..01d87bc 100644 --- a/src/eynollah/training/train.py +++ b/src/eynollah/training/train.py @@ -3,9 +3,14 @@ import sys import io import json + from tqdm import tqdm import requests +import numpy as np +import cv2 + + os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15 import tensorflow as tf @@ -47,12 +52,21 @@ from .utils import ( generate_arrays_from_folder_reading_order, get_one_hot, preprocess_imgs, + return_number_of_total_training_data, + OCRDatasetYieldAugmentations ) from .weights_ensembling import run_ensembling +import torch +from transformers import TrOCRProcessor +import evaluate +from transformers import default_data_collator +from transformers import VisionEncoderDecoderModel +from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments + class SaveWeightsAfterSteps(ModelCheckpoint): - def __init__(self, save_interval, save_path, _config, **kwargs): + def __init__(self, save_interval, save_path, _config, characters_cnnrnn_ocr=None, **kwargs): if save_interval: # batches super().__init__( @@ -67,12 +81,15 @@ class SaveWeightsAfterSteps(ModelCheckpoint): verbose=1, **kwargs) self._config = _config + self.characters_cnnrnn_ocr = characters_cnnrnn_ocr # overwrite tf-keras (Keras 2) implementation to get our _config JSON in def _save_handler(self, filepath): super()._save_handler(filepath) with open(os.path.join(filepath, "config.json"), "w") as fp: json.dump(self._config, fp) # encode dict into JSON + if self.characters_cnnrnn_ocr: + os.system("cp "+self.characters_cnnrnn_ocr+" "+os.path.join(os.path.join(self.save_path, f"model_step_{self.step_count}"),"characters_org.txt")) def configuration(): try: @@ -820,9 +837,126 @@ def run(_config, usable_checkpoints = [os.path.join(dir_output, 'model_{epoch:02d}'.format(epoch=epoch + 1)) for epoch in usable_checkpoints] ens_path = os.path.join(dir_output, 'model_ens_avg') - run_ensembling(usable_checkpoints, ens_path) + run_ensembling(usable_checkpoints, ens_path, framework="tensorflow") _log.info("ensemble model saved under '%s'", ens_path) + # ======= + + + elif task=="transformer-ocr": + dir_img, dir_lab = get_dirs_or_files(dir_train) + + processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-printed") + + ls_files_images = os.listdir(dir_img) + + aug_multip = return_multiplier_based_on_augmnentations(augmentation, color_padding_rotation, rotation_not_90, blur_aug, degrading, bin_deg, + brightening, padding_white, adding_rgb_foreground, adding_rgb_background, binarization, + image_inversion, channels_shuffling, add_red_textlines, white_noise_strap, textline_skewing, textline_skewing_bin, textline_left_in_depth, textline_left_in_depth_bin, textline_right_in_depth, textline_right_in_depth_bin, textline_up_in_depth, textline_up_in_depth_bin, textline_down_in_depth, textline_down_in_depth_bin, pepper_bin_aug, pepper_aug, degrade_scales, number_of_backgrounds_per_image, thetha, thetha_padd, brightness, padd_colors, shuffle_indexes, pepper_indexes, skewing_amplitudes, blur_k, white_padds) + + len_dataset = aug_multip*len(ls_files_images) + + dataset = OCRDatasetYieldAugmentations( + dir_img=dir_img, + dir_img_bin=dir_img_bin, + dir_lab=dir_lab, + processor=processor, + max_target_length=max_len, + augmentation = augmentation, + binarization = binarization, + add_red_textlines = add_red_textlines, + white_noise_strap = white_noise_strap, + adding_rgb_foreground = adding_rgb_foreground, + adding_rgb_background = adding_rgb_background, + bin_deg = bin_deg, + blur_aug = blur_aug, + brightening = brightening, + padding_white = padding_white, + color_padding_rotation = color_padding_rotation, + rotation_not_90 = rotation_not_90, + degrading = degrading, + channels_shuffling = channels_shuffling, + textline_skewing = textline_skewing, + textline_skewing_bin = textline_skewing_bin, + textline_right_in_depth = textline_right_in_depth, + textline_left_in_depth = textline_left_in_depth, + textline_up_in_depth = textline_up_in_depth, + textline_down_in_depth = textline_down_in_depth, + textline_right_in_depth_bin = textline_right_in_depth_bin, + textline_left_in_depth_bin = textline_left_in_depth_bin, + textline_up_in_depth_bin = textline_up_in_depth_bin, + textline_down_in_depth_bin = textline_down_in_depth_bin, + pepper_aug = pepper_aug, + pepper_bin_aug = pepper_bin_aug, + list_all_possible_background_images=list_all_possible_background_images, + list_all_possible_foreground_rgbs=list_all_possible_foreground_rgbs, + blur_k = blur_k, + degrade_scales = degrade_scales, + white_padds = white_padds, + thetha_padd = thetha_padd, + thetha = thetha, + brightness = brightness, + padd_colors = padd_colors, + number_of_backgrounds_per_image = number_of_backgrounds_per_image, + shuffle_indexes = shuffle_indexes, + pepper_indexes = pepper_indexes, + skewing_amplitudes = skewing_amplitudes, + dir_rgb_backgrounds = dir_rgb_backgrounds, + dir_rgb_foregrounds = dir_rgb_foregrounds, + len_data=len_dataset, + ) + + # Create a DataLoader + data_loader = torch.utils.data.DataLoader(dataset, batch_size=1) + train_dataset = data_loader.dataset + + + if continue_training: + model = VisionEncoderDecoderModel.from_pretrained(dir_of_start_model) + else: + model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-printed") + + + # set special tokens used for creating the decoder_input_ids from the labels + model.config.decoder_start_token_id = processor.tokenizer.cls_token_id + model.config.pad_token_id = processor.tokenizer.pad_token_id + # make sure vocab size is set correctly + model.config.vocab_size = model.config.decoder.vocab_size + + # set beam search parameters + model.config.eos_token_id = processor.tokenizer.sep_token_id + model.config.max_length = max_len + model.config.early_stopping = True + model.config.no_repeat_ngram_size = 3 + model.config.length_penalty = 2.0 + model.config.num_beams = 4 + + + training_args = Seq2SeqTrainingArguments( + predict_with_generate=True, + num_train_epochs=n_epochs, + learning_rate=learning_rate, + per_device_train_batch_size=n_batch, + fp16=True, + output_dir=dir_output, + logging_steps=2, + save_steps=save_interval, + ) + + + cer_metric = evaluate.load("cer") + + # instantiate trainer + trainer = Seq2SeqTrainer( + model=model, + tokenizer=processor.feature_extractor, + args=training_args, + train_dataset=train_dataset, + data_collator=default_data_collator, + ) + trainer.train() + + elif task=='reading_order': if continue_training: model = load_model(dir_of_start_model, compile=False) diff --git a/src/eynollah/training/utils.py b/src/eynollah/training/utils.py index 33a1fd2..3b4f1f9 100644 --- a/src/eynollah/training/utils.py +++ b/src/eynollah/training/utils.py @@ -14,6 +14,9 @@ import tensorflow as tf from PIL import Image, ImageFile, ImageEnhance +import torch +from torch.utils.data import IterableDataset + ImageFile.LOAD_TRUNCATED_IMAGES = True @@ -78,6 +81,7 @@ def add_salt_and_pepper_noise(img, salt_prob, pepper_prob): return noisy_image + def invert_image(img): img_inv = 255 - img return img_inv @@ -1242,3 +1246,411 @@ def preprocess_img_ocr( for pepper_ind in pepper_indexes: img_noisy = add_salt_and_pepper_noise(img_bin_corr, pepper_ind, pepper_ind) yield scale_image(img_noisy), lab + + +class OCRDatasetYieldAugmentations(IterableDataset): + def __init__( + self, + dir_img, + dir_img_bin, + dir_lab, + processor, + max_target_length=128, + augmentation = None, + binarization = None, + add_red_textlines = None, + white_noise_strap = None, + adding_rgb_foreground = None, + adding_rgb_background = None, + bin_deg = None, + blur_aug = None, + brightening = None, + padding_white = None, + color_padding_rotation = None, + rotation_not_90 = None, + degrading = None, + channels_shuffling = None, + textline_skewing = None, + textline_skewing_bin = None, + textline_right_in_depth = None, + textline_left_in_depth = None, + textline_up_in_depth = None, + textline_down_in_depth = None, + textline_right_in_depth_bin = None, + textline_left_in_depth_bin = None, + textline_up_in_depth_bin = None, + textline_down_in_depth_bin = None, + pepper_aug = None, + pepper_bin_aug = None, + list_all_possible_background_images=None, + list_all_possible_foreground_rgbs=None, + blur_k = None, + degrade_scales = None, + white_padds = None, + thetha_padd = None, + thetha = None, + brightness = None, + padd_colors = None, + number_of_backgrounds_per_image = None, + shuffle_indexes = None, + pepper_indexes = None, + skewing_amplitudes = None, + dir_rgb_backgrounds = None, + dir_rgb_foregrounds = None, + len_data=None, + ): + """ + Args: + images_dir (str): Path to the directory containing images. + labels_dir (str): Path to the directory containing label text files. + tokenizer: Tokenizer for processing labels. + transform: Transformations applied after augmentation (e.g., ToTensor, normalization). + image_size (tuple): Size to resize images to. + max_seq_len (int): Maximum sequence length for tokenized labels. + scales (list or None): List of scale factors to apply. + """ + self.dir_img = dir_img + self.dir_img_bin = dir_img_bin + self.dir_lab = dir_lab + self.processor = processor + self.max_target_length = max_target_length + #self.scales = scales if scales else [] + + self.augmentation = augmentation + self.binarization = binarization + self.add_red_textlines = add_red_textlines + self.white_noise_strap = white_noise_strap + self.adding_rgb_foreground = adding_rgb_foreground + self.adding_rgb_background = adding_rgb_background + self.bin_deg = bin_deg + self.blur_aug = blur_aug + self.brightening = brightening + self.padding_white = padding_white + self.color_padding_rotation = color_padding_rotation + self.rotation_not_90 = rotation_not_90 + self.degrading = degrading + self.channels_shuffling = channels_shuffling + self.textline_skewing = textline_skewing + self.textline_skewing_bin = textline_skewing_bin + self.textline_right_in_depth = textline_right_in_depth + self.textline_left_in_depth = textline_left_in_depth + self.textline_up_in_depth = textline_up_in_depth + self.textline_down_in_depth = textline_down_in_depth + self.textline_right_in_depth_bin = textline_right_in_depth_bin + self.textline_left_in_depth_bin = textline_left_in_depth_bin + self.textline_up_in_depth_bin = textline_up_in_depth_bin + self.textline_down_in_depth_bin = textline_down_in_depth_bin + self.pepper_aug = pepper_aug + self.pepper_bin_aug = pepper_bin_aug + self.list_all_possible_background_images=list_all_possible_background_images + self.list_all_possible_foreground_rgbs=list_all_possible_foreground_rgbs + self.blur_k = blur_k + self.degrade_scales = degrade_scales + self.white_padds = white_padds + self.thetha_padd = thetha_padd + self.thetha = thetha + self.brightness = brightness + self.padd_colors = padd_colors + self.number_of_backgrounds_per_image = number_of_backgrounds_per_image + self.shuffle_indexes = shuffle_indexes + self.pepper_indexes = pepper_indexes + self.skewing_amplitudes = skewing_amplitudes + self.dir_rgb_backgrounds = dir_rgb_backgrounds + self.dir_rgb_foregrounds = dir_rgb_foregrounds + self.image_files = os.listdir(dir_img)#sorted([f for f in os.listdir(images_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]) + self.len_data = len_data + #assert len(self.image_files) == len(self.label_files), "Number of images and labels must match!" + + def __len__(self): + return self.len_data + + def __iter__(self): + for img_file in self.image_files: + # Load image + f_name = img_file.split('.')[0] + + txt_inp = open(os.path.join(self.dir_lab, f_name+'.txt'),'r').read().split('\n')[0] + + img = cv2.imread(os.path.join(self.dir_img, img_file)) + img = img.astype(np.uint8) + + + if self.dir_img_bin: + img_bin_corr = cv2.imread(os.path.join(self.dir_img_bin, f_name+'.png') ) + img_bin_corr = img_bin_corr.astype(np.uint8) + else: + img_bin_corr = None + + + labels = self.processor.tokenizer(txt_inp, + padding="max_length", + max_length=self.max_target_length).input_ids + + labels = [label if label != self.processor.tokenizer.pad_token_id else -100 for label in labels] + + + if self.augmentation: + pixel_values = self.processor(Image.fromarray(img), return_tensors="pt").pixel_values + encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)} + yield encoding + + + if self.color_padding_rotation: + for index, thetha_ind in enumerate(self.thetha_padd): + for padd_col in self.padd_colors: + img_out = rotation_not_90_func_single_image(do_padding_for_ocr(img, 1.2, padd_col), thetha_ind) + img_out = img_out.astype(np.uint8) + pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values + encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)} + yield encoding + + if self.rotation_not_90: + for index, thetha_ind in enumerate(self.thetha): + img_out = rotation_not_90_func_single_image(img, thetha_ind) + img_out = img_out.astype(np.uint8) + pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values + encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)} + yield encoding + + + if self.blur_aug: + for index, blur_type in enumerate(self.blur_k): + img_out = bluring(img, blur_type) + img_out = img_out.astype(np.uint8) + pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values + encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)} + yield encoding + + if self.degrading: + for index, deg_scale_ind in enumerate(self.degrade_scales): + try: + img_out = do_degrading(img, deg_scale_ind) + except: + img_out = np.copy(img) + img_out = img_out.astype(np.uint8) + pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values + encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)} + yield encoding + + if self.bin_deg: + for index, deg_scale_ind in enumerate(self.degrade_scales): + try: + img_out = self.do_degrading(img_bin_corr, deg_scale_ind) + except: + img_out = np.copy(img_bin_corr) + img_out = img_out.astype(np.uint8) + pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values + encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)} + yield encoding + + + if self.brightening: + for index, bright_scale_ind in enumerate(self.brightness): + try: + img_out = do_brightening(dir_img, bright_scale_ind) + except: + img_out = np.copy(img) + img_out = img_out.astype(np.uint8) + pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values + encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)} + yield encoding + + + if self.padding_white: + for index, padding_size in enumerate(self.white_padds): + for padd_col in self.padd_colors: + img_out = do_padding_for_ocr(img, padding_size, padd_col) + img_out = img_out.astype(np.uint8) + pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values + encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)} + yield encoding + + + if self.adding_rgb_foreground: + for i_n in range(self.number_of_backgrounds_per_image): + background_image_chosen_name = random.choice(self.list_all_possible_background_images) + foreground_rgb_chosen_name = random.choice(self.list_all_possible_foreground_rgbs) + + img_rgb_background_chosen = cv2.imread(self.dir_rgb_backgrounds + '/' + background_image_chosen_name) + foreground_rgb_chosen = np.load(self.dir_rgb_foregrounds + '/' + foreground_rgb_chosen_name) + + img_out = return_binary_image_with_given_rgb_background_and_given_foreground_rgb(img_bin_corr, img_rgb_background_chosen, foreground_rgb_chosen) + img_out = img_out.astype(np.uint8) + pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values + encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)} + yield encoding + + + + + if self.adding_rgb_background: + for i_n in range(self.number_of_backgrounds_per_image): + background_image_chosen_name = random.choice(self.list_all_possible_background_images) + img_rgb_background_chosen = cv2.imread(self.dir_rgb_backgrounds + '/' + background_image_chosen_name) + img_out = return_binary_image_with_given_rgb_background(img_bin_corr, img_rgb_background_chosen) + img_out = img_out.astype(np.uint8) + pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values + encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)} + yield encoding + + if self.binarization: + pixel_values = self.processor(Image.fromarray(img_bin_corr), return_tensors="pt").pixel_values + encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)} + yield encoding + + + if self.channels_shuffling: + for shuffle_index in self.shuffle_indexes: + img_out = return_shuffled_channels(img, shuffle_index) + img_out = img_out.astype(np.uint8) + pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values + encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)} + yield encoding + + if self.add_red_textlines: + img_out = return_image_with_red_elements(img, img_bin_corr) + img_out = img_out.astype(np.uint8) + pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values + encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)} + yield encoding + + if self.white_noise_strap: + img_out = return_image_with_strapped_white_noises(img) + img_out = img_out.astype(np.uint8) + pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values + encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)} + yield encoding + + if self.textline_skewing: + for index, des_scale_ind in enumerate(self.skewing_amplitudes): + try: + img_out = do_deskewing(img, des_scale_ind) + except: + img_out = np.copy(img) + img_out = img_out.astype(np.uint8) + pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values + encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)} + yield encoding + + if self.textline_skewing_bin: + for index, des_scale_ind in enumerate(self.skewing_amplitudes): + try: + img_out = do_deskewing(img_bin_corr, des_scale_ind) + except: + img_out = np.copy(img_bin_corr) + img_out = img_out.astype(np.uint8) + pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values + encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)} + yield encoding + + + if self.textline_left_in_depth: + try: + img_out = do_direction_in_depth(img, 'left') + except: + img_out = np.copy(img) + img_out = img_out.astype(np.uint8) + pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values + encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)} + yield encoding + + + if self.textline_left_in_depth_bin: + try: + img_out = do_direction_in_depth(img_bin_corr, 'left') + except: + img_out = np.copy(img_bin_corr) + img_out = img_out.astype(np.uint8) + pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values + encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)} + yield encoding + + + if self.textline_right_in_depth: + try: + img_out = do_direction_in_depth(img, 'right') + except: + img_out = np.copy(img) + img_out = img_out.astype(np.uint8) + pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values + encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)} + yield encoding + + + if self.textline_right_in_depth_bin: + try: + img_out = do_direction_in_depth(img_bin_corr, 'right') + except: + img_out = np.copy(img_bin_corr) + img_out = img_out.astype(np.uint8) + pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values + encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)} + yield encoding + + + if self.textline_up_in_depth: + try: + img_out = do_direction_in_depth(img, 'up') + except: + img_out = np.copy(img) + img_out = img_out.astype(np.uint8) + pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values + encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)} + yield encoding + + + if self.textline_up_in_depth_bin: + try: + img_out = do_direction_in_depth(img_bin_corr, 'up') + except: + img_out = np.copy(img_bin_corr) + img_out = img_out.astype(np.uint8) + pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values + encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)} + yield encoding + + + if self.textline_down_in_depth: + try: + img_out = do_direction_in_depth(img, 'down') + except: + img_out = np.copy(img) + img_out = img_out.astype(np.uint8) + pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values + encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)} + yield encoding + + + if self.textline_down_in_depth_bin: + try: + img_out = do_direction_in_depth(img_bin_corr, 'down') + except: + img_out = np.copy(img_bin_corr) + img_out = img_out.astype(np.uint8) + pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values + encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)} + yield encoding + + if self.pepper_bin_aug: + for index, pepper_ind in enumerate(self.pepper_indexes): + img_out = add_salt_and_pepper_noise(img_bin_corr, pepper_ind, pepper_ind) + img_out = img_out.astype(np.uint8) + pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values + encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)} + yield encoding + + + if self.pepper_aug: + for index, pepper_ind in enumerate(self.pepper_indexes): + img_out = add_salt_and_pepper_noise(img, pepper_ind, pepper_ind) + img_out = img_out.astype(np.uint8) + pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values + encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)} + yield encoding + + + + else: + pixel_values = self.processor(Image.fromarray(img), return_tensors="pt").pixel_values + encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)} + yield encoding diff --git a/src/eynollah/training/weights_ensembling.py b/src/eynollah/training/weights_ensembling.py index e3ede24..0a308ea 100644 --- a/src/eynollah/training/weights_ensembling.py +++ b/src/eynollah/training/weights_ensembling.py @@ -16,28 +16,53 @@ from ..patch_encoder import ( PatchEncoder, Patches, ) +from PIL import Image +import torch +from transformers import VisionEncoderDecoderModel -def run_ensembling(model_dirs, out_dir): - all_weights = [] - - for model_dir in model_dirs: - assert os.path.isdir(model_dir), model_dir - model = load_model(model_dir, compile=False, - custom_objects=dict(PatchEncoder=PatchEncoder, - Patches=Patches)) - all_weights.append(model.get_weights()) +def run_ensembling(dir_models, out, framework): + ls_models = os.listdir(dir_models) + if framework=="torch": + models = [] + sd_models = [] - new_weights = [] - for layer_weights in zip(*all_weights): - layer_weights = np.array([np.array(weights).mean(axis=0) - for weights in zip(*layer_weights)]) - new_weights.append(layer_weights) + for model_name in ls_models: + model = VisionEncoderDecoderModel.from_pretrained(os.path.join(dir_models,model_name)) + models.append(model) + sd_models.append(model.state_dict()) + for key in sd_models[0]: + sd_models[0][key] = sum(sd[key] for sd in sd_models) / len(sd_models) + + model.load_state_dict(sd_models[0]) + os.system("mkdir "+out) + torch.save(model.state_dict(), os.path.join(out, "pytorch_model.bin")) + os.system('cp '+os.path.join(os.path.join(dir_models,model_name) , "config_eynollah.json ")+out) + + else: + weights=[] - #model = tf.keras.models.clone_model(model) - model.set_weights(new_weights) + for model_name in ls_models: + model = load_model(os.path.join(dir_models,model_name) , compile=False, custom_objects={'PatchEncoder':PatchEncoder, 'Patches': Patches}) + weights.append(model.get_weights()) + + new_weights = list() - model.save(out_dir) - os.system('cp ' + os.path.join(model_dirs[0], "config.json ") + out_dir + "/") + for weights_list_tuple in zip(*weights): + new_weights.append( + [np.array(weights_).mean(axis=0)\ + for weights_ in zip(*weights_list_tuple)]) + + + + new_weights = [np.array(x) for x in new_weights] + + model.set_weights(new_weights) + model.save(out) + os.system('cp '+os.path.join(os.path.join(dir_models,model_name) , "config_eynollah.json ")+out) + try: + os.system('cp '+os.path.join(os.path.join(dir_models,model_name) , "characters_org.txt ")+out) + except: + pass @click.command() @click.option( @@ -55,12 +80,17 @@ def run_ensembling(model_dirs, out_dir): required=True, type=click.Path(exists=False, file_okay=False), ) -def ensemble_cli(in_, out): +@click.option( + "--framework", + "-fw", + help="this parameter gets tensorflow or torch as model framework", +) + +def ensemble_cli(in_, out, framework): """ mix multiple model weights Load a sequence of models and mix them into a single ensemble model by averaging their weights. Write the resulting model. """ - run_ensembling(in_, out) - + run_ensembling(in_, out, framework) diff --git a/src/eynollah/utils/font.py b/src/eynollah/utils/font.py index 939933e..3e9e588 100644 --- a/src/eynollah/utils/font.py +++ b/src/eynollah/utils/font.py @@ -9,8 +9,8 @@ else: import importlib.resources as importlib_resources -def get_font(): +def get_font(font_size): #font_path = "Charis-7.000/Charis-Regular.ttf" # Make sure this file exists! - font = importlib_resources.files(__package__) / "../Charis-Regular.ttf" + font = importlib_resources.files(__package__) / "../Amiri-Regular.ttf" with importlib_resources.as_file(font) as font: - return ImageFont.truetype(font=font, size=40) + return ImageFont.truetype(font=font, size=font_size) diff --git a/train/config_params.json b/train/config_params.json index b01ac08..34c6376 100644 --- a/train/config_params.json +++ b/train/config_params.json @@ -1,17 +1,17 @@ { "backbone_type" : "transformer", - "task": "cnn-rnn-ocr", + "task": "transformer-ocr", "n_classes" : 2, - "max_len": 280, - "n_epochs" : 3, + "max_len": 192, + "n_epochs" : 1, "input_height" : 32, "input_width" : 512, "weight_decay" : 1e-6, - "n_batch" : 4, + "n_batch" : 1, "learning_rate": 1e-5, "save_interval": 1500, "patches" : false, - "pretraining" : true, + "pretraining" : false, "augmentation" : true, "flip_aug" : false, "blur_aug" : true, @@ -77,7 +77,6 @@ "dir_output": "/home/vahid/extracted_lines/1919_bin/output", "dir_rgb_backgrounds": "/home/vahid/Documents/1_2_test_eynollah/set_rgb_background", "dir_rgb_foregrounds": "/home/vahid/Documents/1_2_test_eynollah/out_set_rgb_foreground", - "dir_img_bin": "/home/vahid/extracted_lines/1919_bin/images_bin", - "characters_txt_file":"/home/vahid/Downloads/models_eynollah/model_eynollah_ocr_cnnrnn_20250930/characters_org.txt" + "dir_img_bin": "/home/vahid/extracted_lines/1919_bin/images_bin" } diff --git a/train/requirements.txt b/train/requirements.txt index 090bc50..9b8ee75 100644 --- a/train/requirements.txt +++ b/train/requirements.txt @@ -8,3 +8,8 @@ tensorflow-addons # for connected_components, depublished and only compatible wi tensorflow < 2.16 # for tensorflow-addons, so only needed in training tf_data < 2.16 # for tensorflow-addons, so only needed in training protobuf < 5 # for tensorflow-addons, so only needed in training +torch +evaluate +accelerate +jiwer +transformers <= 4.30.2