From b6d2440ce1eca9f8e2b20f030d604ecd63466aeb Mon Sep 17 00:00:00 2001 From: Robert Sachunsky Date: Wed, 25 Feb 2026 20:39:15 +0100 Subject: [PATCH] =?UTF-8?q?training.utils.preprocess=5Fimgs:=20fix=20polym?= =?UTF-8?q?orphy=20in=2027f43c1=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit (Functions cannot be both generators and procedures, so make this a pure generator and save the image files on the caller's side; also avoids passing output directories) Moreover, simplify by moving the `os.listdir` into the function body (saving lots of extra variable bindings). --- src/eynollah/training/train.py | 66 ++++++++++++---------------------- src/eynollah/training/utils.py | 37 ++++++++----------- 2 files changed, 36 insertions(+), 67 deletions(-) diff --git a/src/eynollah/training/train.py b/src/eynollah/training/train.py index a3cd1e4..74a7a90 100644 --- a/src/eynollah/training/train.py +++ b/src/eynollah/training/train.py @@ -2,6 +2,7 @@ import os import sys import json +from tqdm import tqdm import requests os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' @@ -422,31 +423,25 @@ def run(_config, os.mkdir(dir_flow_eval_imgs) os.mkdir(dir_flow_eval_labels) - dir_img, dir_seg = get_dirs_or_files(dir_train) - dir_img_val, dir_seg_val = get_dirs_or_files(dir_eval) - - imgs_list = list(os.listdir(dir_img)) - segs_list = list(os.listdir(dir_seg)) - - imgs_list_test = list(os.listdir(dir_img_val)) - segs_list_test = list(os.listdir(dir_seg_val)) - # writing patches into a sub-folder in order to be flowed from directory. - preprocess_imgs(_config, - imgs_list, - segs_list, - dir_img, - dir_seg, - dir_flow_train_imgs, - dir_flow_train_labels) - preprocess_imgs(_config, - imgs_list_test, - segs_list_test, - dir_img_val, - dir_seg_val, - dir_flow_eval_imgs, - dir_flow_eval_labels, - augmentation=False) + def gen(dir_img, dir_lab, dir_flow_imgs, dir_flow_labs, augmentation=True): + indexer = 0 + for img, lab in tqdm(preprocess_imgs(_config, + dir_img, + dir_lab, + augmentation=augmentation), + desc="data_is_provided"): + fname = 'img_%d.png' % indexer + cv2.imwrite(os.path.join(dir_flow_imgs, fname), img) + cv2.imwrite(os.path.join(dir_flow_labs, fname), lab) + indexer += 1 + gen(*get_dirs_or_files(dir_train), + dir_flow_train_imgs, + dir_flow_train_labels) + gen(*get_dirs_or_files(dir_eval), + dir_flow_eval_imgs, + dir_flow_eval_labels, + augmentation=False) if weighted_loss: weights = np.zeros(n_classes) @@ -606,13 +601,6 @@ def run(_config, elif task=="cnn-rnn-ocr": - dir_img_train, dir_lab_train = get_dirs_or_files(dir_train) - dir_img_valdn, dir_lab_valdn = get_dirs_or_files(dir_eval) - imgs_list_train = list(os.listdir(dir_img_train)) - labs_list_train = list(os.listdir(dir_lab_train)) - imgs_list_valdn = list(os.listdir(dir_img_valdn)) - labs_list_valdn = list(os.listdir(dir_lab_valdn)) - with open(characters_txt_file, 'r') as char_txt_f: characters = json.load(char_txt_f) padding_token = len(characters) + 5 @@ -631,15 +619,11 @@ def run(_config, #print(model.summary()) # todo: use Dataset.map() on Dataset.list_files() - def get_dataset(imgs_list, labs_list, dir_img, dir_lab): + def get_dataset(dir_img, dir_lab): def gen(): return preprocess_imgs(_config, - imgs_list, - labs_list, dir_img, dir_lab, - None, # no file I/O, but in-memory - None, # no file I/O, but in-memory # extra+overrides char_to_num=char_to_num, padding_token=padding_token @@ -654,14 +638,8 @@ def run(_config, .map(lambda x, y: {"image": x, "label": y}) .prefetch(tf.data.AUTOTUNE) ) - train_ds = get_dataset(imgs_list_train, - labs_list_train, - dir_img_train, - dir_lab_train) - valdn_ds = get_dataset(imgs_list_valdn, - labs_list_valdn, - dir_img_valdn, - dir_lab_valdn) + train_ds = get_dataset(*get_dirs_or_files(dir_train)) + valdn_ds = get_dataset(*get_dirs_or_files(dir_eval)) #initial_learning_rate = 1e-4 #decay_steps = int (n_epochs * ( len_dataset / n_batch )) diff --git a/src/eynollah/training/utils.py b/src/eynollah/training/utils.py index 02a1ca5..33a1fd2 100644 --- a/src/eynollah/training/utils.py +++ b/src/eynollah/training/utils.py @@ -9,7 +9,6 @@ import numpy as np import seaborn as sns from scipy.ndimage.interpolation import map_coordinates from scipy.ndimage.filters import gaussian_filter -from tqdm import tqdm import imutils import tensorflow as tf @@ -753,17 +752,11 @@ def get_patches_num_scale_new(img, label, height, width, scaler=1.0): yield img_patch, label_patch -# TODO: refactor to combine with data_gen_ocr def preprocess_imgs(config, - imgs_list, - labs_list, dir_img, dir_lab, - dir_flow_imgs, - dir_flow_lbls, logger=None, - **kwargs, -): + **kwargs): if logger is None: logger = getLogger('') @@ -779,14 +772,16 @@ def preprocess_imgs(config, # override keys from call config.update(kwargs) + imgs_list = list(sorted(os.listdir(dir_img))) + labs_list = list(sorted(os.listdir(dir_lab))) + seed = random.getstate() random.shuffle(imgs_list) random.setstate(seed) random.shuffle(labs_list) # labs_list not used because stem matching more robust - indexer = 0 - for img, lab in tqdm(zip(imgs_list, labs_list)): + for img, lab in zip(imgs_list, labs_list): img_name = os.path.splitext(img)[0] img = cv2.imread(os.path.join(dir_img, img)) if config['task'] in ["segmentation", "binarization"]: @@ -803,20 +798,16 @@ def preprocess_imgs(config, try: if config['task'] == "cnn-rnn-ocr": - yield from preprocess_img_ocr(img, img_name, lab, - **config) + yield from preprocess_img_ocr(img, img_name, lab, **config) continue - for img, lab in preprocess_img(img, img_name, lab, - **config): - cv2.imwrite(os.path.join(dir_flow_imgs, '/img_%d.png' % indexer), - resize_image(img, - config['input_height'], - config['input_width'])) - cv2.imwrite(os.path.join(dir_flow_lbls, '/img_%d.png' % indexer), - resize_image(lab, - config['input_height'], - config['input_width'])) - indexer += 1 + else: + for img, lab in preprocess_img(img, img_name, lab, **config): + yield (resize_image(img, + config['input_height'], + config['input_width']), + resize_image(lab, + config['input_height'], + config['input_width'])) except: logger.exception("skipping image %s", img_name)