mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-02-20 16:32:03 +01:00
copy characters list needed for cnn-rnn ocr model output while training and ensembling
This commit is contained in:
parent
4f66734e4d
commit
77adcbea8a
2 changed files with 7 additions and 2 deletions
|
|
@ -52,7 +52,7 @@ from transformers import VisionEncoderDecoderModel
|
||||||
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
|
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
|
||||||
|
|
||||||
class SaveWeightsAfterSteps(Callback):
|
class SaveWeightsAfterSteps(Callback):
|
||||||
def __init__(self, save_interval, save_path, _config):
|
def __init__(self, save_interval, save_path, _config, characters_cnnrnn_ocr=None):
|
||||||
super(SaveWeightsAfterSteps, self).__init__()
|
super(SaveWeightsAfterSteps, self).__init__()
|
||||||
self.save_interval = save_interval
|
self.save_interval = save_interval
|
||||||
self.save_path = save_path
|
self.save_path = save_path
|
||||||
|
|
@ -68,6 +68,9 @@ class SaveWeightsAfterSteps(Callback):
|
||||||
|
|
||||||
self.model.save(save_file)
|
self.model.save(save_file)
|
||||||
|
|
||||||
|
if characters_cnnrnn_ocr:
|
||||||
|
os.system("cp "+characters_cnnrnn_ocr+" "+os.path.join(os.path.join(self.save_path, f"model_step_{self.step_count}"),"characters_org.txt"))
|
||||||
|
|
||||||
with open(os.path.join(os.path.join(self.save_path, f"model_step_{self.step_count}"),"config_eynollah.json"), "w") as fp:
|
with open(os.path.join(os.path.join(self.save_path, f"model_step_{self.step_count}"),"config_eynollah.json"), "w") as fp:
|
||||||
json.dump(self._config, fp) # encode dict into JSON
|
json.dump(self._config, fp) # encode dict into JSON
|
||||||
print(f"saved model as steps {self.step_count} to {save_file}")
|
print(f"saved model as steps {self.step_count} to {save_file}")
|
||||||
|
|
@ -544,7 +547,7 @@ def run(
|
||||||
opt = tf.keras.optimizers.Adam(learning_rate=learning_rate)#1e-4)#(lr_schedule)
|
opt = tf.keras.optimizers.Adam(learning_rate=learning_rate)#1e-4)#(lr_schedule)
|
||||||
model.compile(optimizer=opt)
|
model.compile(optimizer=opt)
|
||||||
|
|
||||||
save_weights_callback = SaveWeightsAfterSteps(save_interval, dir_output, _config) if save_interval else None
|
save_weights_callback = SaveWeightsAfterSteps(save_interval, dir_output, _config, characters_cnnrnn_ocr=characters_txt_file) if save_interval else None
|
||||||
|
|
||||||
for i in tqdm(range(index_start, n_epochs + index_start)):
|
for i in tqdm(range(index_start, n_epochs + index_start)):
|
||||||
if save_interval:
|
if save_interval:
|
||||||
|
|
@ -563,6 +566,7 @@ def run(
|
||||||
|
|
||||||
if i >=0:
|
if i >=0:
|
||||||
model.save( os.path.join(dir_output,'model_'+str(i) ))
|
model.save( os.path.join(dir_output,'model_'+str(i) ))
|
||||||
|
os.system("cp "+characters_txt_file+" "+os.path.join(os.path.join(dir_output,'model_'+str(i)),"characters_org.txt")
|
||||||
with open(os.path.join(os.path.join(dir_output,'model_'+str(i)),"config_eynollah.json"), "w") as fp:
|
with open(os.path.join(os.path.join(dir_output,'model_'+str(i)),"config_eynollah.json"), "w") as fp:
|
||||||
json.dump(_config, fp) # encode dict into JSON
|
json.dump(_config, fp) # encode dict into JSON
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -136,6 +136,7 @@ def run_ensembling(dir_models, out, framework):
|
||||||
model.set_weights(new_weights)
|
model.set_weights(new_weights)
|
||||||
model.save(out)
|
model.save(out)
|
||||||
os.system('cp '+os.path.join(os.path.join(dir_models,model_name) , "config_eynollah.json ")+out)
|
os.system('cp '+os.path.join(os.path.join(dir_models,model_name) , "config_eynollah.json ")+out)
|
||||||
|
os.system('cp '+os.path.join(os.path.join(dir_models,model_name) , "characters_org.txt ")+out)
|
||||||
|
|
||||||
@click.command()
|
@click.command()
|
||||||
@click.option(
|
@click.option(
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue