From 4651000191bfce4d33337668b71e8a3c75b65c1d Mon Sep 17 00:00:00 2001 From: vahidrezanezhad Date: Mon, 15 Dec 2025 11:36:09 +0100 Subject: [PATCH] debuging input shape + enable finetuning a model --- src/eynollah/training/train.py | 12 +++++++----- src/eynollah/training/utils.py | 11 +++++++++++ 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/src/eynollah/training/train.py b/src/eynollah/training/train.py index 469df27..c15a562 100644 --- a/src/eynollah/training/train.py +++ b/src/eynollah/training/train.py @@ -428,8 +428,11 @@ def run(_config, n_classes, n_epochs, input_height, n_classes = len(char_to_num.get_vocabulary()) + 2 - - model = cnn_rnn_ocr_model(image_height=input_height, image_width=input_width, n_classes=n_classes, max_seq=max_len) + if continue_training: + model = load_model(dir_of_start_model) + else: + index_start = 0 + model = cnn_rnn_ocr_model(image_height=input_height, image_width=input_width, n_classes=n_classes, max_seq=max_len) print(model.summary()) @@ -459,8 +462,7 @@ def run(_config, n_classes, n_epochs, input_height, if save_interval: save_weights_callback = SaveWeightsAfterSteps(save_interval, dir_output, _config) - indexer_start = 0 - for i in range(n_epochs): + for i in tqdm(range(index_start, n_epochs + index_start)): if save_interval: model.fit( train_ds, @@ -476,7 +478,7 @@ def run(_config, n_classes, n_epochs, input_height, ) if i >=0: - model.save( os.path.join(dir_output,'model_'+str(i+indexer_start) )) + model.save( os.path.join(dir_output,'model_'+str(i) )) with open(os.path.join(os.path.join(dir_output,'model_'+str(i)),"config.json"), "w") as fp: json.dump(_config, fp) # encode dict into JSON diff --git a/src/eynollah/training/utils.py b/src/eynollah/training/utils.py index 34ac488..c589957 100644 --- a/src/eynollah/training/utils.py +++ b/src/eynollah/training/utils.py @@ -1390,6 +1390,7 @@ def data_gen_ocr(padding_token, n_batch, input_height, input_width, max_len, dir img_out = scale_padd_image_for_ocr(img_out, input_height, input_width) except: img_out = np.copy(img) + img_out = scale_padd_image_for_ocr(img_out, input_height, input_width) ret_x[batchcount, :,:,:] = img_out[:,:,:] ret_y[batchcount, :] = vectorize_label(txt_inp, char_to_num, padding_token, max_len) @@ -1599,6 +1600,7 @@ def data_gen_ocr(padding_token, n_batch, input_height, input_width, max_len, dir img_out = scale_padd_image_for_ocr(img_out, input_height, input_width) except: img_out = np.copy(img) + img_out = scale_padd_image_for_ocr(img_out, input_height, input_width) ret_x[batchcount, :,:,:] = img_out[:,:,:] ret_y[batchcount, :] = vectorize_label(txt_inp, char_to_num, padding_token, max_len) @@ -1619,6 +1621,7 @@ def data_gen_ocr(padding_token, n_batch, input_height, input_width, max_len, dir img_out = scale_padd_image_for_ocr(img_out, input_height, input_width) except: img_out = np.copy(img_bin_corr) + img_out = scale_padd_image_for_ocr(img_out, input_height, input_width) ret_x[batchcount, :,:,:] = img_out[:,:,:] ret_y[batchcount, :] = vectorize_label(txt_inp, char_to_num, padding_token, max_len) @@ -1639,6 +1642,7 @@ def data_gen_ocr(padding_token, n_batch, input_height, input_width, max_len, dir img_out = scale_padd_image_for_ocr(img_out, input_height, input_width) except: img_out = np.copy(img) + img_out = scale_padd_image_for_ocr(img_out, input_height, input_width) ret_x[batchcount, :,:,:] = img_out[:,:,:] ret_y[batchcount, :] = vectorize_label(txt_inp, char_to_num, padding_token, max_len) @@ -1659,6 +1663,7 @@ def data_gen_ocr(padding_token, n_batch, input_height, input_width, max_len, dir img_out = scale_padd_image_for_ocr(img_out, input_height, input_width) except: img_out = np.copy(img_bin_corr) + img_out = scale_padd_image_for_ocr(img_out, input_height, input_width) ret_x[batchcount, :,:,:] = img_out[:,:,:] ret_y[batchcount, :] = vectorize_label(txt_inp, char_to_num, padding_token, max_len) @@ -1679,6 +1684,7 @@ def data_gen_ocr(padding_token, n_batch, input_height, input_width, max_len, dir img_out = scale_padd_image_for_ocr(img_out, input_height, input_width) except: img_out = np.copy(img) + img_out = scale_padd_image_for_ocr(img_out, input_height, input_width) ret_x[batchcount, :,:,:] = img_out[:,:,:] ret_y[batchcount, :] = vectorize_label(txt_inp, char_to_num, padding_token, max_len) @@ -1699,6 +1705,7 @@ def data_gen_ocr(padding_token, n_batch, input_height, input_width, max_len, dir img_out = scale_padd_image_for_ocr(img_out, input_height, input_width) except: img_out = np.copy(img_bin_corr) + img_out = scale_padd_image_for_ocr(img_out, input_height, input_width) ret_x[batchcount, :,:,:] = img_out[:,:,:] ret_y[batchcount, :] = vectorize_label(txt_inp, char_to_num, padding_token, max_len) @@ -1719,6 +1726,7 @@ def data_gen_ocr(padding_token, n_batch, input_height, input_width, max_len, dir img_out = scale_padd_image_for_ocr(img_out, input_height, input_width) except: img_out = np.copy(img) + img_out = scale_padd_image_for_ocr(img_out, input_height, input_width) ret_x[batchcount, :,:,:] = img_out[:,:,:] ret_y[batchcount, :] = vectorize_label(txt_inp, char_to_num, padding_token, max_len) @@ -1739,6 +1747,7 @@ def data_gen_ocr(padding_token, n_batch, input_height, input_width, max_len, dir img_out = scale_padd_image_for_ocr(img_out, input_height, input_width) except: img_out = np.copy(img_bin_corr) + img_out = scale_padd_image_for_ocr(img_out, input_height, input_width) ret_x[batchcount, :,:,:] = img_out[:,:,:] ret_y[batchcount, :] = vectorize_label(txt_inp, char_to_num, padding_token, max_len) @@ -1759,6 +1768,7 @@ def data_gen_ocr(padding_token, n_batch, input_height, input_width, max_len, dir img_out = scale_padd_image_for_ocr(img_out, input_height, input_width) except: img_out = np.copy(img) + img_out = scale_padd_image_for_ocr(img_out, input_height, input_width) ret_x[batchcount, :,:,:] = img_out[:,:,:] ret_y[batchcount, :] = vectorize_label(txt_inp, char_to_num, padding_token, max_len) @@ -1779,6 +1789,7 @@ def data_gen_ocr(padding_token, n_batch, input_height, input_width, max_len, dir img_out = scale_padd_image_for_ocr(img_out, input_height, input_width) except: img_out = np.copy(img_bin_corr) + img_out = scale_padd_image_for_ocr(img_out, input_height, input_width) ret_x[batchcount, :,:,:] = img_out[:,:,:] ret_y[batchcount, :] = vectorize_label(txt_inp, char_to_num, padding_token, max_len)