diff --git a/qurator/eynollah/eynollah.py b/qurator/eynollah/eynollah.py index d587cc9..8ce50bd 100644 --- a/qurator/eynollah/eynollah.py +++ b/qurator/eynollah/eynollah.py @@ -429,7 +429,7 @@ class Eynollah: self.writer.height_org = self.height_org self.writer.width_org = self.width_org - def start_new_session_and_model(self, model_dir): + def start_new_session_and_model_old(self, model_dir): self.logger.debug("enter start_new_session_and_model (model_dir=%s)", model_dir) config = tf.ConfigProto() config.gpu_options.allow_growth = True @@ -438,7 +438,13 @@ class Eynollah: model = load_model(model_dir, compile=False) return model, session + def start_new_session_and_model(self, model_dir): + self.logger.debug("enter start_new_session_and_model (model_dir=%s)", model_dir) + gpu_options = tf.compat.v1.GPUOptions(per_process_gpu_memory_fraction=7.7, allow_growth=True) + session = tf.compat.v1.Session(config=tf.compat.v1.ConfigProto(gpu_options=gpu_options)) + model = load_model(model_dir, compile=False) + return model, session def do_prediction(self, patches, img, model, marginal_of_patch_percent=0.1): self.logger.debug("enter do_prediction")