mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-04-15 03:41:56 +02:00
fixup device sel
This commit is contained in:
parent
6bbdcc39ef
commit
1756443605
2 changed files with 19 additions and 5 deletions
|
|
@ -139,11 +139,21 @@ class EynollahModelZoo:
|
||||||
else:
|
else:
|
||||||
assert device.startswith('GPU')
|
assert device.startswith('GPU')
|
||||||
gpus = [gpus[int(device[3:])]]
|
gpus = [gpus[int(device[3:])]]
|
||||||
|
else:
|
||||||
|
gpus = gpus[:1] # TF will always use first allowable
|
||||||
tf.config.set_visible_devices(gpus, 'GPU')
|
tf.config.set_visible_devices(gpus, 'GPU')
|
||||||
for device in gpus:
|
for device in gpus:
|
||||||
tf.config.experimental.set_memory_growth(device, True)
|
tf.config.experimental.set_memory_growth(device, True)
|
||||||
|
vendor_name = (
|
||||||
|
tf.config.experimental.get_device_details(device)
|
||||||
|
.get('device_name', 'unknown'))
|
||||||
cuda = True
|
cuda = True
|
||||||
self.logger.info("using GPU %s for model %s", device.name, model_category)
|
self.logger.info("using GPU %s (%s) for model %s",
|
||||||
|
device.name,
|
||||||
|
vendor_name,
|
||||||
|
model_category + (
|
||||||
|
"_patched" if patched else
|
||||||
|
"_resized" if resized else ""))
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
self.logger.exception("cannot configure GPU devices")
|
self.logger.exception("cannot configure GPU devices")
|
||||||
if not cuda:
|
if not cuda:
|
||||||
|
|
@ -166,12 +176,14 @@ class EynollahModelZoo:
|
||||||
model = TrOCRProcessor.from_pretrained(model_path)
|
model = TrOCRProcessor.from_pretrained(model_path)
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
|
# avoid wasting VRAM on non-transformer models
|
||||||
model = load_model(model_path, compile=False)
|
model = load_model(model_path, compile=False)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.logger.exception(e)
|
self.logger.error(e)
|
||||||
model = load_model(
|
model = load_model(
|
||||||
model_path, compile=False, custom_objects={"PatchEncoder": PatchEncoder, "Patches": Patches}
|
model_path, compile=False,
|
||||||
)
|
custom_objects=dict(PatchEncoder=PatchEncoder,
|
||||||
|
Patches=Patches))
|
||||||
model._name = model_category
|
model._name = model_category
|
||||||
if resized:
|
if resized:
|
||||||
model = wrap_layout_model_resized(model)
|
model = wrap_layout_model_resized(model)
|
||||||
|
|
|
||||||
|
|
@ -170,7 +170,9 @@ class Predictor(mp.context.SpawnProcess):
|
||||||
|
|
||||||
def load_model(self, *load_args, **load_kwargs):
|
def load_model(self, *load_args, **load_kwargs):
|
||||||
assert len(load_args)
|
assert len(load_args)
|
||||||
self.name = '_'.join(list(load_args[:1]) + list(load_kwargs))
|
self.name = '_'.join(list(load_args[:1]) +
|
||||||
|
list(key for key in load_kwargs
|
||||||
|
if key != 'device'))
|
||||||
self.load_args = load_args
|
self.load_args = load_args
|
||||||
self.load_kwargs = load_kwargs
|
self.load_kwargs = load_kwargs
|
||||||
self.start() # call run() in subprocess
|
self.start() # call run() in subprocess
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue