get_textlines_of_a_textregion_sorted: simplify

This commit is contained in:
Robert Sachunsky 2026-04-23 23:45:27 +02:00
parent be61875d6e
commit 998ee2ecee

View file

@ -71,6 +71,7 @@ from .utils.resize import resize_image
from .utils.shm import share_ndarray from .utils.shm import share_ndarray
from .utils import ( from .utils import (
ensure_array, ensure_array,
pairwise,
is_image_filename, is_image_filename,
isNaN, isNaN,
crop_image_inside_box, crop_image_inside_box,
@ -929,71 +930,56 @@ class Eynollah:
def get_textlines_of_a_textregion_sorted(self, textlines_textregion, cx_textline, cy_textline, w_h_textline): def get_textlines_of_a_textregion_sorted(self, textlines_textregion, cx_textline, cy_textline, w_h_textline):
N = len(cy_textline) N = len(cy_textline)
if N==0: if N <= 1:
return [] return textlines_textregion
diff_cy = np.abs( np.diff(sorted(cy_textline)) ) cx_textline = np.array(cx_textline)
diff_cx = np.abs(np.diff(sorted(cx_textline)) ) cy_textline = np.array(cy_textline)
diff_cy = np.abs(np.diff(np.sort(cy_textline)))
diff_cx = np.abs(np.diff(np.sort(cx_textline)))
if N > 1:
if len(diff_cy)>0: mean_y_diff = np.median(diff_cy)
mean_y_diff = np.mean(diff_cy) mean_x_diff = np.median(diff_cx)
mean_x_diff = np.mean(diff_cx)
count_hor = np.count_nonzero(np.diff(w_h_textline) > 0) count_hor = np.count_nonzero(np.diff(w_h_textline) > 0)
count_ver = len(w_h_textline) - count_hor count_ver = N - count_hor
else: else:
mean_y_diff = 0 mean_y_diff = 0
mean_x_diff = 0 mean_x_diff = 0
count_hor = 1 count_hor = 1
count_ver = 0 count_ver = 0
if count_hor >= count_ver: if count_hor >= count_ver:
row_threshold = mean_y_diff / 1.5 if mean_y_diff > 0 else 10 row_threshold = mean_y_diff / 1.5 if mean_y_diff > 0 else 10
indices_sorted_by_y = sorted(range(N), key=lambda i: cy_textline[i])
rows = [] rows = []
current_row = [indices_sorted_by_y[0]] for prev_idx, curr_idx in pairwise(np.argsort(cy_textline)):
for i in range(1, N): if not len(rows):
current_idx = indices_sorted_by_y[i] rows.append([prev_idx])
prev_idx = current_row[0] if abs(cy_textline[curr_idx] - cy_textline[prev_idx]) <= row_threshold:
if abs(cy_textline[current_idx] - cy_textline[prev_idx]) <= row_threshold: rows[-1].append(curr_idx)
current_row.append(current_idx)
else: else:
rows.append(current_row) rows.append([curr_idx])
current_row = [current_idx]
rows.append(current_row)
sorted_textlines = [] sorted_textlines = []
for row in rows: for row in rows:
row_sorted = sorted(row, key=lambda i: cx_textline[i]) for idx in np.argsort(cx_textline[row]):
for idx in row_sorted: sorted_textlines.append(textlines_textregion[row[idx]])
sorted_textlines.append(textlines_textregion[idx])
else: else:
row_threshold = mean_x_diff / 1.5 if mean_x_diff > 0 else 10 col_threshold = mean_x_diff / 1.5 if mean_x_diff > 0 else 10
indices_sorted_by_x = sorted(range(N), key=lambda i: cx_textline[i]) cols = []
for prev_idx, curr_idx in pairwise(np.argsort(cx_textline)):
rows = [] if not len(cols):
current_row = [indices_sorted_by_x[0]] cols.append([prev_idx])
if abs(cx_textline[curr_idx] - cx_textline[prev_idx]) <= col_threshold:
for i in range(1, N): cols[-1].append(curr_idx)
current_idy = indices_sorted_by_x[i]
prev_idy = current_row[0]
if abs(cx_textline[current_idy] - cx_textline[prev_idy] ) <= row_threshold:
current_row.append(current_idy)
else: else:
rows.append(current_row) cols.append([curr_idx])
current_row = [current_idy]
rows.append(current_row)
sorted_textlines = [] sorted_textlines = []
for row in rows: for col in cols:
row_sorted = sorted(row , key=lambda i: cy_textline[i]) for idx in np.argsort(cy_textline[col]):
for idy in row_sorted: sorted_textlines.append(textlines_textregion[col[idx]])
sorted_textlines.append(textlines_textregion[idy])
return sorted_textlines return sorted_textlines