diff --git a/qurator/eynollah/eynollah.py b/qurator/eynollah/eynollah.py index ff3ceac..d6f70c3 100644 --- a/qurator/eynollah/eynollah.py +++ b/qurator/eynollah/eynollah.py @@ -515,6 +515,9 @@ class Eynollah: gpu_options = tf.compat.v1.GPUOptions(allow_growth=True) #gpu_options = tf.compat.v1.GPUOptions(per_process_gpu_memory_fraction=7.7, allow_growth=True) session = tf.compat.v1.Session(config=tf.compat.v1.ConfigProto(gpu_options=gpu_options)) + if model_dir.endswith('.h5') and Path(model_dir[:-3]).exists(): + # prefer SavedModel over HDF5 format if it exists + model_dir = model_dir[:-3] model = load_model(model_dir, compile=False) return model, session