do not reload enhancement model in dir_in mode, simplify

pull/142/head
Robert Sachunsky 3 weeks ago
parent 3b9a29bc5c
commit 329fac23f6

@ -363,10 +363,11 @@ class Eynollah:
def predict_enhancement(self, img): def predict_enhancement(self, img):
self.logger.debug("enter predict_enhancement") self.logger.debug("enter predict_enhancement")
model_enhancement, session_enhancement = self.start_new_session_and_model(self.model_dir_of_enhancement) if not self.dir_in:
self.model_enhancement, _ = self.start_new_session_and_model(self.model_dir_of_enhancement)
img_height_model = model_enhancement.layers[len(model_enhancement.layers) - 1].output_shape[1] img_height_model = self.model_enhancement.layers[len(self.model_enhancement.layers) - 1].output_shape[1]
img_width_model = model_enhancement.layers[len(model_enhancement.layers) - 1].output_shape[2] img_width_model = self.model_enhancement.layers[len(self.model_enhancement.layers) - 1].output_shape[2]
if img.shape[0] < img_height_model: if img.shape[0] < img_height_model:
img = cv2.resize(img, (img.shape[1], img_width_model), interpolation=cv2.INTER_NEAREST) img = cv2.resize(img, (img.shape[1], img_width_model), interpolation=cv2.INTER_NEAREST)
@ -409,9 +410,8 @@ class Eynollah:
index_y_u = img_h index_y_u = img_h
index_y_d = img_h - img_height_model index_y_d = img_h - img_height_model
img_patch = img[index_y_d:index_y_u, index_x_d:index_x_u, :] img_patch = img[np.newaxis, index_y_d:index_y_u, index_x_d:index_x_u, :]
label_p_pred = model_enhancement.predict(img_patch.reshape(1, img_patch.shape[0], img_patch.shape[1], img_patch.shape[2]), label_p_pred = self.model_enhancement.predict(img_patch, verbose=0)
verbose=0)
seg = label_p_pred[0, :, :, :] seg = label_p_pred[0, :, :, :]
seg = seg * 255 seg = seg * 255

Loading…
Cancel
Save