mirror of
				https://github.com/qurator-spk/dinglehopper.git
				synced 2025-10-31 09:24:15 +01:00 
			
		
		
		
	apply black
This commit is contained in:
		
							parent
							
								
									01571f23b7
								
							
						
					
					
						commit
						22c3817f45
					
				
					 7 changed files with 32 additions and 14 deletions
				
			
		|  | @ -9,7 +9,9 @@ from .extracted_text import ExtractedText | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @multimethod | @multimethod | ||||||
| def character_error_rate_n(reference: list[str], compared: list[str]) -> Tuple[float, int]: | def character_error_rate_n( | ||||||
|  |     reference: list[str], compared: list[str] | ||||||
|  | ) -> Tuple[float, int]: | ||||||
|     """ |     """ | ||||||
|     Compute character error rate. |     Compute character error rate. | ||||||
| 
 | 
 | ||||||
|  | @ -39,7 +41,9 @@ def character_error_rate_n(reference: str, compared: str) -> Tuple[float, int]: | ||||||
| def character_error_rate_n( | def character_error_rate_n( | ||||||
|     reference: ExtractedText, compared: ExtractedText |     reference: ExtractedText, compared: ExtractedText | ||||||
| ) -> Tuple[float, int]: | ) -> Tuple[float, int]: | ||||||
|     return character_error_rate_n(reference.grapheme_clusters, compared.grapheme_clusters) |     return character_error_rate_n( | ||||||
|  |         reference.grapheme_clusters, compared.grapheme_clusters | ||||||
|  |     ) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def character_error_rate(reference, compared) -> float: | def character_error_rate(reference, compared) -> float: | ||||||
|  |  | ||||||
|  | @ -26,7 +26,7 @@ def common_suffix(its): | ||||||
| 
 | 
 | ||||||
| def removesuffix(text, suffix): | def removesuffix(text, suffix): | ||||||
|     if suffix and text.endswith(suffix): |     if suffix and text.endswith(suffix): | ||||||
|         return text[:-len(suffix)] |         return text[: -len(suffix)] | ||||||
|     return text |     return text | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | @ -46,7 +46,9 @@ def process(gt_dir, ocr_dir, report_prefix, *, metrics=True): | ||||||
|         ocr = removesuffix(gt, gt_suffix) + ocr_suffix |         ocr = removesuffix(gt, gt_suffix) + ocr_suffix | ||||||
| 
 | 
 | ||||||
|         gt_text = plain_extract(os.path.join(gt_dir, gt), include_filename_in_id=True) |         gt_text = plain_extract(os.path.join(gt_dir, gt), include_filename_in_id=True) | ||||||
|         ocr_text = plain_extract(os.path.join(ocr_dir, ocr), include_filename_in_id=True) |         ocr_text = plain_extract( | ||||||
|  |             os.path.join(ocr_dir, ocr), include_filename_in_id=True | ||||||
|  |         ) | ||||||
|         gt_words = words_normalized(gt_text) |         gt_words = words_normalized(gt_text) | ||||||
|         ocr_words = words_normalized(ocr_text) |         ocr_words = words_normalized(ocr_text) | ||||||
| 
 | 
 | ||||||
|  | @ -56,7 +58,9 @@ def process(gt_dir, ocr_dir, report_prefix, *, metrics=True): | ||||||
|             cer, n_characters = l_cer, l_n_characters |             cer, n_characters = l_cer, l_n_characters | ||||||
|         else: |         else: | ||||||
|             # Rolling update |             # Rolling update | ||||||
|             cer = (cer * n_characters + l_cer * l_n_characters) / (n_characters + l_n_characters) |             cer = (cer * n_characters + l_cer * l_n_characters) / ( | ||||||
|  |                 n_characters + l_n_characters | ||||||
|  |             ) | ||||||
|             n_characters = n_characters + l_n_characters |             n_characters = n_characters + l_n_characters | ||||||
| 
 | 
 | ||||||
|         # Compute WER |         # Compute WER | ||||||
|  |  | ||||||
|  | @ -17,6 +17,7 @@ def distance(seq1: list[str], seq2: list[str]): | ||||||
|     """ |     """ | ||||||
|     return Levenshtein.distance(seq1, seq2) |     return Levenshtein.distance(seq1, seq2) | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
| @multimethod | @multimethod | ||||||
| def distance(s1: str, s2: str): | def distance(s1: str, s2: str): | ||||||
|     """Compute the Levenshtein edit distance between two Unicode strings |     """Compute the Levenshtein edit distance between two Unicode strings | ||||||
|  |  | ||||||
|  | @ -175,7 +175,7 @@ class ExtractedText: | ||||||
|             return self._grapheme_clusters |             return self._grapheme_clusters | ||||||
|         else: |         else: | ||||||
|             clusters = [] |             clusters = [] | ||||||
|             for seg in  self.segments: |             for seg in self.segments: | ||||||
|                 # todo could there be cases where joiner is no grapheme cluster? |                 # todo could there be cases where joiner is no grapheme cluster? | ||||||
|                 clusters.extend(seg.grapheme_clusters + [self.joiner]) |                 clusters.extend(seg.grapheme_clusters + [self.joiner]) | ||||||
|             return clusters[:-1] |             return clusters[:-1] | ||||||
|  | @ -242,7 +242,9 @@ class ExtractedText: | ||||||
|     def from_str(cls, text, normalization=Normalization.NFC_SBB): |     def from_str(cls, text, normalization=Normalization.NFC_SBB): | ||||||
|         normalized_text = normalize(text, normalization) |         normalized_text = normalize(text, normalization) | ||||||
|         clusters = list(grapheme_clusters(normalized_text)) |         clusters = list(grapheme_clusters(normalized_text)) | ||||||
|         return cls(None, None, None, normalized_text, clusters, normalization=normalization) |         return cls( | ||||||
|  |             None, None, None, normalized_text, clusters, normalization=normalization | ||||||
|  |         ) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def invert_dict(d): | def invert_dict(d): | ||||||
|  |  | ||||||
|  | @ -98,14 +98,18 @@ def extract_texts_from_reading_order_group(group, tree, nsmap, textequiv_level): | ||||||
| 
 | 
 | ||||||
|         ro_children = filter(lambda child: "index" in child.attrib.keys(), ro_children) |         ro_children = filter(lambda child: "index" in child.attrib.keys(), ro_children) | ||||||
|         ro_children = sorted(ro_children, key=lambda child: int(child.attrib["index"])) |         ro_children = sorted(ro_children, key=lambda child: int(child.attrib["index"])) | ||||||
|     elif ET.QName(group.tag).localname in ["UnorderedGroup","UnorderedGroupIndexed"]: |     elif ET.QName(group.tag).localname in ["UnorderedGroup", "UnorderedGroupIndexed"]: | ||||||
|         ro_children = list(group) |         ro_children = list(group) | ||||||
|     else: |     else: | ||||||
|         raise NotImplementedError |         raise NotImplementedError | ||||||
| 
 | 
 | ||||||
| 
 |  | ||||||
|     for ro_child in ro_children: |     for ro_child in ro_children: | ||||||
|         if ET.QName(ro_child.tag).localname in ["OrderedGroup", "OrderedGroupIndexed", "UnorderedGroup", "UnorderedGroupIndexed"]: |         if ET.QName(ro_child.tag).localname in [ | ||||||
|  |             "OrderedGroup", | ||||||
|  |             "OrderedGroupIndexed", | ||||||
|  |             "UnorderedGroup", | ||||||
|  |             "UnorderedGroupIndexed", | ||||||
|  |         ]: | ||||||
|             regions.extend( |             regions.extend( | ||||||
|                 extract_texts_from_reading_order_group( |                 extract_texts_from_reading_order_group( | ||||||
|                     ro_child, tree, nsmap, textequiv_level |                     ro_child, tree, nsmap, textequiv_level | ||||||
|  | @ -139,7 +143,11 @@ def plain_extract(filename, include_filename_in_id=False): | ||||||
|         clusters = list(grapheme_clusters(normalized_text)) |         clusters = list(grapheme_clusters(normalized_text)) | ||||||
|         return ExtractedText( |         return ExtractedText( | ||||||
|             id_template.format(filename=os.path.basename(filename), no=no), |             id_template.format(filename=os.path.basename(filename), no=no), | ||||||
|             None, None, normalized_text, clusters) |             None, | ||||||
|  |             None, | ||||||
|  |             normalized_text, | ||||||
|  |             clusters, | ||||||
|  |         ) | ||||||
| 
 | 
 | ||||||
|     with open(filename, "r") as f: |     with open(filename, "r") as f: | ||||||
|         return ExtractedText( |         return ExtractedText( | ||||||
|  | @ -147,7 +155,7 @@ def plain_extract(filename, include_filename_in_id=False): | ||||||
|             [make_segment(no, line) for no, line in enumerate(f.readlines())], |             [make_segment(no, line) for no, line in enumerate(f.readlines())], | ||||||
|             "\n", |             "\n", | ||||||
|             None, |             None, | ||||||
|             None |             None, | ||||||
|         ) |         ) | ||||||
|     # XXX hardcoded SBB normalization |     # XXX hardcoded SBB normalization | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -33,7 +33,7 @@ class OcrdDinglehopperEvaluate(Processor): | ||||||
|         textequiv_level = self.parameter["textequiv_level"] |         textequiv_level = self.parameter["textequiv_level"] | ||||||
|         gt_grp, ocr_grp = self.input_file_grp.split(",") |         gt_grp, ocr_grp = self.input_file_grp.split(",") | ||||||
| 
 | 
 | ||||||
|         input_file_tuples = self.zip_input_files(on_error='abort') |         input_file_tuples = self.zip_input_files(on_error="abort") | ||||||
|         for n, (gt_file, ocr_file) in enumerate(input_file_tuples): |         for n, (gt_file, ocr_file) in enumerate(input_file_tuples): | ||||||
|             if not gt_file or not ocr_file: |             if not gt_file or not ocr_file: | ||||||
|                 # file/page was not found in this group |                 # file/page was not found in this group | ||||||
|  |  | ||||||
|  | @ -40,7 +40,6 @@ def words(s: str): | ||||||
|     if not word_break_patched: |     if not word_break_patched: | ||||||
|         patch_word_break() |         patch_word_break() | ||||||
| 
 | 
 | ||||||
| 
 |  | ||||||
|     # Check if c is an unwanted character, i.e. whitespace, punctuation, or similar |     # Check if c is an unwanted character, i.e. whitespace, punctuation, or similar | ||||||
|     def unwanted(c): |     def unwanted(c): | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue