mirror of
https://github.com/qurator-spk/sbb_pixelwise_segmentation.git
synced 2025-09-16 12:19:58 +02:00
Running inference on files in a directory
This commit is contained in:
parent
8b75d46d3d
commit
e944d334e9
1 changed files with 62 additions and 24 deletions
88
inference.py
88
inference.py
|
@ -28,8 +28,9 @@ Tool to load model and predict for given image.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
class sbb_predict:
|
class sbb_predict:
|
||||||
def __init__(self,image, model, task, config_params_model, patches, save, save_layout, ground_truth, xml_file, out, min_area):
|
def __init__(self,image, dir_in, model, task, config_params_model, patches, save, save_layout, ground_truth, xml_file, out, min_area):
|
||||||
self.image=image
|
self.image=image
|
||||||
|
self.dir_in=dir_in
|
||||||
self.patches=patches
|
self.patches=patches
|
||||||
self.save=save
|
self.save=save
|
||||||
self.save_layout=save_layout
|
self.save_layout=save_layout
|
||||||
|
@ -223,11 +224,10 @@ class sbb_predict:
|
||||||
|
|
||||||
return added_image, layout_only
|
return added_image, layout_only
|
||||||
|
|
||||||
def predict(self):
|
def predict(self, image_dir):
|
||||||
self.start_new_session_and_model()
|
|
||||||
if self.task == 'classification':
|
if self.task == 'classification':
|
||||||
classes_names = self.config_params_model['classification_classes_name']
|
classes_names = self.config_params_model['classification_classes_name']
|
||||||
img_1ch = img=cv2.imread(self.image, 0)
|
img_1ch = img=cv2.imread(image_dir, 0)
|
||||||
|
|
||||||
img_1ch = img_1ch / 255.0
|
img_1ch = img_1ch / 255.0
|
||||||
img_1ch = cv2.resize(img_1ch, (self.config_params_model['input_height'], self.config_params_model['input_width']), interpolation=cv2.INTER_NEAREST)
|
img_1ch = cv2.resize(img_1ch, (self.config_params_model['input_height'], self.config_params_model['input_width']), interpolation=cv2.INTER_NEAREST)
|
||||||
|
@ -438,7 +438,7 @@ class sbb_predict:
|
||||||
if self.patches:
|
if self.patches:
|
||||||
#def textline_contours(img,input_width,input_height,n_classes,model):
|
#def textline_contours(img,input_width,input_height,n_classes,model):
|
||||||
|
|
||||||
img=cv2.imread(self.image)
|
img=cv2.imread(image_dir)
|
||||||
self.img_org = np.copy(img)
|
self.img_org = np.copy(img)
|
||||||
|
|
||||||
if img.shape[0] < self.img_height:
|
if img.shape[0] < self.img_height:
|
||||||
|
@ -529,7 +529,7 @@ class sbb_predict:
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
|
||||||
img=cv2.imread(self.image)
|
img=cv2.imread(image_dir)
|
||||||
self.img_org = np.copy(img)
|
self.img_org = np.copy(img)
|
||||||
|
|
||||||
width=self.img_width
|
width=self.img_width
|
||||||
|
@ -557,22 +557,50 @@ class sbb_predict:
|
||||||
|
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
res=self.predict()
|
self.start_new_session_and_model()
|
||||||
if (self.task == 'classification' or self.task == 'reading_order'):
|
if self.image:
|
||||||
pass
|
res=self.predict(image_dir = self.image)
|
||||||
elif self.task == 'enhancement':
|
|
||||||
if self.save:
|
if (self.task == 'classification' or self.task == 'reading_order'):
|
||||||
cv2.imwrite(self.save,res)
|
pass
|
||||||
else:
|
elif self.task == 'enhancement':
|
||||||
img_seg_overlayed, only_layout = self.visualize_model_output(res, self.img_org, self.task)
|
if self.save:
|
||||||
if self.save:
|
cv2.imwrite(self.save,res)
|
||||||
cv2.imwrite(self.save,img_seg_overlayed)
|
else:
|
||||||
if self.save_layout:
|
img_seg_overlayed, only_layout = self.visualize_model_output(res, self.img_org, self.task)
|
||||||
cv2.imwrite(self.save_layout, only_layout)
|
if self.save:
|
||||||
|
cv2.imwrite(self.save,img_seg_overlayed)
|
||||||
|
if self.save_layout:
|
||||||
|
cv2.imwrite(self.save_layout, only_layout)
|
||||||
|
|
||||||
|
if self.ground_truth:
|
||||||
|
gt_img=cv2.imread(self.ground_truth)
|
||||||
|
self.IoU(gt_img[:,:,0],res[:,:,0])
|
||||||
|
|
||||||
|
else:
|
||||||
|
ls_images = os.listdir(self.dir_in)
|
||||||
|
for ind_image in ls_images:
|
||||||
|
f_name = ind_image.split('.')[0]
|
||||||
|
image_dir = os.path.join(self.dir_in, ind_image)
|
||||||
|
res=self.predict(image_dir)
|
||||||
|
|
||||||
|
if (self.task == 'classification' or self.task == 'reading_order'):
|
||||||
|
pass
|
||||||
|
elif self.task == 'enhancement':
|
||||||
|
self.save = os.path.join(self.out, f_name+'.png')
|
||||||
|
cv2.imwrite(self.save,res)
|
||||||
|
else:
|
||||||
|
img_seg_overlayed, only_layout = self.visualize_model_output(res, self.img_org, self.task)
|
||||||
|
self.save = os.path.join(self.out, f_name+'_overlayed.png')
|
||||||
|
cv2.imwrite(self.save,img_seg_overlayed)
|
||||||
|
self.save_layout = os.path.join(self.out, f_name+'_layout.png')
|
||||||
|
cv2.imwrite(self.save_layout, only_layout)
|
||||||
|
|
||||||
|
if self.ground_truth:
|
||||||
|
gt_img=cv2.imread(self.ground_truth)
|
||||||
|
self.IoU(gt_img[:,:,0],res[:,:,0])
|
||||||
|
|
||||||
|
|
||||||
if self.ground_truth:
|
|
||||||
gt_img=cv2.imread(self.ground_truth)
|
|
||||||
self.IoU(gt_img[:,:,0],res[:,:,0])
|
|
||||||
|
|
||||||
@click.command()
|
@click.command()
|
||||||
@click.option(
|
@click.option(
|
||||||
|
@ -581,6 +609,12 @@ class sbb_predict:
|
||||||
help="image filename",
|
help="image filename",
|
||||||
type=click.Path(exists=True, dir_okay=False),
|
type=click.Path(exists=True, dir_okay=False),
|
||||||
)
|
)
|
||||||
|
@click.option(
|
||||||
|
"--dir_in",
|
||||||
|
"-di",
|
||||||
|
help="directory of images",
|
||||||
|
type=click.Path(exists=True, file_okay=False),
|
||||||
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
"--out",
|
"--out",
|
||||||
"-o",
|
"-o",
|
||||||
|
@ -626,15 +660,19 @@ class sbb_predict:
|
||||||
"-min",
|
"-min",
|
||||||
help="min area size of regions considered for reading order detection. The default value is zero and means that all text regions are considered for reading order.",
|
help="min area size of regions considered for reading order detection. The default value is zero and means that all text regions are considered for reading order.",
|
||||||
)
|
)
|
||||||
def main(image, model, patches, save, save_layout, ground_truth, xml_file, out, min_area):
|
def main(image, dir_in, model, patches, save, save_layout, ground_truth, xml_file, out, min_area):
|
||||||
|
assert image or dir_in, "Either a single image -i or a dir_in -di is required"
|
||||||
with open(os.path.join(model,'config.json')) as f:
|
with open(os.path.join(model,'config.json')) as f:
|
||||||
config_params_model = json.load(f)
|
config_params_model = json.load(f)
|
||||||
task = config_params_model['task']
|
task = config_params_model['task']
|
||||||
if (task != 'classification' and task != 'reading_order'):
|
if (task != 'classification' and task != 'reading_order'):
|
||||||
if not save:
|
if image and not save:
|
||||||
print("Error: You used one of segmentation or binarization task but not set -s, you need a filename to save visualized output with -s")
|
print("Error: You used one of segmentation or binarization task with image input but not set -s, you need a filename to save visualized output with -s")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
x=sbb_predict(image, model, task, config_params_model, patches, save, save_layout, ground_truth, xml_file, out, min_area)
|
if dir_in and not out:
|
||||||
|
print("Error: You used one of segmentation or binarization task with dir_in but not set -out")
|
||||||
|
sys.exit(1)
|
||||||
|
x=sbb_predict(image, dir_in, model, task, config_params_model, patches, save, save_layout, ground_truth, xml_file, out, min_area)
|
||||||
x.run()
|
x.run()
|
||||||
|
|
||||||
if __name__=="__main__":
|
if __name__=="__main__":
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue