This commit is contained in:
Konstantin Baierer 2026-02-19 12:59:25 +00:00 committed by GitHub
commit 38745111df
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 875 additions and 113 deletions

View file

@ -70,6 +70,7 @@ class Eynollah_ocr:
self.model_zoo.get('ocr').to(self.device)
else:
self.model_zoo.load_model('ocr', '')
self.input_shape = self.model_zoo.get('ocr').input_shape[1:3]
self.model_zoo.load_model('num_to_char')
self.model_zoo.load_model('characters')
self.end_character = len(self.model_zoo.get('characters', list)) + 2
@ -657,7 +658,7 @@ class Eynollah_ocr:
if out_image_with_text:
image_text = Image.new("RGB", (img.shape[1], img.shape[0]), "white")
draw = ImageDraw.Draw(image_text)
font = get_font()
font = get_font(font_size=40)
for indexer_text, bb_ind in enumerate(total_bb_coordinates):
x_bb = bb_ind[0]
@ -823,8 +824,8 @@ class Eynollah_ocr:
page_ns=page_ns,
img_bin=img_bin,
image_width=512,
image_height=32,
image_width=self.input_shape[1],
image_height=self.input_shape[0],
)
self.write_ocr(

View file

@ -10,6 +10,7 @@ from .inference import main as inference_cli
from .train import ex
from .extract_line_gt import linegt_cli
from .weights_ensembling import main as ensemble_cli
from .generate_or_update_cnn_rnn_ocr_character_list import main as update_ocr_characters_cli
@click.command(context_settings=dict(
ignore_unknown_options=True,
@ -28,3 +29,4 @@ main.add_command(inference_cli, 'inference')
main.add_command(train_cli, 'train')
main.add_command(linegt_cli, 'export_textline_images_and_text')
main.add_command(ensemble_cli, 'ensembling')
main.add_command(update_ocr_characters_cli, 'generate_or_update_cnn_rnn_ocr_character_list')

View file

@ -50,6 +50,18 @@ from ..utils import is_image_filename
is_flag=True,
help="if this parameter set to true, cropped textline images will not be masked with textline contour.",
)
@click.option(
"--exclude_vertical_lines",
"-exv",
is_flag=True,
help="if this parameter set to true, vertical textline images will be excluded.",
)
@click.option(
"--page_alto",
"-alto",
is_flag=True,
help="If this parameter is set to True, text line image cropping and text extraction are performed using PAGE/ALTO files. Otherwise, the default method for PAGE XML files is used.",
)
def linegt_cli(
image,
dir_in,
@ -57,6 +69,8 @@ def linegt_cli(
dir_out,
pref_of_dataset,
do_not_mask_with_textline_contour,
exclude_vertical_lines,
page_alto,
):
assert bool(dir_in) ^ bool(image), "Set --dir-in or --image-filename, not both"
if dir_in:
@ -70,65 +84,149 @@ def linegt_cli(
for dir_img in ls_imgs:
file_name = Path(dir_img).stem
dir_xml = os.path.join(dir_xmls, file_name + '.xml')
img = cv2.imread(dir_img)
total_bb_coordinates = []
if page_alto:
h, w = img.shape[:2]
tree1 = ET.parse(dir_xml, parser=ET.XMLParser(encoding="utf-8"))
root1 = tree1.getroot()
alltags = [elem.tag for elem in root1.iter()]
tree = ET.parse(dir_xml)
root = tree.getroot()
name_space = alltags[0].split('}')[0]
name_space = name_space.split('{')[1]
NS = {"alto": "http://www.loc.gov/standards/alto/ns-v4#"}
region_tags = np.unique([x for x in alltags if x.endswith('TextRegion')])
results = []
cropped_lines_region_indexer = []
indexer_textlines = 0
for line in root.findall(".//alto:TextLine", NS):
string_el = line.find("alto:String", NS)
textline_text = string_el.attrib["CONTENT"] if string_el is not None else None
indexer_text_region = 0
indexer_textlines = 0
# FIXME: non recursive, use OCR-D PAGE generateDS API. Or use an existing tool for this purpose altogether
for nn in root1.iter(region_tags):
for child_textregion in nn:
if child_textregion.tag.endswith("TextLine"):
for child_textlines in child_textregion:
if child_textlines.tag.endswith("Coords"):
cropped_lines_region_indexer.append(indexer_text_region)
p_h = child_textlines.attrib['points'].split(' ')
textline_coords = np.array([[int(x.split(',')[0]), int(x.split(',')[1])] for x in p_h])
polygon_el = line.find("alto:Shape/alto:Polygon", NS)
if polygon_el is None:
continue
x, y, w, h = cv2.boundingRect(textline_coords)
points = polygon_el.attrib["POINTS"].split()
coords = [
(int(points[i]), int(points[i + 1]))
for i in range(0, len(points), 2)
]
total_bb_coordinates.append([x, y, w, h])
coords = np.array(coords, dtype=np.int32)
x, y, w, h = cv2.boundingRect(coords)
img_poly_on_img = np.copy(img)
mask_poly = np.zeros(img.shape)
mask_poly = cv2.fillPoly(mask_poly, pts=[textline_coords], color=(1, 1, 1))
if exclude_vertical_lines and h > 1.4 * w:
img_crop = None
continue
mask_poly = mask_poly[y : y + h, x : x + w, :]
img_crop = img_poly_on_img[y : y + h, x : x + w, :]
img_poly_on_img = np.copy(img)
if not do_not_mask_with_textline_contour:
img_crop[mask_poly == 0] = 255
mask_poly = np.zeros(img.shape)
mask_poly = cv2.fillPoly(mask_poly, pts=[coords], color=(1, 1, 1))
if img_crop.shape[0] == 0 or img_crop.shape[1] == 0:
continue
if child_textlines.tag.endswith("TextEquiv"):
for cheild_text in child_textlines:
if cheild_text.tag.endswith("Unicode"):
textline_text = cheild_text.text
if textline_text:
base_name = os.path.join(
dir_out, file_name + '_line_' + str(indexer_textlines)
)
if pref_of_dataset:
base_name += '_' + pref_of_dataset
if not do_not_mask_with_textline_contour:
base_name += '_masked'
mask_poly = mask_poly[y : y + h, x : x + w, :]
img_crop = img_poly_on_img[y : y + h, x : x + w, :]
with open(base_name + '.txt', 'w') as text_file:
text_file.write(textline_text)
cv2.imwrite(base_name + '.png', img_crop)
indexer_textlines += 1
if not do_not_mask_with_textline_contour:
img_crop[mask_poly == 0] = 255
if img_crop.shape[0] == 0 or img_crop.shape[1] == 0:
img_crop = None
continue
if textline_text and img_crop is not None:
base_name = os.path.join(
dir_out, file_name + '_line_' + str(indexer_textlines)
)
if pref_of_dataset:
base_name += '_' + pref_of_dataset
if not do_not_mask_with_textline_contour:
base_name += '_masked'
with open(base_name + '.txt', 'w') as text_file:
text_file.write(textline_text)
cv2.imwrite(base_name + '.png', img_crop)
indexer_textlines += 1
else:
total_bb_coordinates = []
tree = ET.parse(dir_xml, parser=ET.XMLParser(encoding="utf-8"))
root = tree.getroot()
alltags = [elem.tag for elem in root.iter()]
name_space = alltags[0].split('}')[0]
name_space = name_space.split('{')[1]
region_tags = np.unique([x for x in alltags if x.endswith('TextRegion')])
cropped_lines_region_indexer = []
indexer_text_region = 0
indexer_textlines = 0
# FIXME: non recursive, use OCR-D PAGE generateDS API. Or use an existing tool for this purpose altogether
for nn in root.iter(region_tags):
for child_textregion in nn:
if child_textregion.tag.endswith("TextLine"):
for child_textlines in child_textregion:
if child_textlines.tag.endswith("Coords"):
cropped_lines_region_indexer.append(indexer_text_region)
p_h = child_textlines.attrib['points'].split(' ')
textline_coords = np.array([[int(x.split(',')[0]), int(x.split(',')[1])] for x in p_h])
x, y, w, h = cv2.boundingRect(textline_coords)
if exclude_vertical_lines and h > 1.4 * w:
img_crop = None
continue
total_bb_coordinates.append([x, y, w, h])
img_poly_on_img = np.copy(img)
mask_poly = np.zeros(img.shape)
mask_poly = cv2.fillPoly(mask_poly, pts=[textline_coords], color=(1, 1, 1))
mask_poly = mask_poly[y : y + h, x : x + w, :]
img_crop = img_poly_on_img[y : y + h, x : x + w, :]
if not do_not_mask_with_textline_contour:
img_crop[mask_poly == 0] = 255
if img_crop.shape[0] == 0 or img_crop.shape[1] == 0:
img_crop = None
continue
if child_textlines.tag.endswith("TextEquiv"):
for cheild_text in child_textlines:
if cheild_text.tag.endswith("Unicode"):
textline_text = cheild_text.text
if textline_text and img_crop is not None:
base_name = os.path.join(
dir_out, file_name + '_line_' + str(indexer_textlines)
)
if pref_of_dataset:
base_name += '_' + pref_of_dataset
if not do_not_mask_with_textline_contour:
base_name += '_masked'
with open(base_name + '.txt', 'w') as text_file:
text_file.write(textline_text)
cv2.imwrite(base_name + '.png', img_crop)
indexer_textlines += 1

View file

@ -6,6 +6,7 @@ from pathlib import Path
from PIL import Image, ImageDraw, ImageFont
import cv2
import numpy as np
from eynollah.utils.font import get_font
from eynollah.training.gt_gen_utils import (
filter_contours_area_of_image,
@ -477,7 +478,7 @@ def visualize_layout_segmentation(xml_file, dir_xml, dir_out, dir_imgs):
co_text, co_graphic, co_sep, co_img, co_table, co_map, co_noise, y_len, x_len = get_layout_contours_for_visualization(xml_file)
added_image = visualize_image_from_contours_layout(co_text['paragraph'], co_text['header']+co_text['heading'], co_text['drop-capital'], co_sep, co_img, co_text['marginalia'], co_table, img)
added_image = visualize_image_from_contours_layout(co_text['paragraph'], co_text['header']+co_text['heading'], co_text['drop-capital'], co_sep, co_img, co_text['marginalia'], co_table, co_map, img)
cv2.imwrite(os.path.join(dir_out, f_name+'.png'), added_image)
@ -514,8 +515,8 @@ def visualize_ocr_text(xml_file, dir_xml, dir_out):
else:
xml_files_ind = [xml_file]
font_path = "Charis-7.000/Charis-Regular.ttf" # Make sure this file exists!
font = ImageFont.truetype(font_path, 40)
###font_path = "Charis-7.000/Charis-Regular.ttf" # Make sure this file exists!
font = get_font(font_size=40)#ImageFont.truetype(font_path, 40)
for ind_xml in tqdm(xml_files_ind):
indexer = 0
@ -552,11 +553,11 @@ def visualize_ocr_text(xml_file, dir_xml, dir_out):
is_vertical = h > 2*w # Check orientation
font = fit_text_single_line(draw, ocr_texts[index], font_path, w, int(h*0.4) )
font = fit_text_single_line(draw, ocr_texts[index], w, int(h*0.4) )
if is_vertical:
vertical_font = fit_text_single_line(draw, ocr_texts[index], font_path, h, int(w * 0.8))
vertical_font = fit_text_single_line(draw, ocr_texts[index], h, int(w * 0.8))
text_img = Image.new("RGBA", (h, w), (255, 255, 255, 0)) # Note: dimensions are swapped
text_draw = ImageDraw.Draw(text_img)

View file

@ -0,0 +1,59 @@
import os
import numpy as np
import json
import click
import logging
def run_character_list_update(dir_labels, out, current_character_list):
ls_labels = os.listdir(dir_labels)
ls_labels = [ind for ind in ls_labels if ind.endswith('.txt')]
if current_character_list:
with open(current_character_list, 'r') as f_name:
characters = json.load(f_name)
characters = set(characters)
else:
characters = set()
for ind in ls_labels:
label = open(os.path.join(dir_labels,ind),'r').read().split('\n')[0]
for char in label:
characters.add(char)
characters = sorted(list(set(characters)))
with open(out, 'w') as f_name:
json.dump(characters, f_name)
@click.command()
@click.option(
"--dir_labels",
"-dl",
help="directory of labels which are .txt files",
type=click.Path(exists=True, file_okay=False),
required=True,
)
@click.option(
"--current_character_list",
"-ccl",
help="existing character list in a .txt file that needs to be updated with a set of labels",
type=click.Path(exists=True, file_okay=True),
required=False,
)
@click.option(
"--out",
"-o",
help="An output .txt file where the generated or updated character list will be written",
type=click.Path(exists=False, file_okay=True),
)
def main(dir_labels, out, current_character_list):
run_character_list_update(dir_labels, out, current_character_list)

View file

@ -7,7 +7,7 @@ import cv2
from shapely import geometry
from pathlib import Path
from PIL import ImageFont
from eynollah.utils.font import get_font
KERNEL = np.ones((5, 5), np.uint8)
@ -350,11 +350,11 @@ def get_textline_contours_and_ocr_text(xml_file):
ocr_textlines.append(ocr_text_in[0])
return co_use_case, y_len, x_len, ocr_textlines
def fit_text_single_line(draw, text, font_path, max_width, max_height):
def fit_text_single_line(draw, text, max_width, max_height):
initial_font_size = 50
font_size = initial_font_size
while font_size > 10: # Minimum font size
font = ImageFont.truetype(font_path, font_size)
font = get_font(font_size=font_size)# ImageFont.truetype(font_path, font_size)
text_bbox = draw.textbbox((0, 0), text, font=font) # Get text bounding box
text_width = text_bbox[2] - text_bbox[0]
text_height = text_bbox[3] - text_bbox[1]
@ -364,7 +364,7 @@ def fit_text_single_line(draw, text, font_path, max_width, max_height):
font_size -= 2 # Reduce font size and retry
return ImageFont.truetype(font_path, 10) # Smallest font fallback
return get_font(font_size=10)#ImageFont.truetype(font_path, 10) # Smallest font fallback
def get_layout_contours_for_visualization(xml_file):
tree1 = ET.parse(xml_file, parser = ET.XMLParser(encoding='utf-8'))

View file

@ -12,6 +12,7 @@ from keras.models import Model, load_model
from keras import backend as K
import click
from tensorflow.python.keras import backend as tensorflow_backend
from tensorflow.keras.layers import StringLookup
import xml.etree.ElementTree as ET
from .gt_gen_utils import (
@ -169,6 +170,25 @@ class sbb_predict:
self.model = tf.keras.models.Model(
self.model.get_layer(name = "image").input,
self.model.get_layer(name = "dense2").output)
assert isinstance(self.model, Model)
elif self.task == "transformer-ocr":
import torch
from transformers import VisionEncoderDecoderModel
from transformers import TrOCRProcessor
self.model = VisionEncoderDecoderModel.from_pretrained(self.model_dir)
self.processor = TrOCRProcessor.from_pretrained(self.model_dir)
if self.cpu:
self.device = torch.device('cpu')
else:
self.device = torch.device('cuda:0')
self.model.to(self.device)
assert isinstance(self.model, torch.nn.Module)
else:
config = tf.compat.v1.ConfigProto()
config.gpu_options.allow_growth = True
@ -176,15 +196,15 @@ class sbb_predict:
session = tf.compat.v1.Session(config=config) # tf.InteractiveSession()
tensorflow_backend.set_session(session)
self.model = load_model(self.model_dir , compile=False,custom_objects = {"PatchEncoder": PatchEncoder, "Patches": Patches})
##if self.weights_dir!=None:
##self.model.load_weights(self.weights_dir)
if self.task != 'classification' and self.task != 'reading_order':
self.img_height=self.model.layers[len(self.model.layers)-1].output_shape[1]
self.img_width=self.model.layers[len(self.model.layers)-1].output_shape[2]
self.n_classes=self.model.layers[len(self.model.layers)-1].output_shape[3]
assert isinstance(self.model, Model)
if self.task != 'classification' and self.task != 'reading_order':
self.img_height=self.model.layers[len(self.model.layers)-1].output_shape[1]
self.img_width=self.model.layers[len(self.model.layers)-1].output_shape[2]
self.n_classes=self.model.layers[len(self.model.layers)-1].output_shape[3]
assert isinstance(self.model, Model)
def visualize_model_output(self, prediction, img, task) -> Tuple[NDArray, NDArray]:
if task == "binarization":
@ -235,10 +255,9 @@ class sbb_predict:
return added_image, layout_only
def predict(self, image_dir):
assert isinstance(self.model, Model)
if self.task == 'classification':
classes_names = self.config_params_model['classification_classes_name']
img_1ch = img=cv2.imread(image_dir, 0)
img_1ch =cv2.imread(image_dir, 0)
img_1ch = img_1ch / 255.0
img_1ch = cv2.resize(img_1ch, (self.config_params_model['input_height'], self.config_params_model['input_width']), interpolation=cv2.INTER_NEAREST)
@ -274,6 +293,15 @@ class sbb_predict:
pred_texts = pred_texts[0].replace("[UNK]", "")
return pred_texts
elif self.task == "transformer-ocr":
from PIL import Image
image = Image.open(image_dir).convert("RGB")
pixel_values = self.processor(image, return_tensors="pt").pixel_values
generated_ids = self.model.generate(pixel_values.to(self.device))
return self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
elif self.task == 'reading_order':
img_height = self.config_params_model['input_height']
@ -607,6 +635,8 @@ class sbb_predict:
cv2.imwrite(self.save,res)
elif self.task == "cnn-rnn-ocr":
print(f"Detected text: {res}")
elif self.task == "transformer-ocr":
print(f"Detected text: {res}")
else:
img_seg_overlayed, only_layout = self.visualize_model_output(res, self.img_org, self.task)
if self.save:
@ -710,10 +740,12 @@ class sbb_predict:
)
def main(image, dir_in, model, patches, save, save_layout, ground_truth, xml_file, cpu, out, min_area):
assert image or dir_in, "Either a single image -i or a dir_in -di is required"
with open(os.path.join(model,'config.json')) as f:
with open(os.path.join(model,'config_eynollah.json')) as f:
config_params_model = json.load(f)
task = config_params_model['task']
if task != 'classification' and task != 'reading_order' and task != "cnn-rnn-ocr":
if task != 'classification' and task != 'reading_order' and task != "cnn-rnn-ocr" and task != "transformer-ocr":
if image and not save:
print("Error: You used one of segmentation or binarization task with image input but not set -s, you need a filename to save visualized output with -s")
sys.exit(1)

View file

@ -1,7 +1,8 @@
import os
import sys
import json
import numpy as np
import cv2
import click
from eynollah.training.metrics import (
@ -27,7 +28,8 @@ from eynollah.training.utils import (
generate_data_from_folder_training,
get_one_hot,
provide_patches,
return_number_of_total_training_data
return_number_of_total_training_data,
OCRDatasetYieldAugmentations
)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
@ -41,11 +43,16 @@ from sklearn.metrics import f1_score
from tensorflow.keras.callbacks import Callback
from tensorflow.keras.layers import StringLookup
import numpy as np
import cv2
import torch
from transformers import TrOCRProcessor
import evaluate
from transformers import default_data_collator
from transformers import VisionEncoderDecoderModel
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
class SaveWeightsAfterSteps(Callback):
def __init__(self, save_interval, save_path, _config):
def __init__(self, save_interval, save_path, _config, characters_cnnrnn_ocr=None):
super(SaveWeightsAfterSteps, self).__init__()
self.save_interval = save_interval
self.save_path = save_path
@ -61,7 +68,10 @@ class SaveWeightsAfterSteps(Callback):
self.model.save(save_file)
with open(os.path.join(os.path.join(self.save_path, f"model_step_{self.step_count}"),"config.json"), "w") as fp:
if characters_cnnrnn_ocr:
os.system("cp "+characters_cnnrnn_ocr+" "+os.path.join(os.path.join(self.save_path, f"model_step_{self.step_count}"),"characters_org.txt"))
with open(os.path.join(os.path.join(self.save_path, f"model_step_{self.step_count}"),"config_eynollah.json"), "w") as fp:
json.dump(self._config, fp) # encode dict into JSON
print(f"saved model as steps {self.step_count} to {save_file}")
@ -477,7 +487,7 @@ def run(
model.save(os.path.join(dir_output,'model_'+str(i)))
with open(os.path.join(os.path.join(dir_output,'model_'+str(i)),"config.json"), "w") as fp:
with open(os.path.join(os.path.join(dir_output,'model_'+str(i)),"config_eynollah.json"), "w") as fp:
json.dump(_config, fp) # encode dict into JSON
#os.system('rm -rf '+dir_train_flowing)
@ -537,7 +547,7 @@ def run(
opt = tf.keras.optimizers.Adam(learning_rate=learning_rate)#1e-4)#(lr_schedule)
model.compile(optimizer=opt)
save_weights_callback = SaveWeightsAfterSteps(save_interval, dir_output, _config) if save_interval else None
save_weights_callback = SaveWeightsAfterSteps(save_interval, dir_output, _config, characters_cnnrnn_ocr=characters_txt_file) if save_interval else None
for i in tqdm(range(index_start, n_epochs + index_start)):
if save_interval:
@ -556,9 +566,125 @@ def run(
if i >=0:
model.save( os.path.join(dir_output,'model_'+str(i) ))
with open(os.path.join(os.path.join(dir_output,'model_'+str(i)),"config.json"), "w") as fp:
os.system("cp "+characters_txt_file+" "+os.path.join(os.path.join(dir_output,'model_'+str(i)),"characters_org.txt"))
with open(os.path.join(os.path.join(dir_output,'model_'+str(i)),"config_eynollah.json"), "w") as fp:
json.dump(_config, fp) # encode dict into JSON
elif task=="transformer-ocr":
dir_img, dir_lab = get_dirs_or_files(dir_train)
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-printed")
ls_files_images = os.listdir(dir_img)
aug_multip = return_multiplier_based_on_augmnentations(augmentation, color_padding_rotation, rotation_not_90, blur_aug, degrading, bin_deg,
brightening, padding_white, adding_rgb_foreground, adding_rgb_background, binarization,
image_inversion, channels_shuffling, add_red_textlines, white_noise_strap, textline_skewing, textline_skewing_bin, textline_left_in_depth, textline_left_in_depth_bin, textline_right_in_depth, textline_right_in_depth_bin, textline_up_in_depth, textline_up_in_depth_bin, textline_down_in_depth, textline_down_in_depth_bin, pepper_bin_aug, pepper_aug, degrade_scales, number_of_backgrounds_per_image, thetha, thetha_padd, brightness, padd_colors, shuffle_indexes, pepper_indexes, skewing_amplitudes, blur_k, white_padds)
len_dataset = aug_multip*len(ls_files_images)
dataset = OCRDatasetYieldAugmentations(
dir_img=dir_img,
dir_img_bin=dir_img_bin,
dir_lab=dir_lab,
processor=processor,
max_target_length=max_len,
augmentation = augmentation,
binarization = binarization,
add_red_textlines = add_red_textlines,
white_noise_strap = white_noise_strap,
adding_rgb_foreground = adding_rgb_foreground,
adding_rgb_background = adding_rgb_background,
bin_deg = bin_deg,
blur_aug = blur_aug,
brightening = brightening,
padding_white = padding_white,
color_padding_rotation = color_padding_rotation,
rotation_not_90 = rotation_not_90,
degrading = degrading,
channels_shuffling = channels_shuffling,
textline_skewing = textline_skewing,
textline_skewing_bin = textline_skewing_bin,
textline_right_in_depth = textline_right_in_depth,
textline_left_in_depth = textline_left_in_depth,
textline_up_in_depth = textline_up_in_depth,
textline_down_in_depth = textline_down_in_depth,
textline_right_in_depth_bin = textline_right_in_depth_bin,
textline_left_in_depth_bin = textline_left_in_depth_bin,
textline_up_in_depth_bin = textline_up_in_depth_bin,
textline_down_in_depth_bin = textline_down_in_depth_bin,
pepper_aug = pepper_aug,
pepper_bin_aug = pepper_bin_aug,
list_all_possible_background_images=list_all_possible_background_images,
list_all_possible_foreground_rgbs=list_all_possible_foreground_rgbs,
blur_k = blur_k,
degrade_scales = degrade_scales,
white_padds = white_padds,
thetha_padd = thetha_padd,
thetha = thetha,
brightness = brightness,
padd_colors = padd_colors,
number_of_backgrounds_per_image = number_of_backgrounds_per_image,
shuffle_indexes = shuffle_indexes,
pepper_indexes = pepper_indexes,
skewing_amplitudes = skewing_amplitudes,
dir_rgb_backgrounds = dir_rgb_backgrounds,
dir_rgb_foregrounds = dir_rgb_foregrounds,
len_data=len_dataset,
)
# Create a DataLoader
data_loader = torch.utils.data.DataLoader(dataset, batch_size=1)
train_dataset = data_loader.dataset
if continue_training:
model = VisionEncoderDecoderModel.from_pretrained(dir_of_start_model)
else:
model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-printed")
# set special tokens used for creating the decoder_input_ids from the labels
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id
# make sure vocab size is set correctly
model.config.vocab_size = model.config.decoder.vocab_size
# set beam search parameters
model.config.eos_token_id = processor.tokenizer.sep_token_id
model.config.max_length = max_len
model.config.early_stopping = True
model.config.no_repeat_ngram_size = 3
model.config.length_penalty = 2.0
model.config.num_beams = 4
training_args = Seq2SeqTrainingArguments(
predict_with_generate=True,
num_train_epochs=n_epochs,
learning_rate=learning_rate,
per_device_train_batch_size=n_batch,
fp16=True,
output_dir=dir_output,
logging_steps=2,
save_steps=save_interval,
)
cer_metric = evaluate.load("cer")
# instantiate trainer
trainer = Seq2SeqTrainer(
model=model,
tokenizer=processor.feature_extractor,
args=training_args,
train_dataset=train_dataset,
data_collator=default_data_collator,
)
trainer.train()
elif task=='classification':
configuration()
model = resnet50_classifier(n_classes, input_height, input_width, weight_decay, pretraining)
@ -609,10 +735,10 @@ def run(
model_weight_averaged.set_weights(new_weights)
model_weight_averaged.save(os.path.join(dir_output,'model_ens_avg'))
with open(os.path.join( os.path.join(dir_output,'model_ens_avg'), "config.json"), "w") as fp:
with open(os.path.join( os.path.join(dir_output,'model_ens_avg'), "config_eynollah.json"), "w") as fp:
json.dump(_config, fp) # encode dict into JSON
with open(os.path.join( os.path.join(dir_output,'model_best'), "config.json"), "w") as fp:
with open(os.path.join( os.path.join(dir_output,'model_best'), "config_eynollah.json"), "w") as fp:
json.dump(_config, fp) # encode dict into JSON
elif task=='reading_order':
@ -645,7 +771,7 @@ def run(
history = model.fit(generate_arrays_from_folder_reading_order(dir_flow_train_labels, dir_flow_train_imgs, n_batch, input_height, input_width, n_classes, thetha, augmentation), steps_per_epoch=num_rows / n_batch, verbose=1)
model.save( os.path.join(dir_output,'model_'+str(i+indexer_start) ))
with open(os.path.join(os.path.join(dir_output,'model_'+str(i)),"config.json"), "w") as fp:
with open(os.path.join(os.path.join(dir_output,'model_'+str(i)),"config_eynollah.json"), "w") as fp:
json.dump(_config, fp) # encode dict into JSON
'''
if f1score>f1score_tot[0]:

View file

@ -13,8 +13,12 @@ import tensorflow as tf
from tensorflow.keras.utils import to_categorical
from PIL import Image, ImageFile, ImageEnhance
import torch
from torch.utils.data import IterableDataset
ImageFile.LOAD_TRUNCATED_IMAGES = True
def vectorize_label(label, char_to_num, padding_token, max_len):
label = char_to_num(tf.strings.unicode_split(label, input_encoding="UTF-8"))
length = tf.shape(label)[0]
@ -76,6 +80,7 @@ def add_salt_and_pepper_noise(img, salt_prob, pepper_prob):
return noisy_image
def invert_image(img):
img_inv = 255 - img
return img_inv
@ -1668,3 +1673,411 @@ def return_multiplier_based_on_augmnentations(
aug_multip += len(pepper_indexes)
return aug_multip
class OCRDatasetYieldAugmentations(IterableDataset):
def __init__(
self,
dir_img,
dir_img_bin,
dir_lab,
processor,
max_target_length=128,
augmentation = None,
binarization = None,
add_red_textlines = None,
white_noise_strap = None,
adding_rgb_foreground = None,
adding_rgb_background = None,
bin_deg = None,
blur_aug = None,
brightening = None,
padding_white = None,
color_padding_rotation = None,
rotation_not_90 = None,
degrading = None,
channels_shuffling = None,
textline_skewing = None,
textline_skewing_bin = None,
textline_right_in_depth = None,
textline_left_in_depth = None,
textline_up_in_depth = None,
textline_down_in_depth = None,
textline_right_in_depth_bin = None,
textline_left_in_depth_bin = None,
textline_up_in_depth_bin = None,
textline_down_in_depth_bin = None,
pepper_aug = None,
pepper_bin_aug = None,
list_all_possible_background_images=None,
list_all_possible_foreground_rgbs=None,
blur_k = None,
degrade_scales = None,
white_padds = None,
thetha_padd = None,
thetha = None,
brightness = None,
padd_colors = None,
number_of_backgrounds_per_image = None,
shuffle_indexes = None,
pepper_indexes = None,
skewing_amplitudes = None,
dir_rgb_backgrounds = None,
dir_rgb_foregrounds = None,
len_data=None,
):
"""
Args:
images_dir (str): Path to the directory containing images.
labels_dir (str): Path to the directory containing label text files.
tokenizer: Tokenizer for processing labels.
transform: Transformations applied after augmentation (e.g., ToTensor, normalization).
image_size (tuple): Size to resize images to.
max_seq_len (int): Maximum sequence length for tokenized labels.
scales (list or None): List of scale factors to apply.
"""
self.dir_img = dir_img
self.dir_img_bin = dir_img_bin
self.dir_lab = dir_lab
self.processor = processor
self.max_target_length = max_target_length
#self.scales = scales if scales else []
self.augmentation = augmentation
self.binarization = binarization
self.add_red_textlines = add_red_textlines
self.white_noise_strap = white_noise_strap
self.adding_rgb_foreground = adding_rgb_foreground
self.adding_rgb_background = adding_rgb_background
self.bin_deg = bin_deg
self.blur_aug = blur_aug
self.brightening = brightening
self.padding_white = padding_white
self.color_padding_rotation = color_padding_rotation
self.rotation_not_90 = rotation_not_90
self.degrading = degrading
self.channels_shuffling = channels_shuffling
self.textline_skewing = textline_skewing
self.textline_skewing_bin = textline_skewing_bin
self.textline_right_in_depth = textline_right_in_depth
self.textline_left_in_depth = textline_left_in_depth
self.textline_up_in_depth = textline_up_in_depth
self.textline_down_in_depth = textline_down_in_depth
self.textline_right_in_depth_bin = textline_right_in_depth_bin
self.textline_left_in_depth_bin = textline_left_in_depth_bin
self.textline_up_in_depth_bin = textline_up_in_depth_bin
self.textline_down_in_depth_bin = textline_down_in_depth_bin
self.pepper_aug = pepper_aug
self.pepper_bin_aug = pepper_bin_aug
self.list_all_possible_background_images=list_all_possible_background_images
self.list_all_possible_foreground_rgbs=list_all_possible_foreground_rgbs
self.blur_k = blur_k
self.degrade_scales = degrade_scales
self.white_padds = white_padds
self.thetha_padd = thetha_padd
self.thetha = thetha
self.brightness = brightness
self.padd_colors = padd_colors
self.number_of_backgrounds_per_image = number_of_backgrounds_per_image
self.shuffle_indexes = shuffle_indexes
self.pepper_indexes = pepper_indexes
self.skewing_amplitudes = skewing_amplitudes
self.dir_rgb_backgrounds = dir_rgb_backgrounds
self.dir_rgb_foregrounds = dir_rgb_foregrounds
self.image_files = os.listdir(dir_img)#sorted([f for f in os.listdir(images_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))])
self.len_data = len_data
#assert len(self.image_files) == len(self.label_files), "Number of images and labels must match!"
def __len__(self):
return self.len_data
def __iter__(self):
for img_file in self.image_files:
# Load image
f_name = img_file.split('.')[0]
txt_inp = open(os.path.join(self.dir_lab, f_name+'.txt'),'r').read().split('\n')[0]
img = cv2.imread(os.path.join(self.dir_img, img_file))
img = img.astype(np.uint8)
if self.dir_img_bin:
img_bin_corr = cv2.imread(os.path.join(self.dir_img_bin, f_name+'.png') )
img_bin_corr = img_bin_corr.astype(np.uint8)
else:
img_bin_corr = None
labels = self.processor.tokenizer(txt_inp,
padding="max_length",
max_length=self.max_target_length).input_ids
labels = [label if label != self.processor.tokenizer.pad_token_id else -100 for label in labels]
if self.augmentation:
pixel_values = self.processor(Image.fromarray(img), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.color_padding_rotation:
for index, thetha_ind in enumerate(self.thetha_padd):
for padd_col in self.padd_colors:
img_out = rotation_not_90_func_single_image(do_padding_for_ocr(img, 1.2, padd_col), thetha_ind)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.rotation_not_90:
for index, thetha_ind in enumerate(self.thetha):
img_out = rotation_not_90_func_single_image(img, thetha_ind)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.blur_aug:
for index, blur_type in enumerate(self.blur_k):
img_out = bluring(img, blur_type)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.degrading:
for index, deg_scale_ind in enumerate(self.degrade_scales):
try:
img_out = do_degrading(img, deg_scale_ind)
except:
img_out = np.copy(img)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.bin_deg:
for index, deg_scale_ind in enumerate(self.degrade_scales):
try:
img_out = self.do_degrading(img_bin_corr, deg_scale_ind)
except:
img_out = np.copy(img_bin_corr)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.brightening:
for index, bright_scale_ind in enumerate(self.brightness):
try:
img_out = do_brightening(dir_img, bright_scale_ind)
except:
img_out = np.copy(img)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.padding_white:
for index, padding_size in enumerate(self.white_padds):
for padd_col in self.padd_colors:
img_out = do_padding_for_ocr(img, padding_size, padd_col)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.adding_rgb_foreground:
for i_n in range(self.number_of_backgrounds_per_image):
background_image_chosen_name = random.choice(self.list_all_possible_background_images)
foreground_rgb_chosen_name = random.choice(self.list_all_possible_foreground_rgbs)
img_rgb_background_chosen = cv2.imread(self.dir_rgb_backgrounds + '/' + background_image_chosen_name)
foreground_rgb_chosen = np.load(self.dir_rgb_foregrounds + '/' + foreground_rgb_chosen_name)
img_out = return_binary_image_with_given_rgb_background_and_given_foreground_rgb(img_bin_corr, img_rgb_background_chosen, foreground_rgb_chosen)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.adding_rgb_background:
for i_n in range(self.number_of_backgrounds_per_image):
background_image_chosen_name = random.choice(self.list_all_possible_background_images)
img_rgb_background_chosen = cv2.imread(self.dir_rgb_backgrounds + '/' + background_image_chosen_name)
img_out = return_binary_image_with_given_rgb_background(img_bin_corr, img_rgb_background_chosen)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.binarization:
pixel_values = self.processor(Image.fromarray(img_bin_corr), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.channels_shuffling:
for shuffle_index in self.shuffle_indexes:
img_out = return_shuffled_channels(img, shuffle_index)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.add_red_textlines:
img_out = return_image_with_red_elements(img, img_bin_corr)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.white_noise_strap:
img_out = return_image_with_strapped_white_noises(img)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.textline_skewing:
for index, des_scale_ind in enumerate(self.skewing_amplitudes):
try:
img_out = do_deskewing(img, des_scale_ind)
except:
img_out = np.copy(img)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.textline_skewing_bin:
for index, des_scale_ind in enumerate(self.skewing_amplitudes):
try:
img_out = do_deskewing(img_bin_corr, des_scale_ind)
except:
img_out = np.copy(img_bin_corr)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.textline_left_in_depth:
try:
img_out = do_direction_in_depth(img, 'left')
except:
img_out = np.copy(img)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.textline_left_in_depth_bin:
try:
img_out = do_direction_in_depth(img_bin_corr, 'left')
except:
img_out = np.copy(img_bin_corr)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.textline_right_in_depth:
try:
img_out = do_direction_in_depth(img, 'right')
except:
img_out = np.copy(img)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.textline_right_in_depth_bin:
try:
img_out = do_direction_in_depth(img_bin_corr, 'right')
except:
img_out = np.copy(img_bin_corr)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.textline_up_in_depth:
try:
img_out = do_direction_in_depth(img, 'up')
except:
img_out = np.copy(img)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.textline_up_in_depth_bin:
try:
img_out = do_direction_in_depth(img_bin_corr, 'up')
except:
img_out = np.copy(img_bin_corr)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.textline_down_in_depth:
try:
img_out = do_direction_in_depth(img, 'down')
except:
img_out = np.copy(img)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.textline_down_in_depth_bin:
try:
img_out = do_direction_in_depth(img_bin_corr, 'down')
except:
img_out = np.copy(img_bin_corr)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.pepper_bin_aug:
for index, pepper_ind in enumerate(self.pepper_indexes):
img_out = add_salt_and_pepper_noise(img_bin_corr, pepper_ind, pepper_ind)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.pepper_aug:
for index, pepper_ind in enumerate(self.pepper_indexes):
img_out = add_salt_and_pepper_noise(img, pepper_ind, pepper_ind)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
else:
pixel_values = self.processor(Image.fromarray(img), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding

View file

@ -21,6 +21,11 @@ from tensorflow.keras.layers import *
import click
import logging
from transformers import TrOCRProcessor
from PIL import Image
import torch
from transformers import VisionEncoderDecoderModel
class Patches(layers.Layer):
def __init__(self, patch_size_x, patch_size_y):
@ -92,30 +97,46 @@ def start_new_session():
tensorflow_backend.set_session(session)
return session
def run_ensembling(dir_models, out):
def run_ensembling(dir_models, out, framework):
ls_models = os.listdir(dir_models)
if framework=="torch":
models = []
sd_models = []
for model_name in ls_models:
model = VisionEncoderDecoderModel.from_pretrained(os.path.join(dir_models,model_name))
models.append(model)
sd_models.append(model.state_dict())
for key in sd_models[0]:
sd_models[0][key] = sum(sd[key] for sd in sd_models) / len(sd_models)
weights=[]
model.load_state_dict(sd_models[0])
os.system("mkdir "+out)
torch.save(model.state_dict(), os.path.join(out, "pytorch_model.bin"))
os.system('cp '+os.path.join(os.path.join(dir_models,model_name) , "config_eynollah.json ")+out)
for model_name in ls_models:
model = load_model(os.path.join(dir_models,model_name) , compile=False, custom_objects={'PatchEncoder':PatchEncoder, 'Patches': Patches})
weights.append(model.get_weights())
else:
weights=[]
new_weights = list()
for model_name in ls_models:
model = load_model(os.path.join(dir_models,model_name) , compile=False, custom_objects={'PatchEncoder':PatchEncoder, 'Patches': Patches})
weights.append(model.get_weights())
for weights_list_tuple in zip(*weights):
new_weights.append(
[np.array(weights_).mean(axis=0)\
for weights_ in zip(*weights_list_tuple)])
new_weights = list()
for weights_list_tuple in zip(*weights):
new_weights.append(
[np.array(weights_).mean(axis=0)\
for weights_ in zip(*weights_list_tuple)])
new_weights = [np.array(x) for x in new_weights]
new_weights = [np.array(x) for x in new_weights]
model.set_weights(new_weights)
model.save(out)
os.system('cp '+os.path.join(os.path.join(dir_models,model_name) , "config.json ")+out)
model.set_weights(new_weights)
model.save(out)
os.system('cp '+os.path.join(os.path.join(dir_models,model_name) , "config_eynollah.json ")+out)
os.system('cp '+os.path.join(os.path.join(dir_models,model_name) , "characters_org.txt ")+out)
@click.command()
@click.option(
@ -130,7 +151,12 @@ def run_ensembling(dir_models, out):
help="output directory where ensembled model will be written.",
type=click.Path(exists=False, file_okay=False),
)
@click.option(
"--framework",
"-fw",
help="this parameter gets tensorflow or torch as model framework",
)
def main(dir_models, out):
run_ensembling(dir_models, out)
def main(dir_models, out, framework):
run_ensembling(dir_models, out, framework)

View file

@ -9,8 +9,8 @@ else:
import importlib.resources as importlib_resources
def get_font():
def get_font(font_size):
#font_path = "Charis-7.000/Charis-Regular.ttf" # Make sure this file exists!
font = importlib_resources.files(__package__) / "../Charis-Regular.ttf"
with importlib_resources.as_file(font) as font:
return ImageFont.truetype(font=font, size=40)
return ImageFont.truetype(font=font, size=font_size)

View file

@ -1,17 +1,17 @@
{
"backbone_type" : "transformer",
"task": "cnn-rnn-ocr",
"task": "transformer-ocr",
"n_classes" : 2,
"max_len": 280,
"n_epochs" : 3,
"max_len": 192,
"n_epochs" : 1,
"input_height" : 32,
"input_width" : 512,
"weight_decay" : 1e-6,
"n_batch" : 4,
"n_batch" : 1,
"learning_rate": 1e-5,
"save_interval": 1500,
"patches" : false,
"pretraining" : true,
"pretraining" : false,
"augmentation" : true,
"flip_aug" : false,
"blur_aug" : true,
@ -77,7 +77,6 @@
"dir_output": "/home/vahid/extracted_lines/1919_bin/output",
"dir_rgb_backgrounds": "/home/vahid/Documents/1_2_test_eynollah/set_rgb_background",
"dir_rgb_foregrounds": "/home/vahid/Documents/1_2_test_eynollah/out_set_rgb_foreground",
"dir_img_bin": "/home/vahid/extracted_lines/1919_bin/images_bin",
"characters_txt_file":"/home/vahid/Downloads/models_eynollah/model_eynollah_ocr_cnnrnn_20250930/characters_org.txt"
"dir_img_bin": "/home/vahid/extracted_lines/1919_bin/images_bin"
}

View file

@ -4,3 +4,8 @@ numpy <1.24.0
tqdm
imutils
scipy
torch
evaluate
accelerate
jiwer
transformers <= 4.30.2