mirror of
https://github.com/qurator-spk/sbb_pixelwise_segmentation.git
synced 2025-07-04 08:00:03 +02:00
machine based reading order training is integrated
This commit is contained in:
parent
bf1468391a
commit
4e4490d740
3 changed files with 109 additions and 0 deletions
55
models.py
55
models.py
|
@ -544,4 +544,59 @@ def resnet50_classifier(n_classes,input_height=224,input_width=224,weight_decay=
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
def machine_based_reading_order_model(n_classes,input_height=224,input_width=224,weight_decay=1e-6,pretraining=False):
|
||||||
|
assert input_height%32 == 0
|
||||||
|
assert input_width%32 == 0
|
||||||
|
|
||||||
|
img_input = Input(shape=(input_height,input_width , 3 ))
|
||||||
|
|
||||||
|
if IMAGE_ORDERING == 'channels_last':
|
||||||
|
bn_axis = 3
|
||||||
|
else:
|
||||||
|
bn_axis = 1
|
||||||
|
|
||||||
|
x1 = ZeroPadding2D((3, 3), data_format=IMAGE_ORDERING)(img_input)
|
||||||
|
x1 = Conv2D(64, (7, 7), data_format=IMAGE_ORDERING, strides=(2, 2),kernel_regularizer=l2(weight_decay), name='conv1')(x1)
|
||||||
|
|
||||||
|
x1 = BatchNormalization(axis=bn_axis, name='bn_conv1')(x1)
|
||||||
|
x1 = Activation('relu')(x1)
|
||||||
|
x1 = MaxPooling2D((3, 3) , data_format=IMAGE_ORDERING , strides=(2, 2))(x1)
|
||||||
|
|
||||||
|
x1 = conv_block(x1, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1))
|
||||||
|
x1 = identity_block(x1, 3, [64, 64, 256], stage=2, block='b')
|
||||||
|
x1 = identity_block(x1, 3, [64, 64, 256], stage=2, block='c')
|
||||||
|
|
||||||
|
x1 = conv_block(x1, 3, [128, 128, 512], stage=3, block='a')
|
||||||
|
x1 = identity_block(x1, 3, [128, 128, 512], stage=3, block='b')
|
||||||
|
x1 = identity_block(x1, 3, [128, 128, 512], stage=3, block='c')
|
||||||
|
x1 = identity_block(x1, 3, [128, 128, 512], stage=3, block='d')
|
||||||
|
|
||||||
|
x1 = conv_block(x1, 3, [256, 256, 1024], stage=4, block='a')
|
||||||
|
x1 = identity_block(x1, 3, [256, 256, 1024], stage=4, block='b')
|
||||||
|
x1 = identity_block(x1, 3, [256, 256, 1024], stage=4, block='c')
|
||||||
|
x1 = identity_block(x1, 3, [256, 256, 1024], stage=4, block='d')
|
||||||
|
x1 = identity_block(x1, 3, [256, 256, 1024], stage=4, block='e')
|
||||||
|
x1 = identity_block(x1, 3, [256, 256, 1024], stage=4, block='f')
|
||||||
|
|
||||||
|
x1 = conv_block(x1, 3, [512, 512, 2048], stage=5, block='a')
|
||||||
|
x1 = identity_block(x1, 3, [512, 512, 2048], stage=5, block='b')
|
||||||
|
x1 = identity_block(x1, 3, [512, 512, 2048], stage=5, block='c')
|
||||||
|
|
||||||
|
if pretraining:
|
||||||
|
Model(img_input , x1).load_weights(resnet50_Weights_path)
|
||||||
|
|
||||||
|
x1 = AveragePooling2D((7, 7), name='avg_pool1')(x1)
|
||||||
|
flattened = Flatten()(x1)
|
||||||
|
|
||||||
|
o = Dense(256, activation='relu', name='fc512')(flattened)
|
||||||
|
o=Dropout(0.2)(o)
|
||||||
|
|
||||||
|
o = Dense(256, activation='relu', name='fc512a')(o)
|
||||||
|
o=Dropout(0.2)(o)
|
||||||
|
|
||||||
|
o = Dense(n_classes, activation='sigmoid', name='fc1000')(o)
|
||||||
|
model = Model(img_input , o)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
31
train.py
31
train.py
|
@ -314,3 +314,34 @@ def run(_config, n_classes, n_epochs, input_height,
|
||||||
with open(os.path.join( os.path.join(dir_output,'model_best'), "config.json"), "w") as fp:
|
with open(os.path.join( os.path.join(dir_output,'model_best'), "config.json"), "w") as fp:
|
||||||
json.dump(_config, fp) # encode dict into JSON
|
json.dump(_config, fp) # encode dict into JSON
|
||||||
|
|
||||||
|
elif task=='reading_order':
|
||||||
|
configuration()
|
||||||
|
model = machine_based_reading_order_model(n_classes,input_height,input_width,weight_decay,pretraining)
|
||||||
|
|
||||||
|
dir_flow_train_imgs = os.path.join(dir_train, 'images')
|
||||||
|
dir_flow_train_labels = os.path.join(dir_train, 'labels')
|
||||||
|
|
||||||
|
classes = os.listdir(dir_flow_train_labels)
|
||||||
|
num_rows =len(classes)
|
||||||
|
#ls_test = os.listdir(dir_flow_train_labels)
|
||||||
|
|
||||||
|
#f1score_tot = [0]
|
||||||
|
indexer_start = 0
|
||||||
|
opt = SGD(lr=0.01, momentum=0.9)
|
||||||
|
opt_adam = tf.keras.optimizers.Adam(learning_rate=0.0001)
|
||||||
|
model.compile(loss="binary_crossentropy",
|
||||||
|
optimizer = opt_adam,metrics=['accuracy'])
|
||||||
|
for i in range(n_epochs):
|
||||||
|
history = model.fit(generate_arrays_from_folder_reading_order(dir_flow_train_labels, dir_flow_train_imgs, n_batch, input_height, input_width, n_classes), steps_per_epoch=num_rows / n_batch, verbose=1)
|
||||||
|
model.save( os.path.join(dir_output,'model_'+str(i+indexer_start) ))
|
||||||
|
|
||||||
|
with open(os.path.join(os.path.join(dir_output,'model_'+str(i)),"config.json"), "w") as fp:
|
||||||
|
json.dump(_config, fp) # encode dict into JSON
|
||||||
|
'''
|
||||||
|
if f1score>f1score_tot[0]:
|
||||||
|
f1score_tot[0] = f1score
|
||||||
|
model_dir = os.path.join(dir_out,'model_best')
|
||||||
|
model.save(model_dir)
|
||||||
|
'''
|
||||||
|
|
||||||
|
|
||||||
|
|
23
utils.py
23
utils.py
|
@ -268,6 +268,29 @@ def IoU(Yi, y_predi):
|
||||||
#print("Mean IoU: {:4.3f}".format(mIoU))
|
#print("Mean IoU: {:4.3f}".format(mIoU))
|
||||||
return mIoU
|
return mIoU
|
||||||
|
|
||||||
|
def generate_arrays_from_folder_reading_order(classes_file_dir, modal_dir, batchsize, height, width, n_classes):
|
||||||
|
all_labels_files = os.listdir(classes_file_dir)
|
||||||
|
ret_x= np.zeros((batchsize, height, width, 3))#.astype(np.int16)
|
||||||
|
ret_y= np.zeros((batchsize, n_classes)).astype(np.int16)
|
||||||
|
batchcount = 0
|
||||||
|
while True:
|
||||||
|
for i in all_labels_files:
|
||||||
|
file_name = i.split('.')[0]
|
||||||
|
img = cv2.imread(os.path.join(modal_dir,file_name+'.png'))
|
||||||
|
|
||||||
|
label_class = int( np.load(os.path.join(classes_file_dir,i)) )
|
||||||
|
|
||||||
|
ret_x[batchcount, :,:,0] = img[:,:,0]/3.0
|
||||||
|
ret_x[batchcount, :,:,2] = img[:,:,2]/3.0
|
||||||
|
ret_x[batchcount, :,:,1] = img[:,:,1]/5.0
|
||||||
|
|
||||||
|
ret_y[batchcount, :] = label_class
|
||||||
|
batchcount+=1
|
||||||
|
if batchcount>=batchsize:
|
||||||
|
yield (ret_x, ret_y)
|
||||||
|
ret_x= np.zeros((batchsize, height, width, 3))#.astype(np.int16)
|
||||||
|
ret_y= np.zeros((batchsize, n_classes)).astype(np.int16)
|
||||||
|
batchcount = 0
|
||||||
|
|
||||||
def data_gen(img_folder, mask_folder, batch_size, input_height, input_width, n_classes, task='segmentation'):
|
def data_gen(img_folder, mask_folder, batch_size, input_height, input_width, n_classes, task='segmentation'):
|
||||||
c = 0
|
c = 0
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue