mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-03-01 21:02:00 +01:00
training.train: add validation data for OCR
This commit is contained in:
parent
b399db3c00
commit
36e370aa45
1 changed files with 31 additions and 23 deletions
|
|
@ -10,7 +10,7 @@ import tensorflow as tf
|
||||||
from tensorflow.keras.optimizers import SGD, Adam
|
from tensorflow.keras.optimizers import SGD, Adam
|
||||||
from tensorflow.keras.metrics import MeanIoU, F1Score
|
from tensorflow.keras.metrics import MeanIoU, F1Score
|
||||||
from tensorflow.keras.models import load_model
|
from tensorflow.keras.models import load_model
|
||||||
from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard
|
from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard, EarlyStopping
|
||||||
from tensorflow.keras.layers import StringLookup
|
from tensorflow.keras.layers import StringLookup
|
||||||
from tensorflow.keras.utils import image_dataset_from_directory
|
from tensorflow.keras.utils import image_dataset_from_directory
|
||||||
from tensorflow.keras.backend import one_hot
|
from tensorflow.keras.backend import one_hot
|
||||||
|
|
@ -606,12 +606,12 @@ def run(_config,
|
||||||
|
|
||||||
elif task=="cnn-rnn-ocr":
|
elif task=="cnn-rnn-ocr":
|
||||||
|
|
||||||
dir_img, dir_lab = get_dirs_or_files(dir_train)
|
dir_img_train, dir_lab_train = get_dirs_or_files(dir_train)
|
||||||
dir_img_val, dir_lab_val = get_dirs_or_files(dir_eval)
|
dir_img_valdn, dir_lab_valdn = get_dirs_or_files(dir_eval)
|
||||||
imgs_list = list(os.listdir(dir_img))
|
imgs_list_train = list(os.listdir(dir_img_train))
|
||||||
labs_list = list(os.listdir(dir_lab))
|
labs_list_train = list(os.listdir(dir_lab_train))
|
||||||
imgs_list_val = list(os.listdir(dir_img_val))
|
imgs_list_valdn = list(os.listdir(dir_img_valdn))
|
||||||
labs_list_val = list(os.listdir(dir_lab_val))
|
labs_list_valdn = list(os.listdir(dir_lab_valdn))
|
||||||
|
|
||||||
with open(characters_txt_file, 'r') as char_txt_f:
|
with open(characters_txt_file, 'r') as char_txt_f:
|
||||||
characters = json.load(char_txt_f)
|
characters = json.load(char_txt_f)
|
||||||
|
|
@ -631,20 +631,20 @@ def run(_config,
|
||||||
#print(model.summary())
|
#print(model.summary())
|
||||||
|
|
||||||
# todo: use Dataset.map() on Dataset.list_files()
|
# todo: use Dataset.map() on Dataset.list_files()
|
||||||
# todo: test_ds
|
def get_dataset(imgs_list, labs_list, dir_img, dir_lab):
|
||||||
def gen():
|
def gen():
|
||||||
return preprocess_imgs(_config,
|
return preprocess_imgs(_config,
|
||||||
imgs_list,
|
imgs_list,
|
||||||
labs_list,
|
labs_list,
|
||||||
dir_img,
|
dir_img,
|
||||||
dir_lab,
|
dir_lab,
|
||||||
None, # no file I/O, but in-memory
|
None, # no file I/O, but in-memory
|
||||||
None, # no file I/O, but in-memory
|
None, # no file I/O, but in-memory
|
||||||
# extra+overrides
|
# extra+overrides
|
||||||
char_to_num=char_to_num,
|
char_to_num=char_to_num,
|
||||||
padding_token=padding_token
|
padding_token=padding_token
|
||||||
)
|
)
|
||||||
train_ds = (tf.data.Dataset.from_generator(gen, (tf.float32, tf.int64))
|
return (tf.data.Dataset.from_generator(gen, (tf.float32, tf.int64))
|
||||||
.padded_batch(n_batch,
|
.padded_batch(n_batch,
|
||||||
padded_shapes=([input_height, input_width, 3], [None]),
|
padded_shapes=([input_height, input_width, 3], [None]),
|
||||||
padding_values=(None, tf.constant(padding_token, dtype=tf.int64)),
|
padding_values=(None, tf.constant(padding_token, dtype=tf.int64)),
|
||||||
|
|
@ -653,7 +653,15 @@ def run(_config,
|
||||||
)
|
)
|
||||||
.map(lambda x, y: {"image": x, "label": y})
|
.map(lambda x, y: {"image": x, "label": y})
|
||||||
.prefetch(tf.data.AUTOTUNE)
|
.prefetch(tf.data.AUTOTUNE)
|
||||||
)
|
)
|
||||||
|
train_ds = get_dataset(imgs_list_train,
|
||||||
|
labs_list_train,
|
||||||
|
dir_img_train,
|
||||||
|
dir_lab_train)
|
||||||
|
valdn_ds = get_dataset(imgs_list_valdn,
|
||||||
|
labs_list_valdn,
|
||||||
|
dir_img_valdn,
|
||||||
|
dir_lab_valdn)
|
||||||
|
|
||||||
#initial_learning_rate = 1e-4
|
#initial_learning_rate = 1e-4
|
||||||
#decay_steps = int (n_epochs * ( len_dataset / n_batch ))
|
#decay_steps = int (n_epochs * ( len_dataset / n_batch ))
|
||||||
|
|
@ -669,7 +677,7 @@ def run(_config,
|
||||||
callbacks.append(SaveWeightsAfterSteps(save_interval, dir_output, _config))
|
callbacks.append(SaveWeightsAfterSteps(save_interval, dir_output, _config))
|
||||||
model.fit(
|
model.fit(
|
||||||
train_ds,
|
train_ds,
|
||||||
#validation_data=test_ds,
|
validation_data=valdn_ds,
|
||||||
verbose=1,
|
verbose=1,
|
||||||
epochs=n_epochs,
|
epochs=n_epochs,
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue