adopt mb_ro_on_layout to the zoo

This commit is contained in:
kba 2025-10-21 17:55:08 +02:00
parent bcffa2e503
commit f0c86672f8
2 changed files with 11 additions and 19 deletions

View file

@ -12,7 +12,7 @@ from difflib import SequenceMatcher as sq
import math import math
import os import os
import time import time
from typing import Dict, Type, Union,List, Optional, Tuple from typing import List, Optional, Tuple
import warnings import warnings
from functools import partial from functools import partial
from pathlib import Path from pathlib import Path

View file

@ -10,12 +10,13 @@ from pathlib import Path
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
import cv2 import cv2
from keras.models import Model
import numpy as np import numpy as np
from ocrd_utils import getLogger from ocrd_utils import getLogger
import statistics import statistics
import tensorflow as tf import tensorflow as tf
from tensorflow.keras.models import load_model
from .model_zoo import EynollahModelZoo
from .utils.resize import resize_image from .utils.resize import resize_image
from .utils.contour import ( from .utils.contour import (
find_new_features_of_contours, find_new_features_of_contours,
@ -23,7 +24,6 @@ from .utils.contour import (
return_parent_contours, return_parent_contours,
) )
from .utils import is_xml_filename from .utils import is_xml_filename
from .patch_encoder import PatchEncoder, Patches
DPI_THRESHOLD = 298 DPI_THRESHOLD = 298
KERNEL = np.ones((5, 5), np.uint8) KERNEL = np.ones((5, 5), np.uint8)
@ -45,21 +45,11 @@ class machine_based_reading_order_on_layout:
except: except:
self.logger.warning("no GPU device available") self.logger.warning("no GPU device available")
self.model_reading_order = self.our_load_model(self.model_reading_order_dir) self.model_zoo = EynollahModelZoo(basedir=dir_models)
self.model_zoo.load_model('reading_order')
# FIXME: light_version is always true, no need for checks in the code
self.light_version = True self.light_version = True
@staticmethod
def our_load_model(model_file):
if model_file.endswith('.h5') and Path(model_file[:-3]).exists():
# prefer SavedModel over HDF5 format if it exists
model_file = model_file[:-3]
try:
model = load_model(model_file, compile=False)
except:
model = load_model(model_file, compile=False, custom_objects={
"PatchEncoder": PatchEncoder, "Patches": Patches})
return model
def read_xml(self, xml_file): def read_xml(self, xml_file):
tree1 = ET.parse(xml_file, parser = ET.XMLParser(encoding='utf-8')) tree1 = ET.parse(xml_file, parser = ET.XMLParser(encoding='utf-8'))
root1=tree1.getroot() root1=tree1.getroot()
@ -69,6 +59,7 @@ class machine_based_reading_order_on_layout:
index_tot_regions = [] index_tot_regions = []
tot_region_ref = [] tot_region_ref = []
y_len, x_len = 0, 0
for jj in root1.iter(link+'Page'): for jj in root1.iter(link+'Page'):
y_len=int(jj.attrib['imageHeight']) y_len=int(jj.attrib['imageHeight'])
x_len=int(jj.attrib['imageWidth']) x_len=int(jj.attrib['imageWidth'])
@ -81,13 +72,13 @@ class machine_based_reading_order_on_layout:
co_printspace = [] co_printspace = []
if link+'PrintSpace' in alltags: if link+'PrintSpace' in alltags:
region_tags_printspace = np.unique([x for x in alltags if x.endswith('PrintSpace')]) region_tags_printspace = np.unique([x for x in alltags if x.endswith('PrintSpace')])
elif link+'Border' in alltags: else:
region_tags_printspace = np.unique([x for x in alltags if x.endswith('Border')]) region_tags_printspace = np.unique([x for x in alltags if x.endswith('Border')])
for tag in region_tags_printspace: for tag in region_tags_printspace:
if link+'PrintSpace' in alltags: if link+'PrintSpace' in alltags:
tag_endings_printspace = ['}PrintSpace','}printspace'] tag_endings_printspace = ['}PrintSpace','}printspace']
elif link+'Border' in alltags: else:
tag_endings_printspace = ['}Border','}border'] tag_endings_printspace = ['}Border','}border']
if tag.endswith(tag_endings_printspace[0]) or tag.endswith(tag_endings_printspace[1]): if tag.endswith(tag_endings_printspace[0]) or tag.endswith(tag_endings_printspace[1]):
@ -683,7 +674,7 @@ class machine_based_reading_order_on_layout:
tot_counter += 1 tot_counter += 1
batch.append(j) batch.append(j)
if tot_counter % inference_bs == 0 or tot_counter == len(ij_list): if tot_counter % inference_bs == 0 or tot_counter == len(ij_list):
y_pr = self.model_reading_order.predict(input_1 , verbose=0) y_pr = self.model_zoo.get('reading_order', Model).predict(input_1 , verbose='0')
for jb, j in enumerate(batch): for jb, j in enumerate(batch):
if y_pr[jb][0]>=0.5: if y_pr[jb][0]>=0.5:
post_list.append(j) post_list.append(j)
@ -802,6 +793,7 @@ class machine_based_reading_order_on_layout:
alltags=[elem.tag for elem in root_xml.iter()] alltags=[elem.tag for elem in root_xml.iter()]
ET.register_namespace("",name_space) ET.register_namespace("",name_space)
assert dir_out
tree_xml.write(os.path.join(dir_out, file_name+'.xml'), tree_xml.write(os.path.join(dir_out, file_name+'.xml'),
xml_declaration=True, xml_declaration=True,
method='xml', method='xml',