CLI/Eynollah.setup_models/ModelZoo.load_models: add device option/kwarg

allow setting device specifier to load models into

either
- CPU or single GPU0, GPU1 etc
- per-model patterns, e.g. col*:CPU,page:GPU0,*:GPU1

pass through as kwargs until `ModelZoo.load_models()` setup up TF
This commit is contained in:
Robert Sachunsky 2026-03-15 04:54:04 +01:00
parent 67e9f84b54
commit 6bbdcc39ef
3 changed files with 33 additions and 8 deletions

View file

@ -165,6 +165,11 @@ import click
type=click.IntRange(min=0), type=click.IntRange(min=0),
help="number of parallel images to process (also helps better utilise GPU if available); 0 means based on autodetected number of processor cores", help="number of parallel images to process (also helps better utilise GPU if available); 0 means based on autodetected number of processor cores",
) )
@click.option(
"--device",
"-D",
help="placement of computations in predictors for each model type; if none (by default), will try to use first available GPU or fall back to CPU; set string to force using a device (e.g. 'GPU0', 'GPU1' or 'CPU'). Can also be a comma-separated list of model category to device mappings (e.g. 'col_classifier:CPU,page:GPU0,*:GPU1')",
)
@click.pass_context @click.pass_context
def layout_cli( def layout_cli(
ctx, ctx,
@ -194,6 +199,7 @@ def layout_cli(
skip_layout_and_reading_order, skip_layout_and_reading_order,
ignore_page_extraction, ignore_page_extraction,
num_jobs, num_jobs,
device,
): ):
""" """
Detect Layout (with optional image enhancement and reading order detection) Detect Layout (with optional image enhancement and reading order detection)
@ -209,6 +215,7 @@ def layout_cli(
assert bool(image) != bool(dir_in), "Either -i (single input) or -di (directory) must be provided, but not both." assert bool(image) != bool(dir_in), "Either -i (single input) or -di (directory) must be provided, but not both."
eynollah = Eynollah( eynollah = Eynollah(
model_zoo=ctx.obj.model_zoo, model_zoo=ctx.obj.model_zoo,
device=device,
enable_plotting=enable_plotting, enable_plotting=enable_plotting,
allow_enhancement=allow_enhancement, allow_enhancement=allow_enhancement,
curved_line=curved_line, curved_line=curved_line,

View file

@ -121,6 +121,7 @@ class Eynollah:
self, self,
*, *,
model_zoo: EynollahModelZoo, model_zoo: EynollahModelZoo,
device: str = '',
enable_plotting : bool = False, enable_plotting : bool = False,
allow_enhancement : bool = False, allow_enhancement : bool = False,
curved_line : bool = False, curved_line : bool = False,
@ -165,10 +166,10 @@ class Eynollah:
t_start = time.time() t_start = time.time()
self.logger.info("Loading models...") self.logger.info("Loading models...")
self.setup_models() self.setup_models(device=device)
self.logger.info(f"Model initialization complete ({time.time() - t_start:.1f}s)") self.logger.info(f"Model initialization complete ({time.time() - t_start:.1f}s)")
def setup_models(self): def setup_models(self, device=''):
# load models, depending on modes # load models, depending on modes
# (note: loading too many models can cause OOM on GPU/CUDA, # (note: loading too many models can cause OOM on GPU/CUDA,
@ -194,7 +195,7 @@ class Eynollah:
if self.tables: if self.tables:
loadable.append("table") loadable.append("table")
self.model_zoo.load_models(*loadable) self.model_zoo.load_models(*loadable, device=device)
for model in loadable: for model in loadable:
# retrieve and cache output shapes # retrieve and cache output shapes
if model.endswith(('_resized', '_patched')): if model.endswith(('_resized', '_patched')):

View file

@ -3,6 +3,7 @@ import json
import logging import logging
from copy import deepcopy from copy import deepcopy
from pathlib import Path from pathlib import Path
from fnmatch import fnmatchcase
from typing import Dict, List, Optional, Tuple, Type, Union from typing import Dict, List, Optional, Tuple, Type, Union
from tabulate import tabulate from tabulate import tabulate
@ -74,6 +75,7 @@ class EynollahModelZoo:
def load_models( def load_models(
self, self,
*all_load_args: Union[str, Tuple[str], Tuple[str, str], Tuple[str, str, str]], *all_load_args: Union[str, Tuple[str], Tuple[str, str], Tuple[str, str, str]],
device: str = '',
) -> Dict: ) -> Dict:
""" """
Load all models by calling load_model and return a dictionary mapping model_category to loaded model Load all models by calling load_model and return a dictionary mapping model_category to loaded model
@ -93,7 +95,7 @@ class EynollahModelZoo:
load_args[0] = model_category[:-8] load_args[0] = model_category[:-8]
load_kwargs["patched"] = True load_kwargs["patched"] = True
ret[model_category] = Predictor(self.logger, self) ret[model_category] = Predictor(self.logger, self)
ret[model_category].load_model(*load_args, **load_kwargs) ret[model_category].load_model(*load_args, **load_kwargs, device=device)
self._loaded.update(ret) self._loaded.update(ret)
return self._loaded return self._loaded
@ -104,6 +106,7 @@ class EynollahModelZoo:
model_path_override: Optional[str] = None, model_path_override: Optional[str] = None,
patched: bool = False, patched: bool = False,
resized: bool = False, resized: bool = False,
device: str = '',
) -> AnyModel: ) -> AnyModel:
""" """
Load any model Load any model
@ -123,10 +126,24 @@ class EynollahModelZoo:
) )
cuda = False cuda = False
try: try:
device = tf.config.list_physical_devices('GPU')[0] gpus = tf.config.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(device, True) if device:
cuda = True if ',' in device:
self.logger.info("using GPU %s", device.name) for spec in device.split(','):
cat, dev = spec.split(':')
if fnmatchcase(model_category, cat):
device = dev
break
if device == 'CPU':
gpus = []
else:
assert device.startswith('GPU')
gpus = [gpus[int(device[3:])]]
tf.config.set_visible_devices(gpus, 'GPU')
for device in gpus:
tf.config.experimental.set_memory_growth(device, True)
cuda = True
self.logger.info("using GPU %s for model %s", device.name, model_category)
except RuntimeError: except RuntimeError:
self.logger.exception("cannot configure GPU devices") self.logger.exception("cannot configure GPU devices")
if not cuda: if not cuda: