diff --git a/src/eynollah/cli/cli_readingorder.py b/src/eynollah/cli/cli_readingorder.py index 0f44b7f..eed9fb9 100644 --- a/src/eynollah/cli/cli_readingorder.py +++ b/src/eynollah/cli/cli_readingorder.py @@ -20,14 +20,19 @@ import click type=click.Path(exists=True, file_okay=False), required=True, ) +@click.option( + "--device", + "-D", + help="placement of computations in predictors for each model type; if none (by default), will try to use first available GPU or fall back to CPU; set string to force using a device (e.g. 'GPU0', 'GPU1' or 'CPU'). Can also be a comma-separated list of model category to device mappings (e.g. 'col_classifier:CPU,page:GPU0,*:GPU1')", +) @click.pass_context -def readingorder_cli(ctx, input, dir_in, out): +def readingorder_cli(ctx, input, dir_in, out, device): """ Generate ReadingOrder with a ML model """ - from ..mb_ro_on_layout import machine_based_reading_order_on_layout + from ..mb_ro_on_layout import Reorder assert bool(input) != bool(dir_in), "Either -i (single input) or -di (directory) must be provided, but not both." - orderer = machine_based_reading_order_on_layout(model_zoo=ctx.obj.model_zoo) + orderer = Reorder(model_zoo=ctx.obj.model_zoo, device=device) orderer.run(xml_filename=input, dir_in=dir_in, dir_out=out, diff --git a/src/eynollah/mb_ro_on_layout.py b/src/eynollah/mb_ro_on_layout.py index 22fe97b..5725ba1 100644 --- a/src/eynollah/mb_ro_on_layout.py +++ b/src/eynollah/mb_ro_on_layout.py @@ -21,6 +21,7 @@ os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15 import tensorflow as tf from tensorflow.keras.models import Model +from .eynollah import Eynollah from .model_zoo import EynollahModelZoo from .utils.resize import resize_image from .utils.contour import ( @@ -34,23 +35,27 @@ DPI_THRESHOLD = 298 KERNEL = np.ones((5, 5), np.uint8) -class machine_based_reading_order_on_layout: +class Reorder(Eynollah): def __init__( - self, - *, - model_zoo: EynollahModelZoo, - logger : Optional[logging.Logger] = None, + self, + *, + model_zoo: EynollahModelZoo, + logger : Optional[logging.Logger] = None, + device: str = '', ): self.logger = logger or logging.getLogger('eynollah.mbreorder') self.model_zoo = model_zoo - try: - for device in tf.config.list_physical_devices('GPU'): - tf.config.experimental.set_memory_growth(device, True) - except: - self.logger.warning("no GPU device available") - self.model_zoo.load_model('reading_order') + self.setup_models(device=device) + + def setup_models(self, device=''): + loadable = ['reading_order'] + self.model_zoo.load_models(*loadable, device=device) + for model in loadable: + self.logger.debug("model %s has input shape %s", model, + self.model_zoo.get(model).input_shape) + def read_xml(self, xml_file): tree1 = ET.parse(xml_file, parser = ET.XMLParser(encoding='utf-8')) @@ -676,7 +681,7 @@ class machine_based_reading_order_on_layout: tot_counter += 1 batch.append(j) if tot_counter % inference_bs == 0 or tot_counter == len(ij_list): - y_pr = self.model_zoo.get('reading_order', Model).predict(input_1 , verbose='0') + y_pr = self.model_zoo.get('reading_order').predict(input_1, verbose=0) for jb, j in enumerate(batch): if y_pr[jb][0]>=0.5: post_list.append(j)