binarization: add option --overwrite, skip existing outputs

(also, simplify `run` and separate `run_single`)
This commit is contained in:
Robert Sachunsky 2025-10-15 12:24:21 +02:00
parent 38c028c6b5
commit 086c1880ac
2 changed files with 52 additions and 60 deletions

View file

@ -79,18 +79,28 @@ def machine_based_reading_order(input, dir_in, out, model, log_level):
type=click.Path(file_okay=True, dir_okay=True),
required=True,
)
@click.option(
"--overwrite",
"-O",
help="overwrite (instead of skipping) if output xml exists",
is_flag=True,
)
@click.option(
"--log_level",
"-l",
type=click.Choice(['OFF', 'DEBUG', 'INFO', 'WARN', 'ERROR']),
help="Override log level globally to this",
)
def binarization(patches, model_dir, input_image, dir_in, output, log_level):
def binarization(patches, model_dir, input_image, dir_in, output, overwrite, log_level):
assert bool(input_image) != bool(dir_in), "Either -i (single input) or -di (directory) must be provided, but not both."
binarizer = SbbBinarizer(model_dir)
if log_level:
binarizer.log.setLevel(getLevelName(log_level))
binarizer.run(image_path=input_image, use_patches=patches, output=output, dir_in=dir_in)
binarizer.logger.setLevel(getLevelName(log_level))
binarizer.run(overwrite=overwrite,
use_patches=patches,
image_path=input_image,
output=output,
dir_in=dir_in)
@main.command()

View file

@ -25,7 +25,7 @@ class SbbBinarizer:
def __init__(self, model_dir, logger=None):
self.model_dir = model_dir
self.log = logger if logger else logging.getLogger('SbbBinarizer')
self.logger = logger if logger else logging.getLogger('SbbBinarizer')
self.start_new_session()
@ -315,64 +315,46 @@ class SbbBinarizer:
prediction_true = prediction_true.astype(np.uint8)
return prediction_true[:,:,0]
def run(self, image=None, image_path=None, output=None, use_patches=False, dir_in=None):
# print(dir_in,'dir_in')
if not dir_in:
if (image is not None and image_path is not None) or \
(image is None and image_path is None):
raise ValueError("Must pass either a opencv2 image or an image_path")
if image_path is not None:
image = cv2.imread(image_path)
img_last = 0
for n, (model, model_file) in enumerate(zip(self.models, self.model_files)):
self.log.info('Predicting with model %s [%s/%s]' % (model_file, n + 1, len(self.model_files)))
res = self.predict(model, image, use_patches)
img_fin = np.zeros((res.shape[0], res.shape[1], 3))
res[:, :][res[:, :] == 0] = 2
res = res - 1
res = res * 255
img_fin[:, :, 0] = res
img_fin[:, :, 1] = res
img_fin[:, :, 2] = res
img_fin = img_fin.astype(np.uint8)
img_fin = (res[:, :] == 0) * 255
img_last = img_last + img_fin
kernel = np.ones((5, 5), np.uint8)
img_last[:, :][img_last[:, :] > 0] = 255
img_last = (img_last[:, :] == 0) * 255
if output:
cv2.imwrite(output, img_last)
return img_last
def run(self, image_path=None, output=None, dir_in=None, use_patches=False, overwrite=False):
if dir_in:
ls_imgs = [(os.path.join(dir_in, image_filename),
os.path.join(output, os.path.splitext(image_filename)[0] + '.png'))
for image_filename in filter(is_image_filename,
os.listdir(dir_in))]
else:
ls_imgs = list(filter(is_image_filename, os.listdir(dir_in)))
for image_name in ls_imgs:
image_stem = image_name.split('.')[0]
print(image_name,'image_name')
image = cv2.imread(os.path.join(dir_in,image_name) )
img_last = 0
for n, (model, model_file) in enumerate(zip(self.models, self.model_files)):
self.log.info('Predicting with model %s [%s/%s]' % (model_file, n + 1, len(self.model_files)))
ls_imgs = [(image_path, output)]
res = self.predict(model, image, use_patches)
for input_path, output_path in ls_imgs:
print(input_path, 'image_name')
if os.path.exists(output_path):
if overwrite:
self.logger.warning("will overwrite existing output file '%s'", output_ptah)
else:
self.logger.warning("will skip input for existing output file '%s'", output_path)
image = cv2.imread(input_path)
result = self.run_single(image, use_patches)
cv2.imwrite(output_path, result)
img_fin = np.zeros((res.shape[0], res.shape[1], 3))
res[:, :][res[:, :] == 0] = 2
res = res - 1
res = res * 255
img_fin[:, :, 0] = res
img_fin[:, :, 1] = res
img_fin[:, :, 2] = res
def run_single(self, image: np.ndarray, use_patches=False):
img_last = 0
for n, (model, model_file) in enumerate(zip(self.models, self.model_files)):
self.logger.info('Predicting with model %s [%s/%s]' % (model_file, n + 1, len(self.model_files)))
img_fin = img_fin.astype(np.uint8)
img_fin = (res[:, :] == 0) * 255
img_last = img_last + img_fin
res = self.predict(model, image, use_patches)
kernel = np.ones((5, 5), np.uint8)
img_last[:, :][img_last[:, :] > 0] = 255
img_last = (img_last[:, :] == 0) * 255
cv2.imwrite(os.path.join(output, image_stem + '.png'), img_last)
img_fin = np.zeros((res.shape[0], res.shape[1], 3))
res[:, :][res[:, :] == 0] = 2
res = res - 1
res = res * 255
img_fin[:, :, 0] = res
img_fin[:, :, 1] = res
img_fin[:, :, 2] = res
img_fin = img_fin.astype(np.uint8)
img_fin = (res[:, :] == 0) * 255
img_last = img_last + img_fin
kernel = np.ones((5, 5), np.uint8)
img_last[:, :][img_last[:, :] > 0] = 255
img_last = (img_last[:, :] == 0) * 255
return img_last