mirror of
				https://github.com/qurator-spk/dinglehopper.git
				synced 2025-10-31 17:34:15 +01:00 
			
		
		
		
	First draft of flexible character accuracy
This commit is contained in:
		
							parent
							
								
									bd324331e6
								
							
						
					
					
						commit
						d7a74fa58b
					
				
					 3 changed files with 685 additions and 4 deletions
				
			
		
							
								
								
									
										394
									
								
								qurator/dinglehopper/flexible_character_accuracy.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										394
									
								
								qurator/dinglehopper/flexible_character_accuracy.py
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,394 @@ | |||
| """ | ||||
| Implementation of the flexible character accuracy | ||||
| 
 | ||||
| Citation: | ||||
|     Flexible character accuracy measure for reading-order-independent evaluation | ||||
|     C. Clausner, S. Pletschacher, A. Antonacopoulos | ||||
|     Pattern Recognition Letters, Volume 131, March 2020, Pages 390-397 | ||||
| Link: http://www.primaresearch.org/publications/PRL_Clausner_FlexibleCharacterAccuracy | ||||
| DOI: https://doi.org/10.1016/j.patrec.2020.02.003 | ||||
| 
 | ||||
| Note that we deviated from the original algorithm at some places. | ||||
| """ | ||||
| 
 | ||||
| from collections import Counter | ||||
| from functools import lru_cache, reduce | ||||
| from itertools import product, takewhile | ||||
| from typing import List, NamedTuple, Tuple, Optional | ||||
| 
 | ||||
| from . import editops | ||||
| 
 | ||||
| 
 | ||||
| def flexible_character_accuracy(gt: str, ocr: str) -> Tuple[float, List["Match"]]: | ||||
|     """Calculate the flexible character accuracy. | ||||
| 
 | ||||
|     Reference: contains steps 1-7 of the flexible character accuracy algorithm. | ||||
| 
 | ||||
|     :param gt: The ground truth text. | ||||
|     :param ocr: The text to compare the ground truth with. | ||||
|     :return: Score between 0 and 1 and match objects. | ||||
|     """ | ||||
| 
 | ||||
|     best_score = -float('inf') | ||||
|     best_matches = [] | ||||
|     # TODO: this should be configurable | ||||
|     combinations = product(range(15, 31, 5), | ||||
|                            range(0, 24, 3), | ||||
|                            range(0, 4, 1), | ||||
|                            range(0, 6, 1)) | ||||
|     # TODO: place to parallelize the algorithm | ||||
|     for (edit_dist, length_diff, offset, length) in combinations: | ||||
|         coef = Coefficients( | ||||
|             edit_dist=edit_dist, | ||||
|             length_diff=length_diff, | ||||
|             offset=offset, | ||||
|             length=length | ||||
|         ) | ||||
|         # Steps 1 - 6 of the flexible character accuracy algorithm. | ||||
|         matches = match_with_coefficients(gt, ocr, coef) | ||||
|         # Step 7 of the flexible character accuracy algorithm. | ||||
|         score = character_accuracy_for_matches(matches) | ||||
|         if score > best_score: | ||||
|             best_score = score | ||||
|             best_matches = matches | ||||
|         # early breaking: we only need one perfect fit | ||||
|         if best_score >= 1: | ||||
|             break | ||||
|     return best_score, best_matches | ||||
| 
 | ||||
| 
 | ||||
| def match_with_coefficients(gt: str, ocr: str, coef: "Coefficients") -> List["Match"]: | ||||
|     """Match ground truth with ocr and considers a given set of coefficients. | ||||
| 
 | ||||
|     Reference: contains steps 1 - 6 of the flexible character accuracy algorithm. | ||||
| 
 | ||||
|     :return: A list of match objects to score and align the texts. | ||||
|     """ | ||||
|     # Steps 1 and 2 of the flexible character accuracy algorithm. | ||||
|     ocr_lines = initialize_lines(ocr) | ||||
|     gt_lines = initialize_lines(gt) | ||||
| 
 | ||||
|     matches = [] | ||||
| 
 | ||||
|     # Step 5 of the flexible character accuracy algorithm. | ||||
|     while len(gt_lines) != 0 and len(ocr_lines) != 0: | ||||
|         # Steps 3 and 4 of the flexible character accuracy algorithm. | ||||
|         match = match_longest_gt_lines(gt_lines, ocr_lines, coef) | ||||
|         if match: | ||||
|             matches.append(match) | ||||
| 
 | ||||
|     # Step 6 of the flexible character accuracy algorithm. | ||||
|     # remaining lines are considered as deletes and inserts | ||||
|     deletes = [distance(line, Part(text="", line=line.line, start=line.start)) | ||||
|                for line in gt_lines] | ||||
|     inserts = [distance(Part(text="", line=line.line, start=line.start), line) | ||||
|                for line in ocr_lines] | ||||
| 
 | ||||
|     return [*matches, *deletes, *inserts] | ||||
| 
 | ||||
| 
 | ||||
| def match_longest_gt_lines(gt_lines: List["Part"], | ||||
|                            ocr_lines: List["Part"], | ||||
|                            coef: "Coefficients") -> Optional["Match"]: | ||||
|     """Find the best match for the longest line(s) in ground truth. | ||||
| 
 | ||||
|     The longest lines in ground truth are matched against lines in ocr to find the | ||||
|     best matching pair. This pair is then either considered a match on full line | ||||
| 
 | ||||
|     Reference: contains steps 3 and 4 of the flexible character accuracy algorithm. | ||||
| 
 | ||||
|     :return: Possible match object. | ||||
|     """ | ||||
|     best_score, best_match, best_gt, best_ocr = -float('inf'), None, None, None | ||||
|     if not ocr_lines: | ||||
|         return best_match | ||||
| 
 | ||||
|     # Step 3 of the flexible character accuracy algorithm (variation). | ||||
|     # Instead of the longest line we take all longest lines with equal length. | ||||
|     length = min(gt_lines[0].length, ocr_lines[0].length) | ||||
|     for gt_line in takewhile(lambda line: line.length >= length, gt_lines): | ||||
|         match, ocr_line = match_gt_line(gt_line, ocr_lines, coef) | ||||
|         score = 0 if not match else character_accuracy(match.dist) | ||||
|         if score > best_score: | ||||
|             best_score, best_match, best_gt, best_ocr = score, match, gt_line, ocr_line | ||||
| 
 | ||||
|     # Step 4 of the flexible character accuracy algorithm. | ||||
|     # Remove on full match or split. | ||||
|     if best_match and best_gt: | ||||
|         splitted = remove_or_split(best_gt, best_match.gt, gt_lines) | ||||
|         if splitted: | ||||
|             gt_lines.append(best_match.gt) | ||||
|             best_match = None | ||||
|     if best_match and best_ocr: | ||||
|         remove_or_split(best_ocr, best_match.ocr, ocr_lines) | ||||
| 
 | ||||
|     return best_match | ||||
| 
 | ||||
| 
 | ||||
| def match_gt_line(gt_line: "Part", | ||||
|                   ocr_lines: List["Part"], | ||||
|                   coef: "Coefficients") -> Tuple[Optional["Match"], | ||||
|                                                  Optional["Part"]]: | ||||
|     """Match the given ground truth line against all the lines in ocr. | ||||
| 
 | ||||
|     Reference: contains steps 3 of the flexible character accuracy algorithm. | ||||
| 
 | ||||
|     TODO: Make penalty function configurable? | ||||
|     TODO: Add empty ocr line to avoid having nonesense one character alignments? | ||||
| 
 | ||||
|     :return: Match object and the matched ocr line. | ||||
|     """ | ||||
|     min_penalty = float('inf') | ||||
|     best_match, best_ocr = None, None | ||||
|     for ocr_line in ocr_lines: | ||||
|         match = match_lines(gt_line, ocr_line) | ||||
|         penalty = calculate_penalty(gt_line, ocr_line, match, coef) | ||||
|         if penalty < min_penalty: | ||||
|             min_penalty, best_match, best_ocr = penalty, match, ocr_line | ||||
|     return best_match, best_ocr | ||||
| 
 | ||||
| 
 | ||||
| def remove_or_split(original: "Part", | ||||
|                     match: "Part", | ||||
|                     lines: List["Part"]) -> bool: | ||||
|     """Removes the matched line or splits it into parts. | ||||
| 
 | ||||
|     Reference: contains step 4 of the flexible character accuracy algorithm. | ||||
| 
 | ||||
|     :return: True if line was splitted. | ||||
|     """ | ||||
|     splitted = False | ||||
|     del lines[lines.index(original)] | ||||
|     if match.length < original.length: | ||||
|         lines.extend(original.split(match)) | ||||
|         # sorting for ocr is not mentioned in the paper, but is used as tie breaking =) | ||||
|         lines.sort(key=lambda x: x.length, reverse=True) | ||||
|         splitted = True | ||||
|     return splitted | ||||
| 
 | ||||
| 
 | ||||
| @lru_cache(maxsize=1000000) | ||||
| def match_lines(gt_line: "Part", ocr_line: "Part") -> Optional["Match"]: | ||||
|     """Matches two lines searching for a local alignment. | ||||
| 
 | ||||
|     The shorter line is moved along the longer line | ||||
|     until the editing distance is minimized. | ||||
| 
 | ||||
|     Reference: see figure 2 in the paper. | ||||
| 
 | ||||
|     TODO: make distance function configurable? | ||||
| 
 | ||||
|     :return: Match object if one is found. | ||||
|     """ | ||||
|     min_length = min(gt_line.length, ocr_line.length) | ||||
|     best_match = None | ||||
|     if min_length == 0: | ||||
|         return best_match | ||||
|     length_diff = gt_line.length - ocr_line.length | ||||
|     min_edit_dist = float('inf') | ||||
|     # TODO: handle deletes and replacements by extending the length. | ||||
|     for i in range(0, max(1, length_diff + 1)): | ||||
|         for j in range(0, max(1, -1 * length_diff + 1)): | ||||
|             match = distance(gt_line.substring(rel_start=i, rel_end=i + min_length), | ||||
|                              ocr_line.substring(rel_start=j, rel_end=j + min_length)) | ||||
|             edit_dist = score_edit_distance(match) | ||||
|             if edit_dist < min_edit_dist: | ||||
|                 min_edit_dist = edit_dist | ||||
|                 best_match = match | ||||
|     return best_match | ||||
| 
 | ||||
| 
 | ||||
| @lru_cache(maxsize=1000000) | ||||
| def distance(gt: "Part", ocr: "Part") -> "Match": | ||||
|     """Calculate the editing distance between the two lines. | ||||
| 
 | ||||
|     Using the already available `editops()` function with the Levenshtein distance. | ||||
| 
 | ||||
|     TODO: replace with @cache annotation in Python 3.9 | ||||
| 
 | ||||
|     :return: Match object containing the lines and the editing operations. | ||||
|     """ | ||||
|     ops = editops(gt.text, ocr.text) | ||||
|     edits = Counter([edit[0] for edit in ops]) | ||||
|     edits["match"] = gt.length - edits["delete"] - edits["replace"] | ||||
|     return Match(gt=gt, ocr=ocr, dist=Distance(**edits), ops=ops) | ||||
| 
 | ||||
| 
 | ||||
| def score_edit_distance(match: "Match") -> int: | ||||
|     """Calculate edit distance for a match. | ||||
| 
 | ||||
|     Formula: $deletes + inserts + 2 * replacements$ | ||||
| 
 | ||||
|     :return: Sum of deletes, inserts and replacements. | ||||
|     """ | ||||
|     return match.dist.delete + match.dist.insert + 2 * match.dist.replace | ||||
| 
 | ||||
| 
 | ||||
| def calculate_penalty(gt: "Part", ocr: "Part", match: "Match", | ||||
|                       coef: "Coefficients") -> float: | ||||
|     """Calculate the penalty for a given match. | ||||
| 
 | ||||
|     For details and discussion see Section 3 in doi:10.1016/j.patrec.2020.02.003. | ||||
| 
 | ||||
|     :return: Penalty for the given match. | ||||
|     """ | ||||
|     min_edit_dist = score_edit_distance(match) | ||||
|     length_diff = abs(gt.length - ocr.length) | ||||
|     substring_length = min(gt.length, ocr.length) | ||||
|     offset = 0.0 | ||||
|     if length_diff > 1: | ||||
|         substring_pos = max(match.gt.start - gt.start, match.ocr.start - ocr.start) | ||||
|         offset = length_diff / 2 - abs(substring_pos - length_diff / 2) | ||||
|     return (min_edit_dist * coef.edit_dist | ||||
|             + length_diff * coef.length_diff | ||||
|             + offset * coef.offset | ||||
|             - substring_length * coef.length) | ||||
| 
 | ||||
| 
 | ||||
| def character_accuracy_for_matches(matches: List["Match"]) -> float: | ||||
|     """Character accuracy of a full text represented by a list of matches. | ||||
| 
 | ||||
|     See other `character_accuracy` for details. | ||||
| 
 | ||||
|     """ | ||||
|     agg: Counter = reduce(lambda acc, match: acc + Counter(match.dist._asdict()), | ||||
|                           matches, Counter()) | ||||
| 
 | ||||
|     score = character_accuracy(Distance(**agg)) | ||||
|     return score | ||||
| 
 | ||||
| 
 | ||||
| def character_accuracy(edits: "Distance") -> float: | ||||
|     """Character accuracy calculated by necessary edit operations. | ||||
| 
 | ||||
|     Edit operations are needed edits to transform one text into another. | ||||
| 
 | ||||
|     The character accuracy is given by $1 - errors / characters$. | ||||
| 
 | ||||
|     Errors are replacements, deletes and inserts. | ||||
| 
 | ||||
|     Note that is is possible to have more errors than characters in which case the | ||||
|     character accuracy turns negative. | ||||
| 
 | ||||
|     Comparing two empty strings (having no edits) results in a character accuracy of 1. | ||||
|     """ | ||||
|     errors = edits.replace + edits.delete + edits.insert | ||||
|     chars = edits.match + edits.replace + edits.delete | ||||
|     if not chars and not errors: | ||||
|         # comparison of empty strings is considered a full match | ||||
|         score = 1 | ||||
|     else: | ||||
|         score = 1 - errors / chars | ||||
|     return score | ||||
| 
 | ||||
| 
 | ||||
| def initialize_lines(text: str) -> List["Part"]: | ||||
|     """Splits a text into lines and converts them to our line data object. | ||||
| 
 | ||||
|     The line objects are sorted by their length descending. | ||||
| 
 | ||||
|     Reference: contains steps 1 and 2 of the flexible character accuracy algorithm. | ||||
| 
 | ||||
|     :param text: Text to split into lines. | ||||
|     :return: List of sorted line objects. | ||||
|     """ | ||||
|     lines = [Part(text=line, line=i, start=0) | ||||
|              for i, line in enumerate(text.splitlines()) | ||||
|              if len(line) > 0] | ||||
|     lines.sort(key=lambda x: x.length, reverse=True) | ||||
|     return lines | ||||
| 
 | ||||
| 
 | ||||
| def combine_lines(matches: List["Match"]) -> Tuple[str, str]: | ||||
|     """Combines the matches to aligned texts. | ||||
| 
 | ||||
|     TODO: just hacked, needs tests and refinement. Also missing insert/delete marking. | ||||
| 
 | ||||
|     :param matches: List of match objects. | ||||
|     :return: the aligned ground truth and ocr as texts. | ||||
|     """ | ||||
|     matches.sort(key=lambda x: x.gt.line + x.gt.start / 10000) | ||||
|     line = 0 | ||||
|     gt, ocr = "", "" | ||||
|     for match in matches: | ||||
|         if match.gt.line > line: | ||||
|             gt += "\n" | ||||
|             ocr += "\n" | ||||
|             line += 1 | ||||
|         gt += match.gt.text | ||||
|         ocr += match.ocr.text | ||||
|     return gt, ocr | ||||
| 
 | ||||
| 
 | ||||
| class Part(NamedTuple): | ||||
|     """Represent a line or part of a line. | ||||
| 
 | ||||
|     This data object is maintained to be able to reproduce the original text. | ||||
|     """ | ||||
|     text: str = "" | ||||
|     line: int = 0 | ||||
|     start: int = 0 | ||||
| 
 | ||||
|     @property | ||||
|     def end(self) -> int: | ||||
|         return self.start + self.length | ||||
| 
 | ||||
|     @property | ||||
|     def length(self) -> int: | ||||
|         return len(self.text) | ||||
| 
 | ||||
|     def split(self, split: "Part") -> List["Part"]: | ||||
|         """Split the line part by another and returns the remaining parts. | ||||
| 
 | ||||
|         `abc.split("b")` will return ´["a", "c"]`. | ||||
| 
 | ||||
|         :param split: The line part we want to use to split. | ||||
|         :return: The parts before and after the split. | ||||
|         """ | ||||
|         rest = [] | ||||
|         if self.start < split.start: | ||||
|             rest.append(self.substring(rel_end=split.start - self.start)) | ||||
|         if split.end < self.end: | ||||
|             rest.append(self.substring(rel_start=split.end - self.start)) | ||||
|         return rest | ||||
| 
 | ||||
|     def substring(self, rel_start: int = 0, rel_end: int = None) -> "Part": | ||||
|         """Get part of the given line. | ||||
| 
 | ||||
|         Automatically handles the offset of the line. | ||||
|         Therefore `substring(rel_start=2)` will return `Part[start+rel_start:]`. | ||||
| 
 | ||||
|         :param rel_start: start relative to the part of the line. | ||||
|         :param rel_end: end relative to the part of the line. | ||||
|         :return: Extracted part of the given part of the line. | ||||
|         """ | ||||
|         text = self.text[rel_start:rel_end] | ||||
|         start = self.start + rel_start | ||||
|         return Part(text=text, line=self.line, start=start) | ||||
| 
 | ||||
| 
 | ||||
| class Distance(NamedTuple): | ||||
|     """Represent distance between two sequences.""" | ||||
|     match: int = 0 | ||||
|     replace: int = 0 | ||||
|     delete: int = 0 | ||||
|     insert: int = 0 | ||||
| 
 | ||||
| 
 | ||||
| class Match(NamedTuple): | ||||
|     """Represent a calculated match between ground truth and the ocr result.""" | ||||
|     gt: "Part" | ||||
|     ocr: "Part" | ||||
|     dist: "Distance" | ||||
|     ops: List | ||||
| 
 | ||||
| 
 | ||||
| class Coefficients(NamedTuple): | ||||
|     """Coefficients to calculate penalty for substrings. | ||||
| 
 | ||||
|     See Section 3 in doi:10.1016/j.patrec.2020.02.003 | ||||
|     """ | ||||
|     edit_dist: int = 25 | ||||
|     length_diff: int = 20 | ||||
|     offset: int = 1 | ||||
|     length: int = 4 | ||||
							
								
								
									
										291
									
								
								qurator/dinglehopper/tests/test_flexible_character_accuracy.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										291
									
								
								qurator/dinglehopper/tests/test_flexible_character_accuracy.py
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,291 @@ | |||
| """ | ||||
| Tests for the implementation of the flexible character accuracy | ||||
| 
 | ||||
| Citation: | ||||
|     Flexible character accuracy measure for reading-order-independent evaluation | ||||
|     C. Clausner, S. Pletschacher, A. Antonacopoulos | ||||
|     Pattern Recognition Letters, Volume 131, March 2020, Pages 390-397 | ||||
| Link: http://www.primaresearch.org/publications/PRL_Clausner_FlexibleCharacterAccuracy | ||||
| DOI: 10.1016/j.patrec.2020.02.003 | ||||
| """ | ||||
| 
 | ||||
| import pytest | ||||
| 
 | ||||
| from ..flexible_character_accuracy import * | ||||
| 
 | ||||
| CASE_ARGS = "gt,ocr,first_line_score,all_line_score" | ||||
| 
 | ||||
| SIMPLE_CASES = [ | ||||
|     ("a", "", 0, 0), | ||||
|     ("a", "a", 1, 1), | ||||
|     ("a\nb", "a\nb", 1, 1), | ||||
|     ("a\nb", "b\na", 1, 1), | ||||
|     ("aaa\nbbb\nccc", "ccc\naaa\nbbb", 1, 1), | ||||
|     ("aaa\nbbb\nccc", "aaa\nbbb", 1, 1 - 3 / 9), | ||||
|     ("bbb", "aaa\nbbb\nccc", 1, 1 - 6 / 3), | ||||
|     ("a", "a\nbb\nccc", 1, 1 - 5 / 1), | ||||
|     ("bb", "a\nbb\nccc", 1, 1 - 4 / 2), | ||||
| ] | ||||
| 
 | ||||
| COMPLEX_CASES = [ | ||||
|     ("accc", "a\nbb\nccc", 0, 1 - 2 / 4), | ||||
|     ("aaa\nbbb\nccc", "bbb", 1, 1 - 6 / 9), | ||||
| ] | ||||
| 
 | ||||
| EXTENDED_CASES = [ | ||||
|     # A: No errors | ||||
|     ((0, 1, 2, 3, 4, 11, 5, 6, 7, 8, 9), | ||||
|      (0, 1, 2, 3, 4, 11, 5, 6, 7, 8, 9), | ||||
|      1, 1), | ||||
|     # B: Different ordering of text blocks | ||||
|     ((0, 1, 2, 3, 4, 11, 5, 6, 7, 8, 9), | ||||
|      (5, 6, 7, 8, 9, 11, 0, 1, 2, 3, 4), | ||||
|      1, 1), | ||||
|     # C: Merge across columns | ||||
|     ((0, 1, 2, 11, 3, 4, 11, 5, 6, 7, 11, 8, 9), | ||||
|      (0, 1, 2, 5, 6, 7, 11, 3, 4, 8, 9), | ||||
|      1, 0.964), | ||||
|     # D: Over-segmentation | ||||
|     ((0, 1, 2, 3, 4, 11, 5, 6, 7, 8, 9), | ||||
|      (0, 1, 2, 11, 5, 6, 7, 11, 3, 4, 11, 8, 9), | ||||
|      1, 0.966), | ||||
|     # E: Part missing | ||||
|     ((0, 1, 2, 3, 4, 11, 5, 6, 7, 8, 9), | ||||
|      (0, 1, 2, 3, 4), | ||||
|      1, 0.50), | ||||
|     # E.2: Part missing | ||||
|     ((0, 1, 2, 3, 4, 11, 5, 6, 7, 8, 9), | ||||
|      (5, 6, 7, 8, 9), | ||||
|      1, 0.50), | ||||
|     # F: All missing | ||||
|     ((0, 1, 2, 3, 4, 11, 5, 6, 7, 8, 9), | ||||
|      (), | ||||
|      1, 0), | ||||
|     # G: Added parts | ||||
|     ((0, 1, 2, 3, 4), | ||||
|      (0, 1, 2, 3, 4, 11, 5, 6), | ||||
|      1, 0.621), | ||||
| ] | ||||
| 
 | ||||
| EDIT_ARGS = "gt,ocr,expected_dist" | ||||
| 
 | ||||
| SIMPLE_EDITS = [ | ||||
|     (Part(text="a").substring(), Part(text="a"), Distance(match=1)), | ||||
|     (Part(text="a").substring(), Part(text="b"), Distance(replace=1)), | ||||
|     (Part(text="abcd").substring(), Part(text="beed"), | ||||
|      Distance(match=2, replace=1, insert=1, delete=1)), | ||||
| ] | ||||
| 
 | ||||
| 
 | ||||
| def extended_case_to_text(gt, ocr): | ||||
|     sentence = ("Eight", "happy", "frogs", "scuba", "dived", | ||||
|                 "Jenny", "chick", "flaps", "white", "wings", | ||||
|                 "", "\n") | ||||
| 
 | ||||
