mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-02-20 16:32:03 +01:00
training: download pretrained RESNET weights if missing
This commit is contained in:
parent
6a81db934e
commit
eb92760f73
2 changed files with 24 additions and 8 deletions
|
|
@ -12,7 +12,10 @@ from tensorflow.keras.regularizers import l2
|
||||||
###projection_dim = 64
|
###projection_dim = 64
|
||||||
##transformer_layers = 2#8
|
##transformer_layers = 2#8
|
||||||
##num_heads = 1#4
|
##num_heads = 1#4
|
||||||
resnet50_Weights_path = './pretrained_model/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5'
|
RESNET50_WEIGHTS_PATH = './pretrained_model/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5'
|
||||||
|
RESNET50_WEIGHTS_URL = ('https://github.com/fchollet/deep-learning-models/releases/download/v0.2/'
|
||||||
|
'resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5')
|
||||||
|
|
||||||
IMAGE_ORDERING = 'channels_last'
|
IMAGE_ORDERING = 'channels_last'
|
||||||
MERGE_AXIS = -1
|
MERGE_AXIS = -1
|
||||||
|
|
||||||
|
|
@ -242,7 +245,7 @@ def resnet50_unet_light(n_classes, input_height=224, input_width=224, taks="segm
|
||||||
f5 = x
|
f5 = x
|
||||||
|
|
||||||
if pretraining:
|
if pretraining:
|
||||||
model = Model(img_input, x).load_weights(resnet50_Weights_path)
|
model = Model(img_input, x).load_weights(RESNET50_WEIGHTS_PATH)
|
||||||
|
|
||||||
v512_2048 = Conv2D(512, (1, 1), padding='same', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay))(f5)
|
v512_2048 = Conv2D(512, (1, 1), padding='same', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay))(f5)
|
||||||
v512_2048 = (BatchNormalization(axis=bn_axis))(v512_2048)
|
v512_2048 = (BatchNormalization(axis=bn_axis))(v512_2048)
|
||||||
|
|
@ -343,7 +346,7 @@ def resnet50_unet(n_classes, input_height=224, input_width=224, task="segmentati
|
||||||
f5 = x
|
f5 = x
|
||||||
|
|
||||||
if pretraining:
|
if pretraining:
|
||||||
Model(img_input, x).load_weights(resnet50_Weights_path)
|
Model(img_input, x).load_weights(RESNET50_WEIGHTS_PATH)
|
||||||
|
|
||||||
v1024_2048 = Conv2D(1024, (1, 1), padding='same', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay))(
|
v1024_2048 = Conv2D(1024, (1, 1), padding='same', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay))(
|
||||||
f5)
|
f5)
|
||||||
|
|
@ -442,7 +445,7 @@ def vit_resnet50_unet(n_classes, patch_size_x, patch_size_y, num_patches, mlp_he
|
||||||
f5 = x
|
f5 = x
|
||||||
|
|
||||||
if pretraining:
|
if pretraining:
|
||||||
model = Model(inputs, x).load_weights(resnet50_Weights_path)
|
model = Model(inputs, x).load_weights(RESNET50_WEIGHTS_PATH)
|
||||||
|
|
||||||
#num_patches = x.shape[1]*x.shape[2]
|
#num_patches = x.shape[1]*x.shape[2]
|
||||||
|
|
||||||
|
|
@ -590,7 +593,7 @@ def vit_resnet50_unet_transformer_before_cnn(n_classes, patch_size_x, patch_size
|
||||||
f5 = x
|
f5 = x
|
||||||
|
|
||||||
if pretraining:
|
if pretraining:
|
||||||
model = Model(encoded_patches, x).load_weights(resnet50_Weights_path)
|
model = Model(encoded_patches, x).load_weights(RESNET50_WEIGHTS_PATH)
|
||||||
|
|
||||||
v1024_2048 = Conv2D( 1024 , (1, 1), padding='same', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay))(x)
|
v1024_2048 = Conv2D( 1024 , (1, 1), padding='same', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay))(x)
|
||||||
v1024_2048 = (BatchNormalization(axis=bn_axis))(v1024_2048)
|
v1024_2048 = (BatchNormalization(axis=bn_axis))(v1024_2048)
|
||||||
|
|
@ -690,7 +693,7 @@ def resnet50_classifier(n_classes,input_height=224,input_width=224,weight_decay=
|
||||||
f5 = x
|
f5 = x
|
||||||
|
|
||||||
if pretraining:
|
if pretraining:
|
||||||
Model(img_input, x).load_weights(resnet50_Weights_path)
|
Model(img_input, x).load_weights(RESNET50_WEIGHTS_PATH)
|
||||||
|
|
||||||
x = AveragePooling2D((7, 7), name='avg_pool')(x)
|
x = AveragePooling2D((7, 7), name='avg_pool')(x)
|
||||||
x = Flatten()(x)
|
x = Flatten()(x)
|
||||||
|
|
@ -746,7 +749,7 @@ def machine_based_reading_order_model(n_classes,input_height=224,input_width=224
|
||||||
x1 = identity_block(x1, 3, [512, 512, 2048], stage=5, block='c')
|
x1 = identity_block(x1, 3, [512, 512, 2048], stage=5, block='c')
|
||||||
|
|
||||||
if pretraining:
|
if pretraining:
|
||||||
Model(img_input , x1).load_weights(resnet50_Weights_path)
|
Model(img_input , x1).load_weights(RESNET50_WEIGHTS_PATH)
|
||||||
|
|
||||||
x1 = AveragePooling2D((7, 7), name='avg_pool1')(x1)
|
x1 = AveragePooling2D((7, 7), name='avg_pool1')(x1)
|
||||||
flattened = Flatten()(x1)
|
flattened = Flatten()(x1)
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@ import os
|
||||||
import sys
|
import sys
|
||||||
import json
|
import json
|
||||||
|
|
||||||
|
import requests
|
||||||
import click
|
import click
|
||||||
|
|
||||||
from eynollah.training.metrics import (
|
from eynollah.training.metrics import (
|
||||||
|
|
@ -15,7 +16,9 @@ from eynollah.training.models import (
|
||||||
resnet50_classifier,
|
resnet50_classifier,
|
||||||
resnet50_unet,
|
resnet50_unet,
|
||||||
vit_resnet50_unet,
|
vit_resnet50_unet,
|
||||||
vit_resnet50_unet_transformer_before_cnn
|
vit_resnet50_unet_transformer_before_cnn,
|
||||||
|
RESNET50_WEIGHTS_PATH,
|
||||||
|
RESNET50_WEIGHTS_URL
|
||||||
)
|
)
|
||||||
from eynollah.training.utils import (
|
from eynollah.training.utils import (
|
||||||
data_gen,
|
data_gen,
|
||||||
|
|
@ -80,6 +83,12 @@ def get_dirs_or_files(input_data):
|
||||||
assert os.path.isdir(labels_input), "{} is not a directory".format(labels_input)
|
assert os.path.isdir(labels_input), "{} is not a directory".format(labels_input)
|
||||||
return image_input, labels_input
|
return image_input, labels_input
|
||||||
|
|
||||||
|
def download_file(url, path):
|
||||||
|
with open(path, 'wb') as f:
|
||||||
|
with requests.get(url, stream=True) as r:
|
||||||
|
r.raise_for_status()
|
||||||
|
for data in r.iter_content(chunk_size=4096):
|
||||||
|
f.write(data)
|
||||||
|
|
||||||
ex = Experiment(save_git_info=False)
|
ex = Experiment(save_git_info=False)
|
||||||
|
|
||||||
|
|
@ -163,6 +172,10 @@ def run(_config, n_classes, n_epochs, input_height,
|
||||||
transformer_patchsize_x, transformer_patchsize_y,
|
transformer_patchsize_x, transformer_patchsize_y,
|
||||||
transformer_num_patches_xy, backbone_type, save_interval, flip_index, dir_eval, dir_output,
|
transformer_num_patches_xy, backbone_type, save_interval, flip_index, dir_eval, dir_output,
|
||||||
pretraining, learning_rate, task, f1_threshold_classification, classification_classes_name, dir_img_bin, number_of_backgrounds_per_image,dir_rgb_backgrounds, dir_rgb_foregrounds):
|
pretraining, learning_rate, task, f1_threshold_classification, classification_classes_name, dir_img_bin, number_of_backgrounds_per_image,dir_rgb_backgrounds, dir_rgb_foregrounds):
|
||||||
|
|
||||||
|
if pretraining and not os.path.isfile(RESNET50_WEIGHTS_PATH):
|
||||||
|
print("downloading RESNET50 pretrained weights to", RESNET50_WEIGHTS_PATH)
|
||||||
|
download_file(RESNET50_WEIGHTS_URL, RESNET50_WEIGHTS_PATH)
|
||||||
|
|
||||||
if dir_rgb_backgrounds:
|
if dir_rgb_backgrounds:
|
||||||
list_all_possible_background_images = os.listdir(dir_rgb_backgrounds)
|
list_all_possible_background_images = os.listdir(dir_rgb_backgrounds)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue