reload_weights: save()export() w/ serve() inference

This commit is contained in:
Robert Sachunsky 2026-05-19 03:20:24 +02:00
parent 86adaf299a
commit bdfebd2c70
2 changed files with 11 additions and 10 deletions

View file

@ -191,14 +191,12 @@ class EynollahModelZoo:
try:
# avoid wasting VRAM on non-transformer models
model = load_model(model_path, compile=False)
except Exception as e:
self.logger.error(e)
model = load_model(
model_path, compile=False,
custom_objects=dict(PatchEncoder=PatchEncoder,
Patches=Patches))
model.make_predict_function()
assert isinstance(model, KerasModel)
model.make_predict_function()
except ValueError:
model = tf.saved_model.load(model_path)
model.predict_on_batch = model.serve
model.input_shape = model.signatures.get('serving_default').inputs[0].shape
model._name = model_category
if resized:
model = wrap_layout_model_resized(model)

View file

@ -562,7 +562,8 @@ def run(_config,
if reload_weights:
model.load_weights(dir_of_start_model).assert_existing_objects_matched().expect_partial()
dir_save = os.path.join(dir_output, os.path.basename(os.path.normpath(dir_of_start_model)))
model.save(dir_save, include_optimizer=False)
#model.save(dir_save, include_optimizer=False)
model.export(dir_save)
with open(os.path.join(dir_save, "config.json"), "w") as fp:
json.dump(_config, fp) # encode dict into JSON
_log.info("reloaded model from %s to %s", dir_of_start_model, dir_save)
@ -725,7 +726,8 @@ def run(_config,
if reload_weights:
model.load_weights(dir_of_start_model).assert_existing_objects_matched().expect_partial()
dir_save = os.path.join(dir_output, os.path.basename(os.path.normpath(dir_of_start_model)))
model.save(dir_save, include_optimizer=False)
#model.save(dir_save, include_optimizer=False)
model.export(dir_save)
with open(os.path.join(dir_save, "config.json"), "w") as fp:
json.dump(_config, fp) # encode dict into JSON
_log.info("reloaded model from %s to %s", dir_of_start_model, dir_save)
@ -843,7 +845,8 @@ def run(_config,
if reload_weights:
model.load_weights(dir_of_start_model).assert_existing_objects_matched().expect_partial()
dir_save = os.path.join(dir_output, os.path.basename(os.path.normpath(dir_of_start_model)))
model.save(dir_save, include_optimizer=False)
#model.save(dir_save, include_optimizer=False)
model.export(dir_save)
with open(os.path.join(dir_save, "config.json"), "w") as fp:
json.dump(_config, fp) # encode dict into JSON
_log.info("reloaded model from %s to %s", dir_of_start_model, dir_save)