mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-05-26 07:39:22 +02:00
model_zoo: also parse comma/colon syntax for device in Torch case
This commit is contained in:
parent
be4fe8c263
commit
17b311441a
1 changed files with 7 additions and 1 deletions
|
|
@ -154,7 +154,7 @@ class EynollahModelZoo:
|
||||||
try:
|
try:
|
||||||
gpus = tf.config.list_physical_devices('GPU')
|
gpus = tf.config.list_physical_devices('GPU')
|
||||||
if device:
|
if device:
|
||||||
if ',' in device:
|
if ':' in device:
|
||||||
for spec in device.split(','):
|
for spec in device.split(','):
|
||||||
cat, dev = spec.split(':')
|
cat, dev = spec.split(':')
|
||||||
if fnmatchcase(model_category, cat):
|
if fnmatchcase(model_category, cat):
|
||||||
|
|
@ -235,6 +235,12 @@ class EynollahModelZoo:
|
||||||
dev = torch.device('cpu')
|
dev = torch.device('cpu')
|
||||||
if not device and torch.cuda.is_available():
|
if not device and torch.cuda.is_available():
|
||||||
device = 'GPU' # try
|
device = 'GPU' # try
|
||||||
|
if device and ':' in device:
|
||||||
|
for spec in device.split(','):
|
||||||
|
cat, dev = spec.split(':')
|
||||||
|
if fnmatchcase('ocr', cat):
|
||||||
|
device = dev
|
||||||
|
break
|
||||||
if device and device.startswith('GPU'):
|
if device and device.startswith('GPU'):
|
||||||
try:
|
try:
|
||||||
dev = torch.device('cuda', int(device[3:] or 0))
|
dev = torch.device('cuda', int(device[3:] or 0))
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue