diff --git a/src/eynollah/eynollah.py b/src/eynollah/eynollah.py index c1e0f4d..145f722 100644 --- a/src/eynollah/eynollah.py +++ b/src/eynollah/eynollah.py @@ -363,10 +363,11 @@ class Eynollah: def predict_enhancement(self, img): self.logger.debug("enter predict_enhancement") - model_enhancement, session_enhancement = self.start_new_session_and_model(self.model_dir_of_enhancement) + if not self.dir_in: + self.model_enhancement, _ = self.start_new_session_and_model(self.model_dir_of_enhancement) - img_height_model = model_enhancement.layers[len(model_enhancement.layers) - 1].output_shape[1] - img_width_model = model_enhancement.layers[len(model_enhancement.layers) - 1].output_shape[2] + img_height_model = self.model_enhancement.layers[len(self.model_enhancement.layers) - 1].output_shape[1] + img_width_model = self.model_enhancement.layers[len(self.model_enhancement.layers) - 1].output_shape[2] if img.shape[0] < img_height_model: img = cv2.resize(img, (img.shape[1], img_width_model), interpolation=cv2.INTER_NEAREST) @@ -409,9 +410,8 @@ class Eynollah: index_y_u = img_h index_y_d = img_h - img_height_model - img_patch = img[index_y_d:index_y_u, index_x_d:index_x_u, :] - label_p_pred = model_enhancement.predict(img_patch.reshape(1, img_patch.shape[0], img_patch.shape[1], img_patch.shape[2]), - verbose=0) + img_patch = img[np.newaxis, index_y_d:index_y_u, index_x_d:index_x_u, :] + label_p_pred = self.model_enhancement.predict(img_patch, verbose=0) seg = label_p_pred[0, :, :, :] seg = seg * 255