training: download pretrained RESNET weights if missing

This commit is contained in:
Robert Sachunsky 2026-01-22 19:49:39 +01:00
parent 6a81db934e
commit eb92760f73
2 changed files with 24 additions and 8 deletions

View file

@ -12,7 +12,10 @@ from tensorflow.keras.regularizers import l2
###projection_dim = 64 ###projection_dim = 64
##transformer_layers = 2#8 ##transformer_layers = 2#8
##num_heads = 1#4 ##num_heads = 1#4
resnet50_Weights_path = './pretrained_model/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5' RESNET50_WEIGHTS_PATH = './pretrained_model/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5'
RESNET50_WEIGHTS_URL = ('https://github.com/fchollet/deep-learning-models/releases/download/v0.2/'
'resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5')
IMAGE_ORDERING = 'channels_last' IMAGE_ORDERING = 'channels_last'
MERGE_AXIS = -1 MERGE_AXIS = -1
@ -242,7 +245,7 @@ def resnet50_unet_light(n_classes, input_height=224, input_width=224, taks="segm
f5 = x f5 = x
if pretraining: if pretraining:
model = Model(img_input, x).load_weights(resnet50_Weights_path) model = Model(img_input, x).load_weights(RESNET50_WEIGHTS_PATH)
v512_2048 = Conv2D(512, (1, 1), padding='same', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay))(f5) v512_2048 = Conv2D(512, (1, 1), padding='same', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay))(f5)
v512_2048 = (BatchNormalization(axis=bn_axis))(v512_2048) v512_2048 = (BatchNormalization(axis=bn_axis))(v512_2048)
@ -343,7 +346,7 @@ def resnet50_unet(n_classes, input_height=224, input_width=224, task="segmentati
f5 = x f5 = x
if pretraining: if pretraining:
Model(img_input, x).load_weights(resnet50_Weights_path) Model(img_input, x).load_weights(RESNET50_WEIGHTS_PATH)
v1024_2048 = Conv2D(1024, (1, 1), padding='same', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay))( v1024_2048 = Conv2D(1024, (1, 1), padding='same', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay))(
f5) f5)
@ -442,7 +445,7 @@ def vit_resnet50_unet(n_classes, patch_size_x, patch_size_y, num_patches, mlp_he
f5 = x f5 = x
if pretraining: if pretraining:
model = Model(inputs, x).load_weights(resnet50_Weights_path) model = Model(inputs, x).load_weights(RESNET50_WEIGHTS_PATH)
#num_patches = x.shape[1]*x.shape[2] #num_patches = x.shape[1]*x.shape[2]
@ -590,7 +593,7 @@ def vit_resnet50_unet_transformer_before_cnn(n_classes, patch_size_x, patch_size
f5 = x f5 = x
if pretraining: if pretraining:
model = Model(encoded_patches, x).load_weights(resnet50_Weights_path) model = Model(encoded_patches, x).load_weights(RESNET50_WEIGHTS_PATH)
v1024_2048 = Conv2D( 1024 , (1, 1), padding='same', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay))(x) v1024_2048 = Conv2D( 1024 , (1, 1), padding='same', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay))(x)
v1024_2048 = (BatchNormalization(axis=bn_axis))(v1024_2048) v1024_2048 = (BatchNormalization(axis=bn_axis))(v1024_2048)
@ -690,7 +693,7 @@ def resnet50_classifier(n_classes,input_height=224,input_width=224,weight_decay=
f5 = x f5 = x
if pretraining: if pretraining:
Model(img_input, x).load_weights(resnet50_Weights_path) Model(img_input, x).load_weights(RESNET50_WEIGHTS_PATH)
x = AveragePooling2D((7, 7), name='avg_pool')(x) x = AveragePooling2D((7, 7), name='avg_pool')(x)
x = Flatten()(x) x = Flatten()(x)
@ -746,7 +749,7 @@ def machine_based_reading_order_model(n_classes,input_height=224,input_width=224
x1 = identity_block(x1, 3, [512, 512, 2048], stage=5, block='c') x1 = identity_block(x1, 3, [512, 512, 2048], stage=5, block='c')
if pretraining: if pretraining:
Model(img_input , x1).load_weights(resnet50_Weights_path) Model(img_input , x1).load_weights(RESNET50_WEIGHTS_PATH)
x1 = AveragePooling2D((7, 7), name='avg_pool1')(x1) x1 = AveragePooling2D((7, 7), name='avg_pool1')(x1)
flattened = Flatten()(x1) flattened = Flatten()(x1)

View file

@ -2,6 +2,7 @@ import os
import sys import sys
import json import json
import requests
import click import click
from eynollah.training.metrics import ( from eynollah.training.metrics import (
@ -15,7 +16,9 @@ from eynollah.training.models import (
resnet50_classifier, resnet50_classifier,
resnet50_unet, resnet50_unet,
vit_resnet50_unet, vit_resnet50_unet,
vit_resnet50_unet_transformer_before_cnn vit_resnet50_unet_transformer_before_cnn,
RESNET50_WEIGHTS_PATH,
RESNET50_WEIGHTS_URL
) )
from eynollah.training.utils import ( from eynollah.training.utils import (
data_gen, data_gen,
@ -80,6 +83,12 @@ def get_dirs_or_files(input_data):
assert os.path.isdir(labels_input), "{} is not a directory".format(labels_input) assert os.path.isdir(labels_input), "{} is not a directory".format(labels_input)
return image_input, labels_input return image_input, labels_input
def download_file(url, path):
with open(path, 'wb') as f:
with requests.get(url, stream=True) as r:
r.raise_for_status()
for data in r.iter_content(chunk_size=4096):
f.write(data)
ex = Experiment(save_git_info=False) ex = Experiment(save_git_info=False)
@ -163,6 +172,10 @@ def run(_config, n_classes, n_epochs, input_height,
transformer_patchsize_x, transformer_patchsize_y, transformer_patchsize_x, transformer_patchsize_y,
transformer_num_patches_xy, backbone_type, save_interval, flip_index, dir_eval, dir_output, transformer_num_patches_xy, backbone_type, save_interval, flip_index, dir_eval, dir_output,
pretraining, learning_rate, task, f1_threshold_classification, classification_classes_name, dir_img_bin, number_of_backgrounds_per_image,dir_rgb_backgrounds, dir_rgb_foregrounds): pretraining, learning_rate, task, f1_threshold_classification, classification_classes_name, dir_img_bin, number_of_backgrounds_per_image,dir_rgb_backgrounds, dir_rgb_foregrounds):
if pretraining and not os.path.isfile(RESNET50_WEIGHTS_PATH):
print("downloading RESNET50 pretrained weights to", RESNET50_WEIGHTS_PATH)
download_file(RESNET50_WEIGHTS_URL, RESNET50_WEIGHTS_PATH)
if dir_rgb_backgrounds: if dir_rgb_backgrounds:
list_all_possible_background_images = os.listdir(dir_rgb_backgrounds) list_all_possible_background_images = os.listdir(dir_rgb_backgrounds)