diff --git a/src/eynollah/training/models.py b/src/eynollah/training/models.py index 3b38fe8..011c614 100644 --- a/src/eynollah/training/models.py +++ b/src/eynollah/training/models.py @@ -12,7 +12,10 @@ from tensorflow.keras.regularizers import l2 ###projection_dim = 64 ##transformer_layers = 2#8 ##num_heads = 1#4 -resnet50_Weights_path = './pretrained_model/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5' +RESNET50_WEIGHTS_PATH = './pretrained_model/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5' +RESNET50_WEIGHTS_URL = ('https://github.com/fchollet/deep-learning-models/releases/download/v0.2/' + 'resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5') + IMAGE_ORDERING = 'channels_last' MERGE_AXIS = -1 @@ -242,7 +245,7 @@ def resnet50_unet_light(n_classes, input_height=224, input_width=224, taks="segm f5 = x if pretraining: - model = Model(img_input, x).load_weights(resnet50_Weights_path) + model = Model(img_input, x).load_weights(RESNET50_WEIGHTS_PATH) v512_2048 = Conv2D(512, (1, 1), padding='same', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay))(f5) v512_2048 = (BatchNormalization(axis=bn_axis))(v512_2048) @@ -343,7 +346,7 @@ def resnet50_unet(n_classes, input_height=224, input_width=224, task="segmentati f5 = x if pretraining: - Model(img_input, x).load_weights(resnet50_Weights_path) + Model(img_input, x).load_weights(RESNET50_WEIGHTS_PATH) v1024_2048 = Conv2D(1024, (1, 1), padding='same', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay))( f5) @@ -442,7 +445,7 @@ def vit_resnet50_unet(n_classes, patch_size_x, patch_size_y, num_patches, mlp_he f5 = x if pretraining: - model = Model(inputs, x).load_weights(resnet50_Weights_path) + model = Model(inputs, x).load_weights(RESNET50_WEIGHTS_PATH) #num_patches = x.shape[1]*x.shape[2] @@ -590,7 +593,7 @@ def vit_resnet50_unet_transformer_before_cnn(n_classes, patch_size_x, patch_size f5 = x if pretraining: - model = Model(encoded_patches, x).load_weights(resnet50_Weights_path) + model = Model(encoded_patches, x).load_weights(RESNET50_WEIGHTS_PATH) v1024_2048 = Conv2D( 1024 , (1, 1), padding='same', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay))(x) v1024_2048 = (BatchNormalization(axis=bn_axis))(v1024_2048) @@ -690,7 +693,7 @@ def resnet50_classifier(n_classes,input_height=224,input_width=224,weight_decay= f5 = x if pretraining: - Model(img_input, x).load_weights(resnet50_Weights_path) + Model(img_input, x).load_weights(RESNET50_WEIGHTS_PATH) x = AveragePooling2D((7, 7), name='avg_pool')(x) x = Flatten()(x) @@ -746,7 +749,7 @@ def machine_based_reading_order_model(n_classes,input_height=224,input_width=224 x1 = identity_block(x1, 3, [512, 512, 2048], stage=5, block='c') if pretraining: - Model(img_input , x1).load_weights(resnet50_Weights_path) + Model(img_input , x1).load_weights(RESNET50_WEIGHTS_PATH) x1 = AveragePooling2D((7, 7), name='avg_pool1')(x1) flattened = Flatten()(x1) diff --git a/src/eynollah/training/train.py b/src/eynollah/training/train.py index 7ee63f9..6353474 100644 --- a/src/eynollah/training/train.py +++ b/src/eynollah/training/train.py @@ -2,6 +2,7 @@ import os import sys import json +import requests import click from eynollah.training.metrics import ( @@ -15,7 +16,9 @@ from eynollah.training.models import ( resnet50_classifier, resnet50_unet, vit_resnet50_unet, - vit_resnet50_unet_transformer_before_cnn + vit_resnet50_unet_transformer_before_cnn, + RESNET50_WEIGHTS_PATH, + RESNET50_WEIGHTS_URL ) from eynollah.training.utils import ( data_gen, @@ -80,6 +83,12 @@ def get_dirs_or_files(input_data): assert os.path.isdir(labels_input), "{} is not a directory".format(labels_input) return image_input, labels_input +def download_file(url, path): + with open(path, 'wb') as f: + with requests.get(url, stream=True) as r: + r.raise_for_status() + for data in r.iter_content(chunk_size=4096): + f.write(data) ex = Experiment(save_git_info=False) @@ -163,6 +172,10 @@ def run(_config, n_classes, n_epochs, input_height, transformer_patchsize_x, transformer_patchsize_y, transformer_num_patches_xy, backbone_type, save_interval, flip_index, dir_eval, dir_output, pretraining, learning_rate, task, f1_threshold_classification, classification_classes_name, dir_img_bin, number_of_backgrounds_per_image,dir_rgb_backgrounds, dir_rgb_foregrounds): + + if pretraining and not os.path.isfile(RESNET50_WEIGHTS_PATH): + print("downloading RESNET50 pretrained weights to", RESNET50_WEIGHTS_PATH) + download_file(RESNET50_WEIGHTS_URL, RESNET50_WEIGHTS_PATH) if dir_rgb_backgrounds: list_all_possible_background_images = os.listdir(dir_rgb_backgrounds)