diff --git a/qurator/dinglehopper/ocr_files.py b/qurator/dinglehopper/ocr_files.py index 5ce0bcd..180ecd3 100644 --- a/qurator/dinglehopper/ocr_files.py +++ b/qurator/dinglehopper/ocr_files.py @@ -5,6 +5,7 @@ from warnings import warn from lxml import etree as ET from lxml.etree import XMLSyntaxError from contextlib import suppress +from itertools import repeat from .substitute_equivalences import substitute_equivalences import sys import attr @@ -22,16 +23,20 @@ class ExtractedText: def text(self): return self.joiner.join(s.text for s in self.segments) + _segment_id_for_pos = None + def segment_id_for_pos(self, pos): - i = 0 - for s in self.segments: - if i <= pos < i + len(s.text): - return s.id - i += len(s.text) - if i <= pos < i + len(self.joiner): - return None - i += len(self.joiner) - # XXX Cache results + # Calculate segment ids once, on the first call + if not self._segment_id_for_pos: + segment_id_for_pos = [] + for s in self.segments: + segment_id_for_pos.extend(repeat(s.id, len(s.text))) + segment_id_for_pos.extend(repeat(None, len(self.joiner))) + # This is frozen, so we have to jump through the hoop: + object.__setattr__(self, '_segment_id_for_pos', segment_id_for_pos) + assert self._segment_id_for_pos + + return self._segment_id_for_pos[pos] class Normalization(enum.Enum):