diff --git a/src/eynollah/eynollah.py b/src/eynollah/eynollah.py index 6356198..4d1644d 100644 --- a/src/eynollah/eynollah.py +++ b/src/eynollah/eynollah.py @@ -12,7 +12,7 @@ from difflib import SequenceMatcher as sq import math import os import time -from typing import Dict, Type, Union,List, Optional, Tuple +from typing import List, Optional, Tuple import warnings from functools import partial from pathlib import Path diff --git a/src/eynollah/mb_ro_on_layout.py b/src/eynollah/mb_ro_on_layout.py index 0a8a7ae..8338d35 100644 --- a/src/eynollah/mb_ro_on_layout.py +++ b/src/eynollah/mb_ro_on_layout.py @@ -10,12 +10,13 @@ from pathlib import Path import xml.etree.ElementTree as ET import cv2 +from keras.models import Model import numpy as np from ocrd_utils import getLogger import statistics import tensorflow as tf -from tensorflow.keras.models import load_model +from .model_zoo import EynollahModelZoo from .utils.resize import resize_image from .utils.contour import ( find_new_features_of_contours, @@ -23,7 +24,6 @@ from .utils.contour import ( return_parent_contours, ) from .utils import is_xml_filename -from .patch_encoder import PatchEncoder, Patches DPI_THRESHOLD = 298 KERNEL = np.ones((5, 5), np.uint8) @@ -45,21 +45,11 @@ class machine_based_reading_order_on_layout: except: self.logger.warning("no GPU device available") - self.model_reading_order = self.our_load_model(self.model_reading_order_dir) + self.model_zoo = EynollahModelZoo(basedir=dir_models) + self.model_zoo.load_model('reading_order') + # FIXME: light_version is always true, no need for checks in the code self.light_version = True - @staticmethod - def our_load_model(model_file): - if model_file.endswith('.h5') and Path(model_file[:-3]).exists(): - # prefer SavedModel over HDF5 format if it exists - model_file = model_file[:-3] - try: - model = load_model(model_file, compile=False) - except: - model = load_model(model_file, compile=False, custom_objects={ - "PatchEncoder": PatchEncoder, "Patches": Patches}) - return model - def read_xml(self, xml_file): tree1 = ET.parse(xml_file, parser = ET.XMLParser(encoding='utf-8')) root1=tree1.getroot() @@ -69,6 +59,7 @@ class machine_based_reading_order_on_layout: index_tot_regions = [] tot_region_ref = [] + y_len, x_len = 0, 0 for jj in root1.iter(link+'Page'): y_len=int(jj.attrib['imageHeight']) x_len=int(jj.attrib['imageWidth']) @@ -81,13 +72,13 @@ class machine_based_reading_order_on_layout: co_printspace = [] if link+'PrintSpace' in alltags: region_tags_printspace = np.unique([x for x in alltags if x.endswith('PrintSpace')]) - elif link+'Border' in alltags: + else: region_tags_printspace = np.unique([x for x in alltags if x.endswith('Border')]) for tag in region_tags_printspace: if link+'PrintSpace' in alltags: tag_endings_printspace = ['}PrintSpace','}printspace'] - elif link+'Border' in alltags: + else: tag_endings_printspace = ['}Border','}border'] if tag.endswith(tag_endings_printspace[0]) or tag.endswith(tag_endings_printspace[1]): @@ -683,7 +674,7 @@ class machine_based_reading_order_on_layout: tot_counter += 1 batch.append(j) if tot_counter % inference_bs == 0 or tot_counter == len(ij_list): - y_pr = self.model_reading_order.predict(input_1 , verbose=0) + y_pr = self.model_zoo.get('reading_order', Model).predict(input_1 , verbose='0') for jb, j in enumerate(batch): if y_pr[jb][0]>=0.5: post_list.append(j) @@ -802,6 +793,7 @@ class machine_based_reading_order_on_layout: alltags=[elem.tag for elem in root_xml.iter()] ET.register_namespace("",name_space) + assert dir_out tree_xml.write(os.path.join(dir_out, file_name+'.xml'), xml_declaration=True, method='xml',