mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-04-30 19:22:03 +02:00
predictor: cache models' input shape instead of output shape
This commit is contained in:
parent
829256df91
commit
264b00f8ab
2 changed files with 17 additions and 24 deletions
|
|
@ -193,11 +193,11 @@ class Eynollah:
|
||||||
for model in loadable:
|
for model in loadable:
|
||||||
# retrieve and cache output shapes
|
# retrieve and cache output shapes
|
||||||
if model.endswith(('_resized', '_patched')):
|
if model.endswith(('_resized', '_patched')):
|
||||||
# autosized models do not have a predefined output_shape
|
# autosized models do not have a predefined input_shape
|
||||||
# (and don't need one)
|
# (and don't need one)
|
||||||
continue
|
continue
|
||||||
self.logger.debug("model %s has output shape %s", model,
|
self.logger.debug("model %s has input shape %s", model,
|
||||||
self.model_zoo.get(model).output_shape)
|
self.model_zoo.get(model).input_shape)
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
if model_zoo := getattr(self, 'model_zoo', None):
|
if model_zoo := getattr(self, 'model_zoo', None):
|
||||||
|
|
@ -431,7 +431,7 @@ class Eynollah:
|
||||||
):
|
):
|
||||||
|
|
||||||
self.logger.debug("enter do_prediction (patches=%d)", patches)
|
self.logger.debug("enter do_prediction (patches=%d)", patches)
|
||||||
_, img_height_model, img_width_model, _ = model.output_shape
|
_, img_height_model, img_width_model, _ = model.input_shape
|
||||||
img_h_page = img.shape[0]
|
img_h_page = img.shape[0]
|
||||||
img_w_page = img.shape[1]
|
img_w_page = img.shape[1]
|
||||||
|
|
||||||
|
|
@ -634,7 +634,7 @@ class Eynollah:
|
||||||
):
|
):
|
||||||
|
|
||||||
self.logger.debug("enter do_prediction_new_concept (patches=%d)", patches)
|
self.logger.debug("enter do_prediction_new_concept (patches=%d)", patches)
|
||||||
_, img_height_model, img_width_model, _ = model.output_shape
|
_, img_height_model, img_width_model, _ = model.input_shape
|
||||||
|
|
||||||
img = img / 255.0
|
img = img / 255.0
|
||||||
img = img.astype(np.float16)
|
img = img.astype(np.float16)
|
||||||
|
|
@ -1845,6 +1845,7 @@ class Eynollah:
|
||||||
# not trained on drops directly, but it does work:
|
# not trained on drops directly, but it does work:
|
||||||
polygons_of_drop_capitals,
|
polygons_of_drop_capitals,
|
||||||
text_regions_p,
|
text_regions_p,
|
||||||
|
n_batch_inference=1, # 3 (causes OOM on 8 GB GPUs)
|
||||||
# input labels as in run_boxes_full_layout
|
# input labels as in run_boxes_full_layout
|
||||||
# output labels as in RO model's read_xml
|
# output labels as in RO model's read_xml
|
||||||
label_text=1,
|
label_text=1,
|
||||||
|
|
@ -1859,15 +1860,8 @@ class Eynollah:
|
||||||
# no drop-capital in RO model, yet
|
# no drop-capital in RO model, yet
|
||||||
label_drop_ro=4,
|
label_drop_ro=4,
|
||||||
):
|
):
|
||||||
# FIXME: use model.input_shape
|
model = self.model_zoo.get("reading_order")
|
||||||
height1 =672#448
|
_, height_model, width_model, _ = model.input_shape
|
||||||
width1 = 448#224
|
|
||||||
|
|
||||||
height2 =672#448
|
|
||||||
width2= 448#224
|
|
||||||
|
|
||||||
height3 =672#448
|
|
||||||
width3 = 448#224
|
|
||||||
|
|
||||||
ver_kernel = np.ones((5, 1), dtype=np.uint8)
|
ver_kernel = np.ones((5, 1), dtype=np.uint8)
|
||||||
hor_kernel = np.ones((1, 5), dtype=np.uint8)
|
hor_kernel = np.ones((1, 5), dtype=np.uint8)
|
||||||
|
|
@ -1961,16 +1955,15 @@ class Eynollah:
|
||||||
img = np.zeros(labels_con.shape[:2], dtype=np.uint8)
|
img = np.zeros(labels_con.shape[:2], dtype=np.uint8)
|
||||||
cv2.fillPoly(img, pts=[co_text_all[i] // 6], color=1)
|
cv2.fillPoly(img, pts=[co_text_all[i] // 6], color=1)
|
||||||
labels_con[:, :, i] = img
|
labels_con[:, :, i] = img
|
||||||
labels_con = resize_image(labels_con.astype(np.uint8), height1, width1).astype(bool)
|
labels_con = resize_image(labels_con.astype(np.uint8), height_model, width_model).astype(bool)
|
||||||
img_header_and_sep = resize_image(img_header_and_sep, height1, width1)
|
img_header_and_sep = resize_image(img_header_and_sep, height_model, width_model)
|
||||||
img_poly = resize_image(img_poly, height1, width1)
|
img_poly = resize_image(img_poly, height_model, width_model)
|
||||||
labels_con[img_poly == label_seps_ro] = 2
|
labels_con[img_poly == label_seps_ro] = 2
|
||||||
labels_con[img_header_and_sep == 1] = 3
|
labels_con[img_header_and_sep == 1] = 3
|
||||||
labels_con = labels_con / 3.
|
labels_con = labels_con / 3.
|
||||||
img_poly = img_poly / 5.
|
img_poly = img_poly / 5.
|
||||||
|
|
||||||
inference_bs = 1 # 3 (causes OOM on 8 GB GPUs)
|
input_1 = np.zeros((n_batch_inference, height_model, width_model, 3))
|
||||||
input_1 = np.zeros((inference_bs, height1, width1, 3))
|
|
||||||
ordered = [list(range(len(co_text_all)))]
|
ordered = [list(range(len(co_text_all)))]
|
||||||
index_update = 0
|
index_update = 0
|
||||||
#print(labels_con.shape[2],"number of regions for reading order")
|
#print(labels_con.shape[2],"number of regions for reading order")
|
||||||
|
|
@ -1989,8 +1982,8 @@ class Eynollah:
|
||||||
|
|
||||||
tot_counter += 1
|
tot_counter += 1
|
||||||
batch.append(j)
|
batch.append(j)
|
||||||
if tot_counter % inference_bs == 0 or tot_counter == len(ij_list):
|
if tot_counter % n_batch_inference == 0 or tot_counter == len(ij_list):
|
||||||
y_pr = self.model_zoo.get("reading_order").predict(input_1 , verbose=0)
|
y_pr = model.predict(input_1 , verbose=0)
|
||||||
for post_pr in y_pr:
|
for post_pr in y_pr:
|
||||||
if post_pr[0] >= 0.5:
|
if post_pr[0] >= 0.5:
|
||||||
post_list.append(j)
|
post_list.append(j)
|
||||||
|
|
|
||||||
|
|
@ -37,7 +37,7 @@ class Predictor(mp.context.SpawnProcess):
|
||||||
super().__init__(name="EynollahPredictor", daemon=True)
|
super().__init__(name="EynollahPredictor", daemon=True)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def output_shape(self):
|
def input_shape(self):
|
||||||
return self({})
|
return self({})
|
||||||
|
|
||||||
def predict(self, data: dict, verbose=0):
|
def predict(self, data: dict, verbose=0):
|
||||||
|
|
@ -122,7 +122,7 @@ class Predictor(mp.context.SpawnProcess):
|
||||||
REBATCH_SIZE = 1 # save VRAM; FIXME: re-enable w/ runtime parameter
|
REBATCH_SIZE = 1 # save VRAM; FIXME: re-enable w/ runtime parameter
|
||||||
if not len(shared_data):
|
if not len(shared_data):
|
||||||
#self.logger.debug("getting '%d' output shape of model '%s'", jobid, self.name)
|
#self.logger.debug("getting '%d' output shape of model '%s'", jobid, self.name)
|
||||||
result = self.model.output_shape
|
result = self.model.input_shape
|
||||||
self.resultq.put((jobid, result))
|
self.resultq.put((jobid, result))
|
||||||
#self.logger.debug("sent result for '%d': %s", jobid, result)
|
#self.logger.debug("sent result for '%d': %s", jobid, result)
|
||||||
else:
|
else:
|
||||||
|
|
@ -137,7 +137,7 @@ class Predictor(mp.context.SpawnProcess):
|
||||||
tasks.append((jobid0, shared_data0))
|
tasks.append((jobid0, shared_data0))
|
||||||
else:
|
else:
|
||||||
# immediately anser
|
# immediately anser
|
||||||
self.resultq.put((jobid0, self.model.output_shape))
|
self.resultq.put((jobid0, self.model.input_shape))
|
||||||
if len(tasks) > 1:
|
if len(tasks) > 1:
|
||||||
self.logger.debug("rebatching %d '%s' tasks of batch size %d",
|
self.logger.debug("rebatching %d '%s' tasks of batch size %d",
|
||||||
len(tasks), self.name, batch_size)
|
len(tasks), self.name, batch_size)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue