training.models.transformer_block: tf.reshape → Keras Reshape layer

This commit is contained in:
Robert Sachunsky 2026-05-19 03:17:31 +02:00
parent 9efce5e9f2
commit 86adaf299a
2 changed files with 8 additions and 8 deletions

View file

@ -309,11 +309,10 @@ def transformer_block(img,
# Skip connection 2. # Skip connection 2.
encoded_patches = Add()([x3, x2]) encoded_patches = Add()([x3, x2])
encoded_patches = tf.reshape(encoded_patches, encoded_patches = Reshape(target_shape=(img.shape[1],
[-1,
img.shape[1],
img.shape[2], img.shape[2],
projection_dim // (patchsize_x * patchsize_y)]) projection_dim // (patchsize_x * patchsize_y)),
name="reshape_patches")(encoded_patches)
return encoded_patches return encoded_patches
def vit_resnet50_unet(num_patches, def vit_resnet50_unet(num_patches,

View file

@ -26,16 +26,17 @@ RELOADABLE_MODELS = \
all: $(RELOADABLE_MODELS) all: $(RELOADABLE_MODELS)
$(MODELS_DST)/%: $(MODELS_SRC)/% $(MODELS_DST)/%: $(MODELS_SRC)/%
mkdir -p $@
test -e $</config.json || exit 1 test -e $</config.json || exit 1
eynollah-training train --force \ { mkdir -p $@ \
&& eynollah-training train --force \
with $</config.json \ with $</config.json \
reload_weights=True \ reload_weights=True \
continue_training=False \ continue_training=False \
dir_output=$(dir $@) \ dir_output=$(dir $@) \
dir_of_start_model=$< \ dir_of_start_model=$< \
&& cp $</config.json $@/config.json \
|| { rm -rf $@; false; }; } \
2>&1 | tee $(notdir $<).log 2>&1 | tee $(notdir $<).log
cp $</config.json $@/config.json
compare: compare:
for i in `find $(MODELS_DST) -mindepth 2`;do \ for i in `find $(MODELS_DST) -mindepth 2`;do \