mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-02-20 16:32:03 +01:00
copy characters list needed for cnn-rnn ocr model output while training and ensembling
This commit is contained in:
parent
4f66734e4d
commit
77adcbea8a
2 changed files with 7 additions and 2 deletions
|
|
@ -52,7 +52,7 @@ from transformers import VisionEncoderDecoderModel
|
|||
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
|
||||
|
||||
class SaveWeightsAfterSteps(Callback):
|
||||
def __init__(self, save_interval, save_path, _config):
|
||||
def __init__(self, save_interval, save_path, _config, characters_cnnrnn_ocr=None):
|
||||
super(SaveWeightsAfterSteps, self).__init__()
|
||||
self.save_interval = save_interval
|
||||
self.save_path = save_path
|
||||
|
|
@ -68,6 +68,9 @@ class SaveWeightsAfterSteps(Callback):
|
|||
|
||||
self.model.save(save_file)
|
||||
|
||||
if characters_cnnrnn_ocr:
|
||||
os.system("cp "+characters_cnnrnn_ocr+" "+os.path.join(os.path.join(self.save_path, f"model_step_{self.step_count}"),"characters_org.txt"))
|
||||
|
||||
with open(os.path.join(os.path.join(self.save_path, f"model_step_{self.step_count}"),"config_eynollah.json"), "w") as fp:
|
||||
json.dump(self._config, fp) # encode dict into JSON
|
||||
print(f"saved model as steps {self.step_count} to {save_file}")
|
||||
|
|
@ -544,7 +547,7 @@ def run(
|
|||
opt = tf.keras.optimizers.Adam(learning_rate=learning_rate)#1e-4)#(lr_schedule)
|
||||
model.compile(optimizer=opt)
|
||||
|
||||
save_weights_callback = SaveWeightsAfterSteps(save_interval, dir_output, _config) if save_interval else None
|
||||
save_weights_callback = SaveWeightsAfterSteps(save_interval, dir_output, _config, characters_cnnrnn_ocr=characters_txt_file) if save_interval else None
|
||||
|
||||
for i in tqdm(range(index_start, n_epochs + index_start)):
|
||||
if save_interval:
|
||||
|
|
@ -563,6 +566,7 @@ def run(
|
|||
|
||||
if i >=0:
|
||||
model.save( os.path.join(dir_output,'model_'+str(i) ))
|
||||
os.system("cp "+characters_txt_file+" "+os.path.join(os.path.join(dir_output,'model_'+str(i)),"characters_org.txt")
|
||||
with open(os.path.join(os.path.join(dir_output,'model_'+str(i)),"config_eynollah.json"), "w") as fp:
|
||||
json.dump(_config, fp) # encode dict into JSON
|
||||
|
||||
|
|
|
|||
|
|
@ -136,6 +136,7 @@ def run_ensembling(dir_models, out, framework):
|
|||
model.set_weights(new_weights)
|
||||
model.save(out)
|
||||
os.system('cp '+os.path.join(os.path.join(dir_models,model_name) , "config_eynollah.json ")+out)
|
||||
os.system('cp '+os.path.join(os.path.join(dir_models,model_name) , "characters_org.txt ")+out)
|
||||
|
||||
@click.command()
|
||||
@click.option(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue