mirror of
				https://github.com/qurator-spk/dinglehopper.git
				synced 2025-10-31 17:34:15 +01:00 
			
		
		
		
	Include fca as parameter and add some tests
This commit is contained in:
		
							parent
							
								
									9b76539936
								
							
						
					
					
						commit
						53064bf833
					
				
					 11 changed files with 219 additions and 65 deletions
				
			
		|  | @ -3,3 +3,4 @@ from .extracted_text import * | |||
| from .character_error_rate import * | ||||
| from .word_error_rate import * | ||||
| from .align import * | ||||
| from .flexible_character_accuracy import flexible_character_accuracy, split_matches | ||||
|  |  | |||
|  | @ -8,11 +8,12 @@ def align(t1, t2): | |||
|     return seq_align(s1, s2) | ||||
| 
 | ||||
| 
 | ||||
| def seq_align(s1, s2): | ||||
| def seq_align(s1, s2, ops=None): | ||||
|     """Align general sequences.""" | ||||
|     s1 = list(s1) | ||||
|     s2 = list(s2) | ||||
|     ops = seq_editops(s1, s2) | ||||
|     if not ops: | ||||
|         ops = seq_editops(s1, s2) | ||||
|     i = 0 | ||||
|     j = 0 | ||||
| 
 | ||||
|  |  | |||
|  | @ -6,6 +6,7 @@ from markupsafe import escape | |||
| from uniseg.graphemecluster import grapheme_clusters | ||||
| 
 | ||||
| from .character_error_rate import character_error_rate_n | ||||
| from .flexible_character_accuracy import flexible_character_accuracy, split_matches | ||||
| from .word_error_rate import word_error_rate_n, words_normalized | ||||
| from .align import seq_align | ||||
| from .extracted_text import ExtractedText | ||||
|  | @ -13,7 +14,7 @@ from .ocr_files import extract | |||
| from .config import Config | ||||
| 
 | ||||
| 
 | ||||
| def gen_diff_report(gt_in, ocr_in, css_prefix, joiner, none): | ||||
| def gen_diff_report(gt_in, ocr_in, css_prefix, joiner, none, ops=None): | ||||
|     gtx = "" | ||||
|     ocrx = "" | ||||
| 
 | ||||
|  | @ -53,7 +54,7 @@ def gen_diff_report(gt_in, ocr_in, css_prefix, joiner, none): | |||
| 
 | ||||
|     g_pos = 0 | ||||
|     o_pos = 0 | ||||
|     for k, (g, o) in enumerate(seq_align(gt_things, ocr_things)): | ||||
|     for k, (g, o) in enumerate(seq_align(gt_things, ocr_things, ops=ops)): | ||||
|         css_classes = None | ||||
|         gt_id = None | ||||
|         ocr_id = None | ||||
|  | @ -83,28 +84,43 @@ def gen_diff_report(gt_in, ocr_in, css_prefix, joiner, none): | |||
|     ) | ||||
| 
 | ||||
| 
 | ||||
| def process(gt, ocr, report_prefix, *, metrics=True, textequiv_level="region"): | ||||
| def process(gt, ocr, report_prefix, *, metrics="cer,wer", textequiv_level="region"): | ||||
|     """Check OCR result against GT. | ||||
| 
 | ||||
|     The @click decorators change the signature of the decorated functions, so we keep this undecorated version and use | ||||
|     Click on a wrapper. | ||||
|     The @click decorators change the signature of the decorated functions, | ||||
|     so we keep this undecorated version and use Click on a wrapper. | ||||
|     """ | ||||
|     cer, char_diff_report, n_characters = None, None, None | ||||
|     wer, word_diff_report, n_words = None, None, None | ||||
|     fca, fca_diff_report = None, None | ||||
| 
 | ||||
|     gt_text = extract(gt, textequiv_level=textequiv_level) | ||||
|     ocr_text = extract(ocr, textequiv_level=textequiv_level) | ||||
| 
 | ||||
|     cer, n_characters = character_error_rate_n(gt_text, ocr_text) | ||||
|     wer, n_words = word_error_rate_n(gt_text, ocr_text) | ||||
|     if "cer" in metrics or not metrics: | ||||
|         cer, n_characters = character_error_rate_n(gt_text, ocr_text) | ||||
|         char_diff_report = gen_diff_report( | ||||
|             gt_text, ocr_text, css_prefix="c", joiner="", none="·" | ||||
|         ) | ||||
| 
 | ||||
|     char_diff_report = gen_diff_report( | ||||
|         gt_text, ocr_text, css_prefix="c", joiner="", none="·" | ||||
|     ) | ||||
| 
 | ||||
|     gt_words = words_normalized(gt_text) | ||||
|     ocr_words = words_normalized(ocr_text) | ||||
|     word_diff_report = gen_diff_report( | ||||
|         gt_words, ocr_words, css_prefix="w", joiner=" ", none="⋯" | ||||
|     ) | ||||
|     if "wer" in metrics: | ||||
|         gt_words = words_normalized(gt_text) | ||||
|         ocr_words = words_normalized(ocr_text) | ||||
|         wer, n_words = word_error_rate_n(gt_text, ocr_text) | ||||
|         word_diff_report = gen_diff_report( | ||||
|             gt_words, ocr_words, css_prefix="w", joiner=" ", none="⋯" | ||||
|         ) | ||||
|     if "fca" in metrics: | ||||
|         fca, fca_matches = flexible_character_accuracy(gt_text.text, ocr_text.text) | ||||
|         fca_gt_segments, fca_ocr_segments, ops = split_matches(fca_matches) | ||||
|         fca_diff_report = gen_diff_report( | ||||
|             fca_gt_segments, | ||||
|             fca_ocr_segments, | ||||
|             css_prefix="c", | ||||
|             joiner="", | ||||
|             none="·", | ||||
|             ops=ops, | ||||
|         ) | ||||
| 
 | ||||
|     def json_float(value): | ||||
|         """Convert a float value to an JSON float. | ||||
|  | @ -137,8 +153,10 @@ def process(gt, ocr, report_prefix, *, metrics=True, textequiv_level="region"): | |||
|             n_characters=n_characters, | ||||
|             wer=wer, | ||||
|             n_words=n_words, | ||||
|             fca=fca, | ||||
|             char_diff_report=char_diff_report, | ||||
|             word_diff_report=word_diff_report, | ||||
|             fca_diff_report=fca_diff_report, | ||||
|             metrics=metrics, | ||||
|         ).dump(out_fn) | ||||
| 
 | ||||
|  | @ -148,7 +166,9 @@ def process(gt, ocr, report_prefix, *, metrics=True, textequiv_level="region"): | |||
| @click.argument("ocr", type=click.Path(exists=True)) | ||||
| @click.argument("report_prefix", type=click.Path(), default="report") | ||||
| @click.option( | ||||
|     "--metrics/--no-metrics", default=True, help="Enable/disable metrics and green/red" | ||||
|     "--metrics", | ||||
|     default="cer,wer", | ||||
|     help="Enable different metrics like cer, wer and fca.", | ||||
| ) | ||||
| @click.option( | ||||
|     "--textequiv-level", | ||||
|  | @ -166,12 +186,16 @@ def main(gt, ocr, report_prefix, metrics, textequiv_level, progress): | |||
| 
 | ||||
|     The files GT and OCR are usually a ground truth document and the result of | ||||
|     an OCR software, but you may use dinglehopper to compare two OCR results. In | ||||
|     that case, use --no-metrics to disable the then meaningless metrics and also | ||||
|     that case, use --metrics='' to disable the then meaningless metrics and also | ||||
|     change the color scheme from green/red to blue. | ||||
| 
 | ||||
|     The comparison report will be written to $REPORT_PREFIX.{html,json}, where | ||||
|     $REPORT_PREFIX defaults to "report". The reports include the character error | ||||
|     rate (CER) and the word error rate (WER). | ||||
|     $REPORT_PREFIX defaults to "report". Depending on your configuration the | ||||
|     reports include the character error rate (CER), the word error rate (WER) | ||||
|     and the flexible character accuracy (FCA). | ||||
| 
 | ||||
|     The metrics can be chosen via a comma separated combination of their acronyms | ||||
|     like "--metrics=cer,wer,fca". | ||||
| 
 | ||||
|     By default, the text of PAGE files is extracted on 'region' level. You may | ||||
|     use "--textequiv-level line" to extract from the level of TextLine tags. | ||||
|  |  | |||
|  | @ -270,7 +270,9 @@ def score_edit_distance(match: Match) -> int: | |||
|     return match.dist.delete + match.dist.insert + 2 * match.dist.replace | ||||
| 
 | ||||
| 
 | ||||
| def calculate_penalty(gt: "Part", ocr: "Part", match: Match, coef: Coefficients) -> float: | ||||
| def calculate_penalty( | ||||
|     gt: "Part", ocr: "Part", match: Match, coef: Coefficients | ||||
| ) -> float: | ||||
|     """Calculate the penalty for a given match. | ||||
| 
 | ||||
|     For details and discussion see Section 3 in doi:10.1016/j.patrec.2020.02.003. | ||||
|  | @ -325,6 +327,8 @@ def character_accuracy(edits: Distance) -> float: | |||
|     if not chars and not errors: | ||||
|         # comparison of empty strings is considered a full match | ||||
|         score = 1.0 | ||||
|     elif not chars: | ||||
|         score = -errors | ||||
|     else: | ||||
|         score = 1.0 - errors / chars | ||||
|     return score | ||||
|  | @ -349,25 +353,25 @@ def initialize_lines(text: str) -> List["Part"]: | |||
|     return lines | ||||
| 
 | ||||
| 
 | ||||
| def combine_lines(matches: List[Match]) -> Tuple[str, str]: | ||||
|     """Combines the matches to aligned texts. | ||||
| 
 | ||||
|     TODO: just hacked, needs tests and refinement. Also missing insert/delete marking. | ||||
| def split_matches(matches: List[Match]) -> Tuple[List[str], List[str], List[List]]: | ||||
|     """Extracts text segments and editing operations in separate lists. | ||||
| 
 | ||||
|     :param matches: List of match objects. | ||||
|     :return: the aligned ground truth and ocr as texts. | ||||
|     :return: List of ground truth segments, ocr segments and editing operations. | ||||
|     """ | ||||
|     matches.sort(key=lambda x: x.gt.line + x.gt.start / 10000) | ||||
|     matches = sorted(matches, key=lambda x: x.gt.line + x.gt.start / 10000) | ||||
|     line = 0 | ||||
|     gt, ocr = "", "" | ||||
|     gt, ocr, ops = [], [], [] | ||||
|     for match in matches: | ||||
|         if match.gt.line > line: | ||||
|             gt += "\n" | ||||
|             ocr += "\n" | ||||
|             line += 1 | ||||
|         gt += match.gt.text | ||||
|         ocr += match.ocr.text | ||||
|     return gt, ocr | ||||
|             gt.append("\n") | ||||
|             ocr.append("\n") | ||||
|             ops.append([]) | ||||
|         line = match.gt.line | ||||
|         gt.append(match.gt.text) | ||||
|         ocr.append(match.ocr.text) | ||||
|         ops.append(match.ops) | ||||
|     return gt, ocr, ops | ||||
| 
 | ||||
| 
 | ||||
| class Part(PartVersionSpecific): | ||||
|  |  | |||
|  | @ -19,9 +19,10 @@ | |||
|       ], | ||||
|       "parameters": { | ||||
|         "metrics": { | ||||
|           "type": "boolean", | ||||
|           "default": true, | ||||
|           "description": "Enable/disable metrics and green/red" | ||||
|           "type": "string", | ||||
|           "enum": ["", "cer", "wer", "fca", "cer,wer", "cer,fca", "wer,fca", "cer,wer,fca"], | ||||
|           "default": "cer,wer", | ||||
|           "description": "Enable different metrics like cer, wer and fca." | ||||
|         }, | ||||
|         "textequiv_level": { | ||||
|           "type": "string", | ||||
|  |  | |||
|  | @ -40,16 +40,31 @@ | |||
| 
 | ||||
| {% if metrics %} | ||||
| <h2>Metrics</h2> | ||||
| <p>CER: {{ cer|round(4) }}</p> | ||||
| <p>WER: {{ wer|round(4) }}</p> | ||||
|     {% if cer %} | ||||
|     <p>CER: {{ cer|round(4) }}</p> | ||||
|     {% endif %} | ||||
|     {% if wer %} | ||||
|     <p>WER: {{ wer|round(4) }}</p> | ||||
|     {% endif %} | ||||
|     {% if fca %} | ||||
|     <p>FCA: {{ fca|round(4) }}</p> | ||||
|     {% endif %} | ||||
| {% endif %} | ||||
| 
 | ||||
| {% if char_diff_report %} | ||||
| <h2>Character differences</h2> | ||||
| {{ char_diff_report }} | ||||
| {% endif %} | ||||
| 
 | ||||
| {% if word_diff_report %} | ||||
| <h2>Word differences</h2> | ||||
| {{ word_diff_report }} | ||||
| {% endif %} | ||||
| 
 | ||||
| {% if fca_diff_report %} | ||||
| <h2>Flexible character accuracy differences</h2> | ||||
| {{ fca_diff_report }} | ||||
| {% endif %} | ||||
| 
 | ||||
| </div> | ||||
| 
 | ||||
|  |  | |||
|  | @ -1,10 +1,11 @@ | |||
| { | ||||
|     "gt": "{{ gt }}", | ||||
|     "ocr": "{{ ocr }}", | ||||
| {% if metrics %} | ||||
|     "cer": {{ cer|json_float }}, | ||||
|     "wer": {{ wer|json_float }}, | ||||
|     {% if cer %}"cer": {{ cer|json_float }},{% endif %} | ||||
|     {% if wer %}"wer": {{ wer|json_float }},{% endif %} | ||||
|     {% if fca %}"fca": {{ fca|json_float }},{% endif %} | ||||
|     {% if n_characters %}"n_characters": {{ n_characters }},{% endif %} | ||||
|     {% if n_words %}"n_words": {{ n_words }},{% endif %} | ||||
| {% endif %} | ||||
|     "n_characters": {{ n_characters }}, | ||||
|     "n_words": {{ n_words }} | ||||
|     "gt": "{{ gt }}", | ||||
|     "ocr": "{{ ocr }}" | ||||
| } | ||||
|  |  | |||
|  | @ -117,13 +117,13 @@ def test_flexible_character_accuracy_simple(gt, ocr, first_line_score, all_line_ | |||
|         ), | ||||
|         ( | ||||
|             "Config II", | ||||
|             '1 hav\nnospecial\ntalents. Alberto\n' | ||||
|             "1 hav\nnospecial\ntalents. Alberto\n" | ||||
|             'I am one Emstein\npassionate\ncuriousity."', | ||||
|         ), | ||||
|         ( | ||||
|             "Config III", | ||||
|             'Alberto\nEmstein\n' | ||||
|             '1 hav\nnospecial\ntalents.\n' | ||||
|             "Alberto\nEmstein\n" | ||||
|             "1 hav\nnospecial\ntalents.\n" | ||||
|             'I am one\npassionate\ncuriousity."', | ||||
|         ), | ||||
|     ], | ||||
|  | @ -323,6 +323,8 @@ def test_character_accuracy_matches(matches, expected_dist): | |||
|         (Distance(), 1), | ||||
|         (Distance(match=1), 1), | ||||
|         (Distance(replace=1), 0), | ||||
|         (Distance(delete=1), 0), | ||||
|         (Distance(insert=1), -1), | ||||
|         (Distance(match=1, insert=1), 0), | ||||
|         (Distance(match=1, insert=2), 1 - 2 / 1), | ||||
|         (Distance(match=2, insert=1), 0.5), | ||||
|  | @ -377,9 +379,42 @@ def test_initialize_lines(): | |||
|     assert lines == [line3, line1, line2] | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.xfail | ||||
| def test_combine_lines(): | ||||
|     assert False | ||||
| @pytest.mark.parametrize( | ||||
|     "matches,expected_gt,expected_ocr,expected_ops", | ||||
|     [ | ||||
|         ([], [], [], []), | ||||
|         ( | ||||
|             [Match(gt=Part(text="aaa"), ocr=Part(text="aaa"), dist=Distance(), ops=[])], | ||||
|             ["aaa"], | ||||
|             ["aaa"], | ||||
|             [[]], | ||||
|         ), | ||||
|         ( | ||||
|             [ | ||||
|                 Match( | ||||
|                     gt=Part(text="aaa", line=1), | ||||
|                     ocr=Part(text="aaa"), | ||||
|                     dist=Distance(), | ||||
|                     ops=[], | ||||
|                 ), | ||||
|                 Match( | ||||
|                     gt=Part(text="bbb", line=2), | ||||
|                     ocr=Part(text="bbc"), | ||||
|                     dist=Distance(), | ||||
|                     ops=[["replace", 2]], | ||||
|                 ), | ||||
|             ], | ||||
|             ["\n", "aaa", "\n", "bbb"], | ||||
|             ["\n", "aaa", "\n", "bbc"], | ||||
|             [[], [], [], [["replace", 2]]], | ||||
|         ), | ||||
|     ], | ||||
| ) | ||||
| def test_split_matches(matches, expected_gt, expected_ocr, expected_ops): | ||||
|     gt_segments, ocr_segments, ops = split_matches(matches) | ||||
|     assert gt_segments == expected_gt | ||||
|     assert ocr_segments == expected_ocr | ||||
|     assert ops == expected_ops | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.parametrize( | ||||
|  |  | |||
|  | @ -1,4 +1,5 @@ | |||
| import json | ||||
| from itertools import combinations | ||||
| 
 | ||||
| import pytest | ||||
| from .util import working_directory | ||||
|  | @ -7,9 +8,19 @@ from ..cli import process | |||
| 
 | ||||
| 
 | ||||
| @pytest.mark.integration | ||||
| def test_cli_json(tmp_path): | ||||
| @pytest.mark.parametrize( | ||||
|     "metrics", | ||||
|     [ | ||||
|         *(("",), ("cer",), ("wer",), ("fca",)), | ||||
|         *combinations(("cer", "wer", "fca"), 2), | ||||
|         ("cer", "wer", "fca"), | ||||
|     ], | ||||
| ) | ||||
| def test_cli_json(metrics, tmp_path): | ||||
|     """Test that the cli/process() yields a loadable JSON report""" | ||||
| 
 | ||||
|     expected_values = {"cer": 0.2, "wer": 1.0, "fca": 0.8} | ||||
| 
 | ||||
|     with working_directory(str(tmp_path)): | ||||
|         with open("gt.txt", "w") as gtf: | ||||
|             gtf.write("AAAAA") | ||||
|  | @ -18,12 +29,18 @@ def test_cli_json(tmp_path): | |||
| 
 | ||||
|         with open("gt.txt", "r") as gtf: | ||||
|             print(gtf.read()) | ||||
|         process("gt.txt", "ocr.txt", "report") | ||||
| 
 | ||||
|         process("gt.txt", "ocr.txt", "report", metrics=",".join(metrics)) | ||||
| 
 | ||||
|         with open("report.json", "r") as jsonf: | ||||
|             print(jsonf.read()) | ||||
|         with open("report.json", "r") as jsonf: | ||||
|             j = json.load(jsonf) | ||||
|             assert j["cer"] == pytest.approx(0.2) | ||||
|             for metric, expected_value in expected_values.items(): | ||||
|                 if metric in metrics: | ||||
|                     assert j[metric] == pytest.approx(expected_values[metric]) | ||||
|                 else: | ||||
|                     assert metric not in j.keys() | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.integration | ||||
|  | @ -36,7 +53,8 @@ def test_cli_json_cer_is_infinity(tmp_path): | |||
|         with open("ocr.txt", "w") as ocrf: | ||||
|             ocrf.write("Not important") | ||||
| 
 | ||||
|         process("gt.txt", "ocr.txt", "report") | ||||
|         process("gt.txt", "ocr.txt", "report", metrics="cer,wer,fca") | ||||
|         with open("report.json", "r") as jsonf: | ||||
|             j = json.load(jsonf) | ||||
|             assert j["cer"] == pytest.approx(float("inf")) | ||||
|             assert j["fca"] == pytest.approx(-13) | ||||
|  |  | |||
|  | @ -0,0 +1,50 @@ | |||
| import os | ||||
| 
 | ||||
| import pytest | ||||
| from lxml import etree as ET | ||||
| 
 | ||||
| from .. import distance, page_text | ||||
| from .. import flexible_character_accuracy, split_matches | ||||
| 
 | ||||
| data_dir = os.path.join( | ||||
|     os.path.dirname(os.path.abspath(__file__)), "data", "table-order" | ||||
| ) | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.parametrize("file", ["table-order-0002.xml", "table-no-reading-order.xml"]) | ||||
| @pytest.mark.integration | ||||
| def test_fac_ignoring_reading_order(file): | ||||
|     expected = "1\n2\n3\n4\n5\n6\n7\n8\n9" | ||||
| 
 | ||||
|     gt = page_text(ET.parse(os.path.join(data_dir, "table-order-0001.xml"))) | ||||
|     assert gt == expected | ||||
| 
 | ||||
|     ocr = page_text(ET.parse(os.path.join(data_dir, file))) | ||||
|     assert distance(gt, ocr) > 0 | ||||
| 
 | ||||
|     fac, matches = flexible_character_accuracy(gt, ocr) | ||||
|     assert fac == pytest.approx(1.0) | ||||
| 
 | ||||
|     gt_segments, ocr_segments, ops = split_matches(matches) | ||||
|     assert not any(ops) | ||||
|     assert "".join(gt_segments) == expected | ||||
|     assert "".join(ocr_segments) == expected | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.parametrize( | ||||
|     "file,expected_text", | ||||
|     [ | ||||
|         ("table-order-0001.xml", "1\n2\n3\n4\n5\n6\n7\n8\n9"), | ||||
|         ("table-order-0002.xml", "1\n4\n7\n2\n5\n8\n3\n6\n9"), | ||||
|         ("table-no-reading-order.xml", "5\n6\n7\n8\n9\n1\n2\n3\n4"), | ||||
|         ("table-unordered.xml", "5\n6\n7\n8\n9\n1\n2\n3\n4"), | ||||
|     ], | ||||
| ) | ||||
| @pytest.mark.integration | ||||
| def test_reading_order_settings(file, expected_text): | ||||
|     if "table-unordered.xml" == file: | ||||
|         with pytest.raises(NotImplementedError): | ||||
|             page_text(ET.parse(os.path.join(data_dir, file))) | ||||
|     else: | ||||
|         ocr = page_text(ET.parse(os.path.join(data_dir, file))) | ||||
|         assert ocr == expected_text | ||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue