mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-04-14 19:31:57 +02:00
CLI/Eynollah.setup_models/ModelZoo.load_models: add device option/kwarg
allow setting device specifier to load models into either - CPU or single GPU0, GPU1 etc - per-model patterns, e.g. col*:CPU,page:GPU0,*:GPU1 pass through as kwargs until `ModelZoo.load_models()` setup up TF
This commit is contained in:
parent
67e9f84b54
commit
6bbdcc39ef
3 changed files with 33 additions and 8 deletions
|
|
@ -165,6 +165,11 @@ import click
|
|||
type=click.IntRange(min=0),
|
||||
help="number of parallel images to process (also helps better utilise GPU if available); 0 means based on autodetected number of processor cores",
|
||||
)
|
||||
@click.option(
|
||||
"--device",
|
||||
"-D",
|
||||
help="placement of computations in predictors for each model type; if none (by default), will try to use first available GPU or fall back to CPU; set string to force using a device (e.g. 'GPU0', 'GPU1' or 'CPU'). Can also be a comma-separated list of model category to device mappings (e.g. 'col_classifier:CPU,page:GPU0,*:GPU1')",
|
||||
)
|
||||
@click.pass_context
|
||||
def layout_cli(
|
||||
ctx,
|
||||
|
|
@ -194,6 +199,7 @@ def layout_cli(
|
|||
skip_layout_and_reading_order,
|
||||
ignore_page_extraction,
|
||||
num_jobs,
|
||||
device,
|
||||
):
|
||||
"""
|
||||
Detect Layout (with optional image enhancement and reading order detection)
|
||||
|
|
@ -209,6 +215,7 @@ def layout_cli(
|
|||
assert bool(image) != bool(dir_in), "Either -i (single input) or -di (directory) must be provided, but not both."
|
||||
eynollah = Eynollah(
|
||||
model_zoo=ctx.obj.model_zoo,
|
||||
device=device,
|
||||
enable_plotting=enable_plotting,
|
||||
allow_enhancement=allow_enhancement,
|
||||
curved_line=curved_line,
|
||||
|
|
|
|||
|
|
@ -121,6 +121,7 @@ class Eynollah:
|
|||
self,
|
||||
*,
|
||||
model_zoo: EynollahModelZoo,
|
||||
device: str = '',
|
||||
enable_plotting : bool = False,
|
||||
allow_enhancement : bool = False,
|
||||
curved_line : bool = False,
|
||||
|
|
@ -165,10 +166,10 @@ class Eynollah:
|
|||
t_start = time.time()
|
||||
|
||||
self.logger.info("Loading models...")
|
||||
self.setup_models()
|
||||
self.setup_models(device=device)
|
||||
self.logger.info(f"Model initialization complete ({time.time() - t_start:.1f}s)")
|
||||
|
||||
def setup_models(self):
|
||||
def setup_models(self, device=''):
|
||||
|
||||
# load models, depending on modes
|
||||
# (note: loading too many models can cause OOM on GPU/CUDA,
|
||||
|
|
@ -194,7 +195,7 @@ class Eynollah:
|
|||
if self.tables:
|
||||
loadable.append("table")
|
||||
|
||||
self.model_zoo.load_models(*loadable)
|
||||
self.model_zoo.load_models(*loadable, device=device)
|
||||
for model in loadable:
|
||||
# retrieve and cache output shapes
|
||||
if model.endswith(('_resized', '_patched')):
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ import json
|
|||
import logging
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
from fnmatch import fnmatchcase
|
||||
from typing import Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
from tabulate import tabulate
|
||||
|
|
@ -74,6 +75,7 @@ class EynollahModelZoo:
|
|||
def load_models(
|
||||
self,
|
||||
*all_load_args: Union[str, Tuple[str], Tuple[str, str], Tuple[str, str, str]],
|
||||
device: str = '',
|
||||
) -> Dict:
|
||||
"""
|
||||
Load all models by calling load_model and return a dictionary mapping model_category to loaded model
|
||||
|
|
@ -93,7 +95,7 @@ class EynollahModelZoo:
|
|||
load_args[0] = model_category[:-8]
|
||||
load_kwargs["patched"] = True
|
||||
ret[model_category] = Predictor(self.logger, self)
|
||||
ret[model_category].load_model(*load_args, **load_kwargs)
|
||||
ret[model_category].load_model(*load_args, **load_kwargs, device=device)
|
||||
self._loaded.update(ret)
|
||||
return self._loaded
|
||||
|
||||
|
|
@ -104,6 +106,7 @@ class EynollahModelZoo:
|
|||
model_path_override: Optional[str] = None,
|
||||
patched: bool = False,
|
||||
resized: bool = False,
|
||||
device: str = '',
|
||||
) -> AnyModel:
|
||||
"""
|
||||
Load any model
|
||||
|
|
@ -123,10 +126,24 @@ class EynollahModelZoo:
|
|||
)
|
||||
cuda = False
|
||||
try:
|
||||
device = tf.config.list_physical_devices('GPU')[0]
|
||||
tf.config.experimental.set_memory_growth(device, True)
|
||||
cuda = True
|
||||
self.logger.info("using GPU %s", device.name)
|
||||
gpus = tf.config.list_physical_devices('GPU')
|
||||
if device:
|
||||
if ',' in device:
|
||||
for spec in device.split(','):
|
||||
cat, dev = spec.split(':')
|
||||
if fnmatchcase(model_category, cat):
|
||||
device = dev
|
||||
break
|
||||
if device == 'CPU':
|
||||
gpus = []
|
||||
else:
|
||||
assert device.startswith('GPU')
|
||||
gpus = [gpus[int(device[3:])]]
|
||||
tf.config.set_visible_devices(gpus, 'GPU')
|
||||
for device in gpus:
|
||||
tf.config.experimental.set_memory_growth(device, True)
|
||||
cuda = True
|
||||
self.logger.info("using GPU %s for model %s", device.name, model_category)
|
||||
except RuntimeError:
|
||||
self.logger.exception("cannot configure GPU devices")
|
||||
if not cuda:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue