mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-03-02 05:11:57 +01:00
training.utils.preprocess_imgs: fix polymorphy in 27f43c1…
(Functions cannot be both generators and procedures, so make this a pure generator and save the image files on the caller's side; also avoids passing output directories) Moreover, simplify by moving the `os.listdir` into the function body (saving lots of extra variable bindings).
This commit is contained in:
parent
42bab0f935
commit
b6d2440ce1
2 changed files with 36 additions and 67 deletions
|
|
@ -2,6 +2,7 @@ import os
|
||||||
import sys
|
import sys
|
||||||
import json
|
import json
|
||||||
|
|
||||||
|
from tqdm import tqdm
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
||||||
|
|
@ -422,28 +423,22 @@ def run(_config,
|
||||||
os.mkdir(dir_flow_eval_imgs)
|
os.mkdir(dir_flow_eval_imgs)
|
||||||
os.mkdir(dir_flow_eval_labels)
|
os.mkdir(dir_flow_eval_labels)
|
||||||
|
|
||||||
dir_img, dir_seg = get_dirs_or_files(dir_train)
|
|
||||||
dir_img_val, dir_seg_val = get_dirs_or_files(dir_eval)
|
|
||||||
|
|
||||||
imgs_list = list(os.listdir(dir_img))
|
|
||||||
segs_list = list(os.listdir(dir_seg))
|
|
||||||
|
|
||||||
imgs_list_test = list(os.listdir(dir_img_val))
|
|
||||||
segs_list_test = list(os.listdir(dir_seg_val))
|
|
||||||
|
|
||||||
# writing patches into a sub-folder in order to be flowed from directory.
|
# writing patches into a sub-folder in order to be flowed from directory.
|
||||||
preprocess_imgs(_config,
|
def gen(dir_img, dir_lab, dir_flow_imgs, dir_flow_labs, augmentation=True):
|
||||||
imgs_list,
|
indexer = 0
|
||||||
segs_list,
|
for img, lab in tqdm(preprocess_imgs(_config,
|
||||||
dir_img,
|
dir_img,
|
||||||
dir_seg,
|
dir_lab,
|
||||||
|
augmentation=augmentation),
|
||||||
|
desc="data_is_provided"):
|
||||||
|
fname = 'img_%d.png' % indexer
|
||||||
|
cv2.imwrite(os.path.join(dir_flow_imgs, fname), img)
|
||||||
|
cv2.imwrite(os.path.join(dir_flow_labs, fname), lab)
|
||||||
|
indexer += 1
|
||||||
|
gen(*get_dirs_or_files(dir_train),
|
||||||
dir_flow_train_imgs,
|
dir_flow_train_imgs,
|
||||||
dir_flow_train_labels)
|
dir_flow_train_labels)
|
||||||
preprocess_imgs(_config,
|
gen(*get_dirs_or_files(dir_eval),
|
||||||
imgs_list_test,
|
|
||||||
segs_list_test,
|
|
||||||
dir_img_val,
|
|
||||||
dir_seg_val,
|
|
||||||
dir_flow_eval_imgs,
|
dir_flow_eval_imgs,
|
||||||
dir_flow_eval_labels,
|
dir_flow_eval_labels,
|
||||||
augmentation=False)
|
augmentation=False)
|
||||||
|
|
@ -606,13 +601,6 @@ def run(_config,
|
||||||
|
|
||||||
elif task=="cnn-rnn-ocr":
|
elif task=="cnn-rnn-ocr":
|
||||||
|
|
||||||
dir_img_train, dir_lab_train = get_dirs_or_files(dir_train)
|
|
||||||
dir_img_valdn, dir_lab_valdn = get_dirs_or_files(dir_eval)
|
|
||||||
imgs_list_train = list(os.listdir(dir_img_train))
|
|
||||||
labs_list_train = list(os.listdir(dir_lab_train))
|
|
||||||
imgs_list_valdn = list(os.listdir(dir_img_valdn))
|
|
||||||
labs_list_valdn = list(os.listdir(dir_lab_valdn))
|
|
||||||
|
|
||||||
with open(characters_txt_file, 'r') as char_txt_f:
|
with open(characters_txt_file, 'r') as char_txt_f:
|
||||||
characters = json.load(char_txt_f)
|
characters = json.load(char_txt_f)
|
||||||
padding_token = len(characters) + 5
|
padding_token = len(characters) + 5
|
||||||
|
|
@ -631,15 +619,11 @@ def run(_config,
|
||||||
#print(model.summary())
|
#print(model.summary())
|
||||||
|
|
||||||
# todo: use Dataset.map() on Dataset.list_files()
|
# todo: use Dataset.map() on Dataset.list_files()
|
||||||
def get_dataset(imgs_list, labs_list, dir_img, dir_lab):
|
def get_dataset(dir_img, dir_lab):
|
||||||
def gen():
|
def gen():
|
||||||
return preprocess_imgs(_config,
|
return preprocess_imgs(_config,
|
||||||
imgs_list,
|
|
||||||
labs_list,
|
|
||||||
dir_img,
|
dir_img,
|
||||||
dir_lab,
|
dir_lab,
|
||||||
None, # no file I/O, but in-memory
|
|
||||||
None, # no file I/O, but in-memory
|
|
||||||
# extra+overrides
|
# extra+overrides
|
||||||
char_to_num=char_to_num,
|
char_to_num=char_to_num,
|
||||||
padding_token=padding_token
|
padding_token=padding_token
|
||||||
|
|
@ -654,14 +638,8 @@ def run(_config,
|
||||||
.map(lambda x, y: {"image": x, "label": y})
|
.map(lambda x, y: {"image": x, "label": y})
|
||||||
.prefetch(tf.data.AUTOTUNE)
|
.prefetch(tf.data.AUTOTUNE)
|
||||||
)
|
)
|
||||||
train_ds = get_dataset(imgs_list_train,
|
train_ds = get_dataset(*get_dirs_or_files(dir_train))
|
||||||
labs_list_train,
|
valdn_ds = get_dataset(*get_dirs_or_files(dir_eval))
|
||||||
dir_img_train,
|
|
||||||
dir_lab_train)
|
|
||||||
valdn_ds = get_dataset(imgs_list_valdn,
|
|
||||||
labs_list_valdn,
|
|
||||||
dir_img_valdn,
|
|
||||||
dir_lab_valdn)
|
|
||||||
|
|
||||||
#initial_learning_rate = 1e-4
|
#initial_learning_rate = 1e-4
|
||||||
#decay_steps = int (n_epochs * ( len_dataset / n_batch ))
|
#decay_steps = int (n_epochs * ( len_dataset / n_batch ))
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,6 @@ import numpy as np
|
||||||
import seaborn as sns
|
import seaborn as sns
|
||||||
from scipy.ndimage.interpolation import map_coordinates
|
from scipy.ndimage.interpolation import map_coordinates
|
||||||
from scipy.ndimage.filters import gaussian_filter
|
from scipy.ndimage.filters import gaussian_filter
|
||||||
from tqdm import tqdm
|
|
||||||
import imutils
|
import imutils
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
|
|
@ -753,17 +752,11 @@ def get_patches_num_scale_new(img, label, height, width, scaler=1.0):
|
||||||
yield img_patch, label_patch
|
yield img_patch, label_patch
|
||||||
|
|
||||||
|
|
||||||
# TODO: refactor to combine with data_gen_ocr
|
|
||||||
def preprocess_imgs(config,
|
def preprocess_imgs(config,
|
||||||
imgs_list,
|
|
||||||
labs_list,
|
|
||||||
dir_img,
|
dir_img,
|
||||||
dir_lab,
|
dir_lab,
|
||||||
dir_flow_imgs,
|
|
||||||
dir_flow_lbls,
|
|
||||||
logger=None,
|
logger=None,
|
||||||
**kwargs,
|
**kwargs):
|
||||||
):
|
|
||||||
if logger is None:
|
if logger is None:
|
||||||
logger = getLogger('')
|
logger = getLogger('')
|
||||||
|
|
||||||
|
|
@ -779,14 +772,16 @@ def preprocess_imgs(config,
|
||||||
# override keys from call
|
# override keys from call
|
||||||
config.update(kwargs)
|
config.update(kwargs)
|
||||||
|
|
||||||
|
imgs_list = list(sorted(os.listdir(dir_img)))
|
||||||
|
labs_list = list(sorted(os.listdir(dir_lab)))
|
||||||
|
|
||||||
seed = random.getstate()
|
seed = random.getstate()
|
||||||
random.shuffle(imgs_list)
|
random.shuffle(imgs_list)
|
||||||
random.setstate(seed)
|
random.setstate(seed)
|
||||||
random.shuffle(labs_list)
|
random.shuffle(labs_list)
|
||||||
|
|
||||||
# labs_list not used because stem matching more robust
|
# labs_list not used because stem matching more robust
|
||||||
indexer = 0
|
for img, lab in zip(imgs_list, labs_list):
|
||||||
for img, lab in tqdm(zip(imgs_list, labs_list)):
|
|
||||||
img_name = os.path.splitext(img)[0]
|
img_name = os.path.splitext(img)[0]
|
||||||
img = cv2.imread(os.path.join(dir_img, img))
|
img = cv2.imread(os.path.join(dir_img, img))
|
||||||
if config['task'] in ["segmentation", "binarization"]:
|
if config['task'] in ["segmentation", "binarization"]:
|
||||||
|
|
@ -803,20 +798,16 @@ def preprocess_imgs(config,
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if config['task'] == "cnn-rnn-ocr":
|
if config['task'] == "cnn-rnn-ocr":
|
||||||
yield from preprocess_img_ocr(img, img_name, lab,
|
yield from preprocess_img_ocr(img, img_name, lab, **config)
|
||||||
**config)
|
|
||||||
continue
|
continue
|
||||||
for img, lab in preprocess_img(img, img_name, lab,
|
else:
|
||||||
**config):
|
for img, lab in preprocess_img(img, img_name, lab, **config):
|
||||||
cv2.imwrite(os.path.join(dir_flow_imgs, '/img_%d.png' % indexer),
|
yield (resize_image(img,
|
||||||
resize_image(img,
|
|
||||||
config['input_height'],
|
config['input_height'],
|
||||||
config['input_width']))
|
config['input_width']),
|
||||||
cv2.imwrite(os.path.join(dir_flow_lbls, '/img_%d.png' % indexer),
|
|
||||||
resize_image(lab,
|
resize_image(lab,
|
||||||
config['input_height'],
|
config['input_height'],
|
||||||
config['input_width']))
|
config['input_width']))
|
||||||
indexer += 1
|
|
||||||
except:
|
except:
|
||||||
logger.exception("skipping image %s", img_name)
|
logger.exception("skipping image %s", img_name)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue