From e944d334e93f5da1aad5a2ca70975a128a2a00a1 Mon Sep 17 00:00:00 2001 From: vahidrezanezhad Date: Sat, 13 Sep 2025 22:40:11 +0200 Subject: [PATCH] Running inference on files in a directory --- inference.py | 86 +++++++++++++++++++++++++++++++++++++--------------- 1 file changed, 62 insertions(+), 24 deletions(-) diff --git a/inference.py b/inference.py index aecd0e6..094c528 100644 --- a/inference.py +++ b/inference.py @@ -28,8 +28,9 @@ Tool to load model and predict for given image. """ class sbb_predict: - def __init__(self,image, model, task, config_params_model, patches, save, save_layout, ground_truth, xml_file, out, min_area): + def __init__(self,image, dir_in, model, task, config_params_model, patches, save, save_layout, ground_truth, xml_file, out, min_area): self.image=image + self.dir_in=dir_in self.patches=patches self.save=save self.save_layout=save_layout @@ -223,11 +224,10 @@ class sbb_predict: return added_image, layout_only - def predict(self): - self.start_new_session_and_model() + def predict(self, image_dir): if self.task == 'classification': classes_names = self.config_params_model['classification_classes_name'] - img_1ch = img=cv2.imread(self.image, 0) + img_1ch = img=cv2.imread(image_dir, 0) img_1ch = img_1ch / 255.0 img_1ch = cv2.resize(img_1ch, (self.config_params_model['input_height'], self.config_params_model['input_width']), interpolation=cv2.INTER_NEAREST) @@ -438,7 +438,7 @@ class sbb_predict: if self.patches: #def textline_contours(img,input_width,input_height,n_classes,model): - img=cv2.imread(self.image) + img=cv2.imread(image_dir) self.img_org = np.copy(img) if img.shape[0] < self.img_height: @@ -529,7 +529,7 @@ class sbb_predict: else: - img=cv2.imread(self.image) + img=cv2.imread(image_dir) self.img_org = np.copy(img) width=self.img_width @@ -557,22 +557,50 @@ class sbb_predict: def run(self): - res=self.predict() - if (self.task == 'classification' or self.task == 'reading_order'): - pass - elif self.task == 'enhancement': - if self.save: - cv2.imwrite(self.save,res) + self.start_new_session_and_model() + if self.image: + res=self.predict(image_dir = self.image) + + if (self.task == 'classification' or self.task == 'reading_order'): + pass + elif self.task == 'enhancement': + if self.save: + cv2.imwrite(self.save,res) + else: + img_seg_overlayed, only_layout = self.visualize_model_output(res, self.img_org, self.task) + if self.save: + cv2.imwrite(self.save,img_seg_overlayed) + if self.save_layout: + cv2.imwrite(self.save_layout, only_layout) + + if self.ground_truth: + gt_img=cv2.imread(self.ground_truth) + self.IoU(gt_img[:,:,0],res[:,:,0]) + else: - img_seg_overlayed, only_layout = self.visualize_model_output(res, self.img_org, self.task) - if self.save: - cv2.imwrite(self.save,img_seg_overlayed) - if self.save_layout: - cv2.imwrite(self.save_layout, only_layout) + ls_images = os.listdir(self.dir_in) + for ind_image in ls_images: + f_name = ind_image.split('.')[0] + image_dir = os.path.join(self.dir_in, ind_image) + res=self.predict(image_dir) - if self.ground_truth: - gt_img=cv2.imread(self.ground_truth) - self.IoU(gt_img[:,:,0],res[:,:,0]) + if (self.task == 'classification' or self.task == 'reading_order'): + pass + elif self.task == 'enhancement': + self.save = os.path.join(self.out, f_name+'.png') + cv2.imwrite(self.save,res) + else: + img_seg_overlayed, only_layout = self.visualize_model_output(res, self.img_org, self.task) + self.save = os.path.join(self.out, f_name+'_overlayed.png') + cv2.imwrite(self.save,img_seg_overlayed) + self.save_layout = os.path.join(self.out, f_name+'_layout.png') + cv2.imwrite(self.save_layout, only_layout) + + if self.ground_truth: + gt_img=cv2.imread(self.ground_truth) + self.IoU(gt_img[:,:,0],res[:,:,0]) + + @click.command() @click.option( @@ -581,6 +609,12 @@ class sbb_predict: help="image filename", type=click.Path(exists=True, dir_okay=False), ) +@click.option( + "--dir_in", + "-di", + help="directory of images", + type=click.Path(exists=True, file_okay=False), +) @click.option( "--out", "-o", @@ -626,15 +660,19 @@ class sbb_predict: "-min", help="min area size of regions considered for reading order detection. The default value is zero and means that all text regions are considered for reading order.", ) -def main(image, model, patches, save, save_layout, ground_truth, xml_file, out, min_area): +def main(image, dir_in, model, patches, save, save_layout, ground_truth, xml_file, out, min_area): + assert image or dir_in, "Either a single image -i or a dir_in -di is required" with open(os.path.join(model,'config.json')) as f: config_params_model = json.load(f) task = config_params_model['task'] if (task != 'classification' and task != 'reading_order'): - if not save: - print("Error: You used one of segmentation or binarization task but not set -s, you need a filename to save visualized output with -s") + if image and not save: + print("Error: You used one of segmentation or binarization task with image input but not set -s, you need a filename to save visualized output with -s") sys.exit(1) - x=sbb_predict(image, model, task, config_params_model, patches, save, save_layout, ground_truth, xml_file, out, min_area) + if dir_in and not out: + print("Error: You used one of segmentation or binarization task with dir_in but not set -out") + sys.exit(1) + x=sbb_predict(image, dir_in, model, task, config_params_model, patches, save, save_layout, ground_truth, xml_file, out, min_area) x.run() if __name__=="__main__":