diff --git a/src/eynollah/cli/cli_layout.py b/src/eynollah/cli/cli_layout.py index 8ce0872..03cf9c8 100644 --- a/src/eynollah/cli/cli_layout.py +++ b/src/eynollah/cli/cli_layout.py @@ -165,6 +165,11 @@ import click type=click.IntRange(min=0), help="number of parallel images to process (also helps better utilise GPU if available); 0 means based on autodetected number of processor cores", ) +@click.option( + "--device", + "-D", + help="placement of computations in predictors for each model type; if none (by default), will try to use first available GPU or fall back to CPU; set string to force using a device (e.g. 'GPU0', 'GPU1' or 'CPU'). Can also be a comma-separated list of model category to device mappings (e.g. 'col_classifier:CPU,page:GPU0,*:GPU1')", +) @click.pass_context def layout_cli( ctx, @@ -194,6 +199,7 @@ def layout_cli( skip_layout_and_reading_order, ignore_page_extraction, num_jobs, + device, ): """ Detect Layout (with optional image enhancement and reading order detection) @@ -209,6 +215,7 @@ def layout_cli( assert bool(image) != bool(dir_in), "Either -i (single input) or -di (directory) must be provided, but not both." eynollah = Eynollah( model_zoo=ctx.obj.model_zoo, + device=device, enable_plotting=enable_plotting, allow_enhancement=allow_enhancement, curved_line=curved_line, diff --git a/src/eynollah/eynollah.py b/src/eynollah/eynollah.py index 4823559..67ee525 100644 --- a/src/eynollah/eynollah.py +++ b/src/eynollah/eynollah.py @@ -121,6 +121,7 @@ class Eynollah: self, *, model_zoo: EynollahModelZoo, + device: str = '', enable_plotting : bool = False, allow_enhancement : bool = False, curved_line : bool = False, @@ -165,10 +166,10 @@ class Eynollah: t_start = time.time() self.logger.info("Loading models...") - self.setup_models() + self.setup_models(device=device) self.logger.info(f"Model initialization complete ({time.time() - t_start:.1f}s)") - def setup_models(self): + def setup_models(self, device=''): # load models, depending on modes # (note: loading too many models can cause OOM on GPU/CUDA, @@ -194,7 +195,7 @@ class Eynollah: if self.tables: loadable.append("table") - self.model_zoo.load_models(*loadable) + self.model_zoo.load_models(*loadable, device=device) for model in loadable: # retrieve and cache output shapes if model.endswith(('_resized', '_patched')): diff --git a/src/eynollah/model_zoo/model_zoo.py b/src/eynollah/model_zoo/model_zoo.py index 92fedfb..b72d36a 100644 --- a/src/eynollah/model_zoo/model_zoo.py +++ b/src/eynollah/model_zoo/model_zoo.py @@ -3,6 +3,7 @@ import json import logging from copy import deepcopy from pathlib import Path +from fnmatch import fnmatchcase from typing import Dict, List, Optional, Tuple, Type, Union from tabulate import tabulate @@ -74,6 +75,7 @@ class EynollahModelZoo: def load_models( self, *all_load_args: Union[str, Tuple[str], Tuple[str, str], Tuple[str, str, str]], + device: str = '', ) -> Dict: """ Load all models by calling load_model and return a dictionary mapping model_category to loaded model @@ -93,7 +95,7 @@ class EynollahModelZoo: load_args[0] = model_category[:-8] load_kwargs["patched"] = True ret[model_category] = Predictor(self.logger, self) - ret[model_category].load_model(*load_args, **load_kwargs) + ret[model_category].load_model(*load_args, **load_kwargs, device=device) self._loaded.update(ret) return self._loaded @@ -104,6 +106,7 @@ class EynollahModelZoo: model_path_override: Optional[str] = None, patched: bool = False, resized: bool = False, + device: str = '', ) -> AnyModel: """ Load any model @@ -123,10 +126,24 @@ class EynollahModelZoo: ) cuda = False try: - device = tf.config.list_physical_devices('GPU')[0] - tf.config.experimental.set_memory_growth(device, True) - cuda = True - self.logger.info("using GPU %s", device.name) + gpus = tf.config.list_physical_devices('GPU') + if device: + if ',' in device: + for spec in device.split(','): + cat, dev = spec.split(':') + if fnmatchcase(model_category, cat): + device = dev + break + if device == 'CPU': + gpus = [] + else: + assert device.startswith('GPU') + gpus = [gpus[int(device[3:])]] + tf.config.set_visible_devices(gpus, 'GPU') + for device in gpus: + tf.config.experimental.set_memory_growth(device, True) + cuda = True + self.logger.info("using GPU %s for model %s", device.name, model_category) except RuntimeError: self.logger.exception("cannot configure GPU devices") if not cuda: