diff --git a/.github/workflows/test-eynollah.yml b/.github/workflows/test-eynollah.yml index 466e690..d6b92ba 100644 --- a/.github/workflows/test-eynollah.yml +++ b/.github/workflows/test-eynollah.yml @@ -24,61 +24,63 @@ jobs: sudo rm -rf "$AGENT_TOOLSDIRECTORY" df -h - uses: actions/checkout@v4 - - uses: actions/cache/restore@v4 - id: seg_model_cache + + # - name: Lint with ruff + # uses: astral-sh/ruff-action@v3 + # with: + # src: "./src" + + - name: Try to restore models_eynollah + uses: actions/cache/restore@v4 + id: all_model_cache with: - path: models_layout_v0_5_0 - key: seg-models - - uses: actions/cache/restore@v4 - id: ocr_model_cache - with: - path: models_ocr_v0_5_1 - key: ocr-models - - uses: actions/cache/restore@v4 - id: bin_model_cache - with: - path: default-2021-03-09 - key: bin-models + path: models_eynollah + key: models_eynollah-${{ hashFiles('src/eynollah/model_zoo/default_specs.py') }} + - name: Download models - if: steps.seg_model_cache.outputs.cache-hit != 'true' || steps.bin_model_cache.outputs.cache-hit != 'true' || steps.ocr_model_cache.outputs.cache-hit != true - run: make models + if: steps.all_model_cache.outputs.cache-hit != 'true' + run: | + make models + ls -la models_eynollah + - uses: actions/cache/save@v4 - if: steps.seg_model_cache.outputs.cache-hit != 'true' + if: steps.all_model_cache.outputs.cache-hit != 'true' with: - path: models_layout_v0_5_0 - key: seg-models - - uses: actions/cache/save@v4 - if: steps.ocr_model_cache.outputs.cache-hit != 'true' - with: - path: models_ocr_v0_5_1 - key: ocr-models - - uses: actions/cache/save@v4 - if: steps.bin_model_cache.outputs.cache-hit != 'true' - with: - path: default-2021-03-09 - key: bin-models + path: models_eynollah + key: models_eynollah-${{ hashFiles('src/eynollah/model_zoo/default_specs.py') }} + - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} + + # - uses: actions/cache@v4 + # with: + # path: | + # path/to/dependencies + # some/other/dependencies + # key: ${{ runner.os }}-${{ hashFiles('**/lockfiles') }} + - name: Install dependencies run: | python -m pip install --upgrade pip make install-dev EXTRAS=OCR,plotting make deps-test EXTRAS=OCR,plotting - ls -l models_* - - name: Lint with ruff - uses: astral-sh/ruff-action@v3 - with: - src: "./src" + + - name: Hard-upgrade torch for debugging + run: | + python -m pip install --upgrade torch + - name: Test with pytest run: make coverage PYTEST_ARGS="-vv --junitxml=pytest.xml" + - name: Get coverage results run: | coverage report --format=markdown >> $GITHUB_STEP_SUMMARY coverage html coverage json coverage xml + - name: Store coverage results uses: actions/upload-artifact@v4 with: @@ -88,12 +90,15 @@ jobs: pytest.xml coverage.xml coverage.json + - name: Upload coverage results uses: codecov/codecov-action@v4 with: files: coverage.xml fail_ci_if_error: false + - name: Test standalone CLI run: make smoke-test + - name: Test OCR-D CLI run: make ocrd-test diff --git a/.gitignore b/.gitignore index fd64f0b..49835a7 100644 --- a/.gitignore +++ b/.gitignore @@ -11,3 +11,4 @@ output.html *.tif *.sw? TAGS +uv.lock diff --git a/Makefile b/Makefile index 29dd877..6d7ceff 100644 --- a/Makefile +++ b/Makefile @@ -6,23 +6,17 @@ EXTRAS ?= DOCKER_BASE_IMAGE ?= docker.io/ocrd/core-cuda-tf2:latest DOCKER_TAG ?= ocrd/eynollah DOCKER ?= docker +WGET = wget -O #SEG_MODEL := https://qurator-data.de/eynollah/2021-04-25/models_eynollah.tar.gz #SEG_MODEL := https://qurator-data.de/eynollah/2022-04-05/models_eynollah_renamed.tar.gz # SEG_MODEL := https://qurator-data.de/eynollah/2022-04-05/models_eynollah.tar.gz #SEG_MODEL := https://github.com/qurator-spk/eynollah/releases/download/v0.3.0/models_eynollah.tar.gz #SEG_MODEL := https://github.com/qurator-spk/eynollah/releases/download/v0.3.1/models_eynollah.tar.gz -SEG_MODEL := https://zenodo.org/records/17194824/files/models_layout_v0_5_0.tar.gz?download=1 -SEG_MODELFILE = $(notdir $(patsubst %?download=1,%,$(SEG_MODEL))) -SEG_MODELNAME = $(SEG_MODELFILE:%.tar.gz=%) - -BIN_MODEL := https://github.com/qurator-spk/sbb_binarization/releases/download/v0.0.11/saved_model_2021_03_09.zip -BIN_MODELFILE = $(notdir $(BIN_MODEL)) -BIN_MODELNAME := default-2021-03-09 - -OCR_MODEL := https://zenodo.org/records/17236998/files/models_ocr_v0_5_1.tar.gz?download=1 -OCR_MODELFILE = $(notdir $(patsubst %?download=1,%,$(OCR_MODEL))) -OCR_MODELNAME = $(OCR_MODELFILE:%.tar.gz=%) +#SEG_MODEL := https://zenodo.org/records/17194824/files/models_layout_v0_5_0.tar.gz?download=1 +EYNOLLAH_MODELS_URL := https://zenodo.org/records/17580627/files/models_all_v0_7_0.zip +EYNOLLAH_MODELS_ZIP = $(notdir $(EYNOLLAH_MODELS_URL)) +EYNOLLAH_MODELS_DIR = $(EYNOLLAH_MODELS_ZIP:%.zip=%) PYTEST_ARGS ?= -vv --isolate @@ -38,7 +32,7 @@ help: @echo " install-dev Install editable with pip" @echo " deps-test Install test dependencies with pip" @echo " models Download and extract models to $(CURDIR):" - @echo " $(BIN_MODELNAME) $(SEG_MODELNAME) $(OCR_MODELNAME)" + @echo " $(EYNOLLAH_MODELS_DIR)" @echo " smoke-test Run simple CLI check" @echo " ocrd-test Run OCR-D CLI check" @echo " test Run unit tests" @@ -47,34 +41,22 @@ help: @echo " EXTRAS comma-separated list of features (like 'OCR,plotting') for 'install' [$(EXTRAS)]" @echo " DOCKER_TAG Docker image tag for 'docker' [$(DOCKER_TAG)]" @echo " PYTEST_ARGS pytest args for 'test' (Set to '-s' to see log output during test execution, '-vv' to see individual tests. [$(PYTEST_ARGS)]" - @echo " SEG_MODEL URL of 'models' archive to download for segmentation 'test' [$(SEG_MODEL)]" - @echo " BIN_MODEL URL of 'models' archive to download for binarization 'test' [$(BIN_MODEL)]" - @echo " OCR_MODEL URL of 'models' archive to download for binarization 'test' [$(OCR_MODEL)]" + @echo " ALL_MODELS URL of archive of all models [$(ALL_MODELS)]" @echo "" # END-EVAL - -# Download and extract models to $(PWD)/models_layout_v0_5_0 -models: $(BIN_MODELNAME) $(SEG_MODELNAME) $(OCR_MODELNAME) +# Download and extract models to $(PWD)/models_layout_v0_6_0 +models: $(EYNOLLAH_MODELS_DIR) # do not download these files if we already have the directories -.INTERMEDIATE: $(BIN_MODELFILE) $(SEG_MODELFILE) $(OCR_MODELFILE) +.INTERMEDIATE: $(EYNOLLAH_MODELS_ZIP) -$(BIN_MODELFILE): - wget -O $@ $(BIN_MODEL) -$(SEG_MODELFILE): - wget -O $@ $(SEG_MODEL) -$(OCR_MODELFILE): - wget -O $@ $(OCR_MODEL) +$(EYNOLLAH_MODELS_ZIP): + $(WGET) $@ $(EYNOLLAH_MODELS_URL) -$(BIN_MODELNAME): $(BIN_MODELFILE) - mkdir $@ - unzip -d $@ $< -$(SEG_MODELNAME): $(SEG_MODELFILE) - tar zxf $< -$(OCR_MODELNAME): $(OCR_MODELFILE) - tar zxf $< +$(EYNOLLAH_MODELS_DIR): $(EYNOLLAH_MODELS_ZIP) + unzip $< build: $(PIP) install build @@ -88,34 +70,28 @@ install: install-dev: $(PIP) install -e .$(and $(EXTRAS),[$(EXTRAS)]) -ifeq (OCR,$(findstring OCR, $(EXTRAS))) -deps-test: $(OCR_MODELNAME) -endif -deps-test: $(BIN_MODELNAME) $(SEG_MODELNAME) +deps-test: $(PIP) install -r requirements-test.txt -ifeq (OCR,$(findstring OCR, $(EXTRAS))) - ln -rs $(OCR_MODELNAME)/* $(SEG_MODELNAME)/ -endif smoke-test: TMPDIR != mktemp -d smoke-test: tests/resources/kant_aufklaerung_1784_0020.tif # layout analysis: - eynollah layout -i $< -o $(TMPDIR) -m $(CURDIR)/$(SEG_MODELNAME) + eynollah -m $(CURDIR)/models_eynollah layout -i $< -o $(TMPDIR) fgrep -q http://schema.primaresearch.org/PAGE/gts/pagecontent/2019-07-15 $(TMPDIR)/$(basename $( Document Layout Analysis, Binarization and OCR with Deep Learning and Heuristics +[![Python Versions](https://img.shields.io/pypi/pyversions/eynollah.svg)](https://pypi.python.org/pypi/eynollah) [![PyPI Version](https://img.shields.io/pypi/v/eynollah)](https://pypi.org/project/eynollah/) [![GH Actions Test](https://github.com/qurator-spk/eynollah/actions/workflows/test-eynollah.yml/badge.svg)](https://github.com/qurator-spk/eynollah/actions/workflows/test-eynollah.yml) [![GH Actions Deploy](https://github.com/qurator-spk/eynollah/actions/workflows/build-docker.yml/badge.svg)](https://github.com/qurator-spk/eynollah/actions/workflows/build-docker.yml) @@ -11,24 +12,22 @@ ![](https://user-images.githubusercontent.com/952378/102350683-8a74db80-3fa5-11eb-8c7e-f743f7d6eae2.jpg) ## Features -* Support for 10 distinct segmentation classes: +* Document layout analysis using pixelwise segmentation models with support for 10 segmentation classes: * background, [page border](https://ocr-d.de/en/gt-guidelines/trans/lyRand.html), [text region](https://ocr-d.de/en/gt-guidelines/trans/lytextregion.html#textregionen__textregion_), [text line](https://ocr-d.de/en/gt-guidelines/pagexml/pagecontent_xsd_Complex_Type_pc_TextLineType.html), [header](https://ocr-d.de/en/gt-guidelines/trans/lyUeberschrift.html), [image](https://ocr-d.de/en/gt-guidelines/trans/lyBildbereiche.html), [separator](https://ocr-d.de/en/gt-guidelines/trans/lySeparatoren.html), [marginalia](https://ocr-d.de/en/gt-guidelines/trans/lyMarginalie.html), [initial](https://ocr-d.de/en/gt-guidelines/trans/lyInitiale.html), [table](https://ocr-d.de/en/gt-guidelines/trans/lyTabellen.html) -* Support for various image optimization operations: - * cropping (border detection), binarization, deskewing, dewarping, scaling, enhancing, resizing * Textline segmentation to bounding boxes or polygons (contours) including for curved lines and vertical text -* Text recognition (OCR) using either CNN-RNN or Transformer models -* Detection of reading order (left-to-right or right-to-left) using either heuristics or trainable models +* Document image binarization with pixelwise segmentation or hybrid CNN-Transformer models +* Text recognition (OCR) with CNN-RNN or TrOCR models +* Detection of reading order (left-to-right or right-to-left) using heuristics or trainable models * Output in [PAGE-XML](https://github.com/PRImA-Research-Lab/PAGE-XML) * [OCR-D](https://github.com/qurator-spk/eynollah#use-as-ocr-d-processor) interface :warning: Development is focused on achieving the best quality of results for a wide variety of historical -documents and therefore processing can be very slow. We aim to improve this, but contributions are welcome. +documents using a combination of multiple deep learning models and heuristics; therefore processing can be slow. ## Installation - Python `3.8-3.11` with Tensorflow `<2.13` on Linux are currently supported. - -For (limited) GPU support the CUDA toolkit needs to be installed. A known working config is CUDA `11` with cuDNN `8.6`. +For (limited) GPU support the CUDA toolkit needs to be installed. +A working config is CUDA `11.8` with cuDNN `8.6`. You can either install from PyPI @@ -53,31 +52,41 @@ pip install "eynollah[OCR]" make install EXTRAS=OCR ``` +### Docker + +Use + +``` +docker pull ghcr.io/qurator-spk/eynollah:latest +``` + +When using Eynollah with Docker, see [`docker.md`](https://github.com/qurator-spk/eynollah/tree/main/docs/docker.md). + ## Models -Pretrained models can be downloaded from [zenodo](https://zenodo.org/records/17194824) or [huggingface](https://huggingface.co/SBB?search_models=eynollah). +Pretrained models can be downloaded from [Zenodo](https://zenodo.org/records/17194824) or [Hugging Face](https://huggingface.co/SBB?search_models=eynollah). -For documentation on models, have a look at [`models.md`](https://github.com/qurator-spk/eynollah/tree/main/docs/models.md). -Model cards are also provided for our trained models. +For model documentation and model cards, see [`models.md`](https://github.com/qurator-spk/eynollah/tree/main/docs/models.md). ## Training -In case you want to train your own model with Eynollah, see the -documentation in [`train.md`](https://github.com/qurator-spk/eynollah/tree/main/docs/train.md) and use the -tools in the [`train` folder](https://github.com/qurator-spk/eynollah/tree/main/train). +To train your own model with Eynollah, see [`train.md`](https://github.com/qurator-spk/eynollah/tree/main/docs/train.md) and use the tools in the [`train`](https://github.com/qurator-spk/eynollah/tree/main/train) folder. ## Usage -Eynollah supports five use cases: layout analysis (segmentation), binarization, -image enhancement, text recognition (OCR), and reading order detection. +Eynollah supports five use cases: +1. [layout analysis (segmentation)](#layout-analysis), +2. [binarization](#binarization), +3. [image enhancement](#image-enhancement), +4. [text recognition (OCR)](#ocr), and +5. [reading order detection](#reading-order-detection). + +Some example outputs can be found in [`examples.md`](https://github.com/qurator-spk/eynollah/tree/main/docs/examples.md). ### Layout Analysis The layout analysis module is responsible for detecting layout elements, identifying text lines, and determining reading -order using either heuristic methods or a [pretrained reading order detection model](https://github.com/qurator-spk/eynollah#machine-based-reading-order). - -Reading order detection can be performed either as part of layout analysis based on image input, or, currently under -development, based on pre-existing layout analysis results in PAGE-XML format as input. +order using heuristic methods or a [pretrained model](https://github.com/qurator-spk/eynollah#machine-based-reading-order). The command-line interface for layout analysis can be called like this: @@ -91,29 +100,42 @@ eynollah layout \ The following options can be used to further configure the processing: -| option | description | -|-------------------|:-------------------------------------------------------------------------------| -| `-fl` | full layout analysis including all steps and segmentation classes | -| `-light` | lighter and faster but simpler method for main region detection and deskewing | -| `-tll` | this indicates the light textline and should be passed with light version | -| `-tab` | apply table detection | -| `-ae` | apply enhancement (the resulting image is saved to the output directory) | -| `-as` | apply scaling | -| `-cl` | apply contour detection for curved text lines instead of bounding boxes | -| `-ib` | apply binarization (the resulting image is saved to the output directory) | -| `-ep` | enable plotting (MUST always be used with `-sl`, `-sd`, `-sa`, `-si` or `-ae`) | -| `-eoi` | extract only images to output directory (other processing will not be done) | -| `-ho` | ignore headers for reading order dectection | -| `-si ` | save image regions detected to this directory | -| `-sd ` | save deskewed image to this directory | -| `-sl ` | save layout prediction as plot to this directory | -| `-sp ` | save cropped page image to this directory | -| `-sa ` | save all (plot, enhanced/binary image, layout) to this directory | +| option | description | +|-------------------|:--------------------------------------------------------------------------------------------| +| `-fl` | full layout analysis including all steps and segmentation classes (recommended) | +| `-light` | lighter and faster but simpler method for main region detection and deskewing (recommended) | +| `-tll` | this indicates the light textline and should be passed with light version (recommended) | +| `-tab` | apply table detection | +| `-ae` | apply enhancement (the resulting image is saved to the output directory) | +| `-as` | apply scaling | +| `-cl` | apply contour detection for curved text lines instead of bounding boxes | +| `-ib` | apply binarization (the resulting image is saved to the output directory) | +| `-ep` | enable plotting (MUST always be used with `-sl`, `-sd`, `-sa`, `-si` or `-ae`) | +| `-eoi` | extract only images to output directory (other processing will not be done) | +| `-ho` | ignore headers for reading order dectection | +| `-si ` | save image regions detected to this directory | +| `-sd ` | save deskewed image to this directory | +| `-sl ` | save layout prediction as plot to this directory | +| `-sp ` | save cropped page image to this directory | +| `-sa ` | save all (plot, enhanced/binary image, layout) to this directory | +| `-thart` | threshold of artifical class in the case of textline detection. The default value is 0.1 | +| `-tharl` | threshold of artifical class in the case of layout detection. The default value is 0.1 | +| `-ocr` | do ocr | +| `-tr` | apply transformer ocr. Default model is a CNN-RNN model | +| `-bs_ocr` | ocr inference batch size. Default bs for trocr and cnn_rnn models are 2 and 8 respectively | +| `-ncu` | upper limit of columns in document image | +| `-ncl` | lower limit of columns in document image | +| `-slro` | skip layout detection and reading order | +| `-romb` | apply machine based reading order detection | +| `-ipe` | ignore page extraction | + If no further option is set, the tool performs layout detection of main regions (background, text, images, separators and marginals). The best output quality is achieved when RGB images are used as input rather than greyscale or binarized images. +Additional documentation can be found in [`usage.md`](https://github.com/qurator-spk/eynollah/tree/main/docs/usage.md). + ### Binarization The binarization module performs document image binarization using pretrained pixelwise segmentation models. @@ -124,9 +146,12 @@ The command-line interface for binarization can be called like this: eynollah binarization \ -i | -di \ -o \ - -m \ + -m ``` +### Image Enhancement +TODO + ### OCR The OCR module performs text recognition using either a CNN-RNN model or a Transformer model. @@ -138,12 +163,29 @@ eynollah ocr \ -i | -di \ -dx \ -o \ - -m | --model_name \ + -m | --model_name ``` -### Machine-based-reading-order +The following options can be used to further configure the ocr processing: -The machine-based reading-order module employs a pretrained model to identify the reading order from layouts represented in PAGE-XML files. +| option | description | +|-------------------|:-------------------------------------------------------------------------------------------| +| `-dib` | directory of binarized images (file type must be '.png'), prediction with both RGB and bin | +| `-doit` | directory for output images rendered with the predicted text | +| `--model_name` | file path to use specific model for OCR | +| `-trocr` | use transformer ocr model (otherwise cnn_rnn model is used) | +| `-etit` | export textline images and text in xml to output dir (OCR training data) | +| `-nmtc` | cropped textline images will not be masked with textline contour | +| `-bs` | ocr inference batch size. Default batch size is 2 for trocr and 8 for cnn_rnn models | +| `-ds_pref` | add an abbrevation of dataset name to generated training data | +| `-min_conf` | minimum OCR confidence value. OCR with textline conf lower than this will be ignored | + + +### Reading Order Detection +Reading order detection can be performed either as part of layout analysis based on image input, or, currently under +development, based on pre-existing layout analysis data in PAGE-XML format as input. + +The reading order detection module employs a pretrained model to identify the reading order from layouts represented in PAGE-XML files. The command-line interface for machine based reading order can be called like this: @@ -155,36 +197,9 @@ eynollah machine-based-reading-order \ -o ``` -#### Use as OCR-D processor +## Use as OCR-D processor -Eynollah ships with a CLI interface to be used as [OCR-D](https://ocr-d.de) [processor](https://ocr-d.de/en/spec/cli), -formally described in [`ocrd-tool.json`](https://github.com/qurator-spk/eynollah/tree/main/src/eynollah/ocrd-tool.json). - -In this case, the source image file group with (preferably) RGB images should be used as input like this: - - ocrd-eynollah-segment -I OCR-D-IMG -O OCR-D-SEG -P models eynollah_layout_v0_5_0 - -If the input file group is PAGE-XML (from a previous OCR-D workflow step), Eynollah behaves as follows: -- existing regions are kept and ignored (i.e. in effect they might overlap segments from Eynollah results) -- existing annotation (and respective `AlternativeImage`s) are partially _ignored_: - - previous page frame detection (`cropped` images) - - previous derotation (`deskewed` images) - - previous thresholding (`binarized` images) -- if the page-level image nevertheless deviates from the original (`@imageFilename`) - (because some other preprocessing step was in effect like `denoised`), then - the output PAGE-XML will be based on that as new top-level (`@imageFilename`) - - ocrd-eynollah-segment -I OCR-D-XYZ -O OCR-D-SEG -P models eynollah_layout_v0_5_0 - -In general, it makes more sense to add other workflow steps **after** Eynollah. - -There is also an OCR-D processor for binarization: - - ocrd-sbb-binarize -I OCR-D-IMG -O OCR-D-BIN -P models default-2021-03-09 - -#### Additional documentation - -Additional documentation is available in the [docs](https://github.com/qurator-spk/eynollah/tree/main/docs) directory. +See [`ocrd.md`](https://github.com/qurator-spk/eynollah/tree/main/docs/ocrd.md). ## How to cite diff --git a/docs/docker.md b/docs/docker.md new file mode 100644 index 0000000..7965622 --- /dev/null +++ b/docs/docker.md @@ -0,0 +1,43 @@ +## Inference with Docker + + docker pull ghcr.io/qurator-spk/eynollah:latest + +### 1. ocrd resource manager +(just once, to get the models and install them into a named volume for later re-use) + + vol_models=ocrd-resources:/usr/local/share/ocrd-resources + docker run --rm -v $vol_models ocrd/eynollah ocrd resmgr download ocrd-eynollah-segment default + +Now, each time you want to use Eynollah, pass the same resources volume again. +Also, bind-mount some data directory, e.g. current working directory $PWD (/data is default working directory in the container). + +Either use standalone CLI (2) or OCR-D CLI (3): + +### 2. standalone CLI +(follow self-help, cf. readme) + + docker run --rm -v $vol_models -v $PWD:/data ocrd/eynollah eynollah binarization --help + docker run --rm -v $vol_models -v $PWD:/data ocrd/eynollah eynollah layout --help + docker run --rm -v $vol_models -v $PWD:/data ocrd/eynollah eynollah ocr --help + +### 3. OCR-D CLI +(follow self-help, cf. readme and https://ocr-d.de/en/spec/cli) + + docker run --rm -v $vol_models -v $PWD:/data ocrd/eynollah ocrd-eynollah-segment -h + docker run --rm -v $vol_models -v $PWD:/data ocrd/eynollah ocrd-sbb-binarize -h + +Alternatively, just "log in" to the container once and use the commands there: + + docker run --rm -v $vol_models -v $PWD:/data -it ocrd/eynollah bash + +## Training with Docker + +Build the Docker training image + + cd train + docker build -t model-training . + +Run the Docker training image + + cd train + docker run --gpus all -v $PWD:/entry_point_dir model-training diff --git a/docs/examples.md b/docs/examples.md new file mode 100644 index 0000000..24336b3 --- /dev/null +++ b/docs/examples.md @@ -0,0 +1,18 @@ +# Examples + +Example outputs of various Eynollah models + +# Binarisation + + + + +# Reading Order Detection + +Input Image +Output Image + +# OCR + +Input ImageOutput Image +Input ImageOutput Image diff --git a/docs/models.md b/docs/models.md index 3d296d5..b858630 100644 --- a/docs/models.md +++ b/docs/models.md @@ -18,7 +18,8 @@ Two Arabic/Persian terms form the name of the model suite: عين الله, whic See the flowchart below for the different stages and how they interact: -![](https://user-images.githubusercontent.com/952378/100619946-1936f680-331e-11eb-9297-6e8b4cab3c16.png) +eynollah_flowchart + ## Models @@ -151,15 +152,75 @@ This model is used for the task of illustration detection only. Model card: [Reading Order Detection]() -TODO +The model extracts the reading order of text regions from the layout by classifying pairwise relationships between them. A sorting algorithm then determines the overall reading sequence. + +### OCR + +We have trained three OCR models: two CNN-RNN–based models and one transformer-based TrOCR model. The CNN-RNN models are generally faster and provide better results in most cases, though their performance decreases with heavily degraded images. The TrOCR model, on the other hand, is computationally expensive and slower during inference, but it can possibly produce better results on strongly degraded images. + +#### CNN-RNN model: model_eynollah_ocr_cnnrnn_20250805 + +This model is trained on data where most of the samples are in Fraktur german script. + +| Dataset | Input | CER | WER | +|-----------------------|:-------|:-----------|:----------| +| OCR-D-GT-Archiveform | BIN | 0.02147 | 0.05685 | +| OCR-D-GT-Archiveform | RGB | 0.01636 | 0.06285 | + +#### CNN-RNN model: model_eynollah_ocr_cnnrnn_20250904 (Default) + +Compared to the model_eynollah_ocr_cnnrnn_20250805 model, this model is trained on a larger proportion of Antiqua data and achieves superior performance. + +| Dataset | Input | CER | WER | +|-----------------------|:------------|:-----------|:----------| +| OCR-D-GT-Archiveform | BIN | 0.01635 | 0.05410 | +| OCR-D-GT-Archiveform | RGB | 0.01471 | 0.05813 | +| BLN600 | RGB | 0.04409 | 0.08879 | +| BLN600 | Enhanced | 0.03599 | 0.06244 | + + +#### Transformer OCR model: model_eynollah_ocr_trocr_20250919 + +This transformer OCR model is trained on the same data as model_eynollah_ocr_trocr_20250919. + +| Dataset | Input | CER | WER | +|-----------------------|:------------|:-----------|:----------| +| OCR-D-GT-Archiveform | BIN | 0.01841 | 0.05589 | +| OCR-D-GT-Archiveform | RGB | 0.01552 | 0.06177 | +| BLN600 | RGB | 0.06347 | 0.13853 | + +##### Qualitative evaluation of the models + +| | | | | +|:---:|:---:|:---:|:---:| +| Image | cnnrnn_20250805 | cnnrnn_20250904 | trocr_20250919 | + + + +| | | | | +|:---:|:---:|:---:|:---:| +| Image | cnnrnn_20250805 | cnnrnn_20250904 | trocr_20250919 | + + +| | | | | +|:---:|:---:|:---:|:---:| +| Image | cnnrnn_20250805 | cnnrnn_20250904 | trocr_20250919 | + + +| | | | | +|:---:|:---:|:---:|:---:| +| Image | cnnrnn_20250805 | cnnrnn_20250904 | trocr_20250919 | + + ## Heuristic methods Additionally, some heuristic methods are employed to further improve the model predictions: * After border detection, the largest contour is determined by a bounding box, and the image cropped to these coordinates. -* For text region detection, the image is scaled up to make it easier for the model to detect background space between text regions. +* Unlike the non-light version, where the image is scaled up to help the model better detect the background spaces between text regions, the light version uses down-scaled images. In this case, introducing an artificial class along the boundaries of text regions and text lines has helped to isolate and separate the text regions more effectively. * A minimum area is defined for text regions in relation to the overall image dimensions, so that very small regions that are noise can be filtered out. -* Deskewing is applied on the text region level (due to regions having different degrees of skew) in order to improve the textline segmentation result. -* After deskewing, a calculation of the pixel distribution on the X-axis allows the separation of textlines (foreground) and background pixels. -* Finally, using the derived coordinates, bounding boxes are determined for each textline. +* In the non-light version, deskewing is applied at the text-region level (since regions may have different degrees of skew) to improve text-line segmentation results. In contrast, the light version performs deskewing only at the page level to enhance margin detection and heuristic reading-order estimation. +* After deskewing, a calculation of the pixel distribution on the X-axis allows the separation of textlines (foreground) and background pixels (only in non-light version). +* Finally, using the derived coordinates, bounding boxes are determined for each textline (only in non-light version). +* As mentioned above, the reading order can be determined using a model; however, this approach is computationally expensive, time-consuming, and less accurate due to the limited amount of ground-truth data available for training. Therefore, our tool uses a heuristic reading-order detection method as the default. The heuristic approach relies on headers and separators to determine the reading order of text regions. diff --git a/docs/ocrd.md b/docs/ocrd.md new file mode 100644 index 0000000..9e7e268 --- /dev/null +++ b/docs/ocrd.md @@ -0,0 +1,26 @@ +## Use as OCR-D processor + +Eynollah ships with a CLI interface to be used as [OCR-D](https://ocr-d.de) [processor](https://ocr-d.de/en/spec/cli), +formally described in [`ocrd-tool.json`](https://github.com/qurator-spk/eynollah/tree/main/src/eynollah/ocrd-tool.json). + +When using Eynollah in OCR-D, the source image file group with (preferably) RGB images should be used as input like this: + + ocrd-eynollah-segment -I OCR-D-IMG -O OCR-D-SEG -P models eynollah_layout_v0_5_0 + +If the input file group is PAGE-XML (from a previous OCR-D workflow step), Eynollah behaves as follows: +- existing regions are kept and ignored (i.e. in effect they might overlap segments from Eynollah results) +- existing annotation (and respective `AlternativeImage`s) are partially _ignored_: + - previous page frame detection (`cropped` images) + - previous derotation (`deskewed` images) + - previous thresholding (`binarized` images) +- if the page-level image nevertheless deviates from the original (`@imageFilename`) + (because some other preprocessing step was in effect like `denoised`), then + the output PAGE-XML will be based on that as new top-level (`@imageFilename`) + + ocrd-eynollah-segment -I OCR-D-XYZ -O OCR-D-SEG -P models eynollah_layout_v0_5_0 + +In general, it makes more sense to add other workflow steps **after** Eynollah. + +There is also an OCR-D processor for binarization: + + ocrd-sbb-binarize -I OCR-D-IMG -O OCR-D-BIN -P models default-2021-03-09 diff --git a/docs/train.md b/docs/train.md index 252bead..ffa39a9 100644 --- a/docs/train.md +++ b/docs/train.md @@ -1,3 +1,41 @@ +# Prerequisistes + +## 1. Install Eynollah with training dependencies + +Clone the repository and install eynollah along with the dependencies necessary for training: + +```sh +git clone https://github.com/qurator-spk/eynollah +cd eynollah +pip install '.[training]' +``` + +## 2. Pretrained encoder + +Download our pretrained weights and add them to a `train/pretrained_model` folder: + +```sh +cd train +wget -O pretrained_model.tar.gz https://zenodo.org/records/17243320/files/pretrained_model_v0_5_1.tar.gz?download=1 +tar xf pretrained_model.tar.gz +``` + +## 3. Example data + +### Binarization +A small sample of training data for binarization experiment can be found on [Zenodo](https://zenodo.org/records/17243320/files/training_data_sample_binarization_v0_5_1.tar.gz?download=1), +which contains `images` and `labels` folders. + +## 4. Helpful tools + +* [`pagexml2img`](https://github.com/qurator-spk/page2img) +> Tool to extract 2-D or 3-D RGB images from PAGE-XML data. In the former case, the output will be 1 2-D image array which each class has filled with a pixel value. In the case of a 3-D RGB image, +each class will be defined with a RGB value and beside images, a text file of classes will also be produced. +* [`cocoSegmentationToPng`](https://github.com/nightrome/cocostuffapi/blob/17acf33aef3c6cc2d6aca46dcf084266c2778cf0/PythonAPI/pycocotools/cocostuffhelper.py#L130) +> Convert COCO GT or results for a single image to a segmentation map and write it to disk. +* [`ocrd-segment-extract-pages`](https://github.com/OCR-D/ocrd_segment/blob/master/ocrd_segment/extract_pages.py) +> Extract region classes and their colours in mask (pseg) images. Allows the color map as free dict parameter, and comes with a default that mimics PageViewer's coloring for quick debugging; it also warns when regions do overlap. + # Training documentation This document aims to assist users in preparing training datasets, training models, and diff --git a/pyproject.toml b/pyproject.toml index e7744a1..fde7967 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,12 @@ description = "Document Layout Analysis" readme = "README.md" license.file = "LICENSE" requires-python = ">=3.8" -keywords = ["document layout analysis", "image segmentation"] +keywords = [ + "document layout analysis", + "image segmentation", + "binarization", + "optical character recognition" +] dynamic = [ "dependencies", @@ -25,6 +30,10 @@ classifiers = [ "Intended Audience :: Science/Research", "License :: OSI Approved :: Apache Software License", "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3 :: Only", "Topic :: Scientific/Engineering :: Image Processing", ] diff --git a/requirements.txt b/requirements.txt index db1d7df..bbacd48 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,3 +6,4 @@ tensorflow < 2.13 numba <= 0.58.1 scikit-image biopython +tabulate diff --git a/src/eynollah/cli.py b/src/eynollah/cli.py index c9bad52..9787054 100644 --- a/src/eynollah/cli.py +++ b/src/eynollah/cli.py @@ -1,15 +1,67 @@ -import sys -import click +from dataclasses import dataclass import logging -from ocrd_utils import initLogging, getLevelName, getLogger -from eynollah.eynollah import Eynollah, Eynollah_ocr -from eynollah.sbb_binarize import SbbBinarizer -from eynollah.image_enhancer import Enhancer -from eynollah.mb_ro_on_layout import machine_based_reading_order_on_layout +import sys +import os +from typing import Union + +import click + +# NOTE: For debugging/predictable order of imports +from .eynollah_imports import imported_libs +from .model_zoo import EynollahModelZoo +from .cli_models import models_cli + +@dataclass() +class EynollahCliCtx: + """ + Holds options relevant for all eynollah subcommands + """ + model_zoo: EynollahModelZoo + log_level : Union[str, None] = 'INFO' + @click.group() -def main(): - pass +@click.option( + "--model-basedir", + "-m", + help="directory of models", + type=click.Path(exists=True), + default=f'{os.getcwd()}/models_eynollah', +) +@click.option( + "--model-overrides", + "-mv", + help="override default versions of model categories, syntax is 'CATEGORY VARIANT PATH', e.g 'region light /path/to/model'. See eynollah list-models for the full list", + type=(str, str, str), + multiple=True, +) +@click.option( + "--log_level", + "-l", + type=click.Choice(['OFF', 'DEBUG', 'INFO', 'WARN', 'ERROR']), + help="Override log level globally to this", +) +@click.pass_context +def main(ctx, model_basedir, model_overrides, log_level): + """ + eynollah - Document Layout Analysis, Image Enhancement, OCR + """ + # Initialize logging + console_handler = logging.StreamHandler(sys.stdout) + console_handler.setLevel(logging.NOTSET) + formatter = logging.Formatter('%(asctime)s.%(msecs)03d %(levelname)s %(name)s - %(message)s', datefmt='%H:%M:%S') + console_handler.setFormatter(formatter) + logging.getLogger('eynollah').addHandler(console_handler) + logging.getLogger('eynollah').setLevel(log_level or logging.INFO) + # Initialize model zoo + model_zoo = EynollahModelZoo(basedir=model_basedir, model_overrides=model_overrides) + # Initialize CLI context + ctx.obj = EynollahCliCtx( + model_zoo=model_zoo, + log_level=log_level, + ) + +main.add_command(models_cli, 'models') @main.command() @click.option( @@ -31,26 +83,14 @@ def main(): type=click.Path(exists=True, file_okay=False), required=True, ) -@click.option( - "--model", - "-m", - help="directory of models", - type=click.Path(exists=True, file_okay=False), - required=True, -) -@click.option( - "--log_level", - "-l", - type=click.Choice(['OFF', 'DEBUG', 'INFO', 'WARN', 'ERROR']), - help="Override log level globally to this", -) - -def machine_based_reading_order(input, dir_in, out, model, log_level): +@click.pass_context +def machine_based_reading_order(ctx, input, dir_in, out): + """ + Generate ReadingOrder with a ML model + """ + from eynollah.mb_ro_on_layout import machine_based_reading_order_on_layout assert bool(input) != bool(dir_in), "Either -i (single input) or -di (directory) must be provided, but not both." - orderer = machine_based_reading_order_on_layout(model) - if log_level: - orderer.logger.setLevel(getLevelName(log_level)) - + orderer = machine_based_reading_order_on_layout(model_zoo=ctx.obj.model_zoo) orderer.run(xml_filename=input, dir_in=dir_in, dir_out=out, @@ -59,7 +99,6 @@ def machine_based_reading_order(input, dir_in, out, model, log_level): @main.command() @click.option('--patches/--no-patches', default=True, help='by enabling this parameter you let the model to see the image in patches.') -@click.option('--model_dir', '-m', type=click.Path(exists=True, file_okay=False), required=True, help='directory containing models for prediction') @click.option( "--input-image", "--image", "-i", @@ -80,17 +119,33 @@ def machine_based_reading_order(input, dir_in, out, model, log_level): required=True, ) @click.option( - "--log_level", - "-l", - type=click.Choice(['OFF', 'DEBUG', 'INFO', 'WARN', 'ERROR']), - help="Override log level globally to this", + '-M', + '--mode', + type=click.Choice(['single', 'multi']), + default='single', + help="Whether to use the (newer and faster) single-model binarization or the (slightly better) multi-model binarization" ) -def binarization(patches, model_dir, input_image, dir_in, output, log_level): +@click.pass_context +def binarization( + ctx, + patches, + input_image, + mode, + dir_in, + output, +): + """ + Binarize images with a ML model + """ + from eynollah.sbb_binarize import SbbBinarizer assert bool(input_image) != bool(dir_in), "Either -i (single input) or -di (directory) must be provided, but not both." - binarizer = SbbBinarizer(model_dir) - if log_level: - binarizer.log.setLevel(getLevelName(log_level)) - binarizer.run(image_path=input_image, use_patches=patches, output=output, dir_in=dir_in) + binarizer = SbbBinarizer(model_zoo=ctx.obj.model_zoo, mode=mode) + binarizer.run( + image_path=input_image, + use_patches=patches, + output=output, + dir_in=dir_in + ) @main.command() @@ -120,14 +175,6 @@ def binarization(patches, model_dir, input_image, dir_in, output, log_level): help="directory of input images (instead of --image)", type=click.Path(exists=True, file_okay=False), ) -@click.option( - "--model", - "-m", - help="directory of models", - type=click.Path(exists=True, file_okay=False), - required=True, -) - @click.option( "--num_col_upper", "-ncu", @@ -144,24 +191,19 @@ def binarization(patches, model_dir, input_image, dir_in, output, log_level): is_flag=True, help="if this parameter set to true, this tool will save the enhanced image in org scale.", ) -@click.option( - "--log_level", - "-l", - type=click.Choice(['OFF', 'DEBUG', 'INFO', 'WARN', 'ERROR']), - help="Override log level globally to this", -) - -def enhancement(image, out, overwrite, dir_in, model, num_col_upper, num_col_lower, save_org_scale, log_level): +@click.pass_context +def enhancement(ctx, image, out, overwrite, dir_in, num_col_upper, num_col_lower, save_org_scale): + """ + Enhance image + """ assert bool(image) != bool(dir_in), "Either -i (single input) or -di (directory) must be provided, but not both." - initLogging() + from .image_enhancer import Enhancer enhancer = Enhancer( - model, + model_zoo=ctx.obj.model_zoo, num_col_upper=num_col_upper, num_col_lower=num_col_lower, save_org_scale=save_org_scale, ) - if log_level: - enhancer.logger.setLevel(getLevelName(log_level)) enhancer.run(overwrite=overwrite, dir_in=dir_in, image_filename=image, @@ -195,20 +237,6 @@ def enhancement(image, out, overwrite, dir_in, model, num_col_upper, num_col_low help="directory of input images (instead of --image)", type=click.Path(exists=True, file_okay=False), ) -@click.option( - "--model", - "-m", - help="directory of models", - type=click.Path(exists=True, file_okay=False), - required=True, -) -@click.option( - "--model_version", - "-mv", - help="override default versions of model categories", - type=(str, str), - multiple=True, -) @click.option( "--save_images", "-si", @@ -366,30 +394,45 @@ def enhancement(image, out, overwrite, dir_in, model, num_col_upper, num_col_low is_flag=True, help="if this parameter set to true, this tool will ignore layout detection and reading order. It means that textline detection will be done within printspace and contours of textline will be written in xml output file.", ) -# TODO move to top-level CLI context -@click.option( - "--log_level", - "-l", - type=click.Choice(['OFF', 'DEBUG', 'INFO', 'WARN', 'ERROR']), - help="Override 'eynollah' log level globally to this", -) -# -@click.option( - "--setup-logging", - is_flag=True, - help="Setup a basic console logger", -) - -def layout(image, out, overwrite, dir_in, model, model_version, save_images, save_layout, save_deskewed, save_all, extract_only_images, save_page, enable_plotting, allow_enhancement, curved_line, textline_light, full_layout, tables, right2left, input_binary, allow_scaling, headers_off, light_version, reading_order_machine_based, do_ocr, transformer_ocr, batch_size_ocr, num_col_upper, num_col_lower, threshold_art_class_textline, threshold_art_class_layout, skip_layout_and_reading_order, ignore_page_extraction, log_level, setup_logging): - if setup_logging: - console_handler = logging.StreamHandler(sys.stdout) - console_handler.setLevel(logging.INFO) - formatter = logging.Formatter('%(message)s') - console_handler.setFormatter(formatter) - getLogger('eynollah').addHandler(console_handler) - getLogger('eynollah').setLevel(logging.INFO) - else: - initLogging() +@click.pass_context +def layout( + ctx, + image, + out, + overwrite, + dir_in, + save_images, + save_layout, + save_deskewed, + save_all, + extract_only_images, + save_page, + enable_plotting, + allow_enhancement, + curved_line, + textline_light, + full_layout, + tables, + right2left, + input_binary, + allow_scaling, + headers_off, + light_version, + reading_order_machine_based, + do_ocr, + transformer_ocr, + batch_size_ocr, + num_col_upper, + num_col_lower, + threshold_art_class_textline, + threshold_art_class_layout, + skip_layout_and_reading_order, + ignore_page_extraction, +): + """ + Detect Layout (with optional image enhancement and reading order detection) + """ + from .eynollah import Eynollah assert enable_plotting or not save_layout, "Plotting with -sl also requires -ep" assert enable_plotting or not save_deskewed, "Plotting with -sd also requires -ep" assert enable_plotting or not save_all, "Plotting with -sa also requires -ep" @@ -410,8 +453,7 @@ def layout(image, out, overwrite, dir_in, model, model_version, save_images, sav assert not extract_only_images or not headers_off, "Image extraction -eoi can not be set alongside headers_off -ho" assert bool(image) != bool(dir_in), "Either -i (single input) or -di (directory) must be provided, but not both." eynollah = Eynollah( - model, - model_versions=model_version, + model_zoo=ctx.obj.model_zoo, extract_only_images=extract_only_images, enable_plotting=enable_plotting, allow_enhancement=allow_enhancement, @@ -435,8 +477,6 @@ def layout(image, out, overwrite, dir_in, model, model_version, save_images, sav threshold_art_class_textline=threshold_art_class_textline, threshold_art_class_layout=threshold_art_class_layout, ) - if log_level: - eynollah.logger.setLevel(getLevelName(log_level)) eynollah.run(overwrite=overwrite, image_filename=image, dir_in=dir_in, @@ -493,17 +533,6 @@ def layout(image, out, overwrite, dir_in, model, model_version, save_images, sav help="overwrite (instead of skipping) if output xml exists", is_flag=True, ) -@click.option( - "--model", - "-m", - help="directory of models", - type=click.Path(exists=True, file_okay=False), -) -@click.option( - "--model_name", - help="Specific model file path to use for OCR", - type=click.Path(exists=True, file_okay=False), -) @click.option( "--tr_ocr", "-trocr/-notrocr", @@ -537,35 +566,42 @@ def layout(image, out, overwrite, dir_in, model, model_version, save_images, sav "-min_conf", help="minimum OCR confidence value. Text lines with a confidence value lower than this threshold will not be included in the output XML file.", ) -@click.option( - "--log_level", - "-l", - type=click.Choice(['OFF', 'DEBUG', 'INFO', 'WARN', 'ERROR']), - help="Override log level globally to this", -) - -def ocr(image, dir_in, dir_in_bin, dir_xmls, out, dir_out_image_text, overwrite, model, model_name, tr_ocr, export_textline_images_and_text, do_not_mask_with_textline_contour, batch_size, dataset_abbrevation, min_conf_value_of_textline_text, log_level): - initLogging() - - assert bool(model) != bool(model_name), "Either -m (model directory) or --model_name (specific model name) must be provided." +@click.pass_context +def ocr( + ctx, + image, + dir_in, + dir_in_bin, + dir_xmls, + out, + dir_out_image_text, + overwrite, + tr_ocr, + export_textline_images_and_text, + do_not_mask_with_textline_contour, + batch_size, + dataset_abbrevation, + min_conf_value_of_textline_text, +): + """ + Recognize text with a CNN/RNN or transformer ML model. + """ assert not export_textline_images_and_text or not tr_ocr, "Exporting textline and text -etit can not be set alongside transformer ocr -tr_ocr" - assert not export_textline_images_and_text or not model, "Exporting textline and text -etit can not be set alongside model -m" + # FIXME: refactor: move export_textline_images_and_text out of eynollah.py + # assert not export_textline_images_and_text or not model, "Exporting textline and text -etit can not be set alongside model -m" assert not export_textline_images_and_text or not batch_size, "Exporting textline and text -etit can not be set alongside batch size -bs" assert not export_textline_images_and_text or not dir_in_bin, "Exporting textline and text -etit can not be set alongside directory of bin images -dib" assert not export_textline_images_and_text or not dir_out_image_text, "Exporting textline and text -etit can not be set alongside directory of images with predicted text -doit" assert bool(image) != bool(dir_in), "Either -i (single image) or -di (directory) must be provided, but not both." + from .eynollah_ocr import Eynollah_ocr eynollah_ocr = Eynollah_ocr( - dir_models=model, - model_name=model_name, + model_zoo=ctx.obj.model_zoo, tr_ocr=tr_ocr, export_textline_images_and_text=export_textline_images_and_text, do_not_mask_with_textline_contour=do_not_mask_with_textline_contour, batch_size=batch_size, pref_of_dataset=dataset_abbrevation, - min_conf_value_of_textline_text=min_conf_value_of_textline_text, - ) - if log_level: - eynollah_ocr.logger.setLevel(getLevelName(log_level)) + min_conf_value_of_textline_text=min_conf_value_of_textline_text) eynollah_ocr.run(overwrite=overwrite, dir_in=dir_in, dir_in_bin=dir_in_bin, diff --git a/src/eynollah/cli_models.py b/src/eynollah/cli_models.py new file mode 100644 index 0000000..f3de596 --- /dev/null +++ b/src/eynollah/cli_models.py @@ -0,0 +1,69 @@ +from pathlib import Path +from typing import Set, Tuple +import click + +from eynollah.model_zoo.default_specs import MODELS_VERSION + +@click.group() +@click.pass_context +def models_cli( + ctx, +): + """ + Organize models for the various runners in eynollah. + """ + assert ctx.obj.model_zoo + + +@models_cli.command('list') +@click.pass_context +def list_models( + ctx, +): + """ + List all the models in the zoo + """ + print(f"Model basedir: {ctx.obj.model_zoo.model_basedir}") + print(f"Model overrides: {ctx.obj.model_zoo.model_overrides}") + print(ctx.obj.model_zoo) + + +@models_cli.command('package') +@click.option( + '--set-version', '-V', 'version', help="Version to use for packaging", default=MODELS_VERSION, show_default=True +) +@click.argument('output_dir') +@click.pass_context +def package( + ctx, + version, + output_dir, +): + """ + Generate shell code to copy all the models in the zoo into properly named folders in OUTPUT_DIR for distribution. + + eynollah models -m SRC package OUTPUT_DIR + + SRC should contain a directory "models_eynollah" containing all the models. + """ + mkdirs: Set[Path] = set([]) + copies: Set[Tuple[Path, Path]] = set([]) + for spec in ctx.obj.model_zoo.specs.specs: + # skip these as they are dependent on the ocr model + if spec.category in ('num_to_char', 'characters'): + continue + src: Path = ctx.obj.model_zoo.model_path(spec.category, spec.variant) + # Only copy the top-most directory relative to models_eynollah + while src.parent.name != 'models_eynollah': + src = src.parent + for dist in spec.dists: + dist_dir = Path(f"{output_dir}/models_{dist}_{version}/models_eynollah") + copies.add((src, dist_dir)) + mkdirs.add(dist_dir) + for dir in mkdirs: + print(f"mkdir -vp {dir}") + for (src, dst) in copies: + print(f"cp -vr {src} {dst}") + for dir in mkdirs: + zip_path = Path(f'../{dir.parent.name}.zip') + print(f"(cd {dir}/..; zip -vr {zip_path} models_eynollah)") diff --git a/src/eynollah/eynollah.py b/src/eynollah/eynollah.py index 13acba6..dc90f1d 100644 --- a/src/eynollah/eynollah.py +++ b/src/eynollah/eynollah.py @@ -1,47 +1,63 @@ +""" +document layout analysis (segmentation) with output in PAGE-XML +""" # pylint: disable=no-member,invalid-name,line-too-long,missing-function-docstring,missing-class-docstring,too-many-branches # pylint: disable=too-many-locals,wrong-import-position,too-many-lines,too-many-statements,chained-comparison,fixme,broad-except,c-extension-no-member # pylint: disable=too-many-public-methods,too-many-arguments,too-many-instance-attributes,too-many-public-methods, # pylint: disable=consider-using-enumerate -""" -document layout analysis (segmentation) with output in PAGE-XML -""" +# FIXME: fix all of those... +# pyright: reportUnnecessaryTypeIgnoreComment=true +# pyright: reportPossiblyUnboundVariable=false +# pyright: reportMissingImports=false +# pyright: reportCallIssue=false +# pyright: reportOperatorIssue=false +# pyright: reportUnboundVariable=false +# pyright: reportArgumentType=false +# pyright: reportAttributeAccessIssue=false +# pyright: reportOptionalMemberAccess=false +# pyright: reportGeneralTypeIssues=false +# pyright: reportOptionalSubscript=false + +import logging +import sys # cannot use importlib.resources until we move to 3.9+ forimportlib.resources.files -import sys if sys.version_info < (3, 10): import importlib_resources else: import importlib.resources as importlib_resources from difflib import SequenceMatcher as sq -from PIL import Image, ImageDraw, ImageFont import math import os -import sys import time -from typing import Dict, List, Optional, Tuple -import atexit -import warnings +from typing import Optional from functools import partial from pathlib import Path from multiprocessing import cpu_count import gc import copy -import json from concurrent.futures import ProcessPoolExecutor -import xml.etree.ElementTree as ET import cv2 import numpy as np import shapely.affinity from scipy.signal import find_peaks from scipy.ndimage import gaussian_filter1d -from numba import cuda from skimage.morphology import skeletonize -from ocrd import OcrdPage -from ocrd_utils import getLogger, tf_disable_interactive_logs +from ocrd_utils import tf_disable_interactive_logs import statistics +tf_disable_interactive_logs() + +import tensorflow as tf +# warnings.filterwarnings("ignore") +from tensorflow.python.keras import backend as K +from tensorflow.keras.models import load_model +# use tf1 compatibility for keras backend +from tensorflow.compat.v1.keras.backend import set_session +from tensorflow.keras import layers +from tensorflow.keras.layers import StringLookup try: import torch except ImportError: @@ -50,23 +66,8 @@ try: import matplotlib.pyplot as plt except ImportError: plt = None -try: - from transformers import TrOCRProcessor, VisionEncoderDecoderModel -except ImportError: - TrOCRProcessor = VisionEncoderDecoderModel = None - -#os.environ['CUDA_VISIBLE_DEVICES'] = '-1' -tf_disable_interactive_logs() -import tensorflow as tf -from tensorflow.python.keras import backend as K -from tensorflow.keras.models import load_model -tf.get_logger().setLevel("ERROR") -warnings.filterwarnings("ignore") -# use tf1 compatibility for keras backend -from tensorflow.compat.v1.keras.backend import set_session -from tensorflow.keras import layers -from tensorflow.keras.layers import StringLookup +from .model_zoo import EynollahModelZoo from .utils.contour import ( filter_contours_area_of_image, filter_contours_area_of_image_tables, @@ -155,59 +156,12 @@ patch_size = 1 num_patches =21*21#14*14#28*28#14*14#28*28 -class Patches(layers.Layer): - def __init__(self, **kwargs): - super(Patches, self).__init__() - self.patch_size = patch_size - - def call(self, images): - batch_size = tf.shape(images)[0] - patches = tf.image.extract_patches( - images=images, - sizes=[1, self.patch_size, self.patch_size, 1], - strides=[1, self.patch_size, self.patch_size, 1], - rates=[1, 1, 1, 1], - padding="VALID", - ) - patch_dims = patches.shape[-1] - patches = tf.reshape(patches, [batch_size, -1, patch_dims]) - return patches - def get_config(self): - - config = super().get_config().copy() - config.update({ - 'patch_size': self.patch_size, - }) - return config - -class PatchEncoder(layers.Layer): - def __init__(self, **kwargs): - super(PatchEncoder, self).__init__() - self.num_patches = num_patches - self.projection = layers.Dense(units=projection_dim) - self.position_embedding = layers.Embedding( - input_dim=num_patches, output_dim=projection_dim - ) - - def call(self, patch): - positions = tf.range(start=0, limit=self.num_patches, delta=1) - encoded = self.projection(patch) + self.position_embedding(positions) - return encoded - def get_config(self): - - config = super().get_config().copy() - config.update({ - 'num_patches': self.num_patches, - 'projection': self.projection, - 'position_embedding': self.position_embedding, - }) - return config class Eynollah: def __init__( self, - dir_models : str, - model_versions: List[Tuple[str, str]] = [], + *, + model_zoo: EynollahModelZoo, extract_only_images : bool =False, enable_plotting : bool = False, allow_enhancement : bool = False, @@ -230,8 +184,10 @@ class Eynollah: threshold_art_class_layout: Optional[float] = None, threshold_art_class_textline: Optional[float] = None, skip_layout_and_reading_order : bool = False, + logger : Optional[logging.Logger] = None, ): - self.logger = getLogger('eynollah') + self.logger = logger or logging.getLogger('eynollah') + self.model_zoo = model_zoo self.plotter = None if skip_layout_and_reading_order: @@ -245,6 +201,7 @@ class Eynollah: self.full_layout = full_layout self.tables = tables self.right2left = right2left + # --input-binary sensible if image is very dark, if layout is not working. self.input_binary = input_binary self.allow_scaling = allow_scaling self.headers_off = headers_off @@ -297,93 +254,11 @@ class Eynollah: self.logger.warning("no GPU device available") self.logger.info("Loading models...") - self.setup_models(dir_models, model_versions) + self.setup_models() self.logger.info(f"Model initialization complete ({time.time() - t_start:.1f}s)") - @staticmethod - def our_load_model(model_file, basedir=""): - if basedir: - model_file = os.path.join(basedir, model_file) - if model_file.endswith('.h5') and Path(model_file[:-3]).exists(): - # prefer SavedModel over HDF5 format if it exists - model_file = model_file[:-3] - try: - model = load_model(model_file, compile=False) - except: - model = load_model(model_file, compile=False, custom_objects={ - "PatchEncoder": PatchEncoder, "Patches": Patches}) - return model + def setup_models(self): - def setup_models(self, basedir: Path, model_versions: List[Tuple[str, str]] = []): - self.model_versions = { - "enhancement": "eynollah-enhancement_20210425", - "binarization": "eynollah-binarization_20210425", - "col_classifier": "eynollah-column-classifier_20210425", - "page": "model_eynollah_page_extraction_20250915", - #?: "eynollah-main-regions-aug-scaling_20210425", - "region": ( # early layout - "eynollah-main-regions_20231127_672_org_ens_11_13_16_17_18" if self.extract_only_images else - "eynollah-main-regions_20220314" if self.light_version else - "eynollah-main-regions-ensembled_20210425"), - "region_p2": ( # early layout, non-light, 2nd part - "eynollah-main-regions-aug-rotation_20210425"), - "region_1_2": ( # early layout, light, 1-or-2-column - #"modelens_12sp_elay_0_3_4__3_6_n" - #"modelens_earlylayout_12spaltige_2_3_5_6_7_8" - #"modelens_early12_sp_2_3_5_6_7_8_9_10_12_14_15_16_18" - #"modelens_1_2_4_5_early_lay_1_2_spaltige" - #"model_3_eraly_layout_no_patches_1_2_spaltige" - "modelens_e_l_all_sp_0_1_2_3_4_171024"), - "region_fl_np": ( # full layout / no patches - #"modelens_full_lay_1_3_031124" - #"modelens_full_lay_13__3_19_241024" - #"model_full_lay_13_241024" - #"modelens_full_lay_13_17_231024" - #"modelens_full_lay_1_2_221024" - #"eynollah-full-regions-1column_20210425" - "modelens_full_lay_1__4_3_091124"), - "region_fl": ( # full layout / with patches - #"eynollah-full-regions-3+column_20210425" - ##"model_2_full_layout_new_trans" - #"modelens_full_lay_1_3_031124" - #"modelens_full_lay_13__3_19_241024" - #"model_full_lay_13_241024" - #"modelens_full_lay_13_17_231024" - #"modelens_full_lay_1_2_221024" - #"modelens_full_layout_24_till_28" - #"model_2_full_layout_new_trans" - "modelens_full_lay_1__4_3_091124"), - "reading_order": ( - #"model_mb_ro_aug_ens_11" - #"model_step_3200000_mb_ro" - #"model_ens_reading_order_machine_based" - #"model_mb_ro_aug_ens_8" - #"model_ens_reading_order_machine_based" - "model_eynollah_reading_order_20250824"), - "textline": ( - #"modelens_textline_1_4_16092024" - #"model_textline_ens_3_4_5_6_artificial" - #"modelens_textline_1_3_4_20240915" - #"model_textline_ens_3_4_5_6_artificial" - #"modelens_textline_9_12_13_14_15" - #"eynollah-textline_light_20210425" - "modelens_textline_0_1__2_4_16092024" if self.textline_light else - #"eynollah-textline_20210425" - "modelens_textline_0_1__2_4_16092024"), - "table": ( - None if not self.tables else - "modelens_table_0t4_201124" if self.light_version else - "eynollah-tables_20210319"), - "ocr": ( - None if not self.ocr else - "model_eynollah_ocr_trocr_20250919" if self.tr else - "model_eynollah_ocr_cnnrnn_20250930") - } - # override defaults from CLI - for key, val in model_versions: - assert key in self.model_versions, "unknown model category '%s'" % key - self.logger.warning("overriding default model %s version %s to %s", key, self.model_versions[key], val) - self.model_versions[key] = val # load models, depending on modes # (note: loading too many models can cause OOM on GPU/CUDA, # thus, we try set up the minimal configuration for the current mode) @@ -391,10 +266,10 @@ class Eynollah: "col_classifier", "binarization", "page", - "region" + ("region", 'extract_only_images' if self.extract_only_images else 'light' if self.light_version else '') ] if not self.extract_only_images: - loadable.append("textline") + loadable.append(("textline", 'light' if self.light_version else '')) if self.light_version: loadable.append("region_1_2") else: @@ -407,47 +282,34 @@ class Eynollah: if self.reading_order_machine_based: loadable.append("reading_order") if self.tables: - loadable.append("table") - - self.models = {name: self.our_load_model(self.model_versions[name], basedir) - for name in loadable - } + loadable.append(("table", 'light' if self.light_version else '')) if self.ocr: - ocr_model_dir = os.path.join(basedir, self.model_versions["ocr"]) if self.tr: - self.models["ocr"] = VisionEncoderDecoderModel.from_pretrained(ocr_model_dir) - if torch.cuda.is_available(): - self.logger.info("Using GPU acceleration") - self.device = torch.device("cuda:0") - else: - self.logger.info("Using CPU processing") - self.device = torch.device("cpu") - #self.processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten") - self.processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-printed") + loadable.append(('ocr', 'tr')) + loadable.append(('trocr_processor', '')) else: - ocr_model = load_model(ocr_model_dir, compile=False) - self.models["ocr"] = tf.keras.models.Model( - ocr_model.get_layer(name = "image").input, - ocr_model.get_layer(name = "dense2").output) - - with open(os.path.join(ocr_model_dir, "characters_org.txt"), "r") as config_file: - characters = json.load(config_file) - # Mapping characters to integers. - char_to_num = StringLookup(vocabulary=list(characters), mask_token=None) - # Mapping integers back to original characters. - self.num_to_char = StringLookup( - vocabulary=char_to_num.get_vocabulary(), mask_token=None, invert=True - ) + loadable.append('ocr') + loadable.append('num_to_char') + + self.model_zoo.load_models(*loadable) def __del__(self): if hasattr(self, 'executor') and getattr(self, 'executor'): + assert self.executor self.executor.shutdown() self.executor = None - if hasattr(self, 'models') and getattr(self, 'models'): - for model_name in list(self.models): - if self.models[model_name]: - del self.models[model_name] + self.model_zoo.shutdown() + + @property + def device(self): + # TODO why here and why only for tr? + assert torch + if torch.cuda.is_available(): + self.logger.info("Using GPU acceleration") + return torch.device("cuda:0") + self.logger.info("Using CPU processing") + return torch.device("cpu") def cache_images(self, image_filename=None, image_pil=None, dpi=None): ret = {} @@ -494,8 +356,8 @@ class Eynollah: def predict_enhancement(self, img): self.logger.debug("enter predict_enhancement") - img_height_model = self.models["enhancement"].layers[-1].output_shape[1] - img_width_model = self.models["enhancement"].layers[-1].output_shape[2] + img_height_model = self.model_zoo.get("enhancement").layers[-1].output_shape[1] + img_width_model = self.model_zoo.get("enhancement").layers[-1].output_shape[2] if img.shape[0] < img_height_model: img = cv2.resize(img, (img.shape[1], img_width_model), interpolation=cv2.INTER_NEAREST) if img.shape[1] < img_width_model: @@ -536,7 +398,7 @@ class Eynollah: index_y_d = img_h - img_height_model img_patch = img[np.newaxis, index_y_d:index_y_u, index_x_d:index_x_u, :] - label_p_pred = self.models["enhancement"].predict(img_patch, verbose=0) + label_p_pred = self.model_zoo.get("enhancement").predict(img_patch, verbose=0) seg = label_p_pred[0, :, :, :] * 255 if i == 0 and j == 0: @@ -711,7 +573,7 @@ class Eynollah: img_in[0, :, :, 1] = img_1ch[:, :] img_in[0, :, :, 2] = img_1ch[:, :] - label_p_pred = self.models["col_classifier"].predict(img_in, verbose=0) + label_p_pred = self.model_zoo.get("col_classifier").predict(img_in, verbose=0) num_col = np.argmax(label_p_pred[0]) + 1 self.logger.info("Found %s columns (%s)", num_col, label_p_pred) @@ -729,7 +591,7 @@ class Eynollah: self.logger.info("Detected %s DPI", dpi) if self.input_binary: img = self.imread() - prediction_bin = self.do_prediction(True, img, self.models["binarization"], n_batch_inference=5) + prediction_bin = self.do_prediction(True, img, self.model_zoo.get("binarization"), n_batch_inference=5) prediction_bin = 255 * (prediction_bin[:,:,0] == 0) prediction_bin = np.repeat(prediction_bin[:, :, np.newaxis], 3, axis=2).astype(np.uint8) img= np.copy(prediction_bin) @@ -769,7 +631,7 @@ class Eynollah: img_in[0, :, :, 1] = img_1ch[:, :] img_in[0, :, :, 2] = img_1ch[:, :] - label_p_pred = self.models["col_classifier"].predict(img_in, verbose=0) + label_p_pred = self.model_zoo.get("col_classifier").predict(img_in, verbose=0) num_col = np.argmax(label_p_pred[0]) + 1 elif (self.num_col_upper and self.num_col_lower) and (self.num_col_upper!=self.num_col_lower): @@ -790,7 +652,7 @@ class Eynollah: img_in[0, :, :, 1] = img_1ch[:, :] img_in[0, :, :, 2] = img_1ch[:, :] - label_p_pred = self.models["col_classifier"].predict(img_in, verbose=0) + label_p_pred = self.model_zoo.get("col_classifier").predict(img_in, verbose=0) num_col = np.argmax(label_p_pred[0]) + 1 if num_col > self.num_col_upper: @@ -845,8 +707,8 @@ class Eynollah: self.img_hight_int = int(self.image.shape[0] * scale) self.img_width_int = int(self.image.shape[1] * scale) - self.scale_y = self.img_hight_int / float(self.image.shape[0]) - self.scale_x = self.img_width_int / float(self.image.shape[1]) + self.scale_y: float = self.img_hight_int / float(self.image.shape[0]) + self.scale_x: float = self.img_width_int / float(self.image.shape[1]) self.image = resize_image(self.image, self.img_hight_int, self.img_width_int) @@ -1642,7 +1504,7 @@ class Eynollah: cont_page = [] if not self.ignore_page_extraction: img = np.copy(self.image)#cv2.GaussianBlur(self.image, (5, 5), 0) - img_page_prediction = self.do_prediction(False, img, self.models["page"]) + img_page_prediction = self.do_prediction(False, img, self.model_zoo.get("page")) imgray = cv2.cvtColor(img_page_prediction, cv2.COLOR_BGR2GRAY) _, thresh = cv2.threshold(imgray, 0, 255, 0) ##thresh = cv2.dilate(thresh, KERNEL, iterations=3) @@ -1690,7 +1552,7 @@ class Eynollah: else: img = self.imread() img = cv2.GaussianBlur(img, (5, 5), 0) - img_page_prediction = self.do_prediction(False, img, self.models["page"]) + img_page_prediction = self.do_prediction(False, img, self.model_zoo.get("page")) imgray = cv2.cvtColor(img_page_prediction, cv2.COLOR_BGR2GRAY) _, thresh = cv2.threshold(imgray, 0, 255, 0) @@ -1716,7 +1578,7 @@ class Eynollah: self.logger.debug("enter extract_text_regions") img_height_h = img.shape[0] img_width_h = img.shape[1] - model_region = self.models["region_fl"] if patches else self.models["region_fl_np"] + model_region = self.model_zoo.get("region_fl") if patches else self.model_zoo.get("region_fl_np") if self.light_version: thresholding_for_fl_light_version = True @@ -1751,7 +1613,7 @@ class Eynollah: self.logger.debug("enter extract_text_regions") img_height_h = img.shape[0] img_width_h = img.shape[1] - model_region = self.models["region_fl"] if patches else self.models["region_fl_np"] + model_region = self.model_zoo.get("region_fl") if patches else self.model_zoo.get("region_fl_np") if not patches: img = otsu_copy_binary(img) @@ -1911,6 +1773,7 @@ class Eynollah: return [], [], [] self.logger.debug("enter get_slopes_and_deskew_new_light") with share_ndarray(textline_mask_tot) as textline_mask_tot_shared: + assert self.executor results = self.executor.map(partial(do_work_of_slopes_new_light, textline_mask_tot_ea=textline_mask_tot_shared, slope_deskew=slope_deskew, @@ -1927,6 +1790,7 @@ class Eynollah: return [], [], [] self.logger.debug("enter get_slopes_and_deskew_new") with share_ndarray(textline_mask_tot) as textline_mask_tot_shared: + assert self.executor results = self.executor.map(partial(do_work_of_slopes_new, textline_mask_tot_ea=textline_mask_tot_shared, slope_deskew=slope_deskew, @@ -1947,6 +1811,7 @@ class Eynollah: self.logger.debug("enter get_slopes_and_deskew_new_curved") with share_ndarray(textline_mask_tot) as textline_mask_tot_shared: with share_ndarray(mask_texts_only) as mask_texts_only_shared: + assert self.executor results = self.executor.map(partial(do_work_of_slopes_new_curved, textline_mask_tot_ea=textline_mask_tot_shared, mask_texts_only=mask_texts_only_shared, @@ -1972,14 +1837,14 @@ class Eynollah: img_w = img_org.shape[1] img = resize_image(img_org, int(img_org.shape[0] * scaler_h), int(img_org.shape[1] * scaler_w)) - prediction_textline = self.do_prediction(use_patches, img, self.models["textline"], + prediction_textline = self.do_prediction(use_patches, img, self.model_zoo.get("textline"), marginal_of_patch_percent=0.15, n_batch_inference=3, thresholding_for_artificial_class_in_light_version=self.textline_light, threshold_art_class_textline=self.threshold_art_class_textline) #if not self.textline_light: #if num_col_classifier==1: - #prediction_textline_nopatch = self.do_prediction(False, img, self.models["textline"]) + #prediction_textline_nopatch = self.do_prediction(False, img, self.model_zoo.get_model("textline")) #prediction_textline[:,:][prediction_textline_nopatch[:,:]==0] = 0 prediction_textline = resize_image(prediction_textline, img_h, img_w) @@ -2050,7 +1915,7 @@ class Eynollah: #cv2.imwrite('prediction_textline2.png', prediction_textline[:,:,0]) - prediction_textline_longshot = self.do_prediction(False, img, self.models["textline"]) + prediction_textline_longshot = self.do_prediction(False, img, self.model_zoo.get("textline")) prediction_textline_longshot_true_size = resize_image(prediction_textline_longshot, img_h, img_w) @@ -2083,7 +1948,7 @@ class Eynollah: img_h_new = int(img.shape[0] / float(img.shape[1]) * img_w_new) img_resized = resize_image(img,img_h_new, img_w_new ) - prediction_regions_org, _ = self.do_prediction_new_concept(True, img_resized, self.models["region"]) + prediction_regions_org, _ = self.do_prediction_new_concept(True, img_resized, self.model_zoo.get("region")) prediction_regions_org = resize_image(prediction_regions_org,img_height_h, img_width_h ) image_page, page_coord, cont_page = self.extract_page() @@ -2199,7 +2064,7 @@ class Eynollah: #if self.input_binary: #img_bin = np.copy(img_resized) ###if (not self.input_binary and self.full_layout) or (not self.input_binary and num_col_classifier >= 30): - ###prediction_bin = self.do_prediction(True, img_resized, self.models["binarization"], n_batch_inference=5) + ###prediction_bin = self.do_prediction(True, img_resized, self.model_zoo.get_model("binarization"), n_batch_inference=5) ####print("inside bin ", time.time()-t_bin) ###prediction_bin=prediction_bin[:,:,0] @@ -2214,7 +2079,7 @@ class Eynollah: ###else: ###img_bin = np.copy(img_resized) if (self.ocr and self.tr) and not self.input_binary: - prediction_bin = self.do_prediction(True, img_resized, self.models["binarization"], n_batch_inference=5) + prediction_bin = self.do_prediction(True, img_resized, self.model_zoo.get("binarization"), n_batch_inference=5) prediction_bin = 255 * (prediction_bin[:,:,0] == 0) prediction_bin = np.repeat(prediction_bin[:, :, np.newaxis], 3, axis=2) prediction_bin = prediction_bin.astype(np.uint16) @@ -2246,14 +2111,14 @@ class Eynollah: self.logger.debug("resized to %dx%d for %d cols", img_resized.shape[1], img_resized.shape[0], num_col_classifier) prediction_regions_org, confidence_matrix = self.do_prediction_new_concept( - True, img_resized, self.models["region_1_2"], n_batch_inference=1, + True, img_resized, self.model_zoo.get("region_1_2"), n_batch_inference=1, thresholding_for_some_classes_in_light_version=True, threshold_art_class_layout=self.threshold_art_class_layout) else: prediction_regions_org = np.zeros((self.image_org.shape[0], self.image_org.shape[1], 3)) confidence_matrix = np.zeros((self.image_org.shape[0], self.image_org.shape[1])) prediction_regions_page, confidence_matrix_page = self.do_prediction_new_concept( - False, self.image_page_org_size, self.models["region_1_2"], n_batch_inference=1, + False, self.image_page_org_size, self.model_zoo.get("region_1_2"), n_batch_inference=1, thresholding_for_artificial_class_in_light_version=True, threshold_art_class_layout=self.threshold_art_class_layout) ys = slice(*self.page_coord[0:2]) @@ -2267,10 +2132,10 @@ class Eynollah: self.logger.debug("resized to %dx%d (new_h=%d) for %d cols", img_resized.shape[1], img_resized.shape[0], new_h, num_col_classifier) prediction_regions_org, confidence_matrix = self.do_prediction_new_concept( - True, img_resized, self.models["region_1_2"], n_batch_inference=2, + True, img_resized, self.model_zoo.get("region_1_2"), n_batch_inference=2, thresholding_for_some_classes_in_light_version=True, threshold_art_class_layout=self.threshold_art_class_layout) - ###prediction_regions_org = self.do_prediction(True, img_bin, self.models["region"], + ###prediction_regions_org = self.do_prediction(True, img_bin, self.model_zoo.get_model("region"), ###n_batch_inference=3, ###thresholding_for_some_classes_in_light_version=True) #print("inside 3 ", time.time()-t_in) @@ -2350,7 +2215,7 @@ class Eynollah: ratio_x=1 img = resize_image(img_org, int(img_org.shape[0]*ratio_y), int(img_org.shape[1]*ratio_x)) - prediction_regions_org_y = self.do_prediction(True, img, self.models["region"]) + prediction_regions_org_y = self.do_prediction(True, img, self.model_zoo.get("region")) prediction_regions_org_y = resize_image(prediction_regions_org_y, img_height_h, img_width_h ) #plt.imshow(prediction_regions_org_y[:,:,0]) @@ -2365,7 +2230,7 @@ class Eynollah: _, _ = find_num_col(img_only_regions, num_col_classifier, self.tables, multiplier=6.0) img = resize_image(img_org, int(img_org.shape[0]), int(img_org.shape[1]*(1.2 if is_image_enhanced else 1))) - prediction_regions_org = self.do_prediction(True, img, self.models["region"]) + prediction_regions_org = self.do_prediction(True, img, self.model_zoo.get("region")) prediction_regions_org = resize_image(prediction_regions_org, img_height_h, img_width_h ) prediction_regions_org=prediction_regions_org[:,:,0] @@ -2373,7 +2238,7 @@ class Eynollah: img = resize_image(img_org, int(img_org.shape[0]), int(img_org.shape[1])) - prediction_regions_org2 = self.do_prediction(True, img, self.models["region_p2"], marginal_of_patch_percent=0.2) + prediction_regions_org2 = self.do_prediction(True, img, self.model_zoo.get("region_p2"), marginal_of_patch_percent=0.2) prediction_regions_org2=resize_image(prediction_regions_org2, img_height_h, img_width_h ) mask_zeros2 = (prediction_regions_org2[:,:,0] == 0) @@ -2397,7 +2262,7 @@ class Eynollah: if self.input_binary: prediction_bin = np.copy(img_org) else: - prediction_bin = self.do_prediction(True, img_org, self.models["binarization"], n_batch_inference=5) + prediction_bin = self.do_prediction(True, img_org, self.model_zoo.get("binarization"), n_batch_inference=5) prediction_bin = resize_image(prediction_bin, img_height_h, img_width_h ) prediction_bin = 255 * (prediction_bin[:,:,0]==0) prediction_bin = np.repeat(prediction_bin[:, :, np.newaxis], 3, axis=2) @@ -2407,7 +2272,7 @@ class Eynollah: img = resize_image(prediction_bin, int(img_org.shape[0]*ratio_y), int(img_org.shape[1]*ratio_x)) - prediction_regions_org = self.do_prediction(True, img, self.models["region"]) + prediction_regions_org = self.do_prediction(True, img, self.model_zoo.get("region")) prediction_regions_org = resize_image(prediction_regions_org, img_height_h, img_width_h ) prediction_regions_org=prediction_regions_org[:,:,0] @@ -2434,7 +2299,7 @@ class Eynollah: except: if self.input_binary: prediction_bin = np.copy(img_org) - prediction_bin = self.do_prediction(True, img_org, self.models["binarization"], n_batch_inference=5) + prediction_bin = self.do_prediction(True, img_org, self.model_zoo.get("binarization"), n_batch_inference=5) prediction_bin = resize_image(prediction_bin, img_height_h, img_width_h ) prediction_bin = 255 * (prediction_bin[:,:,0]==0) prediction_bin = np.repeat(prediction_bin[:, :, np.newaxis], 3, axis=2) @@ -2445,14 +2310,14 @@ class Eynollah: img = resize_image(prediction_bin, int(img_org.shape[0]*ratio_y), int(img_org.shape[1]*ratio_x)) - prediction_regions_org = self.do_prediction(True, img, self.models["region"]) + prediction_regions_org = self.do_prediction(True, img, self.model_zoo.get("region")) prediction_regions_org = resize_image(prediction_regions_org, img_height_h, img_width_h ) prediction_regions_org=prediction_regions_org[:,:,0] #mask_lines_only=(prediction_regions_org[:,:]==3)*1 #img = resize_image(img_org, int(img_org.shape[0]*1), int(img_org.shape[1]*1)) - #prediction_regions_org = self.do_prediction(True, img, self.models["region"]) + #prediction_regions_org = self.do_prediction(True, img, self.model_zoo.get_model("region")) #prediction_regions_org = resize_image(prediction_regions_org, img_height_h, img_width_h ) #prediction_regions_org = prediction_regions_org[:,:,0] #prediction_regions_org[(prediction_regions_org[:,:] == 1) & (mask_zeros_y[:,:] == 1)]=0 @@ -2823,13 +2688,13 @@ class Eynollah: img_width_h = img_org.shape[1] patches = False if self.light_version: - prediction_table, _ = self.do_prediction_new_concept(patches, img, self.models["table"]) + prediction_table, _ = self.do_prediction_new_concept(patches, img, self.model_zoo.get("table")) prediction_table = prediction_table.astype(np.int16) return prediction_table[:,:,0] else: if num_col_classifier < 4 and num_col_classifier > 2: - prediction_table = self.do_prediction(patches, img, self.models["table"]) - pre_updown = self.do_prediction(patches, cv2.flip(img[:,:,:], -1), self.models["table"]) + prediction_table = self.do_prediction(patches, img, self.model_zoo.get("table")) + pre_updown = self.do_prediction(patches, cv2.flip(img[:,:,:], -1), self.model_zoo.get("table")) pre_updown = cv2.flip(pre_updown, -1) prediction_table[:,:,0][pre_updown[:,:,0]==1]=1 @@ -2848,8 +2713,8 @@ class Eynollah: xs = slice(w_start, w_start + img.shape[1]) img_new[ys, xs] = img - prediction_ext = self.do_prediction(patches, img_new, self.models["table"]) - pre_updown = self.do_prediction(patches, cv2.flip(img_new[:,:,:], -1), self.models["table"]) + prediction_ext = self.do_prediction(patches, img_new, self.model_zoo.get("table")) + pre_updown = self.do_prediction(patches, cv2.flip(img_new[:,:,:], -1), self.model_zoo.get("table")) pre_updown = cv2.flip(pre_updown, -1) prediction_table = prediction_ext[ys, xs] @@ -2870,8 +2735,8 @@ class Eynollah: xs = slice(w_start, w_start + img.shape[1]) img_new[ys, xs] = img - prediction_ext = self.do_prediction(patches, img_new, self.models["table"]) - pre_updown = self.do_prediction(patches, cv2.flip(img_new[:,:,:], -1), self.models["table"]) + prediction_ext = self.do_prediction(patches, img_new, self.model_zoo.get("table")) + pre_updown = self.do_prediction(patches, cv2.flip(img_new[:,:,:], -1), self.model_zoo.get("table")) pre_updown = cv2.flip(pre_updown, -1) prediction_table = prediction_ext[ys, xs] @@ -2883,10 +2748,10 @@ class Eynollah: prediction_table = np.zeros(img.shape) img_w_half = img.shape[1] // 2 - pre1 = self.do_prediction(patches, img[:,0:img_w_half,:], self.models["table"]) - pre2 = self.do_prediction(patches, img[:,img_w_half:,:], self.models["table"]) - pre_full = self.do_prediction(patches, img[:,:,:], self.models["table"]) - pre_updown = self.do_prediction(patches, cv2.flip(img[:,:,:], -1), self.models["table"]) + pre1 = self.do_prediction(patches, img[:,0:img_w_half,:], self.model_zoo.get("table")) + pre2 = self.do_prediction(patches, img[:,img_w_half:,:], self.model_zoo.get("table")) + pre_full = self.do_prediction(patches, img[:,:,:], self.model_zoo.get("table")) + pre_updown = self.do_prediction(patches, cv2.flip(img[:,:,:], -1), self.model_zoo.get("table")) pre_updown = cv2.flip(pre_updown, -1) prediction_table_full_erode = cv2.erode(pre_full[:,:,0], KERNEL, iterations=4) @@ -3678,7 +3543,7 @@ class Eynollah: tot_counter += 1 batch.append(j) if tot_counter % inference_bs == 0 or tot_counter == len(ij_list): - y_pr = self.models["reading_order"].predict(input_1 , verbose=0) + y_pr = self.model_zoo.get("reading_order").predict(input_1 , verbose=0) for jb, j in enumerate(batch): if y_pr[jb][0]>=0.5: post_list.append(j) @@ -3808,7 +3673,15 @@ class Eynollah: pass def return_ocr_of_textline_without_common_section( - self, textline_image, model_ocr, processor, device, width_textline, h2w_ratio,ind_tot): + self, + textline_image, + model_ocr, + processor, + device, + width_textline, + h2w_ratio, + ind_tot, + ): if h2w_ratio > 0.05: pixel_values = processor(textline_image, return_tensors="pt").pixel_values @@ -4261,7 +4134,7 @@ class Eynollah: gc.collect() ocr_all_textlines = return_rnn_cnn_ocr_of_given_textlines( image_page, all_found_textline_polygons, np.zeros((len(all_found_textline_polygons), 4)), - self.models["ocr"], self.b_s_ocr, self.num_to_char, textline_light=True) + self.model_zoo.get("ocr"), self.b_s_ocr, self.model_zoo.get("num_to_char"), textline_light=True) else: ocr_all_textlines = None @@ -4770,27 +4643,27 @@ class Eynollah: if len(all_found_textline_polygons): ocr_all_textlines = return_rnn_cnn_ocr_of_given_textlines( image_page, all_found_textline_polygons, all_box_coord, - self.models["ocr"], self.b_s_ocr, self.num_to_char, self.textline_light, self.curved_line) + self.model_zoo.get("ocr"), self.b_s_ocr, self.model_zoo.get("num_to_char"), self.textline_light, self.curved_line) if len(all_found_textline_polygons_marginals_left): ocr_all_textlines_marginals_left = return_rnn_cnn_ocr_of_given_textlines( image_page, all_found_textline_polygons_marginals_left, all_box_coord_marginals_left, - self.models["ocr"], self.b_s_ocr, self.num_to_char, self.textline_light, self.curved_line) + self.model_zoo.get("ocr"), self.b_s_ocr, self.model_zoo.get("num_to_char"), self.textline_light, self.curved_line) if len(all_found_textline_polygons_marginals_right): ocr_all_textlines_marginals_right = return_rnn_cnn_ocr_of_given_textlines( image_page, all_found_textline_polygons_marginals_right, all_box_coord_marginals_right, - self.models["ocr"], self.b_s_ocr, self.num_to_char, self.textline_light, self.curved_line) + self.model_zoo.get("ocr"), self.b_s_ocr, self.model_zoo.get("num_to_char"), self.textline_light, self.curved_line) if self.full_layout and len(all_found_textline_polygons): ocr_all_textlines_h = return_rnn_cnn_ocr_of_given_textlines( image_page, all_found_textline_polygons_h, all_box_coord_h, - self.models["ocr"], self.b_s_ocr, self.num_to_char, self.textline_light, self.curved_line) + self.model_zoo.get("ocr"), self.b_s_ocr, self.model_zoo.get("num_to_char"), self.textline_light, self.curved_line) if self.full_layout and len(polygons_of_drop_capitals): ocr_all_textlines_drop = return_rnn_cnn_ocr_of_given_textlines( image_page, polygons_of_drop_capitals, np.zeros((len(polygons_of_drop_capitals), 4)), - self.models["ocr"], self.b_s_ocr, self.num_to_char, self.textline_light, self.curved_line) + self.model_zoo.get("ocr"), self.b_s_ocr, self.model_zoo.get("num_to_char"), self.textline_light, self.curved_line) else: if self.light_version: @@ -4802,7 +4675,7 @@ class Eynollah: gc.collect() torch.cuda.empty_cache() - self.models["ocr"].to(self.device) + self.model_zoo.get("ocr").to(self.device) ind_tot = 0 #cv2.imwrite('./img_out.png', image_page) @@ -4839,7 +4712,7 @@ class Eynollah: img_croped = img_poly_on_img[y:y+h, x:x+w, :] #cv2.imwrite('./extracted_lines/'+str(ind_tot)+'.jpg', img_croped) text_ocr = self.return_ocr_of_textline_without_common_section( - img_croped, self.models["ocr"], self.processor, self.device, w, h2w_ratio, ind_tot) + img_croped, self.model_zoo.get("ocr"), self.model_zoo.get("trocr_processor"), self.device, w, h2w_ratio, ind_tot) ocr_textline_in_textregion.append(text_ocr) ind_tot = ind_tot +1 ocr_all_textlines.append(ocr_textline_in_textregion) @@ -4874,966 +4747,3 @@ class Eynollah: conf_contours_textregions=conf_contours_textregions) return pcgts - - -class Eynollah_ocr: - def __init__( - self, - dir_models, - model_name=None, - dir_xmls=None, - tr_ocr=False, - batch_size=None, - export_textline_images_and_text=False, - do_not_mask_with_textline_contour=False, - pref_of_dataset=None, - min_conf_value_of_textline_text : Optional[float]=None, - logger=None, - ): - self.model_name = model_name - self.tr_ocr = tr_ocr - self.export_textline_images_and_text = export_textline_images_and_text - self.do_not_mask_with_textline_contour = do_not_mask_with_textline_contour - self.pref_of_dataset = pref_of_dataset - self.logger = logger if logger else getLogger('eynollah') - - if not export_textline_images_and_text: - if min_conf_value_of_textline_text: - self.min_conf_value_of_textline_text = float(min_conf_value_of_textline_text) - else: - self.min_conf_value_of_textline_text = 0.3 - if tr_ocr: - self.processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-printed") - self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - if self.model_name: - self.model_ocr_dir = self.model_name - else: - self.model_ocr_dir = dir_models + "/model_eynollah_ocr_trocr_20250919" - self.model_ocr = VisionEncoderDecoderModel.from_pretrained(self.model_ocr_dir) - self.model_ocr.to(self.device) - if not batch_size: - self.b_s = 2 - else: - self.b_s = int(batch_size) - - else: - if self.model_name: - self.model_ocr_dir = self.model_name - else: - self.model_ocr_dir = dir_models + "/model_eynollah_ocr_cnnrnn_20250930" - model_ocr = load_model(self.model_ocr_dir , compile=False) - - self.prediction_model = tf.keras.models.Model( - model_ocr.get_layer(name = "image").input, - model_ocr.get_layer(name = "dense2").output) - if not batch_size: - self.b_s = 8 - else: - self.b_s = int(batch_size) - - with open(os.path.join(self.model_ocr_dir, "characters_org.txt"),"r") as config_file: - characters = json.load(config_file) - - AUTOTUNE = tf.data.AUTOTUNE - - # Mapping characters to integers. - char_to_num = StringLookup(vocabulary=list(characters), mask_token=None) - - # Mapping integers back to original characters. - self.num_to_char = StringLookup( - vocabulary=char_to_num.get_vocabulary(), mask_token=None, invert=True - ) - self.end_character = len(characters) + 2 - - def run(self, overwrite: bool = False, - dir_in: Optional[str] = None, - dir_in_bin: Optional[str] = None, - image_filename: Optional[str] = None, - dir_xmls: Optional[str] = None, - dir_out_image_text: Optional[str] = None, - dir_out: Optional[str] = None, - ): - if dir_in: - ls_imgs = [os.path.join(dir_in, image_filename) - for image_filename in filter(is_image_filename, - os.listdir(dir_in))] - else: - ls_imgs = [image_filename] - - if self.tr_ocr: - tr_ocr_input_height_and_width = 384 - for dir_img in ls_imgs: - file_name = Path(dir_img).stem - dir_xml = os.path.join(dir_xmls, file_name+'.xml') - out_file_ocr = os.path.join(dir_out, file_name+'.xml') - - if os.path.exists(out_file_ocr): - if overwrite: - self.logger.warning("will overwrite existing output file '%s'", out_file_ocr) - else: - self.logger.warning("will skip input for existing output file '%s'", out_file_ocr) - continue - - img = cv2.imread(dir_img) - - if dir_out_image_text: - out_image_with_text = os.path.join(dir_out_image_text, file_name+'.png') - image_text = Image.new("RGB", (img.shape[1], img.shape[0]), "white") - draw = ImageDraw.Draw(image_text) - total_bb_coordinates = [] - - ##file_name = Path(dir_xmls).stem - tree1 = ET.parse(dir_xml, parser = ET.XMLParser(encoding="utf-8")) - root1=tree1.getroot() - alltags=[elem.tag for elem in root1.iter()] - link=alltags[0].split('}')[0]+'}' - - name_space = alltags[0].split('}')[0] - name_space = name_space.split('{')[1] - - region_tags=np.unique([x for x in alltags if x.endswith('TextRegion')]) - - - - cropped_lines = [] - cropped_lines_region_indexer = [] - cropped_lines_meging_indexing = [] - - extracted_texts = [] - - indexer_text_region = 0 - indexer_b_s = 0 - - for nn in root1.iter(region_tags): - for child_textregion in nn: - if child_textregion.tag.endswith("TextLine"): - - for child_textlines in child_textregion: - if child_textlines.tag.endswith("Coords"): - cropped_lines_region_indexer.append(indexer_text_region) - p_h=child_textlines.attrib['points'].split(' ') - textline_coords = np.array( [ [int(x.split(',')[0]), - int(x.split(',')[1]) ] - for x in p_h] ) - x,y,w,h = cv2.boundingRect(textline_coords) - - if dir_out_image_text: - total_bb_coordinates.append([x,y,w,h]) - - h2w_ratio = h/float(w) - - img_poly_on_img = np.copy(img) - mask_poly = np.zeros(img.shape) - mask_poly = cv2.fillPoly(mask_poly, pts=[textline_coords], color=(1, 1, 1)) - - mask_poly = mask_poly[y:y+h, x:x+w, :] - img_crop = img_poly_on_img[y:y+h, x:x+w, :] - img_crop[mask_poly==0] = 255 - - self.logger.debug("processing %d lines for '%s'", - len(cropped_lines), nn.attrib['id']) - if h2w_ratio > 0.1: - cropped_lines.append(resize_image(img_crop, - tr_ocr_input_height_and_width, - tr_ocr_input_height_and_width) ) - cropped_lines_meging_indexing.append(0) - indexer_b_s+=1 - if indexer_b_s==self.b_s: - imgs = cropped_lines[:] - cropped_lines = [] - indexer_b_s = 0 - - pixel_values_merged = self.processor(imgs, return_tensors="pt").pixel_values - generated_ids_merged = self.model_ocr.generate( - pixel_values_merged.to(self.device)) - generated_text_merged = self.processor.batch_decode( - generated_ids_merged, skip_special_tokens=True) - - extracted_texts = extracted_texts + generated_text_merged - - else: - splited_images, _ = return_textlines_split_if_needed(img_crop, None) - #print(splited_images) - if splited_images: - cropped_lines.append(resize_image(splited_images[0], - tr_ocr_input_height_and_width, - tr_ocr_input_height_and_width)) - cropped_lines_meging_indexing.append(1) - indexer_b_s+=1 - - if indexer_b_s==self.b_s: - imgs = cropped_lines[:] - cropped_lines = [] - indexer_b_s = 0 - - pixel_values_merged = self.processor(imgs, return_tensors="pt").pixel_values - generated_ids_merged = self.model_ocr.generate( - pixel_values_merged.to(self.device)) - generated_text_merged = self.processor.batch_decode( - generated_ids_merged, skip_special_tokens=True) - - extracted_texts = extracted_texts + generated_text_merged - - - cropped_lines.append(resize_image(splited_images[1], - tr_ocr_input_height_and_width, - tr_ocr_input_height_and_width)) - cropped_lines_meging_indexing.append(-1) - indexer_b_s+=1 - - if indexer_b_s==self.b_s: - imgs = cropped_lines[:] - cropped_lines = [] - indexer_b_s = 0 - - pixel_values_merged = self.processor(imgs, return_tensors="pt").pixel_values - generated_ids_merged = self.model_ocr.generate( - pixel_values_merged.to(self.device)) - generated_text_merged = self.processor.batch_decode( - generated_ids_merged, skip_special_tokens=True) - - extracted_texts = extracted_texts + generated_text_merged - - else: - cropped_lines.append(img_crop) - cropped_lines_meging_indexing.append(0) - indexer_b_s+=1 - - if indexer_b_s==self.b_s: - imgs = cropped_lines[:] - cropped_lines = [] - indexer_b_s = 0 - - pixel_values_merged = self.processor(imgs, return_tensors="pt").pixel_values - generated_ids_merged = self.model_ocr.generate( - pixel_values_merged.to(self.device)) - generated_text_merged = self.processor.batch_decode( - generated_ids_merged, skip_special_tokens=True) - - extracted_texts = extracted_texts + generated_text_merged - - - - indexer_text_region = indexer_text_region +1 - - if indexer_b_s!=0: - imgs = cropped_lines[:] - cropped_lines = [] - indexer_b_s = 0 - - pixel_values_merged = self.processor(imgs, return_tensors="pt").pixel_values - generated_ids_merged = self.model_ocr.generate(pixel_values_merged.to(self.device)) - generated_text_merged = self.processor.batch_decode(generated_ids_merged, skip_special_tokens=True) - - extracted_texts = extracted_texts + generated_text_merged - - ####extracted_texts = [] - ####n_iterations = math.ceil(len(cropped_lines) / self.b_s) - - ####for i in range(n_iterations): - ####if i==(n_iterations-1): - ####n_start = i*self.b_s - ####imgs = cropped_lines[n_start:] - ####else: - ####n_start = i*self.b_s - ####n_end = (i+1)*self.b_s - ####imgs = cropped_lines[n_start:n_end] - ####pixel_values_merged = self.processor(imgs, return_tensors="pt").pixel_values - ####generated_ids_merged = self.model_ocr.generate( - #### pixel_values_merged.to(self.device)) - ####generated_text_merged = self.processor.batch_decode( - #### generated_ids_merged, skip_special_tokens=True) - - ####extracted_texts = extracted_texts + generated_text_merged - - del cropped_lines - gc.collect() - - extracted_texts_merged = [extracted_texts[ind] - if cropped_lines_meging_indexing[ind]==0 - else extracted_texts[ind]+" "+extracted_texts[ind+1] - if cropped_lines_meging_indexing[ind]==1 - else None - for ind in range(len(cropped_lines_meging_indexing))] - - extracted_texts_merged = [ind for ind in extracted_texts_merged if ind is not None] - #print(extracted_texts_merged, len(extracted_texts_merged)) - - unique_cropped_lines_region_indexer = np.unique(cropped_lines_region_indexer) - - if dir_out_image_text: - - #font_path = "Charis-7.000/Charis-Regular.ttf" # Make sure this file exists! - font = importlib_resources.files(__package__) / "Charis-Regular.ttf" - with importlib_resources.as_file(font) as font: - font = ImageFont.truetype(font=font, size=40) - - for indexer_text, bb_ind in enumerate(total_bb_coordinates): - - - x_bb = bb_ind[0] - y_bb = bb_ind[1] - w_bb = bb_ind[2] - h_bb = bb_ind[3] - - font = fit_text_single_line(draw, extracted_texts_merged[indexer_text], - font.path, w_bb, int(h_bb*0.4) ) - - ##draw.rectangle([x_bb, y_bb, x_bb + w_bb, y_bb + h_bb], outline="red", width=2) - - text_bbox = draw.textbbox((0, 0), extracted_texts_merged[indexer_text], font=font) - text_width = text_bbox[2] - text_bbox[0] - text_height = text_bbox[3] - text_bbox[1] - - text_x = x_bb + (w_bb - text_width) // 2 # Center horizontally - text_y = y_bb + (h_bb - text_height) // 2 # Center vertically - - # Draw the text - draw.text((text_x, text_y), extracted_texts_merged[indexer_text], fill="black", font=font) - image_text.save(out_image_with_text) - - #print(len(unique_cropped_lines_region_indexer), 'unique_cropped_lines_region_indexer') - #######text_by_textregion = [] - #######for ind in unique_cropped_lines_region_indexer: - #######ind = np.array(cropped_lines_region_indexer)==ind - #######extracted_texts_merged_un = np.array(extracted_texts_merged)[ind] - #######text_by_textregion.append(" ".join(extracted_texts_merged_un)) - - text_by_textregion = [] - for ind in unique_cropped_lines_region_indexer: - ind = np.array(cropped_lines_region_indexer) == ind - extracted_texts_merged_un = np.array(extracted_texts_merged)[ind] - if len(extracted_texts_merged_un)>1: - text_by_textregion_ind = "" - next_glue = "" - for indt in range(len(extracted_texts_merged_un)): - if (extracted_texts_merged_un[indt].endswith('⸗') or - extracted_texts_merged_un[indt].endswith('-') or - extracted_texts_merged_un[indt].endswith('¬')): - text_by_textregion_ind += next_glue + extracted_texts_merged_un[indt][:-1] - next_glue = "" - else: - text_by_textregion_ind += next_glue + extracted_texts_merged_un[indt] - next_glue = " " - text_by_textregion.append(text_by_textregion_ind) - else: - text_by_textregion.append(" ".join(extracted_texts_merged_un)) - - - indexer = 0 - indexer_textregion = 0 - for nn in root1.iter(region_tags): - #id_textregion = nn.attrib['id'] - #id_textregions.append(id_textregion) - #textregions_by_existing_ids.append(text_by_textregion[indexer_textregion]) - - is_textregion_text = False - for childtest in nn: - if childtest.tag.endswith("TextEquiv"): - is_textregion_text = True - - if not is_textregion_text: - text_subelement_textregion = ET.SubElement(nn, 'TextEquiv') - unicode_textregion = ET.SubElement(text_subelement_textregion, 'Unicode') - - - has_textline = False - for child_textregion in nn: - if child_textregion.tag.endswith("TextLine"): - - is_textline_text = False - for childtest2 in child_textregion: - if childtest2.tag.endswith("TextEquiv"): - is_textline_text = True - - - if not is_textline_text: - text_subelement = ET.SubElement(child_textregion, 'TextEquiv') - ##text_subelement.set('conf', f"{extracted_conf_value_merged[indexer]:.2f}") - unicode_textline = ET.SubElement(text_subelement, 'Unicode') - unicode_textline.text = extracted_texts_merged[indexer] - else: - for childtest3 in child_textregion: - if childtest3.tag.endswith("TextEquiv"): - for child_uc in childtest3: - if child_uc.tag.endswith("Unicode"): - ##childtest3.set('conf', f"{extracted_conf_value_merged[indexer]:.2f}") - child_uc.text = extracted_texts_merged[indexer] - - indexer = indexer + 1 - has_textline = True - if has_textline: - if is_textregion_text: - for child4 in nn: - if child4.tag.endswith("TextEquiv"): - for childtr_uc in child4: - if childtr_uc.tag.endswith("Unicode"): - childtr_uc.text = text_by_textregion[indexer_textregion] - else: - unicode_textregion.text = text_by_textregion[indexer_textregion] - indexer_textregion = indexer_textregion + 1 - - ###sample_order = [(id_to_order[tid], text) - ### for tid, text in zip(id_textregions, textregions_by_existing_ids) - ### if tid in id_to_order] - - ##ordered_texts_sample = [text for _, text in sorted(sample_order)] - ##tot_page_text = ' '.join(ordered_texts_sample) - - ##for page_element in root1.iter(link+'Page'): - ##text_page = ET.SubElement(page_element, 'TextEquiv') - ##unicode_textpage = ET.SubElement(text_page, 'Unicode') - ##unicode_textpage.text = tot_page_text - - ET.register_namespace("",name_space) - tree1.write(out_file_ocr,xml_declaration=True,method='xml',encoding="utf-8",default_namespace=None) - else: - ###max_len = 280#512#280#512 - ###padding_token = 1500#299#1500#299 - image_width = 512#max_len * 4 - image_height = 32 - - - img_size=(image_width, image_height) - - for dir_img in ls_imgs: - file_name = Path(dir_img).stem - dir_xml = os.path.join(dir_xmls, file_name+'.xml') - out_file_ocr = os.path.join(dir_out, file_name+'.xml') - - if os.path.exists(out_file_ocr): - if overwrite: - self.logger.warning("will overwrite existing output file '%s'", out_file_ocr) - else: - self.logger.warning("will skip input for existing output file '%s'", out_file_ocr) - continue - - img = cv2.imread(dir_img) - if dir_in_bin is not None: - cropped_lines_bin = [] - dir_img_bin = os.path.join(dir_in_bin, file_name+'.png') - img_bin = cv2.imread(dir_img_bin) - - if dir_out_image_text: - out_image_with_text = os.path.join(dir_out_image_text, file_name+'.png') - image_text = Image.new("RGB", (img.shape[1], img.shape[0]), "white") - draw = ImageDraw.Draw(image_text) - total_bb_coordinates = [] - - tree1 = ET.parse(dir_xml, parser = ET.XMLParser(encoding="utf-8")) - root1=tree1.getroot() - alltags=[elem.tag for elem in root1.iter()] - link=alltags[0].split('}')[0]+'}' - - name_space = alltags[0].split('}')[0] - name_space = name_space.split('{')[1] - - region_tags=np.unique([x for x in alltags if x.endswith('TextRegion')]) - - cropped_lines = [] - cropped_lines_ver_index = [] - cropped_lines_region_indexer = [] - cropped_lines_meging_indexing = [] - - tinl = time.time() - indexer_text_region = 0 - indexer_textlines = 0 - for nn in root1.iter(region_tags): - try: - type_textregion = nn.attrib['type'] - except: - type_textregion = 'paragraph' - for child_textregion in nn: - if child_textregion.tag.endswith("TextLine"): - for child_textlines in child_textregion: - if child_textlines.tag.endswith("Coords"): - cropped_lines_region_indexer.append(indexer_text_region) - p_h=child_textlines.attrib['points'].split(' ') - textline_coords = np.array( [ [int(x.split(',')[0]), - int(x.split(',')[1]) ] - for x in p_h] ) - - x,y,w,h = cv2.boundingRect(textline_coords) - - angle_radians = math.atan2(h, w) - # Convert to degrees - angle_degrees = math.degrees(angle_radians) - if type_textregion=='drop-capital': - angle_degrees = 0 - - if dir_out_image_text: - total_bb_coordinates.append([x,y,w,h]) - - w_scaled = w * image_height/float(h) - - img_poly_on_img = np.copy(img) - if dir_in_bin is not None: - img_poly_on_img_bin = np.copy(img_bin) - img_crop_bin = img_poly_on_img_bin[y:y+h, x:x+w, :] - - mask_poly = np.zeros(img.shape) - mask_poly = cv2.fillPoly(mask_poly, pts=[textline_coords], color=(1, 1, 1)) - - - mask_poly = mask_poly[y:y+h, x:x+w, :] - img_crop = img_poly_on_img[y:y+h, x:x+w, :] - - if self.export_textline_images_and_text: - if not self.do_not_mask_with_textline_contour: - img_crop[mask_poly==0] = 255 - - else: - # print(file_name, angle_degrees, w*h, - # mask_poly[:,:,0].sum(), - # mask_poly[:,:,0].sum() /float(w*h) , - # 'didi') - - if angle_degrees > 3: - better_des_slope = get_orientation_moments(textline_coords) - - img_crop = rotate_image_with_padding(img_crop, better_des_slope) - if dir_in_bin is not None: - img_crop_bin = rotate_image_with_padding(img_crop_bin, better_des_slope) - - mask_poly = rotate_image_with_padding(mask_poly, better_des_slope) - mask_poly = mask_poly.astype('uint8') - - #new bounding box - x_n, y_n, w_n, h_n = get_contours_and_bounding_boxes(mask_poly[:,:,0]) - - mask_poly = mask_poly[y_n:y_n+h_n, x_n:x_n+w_n, :] - img_crop = img_crop[y_n:y_n+h_n, x_n:x_n+w_n, :] - - if not self.do_not_mask_with_textline_contour: - img_crop[mask_poly==0] = 255 - if dir_in_bin is not None: - img_crop_bin = img_crop_bin[y_n:y_n+h_n, x_n:x_n+w_n, :] - if not self.do_not_mask_with_textline_contour: - img_crop_bin[mask_poly==0] = 255 - - if mask_poly[:,:,0].sum() /float(w_n*h_n) < 0.50 and w_scaled > 90: - if dir_in_bin is not None: - img_crop, img_crop_bin = \ - break_curved_line_into_small_pieces_and_then_merge( - img_crop, mask_poly, img_crop_bin) - else: - img_crop, _ = \ - break_curved_line_into_small_pieces_and_then_merge( - img_crop, mask_poly) - - else: - better_des_slope = 0 - if not self.do_not_mask_with_textline_contour: - img_crop[mask_poly==0] = 255 - if dir_in_bin is not None: - if not self.do_not_mask_with_textline_contour: - img_crop_bin[mask_poly==0] = 255 - if type_textregion=='drop-capital': - pass - else: - if mask_poly[:,:,0].sum() /float(w*h) < 0.50 and w_scaled > 90: - if dir_in_bin is not None: - img_crop, img_crop_bin = \ - break_curved_line_into_small_pieces_and_then_merge( - img_crop, mask_poly, img_crop_bin) - else: - img_crop, _ = \ - break_curved_line_into_small_pieces_and_then_merge( - img_crop, mask_poly) - - if not self.export_textline_images_and_text: - if w_scaled < 750:#1.5*image_width: - img_fin = preprocess_and_resize_image_for_ocrcnn_model( - img_crop, image_height, image_width) - cropped_lines.append(img_fin) - if abs(better_des_slope) > 45: - cropped_lines_ver_index.append(1) - else: - cropped_lines_ver_index.append(0) - - cropped_lines_meging_indexing.append(0) - if dir_in_bin is not None: - img_fin = preprocess_and_resize_image_for_ocrcnn_model( - img_crop_bin, image_height, image_width) - cropped_lines_bin.append(img_fin) - else: - splited_images, splited_images_bin = return_textlines_split_if_needed( - img_crop, img_crop_bin if dir_in_bin is not None else None) - if splited_images: - img_fin = preprocess_and_resize_image_for_ocrcnn_model( - splited_images[0], image_height, image_width) - cropped_lines.append(img_fin) - cropped_lines_meging_indexing.append(1) - - if abs(better_des_slope) > 45: - cropped_lines_ver_index.append(1) - else: - cropped_lines_ver_index.append(0) - - img_fin = preprocess_and_resize_image_for_ocrcnn_model( - splited_images[1], image_height, image_width) - - cropped_lines.append(img_fin) - cropped_lines_meging_indexing.append(-1) - - if abs(better_des_slope) > 45: - cropped_lines_ver_index.append(1) - else: - cropped_lines_ver_index.append(0) - - if dir_in_bin is not None: - img_fin = preprocess_and_resize_image_for_ocrcnn_model( - splited_images_bin[0], image_height, image_width) - cropped_lines_bin.append(img_fin) - img_fin = preprocess_and_resize_image_for_ocrcnn_model( - splited_images_bin[1], image_height, image_width) - cropped_lines_bin.append(img_fin) - - else: - img_fin = preprocess_and_resize_image_for_ocrcnn_model( - img_crop, image_height, image_width) - cropped_lines.append(img_fin) - cropped_lines_meging_indexing.append(0) - - if abs(better_des_slope) > 45: - cropped_lines_ver_index.append(1) - else: - cropped_lines_ver_index.append(0) - - if dir_in_bin is not None: - img_fin = preprocess_and_resize_image_for_ocrcnn_model( - img_crop_bin, image_height, image_width) - cropped_lines_bin.append(img_fin) - - if self.export_textline_images_and_text: - if img_crop.shape[0]==0 or img_crop.shape[1]==0: - pass - else: - if child_textlines.tag.endswith("TextEquiv"): - for cheild_text in child_textlines: - if cheild_text.tag.endswith("Unicode"): - textline_text = cheild_text.text - if textline_text: - base_name = os.path.join( - dir_out, file_name + '_line_' + str(indexer_textlines)) - if self.pref_of_dataset: - base_name += '_' + self.pref_of_dataset - if not self.do_not_mask_with_textline_contour: - base_name += '_masked' - - with open(base_name + '.txt', 'w') as text_file: - text_file.write(textline_text) - cv2.imwrite(base_name + '.png', img_crop) - indexer_textlines+=1 - - if not self.export_textline_images_and_text: - indexer_text_region = indexer_text_region +1 - - if not self.export_textline_images_and_text: - extracted_texts = [] - extracted_conf_value = [] - - n_iterations = math.ceil(len(cropped_lines) / self.b_s) - - for i in range(n_iterations): - if i==(n_iterations-1): - n_start = i*self.b_s - imgs = cropped_lines[n_start:] - imgs = np.array(imgs) - imgs = imgs.reshape(imgs.shape[0], image_height, image_width, 3) - - ver_imgs = np.array( cropped_lines_ver_index[n_start:] ) - indices_ver = np.where(ver_imgs == 1)[0] - - #print(indices_ver, 'indices_ver') - if len(indices_ver)>0: - imgs_ver_flipped = imgs[indices_ver, : ,: ,:] - imgs_ver_flipped = imgs_ver_flipped[:,::-1,::-1,:] - #print(imgs_ver_flipped, 'imgs_ver_flipped') - - else: - imgs_ver_flipped = None - - if dir_in_bin is not None: - imgs_bin = cropped_lines_bin[n_start:] - imgs_bin = np.array(imgs_bin) - imgs_bin = imgs_bin.reshape(imgs_bin.shape[0], image_height, image_width, 3) - - if len(indices_ver)>0: - imgs_bin_ver_flipped = imgs_bin[indices_ver, : ,: ,:] - imgs_bin_ver_flipped = imgs_bin_ver_flipped[:,::-1,::-1,:] - #print(imgs_ver_flipped, 'imgs_ver_flipped') - - else: - imgs_bin_ver_flipped = None - else: - n_start = i*self.b_s - n_end = (i+1)*self.b_s - imgs = cropped_lines[n_start:n_end] - imgs = np.array(imgs).reshape(self.b_s, image_height, image_width, 3) - - ver_imgs = np.array( cropped_lines_ver_index[n_start:n_end] ) - indices_ver = np.where(ver_imgs == 1)[0] - #print(indices_ver, 'indices_ver') - - if len(indices_ver)>0: - imgs_ver_flipped = imgs[indices_ver, : ,: ,:] - imgs_ver_flipped = imgs_ver_flipped[:,::-1,::-1,:] - #print(imgs_ver_flipped, 'imgs_ver_flipped') - else: - imgs_ver_flipped = None - - - if dir_in_bin is not None: - imgs_bin = cropped_lines_bin[n_start:n_end] - imgs_bin = np.array(imgs_bin).reshape(self.b_s, image_height, image_width, 3) - - - if len(indices_ver)>0: - imgs_bin_ver_flipped = imgs_bin[indices_ver, : ,: ,:] - imgs_bin_ver_flipped = imgs_bin_ver_flipped[:,::-1,::-1,:] - #print(imgs_ver_flipped, 'imgs_ver_flipped') - else: - imgs_bin_ver_flipped = None - - - self.logger.debug("processing next %d lines", len(imgs)) - preds = self.prediction_model.predict(imgs, verbose=0) - - if len(indices_ver)>0: - preds_flipped = self.prediction_model.predict(imgs_ver_flipped, verbose=0) - preds_max_fliped = np.max(preds_flipped, axis=2 ) - preds_max_args_flipped = np.argmax(preds_flipped, axis=2 ) - pred_max_not_unk_mask_bool_flipped = preds_max_args_flipped[:,:]!=self.end_character - masked_means_flipped = \ - np.sum(preds_max_fliped * pred_max_not_unk_mask_bool_flipped, axis=1) / \ - np.sum(pred_max_not_unk_mask_bool_flipped, axis=1) - masked_means_flipped[np.isnan(masked_means_flipped)] = 0 - - preds_max = np.max(preds, axis=2 ) - preds_max_args = np.argmax(preds, axis=2 ) - pred_max_not_unk_mask_bool = preds_max_args[:,:]!=self.end_character - - masked_means = \ - np.sum(preds_max * pred_max_not_unk_mask_bool, axis=1) / \ - np.sum(pred_max_not_unk_mask_bool, axis=1) - masked_means[np.isnan(masked_means)] = 0 - - masked_means_ver = masked_means[indices_ver] - #print(masked_means_ver, 'pred_max_not_unk') - - indices_where_flipped_conf_value_is_higher = \ - np.where(masked_means_flipped > masked_means_ver)[0] - - #print(indices_where_flipped_conf_value_is_higher, 'indices_where_flipped_conf_value_is_higher') - if len(indices_where_flipped_conf_value_is_higher)>0: - indices_to_be_replaced = indices_ver[indices_where_flipped_conf_value_is_higher] - preds[indices_to_be_replaced,:,:] = \ - preds_flipped[indices_where_flipped_conf_value_is_higher, :, :] - if dir_in_bin is not None: - preds_bin = self.prediction_model.predict(imgs_bin, verbose=0) - - if len(indices_ver)>0: - preds_flipped = self.prediction_model.predict(imgs_bin_ver_flipped, verbose=0) - preds_max_fliped = np.max(preds_flipped, axis=2 ) - preds_max_args_flipped = np.argmax(preds_flipped, axis=2 ) - pred_max_not_unk_mask_bool_flipped = preds_max_args_flipped[:,:]!=self.end_character - masked_means_flipped = \ - np.sum(preds_max_fliped * pred_max_not_unk_mask_bool_flipped, axis=1) / \ - np.sum(pred_max_not_unk_mask_bool_flipped, axis=1) - masked_means_flipped[np.isnan(masked_means_flipped)] = 0 - - preds_max = np.max(preds, axis=2 ) - preds_max_args = np.argmax(preds, axis=2 ) - pred_max_not_unk_mask_bool = preds_max_args[:,:]!=self.end_character - - masked_means = \ - np.sum(preds_max * pred_max_not_unk_mask_bool, axis=1) / \ - np.sum(pred_max_not_unk_mask_bool, axis=1) - masked_means[np.isnan(masked_means)] = 0 - - masked_means_ver = masked_means[indices_ver] - #print(masked_means_ver, 'pred_max_not_unk') - - indices_where_flipped_conf_value_is_higher = \ - np.where(masked_means_flipped > masked_means_ver)[0] - - #print(indices_where_flipped_conf_value_is_higher, 'indices_where_flipped_conf_value_is_higher') - if len(indices_where_flipped_conf_value_is_higher)>0: - indices_to_be_replaced = indices_ver[indices_where_flipped_conf_value_is_higher] - preds_bin[indices_to_be_replaced,:,:] = \ - preds_flipped[indices_where_flipped_conf_value_is_higher, :, :] - - preds = (preds + preds_bin) / 2. - - pred_texts = decode_batch_predictions(preds, self.num_to_char) - - preds_max = np.max(preds, axis=2 ) - preds_max_args = np.argmax(preds, axis=2 ) - pred_max_not_unk_mask_bool = preds_max_args[:,:]!=self.end_character - masked_means = \ - np.sum(preds_max * pred_max_not_unk_mask_bool, axis=1) / \ - np.sum(pred_max_not_unk_mask_bool, axis=1) - - for ib in range(imgs.shape[0]): - pred_texts_ib = pred_texts[ib].replace("[UNK]", "") - if masked_means[ib] >= self.min_conf_value_of_textline_text: - extracted_texts.append(pred_texts_ib) - extracted_conf_value.append(masked_means[ib]) - else: - extracted_texts.append("") - extracted_conf_value.append(0) - del cropped_lines - if dir_in_bin is not None: - del cropped_lines_bin - gc.collect() - - extracted_texts_merged = [extracted_texts[ind] - if cropped_lines_meging_indexing[ind]==0 - else extracted_texts[ind]+" "+extracted_texts[ind+1] - if cropped_lines_meging_indexing[ind]==1 - else None - for ind in range(len(cropped_lines_meging_indexing))] - - extracted_conf_value_merged = [extracted_conf_value[ind] - if cropped_lines_meging_indexing[ind]==0 - else (extracted_conf_value[ind]+extracted_conf_value[ind+1])/2. - if cropped_lines_meging_indexing[ind]==1 - else None - for ind in range(len(cropped_lines_meging_indexing))] - - extracted_conf_value_merged = [extracted_conf_value_merged[ind_cfm] - for ind_cfm in range(len(extracted_texts_merged)) - if extracted_texts_merged[ind_cfm] is not None] - extracted_texts_merged = [ind for ind in extracted_texts_merged if ind is not None] - unique_cropped_lines_region_indexer = np.unique(cropped_lines_region_indexer) - - if dir_out_image_text: - #font_path = "Charis-7.000/Charis-Regular.ttf" # Make sure this file exists! - font = importlib_resources.files(__package__) / "Charis-Regular.ttf" - with importlib_resources.as_file(font) as font: - font = ImageFont.truetype(font=font, size=40) - - for indexer_text, bb_ind in enumerate(total_bb_coordinates): - x_bb = bb_ind[0] - y_bb = bb_ind[1] - w_bb = bb_ind[2] - h_bb = bb_ind[3] - - font = fit_text_single_line(draw, extracted_texts_merged[indexer_text], - font.path, w_bb, int(h_bb*0.4) ) - - ##draw.rectangle([x_bb, y_bb, x_bb + w_bb, y_bb + h_bb], outline="red", width=2) - - text_bbox = draw.textbbox((0, 0), extracted_texts_merged[indexer_text], font=font) - text_width = text_bbox[2] - text_bbox[0] - text_height = text_bbox[3] - text_bbox[1] - - text_x = x_bb + (w_bb - text_width) // 2 # Center horizontally - text_y = y_bb + (h_bb - text_height) // 2 # Center vertically - - # Draw the text - draw.text((text_x, text_y), extracted_texts_merged[indexer_text], fill="black", font=font) - image_text.save(out_image_with_text) - - text_by_textregion = [] - for ind in unique_cropped_lines_region_indexer: - ind = np.array(cropped_lines_region_indexer)==ind - extracted_texts_merged_un = np.array(extracted_texts_merged)[ind] - if len(extracted_texts_merged_un)>1: - text_by_textregion_ind = "" - next_glue = "" - for indt in range(len(extracted_texts_merged_un)): - if (extracted_texts_merged_un[indt].endswith('⸗') or - extracted_texts_merged_un[indt].endswith('-') or - extracted_texts_merged_un[indt].endswith('¬')): - text_by_textregion_ind += next_glue + extracted_texts_merged_un[indt][:-1] - next_glue = "" - else: - text_by_textregion_ind += next_glue + extracted_texts_merged_un[indt] - next_glue = " " - text_by_textregion.append(text_by_textregion_ind) - else: - text_by_textregion.append(" ".join(extracted_texts_merged_un)) - #print(text_by_textregion, 'text_by_textregiontext_by_textregiontext_by_textregiontext_by_textregiontext_by_textregion') - - ###index_tot_regions = [] - ###tot_region_ref = [] - - ###for jj in root1.iter(link+'RegionRefIndexed'): - ###index_tot_regions.append(jj.attrib['index']) - ###tot_region_ref.append(jj.attrib['regionRef']) - - ###id_to_order = {tid: ro for tid, ro in zip(tot_region_ref, index_tot_regions)} - - #id_textregions = [] - #textregions_by_existing_ids = [] - indexer = 0 - indexer_textregion = 0 - for nn in root1.iter(region_tags): - #id_textregion = nn.attrib['id'] - #id_textregions.append(id_textregion) - #textregions_by_existing_ids.append(text_by_textregion[indexer_textregion]) - - is_textregion_text = False - for childtest in nn: - if childtest.tag.endswith("TextEquiv"): - is_textregion_text = True - - if not is_textregion_text: - text_subelement_textregion = ET.SubElement(nn, 'TextEquiv') - unicode_textregion = ET.SubElement(text_subelement_textregion, 'Unicode') - - - has_textline = False - for child_textregion in nn: - if child_textregion.tag.endswith("TextLine"): - - is_textline_text = False - for childtest2 in child_textregion: - if childtest2.tag.endswith("TextEquiv"): - is_textline_text = True - - - if not is_textline_text: - text_subelement = ET.SubElement(child_textregion, 'TextEquiv') - text_subelement.set('conf', f"{extracted_conf_value_merged[indexer]:.2f}") - unicode_textline = ET.SubElement(text_subelement, 'Unicode') - unicode_textline.text = extracted_texts_merged[indexer] - else: - for childtest3 in child_textregion: - if childtest3.tag.endswith("TextEquiv"): - for child_uc in childtest3: - if child_uc.tag.endswith("Unicode"): - childtest3.set('conf', - f"{extracted_conf_value_merged[indexer]:.2f}") - child_uc.text = extracted_texts_merged[indexer] - - indexer = indexer + 1 - has_textline = True - if has_textline: - if is_textregion_text: - for child4 in nn: - if child4.tag.endswith("TextEquiv"): - for childtr_uc in child4: - if childtr_uc.tag.endswith("Unicode"): - childtr_uc.text = text_by_textregion[indexer_textregion] - else: - unicode_textregion.text = text_by_textregion[indexer_textregion] - indexer_textregion = indexer_textregion + 1 - - ###sample_order = [(id_to_order[tid], text) - ### for tid, text in zip(id_textregions, textregions_by_existing_ids) - ### if tid in id_to_order] - - ##ordered_texts_sample = [text for _, text in sorted(sample_order)] - ##tot_page_text = ' '.join(ordered_texts_sample) - - ##for page_element in root1.iter(link+'Page'): - ##text_page = ET.SubElement(page_element, 'TextEquiv') - ##unicode_textpage = ET.SubElement(text_page, 'Unicode') - ##unicode_textpage.text = tot_page_text - - ET.register_namespace("",name_space) - tree1.write(out_file_ocr,xml_declaration=True,method='xml',encoding="utf-8",default_namespace=None) - #print("Job done in %.1fs", time.time() - t0) diff --git a/src/eynollah/eynollah_imports.py b/src/eynollah/eynollah_imports.py new file mode 100644 index 0000000..a57f87d --- /dev/null +++ b/src/eynollah/eynollah_imports.py @@ -0,0 +1,8 @@ +""" +Load libraries with possible race conditions once. This must be imported as the first module of eynollah. +""" +from torch import * +import tensorflow.keras +from shapely import * +imported_libs = True +__all__ = ['imported_libs'] diff --git a/src/eynollah/eynollah_ocr.py b/src/eynollah/eynollah_ocr.py new file mode 100644 index 0000000..0f3eda6 --- /dev/null +++ b/src/eynollah/eynollah_ocr.py @@ -0,0 +1,1001 @@ +# FIXME: fix all of those... +# pyright: reportPossiblyUnboundVariable=false +# pyright: reportOptionalMemberAccess=false +# pyright: reportArgumentType=false +# pyright: reportCallIssue=false +# pyright: reportOptionalSubscript=false + +from logging import Logger, getLogger +from typing import Optional +from pathlib import Path +import os +import json +import gc +import sys +import math +import time + +from keras.layers import StringLookup +import cv2 +import xml.etree.ElementTree as ET +import tensorflow as tf +from keras.models import load_model +from PIL import Image, ImageDraw, ImageFont +import numpy as np +from eynollah.model_zoo import EynollahModelZoo +try: + import torch +except ImportError: + torch = None + + +from .utils import is_image_filename +from .utils.resize import resize_image +from .utils.utils_ocr import ( + break_curved_line_into_small_pieces_and_then_merge, + decode_batch_predictions, + fit_text_single_line, + get_contours_and_bounding_boxes, + get_orientation_moments, + preprocess_and_resize_image_for_ocrcnn_model, + return_textlines_split_if_needed, + rotate_image_with_padding, +) + +# cannot use importlib.resources until we move to 3.9+ forimportlib.resources.files +if sys.version_info < (3, 10): + import importlib_resources +else: + import importlib.resources as importlib_resources + +try: + from transformers import TrOCRProcessor, VisionEncoderDecoderModel +except ImportError: + TrOCRProcessor = VisionEncoderDecoderModel = None + +class Eynollah_ocr: + def __init__( + self, + *, + model_zoo: EynollahModelZoo, + tr_ocr=False, + batch_size: Optional[int]=None, + export_textline_images_and_text: bool=False, + do_not_mask_with_textline_contour: bool=False, + pref_of_dataset=None, + min_conf_value_of_textline_text : Optional[float]=None, + logger: Optional[Logger]=None, + ): + self.tr_ocr = tr_ocr + # For generating textline-image pairs for traning, move to generate_gt_for_training + self.export_textline_images_and_text = export_textline_images_and_text + # masking for OCR and GT generation, relevant for skewed lines and bounding boxes + self.do_not_mask_with_textline_contour = do_not_mask_with_textline_contour + # prefix or dataset + self.pref_of_dataset = pref_of_dataset + self.logger = logger if logger else getLogger('eynollah.ocr') + self.model_zoo = model_zoo + + # TODO: Properly document what 'export_textline_images_and_text' is about + if export_textline_images_and_text: + self.logger.info("export_textline_images_and_text was set, so no actual models are loaded") + return + + self.min_conf_value_of_textline_text = min_conf_value_of_textline_text if min_conf_value_of_textline_text else 0.3 + self.b_s = 2 if batch_size is None and tr_ocr else 8 if batch_size is None else batch_size + + if tr_ocr: + self.model_zoo.load_model('trocr_processor') + self.model_zoo.load_model('ocr', 'tr') + self.model_zoo.get('ocr').to(self.device) + else: + self.model_zoo.load_model('ocr', '') + self.model_zoo.load_model('num_to_char') + self.model_zoo.load_model('characters') + self.end_character = len(self.model_zoo.get('characters', list)) + 2 + + @property + def device(self): + if torch.cuda.is_available(): + self.logger.info("Using GPU acceleration") + return torch.device("cuda:0") + else: + self.logger.info("Using CPU processing") + return torch.device("cpu") + + def run(self, overwrite: bool = False, + dir_in: Optional[str] = None, + # Prediction with RGB and binarized images for selected pages, should not be the default + dir_in_bin: Optional[str] = None, + image_filename: Optional[str] = None, + dir_xmls: Optional[str] = None, + dir_out_image_text: Optional[str] = None, + dir_out: Optional[str] = None, + ): + if dir_in: + ls_imgs = [os.path.join(dir_in, image_filename) + for image_filename in filter(is_image_filename, + os.listdir(dir_in))] + else: + assert image_filename + ls_imgs = [image_filename] + + if self.tr_ocr: + tr_ocr_input_height_and_width = 384 + for dir_img in ls_imgs: + file_name = Path(dir_img).stem + assert dir_xmls # FIXME: check the logic + dir_xml = os.path.join(dir_xmls, file_name+'.xml') + assert dir_out # FIXME: check the logic + out_file_ocr = os.path.join(dir_out, file_name+'.xml') + + if os.path.exists(out_file_ocr): + if overwrite: + self.logger.warning("will overwrite existing output file '%s'", out_file_ocr) + else: + self.logger.warning("will skip input for existing output file '%s'", out_file_ocr) + continue + + img = cv2.imread(dir_img) + + if dir_out_image_text: + out_image_with_text = os.path.join(dir_out_image_text, file_name+'.png') + image_text = Image.new("RGB", (img.shape[1], img.shape[0]), "white") + draw = ImageDraw.Draw(image_text) + total_bb_coordinates = [] + + ##file_name = Path(dir_xmls).stem + tree1 = ET.parse(dir_xml, parser = ET.XMLParser(encoding="utf-8")) + root1=tree1.getroot() + alltags=[elem.tag for elem in root1.iter()] + link=alltags[0].split('}')[0]+'}' + + name_space = alltags[0].split('}')[0] + name_space = name_space.split('{')[1] + + region_tags=np.unique([x for x in alltags if x.endswith('TextRegion')]) + + + + cropped_lines = [] + cropped_lines_region_indexer = [] + cropped_lines_meging_indexing = [] + + extracted_texts = [] + + indexer_text_region = 0 + indexer_b_s = 0 + + for nn in root1.iter(region_tags): + for child_textregion in nn: + if child_textregion.tag.endswith("TextLine"): + + for child_textlines in child_textregion: + if child_textlines.tag.endswith("Coords"): + cropped_lines_region_indexer.append(indexer_text_region) + p_h=child_textlines.attrib['points'].split(' ') + textline_coords = np.array( [ [int(x.split(',')[0]), + int(x.split(',')[1]) ] + for x in p_h] ) + x,y,w,h = cv2.boundingRect(textline_coords) + + if dir_out_image_text: + total_bb_coordinates.append([x,y,w,h]) + + h2w_ratio = h/float(w) + + img_poly_on_img = np.copy(img) + mask_poly = np.zeros(img.shape) + mask_poly = cv2.fillPoly(mask_poly, pts=[textline_coords], color=(1, 1, 1)) + + mask_poly = mask_poly[y:y+h, x:x+w, :] + img_crop = img_poly_on_img[y:y+h, x:x+w, :] + img_crop[mask_poly==0] = 255 + + self.logger.debug("processing %d lines for '%s'", + len(cropped_lines), nn.attrib['id']) + if h2w_ratio > 0.1: + cropped_lines.append(resize_image(img_crop, + tr_ocr_input_height_and_width, + tr_ocr_input_height_and_width) ) + cropped_lines_meging_indexing.append(0) + indexer_b_s+=1 + if indexer_b_s==self.b_s: + imgs = cropped_lines[:] + cropped_lines = [] + indexer_b_s = 0 + + pixel_values_merged = self.model_zoo.get('trocr_processor')(imgs, return_tensors="pt").pixel_values + generated_ids_merged = self.model_zoo.get('ocr').generate( + pixel_values_merged.to(self.device)) + generated_text_merged = self.model_zoo.get('trocr_processor').batch_decode( + generated_ids_merged, skip_special_tokens=True) + + extracted_texts = extracted_texts + generated_text_merged + + else: + splited_images, _ = return_textlines_split_if_needed(img_crop, None) + #print(splited_images) + if splited_images: + cropped_lines.append(resize_image(splited_images[0], + tr_ocr_input_height_and_width, + tr_ocr_input_height_and_width)) + cropped_lines_meging_indexing.append(1) + indexer_b_s+=1 + + if indexer_b_s==self.b_s: + imgs = cropped_lines[:] + cropped_lines = [] + indexer_b_s = 0 + + pixel_values_merged = self.model_zoo.get('trocr_processor')(imgs, return_tensors="pt").pixel_values + generated_ids_merged = self.model_zoo.get('ocr').generate( + pixel_values_merged.to(self.device)) + generated_text_merged = self.model_zoo.get('trocr_processor').batch_decode( + generated_ids_merged, skip_special_tokens=True) + + extracted_texts = extracted_texts + generated_text_merged + + + cropped_lines.append(resize_image(splited_images[1], + tr_ocr_input_height_and_width, + tr_ocr_input_height_and_width)) + cropped_lines_meging_indexing.append(-1) + indexer_b_s+=1 + + if indexer_b_s==self.b_s: + imgs = cropped_lines[:] + cropped_lines = [] + indexer_b_s = 0 + + pixel_values_merged = self.model_zoo.get('trocr_processor')(imgs, return_tensors="pt").pixel_values + generated_ids_merged = self.model_zoo.get('ocr').generate( + pixel_values_merged.to(self.device)) + generated_text_merged = self.model_zoo.get('trocr_processor').batch_decode( + generated_ids_merged, skip_special_tokens=True) + + extracted_texts = extracted_texts + generated_text_merged + + else: + cropped_lines.append(img_crop) + cropped_lines_meging_indexing.append(0) + indexer_b_s+=1 + + if indexer_b_s==self.b_s: + imgs = cropped_lines[:] + cropped_lines = [] + indexer_b_s = 0 + + pixel_values_merged = self.model_zoo.get('trocr_processor')(imgs, return_tensors="pt").pixel_values + generated_ids_merged = self.model_zoo.get('ocr').generate( + pixel_values_merged.to(self.device)) + generated_text_merged = self.model_zoo.get('trocr_processor').batch_decode( + generated_ids_merged, skip_special_tokens=True) + + extracted_texts = extracted_texts + generated_text_merged + + + + indexer_text_region = indexer_text_region +1 + + if indexer_b_s!=0: + imgs = cropped_lines[:] + cropped_lines = [] + indexer_b_s = 0 + + pixel_values_merged = self.model_zoo.get('trocr_processor')(imgs, return_tensors="pt").pixel_values + generated_ids_merged = self.model_zoo.get('ocr').generate(pixel_values_merged.to(self.device)) + generated_text_merged = self.model_zoo.get('trocr_processor').batch_decode(generated_ids_merged, skip_special_tokens=True) + + extracted_texts = extracted_texts + generated_text_merged + + ####extracted_texts = [] + ####n_iterations = math.ceil(len(cropped_lines) / self.b_s) + + ####for i in range(n_iterations): + ####if i==(n_iterations-1): + ####n_start = i*self.b_s + ####imgs = cropped_lines[n_start:] + ####else: + ####n_start = i*self.b_s + ####n_end = (i+1)*self.b_s + ####imgs = cropped_lines[n_start:n_end] + ####pixel_values_merged = self.model_zoo.get('trocr_processor')(imgs, return_tensors="pt").pixel_values + ####generated_ids_merged = self.model_ocr.generate( + #### pixel_values_merged.to(self.device)) + ####generated_text_merged = self.model_zoo.get('trocr_processor').batch_decode( + #### generated_ids_merged, skip_special_tokens=True) + + ####extracted_texts = extracted_texts + generated_text_merged + + del cropped_lines + gc.collect() + + extracted_texts_merged = [extracted_texts[ind] + if cropped_lines_meging_indexing[ind]==0 + else extracted_texts[ind]+" "+extracted_texts[ind+1] + if cropped_lines_meging_indexing[ind]==1 + else None + for ind in range(len(cropped_lines_meging_indexing))] + + extracted_texts_merged = [ind for ind in extracted_texts_merged if ind is not None] + #print(extracted_texts_merged, len(extracted_texts_merged)) + + unique_cropped_lines_region_indexer = np.unique(cropped_lines_region_indexer) + + if dir_out_image_text: + + #font_path = "Charis-7.000/Charis-Regular.ttf" # Make sure this file exists! + font = importlib_resources.files(__package__) / "Charis-Regular.ttf" + with importlib_resources.as_file(font) as font: + font = ImageFont.truetype(font=font, size=40) + + for indexer_text, bb_ind in enumerate(total_bb_coordinates): + + + x_bb = bb_ind[0] + y_bb = bb_ind[1] + w_bb = bb_ind[2] + h_bb = bb_ind[3] + + font = fit_text_single_line(draw, extracted_texts_merged[indexer_text], + font.path, w_bb, int(h_bb*0.4) ) + + ##draw.rectangle([x_bb, y_bb, x_bb + w_bb, y_bb + h_bb], outline="red", width=2) + + text_bbox = draw.textbbox((0, 0), extracted_texts_merged[indexer_text], font=font) + text_width = text_bbox[2] - text_bbox[0] + text_height = text_bbox[3] - text_bbox[1] + + text_x = x_bb + (w_bb - text_width) // 2 # Center horizontally + text_y = y_bb + (h_bb - text_height) // 2 # Center vertically + + # Draw the text + draw.text((text_x, text_y), extracted_texts_merged[indexer_text], fill="black", font=font) + image_text.save(out_image_with_text) + + #print(len(unique_cropped_lines_region_indexer), 'unique_cropped_lines_region_indexer') + #######text_by_textregion = [] + #######for ind in unique_cropped_lines_region_indexer: + #######ind = np.array(cropped_lines_region_indexer)==ind + #######extracted_texts_merged_un = np.array(extracted_texts_merged)[ind] + #######text_by_textregion.append(" ".join(extracted_texts_merged_un)) + + text_by_textregion = [] + for ind in unique_cropped_lines_region_indexer: + ind = np.array(cropped_lines_region_indexer) == ind + extracted_texts_merged_un = np.array(extracted_texts_merged)[ind] + if len(extracted_texts_merged_un)>1: + text_by_textregion_ind = "" + next_glue = "" + for indt in range(len(extracted_texts_merged_un)): + if (extracted_texts_merged_un[indt].endswith('⸗') or + extracted_texts_merged_un[indt].endswith('-') or + extracted_texts_merged_un[indt].endswith('¬')): + text_by_textregion_ind += next_glue + extracted_texts_merged_un[indt][:-1] + next_glue = "" + else: + text_by_textregion_ind += next_glue + extracted_texts_merged_un[indt] + next_glue = " " + text_by_textregion.append(text_by_textregion_ind) + else: + text_by_textregion.append(" ".join(extracted_texts_merged_un)) + + + indexer = 0 + indexer_textregion = 0 + for nn in root1.iter(region_tags): + #id_textregion = nn.attrib['id'] + #id_textregions.append(id_textregion) + #textregions_by_existing_ids.append(text_by_textregion[indexer_textregion]) + + is_textregion_text = False + for childtest in nn: + if childtest.tag.endswith("TextEquiv"): + is_textregion_text = True + + if not is_textregion_text: + text_subelement_textregion = ET.SubElement(nn, 'TextEquiv') + unicode_textregion = ET.SubElement(text_subelement_textregion, 'Unicode') + + + has_textline = False + for child_textregion in nn: + if child_textregion.tag.endswith("TextLine"): + + is_textline_text = False + for childtest2 in child_textregion: + if childtest2.tag.endswith("TextEquiv"): + is_textline_text = True + + + if not is_textline_text: + text_subelement = ET.SubElement(child_textregion, 'TextEquiv') + ##text_subelement.set('conf', f"{extracted_conf_value_merged[indexer]:.2f}") + unicode_textline = ET.SubElement(text_subelement, 'Unicode') + unicode_textline.text = extracted_texts_merged[indexer] + else: + for childtest3 in child_textregion: + if childtest3.tag.endswith("TextEquiv"): + for child_uc in childtest3: + if child_uc.tag.endswith("Unicode"): + ##childtest3.set('conf', f"{extracted_conf_value_merged[indexer]:.2f}") + child_uc.text = extracted_texts_merged[indexer] + + indexer = indexer + 1 + has_textline = True + if has_textline: + if is_textregion_text: + for child4 in nn: + if child4.tag.endswith("TextEquiv"): + for childtr_uc in child4: + if childtr_uc.tag.endswith("Unicode"): + childtr_uc.text = text_by_textregion[indexer_textregion] + else: + unicode_textregion.text = text_by_textregion[indexer_textregion] + indexer_textregion = indexer_textregion + 1 + + ###sample_order = [(id_to_order[tid], text) + ### for tid, text in zip(id_textregions, textregions_by_existing_ids) + ### if tid in id_to_order] + + ##ordered_texts_sample = [text for _, text in sorted(sample_order)] + ##tot_page_text = ' '.join(ordered_texts_sample) + + ##for page_element in root1.iter(link+'Page'): + ##text_page = ET.SubElement(page_element, 'TextEquiv') + ##unicode_textpage = ET.SubElement(text_page, 'Unicode') + ##unicode_textpage.text = tot_page_text + + ET.register_namespace("",name_space) + tree1.write(out_file_ocr,xml_declaration=True,method='xml',encoding="utf-8",default_namespace=None) + else: + ###max_len = 280#512#280#512 + ###padding_token = 1500#299#1500#299 + image_width = 512#max_len * 4 + image_height = 32 + + + img_size=(image_width, image_height) + + for dir_img in ls_imgs: + file_name = Path(dir_img).stem + dir_xml = os.path.join(dir_xmls, file_name+'.xml') + out_file_ocr = os.path.join(dir_out, file_name+'.xml') + + if os.path.exists(out_file_ocr): + if overwrite: + self.logger.warning("will overwrite existing output file '%s'", out_file_ocr) + else: + self.logger.warning("will skip input for existing output file '%s'", out_file_ocr) + continue + + img = cv2.imread(dir_img) + if dir_in_bin is not None: + cropped_lines_bin = [] + dir_img_bin = os.path.join(dir_in_bin, file_name+'.png') + img_bin = cv2.imread(dir_img_bin) + + if dir_out_image_text: + out_image_with_text = os.path.join(dir_out_image_text, file_name+'.png') + image_text = Image.new("RGB", (img.shape[1], img.shape[0]), "white") + draw = ImageDraw.Draw(image_text) + total_bb_coordinates = [] + + tree1 = ET.parse(dir_xml, parser = ET.XMLParser(encoding="utf-8")) + root1=tree1.getroot() + alltags=[elem.tag for elem in root1.iter()] + link=alltags[0].split('}')[0]+'}' + + name_space = alltags[0].split('}')[0] + name_space = name_space.split('{')[1] + + region_tags=np.unique([x for x in alltags if x.endswith('TextRegion')]) + + cropped_lines = [] + cropped_lines_ver_index = [] + cropped_lines_region_indexer = [] + cropped_lines_meging_indexing = [] + + tinl = time.time() + indexer_text_region = 0 + indexer_textlines = 0 + for nn in root1.iter(region_tags): + try: + type_textregion = nn.attrib['type'] + except: + type_textregion = 'paragraph' + for child_textregion in nn: + if child_textregion.tag.endswith("TextLine"): + for child_textlines in child_textregion: + if child_textlines.tag.endswith("Coords"): + cropped_lines_region_indexer.append(indexer_text_region) + p_h=child_textlines.attrib['points'].split(' ') + textline_coords = np.array( [ [int(x.split(',')[0]), + int(x.split(',')[1]) ] + for x in p_h] ) + + x,y,w,h = cv2.boundingRect(textline_coords) + + angle_radians = math.atan2(h, w) + # Convert to degrees + angle_degrees = math.degrees(angle_radians) + if type_textregion=='drop-capital': + angle_degrees = 0 + + if dir_out_image_text: + total_bb_coordinates.append([x,y,w,h]) + + w_scaled = w * image_height/float(h) + + img_poly_on_img = np.copy(img) + if dir_in_bin is not None: + img_poly_on_img_bin = np.copy(img_bin) + img_crop_bin = img_poly_on_img_bin[y:y+h, x:x+w, :] + + mask_poly = np.zeros(img.shape) + mask_poly = cv2.fillPoly(mask_poly, pts=[textline_coords], color=(1, 1, 1)) + + + mask_poly = mask_poly[y:y+h, x:x+w, :] + img_crop = img_poly_on_img[y:y+h, x:x+w, :] + + if self.export_textline_images_and_text: + if not self.do_not_mask_with_textline_contour: + img_crop[mask_poly==0] = 255 + + else: + # print(file_name, angle_degrees, w*h, + # mask_poly[:,:,0].sum(), + # mask_poly[:,:,0].sum() /float(w*h) , + # 'didi') + + if angle_degrees > 3: + better_des_slope = get_orientation_moments(textline_coords) + + img_crop = rotate_image_with_padding(img_crop, better_des_slope) + if dir_in_bin is not None: + img_crop_bin = rotate_image_with_padding(img_crop_bin, better_des_slope) + + mask_poly = rotate_image_with_padding(mask_poly, better_des_slope) + mask_poly = mask_poly.astype('uint8') + + #new bounding box + x_n, y_n, w_n, h_n = get_contours_and_bounding_boxes(mask_poly[:,:,0]) + + mask_poly = mask_poly[y_n:y_n+h_n, x_n:x_n+w_n, :] + img_crop = img_crop[y_n:y_n+h_n, x_n:x_n+w_n, :] + + if not self.do_not_mask_with_textline_contour: + img_crop[mask_poly==0] = 255 + if dir_in_bin is not None: + img_crop_bin = img_crop_bin[y_n:y_n+h_n, x_n:x_n+w_n, :] + if not self.do_not_mask_with_textline_contour: + img_crop_bin[mask_poly==0] = 255 + + if mask_poly[:,:,0].sum() /float(w_n*h_n) < 0.50 and w_scaled > 90: + if dir_in_bin is not None: + img_crop, img_crop_bin = \ + break_curved_line_into_small_pieces_and_then_merge( + img_crop, mask_poly, img_crop_bin) + else: + img_crop, _ = \ + break_curved_line_into_small_pieces_and_then_merge( + img_crop, mask_poly) + + else: + better_des_slope = 0 + if not self.do_not_mask_with_textline_contour: + img_crop[mask_poly==0] = 255 + if dir_in_bin is not None: + if not self.do_not_mask_with_textline_contour: + img_crop_bin[mask_poly==0] = 255 + if type_textregion=='drop-capital': + pass + else: + if mask_poly[:,:,0].sum() /float(w*h) < 0.50 and w_scaled > 90: + if dir_in_bin is not None: + img_crop, img_crop_bin = \ + break_curved_line_into_small_pieces_and_then_merge( + img_crop, mask_poly, img_crop_bin) + else: + img_crop, _ = \ + break_curved_line_into_small_pieces_and_then_merge( + img_crop, mask_poly) + + if not self.export_textline_images_and_text: + if w_scaled < 750:#1.5*image_width: + img_fin = preprocess_and_resize_image_for_ocrcnn_model( + img_crop, image_height, image_width) + cropped_lines.append(img_fin) + if abs(better_des_slope) > 45: + cropped_lines_ver_index.append(1) + else: + cropped_lines_ver_index.append(0) + + cropped_lines_meging_indexing.append(0) + if dir_in_bin is not None: + img_fin = preprocess_and_resize_image_for_ocrcnn_model( + img_crop_bin, image_height, image_width) + cropped_lines_bin.append(img_fin) + else: + splited_images, splited_images_bin = return_textlines_split_if_needed( + img_crop, img_crop_bin if dir_in_bin is not None else None) + if splited_images: + img_fin = preprocess_and_resize_image_for_ocrcnn_model( + splited_images[0], image_height, image_width) + cropped_lines.append(img_fin) + cropped_lines_meging_indexing.append(1) + + if abs(better_des_slope) > 45: + cropped_lines_ver_index.append(1) + else: + cropped_lines_ver_index.append(0) + + img_fin = preprocess_and_resize_image_for_ocrcnn_model( + splited_images[1], image_height, image_width) + + cropped_lines.append(img_fin) + cropped_lines_meging_indexing.append(-1) + + if abs(better_des_slope) > 45: + cropped_lines_ver_index.append(1) + else: + cropped_lines_ver_index.append(0) + + if dir_in_bin is not None: + img_fin = preprocess_and_resize_image_for_ocrcnn_model( + splited_images_bin[0], image_height, image_width) + cropped_lines_bin.append(img_fin) + img_fin = preprocess_and_resize_image_for_ocrcnn_model( + splited_images_bin[1], image_height, image_width) + cropped_lines_bin.append(img_fin) + + else: + img_fin = preprocess_and_resize_image_for_ocrcnn_model( + img_crop, image_height, image_width) + cropped_lines.append(img_fin) + cropped_lines_meging_indexing.append(0) + + if abs(better_des_slope) > 45: + cropped_lines_ver_index.append(1) + else: + cropped_lines_ver_index.append(0) + + if dir_in_bin is not None: + img_fin = preprocess_and_resize_image_for_ocrcnn_model( + img_crop_bin, image_height, image_width) + cropped_lines_bin.append(img_fin) + + if self.export_textline_images_and_text: + if img_crop.shape[0]==0 or img_crop.shape[1]==0: + pass + else: + if child_textlines.tag.endswith("TextEquiv"): + for cheild_text in child_textlines: + if cheild_text.tag.endswith("Unicode"): + textline_text = cheild_text.text + if textline_text: + base_name = os.path.join( + dir_out, file_name + '_line_' + str(indexer_textlines)) + if self.pref_of_dataset: + base_name += '_' + self.pref_of_dataset + if not self.do_not_mask_with_textline_contour: + base_name += '_masked' + + with open(base_name + '.txt', 'w') as text_file: + text_file.write(textline_text) + cv2.imwrite(base_name + '.png', img_crop) + indexer_textlines+=1 + + if not self.export_textline_images_and_text: + indexer_text_region = indexer_text_region +1 + + if not self.export_textline_images_and_text: + extracted_texts = [] + extracted_conf_value = [] + + n_iterations = math.ceil(len(cropped_lines) / self.b_s) + + for i in range(n_iterations): + if i==(n_iterations-1): + n_start = i*self.b_s + imgs = cropped_lines[n_start:] + imgs = np.array(imgs) + imgs = imgs.reshape(imgs.shape[0], image_height, image_width, 3) + + ver_imgs = np.array( cropped_lines_ver_index[n_start:] ) + indices_ver = np.where(ver_imgs == 1)[0] + + #print(indices_ver, 'indices_ver') + if len(indices_ver)>0: + imgs_ver_flipped = imgs[indices_ver, : ,: ,:] + imgs_ver_flipped = imgs_ver_flipped[:,::-1,::-1,:] + #print(imgs_ver_flipped, 'imgs_ver_flipped') + + else: + imgs_ver_flipped = None + + if dir_in_bin is not None: + imgs_bin = cropped_lines_bin[n_start:] + imgs_bin = np.array(imgs_bin) + imgs_bin = imgs_bin.reshape(imgs_bin.shape[0], image_height, image_width, 3) + + if len(indices_ver)>0: + imgs_bin_ver_flipped = imgs_bin[indices_ver, : ,: ,:] + imgs_bin_ver_flipped = imgs_bin_ver_flipped[:,::-1,::-1,:] + #print(imgs_ver_flipped, 'imgs_ver_flipped') + + else: + imgs_bin_ver_flipped = None + else: + n_start = i*self.b_s + n_end = (i+1)*self.b_s + imgs = cropped_lines[n_start:n_end] + imgs = np.array(imgs).reshape(self.b_s, image_height, image_width, 3) + + ver_imgs = np.array( cropped_lines_ver_index[n_start:n_end] ) + indices_ver = np.where(ver_imgs == 1)[0] + #print(indices_ver, 'indices_ver') + + if len(indices_ver)>0: + imgs_ver_flipped = imgs[indices_ver, : ,: ,:] + imgs_ver_flipped = imgs_ver_flipped[:,::-1,::-1,:] + #print(imgs_ver_flipped, 'imgs_ver_flipped') + else: + imgs_ver_flipped = None + + + if dir_in_bin is not None: + imgs_bin = cropped_lines_bin[n_start:n_end] + imgs_bin = np.array(imgs_bin).reshape(self.b_s, image_height, image_width, 3) + + + if len(indices_ver)>0: + imgs_bin_ver_flipped = imgs_bin[indices_ver, : ,: ,:] + imgs_bin_ver_flipped = imgs_bin_ver_flipped[:,::-1,::-1,:] + #print(imgs_ver_flipped, 'imgs_ver_flipped') + else: + imgs_bin_ver_flipped = None + + + self.logger.debug("processing next %d lines", len(imgs)) + preds = self.model_zoo.get('ocr').predict(imgs, verbose=0) + + if len(indices_ver)>0: + preds_flipped = self.model_zoo.get('ocr').predict(imgs_ver_flipped, verbose=0) + preds_max_fliped = np.max(preds_flipped, axis=2 ) + preds_max_args_flipped = np.argmax(preds_flipped, axis=2 ) + pred_max_not_unk_mask_bool_flipped = preds_max_args_flipped[:,:]!=self.end_character + masked_means_flipped = \ + np.sum(preds_max_fliped * pred_max_not_unk_mask_bool_flipped, axis=1) / \ + np.sum(pred_max_not_unk_mask_bool_flipped, axis=1) + masked_means_flipped[np.isnan(masked_means_flipped)] = 0 + + preds_max = np.max(preds, axis=2 ) + preds_max_args = np.argmax(preds, axis=2 ) + pred_max_not_unk_mask_bool = preds_max_args[:,:]!=self.end_character + + masked_means = \ + np.sum(preds_max * pred_max_not_unk_mask_bool, axis=1) / \ + np.sum(pred_max_not_unk_mask_bool, axis=1) + masked_means[np.isnan(masked_means)] = 0 + + masked_means_ver = masked_means[indices_ver] + #print(masked_means_ver, 'pred_max_not_unk') + + indices_where_flipped_conf_value_is_higher = \ + np.where(masked_means_flipped > masked_means_ver)[0] + + #print(indices_where_flipped_conf_value_is_higher, 'indices_where_flipped_conf_value_is_higher') + if len(indices_where_flipped_conf_value_is_higher)>0: + indices_to_be_replaced = indices_ver[indices_where_flipped_conf_value_is_higher] + preds[indices_to_be_replaced,:,:] = \ + preds_flipped[indices_where_flipped_conf_value_is_higher, :, :] + if dir_in_bin is not None: + preds_bin = self.model_zoo.get('ocr').predict(imgs_bin, verbose=0) + + if len(indices_ver)>0: + preds_flipped = self.model_zoo.get('ocr').predict(imgs_bin_ver_flipped, verbose=0) + preds_max_fliped = np.max(preds_flipped, axis=2 ) + preds_max_args_flipped = np.argmax(preds_flipped, axis=2 ) + pred_max_not_unk_mask_bool_flipped = preds_max_args_flipped[:,:]!=self.end_character + masked_means_flipped = \ + np.sum(preds_max_fliped * pred_max_not_unk_mask_bool_flipped, axis=1) / \ + np.sum(pred_max_not_unk_mask_bool_flipped, axis=1) + masked_means_flipped[np.isnan(masked_means_flipped)] = 0 + + preds_max = np.max(preds, axis=2 ) + preds_max_args = np.argmax(preds, axis=2 ) + pred_max_not_unk_mask_bool = preds_max_args[:,:]!=self.end_character + + masked_means = \ + np.sum(preds_max * pred_max_not_unk_mask_bool, axis=1) / \ + np.sum(pred_max_not_unk_mask_bool, axis=1) + masked_means[np.isnan(masked_means)] = 0 + + masked_means_ver = masked_means[indices_ver] + #print(masked_means_ver, 'pred_max_not_unk') + + indices_where_flipped_conf_value_is_higher = \ + np.where(masked_means_flipped > masked_means_ver)[0] + + #print(indices_where_flipped_conf_value_is_higher, 'indices_where_flipped_conf_value_is_higher') + if len(indices_where_flipped_conf_value_is_higher)>0: + indices_to_be_replaced = indices_ver[indices_where_flipped_conf_value_is_higher] + preds_bin[indices_to_be_replaced,:,:] = \ + preds_flipped[indices_where_flipped_conf_value_is_higher, :, :] + + preds = (preds + preds_bin) / 2. + + pred_texts = decode_batch_predictions(preds, self.model_zoo.get('num_to_char')) + + preds_max = np.max(preds, axis=2 ) + preds_max_args = np.argmax(preds, axis=2 ) + pred_max_not_unk_mask_bool = preds_max_args[:,:]!=self.end_character + masked_means = \ + np.sum(preds_max * pred_max_not_unk_mask_bool, axis=1) / \ + np.sum(pred_max_not_unk_mask_bool, axis=1) + + for ib in range(imgs.shape[0]): + pred_texts_ib = pred_texts[ib].replace("[UNK]", "") + if masked_means[ib] >= self.min_conf_value_of_textline_text: + extracted_texts.append(pred_texts_ib) + extracted_conf_value.append(masked_means[ib]) + else: + extracted_texts.append("") + extracted_conf_value.append(0) + del cropped_lines + if dir_in_bin is not None: + del cropped_lines_bin + gc.collect() + + extracted_texts_merged = [extracted_texts[ind] + if cropped_lines_meging_indexing[ind]==0 + else extracted_texts[ind]+" "+extracted_texts[ind+1] + if cropped_lines_meging_indexing[ind]==1 + else None + for ind in range(len(cropped_lines_meging_indexing))] + + extracted_conf_value_merged = [extracted_conf_value[ind] + if cropped_lines_meging_indexing[ind]==0 + else (extracted_conf_value[ind]+extracted_conf_value[ind+1])/2. + if cropped_lines_meging_indexing[ind]==1 + else None + for ind in range(len(cropped_lines_meging_indexing))] + + extracted_conf_value_merged = [extracted_conf_value_merged[ind_cfm] + for ind_cfm in range(len(extracted_texts_merged)) + if extracted_texts_merged[ind_cfm] is not None] + extracted_texts_merged = [ind for ind in extracted_texts_merged if ind is not None] + unique_cropped_lines_region_indexer = np.unique(cropped_lines_region_indexer) + + if dir_out_image_text: + #font_path = "Charis-7.000/Charis-Regular.ttf" # Make sure this file exists! + font = importlib_resources.files(__package__) / "Charis-Regular.ttf" + with importlib_resources.as_file(font) as font: + font = ImageFont.truetype(font=font, size=40) + + for indexer_text, bb_ind in enumerate(total_bb_coordinates): + x_bb = bb_ind[0] + y_bb = bb_ind[1] + w_bb = bb_ind[2] + h_bb = bb_ind[3] + + font = fit_text_single_line(draw, extracted_texts_merged[indexer_text], + font.path, w_bb, int(h_bb*0.4) ) + + ##draw.rectangle([x_bb, y_bb, x_bb + w_bb, y_bb + h_bb], outline="red", width=2) + + text_bbox = draw.textbbox((0, 0), extracted_texts_merged[indexer_text], font=font) + text_width = text_bbox[2] - text_bbox[0] + text_height = text_bbox[3] - text_bbox[1] + + text_x = x_bb + (w_bb - text_width) // 2 # Center horizontally + text_y = y_bb + (h_bb - text_height) // 2 # Center vertically + + # Draw the text + draw.text((text_x, text_y), extracted_texts_merged[indexer_text], fill="black", font=font) + image_text.save(out_image_with_text) + + text_by_textregion = [] + for ind in unique_cropped_lines_region_indexer: + ind = np.array(cropped_lines_region_indexer)==ind + extracted_texts_merged_un = np.array(extracted_texts_merged)[ind] + if len(extracted_texts_merged_un)>1: + text_by_textregion_ind = "" + next_glue = "" + for indt in range(len(extracted_texts_merged_un)): + if (extracted_texts_merged_un[indt].endswith('⸗') or + extracted_texts_merged_un[indt].endswith('-') or + extracted_texts_merged_un[indt].endswith('¬')): + text_by_textregion_ind += next_glue + extracted_texts_merged_un[indt][:-1] + next_glue = "" + else: + text_by_textregion_ind += next_glue + extracted_texts_merged_un[indt] + next_glue = " " + text_by_textregion.append(text_by_textregion_ind) + else: + text_by_textregion.append(" ".join(extracted_texts_merged_un)) + #print(text_by_textregion, 'text_by_textregiontext_by_textregiontext_by_textregiontext_by_textregiontext_by_textregion') + + ###index_tot_regions = [] + ###tot_region_ref = [] + + ###for jj in root1.iter(link+'RegionRefIndexed'): + ###index_tot_regions.append(jj.attrib['index']) + ###tot_region_ref.append(jj.attrib['regionRef']) + + ###id_to_order = {tid: ro for tid, ro in zip(tot_region_ref, index_tot_regions)} + + #id_textregions = [] + #textregions_by_existing_ids = [] + indexer = 0 + indexer_textregion = 0 + for nn in root1.iter(region_tags): + #id_textregion = nn.attrib['id'] + #id_textregions.append(id_textregion) + #textregions_by_existing_ids.append(text_by_textregion[indexer_textregion]) + + is_textregion_text = False + for childtest in nn: + if childtest.tag.endswith("TextEquiv"): + is_textregion_text = True + + if not is_textregion_text: + text_subelement_textregion = ET.SubElement(nn, 'TextEquiv') + unicode_textregion = ET.SubElement(text_subelement_textregion, 'Unicode') + + + has_textline = False + for child_textregion in nn: + if child_textregion.tag.endswith("TextLine"): + + is_textline_text = False + for childtest2 in child_textregion: + if childtest2.tag.endswith("TextEquiv"): + is_textline_text = True + + + if not is_textline_text: + text_subelement = ET.SubElement(child_textregion, 'TextEquiv') + text_subelement.set('conf', f"{extracted_conf_value_merged[indexer]:.2f}") + unicode_textline = ET.SubElement(text_subelement, 'Unicode') + unicode_textline.text = extracted_texts_merged[indexer] + else: + for childtest3 in child_textregion: + if childtest3.tag.endswith("TextEquiv"): + for child_uc in childtest3: + if child_uc.tag.endswith("Unicode"): + childtest3.set('conf', + f"{extracted_conf_value_merged[indexer]:.2f}") + child_uc.text = extracted_texts_merged[indexer] + + indexer = indexer + 1 + has_textline = True + if has_textline: + if is_textregion_text: + for child4 in nn: + if child4.tag.endswith("TextEquiv"): + for childtr_uc in child4: + if childtr_uc.tag.endswith("Unicode"): + childtr_uc.text = text_by_textregion[indexer_textregion] + else: + unicode_textregion.text = text_by_textregion[indexer_textregion] + indexer_textregion = indexer_textregion + 1 + + ###sample_order = [(id_to_order[tid], text) + ### for tid, text in zip(id_textregions, textregions_by_existing_ids) + ### if tid in id_to_order] + + ##ordered_texts_sample = [text for _, text in sorted(sample_order)] + ##tot_page_text = ' '.join(ordered_texts_sample) + + ##for page_element in root1.iter(link+'Page'): + ##text_page = ET.SubElement(page_element, 'TextEquiv') + ##unicode_textpage = ET.SubElement(text_page, 'Unicode') + ##unicode_textpage.text = tot_page_text + + ET.register_namespace("",name_space) + tree1.write(out_file_ocr,xml_declaration=True,method='xml',encoding="utf-8",default_namespace=None) + #print("Job done in %.1fs", time.time() - t0) diff --git a/src/eynollah/image_enhancer.py b/src/eynollah/image_enhancer.py index 9247efe..575a583 100644 --- a/src/eynollah/image_enhancer.py +++ b/src/eynollah/image_enhancer.py @@ -2,27 +2,32 @@ Image enhancer. The output can be written as same scale of input or in new predicted scale. """ -from logging import Logger +# FIXME: fix all of those... +# pyright: reportUnboundVariable=false +# pyright: reportCallIssue=false +# pyright: reportArgumentType=false + +import logging import os import time -from typing import Optional +from typing import Dict, Optional from pathlib import Path import gc import cv2 +from keras.models import Model import numpy as np -from ocrd_utils import getLogger, tf_disable_interactive_logs import tensorflow as tf from skimage.morphology import skeletonize -from tensorflow.keras.models import load_model +from .model_zoo import EynollahModelZoo from .utils.resize import resize_image from .utils.pil_cv2 import pil2cv from .utils import ( is_image_filename, crop_image_inside_box ) -from .eynollah import PatchEncoder, Patches +from .patch_encoder import PatchEncoder, Patches DPI_THRESHOLD = 298 KERNEL = np.ones((5, 5), np.uint8) @@ -31,11 +36,11 @@ KERNEL = np.ones((5, 5), np.uint8) class Enhancer: def __init__( self, - dir_models : str, + *, + model_zoo: EynollahModelZoo, num_col_upper : Optional[int] = None, num_col_lower : Optional[int] = None, save_org_scale : bool = False, - logger : Optional[Logger] = None, ): self.input_binary = False self.light_version = False @@ -49,12 +54,10 @@ class Enhancer: else: self.num_col_lower = num_col_lower - self.logger = logger if logger else getLogger('enhancement') - self.dir_models = dir_models - self.model_dir_of_binarization = dir_models + "/eynollah-binarization_20210425" - self.model_dir_of_enhancement = dir_models + "/eynollah-enhancement_20210425" - self.model_dir_of_col_classifier = dir_models + "/eynollah-column-classifier_20210425" - self.model_page_dir = dir_models + "/model_eynollah_page_extraction_20250915" + self.logger = logging.getLogger('eynollah.enhance') + self.model_zoo = model_zoo + for v in ['binarization', 'enhancement', 'col_classifier', 'page']: + self.model_zoo.load_model(v) try: for device in tf.config.list_physical_devices('GPU'): @@ -62,11 +65,6 @@ class Enhancer: except: self.logger.warning("no GPU device available") - self.model_page = self.our_load_model(self.model_page_dir) - self.model_classifier = self.our_load_model(self.model_dir_of_col_classifier) - self.model_enhancement = self.our_load_model(self.model_dir_of_enhancement) - self.model_bin = self.our_load_model(self.model_dir_of_binarization) - def cache_images(self, image_filename=None, image_pil=None, dpi=None): ret = {} if image_filename: @@ -102,24 +100,12 @@ class Enhancer: def isNaN(self, num): return num != num - - @staticmethod - def our_load_model(model_file): - if model_file.endswith('.h5') and Path(model_file[:-3]).exists(): - # prefer SavedModel over HDF5 format if it exists - model_file = model_file[:-3] - try: - model = load_model(model_file, compile=False) - except: - model = load_model(model_file, compile=False, custom_objects={ - "PatchEncoder": PatchEncoder, "Patches": Patches}) - return model def predict_enhancement(self, img): self.logger.debug("enter predict_enhancement") - img_height_model = self.model_enhancement.layers[-1].output_shape[1] - img_width_model = self.model_enhancement.layers[-1].output_shape[2] + img_height_model = self.model_zoo.get('enhancement', Model).layers[-1].output_shape[1] + img_width_model = self.model_zoo.get('enhancement', Model).layers[-1].output_shape[2] if img.shape[0] < img_height_model: img = cv2.resize(img, (img.shape[1], img_width_model), interpolation=cv2.INTER_NEAREST) if img.shape[1] < img_width_model: @@ -160,7 +146,7 @@ class Enhancer: index_y_d = img_h - img_height_model img_patch = img[np.newaxis, index_y_d:index_y_u, index_x_d:index_x_u, :] - label_p_pred = self.model_enhancement.predict(img_patch, verbose=0) + label_p_pred = self.model_zoo.get('enhancement', Model).predict(img_patch, verbose='0') seg = label_p_pred[0, :, :, :] * 255 if i == 0 and j == 0: @@ -246,7 +232,7 @@ class Enhancer: else: img = self.imread() img = cv2.GaussianBlur(img, (5, 5), 0) - img_page_prediction = self.do_prediction(False, img, self.model_page) + img_page_prediction = self.do_prediction(False, img, self.model_zoo.get('page')) imgray = cv2.cvtColor(img_page_prediction, cv2.COLOR_BGR2GRAY) _, thresh = cv2.threshold(imgray, 0, 255, 0) @@ -291,7 +277,7 @@ class Enhancer: self.logger.info("Detected %s DPI", dpi) if self.input_binary: img = self.imread() - prediction_bin = self.do_prediction(True, img, self.model_bin, n_batch_inference=5) + prediction_bin = self.do_prediction(True, img, self.model_zoo.get('binarization'), n_batch_inference=5) prediction_bin = 255 * (prediction_bin[:,:,0]==0) prediction_bin = np.repeat(prediction_bin[:, :, np.newaxis], 3, axis=2).astype(np.uint8) img= np.copy(prediction_bin) @@ -332,7 +318,7 @@ class Enhancer: img_in[0, :, :, 1] = img_1ch[:, :] img_in[0, :, :, 2] = img_1ch[:, :] - label_p_pred = self.model_classifier.predict(img_in, verbose=0) + label_p_pred = self.model_zoo.get('col_classifier').predict(img_in, verbose=0) num_col = np.argmax(label_p_pred[0]) + 1 elif (self.num_col_upper and self.num_col_lower) and (self.num_col_upper!=self.num_col_lower): if self.input_binary: @@ -352,7 +338,7 @@ class Enhancer: img_in[0, :, :, 1] = img_1ch[:, :] img_in[0, :, :, 2] = img_1ch[:, :] - label_p_pred = self.model_classifier.predict(img_in, verbose=0) + label_p_pred = self.model_zoo.get('col_classifier').predict(img_in, verbose=0) num_col = np.argmax(label_p_pred[0]) + 1 if num_col > self.num_col_upper: @@ -685,7 +671,7 @@ class Enhancer: t0 = time.time() img_res, is_image_enhanced, num_col_classifier, num_column_is_classified = self.run_enhancement(light_version=False) - return img_res + return img_res, is_image_enhanced def run(self, @@ -723,9 +709,18 @@ class Enhancer: self.logger.warning("will skip input for existing output file '%s'", self.output_filename) continue - image_enhanced = self.run_single() + did_resize = False + image_enhanced, did_enhance = self.run_single() if self.save_org_scale: image_enhanced = resize_image(image_enhanced, self.h_org, self.w_org) + did_resize = True + + self.logger.info( + "Image %s was %senhanced%s.", + img_filename, + '' if did_enhance else 'not ', + 'and resized' if did_resize else '' + ) cv2.imwrite(self.output_filename, image_enhanced) diff --git a/src/eynollah/mb_ro_on_layout.py b/src/eynollah/mb_ro_on_layout.py index 1b991ae..7f065f1 100644 --- a/src/eynollah/mb_ro_on_layout.py +++ b/src/eynollah/mb_ro_on_layout.py @@ -1,8 +1,12 @@ """ -Image enhancer. The output can be written as same scale of input or in new predicted scale. +Machine learning based reading order detection """ -from logging import Logger +# pyright: reportCallIssue=false +# pyright: reportUnboundVariable=false +# pyright: reportArgumentType=false + +import logging import os import time from typing import Optional @@ -10,12 +14,12 @@ from pathlib import Path import xml.etree.ElementTree as ET import cv2 +from keras.models import Model import numpy as np -from ocrd_utils import getLogger import statistics import tensorflow as tf -from tensorflow.keras.models import load_model +from .model_zoo import EynollahModelZoo from .utils.resize import resize_image from .utils.contour import ( find_new_features_of_contours, @@ -23,7 +27,6 @@ from .utils.contour import ( return_parent_contours, ) from .utils import is_xml_filename -from .eynollah import PatchEncoder, Patches DPI_THRESHOLD = 298 KERNEL = np.ones((5, 5), np.uint8) @@ -32,12 +35,12 @@ KERNEL = np.ones((5, 5), np.uint8) class machine_based_reading_order_on_layout: def __init__( self, - dir_models : str, - logger : Optional[Logger] = None, + *, + model_zoo: EynollahModelZoo, + logger : Optional[logging.Logger] = None, ): - self.logger = logger if logger else getLogger('mbreorder') - self.dir_models = dir_models - self.model_reading_order_dir = dir_models + "/model_eynollah_reading_order_20250824" + self.logger = logger or logging.getLogger('eynollah.mbreorder') + self.model_zoo = model_zoo try: for device in tf.config.list_physical_devices('GPU'): @@ -45,21 +48,10 @@ class machine_based_reading_order_on_layout: except: self.logger.warning("no GPU device available") - self.model_reading_order = self.our_load_model(self.model_reading_order_dir) + self.model_zoo.load_model('reading_order') + # FIXME: light_version is always true, no need for checks in the code self.light_version = True - @staticmethod - def our_load_model(model_file): - if model_file.endswith('.h5') and Path(model_file[:-3]).exists(): - # prefer SavedModel over HDF5 format if it exists - model_file = model_file[:-3] - try: - model = load_model(model_file, compile=False) - except: - model = load_model(model_file, compile=False, custom_objects={ - "PatchEncoder": PatchEncoder, "Patches": Patches}) - return model - def read_xml(self, xml_file): tree1 = ET.parse(xml_file, parser = ET.XMLParser(encoding='utf-8')) root1=tree1.getroot() @@ -69,6 +61,7 @@ class machine_based_reading_order_on_layout: index_tot_regions = [] tot_region_ref = [] + y_len, x_len = 0, 0 for jj in root1.iter(link+'Page'): y_len=int(jj.attrib['imageHeight']) x_len=int(jj.attrib['imageWidth']) @@ -81,13 +74,13 @@ class machine_based_reading_order_on_layout: co_printspace = [] if link+'PrintSpace' in alltags: region_tags_printspace = np.unique([x for x in alltags if x.endswith('PrintSpace')]) - elif link+'Border' in alltags: + else: region_tags_printspace = np.unique([x for x in alltags if x.endswith('Border')]) for tag in region_tags_printspace: if link+'PrintSpace' in alltags: tag_endings_printspace = ['}PrintSpace','}printspace'] - elif link+'Border' in alltags: + else: tag_endings_printspace = ['}Border','}border'] if tag.endswith(tag_endings_printspace[0]) or tag.endswith(tag_endings_printspace[1]): @@ -683,7 +676,7 @@ class machine_based_reading_order_on_layout: tot_counter += 1 batch.append(j) if tot_counter % inference_bs == 0 or tot_counter == len(ij_list): - y_pr = self.model_reading_order.predict(input_1 , verbose=0) + y_pr = self.model_zoo.get('reading_order', Model).predict(input_1 , verbose='0') for jb, j in enumerate(batch): if y_pr[jb][0]>=0.5: post_list.append(j) @@ -802,6 +795,7 @@ class machine_based_reading_order_on_layout: alltags=[elem.tag for elem in root_xml.iter()] ET.register_namespace("",name_space) + assert dir_out tree_xml.write(os.path.join(dir_out, file_name+'.xml'), xml_declaration=True, method='xml', diff --git a/src/eynollah/model_zoo/__init__.py b/src/eynollah/model_zoo/__init__.py new file mode 100644 index 0000000..e1dc985 --- /dev/null +++ b/src/eynollah/model_zoo/__init__.py @@ -0,0 +1,4 @@ +__all__ = [ + 'EynollahModelZoo', +] +from .model_zoo import EynollahModelZoo diff --git a/src/eynollah/model_zoo/default_specs.py b/src/eynollah/model_zoo/default_specs.py new file mode 100644 index 0000000..2bbbf15 --- /dev/null +++ b/src/eynollah/model_zoo/default_specs.py @@ -0,0 +1,313 @@ +from .specs import EynollahModelSpec, EynollahModelSpecSet + +# NOTE: This needs to change whenever models/versions change +ZENODO = "https://zenodo.org/records/17295988/files" +MODELS_VERSION = "v0_7_0" + +def dist_url(dist_name: str) -> str: + return f'{ZENODO}/models_{dist_name}_{MODELS_VERSION}.zip' + +DEFAULT_MODEL_SPECS = EynollahModelSpecSet([ + + EynollahModelSpec( + category="enhancement", + variant='', + filename="models_eynollah/eynollah-enhancement_20210425", + dists=['enhancement', 'layout', 'ci'], + dist_url=dist_url("enhancement"), + type='Keras', + ), + + EynollahModelSpec( + category="binarization", + variant='hybrid', + filename="models_eynollah/eynollah-binarization-hybrid_20230504/model_bin_hybrid_trans_cnn_sbb_ens", + dists=['layout', 'binarization', ], + dist_url=dist_url("binarization"), + type='Keras', + ), + + EynollahModelSpec( + category="binarization", + variant='20210309', + filename="models_eynollah/eynollah-binarization_20210309", + dists=['binarization'], + dist_url=dist_url("binarization"), + type='Keras', + ), + + EynollahModelSpec( + category="binarization", + variant='', + filename="models_eynollah/eynollah-binarization_20210425", + dists=['binarization'], + dist_url=dist_url("binarization"), + type='Keras', + ), + + EynollahModelSpec( + category="binarization_multi_1", + variant='', + filename="models_eynollah/eynollah-binarization-multi_2020_01_16/model_bin1", + dist_url=dist_url("binarization"), + dists=['binarization'], + type='Keras', + ), + + EynollahModelSpec( + category="binarization_multi_2", + variant='', + filename="models_eynollah/eynollah-binarization-multi_2020_01_16/model_bin2", + dist_url=dist_url("binarization"), + dists=['binarization'], + type='Keras', + ), + + EynollahModelSpec( + category="binarization_multi_3", + variant='', + filename="models_eynollah/eynollah-binarization-multi_2020_01_16/model_bin3", + dist_url=dist_url("binarization"), + dists=['binarization'], + type='Keras', + ), + + EynollahModelSpec( + category="binarization_multi_4", + variant='', + filename="models_eynollah/eynollah-binarization-multi_2020_01_16/model_bin4", + dist_url=dist_url("binarization"), + dists=['binarization'], + type='Keras', + ), + + EynollahModelSpec( + category="col_classifier", + variant='', + filename="models_eynollah/eynollah-column-classifier_20210425", + dist_url=dist_url("layout"), + dists=['layout'], + type='Keras', + ), + + EynollahModelSpec( + category="page", + variant='', + filename="models_eynollah/model_eynollah_page_extraction_20250915", + dist_url=dist_url("layout"), + dists=['layout'], + type='Keras', + ), + + EynollahModelSpec( + category="region", + variant='', + filename="models_eynollah/eynollah-main-regions-ensembled_20210425", + dist_url=dist_url("layout"), + dists=['layout'], + type='Keras', + ), + + EynollahModelSpec( + category="region", + variant='extract_only_images', + filename="models_eynollah/eynollah-main-regions_20231127_672_org_ens_11_13_16_17_18", + dist_url=dist_url("layout"), + dists=['layout'], + type='Keras', + ), + + EynollahModelSpec( + category="region", + variant='light', + filename="models_eynollah/eynollah-main-regions_20220314", + dist_url=dist_url("layout"), + help="early layout", + dists=['layout'], + type='Keras', + ), + + EynollahModelSpec( + category="region_p2", + variant='', + filename="models_eynollah/eynollah-main-regions-aug-rotation_20210425", + dist_url=dist_url("layout"), + help="early layout, non-light, 2nd part", + dists=['layout'], + type='Keras', + ), + + EynollahModelSpec( + category="region_1_2", + variant='', + #filename="models_eynollah/modelens_12sp_elay_0_3_4__3_6_n", + #filename="models_eynollah/modelens_earlylayout_12spaltige_2_3_5_6_7_8", + #filename="models_eynollah/modelens_early12_sp_2_3_5_6_7_8_9_10_12_14_15_16_18", + #filename="models_eynollah/modelens_1_2_4_5_early_lay_1_2_spaltige", + #filename="models_eynollah/model_3_eraly_layout_no_patches_1_2_spaltige", + filename="models_eynollah/modelens_e_l_all_sp_0_1_2_3_4_171024", + dist_url=dist_url("layout"), + dists=['layout'], + help="early layout, light, 1-or-2-column", + type='Keras', + ), + + EynollahModelSpec( + category="region_fl_np", + variant='', + #'filename="models_eynollah/modelens_full_lay_1_3_031124", + #'filename="models_eynollah/modelens_full_lay_13__3_19_241024", + #'filename="models_eynollah/model_full_lay_13_241024", + #'filename="models_eynollah/modelens_full_lay_13_17_231024", + #'filename="models_eynollah/modelens_full_lay_1_2_221024", + #'filename="models_eynollah/eynollah-full-regions-1column_20210425", + filename="models_eynollah/modelens_full_lay_1__4_3_091124", + dist_url=dist_url("layout"), + help="full layout / no patches", + dists=['layout'], + type='Keras', + ), + + # FIXME: Why is region_fl and region_fl_np the same model? + EynollahModelSpec( + category="region_fl", + variant='', + # filename="models_eynollah/eynollah-full-regions-3+column_20210425", + # filename="models_eynollah/model_2_full_layout_new_trans", + # filename="models_eynollah/modelens_full_lay_1_3_031124", + # filename="models_eynollah/modelens_full_lay_13__3_19_241024", + # filename="models_eynollah/model_full_lay_13_241024", + # filename="models_eynollah/modelens_full_lay_13_17_231024", + # filename="models_eynollah/modelens_full_lay_1_2_221024", + # filename="models_eynollah/modelens_full_layout_24_till_28", + # filename="models_eynollah/model_2_full_layout_new_trans", + filename="models_eynollah/modelens_full_lay_1__4_3_091124", + dist_url=dist_url("layout"), + help="full layout / with patches", + dists=['layout'], + type='Keras', + ), + + EynollahModelSpec( + category="reading_order", + variant='', + #filename="models_eynollah/model_mb_ro_aug_ens_11", + #filename="models_eynollah/model_step_3200000_mb_ro", + #filename="models_eynollah/model_ens_reading_order_machine_based", + #filename="models_eynollah/model_mb_ro_aug_ens_8", + #filename="models_eynollah/model_ens_reading_order_machine_based", + filename="models_eynollah/model_eynollah_reading_order_20250824", + dist_url=dist_url("reading_order"), + dists=['layout', 'reading_order'], + type='Keras', + ), + + EynollahModelSpec( + category="textline", + variant='', + #filename="models_eynollah/modelens_textline_1_4_16092024", + #filename="models_eynollah/model_textline_ens_3_4_5_6_artificial", + #filename="models_eynollah/modelens_textline_1_3_4_20240915", + #filename="models_eynollah/model_textline_ens_3_4_5_6_artificial", + #filename="models_eynollah/modelens_textline_9_12_13_14_15", + #filename="models_eynollah/eynollah-textline_20210425", + filename="models_eynollah/modelens_textline_0_1__2_4_16092024", + dist_url=dist_url("layout"), + dists=['layout'], + type='Keras', + ), + + EynollahModelSpec( + category="textline", + variant='light', + #filename="models_eynollah/eynollah-textline_light_20210425", + filename="models_eynollah/modelens_textline_0_1__2_4_16092024", + dist_url=dist_url("layout"), + dists=['layout'], + type='Keras', + ), + + EynollahModelSpec( + category="table", + variant='', + filename="models_eynollah/eynollah-tables_20210319", + dist_url=dist_url("layout"), + dists=['layout'], + type='Keras', + ), + + EynollahModelSpec( + category="table", + variant='light', + filename="models_eynollah/modelens_table_0t4_201124", + dist_url=dist_url("layout"), + dists=['layout'], + type='Keras', + ), + + EynollahModelSpec( + category="ocr", + variant='', + filename="models_eynollah/model_eynollah_ocr_cnnrnn_20250930", + dist_url=dist_url("ocr"), + dists=['layout', 'ocr'], + type='Keras', + ), + + EynollahModelSpec( + category="ocr", + variant='degraded', + filename="models_eynollah/model_eynollah_ocr_cnnrnn__degraded_20250805/", + help="slightly better at degraded Fraktur", + dist_url=dist_url("ocr"), + dists=['ocr'], + type='Keras', + ), + + EynollahModelSpec( + category="num_to_char", + variant='', + filename="characters_org.txt", + dist_url=dist_url("ocr"), + dists=['ocr'], + type='decoder', + ), + + EynollahModelSpec( + category="characters", + variant='', + filename="characters_org.txt", + dist_url=dist_url("ocr"), + dists=['ocr'], + type='List[str]', + ), + + EynollahModelSpec( + category="ocr", + variant='tr', + filename="models_eynollah/model_eynollah_ocr_trocr_20250919", + dist_url=dist_url("trocr"), + help='much slower transformer-based', + dists=['trocr'], + type='Keras', + ), + + EynollahModelSpec( + category="trocr_processor", + variant='', + filename="models_eynollah/model_eynollah_ocr_trocr_20250919", + dist_url=dist_url("trocr"), + dists=['trocr'], + type='TrOCRProcessor', + ), + + EynollahModelSpec( + category="trocr_processor", + variant='htr', + filename="models_eynollah/microsoft/trocr-base-handwritten", + dist_url=dist_url("trocr"), + dists=['trocr'], + type='TrOCRProcessor', + ), + +]) diff --git a/src/eynollah/model_zoo/model_zoo.py b/src/eynollah/model_zoo/model_zoo.py new file mode 100644 index 0000000..512bf1a --- /dev/null +++ b/src/eynollah/model_zoo/model_zoo.py @@ -0,0 +1,204 @@ +import json +import logging +from copy import deepcopy +from pathlib import Path +from typing import Dict, List, Optional, Tuple, Type, Union + +from ocrd_utils import tf_disable_interactive_logs +tf_disable_interactive_logs() + +from keras.layers import StringLookup +from keras.models import Model as KerasModel +from keras.models import load_model +from tabulate import tabulate +from ..patch_encoder import PatchEncoder, Patches +from .specs import EynollahModelSpecSet +from .default_specs import DEFAULT_MODEL_SPECS +from .types import AnyModel, T + + +class EynollahModelZoo: + """ + Wrapper class that handles storage and loading of models for all eynollah runners. + """ + + model_basedir: Path + specs: EynollahModelSpecSet + + def __init__( + self, + basedir: str, + model_overrides: Optional[List[Tuple[str, str, str]]] = None, + ) -> None: + self.model_basedir = Path(basedir) + self.logger = logging.getLogger('eynollah.model_zoo') + if not self.model_basedir.exists(): + self.logger.warning(f"Model basedir does not exist: {basedir}. Set eynollah --model-basedir to the correct directory.") + self.specs = deepcopy(DEFAULT_MODEL_SPECS) + self._overrides = [] + if model_overrides: + self.override_models(*model_overrides) + self._loaded: Dict[str, AnyModel] = {} + + @property + def model_overrides(self): + return self._overrides + + def override_models( + self, + *model_overrides: Tuple[str, str, str], + ): + """ + Override the default model versions + """ + for model_category, model_variant, model_filename in model_overrides: + spec = self.specs.get(model_category, model_variant) + self.logger.warning("Overriding filename for model spec %s to %s", spec, model_filename) + self.specs.get(model_category, model_variant).filename = model_filename + self._overrides += model_overrides + + def model_path( + self, + model_category: str, + model_variant: str = '', + absolute: bool = True, + ) -> Path: + """ + Translate model_{type,variant} tuple into an absolute (or relative) Path + """ + spec = self.specs.get(model_category, model_variant) + if spec.category in ('characters', 'num_to_char'): + return self.model_path('ocr') / spec.filename + if not Path(spec.filename).is_absolute() and absolute: + model_path = Path(self.model_basedir).joinpath(spec.filename) + else: + model_path = Path(spec.filename) + return model_path + + def load_models( + self, + *all_load_args: Union[str, Tuple[str], Tuple[str, str], Tuple[str, str, str]], + ) -> Dict: + """ + Load all models by calling load_model and return a dictionary mapping model_category to loaded model + """ + ret = {} + for load_args in all_load_args: + if isinstance(load_args, str): + ret[load_args] = self.load_model(load_args) + else: + ret[load_args[0]] = self.load_model(*load_args) + return ret + + def load_model( + self, + model_category: str, + model_variant: str = '', + model_path_override: Optional[str] = None, + ) -> AnyModel: + """ + Load any model + """ + if model_path_override: + self.override_models((model_category, model_variant, model_path_override)) + model_path = self.model_path(model_category, model_variant) + if model_path.suffix == '.h5' and Path(model_path.stem).exists(): + # prefer SavedModel over HDF5 format if it exists + model_path = Path(model_path.stem) + if model_category == 'ocr': + model = self._load_ocr_model(variant=model_variant) + elif model_category == 'num_to_char': + model = self._load_num_to_char() + elif model_category == 'characters': + model = self._load_characters() + elif model_category == 'trocr_processor': + from transformers import TrOCRProcessor + model = TrOCRProcessor.from_pretrained(model_path) + else: + try: + model = load_model(model_path, compile=False) + except Exception as e: + self.logger.exception(e) + model = load_model( + model_path, compile=False, custom_objects={"PatchEncoder": PatchEncoder, "Patches": Patches} + ) + self._loaded[model_category] = model + return model # type: ignore + + def get(self, model_category: str, model_type: Optional[Type[T]] = None) -> T: + if model_category not in self._loaded: + raise ValueError(f'Model "{model_category} not previously loaded with "load_model(..)"') + ret = self._loaded[model_category] + if model_type: + assert isinstance(ret, model_type) + return ret # type: ignore # FIXME: convince typing that we're returning generic type + + def _load_ocr_model(self, variant: str) -> AnyModel: + """ + Load OCR model + """ + ocr_model_dir = self.model_path('ocr', variant) + if variant == 'tr': + from transformers import VisionEncoderDecoderModel + ret = VisionEncoderDecoderModel.from_pretrained(ocr_model_dir) + assert isinstance(ret, VisionEncoderDecoderModel) + return ret + else: + ocr_model = load_model(ocr_model_dir, compile=False) + assert isinstance(ocr_model, KerasModel) + return KerasModel( + ocr_model.get_layer(name="image").input, # type: ignore + ocr_model.get_layer(name="dense2").output, # type: ignore + ) + + def _load_characters(self) -> List[str]: + """ + Load encoding for OCR + """ + with open(self.model_path('num_to_char'), "r") as config_file: + return json.load(config_file) + + def _load_num_to_char(self) -> StringLookup: + """ + Load decoder for OCR + """ + characters = self._load_characters() + # Mapping characters to integers. + char_to_num = StringLookup(vocabulary=characters, mask_token=None) + # Mapping integers back to original characters. + return StringLookup(vocabulary=char_to_num.get_vocabulary(), mask_token=None, invert=True) + + def __str__(self): + return tabulate( + [ + [ + spec.type, + spec.category, + spec.variant, + spec.help, + ', '.join(spec.dists), + f'Yes, at {self.model_path(spec.category, spec.variant)}' + if self.model_path(spec.category, spec.variant).exists() + else f'No, download {spec.dist_url}', + # self.model_path(spec.category, spec.variant), + ] + for spec in self.specs.specs + ], + headers=[ + 'Type', + 'Category', + 'Variant', + 'Help', + 'Used in', + 'Installed', + ], + tablefmt='github', + ) + + def shutdown(self): + """ + Ensure that a loaded models is not referenced by ``self._loaded`` anymore + """ + if hasattr(self, '_loaded') and getattr(self, '_loaded'): + for needle in list(self._loaded.keys()): + del self._loaded[needle] diff --git a/src/eynollah/model_zoo/specs.py b/src/eynollah/model_zoo/specs.py new file mode 100644 index 0000000..54a55f2 --- /dev/null +++ b/src/eynollah/model_zoo/specs.py @@ -0,0 +1,54 @@ +from dataclasses import dataclass +from typing import Dict, List, Set, Tuple + + +@dataclass +class EynollahModelSpec(): + """ + Describing a single model abstractly. + """ + category: str + # Relative filename to the models_eynollah directory in the dists + filename: str + # basename of the ZIP files that should contain this model + dists: List[str] + # URL to the smallest model distribution containing this model (link to Zenodo) + dist_url: str + type: str + variant: str = '' + help: str = '' + +class EynollahModelSpecSet(): + """ + List of all used models for eynollah. + """ + specs: List[EynollahModelSpec] + + def __init__(self, specs: List[EynollahModelSpec]) -> None: + self.specs = sorted(specs, key=lambda x: x.category + '0' + x.variant) + self.categories: Set[str] = set([spec.category for spec in self.specs]) + self.variants: Dict[str, Set[str]] = { + spec.category: set([x.variant for x in self.specs if x.category == spec.category]) + for spec in self.specs + } + self._index_category_variant: Dict[Tuple[str, str], EynollahModelSpec] = { + (spec.category, spec.variant): spec + for spec in self.specs + } + + def asdict(self) -> Dict[str, Dict[str, str]]: + return { + spec.category: { + spec.variant: spec.filename + } + for spec in self.specs + } + + def get(self, category: str, variant: str) -> EynollahModelSpec: + if category not in self.categories: + raise ValueError(f"Unknown category '{category}', must be one of {self.categories}") + if variant not in self.variants[category]: + raise ValueError(f"Unknown variant {variant} for {category}. Known variants: {self.variants[category]}") + return self._index_category_variant[(category, variant)] + + diff --git a/src/eynollah/model_zoo/types.py b/src/eynollah/model_zoo/types.py new file mode 100644 index 0000000..43f6859 --- /dev/null +++ b/src/eynollah/model_zoo/types.py @@ -0,0 +1,7 @@ +from typing import TypeVar + +# NOTE: Creating an actual union type requires loading transformers which is expensive and error-prone +# from transformers import TrOCRProcessor, VisionEncoderDecoderModel +# AnyModel = Union[VisionEncoderDecoderModel, TrOCRProcessor, KerasModel, List] +AnyModel = object +T = TypeVar('T') diff --git a/src/eynollah/ocrd-tool.json b/src/eynollah/ocrd-tool.json index dbbdc3b..3d1193d 100644 --- a/src/eynollah/ocrd-tool.json +++ b/src/eynollah/ocrd-tool.json @@ -83,10 +83,10 @@ }, "resources": [ { - "url": "https://zenodo.org/records/17194824/files/models_layout_v0_5_0.tar.gz?download=1", - "name": "models_layout_v0_5_0", + "url": "https://zenodo.org/records/17295988/files/models_layout_v0_6_0.tar.gz?download=1", + "name": "models_layout_v0_6_0", "type": "archive", - "path_in_archive": "models_layout_v0_5_0", + "path_in_archive": "models_layout_v0_6_0", "size": 3525684179, "description": "Models for layout detection, reading order detection, textline detection, page extraction, column classification, table detection, binarization, image enhancement", "version_range": ">= v0.5.0" diff --git a/src/eynollah/ocrd_cli_binarization.py b/src/eynollah/ocrd_cli_binarization.py index 848bbac..e5f85b1 100644 --- a/src/eynollah/ocrd_cli_binarization.py +++ b/src/eynollah/ocrd_cli_binarization.py @@ -34,6 +34,7 @@ class SbbBinarizeProcessor(Processor): Set up the model prior to processing. """ # resolve relative path via OCR-D ResourceManager + assert isinstance(self.parameter, dict) model_path = self.resolve_resource(self.parameter['model']) self.binarizer = SbbBinarizer(model_dir=model_path, logger=self.logger) diff --git a/src/eynollah/patch_encoder.py b/src/eynollah/patch_encoder.py new file mode 100644 index 0000000..939ad7b --- /dev/null +++ b/src/eynollah/patch_encoder.py @@ -0,0 +1,52 @@ +from keras import layers +import tensorflow as tf + +projection_dim = 64 +patch_size = 1 +num_patches =21*21#14*14#28*28#14*14#28*28 + +class PatchEncoder(layers.Layer): + + def __init__(self): + super().__init__() + self.projection = layers.Dense(units=projection_dim) + self.position_embedding = layers.Embedding(input_dim=num_patches, output_dim=projection_dim) + + def call(self, patch): + positions = tf.range(start=0, limit=num_patches, delta=1) + encoded = self.projection(patch) + self.position_embedding(positions) + return encoded + + def get_config(self): + config = super().get_config().copy() + config.update({ + 'num_patches': num_patches, + 'projection': self.projection, + 'position_embedding': self.position_embedding, + }) + return config + +class Patches(layers.Layer): + def __init__(self, **kwargs): + super(Patches, self).__init__() + self.patch_size = patch_size + + def call(self, images): + batch_size = tf.shape(images)[0] + patches = tf.image.extract_patches( + images=images, + sizes=[1, self.patch_size, self.patch_size, 1], + strides=[1, self.patch_size, self.patch_size, 1], + rates=[1, 1, 1, 1], + padding="VALID", + ) + patch_dims = patches.shape[-1] + patches = tf.reshape(patches, [batch_size, -1, patch_dims]) + return patches + def get_config(self): + + config = super().get_config().copy() + config.update({ + 'patch_size': self.patch_size, + }) + return config diff --git a/src/eynollah/plot.py b/src/eynollah/plot.py index c026e94..b1b2359 100644 --- a/src/eynollah/plot.py +++ b/src/eynollah/plot.py @@ -40,8 +40,8 @@ class EynollahPlotter: self.image_filename_stem = image_filename_stem # XXX TODO hacky these cannot be set at init time self.image_org = image_org - self.scale_x = scale_x - self.scale_y = scale_y + self.scale_x : float = scale_x + self.scale_y : float = scale_y def save_plot_of_layout_main(self, text_regions_p, image_page): if self.dir_of_layout is not None: diff --git a/src/eynollah/processor.py b/src/eynollah/processor.py index 12c7356..60c136c 100644 --- a/src/eynollah/processor.py +++ b/src/eynollah/processor.py @@ -32,8 +32,8 @@ class EynollahProcessor(Processor): allow_scaling=self.parameter['allow_scaling'], headers_off=self.parameter['headers_off'], tables=self.parameter['tables'], + logger=self.logger ) - self.eynollah.logger = self.logger self.eynollah.plotter = None def shutdown(self): diff --git a/src/eynollah/sbb_binarize.py b/src/eynollah/sbb_binarize.py index 3716987..77741e9 100644 --- a/src/eynollah/sbb_binarize.py +++ b/src/eynollah/sbb_binarize.py @@ -2,18 +2,24 @@ Tool to load model and binarize a given image. """ -import sys -from glob import glob +# pyright: reportIndexIssue=false +# pyright: reportCallIssue=false +# pyright: reportArgumentType=false +# pyright: reportPossiblyUnboundVariable=false + import os import logging +from pathlib import Path +from typing import Dict, Optional import numpy as np -from PIL import Image import cv2 from ocrd_utils import tf_disable_interactive_logs + +from eynollah.model_zoo import EynollahModelZoo +from eynollah.model_zoo.types import AnyModel tf_disable_interactive_logs() import tensorflow as tf -from tensorflow.keras.models import load_model from tensorflow.python.keras import backend as tensorflow_backend from .utils import is_image_filename @@ -23,40 +29,40 @@ def resize_image(img_in, input_height, input_width): class SbbBinarizer: - def __init__(self, model_dir, logger=None): - self.model_dir = model_dir - self.log = logger if logger else logging.getLogger('SbbBinarizer') - - self.start_new_session() - - self.model_files = glob(self.model_dir+"/*/", recursive = True) - - self.models = [] - for model_file in self.model_files: - self.models.append(self.load_model(model_file)) + def __init__( + self, + *, + model_zoo: EynollahModelZoo, + mode: str, + logger: Optional[logging.Logger] = None, + ): + self.logger = logger if logger else logging.getLogger('eynollah.binarization') + self.model_zoo = model_zoo + self.models = self.setup_models(mode) + self.session = self.start_new_session() def start_new_session(self): config = tf.compat.v1.ConfigProto() config.gpu_options.allow_growth = True - self.session = tf.compat.v1.Session(config=config) # tf.InteractiveSession() - tensorflow_backend.set_session(self.session) + session = tf.compat.v1.Session(config=config) # tf.InteractiveSession() + tensorflow_backend.set_session(session) + return session + + def setup_models(self, mode: str) -> Dict[Path, AnyModel]: + return { + self.model_zoo.model_path(v): self.model_zoo.load_model(v) + for v in (['binarization'] if mode == 'single' else [f'binarization_multi_{i}' for i in range(1, 5)]) + } def end_session(self): tensorflow_backend.clear_session() self.session.close() del self.session - def load_model(self, model_name): - model = load_model(os.path.join(self.model_dir, model_name), compile=False) + def predict(self, model, img, use_patches, n_batch_inference=5): model_height = model.layers[len(model.layers)-1].output_shape[1] model_width = model.layers[len(model.layers)-1].output_shape[2] - n_classes = model.layers[len(model.layers)-1].output_shape[3] - return model, model_height, model_width, n_classes - - def predict(self, model_in, img, use_patches, n_batch_inference=5): - tensorflow_backend.set_session(self.session) - model, model_height, model_width, n_classes = model_in img_org_h = img.shape[0] img_org_w = img.shape[1] @@ -324,9 +330,8 @@ class SbbBinarizer: if image_path is not None: image = cv2.imread(image_path) img_last = 0 - for n, (model, model_file) in enumerate(zip(self.models, self.model_files)): - self.log.info('Predicting with model %s [%s/%s]' % (model_file, n + 1, len(self.model_files))) - + for n, (model_file, model) in enumerate(self.models.items()): + self.logger.info('Predicting %s with model %s [%s/%s]', image_path if image_path else '[image]', model_file, n + 1, len(self.models.keys())) res = self.predict(model, image, use_patches) img_fin = np.zeros((res.shape[0], res.shape[1], 3)) @@ -345,17 +350,19 @@ class SbbBinarizer: img_last[:, :][img_last[:, :] > 0] = 255 img_last = (img_last[:, :] == 0) * 255 if output: + self.logger.info('Writing binarized image to %s', output) cv2.imwrite(output, img_last) return img_last else: ls_imgs = list(filter(is_image_filename, os.listdir(dir_in))) - for image_name in ls_imgs: + self.logger.info("Found %d image files to binarize in %s", len(ls_imgs), dir_in) + for i, image_name in enumerate(ls_imgs): image_stem = image_name.split('.')[0] - print(image_name,'image_name') + self.logger.info('Binarizing [%3d/%d] %s', i + 1, len(ls_imgs), image_name) image = cv2.imread(os.path.join(dir_in,image_name) ) img_last = 0 - for n, (model, model_file) in enumerate(zip(self.models, self.model_files)): - self.log.info('Predicting with model %s [%s/%s]' % (model_file, n + 1, len(self.model_files))) + for n, (model_file, model) in enumerate(self.models.items()): + self.logger.info('Predicting %s with model %s [%s/%s]', image_name, model_file, n + 1, len(self.models.keys())) res = self.predict(model, image, use_patches) @@ -375,4 +382,6 @@ class SbbBinarizer: img_last[:, :][img_last[:, :] > 0] = 255 img_last = (img_last[:, :] == 0) * 255 - cv2.imwrite(os.path.join(output, image_stem + '.png'), img_last) + output_filename = os.path.join(output, image_stem + '.png') + self.logger.info('Writing binarized image to %s', output_filename) + cv2.imwrite(output_filename, img_last) diff --git a/src/eynollah/utils/__init__.py b/src/eynollah/utils/__init__.py index 5ccb2af..29359eb 100644 --- a/src/eynollah/utils/__init__.py +++ b/src/eynollah/utils/__init__.py @@ -19,7 +19,6 @@ from .contour import (contours_in_same_horizon, find_new_features_of_contours, return_contours_of_image, return_parent_contours) - def pairwise(iterable): # pairwise('ABCDEFG') → AB BC CD DE EF FG @@ -393,7 +392,12 @@ def find_num_col_deskew(regions_without_separators, sigma_, multiplier=3.8): z = gaussian_filter1d(regions_without_separators_0, sigma_) return np.std(z) -def find_num_col(regions_without_separators, num_col_classifier, tables, multiplier=3.8): +def find_num_col( + regions_without_separators, + num_col_classifier, + tables, + multiplier=3.8, +): if not regions_without_separators.any(): return 0, [] #plt.imshow(regions_without_separators) diff --git a/src/eynollah/utils/contour.py b/src/eynollah/utils/contour.py index f304db2..6550171 100644 --- a/src/eynollah/utils/contour.py +++ b/src/eynollah/utils/contour.py @@ -357,7 +357,7 @@ def join_polygons(polygons: Sequence[Polygon], scale=20) -> Polygon: assert jointp.geom_type == 'Polygon', jointp.wkt # follow-up calculations will necessarily be integer; # so anticipate rounding here and then ensure validity - jointp2 = set_precision(jointp, 1.0) + jointp2 = set_precision(jointp, 1.0, mode="keep_collapsed") if jointp2.geom_type != 'Polygon' or not jointp2.is_valid: jointp2 = Polygon(np.round(jointp.exterior.coords)) jointp2 = make_valid(jointp2) diff --git a/src/eynollah/writer.py b/src/eynollah/writer.py index 9c3456a..38b7b9e 100644 --- a/src/eynollah/writer.py +++ b/src/eynollah/writer.py @@ -2,15 +2,15 @@ # pylint: disable=import-error from pathlib import Path import os.path +from typing import Optional +import logging import xml.etree.ElementTree as ET from .utils.xml import create_page_xml, xml_reading_order from .utils.counter import EynollahIdCounter -from ocrd_utils import getLogger from ocrd_models.ocrd_page import ( BorderType, CoordsType, - PcGtsType, TextLineType, TextEquivType, TextRegionType, @@ -24,7 +24,7 @@ import numpy as np class EynollahXmlWriter: def __init__(self, *, dir_out, image_filename, curved_line,textline_light, pcgts=None): - self.logger = getLogger('eynollah.writer') + self.logger = logging.getLogger('eynollah.writer') self.counter = EynollahIdCounter() self.dir_out = dir_out self.image_filename = image_filename @@ -32,10 +32,10 @@ class EynollahXmlWriter: self.curved_line = curved_line self.textline_light = textline_light self.pcgts = pcgts - self.scale_x = None # XXX set outside __init__ - self.scale_y = None # XXX set outside __init__ - self.height_org = None # XXX set outside __init__ - self.width_org = None # XXX set outside __init__ + self.scale_x: Optional[float] = None # XXX set outside __init__ + self.scale_y: Optional[float] = None # XXX set outside __init__ + self.height_org: Optional[int] = None # XXX set outside __init__ + self.width_org: Optional[int] = None # XXX set outside __init__ @property def image_filename_stem(self): @@ -135,6 +135,7 @@ class EynollahXmlWriter: # create the file structure pcgts = self.pcgts if self.pcgts else create_page_xml(self.image_filename, self.height_org, self.width_org) page = pcgts.get_Page() + assert page page.set_Border(BorderType(Coords=CoordsType(points=self.calculate_page_coords(cont_page)))) counter = EynollahIdCounter() @@ -152,6 +153,7 @@ class EynollahXmlWriter: Coords=CoordsType(points=self.calculate_polygon_coords(region_contour, page_coord, skip_layout_reading_order)) ) + assert textregion.Coords if conf_contours_textregions: textregion.Coords.set_conf(conf_contours_textregions[mm]) page.add_TextRegion(textregion) @@ -168,6 +170,7 @@ class EynollahXmlWriter: id=counter.next_region_id, type_='heading', Coords=CoordsType(points=self.calculate_polygon_coords(region_contour, page_coord)) ) + assert textregion.Coords if conf_contours_textregions_h: textregion.Coords.set_conf(conf_contours_textregions_h[mm]) page.add_TextRegion(textregion) diff --git a/tests/__init__.py b/tests/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/cli_tests/conftest.py b/tests/cli_tests/conftest.py new file mode 100644 index 0000000..601d76b --- /dev/null +++ b/tests/cli_tests/conftest.py @@ -0,0 +1,47 @@ +from typing import List +import pytest +import logging + +from click.testing import CliRunner, Result +from eynollah.cli import main as eynollah_cli + + +@pytest.fixture +def run_eynollah_ok_and_check_logs( + pytestconfig, + caplog, + model_dir, + eynollah_subcommands, + eynollah_log_filter, +): + """ + Generates a Click Runner for `cli`, injects model_path and logging level + to `args`, runs the command and checks whether the logs generated contain + every fragment in `expected_logs` + """ + + def _run_click_ok_logs( + subcommand: 'str', + args: List[str], + expected_logs: List[str], + ) -> Result: + assert subcommand in eynollah_subcommands, f'subcommand {subcommand} must be one of {eynollah_subcommands}' + args = [ + '-m', model_dir, + subcommand, + *args + ] + if pytestconfig.getoption('verbose') > 0: + args = ['-l', 'DEBUG'] + args + caplog.set_level(logging.INFO) + runner = CliRunner() + with caplog.filtering(eynollah_log_filter): + result = runner.invoke(eynollah_cli, args, catch_exceptions=False) + assert result.exit_code == 0, result.stdout + if expected_logs: + logmsgs = [logrec.message for logrec in caplog.records] + assert any(logmsg.startswith(needle) for needle in expected_logs for logmsg in logmsgs), f'{expected_logs} not in {logmsgs}' + return result + + return _run_click_ok_logs + diff --git a/tests/cli_tests/test_binarization.py b/tests/cli_tests/test_binarization.py new file mode 100644 index 0000000..aa52957 --- /dev/null +++ b/tests/cli_tests/test_binarization.py @@ -0,0 +1,53 @@ +import pytest +from PIL import Image + +@pytest.mark.parametrize( + "options", + [ + [], # defaults + ["--no-patches"], + ], ids=str) +def test_run_eynollah_binarization_filename( + tmp_path, + run_eynollah_ok_and_check_logs, + resources_dir, + options, +): + infile = resources_dir / '2files/kant_aufklaerung_1784_0020.tif' + outfile = tmp_path / 'kant_aufklaerung_1784_0020.png' + run_eynollah_ok_and_check_logs( + 'binarization', + [ + '-i', str(infile), + '-o', str(outfile), + ] + options, + [ + 'Predicting' + ] + ) + assert outfile.exists() + with Image.open(infile) as original_img: + original_size = original_img.size + with Image.open(outfile) as binarized_img: + binarized_size = binarized_img.size + assert original_size == binarized_size + +def test_run_eynollah_binarization_directory( + tmp_path, + run_eynollah_ok_and_check_logs, + resources_dir, + image_resources, +): + outdir = tmp_path + run_eynollah_ok_and_check_logs( + 'binarization', + [ + '-di', str(resources_dir / '2files'), + '-o', str(outdir), + ], + [ + f'Predicting {image_resources[0].name}', + f'Predicting {image_resources[1].name}', + ] + ) + assert len(list(outdir.iterdir())) == 2 diff --git a/tests/cli_tests/test_enhance.py b/tests/cli_tests/test_enhance.py new file mode 100644 index 0000000..b994c5d --- /dev/null +++ b/tests/cli_tests/test_enhance.py @@ -0,0 +1,52 @@ +import pytest +from PIL import Image + +@pytest.mark.parametrize( + "options", + [ + [], # defaults + ["-sos"], + ], ids=str) +def test_run_eynollah_enhancement_filename( + tmp_path, + resources_dir, + run_eynollah_ok_and_check_logs, + options, +): + infile = resources_dir / '2files/kant_aufklaerung_1784_0020.tif' + outfile = tmp_path / 'kant_aufklaerung_1784_0020.png' + run_eynollah_ok_and_check_logs( + 'enhancement', + [ + '-i', str(infile), + '-o', str(outfile.parent), + ] + options, + [ + 'Image was enhanced', + ] + ) + with Image.open(infile) as original_img: + original_size = original_img.size + with Image.open(outfile) as enhanced_img: + enhanced_size = enhanced_img.size + assert (original_size == enhanced_size) == ("-sos" in options) + +def test_run_eynollah_enhancement_directory( + tmp_path, + resources_dir, + image_resources, + run_eynollah_ok_and_check_logs, +): + outdir = tmp_path + run_eynollah_ok_and_check_logs( + 'enhancement', + [ + '-di', str(resources_dir/ '2files'), + '-o', str(outdir), + ], + [ + f'Image {image_resources[0]} was enhanced', + f'Image {image_resources[1]} was enhanced', + ] + ) + assert len(list(outdir.iterdir())) == 2 diff --git a/tests/cli_tests/test_layout.py b/tests/cli_tests/test_layout.py new file mode 100644 index 0000000..a34514e --- /dev/null +++ b/tests/cli_tests/test_layout.py @@ -0,0 +1,128 @@ +import pytest +from ocrd_modelfactory import page_from_file +from ocrd_models.constants import NAMESPACES as NS + +@pytest.mark.parametrize( + "options", + [ + [], # defaults + #["--allow_scaling", "--curved-line"], + ["--allow_scaling", "--curved-line", "--full-layout"], + ["--allow_scaling", "--curved-line", "--full-layout", "--reading_order_machine_based"], + ["--allow_scaling", "--curved-line", "--full-layout", "--reading_order_machine_based", + "--textline_light", "--light_version"], + # -ep ... + # -eoi ... + # FIXME: find out whether OCR extra was installed, otherwise skip these + ["--do_ocr"], + ["--do_ocr", "--light_version", "--textline_light"], + ["--do_ocr", "--transformer_ocr"], + #["--do_ocr", "--transformer_ocr", "--light_version", "--textline_light"], + ["--do_ocr", "--transformer_ocr", "--light_version", "--textline_light", "--full-layout"], + # --skip_layout_and_reading_order + ], ids=str) +def test_run_eynollah_layout_filename( + tmp_path, + run_eynollah_ok_and_check_logs, + resources_dir, + options, +): + infile = resources_dir / '2files/kant_aufklaerung_1784_0020.tif' + outfile = tmp_path / 'kant_aufklaerung_1784_0020.xml' + run_eynollah_ok_and_check_logs( + 'layout', + [ + '-i', str(infile), + '-o', str(outfile.parent), + ] + options, + [ + str(infile) + ] + ) + assert outfile.exists() + tree = page_from_file(str(outfile)).etree + regions = tree.xpath("//page:TextRegion", namespaces=NS) + assert len(regions) >= 2, "result is inaccurate" + regions = tree.xpath("//page:SeparatorRegion", namespaces=NS) + assert len(regions) >= 2, "result is inaccurate" + lines = tree.xpath("//page:TextLine", namespaces=NS) + assert len(lines) == 31, "result is inaccurate" # 29 paragraph lines, 1 page and 1 catch-word line + +@pytest.mark.parametrize( + "options", + [ + ["--tables"], + ["--tables", "--full-layout"], + ["--tables", "--full-layout", "--textline_light", "--light_version"], + ], ids=str) +def test_run_eynollah_layout_filename2( + tmp_path, + resources_dir, + run_eynollah_ok_and_check_logs, + options, +): + infile = resources_dir / '2files/euler_rechenkunst01_1738_0025.tif' + outfile = tmp_path / 'euler_rechenkunst01_1738_0025.xml' + run_eynollah_ok_and_check_logs( + 'layout', + [ + '-i', str(infile), + '-o', str(outfile.parent), + ] + options, + [ + str(infile) + ] + ) + assert outfile.exists() + tree = page_from_file(str(outfile)).etree + regions = tree.xpath("//page:TextRegion", namespaces=NS) + assert len(regions) >= 2, "result is inaccurate" + regions = tree.xpath("//page:TableRegion", namespaces=NS) + # model/decoding is not very precise, so (depending on mode) we can get fractures/splits/FP + assert len(regions) >= 1, "result is inaccurate" + regions = tree.xpath("//page:SeparatorRegion", namespaces=NS) + assert len(regions) >= 2, "result is inaccurate" + lines = tree.xpath("//page:TextLine", namespaces=NS) + assert len(lines) >= 2, "result is inaccurate" # mostly table (if detected correctly), but 1 page and 1 catch-word line + +def test_run_eynollah_layout_directory( + tmp_path, + resources_dir, + run_eynollah_ok_and_check_logs, +): + outdir = tmp_path + run_eynollah_ok_and_check_logs( + 'layout', + [ + '-di', str(resources_dir / '2files'), + '-o', str(outdir), + ], + [ + 'Job done in', + 'All jobs done in', + ] + ) + assert len(list(outdir.iterdir())) == 2 + +# def test_run_eynollah_layout_marginalia( +# tmp_path, +# resources_dir, +# run_eynollah_ok_and_check_logs, +# ): +# outdir = tmp_path +# outfile = outdir / 'estor_rechtsgelehrsamkeit02_1758_0880_800px.xml' +# run_eynollah_ok_and_check_logs( +# 'layout', +# [ +# '-i', str(resources_dir / 'estor_rechtsgelehrsamkeit02_1758_0880_800px.jpg'), +# '-o', str(outdir), +# ], +# [ +# 'Job done in', +# 'All jobs done in', +# ] +# ) +# assert outfile.exists() +# tree = page_from_file(str(outfile)).etree +# regions = tree.xpath('//page:TextRegion[type="marginalia"]', namespaces=NS) +# assert len(regions) == 5, "expected 5 marginalia regions" diff --git a/tests/cli_tests/test_mbreorder.py b/tests/cli_tests/test_mbreorder.py new file mode 100644 index 0000000..e429e98 --- /dev/null +++ b/tests/cli_tests/test_mbreorder.py @@ -0,0 +1,47 @@ +from ocrd_modelfactory import page_from_file +from ocrd_models.constants import NAMESPACES as NS + +def test_run_eynollah_mbreorder_filename( + tmp_path, + resources_dir, + run_eynollah_ok_and_check_logs, +): + infile = resources_dir / '2files/kant_aufklaerung_1784_0020.xml' + outfile = tmp_path /'kant_aufklaerung_1784_0020.xml' + run_eynollah_ok_and_check_logs( + 'machine-based-reading-order', + [ + '-i', str(infile), + '-o', str(outfile.parent), + ], + [ + # FIXME: mbreorder has no logging! + ] + ) + assert outfile.exists() + #in_tree = page_from_file(str(infile)).etree + #in_order = in_tree.xpath("//page:OrderedGroup//@regionRef", namespaces=NS) + out_tree = page_from_file(str(outfile)).etree + out_order = out_tree.xpath("//page:OrderedGroup//@regionRef", namespaces=NS) + #assert len(out_order) >= 2, "result is inaccurate" + #assert in_order != out_order + assert out_order == ['r_1_1', 'r_2_1', 'r_2_2', 'r_2_3'] + +def test_run_eynollah_mbreorder_directory( + tmp_path, + resources_dir, + run_eynollah_ok_and_check_logs, +): + outdir = tmp_path + run_eynollah_ok_and_check_logs( + 'machine-based-reading-order', + [ + '-di', str(resources_dir / '2files'), + '-o', str(outdir), + ], + [ + # FIXME: mbreorder has no logging! + ] + ) + assert len(list(outdir.iterdir())) == 2 + diff --git a/tests/cli_tests/test_ocr.py b/tests/cli_tests/test_ocr.py new file mode 100644 index 0000000..6bf3080 --- /dev/null +++ b/tests/cli_tests/test_ocr.py @@ -0,0 +1,64 @@ +import pytest +from ocrd_modelfactory import page_from_file +from ocrd_models.constants import NAMESPACES as NS + +@pytest.mark.parametrize( + "options", + [ + ["-trocr"], + [], # defaults + ["-doit", #str(outrenderfile.parent)], + ], + ], ids=str) +def test_run_eynollah_ocr_filename( + tmp_path, + run_eynollah_ok_and_check_logs, + resources_dir, + options, +): + infile = resources_dir / '2files/kant_aufklaerung_1784_0020.tif' + outfile = tmp_path / 'kant_aufklaerung_1784_0020.xml' + outrenderfile = tmp_path / 'render' / 'kant_aufklaerung_1784_0020.png' + outrenderfile.parent.mkdir() + if "-doit" in options: + options.insert(options.index("-doit") + 1, str(outrenderfile.parent)) + run_eynollah_ok_and_check_logs( + 'ocr', + [ + '-i', str(infile), + '-dx', str(infile.parent), + '-o', str(outfile.parent), + ] + options, + [ + # FIXME: ocr has no logging! + ] + ) + assert outfile.exists() + if "-doit" in options: + assert outrenderfile.exists() + #in_tree = page_from_file(str(infile)).etree + #in_order = in_tree.xpath("//page:OrderedGroup//@regionRef", namespaces=NS) + out_tree = page_from_file(str(outfile)).etree + out_texts = out_tree.xpath("//page:TextLine/page:TextEquiv[last()]/page:Unicode/text()", namespaces=NS) + assert len(out_texts) >= 2, ("result is inaccurate", out_texts) + assert sum(map(len, out_texts)) > 100, ("result is inaccurate", out_texts) + +def test_run_eynollah_ocr_directory( + tmp_path, + run_eynollah_ok_and_check_logs, + resources_dir, +): + outdir = tmp_path + run_eynollah_ok_and_check_logs( + 'ocr', + [ + '-di', str(resources_dir / '2files'), + '-dx', str(resources_dir / '2files'), + '-o', str(outdir), + ], + [ + # FIXME: ocr has no logging! + ] + ) + assert len(list(outdir.iterdir())) == 2 + diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..69f3d28 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,37 @@ +from glob import glob +import os +import pytest +from pathlib import Path + + +@pytest.fixture() +def tests_dir(): + return Path(__file__).parent.resolve() + +@pytest.fixture() +def model_dir(tests_dir): + return os.environ.get('EYNOLLAH_MODELS_DIR', str(tests_dir.joinpath('..').resolve())) + +@pytest.fixture() +def resources_dir(tests_dir): + return tests_dir / 'resources' + +@pytest.fixture() +def image_resources(resources_dir): + return [Path(x) for x in glob(str(resources_dir / '2files/*.tif'))] + +@pytest.fixture() +def eynollah_log_filter(): + return lambda logrec: logrec.name.startswith('eynollah') + +@pytest.fixture +def eynollah_subcommands(): + return [ + 'binarization', + 'layout', + 'ocr', + 'enhancement', + 'machine-based-reading-order', + 'models', + ] + diff --git a/tests/resources/euler_rechenkunst01_1738_0025.tif b/tests/resources/2files/euler_rechenkunst01_1738_0025.tif similarity index 100% rename from tests/resources/euler_rechenkunst01_1738_0025.tif rename to tests/resources/2files/euler_rechenkunst01_1738_0025.tif diff --git a/tests/resources/euler_rechenkunst01_1738_0025.xml b/tests/resources/2files/euler_rechenkunst01_1738_0025.xml similarity index 100% rename from tests/resources/euler_rechenkunst01_1738_0025.xml rename to tests/resources/2files/euler_rechenkunst01_1738_0025.xml diff --git a/tests/resources/kant_aufklaerung_1784_0020.tif b/tests/resources/2files/kant_aufklaerung_1784_0020.tif similarity index 100% rename from tests/resources/kant_aufklaerung_1784_0020.tif rename to tests/resources/2files/kant_aufklaerung_1784_0020.tif diff --git a/tests/resources/kant_aufklaerung_1784_0020.xml b/tests/resources/2files/kant_aufklaerung_1784_0020.xml similarity index 100% rename from tests/resources/kant_aufklaerung_1784_0020.xml rename to tests/resources/2files/kant_aufklaerung_1784_0020.xml diff --git a/tests/resources/marginalia/estor_rechtsgelehrsamkeit02_1758_0880_800px.jpg b/tests/resources/marginalia/estor_rechtsgelehrsamkeit02_1758_0880_800px.jpg new file mode 100644 index 0000000..9270508 Binary files /dev/null and b/tests/resources/marginalia/estor_rechtsgelehrsamkeit02_1758_0880_800px.jpg differ diff --git a/tests/resources/marginalia/estor_rechtsgelehrsamkeit02_1758_0880_800px.xml b/tests/resources/marginalia/estor_rechtsgelehrsamkeit02_1758_0880_800px.xml new file mode 100644 index 0000000..45240c4 --- /dev/null +++ b/tests/resources/marginalia/estor_rechtsgelehrsamkeit02_1758_0880_800px.xml @@ -0,0 +1,235 @@ + + + + SBB_QURATOR + 2025-10-30T16:38:21.180191 + 2025-10-30T16:38:21.180191 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/tests/test_model_zoo.py b/tests/test_model_zoo.py new file mode 100644 index 0000000..2042b28 --- /dev/null +++ b/tests/test_model_zoo.py @@ -0,0 +1,16 @@ +from eynollah.model_zoo import EynollahModelZoo + +def test_trocr1( + model_dir, +): + model_zoo = EynollahModelZoo(model_dir) + try: + from transformers import TrOCRProcessor, VisionEncoderDecoderModel + model_zoo.load_model('trocr_processor') + proc = model_zoo.get('trocr_processor', TrOCRProcessor) + assert isinstance(proc, TrOCRProcessor) + model_zoo.load_model('ocr', 'tr') + model = model_zoo.get('ocr', VisionEncoderDecoderModel) + assert isinstance(model, VisionEncoderDecoderModel) + except ImportError: + pass diff --git a/tests/test_run.py b/tests/test_run.py deleted file mode 100644 index 79c64c2..0000000 --- a/tests/test_run.py +++ /dev/null @@ -1,351 +0,0 @@ -from os import environ -from pathlib import Path -import pytest -import logging -from PIL import Image -from eynollah.cli import ( - layout as layout_cli, - binarization as binarization_cli, - enhancement as enhancement_cli, - machine_based_reading_order as mbreorder_cli, - ocr as ocr_cli, -) -from click.testing import CliRunner -from ocrd_modelfactory import page_from_file -from ocrd_models.constants import NAMESPACES as NS - -testdir = Path(__file__).parent.resolve() - -MODELS_LAYOUT = environ.get('MODELS_LAYOUT', str(testdir.joinpath('..', 'models_layout_v0_5_0').resolve())) -MODELS_OCR = environ.get('MODELS_OCR', str(testdir.joinpath('..', 'models_ocr_v0_5_1').resolve())) -MODELS_BIN = environ.get('MODELS_BIN', str(testdir.joinpath('..', 'default-2021-03-09').resolve())) - -@pytest.mark.parametrize( - "options", - [ - [], # defaults - #["--allow_scaling", "--curved-line"], - ["--allow_scaling", "--curved-line", "--full-layout"], - ["--allow_scaling", "--curved-line", "--full-layout", "--reading_order_machine_based"], - ["--allow_scaling", "--curved-line", "--full-layout", "--reading_order_machine_based", - "--textline_light", "--light_version"], - # -ep ... - # -eoi ... - # FIXME: find out whether OCR extra was installed, otherwise skip these - ["--do_ocr"], - ["--do_ocr", "--light_version", "--textline_light"], - ["--do_ocr", "--transformer_ocr"], - #["--do_ocr", "--transformer_ocr", "--light_version", "--textline_light"], - ["--do_ocr", "--transformer_ocr", "--light_version", "--textline_light", "--full-layout"], - # --skip_layout_and_reading_order - ], ids=str) -def test_run_eynollah_layout_filename(tmp_path, pytestconfig, caplog, options): - infile = testdir.joinpath('resources/kant_aufklaerung_1784_0020.tif') - outfile = tmp_path / 'kant_aufklaerung_1784_0020.xml' - args = [ - '-m', MODELS_LAYOUT, - '-i', str(infile), - '-o', str(outfile.parent), - ] - if pytestconfig.getoption('verbose') > 0: - args.extend(['-l', 'DEBUG']) - caplog.set_level(logging.INFO) - def only_eynollah(logrec): - return logrec.name == 'eynollah' - runner = CliRunner() - with caplog.filtering(only_eynollah): - result = runner.invoke(layout_cli, args + options, catch_exceptions=False) - assert result.exit_code == 0, result.stdout - logmsgs = [logrec.message for logrec in caplog.records] - assert str(infile) in logmsgs - assert outfile.exists() - tree = page_from_file(str(outfile)).etree - regions = tree.xpath("//page:TextRegion", namespaces=NS) - assert len(regions) >= 2, "result is inaccurate" - regions = tree.xpath("//page:SeparatorRegion", namespaces=NS) - assert len(regions) >= 2, "result is inaccurate" - lines = tree.xpath("//page:TextLine", namespaces=NS) - assert len(lines) == 31, "result is inaccurate" # 29 paragraph lines, 1 page and 1 catch-word line - -@pytest.mark.parametrize( - "options", - [ - ["--tables"], - ["--tables", "--full-layout"], - ["--tables", "--full-layout", "--textline_light", "--light_version"], - ], ids=str) -def test_run_eynollah_layout_filename2(tmp_path, pytestconfig, caplog, options): - infile = testdir.joinpath('resources/euler_rechenkunst01_1738_0025.tif') - outfile = tmp_path / 'euler_rechenkunst01_1738_0025.xml' - args = [ - '-m', MODELS_LAYOUT, - '-i', str(infile), - '-o', str(outfile.parent), - ] - if pytestconfig.getoption('verbose') > 0: - args.extend(['-l', 'DEBUG']) - caplog.set_level(logging.INFO) - def only_eynollah(logrec): - return logrec.name == 'eynollah' - runner = CliRunner() - with caplog.filtering(only_eynollah): - result = runner.invoke(layout_cli, args + options, catch_exceptions=False) - assert result.exit_code == 0, result.stdout - logmsgs = [logrec.message for logrec in caplog.records] - assert str(infile) in logmsgs - assert outfile.exists() - tree = page_from_file(str(outfile)).etree - regions = tree.xpath("//page:TextRegion", namespaces=NS) - assert len(regions) >= 2, "result is inaccurate" - regions = tree.xpath("//page:TableRegion", namespaces=NS) - # model/decoding is not very precise, so (depending on mode) we can get fractures/splits/FP - assert len(regions) >= 1, "result is inaccurate" - regions = tree.xpath("//page:SeparatorRegion", namespaces=NS) - assert len(regions) >= 2, "result is inaccurate" - lines = tree.xpath("//page:TextLine", namespaces=NS) - assert len(lines) >= 2, "result is inaccurate" # mostly table (if detected correctly), but 1 page and 1 catch-word line - -def test_run_eynollah_layout_directory(tmp_path, pytestconfig, caplog): - indir = testdir.joinpath('resources') - outdir = tmp_path - args = [ - '-m', MODELS_LAYOUT, - '-di', str(indir), - '-o', str(outdir), - ] - if pytestconfig.getoption('verbose') > 0: - args.extend(['-l', 'DEBUG']) - caplog.set_level(logging.INFO) - def only_eynollah(logrec): - return logrec.name == 'eynollah' - runner = CliRunner() - with caplog.filtering(only_eynollah): - result = runner.invoke(layout_cli, args, catch_exceptions=False) - assert result.exit_code == 0, result.stdout - logmsgs = [logrec.message for logrec in caplog.records] - assert len([logmsg for logmsg in logmsgs if logmsg.startswith('Job done in')]) == 2 - assert any(logmsg for logmsg in logmsgs if logmsg.startswith('All jobs done in')) - assert len(list(outdir.iterdir())) == 2 - -@pytest.mark.parametrize( - "options", - [ - [], # defaults - ["--no-patches"], - ], ids=str) -def test_run_eynollah_binarization_filename(tmp_path, pytestconfig, caplog, options): - infile = testdir.joinpath('resources/kant_aufklaerung_1784_0020.tif') - outfile = tmp_path.joinpath('kant_aufklaerung_1784_0020.png') - args = [ - '-m', MODELS_BIN, - '-i', str(infile), - '-o', str(outfile), - ] - if pytestconfig.getoption('verbose') > 0: - args.extend(['-l', 'DEBUG']) - caplog.set_level(logging.INFO) - def only_eynollah(logrec): - return logrec.name == 'SbbBinarizer' - runner = CliRunner() - with caplog.filtering(only_eynollah): - result = runner.invoke(binarization_cli, args + options, catch_exceptions=False) - assert result.exit_code == 0, result.stdout - logmsgs = [logrec.message for logrec in caplog.records] - assert any(True for logmsg in logmsgs if logmsg.startswith('Predicting')) - assert outfile.exists() - with Image.open(infile) as original_img: - original_size = original_img.size - with Image.open(outfile) as binarized_img: - binarized_size = binarized_img.size - assert original_size == binarized_size - -def test_run_eynollah_binarization_directory(tmp_path, pytestconfig, caplog): - indir = testdir.joinpath('resources') - outdir = tmp_path - args = [ - '-m', MODELS_BIN, - '-di', str(indir), - '-o', str(outdir), - ] - if pytestconfig.getoption('verbose') > 0: - args.extend(['-l', 'DEBUG']) - caplog.set_level(logging.INFO) - def only_eynollah(logrec): - return logrec.name == 'SbbBinarizer' - runner = CliRunner() - with caplog.filtering(only_eynollah): - result = runner.invoke(binarization_cli, args, catch_exceptions=False) - assert result.exit_code == 0, result.stdout - logmsgs = [logrec.message for logrec in caplog.records] - assert len([logmsg for logmsg in logmsgs if logmsg.startswith('Predicting')]) == 2 - assert len(list(outdir.iterdir())) == 2 - -@pytest.mark.parametrize( - "options", - [ - [], # defaults - ["-sos"], - ], ids=str) -def test_run_eynollah_enhancement_filename(tmp_path, pytestconfig, caplog, options): - infile = testdir.joinpath('resources/kant_aufklaerung_1784_0020.tif') - outfile = tmp_path.joinpath('kant_aufklaerung_1784_0020.png') - args = [ - '-m', MODELS_LAYOUT, - '-i', str(infile), - '-o', str(outfile.parent), - ] - if pytestconfig.getoption('verbose') > 0: - args.extend(['-l', 'DEBUG']) - caplog.set_level(logging.INFO) - def only_eynollah(logrec): - return logrec.name == 'enhancement' - runner = CliRunner() - with caplog.filtering(only_eynollah): - result = runner.invoke(enhancement_cli, args + options, catch_exceptions=False) - assert result.exit_code == 0, result.stdout - logmsgs = [logrec.message for logrec in caplog.records] - assert any(True for logmsg in logmsgs if logmsg.startswith('Image was enhanced')), logmsgs - assert outfile.exists() - with Image.open(infile) as original_img: - original_size = original_img.size - with Image.open(outfile) as enhanced_img: - enhanced_size = enhanced_img.size - assert (original_size == enhanced_size) == ("-sos" in options) - -def test_run_eynollah_enhancement_directory(tmp_path, pytestconfig, caplog): - indir = testdir.joinpath('resources') - outdir = tmp_path - args = [ - '-m', MODELS_LAYOUT, - '-di', str(indir), - '-o', str(outdir), - ] - if pytestconfig.getoption('verbose') > 0: - args.extend(['-l', 'DEBUG']) - caplog.set_level(logging.INFO) - def only_eynollah(logrec): - return logrec.name == 'enhancement' - runner = CliRunner() - with caplog.filtering(only_eynollah): - result = runner.invoke(enhancement_cli, args, catch_exceptions=False) - assert result.exit_code == 0, result.stdout - logmsgs = [logrec.message for logrec in caplog.records] - assert len([logmsg for logmsg in logmsgs if logmsg.startswith('Image was enhanced')]) == 2 - assert len(list(outdir.iterdir())) == 2 - -def test_run_eynollah_mbreorder_filename(tmp_path, pytestconfig, caplog): - infile = testdir.joinpath('resources/kant_aufklaerung_1784_0020.xml') - outfile = tmp_path.joinpath('kant_aufklaerung_1784_0020.xml') - args = [ - '-m', MODELS_LAYOUT, - '-i', str(infile), - '-o', str(outfile.parent), - ] - if pytestconfig.getoption('verbose') > 0: - args.extend(['-l', 'DEBUG']) - caplog.set_level(logging.INFO) - def only_eynollah(logrec): - return logrec.name == 'mbreorder' - runner = CliRunner() - with caplog.filtering(only_eynollah): - result = runner.invoke(mbreorder_cli, args, catch_exceptions=False) - assert result.exit_code == 0, result.stdout - logmsgs = [logrec.message for logrec in caplog.records] - # FIXME: mbreorder has no logging! - #assert any(True for logmsg in logmsgs if logmsg.startswith('???')), logmsgs - assert outfile.exists() - #in_tree = page_from_file(str(infile)).etree - #in_order = in_tree.xpath("//page:OrderedGroup//@regionRef", namespaces=NS) - out_tree = page_from_file(str(outfile)).etree - out_order = out_tree.xpath("//page:OrderedGroup//@regionRef", namespaces=NS) - #assert len(out_order) >= 2, "result is inaccurate" - #assert in_order != out_order - assert out_order == ['r_1_1', 'r_2_1', 'r_2_2', 'r_2_3'] - -def test_run_eynollah_mbreorder_directory(tmp_path, pytestconfig, caplog): - indir = testdir.joinpath('resources') - outdir = tmp_path - args = [ - '-m', MODELS_LAYOUT, - '-di', str(indir), - '-o', str(outdir), - ] - if pytestconfig.getoption('verbose') > 0: - args.extend(['-l', 'DEBUG']) - caplog.set_level(logging.INFO) - def only_eynollah(logrec): - return logrec.name == 'mbreorder' - runner = CliRunner() - with caplog.filtering(only_eynollah): - result = runner.invoke(mbreorder_cli, args, catch_exceptions=False) - assert result.exit_code == 0, result.stdout - logmsgs = [logrec.message for logrec in caplog.records] - # FIXME: mbreorder has no logging! - #assert len([logmsg for logmsg in logmsgs if logmsg.startswith('???')]) == 2 - assert len(list(outdir.iterdir())) == 2 - -@pytest.mark.parametrize( - "options", - [ - [], # defaults - ["-doit", #str(outrenderfile.parent)], - ], - ["-trocr"], - ], ids=str) -def test_run_eynollah_ocr_filename(tmp_path, pytestconfig, caplog, options): - infile = testdir.joinpath('resources/kant_aufklaerung_1784_0020.tif') - outfile = tmp_path.joinpath('kant_aufklaerung_1784_0020.xml') - outrenderfile = tmp_path.joinpath('render').joinpath('kant_aufklaerung_1784_0020.png') - outrenderfile.parent.mkdir() - args = [ - '-m', MODELS_OCR, - '-i', str(infile), - '-dx', str(infile.parent), - '-o', str(outfile.parent), - ] - if pytestconfig.getoption('verbose') > 0: - args.extend(['-l', 'DEBUG']) - caplog.set_level(logging.DEBUG) - def only_eynollah(logrec): - return logrec.name == 'eynollah' - runner = CliRunner() - if "-doit" in options: - options.insert(options.index("-doit") + 1, str(outrenderfile.parent)) - with caplog.filtering(only_eynollah): - result = runner.invoke(ocr_cli, args + options, catch_exceptions=False) - assert result.exit_code == 0, result.stdout - logmsgs = [logrec.message for logrec in caplog.records] - # FIXME: ocr has no logging! - #assert any(True for logmsg in logmsgs if logmsg.startswith('???')), logmsgs - assert outfile.exists() - if "-doit" in options: - assert outrenderfile.exists() - #in_tree = page_from_file(str(infile)).etree - #in_order = in_tree.xpath("//page:OrderedGroup//@regionRef", namespaces=NS) - out_tree = page_from_file(str(outfile)).etree - out_texts = out_tree.xpath("//page:TextLine/page:TextEquiv[last()]/page:Unicode/text()", namespaces=NS) - assert len(out_texts) >= 2, ("result is inaccurate", out_texts) - assert sum(map(len, out_texts)) > 100, ("result is inaccurate", out_texts) - -def test_run_eynollah_ocr_directory(tmp_path, pytestconfig, caplog): - indir = testdir.joinpath('resources') - outdir = tmp_path - args = [ - '-m', MODELS_OCR, - '-di', str(indir), - '-dx', str(indir), - '-o', str(outdir), - ] - if pytestconfig.getoption('verbose') > 0: - args.extend(['-l', 'DEBUG']) - caplog.set_level(logging.INFO) - def only_eynollah(logrec): - return logrec.name == 'eynollah' - runner = CliRunner() - with caplog.filtering(only_eynollah): - result = runner.invoke(ocr_cli, args, catch_exceptions=False) - assert result.exit_code == 0, result.stdout - logmsgs = [logrec.message for logrec in caplog.records] - # FIXME: ocr has no logging! - #assert any(True for logmsg in logmsgs if logmsg.startswith('???')), logmsgs - assert len(list(outdir.iterdir())) == 2 diff --git a/train/README.md b/train/README.md deleted file mode 100644 index 5f6d326..0000000 --- a/train/README.md +++ /dev/null @@ -1,59 +0,0 @@ -# Training eynollah - -This README explains the technical details of how to set up and run training, for detailed information on parameterization, see [`docs/train.md`](../docs/train.md) - -## Introduction - -This folder contains the source code for training an encoder model for document image segmentation. - -## Installation - -Clone the repository and install eynollah along with the dependencies necessary for training: - -```sh -git clone https://github.com/qurator-spk/eynollah -cd eynollah -pip install '.[training]' -``` - -### Pretrained encoder - -Download our pretrained weights and add them to a `train/pretrained_model` folder: - -```sh -cd train -wget -O pretrained_model.tar.gz https://zenodo.org/records/17243320/files/pretrained_model_v0_5_1.tar.gz?download=1 -tar xf pretrained_model.tar.gz -``` - -### Binarization training data - -A small sample of training data for binarization experiment can be found [on -zenodo](https://zenodo.org/records/17243320/files/training_data_sample_binarization_v0_5_1.tar.gz?download=1), -which contains `images` and `labels` folders. - -### Helpful tools - -* [`pagexml2img`](https://github.com/qurator-spk/page2img) -> Tool to extract 2-D or 3-D RGB images from PAGE-XML data. In the former case, the output will be 1 2-D image array which each class has filled with a pixel value. In the case of a 3-D RGB image, -each class will be defined with a RGB value and beside images, a text file of classes will also be produced. -* [`cocoSegmentationToPng`](https://github.com/nightrome/cocostuffapi/blob/17acf33aef3c6cc2d6aca46dcf084266c2778cf0/PythonAPI/pycocotools/cocostuffhelper.py#L130) -> Convert COCO GT or results for a single image to a segmentation map and write it to disk. -* [`ocrd-segment-extract-pages`](https://github.com/OCR-D/ocrd_segment/blob/master/ocrd_segment/extract_pages.py) -> Extract region classes and their colours in mask (pseg) images. Allows the color map as free dict parameter, and comes with a default that mimics PageViewer's coloring for quick debugging; it also warns when regions do overlap. - -### Train using Docker - -Build the Docker image: - -```bash -cd train -docker build -t model-training . -``` - -Run Docker image - -```bash -cd train -docker run --gpus all -v $PWD:/entry_point_dir model-training -```