mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-02-20 16:32:03 +01:00
training: download pretrained RESNET weights if missing
This commit is contained in:
parent
6a81db934e
commit
eb92760f73
2 changed files with 24 additions and 8 deletions
|
|
@ -12,7 +12,10 @@ from tensorflow.keras.regularizers import l2
|
|||
###projection_dim = 64
|
||||
##transformer_layers = 2#8
|
||||
##num_heads = 1#4
|
||||
resnet50_Weights_path = './pretrained_model/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5'
|
||||
RESNET50_WEIGHTS_PATH = './pretrained_model/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5'
|
||||
RESNET50_WEIGHTS_URL = ('https://github.com/fchollet/deep-learning-models/releases/download/v0.2/'
|
||||
'resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5')
|
||||
|
||||
IMAGE_ORDERING = 'channels_last'
|
||||
MERGE_AXIS = -1
|
||||
|
||||
|
|
@ -242,7 +245,7 @@ def resnet50_unet_light(n_classes, input_height=224, input_width=224, taks="segm
|
|||
f5 = x
|
||||
|
||||
if pretraining:
|
||||
model = Model(img_input, x).load_weights(resnet50_Weights_path)
|
||||
model = Model(img_input, x).load_weights(RESNET50_WEIGHTS_PATH)
|
||||
|
||||
v512_2048 = Conv2D(512, (1, 1), padding='same', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay))(f5)
|
||||
v512_2048 = (BatchNormalization(axis=bn_axis))(v512_2048)
|
||||
|
|
@ -343,7 +346,7 @@ def resnet50_unet(n_classes, input_height=224, input_width=224, task="segmentati
|
|||
f5 = x
|
||||
|
||||
if pretraining:
|
||||
Model(img_input, x).load_weights(resnet50_Weights_path)
|
||||
Model(img_input, x).load_weights(RESNET50_WEIGHTS_PATH)
|
||||
|
||||
v1024_2048 = Conv2D(1024, (1, 1), padding='same', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay))(
|
||||
f5)
|
||||
|
|
@ -442,7 +445,7 @@ def vit_resnet50_unet(n_classes, patch_size_x, patch_size_y, num_patches, mlp_he
|
|||
f5 = x
|
||||
|
||||
if pretraining:
|
||||
model = Model(inputs, x).load_weights(resnet50_Weights_path)
|
||||
model = Model(inputs, x).load_weights(RESNET50_WEIGHTS_PATH)
|
||||
|
||||
#num_patches = x.shape[1]*x.shape[2]
|
||||
|
||||
|
|
@ -590,7 +593,7 @@ def vit_resnet50_unet_transformer_before_cnn(n_classes, patch_size_x, patch_size
|
|||
f5 = x
|
||||
|
||||
if pretraining:
|
||||
model = Model(encoded_patches, x).load_weights(resnet50_Weights_path)
|
||||
model = Model(encoded_patches, x).load_weights(RESNET50_WEIGHTS_PATH)
|
||||
|
||||
v1024_2048 = Conv2D( 1024 , (1, 1), padding='same', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay))(x)
|
||||
v1024_2048 = (BatchNormalization(axis=bn_axis))(v1024_2048)
|
||||
|
|
@ -690,7 +693,7 @@ def resnet50_classifier(n_classes,input_height=224,input_width=224,weight_decay=
|
|||
f5 = x
|
||||
|
||||
if pretraining:
|
||||
Model(img_input, x).load_weights(resnet50_Weights_path)
|
||||
Model(img_input, x).load_weights(RESNET50_WEIGHTS_PATH)
|
||||
|
||||
x = AveragePooling2D((7, 7), name='avg_pool')(x)
|
||||
x = Flatten()(x)
|
||||
|
|
@ -746,7 +749,7 @@ def machine_based_reading_order_model(n_classes,input_height=224,input_width=224
|
|||
x1 = identity_block(x1, 3, [512, 512, 2048], stage=5, block='c')
|
||||
|
||||
if pretraining:
|
||||
Model(img_input , x1).load_weights(resnet50_Weights_path)
|
||||
Model(img_input , x1).load_weights(RESNET50_WEIGHTS_PATH)
|
||||
|
||||
x1 = AveragePooling2D((7, 7), name='avg_pool1')(x1)
|
||||
flattened = Flatten()(x1)
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ import os
|
|||
import sys
|
||||
import json
|
||||
|
||||
import requests
|
||||
import click
|
||||
|
||||
from eynollah.training.metrics import (
|
||||
|
|
@ -15,7 +16,9 @@ from eynollah.training.models import (
|
|||
resnet50_classifier,
|
||||
resnet50_unet,
|
||||
vit_resnet50_unet,
|
||||
vit_resnet50_unet_transformer_before_cnn
|
||||
vit_resnet50_unet_transformer_before_cnn,
|
||||
RESNET50_WEIGHTS_PATH,
|
||||
RESNET50_WEIGHTS_URL
|
||||
)
|
||||
from eynollah.training.utils import (
|
||||
data_gen,
|
||||
|
|
@ -80,6 +83,12 @@ def get_dirs_or_files(input_data):
|
|||
assert os.path.isdir(labels_input), "{} is not a directory".format(labels_input)
|
||||
return image_input, labels_input
|
||||
|
||||
def download_file(url, path):
|
||||
with open(path, 'wb') as f:
|
||||
with requests.get(url, stream=True) as r:
|
||||
r.raise_for_status()
|
||||
for data in r.iter_content(chunk_size=4096):
|
||||
f.write(data)
|
||||
|
||||
ex = Experiment(save_git_info=False)
|
||||
|
||||
|
|
@ -163,6 +172,10 @@ def run(_config, n_classes, n_epochs, input_height,
|
|||
transformer_patchsize_x, transformer_patchsize_y,
|
||||
transformer_num_patches_xy, backbone_type, save_interval, flip_index, dir_eval, dir_output,
|
||||
pretraining, learning_rate, task, f1_threshold_classification, classification_classes_name, dir_img_bin, number_of_backgrounds_per_image,dir_rgb_backgrounds, dir_rgb_foregrounds):
|
||||
|
||||
if pretraining and not os.path.isfile(RESNET50_WEIGHTS_PATH):
|
||||
print("downloading RESNET50 pretrained weights to", RESNET50_WEIGHTS_PATH)
|
||||
download_file(RESNET50_WEIGHTS_URL, RESNET50_WEIGHTS_PATH)
|
||||
|
||||
if dir_rgb_backgrounds:
|
||||
list_all_possible_background_images = os.listdir(dir_rgb_backgrounds)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue