mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-02-20 16:32:03 +01:00
eynollah config files has renamed from config.json to config_eynollah.json - training trocr model still misses to write config file into checkpoint directories
This commit is contained in:
parent
b426f7f152
commit
4f66734e4d
3 changed files with 12 additions and 14 deletions
|
|
@ -740,12 +740,10 @@ class sbb_predict:
|
|||
)
|
||||
def main(image, dir_in, model, patches, save, save_layout, ground_truth, xml_file, cpu, out, min_area):
|
||||
assert image or dir_in, "Either a single image -i or a dir_in -di is required"
|
||||
try:
|
||||
|
||||
with open(os.path.join(model,'config_eynollah.json')) as f:
|
||||
config_params_model = json.load(f)
|
||||
except:
|
||||
with open(os.path.join(model,'config.json')) as f:
|
||||
config_params_model = json.load(f)
|
||||
|
||||
task = config_params_model['task']
|
||||
if task != 'classification' and task != 'reading_order' and task != "cnn-rnn-ocr" and task != "trocr":
|
||||
if image and not save:
|
||||
|
|
|
|||
|
|
@ -68,7 +68,7 @@ class SaveWeightsAfterSteps(Callback):
|
|||
|
||||
self.model.save(save_file)
|
||||
|
||||
with open(os.path.join(os.path.join(self.save_path, f"model_step_{self.step_count}"),"config.json"), "w") as fp:
|
||||
with open(os.path.join(os.path.join(self.save_path, f"model_step_{self.step_count}"),"config_eynollah.json"), "w") as fp:
|
||||
json.dump(self._config, fp) # encode dict into JSON
|
||||
print(f"saved model as steps {self.step_count} to {save_file}")
|
||||
|
||||
|
|
@ -484,7 +484,7 @@ def run(
|
|||
|
||||
model.save(os.path.join(dir_output,'model_'+str(i)))
|
||||
|
||||
with open(os.path.join(os.path.join(dir_output,'model_'+str(i)),"config.json"), "w") as fp:
|
||||
with open(os.path.join(os.path.join(dir_output,'model_'+str(i)),"config_eynollah.json"), "w") as fp:
|
||||
json.dump(_config, fp) # encode dict into JSON
|
||||
|
||||
#os.system('rm -rf '+dir_train_flowing)
|
||||
|
|
@ -563,7 +563,7 @@ def run(
|
|||
|
||||
if i >=0:
|
||||
model.save( os.path.join(dir_output,'model_'+str(i) ))
|
||||
with open(os.path.join(os.path.join(dir_output,'model_'+str(i)),"config.json"), "w") as fp:
|
||||
with open(os.path.join(os.path.join(dir_output,'model_'+str(i)),"config_eynollah.json"), "w") as fp:
|
||||
json.dump(_config, fp) # encode dict into JSON
|
||||
|
||||
|
||||
|
|
@ -731,10 +731,10 @@ def run(
|
|||
model_weight_averaged.set_weights(new_weights)
|
||||
|
||||
model_weight_averaged.save(os.path.join(dir_output,'model_ens_avg'))
|
||||
with open(os.path.join( os.path.join(dir_output,'model_ens_avg'), "config.json"), "w") as fp:
|
||||
with open(os.path.join( os.path.join(dir_output,'model_ens_avg'), "config_eynollah.json"), "w") as fp:
|
||||
json.dump(_config, fp) # encode dict into JSON
|
||||
|
||||
with open(os.path.join( os.path.join(dir_output,'model_best'), "config.json"), "w") as fp:
|
||||
with open(os.path.join( os.path.join(dir_output,'model_best'), "config_eynollah.json"), "w") as fp:
|
||||
json.dump(_config, fp) # encode dict into JSON
|
||||
|
||||
elif task=='reading_order':
|
||||
|
|
@ -767,7 +767,7 @@ def run(
|
|||
history = model.fit(generate_arrays_from_folder_reading_order(dir_flow_train_labels, dir_flow_train_imgs, n_batch, input_height, input_width, n_classes, thetha, augmentation), steps_per_epoch=num_rows / n_batch, verbose=1)
|
||||
model.save( os.path.join(dir_output,'model_'+str(i+indexer_start) ))
|
||||
|
||||
with open(os.path.join(os.path.join(dir_output,'model_'+str(i)),"config.json"), "w") as fp:
|
||||
with open(os.path.join(os.path.join(dir_output,'model_'+str(i)),"config_eynollah.json"), "w") as fp:
|
||||
json.dump(_config, fp) # encode dict into JSON
|
||||
'''
|
||||
if f1score>f1score_tot[0]:
|
||||
|
|
|
|||
|
|
@ -113,7 +113,7 @@ def run_ensembling(dir_models, out, framework):
|
|||
model.load_state_dict(sd_models[0])
|
||||
os.system("mkdir "+out)
|
||||
torch.save(model.state_dict(), os.path.join(out, "pytorch_model.bin"))
|
||||
os.system('cp '+os.path.join(os.path.join(dir_models,model_name) , "config.json ")+out)
|
||||
os.system('cp '+os.path.join(os.path.join(dir_models,model_name) , "config_eynollah.json ")+out)
|
||||
|
||||
else:
|
||||
weights=[]
|
||||
|
|
@ -135,7 +135,7 @@ def run_ensembling(dir_models, out, framework):
|
|||
|
||||
model.set_weights(new_weights)
|
||||
model.save(out)
|
||||
os.system('cp '+os.path.join(os.path.join(dir_models,model_name) , "config.json ")+out)
|
||||
os.system('cp '+os.path.join(os.path.join(dir_models,model_name) , "config_eynollah.json ")+out)
|
||||
|
||||
@click.command()
|
||||
@click.option(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue