mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-02-20 16:32:03 +01:00
eynollah config files has renamed from config.json to config_eynollah.json - training trocr model still misses to write config file into checkpoint directories
This commit is contained in:
parent
b426f7f152
commit
4f66734e4d
3 changed files with 12 additions and 14 deletions
|
|
@ -740,12 +740,10 @@ class sbb_predict:
|
||||||
)
|
)
|
||||||
def main(image, dir_in, model, patches, save, save_layout, ground_truth, xml_file, cpu, out, min_area):
|
def main(image, dir_in, model, patches, save, save_layout, ground_truth, xml_file, cpu, out, min_area):
|
||||||
assert image or dir_in, "Either a single image -i or a dir_in -di is required"
|
assert image or dir_in, "Either a single image -i or a dir_in -di is required"
|
||||||
try:
|
|
||||||
with open(os.path.join(model,'config_eynollah.json')) as f:
|
with open(os.path.join(model,'config_eynollah.json')) as f:
|
||||||
config_params_model = json.load(f)
|
config_params_model = json.load(f)
|
||||||
except:
|
|
||||||
with open(os.path.join(model,'config.json')) as f:
|
|
||||||
config_params_model = json.load(f)
|
|
||||||
task = config_params_model['task']
|
task = config_params_model['task']
|
||||||
if task != 'classification' and task != 'reading_order' and task != "cnn-rnn-ocr" and task != "trocr":
|
if task != 'classification' and task != 'reading_order' and task != "cnn-rnn-ocr" and task != "trocr":
|
||||||
if image and not save:
|
if image and not save:
|
||||||
|
|
|
||||||
|
|
@ -68,7 +68,7 @@ class SaveWeightsAfterSteps(Callback):
|
||||||
|
|
||||||
self.model.save(save_file)
|
self.model.save(save_file)
|
||||||
|
|
||||||
with open(os.path.join(os.path.join(self.save_path, f"model_step_{self.step_count}"),"config.json"), "w") as fp:
|
with open(os.path.join(os.path.join(self.save_path, f"model_step_{self.step_count}"),"config_eynollah.json"), "w") as fp:
|
||||||
json.dump(self._config, fp) # encode dict into JSON
|
json.dump(self._config, fp) # encode dict into JSON
|
||||||
print(f"saved model as steps {self.step_count} to {save_file}")
|
print(f"saved model as steps {self.step_count} to {save_file}")
|
||||||
|
|
||||||
|
|
@ -484,7 +484,7 @@ def run(
|
||||||
|
|
||||||
model.save(os.path.join(dir_output,'model_'+str(i)))
|
model.save(os.path.join(dir_output,'model_'+str(i)))
|
||||||
|
|
||||||
with open(os.path.join(os.path.join(dir_output,'model_'+str(i)),"config.json"), "w") as fp:
|
with open(os.path.join(os.path.join(dir_output,'model_'+str(i)),"config_eynollah.json"), "w") as fp:
|
||||||
json.dump(_config, fp) # encode dict into JSON
|
json.dump(_config, fp) # encode dict into JSON
|
||||||
|
|
||||||
#os.system('rm -rf '+dir_train_flowing)
|
#os.system('rm -rf '+dir_train_flowing)
|
||||||
|
|
@ -563,7 +563,7 @@ def run(
|
||||||
|
|
||||||
if i >=0:
|
if i >=0:
|
||||||
model.save( os.path.join(dir_output,'model_'+str(i) ))
|
model.save( os.path.join(dir_output,'model_'+str(i) ))
|
||||||
with open(os.path.join(os.path.join(dir_output,'model_'+str(i)),"config.json"), "w") as fp:
|
with open(os.path.join(os.path.join(dir_output,'model_'+str(i)),"config_eynollah.json"), "w") as fp:
|
||||||
json.dump(_config, fp) # encode dict into JSON
|
json.dump(_config, fp) # encode dict into JSON
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -731,10 +731,10 @@ def run(
|
||||||
model_weight_averaged.set_weights(new_weights)
|
model_weight_averaged.set_weights(new_weights)
|
||||||
|
|
||||||
model_weight_averaged.save(os.path.join(dir_output,'model_ens_avg'))
|
model_weight_averaged.save(os.path.join(dir_output,'model_ens_avg'))
|
||||||
with open(os.path.join( os.path.join(dir_output,'model_ens_avg'), "config.json"), "w") as fp:
|
with open(os.path.join( os.path.join(dir_output,'model_ens_avg'), "config_eynollah.json"), "w") as fp:
|
||||||
json.dump(_config, fp) # encode dict into JSON
|
json.dump(_config, fp) # encode dict into JSON
|
||||||
|
|
||||||
with open(os.path.join( os.path.join(dir_output,'model_best'), "config.json"), "w") as fp:
|
with open(os.path.join( os.path.join(dir_output,'model_best'), "config_eynollah.json"), "w") as fp:
|
||||||
json.dump(_config, fp) # encode dict into JSON
|
json.dump(_config, fp) # encode dict into JSON
|
||||||
|
|
||||||
elif task=='reading_order':
|
elif task=='reading_order':
|
||||||
|
|
@ -767,7 +767,7 @@ def run(
|
||||||
history = model.fit(generate_arrays_from_folder_reading_order(dir_flow_train_labels, dir_flow_train_imgs, n_batch, input_height, input_width, n_classes, thetha, augmentation), steps_per_epoch=num_rows / n_batch, verbose=1)
|
history = model.fit(generate_arrays_from_folder_reading_order(dir_flow_train_labels, dir_flow_train_imgs, n_batch, input_height, input_width, n_classes, thetha, augmentation), steps_per_epoch=num_rows / n_batch, verbose=1)
|
||||||
model.save( os.path.join(dir_output,'model_'+str(i+indexer_start) ))
|
model.save( os.path.join(dir_output,'model_'+str(i+indexer_start) ))
|
||||||
|
|
||||||
with open(os.path.join(os.path.join(dir_output,'model_'+str(i)),"config.json"), "w") as fp:
|
with open(os.path.join(os.path.join(dir_output,'model_'+str(i)),"config_eynollah.json"), "w") as fp:
|
||||||
json.dump(_config, fp) # encode dict into JSON
|
json.dump(_config, fp) # encode dict into JSON
|
||||||
'''
|
'''
|
||||||
if f1score>f1score_tot[0]:
|
if f1score>f1score_tot[0]:
|
||||||
|
|
|
||||||
|
|
@ -113,7 +113,7 @@ def run_ensembling(dir_models, out, framework):
|
||||||
model.load_state_dict(sd_models[0])
|
model.load_state_dict(sd_models[0])
|
||||||
os.system("mkdir "+out)
|
os.system("mkdir "+out)
|
||||||
torch.save(model.state_dict(), os.path.join(out, "pytorch_model.bin"))
|
torch.save(model.state_dict(), os.path.join(out, "pytorch_model.bin"))
|
||||||
os.system('cp '+os.path.join(os.path.join(dir_models,model_name) , "config.json ")+out)
|
os.system('cp '+os.path.join(os.path.join(dir_models,model_name) , "config_eynollah.json ")+out)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
weights=[]
|
weights=[]
|
||||||
|
|
@ -135,7 +135,7 @@ def run_ensembling(dir_models, out, framework):
|
||||||
|
|
||||||
model.set_weights(new_weights)
|
model.set_weights(new_weights)
|
||||||
model.save(out)
|
model.save(out)
|
||||||
os.system('cp '+os.path.join(os.path.join(dir_models,model_name) , "config.json ")+out)
|
os.system('cp '+os.path.join(os.path.join(dir_models,model_name) , "config_eynollah.json ")+out)
|
||||||
|
|
||||||
@click.command()
|
@click.command()
|
||||||
@click.option(
|
@click.option(
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue