Compare commits

...

13 commits
v0.4.0 ... main

Author SHA1 Message Date
Clemens Neudecker
3dcbb20cac
Merge pull request #159 from bertsky/main
update docker
2025-05-06 15:14:06 +02:00
Robert Sachunsky
e9179e1d34 docker: use latest core base stage 2025-05-02 00:16:22 +02:00
Robert Sachunsky
f8b4d29a59 docker: prepackage ocrd-all-module-dir.json 2025-05-02 00:16:22 +02:00
vahidrezanezhad
e2da7a6239 Fix model name to return the correct machine-based model name 2025-04-30 16:06:29 +02:00
vahidrezanezhad
b227736094 Fix OCR text cleaning to correctly handle 'U', 'K', and 'N' starting sentence; update text line splitting size 2025-04-30 16:04:34 +02:00
vahidrezanezhad
4cb4414740 Resolve remaining issue with #158 and resolving #124 2025-04-30 16:01:52 +02:00
vahidrezanezhad
208bde706f resolving issue #158 2025-04-30 13:55:09 +02:00
Konstantin Baierer
3e8adb86c2
Merge pull request #157 from qurator-spk/kba-patch-1
CI: Use most recent actions/setup-python@v5
2025-04-29 11:42:18 +02:00
Konstantin Baierer
77dae129d5
CI: Use most recent actions/setup-python@v5 2025-04-22 13:22:28 +02:00
Clemens Neudecker
b4df978dd5
Merge pull request #154 from qurator-spk/ci-pypi
CI: pypi
2025-04-17 17:01:20 +02:00
kba
30ba234641 CI: pypi 2025-04-16 19:27:17 +02:00
kba
41318f0404 📝 changelog 2025-04-15 11:14:26 +02:00
vahidrezanezhad
a22df11ebb Restoring the contour in the original image caused an error due to an empty tuple. This issue has been resolved, and as expected, the confidence score for this contour is set to zero 2025-04-14 00:42:08 +02:00
7 changed files with 176 additions and 193 deletions

24
.github/workflows/pypi.yml vendored Normal file
View file

@ -0,0 +1,24 @@
name: PyPI CD
on:
release:
types: [published]
workflow_dispatch:
jobs:
pypi-publish:
name: upload release to PyPI
runs-on: ubuntu-latest
permissions:
# IMPORTANT: this permission is mandatory for Trusted Publishing
id-token: write
steps:
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
- name: Build package
run: make build
- name: Publish package distributions to PyPI
uses: pypa/gh-action-pypi-publish@release/v1
with:
verbose: true

View file

@ -5,6 +5,10 @@ Versioned according to [Semantic Versioning](http://semver.org/).
## Unreleased
Fixed:
* restoring the contour in the original image caused an error due to an empty tuple
## [0.4.0] - 2025-04-07
Fixed:

View file

@ -36,6 +36,8 @@ COPY . .
COPY ocrd-tool.json .
# prepackage ocrd-tool.json as ocrd-all-tool.json
RUN ocrd ocrd-tool ocrd-tool.json dump-tools > $(dirname $(ocrd bashlib filename))/ocrd-all-tool.json
# prepackage ocrd-all-module-dir.json
RUN ocrd ocrd-tool ocrd-tool.json dump-module-dirs > $(dirname $(ocrd bashlib filename))/ocrd-all-module-dir.json
# install everything and reduce image size
RUN make install EXTRAS=OCR && rm -rf /build/eynollah
# smoke test

View file

@ -3,8 +3,9 @@ PIP ?= pip3
EXTRAS ?=
# DOCKER_BASE_IMAGE = artefakt.dev.sbb.berlin:5000/sbb/ocrd_core:v2.68.0
DOCKER_BASE_IMAGE = docker.io/ocrd/core-cuda-tf2:v3.3.0
DOCKER_TAG = ocrd/eynollah
DOCKER_BASE_IMAGE ?= docker.io/ocrd/core-cuda-tf2:latest
DOCKER_TAG ?= ocrd/eynollah
DOCKER ?= docker
#SEG_MODEL := https://qurator-data.de/eynollah/2021-04-25/models_eynollah.tar.gz
#SEG_MODEL := https://qurator-data.de/eynollah/2022-04-05/models_eynollah_renamed.tar.gz
@ -117,7 +118,7 @@ coverage:
# Build docker image
docker:
docker build \
$(DOCKER) build \
--build-arg DOCKER_BASE_IMAGE=$(DOCKER_BASE_IMAGE) \
--build-arg VCS_REF=$$(git rev-parse --short HEAD) \
--build-arg BUILD_DATE=$$(date -u +"%Y-%m-%dT%H:%M:%SZ") \

View file

@ -3320,12 +3320,22 @@ class Eynollah:
def do_order_of_regions_with_model(self, contours_only_text_parent, contours_only_text_parent_h, text_regions_p):
y_len = text_regions_p.shape[0]
x_len = text_regions_p.shape[1]
img_poly = np.zeros((y_len,x_len), dtype='uint8')
img_poly[text_regions_p[:,:]==1] = 1
img_poly[text_regions_p[:,:]==2] = 2
img_poly[text_regions_p[:,:]==3] = 4
img_poly[text_regions_p[:,:]==6] = 5
#temp
sep_mask = (img_poly==5)*1
sep_mask = sep_mask.astype('uint8')
sep_mask = cv2.erode(sep_mask, kernel=KERNEL, iterations=2)
img_poly[img_poly==5] = 0
img_poly[sep_mask==1] = 5
#
img_header_and_sep = np.zeros((y_len,x_len), dtype='uint8')
if contours_only_text_parent_h:
@ -3341,9 +3351,13 @@ class Eynollah:
if not len(co_text_all):
return [], []
labels_con = np.zeros((y_len, x_len, len(co_text_all)), dtype=bool)
labels_con = np.zeros((int(y_len /6.), int(x_len/6.), len(co_text_all)), dtype=bool)
co_text_all = [(i/6).astype(int) for i in co_text_all]
for i in range(len(co_text_all)):
img = labels_con[:,:,i].astype(np.uint8)
#img = cv2.resize(img, (int(img.shape[1]/6), int(img.shape[0]/6)), interpolation=cv2.INTER_NEAREST)
cv2.fillPoly(img, pts=[co_text_all[i]], color=(1,))
labels_con[:,:,i] = img
@ -3359,6 +3373,7 @@ class Eynollah:
labels_con = resize_image(labels_con.astype(np.uint8), height1, width1).astype(bool)
img_header_and_sep = resize_image(img_header_and_sep, height1, width1)
img_poly = resize_image(img_poly, height3, width3)
inference_bs = 3
input_1 = np.zeros((inference_bs, height1, width1, 3))
@ -4575,10 +4590,6 @@ class Eynollah:
return pcgts
## check the ro order
#print("text region early 3 in %.1fs", time.time() - t0)
if self.light_version:
@ -4886,7 +4897,7 @@ class Eynollah_ocr:
self.model_ocr.to(self.device)
else:
self.model_ocr_dir = dir_models + "/model_step_75000_ocr"#"/model_0_ocr_cnnrnn"#"/model_23_ocr_cnnrnn"
self.model_ocr_dir = dir_models + "/model_step_1050000_ocr"#"/model_0_ocr_cnnrnn"#"/model_23_ocr_cnnrnn"
model_ocr = load_model(self.model_ocr_dir , compile=False)
self.prediction_model = tf.keras.models.Model(
@ -4974,7 +4985,7 @@ class Eynollah_ocr:
def return_start_and_end_of_common_text_of_textline_ocr_without_common_section(self, textline_image):
width = np.shape(textline_image)[1]
height = np.shape(textline_image)[0]
common_window = int(0.06*width)
common_window = int(0.22*width)
width1 = int ( width/2. - common_window )
width2 = int ( width/2. + common_window )
@ -4984,13 +4995,17 @@ class Eynollah_ocr:
peaks_real, _ = find_peaks(sum_smoothed, height=0)
if len(peaks_real)>70:
if len(peaks_real)>35:
peaks_real = peaks_real[(peaks_real<width2) & (peaks_real>width1)]
#peaks_real = peaks_real[(peaks_real<width2) & (peaks_real>width1)]
argsort = np.argsort(sum_smoothed[peaks_real])[::-1]
peaks_real_top_six = peaks_real[argsort[:6]]
midpoint = textline_image.shape[1] / 2.
arg_closest = np.argmin(np.abs(peaks_real_top_six - midpoint))
arg_max = np.argmax(sum_smoothed[peaks_real])
#arg_max = np.argmax(sum_smoothed[peaks_real])
peaks_final = peaks_real[arg_max]
peaks_final = peaks_real_top_six[arg_closest]#peaks_real[arg_max]
return peaks_final
else:
@ -5038,10 +5053,19 @@ class Eynollah_ocr:
if width_new == 0:
width_new = img.shape[1]
##if width_new+32 >= image_width:
##width_new = width_new - 32
###patch_zero = np.zeros((32, 32, 3))#+255
###patch_zero[9:19,8:18,:] = 0
img = resize_image(img, image_height, width_new)
img_fin = np.ones((image_height, image_width, 3))*255
img_fin[:,:+width_new,:] = img[:,:,:]
###img_fin[:,:32,:] = patch_zero[:,:,:]
###img_fin[:,32:32+width_new,:] = img[:,:,:]
img_fin[:,:width_new,:] = img[:,:,:]
img_fin = img_fin / 255.
return img_fin
@ -5097,7 +5121,7 @@ class Eynollah_ocr:
img_crop = img_poly_on_img[y:y+h, x:x+w, :]
img_crop[mask_poly==0] = 255
if h2w_ratio > 0.05:
if h2w_ratio > 0.1:
cropped_lines.append(img_crop)
cropped_lines_meging_indexing.append(0)
else:
@ -5234,7 +5258,7 @@ class Eynollah_ocr:
if self.draw_texts_on_image:
total_bb_coordinates.append([x,y,w,h])
h2w_ratio = h/float(w)
w_scaled = w * image_height/float(h)
img_poly_on_img = np.copy(img)
if self.prediction_with_both_of_rgb_and_bin:
@ -5252,7 +5276,7 @@ class Eynollah_ocr:
img_crop_bin[mask_poly==0] = 255
if not self.export_textline_images_and_text:
if h2w_ratio > 0.1:
if w_scaled < 1.5*image_width:
img_fin = self.preprocess_and_resize_image_for_ocrcnn_model(img_crop, image_height, image_width)
cropped_lines.append(img_fin)
cropped_lines_meging_indexing.append(0)
@ -5334,11 +5358,11 @@ class Eynollah_ocr:
if self.prediction_with_both_of_rgb_and_bin:
preds_bin = self.prediction_model.predict(imgs_bin, verbose=0)
preds = (preds + preds_bin) / 2.
pred_texts = self.decode_batch_predictions(preds)
for ib in range(imgs.shape[0]):
pred_texts_ib = pred_texts[ib].strip("[UNK]")
pred_texts_ib = pred_texts[ib].replace("[UNK]", "")
extracted_texts.append(pred_texts_ib)
extracted_texts_merged = [extracted_texts[ind] if cropped_lines_meging_indexing[ind]==0 else extracted_texts[ind]+" "+extracted_texts[ind+1] if cropped_lines_meging_indexing[ind]==1 else None for ind in range(len(cropped_lines_meging_indexing))]
@ -5378,7 +5402,7 @@ class Eynollah_ocr:
text_by_textregion = []
for ind in unique_cropped_lines_region_indexer:
extracted_texts_merged_un = np.array(extracted_texts_merged)[np.array(cropped_lines_region_indexer)==ind]
text_by_textregion.append(" ".join(extracted_texts_merged_un))
text_by_textregion.append("".join(extracted_texts_merged_un))
indexer = 0
indexer_textregion = 0

View file

@ -230,7 +230,6 @@ def get_textregion_contours_in_org_image_light_old(cnts, img, slope_first):
def do_back_rotation_and_get_cnt_back(contour_par, index_r_con, img, slope_first, confidence_matrix):
img_copy = np.zeros(img.shape)
img_copy = cv2.fillPoly(img_copy, pts=[contour_par], color=(1, 1, 1))
confidence_matrix_mapped_with_contour = confidence_matrix * img_copy[:,:,0]
confidence_contour = np.sum(confidence_matrix_mapped_with_contour) / float(np.sum(img_copy[:,:,0]))
@ -239,9 +238,13 @@ def do_back_rotation_and_get_cnt_back(contour_par, index_r_con, img, slope_first
ret, thresh = cv2.threshold(imgray, 0, 255, 0)
cont_int, _ = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
cont_int[0][:, 0, 0] = cont_int[0][:, 0, 0] + np.abs(img_copy.shape[1] - img.shape[1])
cont_int[0][:, 0, 1] = cont_int[0][:, 0, 1] + np.abs(img_copy.shape[0] - img.shape[0])
# print(np.shape(cont_int[0]))
if len(cont_int)==0:
cont_int = []
cont_int.append(contour_par)
confidence_contour = 0
else:
cont_int[0][:, 0, 0] = cont_int[0][:, 0, 0] + np.abs(img_copy.shape[1] - img.shape[1])
cont_int[0][:, 0, 1] = cont_int[0][:, 0, 1] + np.abs(img_copy.shape[0] - img.shape[0])
return cont_int[0], index_r_con, confidence_contour
def get_textregion_contours_in_org_image_light(cnts, img, slope_first, confidence_matrix, map=map):

View file

@ -102,14 +102,15 @@ def dedup_separate_lines(img_patch, contour_text_interest, thetha, axis):
textline_con_fil = filter_contours_area_of_image(img_patch,
textline_con, hierarchy,
max_area=1, min_area=0.0008)
y_diff_mean = np.mean(np.diff(peaks_new_tot)) # self.find_contours_mean_y_diff(textline_con_fil)
sigma_gaus = int(y_diff_mean * (7.0 / 40.0))
# print(sigma_gaus,'sigma_gaus')
if len(np.diff(peaks_new_tot))>1:
y_diff_mean = np.mean(np.diff(peaks_new_tot)) # self.find_contours_mean_y_diff(textline_con_fil)
sigma_gaus = int(y_diff_mean * (7.0 / 40.0))
else:
sigma_gaus = 12
except:
sigma_gaus = 12
if sigma_gaus < 3:
sigma_gaus = 3
# print(sigma_gaus,'sigma')
y_padded_smoothed = gaussian_filter1d(y_padded, sigma_gaus)
y_padded_up_to_down = -y_padded + np.max(y_padded)
@ -137,7 +138,6 @@ def separate_lines(img_patch, contour_text_interest, thetha, x_help, y_help):
M = cv2.getRotationMatrix2D(center, -thetha, 1.0)
x_d = M[0, 2]
y_d = M[1, 2]
thetha = thetha / 180. * np.pi
rotation_matrix = np.array([[np.cos(thetha), -np.sin(thetha)], [np.sin(thetha), np.cos(thetha)]])
contour_text_interest_copy = contour_text_interest.copy()
@ -162,73 +162,73 @@ def separate_lines(img_patch, contour_text_interest, thetha, x_help, y_help):
x = np.array(range(len(y)))
peaks_real, _ = find_peaks(gaussian_filter1d(y, 3), height=0)
if 1>0:
try:
y_padded_smoothed_e= gaussian_filter1d(y_padded, 2)
y_padded_up_to_down_e=-y_padded+np.max(y_padded)
y_padded_up_to_down_padded_e=np.zeros(len(y_padded_up_to_down_e)+40)
y_padded_up_to_down_padded_e[20:len(y_padded_up_to_down_e)+20]=y_padded_up_to_down_e
y_padded_up_to_down_padded_e= gaussian_filter1d(y_padded_up_to_down_padded_e, 2)
try:
y_padded_smoothed_e= gaussian_filter1d(y_padded, 2)
y_padded_up_to_down_e=-y_padded+np.max(y_padded)
y_padded_up_to_down_padded_e=np.zeros(len(y_padded_up_to_down_e)+40)
y_padded_up_to_down_padded_e[20:len(y_padded_up_to_down_e)+20]=y_padded_up_to_down_e
y_padded_up_to_down_padded_e= gaussian_filter1d(y_padded_up_to_down_padded_e, 2)
peaks_e, _ = find_peaks(y_padded_smoothed_e, height=0)
peaks_neg_e, _ = find_peaks(y_padded_up_to_down_padded_e, height=0)
neg_peaks_max=np.max(y_padded_up_to_down_padded_e[peaks_neg_e])
peaks_e, _ = find_peaks(y_padded_smoothed_e, height=0)
peaks_neg_e, _ = find_peaks(y_padded_up_to_down_padded_e, height=0)
neg_peaks_max=np.max(y_padded_up_to_down_padded_e[peaks_neg_e])
arg_neg_must_be_deleted= np.arange(len(peaks_neg_e))[y_padded_up_to_down_padded_e[peaks_neg_e]/float(neg_peaks_max)<0.3]
diff_arg_neg_must_be_deleted=np.diff(arg_neg_must_be_deleted)
arg_diff=np.array(range(len(diff_arg_neg_must_be_deleted)))
arg_diff_cluster=arg_diff[diff_arg_neg_must_be_deleted>1]
peaks_new=peaks_e[:]
peaks_neg_new=peaks_neg_e[:]
arg_neg_must_be_deleted= np.arange(len(peaks_neg_e))[y_padded_up_to_down_padded_e[peaks_neg_e]/float(neg_peaks_max)<0.3]
diff_arg_neg_must_be_deleted=np.diff(arg_neg_must_be_deleted)
arg_diff=np.array(range(len(diff_arg_neg_must_be_deleted)))
arg_diff_cluster=arg_diff[diff_arg_neg_must_be_deleted>1]
clusters_to_be_deleted=[]
if len(arg_diff_cluster)>0:
clusters_to_be_deleted.append(arg_neg_must_be_deleted[0:arg_diff_cluster[0]+1])
for i in range(len(arg_diff_cluster)-1):
clusters_to_be_deleted.append(arg_neg_must_be_deleted[arg_diff_cluster[i]+1:
arg_diff_cluster[i+1]+1])
clusters_to_be_deleted.append(arg_neg_must_be_deleted[arg_diff_cluster[len(arg_diff_cluster)-1]+1:])
if len(clusters_to_be_deleted)>0:
peaks_new_extra=[]
for m in range(len(clusters_to_be_deleted)):
min_cluster=np.min(peaks_e[clusters_to_be_deleted[m]])
max_cluster=np.max(peaks_e[clusters_to_be_deleted[m]])
peaks_new_extra.append( int( (min_cluster+max_cluster)/2.0) )
for m1 in range(len(clusters_to_be_deleted[m])):
peaks_new=peaks_new[peaks_new!=peaks_e[clusters_to_be_deleted[m][m1]-1]]
peaks_new=peaks_new[peaks_new!=peaks_e[clusters_to_be_deleted[m][m1]]]
peaks_neg_new=peaks_neg_new[peaks_neg_new!=peaks_neg_e[clusters_to_be_deleted[m][m1]]]
peaks_new_tot=[]
for i1 in peaks_new:
peaks_new_tot.append(i1)
for i1 in peaks_new_extra:
peaks_new_tot.append(i1)
peaks_new_tot=np.sort(peaks_new_tot)
else:
peaks_new_tot=peaks_e[:]
peaks_new=peaks_e[:]
peaks_neg_new=peaks_neg_e[:]
textline_con,hierarchy=return_contours_of_image(img_patch)
textline_con_fil=filter_contours_area_of_image(img_patch,
textline_con, hierarchy,
max_area=1, min_area=0.0008)
clusters_to_be_deleted=[]
if len(arg_diff_cluster)>0:
clusters_to_be_deleted.append(arg_neg_must_be_deleted[0:arg_diff_cluster[0]+1])
for i in range(len(arg_diff_cluster)-1):
clusters_to_be_deleted.append(arg_neg_must_be_deleted[arg_diff_cluster[i]+1:
arg_diff_cluster[i+1]+1])
clusters_to_be_deleted.append(arg_neg_must_be_deleted[arg_diff_cluster[len(arg_diff_cluster)-1]+1:])
if len(clusters_to_be_deleted)>0:
peaks_new_extra=[]
for m in range(len(clusters_to_be_deleted)):
min_cluster=np.min(peaks_e[clusters_to_be_deleted[m]])
max_cluster=np.max(peaks_e[clusters_to_be_deleted[m]])
peaks_new_extra.append( int( (min_cluster+max_cluster)/2.0) )
for m1 in range(len(clusters_to_be_deleted[m])):
peaks_new=peaks_new[peaks_new!=peaks_e[clusters_to_be_deleted[m][m1]-1]]
peaks_new=peaks_new[peaks_new!=peaks_e[clusters_to_be_deleted[m][m1]]]
peaks_neg_new=peaks_neg_new[peaks_neg_new!=peaks_neg_e[clusters_to_be_deleted[m][m1]]]
peaks_new_tot=[]
for i1 in peaks_new:
peaks_new_tot.append(i1)
for i1 in peaks_new_extra:
peaks_new_tot.append(i1)
peaks_new_tot=np.sort(peaks_new_tot)
else:
peaks_new_tot=peaks_e[:]
textline_con,hierarchy=return_contours_of_image(img_patch)
textline_con_fil=filter_contours_area_of_image(img_patch,
textline_con, hierarchy,
max_area=1, min_area=0.0008)
if len(np.diff(peaks_new_tot))>0:
y_diff_mean=np.mean(np.diff(peaks_new_tot))#self.find_contours_mean_y_diff(textline_con_fil)
sigma_gaus=int( y_diff_mean * (7./40.0) )
#print(sigma_gaus,'sigma_gaus')
except:
else:
sigma_gaus=12
if sigma_gaus<3:
sigma_gaus=3
#print(sigma_gaus,'sigma')
except:
sigma_gaus=12
if sigma_gaus<3:
sigma_gaus=3
y_padded_smoothed= gaussian_filter1d(y_padded, sigma_gaus)
y_padded_up_to_down=-y_padded+np.max(y_padded)
y_padded_up_to_down_padded=np.zeros(len(y_padded_up_to_down)+40)
y_padded_up_to_down_padded[20:len(y_padded_up_to_down)+20]=y_padded_up_to_down
y_padded_up_to_down_padded= gaussian_filter1d(y_padded_up_to_down_padded, sigma_gaus)
peaks, _ = find_peaks(y_padded_smoothed, height=0)
peaks_neg, _ = find_peaks(y_padded_up_to_down_padded, height=0)
@ -239,6 +239,7 @@ def separate_lines(img_patch, contour_text_interest, thetha, x_help, y_help):
arg_diff=np.array(range(len(diff_arg_neg_must_be_deleted)))
arg_diff_cluster=arg_diff[diff_arg_neg_must_be_deleted>1]
except:
arg_neg_must_be_deleted=[]
arg_diff_cluster=[]
@ -246,7 +247,6 @@ def separate_lines(img_patch, contour_text_interest, thetha, x_help, y_help):
peaks_new=peaks[:]
peaks_neg_new=peaks_neg[:]
clusters_to_be_deleted=[]
if len(arg_diff_cluster)>=2 and len(arg_diff_cluster)>0:
clusters_to_be_deleted.append(arg_neg_must_be_deleted[0:arg_diff_cluster[0]+1])
for i in range(len(arg_diff_cluster)-1):
@ -275,21 +275,6 @@ def separate_lines(img_patch, contour_text_interest, thetha, x_help, y_help):
peaks_new_tot.append(i1)
peaks_new_tot=np.sort(peaks_new_tot)
##plt.plot(y_padded_up_to_down_padded)
##plt.plot(peaks_neg,y_padded_up_to_down_padded[peaks_neg],'*')
##plt.show()
##plt.plot(y_padded_up_to_down_padded)
##plt.plot(peaks_neg_new,y_padded_up_to_down_padded[peaks_neg_new],'*')
##plt.show()
##plt.plot(y_padded_smoothed)
##plt.plot(peaks,y_padded_smoothed[peaks],'*')
##plt.show()
##plt.plot(y_padded_smoothed)
##plt.plot(peaks_new_tot,y_padded_smoothed[peaks_new_tot],'*')
##plt.show()
peaks=peaks_new_tot[:]
peaks_neg=peaks_neg_new[:]
else:
@ -298,11 +283,13 @@ def separate_lines(img_patch, contour_text_interest, thetha, x_help, y_help):
peaks_neg=peaks_neg_new[:]
except:
pass
mean_value_of_peaks=np.mean(y_padded_smoothed[peaks])
std_value_of_peaks=np.std(y_padded_smoothed[peaks])
if len(y_padded_smoothed[peaks]) > 1:
mean_value_of_peaks=np.mean(y_padded_smoothed[peaks])
std_value_of_peaks=np.std(y_padded_smoothed[peaks])
else:
mean_value_of_peaks = np.nan
std_value_of_peaks = np.nan
peaks_values=y_padded_smoothed[peaks]
peaks_neg = peaks_neg - 20 - 20
peaks = peaks - 20
for jj in range(len(peaks_neg)):
@ -345,7 +332,6 @@ def separate_lines(img_patch, contour_text_interest, thetha, x_help, y_help):
point_down_narrow = peaks[jj] + first_nonzero + int(
1.1 * dis_to_next_down) ###-int(dis_to_next_down*1./2)
if point_down_narrow >= img_patch.shape[0]:
point_down_narrow = img_patch.shape[0] - 2
@ -601,7 +587,6 @@ def separate_lines(img_patch, contour_text_interest, thetha, x_help, y_help):
[int(x_max), int(point_up)],
[int(x_max), int(point_down)],
[int(x_min), int(point_down)]]))
return peaks, textline_boxes_rot
def separate_lines_vertical(img_patch, contour_text_interest, thetha):
@ -633,7 +618,7 @@ def separate_lines_vertical(img_patch, contour_text_interest, thetha):
peaks_neg_new = peaks_neg[:]
clusters_to_be_deleted = []
if len(arg_diff_cluster) >= 2 and len(arg_diff_cluster) > 0:
if len(arg_neg_must_be_deleted) >= 2 and len(arg_diff_cluster) >= 2:
clusters_to_be_deleted.append(arg_neg_must_be_deleted[0 : arg_diff_cluster[0] + 1])
for i in range(len(arg_diff_cluster) - 1):
clusters_to_be_deleted.append(arg_neg_must_be_deleted[arg_diff_cluster[i] + 1 :
@ -641,7 +626,7 @@ def separate_lines_vertical(img_patch, contour_text_interest, thetha):
clusters_to_be_deleted.append(arg_neg_must_be_deleted[arg_diff_cluster[len(arg_diff_cluster) - 1] + 1 :])
elif len(arg_neg_must_be_deleted) >= 2 and len(arg_diff_cluster) == 0:
clusters_to_be_deleted.append(arg_neg_must_be_deleted[:])
if len(arg_neg_must_be_deleted) == 1:
else:
clusters_to_be_deleted.append(arg_neg_must_be_deleted)
if len(clusters_to_be_deleted) > 0:
peaks_new_extra = []
@ -667,9 +652,14 @@ def separate_lines_vertical(img_patch, contour_text_interest, thetha):
peaks_new_tot = peaks[:]
peaks = peaks_new_tot[:]
peaks_neg = peaks_neg_new[:]
mean_value_of_peaks = np.mean(y_padded_smoothed[peaks])
std_value_of_peaks = np.std(y_padded_smoothed[peaks])
if len(y_padded_smoothed[peaks])>1:
mean_value_of_peaks = np.mean(y_padded_smoothed[peaks])
std_value_of_peaks = np.std(y_padded_smoothed[peaks])
else:
mean_value_of_peaks = np.nan
std_value_of_peaks = np.nan
peaks_values = y_padded_smoothed[peaks]
peaks_neg = peaks_neg - 20 - 20
@ -687,7 +677,6 @@ def separate_lines_vertical(img_patch, contour_text_interest, thetha):
textline_boxes_rot = []
if len(peaks_neg) == len(peaks) + 1 and len(peaks) >= 3:
# print('11')
for jj in range(len(peaks)):
if jj == (len(peaks) - 1):
@ -994,15 +983,16 @@ def separate_lines_new_inside_tiles2(img_patch, thetha):
textline_con_fil = filter_contours_area_of_image(img_patch,
textline_con, hierarchy,
max_area=1, min_area=0.0008)
y_diff_mean = np.mean(np.diff(peaks_new_tot)) # self.find_contours_mean_y_diff(textline_con_fil)
if len(np.diff(peaks_new_tot)):
y_diff_mean = np.mean(np.diff(peaks_new_tot)) # self.find_contours_mean_y_diff(textline_con_fil)
sigma_gaus = int(y_diff_mean * (7.0 / 40.0))
else:
sigma_gaus = 12
sigma_gaus = int(y_diff_mean * (7.0 / 40.0))
# print(sigma_gaus,'sigma_gaus')
except:
sigma_gaus = 12
if sigma_gaus < 3:
sigma_gaus = 3
# print(sigma_gaus,'sigma')
y_padded_smoothed = gaussian_filter1d(y_padded, sigma_gaus)
y_padded_up_to_down = -y_padded + np.max(y_padded)
@ -1026,7 +1016,7 @@ def separate_lines_new_inside_tiles2(img_patch, thetha):
arg_diff_cluster = arg_diff[diff_arg_neg_must_be_deleted > 1]
clusters_to_be_deleted = []
if len(arg_diff_cluster) >= 2 and len(arg_diff_cluster) > 0:
if len(arg_neg_must_be_deleted) >= 2 and len(arg_diff_cluster) >= 2:
clusters_to_be_deleted.append(arg_neg_must_be_deleted[0 : arg_diff_cluster[0] + 1])
for i in range(len(arg_diff_cluster) - 1):
clusters_to_be_deleted.append(arg_neg_must_be_deleted[arg_diff_cluster[i] + 1 :
@ -1034,7 +1024,7 @@ def separate_lines_new_inside_tiles2(img_patch, thetha):
clusters_to_be_deleted.append(arg_neg_must_be_deleted[arg_diff_cluster[len(arg_diff_cluster) - 1] + 1 :])
elif len(arg_neg_must_be_deleted) >= 2 and len(arg_diff_cluster) == 0:
clusters_to_be_deleted.append(arg_neg_must_be_deleted[:])
if len(arg_neg_must_be_deleted) == 1:
else:
clusters_to_be_deleted.append(arg_neg_must_be_deleted)
if len(clusters_to_be_deleted) > 0:
peaks_new_extra = []
@ -1077,9 +1067,14 @@ def separate_lines_new_inside_tiles2(img_patch, thetha):
peaks_new_tot = peaks[:]
peaks = peaks_new_tot[:]
peaks_neg = peaks_neg_new[:]
mean_value_of_peaks = np.mean(y_padded_smoothed[peaks])
std_value_of_peaks = np.std(y_padded_smoothed[peaks])
if len(y_padded_smoothed[peaks]) > 1:
mean_value_of_peaks = np.mean(y_padded_smoothed[peaks])
std_value_of_peaks = np.std(y_padded_smoothed[peaks])
else:
mean_value_of_peaks = np.nan
std_value_of_peaks = np.nan
peaks_values = y_padded_smoothed[peaks]
###peaks_neg = peaks_neg - 20 - 20
@ -1089,10 +1084,8 @@ def separate_lines_new_inside_tiles2(img_patch, thetha):
if len(peaks_neg_true) > 0:
peaks_neg_true = np.array(peaks_neg_true)
peaks_neg_true = peaks_neg_true - 20 - 20
# print(peaks_neg_true)
for i in range(len(peaks_neg_true)):
img_patch[peaks_neg_true[i] - 6 : peaks_neg_true[i] + 6, :] = 0
else:
@ -1177,13 +1170,11 @@ def separate_lines_new_inside_tiles(img_path, thetha):
if diff_peaks[i] <= cut_off:
forest.append(peaks_neg[i + 1])
if diff_peaks[i] > cut_off:
# print(forest[np.argmin(z[forest]) ] )
if not np.isnan(forest[np.argmin(z[forest])]):
peaks_neg_true.append(forest[np.argmin(z[forest])])
forest = []
forest.append(peaks_neg[i + 1])
if i == (len(peaks_neg) - 1):
# print(print(forest[np.argmin(z[forest]) ] ))
if not np.isnan(forest[np.argmin(z[forest])]):
peaks_neg_true.append(forest[np.argmin(z[forest])])
@ -1200,17 +1191,14 @@ def separate_lines_new_inside_tiles(img_path, thetha):
if diff_peaks_pos[i] <= cut_off:
forest.append(peaks[i + 1])
if diff_peaks_pos[i] > cut_off:
# print(forest[np.argmin(z[forest]) ] )
if not np.isnan(forest[np.argmax(z[forest])]):
peaks_pos_true.append(forest[np.argmax(z[forest])])
forest = []
forest.append(peaks[i + 1])
if i == (len(peaks) - 1):
# print(print(forest[np.argmin(z[forest]) ] ))
if not np.isnan(forest[np.argmax(z[forest])]):
peaks_pos_true.append(forest[np.argmax(z[forest])])
# print(len(peaks_neg_true) ,len(peaks_pos_true) ,'lensss')
if len(peaks_neg_true) > 0:
peaks_neg_true = np.array(peaks_neg_true)
@ -1236,7 +1224,6 @@ def separate_lines_new_inside_tiles(img_path, thetha):
"""
peaks_neg_true = peaks_neg_true - 20 - 20
# print(peaks_neg_true)
for i in range(len(peaks_neg_true)):
img_path[peaks_neg_true[i] - 6 : peaks_neg_true[i] + 6, :] = 0
@ -1278,7 +1265,6 @@ def separate_lines_vertical_cont(img_patch, contour_text_interest, thetha, box_i
contours_imgs, hierarchy,
max_area=max_area, min_area=min_area)
cont_final = []
###print(add_boxes_coor_into_textlines,'ikki')
for i in range(len(contours_imgs)):
img_contour = np.zeros((cnts_images.shape[0], cnts_images.shape[1], 3))
img_contour = cv2.fillPoly(img_contour, pts=[contours_imgs[i]], color=(255, 255, 255))
@ -1293,12 +1279,10 @@ def separate_lines_vertical_cont(img_patch, contour_text_interest, thetha, box_i
##0]
##contour_text_copy[:, 0, 1] = contour_text_copy[:, 0, 1] - box_ind[1]
##if add_boxes_coor_into_textlines:
##print(np.shape(contours_text_rot[0]),'sjppo')
##contours_text_rot[0][:, 0, 0]=contours_text_rot[0][:, 0, 0] + box_ind[0]
##contours_text_rot[0][:, 0, 1]=contours_text_rot[0][:, 0, 1] + box_ind[1]
cont_final.append(contours_text_rot[0])
##print(cont_final,'nadizzzz')
return None, cont_final
def textline_contours_postprocessing(textline_mask, slope, contour_text_interest, box_ind, add_boxes_coor_into_textlines=False):
@ -1309,20 +1293,7 @@ def textline_contours_postprocessing(textline_mask, slope, contour_text_interest
textline_mask = cv2.morphologyEx(textline_mask, cv2.MORPH_CLOSE, kernel)
textline_mask = cv2.erode(textline_mask, kernel, iterations=2)
# textline_mask = cv2.erode(textline_mask, kernel, iterations=1)
# print(textline_mask.shape[0]/float(textline_mask.shape[1]),'miz')
try:
# if np.abs(slope)>.5 and textline_mask.shape[0]/float(textline_mask.shape[1])>3:
# plt.imshow(textline_mask)
# plt.show()
# if abs(slope)>1:
# x_help=30
# y_help=2
# else:
# x_help=2
# y_help=2
x_help = 30
y_help = 2
@ -1346,28 +1317,12 @@ def textline_contours_postprocessing(textline_mask, slope, contour_text_interest
img_contour = np.zeros((box_ind[3], box_ind[2], 3))
img_contour = cv2.fillPoly(img_contour, pts=[contour_text_copy], color=(255, 255, 255))
# if np.abs(slope)>.5 and textline_mask.shape[0]/float(textline_mask.shape[1])>3:
# plt.imshow(img_contour)
# plt.show()
img_contour_help = np.zeros((img_contour.shape[0] + int(2 * y_help),
img_contour.shape[1] + int(2 * x_help), 3))
img_contour_help[y_help : y_help + img_contour.shape[0],
x_help : x_help + img_contour.shape[1], :] = np.copy(img_contour[:, :, :])
img_contour_rot = rotate_image(img_contour_help, slope)
# plt.imshow(img_contour_rot_help)
# plt.show()
# plt.imshow(dst_help)
# plt.show()
# if np.abs(slope)>.5 and textline_mask.shape[0]/float(textline_mask.shape[1])>3:
# plt.imshow(img_contour_rot_help)
# plt.show()
# plt.imshow(dst_help)
# plt.show()
img_contour_rot = img_contour_rot.astype(np.uint8)
# dst_help = dst_help.astype(np.uint8)
@ -1378,9 +1333,7 @@ def textline_contours_postprocessing(textline_mask, slope, contour_text_interest
len_con_text_rot = [len(contours_text_rot[ib]) for ib in range(len(contours_text_rot))]
ind_big_con = np.argmax(len_con_text_rot)
# print('juzaa')
if abs(slope) > 45:
# print(add_boxes_coor_into_textlines,'avval')
_, contours_rotated_clean = separate_lines_vertical_cont(
textline_mask, contours_text_rot[ind_big_con], box_ind, slope,
add_boxes_coor_into_textlines=add_boxes_coor_into_textlines)
@ -1412,7 +1365,6 @@ def separate_lines_new2(img_path, thetha, num_col, slope_region, logger=None, pl
length_x = int(img_path.shape[1] / float(num_patches))
# margin = int(0.04 * length_x) just recently this was changed because it break lines into 2
margin = int(0.04 * length_x)
# print(margin,'margin')
# if margin<=4:
# margin = int(0.08 * length_x)
# margin=0
@ -1452,11 +1404,9 @@ def separate_lines_new2(img_path, thetha, num_col, slope_region, logger=None, pl
# if abs(slope_region)>70 and abs(slope_xline)<25:
# slope_xline=[slope_region][0]
slopes_tile_wise.append(slope_xline)
# print(slope_xline,'xlineeee')
img_line_rotated = rotate_image(img_xline, slope_xline)
img_line_rotated[:, :][img_line_rotated[:, :] != 0] = 1
# print(slopes_tile_wise,'slopes_tile_wise')
img_patch_ineterst = img_path[:, :] # [peaks_neg_true[14]-dis_up:peaks_neg_true[14]+dis_down ,:]
img_patch_ineterst_revised = np.zeros(img_patch_ineterst.shape)
@ -1498,8 +1448,6 @@ def separate_lines_new2(img_path, thetha, num_col, slope_region, logger=None, pl
img_patch_separated_returned_true_size = img_patch_separated_returned_true_size[:, margin : length_x - margin]
img_patch_ineterst_revised[:, index_x_d + margin : index_x_u - margin] = img_patch_separated_returned_true_size
# plt.imshow(img_patch_ineterst_revised)
# plt.show()
return img_patch_ineterst_revised
def do_image_rotation(angle, img, sigma_des, logger=None):
@ -1532,20 +1480,13 @@ def return_deskew_slop(img_patch_org, sigma_des,n_tot_angles=100,
#img_resized[ int( img_int.shape[0]*(.4)):int( img_int.shape[0]*(.4))+img_int.shape[0] , int( img_int.shape[1]*(.8)):int( img_int.shape[1]*(.8))+img_int.shape[1] ]=img_int[:,:]
img_resized[ onset_y:onset_y+img_int.shape[0] , onset_x:onset_x+img_int.shape[1] ]=img_int[:,:]
#print(img_resized.shape,'img_resizedshape')
#plt.imshow(img_resized)
#plt.show()
if main_page and img_patch_org.shape[1] > img_patch_org.shape[0]:
#plt.imshow(img_resized)
#plt.show()
angles = np.array([-45, 0, 45, 90,])
angle = get_smallest_skew(img_resized, sigma_des, angles, map=map, logger=logger, plotter=plotter)
angles = np.linspace(angle - 22.5, angle + 22.5, n_tot_angles)
angle = get_smallest_skew(img_resized, sigma_des, angles, map=map, logger=logger, plotter=plotter)
elif main_page:
#plt.imshow(img_resized)
#plt.show()
angles = np.linspace(-12, 12, n_tot_angles)#np.array([0 , 45 , 90 , -45])
angle = get_smallest_skew(img_resized, sigma_des, angles, map=map, logger=logger, plotter=plotter)
@ -1632,22 +1573,12 @@ def do_work_of_slopes_new(
if slope_for_all == MAX_SLOPE:
slope_for_all = slope_deskew
slope = slope_for_all
mask_only_con_region = np.zeros(textline_mask_tot_ea.shape)
mask_only_con_region = cv2.fillPoly(mask_only_con_region, pts=[contour_par], color=(1, 1, 1))
# plt.imshow(mask_only_con_region)
# plt.show()
all_text_region_raw = textline_mask_tot_ea[y: y + h, x: x + w].copy()
mask_only_con_region = mask_only_con_region[y: y + h, x: x + w]
##plt.imshow(textline_mask_tot_ea)
##plt.show()
##plt.imshow(all_text_region_raw)
##plt.show()
##plt.imshow(mask_only_con_region)
##plt.show()
all_text_region_raw[mask_only_con_region == 0] = 0
cnt_clean_rot = textline_contours_postprocessing(all_text_region_raw, slope_for_all, contour_par, box_text)
@ -1708,20 +1639,15 @@ def do_work_of_slopes_new_curved(
mask_region_in_patch_region = mask_biggest[y : y + h, x : x + w]
textline_biggest_region = mask_biggest * textline_mask_tot_ea
# print(slope_for_all,'slope_for_all')
textline_rotated_separated = separate_lines_new2(textline_biggest_region[y: y+h, x: x+w], 0,
num_col, slope_for_all,
logger=logger, plotter=plotter)
# new line added
##print(np.shape(textline_rotated_separated),np.shape(mask_biggest))
textline_rotated_separated[mask_region_in_patch_region[:, :] != 1] = 0
# till here
textline_region_in_image[y : y + h, x : x + w] = textline_rotated_separated
# plt.imshow(textline_region_in_image)
# plt.show()
pixel_img = 1
cnt_textlines_in_image = return_contours_of_interested_textline(textline_region_in_image, pixel_img)
@ -1744,7 +1670,6 @@ def do_work_of_slopes_new_curved(
logger.error(why)
else:
textlines_cnt_per_region = textline_contours_postprocessing(all_text_region_raw, slope_for_all, contour_par, box_text, True)
# print(np.shape(textlines_cnt_per_region),'textlines_cnt_per_region')
return textlines_cnt_per_region[::-1], box_text, contour, contour_par, crop_coor, index_r_con, slope