@ -32,6 +32,7 @@ class sbb_predict:
self . image = image
self . patches = patches
self . save = save
self . save_layout = save_layout
self . model_dir = model
self . ground_truth = ground_truth
self . task = task
@ -181,6 +182,7 @@ class sbb_predict:
prediction = prediction * - 1
prediction = prediction + 1
added_image = prediction * 255
layout_only = None
else :
unique_classes = np . unique ( prediction [ : , : , 0 ] )
rgb_colors = { ' 0 ' : [ 255 , 255 , 255 ] ,
@ -200,26 +202,26 @@ class sbb_predict:
' 14 ' : [ 255 , 125 , 125 ] ,
' 15 ' : [ 255 , 0 , 255 ] }
output = np . zeros ( prediction . shape )
layout_only = np . zeros ( prediction . shape )
for unq_class in unique_classes :
rgb_class_unique = rgb_colors [ str ( int ( unq_class ) ) ]
output [ : , : , 0 ] [ prediction [ : , : , 0 ] == unq_class ] = rgb_class_unique [ 0 ]
output [ : , : , 1 ] [ prediction [ : , : , 0 ] == unq_class ] = rgb_class_unique [ 1 ]
output [ : , : , 2 ] [ prediction [ : , : , 0 ] == unq_class ] = rgb_class_unique [ 2 ]
layout_only [ : , : , 0 ] [ prediction [ : , : , 0 ] == unq_class ] = rgb_class_unique [ 0 ]
layout_only [ : , : , 1 ] [ prediction [ : , : , 0 ] == unq_class ] = rgb_class_unique [ 1 ]
layout_only [ : , : , 2 ] [ prediction [ : , : , 0 ] == unq_class ] = rgb_class_unique [ 2 ]
img = self . resize_image ( img , output. shape [ 0 ] , output . shape [ 1 ] )
img = self . resize_image ( img , layout_only. shape [ 0 ] , layout_only . shape [ 1 ] )
output = output . astype ( np . int32 )
layout_only = layout_only . astype ( np . int32 )
img = img . astype ( np . int32 )
added_image = cv2 . addWeighted ( img , 0.5 , output , 0.1 , 0 )
added_image = cv2 . addWeighted ( img , 0.5 , layout_only , 0.1 , 0 )
return added_image , output
return added_image , layout_only
def predict ( self ) :
self . start_new_session_and_model ( )
@ -559,13 +561,12 @@ class sbb_predict:
pass
elif self . task == ' enhancement ' :
if self . save :
print ( self . save )
cv2 . imwrite ( self . save , res )
else :
img_seg_overlayed , only_ prediction = self . visualize_model_output ( res , self . img_org , self . task )
img_seg_overlayed , only_ layout = self . visualize_model_output ( res , self . img_org , self . task )
if self . save :
cv2 . imwrite ( self . save , img_seg_overlayed )
cv2 . imwrite ( ' ./layout.png ' , only_prediction )
cv2 . imwrite ( self . save_layout , only_layout )
if self . ground_truth :
gt_img = cv2 . imread ( self . ground_truth )
@ -595,6 +596,11 @@ class sbb_predict:
" -s " ,
help = " save prediction as a png file in current folder. " ,
)
@click.option (
" --save_layout " ,
" -sl " ,
help = " save layout prediction only as a png file in current folder. " ,
)
@click.option (
" --model " ,
" -m " ,
@ -618,7 +624,7 @@ class sbb_predict:
" -min " ,
help = " min area size of regions considered for reading order detection. The default value is zero and means that all text regions are considered for reading order. " ,
)
def main ( image , model , patches , save , ground_truth, xml_file , out , min_area ) :
def main ( image , model , patches , save , save_layout, ground_truth, xml_file , out , min_area ) :
with open ( os . path . join ( model , ' config.json ' ) ) as f :
config_params_model = json . load ( f )
task = config_params_model [ ' task ' ]
@ -626,7 +632,7 @@ def main(image, model, patches, save, ground_truth, xml_file, out, min_area):
if not save :
print ( " Error: You used one of segmentation or binarization task but not set -s, you need a filename to save visualized output with -s " )
sys . exit ( 1 )
x = sbb_predict ( image , model , task , config_params_model , patches , save , ground_truth, xml_file , out , min_area )
x = sbb_predict ( image , model , task , config_params_model , patches , save , save_layout, ground_truth, xml_file , out , min_area )
x . run ( )
if __name__ == " __main__ " :