mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-04-14 19:31:57 +02:00
CLI/Eynollah.setup_models/ModelZoo.load_models: add device option/kwarg
allow setting device specifier to load models into either - CPU or single GPU0, GPU1 etc - per-model patterns, e.g. col*:CPU,page:GPU0,*:GPU1 pass through as kwargs until `ModelZoo.load_models()` setup up TF
This commit is contained in:
parent
67e9f84b54
commit
6bbdcc39ef
3 changed files with 33 additions and 8 deletions
|
|
@ -165,6 +165,11 @@ import click
|
||||||
type=click.IntRange(min=0),
|
type=click.IntRange(min=0),
|
||||||
help="number of parallel images to process (also helps better utilise GPU if available); 0 means based on autodetected number of processor cores",
|
help="number of parallel images to process (also helps better utilise GPU if available); 0 means based on autodetected number of processor cores",
|
||||||
)
|
)
|
||||||
|
@click.option(
|
||||||
|
"--device",
|
||||||
|
"-D",
|
||||||
|
help="placement of computations in predictors for each model type; if none (by default), will try to use first available GPU or fall back to CPU; set string to force using a device (e.g. 'GPU0', 'GPU1' or 'CPU'). Can also be a comma-separated list of model category to device mappings (e.g. 'col_classifier:CPU,page:GPU0,*:GPU1')",
|
||||||
|
)
|
||||||
@click.pass_context
|
@click.pass_context
|
||||||
def layout_cli(
|
def layout_cli(
|
||||||
ctx,
|
ctx,
|
||||||
|
|
@ -194,6 +199,7 @@ def layout_cli(
|
||||||
skip_layout_and_reading_order,
|
skip_layout_and_reading_order,
|
||||||
ignore_page_extraction,
|
ignore_page_extraction,
|
||||||
num_jobs,
|
num_jobs,
|
||||||
|
device,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Detect Layout (with optional image enhancement and reading order detection)
|
Detect Layout (with optional image enhancement and reading order detection)
|
||||||
|
|
@ -209,6 +215,7 @@ def layout_cli(
|
||||||
assert bool(image) != bool(dir_in), "Either -i (single input) or -di (directory) must be provided, but not both."
|
assert bool(image) != bool(dir_in), "Either -i (single input) or -di (directory) must be provided, but not both."
|
||||||
eynollah = Eynollah(
|
eynollah = Eynollah(
|
||||||
model_zoo=ctx.obj.model_zoo,
|
model_zoo=ctx.obj.model_zoo,
|
||||||
|
device=device,
|
||||||
enable_plotting=enable_plotting,
|
enable_plotting=enable_plotting,
|
||||||
allow_enhancement=allow_enhancement,
|
allow_enhancement=allow_enhancement,
|
||||||
curved_line=curved_line,
|
curved_line=curved_line,
|
||||||
|
|
|
||||||
|
|
@ -121,6 +121,7 @@ class Eynollah:
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
model_zoo: EynollahModelZoo,
|
model_zoo: EynollahModelZoo,
|
||||||
|
device: str = '',
|
||||||
enable_plotting : bool = False,
|
enable_plotting : bool = False,
|
||||||
allow_enhancement : bool = False,
|
allow_enhancement : bool = False,
|
||||||
curved_line : bool = False,
|
curved_line : bool = False,
|
||||||
|
|
@ -165,10 +166,10 @@ class Eynollah:
|
||||||
t_start = time.time()
|
t_start = time.time()
|
||||||
|
|
||||||
self.logger.info("Loading models...")
|
self.logger.info("Loading models...")
|
||||||
self.setup_models()
|
self.setup_models(device=device)
|
||||||
self.logger.info(f"Model initialization complete ({time.time() - t_start:.1f}s)")
|
self.logger.info(f"Model initialization complete ({time.time() - t_start:.1f}s)")
|
||||||
|
|
||||||
def setup_models(self):
|
def setup_models(self, device=''):
|
||||||
|
|
||||||
# load models, depending on modes
|
# load models, depending on modes
|
||||||
# (note: loading too many models can cause OOM on GPU/CUDA,
|
# (note: loading too many models can cause OOM on GPU/CUDA,
|
||||||
|
|
@ -194,7 +195,7 @@ class Eynollah:
|
||||||
if self.tables:
|
if self.tables:
|
||||||
loadable.append("table")
|
loadable.append("table")
|
||||||
|
|
||||||
self.model_zoo.load_models(*loadable)
|
self.model_zoo.load_models(*loadable, device=device)
|
||||||
for model in loadable:
|
for model in loadable:
|
||||||
# retrieve and cache output shapes
|
# retrieve and cache output shapes
|
||||||
if model.endswith(('_resized', '_patched')):
|
if model.endswith(('_resized', '_patched')):
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@ import json
|
||||||
import logging
|
import logging
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from fnmatch import fnmatchcase
|
||||||
from typing import Dict, List, Optional, Tuple, Type, Union
|
from typing import Dict, List, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
from tabulate import tabulate
|
from tabulate import tabulate
|
||||||
|
|
@ -74,6 +75,7 @@ class EynollahModelZoo:
|
||||||
def load_models(
|
def load_models(
|
||||||
self,
|
self,
|
||||||
*all_load_args: Union[str, Tuple[str], Tuple[str, str], Tuple[str, str, str]],
|
*all_load_args: Union[str, Tuple[str], Tuple[str, str], Tuple[str, str, str]],
|
||||||
|
device: str = '',
|
||||||
) -> Dict:
|
) -> Dict:
|
||||||
"""
|
"""
|
||||||
Load all models by calling load_model and return a dictionary mapping model_category to loaded model
|
Load all models by calling load_model and return a dictionary mapping model_category to loaded model
|
||||||
|
|
@ -93,7 +95,7 @@ class EynollahModelZoo:
|
||||||
load_args[0] = model_category[:-8]
|
load_args[0] = model_category[:-8]
|
||||||
load_kwargs["patched"] = True
|
load_kwargs["patched"] = True
|
||||||
ret[model_category] = Predictor(self.logger, self)
|
ret[model_category] = Predictor(self.logger, self)
|
||||||
ret[model_category].load_model(*load_args, **load_kwargs)
|
ret[model_category].load_model(*load_args, **load_kwargs, device=device)
|
||||||
self._loaded.update(ret)
|
self._loaded.update(ret)
|
||||||
return self._loaded
|
return self._loaded
|
||||||
|
|
||||||
|
|
@ -104,6 +106,7 @@ class EynollahModelZoo:
|
||||||
model_path_override: Optional[str] = None,
|
model_path_override: Optional[str] = None,
|
||||||
patched: bool = False,
|
patched: bool = False,
|
||||||
resized: bool = False,
|
resized: bool = False,
|
||||||
|
device: str = '',
|
||||||
) -> AnyModel:
|
) -> AnyModel:
|
||||||
"""
|
"""
|
||||||
Load any model
|
Load any model
|
||||||
|
|
@ -123,10 +126,24 @@ class EynollahModelZoo:
|
||||||
)
|
)
|
||||||
cuda = False
|
cuda = False
|
||||||
try:
|
try:
|
||||||
device = tf.config.list_physical_devices('GPU')[0]
|
gpus = tf.config.list_physical_devices('GPU')
|
||||||
tf.config.experimental.set_memory_growth(device, True)
|
if device:
|
||||||
cuda = True
|
if ',' in device:
|
||||||
self.logger.info("using GPU %s", device.name)
|
for spec in device.split(','):
|
||||||
|
cat, dev = spec.split(':')
|
||||||
|
if fnmatchcase(model_category, cat):
|
||||||
|
device = dev
|
||||||
|
break
|
||||||
|
if device == 'CPU':
|
||||||
|
gpus = []
|
||||||
|
else:
|
||||||
|
assert device.startswith('GPU')
|
||||||
|
gpus = [gpus[int(device[3:])]]
|
||||||
|
tf.config.set_visible_devices(gpus, 'GPU')
|
||||||
|
for device in gpus:
|
||||||
|
tf.config.experimental.set_memory_growth(device, True)
|
||||||
|
cuda = True
|
||||||
|
self.logger.info("using GPU %s for model %s", device.name, model_category)
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
self.logger.exception("cannot configure GPU devices")
|
self.logger.exception("cannot configure GPU devices")
|
||||||
if not cuda:
|
if not cuda:
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue