predictor: cache models' input shape instead of output shape

This commit is contained in:
Robert Sachunsky 2026-04-20 23:37:54 +02:00
parent 829256df91
commit 264b00f8ab
2 changed files with 17 additions and 24 deletions

View file

@ -193,11 +193,11 @@ class Eynollah:
for model in loadable:
# retrieve and cache output shapes
if model.endswith(('_resized', '_patched')):
# autosized models do not have a predefined output_shape
# autosized models do not have a predefined input_shape
# (and don't need one)
continue
self.logger.debug("model %s has output shape %s", model,
self.model_zoo.get(model).output_shape)
self.logger.debug("model %s has input shape %s", model,
self.model_zoo.get(model).input_shape)
def __del__(self):
if model_zoo := getattr(self, 'model_zoo', None):
@ -431,7 +431,7 @@ class Eynollah:
):
self.logger.debug("enter do_prediction (patches=%d)", patches)
_, img_height_model, img_width_model, _ = model.output_shape
_, img_height_model, img_width_model, _ = model.input_shape
img_h_page = img.shape[0]
img_w_page = img.shape[1]
@ -634,7 +634,7 @@ class Eynollah:
):
self.logger.debug("enter do_prediction_new_concept (patches=%d)", patches)
_, img_height_model, img_width_model, _ = model.output_shape
_, img_height_model, img_width_model, _ = model.input_shape
img = img / 255.0
img = img.astype(np.float16)
@ -1845,6 +1845,7 @@ class Eynollah:
# not trained on drops directly, but it does work:
polygons_of_drop_capitals,
text_regions_p,
n_batch_inference=1, # 3 (causes OOM on 8 GB GPUs)
# input labels as in run_boxes_full_layout
# output labels as in RO model's read_xml
label_text=1,
@ -1859,15 +1860,8 @@ class Eynollah:
# no drop-capital in RO model, yet
label_drop_ro=4,
):
# FIXME: use model.input_shape
height1 =672#448
width1 = 448#224
height2 =672#448
width2= 448#224
height3 =672#448
width3 = 448#224
model = self.model_zoo.get("reading_order")
_, height_model, width_model, _ = model.input_shape
ver_kernel = np.ones((5, 1), dtype=np.uint8)
hor_kernel = np.ones((1, 5), dtype=np.uint8)
@ -1961,16 +1955,15 @@ class Eynollah:
img = np.zeros(labels_con.shape[:2], dtype=np.uint8)
cv2.fillPoly(img, pts=[co_text_all[i] // 6], color=1)
labels_con[:, :, i] = img
labels_con = resize_image(labels_con.astype(np.uint8), height1, width1).astype(bool)
img_header_and_sep = resize_image(img_header_and_sep, height1, width1)
img_poly = resize_image(img_poly, height1, width1)
labels_con = resize_image(labels_con.astype(np.uint8), height_model, width_model).astype(bool)
img_header_and_sep = resize_image(img_header_and_sep, height_model, width_model)
img_poly = resize_image(img_poly, height_model, width_model)
labels_con[img_poly == label_seps_ro] = 2
labels_con[img_header_and_sep == 1] = 3
labels_con = labels_con / 3.
img_poly = img_poly / 5.
inference_bs = 1 # 3 (causes OOM on 8 GB GPUs)
input_1 = np.zeros((inference_bs, height1, width1, 3))
input_1 = np.zeros((n_batch_inference, height_model, width_model, 3))
ordered = [list(range(len(co_text_all)))]
index_update = 0
#print(labels_con.shape[2],"number of regions for reading order")
@ -1989,8 +1982,8 @@ class Eynollah:
tot_counter += 1
batch.append(j)
if tot_counter % inference_bs == 0 or tot_counter == len(ij_list):
y_pr = self.model_zoo.get("reading_order").predict(input_1 , verbose=0)
if tot_counter % n_batch_inference == 0 or tot_counter == len(ij_list):
y_pr = model.predict(input_1 , verbose=0)
for post_pr in y_pr:
if post_pr[0] >= 0.5:
post_list.append(j)

View file

@ -37,7 +37,7 @@ class Predictor(mp.context.SpawnProcess):
super().__init__(name="EynollahPredictor", daemon=True)
@property
def output_shape(self):
def input_shape(self):
return self({})
def predict(self, data: dict, verbose=0):
@ -122,7 +122,7 @@ class Predictor(mp.context.SpawnProcess):
REBATCH_SIZE = 1 # save VRAM; FIXME: re-enable w/ runtime parameter
if not len(shared_data):
#self.logger.debug("getting '%d' output shape of model '%s'", jobid, self.name)
result = self.model.output_shape
result = self.model.input_shape
self.resultq.put((jobid, result))
#self.logger.debug("sent result for '%d': %s", jobid, result)
else:
@ -137,7 +137,7 @@ class Predictor(mp.context.SpawnProcess):
tasks.append((jobid0, shared_data0))
else:
# immediately anser
self.resultq.put((jobid0, self.model.output_shape))
self.resultq.put((jobid0, self.model.input_shape))
if len(tasks) > 1:
self.logger.debug("rebatching %d '%s' tasks of batch size %d",
len(tasks), self.name, batch_size)