diff --git a/src/eynollah/model_zoo/model_zoo.py b/src/eynollah/model_zoo/model_zoo.py index 7f3cd6c..f1d8824 100644 --- a/src/eynollah/model_zoo/model_zoo.py +++ b/src/eynollah/model_zoo/model_zoo.py @@ -154,7 +154,7 @@ class EynollahModelZoo: try: gpus = tf.config.list_physical_devices('GPU') if device: - if ',' in device: + if ':' in device: for spec in device.split(','): cat, dev = spec.split(':') if fnmatchcase(model_category, cat): @@ -235,6 +235,12 @@ class EynollahModelZoo: dev = torch.device('cpu') if not device and torch.cuda.is_available(): device = 'GPU' # try + if device and ':' in device: + for spec in device.split(','): + cat, dev = spec.split(':') + if fnmatchcase('ocr', cat): + device = dev + break if device and device.startswith('GPU'): try: dev = torch.device('cuda', int(device[3:] or 0))