refactor cli tests

This commit is contained in:
kba 2025-10-29 16:20:30 +01:00
parent ef999c8f0a
commit b6f82c72b9
15 changed files with 453 additions and 592 deletions

View file

@ -59,7 +59,7 @@ class Eynollah_ocr:
export_textline_images_and_text: bool=False,
do_not_mask_with_textline_contour: bool=False,
pref_of_dataset=None,
min_conf_value_of_textline_text : float=0.3,
min_conf_value_of_textline_text : Optional[float]=None,
logger: Optional[Logger]=None,
):
self.tr_ocr = tr_ocr
@ -69,7 +69,7 @@ class Eynollah_ocr:
self.do_not_mask_with_textline_contour = do_not_mask_with_textline_contour
# prefix or dataset
self.pref_of_dataset = pref_of_dataset
self.logger = logger if logger else getLogger('eynollah')
self.logger = logger if logger else getLogger('eynollah.ocr')
self.model_zoo = EynollahModelZoo(basedir=dir_models)
# TODO: Properly document what 'export_textline_images_and_text' is about
@ -77,21 +77,15 @@ class Eynollah_ocr:
self.logger.info("export_textline_images_and_text was set, so no actual models are loaded")
return
self.min_conf_value_of_textline_text = min_conf_value_of_textline_text
self.min_conf_value_of_textline_text = min_conf_value_of_textline_text if min_conf_value_of_textline_text else 0.3
self.b_s = 2 if batch_size is None and tr_ocr else 8 if batch_size is None else batch_size
if tr_ocr:
self.model_zoo.load_model('trocr_processor', '')
if model_name:
self.model_zoo.load_model('ocr', 'tr', model_name)
else:
self.model_zoo.load_model('ocr', 'tr')
self.model_zoo.load_model('trocr_processor')
self.model_zoo.load_model('ocr', 'tr', model_path_override=model_name)
self.model_zoo.get('ocr').to(self.device)
else:
if model_name:
self.model_zoo.load_model('ocr', '', model_name)
else:
self.model_zoo.load_model('ocr', '')
self.model_zoo.load_model('ocr', '', model_path_override=model_name)
self.model_zoo.load_model('num_to_char')
self.end_character = len(self.model_zoo.load_model('characters')) + 2
@ -206,10 +200,10 @@ class Eynollah_ocr:
cropped_lines = []
indexer_b_s = 0
pixel_values_merged = self.model_zoo.get('processor')(imgs, return_tensors="pt").pixel_values
pixel_values_merged = self.model_zoo.get('trocr_processor')(imgs, return_tensors="pt").pixel_values
generated_ids_merged = self.model_zoo.get('ocr').generate(
pixel_values_merged.to(self.device))
generated_text_merged = self.model_zoo.get('processor').batch_decode(
generated_text_merged = self.model_zoo.get('trocr_processor').batch_decode(
generated_ids_merged, skip_special_tokens=True)
extracted_texts = extracted_texts + generated_text_merged
@ -229,10 +223,10 @@ class Eynollah_ocr:
cropped_lines = []
indexer_b_s = 0
pixel_values_merged = self.model_zoo.get('processor')(imgs, return_tensors="pt").pixel_values
pixel_values_merged = self.model_zoo.get('trocr_processor')(imgs, return_tensors="pt").pixel_values
generated_ids_merged = self.model_zoo.get('ocr').generate(
pixel_values_merged.to(self.device))
generated_text_merged = self.model_zoo.get('processor').batch_decode(
generated_text_merged = self.model_zoo.get('trocr_processor').batch_decode(
generated_ids_merged, skip_special_tokens=True)
extracted_texts = extracted_texts + generated_text_merged
@ -249,10 +243,10 @@ class Eynollah_ocr:
cropped_lines = []
indexer_b_s = 0
pixel_values_merged = self.model_zoo.get('processor')(imgs, return_tensors="pt").pixel_values
pixel_values_merged = self.model_zoo.get('trocr_processor')(imgs, return_tensors="pt").pixel_values
generated_ids_merged = self.model_zoo.get('ocr').generate(
pixel_values_merged.to(self.device))
generated_text_merged = self.model_zoo.get('processor').batch_decode(
generated_text_merged = self.model_zoo.get('trocr_processor').batch_decode(
generated_ids_merged, skip_special_tokens=True)
extracted_texts = extracted_texts + generated_text_merged
@ -267,10 +261,10 @@ class Eynollah_ocr:
cropped_lines = []
indexer_b_s = 0
pixel_values_merged = self.model_zoo.get('processor')(imgs, return_tensors="pt").pixel_values
pixel_values_merged = self.model_zoo.get('trocr_processor')(imgs, return_tensors="pt").pixel_values
generated_ids_merged = self.model_zoo.get('ocr').generate(
pixel_values_merged.to(self.device))
generated_text_merged = self.model_zoo.get('processor').batch_decode(
generated_text_merged = self.model_zoo.get('trocr_processor').batch_decode(
generated_ids_merged, skip_special_tokens=True)
extracted_texts = extracted_texts + generated_text_merged
@ -284,9 +278,9 @@ class Eynollah_ocr:
cropped_lines = []
indexer_b_s = 0
pixel_values_merged = self.model_zoo.get('processor')(imgs, return_tensors="pt").pixel_values
pixel_values_merged = self.model_zoo.get('trocr_processor')(imgs, return_tensors="pt").pixel_values
generated_ids_merged = self.model_zoo.get('ocr').generate(pixel_values_merged.to(self.device))
generated_text_merged = self.model_zoo.get('processor').batch_decode(generated_ids_merged, skip_special_tokens=True)
generated_text_merged = self.model_zoo.get('trocr_processor').batch_decode(generated_ids_merged, skip_special_tokens=True)
extracted_texts = extracted_texts + generated_text_merged
@ -301,10 +295,10 @@ class Eynollah_ocr:
####n_start = i*self.b_s
####n_end = (i+1)*self.b_s
####imgs = cropped_lines[n_start:n_end]
####pixel_values_merged = self.model_zoo.get('processor')(imgs, return_tensors="pt").pixel_values
####pixel_values_merged = self.model_zoo.get('trocr_processor')(imgs, return_tensors="pt").pixel_values
####generated_ids_merged = self.model_ocr.generate(
#### pixel_values_merged.to(self.device))
####generated_text_merged = self.model_zoo.get('processor').batch_decode(
####generated_text_merged = self.model_zoo.get('trocr_processor').batch_decode(
#### generated_ids_merged, skip_special_tokens=True)
####extracted_texts = extracted_texts + generated_text_merged

View file

@ -50,7 +50,7 @@ class Enhancer:
else:
self.num_col_lower = num_col_lower
self.logger = logger if logger else getLogger('enhancement')
self.logger = logger if logger else getLogger('eynollah.enhance')
self.model_zoo = EynollahModelZoo(basedir=dir_models)
for v in ['binarization', 'enhancement', 'col_classifier', 'page']:
self.model_zoo.load_model(v)
@ -142,7 +142,7 @@ class Enhancer:
index_y_d = img_h - img_height_model
img_patch = img[np.newaxis, index_y_d:index_y_u, index_x_d:index_x_u, :]
label_p_pred = self.model_zoo.get('enhancement', Model).predict(img_patch, verbose=0)
label_p_pred = self.model_zoo.get('enhancement', Model).predict(img_patch, verbose='0')
seg = label_p_pred[0, :, :, :] * 255
if i == 0 and j == 0:
@ -667,7 +667,7 @@ class Enhancer:
t0 = time.time()
img_res, is_image_enhanced, num_col_classifier, num_column_is_classified = self.run_enhancement(light_version=False)
return img_res
return img_res, is_image_enhanced
def run(self,
@ -705,9 +705,18 @@ class Enhancer:
self.logger.warning("will skip input for existing output file '%s'", self.output_filename)
continue
image_enhanced = self.run_single()
did_resize = False
image_enhanced, did_enhance = self.run_single()
if self.save_org_scale:
image_enhanced = resize_image(image_enhanced, self.h_org, self.w_org)
did_resize = True
self.logger.info(
"Image %s was %senhanced%s.",
img_filename,
'' if did_enhance else 'not ',
'and resized' if did_resize else ''
)
cv2.imwrite(self.output_filename, image_enhanced)

View file

@ -84,10 +84,13 @@ class EynollahModelZoo:
self,
model_category: str,
model_variant: str = '',
model_path_override: Optional[str] = None,
) -> AnyModel:
"""
Load any model
"""
if model_path_override:
self.override_models((model_category, model_variant, model_path_override))
model_path = self.model_path(model_category, model_variant)
if model_path.suffix == '.h5' and Path(model_path.stem).exists():
# prefer SavedModel over HDF5 format if it exists
@ -183,5 +186,5 @@ class EynollahModelZoo:
Ensure that a loaded models is not referenced by ``self._loaded`` anymore
"""
if hasattr(self, '_loaded') and getattr(self, '_loaded'):
for needle in self._loaded.keys():
for needle in list(self._loaded.keys()):
del self._loaded[needle]

View file

@ -322,8 +322,7 @@ class SbbBinarizer:
image = cv2.imread(image_path)
img_last = 0
for n, (model_file, model) in enumerate(self.models.items()):
self.log.info('Predicting with model %s [%s/%s]' % (model_file, n + 1, len(self.models.keys())))
self.log.info('Predicting %s with model %s [%s/%s]', image_path if image_path else '[image]', model_file, n + 1, len(self.models.keys()))
res = self.predict(model, image, use_patches)
img_fin = np.zeros((res.shape[0], res.shape[1], 3))
@ -348,11 +347,11 @@ class SbbBinarizer:
ls_imgs = list(filter(is_image_filename, os.listdir(dir_in)))
for image_name in ls_imgs:
image_stem = image_name.split('.')[0]
print(image_name,'image_name')
# print(image_name,'image_name')
image = cv2.imread(os.path.join(dir_in,image_name) )
img_last = 0
for n, (model_file, model) in enumerate(self.models.items()):
self.log.info('Predicting with model %s [%s/%s]' % (model_file, n + 1, len(self.models.keys())))
self.log.info('Predicting %s with model %s [%s/%s]', image_name, model_file, n + 1, len(self.models.keys()))
res = self.predict(model, image, use_patches)