Running inference on files in a directory

This commit is contained in:
vahidrezanezhad 2025-09-13 22:40:11 +02:00
parent 8b75d46d3d
commit e944d334e9

View file

@ -28,8 +28,9 @@ Tool to load model and predict for given image.
"""
class sbb_predict:
def __init__(self,image, model, task, config_params_model, patches, save, save_layout, ground_truth, xml_file, out, min_area):
def __init__(self,image, dir_in, model, task, config_params_model, patches, save, save_layout, ground_truth, xml_file, out, min_area):
self.image=image
self.dir_in=dir_in
self.patches=patches
self.save=save
self.save_layout=save_layout
@ -223,11 +224,10 @@ class sbb_predict:
return added_image, layout_only
def predict(self):
self.start_new_session_and_model()
def predict(self, image_dir):
if self.task == 'classification':
classes_names = self.config_params_model['classification_classes_name']
img_1ch = img=cv2.imread(self.image, 0)
img_1ch = img=cv2.imread(image_dir, 0)
img_1ch = img_1ch / 255.0
img_1ch = cv2.resize(img_1ch, (self.config_params_model['input_height'], self.config_params_model['input_width']), interpolation=cv2.INTER_NEAREST)
@ -438,7 +438,7 @@ class sbb_predict:
if self.patches:
#def textline_contours(img,input_width,input_height,n_classes,model):
img=cv2.imread(self.image)
img=cv2.imread(image_dir)
self.img_org = np.copy(img)
if img.shape[0] < self.img_height:
@ -529,7 +529,7 @@ class sbb_predict:
else:
img=cv2.imread(self.image)
img=cv2.imread(image_dir)
self.img_org = np.copy(img)
width=self.img_width
@ -557,22 +557,50 @@ class sbb_predict:
def run(self):
res=self.predict()
if (self.task == 'classification' or self.task == 'reading_order'):
pass
elif self.task == 'enhancement':
if self.save:
cv2.imwrite(self.save,res)
self.start_new_session_and_model()
if self.image:
res=self.predict(image_dir = self.image)
if (self.task == 'classification' or self.task == 'reading_order'):
pass
elif self.task == 'enhancement':
if self.save:
cv2.imwrite(self.save,res)
else:
img_seg_overlayed, only_layout = self.visualize_model_output(res, self.img_org, self.task)
if self.save:
cv2.imwrite(self.save,img_seg_overlayed)
if self.save_layout:
cv2.imwrite(self.save_layout, only_layout)
if self.ground_truth:
gt_img=cv2.imread(self.ground_truth)
self.IoU(gt_img[:,:,0],res[:,:,0])
else:
img_seg_overlayed, only_layout = self.visualize_model_output(res, self.img_org, self.task)
if self.save:
cv2.imwrite(self.save,img_seg_overlayed)
if self.save_layout:
cv2.imwrite(self.save_layout, only_layout)
ls_images = os.listdir(self.dir_in)
for ind_image in ls_images:
f_name = ind_image.split('.')[0]
image_dir = os.path.join(self.dir_in, ind_image)
res=self.predict(image_dir)
if self.ground_truth:
gt_img=cv2.imread(self.ground_truth)
self.IoU(gt_img[:,:,0],res[:,:,0])
if (self.task == 'classification' or self.task == 'reading_order'):
pass
elif self.task == 'enhancement':
self.save = os.path.join(self.out, f_name+'.png')
cv2.imwrite(self.save,res)
else:
img_seg_overlayed, only_layout = self.visualize_model_output(res, self.img_org, self.task)
self.save = os.path.join(self.out, f_name+'_overlayed.png')
cv2.imwrite(self.save,img_seg_overlayed)
self.save_layout = os.path.join(self.out, f_name+'_layout.png')
cv2.imwrite(self.save_layout, only_layout)
if self.ground_truth:
gt_img=cv2.imread(self.ground_truth)
self.IoU(gt_img[:,:,0],res[:,:,0])
@click.command()
@click.option(
@ -581,6 +609,12 @@ class sbb_predict:
help="image filename",
type=click.Path(exists=True, dir_okay=False),
)
@click.option(
"--dir_in",
"-di",
help="directory of images",
type=click.Path(exists=True, file_okay=False),
)
@click.option(
"--out",
"-o",
@ -626,15 +660,19 @@ class sbb_predict:
"-min",
help="min area size of regions considered for reading order detection. The default value is zero and means that all text regions are considered for reading order.",
)
def main(image, model, patches, save, save_layout, ground_truth, xml_file, out, min_area):
def main(image, dir_in, model, patches, save, save_layout, ground_truth, xml_file, out, min_area):
assert image or dir_in, "Either a single image -i or a dir_in -di is required"
with open(os.path.join(model,'config.json')) as f:
config_params_model = json.load(f)
task = config_params_model['task']
if (task != 'classification' and task != 'reading_order'):
if not save:
print("Error: You used one of segmentation or binarization task but not set -s, you need a filename to save visualized output with -s")
if image and not save:
print("Error: You used one of segmentation or binarization task with image input but not set -s, you need a filename to save visualized output with -s")
sys.exit(1)
x=sbb_predict(image, model, task, config_params_model, patches, save, save_layout, ground_truth, xml_file, out, min_area)
if dir_in and not out:
print("Error: You used one of segmentation or binarization task with dir_in but not set -out")
sys.exit(1)
x=sbb_predict(image, dir_in, model, task, config_params_model, patches, save, save_layout, ground_truth, xml_file, out, min_area)
x.run()
if __name__=="__main__":