mirror of
https://github.com/qurator-spk/eynollah.git
synced 2025-10-26 23:34:13 +01:00
adopt mb_ro_on_layout to the zoo
This commit is contained in:
parent
bcffa2e503
commit
f0c86672f8
2 changed files with 11 additions and 19 deletions
|
|
@ -12,7 +12,7 @@ from difflib import SequenceMatcher as sq
|
|||
import math
|
||||
import os
|
||||
import time
|
||||
from typing import Dict, Type, Union,List, Optional, Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
import warnings
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
|
|
|
|||
|
|
@ -10,12 +10,13 @@ from pathlib import Path
|
|||
import xml.etree.ElementTree as ET
|
||||
|
||||
import cv2
|
||||
from keras.models import Model
|
||||
import numpy as np
|
||||
from ocrd_utils import getLogger
|
||||
import statistics
|
||||
import tensorflow as tf
|
||||
from tensorflow.keras.models import load_model
|
||||
|
||||
from .model_zoo import EynollahModelZoo
|
||||
from .utils.resize import resize_image
|
||||
from .utils.contour import (
|
||||
find_new_features_of_contours,
|
||||
|
|
@ -23,7 +24,6 @@ from .utils.contour import (
|
|||
return_parent_contours,
|
||||
)
|
||||
from .utils import is_xml_filename
|
||||
from .patch_encoder import PatchEncoder, Patches
|
||||
|
||||
DPI_THRESHOLD = 298
|
||||
KERNEL = np.ones((5, 5), np.uint8)
|
||||
|
|
@ -45,21 +45,11 @@ class machine_based_reading_order_on_layout:
|
|||
except:
|
||||
self.logger.warning("no GPU device available")
|
||||
|
||||
self.model_reading_order = self.our_load_model(self.model_reading_order_dir)
|
||||
self.model_zoo = EynollahModelZoo(basedir=dir_models)
|
||||
self.model_zoo.load_model('reading_order')
|
||||
# FIXME: light_version is always true, no need for checks in the code
|
||||
self.light_version = True
|
||||
|
||||
@staticmethod
|
||||
def our_load_model(model_file):
|
||||
if model_file.endswith('.h5') and Path(model_file[:-3]).exists():
|
||||
# prefer SavedModel over HDF5 format if it exists
|
||||
model_file = model_file[:-3]
|
||||
try:
|
||||
model = load_model(model_file, compile=False)
|
||||
except:
|
||||
model = load_model(model_file, compile=False, custom_objects={
|
||||
"PatchEncoder": PatchEncoder, "Patches": Patches})
|
||||
return model
|
||||
|
||||
def read_xml(self, xml_file):
|
||||
tree1 = ET.parse(xml_file, parser = ET.XMLParser(encoding='utf-8'))
|
||||
root1=tree1.getroot()
|
||||
|
|
@ -69,6 +59,7 @@ class machine_based_reading_order_on_layout:
|
|||
index_tot_regions = []
|
||||
tot_region_ref = []
|
||||
|
||||
y_len, x_len = 0, 0
|
||||
for jj in root1.iter(link+'Page'):
|
||||
y_len=int(jj.attrib['imageHeight'])
|
||||
x_len=int(jj.attrib['imageWidth'])
|
||||
|
|
@ -81,13 +72,13 @@ class machine_based_reading_order_on_layout:
|
|||
co_printspace = []
|
||||
if link+'PrintSpace' in alltags:
|
||||
region_tags_printspace = np.unique([x for x in alltags if x.endswith('PrintSpace')])
|
||||
elif link+'Border' in alltags:
|
||||
else:
|
||||
region_tags_printspace = np.unique([x for x in alltags if x.endswith('Border')])
|
||||
|
||||
for tag in region_tags_printspace:
|
||||
if link+'PrintSpace' in alltags:
|
||||
tag_endings_printspace = ['}PrintSpace','}printspace']
|
||||
elif link+'Border' in alltags:
|
||||
else:
|
||||
tag_endings_printspace = ['}Border','}border']
|
||||
|
||||
if tag.endswith(tag_endings_printspace[0]) or tag.endswith(tag_endings_printspace[1]):
|
||||
|
|
@ -683,7 +674,7 @@ class machine_based_reading_order_on_layout:
|
|||
tot_counter += 1
|
||||
batch.append(j)
|
||||
if tot_counter % inference_bs == 0 or tot_counter == len(ij_list):
|
||||
y_pr = self.model_reading_order.predict(input_1 , verbose=0)
|
||||
y_pr = self.model_zoo.get('reading_order', Model).predict(input_1 , verbose='0')
|
||||
for jb, j in enumerate(batch):
|
||||
if y_pr[jb][0]>=0.5:
|
||||
post_list.append(j)
|
||||
|
|
@ -802,6 +793,7 @@ class machine_based_reading_order_on_layout:
|
|||
alltags=[elem.tag for elem in root_xml.iter()]
|
||||
|
||||
ET.register_namespace("",name_space)
|
||||
assert dir_out
|
||||
tree_xml.write(os.path.join(dir_out, file_name+'.xml'),
|
||||
xml_declaration=True,
|
||||
method='xml',
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue