From 522f00ab9914d2c8a7345db8fffec8f062930c23 Mon Sep 17 00:00:00 2001 From: vahidrezanezhad Date: Thu, 4 Apr 2024 11:26:28 +0200 Subject: [PATCH] adjusting to tf2 --- ..._model_load_pretrained_weights_and_save.py | 4 ++-- metrics.py | 2 +- models.py | 8 +++---- train.py | 24 +++++++------------ 4 files changed, 15 insertions(+), 23 deletions(-) diff --git a/build_model_load_pretrained_weights_and_save.py b/build_model_load_pretrained_weights_and_save.py index 251e698..3b1a577 100644 --- a/build_model_load_pretrained_weights_and_save.py +++ b/build_model_load_pretrained_weights_and_save.py @@ -1,8 +1,8 @@ import os import sys import tensorflow as tf -import keras , warnings -from keras.optimizers import * +import warnings +from tensorflow.keras.optimizers import * from sacred import Experiment from models import * from utils import * diff --git a/metrics.py b/metrics.py index c63cc22..1768960 100644 --- a/metrics.py +++ b/metrics.py @@ -1,4 +1,4 @@ -from keras import backend as K +from tensorflow.keras import backend as K import tensorflow as tf import numpy as np diff --git a/models.py b/models.py index 7c806b4..40a21a1 100644 --- a/models.py +++ b/models.py @@ -1,7 +1,7 @@ -from keras.models import * -from keras.layers import * -from keras import layers -from keras.regularizers import l2 +from tensorflow.keras.models import * +from tensorflow.keras.layers import * +from tensorflow.keras import layers +from tensorflow.keras.regularizers import l2 resnet50_Weights_path='./pretrained_model/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5' IMAGE_ORDERING ='channels_last' diff --git a/train.py b/train.py index 0cc5ef3..142b79b 100644 --- a/train.py +++ b/train.py @@ -1,29 +1,21 @@ import os import sys import tensorflow as tf -from keras.backend.tensorflow_backend import set_session -import keras , warnings -from keras.optimizers import * +from tensorflow.compat.v1.keras.backend import set_session +import warnings +from tensorflow.keras.optimizers import * from sacred import Experiment from models import * from utils import * from metrics import * -from keras.models import load_model +from tensorflow.keras.models import load_model from tqdm import tqdm def configuration(): - keras.backend.clear_session() - tf.reset_default_graph() - warnings.filterwarnings('ignore') - - os.environ['CUDA_DEVICE_ORDER']='PCI_BUS_ID' - config = tf.ConfigProto(log_device_placement=False, allow_soft_placement=True) - - + config = tf.compat.v1.ConfigProto() config.gpu_options.allow_growth = True - config.gpu_options.per_process_gpu_memory_fraction=0.95#0.95 - config.gpu_options.visible_device_list="0" - set_session(tf.Session(config=config)) + session = tf.compat.v1.Session(config=config) + set_session(session) def get_dirs_or_files(input_data): if os.path.isdir(input_data): @@ -219,7 +211,7 @@ def run(n_classes,n_epochs,input_height, validation_data=val_gen, validation_steps=1, epochs=1) - model.save(dir_output+'/'+'model_'+str(i)+'.h5') + model.save(dir_output+'/'+'model_'+str(i)) #os.system('rm -rf '+dir_train_flowing)