mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-03-01 21:02:00 +01:00
training.utils.preprocess_imgs: fix polymorphy in 27f43c1…
(Functions cannot be both generators and procedures, so make this a pure generator and save the image files on the caller's side; also avoids passing output directories) Moreover, simplify by moving the `os.listdir` into the function body (saving lots of extra variable bindings).
This commit is contained in:
parent
42bab0f935
commit
b6d2440ce1
2 changed files with 36 additions and 67 deletions
|
|
@ -2,6 +2,7 @@ import os
|
|||
import sys
|
||||
import json
|
||||
|
||||
from tqdm import tqdm
|
||||
import requests
|
||||
|
||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
||||
|
|
@ -422,28 +423,22 @@ def run(_config,
|
|||
os.mkdir(dir_flow_eval_imgs)
|
||||
os.mkdir(dir_flow_eval_labels)
|
||||
|
||||
dir_img, dir_seg = get_dirs_or_files(dir_train)
|
||||
dir_img_val, dir_seg_val = get_dirs_or_files(dir_eval)
|
||||
|
||||
imgs_list = list(os.listdir(dir_img))
|
||||
segs_list = list(os.listdir(dir_seg))
|
||||
|
||||
imgs_list_test = list(os.listdir(dir_img_val))
|
||||
segs_list_test = list(os.listdir(dir_seg_val))
|
||||
|
||||
# writing patches into a sub-folder in order to be flowed from directory.
|
||||
preprocess_imgs(_config,
|
||||
imgs_list,
|
||||
segs_list,
|
||||
def gen(dir_img, dir_lab, dir_flow_imgs, dir_flow_labs, augmentation=True):
|
||||
indexer = 0
|
||||
for img, lab in tqdm(preprocess_imgs(_config,
|
||||
dir_img,
|
||||
dir_seg,
|
||||
dir_lab,
|
||||
augmentation=augmentation),
|
||||
desc="data_is_provided"):
|
||||
fname = 'img_%d.png' % indexer
|
||||
cv2.imwrite(os.path.join(dir_flow_imgs, fname), img)
|
||||
cv2.imwrite(os.path.join(dir_flow_labs, fname), lab)
|
||||
indexer += 1
|
||||
gen(*get_dirs_or_files(dir_train),
|
||||
dir_flow_train_imgs,
|
||||
dir_flow_train_labels)
|
||||
preprocess_imgs(_config,
|
||||
imgs_list_test,
|
||||
segs_list_test,
|
||||
dir_img_val,
|
||||
dir_seg_val,
|
||||
gen(*get_dirs_or_files(dir_eval),
|
||||
dir_flow_eval_imgs,
|
||||
dir_flow_eval_labels,
|
||||
augmentation=False)
|
||||
|
|
@ -606,13 +601,6 @@ def run(_config,
|
|||
|
||||
elif task=="cnn-rnn-ocr":
|
||||
|
||||
dir_img_train, dir_lab_train = get_dirs_or_files(dir_train)
|
||||
dir_img_valdn, dir_lab_valdn = get_dirs_or_files(dir_eval)
|
||||
imgs_list_train = list(os.listdir(dir_img_train))
|
||||
labs_list_train = list(os.listdir(dir_lab_train))
|
||||
imgs_list_valdn = list(os.listdir(dir_img_valdn))
|
||||
labs_list_valdn = list(os.listdir(dir_lab_valdn))
|
||||
|
||||
with open(characters_txt_file, 'r') as char_txt_f:
|
||||
characters = json.load(char_txt_f)
|
||||
padding_token = len(characters) + 5
|
||||
|
|
@ -631,15 +619,11 @@ def run(_config,
|
|||
#print(model.summary())
|
||||
|
||||
# todo: use Dataset.map() on Dataset.list_files()
|
||||
def get_dataset(imgs_list, labs_list, dir_img, dir_lab):
|
||||
def get_dataset(dir_img, dir_lab):
|
||||
def gen():
|
||||
return preprocess_imgs(_config,
|
||||
imgs_list,
|
||||
labs_list,
|
||||
dir_img,
|
||||
dir_lab,
|
||||
None, # no file I/O, but in-memory
|
||||
None, # no file I/O, but in-memory
|
||||
# extra+overrides
|
||||
char_to_num=char_to_num,
|
||||
padding_token=padding_token
|
||||
|
|
@ -654,14 +638,8 @@ def run(_config,
|
|||
.map(lambda x, y: {"image": x, "label": y})
|
||||
.prefetch(tf.data.AUTOTUNE)
|
||||
)
|
||||
train_ds = get_dataset(imgs_list_train,
|
||||
labs_list_train,
|
||||
dir_img_train,
|
||||
dir_lab_train)
|
||||
valdn_ds = get_dataset(imgs_list_valdn,
|
||||
labs_list_valdn,
|
||||
dir_img_valdn,
|
||||
dir_lab_valdn)
|
||||
train_ds = get_dataset(*get_dirs_or_files(dir_train))
|
||||
valdn_ds = get_dataset(*get_dirs_or_files(dir_eval))
|
||||
|
||||
#initial_learning_rate = 1e-4
|
||||
#decay_steps = int (n_epochs * ( len_dataset / n_batch ))
|
||||
|
|
|
|||
|
|
@ -9,7 +9,6 @@ import numpy as np
|
|||
import seaborn as sns
|
||||
from scipy.ndimage.interpolation import map_coordinates
|
||||
from scipy.ndimage.filters import gaussian_filter
|
||||
from tqdm import tqdm
|
||||
import imutils
|
||||
import tensorflow as tf
|
||||
|
||||
|
|
@ -753,17 +752,11 @@ def get_patches_num_scale_new(img, label, height, width, scaler=1.0):
|
|||
yield img_patch, label_patch
|
||||
|
||||
|
||||
# TODO: refactor to combine with data_gen_ocr
|
||||
def preprocess_imgs(config,
|
||||
imgs_list,
|
||||
labs_list,
|
||||
dir_img,
|
||||
dir_lab,
|
||||
dir_flow_imgs,
|
||||
dir_flow_lbls,
|
||||
logger=None,
|
||||
**kwargs,
|
||||
):
|
||||
**kwargs):
|
||||
if logger is None:
|
||||
logger = getLogger('')
|
||||
|
||||
|
|
@ -779,14 +772,16 @@ def preprocess_imgs(config,
|
|||
# override keys from call
|
||||
config.update(kwargs)
|
||||
|
||||
imgs_list = list(sorted(os.listdir(dir_img)))
|
||||
labs_list = list(sorted(os.listdir(dir_lab)))
|
||||
|
||||
seed = random.getstate()
|
||||
random.shuffle(imgs_list)
|
||||
random.setstate(seed)
|
||||
random.shuffle(labs_list)
|
||||
|
||||
# labs_list not used because stem matching more robust
|
||||
indexer = 0
|
||||
for img, lab in tqdm(zip(imgs_list, labs_list)):
|
||||
for img, lab in zip(imgs_list, labs_list):
|
||||
img_name = os.path.splitext(img)[0]
|
||||
img = cv2.imread(os.path.join(dir_img, img))
|
||||
if config['task'] in ["segmentation", "binarization"]:
|
||||
|
|
@ -803,20 +798,16 @@ def preprocess_imgs(config,
|
|||
|
||||
try:
|
||||
if config['task'] == "cnn-rnn-ocr":
|
||||
yield from preprocess_img_ocr(img, img_name, lab,
|
||||
**config)
|
||||
yield from preprocess_img_ocr(img, img_name, lab, **config)
|
||||
continue
|
||||
for img, lab in preprocess_img(img, img_name, lab,
|
||||
**config):
|
||||
cv2.imwrite(os.path.join(dir_flow_imgs, '/img_%d.png' % indexer),
|
||||
resize_image(img,
|
||||
else:
|
||||
for img, lab in preprocess_img(img, img_name, lab, **config):
|
||||
yield (resize_image(img,
|
||||
config['input_height'],
|
||||
config['input_width']))
|
||||
cv2.imwrite(os.path.join(dir_flow_lbls, '/img_%d.png' % indexer),
|
||||
config['input_width']),
|
||||
resize_image(lab,
|
||||
config['input_height'],
|
||||
config['input_width']))
|
||||
indexer += 1
|
||||
except:
|
||||
logger.exception("skipping image %s", img_name)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue