predictor: cache models' input shape instead of output shape

This commit is contained in:
Robert Sachunsky 2026-04-20 23:37:54 +02:00
parent 829256df91
commit 264b00f8ab
2 changed files with 17 additions and 24 deletions

View file

@ -193,11 +193,11 @@ class Eynollah:
for model in loadable: for model in loadable:
# retrieve and cache output shapes # retrieve and cache output shapes
if model.endswith(('_resized', '_patched')): if model.endswith(('_resized', '_patched')):
# autosized models do not have a predefined output_shape # autosized models do not have a predefined input_shape
# (and don't need one) # (and don't need one)
continue continue
self.logger.debug("model %s has output shape %s", model, self.logger.debug("model %s has input shape %s", model,
self.model_zoo.get(model).output_shape) self.model_zoo.get(model).input_shape)
def __del__(self): def __del__(self):
if model_zoo := getattr(self, 'model_zoo', None): if model_zoo := getattr(self, 'model_zoo', None):
@ -431,7 +431,7 @@ class Eynollah:
): ):
self.logger.debug("enter do_prediction (patches=%d)", patches) self.logger.debug("enter do_prediction (patches=%d)", patches)
_, img_height_model, img_width_model, _ = model.output_shape _, img_height_model, img_width_model, _ = model.input_shape
img_h_page = img.shape[0] img_h_page = img.shape[0]
img_w_page = img.shape[1] img_w_page = img.shape[1]
@ -634,7 +634,7 @@ class Eynollah:
): ):
self.logger.debug("enter do_prediction_new_concept (patches=%d)", patches) self.logger.debug("enter do_prediction_new_concept (patches=%d)", patches)
_, img_height_model, img_width_model, _ = model.output_shape _, img_height_model, img_width_model, _ = model.input_shape
img = img / 255.0 img = img / 255.0
img = img.astype(np.float16) img = img.astype(np.float16)
@ -1845,6 +1845,7 @@ class Eynollah:
# not trained on drops directly, but it does work: # not trained on drops directly, but it does work:
polygons_of_drop_capitals, polygons_of_drop_capitals,
text_regions_p, text_regions_p,
n_batch_inference=1, # 3 (causes OOM on 8 GB GPUs)
# input labels as in run_boxes_full_layout # input labels as in run_boxes_full_layout
# output labels as in RO model's read_xml # output labels as in RO model's read_xml
label_text=1, label_text=1,
@ -1859,15 +1860,8 @@ class Eynollah:
# no drop-capital in RO model, yet # no drop-capital in RO model, yet
label_drop_ro=4, label_drop_ro=4,
): ):
# FIXME: use model.input_shape model = self.model_zoo.get("reading_order")
height1 =672#448 _, height_model, width_model, _ = model.input_shape
width1 = 448#224
height2 =672#448
width2= 448#224
height3 =672#448
width3 = 448#224
ver_kernel = np.ones((5, 1), dtype=np.uint8) ver_kernel = np.ones((5, 1), dtype=np.uint8)
hor_kernel = np.ones((1, 5), dtype=np.uint8) hor_kernel = np.ones((1, 5), dtype=np.uint8)
@ -1961,16 +1955,15 @@ class Eynollah:
img = np.zeros(labels_con.shape[:2], dtype=np.uint8) img = np.zeros(labels_con.shape[:2], dtype=np.uint8)
cv2.fillPoly(img, pts=[co_text_all[i] // 6], color=1) cv2.fillPoly(img, pts=[co_text_all[i] // 6], color=1)
labels_con[:, :, i] = img labels_con[:, :, i] = img
labels_con = resize_image(labels_con.astype(np.uint8), height1, width1).astype(bool) labels_con = resize_image(labels_con.astype(np.uint8), height_model, width_model).astype(bool)
img_header_and_sep = resize_image(img_header_and_sep, height1, width1) img_header_and_sep = resize_image(img_header_and_sep, height_model, width_model)
img_poly = resize_image(img_poly, height1, width1) img_poly = resize_image(img_poly, height_model, width_model)
labels_con[img_poly == label_seps_ro] = 2 labels_con[img_poly == label_seps_ro] = 2
labels_con[img_header_and_sep == 1] = 3 labels_con[img_header_and_sep == 1] = 3
labels_con = labels_con / 3. labels_con = labels_con / 3.
img_poly = img_poly / 5. img_poly = img_poly / 5.
inference_bs = 1 # 3 (causes OOM on 8 GB GPUs) input_1 = np.zeros((n_batch_inference, height_model, width_model, 3))
input_1 = np.zeros((inference_bs, height1, width1, 3))
ordered = [list(range(len(co_text_all)))] ordered = [list(range(len(co_text_all)))]
index_update = 0 index_update = 0
#print(labels_con.shape[2],"number of regions for reading order") #print(labels_con.shape[2],"number of regions for reading order")
@ -1989,8 +1982,8 @@ class Eynollah:
tot_counter += 1 tot_counter += 1
batch.append(j) batch.append(j)
if tot_counter % inference_bs == 0 or tot_counter == len(ij_list): if tot_counter % n_batch_inference == 0 or tot_counter == len(ij_list):
y_pr = self.model_zoo.get("reading_order").predict(input_1 , verbose=0) y_pr = model.predict(input_1 , verbose=0)
for post_pr in y_pr: for post_pr in y_pr:
if post_pr[0] >= 0.5: if post_pr[0] >= 0.5:
post_list.append(j) post_list.append(j)

View file

@ -37,7 +37,7 @@ class Predictor(mp.context.SpawnProcess):
super().__init__(name="EynollahPredictor", daemon=True) super().__init__(name="EynollahPredictor", daemon=True)
@property @property
def output_shape(self): def input_shape(self):
return self({}) return self({})
def predict(self, data: dict, verbose=0): def predict(self, data: dict, verbose=0):
@ -122,7 +122,7 @@ class Predictor(mp.context.SpawnProcess):
REBATCH_SIZE = 1 # save VRAM; FIXME: re-enable w/ runtime parameter REBATCH_SIZE = 1 # save VRAM; FIXME: re-enable w/ runtime parameter
if not len(shared_data): if not len(shared_data):
#self.logger.debug("getting '%d' output shape of model '%s'", jobid, self.name) #self.logger.debug("getting '%d' output shape of model '%s'", jobid, self.name)
result = self.model.output_shape result = self.model.input_shape
self.resultq.put((jobid, result)) self.resultq.put((jobid, result))
#self.logger.debug("sent result for '%d': %s", jobid, result) #self.logger.debug("sent result for '%d': %s", jobid, result)
else: else:
@ -137,7 +137,7 @@ class Predictor(mp.context.SpawnProcess):
tasks.append((jobid0, shared_data0)) tasks.append((jobid0, shared_data0))
else: else:
# immediately anser # immediately anser
self.resultq.put((jobid0, self.model.output_shape)) self.resultq.put((jobid0, self.model.input_shape))
if len(tasks) > 1: if len(tasks) > 1:
self.logger.debug("rebatching %d '%s' tasks of batch size %d", self.logger.debug("rebatching %d '%s' tasks of batch size %d",
len(tasks), self.name, batch_size) len(tasks), self.name, batch_size)