diff --git a/tests/test_model_zoo.py b/tests/test_model_zoo.py index 2042b28..9d37431 100644 --- a/tests/test_model_zoo.py +++ b/tests/test_model_zoo.py @@ -6,11 +6,11 @@ def test_trocr1( model_zoo = EynollahModelZoo(model_dir) try: from transformers import TrOCRProcessor, VisionEncoderDecoderModel - model_zoo.load_model('trocr_processor') - proc = model_zoo.get('trocr_processor', TrOCRProcessor) + model_zoo.load_models('trocr_processor') + proc = model_zoo.get('trocr_processor') assert isinstance(proc, TrOCRProcessor) - model_zoo.load_model('ocr', 'tr') - model = model_zoo.get('ocr', VisionEncoderDecoderModel) + model_zoo.load_models(['ocr', 'tr']) + model = model_zoo.get('ocr') assert isinstance(model, VisionEncoderDecoderModel) except ImportError: pass