From 77adcbea8ad97d5b499ae8aeb4c53f500de02dcc Mon Sep 17 00:00:00 2001 From: vahidrezanezhad Date: Wed, 18 Feb 2026 16:47:21 +0100 Subject: [PATCH] copy characters list needed for cnn-rnn ocr model output while training and ensembling --- src/eynollah/training/train.py | 8 ++++++-- src/eynollah/training/weights_ensembling.py | 1 + 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/eynollah/training/train.py b/src/eynollah/training/train.py index 830fab0..e59ef80 100644 --- a/src/eynollah/training/train.py +++ b/src/eynollah/training/train.py @@ -52,7 +52,7 @@ from transformers import VisionEncoderDecoderModel from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments class SaveWeightsAfterSteps(Callback): - def __init__(self, save_interval, save_path, _config): + def __init__(self, save_interval, save_path, _config, characters_cnnrnn_ocr=None): super(SaveWeightsAfterSteps, self).__init__() self.save_interval = save_interval self.save_path = save_path @@ -68,6 +68,9 @@ class SaveWeightsAfterSteps(Callback): self.model.save(save_file) + if characters_cnnrnn_ocr: + os.system("cp "+characters_cnnrnn_ocr+" "+os.path.join(os.path.join(self.save_path, f"model_step_{self.step_count}"),"characters_org.txt")) + with open(os.path.join(os.path.join(self.save_path, f"model_step_{self.step_count}"),"config_eynollah.json"), "w") as fp: json.dump(self._config, fp) # encode dict into JSON print(f"saved model as steps {self.step_count} to {save_file}") @@ -544,7 +547,7 @@ def run( opt = tf.keras.optimizers.Adam(learning_rate=learning_rate)#1e-4)#(lr_schedule) model.compile(optimizer=opt) - save_weights_callback = SaveWeightsAfterSteps(save_interval, dir_output, _config) if save_interval else None + save_weights_callback = SaveWeightsAfterSteps(save_interval, dir_output, _config, characters_cnnrnn_ocr=characters_txt_file) if save_interval else None for i in tqdm(range(index_start, n_epochs + index_start)): if save_interval: @@ -563,6 +566,7 @@ def run( if i >=0: model.save( os.path.join(dir_output,'model_'+str(i) )) + os.system("cp "+characters_txt_file+" "+os.path.join(os.path.join(dir_output,'model_'+str(i)),"characters_org.txt") with open(os.path.join(os.path.join(dir_output,'model_'+str(i)),"config_eynollah.json"), "w") as fp: json.dump(_config, fp) # encode dict into JSON diff --git a/src/eynollah/training/weights_ensembling.py b/src/eynollah/training/weights_ensembling.py index f293658..2f25dbf 100644 --- a/src/eynollah/training/weights_ensembling.py +++ b/src/eynollah/training/weights_ensembling.py @@ -136,6 +136,7 @@ def run_ensembling(dir_models, out, framework): model.set_weights(new_weights) model.save(out) os.system('cp '+os.path.join(os.path.join(dir_models,model_name) , "config_eynollah.json ")+out) + os.system('cp '+os.path.join(os.path.join(dir_models,model_name) , "characters_org.txt ")+out) @click.command() @click.option(