mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-02-20 16:32:03 +01:00
training: extend index_start to tasks classification and RO
This commit is contained in:
parent
e85003db4a
commit
1581094141
1 changed files with 22 additions and 10 deletions
|
|
@ -423,6 +423,10 @@ def run(_config,
|
|||
#model.save(dir_output+'/'+'model'+'.h5')
|
||||
|
||||
elif task=='classification':
|
||||
if continue_training:
|
||||
model = load_model(dir_of_start_model, compile=False)
|
||||
else:
|
||||
index_start = 0
|
||||
model = resnet50_classifier(n_classes,
|
||||
input_height,
|
||||
input_width,
|
||||
|
|
@ -453,7 +457,8 @@ def run(_config,
|
|||
verbose=1,
|
||||
epochs=n_epochs,
|
||||
metrics=[F1Score(average='macro', name='f1')],
|
||||
callbacks=callbacks)
|
||||
callbacks=callbacks,
|
||||
initial_epoch=index_start)
|
||||
|
||||
usable_checkpoints = np.flatnonzero(np.array(history['val_f1']) > f1_threshold_classification)
|
||||
if len(usable_checkpoints) >= 1:
|
||||
|
|
@ -481,8 +486,15 @@ def run(_config,
|
|||
_log.info("ensemble model saved under '%s'", cp_path)
|
||||
|
||||
elif task=='reading_order':
|
||||
model = machine_based_reading_order_model(
|
||||
n_classes, input_height, input_width, weight_decay, pretraining)
|
||||
if continue_training:
|
||||
model = load_model(dir_of_start_model, compile=False)
|
||||
else:
|
||||
index_start = 0
|
||||
model = machine_based_reading_order_model(n_classes,
|
||||
input_height,
|
||||
input_width,
|
||||
weight_decay,
|
||||
pretraining)
|
||||
|
||||
dir_flow_train_imgs = os.path.join(dir_train, 'images')
|
||||
dir_flow_train_labels = os.path.join(dir_train, 'labels')
|
||||
|
|
@ -495,7 +507,6 @@ def run(_config,
|
|||
#ls_test = os.listdir(dir_flow_train_labels)
|
||||
|
||||
#f1score_tot = [0]
|
||||
indexer_start = 0
|
||||
model.compile(loss="binary_crossentropy",
|
||||
#optimizer=SGD(learning_rate=0.01, momentum=0.9),
|
||||
optimizer=Adam(learning_rate=0.0001), # rs: why not learning_rate?
|
||||
|
|
@ -515,7 +526,8 @@ def run(_config,
|
|||
steps_per_epoch=num_rows / n_batch,
|
||||
verbose=1,
|
||||
epochs=n_epochs,
|
||||
callbacks=callbacks)
|
||||
callbacks=callbacks,
|
||||
initial_epoch=index_start)
|
||||
'''
|
||||
if f1score>f1score_tot[0]:
|
||||
f1score_tot[0] = f1score
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue