diff --git a/src/eynollah/model_zoo/model_zoo.py b/src/eynollah/model_zoo/model_zoo.py index ec35a80..a1f9a24 100644 --- a/src/eynollah/model_zoo/model_zoo.py +++ b/src/eynollah/model_zoo/model_zoo.py @@ -191,9 +191,8 @@ class EynollahModelZoo: try: # avoid wasting VRAM on non-transformer models model = load_model(model_path, compile=False) - assert isinstance(model, KerasModel) model.make_predict_function() - except ValueError: + except (AttributeError, ValueError): model = tf.saved_model.load(model_path) model.predict_on_batch = model.serve model.input_shape = model.signatures.get('serving_default').inputs[0].shape