|     gt_sentence = " ".join(sentence[i] for i in gt).replace(" \n ", "\n") | ||||
|     ocr_sentence = " ".join(sentence[i] for i in ocr).replace(" \n ", "\n") | ||||
|     return gt_sentence, ocr_sentence | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.parametrize(CASE_ARGS, [*SIMPLE_CASES, *COMPLEX_CASES]) | ||||
| def test_flexible_character_accuracy_simple(gt, ocr, first_line_score, all_line_score): | ||||
|     score, _ = flexible_character_accuracy(gt, ocr) | ||||
|     assert score == pytest.approx(all_line_score) | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.xfail | ||||
| @pytest.mark.parametrize("ocr", [ | ||||
|     "1 hav\nnospecial\ntalents.\nI am one\npassionate\ncuriousity.\"\nAlberto\nEmstein", | ||||
|     "1 hav\nnospecial\ntalents. Alberto\nI am one Emstein\npassionate\ncuriousity.\"", | ||||
|     "Alberto\nEmstein\n1 hav\nnospecial\ntalents.\nI am one\npassionate\ncuriousity.\"" | ||||
| ]) | ||||
| def test_flexible_character_accuracy(ocr): | ||||
|     """Tests from figure 3 in the paper. | ||||
| 
 | ||||
|     TODO: We have a 2 percent deviation from the original because of redistributed | ||||
|           one character alignments (e.g. the y-insert replaces the y-delete). | ||||
|     """ | ||||
|     gt = """"I have | ||||
| no special | ||||
| talent. | ||||
| I am only | ||||
| passionately | ||||
| curious." | ||||
| Albert | ||||
| Einstein | ||||
| """ | ||||
|     replacements = 3 | ||||
|     inserts = 5 | ||||
|     deletes = 7 | ||||
|     chars = len(gt) - gt.count("\n") | ||||
|     assert replacements + inserts + deletes == 15 | ||||
|     edits = Distance(match=chars - deletes - replacements, replace=replacements, | ||||
|                      insert=inserts, delete=deletes) | ||||
|     expected = character_accuracy(edits) | ||||
|     assert expected == pytest.approx(0.779, abs=0.0005) | ||||
|     result, matches = flexible_character_accuracy(gt, ocr) | ||||
|     assert result == pytest.approx(expected, abs=0.0005) | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.parametrize(CASE_ARGS, EXTENDED_CASES) | ||||
| def test_flexible_character_accuracy_extended(gt, ocr, first_line_score, | ||||
|                                               all_line_score): | ||||
|     """Tests from figure 4 in the paper.""" | ||||
|     gt_sentence, ocr_sentence = extended_case_to_text(gt, ocr) | ||||
|     result, _ = flexible_character_accuracy(gt_sentence, ocr_sentence) | ||||
|     assert result == pytest.approx(all_line_score, abs=0.001) | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.parametrize(CASE_ARGS, [*SIMPLE_CASES, *COMPLEX_CASES, *EXTENDED_CASES]) | ||||
| def test_match_with_coefficients(gt, ocr, first_line_score, all_line_score): | ||||
|     coef = Coefficients() | ||||
|     if not isinstance(gt, str): | ||||
|         gt, ocr = extended_case_to_text(gt, ocr) | ||||
|     matches = match_with_coefficients(gt, ocr, coef) | ||||
|     score = character_accuracy_for_matches(matches) | ||||
|     assert score == pytest.approx(all_line_score, abs=0.001) | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.parametrize(CASE_ARGS, [*SIMPLE_CASES, *COMPLEX_CASES]) | ||||
| def test_match_longest_gt_lines(gt, ocr, first_line_score, all_line_score): | ||||
|     coef = Coefficients() | ||||
|     gt_lines = initialize_lines(gt) | ||||
|     ocr_lines = initialize_lines(ocr) | ||||
|     match = match_longest_gt_lines(gt_lines, ocr_lines, coef) | ||||
|     score = 0 | ||||
|     if match: | ||||
|         score = character_accuracy(match.dist) | ||||
|     assert score == pytest.approx(first_line_score) | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.parametrize(CASE_ARGS, [ | ||||
|     *SIMPLE_CASES, | ||||
|     ("accc", "a\nbb\nccc", 1.0, 1.0), | ||||
| ]) | ||||
| def test_match_gt_line(gt, ocr, first_line_score, all_line_score): | ||||
|     coef = Coefficients() | ||||
|     gt_lines = initialize_lines(gt) | ||||
|     ocr_lines = initialize_lines(ocr) | ||||
|     match, _ = match_gt_line(gt_lines[0], ocr_lines, coef) | ||||
|     score = 0 | ||||
|     if match: | ||||
|         score = character_accuracy(match.dist) | ||||
|     assert score == pytest.approx(first_line_score) | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.parametrize("original,match,expected_lines", [ | ||||
|     (Part(), Part(), []), | ||||
|     (Part(text="abc"), Part(), [Part(text="abc")]), | ||||
|     (Part(text="abc"), Part("d"), [Part(text="bc", start=1)]), | ||||
|     (Part(text="abc"), Part("a", start=100), [Part(text="abc")]), | ||||
|     (Part(text="abc"), Part("a"), [Part(text="bc", start=1)]), | ||||
|     (Part(text="abc"), Part("b", start=1), [Part(text="a"), Part(text="c", start=2)]), | ||||
|     (Part(text="abc"), Part("c", start=2), [Part(text="ab")]), | ||||
| ]) | ||||
| def test_remove_or_split(original, match, expected_lines): | ||||
|     lines = [original] | ||||
|     splitted = remove_or_split(original, match, lines) | ||||
|     assert splitted == (len(lines) > 0) | ||||
|     assert lines == expected_lines | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.parametrize(EDIT_ARGS, [ | ||||
|     *SIMPLE_EDITS, | ||||
|     (Part(text="aaabbbaaa"), Part(text="bbb"), Distance(match=3)), | ||||
|     (Part(text="bbb"), Part(text="aaabbbaaa"), Distance(match=3)), | ||||
|     (Part(text=""), Part(text=""), None) | ||||
| ]) | ||||
| def test_match_lines(gt, ocr, expected_dist): | ||||
|     match = match_lines(gt, ocr) | ||||
|     if not expected_dist: | ||||
|         assert match is None | ||||
|     else: | ||||
|         assert match.gt.text in gt.text | ||||
|         assert match.ocr.text in ocr.text | ||||
|         assert match.dist == expected_dist | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.parametrize(EDIT_ARGS, [ | ||||
|     *SIMPLE_EDITS, | ||||
|     (Part(text="").substring(), Part(text=""), Distance()), | ||||
|     (Part(text="ab").substring(), Part("a"), Distance(match=1, delete=1)), | ||||
|     (Part(text="a").substring(), Part("ab"), Distance(match=1, insert=1)), | ||||
| ]) | ||||
| def test_distance(gt, ocr, expected_dist): | ||||
|     match = distance(gt, ocr) | ||||
|     assert match.gt == gt | ||||
|     assert match.ocr == ocr | ||||
|     assert match.dist == expected_dist | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.parametrize("matches,expected_dist", [ | ||||
|     ([], 1), | ||||
|     ([Match(gt=Part(text=""), ocr=Part(text=""), dist=Distance(), ops=[])], 1), | ||||
|     ([Match(gt=Part(text="abee"), ocr=Part("ac"), | ||||
|             dist=Distance(match=1, replace=1, delete=2), ops=[]), | ||||
|       Match(gt=Part(text="cd"), ocr=Part("ceff"), | ||||
|             dist=Distance(match=1, replace=1, insert=2), ops=[])], | ||||
|      1 - 6 / 6), | ||||
| ]) | ||||
| def test_character_accuracy_matches(matches, expected_dist): | ||||
|     assert character_accuracy_for_matches(matches) == pytest.approx(expected_dist) | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.parametrize("dist,expected_dist", [ | ||||
|     (Distance(), 1), | ||||
|     (Distance(match=1), 1), | ||||
|     (Distance(replace=1), 0), | ||||
|     (Distance(match=1, insert=1), 0), | ||||
|     (Distance(match=1, insert=2), 1 - 2 / 1), | ||||
|     (Distance(match=2, insert=1), 0.5), | ||||
|     (Distance(match=1, delete=1), 0.5), | ||||
| ]) | ||||
| def test_character_accuracy_dist(dist, expected_dist): | ||||
|     assert character_accuracy(dist) == pytest.approx(expected_dist) | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.parametrize("line,subline,expected_rest", [ | ||||
|     (Part(), Part(), []), | ||||
|     (Part("aaa bbb"), Part("aaa bbb"), []), | ||||
|     (Part("aaa bbb"), Part("aaa"), [Part(" bbb", start=3)]), | ||||
|     (Part("aaa bbb"), Part("bbb", start=4), [Part("aaa ")]), | ||||
|     (Part("aaa bbb", start=3), Part("aaa", start=3), [Part(" bbb", start=6)]), | ||||
|     (Part("aaa bbb", start=3), Part("bbb", start=7), [Part("aaa ", start=3)]), | ||||
|     (Part("aaa bbb ccc"), Part("bbb", start=4), [Part("aaa "), Part(" ccc", start=7)]), | ||||
|     (Part("aaa bbb ccc", start=3), Part("bbb", start=7), | ||||
|      [Part("aaa ", start=3), Part(" ccc", start=10)]), | ||||
|     (Part("aaa bbb"), Part(" ", start=3), [Part("aaa"), Part("bbb", start=4)]), | ||||
|     (Part("aaa bbb", start=3), Part(" ", start=6), | ||||
|      [Part("aaa", start=3), Part("bbb", start=7)]), | ||||
| ]) | ||||
| def test_split_line(line, subline, expected_rest): | ||||
|     rest = line.split(subline) | ||||
|     assert len(rest) == len(expected_rest) | ||||
|     assert set(rest) == set(expected_rest) | ||||
| 
 | ||||
| 
 | ||||
| def test_initialize_lines(): | ||||
|     lines = initialize_lines("") | ||||
|     assert lines == [] | ||||
| 
 | ||||
|     lines = initialize_lines("22\n1\n333") | ||||
|     line1 = Part(text="22", line=0, start=0) | ||||
|     line2 = Part("1", line=1, start=0) | ||||
|     line3 = Part("333", line=2, start=0) | ||||
|     assert lines == [line3, line1, line2] | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.xfail | ||||
| def test_combine_lines(): | ||||
|     assert False | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.parametrize("line,start,end,expected", [ | ||||
|     (Part(text=""), 0, None, Part(text="")), | ||||
|     (Part(text="a"), 0, None, Part(text="a")), | ||||
|     (Part(text="ab"), 0, 1, Part(text="a")), | ||||
|     (Part(text="abc"), 0, -1, Part(text="ab")), | ||||
|     (Part(text="ab"), 1, None, Part(text="b", start=1)), | ||||
| ]) | ||||
| def test_line_substring(line, start, end, expected): | ||||
|     assert line.substring(rel_start=start, rel_end=end) == expected | ||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue