trocr: move preprocessor and decoder into model object, too…

- ModelZoo: drop `trocr_processor` model type
- `ModelZoo.load_models()`: use Predictor for `ocr_tr` models, too
- `ModelZoo.load_model()`: for `ocr_tr`, load processor and model,
  then define a function object as stand-in for the common model
  interface based on Keras (w/ `.predict_on_batch()`)
- Predictor: allow multi-input without actual batch dimension
  for `ocr_tr` models (because the model takes a list of original
  image arrays and resizes them to model shape internally)
- Eynollah_ocr: adapt (replacing preprocessing, prediction and
  decoding steps by a single `.predict()` call)
This commit is contained in:
Robert Sachunsky 2026-06-03 03:41:44 +02:00
parent d2f2a1e06b
commit f447a9f248
4 changed files with 88 additions and 86 deletions

View file

@ -70,8 +70,7 @@ class Eynollah_ocr(Eynollah):
def setup_models(self, device=''): def setup_models(self, device=''):
if self.tr_ocr: if self.tr_ocr:
self.model_zoo.load_models('trocr_processor', self.model_zoo.load_models(('ocr', 'tr'),
('ocr', 'tr'),
device=device) device=device)
else: else:
self.model_zoo.load_models('ocr', self.model_zoo.load_models('ocr',
@ -142,24 +141,7 @@ class Eynollah_ocr(Eynollah):
self.logger.debug("processing %d lines for %d regions", self.logger.debug("processing %d lines for %d regions",
len(cropped_lines), len(set(cropped_lines_region_indexer))) len(cropped_lines), len(set(cropped_lines_region_indexer)))
for imgs in batched(cropped_lines, self.b_s): for imgs in batched(cropped_lines, self.b_s):
pixel_values = self.model_zoo.get('trocr_processor')( text, conf = self.model_zoo.get('ocr').predict(imgs)
imgs, return_tensors="pt").pixel_values
output = self.model_zoo.get('ocr').generate(
pixel_values.to(self.device),
# beam search instead of greedy decoding:
num_beams=4,
# also return probability
output_scores=True,
return_dict_in_generate=True)
if output.sequences_scores is not None:
# log-prob averaged over length
conf = output.sequences_scores.exp().clamp(0.0, 1.0).tolist()
else:
conf = [1.0] * len(output.sequences)
text = self.model_zoo.get('trocr_processor').batch_decode(
output.sequences,
skip_special_tokens=True,
clean_up_tokenization_spaces=False)
extracted_confs.extend(conf) extracted_confs.extend(conf)
extracted_texts.extend(text) extracted_texts.extend(text)
del cropped_lines del cropped_lines

View file

@ -217,20 +217,4 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([
type='Keras', type='Keras',
), ),
EynollahModelSpec(
category="trocr_processor",
variant='',
filename="models_eynollah/model_eynollah_ocr_trocr_20250919",
dist_url=dist_url("ocr"),
type='TrOCRProcessor',
),
EynollahModelSpec(
category="trocr_processor",
variant='htr',
filename="models_eynollah/microsoft/trocr-base-handwritten",
dist_url=dist_url("extra"),
type='TrOCRProcessor',
),
]) ])

View file

@ -116,22 +116,15 @@ class EynollahModelZoo:
model_category, model_variant = load_args model_category, model_variant = load_args
load_kwargs["model_variant"] = model_variant load_kwargs["model_variant"] = model_variant
if model_category.endswith('_resized'): # if model_category.endswith('_resized'):
model_category = model_category[:-8] # model_category = model_category[:-8]
load_kwargs["resized"] = True # load_kwargs["resized"] = True
elif model_category.endswith('_patched'): # elif model_category.endswith('_patched'):
model_category = model_category[:-8] # model_category = model_category[:-8]
load_kwargs["patched"] = True # load_kwargs["patched"] = True
if model_category == 'ocr' and model_variant == 'tr': model = Predictor(self.logger, self)
model = self._load_ocr_model(variant=model_variant, device=device) model.load_model(model_category, **load_kwargs)
elif model_category == 'trocr_processor':
from transformers import TrOCRProcessor
model_path = self.model_path(model_category, model_variant)
model = TrOCRProcessor.from_pretrained(model_path)
else:
model = Predictor(self.logger, self)
model.load_model(model_category, **load_kwargs)
ret[model_category] = model ret[model_category] = model
self._loaded.update(ret) self._loaded.update(ret)
@ -142,8 +135,8 @@ class EynollahModelZoo:
model_category: str, model_category: str,
model_variant: str = '', model_variant: str = '',
model_path_override: Optional[str] = None, model_path_override: Optional[str] = None,
patched: bool = False, # patched: bool = False,
resized: bool = False, # resized: bool = False,
device: str = '', device: str = '',
) -> AnyModel: ) -> AnyModel:
""" """
@ -153,7 +146,9 @@ class EynollahModelZoo:
self.override_models((model_category, model_variant, model_path_override)) self.override_models((model_category, model_variant, model_path_override))
model_path = self.model_path(model_category, model_variant) model_path = self.model_path(model_category, model_variant)
if model_path.is_dir() and (model_path / "keras_metadata.pb").exists(): if model_category == 'ocr' and model_variant == 'tr':
model = self._load_trocr_model(model_path, device=device)
elif model_path.is_dir() and (model_path / "keras_metadata.pb").exists():
# Keras model # Keras model
model = self._load_keras_model(model_category, model_path, device=device) model = self._load_keras_model(model_category, model_path, device=device)
elif model_path.is_dir(): elif model_path.is_dir():
@ -220,6 +215,30 @@ class EynollahModelZoo:
if not cuda: if not cuda:
self.logger.warning("no GPU device available") self.logger.warning("no GPU device available")
def _configure_torch_device(self, model_category, device=''):
import torch
device0 = torch.device('cpu')
if not device and torch.cuda.is_available():
device = 'GPU' # try
if device and ':' in device:
for spec in device.split(','):
cat, dev = spec.split(':')
if fnmatchcase('ocr', cat):
device = dev
break
if device and device.startswith('GPU'):
try:
device0 = torch.device('cuda', int(device[3:] or 0))
name = torch.cuda.get_device_name(device0)
self.logger.info("using GPU %s (%s) for model ocr:tr", device0, name)
except:
self.logger.exception("cannot configure GPU device")
device0 = torch.device('cpu')
if device0.type != 'cuda':
self.logger.warning("no GPU device available")
return device0
def _load_keras_model(self, model_category, model_path, device=''): def _load_keras_model(self, model_category, model_path, device=''):
os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15 os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15
from ocrd_utils import tf_disable_interactive_logs from ocrd_utils import tf_disable_interactive_logs
@ -325,40 +344,46 @@ class EynollahModelZoo:
return model return model
def _load_ocr_model(self, variant: str, device: str = "") -> AnyModel: def _load_trocr_model(self, model_path, device: str = "") -> AnyModel:
""" """
Load OCR model Load OCR model
""" """
model_dir = self.model_path('ocr', variant) from transformers import VisionEncoderDecoderModel, TrOCRProcessor
if variant == 'tr': import numpy as np
from transformers import VisionEncoderDecoderModel
import torch
model = VisionEncoderDecoderModel.from_pretrained(model_dir)
assert isinstance(model, VisionEncoderDecoderModel)
device0 = torch.device('cpu')
if not device and torch.cuda.is_available():
device = 'GPU' # try
if device and ':' in device:
for spec in device.split(','):
cat, dev = spec.split(':')
if fnmatchcase('ocr', cat):
device = dev
break
if device and device.startswith('GPU'):
try:
device0 = torch.device('cuda', int(device[3:] or 0))
name = torch.cuda.get_device_name(device0)
self.logger.info("using GPU %s (%s) for model ocr:tr", device0, name)
except:
self.logger.exception("cannot configure GPU device")
device0 = torch.device('cpu')
if device0.type == 'cuda':
model.to(device0)
else:
self.logger.warning("no GPU device available")
return model
return self.load_model('ocr', model_variant=variant, device=device) device = self._configure_torch_device('ocr', device=device)
proc = TrOCRProcessor.from_pretrained(model_path)
model = VisionEncoderDecoderModel.from_pretrained(model_path)
assert isinstance(model, VisionEncoderDecoderModel)
model.to(device)
def predict_torch(inputs):
output = model.generate(
proc(inputs, return_tensors="pt").pixel_values.to(device),
# beam search instead of greedy decoding:
num_beams=4,
# also return probability
output_scores=True,
return_dict_in_generate=True)
if output.sequences_scores is not None:
# log-prob averaged over length
conf = output.sequences_scores.exp().clamp(0.0, 1.0).cpu().numpy()
else:
conf = np.ones(len(output.sequences), dtype=float)
text = proc.batch_decode(
output.sequences,
skip_special_tokens=True,
clean_up_tokenization_spaces=False)
# we must convert to ndarray for Predictor resultq to work
text = np.array(text)
return text, conf
model.predict_on_batch = predict_torch
# not actually needed (image processor does resize itself)
model.input_shape = (None,
proc.image_processor.size.height,
proc.image_processor.size.width,
len(proc.image_processor.image_mean))
return model
def __str__(self): def __str__(self):
return tabulate( return tabulate(

View file

@ -129,6 +129,7 @@ class Predictor(mp.context.SpawnProcess):
"enhancement": 4, "enhancement": 4,
"reading_order": 4, "reading_order": 4,
"ocr": 8, "ocr": 8,
"ocr_tr": 2,
# medium size (672x672x3)... # medium size (672x672x3)...
"textline": 2, "textline": 2,
# large models... # large models...
@ -144,7 +145,14 @@ class Predictor(mp.context.SpawnProcess):
self.resultq.put((jobid, result)) self.resultq.put((jobid, result))
#self.logger.debug("sent result for '%d': %s", jobid, result) #self.logger.debug("sent result for '%d': %s", jobid, result)
else: else:
if isinstance(shared_data, tuple): if self.name == 'ocr_tr':
# this model takes a list of (image) tensors
# of heterogeneous shape as input,
# resizing them internally;
# so this looks like multi-input
multi_input = True
batch_size = len(shared_data)
elif isinstance(shared_data, tuple):
multi_input = True multi_input = True
batch_size = shared_data[0]['shape'][0] batch_size = shared_data[0]['shape'][0]
else: else:
@ -215,8 +223,11 @@ class Predictor(mp.context.SpawnProcess):
def load_model(self, *load_args, **load_kwargs): def load_model(self, *load_args, **load_kwargs):
assert len(load_args) assert len(load_args)
self.name = '_'.join(list(load_args[:1]) + self.name = '_'.join(list(load_args[:1]) +
list(load_kwargs[key] for key in load_kwargs
if key == 'model_variant') +
list(key for key in load_kwargs list(key for key in load_kwargs
if key != 'device')) if key in ['patched', 'resized']
and load_kwargs[key]))
self.load_args = load_args self.load_args = load_args
self.load_kwargs = load_kwargs self.load_kwargs = load_kwargs
self.start() # call run() in subprocess self.start() # call run() in subprocess