mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-04-14 19:31:57 +02:00
fixup device sel
This commit is contained in:
parent
6bbdcc39ef
commit
1756443605
2 changed files with 19 additions and 5 deletions
|
|
@ -139,11 +139,21 @@ class EynollahModelZoo:
|
|||
else:
|
||||
assert device.startswith('GPU')
|
||||
gpus = [gpus[int(device[3:])]]
|
||||
else:
|
||||
gpus = gpus[:1] # TF will always use first allowable
|
||||
tf.config.set_visible_devices(gpus, 'GPU')
|
||||
for device in gpus:
|
||||
tf.config.experimental.set_memory_growth(device, True)
|
||||
vendor_name = (
|
||||
tf.config.experimental.get_device_details(device)
|
||||
.get('device_name', 'unknown'))
|
||||
cuda = True
|
||||
self.logger.info("using GPU %s for model %s", device.name, model_category)
|
||||
self.logger.info("using GPU %s (%s) for model %s",
|
||||
device.name,
|
||||
vendor_name,
|
||||
model_category + (
|
||||
"_patched" if patched else
|
||||
"_resized" if resized else ""))
|
||||
except RuntimeError:
|
||||
self.logger.exception("cannot configure GPU devices")
|
||||
if not cuda:
|
||||
|
|
@ -166,12 +176,14 @@ class EynollahModelZoo:
|
|||
model = TrOCRProcessor.from_pretrained(model_path)
|
||||
else:
|
||||
try:
|
||||
# avoid wasting VRAM on non-transformer models
|
||||
model = load_model(model_path, compile=False)
|
||||
except Exception as e:
|
||||
self.logger.exception(e)
|
||||
self.logger.error(e)
|
||||
model = load_model(
|
||||
model_path, compile=False, custom_objects={"PatchEncoder": PatchEncoder, "Patches": Patches}
|
||||
)
|
||||
model_path, compile=False,
|
||||
custom_objects=dict(PatchEncoder=PatchEncoder,
|
||||
Patches=Patches))
|
||||
model._name = model_category
|
||||
if resized:
|
||||
model = wrap_layout_model_resized(model)
|
||||
|
|
|
|||
|
|
@ -170,7 +170,9 @@ class Predictor(mp.context.SpawnProcess):
|
|||
|
||||
def load_model(self, *load_args, **load_kwargs):
|
||||
assert len(load_args)
|
||||
self.name = '_'.join(list(load_args[:1]) + list(load_kwargs))
|
||||
self.name = '_'.join(list(load_args[:1]) +
|
||||
list(key for key in load_kwargs
|
||||
if key != 'device'))
|
||||
self.load_args = load_args
|
||||
self.load_kwargs = load_kwargs
|
||||
self.start() # call run() in subprocess
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue