mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-06-16 09:59:13 +02:00
trocr: move preprocessor and decoder into model object, too…
- ModelZoo: drop `trocr_processor` model type - `ModelZoo.load_models()`: use Predictor for `ocr_tr` models, too - `ModelZoo.load_model()`: for `ocr_tr`, load processor and model, then define a function object as stand-in for the common model interface based on Keras (w/ `.predict_on_batch()`) - Predictor: allow multi-input without actual batch dimension for `ocr_tr` models (because the model takes a list of original image arrays and resizes them to model shape internally) - Eynollah_ocr: adapt (replacing preprocessing, prediction and decoding steps by a single `.predict()` call)
This commit is contained in:
parent
d2f2a1e06b
commit
f447a9f248
4 changed files with 88 additions and 86 deletions
|
|
@ -70,8 +70,7 @@ class Eynollah_ocr(Eynollah):
|
||||||
|
|
||||||
def setup_models(self, device=''):
|
def setup_models(self, device=''):
|
||||||
if self.tr_ocr:
|
if self.tr_ocr:
|
||||||
self.model_zoo.load_models('trocr_processor',
|
self.model_zoo.load_models(('ocr', 'tr'),
|
||||||
('ocr', 'tr'),
|
|
||||||
device=device)
|
device=device)
|
||||||
else:
|
else:
|
||||||
self.model_zoo.load_models('ocr',
|
self.model_zoo.load_models('ocr',
|
||||||
|
|
@ -142,24 +141,7 @@ class Eynollah_ocr(Eynollah):
|
||||||
self.logger.debug("processing %d lines for %d regions",
|
self.logger.debug("processing %d lines for %d regions",
|
||||||
len(cropped_lines), len(set(cropped_lines_region_indexer)))
|
len(cropped_lines), len(set(cropped_lines_region_indexer)))
|
||||||
for imgs in batched(cropped_lines, self.b_s):
|
for imgs in batched(cropped_lines, self.b_s):
|
||||||
pixel_values = self.model_zoo.get('trocr_processor')(
|
text, conf = self.model_zoo.get('ocr').predict(imgs)
|
||||||
imgs, return_tensors="pt").pixel_values
|
|
||||||
output = self.model_zoo.get('ocr').generate(
|
|
||||||
pixel_values.to(self.device),
|
|
||||||
# beam search instead of greedy decoding:
|
|
||||||
num_beams=4,
|
|
||||||
# also return probability
|
|
||||||
output_scores=True,
|
|
||||||
return_dict_in_generate=True)
|
|
||||||
if output.sequences_scores is not None:
|
|
||||||
# log-prob averaged over length
|
|
||||||
conf = output.sequences_scores.exp().clamp(0.0, 1.0).tolist()
|
|
||||||
else:
|
|
||||||
conf = [1.0] * len(output.sequences)
|
|
||||||
text = self.model_zoo.get('trocr_processor').batch_decode(
|
|
||||||
output.sequences,
|
|
||||||
skip_special_tokens=True,
|
|
||||||
clean_up_tokenization_spaces=False)
|
|
||||||
extracted_confs.extend(conf)
|
extracted_confs.extend(conf)
|
||||||
extracted_texts.extend(text)
|
extracted_texts.extend(text)
|
||||||
del cropped_lines
|
del cropped_lines
|
||||||
|
|
|
||||||
|
|
@ -217,20 +217,4 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([
|
||||||
type='Keras',
|
type='Keras',
|
||||||
),
|
),
|
||||||
|
|
||||||
EynollahModelSpec(
|
|
||||||
category="trocr_processor",
|
|
||||||
variant='',
|
|
||||||
filename="models_eynollah/model_eynollah_ocr_trocr_20250919",
|
|
||||||
dist_url=dist_url("ocr"),
|
|
||||||
type='TrOCRProcessor',
|
|
||||||
),
|
|
||||||
|
|
||||||
EynollahModelSpec(
|
|
||||||
category="trocr_processor",
|
|
||||||
variant='htr',
|
|
||||||
filename="models_eynollah/microsoft/trocr-base-handwritten",
|
|
||||||
dist_url=dist_url("extra"),
|
|
||||||
type='TrOCRProcessor',
|
|
||||||
),
|
|
||||||
|
|
||||||
])
|
])
|
||||||
|
|
|
||||||
|
|
@ -116,22 +116,15 @@ class EynollahModelZoo:
|
||||||
model_category, model_variant = load_args
|
model_category, model_variant = load_args
|
||||||
load_kwargs["model_variant"] = model_variant
|
load_kwargs["model_variant"] = model_variant
|
||||||
|
|
||||||
if model_category.endswith('_resized'):
|
# if model_category.endswith('_resized'):
|
||||||
model_category = model_category[:-8]
|
# model_category = model_category[:-8]
|
||||||
load_kwargs["resized"] = True
|
# load_kwargs["resized"] = True
|
||||||
elif model_category.endswith('_patched'):
|
# elif model_category.endswith('_patched'):
|
||||||
model_category = model_category[:-8]
|
# model_category = model_category[:-8]
|
||||||
load_kwargs["patched"] = True
|
# load_kwargs["patched"] = True
|
||||||
|
|
||||||
if model_category == 'ocr' and model_variant == 'tr':
|
model = Predictor(self.logger, self)
|
||||||
model = self._load_ocr_model(variant=model_variant, device=device)
|
model.load_model(model_category, **load_kwargs)
|
||||||
elif model_category == 'trocr_processor':
|
|
||||||
from transformers import TrOCRProcessor
|
|
||||||
model_path = self.model_path(model_category, model_variant)
|
|
||||||
model = TrOCRProcessor.from_pretrained(model_path)
|
|
||||||
else:
|
|
||||||
model = Predictor(self.logger, self)
|
|
||||||
model.load_model(model_category, **load_kwargs)
|
|
||||||
|
|
||||||
ret[model_category] = model
|
ret[model_category] = model
|
||||||
self._loaded.update(ret)
|
self._loaded.update(ret)
|
||||||
|
|
@ -142,8 +135,8 @@ class EynollahModelZoo:
|
||||||
model_category: str,
|
model_category: str,
|
||||||
model_variant: str = '',
|
model_variant: str = '',
|
||||||
model_path_override: Optional[str] = None,
|
model_path_override: Optional[str] = None,
|
||||||
patched: bool = False,
|
# patched: bool = False,
|
||||||
resized: bool = False,
|
# resized: bool = False,
|
||||||
device: str = '',
|
device: str = '',
|
||||||
) -> AnyModel:
|
) -> AnyModel:
|
||||||
"""
|
"""
|
||||||
|
|
@ -153,7 +146,9 @@ class EynollahModelZoo:
|
||||||
self.override_models((model_category, model_variant, model_path_override))
|
self.override_models((model_category, model_variant, model_path_override))
|
||||||
model_path = self.model_path(model_category, model_variant)
|
model_path = self.model_path(model_category, model_variant)
|
||||||
|
|
||||||
if model_path.is_dir() and (model_path / "keras_metadata.pb").exists():
|
if model_category == 'ocr' and model_variant == 'tr':
|
||||||
|
model = self._load_trocr_model(model_path, device=device)
|
||||||
|
elif model_path.is_dir() and (model_path / "keras_metadata.pb").exists():
|
||||||
# Keras model
|
# Keras model
|
||||||
model = self._load_keras_model(model_category, model_path, device=device)
|
model = self._load_keras_model(model_category, model_path, device=device)
|
||||||
elif model_path.is_dir():
|
elif model_path.is_dir():
|
||||||
|
|
@ -220,6 +215,30 @@ class EynollahModelZoo:
|
||||||
if not cuda:
|
if not cuda:
|
||||||
self.logger.warning("no GPU device available")
|
self.logger.warning("no GPU device available")
|
||||||
|
|
||||||
|
def _configure_torch_device(self, model_category, device=''):
|
||||||
|
import torch
|
||||||
|
|
||||||
|
device0 = torch.device('cpu')
|
||||||
|
if not device and torch.cuda.is_available():
|
||||||
|
device = 'GPU' # try
|
||||||
|
if device and ':' in device:
|
||||||
|
for spec in device.split(','):
|
||||||
|
cat, dev = spec.split(':')
|
||||||
|
if fnmatchcase('ocr', cat):
|
||||||
|
device = dev
|
||||||
|
break
|
||||||
|
if device and device.startswith('GPU'):
|
||||||
|
try:
|
||||||
|
device0 = torch.device('cuda', int(device[3:] or 0))
|
||||||
|
name = torch.cuda.get_device_name(device0)
|
||||||
|
self.logger.info("using GPU %s (%s) for model ocr:tr", device0, name)
|
||||||
|
except:
|
||||||
|
self.logger.exception("cannot configure GPU device")
|
||||||
|
device0 = torch.device('cpu')
|
||||||
|
if device0.type != 'cuda':
|
||||||
|
self.logger.warning("no GPU device available")
|
||||||
|
return device0
|
||||||
|
|
||||||
def _load_keras_model(self, model_category, model_path, device=''):
|
def _load_keras_model(self, model_category, model_path, device=''):
|
||||||
os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15
|
os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15
|
||||||
from ocrd_utils import tf_disable_interactive_logs
|
from ocrd_utils import tf_disable_interactive_logs
|
||||||
|
|
@ -325,40 +344,46 @@ class EynollahModelZoo:
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def _load_ocr_model(self, variant: str, device: str = "") -> AnyModel:
|
def _load_trocr_model(self, model_path, device: str = "") -> AnyModel:
|
||||||
"""
|
"""
|
||||||
Load OCR model
|
Load OCR model
|
||||||
"""
|
"""
|
||||||
model_dir = self.model_path('ocr', variant)
|
from transformers import VisionEncoderDecoderModel, TrOCRProcessor
|
||||||
if variant == 'tr':
|
import numpy as np
|
||||||
from transformers import VisionEncoderDecoderModel
|
|
||||||
import torch
|
|
||||||
model = VisionEncoderDecoderModel.from_pretrained(model_dir)
|
|
||||||
assert isinstance(model, VisionEncoderDecoderModel)
|
|
||||||
device0 = torch.device('cpu')
|
|
||||||
if not device and torch.cuda.is_available():
|
|
||||||
device = 'GPU' # try
|
|
||||||
if device and ':' in device:
|
|
||||||
for spec in device.split(','):
|
|
||||||
cat, dev = spec.split(':')
|
|
||||||
if fnmatchcase('ocr', cat):
|
|
||||||
device = dev
|
|
||||||
break
|
|
||||||
if device and device.startswith('GPU'):
|
|
||||||
try:
|
|
||||||
device0 = torch.device('cuda', int(device[3:] or 0))
|
|
||||||
name = torch.cuda.get_device_name(device0)
|
|
||||||
self.logger.info("using GPU %s (%s) for model ocr:tr", device0, name)
|
|
||||||
except:
|
|
||||||
self.logger.exception("cannot configure GPU device")
|
|
||||||
device0 = torch.device('cpu')
|
|
||||||
if device0.type == 'cuda':
|
|
||||||
model.to(device0)
|
|
||||||
else:
|
|
||||||
self.logger.warning("no GPU device available")
|
|
||||||
return model
|
|
||||||
|
|
||||||
return self.load_model('ocr', model_variant=variant, device=device)
|
device = self._configure_torch_device('ocr', device=device)
|
||||||
|
proc = TrOCRProcessor.from_pretrained(model_path)
|
||||||
|
model = VisionEncoderDecoderModel.from_pretrained(model_path)
|
||||||
|
assert isinstance(model, VisionEncoderDecoderModel)
|
||||||
|
|
||||||
|
model.to(device)
|
||||||
|
def predict_torch(inputs):
|
||||||
|
output = model.generate(
|
||||||
|
proc(inputs, return_tensors="pt").pixel_values.to(device),
|
||||||
|
# beam search instead of greedy decoding:
|
||||||
|
num_beams=4,
|
||||||
|
# also return probability
|
||||||
|
output_scores=True,
|
||||||
|
return_dict_in_generate=True)
|
||||||
|
if output.sequences_scores is not None:
|
||||||
|
# log-prob averaged over length
|
||||||
|
conf = output.sequences_scores.exp().clamp(0.0, 1.0).cpu().numpy()
|
||||||
|
else:
|
||||||
|
conf = np.ones(len(output.sequences), dtype=float)
|
||||||
|
text = proc.batch_decode(
|
||||||
|
output.sequences,
|
||||||
|
skip_special_tokens=True,
|
||||||
|
clean_up_tokenization_spaces=False)
|
||||||
|
# we must convert to ndarray for Predictor resultq to work
|
||||||
|
text = np.array(text)
|
||||||
|
return text, conf
|
||||||
|
model.predict_on_batch = predict_torch
|
||||||
|
# not actually needed (image processor does resize itself)
|
||||||
|
model.input_shape = (None,
|
||||||
|
proc.image_processor.size.height,
|
||||||
|
proc.image_processor.size.width,
|
||||||
|
len(proc.image_processor.image_mean))
|
||||||
|
return model
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return tabulate(
|
return tabulate(
|
||||||
|
|
|
||||||
|
|
@ -129,6 +129,7 @@ class Predictor(mp.context.SpawnProcess):
|
||||||
"enhancement": 4,
|
"enhancement": 4,
|
||||||
"reading_order": 4,
|
"reading_order": 4,
|
||||||
"ocr": 8,
|
"ocr": 8,
|
||||||
|
"ocr_tr": 2,
|
||||||
# medium size (672x672x3)...
|
# medium size (672x672x3)...
|
||||||
"textline": 2,
|
"textline": 2,
|
||||||
# large models...
|
# large models...
|
||||||
|
|
@ -144,7 +145,14 @@ class Predictor(mp.context.SpawnProcess):
|
||||||
self.resultq.put((jobid, result))
|
self.resultq.put((jobid, result))
|
||||||
#self.logger.debug("sent result for '%d': %s", jobid, result)
|
#self.logger.debug("sent result for '%d': %s", jobid, result)
|
||||||
else:
|
else:
|
||||||
if isinstance(shared_data, tuple):
|
if self.name == 'ocr_tr':
|
||||||
|
# this model takes a list of (image) tensors
|
||||||
|
# of heterogeneous shape as input,
|
||||||
|
# resizing them internally;
|
||||||
|
# so this looks like multi-input
|
||||||
|
multi_input = True
|
||||||
|
batch_size = len(shared_data)
|
||||||
|
elif isinstance(shared_data, tuple):
|
||||||
multi_input = True
|
multi_input = True
|
||||||
batch_size = shared_data[0]['shape'][0]
|
batch_size = shared_data[0]['shape'][0]
|
||||||
else:
|
else:
|
||||||
|
|
@ -215,8 +223,11 @@ class Predictor(mp.context.SpawnProcess):
|
||||||
def load_model(self, *load_args, **load_kwargs):
|
def load_model(self, *load_args, **load_kwargs):
|
||||||
assert len(load_args)
|
assert len(load_args)
|
||||||
self.name = '_'.join(list(load_args[:1]) +
|
self.name = '_'.join(list(load_args[:1]) +
|
||||||
|
list(load_kwargs[key] for key in load_kwargs
|
||||||
|
if key == 'model_variant') +
|
||||||
list(key for key in load_kwargs
|
list(key for key in load_kwargs
|
||||||
if key != 'device'))
|
if key in ['patched', 'resized']
|
||||||
|
and load_kwargs[key]))
|
||||||
self.load_args = load_args
|
self.load_args = load_args
|
||||||
self.load_kwargs = load_kwargs
|
self.load_kwargs = load_kwargs
|
||||||
self.start() # call run() in subprocess
|
self.start() # call run() in subprocess
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue