copy characters list needed for cnn-rnn ocr model output while training and ensembling

This commit is contained in:
vahidrezanezhad 2026-02-18 16:47:21 +01:00
parent 4f66734e4d
commit 77adcbea8a
2 changed files with 7 additions and 2 deletions

View file

@ -52,7 +52,7 @@ from transformers import VisionEncoderDecoderModel
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
class SaveWeightsAfterSteps(Callback):
def __init__(self, save_interval, save_path, _config):
def __init__(self, save_interval, save_path, _config, characters_cnnrnn_ocr=None):
super(SaveWeightsAfterSteps, self).__init__()
self.save_interval = save_interval
self.save_path = save_path
@ -68,6 +68,9 @@ class SaveWeightsAfterSteps(Callback):
self.model.save(save_file)
if characters_cnnrnn_ocr:
os.system("cp "+characters_cnnrnn_ocr+" "+os.path.join(os.path.join(self.save_path, f"model_step_{self.step_count}"),"characters_org.txt"))
with open(os.path.join(os.path.join(self.save_path, f"model_step_{self.step_count}"),"config_eynollah.json"), "w") as fp:
json.dump(self._config, fp) # encode dict into JSON
print(f"saved model as steps {self.step_count} to {save_file}")
@ -544,7 +547,7 @@ def run(
opt = tf.keras.optimizers.Adam(learning_rate=learning_rate)#1e-4)#(lr_schedule)
model.compile(optimizer=opt)
save_weights_callback = SaveWeightsAfterSteps(save_interval, dir_output, _config) if save_interval else None
save_weights_callback = SaveWeightsAfterSteps(save_interval, dir_output, _config, characters_cnnrnn_ocr=characters_txt_file) if save_interval else None
for i in tqdm(range(index_start, n_epochs + index_start)):
if save_interval:
@ -563,6 +566,7 @@ def run(
if i >=0:
model.save( os.path.join(dir_output,'model_'+str(i) ))
os.system("cp "+characters_txt_file+" "+os.path.join(os.path.join(dir_output,'model_'+str(i)),"characters_org.txt")
with open(os.path.join(os.path.join(dir_output,'model_'+str(i)),"config_eynollah.json"), "w") as fp:
json.dump(_config, fp) # encode dict into JSON

View file

@ -136,6 +136,7 @@ def run_ensembling(dir_models, out, framework):
model.set_weights(new_weights)
model.save(out)
os.system('cp '+os.path.join(os.path.join(dir_models,model_name) , "config_eynollah.json ")+out)
os.system('cp '+os.path.join(os.path.join(dir_models,model_name) , "characters_org.txt ")+out)
@click.command()
@click.option(