Running inference on files in a directory

This commit is contained in:
vahidrezanezhad 2025-09-13 22:40:11 +02:00
parent 8b75d46d3d
commit e944d334e9

View file

@ -28,8 +28,9 @@ Tool to load model and predict for given image.
""" """
class sbb_predict: class sbb_predict:
def __init__(self,image, model, task, config_params_model, patches, save, save_layout, ground_truth, xml_file, out, min_area): def __init__(self,image, dir_in, model, task, config_params_model, patches, save, save_layout, ground_truth, xml_file, out, min_area):
self.image=image self.image=image
self.dir_in=dir_in
self.patches=patches self.patches=patches
self.save=save self.save=save
self.save_layout=save_layout self.save_layout=save_layout
@ -223,11 +224,10 @@ class sbb_predict:
return added_image, layout_only return added_image, layout_only
def predict(self): def predict(self, image_dir):
self.start_new_session_and_model()
if self.task == 'classification': if self.task == 'classification':
classes_names = self.config_params_model['classification_classes_name'] classes_names = self.config_params_model['classification_classes_name']
img_1ch = img=cv2.imread(self.image, 0) img_1ch = img=cv2.imread(image_dir, 0)
img_1ch = img_1ch / 255.0 img_1ch = img_1ch / 255.0
img_1ch = cv2.resize(img_1ch, (self.config_params_model['input_height'], self.config_params_model['input_width']), interpolation=cv2.INTER_NEAREST) img_1ch = cv2.resize(img_1ch, (self.config_params_model['input_height'], self.config_params_model['input_width']), interpolation=cv2.INTER_NEAREST)
@ -438,7 +438,7 @@ class sbb_predict:
if self.patches: if self.patches:
#def textline_contours(img,input_width,input_height,n_classes,model): #def textline_contours(img,input_width,input_height,n_classes,model):
img=cv2.imread(self.image) img=cv2.imread(image_dir)
self.img_org = np.copy(img) self.img_org = np.copy(img)
if img.shape[0] < self.img_height: if img.shape[0] < self.img_height:
@ -529,7 +529,7 @@ class sbb_predict:
else: else:
img=cv2.imread(self.image) img=cv2.imread(image_dir)
self.img_org = np.copy(img) self.img_org = np.copy(img)
width=self.img_width width=self.img_width
@ -557,22 +557,50 @@ class sbb_predict:
def run(self): def run(self):
res=self.predict() self.start_new_session_and_model()
if (self.task == 'classification' or self.task == 'reading_order'): if self.image:
pass res=self.predict(image_dir = self.image)
elif self.task == 'enhancement':
if self.save: if (self.task == 'classification' or self.task == 'reading_order'):
cv2.imwrite(self.save,res) pass
elif self.task == 'enhancement':
if self.save:
cv2.imwrite(self.save,res)
else:
img_seg_overlayed, only_layout = self.visualize_model_output(res, self.img_org, self.task)
if self.save:
cv2.imwrite(self.save,img_seg_overlayed)
if self.save_layout:
cv2.imwrite(self.save_layout, only_layout)
if self.ground_truth:
gt_img=cv2.imread(self.ground_truth)
self.IoU(gt_img[:,:,0],res[:,:,0])
else: else:
img_seg_overlayed, only_layout = self.visualize_model_output(res, self.img_org, self.task) ls_images = os.listdir(self.dir_in)
if self.save: for ind_image in ls_images:
cv2.imwrite(self.save,img_seg_overlayed) f_name = ind_image.split('.')[0]
if self.save_layout: image_dir = os.path.join(self.dir_in, ind_image)
cv2.imwrite(self.save_layout, only_layout) res=self.predict(image_dir)
if self.ground_truth: if (self.task == 'classification' or self.task == 'reading_order'):
gt_img=cv2.imread(self.ground_truth) pass
self.IoU(gt_img[:,:,0],res[:,:,0]) elif self.task == 'enhancement':
self.save = os.path.join(self.out, f_name+'.png')
cv2.imwrite(self.save,res)
else:
img_seg_overlayed, only_layout = self.visualize_model_output(res, self.img_org, self.task)
self.save = os.path.join(self.out, f_name+'_overlayed.png')
cv2.imwrite(self.save,img_seg_overlayed)
self.save_layout = os.path.join(self.out, f_name+'_layout.png')
cv2.imwrite(self.save_layout, only_layout)
if self.ground_truth:
gt_img=cv2.imread(self.ground_truth)
self.IoU(gt_img[:,:,0],res[:,:,0])
@click.command() @click.command()
@click.option( @click.option(
@ -581,6 +609,12 @@ class sbb_predict:
help="image filename", help="image filename",
type=click.Path(exists=True, dir_okay=False), type=click.Path(exists=True, dir_okay=False),
) )
@click.option(
"--dir_in",
"-di",
help="directory of images",
type=click.Path(exists=True, file_okay=False),
)
@click.option( @click.option(
"--out", "--out",
"-o", "-o",
@ -626,15 +660,19 @@ class sbb_predict:
"-min", "-min",
help="min area size of regions considered for reading order detection. The default value is zero and means that all text regions are considered for reading order.", help="min area size of regions considered for reading order detection. The default value is zero and means that all text regions are considered for reading order.",
) )
def main(image, model, patches, save, save_layout, ground_truth, xml_file, out, min_area): def main(image, dir_in, model, patches, save, save_layout, ground_truth, xml_file, out, min_area):
assert image or dir_in, "Either a single image -i or a dir_in -di is required"
with open(os.path.join(model,'config.json')) as f: with open(os.path.join(model,'config.json')) as f:
config_params_model = json.load(f) config_params_model = json.load(f)
task = config_params_model['task'] task = config_params_model['task']
if (task != 'classification' and task != 'reading_order'): if (task != 'classification' and task != 'reading_order'):
if not save: if image and not save:
print("Error: You used one of segmentation or binarization task but not set -s, you need a filename to save visualized output with -s") print("Error: You used one of segmentation or binarization task with image input but not set -s, you need a filename to save visualized output with -s")
sys.exit(1) sys.exit(1)
x=sbb_predict(image, model, task, config_params_model, patches, save, save_layout, ground_truth, xml_file, out, min_area) if dir_in and not out:
print("Error: You used one of segmentation or binarization task with dir_in but not set -out")
sys.exit(1)
x=sbb_predict(image, dir_in, model, task, config_params_model, patches, save, save_layout, ground_truth, xml_file, out, min_area)
x.run() x.run()
if __name__=="__main__": if __name__=="__main__":