ModelZoo: fix Torch device selection

This commit is contained in:
Robert Sachunsky 2026-05-21 17:25:53 +02:00
parent 000e4ac8d8
commit 074753a98e

View file

@ -247,9 +247,9 @@ class EynollahModelZoo:
if variant == 'tr': if variant == 'tr':
from transformers import VisionEncoderDecoderModel from transformers import VisionEncoderDecoderModel
import torch import torch
ret = VisionEncoderDecoderModel.from_pretrained(model_dir) model = VisionEncoderDecoderModel.from_pretrained(model_dir)
assert isinstance(ret, VisionEncoderDecoderModel) assert isinstance(model, VisionEncoderDecoderModel)
dev = torch.device('cpu') device0 = torch.device('cpu')
if not device and torch.cuda.is_available(): if not device and torch.cuda.is_available():
device = 'GPU' # try device = 'GPU' # try
if device and ':' in device: if device and ':' in device:
@ -260,17 +260,17 @@ class EynollahModelZoo:
break break
if device and device.startswith('GPU'): if device and device.startswith('GPU'):
try: try:
dev = torch.device('cuda', int(device[3:] or 0)) device0 = torch.device('cuda', int(device[3:] or 0))
name = torch.cuda.get_device_name(dev) name = torch.cuda.get_device_name(device0)
self.logger.info("using GPU %s (%s) for model ocr:tr", dev, name) self.logger.info("using GPU %s (%s) for model ocr:tr", device0, name)
except: except:
self.logger.exception("cannot configure GPU device") self.logger.exception("cannot configure GPU device")
dev = torch.device('cpu') device0 = torch.device('cpu')
if dev.type == 'cuda': if device0.type == 'cuda':
ret.to(dev) model.to(device0)
else: else:
self.logger.warning("no GPU device available") self.logger.warning("no GPU device available")
return ret return model
return self.load_model('ocr', model_variant=variant, device=device) return self.load_model('ocr', model_variant=variant, device=device)