From 17564436055d6b6c4b2ad7a5d61c17e57b944b3c Mon Sep 17 00:00:00 2001 From: Robert Sachunsky Date: Mon, 16 Mar 2026 15:35:07 +0100 Subject: [PATCH] fixup device sel --- src/eynollah/model_zoo/model_zoo.py | 20 ++++++++++++++++---- src/eynollah/predictor.py | 4 +++- 2 files changed, 19 insertions(+), 5 deletions(-) diff --git a/src/eynollah/model_zoo/model_zoo.py b/src/eynollah/model_zoo/model_zoo.py index b72d36a..79d3573 100644 --- a/src/eynollah/model_zoo/model_zoo.py +++ b/src/eynollah/model_zoo/model_zoo.py @@ -139,11 +139,21 @@ class EynollahModelZoo: else: assert device.startswith('GPU') gpus = [gpus[int(device[3:])]] + else: + gpus = gpus[:1] # TF will always use first allowable tf.config.set_visible_devices(gpus, 'GPU') for device in gpus: tf.config.experimental.set_memory_growth(device, True) + vendor_name = ( + tf.config.experimental.get_device_details(device) + .get('device_name', 'unknown')) cuda = True - self.logger.info("using GPU %s for model %s", device.name, model_category) + self.logger.info("using GPU %s (%s) for model %s", + device.name, + vendor_name, + model_category + ( + "_patched" if patched else + "_resized" if resized else "")) except RuntimeError: self.logger.exception("cannot configure GPU devices") if not cuda: @@ -166,12 +176,14 @@ class EynollahModelZoo: model = TrOCRProcessor.from_pretrained(model_path) else: try: + # avoid wasting VRAM on non-transformer models model = load_model(model_path, compile=False) except Exception as e: - self.logger.exception(e) + self.logger.error(e) model = load_model( - model_path, compile=False, custom_objects={"PatchEncoder": PatchEncoder, "Patches": Patches} - ) + model_path, compile=False, + custom_objects=dict(PatchEncoder=PatchEncoder, + Patches=Patches)) model._name = model_category if resized: model = wrap_layout_model_resized(model) diff --git a/src/eynollah/predictor.py b/src/eynollah/predictor.py index d6e149c..a6b15d5 100644 --- a/src/eynollah/predictor.py +++ b/src/eynollah/predictor.py @@ -170,7 +170,9 @@ class Predictor(mp.context.SpawnProcess): def load_model(self, *load_args, **load_kwargs): assert len(load_args) - self.name = '_'.join(list(load_args[:1]) + list(load_kwargs)) + self.name = '_'.join(list(load_args[:1]) + + list(key for key in load_kwargs + if key != 'device')) self.load_args = load_args self.load_kwargs = load_kwargs self.start() # call run() in subprocess