diff --git a/src/eynollah/training/train.py b/src/eynollah/training/train.py index 7a0cb3d..7ed8282 100644 --- a/src/eynollah/training/train.py +++ b/src/eynollah/training/train.py @@ -1,7 +1,8 @@ import os import sys import json - +import numpy as np +import cv2 import click from eynollah.training.metrics import ( @@ -27,7 +28,8 @@ from eynollah.training.utils import ( generate_data_from_folder_training, get_one_hot, provide_patches, - return_number_of_total_training_data + return_number_of_total_training_data, + OCRDatasetYieldAugmentations ) os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' @@ -41,8 +43,13 @@ from sklearn.metrics import f1_score from tensorflow.keras.callbacks import Callback from tensorflow.keras.layers import StringLookup -import numpy as np -import cv2 + +import torch +from transformers import TrOCRProcessor +import evaluate +from transformers import default_data_collator +from transformers import VisionEncoderDecoderModel +from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments class SaveWeightsAfterSteps(Callback): def __init__(self, save_interval, save_path, _config): @@ -559,6 +566,121 @@ def run( with open(os.path.join(os.path.join(dir_output,'model_'+str(i)),"config.json"), "w") as fp: json.dump(_config, fp) # encode dict into JSON + + elif task=="transformer-ocr": + dir_img, dir_lab = get_dirs_or_files(dir_train) + + processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-printed") + + ls_files_images = os.listdir(dir_img) + + aug_multip = return_multiplier_based_on_augmnentations(augmentation, color_padding_rotation, rotation_not_90, blur_aug, degrading, bin_deg, + brightening, padding_white, adding_rgb_foreground, adding_rgb_background, binarization, + image_inversion, channels_shuffling, add_red_textlines, white_noise_strap, textline_skewing, textline_skewing_bin, textline_left_in_depth, textline_left_in_depth_bin, textline_right_in_depth, textline_right_in_depth_bin, textline_up_in_depth, textline_up_in_depth_bin, textline_down_in_depth, textline_down_in_depth_bin, pepper_bin_aug, pepper_aug, degrade_scales, number_of_backgrounds_per_image, thetha, thetha_padd, brightness, padd_colors, shuffle_indexes, pepper_indexes, skewing_amplitudes, blur_k, white_padds) + + len_dataset = aug_multip*len(ls_files_images) + + dataset = OCRDatasetYieldAugmentations( + dir_img=dir_img, + dir_img_bin=dir_img_bin, + dir_lab=dir_lab, + processor=processor, + max_target_length=max_len, + augmentation = augmentation, + binarization = binarization, + add_red_textlines = add_red_textlines, + white_noise_strap = white_noise_strap, + adding_rgb_foreground = adding_rgb_foreground, + adding_rgb_background = adding_rgb_background, + bin_deg = bin_deg, + blur_aug = blur_aug, + brightening = brightening, + padding_white = padding_white, + color_padding_rotation = color_padding_rotation, + rotation_not_90 = rotation_not_90, + degrading = degrading, + channels_shuffling = channels_shuffling, + textline_skewing = textline_skewing, + textline_skewing_bin = textline_skewing_bin, + textline_right_in_depth = textline_right_in_depth, + textline_left_in_depth = textline_left_in_depth, + textline_up_in_depth = textline_up_in_depth, + textline_down_in_depth = textline_down_in_depth, + textline_right_in_depth_bin = textline_right_in_depth_bin, + textline_left_in_depth_bin = textline_left_in_depth_bin, + textline_up_in_depth_bin = textline_up_in_depth_bin, + textline_down_in_depth_bin = textline_down_in_depth_bin, + pepper_aug = pepper_aug, + pepper_bin_aug = pepper_bin_aug, + list_all_possible_background_images=list_all_possible_background_images, + list_all_possible_foreground_rgbs=list_all_possible_foreground_rgbs, + blur_k = blur_k, + degrade_scales = degrade_scales, + white_padds = white_padds, + thetha_padd = thetha_padd, + thetha = thetha, + brightness = brightness, + padd_colors = padd_colors, + number_of_backgrounds_per_image = number_of_backgrounds_per_image, + shuffle_indexes = shuffle_indexes, + pepper_indexes = pepper_indexes, + skewing_amplitudes = skewing_amplitudes, + dir_rgb_backgrounds = dir_rgb_backgrounds, + dir_rgb_foregrounds = dir_rgb_foregrounds, + len_data=len_dataset, + ) + + # Create a DataLoader + data_loader = torch.utils.data.DataLoader(dataset, batch_size=1) + train_dataset = data_loader.dataset + + + if continue_training: + model = VisionEncoderDecoderModel.from_pretrained(dir_of_start_model) + else: + model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-printed") + + + # set special tokens used for creating the decoder_input_ids from the labels + model.config.decoder_start_token_id = processor.tokenizer.cls_token_id + model.config.pad_token_id = processor.tokenizer.pad_token_id + # make sure vocab size is set correctly + model.config.vocab_size = model.config.decoder.vocab_size + + # set beam search parameters + model.config.eos_token_id = processor.tokenizer.sep_token_id + model.config.max_length = max_len + model.config.early_stopping = True + model.config.no_repeat_ngram_size = 3 + model.config.length_penalty = 2.0 + model.config.num_beams = 4 + + + training_args = Seq2SeqTrainingArguments( + predict_with_generate=True, + num_train_epochs=n_epochs, + learning_rate=learning_rate, + per_device_train_batch_size=n_batch, + fp16=True, + output_dir=dir_output, + logging_steps=2, + save_steps=save_interval, + ) + + + cer_metric = evaluate.load("cer") + + # instantiate trainer + trainer = Seq2SeqTrainer( + model=model, + tokenizer=processor.feature_extractor, + args=training_args, + train_dataset=train_dataset, + data_collator=default_data_collator, + ) + trainer.train() + + elif task=='classification': configuration() model = resnet50_classifier(n_classes, input_height, input_width, weight_decay, pretraining) diff --git a/src/eynollah/training/utils.py b/src/eynollah/training/utils.py index 005810f..9b4e01a 100644 --- a/src/eynollah/training/utils.py +++ b/src/eynollah/training/utils.py @@ -9,12 +9,16 @@ from scipy.ndimage.interpolation import map_coordinates from scipy.ndimage.filters import gaussian_filter from tqdm import tqdm import imutils -import tensorflow as tf -from tensorflow.keras.utils import to_categorical +##import tensorflow as tf +##from tensorflow.keras.utils import to_categorical from PIL import Image, ImageFile, ImageEnhance +import torch +from torch.utils.data import IterableDataset + ImageFile.LOAD_TRUNCATED_IMAGES = True + def vectorize_label(label, char_to_num, padding_token, max_len): label = char_to_num(tf.strings.unicode_split(label, input_encoding="UTF-8")) length = tf.shape(label)[0] @@ -76,6 +80,7 @@ def add_salt_and_pepper_noise(img, salt_prob, pepper_prob): return noisy_image + def invert_image(img): img_inv = 255 - img return img_inv @@ -1668,3 +1673,411 @@ def return_multiplier_based_on_augmnentations( aug_multip += len(pepper_indexes) return aug_multip + + +class OCRDatasetYieldAugmentations(IterableDataset): + def __init__( + self, + dir_img, + dir_img_bin, + dir_lab, + processor, + max_target_length=128, + augmentation = None, + binarization = None, + add_red_textlines = None, + white_noise_strap = None, + adding_rgb_foreground = None, + adding_rgb_background = None, + bin_deg = None, + blur_aug = None, + brightening = None, + padding_white = None, + color_padding_rotation = None, + rotation_not_90 = None, + degrading = None, + channels_shuffling = None, + textline_skewing = None, + textline_skewing_bin = None, + textline_right_in_depth = None, + textline_left_in_depth = None, + textline_up_in_depth = None, + textline_down_in_depth = None, + textline_right_in_depth_bin = None, + textline_left_in_depth_bin = None, + textline_up_in_depth_bin = None, + textline_down_in_depth_bin = None, + pepper_aug = None, + pepper_bin_aug = None, + list_all_possible_background_images=None, + list_all_possible_foreground_rgbs=None, + blur_k = None, + degrade_scales = None, + white_padds = None, + thetha_padd = None, + thetha = None, + brightness = None, + padd_colors = None, + number_of_backgrounds_per_image = None, + shuffle_indexes = None, + pepper_indexes = None, + skewing_amplitudes = None, + dir_rgb_backgrounds = None, + dir_rgb_foregrounds = None, + len_data=None, + ): + """ + Args: + images_dir (str): Path to the directory containing images. + labels_dir (str): Path to the directory containing label text files. + tokenizer: Tokenizer for processing labels. + transform: Transformations applied after augmentation (e.g., ToTensor, normalization). + image_size (tuple): Size to resize images to. + max_seq_len (int): Maximum sequence length for tokenized labels. + scales (list or None): List of scale factors to apply. + """ + self.dir_img = dir_img + self.dir_img_bin = dir_img_bin + self.dir_lab = dir_lab + self.processor = processor + self.max_target_length = max_target_length + #self.scales = scales if scales else [] + + self.augmentation = augmentation + self.binarization = binarization + self.add_red_textlines = add_red_textlines + self.white_noise_strap = white_noise_strap + self.adding_rgb_foreground = adding_rgb_foreground + self.adding_rgb_background = adding_rgb_background + self.bin_deg = bin_deg + self.blur_aug = blur_aug + self.brightening = brightening + self.padding_white = padding_white + self.color_padding_rotation = color_padding_rotation + self.rotation_not_90 = rotation_not_90 + self.degrading = degrading + self.channels_shuffling = channels_shuffling + self.textline_skewing = textline_skewing + self.textline_skewing_bin = textline_skewing_bin + self.textline_right_in_depth = textline_right_in_depth + self.textline_left_in_depth = textline_left_in_depth + self.textline_up_in_depth = textline_up_in_depth + self.textline_down_in_depth = textline_down_in_depth + self.textline_right_in_depth_bin = textline_right_in_depth_bin + self.textline_left_in_depth_bin = textline_left_in_depth_bin + self.textline_up_in_depth_bin = textline_up_in_depth_bin + self.textline_down_in_depth_bin = textline_down_in_depth_bin + self.pepper_aug = pepper_aug + self.pepper_bin_aug = pepper_bin_aug + self.list_all_possible_background_images=list_all_possible_background_images + self.list_all_possible_foreground_rgbs=list_all_possible_foreground_rgbs + self.blur_k = blur_k + self.degrade_scales = degrade_scales + self.white_padds = white_padds + self.thetha_padd = thetha_padd + self.thetha = thetha + self.brightness = brightness + self.padd_colors = padd_colors + self.number_of_backgrounds_per_image = number_of_backgrounds_per_image + self.shuffle_indexes = shuffle_indexes + self.pepper_indexes = pepper_indexes + self.skewing_amplitudes = skewing_amplitudes + self.dir_rgb_backgrounds = dir_rgb_backgrounds + self.dir_rgb_foregrounds = dir_rgb_foregrounds + self.image_files = os.listdir(dir_img)#sorted([f for f in os.listdir(images_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]) + self.len_data = len_data + #assert len(self.image_files) == len(self.label_files), "Number of images and labels must match!" + + def __len__(self): + return self.len_data + + def __iter__(self): + for img_file in self.image_files: + # Load image + f_name = img_file.split('.')[0] + + txt_inp = open(os.path.join(self.dir_lab, f_name+'.txt'),'r').read().split('\n')[0] + + img = cv2.imread(os.path.join(self.dir_img, img_file)) + img = img.astype(np.uint8) + + + if self.dir_img_bin: + img_bin_corr = cv2.imread(os.path.join(self.dir_img_bin, f_name+'.png') ) + img_bin_corr = img_bin_corr.astype(np.uint8) + else: + img_bin_corr = None + + + labels = self.processor.tokenizer(txt_inp, + padding="max_length", + max_length=self.max_target_length).input_ids + + labels = [label if label != self.processor.tokenizer.pad_token_id else -100 for label in labels] + + + if self.augmentation: + pixel_values = self.processor(Image.fromarray(img), return_tensors="pt").pixel_values + encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)} + yield encoding + + + if self.color_padding_rotation: + for index, thetha_ind in enumerate(self.thetha_padd): + for padd_col in self.padd_colors: + img_out = rotation_not_90_func_single_image(do_padding_for_ocr(img, 1.2, padd_col), thetha_ind) + img_out = img_out.astype(np.uint8) + pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values + encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)} + yield encoding + + if self.rotation_not_90: + for index, thetha_ind in enumerate(self.thetha): + img_out = rotation_not_90_func_single_image(img, thetha_ind) + img_out = img_out.astype(np.uint8) + pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values + encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)} + yield encoding + + + if self.blur_aug: + for index, blur_type in enumerate(self.blur_k): + img_out = bluring(img, blur_type) + img_out = img_out.astype(np.uint8) + pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values + encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)} + yield encoding + + if self.degrading: + for index, deg_scale_ind in enumerate(self.degrade_scales): + try: + img_out = do_degrading(img, deg_scale_ind) + except: + img_out = np.copy(img) + img_out = img_out.astype(np.uint8) + pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values + encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)} + yield encoding + + if self.bin_deg: + for index, deg_scale_ind in enumerate(self.degrade_scales): + try: + img_out = self.do_degrading(img_bin_corr, deg_scale_ind) + except: + img_out = np.copy(img_bin_corr) + img_out = img_out.astype(np.uint8) + pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values + encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)} + yield encoding + + + if self.brightening: + for index, bright_scale_ind in enumerate(self.brightness): + try: + img_out = do_brightening(dir_img, bright_scale_ind) + except: + img_out = np.copy(img) + img_out = img_out.astype(np.uint8) + pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values + encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)} + yield encoding + + + if self.padding_white: + for index, padding_size in enumerate(self.white_padds): + for padd_col in self.padd_colors: + img_out = do_padding_for_ocr(img, padding_size, padd_col) + img_out = img_out.astype(np.uint8) + pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values + encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)} + yield encoding + + + if self.adding_rgb_foreground: + for i_n in range(self.number_of_backgrounds_per_image): + background_image_chosen_name = random.choice(self.list_all_possible_background_images) + foreground_rgb_chosen_name = random.choice(self.list_all_possible_foreground_rgbs) + + img_rgb_background_chosen = cv2.imread(self.dir_rgb_backgrounds + '/' + background_image_chosen_name) + foreground_rgb_chosen = np.load(self.dir_rgb_foregrounds + '/' + foreground_rgb_chosen_name) + + img_out = return_binary_image_with_given_rgb_background_and_given_foreground_rgb(img_bin_corr, img_rgb_background_chosen, foreground_rgb_chosen) + img_out = img_out.astype(np.uint8) + pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values + encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)} + yield encoding + + + + + if self.adding_rgb_background: + for i_n in range(self.number_of_backgrounds_per_image): + background_image_chosen_name = random.choice(self.list_all_possible_background_images) + img_rgb_background_chosen = cv2.imread(self.dir_rgb_backgrounds + '/' + background_image_chosen_name) + img_out = return_binary_image_with_given_rgb_background(img_bin_corr, img_rgb_background_chosen) + img_out = img_out.astype(np.uint8) + pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values + encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)} + yield encoding + + if self.binarization: + pixel_values = self.processor(Image.fromarray(img_bin_corr), return_tensors="pt").pixel_values + encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)} + yield encoding + + + if self.channels_shuffling: + for shuffle_index in self.shuffle_indexes: + img_out = return_shuffled_channels(img, shuffle_index) + img_out = img_out.astype(np.uint8) + pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values + encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)} + yield encoding + + if self.add_red_textlines: + img_out = return_image_with_red_elements(img, img_bin_corr) + img_out = img_out.astype(np.uint8) + pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values + encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)} + yield encoding + + if self.white_noise_strap: + img_out = return_image_with_strapped_white_noises(img) + img_out = img_out.astype(np.uint8) + pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values + encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)} + yield encoding + + if self.textline_skewing: + for index, des_scale_ind in enumerate(self.skewing_amplitudes): + try: + img_out = do_deskewing(img, des_scale_ind) + except: + img_out = np.copy(img) + img_out = img_out.astype(np.uint8) + pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values + encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)} + yield encoding + + if self.textline_skewing_bin: + for index, des_scale_ind in enumerate(self.skewing_amplitudes): + try: + img_out = do_deskewing(img_bin_corr, des_scale_ind) + except: + img_out = np.copy(img_bin_corr) + img_out = img_out.astype(np.uint8) + pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values + encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)} + yield encoding + + + if self.textline_left_in_depth: + try: + img_out = do_direction_in_depth(img, 'left') + except: + img_out = np.copy(img) + img_out = img_out.astype(np.uint8) + pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values + encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)} + yield encoding + + + if self.textline_left_in_depth_bin: + try: + img_out = do_direction_in_depth(img_bin_corr, 'left') + except: + img_out = np.copy(img_bin_corr) + img_out = img_out.astype(np.uint8) + pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values + encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)} + yield encoding + + + if self.textline_right_in_depth: + try: + img_out = do_direction_in_depth(img, 'right') + except: + img_out = np.copy(img) + img_out = img_out.astype(np.uint8) + pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values + encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)} + yield encoding + + + if self.textline_right_in_depth_bin: + try: + img_out = do_direction_in_depth(img_bin_corr, 'right') + except: + img_out = np.copy(img_bin_corr) + img_out = img_out.astype(np.uint8) + pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values + encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)} + yield encoding + + + if self.textline_up_in_depth: + try: + img_out = do_direction_in_depth(img, 'up') + except: + img_out = np.copy(img) + img_out = img_out.astype(np.uint8) + pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values + encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)} + yield encoding + + + if self.textline_up_in_depth_bin: + try: + img_out = do_direction_in_depth(img_bin_corr, 'up') + except: + img_out = np.copy(img_bin_corr) + img_out = img_out.astype(np.uint8) + pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values + encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)} + yield encoding + + + if self.textline_down_in_depth: + try: + img_out = do_direction_in_depth(img, 'down') + except: + img_out = np.copy(img) + img_out = img_out.astype(np.uint8) + pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values + encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)} + yield encoding + + + if self.textline_down_in_depth_bin: + try: + img_out = do_direction_in_depth(img_bin_corr, 'down') + except: + img_out = np.copy(img_bin_corr) + img_out = img_out.astype(np.uint8) + pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values + encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)} + yield encoding + + if self.pepper_bin_aug: + for index, pepper_ind in enumerate(self.pepper_indexes): + img_out = add_salt_and_pepper_noise(img_bin_corr, pepper_ind, pepper_ind) + img_out = img_out.astype(np.uint8) + pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values + encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)} + yield encoding + + + if self.pepper_aug: + for index, pepper_ind in enumerate(self.pepper_indexes): + img_out = add_salt_and_pepper_noise(img, pepper_ind, pepper_ind) + img_out = img_out.astype(np.uint8) + pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values + encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)} + yield encoding + + + + else: + pixel_values = self.processor(Image.fromarray(img), return_tensors="pt").pixel_values + encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)} + yield encoding diff --git a/train/config_params.json b/train/config_params.json index b01ac08..34c6376 100644 --- a/train/config_params.json +++ b/train/config_params.json @@ -1,17 +1,17 @@ { "backbone_type" : "transformer", - "task": "cnn-rnn-ocr", + "task": "transformer-ocr", "n_classes" : 2, - "max_len": 280, - "n_epochs" : 3, + "max_len": 192, + "n_epochs" : 1, "input_height" : 32, "input_width" : 512, "weight_decay" : 1e-6, - "n_batch" : 4, + "n_batch" : 1, "learning_rate": 1e-5, "save_interval": 1500, "patches" : false, - "pretraining" : true, + "pretraining" : false, "augmentation" : true, "flip_aug" : false, "blur_aug" : true, @@ -77,7 +77,6 @@ "dir_output": "/home/vahid/extracted_lines/1919_bin/output", "dir_rgb_backgrounds": "/home/vahid/Documents/1_2_test_eynollah/set_rgb_background", "dir_rgb_foregrounds": "/home/vahid/Documents/1_2_test_eynollah/out_set_rgb_foreground", - "dir_img_bin": "/home/vahid/extracted_lines/1919_bin/images_bin", - "characters_txt_file":"/home/vahid/Downloads/models_eynollah/model_eynollah_ocr_cnnrnn_20250930/characters_org.txt" + "dir_img_bin": "/home/vahid/extracted_lines/1919_bin/images_bin" } diff --git a/train/requirements.txt b/train/requirements.txt index 63f3813..e3599a8 100644 --- a/train/requirements.txt +++ b/train/requirements.txt @@ -4,3 +4,8 @@ numpy <1.24.0 tqdm imutils scipy +torch +evaluate +accelerate +jiwer +transformers <= 4.30.2