mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-06-16 09:59:13 +02:00
trocr: move preprocessor and decoder into model object, too…
- ModelZoo: drop `trocr_processor` model type - `ModelZoo.load_models()`: use Predictor for `ocr_tr` models, too - `ModelZoo.load_model()`: for `ocr_tr`, load processor and model, then define a function object as stand-in for the common model interface based on Keras (w/ `.predict_on_batch()`) - Predictor: allow multi-input without actual batch dimension for `ocr_tr` models (because the model takes a list of original image arrays and resizes them to model shape internally) - Eynollah_ocr: adapt (replacing preprocessing, prediction and decoding steps by a single `.predict()` call)
This commit is contained in:
parent
d2f2a1e06b
commit
f447a9f248
4 changed files with 88 additions and 86 deletions
|
|
@ -70,8 +70,7 @@ class Eynollah_ocr(Eynollah):
|
|||
|
||||
def setup_models(self, device=''):
|
||||
if self.tr_ocr:
|
||||
self.model_zoo.load_models('trocr_processor',
|
||||
('ocr', 'tr'),
|
||||
self.model_zoo.load_models(('ocr', 'tr'),
|
||||
device=device)
|
||||
else:
|
||||
self.model_zoo.load_models('ocr',
|
||||
|
|
@ -142,24 +141,7 @@ class Eynollah_ocr(Eynollah):
|
|||
self.logger.debug("processing %d lines for %d regions",
|
||||
len(cropped_lines), len(set(cropped_lines_region_indexer)))
|
||||
for imgs in batched(cropped_lines, self.b_s):
|
||||
pixel_values = self.model_zoo.get('trocr_processor')(
|
||||
imgs, return_tensors="pt").pixel_values
|
||||
output = self.model_zoo.get('ocr').generate(
|
||||
pixel_values.to(self.device),
|
||||
# beam search instead of greedy decoding:
|
||||
num_beams=4,
|
||||
# also return probability
|
||||
output_scores=True,
|
||||
return_dict_in_generate=True)
|
||||
if output.sequences_scores is not None:
|
||||
# log-prob averaged over length
|
||||
conf = output.sequences_scores.exp().clamp(0.0, 1.0).tolist()
|
||||
else:
|
||||
conf = [1.0] * len(output.sequences)
|
||||
text = self.model_zoo.get('trocr_processor').batch_decode(
|
||||
output.sequences,
|
||||
skip_special_tokens=True,
|
||||
clean_up_tokenization_spaces=False)
|
||||
text, conf = self.model_zoo.get('ocr').predict(imgs)
|
||||
extracted_confs.extend(conf)
|
||||
extracted_texts.extend(text)
|
||||
del cropped_lines
|
||||
|
|
|
|||
|
|
@ -217,20 +217,4 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([
|
|||
type='Keras',
|
||||
),
|
||||
|
||||
EynollahModelSpec(
|
||||
category="trocr_processor",
|
||||
variant='',
|
||||
filename="models_eynollah/model_eynollah_ocr_trocr_20250919",
|
||||
dist_url=dist_url("ocr"),
|
||||
type='TrOCRProcessor',
|
||||
),
|
||||
|
||||
EynollahModelSpec(
|
||||
category="trocr_processor",
|
||||
variant='htr',
|
||||
filename="models_eynollah/microsoft/trocr-base-handwritten",
|
||||
dist_url=dist_url("extra"),
|
||||
type='TrOCRProcessor',
|
||||
),
|
||||
|
||||
])
|
||||
|
|
|
|||
|
|
@ -116,20 +116,13 @@ class EynollahModelZoo:
|
|||
model_category, model_variant = load_args
|
||||
load_kwargs["model_variant"] = model_variant
|
||||
|
||||
if model_category.endswith('_resized'):
|
||||
model_category = model_category[:-8]
|
||||
load_kwargs["resized"] = True
|
||||
elif model_category.endswith('_patched'):
|
||||
model_category = model_category[:-8]
|
||||
load_kwargs["patched"] = True
|
||||
# if model_category.endswith('_resized'):
|
||||
# model_category = model_category[:-8]
|
||||
# load_kwargs["resized"] = True
|
||||
# elif model_category.endswith('_patched'):
|
||||
# model_category = model_category[:-8]
|
||||
# load_kwargs["patched"] = True
|
||||
|
||||
if model_category == 'ocr' and model_variant == 'tr':
|
||||
model = self._load_ocr_model(variant=model_variant, device=device)
|
||||
elif model_category == 'trocr_processor':
|
||||
from transformers import TrOCRProcessor
|
||||
model_path = self.model_path(model_category, model_variant)
|
||||
model = TrOCRProcessor.from_pretrained(model_path)
|
||||
else:
|
||||
model = Predictor(self.logger, self)
|
||||
model.load_model(model_category, **load_kwargs)
|
||||
|
||||
|
|
@ -142,8 +135,8 @@ class EynollahModelZoo:
|
|||
model_category: str,
|
||||
model_variant: str = '',
|
||||
model_path_override: Optional[str] = None,
|
||||
patched: bool = False,
|
||||
resized: bool = False,
|
||||
# patched: bool = False,
|
||||
# resized: bool = False,
|
||||
device: str = '',
|
||||
) -> AnyModel:
|
||||
"""
|
||||
|
|
@ -153,7 +146,9 @@ class EynollahModelZoo:
|
|||
self.override_models((model_category, model_variant, model_path_override))
|
||||
model_path = self.model_path(model_category, model_variant)
|
||||
|
||||
if model_path.is_dir() and (model_path / "keras_metadata.pb").exists():
|
||||
if model_category == 'ocr' and model_variant == 'tr':
|
||||
model = self._load_trocr_model(model_path, device=device)
|
||||
elif model_path.is_dir() and (model_path / "keras_metadata.pb").exists():
|
||||
# Keras model
|
||||
model = self._load_keras_model(model_category, model_path, device=device)
|
||||
elif model_path.is_dir():
|
||||
|
|
@ -220,6 +215,30 @@ class EynollahModelZoo:
|
|||
if not cuda:
|
||||
self.logger.warning("no GPU device available")
|
||||
|
||||
def _configure_torch_device(self, model_category, device=''):
|
||||
import torch
|
||||
|
||||
device0 = torch.device('cpu')
|
||||
if not device and torch.cuda.is_available():
|
||||
device = 'GPU' # try
|
||||
if device and ':' in device:
|
||||
for spec in device.split(','):
|
||||
cat, dev = spec.split(':')
|
||||
if fnmatchcase('ocr', cat):
|
||||
device = dev
|
||||
break
|
||||
if device and device.startswith('GPU'):
|
||||
try:
|
||||
device0 = torch.device('cuda', int(device[3:] or 0))
|
||||
name = torch.cuda.get_device_name(device0)
|
||||
self.logger.info("using GPU %s (%s) for model ocr:tr", device0, name)
|
||||
except:
|
||||
self.logger.exception("cannot configure GPU device")
|
||||
device0 = torch.device('cpu')
|
||||
if device0.type != 'cuda':
|
||||
self.logger.warning("no GPU device available")
|
||||
return device0
|
||||
|
||||
def _load_keras_model(self, model_category, model_path, device=''):
|
||||
os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15
|
||||
from ocrd_utils import tf_disable_interactive_logs
|
||||
|
|
@ -325,40 +344,46 @@ class EynollahModelZoo:
|
|||
|
||||
return model
|
||||
|
||||
def _load_ocr_model(self, variant: str, device: str = "") -> AnyModel:
|
||||
def _load_trocr_model(self, model_path, device: str = "") -> AnyModel:
|
||||
"""
|
||||
Load OCR model
|
||||
"""
|
||||
model_dir = self.model_path('ocr', variant)
|
||||
if variant == 'tr':
|
||||
from transformers import VisionEncoderDecoderModel
|
||||
import torch
|
||||
model = VisionEncoderDecoderModel.from_pretrained(model_dir)
|
||||
assert isinstance(model, VisionEncoderDecoderModel)
|
||||
device0 = torch.device('cpu')
|
||||
if not device and torch.cuda.is_available():
|
||||
device = 'GPU' # try
|
||||
if device and ':' in device:
|
||||
for spec in device.split(','):
|
||||
cat, dev = spec.split(':')
|
||||
if fnmatchcase('ocr', cat):
|
||||
device = dev
|
||||
break
|
||||
if device and device.startswith('GPU'):
|
||||
try:
|
||||
device0 = torch.device('cuda', int(device[3:] or 0))
|
||||
name = torch.cuda.get_device_name(device0)
|
||||
self.logger.info("using GPU %s (%s) for model ocr:tr", device0, name)
|
||||
except:
|
||||
self.logger.exception("cannot configure GPU device")
|
||||
device0 = torch.device('cpu')
|
||||
if device0.type == 'cuda':
|
||||
model.to(device0)
|
||||
else:
|
||||
self.logger.warning("no GPU device available")
|
||||
return model
|
||||
from transformers import VisionEncoderDecoderModel, TrOCRProcessor
|
||||
import numpy as np
|
||||
|
||||
return self.load_model('ocr', model_variant=variant, device=device)
|
||||
device = self._configure_torch_device('ocr', device=device)
|
||||
proc = TrOCRProcessor.from_pretrained(model_path)
|
||||
model = VisionEncoderDecoderModel.from_pretrained(model_path)
|
||||
assert isinstance(model, VisionEncoderDecoderModel)
|
||||
|
||||
model.to(device)
|
||||
def predict_torch(inputs):
|
||||
output = model.generate(
|
||||
proc(inputs, return_tensors="pt").pixel_values.to(device),
|
||||
# beam search instead of greedy decoding:
|
||||
num_beams=4,
|
||||
# also return probability
|
||||
output_scores=True,
|
||||
return_dict_in_generate=True)
|
||||
if output.sequences_scores is not None:
|
||||
# log-prob averaged over length
|
||||
conf = output.sequences_scores.exp().clamp(0.0, 1.0).cpu().numpy()
|
||||
else:
|
||||
conf = np.ones(len(output.sequences), dtype=float)
|
||||
text = proc.batch_decode(
|
||||
output.sequences,
|
||||
skip_special_tokens=True,
|
||||
clean_up_tokenization_spaces=False)
|
||||
# we must convert to ndarray for Predictor resultq to work
|
||||
text = np.array(text)
|
||||
return text, conf
|
||||
model.predict_on_batch = predict_torch
|
||||
# not actually needed (image processor does resize itself)
|
||||
model.input_shape = (None,
|
||||
proc.image_processor.size.height,
|
||||
proc.image_processor.size.width,
|
||||
len(proc.image_processor.image_mean))
|
||||
return model
|
||||
|
||||
def __str__(self):
|
||||
return tabulate(
|
||||
|
|
|
|||
|
|
@ -129,6 +129,7 @@ class Predictor(mp.context.SpawnProcess):
|
|||
"enhancement": 4,
|
||||
"reading_order": 4,
|
||||
"ocr": 8,
|
||||
"ocr_tr": 2,
|
||||
# medium size (672x672x3)...
|
||||
"textline": 2,
|
||||
# large models...
|
||||
|
|
@ -144,7 +145,14 @@ class Predictor(mp.context.SpawnProcess):
|
|||
self.resultq.put((jobid, result))
|
||||
#self.logger.debug("sent result for '%d': %s", jobid, result)
|
||||
else:
|
||||
if isinstance(shared_data, tuple):
|
||||
if self.name == 'ocr_tr':
|
||||
# this model takes a list of (image) tensors
|
||||
# of heterogeneous shape as input,
|
||||
# resizing them internally;
|
||||
# so this looks like multi-input
|
||||
multi_input = True
|
||||
batch_size = len(shared_data)
|
||||
elif isinstance(shared_data, tuple):
|
||||
multi_input = True
|
||||
batch_size = shared_data[0]['shape'][0]
|
||||
else:
|
||||
|
|
@ -215,8 +223,11 @@ class Predictor(mp.context.SpawnProcess):
|
|||
def load_model(self, *load_args, **load_kwargs):
|
||||
assert len(load_args)
|
||||
self.name = '_'.join(list(load_args[:1]) +
|
||||
list(load_kwargs[key] for key in load_kwargs
|
||||
if key == 'model_variant') +
|
||||
list(key for key in load_kwargs
|
||||
if key != 'device'))
|
||||
if key in ['patched', 'resized']
|
||||
and load_kwargs[key]))
|
||||
self.load_args = load_args
|
||||
self.load_kwargs = load_kwargs
|
||||
self.start() # call run() in subprocess
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue