From 264b00f8ab24c7ff1e35a89961f6a89c59da24b2 Mon Sep 17 00:00:00 2001 From: Robert Sachunsky Date: Mon, 20 Apr 2026 23:37:54 +0200 Subject: [PATCH] predictor: cache models' input shape instead of output shape --- src/eynollah/eynollah.py | 35 ++++++++++++++--------------------- src/eynollah/predictor.py | 6 +++--- 2 files changed, 17 insertions(+), 24 deletions(-) diff --git a/src/eynollah/eynollah.py b/src/eynollah/eynollah.py index 1a9aa2b..df98c19 100644 --- a/src/eynollah/eynollah.py +++ b/src/eynollah/eynollah.py @@ -193,11 +193,11 @@ class Eynollah: for model in loadable: # retrieve and cache output shapes if model.endswith(('_resized', '_patched')): - # autosized models do not have a predefined output_shape + # autosized models do not have a predefined input_shape # (and don't need one) continue - self.logger.debug("model %s has output shape %s", model, - self.model_zoo.get(model).output_shape) + self.logger.debug("model %s has input shape %s", model, + self.model_zoo.get(model).input_shape) def __del__(self): if model_zoo := getattr(self, 'model_zoo', None): @@ -431,7 +431,7 @@ class Eynollah: ): self.logger.debug("enter do_prediction (patches=%d)", patches) - _, img_height_model, img_width_model, _ = model.output_shape + _, img_height_model, img_width_model, _ = model.input_shape img_h_page = img.shape[0] img_w_page = img.shape[1] @@ -634,7 +634,7 @@ class Eynollah: ): self.logger.debug("enter do_prediction_new_concept (patches=%d)", patches) - _, img_height_model, img_width_model, _ = model.output_shape + _, img_height_model, img_width_model, _ = model.input_shape img = img / 255.0 img = img.astype(np.float16) @@ -1845,6 +1845,7 @@ class Eynollah: # not trained on drops directly, but it does work: polygons_of_drop_capitals, text_regions_p, + n_batch_inference=1, # 3 (causes OOM on 8 GB GPUs) # input labels as in run_boxes_full_layout # output labels as in RO model's read_xml label_text=1, @@ -1859,15 +1860,8 @@ class Eynollah: # no drop-capital in RO model, yet label_drop_ro=4, ): - # FIXME: use model.input_shape - height1 =672#448 - width1 = 448#224 - - height2 =672#448 - width2= 448#224 - - height3 =672#448 - width3 = 448#224 + model = self.model_zoo.get("reading_order") + _, height_model, width_model, _ = model.input_shape ver_kernel = np.ones((5, 1), dtype=np.uint8) hor_kernel = np.ones((1, 5), dtype=np.uint8) @@ -1961,16 +1955,15 @@ class Eynollah: img = np.zeros(labels_con.shape[:2], dtype=np.uint8) cv2.fillPoly(img, pts=[co_text_all[i] // 6], color=1) labels_con[:, :, i] = img - labels_con = resize_image(labels_con.astype(np.uint8), height1, width1).astype(bool) - img_header_and_sep = resize_image(img_header_and_sep, height1, width1) - img_poly = resize_image(img_poly, height1, width1) + labels_con = resize_image(labels_con.astype(np.uint8), height_model, width_model).astype(bool) + img_header_and_sep = resize_image(img_header_and_sep, height_model, width_model) + img_poly = resize_image(img_poly, height_model, width_model) labels_con[img_poly == label_seps_ro] = 2 labels_con[img_header_and_sep == 1] = 3 labels_con = labels_con / 3. img_poly = img_poly / 5. - inference_bs = 1 # 3 (causes OOM on 8 GB GPUs) - input_1 = np.zeros((inference_bs, height1, width1, 3)) + input_1 = np.zeros((n_batch_inference, height_model, width_model, 3)) ordered = [list(range(len(co_text_all)))] index_update = 0 #print(labels_con.shape[2],"number of regions for reading order") @@ -1989,8 +1982,8 @@ class Eynollah: tot_counter += 1 batch.append(j) - if tot_counter % inference_bs == 0 or tot_counter == len(ij_list): - y_pr = self.model_zoo.get("reading_order").predict(input_1 , verbose=0) + if tot_counter % n_batch_inference == 0 or tot_counter == len(ij_list): + y_pr = model.predict(input_1 , verbose=0) for post_pr in y_pr: if post_pr[0] >= 0.5: post_list.append(j) diff --git a/src/eynollah/predictor.py b/src/eynollah/predictor.py index 2ab62a0..e1159e7 100644 --- a/src/eynollah/predictor.py +++ b/src/eynollah/predictor.py @@ -37,7 +37,7 @@ class Predictor(mp.context.SpawnProcess): super().__init__(name="EynollahPredictor", daemon=True) @property - def output_shape(self): + def input_shape(self): return self({}) def predict(self, data: dict, verbose=0): @@ -122,7 +122,7 @@ class Predictor(mp.context.SpawnProcess): REBATCH_SIZE = 1 # save VRAM; FIXME: re-enable w/ runtime parameter if not len(shared_data): #self.logger.debug("getting '%d' output shape of model '%s'", jobid, self.name) - result = self.model.output_shape + result = self.model.input_shape self.resultq.put((jobid, result)) #self.logger.debug("sent result for '%d': %s", jobid, result) else: @@ -137,7 +137,7 @@ class Predictor(mp.context.SpawnProcess): tasks.append((jobid0, shared_data0)) else: # immediately anser - self.resultq.put((jobid0, self.model.output_shape)) + self.resultq.put((jobid0, self.model.input_shape)) if len(tasks) > 1: self.logger.debug("rebatching %d '%s' tasks of batch size %d", len(tasks), self.name, batch_size)