training.utils.preprocess_imgs: fix polymorphy in 27f43c1

(Functions cannot be both generators and procedures,
 so make this a pure generator and save the image files
 on the caller's side; also avoids passing output
 directories)

Moreover, simplify by moving the `os.listdir` into the function
body (saving lots of extra variable bindings).
This commit is contained in:
Robert Sachunsky 2026-02-25 20:39:15 +01:00
parent 42bab0f935
commit b6d2440ce1
2 changed files with 36 additions and 67 deletions

View file

@ -2,6 +2,7 @@ import os
import sys
import json
from tqdm import tqdm
import requests
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
@ -422,28 +423,22 @@ def run(_config,
os.mkdir(dir_flow_eval_imgs)
os.mkdir(dir_flow_eval_labels)
dir_img, dir_seg = get_dirs_or_files(dir_train)
dir_img_val, dir_seg_val = get_dirs_or_files(dir_eval)
imgs_list = list(os.listdir(dir_img))
segs_list = list(os.listdir(dir_seg))
imgs_list_test = list(os.listdir(dir_img_val))
segs_list_test = list(os.listdir(dir_seg_val))
# writing patches into a sub-folder in order to be flowed from directory.
preprocess_imgs(_config,
imgs_list,
segs_list,
def gen(dir_img, dir_lab, dir_flow_imgs, dir_flow_labs, augmentation=True):
indexer = 0
for img, lab in tqdm(preprocess_imgs(_config,
dir_img,
dir_seg,
dir_lab,
augmentation=augmentation),
desc="data_is_provided"):
fname = 'img_%d.png' % indexer
cv2.imwrite(os.path.join(dir_flow_imgs, fname), img)
cv2.imwrite(os.path.join(dir_flow_labs, fname), lab)
indexer += 1
gen(*get_dirs_or_files(dir_train),
dir_flow_train_imgs,
dir_flow_train_labels)
preprocess_imgs(_config,
imgs_list_test,
segs_list_test,
dir_img_val,
dir_seg_val,
gen(*get_dirs_or_files(dir_eval),
dir_flow_eval_imgs,
dir_flow_eval_labels,
augmentation=False)
@ -606,13 +601,6 @@ def run(_config,
elif task=="cnn-rnn-ocr":
dir_img_train, dir_lab_train = get_dirs_or_files(dir_train)
dir_img_valdn, dir_lab_valdn = get_dirs_or_files(dir_eval)
imgs_list_train = list(os.listdir(dir_img_train))
labs_list_train = list(os.listdir(dir_lab_train))
imgs_list_valdn = list(os.listdir(dir_img_valdn))
labs_list_valdn = list(os.listdir(dir_lab_valdn))
with open(characters_txt_file, 'r') as char_txt_f:
characters = json.load(char_txt_f)
padding_token = len(characters) + 5
@ -631,15 +619,11 @@ def run(_config,
#print(model.summary())
# todo: use Dataset.map() on Dataset.list_files()
def get_dataset(imgs_list, labs_list, dir_img, dir_lab):
def get_dataset(dir_img, dir_lab):
def gen():
return preprocess_imgs(_config,
imgs_list,
labs_list,
dir_img,
dir_lab,
None, # no file I/O, but in-memory
None, # no file I/O, but in-memory
# extra+overrides
char_to_num=char_to_num,
padding_token=padding_token
@ -654,14 +638,8 @@ def run(_config,
.map(lambda x, y: {"image": x, "label": y})
.prefetch(tf.data.AUTOTUNE)
)
train_ds = get_dataset(imgs_list_train,
labs_list_train,
dir_img_train,
dir_lab_train)
valdn_ds = get_dataset(imgs_list_valdn,
labs_list_valdn,
dir_img_valdn,
dir_lab_valdn)
train_ds = get_dataset(*get_dirs_or_files(dir_train))
valdn_ds = get_dataset(*get_dirs_or_files(dir_eval))
#initial_learning_rate = 1e-4
#decay_steps = int (n_epochs * ( len_dataset / n_batch ))

View file

@ -9,7 +9,6 @@ import numpy as np
import seaborn as sns
from scipy.ndimage.interpolation import map_coordinates
from scipy.ndimage.filters import gaussian_filter
from tqdm import tqdm
import imutils
import tensorflow as tf
@ -753,17 +752,11 @@ def get_patches_num_scale_new(img, label, height, width, scaler=1.0):
yield img_patch, label_patch
# TODO: refactor to combine with data_gen_ocr
def preprocess_imgs(config,
imgs_list,
labs_list,
dir_img,
dir_lab,
dir_flow_imgs,
dir_flow_lbls,
logger=None,
**kwargs,
):
**kwargs):
if logger is None:
logger = getLogger('')
@ -779,14 +772,16 @@ def preprocess_imgs(config,
# override keys from call
config.update(kwargs)
imgs_list = list(sorted(os.listdir(dir_img)))
labs_list = list(sorted(os.listdir(dir_lab)))
seed = random.getstate()
random.shuffle(imgs_list)
random.setstate(seed)
random.shuffle(labs_list)
# labs_list not used because stem matching more robust
indexer = 0
for img, lab in tqdm(zip(imgs_list, labs_list)):
for img, lab in zip(imgs_list, labs_list):
img_name = os.path.splitext(img)[0]
img = cv2.imread(os.path.join(dir_img, img))
if config['task'] in ["segmentation", "binarization"]:
@ -803,20 +798,16 @@ def preprocess_imgs(config,
try:
if config['task'] == "cnn-rnn-ocr":
yield from preprocess_img_ocr(img, img_name, lab,
**config)
yield from preprocess_img_ocr(img, img_name, lab, **config)
continue
for img, lab in preprocess_img(img, img_name, lab,
**config):
cv2.imwrite(os.path.join(dir_flow_imgs, '/img_%d.png' % indexer),
resize_image(img,
else:
for img, lab in preprocess_img(img, img_name, lab, **config):
yield (resize_image(img,
config['input_height'],
config['input_width']))
cv2.imwrite(os.path.join(dir_flow_lbls, '/img_%d.png' % indexer),
config['input_width']),
resize_image(lab,
config['input_height'],
config['input_width']))
indexer += 1
except:
logger.exception("skipping image %s", img_name)