diff --git a/src/eynollah/model_zoo/model_zoo.py b/src/eynollah/model_zoo/model_zoo.py index 3de8b6b..815663e 100644 --- a/src/eynollah/model_zoo/model_zoo.py +++ b/src/eynollah/model_zoo/model_zoo.py @@ -191,14 +191,12 @@ class EynollahModelZoo: try: # avoid wasting VRAM on non-transformer models model = load_model(model_path, compile=False) - except Exception as e: - self.logger.error(e) - model = load_model( - model_path, compile=False, - custom_objects=dict(PatchEncoder=PatchEncoder, - Patches=Patches)) + assert isinstance(model, KerasModel) model.make_predict_function() - assert isinstance(model, KerasModel) + except ValueError: + model = tf.saved_model.load(model_path) + model.predict_on_batch = model.serve + model.input_shape = model.signatures.get('serving_default').inputs[0].shape model._name = model_category if resized: model = wrap_layout_model_resized(model) diff --git a/src/eynollah/training/train.py b/src/eynollah/training/train.py index de998fd..00ed6ee 100644 --- a/src/eynollah/training/train.py +++ b/src/eynollah/training/train.py @@ -562,7 +562,8 @@ def run(_config, if reload_weights: model.load_weights(dir_of_start_model).assert_existing_objects_matched().expect_partial() dir_save = os.path.join(dir_output, os.path.basename(os.path.normpath(dir_of_start_model))) - model.save(dir_save, include_optimizer=False) + #model.save(dir_save, include_optimizer=False) + model.export(dir_save) with open(os.path.join(dir_save, "config.json"), "w") as fp: json.dump(_config, fp) # encode dict into JSON _log.info("reloaded model from %s to %s", dir_of_start_model, dir_save) @@ -725,7 +726,8 @@ def run(_config, if reload_weights: model.load_weights(dir_of_start_model).assert_existing_objects_matched().expect_partial() dir_save = os.path.join(dir_output, os.path.basename(os.path.normpath(dir_of_start_model))) - model.save(dir_save, include_optimizer=False) + #model.save(dir_save, include_optimizer=False) + model.export(dir_save) with open(os.path.join(dir_save, "config.json"), "w") as fp: json.dump(_config, fp) # encode dict into JSON _log.info("reloaded model from %s to %s", dir_of_start_model, dir_save) @@ -843,7 +845,8 @@ def run(_config, if reload_weights: model.load_weights(dir_of_start_model).assert_existing_objects_matched().expect_partial() dir_save = os.path.join(dir_output, os.path.basename(os.path.normpath(dir_of_start_model))) - model.save(dir_save, include_optimizer=False) + #model.save(dir_save, include_optimizer=False) + model.export(dir_save) with open(os.path.join(dir_save, "config.json"), "w") as fp: json.dump(_config, fp) # encode dict into JSON _log.info("reloaded model from %s to %s", dir_of_start_model, dir_save)