mirror of
https://github.com/qurator-spk/eynollah.git
synced 2025-06-15 07:09:55 +02:00
Merge pull request #1 from bertsky/v3-api-refactor-init
refactoring of Eynollah init and model loading
This commit is contained in:
commit
1a0b9d1958
10 changed files with 690 additions and 826 deletions
5
Makefile
5
Makefile
|
@ -77,9 +77,14 @@ deps-test: models_eynollah
|
|||
|
||||
smoke-test: TMPDIR != mktemp -d
|
||||
smoke-test: tests/resources/kant_aufklaerung_1784_0020.tif
|
||||
# layout analysis:
|
||||
eynollah layout -i $< -o $(TMPDIR) -m $(CURDIR)/models_eynollah
|
||||
fgrep -q http://schema.primaresearch.org/PAGE/gts/pagecontent/2019-07-15 $(TMPDIR)/$(basename $(<F)).xml
|
||||
fgrep -c -e TextRegion -e ImageRegion -e SeparatorRegion $(TMPDIR)/$(basename $(<F)).xml
|
||||
# directory mode (skip one, add one):
|
||||
eynollah layout -di $(<D) -o $(TMPDIR) -m $(CURDIR)/models_eynollah
|
||||
test -s $(TMPDIR)/euler_rechenkunst01_1738_0025.xml
|
||||
# binarize:
|
||||
eynollah binarization -m $(CURDIR)/default-2021-03-09 $< $(TMPDIR)/$(<F)
|
||||
test -s $(TMPDIR)/$(<F)
|
||||
@set -x; test "$$(identify -format '%w %h' $<)" = "$$(identify -format '%w %h' $(TMPDIR)/$(<F))"
|
||||
|
|
25
README.md
25
README.md
|
@ -83,23 +83,28 @@ If no option is set, the tool performs layout detection of main regions (backgro
|
|||
The best output quality is produced when RGB images are used as input rather than greyscale or binarized images.
|
||||
|
||||
#### Use as OCR-D processor
|
||||
🚧 **Work in progress**
|
||||
|
||||
Eynollah ships with a CLI interface to be used as [OCR-D](https://ocr-d.de) processor.
|
||||
Eynollah ships with a CLI interface to be used as [OCR-D](https://ocr-d.de) [processor](https://ocr-d.de/en/spec/cli).
|
||||
|
||||
In this case, the source image file group with (preferably) RGB images should be used as input like this:
|
||||
|
||||
```
|
||||
ocrd-eynollah-segment -I OCR-D-IMG -O SEG-LINE -P models
|
||||
```
|
||||
ocrd-eynollah-segment -I OCR-D-IMG -O OCR-D-SEG -P models 2022-04-05
|
||||
|
||||
Any image referenced by `@imageFilename` in PAGE-XML is passed on directly to Eynollah as a processor, so that e.g.
|
||||
|
||||
```
|
||||
ocrd-eynollah-segment -I OCR-D-IMG-BIN -O SEG-LINE -P models
|
||||
```
|
||||
If the input file group is PAGE-XML (from a previous OCR-D workflow step), Eynollah behaves as follows:
|
||||
- existing regions are kept and ignored (i.e. in effect they might overlap segments from Eynollah results)
|
||||
- existing annotation (and respective `AlternativeImage`s) are partially _ignored_:
|
||||
- previous page frame detection (`cropped` images)
|
||||
- previous derotation (`deskewed` images)
|
||||
- previous thresholding (`binarized` images)
|
||||
- if the page-level image nevertheless deviates from the original (`@imageFilename`)
|
||||
(because some other preprocessing step was in effect like `denoised`), then
|
||||
the output PAGE-XML will be based on that as new top-level (`@imageFilename`)
|
||||
|
||||
uses the original (RGB) image despite any binarization that may have occured in previous OCR-D processing steps
|
||||
|
||||
ocrd-eynollah-segment -I OCR-D-XYZ -O OCR-D-SEG -P models 2022-04-05
|
||||
|
||||
Still, in general, it makes more sense to add other workflow steps **after** Eynollah.
|
||||
|
||||
#### Additional documentation
|
||||
Please check the [wiki](https://github.com/qurator-spk/eynollah/wiki).
|
||||
|
|
|
@ -256,26 +256,37 @@ def layout(image, out, overwrite, dir_in, model, save_images, save_layout, save_
|
|||
if log_level:
|
||||
getLogger('eynollah').setLevel(getLevelName(log_level))
|
||||
if not enable_plotting and (save_layout or save_deskewed or save_all or save_page or save_images or allow_enhancement):
|
||||
print("Error: You used one of -sl, -sd, -sa, -sp, -si or -ae but did not enable plotting with -ep")
|
||||
sys.exit(1)
|
||||
raise ValueError("Plotting with -sl, -sd, -sa, -sp, -si or -ae also requires -ep")
|
||||
elif enable_plotting and not (save_layout or save_deskewed or save_all or save_page or save_images or allow_enhancement):
|
||||
print("Error: You used -ep to enable plotting but set none of -sl, -sd, -sa, -sp, -si or -ae")
|
||||
sys.exit(1)
|
||||
raise ValueError("Plotting with -ep also requires -sl, -sd, -sa, -sp, -si or -ae")
|
||||
if textline_light and not light_version:
|
||||
print('Error: You used -tll to enable light textline detection but -light is not enabled')
|
||||
sys.exit(1)
|
||||
raise ValueError("Light textline detection with -tll also requires -light")
|
||||
if light_version and not textline_light:
|
||||
print('Error: You used -light without -tll. Light version need light textline to be enabled.')
|
||||
if extract_only_images and (allow_enhancement or allow_scaling or light_version or curved_line or textline_light or full_layout or tables or right2left or headers_off) :
|
||||
print('Error: You used -eoi which can not be enabled alongside light_version -light or allow_scaling -as or allow_enhancement -ae or curved_line -cl or textline_light -tll or full_layout -fl or tables -tab or right2left -r2l or headers_off -ho')
|
||||
sys.exit(1)
|
||||
raise ValueError("Light version with -light also requires light textline detection -tll")
|
||||
if extract_only_images and allow_enhancement:
|
||||
raise ValueError("Image extraction with -eoi can not be enabled alongside allow_enhancement -ae")
|
||||
if extract_only_images and allow_scaling:
|
||||
raise ValueError("Image extraction with -eoi can not be enabled alongside allow_scaling -as")
|
||||
if extract_only_images and light_version:
|
||||
raise ValueError("Image extraction with -eoi can not be enabled alongside light_version -light")
|
||||
if extract_only_images and curved_line:
|
||||
raise ValueError("Image extraction with -eoi can not be enabled alongside curved_line -cl")
|
||||
if extract_only_images and textline_light:
|
||||
raise ValueError("Image extraction with -eoi can not be enabled alongside textline_light -tll")
|
||||
if extract_only_images and full_layout:
|
||||
raise ValueError("Image extraction with -eoi can not be enabled alongside full_layout -fl")
|
||||
if extract_only_images and tables:
|
||||
raise ValueError("Image extraction with -eoi can not be enabled alongside tables -tab")
|
||||
if extract_only_images and right2left:
|
||||
raise ValueError("Image extraction with -eoi can not be enabled alongside right2left -r2l")
|
||||
if extract_only_images and headers_off:
|
||||
raise ValueError("Image extraction with -eoi can not be enabled alongside headers_off -ho")
|
||||
if image is None and dir_in is None:
|
||||
raise ValueError("Either a single image -i or a dir_in -di is required")
|
||||
eynollah = Eynollah(
|
||||
model,
|
||||
logger=getLogger('Eynollah'),
|
||||
image_filename=image,
|
||||
overwrite=overwrite,
|
||||
logger=getLogger('eynollah'),
|
||||
dir_out=out,
|
||||
dir_in=dir_in,
|
||||
dir_of_cropped_images=save_images,
|
||||
extract_only_images=extract_only_images,
|
||||
dir_of_layout=save_layout,
|
||||
|
@ -301,10 +312,9 @@ def layout(image, out, overwrite, dir_in, model, save_images, save_layout, save_
|
|||
skip_layout_and_reading_order=skip_layout_and_reading_order,
|
||||
)
|
||||
if dir_in:
|
||||
eynollah.run()
|
||||
eynollah.run(dir_in=dir_in, overwrite=overwrite)
|
||||
else:
|
||||
pcgts = eynollah.run()
|
||||
eynollah.writer.write_pagexml(pcgts)
|
||||
eynollah.run(image_filename=image, overwrite=overwrite)
|
||||
|
||||
|
||||
@main.command()
|
||||
|
|
|
@ -32,7 +32,7 @@ from scipy.ndimage import gaussian_filter1d
|
|||
from numba import cuda
|
||||
|
||||
from ocrd import OcrdPage
|
||||
from ocrd_utils import getLogger
|
||||
from ocrd_utils import getLogger, tf_disable_interactive_logs
|
||||
|
||||
try:
|
||||
import torch
|
||||
|
@ -47,14 +47,11 @@ try:
|
|||
except ImportError:
|
||||
TrOCRProcessor = VisionEncoderDecoderModel = None
|
||||
|
||||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
||||
#os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
|
||||
stderr = sys.stderr
|
||||
sys.stderr = open(os.devnull, "w")
|
||||
tf_disable_interactive_logs()
|
||||
import tensorflow as tf
|
||||
from tensorflow.python.keras import backend as K
|
||||
from tensorflow.keras.models import load_model
|
||||
sys.stderr = stderr
|
||||
tf.get_logger().setLevel("ERROR")
|
||||
warnings.filterwarnings("ignore")
|
||||
# use tf1 compatibility for keras backend
|
||||
|
@ -180,12 +177,7 @@ class Eynollah:
|
|||
def __init__(
|
||||
self,
|
||||
dir_models : str,
|
||||
image_filename : Optional[str] = None,
|
||||
image_pil : Optional[Image] = None,
|
||||
image_filename_stem : Optional[str] = None,
|
||||
overwrite : bool = False,
|
||||
dir_out : Optional[str] = None,
|
||||
dir_in : Optional[str] = None,
|
||||
dir_of_cropped_images : Optional[str] = None,
|
||||
extract_only_images : bool =False,
|
||||
dir_of_layout : Optional[str] = None,
|
||||
|
@ -209,24 +201,12 @@ class Eynollah:
|
|||
num_col_upper : Optional[int] = None,
|
||||
num_col_lower : Optional[int] = None,
|
||||
skip_layout_and_reading_order : bool = False,
|
||||
override_dpi : Optional[int] = None,
|
||||
logger : Logger = None,
|
||||
pcgts : Optional[OcrdPage] = None,
|
||||
):
|
||||
if skip_layout_and_reading_order:
|
||||
textline_light = True
|
||||
self.light_version = light_version
|
||||
if not dir_in:
|
||||
if image_pil:
|
||||
self._imgs = self._cache_images(image_pil=image_pil)
|
||||
else:
|
||||
self._imgs = self._cache_images(image_filename=image_filename)
|
||||
if override_dpi:
|
||||
self.dpi = override_dpi
|
||||
self.image_filename = image_filename
|
||||
self.overwrite = overwrite
|
||||
self.dir_out = dir_out
|
||||
self.dir_in = dir_in
|
||||
self.dir_of_all = dir_of_all
|
||||
self.dir_save_page = dir_save_page
|
||||
self.reading_order_machine_based = reading_order_machine_based
|
||||
|
@ -257,22 +237,6 @@ class Eynollah:
|
|||
self.num_col_lower = int(num_col_lower)
|
||||
else:
|
||||
self.num_col_lower = num_col_lower
|
||||
self.pcgts = pcgts
|
||||
if not dir_in:
|
||||
self.plotter = None if not enable_plotting else EynollahPlotter(
|
||||
dir_out=self.dir_out,
|
||||
dir_of_all=dir_of_all,
|
||||
dir_save_page=dir_save_page,
|
||||
dir_of_deskewed=dir_of_deskewed,
|
||||
dir_of_cropped_images=dir_of_cropped_images,
|
||||
dir_of_layout=dir_of_layout,
|
||||
image_filename_stem=Path(Path(image_filename).name).stem)
|
||||
self.writer = EynollahXmlWriter(
|
||||
dir_out=self.dir_out,
|
||||
image_filename=self.image_filename,
|
||||
curved_line=self.curved_line,
|
||||
textline_light = self.textline_light,
|
||||
pcgts=pcgts)
|
||||
self.logger = logger if logger else getLogger('eynollah')
|
||||
# for parallelization of CPU-intensive tasks:
|
||||
self.executor = ProcessPoolExecutor(max_workers=cpu_count(), timeout=1200)
|
||||
|
@ -324,21 +288,25 @@ class Eynollah:
|
|||
self.model_textline_dir = dir_models + "/modelens_textline_0_1__2_4_16092024"
|
||||
if self.ocr:
|
||||
self.model_ocr_dir = dir_models + "/trocr_model_ens_of_3_checkpoints_201124"
|
||||
|
||||
if self.tables:
|
||||
if self.light_version:
|
||||
self.model_table_dir = dir_models + "/modelens_table_0t4_201124"
|
||||
else:
|
||||
self.model_table_dir = dir_models + "/eynollah-tables_20210319"
|
||||
|
||||
self.models = {}
|
||||
|
||||
if dir_in:
|
||||
# as in start_new_session:
|
||||
config = tf.compat.v1.ConfigProto()
|
||||
config.gpu_options.allow_growth = True
|
||||
session = tf.compat.v1.Session(config=config)
|
||||
set_session(session)
|
||||
# #gpu_options = tf.compat.v1.GPUOptions(allow_growth=True)
|
||||
# #gpu_options = tf.compat.v1.GPUOptions(per_process_gpu_memory_fraction=7.7, allow_growth=True)
|
||||
# #session = tf.compat.v1.Session(config=tf.compat.v1.ConfigProto(gpu_options=gpu_options))
|
||||
# config = tf.compat.v1.ConfigProto()
|
||||
# config.gpu_options.allow_growth = True
|
||||
# #session = tf.InteractiveSession()
|
||||
# session = tf.compat.v1.Session(config=config)
|
||||
# set_session(session)
|
||||
try:
|
||||
for device in tf.config.list_physical_devices('GPU'):
|
||||
tf.config.experimental.set_memory_growth(device, True)
|
||||
except:
|
||||
self.logger.warning("no GPU device available")
|
||||
|
||||
self.model_page = self.our_load_model(self.model_page_dir)
|
||||
self.model_classifier = self.our_load_model(self.model_dir_of_col_classifier)
|
||||
|
@ -367,9 +335,7 @@ class Eynollah:
|
|||
if self.tables:
|
||||
self.model_table = self.our_load_model(self.model_table_dir)
|
||||
|
||||
self.ls_imgs = os.listdir(self.dir_in)
|
||||
|
||||
def _cache_images(self, image_filename=None, image_pil=None):
|
||||
def cache_images(self, image_filename=None, image_pil=None, dpi=None):
|
||||
ret = {}
|
||||
t_c0 = time.time()
|
||||
if image_filename:
|
||||
|
@ -387,12 +353,13 @@ class Eynollah:
|
|||
ret['img_grayscale'] = cv2.cvtColor(ret['img'], cv2.COLOR_BGR2GRAY)
|
||||
for prefix in ('', '_grayscale'):
|
||||
ret[f'img{prefix}_uint8'] = ret[f'img{prefix}'].astype(np.uint8)
|
||||
return ret
|
||||
self._imgs = ret
|
||||
if dpi is not None:
|
||||
self.dpi = dpi
|
||||
|
||||
def reset_file_name_dir(self, image_filename):
|
||||
t_c = time.time()
|
||||
self._imgs = self._cache_images(image_filename=image_filename)
|
||||
self.image_filename = image_filename
|
||||
self.cache_images(image_filename=image_filename)
|
||||
|
||||
self.plotter = None if not self.enable_plotting else EynollahPlotter(
|
||||
dir_out=self.dir_out,
|
||||
|
@ -405,10 +372,9 @@ class Eynollah:
|
|||
|
||||
self.writer = EynollahXmlWriter(
|
||||
dir_out=self.dir_out,
|
||||
image_filename=self.image_filename,
|
||||
image_filename=image_filename,
|
||||
curved_line=self.curved_line,
|
||||
textline_light = self.textline_light,
|
||||
pcgts=self.pcgts)
|
||||
textline_light = self.textline_light)
|
||||
|
||||
def imread(self, grayscale=False, uint8=True):
|
||||
key = 'img'
|
||||
|
@ -423,8 +389,6 @@ class Eynollah:
|
|||
|
||||
def predict_enhancement(self, img):
|
||||
self.logger.debug("enter predict_enhancement")
|
||||
if not self.dir_in:
|
||||
self.model_enhancement, _ = self.start_new_session_and_model(self.model_dir_of_enhancement)
|
||||
|
||||
img_height_model = self.model_enhancement.layers[-1].output_shape[1]
|
||||
img_width_model = self.model_enhancement.layers[-1].output_shape[2]
|
||||
|
@ -622,9 +586,6 @@ class Eynollah:
|
|||
|
||||
_, page_coord = self.early_page_for_num_of_column_classification(img)
|
||||
|
||||
if not self.dir_in:
|
||||
self.model_classifier, _ = self.start_new_session_and_model(self.model_dir_of_col_classifier)
|
||||
|
||||
if self.input_binary:
|
||||
img_in = np.copy(img)
|
||||
img_in = img_in / 255.0
|
||||
|
@ -664,9 +625,6 @@ class Eynollah:
|
|||
self.logger.info("Detected %s DPI", dpi)
|
||||
if self.input_binary:
|
||||
img = self.imread()
|
||||
if not self.dir_in:
|
||||
self.model_bin, _ = self.start_new_session_and_model(self.model_dir_of_binarization)
|
||||
|
||||
prediction_bin = self.do_prediction(True, img, self.model_bin, n_batch_inference=5)
|
||||
prediction_bin = 255 * (prediction_bin[:,:,0]==0)
|
||||
prediction_bin = np.repeat(prediction_bin[:, :, np.newaxis], 3, axis=2).astype(np.uint8)
|
||||
|
@ -683,9 +641,6 @@ class Eynollah:
|
|||
self.image_page_org_size = img[page_coord[0] : page_coord[1], page_coord[2] : page_coord[3], :]
|
||||
self.page_coord = page_coord
|
||||
|
||||
if not self.dir_in:
|
||||
self.model_classifier, _ = self.start_new_session_and_model(self.model_dir_of_col_classifier)
|
||||
|
||||
if self.num_col_upper and not self.num_col_lower:
|
||||
num_col = self.num_col_upper
|
||||
label_p_pred = [np.ones(6)]
|
||||
|
@ -825,43 +780,6 @@ class Eynollah:
|
|||
self.writer.height_org = self.height_org
|
||||
self.writer.width_org = self.width_org
|
||||
|
||||
def start_new_session_and_model_old(self, model_dir):
|
||||
self.logger.debug("enter start_new_session_and_model (model_dir=%s)", model_dir)
|
||||
config = tf.ConfigProto()
|
||||
config.gpu_options.allow_growth = True
|
||||
|
||||
session = tf.InteractiveSession()
|
||||
model = load_model(model_dir, compile=False)
|
||||
|
||||
return model, session
|
||||
|
||||
def start_new_session_and_model(self, model_dir):
|
||||
self.logger.debug("enter start_new_session_and_model (model_dir=%s)", model_dir)
|
||||
#gpu_options = tf.compat.v1.GPUOptions(allow_growth=True)
|
||||
#gpu_options = tf.compat.v1.GPUOptions(per_process_gpu_memory_fraction=7.7, allow_growth=True)
|
||||
#session = tf.compat.v1.Session(config=tf.compat.v1.ConfigProto(gpu_options=gpu_options))
|
||||
physical_devices = tf.config.list_physical_devices('GPU')
|
||||
try:
|
||||
for device in physical_devices:
|
||||
tf.config.experimental.set_memory_growth(device, True)
|
||||
except:
|
||||
self.logger.warning("no GPU device available")
|
||||
|
||||
if model_dir.endswith('.h5') and Path(model_dir[:-3]).exists():
|
||||
# prefer SavedModel over HDF5 format if it exists
|
||||
model_dir = model_dir[:-3]
|
||||
if model_dir in self.models:
|
||||
model = self.models[model_dir]
|
||||
else:
|
||||
try:
|
||||
model = load_model(model_dir, compile=False)
|
||||
except:
|
||||
model = load_model(model_dir , compile=False, custom_objects={
|
||||
"PatchEncoder": PatchEncoder, "Patches": Patches})
|
||||
self.models[model_dir] = model
|
||||
|
||||
return model, None
|
||||
|
||||
def do_prediction(
|
||||
self, patches, img, model,
|
||||
n_batch_inference=1, marginal_of_patch_percent=0.1,
|
||||
|
@ -1399,9 +1317,6 @@ class Eynollah:
|
|||
self.logger.debug("enter extract_page")
|
||||
cont_page = []
|
||||
if not self.ignore_page_extraction:
|
||||
if not self.dir_in:
|
||||
self.model_page, _ = self.start_new_session_and_model(self.model_page_dir)
|
||||
|
||||
img = cv2.GaussianBlur(self.image, (5, 5), 0)
|
||||
img_page_prediction = self.do_prediction(False, img, self.model_page)
|
||||
imgray = cv2.cvtColor(img_page_prediction, cv2.COLOR_BGR2GRAY)
|
||||
|
@ -1449,8 +1364,6 @@ class Eynollah:
|
|||
img = np.copy(img_bin).astype(np.uint8)
|
||||
else:
|
||||
img = self.imread()
|
||||
if not self.dir_in:
|
||||
self.model_page, _ = self.start_new_session_and_model(self.model_page_dir)
|
||||
img = cv2.GaussianBlur(img, (5, 5), 0)
|
||||
img_page_prediction = self.do_prediction(False, img, self.model_page)
|
||||
|
||||
|
@ -1478,11 +1391,6 @@ class Eynollah:
|
|||
self.logger.debug("enter extract_text_regions")
|
||||
img_height_h = img.shape[0]
|
||||
img_width_h = img.shape[1]
|
||||
if not self.dir_in:
|
||||
if patches:
|
||||
self.model_region_fl, _ = self.start_new_session_and_model(self.model_region_dir_fully)
|
||||
else:
|
||||
self.model_region_fl_np, _ = self.start_new_session_and_model(self.model_region_dir_fully_np)
|
||||
model_region = self.model_region_fl if patches else self.model_region_fl_np
|
||||
|
||||
if self.light_version:
|
||||
|
@ -1514,11 +1422,6 @@ class Eynollah:
|
|||
self.logger.debug("enter extract_text_regions")
|
||||
img_height_h = img.shape[0]
|
||||
img_width_h = img.shape[1]
|
||||
if not self.dir_in:
|
||||
if patches:
|
||||
self.model_region_fl, _ = self.start_new_session_and_model(self.model_region_dir_fully)
|
||||
else:
|
||||
self.model_region_fl_np, _ = self.start_new_session_and_model(self.model_region_dir_fully_np)
|
||||
model_region = self.model_region_fl if patches else self.model_region_fl_np
|
||||
|
||||
if not patches:
|
||||
|
@ -1649,8 +1552,6 @@ class Eynollah:
|
|||
|
||||
def textline_contours(self, img, use_patches, scaler_h, scaler_w, num_col_classifier=None):
|
||||
self.logger.debug('enter textline_contours')
|
||||
if not self.dir_in:
|
||||
self.model_textline, _ = self.start_new_session_and_model(self.model_textline_dir)
|
||||
|
||||
#img = img.astype(np.uint8)
|
||||
img_org = np.copy(img)
|
||||
|
@ -1752,9 +1653,6 @@ class Eynollah:
|
|||
img_h_new = int(img.shape[0] / float(img.shape[1]) * img_w_new)
|
||||
img_resized = resize_image(img,img_h_new, img_w_new )
|
||||
|
||||
if not self.dir_in:
|
||||
self.model_region, _ = self.start_new_session_and_model(self.model_region_dir_p_ens_light_only_images_extraction)
|
||||
|
||||
prediction_regions_org = self.do_prediction_new_concept(True, img_resized, self.model_region)
|
||||
|
||||
prediction_regions_org = resize_image(prediction_regions_org,img_height_h, img_width_h )
|
||||
|
@ -1843,7 +1741,6 @@ class Eynollah:
|
|||
img_height_h = img_org.shape[0]
|
||||
img_width_h = img_org.shape[1]
|
||||
|
||||
#model_region, _ = self.start_new_session_and_model(self.model_region_dir_p_ens)
|
||||
#print(num_col_classifier,'num_col_classifier')
|
||||
|
||||
if num_col_classifier == 1:
|
||||
|
@ -1866,8 +1763,6 @@ class Eynollah:
|
|||
#if self.input_binary:
|
||||
#img_bin = np.copy(img_resized)
|
||||
###if (not self.input_binary and self.full_layout) or (not self.input_binary and num_col_classifier >= 30):
|
||||
###if not self.dir_in:
|
||||
###self.model_bin, _ = self.start_new_session_and_model(self.model_dir_of_binarization)
|
||||
###prediction_bin = self.do_prediction(True, img_resized, self.model_bin, n_batch_inference=5)
|
||||
|
||||
####print("inside bin ", time.time()-t_bin)
|
||||
|
@ -1883,8 +1778,6 @@ class Eynollah:
|
|||
###else:
|
||||
###img_bin = np.copy(img_resized)
|
||||
if self.ocr and not self.input_binary:
|
||||
if not self.dir_in:
|
||||
self.model_bin, _ = self.start_new_session_and_model(self.model_dir_of_binarization)
|
||||
prediction_bin = self.do_prediction(True, img_resized, self.model_bin, n_batch_inference=5)
|
||||
prediction_bin = 255 * (prediction_bin[:,:,0] == 0)
|
||||
prediction_bin = np.repeat(prediction_bin[:, :, np.newaxis], 3, axis=2)
|
||||
|
@ -1907,12 +1800,7 @@ class Eynollah:
|
|||
#plt.show()
|
||||
if not skip_layout_and_reading_order:
|
||||
#print("inside 2 ", time.time()-t_in)
|
||||
if not self.dir_in:
|
||||
self.model_region_1_2, _ = self.start_new_session_and_model(self.model_region_dir_p_1_2_sp_np)
|
||||
##self.model_region, _ = self.start_new_session_and_model(self.model_region_dir_p_ens_light)
|
||||
|
||||
if num_col_classifier == 1 or num_col_classifier == 2:
|
||||
model_region, session_region = self.start_new_session_and_model(self.model_region_dir_p_1_2_sp_np)
|
||||
if self.image_org.shape[0]/self.image_org.shape[1] > 2.5:
|
||||
self.logger.debug("resized to %dx%d for %d cols",
|
||||
img_resized.shape[1], img_resized.shape[0], num_col_classifier)
|
||||
|
@ -2011,9 +1899,6 @@ class Eynollah:
|
|||
img_height_h = img_org.shape[0]
|
||||
img_width_h = img_org.shape[1]
|
||||
|
||||
if not self.dir_in:
|
||||
self.model_region, _ = self.start_new_session_and_model(self.model_region_dir_p_ens)
|
||||
|
||||
ratio_y=1.3
|
||||
ratio_x=1
|
||||
|
||||
|
@ -2039,9 +1924,6 @@ class Eynollah:
|
|||
prediction_regions_org=prediction_regions_org[:,:,0]
|
||||
prediction_regions_org[(prediction_regions_org[:,:]==1) & (mask_zeros_y[:,:]==1)]=0
|
||||
|
||||
if not self.dir_in:
|
||||
self.model_region_p2, _ = self.start_new_session_and_model(self.model_region_dir_p2)
|
||||
|
||||
img = resize_image(img_org, int(img_org.shape[0]), int(img_org.shape[1]))
|
||||
|
||||
prediction_regions_org2 = self.do_prediction(True, img, self.model_region_p2, marginal_of_patch_percent=0.2)
|
||||
|
@ -2068,15 +1950,11 @@ class Eynollah:
|
|||
if self.input_binary:
|
||||
prediction_bin = np.copy(img_org)
|
||||
else:
|
||||
if not self.dir_in:
|
||||
self.model_bin, _ = self.start_new_session_and_model(self.model_dir_of_binarization)
|
||||
prediction_bin = self.do_prediction(True, img_org, self.model_bin, n_batch_inference=5)
|
||||
prediction_bin = resize_image(prediction_bin, img_height_h, img_width_h )
|
||||
prediction_bin = 255 * (prediction_bin[:,:,0]==0)
|
||||
prediction_bin = np.repeat(prediction_bin[:, :, np.newaxis], 3, axis=2)
|
||||
|
||||
if not self.dir_in:
|
||||
self.model_region, _ = self.start_new_session_and_model(self.model_region_dir_p_ens)
|
||||
ratio_y=1
|
||||
ratio_x=1
|
||||
|
||||
|
@ -2109,17 +1987,10 @@ class Eynollah:
|
|||
except:
|
||||
if self.input_binary:
|
||||
prediction_bin = np.copy(img_org)
|
||||
|
||||
if not self.dir_in:
|
||||
self.model_bin, _ = self.start_new_session_and_model(self.model_dir_of_binarization)
|
||||
prediction_bin = self.do_prediction(True, img_org, self.model_bin, n_batch_inference=5)
|
||||
prediction_bin = resize_image(prediction_bin, img_height_h, img_width_h )
|
||||
prediction_bin = 255 * (prediction_bin[:,:,0]==0)
|
||||
prediction_bin = np.repeat(prediction_bin[:, :, np.newaxis], 3, axis=2)
|
||||
|
||||
if not self.dir_in:
|
||||
self.model_region, _ = self.start_new_session_and_model(self.model_region_dir_p_ens)
|
||||
|
||||
else:
|
||||
prediction_bin = np.copy(img_org)
|
||||
ratio_y=1
|
||||
|
@ -2749,10 +2620,6 @@ class Eynollah:
|
|||
img_org = np.copy(img)
|
||||
img_height_h = img_org.shape[0]
|
||||
img_width_h = img_org.shape[1]
|
||||
|
||||
if not self.dir_in:
|
||||
self.model_table, _ = self.start_new_session_and_model(self.model_table_dir)
|
||||
|
||||
patches = False
|
||||
if self.light_version:
|
||||
prediction_table = self.do_prediction_new_concept(patches, img, self.model_table)
|
||||
|
@ -3389,7 +3256,11 @@ class Eynollah:
|
|||
regions_without_separators_d, regions_fully, regions_without_separators,
|
||||
polygons_of_marginals, contours_tables)
|
||||
|
||||
def our_load_model(self, model_file):
|
||||
@staticmethod
|
||||
def our_load_model(model_file):
|
||||
if model_file.endswith('.h5') and Path(model_file[:-3]).exists():
|
||||
# prefer SavedModel over HDF5 format if it exists
|
||||
model_file = model_file[:-3]
|
||||
try:
|
||||
model = load_model(model_file, compile=False)
|
||||
except:
|
||||
|
@ -3440,9 +3311,6 @@ class Eynollah:
|
|||
img_header_and_sep = resize_image(img_header_and_sep, height1, width1)
|
||||
img_poly = resize_image(img_poly, height3, width3)
|
||||
|
||||
if not self.dir_in:
|
||||
self.model_reading_order, _ = self.start_new_session_and_model(self.model_reading_order_dir)
|
||||
|
||||
inference_bs = 3
|
||||
input_1 = np.zeros((inference_bs, height1, width1, 3))
|
||||
ordered = [list(range(len(co_text_all)))]
|
||||
|
@ -3743,7 +3611,7 @@ class Eynollah:
|
|||
for ij in range(len(all_found_textline_polygons[j])):
|
||||
con_ind = all_found_textline_polygons[j][ij]
|
||||
area = cv2.contourArea(con_ind)
|
||||
con_ind = con_ind.astype(np.float)
|
||||
con_ind = con_ind.astype(float)
|
||||
|
||||
x_differential = np.diff( con_ind[:,0,0])
|
||||
y_differential = np.diff( con_ind[:,0,1])
|
||||
|
@ -3847,7 +3715,7 @@ class Eynollah:
|
|||
con_ind = all_found_textline_polygons[j]
|
||||
#print(len(con_ind[:,0,0]),'con_ind[:,0,0]')
|
||||
area = cv2.contourArea(con_ind)
|
||||
con_ind = con_ind.astype(np.float)
|
||||
con_ind = con_ind.astype(float)
|
||||
|
||||
x_differential = np.diff( con_ind[:,0,0])
|
||||
y_differential = np.diff( con_ind[:,0,1])
|
||||
|
@ -3950,7 +3818,7 @@ class Eynollah:
|
|||
con_ind = all_found_textline_polygons[j][ij]
|
||||
area = cv2.contourArea(con_ind)
|
||||
|
||||
con_ind = con_ind.astype(np.float)
|
||||
con_ind = con_ind.astype(float)
|
||||
|
||||
x_differential = np.diff( con_ind[:,0,0])
|
||||
y_differential = np.diff( con_ind[:,0,1])
|
||||
|
@ -4182,7 +4050,7 @@ class Eynollah:
|
|||
for j in range(len(all_found_textline_polygons)):
|
||||
for i in range(len(all_found_textline_polygons[j])):
|
||||
con_ind = all_found_textline_polygons[j][i]
|
||||
con_ind = con_ind.astype(np.float)
|
||||
con_ind = con_ind.astype(float)
|
||||
|
||||
x_differential = np.diff( con_ind[:,0,0])
|
||||
y_differential = np.diff( con_ind[:,0,1])
|
||||
|
@ -4322,30 +4190,44 @@ class Eynollah:
|
|||
return (slopes_rem, all_found_textline_polygons_rem, boxes_text_rem, txt_con_org_rem,
|
||||
contours_only_text_parent_rem, index_by_text_par_con_rem_sort)
|
||||
|
||||
def run(self):
|
||||
def run(self, image_filename : Optional[str] = None, dir_in : Optional[str] = None, overwrite : bool = False):
|
||||
"""
|
||||
Get image and scales, then extract the page of scanned image
|
||||
"""
|
||||
self.logger.debug("enter run")
|
||||
|
||||
t0_tot = time.time()
|
||||
|
||||
if not self.dir_in:
|
||||
self.ls_imgs = [self.image_filename]
|
||||
if dir_in:
|
||||
self.ls_imgs = os.listdir(dir_in)
|
||||
elif image_filename:
|
||||
self.ls_imgs = [image_filename]
|
||||
else:
|
||||
raise ValueError("run requires either a single image filename or a directory")
|
||||
|
||||
for img_name in self.ls_imgs:
|
||||
self.logger.info(img_name)
|
||||
for img_filename in self.ls_imgs:
|
||||
self.logger.info(img_filename)
|
||||
t0 = time.time()
|
||||
if self.dir_in:
|
||||
self.reset_file_name_dir(os.path.join(self.dir_in,img_name))
|
||||
|
||||
self.reset_file_name_dir(os.path.join(dir_in or "", img_filename))
|
||||
#print("text region early -11 in %.1fs", time.time() - t0)
|
||||
if os.path.exists(self.writer.output_filename):
|
||||
if self.overwrite:
|
||||
if overwrite:
|
||||
self.logger.warning("will overwrite existing output file '%s'", self.writer.output_filename)
|
||||
else:
|
||||
self.logger.warning("will skip input for existing output file '%s'", self.writer.output_filename)
|
||||
continue
|
||||
|
||||
pcgts = self.run_single()
|
||||
self.logger.info("Job done in %.1fs", time.time() - t0)
|
||||
#print("Job done in %.1fs" % (time.time() - t0))
|
||||
self.writer.write_pagexml(pcgts)
|
||||
|
||||
if dir_in:
|
||||
self.logger.info("All jobs done in %.1fs", time.time() - t0_tot)
|
||||
print("all Job done in %.1fs", time.time() - t0_tot)
|
||||
|
||||
def run_single(self):
|
||||
t0 = time.time()
|
||||
img_res, is_image_enhanced, num_col_classifier, num_column_is_classified = self.run_enhancement(self.light_version)
|
||||
self.logger.info("Enhancing took %.1fs ", time.time() - t0)
|
||||
if self.extract_only_images:
|
||||
|
@ -4358,11 +4240,6 @@ class Eynollah:
|
|||
cont_page, [], [], ocr_all_textlines)
|
||||
if self.plotter:
|
||||
self.plotter.write_images_into_directory(polygons_of_images, image_page)
|
||||
|
||||
if self.dir_in:
|
||||
self.writer.write_pagexml(pcgts)
|
||||
continue
|
||||
else:
|
||||
return pcgts
|
||||
|
||||
if self.skip_layout_and_reading_order:
|
||||
|
@ -4405,10 +4282,6 @@ class Eynollah:
|
|||
all_found_textline_polygons, page_coord, polygons_of_images, polygons_of_marginals,
|
||||
all_found_textline_polygons_marginals, all_box_coord_marginals, slopes, slopes_marginals,
|
||||
cont_page, polygons_lines_xml, contours_tables, ocr_all_textlines)
|
||||
if self.dir_in:
|
||||
self.writer.write_pagexml(pcgts)
|
||||
continue
|
||||
else:
|
||||
return pcgts
|
||||
|
||||
#print("text region early -1 in %.1fs", time.time() - t0)
|
||||
|
@ -4461,11 +4334,6 @@ class Eynollah:
|
|||
pcgts = self.writer.build_pagexml_no_full_layout(
|
||||
[], page_coord, [], [], [], [], [], [], [], [], [], [],
|
||||
cont_page, [], [], ocr_all_textlines)
|
||||
self.logger.info("Job done in %.1fs", time.time() - t1)
|
||||
if self.dir_in:
|
||||
self.writer.write_pagexml(pcgts)
|
||||
continue
|
||||
else:
|
||||
return pcgts
|
||||
|
||||
#print("text region early in %.1fs", time.time() - t0)
|
||||
|
@ -4651,11 +4519,6 @@ class Eynollah:
|
|||
polygons_of_images,
|
||||
polygons_of_marginals, empty_marginals, empty_marginals, [], [],
|
||||
cont_page, polygons_lines_xml, contours_tables, [])
|
||||
self.logger.info("Job done in %.1fs", time.time() - t0)
|
||||
if self.dir_in:
|
||||
self.writer.write_pagexml(pcgts)
|
||||
continue
|
||||
else:
|
||||
return pcgts
|
||||
|
||||
#print("text region early 3 in %.1fs", time.time() - t0)
|
||||
|
@ -4846,15 +4709,8 @@ class Eynollah:
|
|||
polygons_of_images, contours_tables, polygons_of_drop_capitals, polygons_of_marginals,
|
||||
all_found_textline_polygons_marginals, all_box_coord_marginals, slopes, slopes_h, slopes_marginals,
|
||||
cont_page, polygons_lines_xml, ocr_all_textlines)
|
||||
self.logger.info("Job done in %.1fs", time.time() - t0)
|
||||
#print("Job done in %.1fs", time.time() - t0)
|
||||
if self.dir_in:
|
||||
self.writer.write_pagexml(pcgts)
|
||||
continue
|
||||
else:
|
||||
return pcgts
|
||||
|
||||
else:
|
||||
contours_only_text_parent_h = None
|
||||
if self.reading_order_machine_based:
|
||||
order_text_new, id_of_texts_tot = self.do_order_of_regions_with_model(
|
||||
|
@ -4932,20 +4788,7 @@ class Eynollah:
|
|||
all_found_textline_polygons, all_box_coord, polygons_of_images, polygons_of_marginals,
|
||||
all_found_textline_polygons_marginals, all_box_coord_marginals, slopes, slopes_marginals,
|
||||
cont_page, polygons_lines_xml, contours_tables, ocr_all_textlines)
|
||||
#print("Job done in %.1fs" % (time.time() - t0))
|
||||
self.logger.info("Job done in %.1fs", time.time() - t0)
|
||||
if not self.dir_in:
|
||||
return pcgts
|
||||
#print("text region early 7 in %.1fs", time.time() - t0)
|
||||
|
||||
if self.dir_in:
|
||||
self.writer.write_pagexml(pcgts)
|
||||
self.logger.info("Job done in %.1fs", time.time() - t0)
|
||||
#print("Job done in %.1fs" % (time.time() - t0))
|
||||
|
||||
if self.dir_in:
|
||||
self.logger.info("All jobs done in %.1fs", time.time() - t0_tot)
|
||||
print("all Job done in %.1fs", time.time() - t0_tot)
|
||||
|
||||
|
||||
class Eynollah_ocr:
|
||||
|
|
|
@ -2,7 +2,7 @@ from typing import Optional
|
|||
from ocrd_models import OcrdPage
|
||||
from ocrd import Processor, OcrdPageResult
|
||||
|
||||
from .eynollah import Eynollah
|
||||
from .eynollah import Eynollah, EynollahXmlWriter
|
||||
|
||||
class EynollahProcessor(Processor):
|
||||
# already employs background CPU multiprocessing per page
|
||||
|
@ -14,11 +14,28 @@ class EynollahProcessor(Processor):
|
|||
return 'ocrd-eynollah-segment'
|
||||
|
||||
def setup(self) -> None:
|
||||
# for caching models
|
||||
self.models = None
|
||||
if self.parameter['textline_light'] and not self.parameter['light_version']:
|
||||
raise ValueError("Error: You set parameter 'textline_light' to enable light textline detection, "
|
||||
"but parameter 'light_version' is not enabled")
|
||||
self.eynollah = Eynollah(
|
||||
self.resolve_resource(self.parameter['models']),
|
||||
logger=self.logger,
|
||||
allow_enhancement=self.parameter['allow_enhancement'],
|
||||
curved_line=self.parameter['curved_line'],
|
||||
right2left=self.parameter['right_to_left'],
|
||||
ignore_page_extraction=self.parameter['ignore_page_extraction'],
|
||||
light_version=self.parameter['light_version'],
|
||||
textline_light=self.parameter['textline_light'],
|
||||
full_layout=self.parameter['full_layout'],
|
||||
allow_scaling=self.parameter['allow_scaling'],
|
||||
headers_off=self.parameter['headers_off'],
|
||||
tables=self.parameter['tables'],
|
||||
)
|
||||
self.eynollah.plotter = None
|
||||
|
||||
def shutdown(self):
|
||||
if hasattr(self, 'eynollah'):
|
||||
del self.eynollah
|
||||
|
||||
def process_page_pcgts(self, *input_pcgts: Optional[OcrdPage], page_id: Optional[str] = None) -> OcrdPageResult:
|
||||
"""
|
||||
|
@ -60,27 +77,15 @@ class EynollahProcessor(Processor):
|
|||
image_filename = "dummy" # will be replaced by ocrd.Processor.process_page_file
|
||||
result.images.append(OcrdPageResultImage(page_image, '.IMG', page)) # mark as new original
|
||||
# FIXME: mask out already existing regions (incremental segmentation)
|
||||
eynollah = Eynollah(
|
||||
self.resolve_resource(self.parameter['models']),
|
||||
logger=self.logger,
|
||||
allow_enhancement=self.parameter['allow_enhancement'],
|
||||
curved_line=self.parameter['curved_line'],
|
||||
right2left=self.parameter['right_to_left'],
|
||||
ignore_page_extraction=self.parameter['ignore_page_extraction'],
|
||||
light_version=self.parameter['light_version'],
|
||||
textline_light=self.parameter['textline_light'],
|
||||
full_layout=self.parameter['full_layout'],
|
||||
allow_scaling=self.parameter['allow_scaling'],
|
||||
headers_off=self.parameter['headers_off'],
|
||||
tables=self.parameter['tables'],
|
||||
override_dpi=self.parameter['dpi'],
|
||||
pcgts=pcgts,
|
||||
image_filename=image_filename,
|
||||
image_pil=page_image
|
||||
self.eynollah.cache_images(
|
||||
image_pil=page_image,
|
||||
dpi=self.parameter['dpi'],
|
||||
)
|
||||
if self.models is not None:
|
||||
# reuse loaded models from previous page
|
||||
eynollah.models = self.models
|
||||
eynollah.run()
|
||||
self.models = eynollah.models
|
||||
self.eynollah.writer = EynollahXmlWriter(
|
||||
dir_out=None,
|
||||
image_filename=image_filename,
|
||||
curved_line=self.eynollah.curved_line,
|
||||
textline_light=self.eynollah.textline_light,
|
||||
pcgts=pcgts)
|
||||
self.eynollah.run_single()
|
||||
return result
|
||||
|
|
|
@ -4,25 +4,19 @@ Tool to load model and binarize a given image.
|
|||
|
||||
import sys
|
||||
from glob import glob
|
||||
from os import environ, devnull
|
||||
from os.path import join
|
||||
from warnings import catch_warnings, simplefilter
|
||||
import os
|
||||
import logging
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import cv2
|
||||
environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
||||
stderr = sys.stderr
|
||||
sys.stderr = open(devnull, 'w')
|
||||
from ocrd_utils import tf_disable_interactive_logs
|
||||
tf_disable_interactive_logs()
|
||||
import tensorflow as tf
|
||||
from tensorflow.keras.models import load_model
|
||||
from tensorflow.python.keras import backend as tensorflow_backend
|
||||
sys.stderr = stderr
|
||||
|
||||
|
||||
import logging
|
||||
|
||||
def resize_image(img_in, input_height, input_width):
|
||||
return cv2.resize(img_in, (input_width, input_height), interpolation=cv2.INTER_NEAREST)
|
||||
|
||||
|
@ -53,7 +47,7 @@ class SbbBinarizer:
|
|||
del self.session
|
||||
|
||||
def load_model(self, model_name):
|
||||
model = load_model(join(self.model_dir, model_name), compile=False)
|
||||
model = load_model(os.path.join(self.model_dir, model_name), compile=False)
|
||||
model_height = model.layers[len(model.layers)-1].output_shape[1]
|
||||
model_width = model.layers[len(model.layers)-1].output_shape[2]
|
||||
n_classes = model.layers[len(model.layers)-1].output_shape[3]
|
||||
|
|
|
@ -247,7 +247,7 @@ def get_textregion_contours_in_org_image_light(cnts, img, slope_first, map=map):
|
|||
img = cv2.resize(img, (int(img.shape[1]/6), int(img.shape[0]/6)), interpolation=cv2.INTER_NEAREST)
|
||||
##cnts = list( (np.array(cnts)/2).astype(np.int16) )
|
||||
#cnts = cnts/2
|
||||
cnts = [(i/6).astype(np.int) for i in cnts]
|
||||
cnts = [(i/6).astype(int) for i in cnts]
|
||||
results = map(partial(do_back_rotation_and_get_cnt_back,
|
||||
img=img,
|
||||
slope_first=slope_first,
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
from contextlib import nullcontext
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
from ocrd_models import OcrdExif
|
||||
|
@ -17,11 +18,12 @@ def pil2cv(img):
|
|||
def check_dpi(img):
|
||||
try:
|
||||
if isinstance(img, Image.Image):
|
||||
pil_image = img
|
||||
pil_image = nullcontext(img)
|
||||
elif isinstance(img, str):
|
||||
pil_image = Image.open(img)
|
||||
else:
|
||||
pil_image = cv2pil(img)
|
||||
pil_image = nullcontext(cv2pil(img))
|
||||
with pil_image:
|
||||
exif = OcrdExif(pil_image)
|
||||
resolution = exif.resolution
|
||||
if resolution == 1:
|
||||
|
|
|
@ -1616,7 +1616,7 @@ def do_work_of_slopes_new(
|
|||
textline_con_fil = filter_contours_area_of_image(img_int_p, textline_con,
|
||||
hierarchy,
|
||||
max_area=1, min_area=0.00008)
|
||||
y_diff_mean = find_contours_mean_y_diff(textline_con_fil)
|
||||
y_diff_mean = find_contours_mean_y_diff(textline_con_fil) if len(textline_con_fil) > 1 else np.NaN
|
||||
if np.isnan(y_diff_mean):
|
||||
slope_for_all = MAX_SLOPE
|
||||
else:
|
||||
|
@ -1681,7 +1681,7 @@ def do_work_of_slopes_new_curved(
|
|||
textline_con_fil = filter_contours_area_of_image(img_int_p, textline_con,
|
||||
hierarchy,
|
||||
max_area=1, min_area=0.0008)
|
||||
y_diff_mean = find_contours_mean_y_diff(textline_con_fil)
|
||||
y_diff_mean = find_contours_mean_y_diff(textline_con_fil) if len(textline_con_fil) > 1 else np.NaN
|
||||
if np.isnan(y_diff_mean):
|
||||
slope_for_all = MAX_SLOPE
|
||||
else:
|
||||
|
|
BIN
tests/resources/euler_rechenkunst01_1738_0025.tif
Normal file
BIN
tests/resources/euler_rechenkunst01_1738_0025.tif
Normal file
Binary file not shown.
Loading…
Add table
Add a link
Reference in a new issue