mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-04-30 19:22:03 +02:00
predictor: cache models' input shape instead of output shape
This commit is contained in:
parent
829256df91
commit
264b00f8ab
2 changed files with 17 additions and 24 deletions
|
|
@ -193,11 +193,11 @@ class Eynollah:
|
|||
for model in loadable:
|
||||
# retrieve and cache output shapes
|
||||
if model.endswith(('_resized', '_patched')):
|
||||
# autosized models do not have a predefined output_shape
|
||||
# autosized models do not have a predefined input_shape
|
||||
# (and don't need one)
|
||||
continue
|
||||
self.logger.debug("model %s has output shape %s", model,
|
||||
self.model_zoo.get(model).output_shape)
|
||||
self.logger.debug("model %s has input shape %s", model,
|
||||
self.model_zoo.get(model).input_shape)
|
||||
|
||||
def __del__(self):
|
||||
if model_zoo := getattr(self, 'model_zoo', None):
|
||||
|
|
@ -431,7 +431,7 @@ class Eynollah:
|
|||
):
|
||||
|
||||
self.logger.debug("enter do_prediction (patches=%d)", patches)
|
||||
_, img_height_model, img_width_model, _ = model.output_shape
|
||||
_, img_height_model, img_width_model, _ = model.input_shape
|
||||
img_h_page = img.shape[0]
|
||||
img_w_page = img.shape[1]
|
||||
|
||||
|
|
@ -634,7 +634,7 @@ class Eynollah:
|
|||
):
|
||||
|
||||
self.logger.debug("enter do_prediction_new_concept (patches=%d)", patches)
|
||||
_, img_height_model, img_width_model, _ = model.output_shape
|
||||
_, img_height_model, img_width_model, _ = model.input_shape
|
||||
|
||||
img = img / 255.0
|
||||
img = img.astype(np.float16)
|
||||
|
|
@ -1845,6 +1845,7 @@ class Eynollah:
|
|||
# not trained on drops directly, but it does work:
|
||||
polygons_of_drop_capitals,
|
||||
text_regions_p,
|
||||
n_batch_inference=1, # 3 (causes OOM on 8 GB GPUs)
|
||||
# input labels as in run_boxes_full_layout
|
||||
# output labels as in RO model's read_xml
|
||||
label_text=1,
|
||||
|
|
@ -1859,15 +1860,8 @@ class Eynollah:
|
|||
# no drop-capital in RO model, yet
|
||||
label_drop_ro=4,
|
||||
):
|
||||
# FIXME: use model.input_shape
|
||||
height1 =672#448
|
||||
width1 = 448#224
|
||||
|
||||
height2 =672#448
|
||||
width2= 448#224
|
||||
|
||||
height3 =672#448
|
||||
width3 = 448#224
|
||||
model = self.model_zoo.get("reading_order")
|
||||
_, height_model, width_model, _ = model.input_shape
|
||||
|
||||
ver_kernel = np.ones((5, 1), dtype=np.uint8)
|
||||
hor_kernel = np.ones((1, 5), dtype=np.uint8)
|
||||
|
|
@ -1961,16 +1955,15 @@ class Eynollah:
|
|||
img = np.zeros(labels_con.shape[:2], dtype=np.uint8)
|
||||
cv2.fillPoly(img, pts=[co_text_all[i] // 6], color=1)
|
||||
labels_con[:, :, i] = img
|
||||
labels_con = resize_image(labels_con.astype(np.uint8), height1, width1).astype(bool)
|
||||
img_header_and_sep = resize_image(img_header_and_sep, height1, width1)
|
||||
img_poly = resize_image(img_poly, height1, width1)
|
||||
labels_con = resize_image(labels_con.astype(np.uint8), height_model, width_model).astype(bool)
|
||||
img_header_and_sep = resize_image(img_header_and_sep, height_model, width_model)
|
||||
img_poly = resize_image(img_poly, height_model, width_model)
|
||||
labels_con[img_poly == label_seps_ro] = 2
|
||||
labels_con[img_header_and_sep == 1] = 3
|
||||
labels_con = labels_con / 3.
|
||||
img_poly = img_poly / 5.
|
||||
|
||||
inference_bs = 1 # 3 (causes OOM on 8 GB GPUs)
|
||||
input_1 = np.zeros((inference_bs, height1, width1, 3))
|
||||
input_1 = np.zeros((n_batch_inference, height_model, width_model, 3))
|
||||
ordered = [list(range(len(co_text_all)))]
|
||||
index_update = 0
|
||||
#print(labels_con.shape[2],"number of regions for reading order")
|
||||
|
|
@ -1989,8 +1982,8 @@ class Eynollah:
|
|||
|
||||
tot_counter += 1
|
||||
batch.append(j)
|
||||
if tot_counter % inference_bs == 0 or tot_counter == len(ij_list):
|
||||
y_pr = self.model_zoo.get("reading_order").predict(input_1 , verbose=0)
|
||||
if tot_counter % n_batch_inference == 0 or tot_counter == len(ij_list):
|
||||
y_pr = model.predict(input_1 , verbose=0)
|
||||
for post_pr in y_pr:
|
||||
if post_pr[0] >= 0.5:
|
||||
post_list.append(j)
|
||||
|
|
|
|||
|
|
@ -37,7 +37,7 @@ class Predictor(mp.context.SpawnProcess):
|
|||
super().__init__(name="EynollahPredictor", daemon=True)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
def input_shape(self):
|
||||
return self({})
|
||||
|
||||
def predict(self, data: dict, verbose=0):
|
||||
|
|
@ -122,7 +122,7 @@ class Predictor(mp.context.SpawnProcess):
|
|||
REBATCH_SIZE = 1 # save VRAM; FIXME: re-enable w/ runtime parameter
|
||||
if not len(shared_data):
|
||||
#self.logger.debug("getting '%d' output shape of model '%s'", jobid, self.name)
|
||||
result = self.model.output_shape
|
||||
result = self.model.input_shape
|
||||
self.resultq.put((jobid, result))
|
||||
#self.logger.debug("sent result for '%d': %s", jobid, result)
|
||||
else:
|
||||
|
|
@ -137,7 +137,7 @@ class Predictor(mp.context.SpawnProcess):
|
|||
tasks.append((jobid0, shared_data0))
|
||||
else:
|
||||
# immediately anser
|
||||
self.resultq.put((jobid0, self.model.output_shape))
|
||||
self.resultq.put((jobid0, self.model.input_shape))
|
||||
if len(tasks) > 1:
|
||||
self.logger.debug("rebatching %d '%s' tasks of batch size %d",
|
||||
len(tasks), self.name, batch_size)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue