mirror of
https://github.com/qurator-spk/sbb_pixelwise_segmentation.git
synced 2025-09-16 04:09:57 +02:00
Running inference on files in a directory
This commit is contained in:
parent
8b75d46d3d
commit
e944d334e9
1 changed files with 62 additions and 24 deletions
88
inference.py
88
inference.py
|
@ -28,8 +28,9 @@ Tool to load model and predict for given image.
|
|||
"""
|
||||
|
||||
class sbb_predict:
|
||||
def __init__(self,image, model, task, config_params_model, patches, save, save_layout, ground_truth, xml_file, out, min_area):
|
||||
def __init__(self,image, dir_in, model, task, config_params_model, patches, save, save_layout, ground_truth, xml_file, out, min_area):
|
||||
self.image=image
|
||||
self.dir_in=dir_in
|
||||
self.patches=patches
|
||||
self.save=save
|
||||
self.save_layout=save_layout
|
||||
|
@ -223,11 +224,10 @@ class sbb_predict:
|
|||
|
||||
return added_image, layout_only
|
||||
|
||||
def predict(self):
|
||||
self.start_new_session_and_model()
|
||||
def predict(self, image_dir):
|
||||
if self.task == 'classification':
|
||||
classes_names = self.config_params_model['classification_classes_name']
|
||||
img_1ch = img=cv2.imread(self.image, 0)
|
||||
img_1ch = img=cv2.imread(image_dir, 0)
|
||||
|
||||
img_1ch = img_1ch / 255.0
|
||||
img_1ch = cv2.resize(img_1ch, (self.config_params_model['input_height'], self.config_params_model['input_width']), interpolation=cv2.INTER_NEAREST)
|
||||
|
@ -438,7 +438,7 @@ class sbb_predict:
|
|||
if self.patches:
|
||||
#def textline_contours(img,input_width,input_height,n_classes,model):
|
||||
|
||||
img=cv2.imread(self.image)
|
||||
img=cv2.imread(image_dir)
|
||||
self.img_org = np.copy(img)
|
||||
|
||||
if img.shape[0] < self.img_height:
|
||||
|
@ -529,7 +529,7 @@ class sbb_predict:
|
|||
|
||||
else:
|
||||
|
||||
img=cv2.imread(self.image)
|
||||
img=cv2.imread(image_dir)
|
||||
self.img_org = np.copy(img)
|
||||
|
||||
width=self.img_width
|
||||
|
@ -557,22 +557,50 @@ class sbb_predict:
|
|||
|
||||
|
||||
def run(self):
|
||||
res=self.predict()
|
||||
if (self.task == 'classification' or self.task == 'reading_order'):
|
||||
pass
|
||||
elif self.task == 'enhancement':
|
||||
if self.save:
|
||||
cv2.imwrite(self.save,res)
|
||||
else:
|
||||
img_seg_overlayed, only_layout = self.visualize_model_output(res, self.img_org, self.task)
|
||||
if self.save:
|
||||
cv2.imwrite(self.save,img_seg_overlayed)
|
||||
if self.save_layout:
|
||||
cv2.imwrite(self.save_layout, only_layout)
|
||||
self.start_new_session_and_model()
|
||||
if self.image:
|
||||
res=self.predict(image_dir = self.image)
|
||||
|
||||
if (self.task == 'classification' or self.task == 'reading_order'):
|
||||
pass
|
||||
elif self.task == 'enhancement':
|
||||
if self.save:
|
||||
cv2.imwrite(self.save,res)
|
||||
else:
|
||||
img_seg_overlayed, only_layout = self.visualize_model_output(res, self.img_org, self.task)
|
||||
if self.save:
|
||||
cv2.imwrite(self.save,img_seg_overlayed)
|
||||
if self.save_layout:
|
||||
cv2.imwrite(self.save_layout, only_layout)
|
||||
|
||||
if self.ground_truth:
|
||||
gt_img=cv2.imread(self.ground_truth)
|
||||
self.IoU(gt_img[:,:,0],res[:,:,0])
|
||||
|
||||
else:
|
||||
ls_images = os.listdir(self.dir_in)
|
||||
for ind_image in ls_images:
|
||||
f_name = ind_image.split('.')[0]
|
||||
image_dir = os.path.join(self.dir_in, ind_image)
|
||||
res=self.predict(image_dir)
|
||||
|
||||
if (self.task == 'classification' or self.task == 'reading_order'):
|
||||
pass
|
||||
elif self.task == 'enhancement':
|
||||
self.save = os.path.join(self.out, f_name+'.png')
|
||||
cv2.imwrite(self.save,res)
|
||||
else:
|
||||
img_seg_overlayed, only_layout = self.visualize_model_output(res, self.img_org, self.task)
|
||||
self.save = os.path.join(self.out, f_name+'_overlayed.png')
|
||||
cv2.imwrite(self.save,img_seg_overlayed)
|
||||
self.save_layout = os.path.join(self.out, f_name+'_layout.png')
|
||||
cv2.imwrite(self.save_layout, only_layout)
|
||||
|
||||
if self.ground_truth:
|
||||
gt_img=cv2.imread(self.ground_truth)
|
||||
self.IoU(gt_img[:,:,0],res[:,:,0])
|
||||
|
||||
|
||||
if self.ground_truth:
|
||||
gt_img=cv2.imread(self.ground_truth)
|
||||
self.IoU(gt_img[:,:,0],res[:,:,0])
|
||||
|
||||
@click.command()
|
||||
@click.option(
|
||||
|
@ -581,6 +609,12 @@ class sbb_predict:
|
|||
help="image filename",
|
||||
type=click.Path(exists=True, dir_okay=False),
|
||||
)
|
||||
@click.option(
|
||||
"--dir_in",
|
||||
"-di",
|
||||
help="directory of images",
|
||||
type=click.Path(exists=True, file_okay=False),
|
||||
)
|
||||
@click.option(
|
||||
"--out",
|
||||
"-o",
|
||||
|
@ -626,15 +660,19 @@ class sbb_predict:
|
|||
"-min",
|
||||
help="min area size of regions considered for reading order detection. The default value is zero and means that all text regions are considered for reading order.",
|
||||
)
|
||||
def main(image, model, patches, save, save_layout, ground_truth, xml_file, out, min_area):
|
||||
def main(image, dir_in, model, patches, save, save_layout, ground_truth, xml_file, out, min_area):
|
||||
assert image or dir_in, "Either a single image -i or a dir_in -di is required"
|
||||
with open(os.path.join(model,'config.json')) as f:
|
||||
config_params_model = json.load(f)
|
||||
task = config_params_model['task']
|
||||
if (task != 'classification' and task != 'reading_order'):
|
||||
if not save:
|
||||
print("Error: You used one of segmentation or binarization task but not set -s, you need a filename to save visualized output with -s")
|
||||
if image and not save:
|
||||
print("Error: You used one of segmentation or binarization task with image input but not set -s, you need a filename to save visualized output with -s")
|
||||
sys.exit(1)
|
||||
x=sbb_predict(image, model, task, config_params_model, patches, save, save_layout, ground_truth, xml_file, out, min_area)
|
||||
if dir_in and not out:
|
||||
print("Error: You used one of segmentation or binarization task with dir_in but not set -out")
|
||||
sys.exit(1)
|
||||
x=sbb_predict(image, dir_in, model, task, config_params_model, patches, save, save_layout, ground_truth, xml_file, out, min_area)
|
||||
x.run()
|
||||
|
||||
if __name__=="__main__":
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue