mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-05-26 07:39:22 +02:00
ModelZoo: fix Torch device selection
This commit is contained in:
parent
000e4ac8d8
commit
074753a98e
1 changed files with 10 additions and 10 deletions
|
|
@ -247,9 +247,9 @@ class EynollahModelZoo:
|
||||||
if variant == 'tr':
|
if variant == 'tr':
|
||||||
from transformers import VisionEncoderDecoderModel
|
from transformers import VisionEncoderDecoderModel
|
||||||
import torch
|
import torch
|
||||||
ret = VisionEncoderDecoderModel.from_pretrained(model_dir)
|
model = VisionEncoderDecoderModel.from_pretrained(model_dir)
|
||||||
assert isinstance(ret, VisionEncoderDecoderModel)
|
assert isinstance(model, VisionEncoderDecoderModel)
|
||||||
dev = torch.device('cpu')
|
device0 = torch.device('cpu')
|
||||||
if not device and torch.cuda.is_available():
|
if not device and torch.cuda.is_available():
|
||||||
device = 'GPU' # try
|
device = 'GPU' # try
|
||||||
if device and ':' in device:
|
if device and ':' in device:
|
||||||
|
|
@ -260,17 +260,17 @@ class EynollahModelZoo:
|
||||||
break
|
break
|
||||||
if device and device.startswith('GPU'):
|
if device and device.startswith('GPU'):
|
||||||
try:
|
try:
|
||||||
dev = torch.device('cuda', int(device[3:] or 0))
|
device0 = torch.device('cuda', int(device[3:] or 0))
|
||||||
name = torch.cuda.get_device_name(dev)
|
name = torch.cuda.get_device_name(device0)
|
||||||
self.logger.info("using GPU %s (%s) for model ocr:tr", dev, name)
|
self.logger.info("using GPU %s (%s) for model ocr:tr", device0, name)
|
||||||
except:
|
except:
|
||||||
self.logger.exception("cannot configure GPU device")
|
self.logger.exception("cannot configure GPU device")
|
||||||
dev = torch.device('cpu')
|
device0 = torch.device('cpu')
|
||||||
if dev.type == 'cuda':
|
if device0.type == 'cuda':
|
||||||
ret.to(dev)
|
model.to(device0)
|
||||||
else:
|
else:
|
||||||
self.logger.warning("no GPU device available")
|
self.logger.warning("no GPU device available")
|
||||||
return ret
|
return model
|
||||||
|
|
||||||
return self.load_model('ocr', model_variant=variant, device=device)
|
return self.load_model('ocr', model_variant=variant, device=device)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue