mirror of
https://github.com/qurator-spk/sbb_binarization.git
synced 2025-06-09 12:19:56 +02:00
allow passing image directly, return image on binarize
This commit is contained in:
parent
389ef088d0
commit
ca03844c2b
2 changed files with 22 additions and 11 deletions
|
@ -16,7 +16,12 @@ def main():
|
|||
|
||||
options = parser.parse_args()
|
||||
|
||||
binarizer = SbbBinarizer(options.image, options.model, options.patches, options.save)
|
||||
binarizer = SbbBinarizer(
|
||||
image_path=options.image,
|
||||
model=options.model,
|
||||
patches=options.patches,
|
||||
save=options.save
|
||||
)
|
||||
binarizer.run()
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -15,25 +15,30 @@ import tensorflow as tf
|
|||
with catch_warnings():
|
||||
simplefilter("ignore")
|
||||
|
||||
def resize_image(img_in, input_height, input_width):
|
||||
return cv2.resize(img_in, (input_width, input_height), interpolation=cv2.INTER_NEAREST)
|
||||
|
||||
class SbbBinarizer:
|
||||
|
||||
# TODO use True/False for patches
|
||||
def __init__(self, image, model, patches='false', save=None):
|
||||
self.image = image
|
||||
def __init__(self, model, image=None, image_path=None, patches='false', save=None):
|
||||
if not(image or image_path) or (image and image_path):
|
||||
raise ValueError("Must pass either a PIL image or an image_path")
|
||||
if image:
|
||||
self.image = image
|
||||
else:
|
||||
self.image = cv2.imread(self.image)
|
||||
self.patches = patches
|
||||
self.save = save
|
||||
self.model_dir = model
|
||||
|
||||
def resize_image(self,img_in,input_height,input_width):
|
||||
return cv2.resize(img_in, (input_width, input_height), interpolation=cv2.INTER_NEAREST)
|
||||
|
||||
def start_new_session_and_model(self):
|
||||
config = tf.ConfigProto()
|
||||
config.gpu_options.allow_growth = True
|
||||
|
||||
self.session = tf.Session(config=config) # tf.InteractiveSession()
|
||||
|
||||
def load_model(self,model_name):
|
||||
def load_model(self, model_name):
|
||||
|
||||
self.model = load_model(join(self.model_dir, model_name), compile=False)
|
||||
|
||||
|
@ -48,11 +53,11 @@ class SbbBinarizer:
|
|||
|
||||
def predict(self,model_name):
|
||||
self.load_model(model_name)
|
||||
img = cv2.imread(self.image)
|
||||
img = self.image
|
||||
img_width_model = self.img_width
|
||||
img_height_model = self.img_height
|
||||
|
||||
if self.patches=='true' or self.patches=='True':
|
||||
if self.patches in ('true', 'True'):
|
||||
|
||||
margin = int(0.1 * img_width_model)
|
||||
|
||||
|
@ -181,14 +186,14 @@ class SbbBinarizer:
|
|||
img_h_page = img.shape[0]
|
||||
img_w_page = img.shape[1]
|
||||
img = img / float(255.0)
|
||||
img = self.resize_image(img, img_height_model, img_width_model)
|
||||
img = resize_image(img, img_height_model, img_width_model)
|
||||
|
||||
label_p_pred = self.model.predict(
|
||||
img.reshape(1, img.shape[0], img.shape[1], img.shape[2]))
|
||||
|
||||
seg = np.argmax(label_p_pred, axis=3)[0]
|
||||
seg_color = np.repeat(seg[:, :, np.newaxis], 3, axis=2)
|
||||
prediction_true = self.resize_image(seg_color, img_h_page, img_w_page)
|
||||
prediction_true = resize_image(seg_color, img_h_page, img_w_page)
|
||||
prediction_true = prediction_true.astype(np.uint8)
|
||||
return prediction_true[:,:,0]
|
||||
|
||||
|
@ -217,3 +222,4 @@ class SbbBinarizer:
|
|||
img_last = (img_last[:, :] == 0)*255
|
||||
if self.save:
|
||||
cv2.imwrite(self.save, img_last)
|
||||
return img_last
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue