mirror of
				https://github.com/qurator-spk/dinglehopper.git
				synced 2025-10-31 17:34:15 +01:00 
			
		
		
		
	Add tooltips to fca report
This commit is contained in:
		
							parent
							
								
									53064bf833
								
							
						
					
					
						commit
						750ad00d1b
					
				
					 4 changed files with 85 additions and 17 deletions
				
			
		|  | @ -3,4 +3,8 @@ from .extracted_text import * | |||
| from .character_error_rate import * | ||||
| from .word_error_rate import * | ||||
| from .align import * | ||||
| from .flexible_character_accuracy import flexible_character_accuracy, split_matches | ||||
| from .flexible_character_accuracy import ( | ||||
|     flexible_character_accuracy, | ||||
|     split_matches, | ||||
|     Match, | ||||
| ) | ||||
|  |  | |||
|  | @ -14,7 +14,7 @@ from .ocr_files import extract | |||
| from .config import Config | ||||
| 
 | ||||
| 
 | ||||
| def gen_diff_report(gt_in, ocr_in, css_prefix, joiner, none, ops=None): | ||||
| def gen_diff_report(gt_in, ocr_in, css_prefix, joiner, none, matches=None): | ||||
|     gtx = "" | ||||
|     ocrx = "" | ||||
| 
 | ||||
|  | @ -42,7 +42,27 @@ def gen_diff_report(gt_in, ocr_in, css_prefix, joiner, none, ops=None): | |||
|         else: | ||||
|             return "{html_t}".format(html_t=html_t) | ||||
| 
 | ||||
|     if isinstance(gt_in, ExtractedText): | ||||
|     ops, ocr_ids = None, None | ||||
|     if matches: | ||||
|         gt_things, ocr_things, ops = split_matches(matches) | ||||
|         # we have to reconstruct the order of the ocr because we mixed it for fca | ||||
|         ocr_lines = [match.ocr for match in matches] | ||||
|         ocr_lines_sorted = sorted(ocr_lines, key=lambda x: x.line + x.start / 10000) | ||||
| 
 | ||||
|         ocr_line_region_id = {} | ||||
|         pos = 0 | ||||
|         for ocr_line in ocr_lines_sorted: | ||||
|             if ocr_line.line not in ocr_line_region_id.keys(): | ||||
|                 ocr_line_region_id[ocr_line.line] = ocr_in.segment_id_for_pos(pos) | ||||
|             pos += ocr_line.length | ||||
| 
 | ||||
|         ocr_ids = {None: None} | ||||
|         pos = 0 | ||||
|         for ocr_line in ocr_lines: | ||||
|             for _ in ocr_line.text: | ||||
|                 ocr_ids[pos] = ocr_line_region_id[ocr_line.line] | ||||
|                 pos += 1 | ||||
|     elif isinstance(gt_in, ExtractedText): | ||||
|         if not isinstance(ocr_in, ExtractedText): | ||||
|             raise TypeError() | ||||
|         # XXX splitting should be done in ExtractedText | ||||
|  | @ -61,10 +81,13 @@ def gen_diff_report(gt_in, ocr_in, css_prefix, joiner, none, ops=None): | |||
|         if g != o: | ||||
|             css_classes = "{css_prefix}diff{k} diff".format(css_prefix=css_prefix, k=k) | ||||
|             if isinstance(gt_in, ExtractedText): | ||||
|                 gt_id = gt_in.segment_id_for_pos(g_pos) if g is not None else None | ||||
|                 ocr_id = ocr_in.segment_id_for_pos(o_pos) if o is not None else None | ||||
|                 # Deletions and inserts only produce one id + None, UI must | ||||
|                 # support this, i.e. display for the one id produced | ||||
|                 gt_id = gt_in.segment_id_for_pos(g_pos) if g else None | ||||
|                 if ocr_ids: | ||||
|                     ocr_id = ocr_ids[o_pos] | ||||
|                 else: | ||||
|                     ocr_id = ocr_in.segment_id_for_pos(o_pos) if o else None | ||||
| 
 | ||||
|         gtx += joiner + format_thing(g, css_classes, gt_id) | ||||
|         ocrx += joiner + format_thing(o, css_classes, ocr_id) | ||||
|  | @ -111,15 +134,9 @@ def process(gt, ocr, report_prefix, *, metrics="cer,wer", textequiv_level="regio | |||
|             gt_words, ocr_words, css_prefix="w", joiner=" ", none="⋯" | ||||
|         ) | ||||
|     if "fca" in metrics: | ||||
|         fca, fca_matches = flexible_character_accuracy(gt_text.text, ocr_text.text) | ||||
|         fca_gt_segments, fca_ocr_segments, ops = split_matches(fca_matches) | ||||
|         fca, fca_matches = flexible_character_accuracy(gt_text, ocr_text) | ||||
|         fca_diff_report = gen_diff_report( | ||||
|             fca_gt_segments, | ||||
|             fca_ocr_segments, | ||||
|             css_prefix="c", | ||||
|             joiner="", | ||||
|             none="·", | ||||
|             ops=ops, | ||||
|             gt_text, ocr_text, css_prefix="c", joiner="", none="·", matches=fca_matches | ||||
|         ) | ||||
| 
 | ||||
|     def json_float(value): | ||||
|  |  | |||
|  | @ -17,7 +17,9 @@ from functools import lru_cache, reduce | |||
| from itertools import product, takewhile | ||||
| from typing import List, Tuple, Optional | ||||
| 
 | ||||
| from . import editops | ||||
| from multimethod import multimethod | ||||
| 
 | ||||
| from . import editops, ExtractedText | ||||
| 
 | ||||
| if sys.version_info.minor == 5: | ||||
|     from .flexible_character_accuracy_ds_35 import ( | ||||
|  | @ -35,6 +37,22 @@ else: | |||
|     ) | ||||
| 
 | ||||
| 
 | ||||
| @multimethod | ||||
| def flexible_character_accuracy( | ||||
|     gt: ExtractedText, ocr: ExtractedText | ||||
| ) -> Tuple[float, List[Match]]: | ||||
|     """Calculate the flexible character accuracy. | ||||
| 
 | ||||
|     Reference: contains steps 1-7 of the flexible character accuracy algorithm. | ||||
| 
 | ||||
|     :param gt: The ground truth text. | ||||
|     :param ocr: The text to compare the ground truth with. | ||||
|     :return: Score between 0 and 1 and match objects. | ||||
|     """ | ||||
|     return flexible_character_accuracy(gt.text, ocr.text) | ||||
| 
 | ||||
| 
 | ||||
| @multimethod | ||||
| def flexible_character_accuracy(gt: str, ocr: str) -> Tuple[float, List[Match]]: | ||||
|     """Calculate the flexible character accuracy. | ||||
| 
 | ||||
|  | @ -359,7 +377,7 @@ def split_matches(matches: List[Match]) -> Tuple[List[str], List[str], List[List | |||
|     :param matches: List of match objects. | ||||
|     :return: List of ground truth segments, ocr segments and editing operations. | ||||
|     """ | ||||
|     matches = sorted(matches, key=lambda x: x.gt.line + x.gt.start / 10000) | ||||
|     matches = sorted(matches, key=lambda m: m.gt.line + m.gt.start / 10000) | ||||
|     line = 0 | ||||
|     gt, ocr, ops = [], [], [] | ||||
|     for match in matches: | ||||
|  | @ -410,4 +428,4 @@ class Part(PartVersionSpecific): | |||
|         """ | ||||
|         text = self.text[rel_start:rel_end] | ||||
|         start = self.start + rel_start | ||||
|         return Part(text=text, line=self.line, start=start) | ||||
|         return Part(**{**self._asdict(), "text": text, "start": start}) | ||||
|  |  | |||
|  | @ -10,6 +10,7 @@ DOI: 10.1016/j.patrec.2020.02.003 | |||
| """ | ||||
| 
 | ||||
| import pytest | ||||
| from lxml import etree as ET | ||||
| 
 | ||||
| from ..flexible_character_accuracy import * | ||||
| 
 | ||||
|  | @ -101,11 +102,39 @@ def extended_case_to_text(gt, ocr): | |||
| 
 | ||||
| 
 | ||||
| @pytest.mark.parametrize(CASE_ARGS, [*SIMPLE_CASES, *COMPLEX_CASES]) | ||||
| def test_flexible_character_accuracy_simple(gt, ocr, first_line_score, all_line_score): | ||||
| def test_flexible_character_accuracy_str(gt, ocr, first_line_score, all_line_score): | ||||
|     score, _ = flexible_character_accuracy(gt, ocr) | ||||
|     assert score == pytest.approx(all_line_score) | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.parametrize(CASE_ARGS, [*SIMPLE_CASES, *COMPLEX_CASES]) | ||||
| def test_flexible_character_accuracy_xml(gt, ocr, first_line_score, all_line_score): | ||||
|     def get_extracted_text(text: str): | ||||
|         xml = '<?xml version="1.0"?>' | ||||
|         ns = "http://schema.primaresearch.org/PAGE/gts/pagecontent/2018-07-15" | ||||
| 
 | ||||
|         textline_tmpl = ( | ||||
|             '<TextLine id="l{0}"><TextEquiv><Unicode>{1}' | ||||
|             "</Unicode></TextEquiv></TextLine>" | ||||
|         ) | ||||
|         xml_tmpl = '{0}<TextRegion id="0" xmlns="{1}">{2}</TextRegion>' | ||||
| 
 | ||||
|         textlines = [ | ||||
|             textline_tmpl.format(i, line) for i, line in enumerate(text.splitlines()) | ||||
|         ] | ||||
|         xml_text = xml_tmpl.format(xml, ns, "".join(textlines)) | ||||
|         root = ET.fromstring(xml_text) | ||||
|         extracted_text = ExtractedText.from_text_segment( | ||||
|             root, {"page": ns}, textequiv_level="line" | ||||
|         ) | ||||
|         return extracted_text | ||||
| 
 | ||||
|     gt_text = get_extracted_text(gt) | ||||
|     ocr_text = get_extracted_text(ocr) | ||||
|     score, _ = flexible_character_accuracy(gt_text, ocr_text) | ||||
|     assert score == pytest.approx(all_line_score) | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.parametrize( | ||||
|     "config,ocr", | ||||
|     [ | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue