CLI/Eynollah.setup_models/ModelZoo.load_models: add device option/kwarg

allow setting device specifier to load models into

either
- CPU or single GPU0, GPU1 etc
- per-model patterns, e.g. col*:CPU,page:GPU0,*:GPU1

pass through as kwargs until `ModelZoo.load_models()` setup up TF
This commit is contained in:
Robert Sachunsky 2026-03-15 04:54:04 +01:00
parent 67e9f84b54
commit 6bbdcc39ef
3 changed files with 33 additions and 8 deletions

View file

@ -165,6 +165,11 @@ import click
type=click.IntRange(min=0),
help="number of parallel images to process (also helps better utilise GPU if available); 0 means based on autodetected number of processor cores",
)
@click.option(
"--device",
"-D",
help="placement of computations in predictors for each model type; if none (by default), will try to use first available GPU or fall back to CPU; set string to force using a device (e.g. 'GPU0', 'GPU1' or 'CPU'). Can also be a comma-separated list of model category to device mappings (e.g. 'col_classifier:CPU,page:GPU0,*:GPU1')",
)
@click.pass_context
def layout_cli(
ctx,
@ -194,6 +199,7 @@ def layout_cli(
skip_layout_and_reading_order,
ignore_page_extraction,
num_jobs,
device,
):
"""
Detect Layout (with optional image enhancement and reading order detection)
@ -209,6 +215,7 @@ def layout_cli(
assert bool(image) != bool(dir_in), "Either -i (single input) or -di (directory) must be provided, but not both."
eynollah = Eynollah(
model_zoo=ctx.obj.model_zoo,
device=device,
enable_plotting=enable_plotting,
allow_enhancement=allow_enhancement,
curved_line=curved_line,

View file

@ -121,6 +121,7 @@ class Eynollah:
self,
*,
model_zoo: EynollahModelZoo,
device: str = '',
enable_plotting : bool = False,
allow_enhancement : bool = False,
curved_line : bool = False,
@ -165,10 +166,10 @@ class Eynollah:
t_start = time.time()
self.logger.info("Loading models...")
self.setup_models()
self.setup_models(device=device)
self.logger.info(f"Model initialization complete ({time.time() - t_start:.1f}s)")
def setup_models(self):
def setup_models(self, device=''):
# load models, depending on modes
# (note: loading too many models can cause OOM on GPU/CUDA,
@ -194,7 +195,7 @@ class Eynollah:
if self.tables:
loadable.append("table")
self.model_zoo.load_models(*loadable)
self.model_zoo.load_models(*loadable, device=device)
for model in loadable:
# retrieve and cache output shapes
if model.endswith(('_resized', '_patched')):

View file

@ -3,6 +3,7 @@ import json
import logging
from copy import deepcopy
from pathlib import Path
from fnmatch import fnmatchcase
from typing import Dict, List, Optional, Tuple, Type, Union
from tabulate import tabulate
@ -74,6 +75,7 @@ class EynollahModelZoo:
def load_models(
self,
*all_load_args: Union[str, Tuple[str], Tuple[str, str], Tuple[str, str, str]],
device: str = '',
) -> Dict:
"""
Load all models by calling load_model and return a dictionary mapping model_category to loaded model
@ -93,7 +95,7 @@ class EynollahModelZoo:
load_args[0] = model_category[:-8]
load_kwargs["patched"] = True
ret[model_category] = Predictor(self.logger, self)
ret[model_category].load_model(*load_args, **load_kwargs)
ret[model_category].load_model(*load_args, **load_kwargs, device=device)
self._loaded.update(ret)
return self._loaded
@ -104,6 +106,7 @@ class EynollahModelZoo:
model_path_override: Optional[str] = None,
patched: bool = False,
resized: bool = False,
device: str = '',
) -> AnyModel:
"""
Load any model
@ -123,10 +126,24 @@ class EynollahModelZoo:
)
cuda = False
try:
device = tf.config.list_physical_devices('GPU')[0]
tf.config.experimental.set_memory_growth(device, True)
cuda = True
self.logger.info("using GPU %s", device.name)
gpus = tf.config.list_physical_devices('GPU')
if device:
if ',' in device:
for spec in device.split(','):
cat, dev = spec.split(':')
if fnmatchcase(model_category, cat):
device = dev
break
if device == 'CPU':
gpus = []
else:
assert device.startswith('GPU')
gpus = [gpus[int(device[3:])]]
tf.config.set_visible_devices(gpus, 'GPU')
for device in gpus:
tf.config.experimental.set_memory_growth(device, True)
cuda = True
self.logger.info("using GPU %s for model %s", device.name, model_category)
except RuntimeError:
self.logger.exception("cannot configure GPU devices")
if not cuda: