From f447a9f248d0d48ea6f4a6fc7185ae82484c8ac3 Mon Sep 17 00:00:00 2001 From: Robert Sachunsky Date: Wed, 3 Jun 2026 03:41:44 +0200 Subject: [PATCH] =?UTF-8?q?trocr:=20move=20preprocessor=20and=20decoder=20?= =?UTF-8?q?into=20model=20object,=20too=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - ModelZoo: drop `trocr_processor` model type - `ModelZoo.load_models()`: use Predictor for `ocr_tr` models, too - `ModelZoo.load_model()`: for `ocr_tr`, load processor and model, then define a function object as stand-in for the common model interface based on Keras (w/ `.predict_on_batch()`) - Predictor: allow multi-input without actual batch dimension for `ocr_tr` models (because the model takes a list of original image arrays and resizes them to model shape internally) - Eynollah_ocr: adapt (replacing preprocessing, prediction and decoding steps by a single `.predict()` call) --- src/eynollah/eynollah_ocr.py | 22 +---- src/eynollah/model_zoo/default_specs.py | 16 ---- src/eynollah/model_zoo/model_zoo.py | 121 ++++++++++++++---------- src/eynollah/predictor.py | 15 ++- 4 files changed, 88 insertions(+), 86 deletions(-) diff --git a/src/eynollah/eynollah_ocr.py b/src/eynollah/eynollah_ocr.py index aeaabfe..1dfe177 100644 --- a/src/eynollah/eynollah_ocr.py +++ b/src/eynollah/eynollah_ocr.py @@ -70,8 +70,7 @@ class Eynollah_ocr(Eynollah): def setup_models(self, device=''): if self.tr_ocr: - self.model_zoo.load_models('trocr_processor', - ('ocr', 'tr'), + self.model_zoo.load_models(('ocr', 'tr'), device=device) else: self.model_zoo.load_models('ocr', @@ -142,24 +141,7 @@ class Eynollah_ocr(Eynollah): self.logger.debug("processing %d lines for %d regions", len(cropped_lines), len(set(cropped_lines_region_indexer))) for imgs in batched(cropped_lines, self.b_s): - pixel_values = self.model_zoo.get('trocr_processor')( - imgs, return_tensors="pt").pixel_values - output = self.model_zoo.get('ocr').generate( - pixel_values.to(self.device), - # beam search instead of greedy decoding: - num_beams=4, - # also return probability - output_scores=True, - return_dict_in_generate=True) - if output.sequences_scores is not None: - # log-prob averaged over length - conf = output.sequences_scores.exp().clamp(0.0, 1.0).tolist() - else: - conf = [1.0] * len(output.sequences) - text = self.model_zoo.get('trocr_processor').batch_decode( - output.sequences, - skip_special_tokens=True, - clean_up_tokenization_spaces=False) + text, conf = self.model_zoo.get('ocr').predict(imgs) extracted_confs.extend(conf) extracted_texts.extend(text) del cropped_lines diff --git a/src/eynollah/model_zoo/default_specs.py b/src/eynollah/model_zoo/default_specs.py index 170d944..18bf093 100644 --- a/src/eynollah/model_zoo/default_specs.py +++ b/src/eynollah/model_zoo/default_specs.py @@ -217,20 +217,4 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([ type='Keras', ), - EynollahModelSpec( - category="trocr_processor", - variant='', - filename="models_eynollah/model_eynollah_ocr_trocr_20250919", - dist_url=dist_url("ocr"), - type='TrOCRProcessor', - ), - - EynollahModelSpec( - category="trocr_processor", - variant='htr', - filename="models_eynollah/microsoft/trocr-base-handwritten", - dist_url=dist_url("extra"), - type='TrOCRProcessor', - ), - ]) diff --git a/src/eynollah/model_zoo/model_zoo.py b/src/eynollah/model_zoo/model_zoo.py index e7d21aa..0dd24a8 100644 --- a/src/eynollah/model_zoo/model_zoo.py +++ b/src/eynollah/model_zoo/model_zoo.py @@ -116,22 +116,15 @@ class EynollahModelZoo: model_category, model_variant = load_args load_kwargs["model_variant"] = model_variant - if model_category.endswith('_resized'): - model_category = model_category[:-8] - load_kwargs["resized"] = True - elif model_category.endswith('_patched'): - model_category = model_category[:-8] - load_kwargs["patched"] = True + # if model_category.endswith('_resized'): + # model_category = model_category[:-8] + # load_kwargs["resized"] = True + # elif model_category.endswith('_patched'): + # model_category = model_category[:-8] + # load_kwargs["patched"] = True - if model_category == 'ocr' and model_variant == 'tr': - model = self._load_ocr_model(variant=model_variant, device=device) - elif model_category == 'trocr_processor': - from transformers import TrOCRProcessor - model_path = self.model_path(model_category, model_variant) - model = TrOCRProcessor.from_pretrained(model_path) - else: - model = Predictor(self.logger, self) - model.load_model(model_category, **load_kwargs) + model = Predictor(self.logger, self) + model.load_model(model_category, **load_kwargs) ret[model_category] = model self._loaded.update(ret) @@ -142,8 +135,8 @@ class EynollahModelZoo: model_category: str, model_variant: str = '', model_path_override: Optional[str] = None, - patched: bool = False, - resized: bool = False, + # patched: bool = False, + # resized: bool = False, device: str = '', ) -> AnyModel: """ @@ -153,7 +146,9 @@ class EynollahModelZoo: self.override_models((model_category, model_variant, model_path_override)) model_path = self.model_path(model_category, model_variant) - if model_path.is_dir() and (model_path / "keras_metadata.pb").exists(): + if model_category == 'ocr' and model_variant == 'tr': + model = self._load_trocr_model(model_path, device=device) + elif model_path.is_dir() and (model_path / "keras_metadata.pb").exists(): # Keras model model = self._load_keras_model(model_category, model_path, device=device) elif model_path.is_dir(): @@ -220,6 +215,30 @@ class EynollahModelZoo: if not cuda: self.logger.warning("no GPU device available") + def _configure_torch_device(self, model_category, device=''): + import torch + + device0 = torch.device('cpu') + if not device and torch.cuda.is_available(): + device = 'GPU' # try + if device and ':' in device: + for spec in device.split(','): + cat, dev = spec.split(':') + if fnmatchcase('ocr', cat): + device = dev + break + if device and device.startswith('GPU'): + try: + device0 = torch.device('cuda', int(device[3:] or 0)) + name = torch.cuda.get_device_name(device0) + self.logger.info("using GPU %s (%s) for model ocr:tr", device0, name) + except: + self.logger.exception("cannot configure GPU device") + device0 = torch.device('cpu') + if device0.type != 'cuda': + self.logger.warning("no GPU device available") + return device0 + def _load_keras_model(self, model_category, model_path, device=''): os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15 from ocrd_utils import tf_disable_interactive_logs @@ -325,40 +344,46 @@ class EynollahModelZoo: return model - def _load_ocr_model(self, variant: str, device: str = "") -> AnyModel: + def _load_trocr_model(self, model_path, device: str = "") -> AnyModel: """ Load OCR model """ - model_dir = self.model_path('ocr', variant) - if variant == 'tr': - from transformers import VisionEncoderDecoderModel - import torch - model = VisionEncoderDecoderModel.from_pretrained(model_dir) - assert isinstance(model, VisionEncoderDecoderModel) - device0 = torch.device('cpu') - if not device and torch.cuda.is_available(): - device = 'GPU' # try - if device and ':' in device: - for spec in device.split(','): - cat, dev = spec.split(':') - if fnmatchcase('ocr', cat): - device = dev - break - if device and device.startswith('GPU'): - try: - device0 = torch.device('cuda', int(device[3:] or 0)) - name = torch.cuda.get_device_name(device0) - self.logger.info("using GPU %s (%s) for model ocr:tr", device0, name) - except: - self.logger.exception("cannot configure GPU device") - device0 = torch.device('cpu') - if device0.type == 'cuda': - model.to(device0) - else: - self.logger.warning("no GPU device available") - return model + from transformers import VisionEncoderDecoderModel, TrOCRProcessor + import numpy as np - return self.load_model('ocr', model_variant=variant, device=device) + device = self._configure_torch_device('ocr', device=device) + proc = TrOCRProcessor.from_pretrained(model_path) + model = VisionEncoderDecoderModel.from_pretrained(model_path) + assert isinstance(model, VisionEncoderDecoderModel) + + model.to(device) + def predict_torch(inputs): + output = model.generate( + proc(inputs, return_tensors="pt").pixel_values.to(device), + # beam search instead of greedy decoding: + num_beams=4, + # also return probability + output_scores=True, + return_dict_in_generate=True) + if output.sequences_scores is not None: + # log-prob averaged over length + conf = output.sequences_scores.exp().clamp(0.0, 1.0).cpu().numpy() + else: + conf = np.ones(len(output.sequences), dtype=float) + text = proc.batch_decode( + output.sequences, + skip_special_tokens=True, + clean_up_tokenization_spaces=False) + # we must convert to ndarray for Predictor resultq to work + text = np.array(text) + return text, conf + model.predict_on_batch = predict_torch + # not actually needed (image processor does resize itself) + model.input_shape = (None, + proc.image_processor.size.height, + proc.image_processor.size.width, + len(proc.image_processor.image_mean)) + return model def __str__(self): return tabulate( diff --git a/src/eynollah/predictor.py b/src/eynollah/predictor.py index 141d3f0..23cc36f 100644 --- a/src/eynollah/predictor.py +++ b/src/eynollah/predictor.py @@ -129,6 +129,7 @@ class Predictor(mp.context.SpawnProcess): "enhancement": 4, "reading_order": 4, "ocr": 8, + "ocr_tr": 2, # medium size (672x672x3)... "textline": 2, # large models... @@ -144,7 +145,14 @@ class Predictor(mp.context.SpawnProcess): self.resultq.put((jobid, result)) #self.logger.debug("sent result for '%d': %s", jobid, result) else: - if isinstance(shared_data, tuple): + if self.name == 'ocr_tr': + # this model takes a list of (image) tensors + # of heterogeneous shape as input, + # resizing them internally; + # so this looks like multi-input + multi_input = True + batch_size = len(shared_data) + elif isinstance(shared_data, tuple): multi_input = True batch_size = shared_data[0]['shape'][0] else: @@ -215,8 +223,11 @@ class Predictor(mp.context.SpawnProcess): def load_model(self, *load_args, **load_kwargs): assert len(load_args) self.name = '_'.join(list(load_args[:1]) + + list(load_kwargs[key] for key in load_kwargs + if key == 'model_variant') + list(key for key in load_kwargs - if key != 'device')) + if key in ['patched', 'resized'] + and load_kwargs[key])) self.load_args = load_args self.load_kwargs = load_kwargs self.start() # call run() in subprocess