From d2f2a1e06b3632b11b0dbf7c40ed59bc50ee5d44 Mon Sep 17 00:00:00 2001 From: Robert Sachunsky Date: Wed, 3 Jun 2026 00:43:46 +0200 Subject: [PATCH] =?UTF-8?q?Eynollah=5Focr:=20correctly=20handle=20min=5Fco?= =?UTF-8?q?nf,=20improve=20writer=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - `min_conf_value_of_textline_text`: apply by skipping lines below threshold (instead of writing empty text), and delete their TextEquiv (if existing) - `write_ocr()`: simplify, and ensure consistency between line and region level text correctly --- src/eynollah/eynollah_ocr.py | 137 +++++++++++++---------------------- 1 file changed, 50 insertions(+), 87 deletions(-) diff --git a/src/eynollah/eynollah_ocr.py b/src/eynollah/eynollah_ocr.py index 40cbeaa..aeaabfe 100644 --- a/src/eynollah/eynollah_ocr.py +++ b/src/eynollah/eynollah_ocr.py @@ -41,7 +41,7 @@ from .utils.utils_ocr import ( @dataclass class EynollahOcrResult: extracted_texts_merged: List - extracted_confs_merged: Optional[List] + extracted_confs_merged: List cropped_lines_region_indexer: List total_bb_coordinates:List @@ -156,10 +156,6 @@ class Eynollah_ocr(Eynollah): conf = output.sequences_scores.exp().clamp(0.0, 1.0).tolist() else: conf = [1.0] * len(output.sequences) - if conf < self.min_conf_value_of_textline_text: - extracted_confs.extend(0) - extracted_texts.extend("") - continue text = self.model_zoo.get('trocr_processor').batch_decode( output.sequences, skip_special_tokens=True, @@ -327,17 +323,13 @@ class Eynollah_ocr(Eynollah): def nooov(x): return x != b'[UNK]' for pred, prob in zip(preds, probs): - if prob < self.min_conf_value_of_textline_text: - extracted_texts.append("") - extracted_confs.append(0) - else: - text = b''.join( - filter(nooov, - map(bytes, - (filter(None, char) - for char in pred.tolist())))).decode('utf-8') - extracted_texts.append(text) - extracted_confs.append(prob) + text = b''.join( + filter(nooov, + map(bytes, + (filter(None, char) + for char in pred.tolist())))).decode('utf-8') + extracted_texts.append(text) + extracted_confs.append(prob) del cropped_lines_rgb del cropped_lines_bin gc.collect() @@ -375,7 +367,6 @@ class Eynollah_ocr(Eynollah): extracted_texts_merged = result.extracted_texts_merged extracted_confs_merged = result.extracted_confs_merged - unique_cropped_lines_region_indexer = np.unique(cropped_lines_region_indexer) if out_image_with_text: image_text = Image.new("RGB", (img.shape[1], img.shape[0]), "white") draw = ImageDraw.Draw(image_text) @@ -403,78 +394,50 @@ class Eynollah_ocr(Eynollah): draw.text((text_x, text_y), extracted_texts_merged[indexer_text], fill="black", font=font) image_text.save(out_image_with_text) - text_by_textregion = [] - for ind in unique_cropped_lines_region_indexer: - ind = np.array(cropped_lines_region_indexer)==ind - extracted_texts_merged_un = np.array(extracted_texts_merged)[ind] - if len(extracted_texts_merged_un)>1: - text_by_textregion_ind = "" - next_glue = "" - for indt in range(len(extracted_texts_merged_un)): - if (extracted_texts_merged_un[indt].endswith('⸗') or - extracted_texts_merged_un[indt].endswith('-') or - extracted_texts_merged_un[indt].endswith('¬')): - text_by_textregion_ind += next_glue + extracted_texts_merged_un[indt][:-1] - next_glue = "" - else: - text_by_textregion_ind += next_glue + extracted_texts_merged_un[indt] - next_glue = " " - text_by_textregion.append(text_by_textregion_ind) - else: - text_by_textregion.append(" ".join(extracted_texts_merged_un)) + cropped_lines_region_indexer = np.array(cropped_lines_region_indexer) + for n_region, region in enumerate(page_tree.getroot().iter('{%s}TextRegion' % page_ns)): + lines_indexer = np.flatnonzero(cropped_lines_region_indexer == n_region) + if not len(lines_indexer): + continue - indexer = 0 - indexer_textregion = 0 - for nn in page_tree.getroot().iter(f'{{{page_ns}}}TextRegion'): - - is_textregion_text = False - for childtest in nn: - if childtest.tag.endswith("TextEquiv"): - is_textregion_text = True - - if not is_textregion_text: - text_subelement_textregion = ET.SubElement(nn, 'TextEquiv') - unicode_textregion = ET.SubElement(text_subelement_textregion, 'Unicode') - - - has_textline = False - for child_textregion in nn: - # FIXME: should remove Word level, if it already exists - if child_textregion.tag.endswith("TextLine"): - - is_textline_text = False - for childtest2 in child_textregion: - if childtest2.tag.endswith("TextEquiv"): - is_textline_text = True - - - if not is_textline_text: - text_subelement = ET.SubElement(child_textregion, 'TextEquiv') - if extracted_confs_merged: - text_subelement.set('conf', f"{extracted_confs_merged[indexer]:.2f}") - unicode_textline = ET.SubElement(text_subelement, 'Unicode') - unicode_textline.text = extracted_texts_merged[indexer] - else: - for childtest3 in child_textregion: - if childtest3.tag.endswith("TextEquiv"): - for child_uc in childtest3: - if child_uc.tag.endswith("Unicode"): - if extracted_confs_merged: - childtest3.set('conf', f"{extracted_confs_merged[indexer]:.2f}") - child_uc.text = extracted_texts_merged[indexer] - - indexer = indexer + 1 - has_textline = True - if has_textline: - if is_textregion_text: - for child4 in nn: - if child4.tag.endswith("TextEquiv"): - for childtr_uc in child4: - if childtr_uc.tag.endswith("Unicode"): - childtr_uc.text = text_by_textregion[indexer_textregion] + text_region = "" + next_glue = "" + for line_idx in lines_indexer: + if extracted_confs_merged[line_idx] < self.min_conf_value_of_textline_text: + continue + text_line = extracted_texts_merged[line_idx] + if (text_line.endswith(('⸗', '-', '¬')) and + # last line of a region can still be wrapped + # around columns or pages + line_idx < len(lines_indexer) - 1): + text_region += next_glue + text_line[:-1] + next_glue = "" else: - unicode_textregion.text = text_by_textregion[indexer_textregion] - indexer_textregion = indexer_textregion + 1 + text_region += next_glue + text_line + next_glue = " " + + region_textequiv = region.find('{%s}TextEquiv' % page_ns) + if region_textequiv is None: + region_textequiv = ET.SubElement(region, 'TextEquiv') + region_teunicode = region_textequiv.find('{%s}Unicode' % page_ns) + if region_teunicode is None: + region_teunicode = ET.SubElement(region_textequiv, 'Unicode') + region_teunicode.text = text_region + + for n_line, line in enumerate(region.iter('{%s}TextLine' % page_ns)): + line_textequiv = line.find('{%s}TextEquiv' % page_ns) + if line_textequiv is None: + line_textequiv = ET.SubElement(line, 'TextEquiv') + line_teunicode = line_textequiv.find('{%s}Unicode' % page_ns) + if line_teunicode is None: + line_teunicode = ET.SubElement(line_textequiv, 'Unicode') + + line_idx = lines_indexer[n_line] + if extracted_confs_merged[line_idx] < self.min_conf_value_of_textline_text: + line.remove(line_textequiv) + else: + line_textequiv.set('conf', str(round(extracted_confs_merged[line_idx], 2))) + line_teunicode.text = extracted_texts_merged[line_idx] ET.register_namespace("",page_ns) self.logger.info("output filename: '%s'", out_file_ocr)