fixup device sel

This commit is contained in:
Robert Sachunsky 2026-03-16 15:35:07 +01:00
parent 6bbdcc39ef
commit 1756443605
2 changed files with 19 additions and 5 deletions

View file

@ -139,11 +139,21 @@ class EynollahModelZoo:
else:
assert device.startswith('GPU')
gpus = [gpus[int(device[3:])]]
else:
gpus = gpus[:1] # TF will always use first allowable
tf.config.set_visible_devices(gpus, 'GPU')
for device in gpus:
tf.config.experimental.set_memory_growth(device, True)
vendor_name = (
tf.config.experimental.get_device_details(device)
.get('device_name', 'unknown'))
cuda = True
self.logger.info("using GPU %s for model %s", device.name, model_category)
self.logger.info("using GPU %s (%s) for model %s",
device.name,
vendor_name,
model_category + (
"_patched" if patched else
"_resized" if resized else ""))
except RuntimeError:
self.logger.exception("cannot configure GPU devices")
if not cuda:
@ -166,12 +176,14 @@ class EynollahModelZoo:
model = TrOCRProcessor.from_pretrained(model_path)
else:
try:
# avoid wasting VRAM on non-transformer models
model = load_model(model_path, compile=False)
except Exception as e:
self.logger.exception(e)
self.logger.error(e)
model = load_model(
model_path, compile=False, custom_objects={"PatchEncoder": PatchEncoder, "Patches": Patches}
)
model_path, compile=False,
custom_objects=dict(PatchEncoder=PatchEncoder,
Patches=Patches))
model._name = model_category
if resized:
model = wrap_layout_model_resized(model)

View file

@ -170,7 +170,9 @@ class Predictor(mp.context.SpawnProcess):
def load_model(self, *load_args, **load_kwargs):
assert len(load_args)
self.name = '_'.join(list(load_args[:1]) + list(load_kwargs))
self.name = '_'.join(list(load_args[:1]) +
list(key for key in load_kwargs
if key != 'device'))
self.load_args = load_args
self.load_kwargs = load_kwargs
self.start() # call run() in subprocess