get_textlines_of_a_textregion_sorted: simplify

This commit is contained in:
Robert Sachunsky 2026-04-23 23:45:27 +02:00
parent be61875d6e
commit 998ee2ecee

View file

@ -71,6 +71,7 @@ from .utils.resize import resize_image
from .utils.shm import share_ndarray
from .utils import (
ensure_array,
pairwise,
is_image_filename,
isNaN,
crop_image_inside_box,
@ -929,71 +930,56 @@ class Eynollah:
def get_textlines_of_a_textregion_sorted(self, textlines_textregion, cx_textline, cy_textline, w_h_textline):
N = len(cy_textline)
if N==0:
return []
if N <= 1:
return textlines_textregion
diff_cy = np.abs( np.diff(sorted(cy_textline)) )
diff_cx = np.abs(np.diff(sorted(cx_textline)) )
cx_textline = np.array(cx_textline)
cy_textline = np.array(cy_textline)
diff_cy = np.abs(np.diff(np.sort(cy_textline)))
diff_cx = np.abs(np.diff(np.sort(cx_textline)))
if len(diff_cy)>0:
mean_y_diff = np.mean(diff_cy)
mean_x_diff = np.mean(diff_cx)
if N > 1:
mean_y_diff = np.median(diff_cy)
mean_x_diff = np.median(diff_cx)
count_hor = np.count_nonzero(np.diff(w_h_textline) > 0)
count_ver = len(w_h_textline) - count_hor
count_ver = N - count_hor
else:
mean_y_diff = 0
mean_x_diff = 0
count_hor = 1
count_ver = 0
if count_hor >= count_ver:
row_threshold = mean_y_diff / 1.5 if mean_y_diff > 0 else 10
indices_sorted_by_y = sorted(range(N), key=lambda i: cy_textline[i])
rows = []
current_row = [indices_sorted_by_y[0]]
for i in range(1, N):
current_idx = indices_sorted_by_y[i]
prev_idx = current_row[0]
if abs(cy_textline[current_idx] - cy_textline[prev_idx]) <= row_threshold:
current_row.append(current_idx)
for prev_idx, curr_idx in pairwise(np.argsort(cy_textline)):
if not len(rows):
rows.append([prev_idx])
if abs(cy_textline[curr_idx] - cy_textline[prev_idx]) <= row_threshold:
rows[-1].append(curr_idx)
else:
rows.append(current_row)
current_row = [current_idx]
rows.append(current_row)
rows.append([curr_idx])
sorted_textlines = []
for row in rows:
row_sorted = sorted(row, key=lambda i: cx_textline[i])
for idx in row_sorted:
sorted_textlines.append(textlines_textregion[idx])
for idx in np.argsort(cx_textline[row]):
sorted_textlines.append(textlines_textregion[row[idx]])
else:
row_threshold = mean_x_diff / 1.5 if mean_x_diff > 0 else 10
indices_sorted_by_x = sorted(range(N), key=lambda i: cx_textline[i])
rows = []
current_row = [indices_sorted_by_x[0]]
for i in range(1, N):
current_idy = indices_sorted_by_x[i]
prev_idy = current_row[0]
if abs(cx_textline[current_idy] - cx_textline[prev_idy] ) <= row_threshold:
current_row.append(current_idy)
col_threshold = mean_x_diff / 1.5 if mean_x_diff > 0 else 10
cols = []
for prev_idx, curr_idx in pairwise(np.argsort(cx_textline)):
if not len(cols):
cols.append([prev_idx])
if abs(cx_textline[curr_idx] - cx_textline[prev_idx]) <= col_threshold:
cols[-1].append(curr_idx)
else:
rows.append(current_row)
current_row = [current_idy]
rows.append(current_row)
cols.append([curr_idx])
sorted_textlines = []
for row in rows:
row_sorted = sorted(row , key=lambda i: cy_textline[i])
for idy in row_sorted:
sorted_textlines.append(textlines_textregion[idy])
for col in cols:
for idx in np.argsort(cy_textline[col]):
sorted_textlines.append(textlines_textregion[col[idx]])
return sorted_textlines