mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-05-26 07:39:22 +02:00
reload_weights: save() → export() w/ serve() inference
This commit is contained in:
parent
86adaf299a
commit
bdfebd2c70
2 changed files with 11 additions and 10 deletions
|
|
@ -191,14 +191,12 @@ class EynollahModelZoo:
|
|||
try:
|
||||
# avoid wasting VRAM on non-transformer models
|
||||
model = load_model(model_path, compile=False)
|
||||
except Exception as e:
|
||||
self.logger.error(e)
|
||||
model = load_model(
|
||||
model_path, compile=False,
|
||||
custom_objects=dict(PatchEncoder=PatchEncoder,
|
||||
Patches=Patches))
|
||||
model.make_predict_function()
|
||||
assert isinstance(model, KerasModel)
|
||||
model.make_predict_function()
|
||||
except ValueError:
|
||||
model = tf.saved_model.load(model_path)
|
||||
model.predict_on_batch = model.serve
|
||||
model.input_shape = model.signatures.get('serving_default').inputs[0].shape
|
||||
model._name = model_category
|
||||
if resized:
|
||||
model = wrap_layout_model_resized(model)
|
||||
|
|
|
|||
|
|
@ -562,7 +562,8 @@ def run(_config,
|
|||
if reload_weights:
|
||||
model.load_weights(dir_of_start_model).assert_existing_objects_matched().expect_partial()
|
||||
dir_save = os.path.join(dir_output, os.path.basename(os.path.normpath(dir_of_start_model)))
|
||||
model.save(dir_save, include_optimizer=False)
|
||||
#model.save(dir_save, include_optimizer=False)
|
||||
model.export(dir_save)
|
||||
with open(os.path.join(dir_save, "config.json"), "w") as fp:
|
||||
json.dump(_config, fp) # encode dict into JSON
|
||||
_log.info("reloaded model from %s to %s", dir_of_start_model, dir_save)
|
||||
|
|
@ -725,7 +726,8 @@ def run(_config,
|
|||
if reload_weights:
|
||||
model.load_weights(dir_of_start_model).assert_existing_objects_matched().expect_partial()
|
||||
dir_save = os.path.join(dir_output, os.path.basename(os.path.normpath(dir_of_start_model)))
|
||||
model.save(dir_save, include_optimizer=False)
|
||||
#model.save(dir_save, include_optimizer=False)
|
||||
model.export(dir_save)
|
||||
with open(os.path.join(dir_save, "config.json"), "w") as fp:
|
||||
json.dump(_config, fp) # encode dict into JSON
|
||||
_log.info("reloaded model from %s to %s", dir_of_start_model, dir_save)
|
||||
|
|
@ -843,7 +845,8 @@ def run(_config,
|
|||
if reload_weights:
|
||||
model.load_weights(dir_of_start_model).assert_existing_objects_matched().expect_partial()
|
||||
dir_save = os.path.join(dir_output, os.path.basename(os.path.normpath(dir_of_start_model)))
|
||||
model.save(dir_save, include_optimizer=False)
|
||||
#model.save(dir_save, include_optimizer=False)
|
||||
model.export(dir_save)
|
||||
with open(os.path.join(dir_save, "config.json"), "w") as fp:
|
||||
json.dump(_config, fp) # encode dict into JSON
|
||||
_log.info("reloaded model from %s to %s", dir_of_start_model, dir_save)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue