diff --git a/.github/workflows/test-eynollah.yml b/.github/workflows/test-eynollah.yml index 466e690..82de94d 100644 --- a/.github/workflows/test-eynollah.yml +++ b/.github/workflows/test-eynollah.yml @@ -24,61 +24,59 @@ jobs: sudo rm -rf "$AGENT_TOOLSDIRECTORY" df -h - uses: actions/checkout@v4 - - uses: actions/cache/restore@v4 - id: seg_model_cache + + # - name: Lint with ruff + # uses: astral-sh/ruff-action@v3 + # with: + # src: "./src" + + - name: Try to restore models_eynollah + uses: actions/cache/restore@v4 + id: all_model_cache with: - path: models_layout_v0_5_0 - key: seg-models - - uses: actions/cache/restore@v4 - id: ocr_model_cache - with: - path: models_ocr_v0_5_1 - key: ocr-models - - uses: actions/cache/restore@v4 - id: bin_model_cache - with: - path: default-2021-03-09 - key: bin-models + path: models_eynollah + key: models_eynollah-${{ hashFiles('src/eynollah/model_zoo/default_specs.py') }} + - name: Download models - if: steps.seg_model_cache.outputs.cache-hit != 'true' || steps.bin_model_cache.outputs.cache-hit != 'true' || steps.ocr_model_cache.outputs.cache-hit != true - run: make models + if: steps.all_model_cache.outputs.cache-hit != 'true' + run: | + make models + ls -la models_eynollah + - uses: actions/cache/save@v4 - if: steps.seg_model_cache.outputs.cache-hit != 'true' + if: steps.all_model_cache.outputs.cache-hit != 'true' with: - path: models_layout_v0_5_0 - key: seg-models - - uses: actions/cache/save@v4 - if: steps.ocr_model_cache.outputs.cache-hit != 'true' - with: - path: models_ocr_v0_5_1 - key: ocr-models - - uses: actions/cache/save@v4 - if: steps.bin_model_cache.outputs.cache-hit != 'true' - with: - path: default-2021-03-09 - key: bin-models + path: models_eynollah + key: models_eynollah-${{ hashFiles('src/eynollah/model_zoo/default_specs.py') }} + - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} + + # - uses: actions/cache@v4 + # with: + # path: | + # path/to/dependencies + # some/other/dependencies + # key: ${{ runner.os }}-${{ hashFiles('**/lockfiles') }} + - name: Install dependencies run: | python -m pip install --upgrade pip make install-dev EXTRAS=OCR,plotting make deps-test EXTRAS=OCR,plotting - ls -l models_* - - name: Lint with ruff - uses: astral-sh/ruff-action@v3 - with: - src: "./src" + - name: Test with pytest run: make coverage PYTEST_ARGS="-vv --junitxml=pytest.xml" + - name: Get coverage results run: | coverage report --format=markdown >> $GITHUB_STEP_SUMMARY coverage html coverage json coverage xml + - name: Store coverage results uses: actions/upload-artifact@v4 with: @@ -88,12 +86,15 @@ jobs: pytest.xml coverage.xml coverage.json + - name: Upload coverage results uses: codecov/codecov-action@v4 with: files: coverage.xml fail_ci_if_error: false + - name: Test standalone CLI run: make smoke-test + - name: Test OCR-D CLI run: make ocrd-test diff --git a/.gitignore b/.gitignore index fd64f0b..49835a7 100644 --- a/.gitignore +++ b/.gitignore @@ -11,3 +11,4 @@ output.html *.tif *.sw? TAGS +uv.lock diff --git a/CHANGELOG.md b/CHANGELOG.md index c2caaa6..2848d21 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,17 @@ Versioned according to [Semantic Versioning](http://semver.org/). ## Unreleased +Added: + + * "Model zoo", central place to describe and load models, #207 + * Training code for the CNN/RNN OCR model + +Changed: + + * Lint training code, #204 + * Update documentation: README, pyproject.toml metadata, guides in `docs/`, #209 + + ## [0.6.0] - 2025-10-17 Added: @@ -307,6 +318,7 @@ Fixed: Initial release +[0.7.0]: ../../compare/v0.7.0...v0.6.0 [0.6.0]: ../../compare/v0.6.0...v0.6.0rc2 [0.6.0rc2]: ../../compare/v0.6.0rc2...v0.6.0rc1 [0.6.0rc1]: ../../compare/v0.6.0rc1...v0.5.0 diff --git a/Makefile b/Makefile index 29dd877..c1458df 100644 --- a/Makefile +++ b/Makefile @@ -6,23 +6,17 @@ EXTRAS ?= DOCKER_BASE_IMAGE ?= docker.io/ocrd/core-cuda-tf2:latest DOCKER_TAG ?= ocrd/eynollah DOCKER ?= docker +WGET = wget -O #SEG_MODEL := https://qurator-data.de/eynollah/2021-04-25/models_eynollah.tar.gz #SEG_MODEL := https://qurator-data.de/eynollah/2022-04-05/models_eynollah_renamed.tar.gz # SEG_MODEL := https://qurator-data.de/eynollah/2022-04-05/models_eynollah.tar.gz #SEG_MODEL := https://github.com/qurator-spk/eynollah/releases/download/v0.3.0/models_eynollah.tar.gz #SEG_MODEL := https://github.com/qurator-spk/eynollah/releases/download/v0.3.1/models_eynollah.tar.gz -SEG_MODEL := https://zenodo.org/records/17194824/files/models_layout_v0_5_0.tar.gz?download=1 -SEG_MODELFILE = $(notdir $(patsubst %?download=1,%,$(SEG_MODEL))) -SEG_MODELNAME = $(SEG_MODELFILE:%.tar.gz=%) - -BIN_MODEL := https://github.com/qurator-spk/sbb_binarization/releases/download/v0.0.11/saved_model_2021_03_09.zip -BIN_MODELFILE = $(notdir $(BIN_MODEL)) -BIN_MODELNAME := default-2021-03-09 - -OCR_MODEL := https://zenodo.org/records/17236998/files/models_ocr_v0_5_1.tar.gz?download=1 -OCR_MODELFILE = $(notdir $(patsubst %?download=1,%,$(OCR_MODEL))) -OCR_MODELNAME = $(OCR_MODELFILE:%.tar.gz=%) +#SEG_MODEL := https://zenodo.org/records/17194824/files/models_layout_v0_5_0.tar.gz?download=1 +EYNOLLAH_MODELS_URL := https://zenodo.org/records/17580627/files/models_all_v0_7_0.zip +EYNOLLAH_MODELS_ZIP = $(notdir $(EYNOLLAH_MODELS_URL)) +EYNOLLAH_MODELS_DIR = $(EYNOLLAH_MODELS_ZIP:%.zip=%) PYTEST_ARGS ?= -vv --isolate @@ -38,7 +32,7 @@ help: @echo " install-dev Install editable with pip" @echo " deps-test Install test dependencies with pip" @echo " models Download and extract models to $(CURDIR):" - @echo " $(BIN_MODELNAME) $(SEG_MODELNAME) $(OCR_MODELNAME)" + @echo " $(EYNOLLAH_MODELS_DIR)" @echo " smoke-test Run simple CLI check" @echo " ocrd-test Run OCR-D CLI check" @echo " test Run unit tests" @@ -47,34 +41,22 @@ help: @echo " EXTRAS comma-separated list of features (like 'OCR,plotting') for 'install' [$(EXTRAS)]" @echo " DOCKER_TAG Docker image tag for 'docker' [$(DOCKER_TAG)]" @echo " PYTEST_ARGS pytest args for 'test' (Set to '-s' to see log output during test execution, '-vv' to see individual tests. [$(PYTEST_ARGS)]" - @echo " SEG_MODEL URL of 'models' archive to download for segmentation 'test' [$(SEG_MODEL)]" - @echo " BIN_MODEL URL of 'models' archive to download for binarization 'test' [$(BIN_MODEL)]" - @echo " OCR_MODEL URL of 'models' archive to download for binarization 'test' [$(OCR_MODEL)]" + @echo " ALL_MODELS URL of archive of all models [$(ALL_MODELS)]" @echo "" # END-EVAL - -# Download and extract models to $(PWD)/models_layout_v0_5_0 -models: $(BIN_MODELNAME) $(SEG_MODELNAME) $(OCR_MODELNAME) +# Download and extract models to $(PWD)/models_layout_v0_6_0 +models: $(EYNOLLAH_MODELS_DIR) # do not download these files if we already have the directories -.INTERMEDIATE: $(BIN_MODELFILE) $(SEG_MODELFILE) $(OCR_MODELFILE) +.INTERMEDIATE: $(EYNOLLAH_MODELS_ZIP) -$(BIN_MODELFILE): - wget -O $@ $(BIN_MODEL) -$(SEG_MODELFILE): - wget -O $@ $(SEG_MODEL) -$(OCR_MODELFILE): - wget -O $@ $(OCR_MODEL) +$(EYNOLLAH_MODELS_ZIP): + $(WGET) $@ $(EYNOLLAH_MODELS_URL) -$(BIN_MODELNAME): $(BIN_MODELFILE) - mkdir $@ - unzip -d $@ $< -$(SEG_MODELNAME): $(SEG_MODELFILE) - tar zxf $< -$(OCR_MODELNAME): $(OCR_MODELFILE) - tar zxf $< +$(EYNOLLAH_MODELS_DIR): $(EYNOLLAH_MODELS_ZIP) + unzip $< build: $(PIP) install build @@ -88,56 +70,48 @@ install: install-dev: $(PIP) install -e .$(and $(EXTRAS),[$(EXTRAS)]) -ifeq (OCR,$(findstring OCR, $(EXTRAS))) -deps-test: $(OCR_MODELNAME) -endif -deps-test: $(BIN_MODELNAME) $(SEG_MODELNAME) +deps-test: $(PIP) install -r requirements-test.txt -ifeq (OCR,$(findstring OCR, $(EXTRAS))) - ln -rs $(OCR_MODELNAME)/* $(SEG_MODELNAME)/ -endif smoke-test: TMPDIR != mktemp -d -smoke-test: tests/resources/kant_aufklaerung_1784_0020.tif +smoke-test: tests/resources/2files/kant_aufklaerung_1784_0020.tif # layout analysis: - eynollah layout -i $< -o $(TMPDIR) -m $(CURDIR)/$(SEG_MODELNAME) + eynollah -m $(CURDIR) layout -i $< -o $(TMPDIR) fgrep -q http://schema.primaresearch.org/PAGE/gts/pagecontent/2019-07-15 $(TMPDIR)/$(basename $( Document Layout Analysis, Binarization and OCR with Deep Learning and Heuristics +[![Python Versions](https://img.shields.io/pypi/pyversions/eynollah.svg)](https://pypi.python.org/pypi/eynollah) [![PyPI Version](https://img.shields.io/pypi/v/eynollah)](https://pypi.org/project/eynollah/) [![GH Actions Test](https://github.com/qurator-spk/eynollah/actions/workflows/test-eynollah.yml/badge.svg)](https://github.com/qurator-spk/eynollah/actions/workflows/test-eynollah.yml) [![GH Actions Deploy](https://github.com/qurator-spk/eynollah/actions/workflows/build-docker.yml/badge.svg)](https://github.com/qurator-spk/eynollah/actions/workflows/build-docker.yml) @@ -11,24 +12,22 @@ ![](https://user-images.githubusercontent.com/952378/102350683-8a74db80-3fa5-11eb-8c7e-f743f7d6eae2.jpg) ## Features -* Support for 10 distinct segmentation classes: +* Document layout analysis using pixelwise segmentation models with support for 10 segmentation classes: * background, [page border](https://ocr-d.de/en/gt-guidelines/trans/lyRand.html), [text region](https://ocr-d.de/en/gt-guidelines/trans/lytextregion.html#textregionen__textregion_), [text line](https://ocr-d.de/en/gt-guidelines/pagexml/pagecontent_xsd_Complex_Type_pc_TextLineType.html), [header](https://ocr-d.de/en/gt-guidelines/trans/lyUeberschrift.html), [image](https://ocr-d.de/en/gt-guidelines/trans/lyBildbereiche.html), [separator](https://ocr-d.de/en/gt-guidelines/trans/lySeparatoren.html), [marginalia](https://ocr-d.de/en/gt-guidelines/trans/lyMarginalie.html), [initial](https://ocr-d.de/en/gt-guidelines/trans/lyInitiale.html), [table](https://ocr-d.de/en/gt-guidelines/trans/lyTabellen.html) -* Support for various image optimization operations: - * cropping (border detection), binarization, deskewing, dewarping, scaling, enhancing, resizing * Textline segmentation to bounding boxes or polygons (contours) including for curved lines and vertical text -* Text recognition (OCR) using either CNN-RNN or Transformer models -* Detection of reading order (left-to-right or right-to-left) using either heuristics or trainable models +* Document image binarization with pixelwise segmentation or hybrid CNN-Transformer models +* Text recognition (OCR) with CNN-RNN or TrOCR models +* Detection of reading order (left-to-right or right-to-left) using heuristics or trainable models * Output in [PAGE-XML](https://github.com/PRImA-Research-Lab/PAGE-XML) * [OCR-D](https://github.com/qurator-spk/eynollah#use-as-ocr-d-processor) interface :warning: Development is focused on achieving the best quality of results for a wide variety of historical -documents and therefore processing can be very slow. We aim to improve this, but contributions are welcome. +documents using a combination of multiple deep learning models and heuristics; therefore processing can be slow. ## Installation - Python `3.8-3.11` with Tensorflow `<2.13` on Linux are currently supported. - -For (limited) GPU support the CUDA toolkit needs to be installed. A known working config is CUDA `11` with cuDNN `8.6`. +For (limited) GPU support the CUDA toolkit needs to be installed. +A working config is CUDA `11.8` with cuDNN `8.6`. You can either install from PyPI @@ -53,31 +52,41 @@ pip install "eynollah[OCR]" make install EXTRAS=OCR ``` +### Docker + +Use + +``` +docker pull ghcr.io/qurator-spk/eynollah:latest +``` + +When using Eynollah with Docker, see [`docker.md`](https://github.com/qurator-spk/eynollah/tree/main/docs/docker.md). + ## Models -Pretrained models can be downloaded from [zenodo](https://zenodo.org/records/17194824) or [huggingface](https://huggingface.co/SBB?search_models=eynollah). +Pretrained models can be downloaded from [Zenodo](https://zenodo.org/records/17194824) or [Hugging Face](https://huggingface.co/SBB?search_models=eynollah). -For documentation on models, have a look at [`models.md`](https://github.com/qurator-spk/eynollah/tree/main/docs/models.md). -Model cards are also provided for our trained models. +For model documentation and model cards, see [`models.md`](https://github.com/qurator-spk/eynollah/tree/main/docs/models.md). ## Training -In case you want to train your own model with Eynollah, see the -documentation in [`train.md`](https://github.com/qurator-spk/eynollah/tree/main/docs/train.md) and use the -tools in the [`train` folder](https://github.com/qurator-spk/eynollah/tree/main/train). +To train your own model with Eynollah, see [`train.md`](https://github.com/qurator-spk/eynollah/tree/main/docs/train.md) and use the tools in the [`train`](https://github.com/qurator-spk/eynollah/tree/main/train) folder. ## Usage -Eynollah supports five use cases: layout analysis (segmentation), binarization, -image enhancement, text recognition (OCR), and reading order detection. +Eynollah supports five use cases: +1. [layout analysis (segmentation)](#layout-analysis), +2. [binarization](#binarization), +3. [image enhancement](#image-enhancement), +4. [text recognition (OCR)](#ocr), and +5. [reading order detection](#reading-order-detection). + +Some example outputs can be found in [`examples.md`](https://github.com/qurator-spk/eynollah/tree/main/docs/examples.md). ### Layout Analysis The layout analysis module is responsible for detecting layout elements, identifying text lines, and determining reading -order using either heuristic methods or a [pretrained reading order detection model](https://github.com/qurator-spk/eynollah#machine-based-reading-order). - -Reading order detection can be performed either as part of layout analysis based on image input, or, currently under -development, based on pre-existing layout analysis results in PAGE-XML format as input. +order using heuristic methods or a [pretrained model](https://github.com/qurator-spk/eynollah#machine-based-reading-order). The command-line interface for layout analysis can be called like this: @@ -91,29 +100,36 @@ eynollah layout \ The following options can be used to further configure the processing: -| option | description | -|-------------------|:-------------------------------------------------------------------------------| -| `-fl` | full layout analysis including all steps and segmentation classes | -| `-light` | lighter and faster but simpler method for main region detection and deskewing | -| `-tll` | this indicates the light textline and should be passed with light version | -| `-tab` | apply table detection | -| `-ae` | apply enhancement (the resulting image is saved to the output directory) | -| `-as` | apply scaling | -| `-cl` | apply contour detection for curved text lines instead of bounding boxes | -| `-ib` | apply binarization (the resulting image is saved to the output directory) | -| `-ep` | enable plotting (MUST always be used with `-sl`, `-sd`, `-sa`, `-si` or `-ae`) | -| `-eoi` | extract only images to output directory (other processing will not be done) | -| `-ho` | ignore headers for reading order dectection | -| `-si ` | save image regions detected to this directory | -| `-sd ` | save deskewed image to this directory | -| `-sl ` | save layout prediction as plot to this directory | -| `-sp ` | save cropped page image to this directory | -| `-sa ` | save all (plot, enhanced/binary image, layout) to this directory | +| option | description | +|-------------------|:--------------------------------------------------------------------------------------------| +| `-fl` | full layout analysis including all steps and segmentation classes (recommended) | +| `-tab` | apply table detection | +| `-ae` | apply enhancement (the resulting image is saved to the output directory) | +| `-as` | apply scaling | +| `-cl` | apply contour detection for curved text lines instead of bounding boxes | +| `-ib` | apply binarization (the resulting image is saved to the output directory) | +| `-ep` | enable plotting (MUST always be used with `-sl`, `-sd`, `-sa`, `-si` or `-ae`) | +| `-ho` | ignore headers for reading order dectection | +| `-si ` | save image regions detected to this directory | +| `-sd ` | save deskewed image to this directory | +| `-sl ` | save layout prediction as plot to this directory | +| `-sp ` | save cropped page image to this directory | +| `-sa ` | save all (plot, enhanced/binary image, layout) to this directory | +| `-thart` | threshold of artifical class in the case of textline detection. The default value is 0.1 | +| `-tharl` | threshold of artifical class in the case of layout detection. The default value is 0.1 | +| `-ncu` | upper limit of columns in document image | +| `-ncl` | lower limit of columns in document image | +| `-slro` | skip layout detection and reading order | +| `-romb` | apply machine based reading order detection | +| `-ipe` | ignore page extraction | + If no further option is set, the tool performs layout detection of main regions (background, text, images, separators and marginals). The best output quality is achieved when RGB images are used as input rather than greyscale or binarized images. +Additional documentation can be found in [`usage.md`](https://github.com/qurator-spk/eynollah/tree/main/docs/usage.md). + ### Binarization The binarization module performs document image binarization using pretrained pixelwise segmentation models. @@ -124,9 +140,12 @@ The command-line interface for binarization can be called like this: eynollah binarization \ -i | -di \ -o \ - -m \ + -m ``` +### Image Enhancement +TODO + ### OCR The OCR module performs text recognition using either a CNN-RNN model or a Transformer model. @@ -138,12 +157,29 @@ eynollah ocr \ -i | -di \ -dx \ -o \ - -m | --model_name \ + -m | --model_name ``` -### Machine-based-reading-order +The following options can be used to further configure the ocr processing: -The machine-based reading-order module employs a pretrained model to identify the reading order from layouts represented in PAGE-XML files. +| option | description | +|-------------------|:-------------------------------------------------------------------------------------------| +| `-dib` | directory of binarized images (file type must be '.png'), prediction with both RGB and bin | +| `-doit` | directory for output images rendered with the predicted text | +| `--model_name` | file path to use specific model for OCR | +| `-trocr` | use transformer ocr model (otherwise cnn_rnn model is used) | +| `-etit` | export textline images and text in xml to output dir (OCR training data) | +| `-nmtc` | cropped textline images will not be masked with textline contour | +| `-bs` | ocr inference batch size. Default batch size is 2 for trocr and 8 for cnn_rnn models | +| `-ds_pref` | add an abbrevation of dataset name to generated training data | +| `-min_conf` | minimum OCR confidence value. OCR with textline conf lower than this will be ignored | + + +### Reading Order Detection +Reading order detection can be performed either as part of layout analysis based on image input, or, currently under +development, based on pre-existing layout analysis data in PAGE-XML format as input. + +The reading order detection module employs a pretrained model to identify the reading order from layouts represented in PAGE-XML files. The command-line interface for machine based reading order can be called like this: @@ -155,36 +191,9 @@ eynollah machine-based-reading-order \ -o ``` -#### Use as OCR-D processor +## Use as OCR-D processor -Eynollah ships with a CLI interface to be used as [OCR-D](https://ocr-d.de) [processor](https://ocr-d.de/en/spec/cli), -formally described in [`ocrd-tool.json`](https://github.com/qurator-spk/eynollah/tree/main/src/eynollah/ocrd-tool.json). - -In this case, the source image file group with (preferably) RGB images should be used as input like this: - - ocrd-eynollah-segment -I OCR-D-IMG -O OCR-D-SEG -P models eynollah_layout_v0_5_0 - -If the input file group is PAGE-XML (from a previous OCR-D workflow step), Eynollah behaves as follows: -- existing regions are kept and ignored (i.e. in effect they might overlap segments from Eynollah results) -- existing annotation (and respective `AlternativeImage`s) are partially _ignored_: - - previous page frame detection (`cropped` images) - - previous derotation (`deskewed` images) - - previous thresholding (`binarized` images) -- if the page-level image nevertheless deviates from the original (`@imageFilename`) - (because some other preprocessing step was in effect like `denoised`), then - the output PAGE-XML will be based on that as new top-level (`@imageFilename`) - - ocrd-eynollah-segment -I OCR-D-XYZ -O OCR-D-SEG -P models eynollah_layout_v0_5_0 - -In general, it makes more sense to add other workflow steps **after** Eynollah. - -There is also an OCR-D processor for binarization: - - ocrd-sbb-binarize -I OCR-D-IMG -O OCR-D-BIN -P models default-2021-03-09 - -#### Additional documentation - -Additional documentation is available in the [docs](https://github.com/qurator-spk/eynollah/tree/main/docs) directory. +See [`ocrd.md`](https://github.com/qurator-spk/eynollah/tree/main/docs/ocrd.md). ## How to cite diff --git a/docs/docker.md b/docs/docker.md new file mode 100644 index 0000000..7965622 --- /dev/null +++ b/docs/docker.md @@ -0,0 +1,43 @@ +## Inference with Docker + + docker pull ghcr.io/qurator-spk/eynollah:latest + +### 1. ocrd resource manager +(just once, to get the models and install them into a named volume for later re-use) + + vol_models=ocrd-resources:/usr/local/share/ocrd-resources + docker run --rm -v $vol_models ocrd/eynollah ocrd resmgr download ocrd-eynollah-segment default + +Now, each time you want to use Eynollah, pass the same resources volume again. +Also, bind-mount some data directory, e.g. current working directory $PWD (/data is default working directory in the container). + +Either use standalone CLI (2) or OCR-D CLI (3): + +### 2. standalone CLI +(follow self-help, cf. readme) + + docker run --rm -v $vol_models -v $PWD:/data ocrd/eynollah eynollah binarization --help + docker run --rm -v $vol_models -v $PWD:/data ocrd/eynollah eynollah layout --help + docker run --rm -v $vol_models -v $PWD:/data ocrd/eynollah eynollah ocr --help + +### 3. OCR-D CLI +(follow self-help, cf. readme and https://ocr-d.de/en/spec/cli) + + docker run --rm -v $vol_models -v $PWD:/data ocrd/eynollah ocrd-eynollah-segment -h + docker run --rm -v $vol_models -v $PWD:/data ocrd/eynollah ocrd-sbb-binarize -h + +Alternatively, just "log in" to the container once and use the commands there: + + docker run --rm -v $vol_models -v $PWD:/data -it ocrd/eynollah bash + +## Training with Docker + +Build the Docker training image + + cd train + docker build -t model-training . + +Run the Docker training image + + cd train + docker run --gpus all -v $PWD:/entry_point_dir model-training diff --git a/docs/examples.md b/docs/examples.md new file mode 100644 index 0000000..24336b3 --- /dev/null +++ b/docs/examples.md @@ -0,0 +1,18 @@ +# Examples + +Example outputs of various Eynollah models + +# Binarisation + + + + +# Reading Order Detection + +Input Image +Output Image + +# OCR + +Input ImageOutput Image +Input ImageOutput Image diff --git a/docs/models.md b/docs/models.md index 3d296d5..b858630 100644 --- a/docs/models.md +++ b/docs/models.md @@ -18,7 +18,8 @@ Two Arabic/Persian terms form the name of the model suite: عين الله, whic See the flowchart below for the different stages and how they interact: -![](https://user-images.githubusercontent.com/952378/100619946-1936f680-331e-11eb-9297-6e8b4cab3c16.png) +eynollah_flowchart + ## Models @@ -151,15 +152,75 @@ This model is used for the task of illustration detection only. Model card: [Reading Order Detection]() -TODO +The model extracts the reading order of text regions from the layout by classifying pairwise relationships between them. A sorting algorithm then determines the overall reading sequence. + +### OCR + +We have trained three OCR models: two CNN-RNN–based models and one transformer-based TrOCR model. The CNN-RNN models are generally faster and provide better results in most cases, though their performance decreases with heavily degraded images. The TrOCR model, on the other hand, is computationally expensive and slower during inference, but it can possibly produce better results on strongly degraded images. + +#### CNN-RNN model: model_eynollah_ocr_cnnrnn_20250805 + +This model is trained on data where most of the samples are in Fraktur german script. + +| Dataset | Input | CER | WER | +|-----------------------|:-------|:-----------|:----------| +| OCR-D-GT-Archiveform | BIN | 0.02147 | 0.05685 | +| OCR-D-GT-Archiveform | RGB | 0.01636 | 0.06285 | + +#### CNN-RNN model: model_eynollah_ocr_cnnrnn_20250904 (Default) + +Compared to the model_eynollah_ocr_cnnrnn_20250805 model, this model is trained on a larger proportion of Antiqua data and achieves superior performance. + +| Dataset | Input | CER | WER | +|-----------------------|:------------|:-----------|:----------| +| OCR-D-GT-Archiveform | BIN | 0.01635 | 0.05410 | +| OCR-D-GT-Archiveform | RGB | 0.01471 | 0.05813 | +| BLN600 | RGB | 0.04409 | 0.08879 | +| BLN600 | Enhanced | 0.03599 | 0.06244 | + + +#### Transformer OCR model: model_eynollah_ocr_trocr_20250919 + +This transformer OCR model is trained on the same data as model_eynollah_ocr_trocr_20250919. + +| Dataset | Input | CER | WER | +|-----------------------|:------------|:-----------|:----------| +| OCR-D-GT-Archiveform | BIN | 0.01841 | 0.05589 | +| OCR-D-GT-Archiveform | RGB | 0.01552 | 0.06177 | +| BLN600 | RGB | 0.06347 | 0.13853 | + +##### Qualitative evaluation of the models + +| | | | | +|:---:|:---:|:---:|:---:| +| Image | cnnrnn_20250805 | cnnrnn_20250904 | trocr_20250919 | + + + +| | | | | +|:---:|:---:|:---:|:---:| +| Image | cnnrnn_20250805 | cnnrnn_20250904 | trocr_20250919 | + + +| | | | | +|:---:|:---:|:---:|:---:| +| Image | cnnrnn_20250805 | cnnrnn_20250904 | trocr_20250919 | + + +| | | | | +|:---:|:---:|:---:|:---:| +| Image | cnnrnn_20250805 | cnnrnn_20250904 | trocr_20250919 | + + ## Heuristic methods Additionally, some heuristic methods are employed to further improve the model predictions: * After border detection, the largest contour is determined by a bounding box, and the image cropped to these coordinates. -* For text region detection, the image is scaled up to make it easier for the model to detect background space between text regions. +* Unlike the non-light version, where the image is scaled up to help the model better detect the background spaces between text regions, the light version uses down-scaled images. In this case, introducing an artificial class along the boundaries of text regions and text lines has helped to isolate and separate the text regions more effectively. * A minimum area is defined for text regions in relation to the overall image dimensions, so that very small regions that are noise can be filtered out. -* Deskewing is applied on the text region level (due to regions having different degrees of skew) in order to improve the textline segmentation result. -* After deskewing, a calculation of the pixel distribution on the X-axis allows the separation of textlines (foreground) and background pixels. -* Finally, using the derived coordinates, bounding boxes are determined for each textline. +* In the non-light version, deskewing is applied at the text-region level (since regions may have different degrees of skew) to improve text-line segmentation results. In contrast, the light version performs deskewing only at the page level to enhance margin detection and heuristic reading-order estimation. +* After deskewing, a calculation of the pixel distribution on the X-axis allows the separation of textlines (foreground) and background pixels (only in non-light version). +* Finally, using the derived coordinates, bounding boxes are determined for each textline (only in non-light version). +* As mentioned above, the reading order can be determined using a model; however, this approach is computationally expensive, time-consuming, and less accurate due to the limited amount of ground-truth data available for training. Therefore, our tool uses a heuristic reading-order detection method as the default. The heuristic approach relies on headers and separators to determine the reading order of text regions. diff --git a/docs/ocrd.md b/docs/ocrd.md new file mode 100644 index 0000000..9e7e268 --- /dev/null +++ b/docs/ocrd.md @@ -0,0 +1,26 @@ +## Use as OCR-D processor + +Eynollah ships with a CLI interface to be used as [OCR-D](https://ocr-d.de) [processor](https://ocr-d.de/en/spec/cli), +formally described in [`ocrd-tool.json`](https://github.com/qurator-spk/eynollah/tree/main/src/eynollah/ocrd-tool.json). + +When using Eynollah in OCR-D, the source image file group with (preferably) RGB images should be used as input like this: + + ocrd-eynollah-segment -I OCR-D-IMG -O OCR-D-SEG -P models eynollah_layout_v0_5_0 + +If the input file group is PAGE-XML (from a previous OCR-D workflow step), Eynollah behaves as follows: +- existing regions are kept and ignored (i.e. in effect they might overlap segments from Eynollah results) +- existing annotation (and respective `AlternativeImage`s) are partially _ignored_: + - previous page frame detection (`cropped` images) + - previous derotation (`deskewed` images) + - previous thresholding (`binarized` images) +- if the page-level image nevertheless deviates from the original (`@imageFilename`) + (because some other preprocessing step was in effect like `denoised`), then + the output PAGE-XML will be based on that as new top-level (`@imageFilename`) + + ocrd-eynollah-segment -I OCR-D-XYZ -O OCR-D-SEG -P models eynollah_layout_v0_5_0 + +In general, it makes more sense to add other workflow steps **after** Eynollah. + +There is also an OCR-D processor for binarization: + + ocrd-sbb-binarize -I OCR-D-IMG -O OCR-D-BIN -P models default-2021-03-09 diff --git a/docs/train.md b/docs/train.md index 3c64ab9..82bb77c 100644 --- a/docs/train.md +++ b/docs/train.md @@ -1,3 +1,41 @@ +# Prerequisistes + +## 1. Install Eynollah with training dependencies + +Clone the repository and install eynollah along with the dependencies necessary for training: + +```sh +git clone https://github.com/qurator-spk/eynollah +cd eynollah +pip install '.[training]' +``` + +## 2. Pretrained encoder + +Download our pretrained weights and add them to a `train/pretrained_model` folder: + +```sh +cd train +wget -O pretrained_model.tar.gz https://zenodo.org/records/17243320/files/pretrained_model_v0_5_1.tar.gz?download=1 +tar xf pretrained_model.tar.gz +``` + +## 3. Example data + +### Binarization +A small sample of training data for binarization experiment can be found on [Zenodo](https://zenodo.org/records/17243320/files/training_data_sample_binarization_v0_5_1.tar.gz?download=1), +which contains `images` and `labels` folders. + +## 4. Helpful tools + +* [`pagexml2img`](https://github.com/qurator-spk/page2img) +> Tool to extract 2-D or 3-D RGB images from PAGE-XML data. In the former case, the output will be 1 2-D image array which each class has filled with a pixel value. In the case of a 3-D RGB image, +each class will be defined with a RGB value and beside images, a text file of classes will also be produced. +* [`cocoSegmentationToPng`](https://github.com/nightrome/cocostuffapi/blob/17acf33aef3c6cc2d6aca46dcf084266c2778cf0/PythonAPI/pycocotools/cocostuffhelper.py#L130) +> Convert COCO GT or results for a single image to a segmentation map and write it to disk. +* [`ocrd-segment-extract-pages`](https://github.com/OCR-D/ocrd_segment/blob/master/ocrd_segment/extract_pages.py) +> Extract region classes and their colours in mask (pseg) images. Allows the color map as free dict parameter, and comes with a default that mimics PageViewer's coloring for quick debugging; it also warns when regions do overlap. + # Training documentation This document aims to assist users in preparing training datasets, training models, and diff --git a/pyproject.toml b/pyproject.toml index e7744a1..e6821a5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,12 @@ description = "Document Layout Analysis" readme = "README.md" license.file = "LICENSE" requires-python = ">=3.8" -keywords = ["document layout analysis", "image segmentation"] +keywords = [ + "document layout analysis", + "image segmentation", + "binarization", + "optical character recognition" +] dynamic = [ "dependencies", @@ -25,6 +30,10 @@ classifiers = [ "Intended Audience :: Science/Research", "License :: OSI Approved :: Apache Software License", "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3 :: Only", "Topic :: Scientific/Engineering :: Image Processing", ] @@ -58,8 +67,6 @@ source = ["eynollah"] [tool.ruff] line-length = 120 -# TODO: Reenable and fix after release v0.6.0 -exclude = ['src/eynollah/training'] [tool.ruff.lint] ignore = [ diff --git a/requirements.txt b/requirements.txt index 5699566..53d1e39 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,3 +7,4 @@ tf-keras # avoid keras 3 (also needs TF_USE_LEGACY_KERAS=1) numba <= 0.58.1 scikit-image biopython +tabulate diff --git a/src/eynollah/cli.py b/src/eynollah/cli.py deleted file mode 100644 index e4a24e4..0000000 --- a/src/eynollah/cli.py +++ /dev/null @@ -1,589 +0,0 @@ -import sys -import click -import logging -from ocrd_utils import initLogging, getLevelName, getLogger -from eynollah.eynollah import Eynollah, Eynollah_ocr -from eynollah.sbb_binarize import SbbBinarizer -from eynollah.image_enhancer import Enhancer -from eynollah.mb_ro_on_layout import machine_based_reading_order_on_layout - -@click.group() -def main(): - pass - -@main.command() -@click.option( - "--input", - "-i", - help="PAGE-XML input filename", - type=click.Path(exists=True, dir_okay=False), -) -@click.option( - "--dir_in", - "-di", - help="directory of PAGE-XML input files (instead of --input)", - type=click.Path(exists=True, file_okay=False), -) -@click.option( - "--out", - "-o", - help="directory for output images", - type=click.Path(exists=True, file_okay=False), - required=True, -) -@click.option( - "--model", - "-m", - help="directory of models", - type=click.Path(exists=True, file_okay=False), - required=True, -) -@click.option( - "--log_level", - "-l", - type=click.Choice(['OFF', 'DEBUG', 'INFO', 'WARN', 'ERROR']), - help="Override log level globally to this", -) - -def machine_based_reading_order(input, dir_in, out, model, log_level): - assert bool(input) != bool(dir_in), "Either -i (single input) or -di (directory) must be provided, but not both." - orderer = machine_based_reading_order_on_layout(model) - if log_level: - orderer.logger.setLevel(getLevelName(log_level)) - - orderer.run(xml_filename=input, - dir_in=dir_in, - dir_out=out, - ) - - -@main.command() -@click.option('--patches/--no-patches', default=True, help='by enabling this parameter you let the model to see the image in patches.') -@click.option('--model_dir', '-m', type=click.Path(exists=True, file_okay=False), required=True, help='directory containing models for prediction') -@click.option( - "--input-image", "--image", - "-i", - help="input image filename", - type=click.Path(exists=True, dir_okay=False) -) -@click.option( - "--dir_in", - "-di", - help="directory of input images (instead of --image)", - type=click.Path(exists=True, file_okay=False), -) -@click.option( - "--output", - "-o", - help="output image (if using -i) or output image directory (if using -di)", - type=click.Path(file_okay=True, dir_okay=True), - required=True, -) -@click.option( - "--overwrite", - "-O", - help="overwrite (instead of skipping) if output xml exists", - is_flag=True, -) -@click.option( - "--log_level", - "-l", - type=click.Choice(['OFF', 'DEBUG', 'INFO', 'WARN', 'ERROR']), - help="Override log level globally to this", -) -def binarization(patches, model_dir, input_image, dir_in, output, overwrite, log_level): - assert bool(input_image) != bool(dir_in), "Either -i (single input) or -di (directory) must be provided, but not both." - binarizer = SbbBinarizer(model_dir) - if log_level: - binarizer.logger.setLevel(getLevelName(log_level)) - binarizer.run(overwrite=overwrite, - use_patches=patches, - image_path=input_image, - output=output, - dir_in=dir_in) - - -@main.command() -@click.option( - "--image", - "-i", - help="input image filename", - type=click.Path(exists=True, dir_okay=False), -) - -@click.option( - "--out", - "-o", - help="directory for output PAGE-XML files", - type=click.Path(exists=True, file_okay=False), - required=True, -) -@click.option( - "--overwrite", - "-O", - help="overwrite (instead of skipping) if output xml exists", - is_flag=True, -) -@click.option( - "--dir_in", - "-di", - help="directory of input images (instead of --image)", - type=click.Path(exists=True, file_okay=False), -) -@click.option( - "--model", - "-m", - help="directory of models", - type=click.Path(exists=True, file_okay=False), - required=True, -) - -@click.option( - "--num_col_upper", - "-ncu", - help="lower limit of columns in document image", -) -@click.option( - "--num_col_lower", - "-ncl", - help="upper limit of columns in document image", -) -@click.option( - "--save_org_scale/--no_save_org_scale", - "-sos/-nosos", - is_flag=True, - help="if this parameter set to true, this tool will save the enhanced image in org scale.", -) -@click.option( - "--log_level", - "-l", - type=click.Choice(['OFF', 'DEBUG', 'INFO', 'WARN', 'ERROR']), - help="Override log level globally to this", -) - -def enhancement(image, out, overwrite, dir_in, model, num_col_upper, num_col_lower, save_org_scale, log_level): - assert bool(image) != bool(dir_in), "Either -i (single input) or -di (directory) must be provided, but not both." - initLogging() - enhancer = Enhancer( - model, - num_col_upper=num_col_upper, - num_col_lower=num_col_lower, - save_org_scale=save_org_scale, - ) - if log_level: - enhancer.logger.setLevel(getLevelName(log_level)) - enhancer.run(overwrite=overwrite, - dir_in=dir_in, - image_filename=image, - dir_out=out, - ) - -@main.command() -@click.option( - "--image", - "-i", - help="input image filename", - type=click.Path(exists=True, dir_okay=False), -) - -@click.option( - "--out", - "-o", - help="directory for output PAGE-XML files", - type=click.Path(exists=True, file_okay=False), - required=True, -) -@click.option( - "--overwrite", - "-O", - help="overwrite (instead of skipping) if output xml exists", - is_flag=True, -) -@click.option( - "--dir_in", - "-di", - help="directory of input images (instead of --image)", - type=click.Path(exists=True, file_okay=False), -) -@click.option( - "--model", - "-m", - help="directory of models", - type=click.Path(exists=True, file_okay=False), - required=True, -) -@click.option( - "--model_version", - "-mv", - help="override default versions of model categories", - type=(str, str), - multiple=True, -) -@click.option( - "--save_images", - "-si", - help="if a directory is given, images in documents will be cropped and saved there", - type=click.Path(exists=True, file_okay=False), -) -@click.option( - "--save_layout", - "-sl", - help="if a directory is given, plot of layout will be saved there", - type=click.Path(exists=True, file_okay=False), -) -@click.option( - "--save_deskewed", - "-sd", - help="if a directory is given, deskewed image will be saved there", - type=click.Path(exists=True, file_okay=False), -) -@click.option( - "--save_all", - "-sa", - help="if a directory is given, all plots needed for documentation will be saved there", - type=click.Path(exists=True, file_okay=False), -) -@click.option( - "--save_page", - "-sp", - help="if a directory is given, page crop of image will be saved there", - type=click.Path(exists=True, file_okay=False), -) -@click.option( - "--enable-plotting/--disable-plotting", - "-ep/-noep", - is_flag=True, - help="If set, will plot intermediary files and images", -) -@click.option( - "--extract_only_images/--disable-extracting_only_images", - "-eoi/-noeoi", - is_flag=True, - help="If a directory is given, only images in documents will be cropped and saved there and the other processing will not be done", -) -@click.option( - "--allow-enhancement/--no-allow-enhancement", - "-ae/-noae", - is_flag=True, - help="if this parameter set to true, this tool would check that input image need resizing and enhancement or not. If so output of resized and enhanced image and corresponding layout data will be written in out directory", -) -@click.option( - "--curved-line/--no-curvedline", - "-cl/-nocl", - is_flag=True, - help="if this parameter set to true, this tool will try to return contoure of textlines instead of rectangle bounding box of textline. This should be taken into account that with this option the tool need more time to do process.", -) -@click.option( - "--textline_light/--no-textline_light", - "-tll/-notll", - is_flag=True, - help="if this parameter set to true, this tool will try to return contoure of textlines instead of rectangle bounding box of textline with a faster method.", -) -@click.option( - "--full-layout/--no-full-layout", - "-fl/-nofl", - is_flag=True, - help="if this parameter set to true, this tool will try to return all elements of layout.", -) -@click.option( - "--tables/--no-tables", - "-tab/-notab", - is_flag=True, - help="if this parameter set to true, this tool will try to detect tables.", -) -@click.option( - "--right2left/--left2right", - "-r2l/-l2r", - is_flag=True, - help="if this parameter set to true, this tool will extract right-to-left reading order.", -) -@click.option( - "--input_binary/--input-RGB", - "-ib/-irgb", - is_flag=True, - help="in general, eynollah uses RGB as input but if the input document is strongly dark, bright or for any other reason you can turn binarized input on. This option does not mean that you have to provide a binary image, otherwise this means that the tool itself will binarized the RGB input document.", -) -@click.option( - "--allow_scaling/--no-allow-scaling", - "-as/-noas", - is_flag=True, - help="if this parameter set to true, this tool would check the scale and if needed it will scale it to perform better layout detection", -) -@click.option( - "--headers_off/--headers-on", - "-ho/-noho", - is_flag=True, - help="if this parameter set to true, this tool would ignore headers role in reading order", -) -@click.option( - "--light_version/--original", - "-light/-org", - is_flag=True, - help="if this parameter set to true, this tool would use lighter version", -) -@click.option( - "--ignore_page_extraction/--extract_page_included", - "-ipe/-epi", - is_flag=True, - help="if this parameter set to true, this tool would ignore page extraction", -) -@click.option( - "--reading_order_machine_based/--heuristic_reading_order", - "-romb/-hro", - is_flag=True, - help="if this parameter set to true, this tool would apply machine based reading order detection", -) -@click.option( - "--do_ocr", - "-ocr/-noocr", - is_flag=True, - help="if this parameter set to true, this tool will try to do ocr", -) -@click.option( - "--transformer_ocr", - "-tr/-notr", - is_flag=True, - help="if this parameter set to true, this tool will apply transformer ocr", -) -@click.option( - "--batch_size_ocr", - "-bs_ocr", - help="number of inference batch size of ocr model. Default b_s for trocr and cnn_rnn models are 2 and 8 respectively", -) -@click.option( - "--num_col_upper", - "-ncu", - help="lower limit of columns in document image", -) -@click.option( - "--num_col_lower", - "-ncl", - help="upper limit of columns in document image", -) -@click.option( - "--threshold_art_class_layout", - "-tharl", - help="threshold of artifical class in the case of layout detection. The default value is 0.1", -) -@click.option( - "--threshold_art_class_textline", - "-thart", - help="threshold of artifical class in the case of textline detection. The default value is 0.1", -) -@click.option( - "--skip_layout_and_reading_order", - "-slro/-noslro", - is_flag=True, - help="if this parameter set to true, this tool will ignore layout detection and reading order. It means that textline detection will be done within printspace and contours of textline will be written in xml output file.", -) -# TODO move to top-level CLI context -@click.option( - "--log_level", - "-l", - type=click.Choice(['OFF', 'DEBUG', 'INFO', 'WARN', 'ERROR']), - help="Override 'eynollah' log level globally to this", -) -# -@click.option( - "--setup-logging", - is_flag=True, - help="Setup a basic console logger", -) - -def layout(image, out, overwrite, dir_in, model, model_version, save_images, save_layout, save_deskewed, save_all, extract_only_images, save_page, enable_plotting, allow_enhancement, curved_line, textline_light, full_layout, tables, right2left, input_binary, allow_scaling, headers_off, light_version, reading_order_machine_based, do_ocr, transformer_ocr, batch_size_ocr, num_col_upper, num_col_lower, threshold_art_class_textline, threshold_art_class_layout, skip_layout_and_reading_order, ignore_page_extraction, log_level, setup_logging): - if setup_logging: - console_handler = logging.StreamHandler(sys.stdout) - console_handler.setLevel(logging.INFO) - formatter = logging.Formatter('%(message)s') - console_handler.setFormatter(formatter) - getLogger('eynollah').addHandler(console_handler) - getLogger('eynollah').setLevel(logging.INFO) - else: - initLogging() - assert enable_plotting or not save_layout, "Plotting with -sl also requires -ep" - assert enable_plotting or not save_deskewed, "Plotting with -sd also requires -ep" - assert enable_plotting or not save_all, "Plotting with -sa also requires -ep" - assert enable_plotting or not save_page, "Plotting with -sp also requires -ep" - assert enable_plotting or not save_images, "Plotting with -si also requires -ep" - assert enable_plotting or not allow_enhancement, "Plotting with -ae also requires -ep" - assert not enable_plotting or save_layout or save_deskewed or save_all or save_page or save_images or allow_enhancement, \ - "Plotting with -ep also requires -sl, -sd, -sa, -sp, -si or -ae" - assert textline_light == light_version, "Both light textline detection -tll and light version -light must be set or unset equally" - assert not extract_only_images or not allow_enhancement, "Image extraction -eoi can not be set alongside allow_enhancement -ae" - assert not extract_only_images or not allow_scaling, "Image extraction -eoi can not be set alongside allow_scaling -as" - assert not extract_only_images or not light_version, "Image extraction -eoi can not be set alongside light_version -light" - assert not extract_only_images or not curved_line, "Image extraction -eoi can not be set alongside curved_line -cl" - assert not extract_only_images or not textline_light, "Image extraction -eoi can not be set alongside textline_light -tll" - assert not extract_only_images or not full_layout, "Image extraction -eoi can not be set alongside full_layout -fl" - assert not extract_only_images or not tables, "Image extraction -eoi can not be set alongside tables -tab" - assert not extract_only_images or not right2left, "Image extraction -eoi can not be set alongside right2left -r2l" - assert not extract_only_images or not headers_off, "Image extraction -eoi can not be set alongside headers_off -ho" - assert bool(image) != bool(dir_in), "Either -i (single input) or -di (directory) must be provided, but not both." - eynollah = Eynollah( - model, - model_versions=model_version, - extract_only_images=extract_only_images, - enable_plotting=enable_plotting, - allow_enhancement=allow_enhancement, - curved_line=curved_line, - textline_light=textline_light, - full_layout=full_layout, - tables=tables, - right2left=right2left, - input_binary=input_binary, - allow_scaling=allow_scaling, - headers_off=headers_off, - light_version=light_version, - ignore_page_extraction=ignore_page_extraction, - reading_order_machine_based=reading_order_machine_based, - do_ocr=do_ocr, - transformer_ocr=transformer_ocr, - batch_size_ocr=batch_size_ocr, - num_col_upper=num_col_upper, - num_col_lower=num_col_lower, - skip_layout_and_reading_order=skip_layout_and_reading_order, - threshold_art_class_textline=threshold_art_class_textline, - threshold_art_class_layout=threshold_art_class_layout, - ) - if log_level: - eynollah.logger.setLevel(getLevelName(log_level)) - eynollah.run(overwrite=overwrite, - image_filename=image, - dir_in=dir_in, - dir_out=out, - dir_of_cropped_images=save_images, - dir_of_layout=save_layout, - dir_of_deskewed=save_deskewed, - dir_of_all=save_all, - dir_save_page=save_page, - ) - -@main.command() -@click.option( - "--image", - "-i", - help="input image filename", - type=click.Path(exists=True, dir_okay=False), -) -@click.option( - "--dir_in", - "-di", - help="directory of input images (instead of --image)", - type=click.Path(exists=True, file_okay=False), -) -@click.option( - "--dir_in_bin", - "-dib", - help="directory of binarized images (in addition to --dir_in for RGB images; filename stems must match the RGB image files, with '.png' suffix).\nPerform prediction using both RGB and binary images. (This does not necessarily improve results, however it may be beneficial for certain document images.)", - type=click.Path(exists=True, file_okay=False), -) -@click.option( - "--dir_xmls", - "-dx", - help="directory of input PAGE-XML files (in addition to --dir_in; filename stems must match the image files, with '.xml' suffix).", - type=click.Path(exists=True, file_okay=False), - required=True, -) -@click.option( - "--out", - "-o", - help="directory for output PAGE-XML files", - type=click.Path(exists=True, file_okay=False), - required=True, -) -@click.option( - "--dir_out_image_text", - "-doit", - help="directory for output images, newly rendered with predicted text", - type=click.Path(exists=True, file_okay=False), -) -@click.option( - "--overwrite", - "-O", - help="overwrite (instead of skipping) if output xml exists", - is_flag=True, -) -@click.option( - "--model", - "-m", - help="directory of models", - type=click.Path(exists=True, file_okay=False), -) -@click.option( - "--model_name", - help="Specific model file path to use for OCR", - type=click.Path(exists=True, file_okay=False), -) -@click.option( - "--tr_ocr", - "-trocr/-notrocr", - is_flag=True, - help="if this parameter set to true, transformer ocr will be applied, otherwise cnn_rnn model.", -) -@click.option( - "--export_textline_images_and_text", - "-etit/-noetit", - is_flag=True, - help="if this parameter set to true, images and text in xml will be exported into output dir. This files can be used for training a OCR engine.", -) -@click.option( - "--do_not_mask_with_textline_contour", - "-nmtc/-mtc", - is_flag=True, - help="if this parameter set to true, cropped textline images will not be masked with textline contour.", -) -@click.option( - "--batch_size", - "-bs", - help="number of inference batch size. Default b_s for trocr and cnn_rnn models are 2 and 8 respectively", -) -@click.option( - "--dataset_abbrevation", - "-ds_pref", - help="in the case of extracting textline and text from a xml GT file user can add an abbrevation of dataset name to generated dataset", -) -@click.option( - "--min_conf_value_of_textline_text", - "-min_conf", - help="minimum OCR confidence value. Text lines with a confidence value lower than this threshold will not be included in the output XML file.", -) -@click.option( - "--log_level", - "-l", - type=click.Choice(['OFF', 'DEBUG', 'INFO', 'WARN', 'ERROR']), - help="Override log level globally to this", -) - -def ocr(image, dir_in, dir_in_bin, dir_xmls, out, dir_out_image_text, overwrite, model, model_name, tr_ocr, export_textline_images_and_text, do_not_mask_with_textline_contour, batch_size, dataset_abbrevation, min_conf_value_of_textline_text, log_level): - initLogging() - - assert bool(model) != bool(model_name), "Either -m (model directory) or --model_name (specific model name) must be provided." - assert not export_textline_images_and_text or not tr_ocr, "Exporting textline and text -etit can not be set alongside transformer ocr -tr_ocr" - assert not export_textline_images_and_text or not model, "Exporting textline and text -etit can not be set alongside model -m" - assert not export_textline_images_and_text or not batch_size, "Exporting textline and text -etit can not be set alongside batch size -bs" - assert not export_textline_images_and_text or not dir_in_bin, "Exporting textline and text -etit can not be set alongside directory of bin images -dib" - assert not export_textline_images_and_text or not dir_out_image_text, "Exporting textline and text -etit can not be set alongside directory of images with predicted text -doit" - assert bool(image) != bool(dir_in), "Either -i (single image) or -di (directory) must be provided, but not both." - eynollah_ocr = Eynollah_ocr( - dir_models=model, - model_name=model_name, - tr_ocr=tr_ocr, - export_textline_images_and_text=export_textline_images_and_text, - do_not_mask_with_textline_contour=do_not_mask_with_textline_contour, - batch_size=batch_size, - pref_of_dataset=dataset_abbrevation, - min_conf_value_of_textline_text=min_conf_value_of_textline_text, - ) - if log_level: - eynollah_ocr.logger.setLevel(getLevelName(log_level)) - eynollah_ocr.run(overwrite=overwrite, - dir_in=dir_in, - dir_in_bin=dir_in_bin, - image_filename=image, - dir_xmls=dir_xmls, - dir_out_image_text=dir_out_image_text, - dir_out=out, - ) - -if __name__ == "__main__": - main() diff --git a/src/eynollah/cli/__init__.py b/src/eynollah/cli/__init__.py new file mode 100644 index 0000000..05dafa1 --- /dev/null +++ b/src/eynollah/cli/__init__.py @@ -0,0 +1,22 @@ +# NOTE: For predictable order of imports of torch/shapely/tensorflow +# this must be the first import of the CLI! +from ..eynollah_imports import imported_libs + +from .cli_models import models_cli +from .cli_binarize import binarize_cli + +from .cli import main +from .cli_binarize import binarize_cli +from .cli_enhance import enhance_cli +from .cli_extract_images import extract_images_cli +from .cli_layout import layout_cli +from .cli_ocr import ocr_cli +from .cli_readingorder import readingorder_cli + +main.add_command(binarize_cli, 'binarization') +main.add_command(enhance_cli, 'enhancement') +main.add_command(layout_cli, 'layout') +main.add_command(readingorder_cli, 'machine-based-reading-order') +main.add_command(models_cli, 'models') +main.add_command(ocr_cli, 'ocr') +main.add_command(extract_images_cli, 'extract-images') diff --git a/src/eynollah/cli/cli.py b/src/eynollah/cli/cli.py new file mode 100644 index 0000000..b374fa8 --- /dev/null +++ b/src/eynollah/cli/cli.py @@ -0,0 +1,66 @@ +from dataclasses import dataclass +import logging +import sys +import os +from typing import Union + +import click + +from ..model_zoo import EynollahModelZoo +from .cli_models import models_cli + +@dataclass() +class EynollahCliCtx: + """ + Holds options relevant for all eynollah subcommands + """ + model_zoo: EynollahModelZoo + log_level : Union[str, None] = 'INFO' + + +@click.group() +@click.option( + "--model-basedir", + "-m", + help="directory of models", + # NOTE: not mandatory to exist so --help for subcommands works but will log a warning + # and raise exception when trying to load models in the CLI + # type=click.Path(exists=True), + default=f'{os.getcwd()}/models_eynollah', +) +@click.option( + "--model-overrides", + "-mv", + help="override default versions of model categories, syntax is 'CATEGORY VARIANT PATH', e.g 'region light /path/to/model'. See eynollah list-models for the full list", + type=(str, str, str), + multiple=True, +) +@click.option( + "--log_level", + "-l", + type=click.Choice(['OFF', 'DEBUG', 'INFO', 'WARN', 'ERROR']), + help="Override log level globally to this", +) +@click.pass_context +def main(ctx, model_basedir, model_overrides, log_level): + """ + eynollah - Document Layout Analysis, Image Enhancement, OCR + """ + # Initialize logging + console_handler = logging.StreamHandler(sys.stderr) + console_handler.setLevel(logging.NOTSET) + formatter = logging.Formatter('%(asctime)s.%(msecs)03d %(levelname)s %(name)s - %(message)s', datefmt='%H:%M:%S') + console_handler.setFormatter(formatter) + logging.getLogger('eynollah').addHandler(console_handler) + logging.getLogger('eynollah').setLevel(log_level or logging.INFO) + # Initialize model zoo + model_zoo = EynollahModelZoo(basedir=model_basedir, model_overrides=model_overrides) + # Initialize CLI context + ctx.obj = EynollahCliCtx( + model_zoo=model_zoo, + log_level=log_level, + ) + + +if __name__ == "__main__": + main() diff --git a/src/eynollah/cli/cli_binarize.py b/src/eynollah/cli/cli_binarize.py new file mode 100644 index 0000000..aa6cefc --- /dev/null +++ b/src/eynollah/cli/cli_binarize.py @@ -0,0 +1,52 @@ +import click + +@click.command() +@click.option('--patches/--no-patches', default=True, help='by enabling this parameter you let the model to see the image in patches.') +@click.option( + "--input-image", "--image", + "-i", + help="input image filename", + type=click.Path(exists=True, dir_okay=False) +) +@click.option( + "--dir_in", + "-di", + help="directory of input images (instead of --image)", + type=click.Path(exists=True, file_okay=False), +) +@click.option( + "--output", + "-o", + help="output image (if using -i) or output image directory (if using -di)", + type=click.Path(file_okay=True, dir_okay=True), + required=True, +) +@click.option( + "--overwrite", + "-O", + help="overwrite (instead of skipping) if output xml exists", + is_flag=True, +) +@click.pass_context +def binarize_cli( + ctx, + patches, + input_image, + dir_in, + output, + overwrite, +): + """ + Binarize images with a ML model + """ + from ..sbb_binarize import SbbBinarizer + assert bool(input_image) != bool(dir_in), "Either -i (single input) or -di (directory) must be provided, but not both." + binarizer = SbbBinarizer(model_zoo=ctx.obj.model_zoo) + binarizer.run( + image_path=input_image, + use_patches=patches, + output=output, + dir_in=dir_in, + overwrite=overwrite + ) + diff --git a/src/eynollah/cli/cli_enhance.py b/src/eynollah/cli/cli_enhance.py new file mode 100644 index 0000000..fa4158b --- /dev/null +++ b/src/eynollah/cli/cli_enhance.py @@ -0,0 +1,63 @@ +import click + +@click.command() +@click.option( + "--image", + "-i", + help="input image filename", + type=click.Path(exists=True, dir_okay=False), +) +@click.option( + "--out", + "-o", + help="directory for output PAGE-XML files", + type=click.Path(exists=True, file_okay=False), + required=True, +) +@click.option( + "--overwrite", + "-O", + help="overwrite (instead of skipping) if output xml exists", + is_flag=True, +) +@click.option( + "--dir_in", + "-di", + help="directory of input images (instead of --image)", + type=click.Path(exists=True, file_okay=False), +) +@click.option( + "--num_col_upper", + "-ncu", + help="lower limit of columns in document image", +) +@click.option( + "--num_col_lower", + "-ncl", + help="upper limit of columns in document image", +) +@click.option( + "--save_org_scale/--no_save_org_scale", + "-sos/-nosos", + is_flag=True, + help="if this parameter set to true, this tool will save the enhanced image in org scale.", +) +@click.pass_context +def enhance_cli(ctx, image, out, overwrite, dir_in, num_col_upper, num_col_lower, save_org_scale): + """ + Enhance image + """ + assert bool(image) != bool(dir_in), "Either -i (single input) or -di (directory) must be provided, but not both." + from ..image_enhancer import Enhancer + enhancer = Enhancer( + model_zoo=ctx.obj.model_zoo, + num_col_upper=num_col_upper, + num_col_lower=num_col_lower, + save_org_scale=save_org_scale, + ) + enhancer.run(overwrite=overwrite, + dir_in=dir_in, + image_filename=image, + dir_out=out, + ) + diff --git a/src/eynollah/cli/cli_extract_images.py b/src/eynollah/cli/cli_extract_images.py new file mode 100644 index 0000000..0add5b5 --- /dev/null +++ b/src/eynollah/cli/cli_extract_images.py @@ -0,0 +1,100 @@ +import click + +@click.command() +@click.option( + "--image", + "-i", + help="input image filename", + type=click.Path(exists=True, dir_okay=False), +) + +@click.option( + "--out", + "-o", + help="directory for output PAGE-XML files", + type=click.Path(exists=True, file_okay=False), + required=True, +) +@click.option( + "--overwrite", + "-O", + help="overwrite (instead of skipping) if output xml exists", + is_flag=True, +) +@click.option( + "--dir_in", + "-di", + help="directory of input images (instead of --image)", + type=click.Path(exists=True, file_okay=False), +) +@click.option( + "--save_images", + "-si", + help="if a directory is given, images in documents will be cropped and saved there", + type=click.Path(exists=True, file_okay=False), +) +@click.option( + "--enable-plotting/--disable-plotting", + "-ep/-noep", + is_flag=True, + help="If set, will plot intermediary files and images", +) +@click.option( + "--input_binary/--input-RGB", + "-ib/-irgb", + is_flag=True, + help="In general, eynollah uses RGB as input but if the input document is very dark, very bright or for any other reason you can turn on input binarization. When this flag is set, eynollah will binarize the RGB input document, you should always provide RGB images to eynollah.", +) +@click.option( + "--ignore_page_extraction/--extract_page_included", + "-ipe/-epi", + is_flag=True, + help="if this parameter set to true, this tool would ignore page extraction", +) +@click.option( + "--num_col_upper", + "-ncu", + help="lower limit of columns in document image", +) +@click.option( + "--num_col_lower", + "-ncl", + help="upper limit of columns in document image", +) +@click.pass_context +def extract_images_cli( + ctx, + image, + out, + overwrite, + dir_in, + save_images, + enable_plotting, + input_binary, + num_col_upper, + num_col_lower, + ignore_page_extraction, +): + """ + Detect Layout (with optional image enhancement and reading order detection) + """ + assert enable_plotting or not save_images, "Plotting with -si also requires -ep" + assert not enable_plotting or save_images, "Plotting with -ep also requires -si" + assert bool(image) != bool(dir_in), "Either -i (single input) or -di (directory) must be provided, but not both." + + from ..extract_images import EynollahImageExtractor + extractor = EynollahImageExtractor( + model_zoo=ctx.obj.model_zoo, + enable_plotting=enable_plotting, + input_binary=input_binary, + ignore_page_extraction=ignore_page_extraction, + num_col_upper=num_col_upper, + num_col_lower=num_col_lower, + ) + extractor.run(overwrite=overwrite, + image_filename=image, + dir_in=dir_in, + dir_out=out, + dir_of_cropped_images=save_images, + ) + diff --git a/src/eynollah/cli/cli_layout.py b/src/eynollah/cli/cli_layout.py new file mode 100644 index 0000000..df66993 --- /dev/null +++ b/src/eynollah/cli/cli_layout.py @@ -0,0 +1,223 @@ +import click + +@click.command() +@click.option( + "--image", + "-i", + help="input image filename", + type=click.Path(exists=True, dir_okay=False), +) + +@click.option( + "--out", + "-o", + help="directory for output PAGE-XML files", + type=click.Path(exists=True, file_okay=False), + required=True, +) +@click.option( + "--overwrite", + "-O", + help="overwrite (instead of skipping) if output xml exists", + is_flag=True, +) +@click.option( + "--dir_in", + "-di", + help="directory of input images (instead of --image)", + type=click.Path(exists=True, file_okay=False), +) +@click.option( + "--save_images", + "-si", + help="if a directory is given, images in documents will be cropped and saved there", + type=click.Path(exists=True, file_okay=False), +) +@click.option( + "--save_layout", + "-sl", + help="if a directory is given, plot of layout will be saved there", + type=click.Path(exists=True, file_okay=False), +) +@click.option( + "--save_deskewed", + "-sd", + help="if a directory is given, deskewed image will be saved there", + type=click.Path(exists=True, file_okay=False), +) +@click.option( + "--save_all", + "-sa", + help="if a directory is given, all plots needed for documentation will be saved there", + type=click.Path(exists=True, file_okay=False), +) +@click.option( + "--save_page", + "-sp", + help="if a directory is given, page crop of image will be saved there", + type=click.Path(exists=True, file_okay=False), +) +@click.option( + "--enable-plotting/--disable-plotting", + "-ep/-noep", + is_flag=True, + help="If set, will plot intermediary files and images", +) +@click.option( + "--allow-enhancement/--no-allow-enhancement", + "-ae/-noae", + is_flag=True, + help="if this parameter set to true, this tool would check that input image need resizing and enhancement or not. If so output of resized and enhanced image and corresponding layout data will be written in out directory", +) +@click.option( + "--curved-line/--no-curvedline", + "-cl/-nocl", + is_flag=True, + help="if this parameter set to true, this tool will try to return contoure of textlines instead of rectangle bounding box of textline. This should be taken into account that with this option the tool need more time to do process.", +) +@click.option( + "--full-layout/--no-full-layout", + "-fl/-nofl", + is_flag=True, + help="if this parameter set to true, this tool will try to return all elements of layout.", +) +@click.option( + "--tables/--no-tables", + "-tab/-notab", + is_flag=True, + help="if this parameter set to true, this tool will try to detect tables.", +) +@click.option( + "--right2left/--left2right", + "-r2l/-l2r", + is_flag=True, + help="if this parameter set to true, this tool will extract right-to-left reading order.", +) +@click.option( + "--input_binary/--input-RGB", + "-ib/-irgb", + is_flag=True, + help="In general, eynollah uses RGB as input but if the input document is very dark, very bright or for any other reason you can turn on input binarization. When this flag is set, eynollah will binarize the RGB input document, you should always provide RGB images to eynollah.", +) +@click.option( + "--allow_scaling/--no-allow-scaling", + "-as/-noas", + is_flag=True, + help="if this parameter set to true, this tool would check the scale and if needed it will scale it to perform better layout detection", +) +@click.option( + "--headers_off/--headers-on", + "-ho/-noho", + is_flag=True, + help="if this parameter set to true, this tool would ignore headers role in reading order", +) +@click.option( + "--ignore_page_extraction/--extract_page_included", + "-ipe/-epi", + is_flag=True, + help="if this parameter set to true, this tool would ignore page extraction", +) +@click.option( + "--reading_order_machine_based/--heuristic_reading_order", + "-romb/-hro", + is_flag=True, + help="if this parameter set to true, this tool would apply machine based reading order detection", +) +@click.option( + "--num_col_upper", + "-ncu", + help="lower limit of columns in document image", +) +@click.option( + "--num_col_lower", + "-ncl", + help="upper limit of columns in document image", +) +@click.option( + "--threshold_art_class_layout", + "-tharl", + help="threshold of artifical class in the case of layout detection. The default value is 0.1", +) +@click.option( + "--threshold_art_class_textline", + "-thart", + help="threshold of artifical class in the case of textline detection. The default value is 0.1", +) +@click.option( + "--skip_layout_and_reading_order", + "-slro/-noslro", + is_flag=True, + help="if this parameter set to true, this tool will ignore layout detection and reading order. It means that textline detection will be done within printspace and contours of textline will be written in xml output file.", +) +@click.pass_context +def layout_cli( + ctx, + image, + out, + overwrite, + dir_in, + save_images, + save_layout, + save_deskewed, + save_all, + save_page, + enable_plotting, + allow_enhancement, + curved_line, + full_layout, + tables, + right2left, + input_binary, + allow_scaling, + headers_off, + reading_order_machine_based, + num_col_upper, + num_col_lower, + threshold_art_class_textline, + threshold_art_class_layout, + skip_layout_and_reading_order, + ignore_page_extraction, +): + """ + Detect Layout (with optional image enhancement and reading order detection) + """ + from ..eynollah import Eynollah + assert enable_plotting or not save_layout, "Plotting with -sl also requires -ep" + assert enable_plotting or not save_deskewed, "Plotting with -sd also requires -ep" + assert enable_plotting or not save_all, "Plotting with -sa also requires -ep" + assert enable_plotting or not save_page, "Plotting with -sp also requires -ep" + assert enable_plotting or not save_images, "Plotting with -si also requires -ep" + assert enable_plotting or not allow_enhancement, "Plotting with -ae also requires -ep" + assert not enable_plotting or save_layout or save_deskewed or save_all or save_page or save_images or allow_enhancement, \ + "Plotting with -ep also requires -sl, -sd, -sa, -sp, -si or -ae" + assert bool(image) != bool(dir_in), "Either -i (single input) or -di (directory) must be provided, but not both." + eynollah = Eynollah( + model_zoo=ctx.obj.model_zoo, + enable_plotting=enable_plotting, + allow_enhancement=allow_enhancement, + curved_line=curved_line, + full_layout=full_layout, + tables=tables, + right2left=right2left, + input_binary=input_binary, + allow_scaling=allow_scaling, + headers_off=headers_off, + ignore_page_extraction=ignore_page_extraction, + reading_order_machine_based=reading_order_machine_based, + num_col_upper=num_col_upper, + num_col_lower=num_col_lower, + skip_layout_and_reading_order=skip_layout_and_reading_order, + threshold_art_class_textline=threshold_art_class_textline, + threshold_art_class_layout=threshold_art_class_layout, + ) + eynollah.run(overwrite=overwrite, + image_filename=image, + dir_in=dir_in, + dir_out=out, + dir_of_cropped_images=save_images, + dir_of_layout=save_layout, + dir_of_deskewed=save_deskewed, + dir_of_all=save_all, + dir_save_page=save_page, + ) + diff --git a/src/eynollah/cli/cli_models.py b/src/eynollah/cli/cli_models.py new file mode 100644 index 0000000..f3de596 --- /dev/null +++ b/src/eynollah/cli/cli_models.py @@ -0,0 +1,69 @@ +from pathlib import Path +from typing import Set, Tuple +import click + +from eynollah.model_zoo.default_specs import MODELS_VERSION + +@click.group() +@click.pass_context +def models_cli( + ctx, +): + """ + Organize models for the various runners in eynollah. + """ + assert ctx.obj.model_zoo + + +@models_cli.command('list') +@click.pass_context +def list_models( + ctx, +): + """ + List all the models in the zoo + """ + print(f"Model basedir: {ctx.obj.model_zoo.model_basedir}") + print(f"Model overrides: {ctx.obj.model_zoo.model_overrides}") + print(ctx.obj.model_zoo) + + +@models_cli.command('package') +@click.option( + '--set-version', '-V', 'version', help="Version to use for packaging", default=MODELS_VERSION, show_default=True +) +@click.argument('output_dir') +@click.pass_context +def package( + ctx, + version, + output_dir, +): + """ + Generate shell code to copy all the models in the zoo into properly named folders in OUTPUT_DIR for distribution. + + eynollah models -m SRC package OUTPUT_DIR + + SRC should contain a directory "models_eynollah" containing all the models. + """ + mkdirs: Set[Path] = set([]) + copies: Set[Tuple[Path, Path]] = set([]) + for spec in ctx.obj.model_zoo.specs.specs: + # skip these as they are dependent on the ocr model + if spec.category in ('num_to_char', 'characters'): + continue + src: Path = ctx.obj.model_zoo.model_path(spec.category, spec.variant) + # Only copy the top-most directory relative to models_eynollah + while src.parent.name != 'models_eynollah': + src = src.parent + for dist in spec.dists: + dist_dir = Path(f"{output_dir}/models_{dist}_{version}/models_eynollah") + copies.add((src, dist_dir)) + mkdirs.add(dist_dir) + for dir in mkdirs: + print(f"mkdir -vp {dir}") + for (src, dst) in copies: + print(f"cp -vr {src} {dst}") + for dir in mkdirs: + zip_path = Path(f'../{dir.parent.name}.zip') + print(f"(cd {dir}/..; zip -vr {zip_path} models_eynollah)") diff --git a/src/eynollah/cli/cli_ocr.py b/src/eynollah/cli/cli_ocr.py new file mode 100644 index 0000000..406af61 --- /dev/null +++ b/src/eynollah/cli/cli_ocr.py @@ -0,0 +1,103 @@ +import click + +@click.command() +@click.option( + "--image", + "-i", + help="input image filename", + type=click.Path(exists=True, dir_okay=False), +) +@click.option( + "--dir_in", + "-di", + help="directory of input images (instead of --image)", + type=click.Path(exists=True, file_okay=False), +) +@click.option( + "--dir_in_bin", + "-dib", + help=("directory of binarized images (in addition to --dir_in for RGB images; filename stems must match the RGB image files, with '.png' \n Perform prediction using both RGB and binary images. (This does not necessarily improve results, however it may be beneficial for certain document images."), + type=click.Path(exists=True, file_okay=False), +) +@click.option( + "--dir_xmls", + "-dx", + help="directory of input PAGE-XML files (in addition to --dir_in; filename stems must match the image files, with '.xml' suffix).", + type=click.Path(exists=True, file_okay=False), + required=True, +) +@click.option( + "--out", + "-o", + help="directory for output PAGE-XML files", + type=click.Path(exists=True, file_okay=False), + required=True, +) +@click.option( + "--dir_out_image_text", + "-doit", + help="directory for output images, newly rendered with predicted text", + type=click.Path(exists=True, file_okay=False), +) +@click.option( + "--overwrite", + "-O", + help="overwrite (instead of skipping) if output xml exists", + is_flag=True, +) +@click.option( + "--tr_ocr", + "-trocr/-notrocr", + is_flag=True, + help="if this parameter set to true, transformer ocr will be applied, otherwise cnn_rnn model.", +) +@click.option( + "--do_not_mask_with_textline_contour", + "-nmtc/-mtc", + is_flag=True, + help="if this parameter set to true, cropped textline images will not be masked with textline contour.", +) +@click.option( + "--batch_size", + "-bs", + help="number of inference batch size. Default b_s for trocr and cnn_rnn models are 2 and 8 respectively", +) +@click.option( + "--min_conf_value_of_textline_text", + "-min_conf", + help="minimum OCR confidence value. Text lines with a confidence value lower than this threshold will not be included in the output XML file.", +) +@click.pass_context +def ocr_cli( + ctx, + image, + dir_in, + dir_in_bin, + dir_xmls, + out, + dir_out_image_text, + overwrite, + tr_ocr, + do_not_mask_with_textline_contour, + batch_size, + min_conf_value_of_textline_text, +): + """ + Recognize text with a CNN/RNN or transformer ML model. + """ + assert bool(image) ^ bool(dir_in), "Either -i (single image) or -di (directory) must be provided, but not both." + from ..eynollah_ocr import Eynollah_ocr + eynollah_ocr = Eynollah_ocr( + model_zoo=ctx.obj.model_zoo, + tr_ocr=tr_ocr, + do_not_mask_with_textline_contour=do_not_mask_with_textline_contour, + batch_size=batch_size, + min_conf_value_of_textline_text=min_conf_value_of_textline_text) + eynollah_ocr.run(overwrite=overwrite, + dir_in=dir_in, + dir_in_bin=dir_in_bin, + image_filename=image, + dir_xmls=dir_xmls, + dir_out_image_text=dir_out_image_text, + dir_out=out, + ) diff --git a/src/eynollah/cli/cli_readingorder.py b/src/eynollah/cli/cli_readingorder.py new file mode 100644 index 0000000..0f44b7f --- /dev/null +++ b/src/eynollah/cli/cli_readingorder.py @@ -0,0 +1,35 @@ +import click + +@click.command() +@click.option( + "--input", + "-i", + help="PAGE-XML input filename", + type=click.Path(exists=True, dir_okay=False), +) +@click.option( + "--dir_in", + "-di", + help="directory of PAGE-XML input files (instead of --input)", + type=click.Path(exists=True, file_okay=False), +) +@click.option( + "--out", + "-o", + help="directory for output images", + type=click.Path(exists=True, file_okay=False), + required=True, +) +@click.pass_context +def readingorder_cli(ctx, input, dir_in, out): + """ + Generate ReadingOrder with a ML model + """ + from ..mb_ro_on_layout import machine_based_reading_order_on_layout + assert bool(input) != bool(dir_in), "Either -i (single input) or -di (directory) must be provided, but not both." + orderer = machine_based_reading_order_on_layout(model_zoo=ctx.obj.model_zoo) + orderer.run(xml_filename=input, + dir_in=dir_in, + dir_out=out, + ) + diff --git a/src/eynollah/extract_images.py b/src/eynollah/extract_images.py new file mode 100644 index 0000000..7a7e3f6 --- /dev/null +++ b/src/eynollah/extract_images.py @@ -0,0 +1,281 @@ +""" +extract images? +""" + +from concurrent.futures import ProcessPoolExecutor +import logging +from multiprocessing import cpu_count +import os +import time +from typing import Optional +from pathlib import Path +import tensorflow as tf +import numpy as np +import cv2 + +from eynollah.utils.contour import filter_contours_area_of_image, return_contours_of_image, return_contours_of_interested_region +from eynollah.utils.resize import resize_image + +from .model_zoo.model_zoo import EynollahModelZoo +from .eynollah import Eynollah +from .utils import box2rect, is_image_filename +from .plot import EynollahPlotter + +class EynollahImageExtractor(Eynollah): + + def __init__( + self, + *, + model_zoo: EynollahModelZoo, + enable_plotting : bool = False, + input_binary : bool = False, + ignore_page_extraction : bool = False, + num_col_upper : Optional[int] = None, + num_col_lower : Optional[int] = None, + full_layout : bool = False, + tables : bool = False, + curved_line : bool = False, + allow_enhancement : bool = False, + + ): + self.logger = logging.getLogger('eynollah.extract_images') + self.model_zoo = model_zoo + self.plotter = None + self.tables = tables + self.curved_line = curved_line + self.allow_enhancement = allow_enhancement + + self.enable_plotting = enable_plotting + # --input-binary sensible if image is very dark, if layout is not working. + self.input_binary = input_binary + self.ignore_page_extraction = ignore_page_extraction + self.full_layout = full_layout + if num_col_upper: + self.num_col_upper = int(num_col_upper) + else: + self.num_col_upper = num_col_upper + if num_col_lower: + self.num_col_lower = int(num_col_lower) + else: + self.num_col_lower = num_col_lower + + # for parallelization of CPU-intensive tasks: + self.executor = ProcessPoolExecutor(max_workers=cpu_count()) + + t_start = time.time() + + try: + for device in tf.config.list_physical_devices('GPU'): + tf.config.experimental.set_memory_growth(device, True) + except: + self.logger.warning("no GPU device available") + + self.logger.info("Loading models...") + self.setup_models() + self.logger.info(f"Model initialization complete ({time.time() - t_start:.1f}s)") + + def setup_models(self): + + loadable = [ + "col_classifier", + "binarization", + "page", + "extract_images", + ] + self.model_zoo.load_models(*loadable) + + def get_regions_light_v_extract_only_images(self,img, num_col_classifier): + self.logger.debug("enter get_regions_extract_images_only") + erosion_hurts = False + img_org = np.copy(img) + img_height_h = img_org.shape[0] + img_width_h = img_org.shape[1] + + if num_col_classifier == 1: + img_w_new = 700 + elif num_col_classifier == 2: + img_w_new = 900 + elif num_col_classifier == 3: + img_w_new = 1500 + elif num_col_classifier == 4: + img_w_new = 1800 + elif num_col_classifier == 5: + img_w_new = 2200 + elif num_col_classifier == 6: + img_w_new = 2500 + else: + raise ValueError("num_col_classifier must be in range 1..6") + img_h_new = int(img.shape[0] / float(img.shape[1]) * img_w_new) + img_resized = resize_image(img,img_h_new, img_w_new ) + + prediction_regions_org, _ = self.do_prediction_new_concept(True, img_resized, self.model_zoo.get("extract_images")) + + prediction_regions_org = resize_image(prediction_regions_org,img_height_h, img_width_h ) + image_page, page_coord, cont_page = self.extract_page() + + prediction_regions_org = prediction_regions_org[page_coord[0] : page_coord[1], page_coord[2] : page_coord[3]] + prediction_regions_org=prediction_regions_org[:,:,0] + + mask_seps_only = (prediction_regions_org[:,:] ==3)*1 + mask_texts_only = (prediction_regions_org[:,:] ==1)*1 + mask_images_only=(prediction_regions_org[:,:] ==2)*1 + + polygons_seplines, hir_seplines = return_contours_of_image(mask_seps_only) + polygons_seplines = filter_contours_area_of_image( + mask_seps_only, polygons_seplines, hir_seplines, max_area=1, min_area=0.00001, dilate=1) + + polygons_of_only_texts = return_contours_of_interested_region(mask_texts_only,1,0.00001) + polygons_of_only_seps = return_contours_of_interested_region(mask_seps_only,1,0.00001) + + text_regions_p_true = np.zeros(prediction_regions_org.shape) + text_regions_p_true = cv2.fillPoly(text_regions_p_true, pts = polygons_of_only_seps, color=(3,3,3)) + + text_regions_p_true[:,:][mask_images_only[:,:] == 1] = 2 + text_regions_p_true = cv2.fillPoly(text_regions_p_true, pts=polygons_of_only_texts, color=(1,1,1)) + + text_regions_p_true[text_regions_p_true.shape[0]-15:text_regions_p_true.shape[0], :] = 0 + text_regions_p_true[:, text_regions_p_true.shape[1]-15:text_regions_p_true.shape[1]] = 0 + + ##polygons_of_images = return_contours_of_interested_region(text_regions_p_true, 2, 0.0001) + polygons_of_images = return_contours_of_interested_region(text_regions_p_true, 2, 0.001) + + polygons_of_images_fin = [] + for ploy_img_ind in polygons_of_images: + box = _, _, w, h = cv2.boundingRect(ploy_img_ind) + if h < 150 or w < 150: + pass + else: + page_coord_img = box2rect(box) # type: ignore + polygons_of_images_fin.append(np.array([[page_coord_img[2], page_coord_img[0]], + [page_coord_img[3], page_coord_img[0]], + [page_coord_img[3], page_coord_img[1]], + [page_coord_img[2], page_coord_img[1]]])) + + self.logger.debug("exit get_regions_extract_images_only") + return (text_regions_p_true, + erosion_hurts, + polygons_seplines, + polygons_of_images_fin, + image_page, + page_coord, + cont_page) + + def run(self, + overwrite: bool = False, + image_filename: Optional[str] = None, + dir_in: Optional[str] = None, + dir_out: Optional[str] = None, + dir_of_cropped_images: Optional[str] = None, + dir_of_layout: Optional[str] = None, + dir_of_deskewed: Optional[str] = None, + dir_of_all: Optional[str] = None, + dir_save_page: Optional[str] = None, + ): + """ + Get image and scales, then extract the page of scanned image + """ + self.logger.debug("enter run") + t0_tot = time.time() + # Log enabled features directly + enabled_modes = [] + if self.full_layout: + enabled_modes.append("Full layout analysis") + if self.tables: + enabled_modes.append("Table detection") + if enabled_modes: + self.logger.info("Enabled modes: " + ", ".join(enabled_modes)) + if self.enable_plotting: + self.logger.info("Saving debug plots") + if dir_of_cropped_images: + self.logger.info(f"Saving cropped images to: {dir_of_cropped_images}") + if dir_of_layout: + self.logger.info(f"Saving layout plots to: {dir_of_layout}") + if dir_of_deskewed: + self.logger.info(f"Saving deskewed images to: {dir_of_deskewed}") + + if dir_in: + ls_imgs = [os.path.join(dir_in, image_filename) + for image_filename in filter(is_image_filename, + os.listdir(dir_in))] + elif image_filename: + ls_imgs = [image_filename] + else: + raise ValueError("run requires either a single image filename or a directory") + + for img_filename in ls_imgs: + self.logger.info(img_filename) + t0 = time.time() + + self.reset_file_name_dir(img_filename, dir_out) + if self.enable_plotting: + self.plotter = EynollahPlotter(dir_out=dir_out, + dir_of_all=dir_of_all, + dir_save_page=dir_save_page, + dir_of_deskewed=dir_of_deskewed, + dir_of_cropped_images=dir_of_cropped_images, + dir_of_layout=dir_of_layout, + image_filename_stem=Path(img_filename).stem) + #print("text region early -11 in %.1fs", time.time() - t0) + if os.path.exists(self.writer.output_filename): + if overwrite: + self.logger.warning("will overwrite existing output file '%s'", self.writer.output_filename) + else: + self.logger.warning("will skip input for existing output file '%s'", self.writer.output_filename) + continue + + pcgts = self.run_single() + self.logger.info("Job done in %.1fs", time.time() - t0) + self.writer.write_pagexml(pcgts) + + if dir_in: + self.logger.info("All jobs done in %.1fs", time.time() - t0_tot) + + def run_single(self): + t0 = time.time() + + self.logger.info(f"Processing file: {self.writer.image_filename}") + self.logger.info("Step 1/5: Image Enhancement") + + img_res, is_image_enhanced, num_col_classifier, _ = \ + self.run_enhancement() + + self.logger.info(f"Image: {self.image.shape[1]}x{self.image.shape[0]}, " + f"{self.dpi} DPI, {num_col_classifier} columns") + if is_image_enhanced: + self.logger.info("Enhancement applied") + + self.logger.info(f"Enhancement complete ({time.time() - t0:.1f}s)") + + + # Image Extraction Mode + self.logger.info("Step 2/5: Image Extraction Mode") + + _, _, _, polygons_of_images, \ + image_page, page_coord, cont_page = \ + self.get_regions_light_v_extract_only_images(img_res, num_col_classifier) + + pcgts = self.writer.build_pagexml_no_full_layout( + found_polygons_text_region=[], + page_coord=page_coord, + order_of_texts=[], + all_found_textline_polygons=[], + all_box_coord=[], + found_polygons_text_region_img=polygons_of_images, + found_polygons_marginals_left=[], + found_polygons_marginals_right=[], + all_found_textline_polygons_marginals_left=[], + all_found_textline_polygons_marginals_right=[], + all_box_coord_marginals_left=[], + all_box_coord_marginals_right=[], + slopes=[], + slopes_marginals_left=[], + slopes_marginals_right=[], + cont_page=cont_page, + polygons_seplines=[], + found_polygons_tables=[], + ) + if self.plotter: + self.plotter.write_images_into_directory(polygons_of_images, image_page) + + self.logger.info("Image extraction complete") + return pcgts diff --git a/src/eynollah/eynollah.py b/src/eynollah/eynollah.py index 4a83c0a..d089511 100644 --- a/src/eynollah/eynollah.py +++ b/src/eynollah/eynollah.py @@ -1,47 +1,47 @@ +""" +document layout analysis (segmentation) with output in PAGE-XML +""" # pylint: disable=no-member,invalid-name,line-too-long,missing-function-docstring,missing-class-docstring,too-many-branches # pylint: disable=too-many-locals,wrong-import-position,too-many-lines,too-many-statements,chained-comparison,fixme,broad-except,c-extension-no-member # pylint: disable=too-many-public-methods,too-many-arguments,too-many-instance-attributes,too-many-public-methods, # pylint: disable=consider-using-enumerate -""" -document layout analysis (segmentation) with output in PAGE-XML -""" +# FIXME: fix all of those... +# pyright: reportUnnecessaryTypeIgnoreComment=true +# pyright: reportPossiblyUnboundVariable=false +# pyright: reportOperatorIssue=false +# pyright: reportUnboundVariable=false +# pyright: reportArgumentType=false +# pyright: reportAttributeAccessIssue=false +# pyright: reportOptionalMemberAccess=false +# pyright: reportGeneralTypeIssues=false +# pyright: reportOptionalSubscript=false -# cannot use importlib.resources until we move to 3.9+ forimportlib.resources.files +import logging import sys -if sys.version_info < (3, 10): - import importlib_resources -else: - import importlib.resources as importlib_resources from difflib import SequenceMatcher as sq -from PIL import Image, ImageDraw, ImageFont import math import os -import sys import time -from typing import Dict, List, Optional, Tuple -import atexit -import warnings +from typing import Optional from functools import partial from pathlib import Path from multiprocessing import cpu_count import gc -import copy -import json from concurrent.futures import ProcessPoolExecutor -import xml.etree.ElementTree as ET import cv2 import numpy as np import shapely.affinity from scipy.signal import find_peaks from scipy.ndimage import gaussian_filter1d -from numba import cuda from skimage.morphology import skeletonize -from ocrd import OcrdPage -from ocrd_utils import getLogger, tf_disable_interactive_logs +from ocrd_utils import tf_disable_interactive_logs import statistics +tf_disable_interactive_logs() + +import tensorflow as tf try: import torch except ImportError: @@ -50,34 +50,18 @@ try: import matplotlib.pyplot as plt except ImportError: plt = None -try: - from transformers import TrOCRProcessor, VisionEncoderDecoderModel -except ImportError: - TrOCRProcessor = VisionEncoderDecoderModel = None - -#os.environ['CUDA_VISIBLE_DEVICES'] = '-1' -os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15 -tf_disable_interactive_logs() -import tensorflow as tf -from tensorflow.keras.models import load_model -tf.get_logger().setLevel("ERROR") -warnings.filterwarnings("ignore") -from tensorflow.keras import layers -from tensorflow.keras.layers import StringLookup +from .model_zoo import EynollahModelZoo from .utils.contour import ( filter_contours_area_of_image, filter_contours_area_of_image_tables, - find_contours_mean_y_diff, find_center_of_contours, find_new_features_of_contours, find_features_of_contours, get_text_region_boxes_by_given_contours, - get_textregion_contours_in_org_image, get_textregion_contours_in_org_image_light, return_contours_of_image, return_contours_of_interested_region, - return_contours_of_interested_textline, return_parent_contours, dilate_textregion_contours, dilate_textline_contours, @@ -87,29 +71,9 @@ from .utils.contour import ( make_intersection, ) from .utils.rotate import rotate_image -from .utils.utils_ocr import ( - return_start_and_end_of_common_text_of_textline_ocr_without_common_section, - return_textline_contour_with_added_box_coordinate, - preprocess_and_resize_image_for_ocrcnn_model, - return_textlines_split_if_needed, - decode_batch_predictions, - return_rnn_cnn_ocr_of_given_textlines, - fit_text_single_line, - break_curved_line_into_small_pieces_and_then_merge, - get_orientation_moments, - rotate_image_with_padding, - get_contours_and_bounding_boxes -) from .utils.separate_lines import ( - separate_lines_new2, return_deskew_slop, - do_work_of_slopes_new, do_work_of_slopes_new_curved, - do_work_of_slopes_new_light, -) -from .utils.drop_capitals import ( - adhere_drop_capital_region_into_corresponding_textline, - filter_small_drop_capitals_from_no_patch_layout ) from .utils.marginals import get_marginals from .utils.resize import resize_image @@ -117,22 +81,19 @@ from .utils.shm import share_ndarray from .utils import ( ensure_array, is_image_filename, - boosting_headers_by_longshot_region_segmentation, + isNaN, crop_image_inside_box, box2rect, - box2slice, find_num_col, otsu_copy_binary, - put_drop_out_from_only_drop_model, putt_bb_of_drop_capitals_of_model_in_patches_in_layout, - check_any_text_region_in_model_one_is_main_or_header, check_any_text_region_in_model_one_is_main_or_header_light, small_textlines_to_parent_adherence2, order_of_regions, find_number_of_columns_in_document, return_boxes_of_images_by_order_of_reading_new ) -from .utils.pil_cv2 import check_dpi, pil2cv +from .utils.pil_cv2 import pil2cv from .plot import EynollahPlotter from .writer import EynollahXmlWriter @@ -148,109 +109,47 @@ patch_size = 1 num_patches =21*21#14*14#28*28#14*14#28*28 -class Patches(layers.Layer): - def __init__(self, **kwargs): - super(Patches, self).__init__() - self.patch_size = patch_size - - def call(self, images): - batch_size = tf.shape(images)[0] - patches = tf.image.extract_patches( - images=images, - sizes=[1, self.patch_size, self.patch_size, 1], - strides=[1, self.patch_size, self.patch_size, 1], - rates=[1, 1, 1, 1], - padding="VALID", - ) - patch_dims = patches.shape[-1] - patches = tf.reshape(patches, [batch_size, -1, patch_dims]) - return patches - def get_config(self): - - config = super().get_config().copy() - config.update({ - 'patch_size': self.patch_size, - }) - return config - -class PatchEncoder(layers.Layer): - def __init__(self, **kwargs): - super(PatchEncoder, self).__init__() - self.num_patches = num_patches - self.projection = layers.Dense(units=projection_dim) - self.position_embedding = layers.Embedding( - input_dim=num_patches, output_dim=projection_dim - ) - - def call(self, patch): - positions = tf.range(start=0, limit=self.num_patches, delta=1) - encoded = self.projection(patch) + self.position_embedding(positions) - return encoded - def get_config(self): - - config = super().get_config().copy() - config.update({ - 'num_patches': self.num_patches, - 'projection': self.projection, - 'position_embedding': self.position_embedding, - }) - return config class Eynollah: def __init__( self, - dir_models : str, - model_versions: List[Tuple[str, str]] = [], - extract_only_images : bool =False, + *, + model_zoo: EynollahModelZoo, enable_plotting : bool = False, allow_enhancement : bool = False, curved_line : bool = False, - textline_light : bool = False, full_layout : bool = False, tables : bool = False, right2left : bool = False, input_binary : bool = False, allow_scaling : bool = False, headers_off : bool = False, - light_version : bool = False, ignore_page_extraction : bool = False, reading_order_machine_based : bool = False, - do_ocr : bool = False, - transformer_ocr: bool = False, - batch_size_ocr: Optional[int] = None, num_col_upper : Optional[int] = None, num_col_lower : Optional[int] = None, threshold_art_class_layout: Optional[float] = None, threshold_art_class_textline: Optional[float] = None, skip_layout_and_reading_order : bool = False, + logger : Optional[logging.Logger] = None, ): - self.logger = getLogger('eynollah') + self.logger = logger or logging.getLogger('eynollah') + self.model_zoo = model_zoo self.plotter = None - if skip_layout_and_reading_order: - textline_light = True - self.light_version = light_version self.reading_order_machine_based = reading_order_machine_based self.enable_plotting = enable_plotting self.allow_enhancement = allow_enhancement self.curved_line = curved_line - self.textline_light = textline_light self.full_layout = full_layout self.tables = tables self.right2left = right2left + # --input-binary sensible if image is very dark, if layout is not working. self.input_binary = input_binary self.allow_scaling = allow_scaling self.headers_off = headers_off - self.light_version = light_version - self.extract_only_images = extract_only_images self.ignore_page_extraction = ignore_page_extraction self.skip_layout_and_reading_order = skip_layout_and_reading_order - self.ocr = do_ocr - self.tr = transformer_ocr - if not batch_size_ocr: - self.b_s_ocr = 8 - else: - self.b_s_ocr = int(batch_size_ocr) if num_col_upper: self.num_col_upper = int(num_col_upper) else: @@ -262,12 +161,12 @@ class Eynollah: # for parallelization of CPU-intensive tasks: self.executor = ProcessPoolExecutor(max_workers=cpu_count()) - + if threshold_art_class_layout: self.threshold_art_class_layout = float(threshold_art_class_layout) else: self.threshold_art_class_layout = 0.1 - + if threshold_art_class_textline: self.threshold_art_class_textline = float(threshold_art_class_textline) else: @@ -280,95 +179,13 @@ class Eynollah: tf.config.experimental.set_memory_growth(device, True) except: self.logger.warning("no GPU device available") - + self.logger.info("Loading models...") - self.setup_models(dir_models, model_versions) + self.setup_models() self.logger.info(f"Model initialization complete ({time.time() - t_start:.1f}s)") - @staticmethod - def our_load_model(model_file, basedir=""): - if basedir: - model_file = os.path.join(basedir, model_file) - if model_file.endswith('.h5') and Path(model_file[:-3]).exists(): - # prefer SavedModel over HDF5 format if it exists - model_file = model_file[:-3] - try: - model = load_model(model_file, compile=False) - except: - model = load_model(model_file, compile=False, custom_objects={ - "PatchEncoder": PatchEncoder, "Patches": Patches}) - return model + def setup_models(self): - def setup_models(self, basedir: Path, model_versions: List[Tuple[str, str]] = []): - self.model_versions = { - "enhancement": "eynollah-enhancement_20210425", - "binarization": "eynollah-binarization_20210425", - "col_classifier": "eynollah-column-classifier_20210425", - "page": "model_eynollah_page_extraction_20250915", - #?: "eynollah-main-regions-aug-scaling_20210425", - "region": ( # early layout - "eynollah-main-regions_20231127_672_org_ens_11_13_16_17_18" if self.extract_only_images else - "eynollah-main-regions_20220314" if self.light_version else - "eynollah-main-regions-ensembled_20210425"), - "region_p2": ( # early layout, non-light, 2nd part - "eynollah-main-regions-aug-rotation_20210425"), - "region_1_2": ( # early layout, light, 1-or-2-column - #"modelens_12sp_elay_0_3_4__3_6_n" - #"modelens_earlylayout_12spaltige_2_3_5_6_7_8" - #"modelens_early12_sp_2_3_5_6_7_8_9_10_12_14_15_16_18" - #"modelens_1_2_4_5_early_lay_1_2_spaltige" - #"model_3_eraly_layout_no_patches_1_2_spaltige" - "modelens_e_l_all_sp_0_1_2_3_4_171024"), - "region_fl_np": ( # full layout / no patches - #"modelens_full_lay_1_3_031124" - #"modelens_full_lay_13__3_19_241024" - #"model_full_lay_13_241024" - #"modelens_full_lay_13_17_231024" - #"modelens_full_lay_1_2_221024" - #"eynollah-full-regions-1column_20210425" - "modelens_full_lay_1__4_3_091124"), - "region_fl": ( # full layout / with patches - #"eynollah-full-regions-3+column_20210425" - ##"model_2_full_layout_new_trans" - #"modelens_full_lay_1_3_031124" - #"modelens_full_lay_13__3_19_241024" - #"model_full_lay_13_241024" - #"modelens_full_lay_13_17_231024" - #"modelens_full_lay_1_2_221024" - #"modelens_full_layout_24_till_28" - #"model_2_full_layout_new_trans" - "modelens_full_lay_1__4_3_091124"), - "reading_order": ( - #"model_mb_ro_aug_ens_11" - #"model_step_3200000_mb_ro" - #"model_ens_reading_order_machine_based" - #"model_mb_ro_aug_ens_8" - #"model_ens_reading_order_machine_based" - "model_eynollah_reading_order_20250824"), - "textline": ( - #"modelens_textline_1_4_16092024" - #"model_textline_ens_3_4_5_6_artificial" - #"modelens_textline_1_3_4_20240915" - #"model_textline_ens_3_4_5_6_artificial" - #"modelens_textline_9_12_13_14_15" - #"eynollah-textline_light_20210425" - "modelens_textline_0_1__2_4_16092024" if self.textline_light else - #"eynollah-textline_20210425" - "modelens_textline_0_1__2_4_16092024"), - "table": ( - None if not self.tables else - "modelens_table_0t4_201124" if self.light_version else - "eynollah-tables_20210319"), - "ocr": ( - None if not self.ocr else - "model_eynollah_ocr_trocr_20250919" if self.tr else - "model_eynollah_ocr_cnnrnn_20250930") - } - # override defaults from CLI - for key, val in model_versions: - assert key in self.model_versions, "unknown model category '%s'" % key - self.logger.warning("overriding default model %s version %s to %s", key, self.model_versions[key], val) - self.model_versions[key] = val # load models, depending on modes # (note: loading too many models can cause OOM on GPU/CUDA, # thus, we try set up the minimal configuration for the current mode) @@ -378,77 +195,44 @@ class Eynollah: "page", "region" ] - if not self.extract_only_images: - loadable.append("textline") - if self.light_version: - loadable.append("region_1_2") - else: - loadable.append("region_p2") - # if self.allow_enhancement:? - loadable.append("enhancement") - if self.full_layout: - loadable.append("region_fl_np") - #loadable.append("region_fl") - if self.reading_order_machine_based: - loadable.append("reading_order") - if self.tables: - loadable.append("table") + loadable.append(("textline")) + loadable.append("region_1_2") + if self.full_layout: + loadable.append("region_fl_np") + #loadable.append("region_fl") + if self.reading_order_machine_based: + loadable.append("reading_order") + if self.tables: + loadable.append(("table")) - self.models = {name: self.our_load_model(self.model_versions[name], basedir) - for name in loadable - } - - if self.ocr: - ocr_model_dir = os.path.join(basedir, self.model_versions["ocr"]) - if self.tr: - self.models["ocr"] = VisionEncoderDecoderModel.from_pretrained(ocr_model_dir) - if torch.cuda.is_available(): - self.logger.info("Using GPU acceleration") - self.device = torch.device("cuda:0") - else: - self.logger.info("Using CPU processing") - self.device = torch.device("cpu") - #self.processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten") - self.processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-printed") - else: - ocr_model = load_model(ocr_model_dir, compile=False) - self.models["ocr"] = tf.keras.models.Model( - ocr_model.get_layer(name = "image").input, - ocr_model.get_layer(name = "dense2").output) - - with open(os.path.join(ocr_model_dir, "characters_org.txt"), "r") as config_file: - characters = json.load(config_file) - # Mapping characters to integers. - char_to_num = StringLookup(vocabulary=list(characters), mask_token=None) - # Mapping integers back to original characters. - self.num_to_char = StringLookup( - vocabulary=char_to_num.get_vocabulary(), mask_token=None, invert=True - ) + self.model_zoo.load_models(*loadable) def __del__(self): if hasattr(self, 'executor') and getattr(self, 'executor'): + assert self.executor self.executor.shutdown() self.executor = None - if hasattr(self, 'models') and getattr(self, 'models'): - for model_name in list(self.models): - if self.models[model_name]: - del self.models[model_name] + self.model_zoo.shutdown() + + @property + def device(self): + # TODO why here and why only for tr? + assert torch + if torch.cuda.is_available(): + self.logger.info("Using GPU acceleration") + return torch.device("cuda:0") + self.logger.info("Using CPU processing") + return torch.device("cpu") def cache_images(self, image_filename=None, image_pil=None, dpi=None): ret = {} t_c0 = time.time() if image_filename: ret['img'] = cv2.imread(image_filename) - if self.light_version: - self.dpi = 100 - else: - self.dpi = check_dpi(image_filename) + self.dpi = 100 else: ret['img'] = pil2cv(image_pil) - if self.light_version: - self.dpi = 100 - else: - self.dpi = check_dpi(image_pil) + self.dpi = 100 ret['img_grayscale'] = cv2.cvtColor(ret['img'], cv2.COLOR_BGR2GRAY) for prefix in ('', '_grayscale'): ret[f'img{prefix}_uint8'] = ret[f'img{prefix}'].astype(np.uint8) @@ -462,8 +246,7 @@ class Eynollah: self.writer = EynollahXmlWriter( dir_out=dir_out, image_filename=image_filename, - curved_line=self.curved_line, - textline_light = self.textline_light) + curved_line=self.curved_line) def imread(self, grayscale=False, uint8=True): key = 'img' @@ -473,14 +256,11 @@ class Eynollah: key += '_uint8' return self._imgs[key].copy() - def isNaN(self, num): - return num != num - def predict_enhancement(self, img): self.logger.debug("enter predict_enhancement") - img_height_model = self.models["enhancement"].layers[-1].output_shape[1] - img_width_model = self.models["enhancement"].layers[-1].output_shape[2] + img_height_model = self.model_zoo.get("enhancement").layers[-1].output_shape[1] + img_width_model = self.model_zoo.get("enhancement").layers[-1].output_shape[2] if img.shape[0] < img_height_model: img = cv2.resize(img, (img.shape[1], img_width_model), interpolation=cv2.INTER_NEAREST) if img.shape[1] < img_width_model: @@ -521,7 +301,7 @@ class Eynollah: index_y_d = img_h - img_height_model img_patch = img[np.newaxis, index_y_d:index_y_u, index_x_d:index_x_u, :] - label_p_pred = self.models["enhancement"].predict(img_patch, verbose=0) + label_p_pred = self.model_zoo.get("enhancement").predict(img_patch, verbose=0) seg = label_p_pred[0, :, :, :] * 255 if i == 0 and j == 0: @@ -645,27 +425,6 @@ class Eynollah: return img_new, num_column_is_classified - def calculate_width_height_by_columns_extract_only_images(self, img, num_col, width_early, label_p_pred): - self.logger.debug("enter calculate_width_height_by_columns") - if num_col == 1: - img_w_new = 700 - elif num_col == 2: - img_w_new = 900 - elif num_col == 3: - img_w_new = 1500 - elif num_col == 4: - img_w_new = 1800 - elif num_col == 5: - img_w_new = 2200 - elif num_col == 6: - img_w_new = 2500 - img_h_new = img_w_new * img.shape[0] // img.shape[1] - - img_new = resize_image(img, img_h_new, img_w_new) - num_column_is_classified = True - - return img_new, num_column_is_classified - def resize_image_with_column_classifier(self, is_image_enhanced, img_bin): self.logger.debug("enter resize_image_with_column_classifier") if self.input_binary: @@ -696,7 +455,7 @@ class Eynollah: img_in[0, :, :, 1] = img_1ch[:, :] img_in[0, :, :, 2] = img_1ch[:, :] - label_p_pred = self.models["col_classifier"].predict(img_in, verbose=0) + label_p_pred = self.model_zoo.get("col_classifier").predict(img_in, verbose=0) num_col = np.argmax(label_p_pred[0]) + 1 self.logger.info("Found %s columns (%s)", num_col, label_p_pred) @@ -708,13 +467,13 @@ class Eynollah: return img, img_new, is_image_enhanced - def resize_and_enhance_image_with_column_classifier(self, light_version): + def resize_and_enhance_image_with_column_classifier(self): self.logger.debug("enter resize_and_enhance_image_with_column_classifier") dpi = self.dpi self.logger.info("Detected %s DPI", dpi) if self.input_binary: img = self.imread() - prediction_bin = self.do_prediction(True, img, self.models["binarization"], n_batch_inference=5) + prediction_bin = self.do_prediction(True, img, self.model_zoo.get("binarization"), n_batch_inference=5) prediction_bin = 255 * (prediction_bin[:,:,0] == 0) prediction_bin = np.repeat(prediction_bin[:, :, np.newaxis], 3, axis=2).astype(np.uint8) img= np.copy(prediction_bin) @@ -754,9 +513,9 @@ class Eynollah: img_in[0, :, :, 1] = img_1ch[:, :] img_in[0, :, :, 2] = img_1ch[:, :] - label_p_pred = self.models["col_classifier"].predict(img_in, verbose=0) + label_p_pred = self.model_zoo.get("col_classifier").predict(img_in, verbose=0) num_col = np.argmax(label_p_pred[0]) + 1 - + elif (self.num_col_upper and self.num_col_lower) and (self.num_col_upper!=self.num_col_lower): if self.input_binary: img_in = np.copy(img) @@ -775,7 +534,7 @@ class Eynollah: img_in[0, :, :, 1] = img_1ch[:, :] img_in[0, :, :, 2] = img_1ch[:, :] - label_p_pred = self.models["col_classifier"].predict(img_in, verbose=0) + label_p_pred = self.model_zoo.get("col_classifier").predict(img_in, verbose=0) num_col = np.argmax(label_p_pred[0]) + 1 if num_col > self.num_col_upper: @@ -789,33 +548,25 @@ class Eynollah: label_p_pred = [np.ones(6)] self.logger.info("Found %d columns (%s)", num_col, np.around(label_p_pred, decimals=5)) - if not self.extract_only_images: - if dpi < DPI_THRESHOLD: - if light_version and num_col in (1,2): - img_new, num_column_is_classified = self.calculate_width_height_by_columns_1_2( - img, num_col, width_early, label_p_pred) - else: - img_new, num_column_is_classified = self.calculate_width_height_by_columns( - img, num_col, width_early, label_p_pred) - if light_version: - image_res = np.copy(img_new) - else: - image_res = self.predict_enhancement(img_new) + if dpi < DPI_THRESHOLD: + if num_col in (1,2): + img_new, num_column_is_classified = self.calculate_width_height_by_columns_1_2( + img, num_col, width_early, label_p_pred) + else: + img_new, num_column_is_classified = self.calculate_width_height_by_columns( + img, num_col, width_early, label_p_pred) + image_res = np.copy(img_new) + is_image_enhanced = True + else: + if num_col in (1,2): + img_new, num_column_is_classified = self.calculate_width_height_by_columns_1_2( + img, num_col, width_early, label_p_pred) + image_res = np.copy(img_new) is_image_enhanced = True else: - if light_version and num_col in (1,2): - img_new, num_column_is_classified = self.calculate_width_height_by_columns_1_2( - img, num_col, width_early, label_p_pred) - image_res = np.copy(img_new) - is_image_enhanced = True - else: - num_column_is_classified = True - image_res = np.copy(img) - is_image_enhanced = False - else: - num_column_is_classified = True - image_res = np.copy(img) - is_image_enhanced = False + num_column_is_classified = True + image_res = np.copy(img) + is_image_enhanced = False self.logger.debug("exit resize_and_enhance_image_with_column_classifier") return is_image_enhanced, img, image_res, num_col, num_column_is_classified, img_bin @@ -830,8 +581,8 @@ class Eynollah: self.img_hight_int = int(self.image.shape[0] * scale) self.img_width_int = int(self.image.shape[1] * scale) - self.scale_y = self.img_hight_int / float(self.image.shape[0]) - self.scale_x = self.img_width_int / float(self.image.shape[1]) + self.scale_y: float = self.img_hight_int / float(self.image.shape[0]) + self.scale_x: float = self.img_width_int / float(self.image.shape[1]) self.image = resize_image(self.image, self.img_hight_int, self.img_width_int) @@ -896,12 +647,12 @@ class Eynollah: seg_art[seg_art0] =1 - + skeleton_art = skeletonize(seg_art) skeleton_art = skeleton_art*1 seg[skeleton_art==1]=2 - + if thresholding_for_fl_light_version: seg_header = label_p_pred[0,:,:,2] @@ -909,7 +660,7 @@ class Eynollah: seg_header[seg_header>0] =1 seg[seg_header==1]=2 - + seg_color = np.repeat(seg[:, :, np.newaxis], 3, axis=2) prediction_true = resize_image(seg_color, img_h_page, img_w_page).astype(np.uint8) return prediction_true @@ -998,7 +749,7 @@ class Eynollah: indexer_inside_batch = 0 for i_batch, j_batch in zip(list_i_s, list_j_s): seg_in = seg[indexer_inside_batch] - + if thresholding_for_artificial_class_in_light_version: seg_in_art = seg_art[indexer_inside_batch] @@ -1019,7 +770,7 @@ class Eynollah: index_x_d_in + 0:index_x_u_in - margin, 1] = \ seg_in_art[0:-margin or None, 0:-margin or None] - + elif i_batch == nxf - 1 and j_batch == nyf - 1: prediction_true[index_y_d_in + margin:index_y_u_in - 0, index_x_d_in + margin:index_x_u_in - 0] = \ @@ -1031,7 +782,7 @@ class Eynollah: index_x_d_in + margin:index_x_u_in - 0, 1] = \ seg_in_art[margin:, margin:] - + elif i_batch == 0 and j_batch == nyf - 1: prediction_true[index_y_d_in + margin:index_y_u_in - 0, index_x_d_in + 0:index_x_u_in - margin] = \ @@ -1043,7 +794,7 @@ class Eynollah: index_x_d_in + 0:index_x_u_in - margin, 1] = \ seg_in_art[margin:, 0:-margin or None] - + elif i_batch == nxf - 1 and j_batch == 0: prediction_true[index_y_d_in + 0:index_y_u_in - margin, index_x_d_in + margin:index_x_u_in - 0] = \ @@ -1055,7 +806,7 @@ class Eynollah: index_x_d_in + margin:index_x_u_in - 0, 1] = \ seg_in_art[0:-margin or None, margin:] - + elif i_batch == 0 and j_batch != 0 and j_batch != nyf - 1: prediction_true[index_y_d_in + margin:index_y_u_in - margin, index_x_d_in + 0:index_x_u_in - margin] = \ @@ -1067,7 +818,7 @@ class Eynollah: index_x_d_in + 0:index_x_u_in - margin, 1] = \ seg_in_art[margin:-margin or None, 0:-margin or None] - + elif i_batch == nxf - 1 and j_batch != 0 and j_batch != nyf - 1: prediction_true[index_y_d_in + margin:index_y_u_in - margin, index_x_d_in + margin:index_x_u_in - 0] = \ @@ -1079,7 +830,7 @@ class Eynollah: index_x_d_in + margin:index_x_u_in - 0, 1] = \ seg_in_art[margin:-margin or None, margin:] - + elif i_batch != 0 and i_batch != nxf - 1 and j_batch == 0: prediction_true[index_y_d_in + 0:index_y_u_in - margin, index_x_d_in + margin:index_x_u_in - margin] = \ @@ -1091,7 +842,7 @@ class Eynollah: index_x_d_in + margin:index_x_u_in - margin, 1] = \ seg_in_art[0:-margin or None, margin:-margin or None] - + elif i_batch != 0 and i_batch != nxf - 1 and j_batch == nyf - 1: prediction_true[index_y_d_in + margin:index_y_u_in - 0, index_x_d_in + margin:index_x_u_in - margin] = \ @@ -1103,7 +854,7 @@ class Eynollah: index_x_d_in + margin:index_x_u_in - margin, 1] = \ seg_in_art[margin:, margin:-margin or None] - + else: prediction_true[index_y_d_in + margin:index_y_u_in - margin, index_x_d_in + margin:index_x_u_in - margin] = \ @@ -1129,16 +880,16 @@ class Eynollah: img_patch[:] = 0 prediction_true = prediction_true.astype(np.uint8) - + if thresholding_for_artificial_class_in_light_version: kernel_min = np.ones((3, 3), np.uint8) prediction_true[:,:,0][prediction_true[:,:,0]==2] = 0 - + skeleton_art = skeletonize(prediction_true[:,:,1]) skeleton_art = skeleton_art*1 - + skeleton_art = skeleton_art.astype('uint8') - + skeleton_art = cv2.dilate(skeleton_art, kernel_min, iterations=1) prediction_true[:,:,0][skeleton_art==1]=2 @@ -1146,136 +897,6 @@ class Eynollah: gc.collect() return prediction_true - def do_padding_with_scale(self, img, scale): - h_n = int(img.shape[0]*scale) - w_n = int(img.shape[1]*scale) - - channel0_avg = int( np.mean(img[:,:,0]) ) - channel1_avg = int( np.mean(img[:,:,1]) ) - channel2_avg = int( np.mean(img[:,:,2]) ) - - h_diff = img.shape[0] - h_n - w_diff = img.shape[1] - w_n - - h_start = int(0.5 * h_diff) - w_start = int(0.5 * w_diff) - - img_res = resize_image(img, h_n, w_n) - #label_res = resize_image(label, h_n, w_n) - - img_scaled_padded = np.copy(img) - - #label_scaled_padded = np.zeros(label.shape) - - img_scaled_padded[:,:,0] = channel0_avg - img_scaled_padded[:,:,1] = channel1_avg - img_scaled_padded[:,:,2] = channel2_avg - - img_scaled_padded[h_start:h_start+h_n, w_start:w_start+w_n,:] = img_res[:,:,:] - #label_scaled_padded[h_start:h_start+h_n, w_start:w_start+w_n,:] = label_res[:,:,:] - - return img_scaled_padded#, label_scaled_padded - - def do_prediction_new_concept_scatter_nd( - self, patches, img, model, - n_batch_inference=1, marginal_of_patch_percent=0.1, - thresholding_for_some_classes_in_light_version=False, - thresholding_for_artificial_class_in_light_version=False): - - self.logger.debug("enter do_prediction_new_concept") - img_height_model = model.layers[-1].output_shape[1] - img_width_model = model.layers[-1].output_shape[2] - - if not patches: - img_h_page = img.shape[0] - img_w_page = img.shape[1] - img = img / 255.0 - img = resize_image(img, img_height_model, img_width_model) - - label_p_pred = model.predict(img[np.newaxis], verbose=0) - seg = np.argmax(label_p_pred, axis=3)[0] - - if thresholding_for_artificial_class_in_light_version: - #seg_text = label_p_pred[0,:,:,1] - #seg_text[seg_text<0.2] =0 - #seg_text[seg_text>0] =1 - #seg[seg_text==1]=1 - - seg_art = label_p_pred[0,:,:,4] - seg_art[seg_art<0.2] =0 - seg_art[seg_art>0] =1 - seg[seg_art==1]=4 - - seg_color = np.repeat(seg[:, :, np.newaxis], 3, axis=2) - prediction_true = resize_image(seg_color, img_h_page, img_w_page).astype(np.uint8) - return prediction_true - - if img.shape[0] < img_height_model: - img = resize_image(img, img_height_model, img.shape[1]) - if img.shape[1] < img_width_model: - img = resize_image(img, img.shape[0], img_width_model) - - self.logger.debug("Patch size: %sx%s", img_height_model, img_width_model) - ##margin = int(marginal_of_patch_percent * img_height_model) - #width_mid = img_width_model - 2 * margin - #height_mid = img_height_model - 2 * margin - img = img / 255.0 - img = img.astype(np.float16) - img_h = img.shape[0] - img_w = img.shape[1] - - stride_x = img_width_model - 100 - stride_y = img_height_model - 100 - - one_tensor = tf.ones_like(img) - img_patches, one_patches = tf.image.extract_patches( - images=[img, one_tensor], - sizes=[1, img_height_model, img_width_model, 1], - strides=[1, stride_y, stride_x, 1], - rates=[1, 1, 1, 1], - padding='SAME') - img_patches = tf.squeeze(img_patches) - one_patches = tf.squeeze(one_patches) - img_patches_resh = tf.reshape(img_patches, shape=(img_patches.shape[0] * img_patches.shape[1], - img_height_model, img_width_model, 3)) - pred_patches = model.predict(img_patches_resh, batch_size=n_batch_inference) - one_patches = tf.reshape(one_patches, shape=(img_patches.shape[0] * img_patches.shape[1], - img_height_model, img_width_model, 3)) - x = tf.range(img.shape[1]) - y = tf.range(img.shape[0]) - x, y = tf.meshgrid(x, y) - indices = tf.stack([y, x], axis=-1) - - indices_patches = tf.image.extract_patches( - images=tf.expand_dims(indices, axis=0), - sizes=[1, img_height_model, img_width_model, 1], - strides=[1, stride_y, stride_x, 1], - rates=[1, 1, 1, 1], - padding='SAME') - indices_patches = tf.squeeze(indices_patches) - indices_patches = tf.reshape(indices_patches, shape=(img_patches.shape[0] * img_patches.shape[1], - img_height_model, img_width_model, 2)) - margin_y = int( 0.5 * (img_height_model - stride_y) ) - margin_x = int( 0.5 * (img_width_model - stride_x) ) - - mask_margin = np.zeros((img_height_model, img_width_model)) - mask_margin[margin_y:img_height_model - margin_y, - margin_x:img_width_model - margin_x] = 1 - - indices_patches_array = indices_patches.numpy() - for i in range(indices_patches_array.shape[0]): - indices_patches_array[i,:,:,0] = indices_patches_array[i,:,:,0]*mask_margin - indices_patches_array[i,:,:,1] = indices_patches_array[i,:,:,1]*mask_margin - - reconstructed = tf.scatter_nd( - indices=indices_patches_array, - updates=pred_patches, - shape=(img.shape[0], img.shape[1], pred_patches.shape[-1])).numpy() - - prediction_true = np.argmax(reconstructed, axis=2).astype(np.uint8) - gc.collect() - return np.repeat(prediction_true[:, :, np.newaxis], 3, axis=2) - def do_prediction_new_concept( self, patches, img, model, n_batch_inference=1, marginal_of_patch_percent=0.1, @@ -1299,7 +920,7 @@ class Eynollah: seg_color = np.repeat(seg[:, :, np.newaxis], 3, axis=2) prediction_true = resize_image(seg_color, img_h_page, img_w_page).astype(np.uint8) - + if thresholding_for_artificial_class_in_light_version: kernel_min = np.ones((3, 3), np.uint8) seg_art = label_p_pred[0,:,:,4] @@ -1307,18 +928,18 @@ class Eynollah: seg_art[seg_art>0] =1 #seg[seg_art==1]=4 seg_art = resize_image(seg_art, img_h_page, img_w_page).astype(np.uint8) - + prediction_true[:,:,0][prediction_true[:,:,0]==4] = 0 - + skeleton_art = skeletonize(seg_art) skeleton_art = skeleton_art*1 - + skeleton_art = skeleton_art.astype('uint8') - + skeleton_art = cv2.dilate(skeleton_art, kernel_min, iterations=1) - + prediction_true[:,:,0][skeleton_art==1] = 4 - + return prediction_true , resize_image(label_p_pred[0, :, :, 1] , img_h_page, img_w_page) if img.shape[0] < img_height_model: @@ -1411,7 +1032,7 @@ class Eynollah: indexer_inside_batch = 0 for i_batch, j_batch in zip(list_i_s, list_j_s): seg_in = seg[indexer_inside_batch] - + if (thresholding_for_artificial_class_in_light_version or thresholding_for_some_classes_in_light_version): seg_in_art = seg_art[indexer_inside_batch] @@ -1439,7 +1060,7 @@ class Eynollah: index_x_d_in + 0:index_x_u_in - margin, 1] = \ seg_in_art[0:-margin or None, 0:-margin or None] - + elif i_batch == nxf - 1 and j_batch == nyf - 1: prediction_true[index_y_d_in + margin:index_y_u_in - 0, index_x_d_in + margin:index_x_u_in - 0] = \ @@ -1457,7 +1078,7 @@ class Eynollah: index_x_d_in + margin:index_x_u_in - 0, 1] = \ seg_in_art[margin:, margin:] - + elif i_batch == 0 and j_batch == nyf - 1: prediction_true[index_y_d_in + margin:index_y_u_in - 0, index_x_d_in + 0:index_x_u_in - margin] = \ @@ -1469,14 +1090,14 @@ class Eynollah: label_p_pred[0, margin:, 0:-margin or None, 1] - + if (thresholding_for_artificial_class_in_light_version or thresholding_for_some_classes_in_light_version): prediction_true[index_y_d_in + margin:index_y_u_in - 0, index_x_d_in + 0:index_x_u_in - margin, 1] = \ seg_in_art[margin:, 0:-margin or None] - + elif i_batch == nxf - 1 and j_batch == 0: prediction_true[index_y_d_in + 0:index_y_u_in - margin, index_x_d_in + margin:index_x_u_in - 0] = \ @@ -1494,7 +1115,7 @@ class Eynollah: index_x_d_in + margin:index_x_u_in - 0, 1] = \ seg_in_art[0:-margin or None, margin:] - + elif i_batch == 0 and j_batch != 0 and j_batch != nyf - 1: prediction_true[index_y_d_in + margin:index_y_u_in - margin, index_x_d_in + 0:index_x_u_in - margin] = \ @@ -1593,29 +1214,29 @@ class Eynollah: img_patch[:] = 0 prediction_true = prediction_true.astype(np.uint8) - + if thresholding_for_artificial_class_in_light_version: kernel_min = np.ones((3, 3), np.uint8) prediction_true[:,:,0][prediction_true[:,:,0]==2] = 0 - + skeleton_art = skeletonize(prediction_true[:,:,1]) skeleton_art = skeleton_art*1 - + skeleton_art = skeleton_art.astype('uint8') - + skeleton_art = cv2.dilate(skeleton_art, kernel_min, iterations=1) prediction_true[:,:,0][skeleton_art==1]=2 - + if thresholding_for_some_classes_in_light_version: kernel_min = np.ones((3, 3), np.uint8) prediction_true[:,:,0][prediction_true[:,:,0]==4] = 0 - + skeleton_art = skeletonize(prediction_true[:,:,1]) skeleton_art = skeleton_art*1 - + skeleton_art = skeleton_art.astype('uint8') - + skeleton_art = cv2.dilate(skeleton_art, kernel_min, iterations=1) prediction_true[:,:,0][skeleton_art==1]=4 @@ -1627,7 +1248,7 @@ class Eynollah: cont_page = [] if not self.ignore_page_extraction: img = np.copy(self.image)#cv2.GaussianBlur(self.image, (5, 5), 0) - img_page_prediction = self.do_prediction(False, img, self.models["page"]) + img_page_prediction = self.do_prediction(False, img, self.model_zoo.get("page")) imgray = cv2.cvtColor(img_page_prediction, cv2.COLOR_BGR2GRAY) _, thresh = cv2.threshold(imgray, 0, 255, 0) ##thresh = cv2.dilate(thresh, KERNEL, iterations=3) @@ -1675,7 +1296,7 @@ class Eynollah: else: img = self.imread() img = cv2.GaussianBlur(img, (5, 5), 0) - img_page_prediction = self.do_prediction(False, img, self.models["page"]) + img_page_prediction = self.do_prediction(False, img, self.model_zoo.get("page")) imgray = cv2.cvtColor(img_page_prediction, cv2.COLOR_BGR2GRAY) _, thresh = cv2.threshold(imgray, 0, 255, 0) @@ -1701,11 +1322,10 @@ class Eynollah: self.logger.debug("enter extract_text_regions") img_height_h = img.shape[0] img_width_h = img.shape[1] - model_region = self.models["region_fl"] if patches else self.models["region_fl_np"] + model_region = self.model_zoo.get("region_fl") if patches else self.model_zoo.get("region_fl_np") - if self.light_version: - thresholding_for_fl_light_version = True - elif not patches: + thresholding_for_fl_light_version = True + if not patches: img = otsu_copy_binary(img).astype(np.uint8) prediction_regions = None thresholding_for_fl_light_version = False @@ -1736,64 +1356,22 @@ class Eynollah: self.logger.debug("enter extract_text_regions") img_height_h = img.shape[0] img_width_h = img.shape[1] - model_region = self.models["region_fl"] if patches else self.models["region_fl_np"] - - if not patches: - img = otsu_copy_binary(img) - img = img.astype(np.uint8) - prediction_regions2 = None - elif cols: - if cols == 1: - img_height_new = int(img_height_h * 0.7) - img_width_new = int(img_width_h * 0.7) - elif cols == 2: - img_height_new = int(img_height_h * 0.4) - img_width_new = int(img_width_h * 0.4) - else: - img_height_new = int(img_height_h * 0.3) - img_width_new = int(img_width_h * 0.3) - img2 = otsu_copy_binary(img) - img2 = img2.astype(np.uint8) - img2 = resize_image(img2, img_height_new, img_width_new) - prediction_regions2 = self.do_prediction(patches, img2, model_region, marginal_of_patch_percent=0.1) - prediction_regions2 = resize_image(prediction_regions2, img_height_h, img_width_h) - - img = otsu_copy_binary(img).astype(np.uint8) - if cols == 1: - img = resize_image(img, int(img_height_h * 0.5), int(img_width_h * 0.5)).astype(np.uint8) - elif cols == 2 and img_width_h >= 2000: - img = resize_image(img, int(img_height_h * 0.9), int(img_width_h * 0.9)).astype(np.uint8) - elif cols == 3 and ((self.scale_x == 1 and img_width_h > 3000) or - (self.scale_x != 1 and img_width_h > 2800)): - img = resize_image(img, 2800 * img_height_h // img_width_h, 2800).astype(np.uint8) - elif cols == 4 and ((self.scale_x == 1 and img_width_h > 4000) or - (self.scale_x != 1 and img_width_h > 3700)): - img = resize_image(img, 3700 * img_height_h // img_width_h, 3700).astype(np.uint8) - elif cols == 4: - img = resize_image(img, int(img_height_h * 0.9), int(img_width_h * 0.9)).astype(np.uint8) - elif cols == 5 and self.scale_x == 1 and img_width_h > 5000: - img = resize_image(img, int(img_height_h * 0.7), int(img_width_h * 0.7)).astype(np.uint8) - elif cols == 5: - img = resize_image(img, int(img_height_h * 0.9), int(img_width_h * 0.9)).astype(np.uint8) - elif img_width_h > 5600: - img = resize_image(img, 5600 * img_height_h // img_width_h, 5600).astype(np.uint8) - else: - img = resize_image(img, int(img_height_h * 0.9), int(img_width_h * 0.9)).astype(np.uint8) + model_region = self.model_zoo.get("region_fl") if patches else self.model_zoo.get("region_fl_np") prediction_regions = self.do_prediction(patches, img, model_region, marginal_of_patch_percent=0.1) prediction_regions = resize_image(prediction_regions, img_height_h, img_width_h) self.logger.debug("exit extract_text_regions") - return prediction_regions, prediction_regions2 - + return prediction_regions, None + def get_textlines_of_a_textregion_sorted(self, textlines_textregion, cx_textline, cy_textline, w_h_textline): N = len(cy_textline) if N==0: return [] - + diff_cy = np.abs( np.diff(sorted(cy_textline)) ) diff_cx = np.abs(np.diff(sorted(cx_textline)) ) - + if len(diff_cy)>0: mean_y_diff = np.mean(diff_cy) mean_x_diff = np.mean(diff_cx) @@ -1805,13 +1383,13 @@ class Eynollah: mean_x_diff = 0 count_hor = 1 count_ver = 0 - + if count_hor >= count_ver: row_threshold = mean_y_diff / 1.5 if mean_y_diff > 0 else 10 indices_sorted_by_y = sorted(range(N), key=lambda i: cy_textline[i]) - + rows = [] current_row = [indices_sorted_by_y[0]] for i in range(1, N): @@ -1880,7 +1458,7 @@ class Eynollah: cx_textline_in, cy_textline_in, w_h_textlines_in) - + all_found_textline_polygons.append(textlines_ins)#[::-1]) slopes.append(slope_deskew) @@ -1891,40 +1469,6 @@ class Eynollah: all_box_coord, slopes) - def get_slopes_and_deskew_new_light(self, contours, contours_par, textline_mask_tot, boxes, slope_deskew): - if not len(contours): - return [], [], [] - self.logger.debug("enter get_slopes_and_deskew_new_light") - with share_ndarray(textline_mask_tot) as textline_mask_tot_shared: - results = self.executor.map(partial(do_work_of_slopes_new_light, - textline_mask_tot_ea=textline_mask_tot_shared, - slope_deskew=slope_deskew, - textline_light=self.textline_light, - logger=self.logger,), - boxes, contours, contours_par) - results = list(results) # exhaust prior to release - #textline_polygons, box_coord, slopes = zip(*results) - self.logger.debug("exit get_slopes_and_deskew_new_light") - return tuple(zip(*results)) - - def get_slopes_and_deskew_new(self, contours, contours_par, textline_mask_tot, boxes, slope_deskew): - if not len(contours): - return [], [], [] - self.logger.debug("enter get_slopes_and_deskew_new") - with share_ndarray(textline_mask_tot) as textline_mask_tot_shared: - results = self.executor.map(partial(do_work_of_slopes_new, - textline_mask_tot_ea=textline_mask_tot_shared, - slope_deskew=slope_deskew, - MAX_SLOPE=MAX_SLOPE, - KERNEL=KERNEL, - logger=self.logger, - plotter=self.plotter,), - boxes, contours, contours_par) - results = list(results) # exhaust prior to release - #textline_polygons, box_coord, slopes = zip(*results) - self.logger.debug("exit get_slopes_and_deskew_new") - return tuple(zip(*results)) - def get_slopes_and_deskew_new_curved(self, contours_par, textline_mask_tot, boxes, mask_texts_only, num_col, scale_par, slope_deskew): if not len(contours_par): @@ -1932,6 +1476,7 @@ class Eynollah: self.logger.debug("enter get_slopes_and_deskew_new_curved") with share_ndarray(textline_mask_tot) as textline_mask_tot_shared: with share_ndarray(mask_texts_only) as mask_texts_only_shared: + assert self.executor results = self.executor.map(partial(do_work_of_slopes_new_curved, textline_mask_tot_ea=textline_mask_tot_shared, mask_texts_only=mask_texts_only_shared, @@ -1957,203 +1502,33 @@ class Eynollah: img_w = img_org.shape[1] img = resize_image(img_org, int(img_org.shape[0] * scaler_h), int(img_org.shape[1] * scaler_w)) - prediction_textline = self.do_prediction(use_patches, img, self.models["textline"], + prediction_textline = self.do_prediction(use_patches, img, self.model_zoo.get("textline"), marginal_of_patch_percent=0.15, n_batch_inference=3, - thresholding_for_artificial_class_in_light_version=self.textline_light, threshold_art_class_textline=self.threshold_art_class_textline) - #if not self.textline_light: - #if num_col_classifier==1: - #prediction_textline_nopatch = self.do_prediction(False, img, self.models["textline"]) - #prediction_textline[:,:][prediction_textline_nopatch[:,:]==0] = 0 prediction_textline = resize_image(prediction_textline, img_h, img_w) textline_mask_tot_ea_art = (prediction_textline[:,:]==2)*1 old_art = np.copy(textline_mask_tot_ea_art) - if not self.textline_light: - textline_mask_tot_ea_art = textline_mask_tot_ea_art.astype('uint8') - #textline_mask_tot_ea_art = cv2.dilate(textline_mask_tot_ea_art, KERNEL, iterations=1) - prediction_textline[:,:][textline_mask_tot_ea_art[:,:]==1]=2 - """ - else: - textline_mask_tot_ea_art = textline_mask_tot_ea_art.astype('uint8') - hor_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (8, 1)) - - kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3)) - ##cv2.imwrite('textline_mask_tot_ea_art.png', textline_mask_tot_ea_art) - textline_mask_tot_ea_art = cv2.dilate(textline_mask_tot_ea_art, hor_kernel, iterations=1) - - ###cv2.imwrite('dil_textline_mask_tot_ea_art.png', dil_textline_mask_tot_ea_art) - - textline_mask_tot_ea_art = textline_mask_tot_ea_art.astype('uint8') - - #print(np.shape(dil_textline_mask_tot_ea_art), np.unique(dil_textline_mask_tot_ea_art), 'dil_textline_mask_tot_ea_art') - tsk = time.time() - skeleton_art_textline = skeletonize(textline_mask_tot_ea_art[:,:,0]) - - skeleton_art_textline = skeleton_art_textline*1 - - skeleton_art_textline = skeleton_art_textline.astype('uint8') - - skeleton_art_textline = cv2.dilate(skeleton_art_textline, kernel, iterations=1) - - #print(np.unique(skeleton_art_textline), np.shape(skeleton_art_textline)) - - #print(skeleton_art_textline, np.unique(skeleton_art_textline)) - - #cv2.imwrite('skeleton_art_textline.png', skeleton_art_textline) - - prediction_textline[:,:,0][skeleton_art_textline[:,:]==1]=2 - - #cv2.imwrite('prediction_textline1.png', prediction_textline[:,:,0]) - - ##hor_kernel2 = cv2.getStructuringElement(cv2.MORPH_RECT, (4, 1)) - ##ver_kernel2 = cv2.getStructuringElement(cv2.MORPH_RECT, (1, 3)) - ##textline_mask_tot_ea_main = (prediction_textline[:,:]==1)*1 - ##textline_mask_tot_ea_main = textline_mask_tot_ea_main.astype('uint8') - - ##dil_textline_mask_tot_ea_main = cv2.erode(textline_mask_tot_ea_main, ver_kernel2, iterations=1) - - ##dil_textline_mask_tot_ea_main = cv2.dilate(textline_mask_tot_ea_main, hor_kernel2, iterations=1) - - ##dil_textline_mask_tot_ea_main = cv2.dilate(textline_mask_tot_ea_main, ver_kernel2, iterations=1) - - ##prediction_textline[:,:][dil_textline_mask_tot_ea_main[:,:]==1]=1 - - """ - textline_mask_tot_ea_lines = (prediction_textline[:,:]==1)*1 textline_mask_tot_ea_lines = textline_mask_tot_ea_lines.astype('uint8') - if not self.textline_light: - textline_mask_tot_ea_lines = cv2.dilate(textline_mask_tot_ea_lines, KERNEL, iterations=1) prediction_textline[:,:][textline_mask_tot_ea_lines[:,:]==1]=1 - if not self.textline_light: - prediction_textline[:,:][old_art[:,:]==1]=2 - + #cv2.imwrite('prediction_textline2.png', prediction_textline[:,:,0]) - prediction_textline_longshot = self.do_prediction(False, img, self.models["textline"]) + prediction_textline_longshot = self.do_prediction(False, img, self.model_zoo.get("textline")) prediction_textline_longshot_true_size = resize_image(prediction_textline_longshot, img_h, img_w) - - + + #cv2.imwrite('prediction_textline.png', prediction_textline[:,:,0]) #sys.exit() self.logger.debug('exit textline_contours') return ((prediction_textline[:, :, 0]==1).astype(np.uint8), (prediction_textline_longshot_true_size[:, :, 0]==1).astype(np.uint8)) - - def get_regions_light_v_extract_only_images(self,img,is_image_enhanced, num_col_classifier): - self.logger.debug("enter get_regions_extract_images_only") - erosion_hurts = False - img_org = np.copy(img) - img_height_h = img_org.shape[0] - img_width_h = img_org.shape[1] - - if num_col_classifier == 1: - img_w_new = 700 - elif num_col_classifier == 2: - img_w_new = 900 - elif num_col_classifier == 3: - img_w_new = 1500 - elif num_col_classifier == 4: - img_w_new = 1800 - elif num_col_classifier == 5: - img_w_new = 2200 - elif num_col_classifier == 6: - img_w_new = 2500 - img_h_new = int(img.shape[0] / float(img.shape[1]) * img_w_new) - img_resized = resize_image(img,img_h_new, img_w_new ) - - prediction_regions_org, _ = self.do_prediction_new_concept(True, img_resized, self.models["region"]) - - prediction_regions_org = resize_image(prediction_regions_org,img_height_h, img_width_h ) - image_page, page_coord, cont_page = self.extract_page() - - prediction_regions_org = prediction_regions_org[page_coord[0] : page_coord[1], page_coord[2] : page_coord[3]] - prediction_regions_org=prediction_regions_org[:,:,0] - - mask_seps_only = (prediction_regions_org[:,:] == 3)*1 - mask_texts_only = (prediction_regions_org[:,:] ==1)*1 - mask_images_only=(prediction_regions_org[:,:] ==2)*1 - - polygons_seplines, hir_seplines = return_contours_of_image(mask_seps_only) - polygons_seplines = filter_contours_area_of_image( - mask_seps_only, polygons_seplines, hir_seplines, max_area=1, min_area=0.00001, dilate=1) - - polygons_of_only_texts = return_contours_of_interested_region(mask_texts_only,1,0.00001) - polygons_of_only_seps = return_contours_of_interested_region(mask_seps_only,1,0.00001) - - text_regions_p_true = np.zeros(prediction_regions_org.shape) - text_regions_p_true = cv2.fillPoly(text_regions_p_true, pts = polygons_of_only_seps, color=(3,3,3)) - - text_regions_p_true[:,:][mask_images_only[:,:] == 1] = 2 - text_regions_p_true = cv2.fillPoly(text_regions_p_true, pts=polygons_of_only_texts, color=(1,1,1)) - - text_regions_p_true[text_regions_p_true.shape[0]-15:text_regions_p_true.shape[0], :] = 0 - text_regions_p_true[:, text_regions_p_true.shape[1]-15:text_regions_p_true.shape[1]] = 0 - - ##polygons_of_images = return_contours_of_interested_region(text_regions_p_true, 2, 0.0001) - polygons_of_images = return_contours_of_interested_region(text_regions_p_true, 2, 0.001) - image_boundary_of_doc = np.zeros((text_regions_p_true.shape[0], text_regions_p_true.shape[1])) - - ###image_boundary_of_doc[:6, :] = 1 - ###image_boundary_of_doc[text_regions_p_true.shape[0]-6:text_regions_p_true.shape[0], :] = 1 - - ###image_boundary_of_doc[:, :6] = 1 - ###image_boundary_of_doc[:, text_regions_p_true.shape[1]-6:text_regions_p_true.shape[1]] = 1 - - polygons_of_images_fin = [] - for ploy_img_ind in polygons_of_images: - """ - test_poly_image = np.zeros((text_regions_p_true.shape[0], text_regions_p_true.shape[1])) - test_poly_image = cv2.fillPoly(test_poly_image, pts=[ploy_img_ind], color=(1,1,1)) - - test_poly_image = test_poly_image + image_boundary_of_doc - test_poly_image_intersected_area = ( test_poly_image[:,:]==2 )*1 - - test_poly_image_intersected_area = test_poly_image_intersected_area.sum() - - if test_poly_image_intersected_area==0: - ##polygons_of_images_fin.append(ploy_img_ind) - - box = cv2.boundingRect(ploy_img_ind) - page_coord_img = box2rect(box) - # cont_page.append(np.array([[page_coord[2], page_coord[0]], - # [page_coord[3], page_coord[0]], - # [page_coord[3], page_coord[1]], - # [page_coord[2], page_coord[1]]])) - polygons_of_images_fin.append(np.array([[page_coord_img[2], page_coord_img[0]], - [page_coord_img[3], page_coord_img[0]], - [page_coord_img[3], page_coord_img[1]], - [page_coord_img[2], page_coord_img[1]]]) ) - """ - box = x, y, w, h = cv2.boundingRect(ploy_img_ind) - if h < 150 or w < 150: - pass - else: - page_coord_img = box2rect(box) - # cont_page.append(np.array([[page_coord[2], page_coord[0]], - # [page_coord[3], page_coord[0]], - # [page_coord[3], page_coord[1]], - # [page_coord[2], page_coord[1]]])) - polygons_of_images_fin.append(np.array([[page_coord_img[2], page_coord_img[0]], - [page_coord_img[3], page_coord_img[0]], - [page_coord_img[3], page_coord_img[1]], - [page_coord_img[2], page_coord_img[1]]])) - - self.logger.debug("exit get_regions_extract_images_only") - return (text_regions_p_true, - erosion_hurts, - polygons_seplines, - polygons_of_images_fin, - image_page, - page_coord, - cont_page) - def get_regions_light_v(self,img,is_image_enhanced, num_col_classifier): self.logger.debug("enter get_regions_light_v") t_in = time.time() @@ -2184,7 +1559,7 @@ class Eynollah: #if self.input_binary: #img_bin = np.copy(img_resized) ###if (not self.input_binary and self.full_layout) or (not self.input_binary and num_col_classifier >= 30): - ###prediction_bin = self.do_prediction(True, img_resized, self.models["binarization"], n_batch_inference=5) + ###prediction_bin = self.do_prediction(True, img_resized, self.model_zoo.get_model("binarization"), n_batch_inference=5) ####print("inside bin ", time.time()-t_bin) ###prediction_bin=prediction_bin[:,:,0] @@ -2198,15 +1573,7 @@ class Eynollah: ###img_bin = np.copy(prediction_bin) ###else: ###img_bin = np.copy(img_resized) - if (self.ocr and self.tr) and not self.input_binary: - prediction_bin = self.do_prediction(True, img_resized, self.models["binarization"], n_batch_inference=5) - prediction_bin = 255 * (prediction_bin[:,:,0] == 0) - prediction_bin = np.repeat(prediction_bin[:, :, np.newaxis], 3, axis=2) - prediction_bin = prediction_bin.astype(np.uint16) - #img= np.copy(prediction_bin) - img_bin = np.copy(prediction_bin) - else: - img_bin = np.copy(img_resized) + img_bin = np.copy(img_resized) #print("inside 1 ", time.time()-t_in) ###textline_mask_tot_ea = self.run_textline(img_bin) @@ -2231,14 +1598,14 @@ class Eynollah: self.logger.debug("resized to %dx%d for %d cols", img_resized.shape[1], img_resized.shape[0], num_col_classifier) prediction_regions_org, confidence_matrix = self.do_prediction_new_concept( - True, img_resized, self.models["region_1_2"], n_batch_inference=1, + True, img_resized, self.model_zoo.get("region_1_2"), n_batch_inference=1, thresholding_for_some_classes_in_light_version=True, threshold_art_class_layout=self.threshold_art_class_layout) else: prediction_regions_org = np.zeros((self.image_org.shape[0], self.image_org.shape[1], 3)) confidence_matrix = np.zeros((self.image_org.shape[0], self.image_org.shape[1])) prediction_regions_page, confidence_matrix_page = self.do_prediction_new_concept( - False, self.image_page_org_size, self.models["region_1_2"], n_batch_inference=1, + False, self.image_page_org_size, self.model_zoo.get("region_1_2"), n_batch_inference=1, thresholding_for_artificial_class_in_light_version=True, threshold_art_class_layout=self.threshold_art_class_layout) ys = slice(*self.page_coord[0:2]) @@ -2252,10 +1619,10 @@ class Eynollah: self.logger.debug("resized to %dx%d (new_h=%d) for %d cols", img_resized.shape[1], img_resized.shape[0], new_h, num_col_classifier) prediction_regions_org, confidence_matrix = self.do_prediction_new_concept( - True, img_resized, self.models["region_1_2"], n_batch_inference=2, + True, img_resized, self.model_zoo.get("region_1_2"), n_batch_inference=2, thresholding_for_some_classes_in_light_version=True, threshold_art_class_layout=self.threshold_art_class_layout) - ###prediction_regions_org = self.do_prediction(True, img_bin, self.models["region"], + ###prediction_regions_org = self.do_prediction(True, img_bin, self.model_zoo.get_model("region"), ###n_batch_inference=3, ###thresholding_for_some_classes_in_light_version=True) #print("inside 3 ", time.time()-t_in) @@ -2324,144 +1691,6 @@ class Eynollah: img_bin, confidence_matrix) - def get_regions_from_xy_2models(self,img,is_image_enhanced, num_col_classifier): - self.logger.debug("enter get_regions_from_xy_2models") - erosion_hurts = False - img_org = np.copy(img) - img_height_h = img_org.shape[0] - img_width_h = img_org.shape[1] - - ratio_y=1.3 - ratio_x=1 - - img = resize_image(img_org, int(img_org.shape[0]*ratio_y), int(img_org.shape[1]*ratio_x)) - prediction_regions_org_y = self.do_prediction(True, img, self.models["region"]) - prediction_regions_org_y = resize_image(prediction_regions_org_y, img_height_h, img_width_h ) - - #plt.imshow(prediction_regions_org_y[:,:,0]) - #plt.show() - prediction_regions_org_y = prediction_regions_org_y[:,:,0] - mask_zeros_y = (prediction_regions_org_y[:,:]==0)*1 - - ##img_only_regions_with_sep = ( (prediction_regions_org_y[:,:] != 3) & (prediction_regions_org_y[:,:] != 0) )*1 - img_only_regions_with_sep = (prediction_regions_org_y == 1).astype(np.uint8) - try: - img_only_regions = cv2.erode(img_only_regions_with_sep[:,:], KERNEL, iterations=20) - img = resize_image(img_org, int(img_org.shape[0]), int(img_org.shape[1]*(1.2 if is_image_enhanced else 1))) - - prediction_regions_org = self.do_prediction(True, img, self.models["region"]) - prediction_regions_org = resize_image(prediction_regions_org, img_height_h, img_width_h ) - - prediction_regions_org=prediction_regions_org[:,:,0] - prediction_regions_org[(prediction_regions_org[:,:]==1) & (mask_zeros_y[:,:]==1)]=0 - - img = resize_image(img_org, int(img_org.shape[0]), int(img_org.shape[1])) - - prediction_regions_org2 = self.do_prediction(True, img, self.models["region_p2"], marginal_of_patch_percent=0.2) - prediction_regions_org2=resize_image(prediction_regions_org2, img_height_h, img_width_h ) - - mask_zeros2 = (prediction_regions_org2[:,:,0] == 0) - mask_seps2 = (prediction_regions_org2[:,:,0] == 3) - text_sume_early = (prediction_regions_org[:,:] == 1).sum() - prediction_regions_org_copy = np.copy(prediction_regions_org) - prediction_regions_org_copy[(prediction_regions_org_copy[:,:]==1) & (mask_zeros2[:,:]==1)] = 0 - text_sume_second = ((prediction_regions_org_copy[:,:]==1)*1).sum() - rate_two_models = 100. * text_sume_second / text_sume_early - - self.logger.info("ratio_of_two_models: %s", rate_two_models) - if not(is_image_enhanced and rate_two_models < RATIO_OF_TWO_MODEL_THRESHOLD): - prediction_regions_org = np.copy(prediction_regions_org_copy) - - prediction_regions_org[(mask_seps2[:,:]==1) & (prediction_regions_org[:,:]==0)]=3 - mask_seps_only=(prediction_regions_org[:,:]==3)*1 - prediction_regions_org = cv2.erode(prediction_regions_org[:,:], KERNEL, iterations=2) - prediction_regions_org = cv2.dilate(prediction_regions_org[:,:], KERNEL, iterations=2) - - if rate_two_models<=40: - if self.input_binary: - prediction_bin = np.copy(img_org) - else: - prediction_bin = self.do_prediction(True, img_org, self.models["binarization"], n_batch_inference=5) - prediction_bin = resize_image(prediction_bin, img_height_h, img_width_h ) - prediction_bin = 255 * (prediction_bin[:,:,0]==0) - prediction_bin = np.repeat(prediction_bin[:, :, np.newaxis], 3, axis=2) - - ratio_y=1 - ratio_x=1 - - img = resize_image(prediction_bin, int(img_org.shape[0]*ratio_y), int(img_org.shape[1]*ratio_x)) - - prediction_regions_org = self.do_prediction(True, img, self.models["region"]) - prediction_regions_org = resize_image(prediction_regions_org, img_height_h, img_width_h ) - prediction_regions_org=prediction_regions_org[:,:,0] - - mask_seps_only=(prediction_regions_org[:,:]==3)*1 - - mask_texts_only=(prediction_regions_org[:,:]==1)*1 - mask_images_only=(prediction_regions_org[:,:]==2)*1 - - polygons_seplines, hir_seplines = return_contours_of_image(mask_seps_only) - polygons_seplines = filter_contours_area_of_image( - mask_seps_only, polygons_seplines, hir_seplines, max_area=1, min_area=0.00001, dilate=1) - - polygons_of_only_texts = return_contours_of_interested_region(mask_texts_only, 1, 0.00001) - polygons_of_only_seps = return_contours_of_interested_region(mask_seps_only, 1, 0.00001) - - text_regions_p_true = np.zeros(prediction_regions_org.shape) - text_regions_p_true = cv2.fillPoly(text_regions_p_true,pts = polygons_of_only_seps, color=(3, 3, 3)) - text_regions_p_true[:,:][mask_images_only[:,:] == 1] = 2 - - text_regions_p_true=cv2.fillPoly(text_regions_p_true,pts=polygons_of_only_texts, color=(1,1,1)) - - self.logger.debug("exit get_regions_from_xy_2models") - return text_regions_p_true, erosion_hurts, polygons_seplines, polygons_of_only_texts - except: - if self.input_binary: - prediction_bin = np.copy(img_org) - prediction_bin = self.do_prediction(True, img_org, self.models["binarization"], n_batch_inference=5) - prediction_bin = resize_image(prediction_bin, img_height_h, img_width_h ) - prediction_bin = 255 * (prediction_bin[:,:,0]==0) - prediction_bin = np.repeat(prediction_bin[:, :, np.newaxis], 3, axis=2) - else: - prediction_bin = np.copy(img_org) - ratio_y=1 - ratio_x=1 - - - img = resize_image(prediction_bin, int(img_org.shape[0]*ratio_y), int(img_org.shape[1]*ratio_x)) - prediction_regions_org = self.do_prediction(True, img, self.models["region"]) - prediction_regions_org = resize_image(prediction_regions_org, img_height_h, img_width_h ) - prediction_regions_org=prediction_regions_org[:,:,0] - - #mask_seps_only=(prediction_regions_org[:,:]==3)*1 - #img = resize_image(img_org, int(img_org.shape[0]*1), int(img_org.shape[1]*1)) - - #prediction_regions_org = self.do_prediction(True, img, self.models["region"]) - #prediction_regions_org = resize_image(prediction_regions_org, img_height_h, img_width_h ) - #prediction_regions_org = prediction_regions_org[:,:,0] - #prediction_regions_org[(prediction_regions_org[:,:] == 1) & (mask_zeros_y[:,:] == 1)]=0 - - mask_seps_only = (prediction_regions_org == 3)*1 - mask_texts_only = (prediction_regions_org == 1)*1 - mask_images_only= (prediction_regions_org == 2)*1 - - polygons_seplines, hir_seplines = return_contours_of_image(mask_seps_only) - polygons_seplines = filter_contours_area_of_image( - mask_seps_only, polygons_seplines, hir_seplines, max_area=1, min_area=0.00001, dilate=1) - - polygons_of_only_texts = return_contours_of_interested_region(mask_texts_only,1,0.00001) - polygons_of_only_seps = return_contours_of_interested_region(mask_seps_only,1,0.00001) - - text_regions_p_true = np.zeros(prediction_regions_org.shape) - text_regions_p_true = cv2.fillPoly(text_regions_p_true, pts = polygons_of_only_seps, color=(3,3,3)) - - text_regions_p_true[:,:][mask_images_only[:,:] == 1] = 2 - text_regions_p_true = cv2.fillPoly(text_regions_p_true, pts = polygons_of_only_texts, color=(1,1,1)) - - erosion_hurts = True - self.logger.debug("exit get_regions_from_xy_2models") - return text_regions_p_true, erosion_hurts, polygons_seplines, polygons_of_only_texts - def do_order_of_regions( self, contours_only_text_parent, contours_only_text_parent_h, boxes, textline_mask_tot): @@ -2693,7 +1922,7 @@ class Eynollah: img_comm = cv2.fillPoly(img_comm, pts=main_contours, color=indiv) - if not self.isNaN(slope_mean_hor): + if not isNaN(slope_mean_hor): image_revised_last = np.zeros(image_regions_eraly_p.shape[:2]) for i in range(len(boxes)): box_ys = slice(*boxes[i][2:4]) @@ -2792,92 +2021,9 @@ class Eynollah: img_height_h = img_org.shape[0] img_width_h = img_org.shape[1] patches = False - if self.light_version: - prediction_table, _ = self.do_prediction_new_concept(patches, img, self.models["table"]) - prediction_table = prediction_table.astype(np.int16) - return prediction_table[:,:,0] - else: - if num_col_classifier < 4 and num_col_classifier > 2: - prediction_table = self.do_prediction(patches, img, self.models["table"]) - pre_updown = self.do_prediction(patches, cv2.flip(img[:,:,:], -1), self.models["table"]) - pre_updown = cv2.flip(pre_updown, -1) - - prediction_table[:,:,0][pre_updown[:,:,0]==1]=1 - prediction_table = prediction_table.astype(np.int16) - - elif num_col_classifier ==2: - height_ext = 0 # img.shape[0] // 4 - h_start = height_ext // 2 - width_ext = img.shape[1] // 8 - w_start = width_ext // 2 - - img_new = np.zeros((img.shape[0] + height_ext, - img.shape[1] + width_ext, - img.shape[2])).astype(float) - ys = slice(h_start, h_start + img.shape[0]) - xs = slice(w_start, w_start + img.shape[1]) - img_new[ys, xs] = img - - prediction_ext = self.do_prediction(patches, img_new, self.models["table"]) - pre_updown = self.do_prediction(patches, cv2.flip(img_new[:,:,:], -1), self.models["table"]) - pre_updown = cv2.flip(pre_updown, -1) - - prediction_table = prediction_ext[ys, xs] - prediction_table_updown = pre_updown[ys, xs] - - prediction_table[:,:,0][prediction_table_updown[:,:,0]==1]=1 - prediction_table = prediction_table.astype(np.int16) - elif num_col_classifier ==1: - height_ext = 0 # img.shape[0] // 4 - h_start = height_ext // 2 - width_ext = img.shape[1] // 4 - w_start = width_ext // 2 - - img_new =np.zeros((img.shape[0] + height_ext, - img.shape[1] + width_ext, - img.shape[2])).astype(float) - ys = slice(h_start, h_start + img.shape[0]) - xs = slice(w_start, w_start + img.shape[1]) - img_new[ys, xs] = img - - prediction_ext = self.do_prediction(patches, img_new, self.models["table"]) - pre_updown = self.do_prediction(patches, cv2.flip(img_new[:,:,:], -1), self.models["table"]) - pre_updown = cv2.flip(pre_updown, -1) - - prediction_table = prediction_ext[ys, xs] - prediction_table_updown = pre_updown[ys, xs] - - prediction_table[:,:,0][prediction_table_updown[:,:,0]==1]=1 - prediction_table = prediction_table.astype(np.int16) - else: - prediction_table = np.zeros(img.shape) - img_w_half = img.shape[1] // 2 - - pre1 = self.do_prediction(patches, img[:,0:img_w_half,:], self.models["table"]) - pre2 = self.do_prediction(patches, img[:,img_w_half:,:], self.models["table"]) - pre_full = self.do_prediction(patches, img[:,:,:], self.models["table"]) - pre_updown = self.do_prediction(patches, cv2.flip(img[:,:,:], -1), self.models["table"]) - pre_updown = cv2.flip(pre_updown, -1) - - prediction_table_full_erode = cv2.erode(pre_full[:,:,0], KERNEL, iterations=4) - prediction_table_full_erode = cv2.dilate(prediction_table_full_erode, KERNEL, iterations=4) - - prediction_table_full_updown_erode = cv2.erode(pre_updown[:,:,0], KERNEL, iterations=4) - prediction_table_full_updown_erode = cv2.dilate(prediction_table_full_updown_erode, KERNEL, iterations=4) - - prediction_table[:,0:img_w_half,:] = pre1[:,:,:] - prediction_table[:,img_w_half:,:] = pre2[:,:,:] - - prediction_table[:,:,0][prediction_table_full_erode[:,:]==1]=1 - prediction_table[:,:,0][prediction_table_full_updown_erode[:,:]==1]=1 - prediction_table = prediction_table.astype(np.int16) - - #prediction_table_erode = cv2.erode(prediction_table[:,:,0], self.kernel, iterations=6) - #prediction_table_erode = cv2.dilate(prediction_table_erode, self.kernel, iterations=6) - - prediction_table_erode = cv2.erode(prediction_table[:,:,0], KERNEL, iterations=20) - prediction_table_erode = cv2.dilate(prediction_table_erode, KERNEL, iterations=20) - return prediction_table_erode.astype(np.int16) + prediction_table, _ = self.do_prediction_new_concept(patches, img, self.model_zoo.get("table")) + prediction_table = prediction_table.astype(np.int16) + return prediction_table[:,:,0] def run_graphics_and_columns_light( self, text_regions_p_1, textline_mask_tot_ea, @@ -2903,18 +2049,18 @@ class Eynollah: if self.plotter: self.plotter.save_page_image(image_page) - + if not self.ignore_page_extraction: mask_page = np.zeros((text_regions_p_1.shape[0], text_regions_p_1.shape[1])).astype(np.int8) mask_page = cv2.fillPoly(mask_page, pts=[cont_page[0]], color=(1,)) - + text_regions_p_1[mask_page==0] = 0 textline_mask_tot_ea[mask_page==0] = 0 - + text_regions_p_1 = text_regions_p_1[page_coord[0] : page_coord[1], page_coord[2] : page_coord[3]] textline_mask_tot_ea = textline_mask_tot_ea[page_coord[0] : page_coord[1], page_coord[2] : page_coord[3]] img_bin_light = img_bin_light[page_coord[0] : page_coord[1], page_coord[2] : page_coord[3]] - + ###text_regions_p_1 = text_regions_p_1[page_coord[0] : page_coord[1], page_coord[2] : page_coord[3]] ###textline_mask_tot_ea = textline_mask_tot_ea[page_coord[0] : page_coord[1], page_coord[2] : page_coord[3]] ###img_bin_light = img_bin_light[page_coord[0] : page_coord[1], page_coord[2] : page_coord[3]] @@ -2972,58 +2118,12 @@ class Eynollah: return page_coord, image_page, textline_mask_tot_ea, img_bin_light, cont_page - def run_graphics_and_columns( - self, text_regions_p_1, - num_col_classifier, num_column_is_classified, erosion_hurts): - t_in_gr = time.time() - img_g = self.imread(grayscale=True, uint8=True) - - img_g3 = np.zeros((img_g.shape[0], img_g.shape[1], 3)) - img_g3 = img_g3.astype(np.uint8) - img_g3[:, :, 0] = img_g[:, :] - img_g3[:, :, 1] = img_g[:, :] - img_g3[:, :, 2] = img_g[:, :] - - image_page, page_coord, cont_page = self.extract_page() - - if self.tables: - table_prediction = self.get_tables_from_model(image_page, num_col_classifier) - else: - table_prediction = np.zeros((image_page.shape[0], image_page.shape[1])).astype(np.int16) - - if self.plotter: - self.plotter.save_page_image(image_page) - - text_regions_p_1 = text_regions_p_1[page_coord[0] : page_coord[1], page_coord[2] : page_coord[3]] - mask_images = (text_regions_p_1[:, :] == 2) * 1 - mask_images = mask_images.astype(np.uint8) - mask_images = cv2.erode(mask_images[:, :], KERNEL, iterations=10) - mask_seps = (text_regions_p_1[:, :] == 3) * 1 - mask_seps = mask_seps.astype(np.uint8) - img_only_regions_with_sep = ((text_regions_p_1[:, :] != 3) & (text_regions_p_1[:, :] != 0)) * 1 - img_only_regions_with_sep = img_only_regions_with_sep.astype(np.uint8) - - if erosion_hurts: - img_only_regions = np.copy(img_only_regions_with_sep[:,:]) - else: - img_only_regions = cv2.erode(img_only_regions_with_sep[:,:], KERNEL, iterations=6) - try: - num_col, _ = find_num_col(img_only_regions, num_col_classifier, self.tables, multiplier=6.0) - num_col = num_col + 1 - if not num_column_is_classified: - num_col_classifier = num_col + 1 - except Exception as why: - self.logger.exception(why) - num_col = None - return (num_col, num_col_classifier, img_only_regions, page_coord, image_page, mask_images, mask_seps, - text_regions_p_1, cont_page, table_prediction) - - def run_enhancement(self, light_version): + def run_enhancement(self): t_in = time.time() self.logger.info("Resizing and enhancing image...") is_image_enhanced, img_org, img_res, num_col_classifier, num_column_is_classified, img_bin = \ - self.resize_and_enhance_image_with_column_classifier(light_version) + self.resize_and_enhance_image_with_column_classifier() self.logger.info("Image was %senhanced.", '' if is_image_enhanced else 'not ') scale = 1 if is_image_enhanced: @@ -3035,10 +2135,7 @@ class Eynollah: else: self.get_image_and_scales_after_enhancing(img_org, img_res) else: - if self.allow_enhancement: - self.get_image_and_scales(img_org, img_res, scale) - else: - self.get_image_and_scales(img_org, img_res, scale) + self.get_image_and_scales(img_org, img_res, scale) if self.allow_scaling: img_org, img_res, is_image_enhanced = \ self.resize_image_with_column_classifier(is_image_enhanced, img_bin) @@ -3054,8 +2151,7 @@ class Eynollah: scaler_h_textline, scaler_w_textline, num_col_classifier) - if self.textline_light: - textline_mask_tot_ea = textline_mask_tot_ea.astype(np.int16) + textline_mask_tot_ea = textline_mask_tot_ea.astype(np.int16) if self.plotter: self.plotter.save_plot_of_textlines(textline_mask_tot_ea, image_page) @@ -3088,7 +2184,7 @@ class Eynollah: regions_without_separators = regions_without_separators.astype(np.uint8) text_regions_p = get_marginals( rotate_image(regions_without_separators, slope_deskew), text_regions_p, - num_col_classifier, slope_deskew, light_version=self.light_version, kernel=KERNEL) + num_col_classifier, slope_deskew, kernel=KERNEL) except Exception as e: self.logger.error("exception %s", e) @@ -3145,20 +2241,6 @@ class Eynollah: self.logger.debug("len(boxes): %s", len(boxes)) #print(time.time()-t_0_box,'time box in 3.1') - if self.tables: - if self.light_version: - pass - else: - text_regions_p_tables = np.copy(text_regions_p) - text_regions_p_tables[(table_prediction == 1)] = 10 - label_seps = 3 - img_revised_tab2 = self.add_tables_heuristic_to_layout( - text_regions_p_tables, boxes, 0, splitter_y_new, peaks_neg_tot_tables, text_regions_p_tables, - num_col_classifier , 0.000005, label_seps) - #print(time.time()-t_0_box,'time box in 3.2') - img_revised_tab2, contoures_tables = self.check_iou_of_bounding_box_and_contour_for_tables( - img_revised_tab2, table_prediction, 10, num_col_classifier) - #print(time.time()-t_0_box,'time box in 3.3') else: boxes_d, peaks_neg_tot_tables_d = return_boxes_of_images_by_order_of_reading_new( splitter_y_new_d, regions_without_separators_d, text_regions_p_d, matrix_of_seps_ch_d, @@ -3166,63 +2248,24 @@ class Eynollah: boxes = None self.logger.debug("len(boxes): %s", len(boxes_d)) - if self.tables: - if self.light_version: - pass - else: - text_regions_p_tables = np.copy(text_regions_p_d) - text_regions_p_tables = np.round(text_regions_p_tables) - text_regions_p_tables[(text_regions_p_tables != 3) & (table_prediction_n == 1)] = 10 - - label_seps = 3 - img_revised_tab2 = self.add_tables_heuristic_to_layout( - text_regions_p_tables, boxes_d, 0, splitter_y_new_d, - peaks_neg_tot_tables_d, text_regions_p_tables, - num_col_classifier, 0.000005, label_seps) - img_revised_tab2_d,_ = self.check_iou_of_bounding_box_and_contour_for_tables( - img_revised_tab2, table_prediction_n, 10, num_col_classifier) - - img_revised_tab2_d_rotated = rotate_image(img_revised_tab2_d, -slope_deskew) - img_revised_tab2_d_rotated = np.round(img_revised_tab2_d_rotated) - img_revised_tab2_d_rotated = img_revised_tab2_d_rotated.astype(np.int8) - img_revised_tab2_d_rotated = resize_image(img_revised_tab2_d_rotated, - text_regions_p.shape[0], text_regions_p.shape[1]) #print(time.time()-t_0_box,'time box in 4') self.logger.info("detecting boxes took %.1fs", time.time() - t1) if self.tables: - if self.light_version: - text_regions_p[table_prediction == 1] = 10 - img_revised_tab = text_regions_p[:,:] - else: - if np.abs(slope_deskew) < SLOPE_THRESHOLD: - img_revised_tab = np.copy(img_revised_tab2) - img_revised_tab[(text_regions_p == 1) & (img_revised_tab != 10)] = 1 - else: - img_revised_tab = np.copy(text_regions_p) - img_revised_tab[img_revised_tab == 10] = 0 - img_revised_tab[img_revised_tab2_d_rotated == 10] = 10 - - text_regions_p[text_regions_p == 10] = 0 - text_regions_p[img_revised_tab == 10] = 10 + text_regions_p[table_prediction == 1] = 10 + img_revised_tab = text_regions_p[:,:] else: img_revised_tab = text_regions_p[:,:] #img_revised_tab = text_regions_p[:, :] - if self.light_version: - polygons_of_images = return_contours_of_interested_region(text_regions_p, 2) - else: - polygons_of_images = return_contours_of_interested_region(img_revised_tab, 2) + polygons_of_images = return_contours_of_interested_region(text_regions_p, 2) label_marginalia = 4 min_area_mar = 0.00001 - if self.light_version: - marginal_mask = (text_regions_p[:,:]==label_marginalia)*1 - marginal_mask = marginal_mask.astype('uint8') - marginal_mask = cv2.dilate(marginal_mask, KERNEL, iterations=2) + marginal_mask = (text_regions_p[:,:]==label_marginalia)*1 + marginal_mask = marginal_mask.astype('uint8') + marginal_mask = cv2.dilate(marginal_mask, KERNEL, iterations=2) - polygons_of_marginals = return_contours_of_interested_region(marginal_mask, 1, min_area_mar) - else: - polygons_of_marginals = return_contours_of_interested_region(text_regions_p, label_marginalia, min_area_mar) + polygons_of_marginals = return_contours_of_interested_region(marginal_mask, 1, min_area_mar) label_tables = 10 contours_tables = return_contours_of_interested_region(text_regions_p, label_tables, min_area_mar) @@ -3240,117 +2283,33 @@ class Eynollah: self.logger.debug('enter run_boxes_full_layout') t_full0 = time.time() if self.tables: - if self.light_version: - text_regions_p[:,:][table_prediction[:,:]==1] = 10 - img_revised_tab = text_regions_p[:,:] - if np.abs(slope_deskew) >= SLOPE_THRESHOLD: - textline_mask_tot_d = rotate_image(textline_mask_tot, slope_deskew) - text_regions_p_d = rotate_image(text_regions_p, slope_deskew) - table_prediction_n = rotate_image(table_prediction, slope_deskew) - regions_without_separators_d = (text_regions_p_d[:,:] == 1)*1 - regions_without_separators_d[table_prediction_n[:,:] == 1] = 1 - else: - text_regions_p_d = None - textline_mask_tot_d = None - regions_without_separators_d = None - # regions_without_separators = ( text_regions_p[:,:]==1 | text_regions_p[:,:]==2 )*1 - #self.return_regions_without_separators_new(text_regions_p[:,:,0],img_only_regions) - regions_without_separators = (text_regions_p[:,:] == 1)*1 - regions_without_separators[table_prediction == 1] = 1 - + text_regions_p[:,:][table_prediction[:,:]==1] = 10 + img_revised_tab = text_regions_p[:,:] + if np.abs(slope_deskew) >= SLOPE_THRESHOLD: + textline_mask_tot_d = rotate_image(textline_mask_tot, slope_deskew) + text_regions_p_d = rotate_image(text_regions_p, slope_deskew) + table_prediction_n = rotate_image(table_prediction, slope_deskew) + regions_without_separators_d = (text_regions_p_d[:,:] == 1)*1 + regions_without_separators_d[table_prediction_n[:,:] == 1] = 1 else: - if np.abs(slope_deskew) >= SLOPE_THRESHOLD: - textline_mask_tot_d = rotate_image(textline_mask_tot, slope_deskew) - text_regions_p_d = rotate_image(text_regions_p, slope_deskew) - table_prediction_n = rotate_image(table_prediction, slope_deskew) - regions_without_separators_d = (text_regions_p_d[:,:] == 1)*1 - regions_without_separators_d[table_prediction_n[:,:] == 1] = 1 - else: - text_regions_p_d = None - textline_mask_tot_d = None - regions_without_separators_d = None + text_regions_p_d = None + textline_mask_tot_d = None + regions_without_separators_d = None + # regions_without_separators = ( text_regions_p[:,:]==1 | text_regions_p[:,:]==2 )*1 + #self.return_regions_without_separators_new(text_regions_p[:,:,0],img_only_regions) + regions_without_separators = (text_regions_p[:,:] == 1)*1 + regions_without_separators[table_prediction == 1] = 1 - # regions_without_separators = ( text_regions_p[:,:]==1 | text_regions_p[:,:]==2 )*1 - #self.return_regions_without_separators_new(text_regions_p[:,:,0],img_only_regions) - regions_without_separators = (text_regions_p[:,:] == 1)*1 - regions_without_separators[table_prediction == 1] = 1 - label_seps=3 - if np.abs(slope_deskew) < SLOPE_THRESHOLD: - num_col, _, matrix_of_seps_ch, splitter_y_new, _ = find_number_of_columns_in_document( - text_regions_p, num_col_classifier, self.tables, label_seps) - if not erosion_hurts: - regions_without_separators = regions_without_separators.astype(np.uint8) - regions_without_separators = cv2.erode(regions_without_separators[:,:], KERNEL, iterations=6) - - else: - num_col_d, _, matrix_of_seps_ch_d, splitter_y_new_d, _ = find_number_of_columns_in_document( - text_regions_p_d, num_col_classifier, self.tables, label_seps) - if not erosion_hurts: - regions_without_separators_d = regions_without_separators_d.astype(np.uint8) - regions_without_separators_d = cv2.erode(regions_without_separators_d[:,:], KERNEL, iterations=6) - - if np.abs(slope_deskew) < SLOPE_THRESHOLD: - boxes, peaks_neg_tot_tables = return_boxes_of_images_by_order_of_reading_new( - splitter_y_new, regions_without_separators, text_regions_p, matrix_of_seps_ch, - num_col_classifier, erosion_hurts, self.tables, self.right2left) - text_regions_p_tables = np.copy(text_regions_p) - text_regions_p_tables[:,:][(table_prediction[:,:]==1)] = 10 - label_seps = 3 - img_revised_tab2 = self.add_tables_heuristic_to_layout( - text_regions_p_tables, boxes, 0, splitter_y_new, peaks_neg_tot_tables, text_regions_p_tables, - num_col_classifier , 0.000005, label_seps) - - img_revised_tab2,contoures_tables = self.check_iou_of_bounding_box_and_contour_for_tables( - img_revised_tab2, table_prediction, 10, num_col_classifier) - else: - boxes_d, peaks_neg_tot_tables_d = return_boxes_of_images_by_order_of_reading_new( - splitter_y_new_d, regions_without_separators_d, text_regions_p_d, matrix_of_seps_ch_d, - num_col_classifier, erosion_hurts, self.tables, self.right2left) - text_regions_p_tables = np.copy(text_regions_p_d) - text_regions_p_tables = np.round(text_regions_p_tables) - text_regions_p_tables[(text_regions_p_tables != 3) & (table_prediction_n == 1)] = 10 - - label_seps = 3 - img_revised_tab2 = self.add_tables_heuristic_to_layout( - text_regions_p_tables, boxes_d, 0, splitter_y_new_d, - peaks_neg_tot_tables_d, text_regions_p_tables, - num_col_classifier, 0.000005, label_seps) - - img_revised_tab2_d,_ = self.check_iou_of_bounding_box_and_contour_for_tables( - img_revised_tab2, table_prediction_n, 10, num_col_classifier) - img_revised_tab2_d_rotated = rotate_image(img_revised_tab2_d, -slope_deskew) - - img_revised_tab2_d_rotated = np.round(img_revised_tab2_d_rotated) - img_revised_tab2_d_rotated = img_revised_tab2_d_rotated.astype(np.int8) - img_revised_tab2_d_rotated = resize_image(img_revised_tab2_d_rotated, - text_regions_p.shape[0], - text_regions_p.shape[1]) - - if np.abs(slope_deskew) < 0.13: - img_revised_tab = np.copy(img_revised_tab2) - else: - img_revised_tab = np.copy(text_regions_p) - img_revised_tab[img_revised_tab == 10] = 0 - img_revised_tab[img_revised_tab2_d_rotated == 10] = 10 - - ##img_revised_tab = img_revised_tab2[:,:] - #img_revised_tab = text_regions_p[:,:] - text_regions_p[text_regions_p == 10] = 0 - text_regions_p[img_revised_tab == 10] = 10 - #img_revised_tab[img_revised_tab2 == 10] = 10 label_marginalia = 4 min_area_mar = 0.00001 - if self.light_version: - marginal_mask = (text_regions_p[:,:]==label_marginalia)*1 - marginal_mask = marginal_mask.astype('uint8') - marginal_mask = cv2.dilate(marginal_mask, KERNEL, iterations=2) + marginal_mask = (text_regions_p[:,:]==label_marginalia)*1 + marginal_mask = marginal_mask.astype('uint8') + marginal_mask = cv2.dilate(marginal_mask, KERNEL, iterations=2) - polygons_of_marginals = return_contours_of_interested_region(marginal_mask, 1, min_area_mar) - else: - polygons_of_marginals = return_contours_of_interested_region(text_regions_p, label_marginalia, min_area_mar) + polygons_of_marginals = return_contours_of_interested_region(marginal_mask, 1, min_area_mar) label_tables = 10 contours_tables = return_contours_of_interested_region(text_regions_p, label_tables, min_area_mar) @@ -3363,13 +2322,13 @@ class Eynollah: image_page = image_page.astype(np.uint8) #print("full inside 1", time.time()- t_full0) regions_fully, regions_fully_only_drop = self.extract_text_regions_new( - img_bin_light if self.light_version else image_page, + img_bin_light, False, cols=num_col_classifier) #print("full inside 2", time.time()- t_full0) # 6 is the separators lable in old full layout model # 4 is the drop capital class in old full layout model # in the new full layout drop capital is 3 and separators are 5 - + # the separators in full layout will not be written on layout if not self.reading_order_machine_based: text_regions_p[:,:][regions_fully[:,:,0]==5]=6 @@ -3427,7 +2386,7 @@ class Eynollah: polygons_of_marginals, contours_tables) def do_order_of_regions_with_model(self, contours_only_text_parent, contours_only_text_parent_h, text_regions_p): - + height1 =672#448 width1 = 448#224 @@ -3436,33 +2395,33 @@ class Eynollah: height3 =672#448 width3 = 448#224 - + inference_bs = 3 - + ver_kernel = np.ones((5, 1), dtype=np.uint8) hor_kernel = np.ones((1, 5), dtype=np.uint8) - - + + min_cont_size_to_be_dilated = 10 - if len(contours_only_text_parent)>min_cont_size_to_be_dilated and self.light_version: + if len(contours_only_text_parent)>min_cont_size_to_be_dilated: (cx_conts, cy_conts, x_min_conts, x_max_conts, y_min_conts, y_max_conts, _) = find_new_features_of_contours(contours_only_text_parent) args_cont_located = np.array(range(len(contours_only_text_parent))) - + diff_y_conts = np.abs(y_max_conts[:]-y_min_conts) diff_x_conts = np.abs(x_max_conts[:]-x_min_conts) - + mean_x = statistics.mean(diff_x_conts) median_x = statistics.median(diff_x_conts) - - + + diff_x_ratio= diff_x_conts/mean_x - + args_cont_located_excluded = args_cont_located[diff_x_ratio>=1.3] args_cont_located_included = args_cont_located[diff_x_ratio<1.3] - + contours_only_text_parent_excluded = [contours_only_text_parent[ind] #contours_only_text_parent[diff_x_ratio>=1.3] for ind in range(len(contours_only_text_parent)) @@ -3471,7 +2430,7 @@ class Eynollah: #contours_only_text_parent[diff_x_ratio<1.3] for ind in range(len(contours_only_text_parent)) if diff_x_ratio[ind]<1.3] - + cx_conts_excluded = [cx_conts[ind] #cx_conts[diff_x_ratio>=1.3] for ind in range(len(cx_conts)) @@ -3488,43 +2447,43 @@ class Eynollah: #cy_conts[diff_x_ratio<1.3] for ind in range(len(cy_conts)) if diff_x_ratio[ind]<1.3] - + #print(diff_x_ratio, 'ratio') text_regions_p = text_regions_p.astype('uint8') - + if len(contours_only_text_parent_excluded)>0: textregion_par = np.zeros((text_regions_p.shape[0], text_regions_p.shape[1])).astype('uint8') textregion_par = cv2.fillPoly(textregion_par, pts=contours_only_text_parent_included, color=(1,1)) else: textregion_par = (text_regions_p[:,:]==1)*1 textregion_par = textregion_par.astype('uint8') - + text_regions_p_textregions_dilated = cv2.erode(textregion_par , hor_kernel, iterations=2) text_regions_p_textregions_dilated = cv2.dilate(text_regions_p_textregions_dilated , ver_kernel, iterations=4) text_regions_p_textregions_dilated = cv2.erode(text_regions_p_textregions_dilated , hor_kernel, iterations=1) text_regions_p_textregions_dilated = cv2.dilate(text_regions_p_textregions_dilated , ver_kernel, iterations=5) text_regions_p_textregions_dilated[text_regions_p[:,:]>1] = 0 - - + + contours_only_dilated, hir_on_text_dilated = return_contours_of_image(text_regions_p_textregions_dilated) contours_only_dilated = return_parent_contours(contours_only_dilated, hir_on_text_dilated) - + indexes_of_located_cont, center_x_coordinates_of_located, center_y_coordinates_of_located = \ self.return_indexes_of_contours_located_inside_another_list_of_contours( contours_only_dilated, contours_only_text_parent_included, cx_conts_included, cy_conts_included, args_cont_located_included) - - + + if len(args_cont_located_excluded)>0: for ind in args_cont_located_excluded: indexes_of_located_cont.append(np.array([ind])) contours_only_dilated.append(contours_only_text_parent[ind]) center_y_coordinates_of_located.append(0) - + array_list = [np.array([elem]) if isinstance(elem, int) else elem for elem in indexes_of_located_cont] flattened_array = np.concatenate([arr.ravel() for arr in array_list]) #print(len( np.unique(flattened_array)), 'indexes_of_located_cont uniques') - + missing_textregions = list( set(range(len(contours_only_text_parent))) - set(flattened_array) ) #print(missing_textregions, 'missing_textregions') @@ -3532,15 +2491,15 @@ class Eynollah: indexes_of_located_cont.append(np.array([ind])) contours_only_dilated.append(contours_only_text_parent[ind]) center_y_coordinates_of_located.append(0) - - + + if contours_only_text_parent_h: for vi in range(len(contours_only_text_parent_h)): indexes_of_located_cont.append(int(vi+len(contours_only_text_parent))) - + array_list = [np.array([elem]) if isinstance(elem, int) else elem for elem in indexes_of_located_cont] flattened_array = np.concatenate([arr.ravel() for arr in array_list]) - + y_len = text_regions_p.shape[0] x_len = text_regions_p.shape[1] @@ -3549,7 +2508,7 @@ class Eynollah: img_poly[text_regions_p[:,:]==2] = 2 img_poly[text_regions_p[:,:]==3] = 4 img_poly[text_regions_p[:,:]==6] = 5 - + img_header_and_sep = np.zeros((y_len,x_len), dtype='uint8') if contours_only_text_parent_h: _, cy_main, x_min_main, x_max_main, y_min_main, y_max_main, _ = find_new_features_of_contours( @@ -3558,13 +2517,13 @@ class Eynollah: img_header_and_sep[int(y_max_main[j]):int(y_max_main[j])+12, int(x_min_main[j]):int(x_max_main[j])] = 1 co_text_all_org = contours_only_text_parent + contours_only_text_parent_h - if len(contours_only_text_parent)>min_cont_size_to_be_dilated and self.light_version: + if len(contours_only_text_parent)>min_cont_size_to_be_dilated: co_text_all = contours_only_dilated + contours_only_text_parent_h else: co_text_all = contours_only_text_parent + contours_only_text_parent_h else: co_text_all_org = contours_only_text_parent - if len(contours_only_text_parent)>min_cont_size_to_be_dilated and self.light_version: + if len(contours_only_text_parent)>min_cont_size_to_be_dilated: co_text_all = contours_only_dilated else: co_text_all = contours_only_text_parent @@ -3576,9 +2535,9 @@ class Eynollah: co_text_all = [(i/6).astype(int) for i in co_text_all] for i in range(len(co_text_all)): img = labels_con[:,:,i].astype(np.uint8) - + #img = cv2.resize(img, (int(img.shape[1]/6), int(img.shape[0]/6)), interpolation=cv2.INTER_NEAREST) - + cv2.fillPoly(img, pts=[co_text_all[i]], color=(1,)) labels_con[:,:,i] = img @@ -3586,9 +2545,9 @@ class Eynollah: labels_con = resize_image(labels_con.astype(np.uint8), height1, width1).astype(bool) img_header_and_sep = resize_image(img_header_and_sep, height1, width1) img_poly = resize_image(img_poly, height3, width3) - - + + input_1 = np.zeros((inference_bs, height1, width1, 3)) ordered = [list(range(len(co_text_all)))] index_update = 0 @@ -3616,7 +2575,7 @@ class Eynollah: tot_counter += 1 batch.append(j) if tot_counter % inference_bs == 0 or tot_counter == len(ij_list): - y_pr = self.models["reading_order"].predict(input_1 , verbose=0) + y_pr = self.model_zoo.get("reading_order").predict(input_1 , verbose=0) for jb, j in enumerate(batch): if y_pr[jb][0]>=0.5: post_list.append(j) @@ -3638,8 +2597,8 @@ class Eynollah: break ordered = [i[0] for i in ordered] - - if len(contours_only_text_parent)>min_cont_size_to_be_dilated and self.light_version: + + if len(contours_only_text_parent)>min_cont_size_to_be_dilated: org_contours_indexes = [] for ind in range(len(ordered)): region_with_curr_order = ordered[ind] @@ -3652,216 +2611,11 @@ class Eynollah: np.array(indexes_of_located_cont[region_with_curr_order])[arg_sort_located_cont]) else: org_contours_indexes.extend([indexes_of_located_cont[region_with_curr_order]]) - + return org_contours_indexes else: return ordered - def return_start_and_end_of_common_text_of_textline_ocr(self,textline_image, ind_tot): - width = np.shape(textline_image)[1] - height = np.shape(textline_image)[0] - common_window = int(0.2*width) - - width1 = int ( width/2. - common_window ) - width2 = int ( width/2. + common_window ) - - img_sum = np.sum(textline_image[:,:,0], axis=0) - sum_smoothed = gaussian_filter1d(img_sum, 3) - - peaks_real, _ = find_peaks(sum_smoothed, height=0) - - if len(peaks_real)>70: - print(len(peaks_real), 'len(peaks_real)') - peaks_real = peaks_real[(peaks_realwidth1)] - - arg_sort = np.argsort(sum_smoothed[peaks_real]) - arg_sort4 =arg_sort[::-1][:4] - peaks_sort_4 = peaks_real[arg_sort][::-1][:4] - - argsort_sorted = np.argsort(peaks_sort_4) - first_4_sorted = peaks_sort_4[argsort_sorted] - y_4_sorted = sum_smoothed[peaks_real][arg_sort4[argsort_sorted]] - #print(first_4_sorted,'first_4_sorted') - - arg_sortnew = np.argsort(y_4_sorted) - peaks_final =np.sort( first_4_sorted[arg_sortnew][2:] ) - - #plt.figure(ind_tot) - #plt.imshow(textline_image) - #plt.plot([peaks_final[0], peaks_final[0]], [0, height-1]) - #plt.plot([peaks_final[1], peaks_final[1]], [0, height-1]) - #plt.savefig('./'+str(ind_tot)+'.png') - - return peaks_final[0], peaks_final[1] - else: - pass - - def return_start_and_end_of_common_text_of_textline_ocr_new_splitted( - self, peaks_real, sum_smoothed, start_split, end_split): - - peaks_real = peaks_real[(peaks_realstart_split)] - - arg_sort = np.argsort(sum_smoothed[peaks_real]) - arg_sort4 =arg_sort[::-1][:4] - peaks_sort_4 = peaks_real[arg_sort][::-1][:4] - argsort_sorted = np.argsort(peaks_sort_4) - - first_4_sorted = peaks_sort_4[argsort_sorted] - y_4_sorted = sum_smoothed[peaks_real][arg_sort4[argsort_sorted]] - #print(first_4_sorted,'first_4_sorted') - - arg_sortnew = np.argsort(y_4_sorted) - peaks_final =np.sort( first_4_sorted[arg_sortnew][3:] ) - return peaks_final[0] - - def return_start_and_end_of_common_text_of_textline_ocr_new(self, textline_image, ind_tot): - width = np.shape(textline_image)[1] - height = np.shape(textline_image)[0] - common_window = int(0.15*width) - - width1 = int ( width/2. - common_window ) - width2 = int ( width/2. + common_window ) - mid = int(width/2.) - - img_sum = np.sum(textline_image[:,:,0], axis=0) - sum_smoothed = gaussian_filter1d(img_sum, 3) - - peaks_real, _ = find_peaks(sum_smoothed, height=0) - if len(peaks_real)>70: - peak_start = self.return_start_and_end_of_common_text_of_textline_ocr_new_splitted( - peaks_real, sum_smoothed, width1, mid+2) - peak_end = self.return_start_and_end_of_common_text_of_textline_ocr_new_splitted( - peaks_real, sum_smoothed, mid-2, width2) - - #plt.figure(ind_tot) - #plt.imshow(textline_image) - #plt.plot([peak_start, peak_start], [0, height-1]) - #plt.plot([peak_end, peak_end], [0, height-1]) - #plt.savefig('./'+str(ind_tot)+'.png') - - return peak_start, peak_end - else: - pass - - def return_ocr_of_textline_without_common_section( - self, textline_image, model_ocr, processor, device, width_textline, h2w_ratio,ind_tot): - - if h2w_ratio > 0.05: - pixel_values = processor(textline_image, return_tensors="pt").pixel_values - generated_ids = model_ocr.generate(pixel_values.to(device)) - generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] - else: - #width = np.shape(textline_image)[1] - #height = np.shape(textline_image)[0] - #common_window = int(0.3*width) - #width1 = int ( width/2. - common_window ) - #width2 = int ( width/2. + common_window ) - - split_point = return_start_and_end_of_common_text_of_textline_ocr_without_common_section(textline_image) - if split_point: - image1 = textline_image[:, :split_point,:]# image.crop((0, 0, width2, height)) - image2 = textline_image[:, split_point:,:]#image.crop((width1, 0, width, height)) - - #pixel_values1 = processor(image1, return_tensors="pt").pixel_values - #pixel_values2 = processor(image2, return_tensors="pt").pixel_values - - pixel_values_merged = processor([image1,image2], return_tensors="pt").pixel_values - generated_ids_merged = model_ocr.generate(pixel_values_merged.to(device)) - generated_text_merged = processor.batch_decode(generated_ids_merged, skip_special_tokens=True) - - #print(generated_text_merged,'generated_text_merged') - - #generated_ids1 = model_ocr.generate(pixel_values1.to(device)) - #generated_ids2 = model_ocr.generate(pixel_values2.to(device)) - - #generated_text1 = processor.batch_decode(generated_ids1, skip_special_tokens=True)[0] - #generated_text2 = processor.batch_decode(generated_ids2, skip_special_tokens=True)[0] - - #generated_text = generated_text1 + ' ' + generated_text2 - generated_text = generated_text_merged[0] + ' ' + generated_text_merged[1] - - #print(generated_text1,'generated_text1') - #print(generated_text2, 'generated_text2') - #print('########################################') - else: - pixel_values = processor(textline_image, return_tensors="pt").pixel_values - generated_ids = model_ocr.generate(pixel_values.to(device)) - generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] - - #print(generated_text,'generated_text') - #print('########################################') - return generated_text - - def return_ocr_of_textline( - self, textline_image, model_ocr, processor, device, width_textline, h2w_ratio,ind_tot): - - if h2w_ratio > 0.05: - pixel_values = processor(textline_image, return_tensors="pt").pixel_values - generated_ids = model_ocr.generate(pixel_values.to(device)) - generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] - else: - #width = np.shape(textline_image)[1] - #height = np.shape(textline_image)[0] - #common_window = int(0.3*width) - #width1 = int ( width/2. - common_window ) - #width2 = int ( width/2. + common_window ) - - try: - width1, width2 = self.return_start_and_end_of_common_text_of_textline_ocr_new(textline_image, ind_tot) - - image1 = textline_image[:, :width2,:]# image.crop((0, 0, width2, height)) - image2 = textline_image[:, width1:,:]#image.crop((width1, 0, width, height)) - - pixel_values1 = processor(image1, return_tensors="pt").pixel_values - pixel_values2 = processor(image2, return_tensors="pt").pixel_values - - generated_ids1 = model_ocr.generate(pixel_values1.to(device)) - generated_ids2 = model_ocr.generate(pixel_values2.to(device)) - - generated_text1 = processor.batch_decode(generated_ids1, skip_special_tokens=True)[0] - generated_text2 = processor.batch_decode(generated_ids2, skip_special_tokens=True)[0] - #print(generated_text1,'generated_text1') - #print(generated_text2, 'generated_text2') - #print('########################################') - - match = sq(None, generated_text1, generated_text2).find_longest_match( - 0, len(generated_text1), 0, len(generated_text2)) - generated_text = generated_text1 + generated_text2[match.b+match.size:] - except: - pixel_values = processor(textline_image, return_tensors="pt").pixel_values - generated_ids = model_ocr.generate(pixel_values.to(device)) - generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] - - return generated_text - - def return_list_of_contours_with_desired_order(self, ls_cons, sorted_indexes): - return list(np.array(ls_cons)[np.array(sorted_indexes)]) - - def return_it_in_two_groups(self, x_differential): - split = [ind if x_differential[ind]!=x_differential[ind+1] else -1 - for ind in range(len(x_differential)-1)] - split_masked = list( np.array(split[:])[np.array(split[:])!=-1] ) - if 0 not in split_masked: - split_masked.insert(0, -1) - split_masked.append(len(x_differential)-1) - - split_masked = np.array(split_masked) +1 - - sums = [np.sum(x_differential[split_masked[ind]:split_masked[ind+1]]) - for ind in range(len(split_masked)-1)] - - indexes_to_bec_changed = [ind if (np.abs(sums[ind-1]) > np.abs(sums[ind]) and - np.abs(sums[ind+1]) > np.abs(sums[ind])) else -1 - for ind in range(1,len(sums)-1)] - indexes_to_bec_changed_filtered = np.array(indexes_to_bec_changed)[np.array(indexes_to_bec_changed)!=-1] - - x_differential_new = np.copy(x_differential) - for i in indexes_to_bec_changed_filtered: - i_slice = slice(split_masked[i], split_masked[i+1]) - x_differential_new[i_slice] = -1 * np.array(x_differential)[i_slice] - - return x_differential_new - def filter_contours_inside_a_bigger_one(self, contours, contours_d_ordered, image, marginal_cnts=None, type_contour="textregion"): if type_contour == "textregion": @@ -3941,7 +2695,7 @@ class Eynollah: axis=0)) return contours - + def return_indexes_of_contours_located_inside_another_list_of_contours( self, contours, contours_loc, cx_main_loc, cy_main_loc, indexes_loc): indexes_of_located_cont = [] @@ -3951,7 +2705,7 @@ class Eynollah: #for j in range(len(contours_loc))] #cx_main_loc = [(M_main_tot[j]["m10"] / (M_main_tot[j]["m00"] + 1e-32)) for j in range(len(M_main_tot))] #cy_main_loc = [(M_main_tot[j]["m01"] / (M_main_tot[j]["m00"] + 1e-32)) for j in range(len(M_main_tot))] - + for ij in range(len(contours)): results = [cv2.pointPolygonTest(contours[ij], (cx_main_loc[ind], cy_main_loc[ind]), False) for ind in range(len(cy_main_loc)) ] @@ -3963,9 +2717,9 @@ class Eynollah: indexes_of_located_cont.append(indexes) center_x_coordinates_of_located.append(np.array(cx_main_loc)[indexes_in] ) center_y_coordinates_of_located.append(np.array(cy_main_loc)[indexes_in] ) - + return indexes_of_located_cont, center_x_coordinates_of_located, center_y_coordinates_of_located - + def filter_contours_without_textline_inside( self, contours_par, contours_textline, @@ -4006,15 +2760,15 @@ class Eynollah: (all_found_textline_polygons_marginals_left, all_found_textline_polygons_marginals_right) = \ split(all_found_textline_polygons_marginals) - + (all_box_coord_marginals_left, all_box_coord_marginals_right) = \ split(all_box_coord_marginals) - + (slopes_marg_left, slopes_marg_right) = \ split(slopes_marginals) - + (cy_marg_left, cy_marg_right) = \ split(cy_marg) @@ -4025,19 +2779,19 @@ class Eynollah: return list(np.array(lis)[order_left]) def sort_right(lis): return list(np.array(lis)[order_right]) - + ordered_left_marginals = sort_left(poly_marg_left) ordered_right_marginals = sort_right(poly_marg_right) - + ordered_left_marginals_textline = sort_left(all_found_textline_polygons_marginals_left) ordered_right_marginals_textline = sort_right(all_found_textline_polygons_marginals_right) - + ordered_left_marginals_bbox = sort_left(all_box_coord_marginals_left) ordered_right_marginals_bbox = sort_right(all_box_coord_marginals_right) - + ordered_left_slopes_marginals = sort_left(slopes_marg_left) ordered_right_slopes_marginals = sort_right(slopes_marg_right) - + return (ordered_left_marginals, ordered_right_marginals, ordered_left_marginals_textline, @@ -4066,14 +2820,8 @@ class Eynollah: # Log enabled features directly enabled_modes = [] - if self.light_version: - enabled_modes.append("Light version") - if self.textline_light: - enabled_modes.append("Light textline detection") if self.full_layout: enabled_modes.append("Full layout analysis") - if self.ocr: - enabled_modes.append("OCR") if self.tables: enabled_modes.append("Table detection") if enabled_modes: @@ -4126,44 +2874,27 @@ class Eynollah: def run_single(self): t0 = time.time() - + self.logger.info(f"Processing file: {self.writer.image_filename}") self.logger.info("Step 1/5: Image Enhancement") - + img_res, is_image_enhanced, num_col_classifier, num_column_is_classified = \ - self.run_enhancement(self.light_version) - + self.run_enhancement() + self.logger.info(f"Image: {self.image.shape[1]}x{self.image.shape[0]}, " f"scale {self.scale_x:.1f}x{self.scale_y:.1f}, " f"{self.dpi} DPI, {num_col_classifier} columns") if is_image_enhanced: self.logger.info("Enhancement applied") - - self.logger.info(f"Enhancement complete ({time.time() - t0:.1f}s)") - - # Image Extraction Mode - if self.extract_only_images: - self.logger.info("Step 2/5: Image Extraction Mode") - - text_regions_p_1, erosion_hurts, polygons_seplines, polygons_of_images, \ - image_page, page_coord, cont_page = \ - self.get_regions_light_v_extract_only_images(img_res, is_image_enhanced, num_col_classifier) - pcgts = self.writer.build_pagexml_no_full_layout( - [], page_coord, [], [], [], - polygons_of_images, [], [], [], [], [], [], [], [], [], - cont_page, [], []) - if self.plotter: - self.plotter.write_images_into_directory(polygons_of_images, image_page) - - self.logger.info("Image extraction complete") - return pcgts + self.logger.info(f"Enhancement complete ({time.time() - t0:.1f}s)") + # Basic Processing Mode if self.skip_layout_and_reading_order: self.logger.info("Step 2/5: Basic Processing Mode") self.logger.info("Skipping layout analysis and reading order detection") - + _ ,_, _, _, textline_mask_tot_ea, img_bin_light, _ = \ self.get_regions_light_v(img_res, is_image_enhanced, num_col_classifier,) @@ -4175,7 +2906,7 @@ class Eynollah: cnt_clean_rot_raw, hir_on_cnt_clean_rot = return_contours_of_image(textline_mask_tot_ea) all_found_textline_polygons = filter_contours_area_of_image( textline_mask_tot_ea, cnt_clean_rot_raw, hir_on_cnt_clean_rot, max_area=1, min_area=0.00001) - + cx_main_tot, cy_main_tot = find_center_of_contours(all_found_textline_polygons) w_h_textlines = [cv2.boundingRect(polygon)[2:] for polygon in all_found_textline_polygons] @@ -4188,99 +2919,99 @@ class Eynollah: all_found_textline_polygons = dilate_textline_contours(all_found_textline_polygons) all_found_textline_polygons = self.filter_contours_inside_a_bigger_one( all_found_textline_polygons, None, textline_mask_tot_ea, type_contour="textline") - + order_text_new = [0] slopes =[0] conf_contours_textregions =[0] - - if self.ocr and not self.tr: - gc.collect() - ocr_all_textlines = return_rnn_cnn_ocr_of_given_textlines( - image_page, all_found_textline_polygons, np.zeros((len(all_found_textline_polygons), 4)), - self.models["ocr"], self.b_s_ocr, self.num_to_char, textline_light=True) - else: - ocr_all_textlines = None - + pcgts = self.writer.build_pagexml_no_full_layout( - cont_page, page_coord, order_text_new, - all_found_textline_polygons, page_coord, [], - [], [], [], [], [], [], - slopes, [], [], - cont_page, [], [], - ocr_all_textlines=ocr_all_textlines, - conf_contours_textregion=conf_contours_textregions, - skip_layout_reading_order=True) + found_polygons_text_region=cont_page, + page_coord=page_coord, + order_of_texts=order_text_new, + all_found_textline_polygons=all_found_textline_polygons, + all_box_coord=page_coord, + found_polygons_text_region_img=[], + found_polygons_marginals_left=[], + found_polygons_marginals_right=[], + all_found_textline_polygons_marginals_left=[], + all_found_textline_polygons_marginals_right=[], + all_box_coord_marginals_left=[], + all_box_coord_marginals_right=[], + slopes=slopes, + slopes_marginals_left=[], + slopes_marginals_right=[], + cont_page=cont_page, + polygons_seplines=[], + found_polygons_tables=[], + skip_layout_reading_order=True + ) self.logger.info("Basic processing complete") return pcgts #print("text region early -1 in %.1fs", time.time() - t0) t1 = time.time() self.logger.info("Step 2/5: Layout Analysis") - - if self.light_version: - self.logger.info("Using light version processing") - text_regions_p_1 ,erosion_hurts, polygons_seplines, polygons_text_early, \ - textline_mask_tot_ea, img_bin_light, confidence_matrix = \ - self.get_regions_light_v(img_res, is_image_enhanced, num_col_classifier) - #print("text region early -2 in %.1fs", time.time() - t0) - if num_col_classifier == 1 or num_col_classifier ==2: - if num_col_classifier == 1: - img_w_new = 1000 - else: - img_w_new = 1300 - img_h_new = img_w_new * textline_mask_tot_ea.shape[0] // textline_mask_tot_ea.shape[1] + self.logger.info("Using light version processing") + text_regions_p_1 ,erosion_hurts, polygons_seplines, polygons_text_early, \ + textline_mask_tot_ea, img_bin_light, confidence_matrix = \ + self.get_regions_light_v(img_res, is_image_enhanced, num_col_classifier) + #print("text region early -2 in %.1fs", time.time() - t0) - textline_mask_tot_ea_deskew = resize_image(textline_mask_tot_ea,img_h_new, img_w_new ) - slope_deskew = self.run_deskew(textline_mask_tot_ea_deskew) + if num_col_classifier == 1 or num_col_classifier ==2: + if num_col_classifier == 1: + img_w_new = 1000 else: - slope_deskew = self.run_deskew(textline_mask_tot_ea) - #print("text region early -2,5 in %.1fs", time.time() - t0) - #self.logger.info("Textregion detection took %.1fs ", time.time() - t1t) - num_col, num_col_classifier, img_only_regions, page_coord, image_page, mask_images, mask_seps, \ - text_regions_p_1, cont_page, table_prediction, textline_mask_tot_ea, img_bin_light = \ - self.run_graphics_and_columns_light(text_regions_p_1, textline_mask_tot_ea, - num_col_classifier, num_column_is_classified, - erosion_hurts, img_bin_light) - #self.logger.info("run graphics %.1fs ", time.time() - t1t) - #print("text region early -3 in %.1fs", time.time() - t0) - textline_mask_tot_ea_org = np.copy(textline_mask_tot_ea) + img_w_new = 1300 + img_h_new = img_w_new * textline_mask_tot_ea.shape[0] // textline_mask_tot_ea.shape[1] + textline_mask_tot_ea_deskew = resize_image(textline_mask_tot_ea,img_h_new, img_w_new ) + slope_deskew = self.run_deskew(textline_mask_tot_ea_deskew) else: - text_regions_p_1, erosion_hurts, polygons_seplines, polygons_text_early = \ - self.get_regions_from_xy_2models(img_res, is_image_enhanced, - num_col_classifier) - self.logger.info(f"Textregion detection took {time.time() - t1:.1f}s") - confidence_matrix = np.zeros((text_regions_p_1.shape[:2])) + slope_deskew = self.run_deskew(textline_mask_tot_ea) + #print("text region early -2,5 in %.1fs", time.time() - t0) + #self.logger.info("Textregion detection took %.1fs ", time.time() - t1t) + num_col, num_col_classifier, img_only_regions, page_coord, image_page, mask_images, mask_seps, \ + text_regions_p_1, cont_page, table_prediction, textline_mask_tot_ea, img_bin_light = \ + self.run_graphics_and_columns_light(text_regions_p_1, textline_mask_tot_ea, + num_col_classifier, num_column_is_classified, + erosion_hurts, img_bin_light) + #self.logger.info("run graphics %.1fs ", time.time() - t1t) + #print("text region early -3 in %.1fs", time.time() - t0) + textline_mask_tot_ea_org = np.copy(textline_mask_tot_ea) - t1 = time.time() - num_col, num_col_classifier, img_only_regions, page_coord, image_page, mask_images, mask_seps, \ - text_regions_p_1, cont_page, table_prediction = \ - self.run_graphics_and_columns(text_regions_p_1, num_col_classifier, num_column_is_classified, - erosion_hurts) - self.logger.info(f"Graphics detection took {time.time() - t1:.1f}s") - #self.logger.info('cont_page %s', cont_page) #plt.imshow(table_prediction) #plt.show() self.logger.info(f"Layout analysis complete ({time.time() - t1:.1f}s)") if not num_col and len(polygons_text_early) == 0: self.logger.info("No columns detected - generating empty PAGE-XML") - + pcgts = self.writer.build_pagexml_no_full_layout( - [], page_coord, [], [], [], [], [], [], [], [], [], [], [], [], [], - cont_page, [], []) + found_polygons_text_region=[], + page_coord=page_coord, + order_of_texts=[], + all_found_textline_polygons=[], + all_box_coord=[], + found_polygons_text_region_img=[], + found_polygons_marginals_left=[], + found_polygons_marginals_right=[], + all_found_textline_polygons_marginals_left=[], + all_found_textline_polygons_marginals_right=[], + all_box_coord_marginals_left=[], + all_box_coord_marginals_right=[], + slopes=[], + slopes_marginals_left=[], + slopes_marginals_right=[], + cont_page=cont_page, + polygons_seplines=[], + found_polygons_tables=[], + ) return pcgts #print("text region early in %.1fs", time.time() - t0) t1 = time.time() - if not self.light_version: - textline_mask_tot_ea = self.run_textline(image_page) - self.logger.info(f"Textline detection took {time.time() - t1:.1f}s") - t1 = time.time() - slope_deskew = self.run_deskew(textline_mask_tot_ea) - self.logger.info(f"Deskewing took {time.time() - t1:.1f}s") - elif num_col_classifier in (1,2): + if num_col_classifier in (1,2): org_h_l_m = textline_mask_tot_ea.shape[0] org_w_l_m = textline_mask_tot_ea.shape[1] if num_col_classifier == 1: @@ -4316,13 +3047,11 @@ class Eynollah: text_regions_p[text_regions_p == 4] = 1 self.logger.info("Step 3/5: Text Line Detection") - + if self.curved_line: self.logger.info("Mode: Curved line detection") - elif self.textline_light: - self.logger.info("Mode: Light detection") - if self.light_version and num_col_classifier in (1,2): + if num_col_classifier in (1,2): image_page = resize_image(image_page,org_h_l_m, org_w_l_m ) textline_mask_tot_ea = resize_image(textline_mask_tot_ea,org_h_l_m, org_w_l_m ) text_regions_p = resize_image(text_regions_p,org_h_l_m, org_w_l_m ) @@ -4346,11 +3075,10 @@ class Eynollah: regions_fully, regions_without_separators, polygons_of_marginals, contours_tables = \ self.run_boxes_full_layout(image_page, textline_mask_tot, text_regions_p, slope_deskew, num_col_classifier, img_only_regions, table_prediction, erosion_hurts, - img_bin_light if self.light_version else None) + img_bin_light) ###polygons_of_marginals = dilate_textregion_contours(polygons_of_marginals) - if self.light_version: - drop_label_in_full_layout = 4 - textline_mask_tot_ea_org[img_revised_tab==drop_label_in_full_layout] = 0 + drop_label_in_full_layout = 4 + textline_mask_tot_ea_org[img_revised_tab==drop_label_in_full_layout] = 0 text_only = (img_revised_tab[:, :] == 1) * 1 @@ -4517,89 +3245,89 @@ class Eynollah: empty_marginals = [[]] * len(polygons_of_marginals) if self.full_layout: pcgts = self.writer.build_pagexml_full_layout( - [], [], page_coord, [], [], [], [], [], - polygons_of_images, contours_tables, [], - polygons_of_marginals, polygons_of_marginals, - empty_marginals, empty_marginals, - empty_marginals, empty_marginals, - [], [], [], [], - cont_page, polygons_seplines) + found_polygons_text_region=[], + found_polygons_text_region_h=[], + page_coord=page_coord, + order_of_texts=[], + all_found_textline_polygons=[], + all_found_textline_polygons_h=[], + all_box_coord=[], + all_box_coord_h=[], + found_polygons_text_region_img=polygons_of_images, + found_polygons_tables=contours_tables, + found_polygons_drop_capitals=[], + found_polygons_marginals_left=polygons_of_marginals, + found_polygons_marginals_right=polygons_of_marginals, + all_found_textline_polygons_marginals_left=empty_marginals, + all_found_textline_polygons_marginals_right=empty_marginals, + all_box_coord_marginals_left=empty_marginals, + all_box_coord_marginals_right=empty_marginals, + slopes=[], + slopes_h=[], + slopes_marginals_left=[], + slopes_marginals_right=[], + cont_page=cont_page, + polygons_seplines=polygons_seplines + ) else: pcgts = self.writer.build_pagexml_no_full_layout( - [], page_coord, [], [], [], - polygons_of_images, - polygons_of_marginals, polygons_of_marginals, - empty_marginals, empty_marginals, - empty_marginals, empty_marginals, - [], [], [], - cont_page, polygons_seplines, contours_tables) + found_polygons_text_region=[], + page_coord=page_coord, + order_of_texts=[], + all_found_textline_polygons=[], + all_box_coord=[], + found_polygons_text_region_img=polygons_of_images, + found_polygons_marginals_left=polygons_of_marginals, + found_polygons_marginals_right=polygons_of_marginals, + all_found_textline_polygons_marginals_left=empty_marginals, + all_found_textline_polygons_marginals_right=empty_marginals, + all_box_coord_marginals_left=empty_marginals, + all_box_coord_marginals_right=empty_marginals, + slopes=[], + slopes_marginals_left=[], + slopes_marginals_right=[], + cont_page=cont_page, + polygons_seplines=polygons_seplines, + found_polygons_tables=contours_tables + ) return pcgts #print("text region early 3 in %.1fs", time.time() - t0) - if self.light_version: - contours_only_text_parent = dilate_textregion_contours(contours_only_text_parent) - contours_only_text_parent, contours_only_text_parent_d_ordered = \ - self.filter_contours_inside_a_bigger_one( - contours_only_text_parent, contours_only_text_parent_d_ordered, text_only, - marginal_cnts=polygons_of_marginals) - #print("text region early 3.5 in %.1fs", time.time() - t0) - conf_contours_textregions = get_textregion_contours_in_org_image_light( - contours_only_text_parent, self.image, confidence_matrix) - #contours_only_text_parent = dilate_textregion_contours(contours_only_text_parent) - else: - conf_contours_textregions = get_textregion_contours_in_org_image_light( - contours_only_text_parent, self.image, confidence_matrix) + contours_only_text_parent = dilate_textregion_contours(contours_only_text_parent) + contours_only_text_parent , contours_only_text_parent_d_ordered = self.filter_contours_inside_a_bigger_one( + contours_only_text_parent, contours_only_text_parent_d_ordered, text_only, + marginal_cnts=polygons_of_marginals) + #print("text region early 3.5 in %.1fs", time.time() - t0) + conf_contours_textregions = get_textregion_contours_in_org_image_light( + contours_only_text_parent, self.image, confidence_matrix) + #contours_only_text_parent = dilate_textregion_contours(contours_only_text_parent) #print("text region early 4 in %.1fs", time.time() - t0) boxes_text = get_text_region_boxes_by_given_contours(contours_only_text_parent) boxes_marginals = get_text_region_boxes_by_given_contours(polygons_of_marginals) #print("text region early 5 in %.1fs", time.time() - t0) ## birdan sora chock chakir if not self.curved_line: - if self.light_version: - if self.textline_light: - all_found_textline_polygons, \ - all_box_coord, slopes = self.get_slopes_and_deskew_new_light2( - contours_only_text_parent, textline_mask_tot_ea_org, - boxes_text, slope_deskew) - all_found_textline_polygons_marginals, \ - all_box_coord_marginals, slopes_marginals = self.get_slopes_and_deskew_new_light2( - polygons_of_marginals, textline_mask_tot_ea_org, - boxes_marginals, slope_deskew) + all_found_textline_polygons, \ + all_box_coord, slopes = self.get_slopes_and_deskew_new_light2( + contours_only_text_parent, textline_mask_tot_ea_org, + boxes_text, slope_deskew) + all_found_textline_polygons_marginals, \ + all_box_coord_marginals, slopes_marginals = self.get_slopes_and_deskew_new_light2( + polygons_of_marginals, textline_mask_tot_ea_org, + boxes_marginals, slope_deskew) - all_found_textline_polygons = dilate_textline_contours( - all_found_textline_polygons) - all_found_textline_polygons = self.filter_contours_inside_a_bigger_one( - all_found_textline_polygons, None, textline_mask_tot_ea_org, type_contour="textline") - all_found_textline_polygons_marginals = dilate_textline_contours( - all_found_textline_polygons_marginals) - contours_only_text_parent, all_found_textline_polygons, \ - contours_only_text_parent_d_ordered, conf_contours_textregions = \ - self.filter_contours_without_textline_inside( - contours_only_text_parent, all_found_textline_polygons, - contours_only_text_parent_d_ordered, conf_contours_textregions) - else: - textline_mask_tot_ea = cv2.erode(textline_mask_tot_ea, kernel=KERNEL, iterations=1) - all_found_textline_polygons, \ - all_box_coord, slopes = self.get_slopes_and_deskew_new_light( - contours_only_text_parent, contours_only_text_parent, textline_mask_tot_ea, - boxes_text, slope_deskew) - all_found_textline_polygons_marginals, \ - all_box_coord_marginals, slopes_marginals = self.get_slopes_and_deskew_new_light( - polygons_of_marginals, polygons_of_marginals, textline_mask_tot_ea, - boxes_marginals, slope_deskew) - #all_found_textline_polygons = self.filter_contours_inside_a_bigger_one( - # all_found_textline_polygons, textline_mask_tot_ea_org, type_contour="textline") - else: - textline_mask_tot_ea = cv2.erode(textline_mask_tot_ea, kernel=KERNEL, iterations=1) - all_found_textline_polygons, \ - all_box_coord, slopes = self.get_slopes_and_deskew_new( - contours_only_text_parent, contours_only_text_parent, textline_mask_tot_ea, - boxes_text, slope_deskew) - all_found_textline_polygons_marginals, \ - all_box_coord_marginals, slopes_marginals = self.get_slopes_and_deskew_new( - polygons_of_marginals, polygons_of_marginals, textline_mask_tot_ea, - boxes_marginals, slope_deskew) + all_found_textline_polygons = dilate_textline_contours( + all_found_textline_polygons) + all_found_textline_polygons = self.filter_contours_inside_a_bigger_one( + all_found_textline_polygons, None, textline_mask_tot_ea_org, type_contour="textline") + all_found_textline_polygons_marginals = dilate_textline_contours( + all_found_textline_polygons_marginals) + contours_only_text_parent, all_found_textline_polygons, \ + contours_only_text_parent_d_ordered, conf_contours_textregions = \ + self.filter_contours_without_textline_inside( + contours_only_text_parent, all_found_textline_polygons, + contours_only_text_parent_d_ordered, conf_contours_textregions) else: scale_param = 1 textline_mask_tot_ea_erode = cv2.erode(textline_mask_tot_ea, kernel=KERNEL, iterations=2) @@ -4617,7 +3345,7 @@ class Eynollah: num_col_classifier, scale_param, slope_deskew) all_found_textline_polygons_marginals = small_textlines_to_parent_adherence2( all_found_textline_polygons_marginals, textline_mask_tot_ea, num_col_classifier) - + mid_point_of_page_width = text_regions_p.shape[1] / 2. (polygons_of_marginals_left, polygons_of_marginals_right, all_found_textline_polygons_marginals_left, all_found_textline_polygons_marginals_right, @@ -4626,14 +3354,11 @@ class Eynollah: self.separate_marginals_to_left_and_right_and_order_from_top_to_down( polygons_of_marginals, all_found_textline_polygons_marginals, all_box_coord_marginals, slopes_marginals, mid_point_of_page_width) - + #print(len(polygons_of_marginals), len(ordered_left_marginals), len(ordered_right_marginals), 'marginals ordred') if self.full_layout: - if self.light_version: - fun = check_any_text_region_in_model_one_is_main_or_header_light - else: - fun = check_any_text_region_in_model_one_is_main_or_header + fun = check_any_text_region_in_model_one_is_main_or_header_light text_regions_p, contours_only_text_parent, contours_only_text_parent_h, all_box_coord, all_box_coord_h, \ all_found_textline_polygons, all_found_textline_polygons_h, slopes, slopes_h, \ contours_only_text_parent_d_ordered, contours_only_text_parent_h_d_ordered, \ @@ -4652,7 +3377,7 @@ class Eynollah: ##all_found_textline_polygons = adhere_drop_capital_region_into_corresponding_textline( ##text_regions_p, polygons_of_drop_capitals, contours_only_text_parent, contours_only_text_parent_h, ##all_box_coord, all_box_coord_h, all_found_textline_polygons, all_found_textline_polygons_h, - ##kernel=KERNEL, curved_line=self.curved_line, textline_light=self.textline_light) + ##kernel=KERNEL, curved_line=self.curved_line) if not self.reading_order_machine_based: label_seps = 6 @@ -4719,1084 +3444,56 @@ class Eynollah: boxes_d, textline_mask_tot_d) self.logger.info(f"Detection of reading order took {time.time() - t_order:.1f}s") - ocr_all_textlines = None - ocr_all_textlines_marginals_left = None - ocr_all_textlines_marginals_right = None - ocr_all_textlines_h = None - ocr_all_textlines_drop = None - if self.ocr: - self.logger.info("Step 4.5/5: OCR Processing") - - if not self.tr: - gc.collect() - - if len(all_found_textline_polygons): - ocr_all_textlines = return_rnn_cnn_ocr_of_given_textlines( - image_page, all_found_textline_polygons, all_box_coord, - self.models["ocr"], self.b_s_ocr, self.num_to_char, self.textline_light, self.curved_line) - - if len(all_found_textline_polygons_marginals_left): - ocr_all_textlines_marginals_left = return_rnn_cnn_ocr_of_given_textlines( - image_page, all_found_textline_polygons_marginals_left, all_box_coord_marginals_left, - self.models["ocr"], self.b_s_ocr, self.num_to_char, self.textline_light, self.curved_line) - - if len(all_found_textline_polygons_marginals_right): - ocr_all_textlines_marginals_right = return_rnn_cnn_ocr_of_given_textlines( - image_page, all_found_textline_polygons_marginals_right, all_box_coord_marginals_right, - self.models["ocr"], self.b_s_ocr, self.num_to_char, self.textline_light, self.curved_line) - - if self.full_layout and len(all_found_textline_polygons): - ocr_all_textlines_h = return_rnn_cnn_ocr_of_given_textlines( - image_page, all_found_textline_polygons_h, all_box_coord_h, - self.models["ocr"], self.b_s_ocr, self.num_to_char, self.textline_light, self.curved_line) - - if self.full_layout and len(polygons_of_drop_capitals): - ocr_all_textlines_drop = return_rnn_cnn_ocr_of_given_textlines( - image_page, polygons_of_drop_capitals, np.zeros((len(polygons_of_drop_capitals), 4)), - self.models["ocr"], self.b_s_ocr, self.num_to_char, self.textline_light, self.curved_line) - - else: - if self.light_version: - self.logger.info("Using light version OCR") - if self.textline_light: - self.logger.info("Using light text line detection for OCR") - self.logger.info("Processing text lines...") - - gc.collect() - - torch.cuda.empty_cache() - self.models["ocr"].to(self.device) - - ind_tot = 0 - #cv2.imwrite('./img_out.png', image_page) - ocr_all_textlines = [] - # FIXME: what about lines in marginals / headings / drop-capitals here? - for indexing, ind_poly_first in enumerate(all_found_textline_polygons): - ocr_textline_in_textregion = [] - for indexing2, ind_poly in enumerate(ind_poly_first): - if not (self.textline_light or self.curved_line): - ind_poly = copy.deepcopy(ind_poly) - box_ind = all_box_coord[indexing] - #print(ind_poly,np.shape(ind_poly), 'ind_poly') - #print(box_ind) - ind_poly = return_textline_contour_with_added_box_coordinate(ind_poly, box_ind) - #print(ind_poly_copy) - ind_poly[ind_poly<0] = 0 - x, y, w, h = cv2.boundingRect(ind_poly) - #print(ind_poly_copy, np.shape(ind_poly_copy)) - #print(x, y, w, h, h/float(w),'ratio') - h2w_ratio = h/float(w) - mask_poly = np.zeros(image_page.shape) - if not self.light_version: - img_poly_on_img = np.copy(image_page) - else: - img_poly_on_img = np.copy(img_bin_light) - mask_poly = cv2.fillPoly(mask_poly, pts=[ind_poly], color=(1, 1, 1)) - - if self.textline_light: - mask_poly = cv2.dilate(mask_poly, KERNEL, iterations=1) - img_poly_on_img[:,:,0][mask_poly[:,:,0] ==0] = 255 - img_poly_on_img[:,:,1][mask_poly[:,:,0] ==0] = 255 - img_poly_on_img[:,:,2][mask_poly[:,:,0] ==0] = 255 - - img_croped = img_poly_on_img[y:y+h, x:x+w, :] - #cv2.imwrite('./extracted_lines/'+str(ind_tot)+'.jpg', img_croped) - text_ocr = self.return_ocr_of_textline_without_common_section( - img_croped, self.models["ocr"], self.processor, self.device, w, h2w_ratio, ind_tot) - ocr_textline_in_textregion.append(text_ocr) - ind_tot = ind_tot +1 - ocr_all_textlines.append(ocr_textline_in_textregion) - self.logger.info("Step 5/5: Output Generation") if self.full_layout: pcgts = self.writer.build_pagexml_full_layout( - contours_only_text_parent, contours_only_text_parent_h, page_coord, order_text_new, - all_found_textline_polygons, all_found_textline_polygons_h, all_box_coord, all_box_coord_h, - polygons_of_images, contours_tables, polygons_of_drop_capitals, - polygons_of_marginals_left, polygons_of_marginals_right, - all_found_textline_polygons_marginals_left, all_found_textline_polygons_marginals_right, - all_box_coord_marginals_left, all_box_coord_marginals_right, - slopes, slopes_h, slopes_marginals_left, slopes_marginals_right, - cont_page, polygons_seplines, ocr_all_textlines, ocr_all_textlines_h, - ocr_all_textlines_marginals_left, ocr_all_textlines_marginals_right, - ocr_all_textlines_drop, - conf_contours_textregions, conf_contours_textregions_h) + found_polygons_text_region=contours_only_text_parent, + found_polygons_text_region_h=contours_only_text_parent_h, + page_coord=page_coord, + order_of_texts=order_text_new, + all_found_textline_polygons=all_found_textline_polygons, + all_found_textline_polygons_h=all_found_textline_polygons_h, + all_box_coord=all_box_coord, + all_box_coord_h=all_box_coord_h, + found_polygons_text_region_img=polygons_of_images, + found_polygons_tables=contours_tables, + found_polygons_drop_capitals=polygons_of_drop_capitals, + found_polygons_marginals_left=polygons_of_marginals_left, + found_polygons_marginals_right=polygons_of_marginals_right, + all_found_textline_polygons_marginals_left=all_found_textline_polygons_marginals_left, + all_found_textline_polygons_marginals_right=all_found_textline_polygons_marginals_right, + all_box_coord_marginals_left=all_box_coord_marginals_left, + all_box_coord_marginals_right=all_box_coord_marginals_right, + slopes=slopes, + slopes_h=slopes_h, + slopes_marginals_left=slopes_marginals_left, + slopes_marginals_right=slopes_marginals_right, + cont_page=cont_page, + polygons_seplines=polygons_seplines, + conf_contours_textregions=conf_contours_textregions, + conf_contours_textregions_h=conf_contours_textregions_h + ) else: pcgts = self.writer.build_pagexml_no_full_layout( - contours_only_text_parent, page_coord, order_text_new, - all_found_textline_polygons, all_box_coord, polygons_of_images, - polygons_of_marginals_left, polygons_of_marginals_right, - all_found_textline_polygons_marginals_left, all_found_textline_polygons_marginals_right, - all_box_coord_marginals_left, all_box_coord_marginals_right, - slopes, slopes_marginals_left, slopes_marginals_right, - cont_page, polygons_seplines, contours_tables, - ocr_all_textlines=ocr_all_textlines, - ocr_all_textlines_marginals_left=ocr_all_textlines_marginals_left, - ocr_all_textlines_marginals_right=ocr_all_textlines_marginals_right, - conf_contours_textregions=conf_contours_textregions) - + found_polygons_text_region=contours_only_text_parent, + page_coord=page_coord, + order_of_texts=order_text_new, + all_found_textline_polygons=all_found_textline_polygons, + all_box_coord=all_box_coord, + found_polygons_text_region_img=polygons_of_images, + found_polygons_marginals_left=polygons_of_marginals_left, + found_polygons_marginals_right=polygons_of_marginals_right, + all_found_textline_polygons_marginals_left=all_found_textline_polygons_marginals_left, + all_found_textline_polygons_marginals_right=all_found_textline_polygons_marginals_right, + all_box_coord_marginals_left=all_box_coord_marginals_left, + all_box_coord_marginals_right=all_box_coord_marginals_right, + slopes=slopes, + slopes_marginals_left=slopes_marginals_left, + slopes_marginals_right=slopes_marginals_right, + cont_page=cont_page, + polygons_seplines=polygons_seplines, + found_polygons_tables=contours_tables, + ) + return pcgts - - -class Eynollah_ocr: - def __init__( - self, - dir_models, - model_name=None, - dir_xmls=None, - tr_ocr=False, - batch_size=None, - export_textline_images_and_text=False, - do_not_mask_with_textline_contour=False, - pref_of_dataset=None, - min_conf_value_of_textline_text : Optional[float]=None, - logger=None, - ): - self.model_name = model_name - self.tr_ocr = tr_ocr - self.export_textline_images_and_text = export_textline_images_and_text - self.do_not_mask_with_textline_contour = do_not_mask_with_textline_contour - self.pref_of_dataset = pref_of_dataset - self.logger = logger if logger else getLogger('eynollah') - - if not export_textline_images_and_text: - if min_conf_value_of_textline_text: - self.min_conf_value_of_textline_text = float(min_conf_value_of_textline_text) - else: - self.min_conf_value_of_textline_text = 0.3 - if tr_ocr: - self.processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-printed") - self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - if self.model_name: - self.model_ocr_dir = self.model_name - else: - self.model_ocr_dir = dir_models + "/model_eynollah_ocr_trocr_20250919" - self.model_ocr = VisionEncoderDecoderModel.from_pretrained(self.model_ocr_dir) - self.model_ocr.to(self.device) - if not batch_size: - self.b_s = 2 - else: - self.b_s = int(batch_size) - - else: - if self.model_name: - self.model_ocr_dir = self.model_name - else: - self.model_ocr_dir = dir_models + "/model_eynollah_ocr_cnnrnn_20250930" - model_ocr = load_model(self.model_ocr_dir , compile=False) - - self.prediction_model = tf.keras.models.Model( - model_ocr.get_layer(name = "image").input, - model_ocr.get_layer(name = "dense2").output) - if not batch_size: - self.b_s = 8 - else: - self.b_s = int(batch_size) - - with open(os.path.join(self.model_ocr_dir, "characters_org.txt"),"r") as config_file: - characters = json.load(config_file) - - AUTOTUNE = tf.data.AUTOTUNE - - # Mapping characters to integers. - char_to_num = StringLookup(vocabulary=list(characters), mask_token=None) - - # Mapping integers back to original characters. - self.num_to_char = StringLookup( - vocabulary=char_to_num.get_vocabulary(), mask_token=None, invert=True - ) - self.end_character = len(characters) + 2 - - def run(self, overwrite: bool = False, - dir_in: Optional[str] = None, - dir_in_bin: Optional[str] = None, - image_filename: Optional[str] = None, - dir_xmls: Optional[str] = None, - dir_out_image_text: Optional[str] = None, - dir_out: Optional[str] = None, - ): - if dir_in: - ls_imgs = [os.path.join(dir_in, image_filename) - for image_filename in filter(is_image_filename, - os.listdir(dir_in))] - else: - ls_imgs = [image_filename] - - if self.tr_ocr: - tr_ocr_input_height_and_width = 384 - for dir_img in ls_imgs: - file_name = Path(dir_img).stem - dir_xml = os.path.join(dir_xmls, file_name+'.xml') - out_file_ocr = os.path.join(dir_out, file_name+'.xml') - - if os.path.exists(out_file_ocr): - if overwrite: - self.logger.warning("will overwrite existing output file '%s'", out_file_ocr) - else: - self.logger.warning("will skip input for existing output file '%s'", out_file_ocr) - continue - - img = cv2.imread(dir_img) - - if dir_out_image_text: - out_image_with_text = os.path.join(dir_out_image_text, file_name+'.png') - image_text = Image.new("RGB", (img.shape[1], img.shape[0]), "white") - draw = ImageDraw.Draw(image_text) - total_bb_coordinates = [] - - ##file_name = Path(dir_xmls).stem - tree1 = ET.parse(dir_xml, parser = ET.XMLParser(encoding="utf-8")) - root1=tree1.getroot() - alltags=[elem.tag for elem in root1.iter()] - link=alltags[0].split('}')[0]+'}' - - name_space = alltags[0].split('}')[0] - name_space = name_space.split('{')[1] - - region_tags=np.unique([x for x in alltags if x.endswith('TextRegion')]) - - - - cropped_lines = [] - cropped_lines_region_indexer = [] - cropped_lines_meging_indexing = [] - - extracted_texts = [] - - indexer_text_region = 0 - indexer_b_s = 0 - - for nn in root1.iter(region_tags): - for child_textregion in nn: - if child_textregion.tag.endswith("TextLine"): - - for child_textlines in child_textregion: - if child_textlines.tag.endswith("Coords"): - cropped_lines_region_indexer.append(indexer_text_region) - p_h=child_textlines.attrib['points'].split(' ') - textline_coords = np.array( [ [int(x.split(',')[0]), - int(x.split(',')[1]) ] - for x in p_h] ) - x,y,w,h = cv2.boundingRect(textline_coords) - - if dir_out_image_text: - total_bb_coordinates.append([x,y,w,h]) - - h2w_ratio = h/float(w) - - img_poly_on_img = np.copy(img) - mask_poly = np.zeros(img.shape) - mask_poly = cv2.fillPoly(mask_poly, pts=[textline_coords], color=(1, 1, 1)) - - mask_poly = mask_poly[y:y+h, x:x+w, :] - img_crop = img_poly_on_img[y:y+h, x:x+w, :] - img_crop[mask_poly==0] = 255 - - self.logger.debug("processing %d lines for '%s'", - len(cropped_lines), nn.attrib['id']) - if h2w_ratio > 0.1: - cropped_lines.append(resize_image(img_crop, - tr_ocr_input_height_and_width, - tr_ocr_input_height_and_width) ) - cropped_lines_meging_indexing.append(0) - indexer_b_s+=1 - if indexer_b_s==self.b_s: - imgs = cropped_lines[:] - cropped_lines = [] - indexer_b_s = 0 - - pixel_values_merged = self.processor(imgs, return_tensors="pt").pixel_values - generated_ids_merged = self.model_ocr.generate( - pixel_values_merged.to(self.device)) - generated_text_merged = self.processor.batch_decode( - generated_ids_merged, skip_special_tokens=True) - - extracted_texts = extracted_texts + generated_text_merged - - else: - splited_images, _ = return_textlines_split_if_needed(img_crop, None) - #print(splited_images) - if splited_images: - cropped_lines.append(resize_image(splited_images[0], - tr_ocr_input_height_and_width, - tr_ocr_input_height_and_width)) - cropped_lines_meging_indexing.append(1) - indexer_b_s+=1 - - if indexer_b_s==self.b_s: - imgs = cropped_lines[:] - cropped_lines = [] - indexer_b_s = 0 - - pixel_values_merged = self.processor(imgs, return_tensors="pt").pixel_values - generated_ids_merged = self.model_ocr.generate( - pixel_values_merged.to(self.device)) - generated_text_merged = self.processor.batch_decode( - generated_ids_merged, skip_special_tokens=True) - - extracted_texts = extracted_texts + generated_text_merged - - - cropped_lines.append(resize_image(splited_images[1], - tr_ocr_input_height_and_width, - tr_ocr_input_height_and_width)) - cropped_lines_meging_indexing.append(-1) - indexer_b_s+=1 - - if indexer_b_s==self.b_s: - imgs = cropped_lines[:] - cropped_lines = [] - indexer_b_s = 0 - - pixel_values_merged = self.processor(imgs, return_tensors="pt").pixel_values - generated_ids_merged = self.model_ocr.generate( - pixel_values_merged.to(self.device)) - generated_text_merged = self.processor.batch_decode( - generated_ids_merged, skip_special_tokens=True) - - extracted_texts = extracted_texts + generated_text_merged - - else: - cropped_lines.append(img_crop) - cropped_lines_meging_indexing.append(0) - indexer_b_s+=1 - - if indexer_b_s==self.b_s: - imgs = cropped_lines[:] - cropped_lines = [] - indexer_b_s = 0 - - pixel_values_merged = self.processor(imgs, return_tensors="pt").pixel_values - generated_ids_merged = self.model_ocr.generate( - pixel_values_merged.to(self.device)) - generated_text_merged = self.processor.batch_decode( - generated_ids_merged, skip_special_tokens=True) - - extracted_texts = extracted_texts + generated_text_merged - - - - indexer_text_region = indexer_text_region +1 - - if indexer_b_s!=0: - imgs = cropped_lines[:] - cropped_lines = [] - indexer_b_s = 0 - - pixel_values_merged = self.processor(imgs, return_tensors="pt").pixel_values - generated_ids_merged = self.model_ocr.generate(pixel_values_merged.to(self.device)) - generated_text_merged = self.processor.batch_decode(generated_ids_merged, skip_special_tokens=True) - - extracted_texts = extracted_texts + generated_text_merged - - ####extracted_texts = [] - ####n_iterations = math.ceil(len(cropped_lines) / self.b_s) - - ####for i in range(n_iterations): - ####if i==(n_iterations-1): - ####n_start = i*self.b_s - ####imgs = cropped_lines[n_start:] - ####else: - ####n_start = i*self.b_s - ####n_end = (i+1)*self.b_s - ####imgs = cropped_lines[n_start:n_end] - ####pixel_values_merged = self.processor(imgs, return_tensors="pt").pixel_values - ####generated_ids_merged = self.model_ocr.generate( - #### pixel_values_merged.to(self.device)) - ####generated_text_merged = self.processor.batch_decode( - #### generated_ids_merged, skip_special_tokens=True) - - ####extracted_texts = extracted_texts + generated_text_merged - - del cropped_lines - gc.collect() - - extracted_texts_merged = [extracted_texts[ind] - if cropped_lines_meging_indexing[ind]==0 - else extracted_texts[ind]+" "+extracted_texts[ind+1] - if cropped_lines_meging_indexing[ind]==1 - else None - for ind in range(len(cropped_lines_meging_indexing))] - - extracted_texts_merged = [ind for ind in extracted_texts_merged if ind is not None] - #print(extracted_texts_merged, len(extracted_texts_merged)) - - unique_cropped_lines_region_indexer = np.unique(cropped_lines_region_indexer) - - if dir_out_image_text: - - #font_path = "Charis-7.000/Charis-Regular.ttf" # Make sure this file exists! - font = importlib_resources.files(__package__) / "Charis-Regular.ttf" - with importlib_resources.as_file(font) as font: - font = ImageFont.truetype(font=font, size=40) - - for indexer_text, bb_ind in enumerate(total_bb_coordinates): - - - x_bb = bb_ind[0] - y_bb = bb_ind[1] - w_bb = bb_ind[2] - h_bb = bb_ind[3] - - font = fit_text_single_line(draw, extracted_texts_merged[indexer_text], - font.path, w_bb, int(h_bb*0.4) ) - - ##draw.rectangle([x_bb, y_bb, x_bb + w_bb, y_bb + h_bb], outline="red", width=2) - - text_bbox = draw.textbbox((0, 0), extracted_texts_merged[indexer_text], font=font) - text_width = text_bbox[2] - text_bbox[0] - text_height = text_bbox[3] - text_bbox[1] - - text_x = x_bb + (w_bb - text_width) // 2 # Center horizontally - text_y = y_bb + (h_bb - text_height) // 2 # Center vertically - - # Draw the text - draw.text((text_x, text_y), extracted_texts_merged[indexer_text], fill="black", font=font) - image_text.save(out_image_with_text) - - #print(len(unique_cropped_lines_region_indexer), 'unique_cropped_lines_region_indexer') - #######text_by_textregion = [] - #######for ind in unique_cropped_lines_region_indexer: - #######ind = np.array(cropped_lines_region_indexer)==ind - #######extracted_texts_merged_un = np.array(extracted_texts_merged)[ind] - #######text_by_textregion.append(" ".join(extracted_texts_merged_un)) - - text_by_textregion = [] - for ind in unique_cropped_lines_region_indexer: - ind = np.array(cropped_lines_region_indexer) == ind - extracted_texts_merged_un = np.array(extracted_texts_merged)[ind] - if len(extracted_texts_merged_un)>1: - text_by_textregion_ind = "" - next_glue = "" - for indt in range(len(extracted_texts_merged_un)): - if (extracted_texts_merged_un[indt].endswith('⸗') or - extracted_texts_merged_un[indt].endswith('-') or - extracted_texts_merged_un[indt].endswith('¬')): - text_by_textregion_ind += next_glue + extracted_texts_merged_un[indt][:-1] - next_glue = "" - else: - text_by_textregion_ind += next_glue + extracted_texts_merged_un[indt] - next_glue = " " - text_by_textregion.append(text_by_textregion_ind) - else: - text_by_textregion.append(" ".join(extracted_texts_merged_un)) - - - indexer = 0 - indexer_textregion = 0 - for nn in root1.iter(region_tags): - #id_textregion = nn.attrib['id'] - #id_textregions.append(id_textregion) - #textregions_by_existing_ids.append(text_by_textregion[indexer_textregion]) - - is_textregion_text = False - for childtest in nn: - if childtest.tag.endswith("TextEquiv"): - is_textregion_text = True - - if not is_textregion_text: - text_subelement_textregion = ET.SubElement(nn, 'TextEquiv') - unicode_textregion = ET.SubElement(text_subelement_textregion, 'Unicode') - - - has_textline = False - for child_textregion in nn: - if child_textregion.tag.endswith("TextLine"): - - is_textline_text = False - for childtest2 in child_textregion: - if childtest2.tag.endswith("TextEquiv"): - is_textline_text = True - - - if not is_textline_text: - text_subelement = ET.SubElement(child_textregion, 'TextEquiv') - ##text_subelement.set('conf', f"{extracted_conf_value_merged[indexer]:.2f}") - unicode_textline = ET.SubElement(text_subelement, 'Unicode') - unicode_textline.text = extracted_texts_merged[indexer] - else: - for childtest3 in child_textregion: - if childtest3.tag.endswith("TextEquiv"): - for child_uc in childtest3: - if child_uc.tag.endswith("Unicode"): - ##childtest3.set('conf', f"{extracted_conf_value_merged[indexer]:.2f}") - child_uc.text = extracted_texts_merged[indexer] - - indexer = indexer + 1 - has_textline = True - if has_textline: - if is_textregion_text: - for child4 in nn: - if child4.tag.endswith("TextEquiv"): - for childtr_uc in child4: - if childtr_uc.tag.endswith("Unicode"): - childtr_uc.text = text_by_textregion[indexer_textregion] - else: - unicode_textregion.text = text_by_textregion[indexer_textregion] - indexer_textregion = indexer_textregion + 1 - - ###sample_order = [(id_to_order[tid], text) - ### for tid, text in zip(id_textregions, textregions_by_existing_ids) - ### if tid in id_to_order] - - ##ordered_texts_sample = [text for _, text in sorted(sample_order)] - ##tot_page_text = ' '.join(ordered_texts_sample) - - ##for page_element in root1.iter(link+'Page'): - ##text_page = ET.SubElement(page_element, 'TextEquiv') - ##unicode_textpage = ET.SubElement(text_page, 'Unicode') - ##unicode_textpage.text = tot_page_text - - ET.register_namespace("",name_space) - tree1.write(out_file_ocr,xml_declaration=True,method='xml',encoding="utf-8",default_namespace=None) - else: - ###max_len = 280#512#280#512 - ###padding_token = 1500#299#1500#299 - image_width = 512#max_len * 4 - image_height = 32 - - - img_size=(image_width, image_height) - - for dir_img in ls_imgs: - file_name = Path(dir_img).stem - dir_xml = os.path.join(dir_xmls, file_name+'.xml') - out_file_ocr = os.path.join(dir_out, file_name+'.xml') - - if os.path.exists(out_file_ocr): - if overwrite: - self.logger.warning("will overwrite existing output file '%s'", out_file_ocr) - else: - self.logger.warning("will skip input for existing output file '%s'", out_file_ocr) - continue - - img = cv2.imread(dir_img) - if dir_in_bin is not None: - cropped_lines_bin = [] - dir_img_bin = os.path.join(dir_in_bin, file_name+'.png') - img_bin = cv2.imread(dir_img_bin) - - if dir_out_image_text: - out_image_with_text = os.path.join(dir_out_image_text, file_name+'.png') - image_text = Image.new("RGB", (img.shape[1], img.shape[0]), "white") - draw = ImageDraw.Draw(image_text) - total_bb_coordinates = [] - - tree1 = ET.parse(dir_xml, parser = ET.XMLParser(encoding="utf-8")) - root1=tree1.getroot() - alltags=[elem.tag for elem in root1.iter()] - link=alltags[0].split('}')[0]+'}' - - name_space = alltags[0].split('}')[0] - name_space = name_space.split('{')[1] - - region_tags=np.unique([x for x in alltags if x.endswith('TextRegion')]) - - cropped_lines = [] - cropped_lines_ver_index = [] - cropped_lines_region_indexer = [] - cropped_lines_meging_indexing = [] - - tinl = time.time() - indexer_text_region = 0 - indexer_textlines = 0 - for nn in root1.iter(region_tags): - try: - type_textregion = nn.attrib['type'] - except: - type_textregion = 'paragraph' - for child_textregion in nn: - if child_textregion.tag.endswith("TextLine"): - for child_textlines in child_textregion: - if child_textlines.tag.endswith("Coords"): - cropped_lines_region_indexer.append(indexer_text_region) - p_h=child_textlines.attrib['points'].split(' ') - textline_coords = np.array( [ [int(x.split(',')[0]), - int(x.split(',')[1]) ] - for x in p_h] ) - - x,y,w,h = cv2.boundingRect(textline_coords) - - angle_radians = math.atan2(h, w) - # Convert to degrees - angle_degrees = math.degrees(angle_radians) - if type_textregion=='drop-capital': - angle_degrees = 0 - - if dir_out_image_text: - total_bb_coordinates.append([x,y,w,h]) - - w_scaled = w * image_height/float(h) - - img_poly_on_img = np.copy(img) - if dir_in_bin is not None: - img_poly_on_img_bin = np.copy(img_bin) - img_crop_bin = img_poly_on_img_bin[y:y+h, x:x+w, :] - - mask_poly = np.zeros(img.shape) - mask_poly = cv2.fillPoly(mask_poly, pts=[textline_coords], color=(1, 1, 1)) - - - mask_poly = mask_poly[y:y+h, x:x+w, :] - img_crop = img_poly_on_img[y:y+h, x:x+w, :] - - if self.export_textline_images_and_text: - if not self.do_not_mask_with_textline_contour: - img_crop[mask_poly==0] = 255 - - else: - # print(file_name, angle_degrees, w*h, - # mask_poly[:,:,0].sum(), - # mask_poly[:,:,0].sum() /float(w*h) , - # 'didi') - - if angle_degrees > 3: - better_des_slope = get_orientation_moments(textline_coords) - - img_crop = rotate_image_with_padding(img_crop, better_des_slope) - if dir_in_bin is not None: - img_crop_bin = rotate_image_with_padding(img_crop_bin, better_des_slope) - - mask_poly = rotate_image_with_padding(mask_poly, better_des_slope) - mask_poly = mask_poly.astype('uint8') - - #new bounding box - x_n, y_n, w_n, h_n = get_contours_and_bounding_boxes(mask_poly[:,:,0]) - - mask_poly = mask_poly[y_n:y_n+h_n, x_n:x_n+w_n, :] - img_crop = img_crop[y_n:y_n+h_n, x_n:x_n+w_n, :] - - if not self.do_not_mask_with_textline_contour: - img_crop[mask_poly==0] = 255 - if dir_in_bin is not None: - img_crop_bin = img_crop_bin[y_n:y_n+h_n, x_n:x_n+w_n, :] - if not self.do_not_mask_with_textline_contour: - img_crop_bin[mask_poly==0] = 255 - - if mask_poly[:,:,0].sum() /float(w_n*h_n) < 0.50 and w_scaled > 90: - if dir_in_bin is not None: - img_crop, img_crop_bin = \ - break_curved_line_into_small_pieces_and_then_merge( - img_crop, mask_poly, img_crop_bin) - else: - img_crop, _ = \ - break_curved_line_into_small_pieces_and_then_merge( - img_crop, mask_poly) - - else: - better_des_slope = 0 - if not self.do_not_mask_with_textline_contour: - img_crop[mask_poly==0] = 255 - if dir_in_bin is not None: - if not self.do_not_mask_with_textline_contour: - img_crop_bin[mask_poly==0] = 255 - if type_textregion=='drop-capital': - pass - else: - if mask_poly[:,:,0].sum() /float(w*h) < 0.50 and w_scaled > 90: - if dir_in_bin is not None: - img_crop, img_crop_bin = \ - break_curved_line_into_small_pieces_and_then_merge( - img_crop, mask_poly, img_crop_bin) - else: - img_crop, _ = \ - break_curved_line_into_small_pieces_and_then_merge( - img_crop, mask_poly) - - if not self.export_textline_images_and_text: - if w_scaled < 750:#1.5*image_width: - img_fin = preprocess_and_resize_image_for_ocrcnn_model( - img_crop, image_height, image_width) - cropped_lines.append(img_fin) - if abs(better_des_slope) > 45: - cropped_lines_ver_index.append(1) - else: - cropped_lines_ver_index.append(0) - - cropped_lines_meging_indexing.append(0) - if dir_in_bin is not None: - img_fin = preprocess_and_resize_image_for_ocrcnn_model( - img_crop_bin, image_height, image_width) - cropped_lines_bin.append(img_fin) - else: - splited_images, splited_images_bin = return_textlines_split_if_needed( - img_crop, img_crop_bin if dir_in_bin is not None else None) - if splited_images: - img_fin = preprocess_and_resize_image_for_ocrcnn_model( - splited_images[0], image_height, image_width) - cropped_lines.append(img_fin) - cropped_lines_meging_indexing.append(1) - - if abs(better_des_slope) > 45: - cropped_lines_ver_index.append(1) - else: - cropped_lines_ver_index.append(0) - - img_fin = preprocess_and_resize_image_for_ocrcnn_model( - splited_images[1], image_height, image_width) - - cropped_lines.append(img_fin) - cropped_lines_meging_indexing.append(-1) - - if abs(better_des_slope) > 45: - cropped_lines_ver_index.append(1) - else: - cropped_lines_ver_index.append(0) - - if dir_in_bin is not None: - img_fin = preprocess_and_resize_image_for_ocrcnn_model( - splited_images_bin[0], image_height, image_width) - cropped_lines_bin.append(img_fin) - img_fin = preprocess_and_resize_image_for_ocrcnn_model( - splited_images_bin[1], image_height, image_width) - cropped_lines_bin.append(img_fin) - - else: - img_fin = preprocess_and_resize_image_for_ocrcnn_model( - img_crop, image_height, image_width) - cropped_lines.append(img_fin) - cropped_lines_meging_indexing.append(0) - - if abs(better_des_slope) > 45: - cropped_lines_ver_index.append(1) - else: - cropped_lines_ver_index.append(0) - - if dir_in_bin is not None: - img_fin = preprocess_and_resize_image_for_ocrcnn_model( - img_crop_bin, image_height, image_width) - cropped_lines_bin.append(img_fin) - - if self.export_textline_images_and_text: - if img_crop.shape[0]==0 or img_crop.shape[1]==0: - pass - else: - if child_textlines.tag.endswith("TextEquiv"): - for cheild_text in child_textlines: - if cheild_text.tag.endswith("Unicode"): - textline_text = cheild_text.text - if textline_text: - base_name = os.path.join( - dir_out, file_name + '_line_' + str(indexer_textlines)) - if self.pref_of_dataset: - base_name += '_' + self.pref_of_dataset - if not self.do_not_mask_with_textline_contour: - base_name += '_masked' - - with open(base_name + '.txt', 'w') as text_file: - text_file.write(textline_text) - cv2.imwrite(base_name + '.png', img_crop) - indexer_textlines+=1 - - if not self.export_textline_images_and_text: - indexer_text_region = indexer_text_region +1 - - if not self.export_textline_images_and_text: - extracted_texts = [] - extracted_conf_value = [] - - n_iterations = math.ceil(len(cropped_lines) / self.b_s) - - for i in range(n_iterations): - if i==(n_iterations-1): - n_start = i*self.b_s - imgs = cropped_lines[n_start:] - imgs = np.array(imgs) - imgs = imgs.reshape(imgs.shape[0], image_height, image_width, 3) - - ver_imgs = np.array( cropped_lines_ver_index[n_start:] ) - indices_ver = np.where(ver_imgs == 1)[0] - - #print(indices_ver, 'indices_ver') - if len(indices_ver)>0: - imgs_ver_flipped = imgs[indices_ver, : ,: ,:] - imgs_ver_flipped = imgs_ver_flipped[:,::-1,::-1,:] - #print(imgs_ver_flipped, 'imgs_ver_flipped') - - else: - imgs_ver_flipped = None - - if dir_in_bin is not None: - imgs_bin = cropped_lines_bin[n_start:] - imgs_bin = np.array(imgs_bin) - imgs_bin = imgs_bin.reshape(imgs_bin.shape[0], image_height, image_width, 3) - - if len(indices_ver)>0: - imgs_bin_ver_flipped = imgs_bin[indices_ver, : ,: ,:] - imgs_bin_ver_flipped = imgs_bin_ver_flipped[:,::-1,::-1,:] - #print(imgs_ver_flipped, 'imgs_ver_flipped') - - else: - imgs_bin_ver_flipped = None - else: - n_start = i*self.b_s - n_end = (i+1)*self.b_s - imgs = cropped_lines[n_start:n_end] - imgs = np.array(imgs).reshape(self.b_s, image_height, image_width, 3) - - ver_imgs = np.array( cropped_lines_ver_index[n_start:n_end] ) - indices_ver = np.where(ver_imgs == 1)[0] - #print(indices_ver, 'indices_ver') - - if len(indices_ver)>0: - imgs_ver_flipped = imgs[indices_ver, : ,: ,:] - imgs_ver_flipped = imgs_ver_flipped[:,::-1,::-1,:] - #print(imgs_ver_flipped, 'imgs_ver_flipped') - else: - imgs_ver_flipped = None - - - if dir_in_bin is not None: - imgs_bin = cropped_lines_bin[n_start:n_end] - imgs_bin = np.array(imgs_bin).reshape(self.b_s, image_height, image_width, 3) - - - if len(indices_ver)>0: - imgs_bin_ver_flipped = imgs_bin[indices_ver, : ,: ,:] - imgs_bin_ver_flipped = imgs_bin_ver_flipped[:,::-1,::-1,:] - #print(imgs_ver_flipped, 'imgs_ver_flipped') - else: - imgs_bin_ver_flipped = None - - - self.logger.debug("processing next %d lines", len(imgs)) - preds = self.prediction_model.predict(imgs, verbose=0) - - if len(indices_ver)>0: - preds_flipped = self.prediction_model.predict(imgs_ver_flipped, verbose=0) - preds_max_fliped = np.max(preds_flipped, axis=2 ) - preds_max_args_flipped = np.argmax(preds_flipped, axis=2 ) - pred_max_not_unk_mask_bool_flipped = preds_max_args_flipped[:,:]!=self.end_character - masked_means_flipped = \ - np.sum(preds_max_fliped * pred_max_not_unk_mask_bool_flipped, axis=1) / \ - np.sum(pred_max_not_unk_mask_bool_flipped, axis=1) - masked_means_flipped[np.isnan(masked_means_flipped)] = 0 - - preds_max = np.max(preds, axis=2 ) - preds_max_args = np.argmax(preds, axis=2 ) - pred_max_not_unk_mask_bool = preds_max_args[:,:]!=self.end_character - - masked_means = \ - np.sum(preds_max * pred_max_not_unk_mask_bool, axis=1) / \ - np.sum(pred_max_not_unk_mask_bool, axis=1) - masked_means[np.isnan(masked_means)] = 0 - - masked_means_ver = masked_means[indices_ver] - #print(masked_means_ver, 'pred_max_not_unk') - - indices_where_flipped_conf_value_is_higher = \ - np.where(masked_means_flipped > masked_means_ver)[0] - - #print(indices_where_flipped_conf_value_is_higher, 'indices_where_flipped_conf_value_is_higher') - if len(indices_where_flipped_conf_value_is_higher)>0: - indices_to_be_replaced = indices_ver[indices_where_flipped_conf_value_is_higher] - preds[indices_to_be_replaced,:,:] = \ - preds_flipped[indices_where_flipped_conf_value_is_higher, :, :] - if dir_in_bin is not None: - preds_bin = self.prediction_model.predict(imgs_bin, verbose=0) - - if len(indices_ver)>0: - preds_flipped = self.prediction_model.predict(imgs_bin_ver_flipped, verbose=0) - preds_max_fliped = np.max(preds_flipped, axis=2 ) - preds_max_args_flipped = np.argmax(preds_flipped, axis=2 ) - pred_max_not_unk_mask_bool_flipped = preds_max_args_flipped[:,:]!=self.end_character - masked_means_flipped = \ - np.sum(preds_max_fliped * pred_max_not_unk_mask_bool_flipped, axis=1) / \ - np.sum(pred_max_not_unk_mask_bool_flipped, axis=1) - masked_means_flipped[np.isnan(masked_means_flipped)] = 0 - - preds_max = np.max(preds, axis=2 ) - preds_max_args = np.argmax(preds, axis=2 ) - pred_max_not_unk_mask_bool = preds_max_args[:,:]!=self.end_character - - masked_means = \ - np.sum(preds_max * pred_max_not_unk_mask_bool, axis=1) / \ - np.sum(pred_max_not_unk_mask_bool, axis=1) - masked_means[np.isnan(masked_means)] = 0 - - masked_means_ver = masked_means[indices_ver] - #print(masked_means_ver, 'pred_max_not_unk') - - indices_where_flipped_conf_value_is_higher = \ - np.where(masked_means_flipped > masked_means_ver)[0] - - #print(indices_where_flipped_conf_value_is_higher, 'indices_where_flipped_conf_value_is_higher') - if len(indices_where_flipped_conf_value_is_higher)>0: - indices_to_be_replaced = indices_ver[indices_where_flipped_conf_value_is_higher] - preds_bin[indices_to_be_replaced,:,:] = \ - preds_flipped[indices_where_flipped_conf_value_is_higher, :, :] - - preds = (preds + preds_bin) / 2. - - pred_texts = decode_batch_predictions(preds, self.num_to_char) - - preds_max = np.max(preds, axis=2 ) - preds_max_args = np.argmax(preds, axis=2 ) - pred_max_not_unk_mask_bool = preds_max_args[:,:]!=self.end_character - masked_means = \ - np.sum(preds_max * pred_max_not_unk_mask_bool, axis=1) / \ - np.sum(pred_max_not_unk_mask_bool, axis=1) - - for ib in range(imgs.shape[0]): - pred_texts_ib = pred_texts[ib].replace("[UNK]", "") - if masked_means[ib] >= self.min_conf_value_of_textline_text: - extracted_texts.append(pred_texts_ib) - extracted_conf_value.append(masked_means[ib]) - else: - extracted_texts.append("") - extracted_conf_value.append(0) - del cropped_lines - if dir_in_bin is not None: - del cropped_lines_bin - gc.collect() - - extracted_texts_merged = [extracted_texts[ind] - if cropped_lines_meging_indexing[ind]==0 - else extracted_texts[ind]+" "+extracted_texts[ind+1] - if cropped_lines_meging_indexing[ind]==1 - else None - for ind in range(len(cropped_lines_meging_indexing))] - - extracted_conf_value_merged = [extracted_conf_value[ind] - if cropped_lines_meging_indexing[ind]==0 - else (extracted_conf_value[ind]+extracted_conf_value[ind+1])/2. - if cropped_lines_meging_indexing[ind]==1 - else None - for ind in range(len(cropped_lines_meging_indexing))] - - extracted_conf_value_merged = [extracted_conf_value_merged[ind_cfm] - for ind_cfm in range(len(extracted_texts_merged)) - if extracted_texts_merged[ind_cfm] is not None] - extracted_texts_merged = [ind for ind in extracted_texts_merged if ind is not None] - unique_cropped_lines_region_indexer = np.unique(cropped_lines_region_indexer) - - if dir_out_image_text: - #font_path = "Charis-7.000/Charis-Regular.ttf" # Make sure this file exists! - font = importlib_resources.files(__package__) / "Charis-Regular.ttf" - with importlib_resources.as_file(font) as font: - font = ImageFont.truetype(font=font, size=40) - - for indexer_text, bb_ind in enumerate(total_bb_coordinates): - x_bb = bb_ind[0] - y_bb = bb_ind[1] - w_bb = bb_ind[2] - h_bb = bb_ind[3] - - font = fit_text_single_line(draw, extracted_texts_merged[indexer_text], - font.path, w_bb, int(h_bb*0.4) ) - - ##draw.rectangle([x_bb, y_bb, x_bb + w_bb, y_bb + h_bb], outline="red", width=2) - - text_bbox = draw.textbbox((0, 0), extracted_texts_merged[indexer_text], font=font) - text_width = text_bbox[2] - text_bbox[0] - text_height = text_bbox[3] - text_bbox[1] - - text_x = x_bb + (w_bb - text_width) // 2 # Center horizontally - text_y = y_bb + (h_bb - text_height) // 2 # Center vertically - - # Draw the text - draw.text((text_x, text_y), extracted_texts_merged[indexer_text], fill="black", font=font) - image_text.save(out_image_with_text) - - text_by_textregion = [] - for ind in unique_cropped_lines_region_indexer: - ind = np.array(cropped_lines_region_indexer)==ind - extracted_texts_merged_un = np.array(extracted_texts_merged)[ind] - if len(extracted_texts_merged_un)>1: - text_by_textregion_ind = "" - next_glue = "" - for indt in range(len(extracted_texts_merged_un)): - if (extracted_texts_merged_un[indt].endswith('⸗') or - extracted_texts_merged_un[indt].endswith('-') or - extracted_texts_merged_un[indt].endswith('¬')): - text_by_textregion_ind += next_glue + extracted_texts_merged_un[indt][:-1] - next_glue = "" - else: - text_by_textregion_ind += next_glue + extracted_texts_merged_un[indt] - next_glue = " " - text_by_textregion.append(text_by_textregion_ind) - else: - text_by_textregion.append(" ".join(extracted_texts_merged_un)) - #print(text_by_textregion, 'text_by_textregiontext_by_textregiontext_by_textregiontext_by_textregiontext_by_textregion') - - ###index_tot_regions = [] - ###tot_region_ref = [] - - ###for jj in root1.iter(link+'RegionRefIndexed'): - ###index_tot_regions.append(jj.attrib['index']) - ###tot_region_ref.append(jj.attrib['regionRef']) - - ###id_to_order = {tid: ro for tid, ro in zip(tot_region_ref, index_tot_regions)} - - #id_textregions = [] - #textregions_by_existing_ids = [] - indexer = 0 - indexer_textregion = 0 - for nn in root1.iter(region_tags): - #id_textregion = nn.attrib['id'] - #id_textregions.append(id_textregion) - #textregions_by_existing_ids.append(text_by_textregion[indexer_textregion]) - - is_textregion_text = False - for childtest in nn: - if childtest.tag.endswith("TextEquiv"): - is_textregion_text = True - - if not is_textregion_text: - text_subelement_textregion = ET.SubElement(nn, 'TextEquiv') - unicode_textregion = ET.SubElement(text_subelement_textregion, 'Unicode') - - - has_textline = False - for child_textregion in nn: - if child_textregion.tag.endswith("TextLine"): - - is_textline_text = False - for childtest2 in child_textregion: - if childtest2.tag.endswith("TextEquiv"): - is_textline_text = True - - - if not is_textline_text: - text_subelement = ET.SubElement(child_textregion, 'TextEquiv') - text_subelement.set('conf', f"{extracted_conf_value_merged[indexer]:.2f}") - unicode_textline = ET.SubElement(text_subelement, 'Unicode') - unicode_textline.text = extracted_texts_merged[indexer] - else: - for childtest3 in child_textregion: - if childtest3.tag.endswith("TextEquiv"): - for child_uc in childtest3: - if child_uc.tag.endswith("Unicode"): - childtest3.set('conf', - f"{extracted_conf_value_merged[indexer]:.2f}") - child_uc.text = extracted_texts_merged[indexer] - - indexer = indexer + 1 - has_textline = True - if has_textline: - if is_textregion_text: - for child4 in nn: - if child4.tag.endswith("TextEquiv"): - for childtr_uc in child4: - if childtr_uc.tag.endswith("Unicode"): - childtr_uc.text = text_by_textregion[indexer_textregion] - else: - unicode_textregion.text = text_by_textregion[indexer_textregion] - indexer_textregion = indexer_textregion + 1 - - ###sample_order = [(id_to_order[tid], text) - ### for tid, text in zip(id_textregions, textregions_by_existing_ids) - ### if tid in id_to_order] - - ##ordered_texts_sample = [text for _, text in sorted(sample_order)] - ##tot_page_text = ' '.join(ordered_texts_sample) - - ##for page_element in root1.iter(link+'Page'): - ##text_page = ET.SubElement(page_element, 'TextEquiv') - ##unicode_textpage = ET.SubElement(text_page, 'Unicode') - ##unicode_textpage.text = tot_page_text - - ET.register_namespace("",name_space) - tree1.write(out_file_ocr,xml_declaration=True,method='xml',encoding="utf-8",default_namespace=None) - #print("Job done in %.1fs", time.time() - t0) diff --git a/src/eynollah/eynollah_imports.py b/src/eynollah/eynollah_imports.py new file mode 100644 index 0000000..f04cfdc --- /dev/null +++ b/src/eynollah/eynollah_imports.py @@ -0,0 +1,10 @@ +""" +Load libraries with possible race conditions once. This must be imported as the first module of eynollah. +""" +from ocrd_utils import tf_disable_interactive_logs +from torch import * +tf_disable_interactive_logs() +import tensorflow.keras +from shapely import * +imported_libs = True +__all__ = ['imported_libs'] diff --git a/src/eynollah/eynollah_ocr.py b/src/eynollah/eynollah_ocr.py new file mode 100644 index 0000000..3c918e5 --- /dev/null +++ b/src/eynollah/eynollah_ocr.py @@ -0,0 +1,837 @@ +# FIXME: fix all of those... +# pyright: reportOptionalSubscript=false + +from logging import Logger, getLogger +from typing import List, Optional +from pathlib import Path +import os +import gc +import math +from dataclasses import dataclass + +import cv2 +from cv2.typing import MatLike +from xml.etree import ElementTree as ET +from PIL import Image, ImageDraw +import numpy as np +from eynollah.model_zoo import EynollahModelZoo +from eynollah.utils.font import get_font +from eynollah.utils.xml import etree_namespace_for_element_tag +try: + import torch +except ImportError: + torch = None + + +from .utils import is_image_filename +from .utils.resize import resize_image +from .utils.utils_ocr import ( + break_curved_line_into_small_pieces_and_then_merge, + decode_batch_predictions, + fit_text_single_line, + get_contours_and_bounding_boxes, + get_orientation_moments, + preprocess_and_resize_image_for_ocrcnn_model, + return_textlines_split_if_needed, + rotate_image_with_padding, +) + +# TODO: refine typing +@dataclass +class EynollahOcrResult: + extracted_texts_merged: List + extracted_conf_value_merged: Optional[List] + cropped_lines_region_indexer: List + total_bb_coordinates:List + +class Eynollah_ocr: + def __init__( + self, + *, + model_zoo: EynollahModelZoo, + tr_ocr=False, + batch_size: Optional[int]=None, + do_not_mask_with_textline_contour: bool=False, + min_conf_value_of_textline_text : Optional[float]=None, + logger: Optional[Logger]=None, + ): + self.tr_ocr = tr_ocr + # masking for OCR and GT generation, relevant for skewed lines and bounding boxes + self.do_not_mask_with_textline_contour = do_not_mask_with_textline_contour + self.logger = logger if logger else getLogger('eynollah.ocr') + self.model_zoo = model_zoo + + self.min_conf_value_of_textline_text = min_conf_value_of_textline_text if min_conf_value_of_textline_text else 0.3 + self.b_s = 2 if batch_size is None and tr_ocr else 8 if batch_size is None else batch_size + + if tr_ocr: + self.model_zoo.load_model('trocr_processor') + self.model_zoo.load_model('ocr', 'tr') + self.model_zoo.get('ocr').to(self.device) + else: + self.model_zoo.load_model('ocr', '') + self.model_zoo.load_model('num_to_char') + self.model_zoo.load_model('characters') + self.end_character = len(self.model_zoo.get('characters', list)) + 2 + + @property + def device(self): + assert torch + if torch.cuda.is_available(): + self.logger.info("Using GPU acceleration") + return torch.device("cuda:0") + else: + self.logger.info("Using CPU processing") + return torch.device("cpu") + + def run_trocr( + self, + *, + img: MatLike, + page_tree: ET.ElementTree, + page_ns, + tr_ocr_input_height_and_width, + ) -> EynollahOcrResult: + + total_bb_coordinates = [] + + + cropped_lines = [] + cropped_lines_region_indexer = [] + cropped_lines_meging_indexing = [] + + extracted_texts = [] + + indexer_text_region = 0 + indexer_b_s = 0 + + for nn in page_tree.getroot().iter(f'{{{page_ns}}}TextRegion'): + for child_textregion in nn: + if child_textregion.tag.endswith("TextLine"): + + for child_textlines in child_textregion: + if child_textlines.tag.endswith("Coords"): + cropped_lines_region_indexer.append(indexer_text_region) + p_h=child_textlines.attrib['points'].split(' ') + textline_coords = np.array( [ [int(x.split(',')[0]), + int(x.split(',')[1]) ] + for x in p_h] ) + x,y,w,h = cv2.boundingRect(textline_coords) + + total_bb_coordinates.append([x,y,w,h]) + + h2w_ratio = h/float(w) + + img_poly_on_img = np.copy(img) + mask_poly = np.zeros(img.shape) + mask_poly = cv2.fillPoly(mask_poly, pts=[textline_coords], color=(1, 1, 1)) + + mask_poly = mask_poly[y:y+h, x:x+w, :] + img_crop = img_poly_on_img[y:y+h, x:x+w, :] + img_crop[mask_poly==0] = 255 + + self.logger.debug("processing %d lines for '%s'", + len(cropped_lines), nn.attrib['id']) + if h2w_ratio > 0.1: + cropped_lines.append(resize_image(img_crop, + tr_ocr_input_height_and_width, + tr_ocr_input_height_and_width) ) + cropped_lines_meging_indexing.append(0) + indexer_b_s+=1 + if indexer_b_s==self.b_s: + imgs = cropped_lines[:] + cropped_lines = [] + indexer_b_s = 0 + + pixel_values_merged = self.model_zoo.get('trocr_processor')(imgs, return_tensors="pt").pixel_values + generated_ids_merged = self.model_zoo.get('ocr').generate( + pixel_values_merged.to(self.device)) + generated_text_merged = self.model_zoo.get('trocr_processor').batch_decode( + generated_ids_merged, skip_special_tokens=True) + + extracted_texts = extracted_texts + generated_text_merged + + else: + splited_images, _ = return_textlines_split_if_needed(img_crop, None) + #print(splited_images) + if splited_images: + cropped_lines.append(resize_image(splited_images[0], + tr_ocr_input_height_and_width, + tr_ocr_input_height_and_width)) + cropped_lines_meging_indexing.append(1) + indexer_b_s+=1 + + if indexer_b_s==self.b_s: + imgs = cropped_lines[:] + cropped_lines = [] + indexer_b_s = 0 + + pixel_values_merged = self.model_zoo.get('trocr_processor')(imgs, return_tensors="pt").pixel_values + generated_ids_merged = self.model_zoo.get('ocr').generate( + pixel_values_merged.to(self.device)) + generated_text_merged = self.model_zoo.get('trocr_processor').batch_decode( + generated_ids_merged, skip_special_tokens=True) + + extracted_texts = extracted_texts + generated_text_merged + + + cropped_lines.append(resize_image(splited_images[1], + tr_ocr_input_height_and_width, + tr_ocr_input_height_and_width)) + cropped_lines_meging_indexing.append(-1) + indexer_b_s+=1 + + if indexer_b_s==self.b_s: + imgs = cropped_lines[:] + cropped_lines = [] + indexer_b_s = 0 + + pixel_values_merged = self.model_zoo.get('trocr_processor')(imgs, return_tensors="pt").pixel_values + generated_ids_merged = self.model_zoo.get('ocr').generate( + pixel_values_merged.to(self.device)) + generated_text_merged = self.model_zoo.get('trocr_processor').batch_decode( + generated_ids_merged, skip_special_tokens=True) + + extracted_texts = extracted_texts + generated_text_merged + + else: + cropped_lines.append(img_crop) + cropped_lines_meging_indexing.append(0) + indexer_b_s+=1 + + if indexer_b_s==self.b_s: + imgs = cropped_lines[:] + cropped_lines = [] + indexer_b_s = 0 + + pixel_values_merged = self.model_zoo.get('trocr_processor')(imgs, return_tensors="pt").pixel_values + generated_ids_merged = self.model_zoo.get('ocr').generate( + pixel_values_merged.to(self.device)) + generated_text_merged = self.model_zoo.get('trocr_processor').batch_decode( + generated_ids_merged, skip_special_tokens=True) + + extracted_texts = extracted_texts + generated_text_merged + + + + indexer_text_region = indexer_text_region +1 + + if indexer_b_s!=0: + imgs = cropped_lines[:] + cropped_lines = [] + indexer_b_s = 0 + + pixel_values_merged = self.model_zoo.get('trocr_processor')(imgs, return_tensors="pt").pixel_values + generated_ids_merged = self.model_zoo.get('ocr').generate(pixel_values_merged.to(self.device)) + generated_text_merged = self.model_zoo.get('trocr_processor').batch_decode(generated_ids_merged, skip_special_tokens=True) + + extracted_texts = extracted_texts + generated_text_merged + + ####extracted_texts = [] + ####n_iterations = math.ceil(len(cropped_lines) / self.b_s) + + ####for i in range(n_iterations): + ####if i==(n_iterations-1): + ####n_start = i*self.b_s + ####imgs = cropped_lines[n_start:] + ####else: + ####n_start = i*self.b_s + ####n_end = (i+1)*self.b_s + ####imgs = cropped_lines[n_start:n_end] + ####pixel_values_merged = self.model_zoo.get('trocr_processor')(imgs, return_tensors="pt").pixel_values + ####generated_ids_merged = self.model_ocr.generate( + #### pixel_values_merged.to(self.device)) + ####generated_text_merged = self.model_zoo.get('trocr_processor').batch_decode( + #### generated_ids_merged, skip_special_tokens=True) + + ####extracted_texts = extracted_texts + generated_text_merged + + del cropped_lines + gc.collect() + + extracted_texts_merged = [extracted_texts[ind] + if cropped_lines_meging_indexing[ind]==0 + else extracted_texts[ind]+" "+extracted_texts[ind+1] + if cropped_lines_meging_indexing[ind]==1 + else None + for ind in range(len(cropped_lines_meging_indexing))] + + extracted_texts_merged = [ind for ind in extracted_texts_merged if ind is not None] + #print(extracted_texts_merged, len(extracted_texts_merged)) + + return EynollahOcrResult( + extracted_texts_merged=extracted_texts_merged, + extracted_conf_value_merged=None, + cropped_lines_region_indexer=cropped_lines_region_indexer, + total_bb_coordinates=total_bb_coordinates, + ) + + def run_cnn( + self, + *, + img: MatLike, + img_bin: Optional[MatLike], + page_tree: ET.ElementTree, + page_ns, + image_width, + image_height, + ) -> EynollahOcrResult: + + total_bb_coordinates = [] + + cropped_lines = [] + img_crop_bin = None + imgs_bin = None + imgs_bin_ver_flipped = None + cropped_lines_bin = [] + cropped_lines_ver_index = [] + cropped_lines_region_indexer = [] + cropped_lines_meging_indexing = [] + + indexer_text_region = 0 + for nn in page_tree.getroot().iter(f'{{{page_ns}}}TextRegion'): + try: + type_textregion = nn.attrib['type'] + except: + type_textregion = 'paragraph' + for child_textregion in nn: + if child_textregion.tag.endswith("TextLine"): + for child_textlines in child_textregion: + if child_textlines.tag.endswith("Coords"): + cropped_lines_region_indexer.append(indexer_text_region) + p_h=child_textlines.attrib['points'].split(' ') + textline_coords = np.array( [ [int(x.split(',')[0]), + int(x.split(',')[1]) ] + for x in p_h] ) + + x,y,w,h = cv2.boundingRect(textline_coords) + + angle_radians = math.atan2(h, w) + # Convert to degrees + angle_degrees = math.degrees(angle_radians) + if type_textregion=='drop-capital': + angle_degrees = 0 + + total_bb_coordinates.append([x,y,w,h]) + + w_scaled = w * image_height/float(h) + + img_poly_on_img = np.copy(img) + if img_bin: + img_poly_on_img_bin = np.copy(img_bin) + img_crop_bin = img_poly_on_img_bin[y:y+h, x:x+w, :] + + mask_poly = np.zeros(img.shape) + mask_poly = cv2.fillPoly(mask_poly, pts=[textline_coords], color=(1, 1, 1)) + + + mask_poly = mask_poly[y:y+h, x:x+w, :] + img_crop = img_poly_on_img[y:y+h, x:x+w, :] + + # print(file_name, angle_degrees, w*h, + # mask_poly[:,:,0].sum(), + # mask_poly[:,:,0].sum() /float(w*h) , + # 'didi') + + if angle_degrees > 3: + better_des_slope = get_orientation_moments(textline_coords) + + img_crop = rotate_image_with_padding(img_crop, better_des_slope) + if img_bin: + img_crop_bin = rotate_image_with_padding(img_crop_bin, better_des_slope) + + mask_poly = rotate_image_with_padding(mask_poly, better_des_slope) + mask_poly = mask_poly.astype('uint8') + + #new bounding box + x_n, y_n, w_n, h_n = get_contours_and_bounding_boxes(mask_poly[:,:,0]) + + mask_poly = mask_poly[y_n:y_n+h_n, x_n:x_n+w_n, :] + img_crop = img_crop[y_n:y_n+h_n, x_n:x_n+w_n, :] + + if not self.do_not_mask_with_textline_contour: + img_crop[mask_poly==0] = 255 + if img_bin: + img_crop_bin = img_crop_bin[y_n:y_n+h_n, x_n:x_n+w_n, :] + if not self.do_not_mask_with_textline_contour: + img_crop_bin[mask_poly==0] = 255 + + if mask_poly[:,:,0].sum() /float(w_n*h_n) < 0.50 and w_scaled > 90: + if img_bin: + img_crop, img_crop_bin = \ + break_curved_line_into_small_pieces_and_then_merge( + img_crop, mask_poly, img_crop_bin) + else: + img_crop, _ = \ + break_curved_line_into_small_pieces_and_then_merge( + img_crop, mask_poly) + + else: + better_des_slope = 0 + if not self.do_not_mask_with_textline_contour: + img_crop[mask_poly==0] = 255 + if img_bin: + if not self.do_not_mask_with_textline_contour: + img_crop_bin[mask_poly==0] = 255 + if type_textregion=='drop-capital': + pass + else: + if mask_poly[:,:,0].sum() /float(w*h) < 0.50 and w_scaled > 90: + if img_bin: + img_crop, img_crop_bin = \ + break_curved_line_into_small_pieces_and_then_merge( + img_crop, mask_poly, img_crop_bin) + else: + img_crop, _ = \ + break_curved_line_into_small_pieces_and_then_merge( + img_crop, mask_poly) + + if w_scaled < 750:#1.5*image_width: + img_fin = preprocess_and_resize_image_for_ocrcnn_model( + img_crop, image_height, image_width) + cropped_lines.append(img_fin) + if abs(better_des_slope) > 45: + cropped_lines_ver_index.append(1) + else: + cropped_lines_ver_index.append(0) + + cropped_lines_meging_indexing.append(0) + if img_bin: + img_fin = preprocess_and_resize_image_for_ocrcnn_model( + img_crop_bin, image_height, image_width) + cropped_lines_bin.append(img_fin) + else: + splited_images, splited_images_bin = return_textlines_split_if_needed( + img_crop, img_crop_bin if img_bin else None) + if splited_images: + img_fin = preprocess_and_resize_image_for_ocrcnn_model( + splited_images[0], image_height, image_width) + cropped_lines.append(img_fin) + cropped_lines_meging_indexing.append(1) + + if abs(better_des_slope) > 45: + cropped_lines_ver_index.append(1) + else: + cropped_lines_ver_index.append(0) + + img_fin = preprocess_and_resize_image_for_ocrcnn_model( + splited_images[1], image_height, image_width) + + cropped_lines.append(img_fin) + cropped_lines_meging_indexing.append(-1) + + if abs(better_des_slope) > 45: + cropped_lines_ver_index.append(1) + else: + cropped_lines_ver_index.append(0) + + if img_bin: + img_fin = preprocess_and_resize_image_for_ocrcnn_model( + splited_images_bin[0], image_height, image_width) + cropped_lines_bin.append(img_fin) + img_fin = preprocess_and_resize_image_for_ocrcnn_model( + splited_images_bin[1], image_height, image_width) + cropped_lines_bin.append(img_fin) + + else: + img_fin = preprocess_and_resize_image_for_ocrcnn_model( + img_crop, image_height, image_width) + cropped_lines.append(img_fin) + cropped_lines_meging_indexing.append(0) + + if abs(better_des_slope) > 45: + cropped_lines_ver_index.append(1) + else: + cropped_lines_ver_index.append(0) + + if img_bin: + img_fin = preprocess_and_resize_image_for_ocrcnn_model( + img_crop_bin, image_height, image_width) + cropped_lines_bin.append(img_fin) + + + indexer_text_region = indexer_text_region +1 + + extracted_texts = [] + extracted_conf_value = [] + + n_iterations = math.ceil(len(cropped_lines) / self.b_s) + + # FIXME: copy pasta + for i in range(n_iterations): + if i==(n_iterations-1): + n_start = i*self.b_s + imgs = cropped_lines[n_start:] + imgs = np.array(imgs) + imgs = imgs.reshape(imgs.shape[0], image_height, image_width, 3) + + ver_imgs = np.array( cropped_lines_ver_index[n_start:] ) + indices_ver = np.where(ver_imgs == 1)[0] + + #print(indices_ver, 'indices_ver') + if len(indices_ver)>0: + imgs_ver_flipped = imgs[indices_ver, : ,: ,:] + imgs_ver_flipped = imgs_ver_flipped[:,::-1,::-1,:] + #print(imgs_ver_flipped, 'imgs_ver_flipped') + + else: + imgs_ver_flipped = None + + if img_bin: + imgs_bin = cropped_lines_bin[n_start:] + imgs_bin = np.array(imgs_bin) + imgs_bin = imgs_bin.reshape(imgs_bin.shape[0], image_height, image_width, 3) + + if len(indices_ver)>0: + imgs_bin_ver_flipped = imgs_bin[indices_ver, : ,: ,:] + imgs_bin_ver_flipped = imgs_bin_ver_flipped[:,::-1,::-1,:] + #print(imgs_ver_flipped, 'imgs_ver_flipped') + + else: + imgs_bin_ver_flipped = None + else: + n_start = i*self.b_s + n_end = (i+1)*self.b_s + imgs = cropped_lines[n_start:n_end] + imgs = np.array(imgs).reshape(self.b_s, image_height, image_width, 3) + + ver_imgs = np.array( cropped_lines_ver_index[n_start:n_end] ) + indices_ver = np.where(ver_imgs == 1)[0] + #print(indices_ver, 'indices_ver') + + if len(indices_ver)>0: + imgs_ver_flipped = imgs[indices_ver, : ,: ,:] + imgs_ver_flipped = imgs_ver_flipped[:,::-1,::-1,:] + #print(imgs_ver_flipped, 'imgs_ver_flipped') + else: + imgs_ver_flipped = None + + + if img_bin: + imgs_bin = cropped_lines_bin[n_start:n_end] + imgs_bin = np.array(imgs_bin).reshape(self.b_s, image_height, image_width, 3) + + + if len(indices_ver)>0: + imgs_bin_ver_flipped = imgs_bin[indices_ver, : ,: ,:] + imgs_bin_ver_flipped = imgs_bin_ver_flipped[:,::-1,::-1,:] + #print(imgs_ver_flipped, 'imgs_ver_flipped') + else: + imgs_bin_ver_flipped = None + + + self.logger.debug("processing next %d lines", len(imgs)) + preds = self.model_zoo.get('ocr').predict(imgs, verbose=0) + + if len(indices_ver)>0: + preds_flipped = self.model_zoo.get('ocr').predict(imgs_ver_flipped, verbose=0) + preds_max_fliped = np.max(preds_flipped, axis=2 ) + preds_max_args_flipped = np.argmax(preds_flipped, axis=2 ) + pred_max_not_unk_mask_bool_flipped = preds_max_args_flipped[:,:]!=self.end_character + masked_means_flipped = \ + np.sum(preds_max_fliped * pred_max_not_unk_mask_bool_flipped, axis=1) / \ + np.sum(pred_max_not_unk_mask_bool_flipped, axis=1) + masked_means_flipped[np.isnan(masked_means_flipped)] = 0 + + preds_max = np.max(preds, axis=2 ) + preds_max_args = np.argmax(preds, axis=2 ) + pred_max_not_unk_mask_bool = preds_max_args[:,:]!=self.end_character + + masked_means = \ + np.sum(preds_max * pred_max_not_unk_mask_bool, axis=1) / \ + np.sum(pred_max_not_unk_mask_bool, axis=1) + masked_means[np.isnan(masked_means)] = 0 + + masked_means_ver = masked_means[indices_ver] + #print(masked_means_ver, 'pred_max_not_unk') + + indices_where_flipped_conf_value_is_higher = \ + np.where(masked_means_flipped > masked_means_ver)[0] + + #print(indices_where_flipped_conf_value_is_higher, 'indices_where_flipped_conf_value_is_higher') + if len(indices_where_flipped_conf_value_is_higher)>0: + indices_to_be_replaced = indices_ver[indices_where_flipped_conf_value_is_higher] + preds[indices_to_be_replaced,:,:] = \ + preds_flipped[indices_where_flipped_conf_value_is_higher, :, :] + + if img_bin: + preds_bin = self.model_zoo.get('ocr').predict(imgs_bin, verbose=0) + + if len(indices_ver)>0: + preds_flipped = self.model_zoo.get('ocr').predict(imgs_bin_ver_flipped, verbose=0) + preds_max_fliped = np.max(preds_flipped, axis=2 ) + preds_max_args_flipped = np.argmax(preds_flipped, axis=2 ) + pred_max_not_unk_mask_bool_flipped = preds_max_args_flipped[:,:]!=self.end_character + masked_means_flipped = \ + np.sum(preds_max_fliped * pred_max_not_unk_mask_bool_flipped, axis=1) / \ + np.sum(pred_max_not_unk_mask_bool_flipped, axis=1) + masked_means_flipped[np.isnan(masked_means_flipped)] = 0 + + preds_max = np.max(preds, axis=2 ) + preds_max_args = np.argmax(preds, axis=2 ) + pred_max_not_unk_mask_bool = preds_max_args[:,:]!=self.end_character + + masked_means = \ + np.sum(preds_max * pred_max_not_unk_mask_bool, axis=1) / \ + np.sum(pred_max_not_unk_mask_bool, axis=1) + masked_means[np.isnan(masked_means)] = 0 + + masked_means_ver = masked_means[indices_ver] + #print(masked_means_ver, 'pred_max_not_unk') + + indices_where_flipped_conf_value_is_higher = \ + np.where(masked_means_flipped > masked_means_ver)[0] + + #print(indices_where_flipped_conf_value_is_higher, 'indices_where_flipped_conf_value_is_higher') + if len(indices_where_flipped_conf_value_is_higher)>0: + indices_to_be_replaced = indices_ver[indices_where_flipped_conf_value_is_higher] + preds_bin[indices_to_be_replaced,:,:] = \ + preds_flipped[indices_where_flipped_conf_value_is_higher, :, :] + + preds = (preds + preds_bin) / 2. + + pred_texts = decode_batch_predictions(preds, self.model_zoo.get('num_to_char')) + + preds_max = np.max(preds, axis=2 ) + preds_max_args = np.argmax(preds, axis=2 ) + pred_max_not_unk_mask_bool = preds_max_args[:,:]!=self.end_character + masked_means = \ + np.sum(preds_max * pred_max_not_unk_mask_bool, axis=1) / \ + np.sum(pred_max_not_unk_mask_bool, axis=1) + + for ib in range(imgs.shape[0]): + pred_texts_ib = pred_texts[ib].replace("[UNK]", "") + if masked_means[ib] >= self.min_conf_value_of_textline_text: + extracted_texts.append(pred_texts_ib) + extracted_conf_value.append(masked_means[ib]) + else: + extracted_texts.append("") + extracted_conf_value.append(0) + del cropped_lines + del cropped_lines_bin + gc.collect() + + extracted_texts_merged = [extracted_texts[ind] + if cropped_lines_meging_indexing[ind]==0 + else extracted_texts[ind]+" "+extracted_texts[ind+1] + if cropped_lines_meging_indexing[ind]==1 + else None + for ind in range(len(cropped_lines_meging_indexing))] + + extracted_conf_value_merged = [extracted_conf_value[ind] # type: ignore + if cropped_lines_meging_indexing[ind]==0 + else (extracted_conf_value[ind]+extracted_conf_value[ind+1])/2. + if cropped_lines_meging_indexing[ind]==1 + else None + for ind in range(len(cropped_lines_meging_indexing))] + + extracted_conf_value_merged: List[float] = [extracted_conf_value_merged[ind_cfm] + for ind_cfm in range(len(extracted_texts_merged)) + if extracted_texts_merged[ind_cfm] is not None] + + extracted_texts_merged = [ind for ind in extracted_texts_merged if ind is not None] + + return EynollahOcrResult( + extracted_texts_merged=extracted_texts_merged, + extracted_conf_value_merged=extracted_conf_value_merged, + cropped_lines_region_indexer=cropped_lines_region_indexer, + total_bb_coordinates=total_bb_coordinates, + ) + + def write_ocr( + self, + *, + result: EynollahOcrResult, + page_tree: ET.ElementTree, + out_file_ocr, + page_ns, + img, + out_image_with_text, + ): + cropped_lines_region_indexer = result.cropped_lines_region_indexer + total_bb_coordinates = result.total_bb_coordinates + extracted_texts_merged = result.extracted_texts_merged + extracted_conf_value_merged = result.extracted_conf_value_merged + + unique_cropped_lines_region_indexer = np.unique(cropped_lines_region_indexer) + if out_image_with_text: + image_text = Image.new("RGB", (img.shape[1], img.shape[0]), "white") + draw = ImageDraw.Draw(image_text) + font = get_font() + + for indexer_text, bb_ind in enumerate(total_bb_coordinates): + x_bb = bb_ind[0] + y_bb = bb_ind[1] + w_bb = bb_ind[2] + h_bb = bb_ind[3] + + font = fit_text_single_line(draw, extracted_texts_merged[indexer_text], + font.path, w_bb, int(h_bb*0.4) ) + + ##draw.rectangle([x_bb, y_bb, x_bb + w_bb, y_bb + h_bb], outline="red", width=2) + + text_bbox = draw.textbbox((0, 0), extracted_texts_merged[indexer_text], font=font) + text_width = text_bbox[2] - text_bbox[0] + text_height = text_bbox[3] - text_bbox[1] + + text_x = x_bb + (w_bb - text_width) // 2 # Center horizontally + text_y = y_bb + (h_bb - text_height) // 2 # Center vertically + + # Draw the text + draw.text((text_x, text_y), extracted_texts_merged[indexer_text], fill="black", font=font) + image_text.save(out_image_with_text) + + text_by_textregion = [] + for ind in unique_cropped_lines_region_indexer: + ind = np.array(cropped_lines_region_indexer)==ind + extracted_texts_merged_un = np.array(extracted_texts_merged)[ind] + if len(extracted_texts_merged_un)>1: + text_by_textregion_ind = "" + next_glue = "" + for indt in range(len(extracted_texts_merged_un)): + if (extracted_texts_merged_un[indt].endswith('⸗') or + extracted_texts_merged_un[indt].endswith('-') or + extracted_texts_merged_un[indt].endswith('¬')): + text_by_textregion_ind += next_glue + extracted_texts_merged_un[indt][:-1] + next_glue = "" + else: + text_by_textregion_ind += next_glue + extracted_texts_merged_un[indt] + next_glue = " " + text_by_textregion.append(text_by_textregion_ind) + else: + text_by_textregion.append(" ".join(extracted_texts_merged_un)) + + indexer = 0 + indexer_textregion = 0 + for nn in page_tree.getroot().iter(f'{{{page_ns}}}TextRegion'): + + is_textregion_text = False + for childtest in nn: + if childtest.tag.endswith("TextEquiv"): + is_textregion_text = True + + if not is_textregion_text: + text_subelement_textregion = ET.SubElement(nn, 'TextEquiv') + unicode_textregion = ET.SubElement(text_subelement_textregion, 'Unicode') + + + has_textline = False + for child_textregion in nn: + if child_textregion.tag.endswith("TextLine"): + + is_textline_text = False + for childtest2 in child_textregion: + if childtest2.tag.endswith("TextEquiv"): + is_textline_text = True + + + if not is_textline_text: + text_subelement = ET.SubElement(child_textregion, 'TextEquiv') + if extracted_conf_value_merged: + text_subelement.set('conf', f"{extracted_conf_value_merged[indexer]:.2f}") + unicode_textline = ET.SubElement(text_subelement, 'Unicode') + unicode_textline.text = extracted_texts_merged[indexer] + else: + for childtest3 in child_textregion: + if childtest3.tag.endswith("TextEquiv"): + for child_uc in childtest3: + if child_uc.tag.endswith("Unicode"): + if extracted_conf_value_merged: + childtest3.set('conf', f"{extracted_conf_value_merged[indexer]:.2f}") + child_uc.text = extracted_texts_merged[indexer] + + indexer = indexer + 1 + has_textline = True + if has_textline: + if is_textregion_text: + for child4 in nn: + if child4.tag.endswith("TextEquiv"): + for childtr_uc in child4: + if childtr_uc.tag.endswith("Unicode"): + childtr_uc.text = text_by_textregion[indexer_textregion] + else: + unicode_textregion.text = text_by_textregion[indexer_textregion] + indexer_textregion = indexer_textregion + 1 + + ET.register_namespace("",page_ns) + page_tree.write(out_file_ocr, xml_declaration=True, method='xml', encoding="utf-8", default_namespace=None) + + def run( + self, + *, + overwrite: bool = False, + dir_in: Optional[str] = None, + dir_in_bin: Optional[str] = None, + image_filename: Optional[str] = None, + dir_xmls: str, + dir_out_image_text: Optional[str] = None, + dir_out: str, + ): + """ + Run OCR. + + Args: + + dir_in_bin (str): Prediction with RGB and binarized images for selected pages, should not be the default + """ + if dir_in: + ls_imgs = [os.path.join(dir_in, image_filename) + for image_filename in filter(is_image_filename, + os.listdir(dir_in))] + else: + assert image_filename + ls_imgs = [image_filename] + + for img_filename in ls_imgs: + file_stem = Path(img_filename).stem + page_file_in = os.path.join(dir_xmls, file_stem+'.xml') + out_file_ocr = os.path.join(dir_out, file_stem+'.xml') + + if os.path.exists(out_file_ocr): + if overwrite: + self.logger.warning("will overwrite existing output file '%s'", out_file_ocr) + else: + self.logger.warning("will skip input for existing output file '%s'", out_file_ocr) + return + + img = cv2.imread(img_filename) + + page_tree = ET.parse(page_file_in, parser = ET.XMLParser(encoding="utf-8")) + page_ns = etree_namespace_for_element_tag(page_tree.getroot().tag) + + out_image_with_text = None + if dir_out_image_text: + out_image_with_text = os.path.join(dir_out_image_text, file_stem + '.png') + + img_bin = None + if dir_in_bin: + img_bin = cv2.imread(os.path.join(dir_in_bin, file_stem+'.png')) + + + if self.tr_ocr: + result = self.run_trocr( + img=img, + page_tree=page_tree, + page_ns=page_ns, + + tr_ocr_input_height_and_width = 384 + ) + else: + result = self.run_cnn( + img=img, + page_tree=page_tree, + page_ns=page_ns, + + img_bin=img_bin, + image_width=512, + image_height=32, + ) + + self.write_ocr( + result=result, + img=img, + page_tree=page_tree, + page_ns=page_ns, + out_file_ocr=out_file_ocr, + out_image_with_text=out_image_with_text, + ) diff --git a/src/eynollah/image_enhancer.py b/src/eynollah/image_enhancer.py index 9247efe..babbd55 100644 --- a/src/eynollah/image_enhancer.py +++ b/src/eynollah/image_enhancer.py @@ -2,7 +2,12 @@ Image enhancer. The output can be written as same scale of input or in new predicted scale. """ -from logging import Logger +# FIXME: fix all of those... +# pyright: reportUnboundVariable=false +# pyright: reportCallIssue=false +# pyright: reportArgumentType=false + +import logging import os import time from typing import Optional @@ -10,19 +15,18 @@ from pathlib import Path import gc import cv2 +from keras.models import Model import numpy as np -from ocrd_utils import getLogger, tf_disable_interactive_logs -import tensorflow as tf +import tensorflow as tf # type: ignore from skimage.morphology import skeletonize -from tensorflow.keras.models import load_model +from .model_zoo import EynollahModelZoo from .utils.resize import resize_image from .utils.pil_cv2 import pil2cv from .utils import ( is_image_filename, crop_image_inside_box ) -from .eynollah import PatchEncoder, Patches DPI_THRESHOLD = 298 KERNEL = np.ones((5, 5), np.uint8) @@ -31,14 +35,13 @@ KERNEL = np.ones((5, 5), np.uint8) class Enhancer: def __init__( self, - dir_models : str, + *, + model_zoo: EynollahModelZoo, num_col_upper : Optional[int] = None, num_col_lower : Optional[int] = None, save_org_scale : bool = False, - logger : Optional[Logger] = None, ): self.input_binary = False - self.light_version = False self.save_org_scale = save_org_scale if num_col_upper: self.num_col_upper = int(num_col_upper) @@ -49,12 +52,10 @@ class Enhancer: else: self.num_col_lower = num_col_lower - self.logger = logger if logger else getLogger('enhancement') - self.dir_models = dir_models - self.model_dir_of_binarization = dir_models + "/eynollah-binarization_20210425" - self.model_dir_of_enhancement = dir_models + "/eynollah-enhancement_20210425" - self.model_dir_of_col_classifier = dir_models + "/eynollah-column-classifier_20210425" - self.model_page_dir = dir_models + "/model_eynollah_page_extraction_20250915" + self.logger = logging.getLogger('eynollah.enhance') + self.model_zoo = model_zoo + for v in ['binarization', 'enhancement', 'col_classifier', 'page']: + self.model_zoo.load_model(v) try: for device in tf.config.list_physical_devices('GPU'): @@ -62,25 +63,14 @@ class Enhancer: except: self.logger.warning("no GPU device available") - self.model_page = self.our_load_model(self.model_page_dir) - self.model_classifier = self.our_load_model(self.model_dir_of_col_classifier) - self.model_enhancement = self.our_load_model(self.model_dir_of_enhancement) - self.model_bin = self.our_load_model(self.model_dir_of_binarization) - def cache_images(self, image_filename=None, image_pil=None, dpi=None): ret = {} if image_filename: ret['img'] = cv2.imread(image_filename) - if self.light_version: - self.dpi = 100 - else: - self.dpi = 0#check_dpi(image_filename) + self.dpi = 100 else: ret['img'] = pil2cv(image_pil) - if self.light_version: - self.dpi = 100 - else: - self.dpi = 0#check_dpi(image_pil) + self.dpi = 100 ret['img_grayscale'] = cv2.cvtColor(ret['img'], cv2.COLOR_BGR2GRAY) for prefix in ('', '_grayscale'): ret[f'img{prefix}_uint8'] = ret[f'img{prefix}'].astype(np.uint8) @@ -100,26 +90,11 @@ class Enhancer: key += '_uint8' return self._imgs[key].copy() - def isNaN(self, num): - return num != num - - @staticmethod - def our_load_model(model_file): - if model_file.endswith('.h5') and Path(model_file[:-3]).exists(): - # prefer SavedModel over HDF5 format if it exists - model_file = model_file[:-3] - try: - model = load_model(model_file, compile=False) - except: - model = load_model(model_file, compile=False, custom_objects={ - "PatchEncoder": PatchEncoder, "Patches": Patches}) - return model - def predict_enhancement(self, img): self.logger.debug("enter predict_enhancement") - img_height_model = self.model_enhancement.layers[-1].output_shape[1] - img_width_model = self.model_enhancement.layers[-1].output_shape[2] + img_height_model = self.model_zoo.get('enhancement', Model).layers[-1].output_shape[1] + img_width_model = self.model_zoo.get('enhancement', Model).layers[-1].output_shape[2] if img.shape[0] < img_height_model: img = cv2.resize(img, (img.shape[1], img_width_model), interpolation=cv2.INTER_NEAREST) if img.shape[1] < img_width_model: @@ -160,7 +135,7 @@ class Enhancer: index_y_d = img_h - img_height_model img_patch = img[np.newaxis, index_y_d:index_y_u, index_x_d:index_x_u, :] - label_p_pred = self.model_enhancement.predict(img_patch, verbose=0) + label_p_pred = self.model_zoo.get('enhancement', Model).predict(img_patch, verbose='0') seg = label_p_pred[0, :, :, :] * 255 if i == 0 and j == 0: @@ -246,7 +221,7 @@ class Enhancer: else: img = self.imread() img = cv2.GaussianBlur(img, (5, 5), 0) - img_page_prediction = self.do_prediction(False, img, self.model_page) + img_page_prediction = self.do_prediction(False, img, self.model_zoo.get('page')) imgray = cv2.cvtColor(img_page_prediction, cv2.COLOR_BGR2GRAY) _, thresh = cv2.threshold(imgray, 0, 255, 0) @@ -285,13 +260,13 @@ class Enhancer: return img_new, num_column_is_classified - def resize_and_enhance_image_with_column_classifier(self, light_version): + def resize_and_enhance_image_with_column_classifier(self): self.logger.debug("enter resize_and_enhance_image_with_column_classifier") dpi = 0#self.dpi self.logger.info("Detected %s DPI", dpi) if self.input_binary: img = self.imread() - prediction_bin = self.do_prediction(True, img, self.model_bin, n_batch_inference=5) + prediction_bin = self.do_prediction(True, img, self.model_zoo.get('binarization'), n_batch_inference=5) prediction_bin = 255 * (prediction_bin[:,:,0]==0) prediction_bin = np.repeat(prediction_bin[:, :, np.newaxis], 3, axis=2).astype(np.uint8) img= np.copy(prediction_bin) @@ -332,7 +307,7 @@ class Enhancer: img_in[0, :, :, 1] = img_1ch[:, :] img_in[0, :, :, 2] = img_1ch[:, :] - label_p_pred = self.model_classifier.predict(img_in, verbose=0) + label_p_pred = self.model_zoo.get('col_classifier').predict(img_in, verbose=0) num_col = np.argmax(label_p_pred[0]) + 1 elif (self.num_col_upper and self.num_col_lower) and (self.num_col_upper!=self.num_col_lower): if self.input_binary: @@ -352,7 +327,7 @@ class Enhancer: img_in[0, :, :, 1] = img_1ch[:, :] img_in[0, :, :, 2] = img_1ch[:, :] - label_p_pred = self.model_classifier.predict(img_in, verbose=0) + label_p_pred = self.model_zoo.get('col_classifier').predict(img_in, verbose=0) num_col = np.argmax(label_p_pred[0]) + 1 if num_col > self.num_col_upper: @@ -368,16 +343,13 @@ class Enhancer: self.logger.info("Found %d columns (%s)", num_col, np.around(label_p_pred, decimals=5)) if dpi < DPI_THRESHOLD: - if light_version and num_col in (1,2): + if num_col in (1,2): img_new, num_column_is_classified = self.calculate_width_height_by_columns_1_2( img, num_col, width_early, label_p_pred) else: img_new, num_column_is_classified = self.calculate_width_height_by_columns( img, num_col, width_early, label_p_pred) - if light_version: - image_res = np.copy(img_new) - else: - image_res = self.predict_enhancement(img_new) + image_res = np.copy(img_new) is_image_enhanced = True else: @@ -671,11 +643,11 @@ class Enhancer: gc.collect() return prediction_true - def run_enhancement(self, light_version): + def run_enhancement(self): t_in = time.time() self.logger.info("Resizing and enhancing image...") is_image_enhanced, img_org, img_res, num_col_classifier, num_column_is_classified, img_bin = \ - self.resize_and_enhance_image_with_column_classifier(light_version) + self.resize_and_enhance_image_with_column_classifier() self.logger.info("Image was %senhanced.", '' if is_image_enhanced else 'not ') return img_res, is_image_enhanced, num_col_classifier, num_column_is_classified @@ -683,9 +655,9 @@ class Enhancer: def run_single(self): t0 = time.time() - img_res, is_image_enhanced, num_col_classifier, num_column_is_classified = self.run_enhancement(light_version=False) + img_res, is_image_enhanced, num_col_classifier, num_column_is_classified = self.run_enhancement() - return img_res + return img_res, is_image_enhanced def run(self, @@ -723,9 +695,18 @@ class Enhancer: self.logger.warning("will skip input for existing output file '%s'", self.output_filename) continue - image_enhanced = self.run_single() + did_resize = False + image_enhanced, did_enhance = self.run_single() if self.save_org_scale: image_enhanced = resize_image(image_enhanced, self.h_org, self.w_org) + did_resize = True + + self.logger.info( + "Image %s was %senhanced%s.", + img_filename, + '' if did_enhance else 'not ', + 'and resized' if did_resize else '' + ) cv2.imwrite(self.output_filename, image_enhanced) diff --git a/src/eynollah/mb_ro_on_layout.py b/src/eynollah/mb_ro_on_layout.py index 1b991ae..eec544c 100644 --- a/src/eynollah/mb_ro_on_layout.py +++ b/src/eynollah/mb_ro_on_layout.py @@ -1,8 +1,12 @@ """ -Image enhancer. The output can be written as same scale of input or in new predicted scale. +Machine learning based reading order detection """ -from logging import Logger +# pyright: reportCallIssue=false +# pyright: reportUnboundVariable=false +# pyright: reportArgumentType=false + +import logging import os import time from typing import Optional @@ -10,12 +14,12 @@ from pathlib import Path import xml.etree.ElementTree as ET import cv2 +from keras.models import Model import numpy as np -from ocrd_utils import getLogger import statistics import tensorflow as tf -from tensorflow.keras.models import load_model +from .model_zoo import EynollahModelZoo from .utils.resize import resize_image from .utils.contour import ( find_new_features_of_contours, @@ -23,7 +27,6 @@ from .utils.contour import ( return_parent_contours, ) from .utils import is_xml_filename -from .eynollah import PatchEncoder, Patches DPI_THRESHOLD = 298 KERNEL = np.ones((5, 5), np.uint8) @@ -32,12 +35,12 @@ KERNEL = np.ones((5, 5), np.uint8) class machine_based_reading_order_on_layout: def __init__( self, - dir_models : str, - logger : Optional[Logger] = None, + *, + model_zoo: EynollahModelZoo, + logger : Optional[logging.Logger] = None, ): - self.logger = logger if logger else getLogger('mbreorder') - self.dir_models = dir_models - self.model_reading_order_dir = dir_models + "/model_eynollah_reading_order_20250824" + self.logger = logger or logging.getLogger('eynollah.mbreorder') + self.model_zoo = model_zoo try: for device in tf.config.list_physical_devices('GPU'): @@ -45,20 +48,7 @@ class machine_based_reading_order_on_layout: except: self.logger.warning("no GPU device available") - self.model_reading_order = self.our_load_model(self.model_reading_order_dir) - self.light_version = True - - @staticmethod - def our_load_model(model_file): - if model_file.endswith('.h5') and Path(model_file[:-3]).exists(): - # prefer SavedModel over HDF5 format if it exists - model_file = model_file[:-3] - try: - model = load_model(model_file, compile=False) - except: - model = load_model(model_file, compile=False, custom_objects={ - "PatchEncoder": PatchEncoder, "Patches": Patches}) - return model + self.model_zoo.load_model('reading_order') def read_xml(self, xml_file): tree1 = ET.parse(xml_file, parser = ET.XMLParser(encoding='utf-8')) @@ -69,6 +59,7 @@ class machine_based_reading_order_on_layout: index_tot_regions = [] tot_region_ref = [] + y_len, x_len = 0, 0 for jj in root1.iter(link+'Page'): y_len=int(jj.attrib['imageHeight']) x_len=int(jj.attrib['imageWidth']) @@ -81,13 +72,13 @@ class machine_based_reading_order_on_layout: co_printspace = [] if link+'PrintSpace' in alltags: region_tags_printspace = np.unique([x for x in alltags if x.endswith('PrintSpace')]) - elif link+'Border' in alltags: + else: region_tags_printspace = np.unique([x for x in alltags if x.endswith('Border')]) for tag in region_tags_printspace: if link+'PrintSpace' in alltags: tag_endings_printspace = ['}PrintSpace','}printspace'] - elif link+'Border' in alltags: + else: tag_endings_printspace = ['}Border','}border'] if tag.endswith(tag_endings_printspace[0]) or tag.endswith(tag_endings_printspace[1]): @@ -524,7 +515,7 @@ class machine_based_reading_order_on_layout: min_cont_size_to_be_dilated = 10 - if len(contours_only_text_parent)>min_cont_size_to_be_dilated and self.light_version: + if len(contours_only_text_parent)>min_cont_size_to_be_dilated: cx_conts, cy_conts, x_min_conts, x_max_conts, y_min_conts, y_max_conts, _ = find_new_features_of_contours(contours_only_text_parent) args_cont_located = np.array(range(len(contours_only_text_parent))) @@ -624,13 +615,13 @@ class machine_based_reading_order_on_layout: img_header_and_sep[int(y_max_main[j]):int(y_max_main[j])+12, int(x_min_main[j]):int(x_max_main[j])] = 1 co_text_all_org = contours_only_text_parent + contours_only_text_parent_h - if len(contours_only_text_parent)>min_cont_size_to_be_dilated and self.light_version: + if len(contours_only_text_parent)>min_cont_size_to_be_dilated: co_text_all = contours_only_dilated + contours_only_text_parent_h else: co_text_all = contours_only_text_parent + contours_only_text_parent_h else: co_text_all_org = contours_only_text_parent - if len(contours_only_text_parent)>min_cont_size_to_be_dilated and self.light_version: + if len(contours_only_text_parent)>min_cont_size_to_be_dilated: co_text_all = contours_only_dilated else: co_text_all = contours_only_text_parent @@ -683,7 +674,7 @@ class machine_based_reading_order_on_layout: tot_counter += 1 batch.append(j) if tot_counter % inference_bs == 0 or tot_counter == len(ij_list): - y_pr = self.model_reading_order.predict(input_1 , verbose=0) + y_pr = self.model_zoo.get('reading_order', Model).predict(input_1 , verbose='0') for jb, j in enumerate(batch): if y_pr[jb][0]>=0.5: post_list.append(j) @@ -709,7 +700,7 @@ class machine_based_reading_order_on_layout: ##id_all_text = np.array(id_all_text)[index_sort] - if len(contours_only_text_parent)>min_cont_size_to_be_dilated and self.light_version: + if len(contours_only_text_parent)>min_cont_size_to_be_dilated: org_contours_indexes = [] for ind in range(len(ordered)): region_with_curr_order = ordered[ind] @@ -802,6 +793,7 @@ class machine_based_reading_order_on_layout: alltags=[elem.tag for elem in root_xml.iter()] ET.register_namespace("",name_space) + assert dir_out tree_xml.write(os.path.join(dir_out, file_name+'.xml'), xml_declaration=True, method='xml', diff --git a/src/eynollah/model_zoo/__init__.py b/src/eynollah/model_zoo/__init__.py new file mode 100644 index 0000000..e1dc985 --- /dev/null +++ b/src/eynollah/model_zoo/__init__.py @@ -0,0 +1,4 @@ +__all__ = [ + 'EynollahModelZoo', +] +from .model_zoo import EynollahModelZoo diff --git a/src/eynollah/model_zoo/default_specs.py b/src/eynollah/model_zoo/default_specs.py new file mode 100644 index 0000000..b9a1a2c --- /dev/null +++ b/src/eynollah/model_zoo/default_specs.py @@ -0,0 +1,252 @@ +from .specs import EynollahModelSpec, EynollahModelSpecSet + +# NOTE: This needs to change whenever models/versions change +ZENODO = "https://zenodo.org/records/17295988/files" +MODELS_VERSION = "v0_7_0" + +def dist_url(dist_name: str="layout") -> str: + return f'{ZENODO}/models_{dist_name}_{MODELS_VERSION}.zip' + +DEFAULT_MODEL_SPECS = EynollahModelSpecSet([ + + EynollahModelSpec( + category="enhancement", + variant='', + filename="models_eynollah/eynollah-enhancement_20210425", + dist_url=dist_url(), + type='Keras', + ), + + EynollahModelSpec( + category="binarization", + variant='hybrid', + filename="models_eynollah/eynollah-binarization-hybrid_20230504/model_bin_hybrid_trans_cnn_sbb_ens", + dist_url=dist_url(), + type='Keras', + ), + + EynollahModelSpec( + category="binarization", + variant='20210309', + filename="models_eynollah/eynollah-binarization_20210309", + dist_url=dist_url("extra"), + type='Keras', + ), + + EynollahModelSpec( + category="binarization", + variant='', + filename="models_eynollah/eynollah-binarization_20210425", + dist_url=dist_url("extra"), + type='Keras', + ), + + EynollahModelSpec( + category="col_classifier", + variant='', + filename="models_eynollah/eynollah-column-classifier_20210425", + dist_url=dist_url(), + type='Keras', + ), + + EynollahModelSpec( + category="page", + variant='', + filename="models_eynollah/model_eynollah_page_extraction_20250915", + dist_url=dist_url(), + type='Keras', + ), + + EynollahModelSpec( + category="region", + variant='', + filename="models_eynollah/eynollah-main-regions-ensembled_20210425", + dist_url=dist_url(), + type='Keras', + ), + + EynollahModelSpec( + category="extract_images", + variant='', + filename="models_eynollah/eynollah-main-regions_20231127_672_org_ens_11_13_16_17_18", + dist_url=dist_url(), + type='Keras', + ), + + EynollahModelSpec( + category="region", + variant='', + filename="models_eynollah/eynollah-main-regions_20220314", + dist_url=dist_url(), + help="early layout", + type='Keras', + ), + + EynollahModelSpec( + category="region_p2", + variant='non-light', + filename="models_eynollah/eynollah-main-regions-aug-rotation_20210425", + dist_url=dist_url('extra'), + help="early layout, non-light, 2nd part", + type='Keras', + ), + + EynollahModelSpec( + category="region_1_2", + variant='', + #filename="models_eynollah/modelens_12sp_elay_0_3_4__3_6_n", + #filename="models_eynollah/modelens_earlylayout_12spaltige_2_3_5_6_7_8", + #filename="models_eynollah/modelens_early12_sp_2_3_5_6_7_8_9_10_12_14_15_16_18", + #filename="models_eynollah/modelens_1_2_4_5_early_lay_1_2_spaltige", + #filename="models_eynollah/model_3_eraly_layout_no_patches_1_2_spaltige", + filename="models_eynollah/modelens_e_l_all_sp_0_1_2_3_4_171024", + dist_url=dist_url("layout"), + help="early layout, light, 1-or-2-column", + type='Keras', + ), + + EynollahModelSpec( + category="region_fl_np", + variant='', + #'filename="models_eynollah/modelens_full_lay_1_3_031124", + #'filename="models_eynollah/modelens_full_lay_13__3_19_241024", + #'filename="models_eynollah/model_full_lay_13_241024", + #'filename="models_eynollah/modelens_full_lay_13_17_231024", + #'filename="models_eynollah/modelens_full_lay_1_2_221024", + #'filename="models_eynollah/eynollah-full-regions-1column_20210425", + filename="models_eynollah/modelens_full_lay_1__4_3_091124", + dist_url=dist_url(), + help="full layout / no patches", + type='Keras', + ), + + # FIXME: Why is region_fl and region_fl_np the same model? + EynollahModelSpec( + category="region_fl", + variant='', + # filename="models_eynollah/eynollah-full-regions-3+column_20210425", + # filename="models_eynollah/model_2_full_layout_new_trans", + # filename="models_eynollah/modelens_full_lay_1_3_031124", + # filename="models_eynollah/modelens_full_lay_13__3_19_241024", + # filename="models_eynollah/model_full_lay_13_241024", + # filename="models_eynollah/modelens_full_lay_13_17_231024", + # filename="models_eynollah/modelens_full_lay_1_2_221024", + # filename="models_eynollah/modelens_full_layout_24_till_28", + # filename="models_eynollah/model_2_full_layout_new_trans", + filename="models_eynollah/modelens_full_lay_1__4_3_091124", + dist_url=dist_url(), + help="full layout / with patches", + type='Keras', + ), + + EynollahModelSpec( + category="reading_order", + variant='', + #filename="models_eynollah/model_mb_ro_aug_ens_11", + #filename="models_eynollah/model_step_3200000_mb_ro", + #filename="models_eynollah/model_ens_reading_order_machine_based", + #filename="models_eynollah/model_mb_ro_aug_ens_8", + #filename="models_eynollah/model_ens_reading_order_machine_based", + filename="models_eynollah/model_eynollah_reading_order_20250824", + dist_url=dist_url(), + type='Keras', + ), + + EynollahModelSpec( + category="textline", + variant='non-light', + #filename="models_eynollah/modelens_textline_1_4_16092024", + #filename="models_eynollah/model_textline_ens_3_4_5_6_artificial", + #filename="models_eynollah/modelens_textline_1_3_4_20240915", + #filename="models_eynollah/model_textline_ens_3_4_5_6_artificial", + #filename="models_eynollah/modelens_textline_9_12_13_14_15", + #filename="models_eynollah/eynollah-textline_20210425", + filename="models_eynollah/modelens_textline_0_1__2_4_16092024", + dist_url=dist_url('extra'), + type='Keras', + ), + + EynollahModelSpec( + category="textline", + variant='', + #filename="models_eynollah/eynollah-textline_light_20210425", + filename="models_eynollah/modelens_textline_0_1__2_4_16092024", + dist_url=dist_url(), + type='Keras', + ), + + EynollahModelSpec( + category="table", + variant='non-light', + filename="models_eynollah/eynollah-tables_20210319", + dist_url=dist_url('extra'), + type='Keras', + ), + + EynollahModelSpec( + category="table", + variant='', + filename="models_eynollah/modelens_table_0t4_201124", + dist_url=dist_url(), + type='Keras', + ), + + EynollahModelSpec( + category="ocr", + variant='', + filename="models_eynollah/model_eynollah_ocr_cnnrnn_20250930", + dist_url=dist_url("ocr"), + type='Keras', + ), + + EynollahModelSpec( + category="ocr", + variant='degraded', + filename="models_eynollah/model_eynollah_ocr_cnnrnn__degraded_20250805/", + help="slightly better at degraded Fraktur", + dist_url=dist_url("ocr"), + type='Keras', + ), + + EynollahModelSpec( + category="num_to_char", + variant='', + filename="characters_org.txt", + dist_url=dist_url("ocr"), + type='decoder', + ), + + EynollahModelSpec( + category="characters", + variant='', + filename="characters_org.txt", + dist_url=dist_url("ocr"), + type='List[str]', + ), + + EynollahModelSpec( + category="ocr", + variant='tr', + filename="models_eynollah/model_eynollah_ocr_trocr_20250919", + dist_url=dist_url("ocr"), + help='much slower transformer-based', + type='Keras', + ), + + EynollahModelSpec( + category="trocr_processor", + variant='', + filename="models_eynollah/model_eynollah_ocr_trocr_20250919", + dist_url=dist_url("ocr"), + type='TrOCRProcessor', + ), + + EynollahModelSpec( + category="trocr_processor", + variant='htr', + filename="models_eynollah/microsoft/trocr-base-handwritten", + dist_url=dist_url("extra"), + type='TrOCRProcessor', + ), + +]) diff --git a/src/eynollah/model_zoo/model_zoo.py b/src/eynollah/model_zoo/model_zoo.py new file mode 100644 index 0000000..83068ff --- /dev/null +++ b/src/eynollah/model_zoo/model_zoo.py @@ -0,0 +1,206 @@ +import os +import json +import logging +from copy import deepcopy +from pathlib import Path +from typing import Dict, List, Optional, Tuple, Type, Union + +os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15 +from ocrd_utils import tf_disable_interactive_logs +tf_disable_interactive_logs() + +from tensorflow.keras.layers import StringLookup +from tensorflow.keras.models import Model as KerasModel +from tensorflow.keras.models import load_model +from tabulate import tabulate + +from ..patch_encoder import PatchEncoder, Patches +from .specs import EynollahModelSpecSet +from .default_specs import DEFAULT_MODEL_SPECS +from .types import AnyModel, T + + +class EynollahModelZoo: + """ + Wrapper class that handles storage and loading of models for all eynollah runners. + """ + + model_basedir: Path + specs: EynollahModelSpecSet + + def __init__( + self, + basedir: str, + model_overrides: Optional[List[Tuple[str, str, str]]] = None, + ) -> None: + self.model_basedir = Path(basedir) + self.logger = logging.getLogger('eynollah.model_zoo') + if not self.model_basedir.exists(): + self.logger.warning(f"Model basedir does not exist: {basedir}. Set eynollah --model-basedir to the correct directory.") + self.specs = deepcopy(DEFAULT_MODEL_SPECS) + self._overrides = [] + if model_overrides: + self.override_models(*model_overrides) + self._loaded: Dict[str, AnyModel] = {} + + @property + def model_overrides(self): + return self._overrides + + def override_models( + self, + *model_overrides: Tuple[str, str, str], + ): + """ + Override the default model versions + """ + for model_category, model_variant, model_filename in model_overrides: + spec = self.specs.get(model_category, model_variant) + self.logger.warning("Overriding filename for model spec %s to %s", spec, model_filename) + self.specs.get(model_category, model_variant).filename = model_filename + self._overrides += model_overrides + + def model_path( + self, + model_category: str, + model_variant: str = '', + absolute: bool = True, + ) -> Path: + """ + Translate model_{type,variant} tuple into an absolute (or relative) Path + """ + spec = self.specs.get(model_category, model_variant) + if spec.category in ('characters', 'num_to_char'): + return self.model_path('ocr') / spec.filename + if not Path(spec.filename).is_absolute() and absolute: + model_path = Path(self.model_basedir).joinpath(spec.filename) + else: + model_path = Path(spec.filename) + return model_path + + def load_models( + self, + *all_load_args: Union[str, Tuple[str], Tuple[str, str], Tuple[str, str, str]], + ) -> Dict: + """ + Load all models by calling load_model and return a dictionary mapping model_category to loaded model + """ + ret = {} + for load_args in all_load_args: + if isinstance(load_args, str): + ret[load_args] = self.load_model(load_args) + else: + ret[load_args[0]] = self.load_model(*load_args) + return ret + + def load_model( + self, + model_category: str, + model_variant: str = '', + model_path_override: Optional[str] = None, + ) -> AnyModel: + """ + Load any model + """ + if model_path_override: + self.override_models((model_category, model_variant, model_path_override)) + model_path = self.model_path(model_category, model_variant) + if model_path.suffix == '.h5' and Path(model_path.stem).exists(): + # prefer SavedModel over HDF5 format if it exists + model_path = Path(model_path.stem) + if model_category == 'ocr': + model = self._load_ocr_model(variant=model_variant) + elif model_category == 'num_to_char': + model = self._load_num_to_char() + elif model_category == 'characters': + model = self._load_characters() + elif model_category == 'trocr_processor': + from transformers import TrOCRProcessor + model = TrOCRProcessor.from_pretrained(model_path) + else: + try: + model = load_model(model_path, compile=False) + except Exception as e: + self.logger.exception(e) + model = load_model( + model_path, compile=False, custom_objects={"PatchEncoder": PatchEncoder, "Patches": Patches} + ) + self._loaded[model_category] = model + return model # type: ignore + + def get(self, model_category: str, model_type: Optional[Type[T]] = None) -> T: + if model_category not in self._loaded: + raise ValueError(f'Model "{model_category} not previously loaded with "load_model(..)"') + ret = self._loaded[model_category] + if model_type: + assert isinstance(ret, model_type) + return ret # type: ignore # FIXME: convince typing that we're returning generic type + + def _load_ocr_model(self, variant: str) -> AnyModel: + """ + Load OCR model + """ + ocr_model_dir = self.model_path('ocr', variant) + if variant == 'tr': + from transformers import VisionEncoderDecoderModel + ret = VisionEncoderDecoderModel.from_pretrained(ocr_model_dir) + assert isinstance(ret, VisionEncoderDecoderModel) + return ret + else: + ocr_model = load_model(ocr_model_dir, compile=False) + assert isinstance(ocr_model, KerasModel) + return KerasModel( + ocr_model.get_layer(name="image").input, # type: ignore + ocr_model.get_layer(name="dense2").output, # type: ignore + ) + + def _load_characters(self) -> List[str]: + """ + Load encoding for OCR + """ + with open(self.model_path('num_to_char'), "r") as config_file: + return json.load(config_file) + + def _load_num_to_char(self) -> StringLookup: + """ + Load decoder for OCR + """ + characters = self._load_characters() + # Mapping characters to integers. + char_to_num = StringLookup(vocabulary=characters, mask_token=None) + # Mapping integers back to original characters. + return StringLookup(vocabulary=char_to_num.get_vocabulary(), mask_token=None, invert=True) + + def __str__(self): + return tabulate( + [ + [ + spec.type, + spec.category, + spec.variant, + spec.help, + f'Yes, at {self.model_path(spec.category, spec.variant)}' + if self.model_path(spec.category, spec.variant).exists() + else f'No, download {spec.dist_url}', + # self.model_path(spec.category, spec.variant), + ] + for spec in sorted(self.specs.specs, key=lambda x: x.dist_url) + ], + headers=[ + 'Type', + 'Category', + 'Variant', + 'Help', + 'Used in', + 'Installed', + ], + tablefmt='github', + ) + + def shutdown(self): + """ + Ensure that a loaded models is not referenced by ``self._loaded`` anymore + """ + if hasattr(self, '_loaded') and getattr(self, '_loaded'): + for needle in list(self._loaded.keys()): + del self._loaded[needle] diff --git a/src/eynollah/model_zoo/specs.py b/src/eynollah/model_zoo/specs.py new file mode 100644 index 0000000..3c47b7b --- /dev/null +++ b/src/eynollah/model_zoo/specs.py @@ -0,0 +1,52 @@ +from dataclasses import dataclass +from typing import Dict, List, Set, Tuple + + +@dataclass +class EynollahModelSpec(): + """ + Describing a single model abstractly. + """ + category: str + # Relative filename to the models_eynollah directory in the dists + filename: str + # URL to the smallest model distribution containing this model (link to Zenodo) + dist_url: str + type: str + variant: str = '' + help: str = '' + +class EynollahModelSpecSet(): + """ + List of all used models for eynollah. + """ + specs: List[EynollahModelSpec] + + def __init__(self, specs: List[EynollahModelSpec]) -> None: + self.specs = sorted(specs, key=lambda x: x.category + '0' + x.variant) + self.categories: Set[str] = set([spec.category for spec in self.specs]) + self.variants: Dict[str, Set[str]] = { + spec.category: set([x.variant for x in self.specs if x.category == spec.category]) + for spec in self.specs + } + self._index_category_variant: Dict[Tuple[str, str], EynollahModelSpec] = { + (spec.category, spec.variant): spec + for spec in self.specs + } + + def asdict(self) -> Dict[str, Dict[str, str]]: + return { + spec.category: { + spec.variant: spec.filename + } + for spec in self.specs + } + + def get(self, category: str, variant: str) -> EynollahModelSpec: + if category not in self.categories: + raise ValueError(f"Unknown category '{category}', must be one of {self.categories}") + if variant not in self.variants[category]: + raise ValueError(f"Unknown variant {variant} for {category}. Known variants: {self.variants[category]}") + return self._index_category_variant[(category, variant)] + + diff --git a/src/eynollah/model_zoo/types.py b/src/eynollah/model_zoo/types.py new file mode 100644 index 0000000..43f6859 --- /dev/null +++ b/src/eynollah/model_zoo/types.py @@ -0,0 +1,7 @@ +from typing import TypeVar + +# NOTE: Creating an actual union type requires loading transformers which is expensive and error-prone +# from transformers import TrOCRProcessor, VisionEncoderDecoderModel +# AnyModel = Union[VisionEncoderDecoderModel, TrOCRProcessor, KerasModel, List] +AnyModel = object +T = TypeVar('T') diff --git a/src/eynollah/ocrd-tool.json b/src/eynollah/ocrd-tool.json index dbbdc3b..3b500fc 100644 --- a/src/eynollah/ocrd-tool.json +++ b/src/eynollah/ocrd-tool.json @@ -1,5 +1,5 @@ { - "version": "0.6.0", + "version": "0.7.0", "git_url": "https://github.com/qurator-spk/eynollah", "dockerhub": "ocrd/eynollah", "tools": { @@ -29,16 +29,6 @@ "type": "boolean", "default": true, "description": "Try to detect all element subtypes, including drop-caps and headings" - }, - "light_version": { - "type": "boolean", - "default": true, - "description": "Try to detect all element subtypes in light version (faster+simpler method for main region detection and deskewing)" - }, - "textline_light": { - "type": "boolean", - "default": true, - "description": "Light version need textline light. If this parameter set to true, this tool will try to return contoure of textlines instead of rectangle bounding box of textline with a faster method." }, "tables": { "type": "boolean", @@ -83,12 +73,20 @@ }, "resources": [ { - "url": "https://zenodo.org/records/17194824/files/models_layout_v0_5_0.tar.gz?download=1", - "name": "models_layout_v0_5_0", + "url": "https://zenodo.org/records/17580627/files/models_all_v0_7_0.zip?download=1", + "name": "models_layout_v0_7_0", "type": "archive", - "path_in_archive": "models_layout_v0_5_0", + "size": 6119874002, + "description": "Models for layout detection, reading order detection, textline detection, page extraction, column classification, table detection, binarization, image enhancement and OCR", + "version_range": ">= v0.7.0" + }, + { + "url": "https://zenodo.org/records/17295988/files/models_layout_v0_6_0.tar.gz?download=1", + "name": "models_layout_v0_6_0", + "type": "archive", + "path_in_archive": "models_layout_v0_6_0", "size": 3525684179, - "description": "Models for layout detection, reading order detection, textline detection, page extraction, column classification, table detection, binarization, image enhancement", + "description": "Models for layout detection, reading order detection, textline detection, page extraction, column classification, table detection, binarization, image enhancement and OCR", "version_range": ">= v0.5.0" }, { diff --git a/src/eynollah/ocrd_cli.py b/src/eynollah/ocrd_cli.py index 8929927..acd8d4e 100644 --- a/src/eynollah/ocrd_cli.py +++ b/src/eynollah/ocrd_cli.py @@ -1,3 +1,6 @@ +# NOTE: For predictable order of imports of torch/shapely/tensorflow +# this must be the first import of the CLI! +from .eynollah_imports import imported_libs from .processor import EynollahProcessor from click import command from ocrd.decorators import ocrd_cli_options, ocrd_cli_wrap_processor diff --git a/src/eynollah/ocrd_cli_binarization.py b/src/eynollah/ocrd_cli_binarization.py index 6289517..e9059df 100644 --- a/src/eynollah/ocrd_cli_binarization.py +++ b/src/eynollah/ocrd_cli_binarization.py @@ -1,6 +1,8 @@ +from functools import cached_property from typing import Optional from PIL import Image +from frozendict import frozendict import numpy as np import cv2 from click import command @@ -9,6 +11,8 @@ from ocrd import Processor, OcrdPageResult, OcrdPageResultImage from ocrd_models.ocrd_page import OcrdPage, AlternativeImageType from ocrd.decorators import ocrd_cli_options, ocrd_cli_wrap_processor +from eynollah.model_zoo.model_zoo import EynollahModelZoo + from .sbb_binarize import SbbBinarizer @@ -25,7 +29,7 @@ class SbbBinarizeProcessor(Processor): # already employs GPU (without singleton process atm) max_workers = 1 - @property + @cached_property def executable(self): return 'ocrd-sbb-binarize' @@ -34,8 +38,9 @@ class SbbBinarizeProcessor(Processor): Set up the model prior to processing. """ # resolve relative path via OCR-D ResourceManager - model_path = self.resolve_resource(self.parameter['model']) - self.binarizer = SbbBinarizer(model_dir=model_path, logger=self.logger) + assert isinstance(self.parameter, frozendict) + model_zoo = EynollahModelZoo(basedir=self.parameter['model']) + self.binarizer = SbbBinarizer(model_zoo=model_zoo, logger=self.logger) def process_page_pcgts(self, *input_pcgts: Optional[OcrdPage], page_id: Optional[str] = None) -> OcrdPageResult: """ @@ -98,7 +103,7 @@ class SbbBinarizeProcessor(Processor): line_image_bin = cv2pil(self.binarizer.run_single(image=pil2cv(line_image), use_patches=True)) # update PAGE (reference the image file): line_image_ref = AlternativeImageType(comments=line_xywh['features'] + ',binarized') - line.add_AlternativeImage(region_image_ref) + line.add_AlternativeImage(line_image_ref) result.images.append(OcrdPageResultImage(line_image_bin, line.id + '.IMG-BIN', line_image_ref)) return result diff --git a/src/eynollah/patch_encoder.py b/src/eynollah/patch_encoder.py new file mode 100644 index 0000000..dc0a291 --- /dev/null +++ b/src/eynollah/patch_encoder.py @@ -0,0 +1,54 @@ +import os +os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15 +import tensorflow as tf +from tensorflow.keras import layers + +projection_dim = 64 +patch_size = 1 +num_patches =21*21#14*14#28*28#14*14#28*28 + +class PatchEncoder(layers.Layer): + + def __init__(self): + super().__init__() + self.projection = layers.Dense(units=projection_dim) + self.position_embedding = layers.Embedding(input_dim=num_patches, output_dim=projection_dim) + + def call(self, patch): + positions = tf.range(start=0, limit=num_patches, delta=1) + encoded = self.projection(patch) + self.position_embedding(positions) + return encoded + + def get_config(self): + config = super().get_config().copy() + config.update({ + 'num_patches': num_patches, + 'projection': self.projection, + 'position_embedding': self.position_embedding, + }) + return config + +class Patches(layers.Layer): + def __init__(self, **kwargs): + super(Patches, self).__init__() + self.patch_size = patch_size + + def call(self, images): + batch_size = tf.shape(images)[0] + patches = tf.image.extract_patches( + images=images, + sizes=[1, self.patch_size, self.patch_size, 1], + strides=[1, self.patch_size, self.patch_size, 1], + rates=[1, 1, 1, 1], + padding="VALID", + ) + patch_dims = patches.shape[-1] + patches = tf.reshape(patches, [batch_size, -1, patch_dims]) + return patches + def get_config(self): + + config = super().get_config().copy() + config.update({ + 'patch_size': self.patch_size, + }) + return config diff --git a/src/eynollah/plot.py b/src/eynollah/plot.py index c026e94..b1b2359 100644 --- a/src/eynollah/plot.py +++ b/src/eynollah/plot.py @@ -40,8 +40,8 @@ class EynollahPlotter: self.image_filename_stem = image_filename_stem # XXX TODO hacky these cannot be set at init time self.image_org = image_org - self.scale_x = scale_x - self.scale_y = scale_y + self.scale_x : float = scale_x + self.scale_y : float = scale_y def save_plot_of_layout_main(self, text_regions_p, image_page): if self.dir_of_layout is not None: diff --git a/src/eynollah/processor.py b/src/eynollah/processor.py index 12c7356..0addaff 100644 --- a/src/eynollah/processor.py +++ b/src/eynollah/processor.py @@ -3,6 +3,8 @@ from typing import Optional from ocrd_models import OcrdPage from ocrd import OcrdPageResultImage, Processor, OcrdPageResult +from eynollah.model_zoo.model_zoo import EynollahModelZoo + from .eynollah import Eynollah, EynollahXmlWriter class EynollahProcessor(Processor): @@ -16,24 +18,20 @@ class EynollahProcessor(Processor): def setup(self) -> None: assert self.parameter - if self.parameter['textline_light'] != self.parameter['light_version']: - raise ValueError("Error: You must set or unset both parameter 'textline_light' (to enable light textline detection), " - "and parameter 'light_version' (faster+simpler method for main region detection and deskewing)") + model_zoo = EynollahModelZoo(basedir=self.parameter['models']) self.eynollah = Eynollah( - self.resolve_resource(self.parameter['models']), + model_zoo=model_zoo, allow_enhancement=self.parameter['allow_enhancement'], curved_line=self.parameter['curved_line'], right2left=self.parameter['right_to_left'], reading_order_machine_based=self.parameter['reading_order_machine_based'], ignore_page_extraction=self.parameter['ignore_page_extraction'], - light_version=self.parameter['light_version'], - textline_light=self.parameter['textline_light'], full_layout=self.parameter['full_layout'], allow_scaling=self.parameter['allow_scaling'], headers_off=self.parameter['headers_off'], tables=self.parameter['tables'], + logger=self.logger ) - self.eynollah.logger = self.logger self.eynollah.plotter = None def shutdown(self): @@ -90,7 +88,6 @@ class EynollahProcessor(Processor): dir_out=None, image_filename=image_filename, curved_line=self.eynollah.curved_line, - textline_light=self.eynollah.textline_light, pcgts=pcgts) self.eynollah.run_single() return result diff --git a/src/eynollah/sbb_binarize.py b/src/eynollah/sbb_binarize.py index 2ca4a40..fe044c9 100644 --- a/src/eynollah/sbb_binarize.py +++ b/src/eynollah/sbb_binarize.py @@ -2,20 +2,25 @@ Tool to load model and binarize a given image. """ -from glob import glob +# pyright: reportIndexIssue=false +# pyright: reportCallIssue=false +# pyright: reportArgumentType=false +# pyright: reportPossiblyUnboundVariable=false + import os import logging -from PIL import Image +from pathlib import Path +from typing import Optional import numpy as np import cv2 -from ocrd_utils import tf_disable_interactive_logs os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15 +from ocrd_utils import tf_disable_interactive_logs tf_disable_interactive_logs() import tensorflow as tf -from tensorflow.keras.models import load_model +from .model_zoo import EynollahModelZoo from .utils import is_image_filename def resize_image(img_in, input_height, input_width): @@ -23,30 +28,24 @@ def resize_image(img_in, input_height, input_width): class SbbBinarizer: - def __init__(self, model_dir, logger=None): - self.model_dir = model_dir - self.logger = logger if logger else logging.getLogger('SbbBinarizer') - + def __init__( + self, + *, + model_zoo: EynollahModelZoo, + logger: Optional[logging.Logger] = None, + ): + self.logger = logger if logger else logging.getLogger('eynollah.binarization') try: for device in tf.config.list_physical_devices('GPU'): tf.config.experimental.set_memory_growth(device, True) except: self.logger.warning("no GPU device available") + self.models = (model_zoo.model_path('binarization'), model_zoo.load_model('binarization')) + self.logger.info('Loaded model %s [%s]', self.models[1], self.models[0]) - self.model_files = glob(self.model_dir + "/*/", recursive=True) - self.models = [] - for model_file in self.model_files: - self.models.append(self.load_model(model_file)) - - def load_model(self, model_name): - model = load_model(os.path.join(self.model_dir, model_name), compile=False) + def predict(self, model, img, use_patches, n_batch_inference=5): model_height = model.layers[len(model.layers)-1].output_shape[1] model_width = model.layers[len(model.layers)-1].output_shape[2] - n_classes = model.layers[len(model.layers)-1].output_shape[3] - return model, model_height, model_width, n_classes - - def predict(self, model_in, img, use_patches, n_batch_inference=5): - model, model_height, model_width, n_classes = model_in img_org_h = img.shape[0] img_org_w = img.shape[1] @@ -305,44 +304,57 @@ class SbbBinarizer: prediction_true = prediction_true.astype(np.uint8) return prediction_true[:,:,0] - def run(self, image_path=None, output=None, dir_in=None, use_patches=False, overwrite=False): - if dir_in: - ls_imgs = [(os.path.join(dir_in, image_filename), - os.path.join(output, os.path.splitext(image_filename)[0] + '.png')) - for image_filename in filter(is_image_filename, - os.listdir(dir_in))] + def run(self, image=None, image_path=None, output=None, use_patches=False, dir_in=None, overwrite=False): + if not dir_in: + if (image is None) == (image_path is None): + raise ValueError("Must pass either a opencv2 image or an image_path") + if image_path is not None: + image = cv2.imread(image_path) + img_last = self.run_single(image, use_patches) + if output: + if os.path.exists(output): + if overwrite: + self.logger.warning("will overwrite existing output file '%s'", output) + else: + self.logger.warning("output file already exists '%s'", output) + return img_last + self.logger.info('Writing binarized image to %s', output) + cv2.imwrite(output, img_last) + return img_last else: - ls_imgs = [(image_path, output)] - - for input_path, output_path in ls_imgs: - print(input_path, 'image_name') - if os.path.exists(output_path): - if overwrite: - self.logger.warning("will overwrite existing output file '%s'", output_path) - else: - self.logger.warning("will skip input for existing output file '%s'", output_path) - image = cv2.imread(input_path) - result = self.run_single(image, use_patches) - cv2.imwrite(output_path, result) + ls_imgs = list(filter(is_image_filename, os.listdir(dir_in))) + self.logger.info("Found %d image files to binarize in %s", len(ls_imgs), dir_in) + for i, image_path in enumerate(ls_imgs): + image_stem = os.path.splitext(image_path)[0] + output_path = os.path.join(output, image_stem + '.png') + if os.path.exists(output_path): + if overwrite: + self.logger.warning("will overwrite existing output file '%s'", output_path) + else: + self.logger.warning("will skip input for existing output file '%s'", output_path) + continue + self.logger.info('Binarizing [%3d/%d] %s', i + 1, len(ls_imgs), image_path) + image = cv2.imread(os.path.join(dir_in, image_path)) + img_last = self.run_single(image, use_patches) + self.logger.info('Writing binarized image to %s', output_path) + cv2.imwrite(output_path, img_last) def run_single(self, image: np.ndarray, use_patches=False): img_last = 0 - for n, (model, model_file) in enumerate(zip(self.models, self.model_files)): - self.logger.info('Predicting with model %s [%s/%s]' % (model_file, n + 1, len(self.model_files))) + model_file, model = self.models + res = self.predict(model, image, use_patches) - res = self.predict(model, image, use_patches) + img_fin = np.zeros((res.shape[0], res.shape[1], 3)) + res[:, :][res[:, :] == 0] = 2 + res = res - 1 + res = res * 255 + img_fin[:, :, 0] = res + img_fin[:, :, 1] = res + img_fin[:, :, 2] = res - img_fin = np.zeros((res.shape[0], res.shape[1], 3)) - res[:, :][res[:, :] == 0] = 2 - res = res - 1 - res = res * 255 - img_fin[:, :, 0] = res - img_fin[:, :, 1] = res - img_fin[:, :, 2] = res - - img_fin = img_fin.astype(np.uint8) - img_fin = (res[:, :] == 0) * 255 - img_last = img_last + img_fin + img_fin = img_fin.astype(np.uint8) + img_fin = (res[:, :] == 0) * 255 + img_last = img_last + img_fin kernel = np.ones((5, 5), np.uint8) img_last[:, :][img_last[:, :] > 0] = 255 diff --git a/src/eynollah/training/cli.py b/src/eynollah/training/cli.py index 8ab754d..3718275 100644 --- a/src/eynollah/training/cli.py +++ b/src/eynollah/training/cli.py @@ -8,6 +8,8 @@ from .build_model_load_pretrained_weights_and_save import build_model_load_pretr from .generate_gt_for_training import main as generate_gt_cli from .inference import main as inference_cli from .train import ex +from .extract_line_gt import linegt_cli +from .weights_ensembling import main as ensemble_cli @click.command(context_settings=dict( ignore_unknown_options=True, @@ -24,3 +26,5 @@ main.add_command(build_model_load_pretrained_weights_and_save) main.add_command(generate_gt_cli, 'generate-gt') main.add_command(inference_cli, 'inference') main.add_command(train_cli, 'train') +main.add_command(linegt_cli, 'export_textline_images_and_text') +main.add_command(ensemble_cli, 'ensembling') diff --git a/src/eynollah/training/extract_line_gt.py b/src/eynollah/training/extract_line_gt.py new file mode 100644 index 0000000..58fc253 --- /dev/null +++ b/src/eynollah/training/extract_line_gt.py @@ -0,0 +1,134 @@ +from logging import Logger, getLogger +from typing import Optional +from pathlib import Path +import os + +import click +import cv2 +import xml.etree.ElementTree as ET +import numpy as np + +from ..utils import is_image_filename + +@click.command() +@click.option( + "--image", + "-i", + help="input image filename", + type=click.Path(exists=True, dir_okay=False), +) +@click.option( + "--dir_in", + "-di", + help="directory of input images (instead of --image)", + type=click.Path(exists=True, file_okay=False), +) +@click.option( + "--dir_xmls", + "-dx", + help="directory of input PAGE-XML files (in addition to --dir_in; filename stems must match the image files, with '.xml' suffix).", + type=click.Path(exists=True, file_okay=False), + required=True, +) +@click.option( + "--out", + "-o", + 'dir_out', + help="directory for output PAGE-XML files", + type=click.Path(exists=True, file_okay=False), + required=True, +) +@click.option( + "--dataset_abbrevation", + "-ds_pref", + 'pref_of_dataset', + help="in the case of extracting textline and text from a xml GT file user can add an abbrevation of dataset name to generated dataset", +) +@click.option( + "--do_not_mask_with_textline_contour", + "-nmtc/-mtc", + is_flag=True, + help="if this parameter set to true, cropped textline images will not be masked with textline contour.", +) +def linegt_cli( + image, + dir_in, + dir_xmls, + dir_out, + pref_of_dataset, + do_not_mask_with_textline_contour, +): + assert bool(dir_in) ^ bool(image), "Set --dir-in or --image-filename, not both" + if dir_in: + ls_imgs = [ + os.path.join(dir_in, image) for image in filter(is_image_filename, os.listdir(dir_in)) + ] + else: + assert image + ls_imgs = [image] + + for dir_img in ls_imgs: + file_name = Path(dir_img).stem + dir_xml = os.path.join(dir_xmls, file_name + '.xml') + + img = cv2.imread(dir_img) + + total_bb_coordinates = [] + + tree1 = ET.parse(dir_xml, parser=ET.XMLParser(encoding="utf-8")) + root1 = tree1.getroot() + alltags = [elem.tag for elem in root1.iter()] + + name_space = alltags[0].split('}')[0] + name_space = name_space.split('{')[1] + + region_tags = np.unique([x for x in alltags if x.endswith('TextRegion')]) + + cropped_lines_region_indexer = [] + + indexer_text_region = 0 + indexer_textlines = 0 + # FIXME: non recursive, use OCR-D PAGE generateDS API. Or use an existing tool for this purpose altogether + for nn in root1.iter(region_tags): + for child_textregion in nn: + if child_textregion.tag.endswith("TextLine"): + for child_textlines in child_textregion: + if child_textlines.tag.endswith("Coords"): + cropped_lines_region_indexer.append(indexer_text_region) + p_h = child_textlines.attrib['points'].split(' ') + textline_coords = np.array([[int(x.split(',')[0]), int(x.split(',')[1])] for x in p_h]) + + x, y, w, h = cv2.boundingRect(textline_coords) + + total_bb_coordinates.append([x, y, w, h]) + + img_poly_on_img = np.copy(img) + + mask_poly = np.zeros(img.shape) + mask_poly = cv2.fillPoly(mask_poly, pts=[textline_coords], color=(1, 1, 1)) + + mask_poly = mask_poly[y : y + h, x : x + w, :] + img_crop = img_poly_on_img[y : y + h, x : x + w, :] + + if not do_not_mask_with_textline_contour: + img_crop[mask_poly == 0] = 255 + + if img_crop.shape[0] == 0 or img_crop.shape[1] == 0: + continue + if child_textlines.tag.endswith("TextEquiv"): + for cheild_text in child_textlines: + if cheild_text.tag.endswith("Unicode"): + textline_text = cheild_text.text + if textline_text: + base_name = os.path.join( + dir_out, file_name + '_line_' + str(indexer_textlines) + ) + if pref_of_dataset: + base_name += '_' + pref_of_dataset + if not do_not_mask_with_textline_contour: + base_name += '_masked' + + with open(base_name + '.txt', 'w') as text_file: + text_file.write(textline_text) + cv2.imwrite(base_name + '.png', img_crop) + indexer_textlines += 1 diff --git a/src/eynollah/training/generate_gt_for_training.py b/src/eynollah/training/generate_gt_for_training.py index f71614c..2c076d3 100644 --- a/src/eynollah/training/generate_gt_for_training.py +++ b/src/eynollah/training/generate_gt_for_training.py @@ -480,7 +480,7 @@ def visualize_layout_segmentation(xml_file, dir_xml, dir_out, dir_imgs): img_file_name_with_format = find_format_of_given_filename_in_dir(dir_imgs, f_name) img = cv2.imread(os.path.join(dir_imgs, img_file_name_with_format)) - co_text, co_graphic, co_sep, co_img, co_table, co_noise, y_len, x_len = get_layout_contours_for_visualization(xml_file) + co_text, co_graphic, co_sep, co_img, co_table, co_map, co_noise, y_len, x_len = get_layout_contours_for_visualization(xml_file) added_image = visualize_image_from_contours_layout(co_text['paragraph'], co_text['header']+co_text['heading'], co_text['drop-capital'], co_sep, co_img, co_text['marginalia'], co_table, img) diff --git a/src/eynollah/training/gt_gen_utils.py b/src/eynollah/training/gt_gen_utils.py index f068afd..8204a8e 100644 --- a/src/eynollah/training/gt_gen_utils.py +++ b/src/eynollah/training/gt_gen_utils.py @@ -18,7 +18,7 @@ with warnings.catch_warnings(): warnings.simplefilter("ignore") -def visualize_image_from_contours_layout(co_par, co_header, co_drop, co_sep, co_image, co_marginal, co_table, img): +def visualize_image_from_contours_layout(co_par, co_header, co_drop, co_sep, co_image, co_marginal, co_table, co_map, img): alpha = 0.5 blank_image = np.ones( (img.shape[:]), dtype=np.uint8) * 255 @@ -31,6 +31,7 @@ def visualize_image_from_contours_layout(co_par, co_header, co_drop, co_sep, co_ col_sep = (255, 0, 0) col_marginal = (106, 90, 205) col_table = (0, 90, 205) + col_map = (90, 90, 205) if len(co_image)>0: cv2.drawContours(blank_image, co_image, -1, col_image, thickness=cv2.FILLED) # Fill the contour @@ -55,6 +56,9 @@ def visualize_image_from_contours_layout(co_par, co_header, co_drop, co_sep, co_ if len(co_table)>0: cv2.drawContours(blank_image, co_table, -1, col_table, thickness=cv2.FILLED) # Fill the contour + + if len(co_map)>0: + cv2.drawContours(blank_image, co_map, -1, col_map, thickness=cv2.FILLED) # Fill the contour img_final =cv2.cvtColor(blank_image, cv2.COLOR_BGR2RGB) @@ -234,7 +238,12 @@ def update_region_contours(co_text, img_boundary, erosion_rate, dilation_rate, y con_eroded = return_contours_of_interested_region(img_boundary_in,pixel, min_size ) try: - co_text_eroded.append(con_eroded[0]) + if len(con_eroded)>1: + cnt_size = np.array([cv2.contourArea(con_eroded[j]) for j in range(len(con_eroded))]) + cnt = contours[np.argmax(cnt_size)] + co_text_eroded.append(cnt) + else: + co_text_eroded.append(con_eroded[0]) except: co_text_eroded.append(con) @@ -255,6 +264,7 @@ def get_textline_contours_for_visualization(xml_file): + x_len, y_len = 0, 0 for jj in root1.iter(link+'Page'): y_len=int(jj.attrib['imageHeight']) x_len=int(jj.attrib['imageWidth']) @@ -296,6 +306,7 @@ def get_textline_contours_and_ocr_text(xml_file): + x_len, y_len = 0, 0 for jj in root1.iter(link+'Page'): y_len=int(jj.attrib['imageHeight']) x_len=int(jj.attrib['imageWidth']) @@ -365,7 +376,7 @@ def get_layout_contours_for_visualization(xml_file): link=alltags[0].split('}')[0]+'}' - + x_len, y_len = 0, 0 for jj in root1.iter(link+'Page'): y_len=int(jj.attrib['imageHeight']) x_len=int(jj.attrib['imageWidth']) @@ -378,6 +389,7 @@ def get_layout_contours_for_visualization(xml_file): co_sep=[] co_img=[] co_table=[] + co_map=[] co_noise=[] types_text = [] @@ -594,6 +606,31 @@ def get_layout_contours_for_visualization(xml_file): elif vv.tag!=link+'Point' and sumi>=1: break co_table.append(np.array(c_t_in)) + + if tag.endswith('}MapRegion') or tag.endswith('}mapregion'): + #print('sth') + for nn in root1.iter(tag): + c_t_in=[] + sumi=0 + for vv in nn.iter(): + # check the format of coords + if vv.tag==link+'Coords': + coords=bool(vv.attrib) + if coords: + p_h=vv.attrib['points'].split(' ') + c_t_in.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + break + else: + pass + + + if vv.tag==link+'Point': + c_t_in.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ]) + sumi+=1 + #print(vv.tag,'in') + elif vv.tag!=link+'Point' and sumi>=1: + break + co_map.append(np.array(c_t_in)) if tag.endswith('}NoiseRegion') or tag.endswith('}noiseregion'): @@ -620,7 +657,7 @@ def get_layout_contours_for_visualization(xml_file): elif vv.tag!=link+'Point' and sumi>=1: break co_noise.append(np.array(c_t_in)) - return co_text, co_graphic, co_sep, co_img, co_table, co_noise, y_len, x_len + return co_text, co_graphic, co_sep, co_img, co_table, co_map, co_noise, y_len, x_len def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_file, config_params, printspace, dir_images, dir_out_images): """ @@ -643,24 +680,21 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_ link=alltags[0].split('}')[0]+'}' - + x_len, y_len = 0, 0 for jj in root1.iter(link+'Page'): y_len=int(jj.attrib['imageHeight']) x_len=int(jj.attrib['imageWidth']) if 'columns_width' in list(config_params.keys()): columns_width_dict = config_params['columns_width'] + # FIXME: look in /Page/@custom as well metadata_element = root1.find(link+'Metadata') - comment_is_sub_element = False + num_col = None for child in metadata_element: tag2 = child.tag if tag2.endswith('}Comments') or tag2.endswith('}comments'): text_comments = child.text num_col = int(text_comments.split('num_col')[1]) - comment_is_sub_element = True - if not comment_is_sub_element: - # FIXME: look in /Page/@custom as well - num_col = None if num_col: x_new = columns_width_dict[str(num_col)] @@ -812,7 +846,7 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_ types_graphic_label = list(types_graphic_dict.values()) - labels_rgb_color = [ (0,0,0), (255,0,0), (255,125,0), (255,0,125), (125,255,125), (125,125,0), (0,125,255), (0,125,0), (125,125,125), (255,0,255), (125,0,125), (0,255,0),(0,0,255), (0,255,255), (255,125,125), (0,125,125), (0,255,125), (255,125,255), (125,255,0)] + labels_rgb_color = [ (0,0,0), (255,0,0), (255,125,0), (255,0,125), (125,255,125), (125,125,0), (0,125,255), (0,125,0), (125,125,125), (255,0,255), (125,0,125), (0,255,0),(0,0,255), (0,255,255), (255,125,125), (0,125,125), (0,255,125), (255,125,255), (125,255,0), (125,255,255)] region_tags=np.unique([x for x in alltags if x.endswith('Region')]) @@ -823,6 +857,7 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_ co_sep=[] co_img=[] co_table=[] + co_map=[] co_noise=[] for tag in region_tags: @@ -1033,6 +1068,32 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_ elif vv.tag!=link+'Point' and sumi>=1: break co_table.append(np.array(c_t_in)) + + if 'mapregion' in keys: + if tag.endswith('}MapRegion') or tag.endswith('}mapregion'): + #print('sth') + for nn in root1.iter(tag): + c_t_in=[] + sumi=0 + for vv in nn.iter(): + # check the format of coords + if vv.tag==link+'Coords': + coords=bool(vv.attrib) + if coords: + p_h=vv.attrib['points'].split(' ') + c_t_in.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + break + else: + pass + + + if vv.tag==link+'Point': + c_t_in.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ]) + sumi+=1 + #print(vv.tag,'in') + elif vv.tag!=link+'Point' and sumi>=1: + break + co_map.append(np.array(c_t_in)) if 'noiseregion' in keys: if tag.endswith('}NoiseRegion') or tag.endswith('}noiseregion'): @@ -1106,6 +1167,10 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_ erosion_rate = 0#2 dilation_rate = 3#4 co_table, img_boundary = update_region_contours(co_table, img_boundary, erosion_rate, dilation_rate, y_len, x_len ) + if "mapregion" in elements_with_artificial_class: + erosion_rate = 0#2 + dilation_rate = 3#4 + co_map, img_boundary = update_region_contours(co_map, img_boundary, erosion_rate, dilation_rate, y_len, x_len ) @@ -1131,6 +1196,8 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_ img_poly=cv2.fillPoly(img, pts =co_img, color=labels_rgb_color[ config_params['imageregion']]) if 'tableregion' in keys: img_poly=cv2.fillPoly(img, pts =co_table, color=labels_rgb_color[ config_params['tableregion']]) + if 'mapregion' in keys: + img_poly=cv2.fillPoly(img, pts =co_map, color=labels_rgb_color[ config_params['mapregion']]) if 'noiseregion' in keys: img_poly=cv2.fillPoly(img, pts =co_noise, color=labels_rgb_color[ config_params['noiseregion']]) @@ -1192,6 +1259,9 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_ if 'tableregion' in keys: color_label = config_params['tableregion'] img_poly=cv2.fillPoly(img, pts =co_table, color=(color_label,color_label,color_label)) + if 'mapregion' in keys: + color_label = config_params['mapregion'] + img_poly=cv2.fillPoly(img, pts =co_map, color=(color_label,color_label,color_label)) if 'noiseregion' in keys: color_label = config_params['noiseregion'] img_poly=cv2.fillPoly(img, pts =co_noise, color=(color_label,color_label,color_label)) @@ -1690,15 +1760,15 @@ def read_xml(xml_file): index_tot_regions, img_poly) -def bounding_box(cnt,color, corr_order_index ): - x, y, w, h = cv2.boundingRect(cnt) - x = int(x*scale_w) - y = int(y*scale_h) - - w = int(w*scale_w) - h = int(h*scale_h) - - return [x,y,w,h,int(color), int(corr_order_index)+1] +# def bounding_box(cnt,color, corr_order_index ): +# x, y, w, h = cv2.boundingRect(cnt) +# x = int(x*scale_w) +# y = int(y*scale_h) +# +# w = int(w*scale_w) +# h = int(h*scale_h) +# +# return [x,y,w,h,int(color), int(corr_order_index)+1] def resize_image(seg_in,input_height,input_width): return cv2.resize(seg_in,(input_width,input_height),interpolation=cv2.INTER_NEAREST) diff --git a/src/eynollah/training/inference.py b/src/eynollah/training/inference.py index 2ef1a91..454c689 100644 --- a/src/eynollah/training/inference.py +++ b/src/eynollah/training/inference.py @@ -4,17 +4,19 @@ Tool to load model and predict for given image. import sys import os +from typing import Tuple import warnings import json import click import numpy as np +from numpy._typing import NDArray import cv2 +import xml.etree.ElementTree as ET os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15 import tensorflow as tf -from tensorflow.keras.models import load_model -import xml.etree.ElementTree as ET +from tensorflow.keras.models import Model, load_model from .gt_gen_utils import ( filter_contours_area_of_image, @@ -32,6 +34,9 @@ from .metrics import ( weighted_categorical_crossentropy, ) +from.utils import (scale_padd_image_for_ocr) +from eynollah.utils.utils_ocr import (decode_batch_predictions) + with warnings.catch_warnings(): warnings.simplefilter("ignore") @@ -47,9 +52,10 @@ class SBBPredict: save_layout, ground_truth, xml_file, + cpu, out, - min_area): - + min_area, + ): self.image=image self.dir_in=dir_in self.patches=patches @@ -61,6 +67,7 @@ class SBBPredict: self.config_params_model=config_params_model self.xml_file = xml_file self.out = out + self.cpu = cpu if min_area: self.min_area = float(min_area) else: @@ -111,30 +118,35 @@ class SBBPredict: return mIoU def start_new_session_and_model(self): - try: - for device in tf.config.list_physical_devices('GPU'): - tf.config.experimental.set_memory_growth(device, True) - except: - print("no GPU device available", file=sys.stderr) + if self.cpu: + tf.config.set_visible_devices([], 'GPU') + else: + try: + for device in tf.config.list_physical_devices('GPU'): + tf.config.experimental.set_memory_growth(device, True) + except: + print("no GPU device available", file=sys.stderr) - #tensorflow.keras.layers.custom_layer = PatchEncoder - #tensorflow.keras.layers.custom_layer = Patches - self.model = load_model(self.model_dir, compile=False, - custom_objects={"PatchEncoder": PatchEncoder, - "Patches": Patches}) - #keras.losses.custom_loss = weighted_categorical_crossentropy - #self.model = load_model(self.model_dir, compile=False) + if self.task == "cnn-rnn-ocr": + self.model = Model( + self.model.get_layer(name = "image").input, + self.model.get_layer(name = "dense2").output) + else: + self.model = load_model(self.model_dir, compile=False, + custom_objects={"PatchEncoder": PatchEncoder, + "Patches": Patches}) ##if self.weights_dir!=None: ##self.model.load_weights(self.weights_dir) + assert isinstance(self.model, Model) if self.task != 'classification' and self.task != 'reading_order': last = self.model.layers[-1] self.img_height = last.output_shape[1] self.img_width = last.output_shape[2] self.n_classes = last.output_shape[3] - def visualize_model_output(self, prediction, img, task): + def visualize_model_output(self, prediction, img, task) -> Tuple[NDArray, NDArray]: if task == "binarization": prediction = prediction * -1 prediction = prediction + 1 @@ -173,9 +185,12 @@ class SBBPredict: added_image = cv2.addWeighted(img,0.5,layout_only,0.1,0) + assert isinstance(added_image, np.ndarray) + assert isinstance(layout_only, np.ndarray) return added_image, layout_only def predict(self, image_dir): + assert isinstance(self.model, Model) if self.task == 'classification': classes_names = self.config_params_model['classification_classes_name'] img_1ch = cv2.imread(image_dir, 0) / 255.0 @@ -187,11 +202,35 @@ class SBBPredict: img_in[0, :, :, 1] = img_1ch[:, :] img_in[0, :, :, 2] = img_1ch[:, :] - label_p_pred = self.model.predict(img_in, verbose=0) + label_p_pred = self.model.predict(img_in, verbose='0') index_class = np.argmax(label_p_pred[0]) print("Predicted Class: {}".format(classes_names[str(int(index_class))])) + elif self.task == "cnn-rnn-ocr": + img=cv2.imread(image_dir) + img = scale_padd_image_for_ocr(img, self.config_params_model['input_height'], self.config_params_model['input_width']) + + img = img / 255. + + with open(os.path.join(self.model_dir, "characters_org.txt"), 'r') as char_txt_f: + characters = json.load(char_txt_f) + + AUTOTUNE = tf.data.AUTOTUNE + + # Mapping characters to integers. + char_to_num = StringLookup(vocabulary=list(characters), mask_token=None) + + # Mapping integers back to original characters. + num_to_char = StringLookup( + vocabulary=char_to_num.get_vocabulary(), mask_token=None, invert=True + ) + preds = self.model.predict(img.reshape(1, img.shape[0], img.shape[1], img.shape[2]), verbose=0) + pred_texts = decode_batch_predictions(preds, num_to_char) + pred_texts = pred_texts[0].replace("[UNK]", "") + return pred_texts + + elif self.task == 'reading_order': img_height = self.config_params_model['input_height'] img_width = self.config_params_model['input_width'] @@ -311,7 +350,7 @@ class SBBPredict: #input_1[:,:,1] = img3[:,:,0]/5. if batch_counter==inference_bs or ( (tot_counter//inference_bs)==full_bs_ite and tot_counter%inference_bs==last_bs): - y_pr = self.model.predict(input_1 , verbose=0) + y_pr = self.model.predict(input_1 , verbose='0') scalibility_num = scalibility_num+1 if batch_counter==inference_bs: @@ -345,6 +384,7 @@ class SBBPredict: name_space = name_space.split('{')[1] page_element = root_xml.find(link+'Page') + assert isinstance(page_element, ET.Element) """ ro_subelement = ET.SubElement(page_element, 'ReadingOrder') @@ -439,7 +479,7 @@ class SBBPredict: img_patch = img[index_y_d:index_y_u, index_x_d:index_x_u, :] label_p_pred = self.model.predict(img_patch.reshape(1, img_patch.shape[0], img_patch.shape[1], img_patch.shape[2]), - verbose=0) + verbose='0') if self.task == 'enhancement': seg = label_p_pred[0, :, :, :] @@ -447,6 +487,8 @@ class SBBPredict: elif self.task == 'segmentation' or self.task == 'binarization': seg = np.argmax(label_p_pred, axis=3)[0] seg = np.repeat(seg[:, :, np.newaxis], 3, axis=2) + else: + raise ValueError(f"Unhandled task {self.task}") if i == 0 and j == 0: @@ -501,6 +543,8 @@ class SBBPredict: elif self.task == 'segmentation' or self.task == 'binarization': seg = np.argmax(label_p_pred, axis=3)[0] seg = np.repeat(seg[:, :, np.newaxis], 3, axis=2) + else: + raise ValueError(f"Unhandled task {self.task}") prediction_true = seg.astype(int) @@ -519,6 +563,8 @@ class SBBPredict: elif self.task == 'enhancement': if self.save: cv2.imwrite(self.save,res) + elif self.task == "cnn-rnn-ocr": + print(f"Detected text: {res}") else: img_seg_overlayed, only_layout = self.visualize_model_output(res, self.img_org, self.task) if self.save: @@ -526,9 +572,9 @@ class SBBPredict: if self.save_layout: cv2.imwrite(self.save_layout, only_layout) - if self.ground_truth: - gt_img=cv2.imread(self.ground_truth) - self.IoU(gt_img[:,:,0],res[:,:,0]) + if self.ground_truth: + gt_img=cv2.imread(self.ground_truth) + self.IoU(gt_img[:,:,0],res[:,:,0]) else: ls_images = os.listdir(self.dir_in) @@ -542,6 +588,8 @@ class SBBPredict: elif self.task == 'enhancement': self.save = os.path.join(self.out, f_name+'.png') cv2.imwrite(self.save,res) + elif self.task == "cnn-rnn-ocr": + print(f"Detected text for file name {f_name} is: {res}") else: img_seg_overlayed, only_layout = self.visualize_model_output(res, self.img_org, self.task) self.save = os.path.join(self.out, f_name+'_overlayed.png') @@ -549,9 +597,9 @@ class SBBPredict: self.save_layout = os.path.join(self.out, f_name+'_layout.png') cv2.imwrite(self.save_layout, only_layout) - if self.ground_truth: - gt_img=cv2.imread(self.ground_truth) - self.IoU(gt_img[:,:,0],res[:,:,0]) + if self.ground_truth: + gt_img=cv2.imread(self.ground_truth) + self.IoU(gt_img[:,:,0],res[:,:,0]) @@ -607,22 +655,27 @@ class SBBPredict: "-xml", help="xml file with layout coordinates that reading order detection will be implemented on. The result will be written in the same xml file.", ) - +@click.option( + "--cpu", + "-cpu", + help="For OCR, the default device is the GPU. If this parameter is set to true, inference will be performed on the CPU", + is_flag=True, +) @click.option( "--min_area", "-min", help="min area size of regions considered for reading order detection. The default value is zero and means that all text regions are considered for reading order.", ) -def main(image, dir_in, model, patches, save, save_layout, ground_truth, xml_file, out, min_area): +def main(image, dir_in, model, patches, save, save_layout, ground_truth, xml_file, cpu, out, min_area): assert image or dir_in, "Either a single image -i or a dir_in -di input is required" with open(os.path.join(model,'config.json')) as f: config_params_model = json.load(f) task = config_params_model['task'] - if task != 'classification' and task != 'reading_order': + if task not in ['classification', 'reading_order', "cnn-rnn-ocr"]: assert not image or save, "For segmentation or binarization, an input single image -i also requires an output filename -s" assert not dir_in or out, "For segmentation or binarization, an input directory -di also requires an output directory -o" x = SBBPredict(image, dir_in, model, task, config_params_model, - patches, save, save_layout, ground_truth, xml_file, out, - min_area) + patches, save, save_layout, ground_truth, xml_file, + cpu, out, min_area) x.run() diff --git a/src/eynollah/training/metrics.py b/src/eynollah/training/metrics.py index cd30b02..a8f47d7 100644 --- a/src/eynollah/training/metrics.py +++ b/src/eynollah/training/metrics.py @@ -147,6 +147,7 @@ def generalized_dice_loss(y_true, y_pred): return 1 - generalized_dice_coeff2(y_true, y_pred) +# TODO: document where this is from def soft_dice_loss(y_true, y_pred, epsilon=1e-6): """ Soft dice loss calculation for arbitrary batch size, number of classes, and number of spatial dimensions. @@ -175,6 +176,7 @@ def soft_dice_loss(y_true, y_pred, epsilon=1e-6): return 1.00 - K.mean(numerator / (denominator + epsilon)) # average over classes and batch +# TODO: document where this is from def seg_metrics(y_true, y_pred, metric_name, metric_type='standard', drop_last=True, mean_per_class=False, verbose=False): """ @@ -267,6 +269,8 @@ def seg_metrics(y_true, y_pred, metric_name, metric_type='standard', drop_last=T return K.mean(non_zero_sum / non_zero_count) +# TODO: document where this is from +# TODO: Why a different implementation than IoU from utils? def mean_iou(y_true, y_pred, **kwargs): """ Compute mean Intersection over Union of two segmentation masks, via Keras. @@ -311,6 +315,7 @@ def iou_vahid(y_true, y_pred): return K.mean(iou) +# TODO: copy from utils? def IoU_metric(Yi, y_predi): # mean Intersection over Union # Mean IoU = TP/(FN + TP + FP) @@ -337,6 +342,7 @@ def IoU_metric_keras(y_true, y_pred): return IoU_metric(y_true.eval(session=sess), y_pred.eval(session=sess)) +# TODO: unused, remove? def jaccard_distance_loss(y_true, y_pred, smooth=100): """ Jaccard = (|X & Y|)/ (|X|+ |Y| - |X & Y|) diff --git a/src/eynollah/training/models.py b/src/eynollah/training/models.py index f053447..d1148f1 100644 --- a/src/eynollah/training/models.py +++ b/src/eynollah/training/models.py @@ -2,12 +2,36 @@ import os os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15 import tensorflow as tf -from tensorflow import keras -from tensorflow.keras.models import * -from tensorflow.keras.layers import * -from tensorflow.keras import layers +from tensorflow.keras.layers import ( + Activation, + Add, + AveragePooling2D, + BatchNormalization, + Bidirectional, + Conv1D, + Conv2D, + Dense, + Dropout, + Embedding, + Flatten, + Input, + Lambda, + Layer, + LayerNormalization, + LSTM, + MaxPooling2D, + MultiHeadAttention, + Reshape, + UpSampling2D, + ZeroPadding2D, + add, + concatenate +) +from tensorflow.keras.models import Model from tensorflow.keras.regularizers import l2 +from eynollah.patch_encoder import Patches, PatchEncoder + ##mlp_head_units = [512, 256]#[2048, 1024] ###projection_dim = 64 ##transformer_layers = 2#8 @@ -19,96 +43,34 @@ RESNET50_WEIGHTS_URL = ('https://github.com/fchollet/deep-learning-models/releas IMAGE_ORDERING = 'channels_last' MERGE_AXIS = -1 + +class CTCLayer(Layer): + def __init__(self, name=None): + super().__init__(name=name) + self.loss_fn = tf.keras.backend.ctc_batch_cost + + def call(self, y_true, y_pred): + batch_len = tf.cast(tf.shape(y_true)[0], dtype="int64") + input_length = tf.cast(tf.shape(y_pred)[1], dtype="int64") + label_length = tf.cast(tf.shape(y_true)[1], dtype="int64") + + input_length = input_length * tf.ones(shape=(batch_len, 1), dtype="int64") + label_length = label_length * tf.ones(shape=(batch_len, 1), dtype="int64") + loss = self.loss_fn(y_true, y_pred, input_length, label_length) + self.add_loss(loss) + + # At test time, just return the computed predictions. + return y_pred + def mlp(x, hidden_units, dropout_rate): for units in hidden_units: - x = layers.Dense(units, activation=tf.nn.gelu)(x) - x = layers.Dropout(dropout_rate)(x) + x = Dense(units, activation=tf.nn.gelu)(x) + x = Dropout(dropout_rate)(x) return x -class Patches(layers.Layer): - def __init__(self, patch_size_x, patch_size_y):#__init__(self, **kwargs):#:__init__(self, patch_size):#__init__(self, **kwargs): - super(Patches, self).__init__() - self.patch_size_x = patch_size_x - self.patch_size_y = patch_size_y - - def call(self, images): - #print(tf.shape(images)[1],'images') - #print(self.patch_size,'self.patch_size') - batch_size = tf.shape(images)[0] - patches = tf.image.extract_patches( - images=images, - sizes=[1, self.patch_size_y, self.patch_size_x, 1], - strides=[1, self.patch_size_y, self.patch_size_x, 1], - rates=[1, 1, 1, 1], - padding="VALID", - ) - #patch_dims = patches.shape[-1] - patch_dims = tf.shape(patches)[-1] - patches = tf.reshape(patches, [batch_size, -1, patch_dims]) - return patches - def get_config(self): - - config = super().get_config().copy() - config.update({ - 'patch_size_x': self.patch_size_x, - 'patch_size_y': self.patch_size_y, - }) - return config - -class Patches_old(layers.Layer): - def __init__(self, patch_size):#__init__(self, **kwargs):#:__init__(self, patch_size):#__init__(self, **kwargs): - super(Patches, self).__init__() - self.patch_size = patch_size - - def call(self, images): - #print(tf.shape(images)[1],'images') - #print(self.patch_size,'self.patch_size') - batch_size = tf.shape(images)[0] - patches = tf.image.extract_patches( - images=images, - sizes=[1, self.patch_size, self.patch_size, 1], - strides=[1, self.patch_size, self.patch_size, 1], - rates=[1, 1, 1, 1], - padding="VALID", - ) - patch_dims = patches.shape[-1] - #print(patches.shape,patch_dims,'patch_dims') - patches = tf.reshape(patches, [batch_size, -1, patch_dims]) - return patches - def get_config(self): - - config = super().get_config().copy() - config.update({ - 'patch_size': self.patch_size, - }) - return config - - -class PatchEncoder(layers.Layer): - def __init__(self, num_patches, projection_dim): - super(PatchEncoder, self).__init__() - self.num_patches = num_patches - self.projection = layers.Dense(units=projection_dim) - self.position_embedding = layers.Embedding( - input_dim=num_patches, output_dim=projection_dim - ) - - def call(self, patch): - positions = tf.range(start=0, limit=self.num_patches, delta=1) - encoded = self.projection(patch) + self.position_embedding(positions) - return encoded - def get_config(self): - - config = super().get_config().copy() - config.update({ - 'num_patches': self.num_patches, - 'projection': self.projection, - 'position_embedding': self.position_embedding, - }) - return config - - def one_side_pad(x): + # rs: fixme: lambda layers are problematic for de/serialization! + # - can we use ZeroPadding1D instead of ZeroPadding2D+Lambda? x = ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING)(x) if IMAGE_ORDERING == 'channels_first': x = Lambda(lambda x: x[:, :, :-1, :-1])(x) @@ -150,7 +112,7 @@ def identity_block(input_tensor, kernel_size, filters, stage, block): x = Conv2D(filters3, (1, 1), data_format=IMAGE_ORDERING, name=conv_name_base + '2c')(x) x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2c')(x) - x = layers.add([x, input_tensor]) + x = add([x, input_tensor]) x = Activation('relu')(x) return x @@ -195,12 +157,12 @@ def conv_block(input_tensor, kernel_size, filters, stage, block, strides=(2, 2)) name=conv_name_base + '1')(input_tensor) shortcut = BatchNormalization(axis=bn_axis, name=bn_name_base + '1')(shortcut) - x = layers.add([x, shortcut]) + x = add([x, shortcut]) x = Activation('relu')(x) return x -def resnet50_unet_light(n_classes, input_height=224, input_width=224, taks="segmentation", weight_decay=1e-6, pretraining=False): +def resnet50_unet_light(n_classes, input_height=224, input_width=224, task="segmentation", weight_decay=1e-6, pretraining=False): assert input_height % 32 == 0 assert input_width % 32 == 0 @@ -415,7 +377,7 @@ def vit_resnet50_unet(num_patches, pretraining=False): if transformer_mlp_head_units is None: transformer_mlp_head_units = [128, 64] - inputs = layers.Input(shape=(input_height, input_width, 3)) + inputs = Input(shape=(input_height, input_width, 3)) #transformer_units = [ #projection_dim * 2, @@ -460,27 +422,35 @@ def vit_resnet50_unet(num_patches, model = Model(inputs, x).load_weights(RESNET50_WEIGHTS_PATH) #num_patches = x.shape[1]*x.shape[2] - - patches = Patches(transformer_patchsize_x, transformer_patchsize_y)(x) + + # rs: fixme patch size not configurable anymore... + #patches = Patches(transformer_patchsize_x, transformer_patchsize_y)(inputs) + patches = Patches()(x) + assert transformer_patchsize_x == transformer_patchsize_y == 1 # Encode patches. - encoded_patches = PatchEncoder(num_patches, transformer_projection_dim)(patches) + # rs: fixme num patches and dim not configurable anymore... + #encoded_patches = PatchEncoder(num_patches, transformer_projection_dim)(patches) + encoded_patches = PatchEncoder()(patches) + assert num_patches == 21 * 21 + assert transformer_projection_dim == 64 for _ in range(transformer_layers): # Layer normalization 1. - x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches) + x1 = LayerNormalization(epsilon=1e-6)(encoded_patches) # Create a multi-head attention layer. - attention_output = layers.MultiHeadAttention( + attention_output = MultiHeadAttention( num_heads=transformer_num_heads, key_dim=transformer_projection_dim, dropout=0.1 )(x1, x1) # Skip connection 1. - x2 = layers.Add()([attention_output, encoded_patches]) + x2 = Add()([attention_output, encoded_patches]) # Layer normalization 2. - x3 = layers.LayerNormalization(epsilon=1e-6)(x2) + x3 = LayerNormalization(epsilon=1e-6)(x2) # MLP. x3 = mlp(x3, hidden_units=transformer_mlp_head_units, dropout_rate=0.1) # Skip connection 2. - encoded_patches = layers.Add()([x3, x2]) + encoded_patches = Add()([x3, x2]) + assert isinstance(x, Layer) encoded_patches = tf.reshape(encoded_patches, [-1, x.shape[1], x.shape[2], transformer_projection_dim // (transformer_patchsize_x * @@ -551,7 +521,7 @@ def vit_resnet50_unet_transformer_before_cnn(num_patches, pretraining=False): if transformer_mlp_head_units is None: transformer_mlp_head_units = [128, 64] - inputs = layers.Input(shape=(input_height, input_width, 3)) + inputs = Input(shape=(input_height, input_width, 3)) ##transformer_units = [ ##projection_dim * 2, @@ -560,25 +530,32 @@ def vit_resnet50_unet_transformer_before_cnn(num_patches, IMAGE_ORDERING = 'channels_last' bn_axis=3 - patches = Patches(transformer_patchsize_x, transformer_patchsize_y)(inputs) + # rs: fixme patch size not configurable anymore... + #patches = Patches(transformer_patchsize_x, transformer_patchsize_y)(inputs) + patches = Patches()(inputs) + assert transformer_patchsize_x == transformer_patchsize_y == 1 # Encode patches. - encoded_patches = PatchEncoder(num_patches, transformer_projection_dim)(patches) + # rs: fixme num patches and dim not configurable anymore... + #encoded_patches = PatchEncoder(num_patches, transformer_projection_dim)(patches) + encoded_patches = PatchEncoder()(patches) + assert num_patches == 21 * 21 + assert transformer_projection_dim == 64 for _ in range(transformer_layers): # Layer normalization 1. - x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches) + x1 = LayerNormalization(epsilon=1e-6)(encoded_patches) # Create a multi-head attention layer. - attention_output = layers.MultiHeadAttention( + attention_output = MultiHeadAttention( num_heads=transformer_num_heads, key_dim=transformer_projection_dim, dropout=0.1 )(x1, x1) # Skip connection 1. - x2 = layers.Add()([attention_output, encoded_patches]) + x2 = Add()([attention_output, encoded_patches]) # Layer normalization 2. - x3 = layers.LayerNormalization(epsilon=1e-6)(x2) + x3 = LayerNormalization(epsilon=1e-6)(x2) # MLP. x3 = mlp(x3, hidden_units=transformer_mlp_head_units, dropout_rate=0.1) # Skip connection 2. - encoded_patches = layers.Add()([x3, x2]) + encoded_patches = Add()([x3, x2]) encoded_patches = tf.reshape(encoded_patches, [-1, @@ -734,9 +711,6 @@ def resnet50_classifier(n_classes,input_height=224,input_width=224,weight_decay= x = Dense(n_classes, activation='softmax', name='fc1000')(x) model = Model(img_input, x) - - - return model def machine_based_reading_order_model(n_classes,input_height=224,input_width=224,weight_decay=1e-6,pretraining=False): @@ -793,3 +767,81 @@ def machine_based_reading_order_model(n_classes,input_height=224,input_width=224 model = Model(img_input , o) return model + +def cnn_rnn_ocr_model(image_height=None, image_width=None, n_classes=None, max_seq=None): + input_img = Input(shape=(image_height, image_width, 3), name="image") + labels = Input(name="label", shape=(None,)) + + x = Conv2D(64,kernel_size=(3,3),padding="same")(input_img) + x = BatchNormalization(name="bn1")(x) + x = Activation("relu", name="relu1")(x) + x = Conv2D(64,kernel_size=(3,3),padding="same")(x) + x = BatchNormalization(name="bn2")(x) + x = Activation("relu", name="relu2")(x) + x = MaxPool2D(pool_size=(1,2),strides=(1,2))(x) + + x = Conv2D(128,kernel_size=(3,3),padding="same")(x) + x = BatchNormalization(name="bn3")(x) + x = Activation("relu", name="relu3")(x) + x = Conv2D(128,kernel_size=(3,3),padding="same")(x) + x = BatchNormalization(name="bn4")(x) + x = Activation("relu", name="relu4")(x) + x = MaxPool2D(pool_size=(1,2),strides=(1,2))(x) + + x = Conv2D(256,kernel_size=(3,3),padding="same")(x) + x = BatchNormalization(name="bn5")(x) + x = Activation("relu", name="relu5")(x) + x = Conv2D(256,kernel_size=(3,3),padding="same")(x) + x = BatchNormalization(name="bn6")(x) + x = Activation("relu", name="relu6")(x) + x = MaxPool2D(pool_size=(2,2),strides=(2,2))(x) + + x = Conv2D(image_width,kernel_size=(3,3),padding="same")(x) + x = BatchNormalization(name="bn7")(x) + x = Activation("relu", name="relu7")(x) + x = Conv2D(image_width,kernel_size=(16,1))(x) + x = BatchNormalization(name="bn8")(x) + x = Activation("relu", name="relu8")(x) + x2d = MaxPool2D(pool_size=(1,2),strides=(1,2))(x) + x4d = MaxPool2D(pool_size=(1,2),strides=(1,2))(x2d) + + + new_shape = (x.shape[1]*x.shape[2], x.shape[3]) + new_shape2 = (x2d.shape[1]*x2d.shape[2], x2d.shape[3]) + new_shape4 = (x4d.shape[1]*x4d.shape[2], x4d.shape[3]) + + x = Reshape(target_shape=new_shape, name="reshape")(x) + x2d = Reshape(target_shape=new_shape2, name="reshape2")(x2d) + x4d = Reshape(target_shape=new_shape4, name="reshape4")(x4d) + + xrnnorg = Bidirectional(LSTM(image_width, return_sequences=True, dropout=0.25))(x) + xrnn2d = Bidirectional(LSTM(image_width, return_sequences=True, dropout=0.25))(x2d) + xrnn4d = Bidirectional(LSTM(image_width, return_sequences=True, dropout=0.25))(x4d) + + xrnn2d = Reshape(target_shape=(1, xrnn2d.shape[1], xrnn2d.shape[2]), name="reshape6")(xrnn2d) + xrnn4d = Reshape(target_shape=(1, xrnn4d.shape[1], xrnn4d.shape[2]), name="reshape8")(xrnn4d) + + + xrnn2dup = UpSampling2D(size=(1, 2), interpolation="nearest")(xrnn2d) + xrnn4dup = UpSampling2D(size=(1, 4), interpolation="nearest")(xrnn4d) + + xrnn2dup = Reshape(target_shape=(xrnn2dup.shape[2], xrnn2dup.shape[3]), name="reshape10")(xrnn2dup) + xrnn4dup = Reshape(target_shape=(xrnn4dup.shape[2], xrnn4dup.shape[3]), name="reshape12")(xrnn4dup) + + addition = Add()([xrnnorg, xrnn2dup, xrnn4dup]) + + addition_rnn = Bidirectional(LSTM(image_width, return_sequences=True, dropout=0.25))(addition) + + out = Conv1D(max_seq, 1, data_format="channels_first")(addition_rnn) + out = BatchNormalization(name="bn9")(out) + out = Activation("relu", name="relu9")(out) + #out = Conv1D(n_classes, 1, activation='relu', data_format="channels_last")(out) + + out = Dense(n_classes, activation="softmax", name="dense2")(out) + + # Add CTC layer for calculating CTC loss at each step. + output = CTCLayer(name="ctc_loss")(labels, out) + + model = Model(inputs=[input_img, labels], outputs=output, name="handwriting_recognizer") + + return model diff --git a/src/eynollah/training/train.py b/src/eynollah/training/train.py index 7cf7536..61dbdf7 100644 --- a/src/eynollah/training/train.py +++ b/src/eynollah/training/train.py @@ -17,6 +17,7 @@ from eynollah.training.models import ( resnet50_unet, vit_resnet50_unet, vit_resnet50_unet_transformer_before_cnn, + cnn_rnn_ocr_model, RESNET50_WEIGHTS_PATH, RESNET50_WEIGHTS_URL ) @@ -25,7 +26,6 @@ from eynollah.training.utils import ( generate_arrays_from_folder_reading_order, get_one_hot, preprocess_imgs, - return_number_of_total_training_data ) os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' @@ -35,11 +35,10 @@ from tensorflow.keras.optimizers import SGD, Adam from tensorflow.keras.metrics import MeanIoU, F1Score from tensorflow.keras.models import load_model from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard +from tensorflow.keras.layers import StringLookup from tensorflow.keras.utils import image_dataset_from_directory from sacred import Experiment from sacred.config import create_captured_function -from tqdm import tqdm -from sklearn.metrics import f1_score import numpy as np import cv2 @@ -68,6 +67,7 @@ class SaveWeightsAfterSteps(ModelCheckpoint): json.dump(self._config, fp) # encode dict into JSON + def configuration(): try: for device in tf.config.list_physical_devices('GPU'): @@ -111,6 +111,9 @@ def config_params(): n_classes = None # Number of classes. In the case of binary classification this should be 2. n_epochs = 1 # Number of epochs to train. n_batch = 1 # Number of images per batch at each iteration. (Try as large as fits on VRAM.) + if task == 'cnn-rnn-ocr': + max_len = None # Maximum sequence length (characters per line) for OCR output. + characters_txt_file = None # Path of JSON file defining character set needed of OCR model. input_height = 224 * 1 # Height of model's input in pixels. input_width = 224 * 1 # Width of model's input in pixels. weight_decay = 1e-6 # Weight decay of l2 regularization of model layers. @@ -124,47 +127,74 @@ def config_params(): patches = False # Divides input image into smaller patches (input size of the model) when set to true. For the model to see the full image, like page extraction, set this to false. augmentation = False # To apply any kind of augmentation, this parameter must be set to true. if augmentation: - flip_aug = False # If true, different types of flipping will be applied to the image. Types of flips are defined with "flip_index" in config_params.json. + flip_aug = False # Whether different types of flipping will be applied to the image. Requires "flip_index" setting. if flip_aug: - flip_index = None # Flip image for augmentation. - blur_aug = False # If true, different types of blurring will be applied to the image. Types of blur are defined with "blur_k" in config_params.json. + flip_index = None # List of codes (as in cv2.flip) for flip augmentation. + blur_aug = False # Whether images will be blurred. Requires "blur_k" setting. if blur_aug: - blur_k = None # Blur image for augmentation. + blur_k = None # Method of blurring (gauss, median or blur). padding_white = False # If true, white padding will be applied to the image. + if padding_white and task == 'cnn-rnn-ocr': + white_padds = None # List of padding sizes. + padd_colors = None # List of padding colors, but only "white" or "black" or both. padding_black = False # If true, black padding will be applied to the image. - scaling = False # If true, scaling will be applied to the image. The amount of scaling is defined with "scales" in config_params.json. - scaling_bluring = False # If true, a combination of scaling and blurring will be applied to the image. - scaling_binarization = False # If true, a combination of scaling and binarization will be applied to the image. - scaling_brightness = False # If true, a combination of scaling and brightening will be applied to the image. - scaling_flip = False # If true, a combination of scaling and flipping will be applied to the image. + scaling = False # Whether images will be scaled up or down. Requires "scales" setting. + scaling_bluring = False # Whether a combination of scaling and blurring will be applied to the image. + scaling_binarization = False # Whether a combination of scaling and binarization will be applied to the image. + scaling_brightness = False # Whether a combination of scaling and brightening will be applied to the image. + scaling_flip = False # Whether a combination of scaling and flipping will be applied to the image. if scaling or scaling_brightness or scaling_bluring or scaling_binarization or scaling_flip: scales = None # Scale patches for augmentation. shifting = False - degrading = False # If true, degrading will be applied to the image. The amount of degrading is defined with "degrade_scales" in config_params.json. - if degrading: - degrade_scales = None # Degrade image for augmentation. - brightening = False # If true, brightening will be applied to the image. The amount of brightening is defined with "brightness" in config_params.json. + brightening = False # Whether images will be brightened. Requires "brightness" setting. if brightening: - brightness = None # Brighten image for augmentation. - binarization = False # If true, Otsu thresholding will be applied to augment the input with binarized images. + brightness = None # List of intensity factors for brightening. + binarization = False # Whether binary images will be used, too. (Will use Otsu thresholding unless supplying precomputed images in "dir_img_bin".) if binarization: dir_img_bin = None # Directory of training dataset subdirectory of binarized images add_red_textlines = False - adding_rgb_background = False + adding_rgb_background = False # Whether texture images will be added as artificial background. if adding_rgb_background: dir_rgb_backgrounds = None # Directory of texture images for synthetic background - adding_rgb_foreground = False + adding_rgb_foreground = False # Whether texture images will be added as artificial foreground. if adding_rgb_foreground: dir_rgb_foregrounds = None # Directory of texture images for synthetic foreground if adding_rgb_background or adding_rgb_foreground: number_of_backgrounds_per_image = 1 + if task == 'cnn-rnn-ocr': + image_inversion = False # Whether the binarized images will be inverted. + textline_skewing_bin = False # Whether binarized textline images will be rotated. + textline_left_in_depth_bin = False # Whether left side of binary textline image will be displayed in depth. + textline_right_in_depth_bin = False # Whether right side of binary textline image will be displayed in depth. + textline_up_in_depth_bin = False # Whether upper side of binary textline image will be displayed in depth. + textline_down_in_depth_bin = False # Whether lower side of binary textline image will be displayed in depth. + pepper_bin_aug = False # Whether pepper noise will be added to binary textline images. + bin_deg = False # Whether a combination of degrading and binarization will be applied to the image. + degrading = False # Whether images will be artificially degraded. Requires the "degrade_scales" setting. + if degrading or binarization and task == 'cnn-rnn-ocr' and bin_deg: + degrade_scales = None # List of quality factors for degradation. channels_shuffling = False # Re-arrange color channels. if channels_shuffling: - shuffle_indexes = None # Which channels to switch between. - rotation = False # If true, a 90 degree rotation will be implemeneted. - rotation_not_90 = False # If true rotation based on provided angles with thetha will be implemeneted. + shuffle_indexes = None # List of channels to switch between. + rotation = False # Whether images will be rotated by 90 degrees. + rotation_not_90 = False # Whether images will be rotated arbitrarily (skewed). Requires "thetha" setting. if rotation_not_90: - thetha = None # Rotate image by these angles for augmentation. + thetha = None # List of rotation angles in degrees. + if task == 'cnn-rnn-ocr': + white_noise_strap = False # Whether white noise will be applied on some straps on the textline image. + textline_skewing = False # Whether textline images will be skewed for augmentation. + if textline_skewing or binarization and textline_skewing_bin: + skewing_amplitudes = None # List of skewing angles in degrees like [5, 8] + textline_left_in_depth = False # If true, left side of textline image will be displayed in depth. + textline_right_in_depth = False # If true, right side of textline image will be displayed in depth. + textline_up_in_depth = False # If true, upper side of textline image will be displayed in depth. + textline_down_in_depth = False # If true, lower side of textline image will be displayed in depth. + pepper_aug = False # Whether pepper noise will be added to textline images. + if pepper_aug or binarization and pepper_bin_aug: + pepper_indexes = None # List of pepper noise factors, e.g. [0.01, 0.005]. + color_padding_rotation = False # Whether images will be rotated with color padding. Requires "thetha_padd" setting. + if color_padding_rotation: + thetha_padd = None # List of angles (in degrees) used for rotation alongside padding. dir_train = None # Directory of training dataset with subdirectories having the names "images" and "labels". dir_eval = None # Directory of validation dataset with subdirectories having the names "images" and "labels". dir_output = None # Directory where the augmented training data and the model checkpoints will be saved. @@ -197,12 +227,15 @@ def run(_config, augmentation, # dependent config keys need a default, # otherwise yields sacred.utils.ConfigAddedError + ## if rotation_not_90 thetha=None, is_loss_soft_dice=False, weighted_loss=False, + ## if continue_training index_start=0, dir_of_start_model=None, backbone_type=None, + ## if backbone_type=transformer transformer_projection_dim=None, transformer_mlp_head_units=None, transformer_layers=None, @@ -211,8 +244,33 @@ def run(_config, transformer_patchsize_x=None, transformer_patchsize_y=None, transformer_num_patches_xy=None, + ## if task=classification f1_threshold_classification=None, classification_classes_name=None, + ## if task=cnn-rnn-ocr + characters_txt_file=None, + color_padding_rotation=False, + thetha_padd=None, + bin_deg=False, + image_inversion=False, + white_noise_strap=False, + textline_skewing=False, + textline_skewing_bin=False, + textline_left_in_depth=False, + textline_left_in_depth_bin=False, + textline_right_in_depth=False, + textline_right_in_depth_bin=False, + textline_up_in_depth=False, + textline_up_in_depth_bin=False, + textline_down_in_depth=False, + textline_down_in_depth_bin=False, + pepper_aug=False, + pepper_bin_aug=False, + pepper_indexes=None, + padd_colors=None, + white_padds=None, + skewing_amplitudes=None, + max_len=None, ): if pretraining and not os.path.isfile(RESNET50_WEIGHTS_PATH): @@ -252,11 +310,11 @@ def run(_config, dir_img, dir_seg = get_dirs_or_files(dir_train) dir_img_val, dir_seg_val = get_dirs_or_files(dir_eval) - imgs_list=np.array(os.listdir(dir_img)) - segs_list=np.array(os.listdir(dir_seg)) + imgs_list = list(os.listdir(dir_img)) + segs_list = list(os.listdir(dir_seg)) - imgs_list_test=np.array(os.listdir(dir_img_val)) - segs_list_test=np.array(os.listdir(dir_seg_val)) + imgs_list_test = list(os.listdir(dir_img_val)) + segs_list_test = list(os.listdir(dir_seg_val)) # writing patches into a sub-folder in order to be flowed from directory. preprocess_imgs(_config, @@ -356,6 +414,7 @@ def run(_config, model_builder.logger = _log model = model_builder(num_patches) + assert model is not None #if you want to see the model structure just uncomment model summary. #model.summary() @@ -412,7 +471,80 @@ def run(_config, #os.system('rm -rf '+dir_eval_flowing) #model.save(dir_output+'/'+'model'+'.h5') + + elif task=="cnn-rnn-ocr": + dir_img, dir_lab = get_dirs_or_files(dir_train) + dir_img_val, dir_lab_val = get_dirs_or_files(dir_eval) + imgs_list = list(os.listdir(dir_img)) + labs_list = list(os.listdir(dir_lab)) + imgs_list_val = list(os.listdir(dir_img_val)) + labs_list_val = list(os.listdir(dir_lab_val)) + + with open(characters_txt_file, 'r') as char_txt_f: + characters = json.load(char_txt_f) + padding_token = len(characters) + 5 + # Mapping characters to integers. + char_to_num = StringLookup(vocabulary=list(characters), mask_token=None) + + # Mapping integers back to original characters. + ##num_to_char = StringLookup( + ##vocabulary=char_to_num.get_vocabulary(), mask_token=None, invert=True + ##) + n_classes = len(char_to_num.get_vocabulary()) + 2 + + if continue_training: + model = load_model(dir_of_start_model) + else: + index_start = 0 + model = cnn_rnn_ocr_model(image_height=input_height, + image_width=input_width, + n_classes=n_classes, + max_seq=max_len) + #print(model.summary()) + + # todo: use Dataset.map() on Dataset.list_files() + # todo: test_ds + def gen(): + return preprocess_imgs(_config, + imgs_list, + labs_list, + dir_img, + dir_lab, + None, # no file I/O, but in-memory + None, # no file I/O, but in-memory + # extra+overrides + char_to_num=char_to_num, + padding_token=padding_token + ) + train_ds = tf.data.Dataset.from_generator(gen) + train_ds = train_ds.padded_batch(n_batch, + padded_shapes=([image_height, image_width, 3], [None]), + padding_values=(0, padding_token), + drop_remainder=True, + #num_parallel_calls=tf.data.AUTOTUNE, + ) + train_ds = train_ds.repeat().shuffle().prefetch(20) + + #initial_learning_rate = 1e-4 + #decay_steps = int (n_epochs * ( len_dataset / n_batch )) + #alpha = 0.01 + #lr_schedule = 1e-4 + #tf.keras.optimizers.schedules.CosineDecay(initial_learning_rate, decay_steps, alpha) + opt = tf.keras.optimizers.Adam(learning_rate=learning_rate) + model.compile(optimizer=opt) # rs: loss seems to be (ctc_batch_cost) in last layer + + callbacks = [TensorBoard(os.path.join(dir_output, 'logs'), write_graph=False), + SaveWeightsAfterSteps(0, dir_output, _config)] + if save_interval: + callbacks.append(SaveWeightsAfterSteps(save_interval, dir_output, _config)) + model.fit( + train_ds, + #validation_data=test_ds, + epochs=n_epochs, + callbacks=callbacks, + initial_epoch=index_start) + elif task=='classification': if continue_training: model = load_model(dir_of_start_model, compile=False) diff --git a/src/eynollah/training/utils.py b/src/eynollah/training/utils.py index 5b25a4f..56d6bdf 100644 --- a/src/eynollah/training/utils.py +++ b/src/eynollah/training/utils.py @@ -2,6 +2,7 @@ import os import math import random from logging import getLogger +from pathlib import Path import cv2 import numpy as np @@ -10,8 +11,218 @@ from scipy.ndimage.interpolation import map_coordinates from scipy.ndimage.filters import gaussian_filter from tqdm import tqdm import imutils -from tensorflow.keras.utils import to_categorical -from PIL import Image, ImageEnhance +import tensorflow as tf + +from PIL import Image, ImageFile, ImageEnhance + +ImageFile.LOAD_TRUNCATED_IMAGES = True + + +def vectorize_label(label, char_to_num, padding_token, max_len): + label = char_to_num(tf.strings.unicode_split(label, input_encoding="UTF-8")) + length = tf.shape(label)[0] + pad_amount = max_len - length + label = tf.pad(label, paddings=[[0, pad_amount]], constant_values=padding_token) + return label + +def scale_padd_image_for_ocr(img, height, width): + ratio = height /float(img.shape[0]) + + w_ratio = int(ratio * img.shape[1]) + + if w_ratio<=width: + width_new = w_ratio + else: + width_new = width + + if width_new <= 0: + width_new = width + + img_res= resize_image (img, height, width_new) + img_fin = np.ones((height, width, 3))*255 + + img_fin[:,:width_new,:] = img_res[:,:,:] + return img_fin + +# TODO: document where this is from +def add_salt_and_pepper_noise(img, salt_prob, pepper_prob): + """ + Add salt-and-pepper noise to an image. + + Parameters: + image: ndarray + Input image. + salt_prob: float + Probability of salt noise. + pepper_prob: float + Probability of pepper noise. + + Returns: + noisy_image: ndarray + Image with salt-and-pepper noise. + """ + # Make a copy of the image + noisy_image = np.copy(img) + + # Generate random noise + total_pixels = img.size + num_salt = int(salt_prob * total_pixels) + num_pepper = int(pepper_prob * total_pixels) + + # Add salt noise + coords = [np.random.randint(0, i - 1, num_salt) for i in img.shape[:2]] + noisy_image[coords[0], coords[1]] = 255 # white pixels + + # Add pepper noise + coords = [np.random.randint(0, i - 1, num_pepper) for i in img.shape[:2]] + noisy_image[coords[0], coords[1]] = 0 # black pixels + + return noisy_image + +def invert_image(img): + img_inv = 255 - img + return img_inv + +def return_image_with_strapped_white_noises(img): + img_w_noised = np.copy(img) + img_h, img_width = img.shape[0], img.shape[1] + n = 9 + p = 0.3 + num_windows = np.random.binomial(n, p, 1)[0] + + if num_windows<1: + num_windows = 1 + + loc_of_windows = np.random.uniform(0,img_width,num_windows).astype(np.int64) + width_windows = np.random.uniform(10,50,num_windows).astype(np.int64) + + for i, loc in enumerate(loc_of_windows): + noise = np.random.normal(0, 50, (img_h, width_windows[i], 3)) + + try: + img_w_noised[:, loc:loc+width_windows[i], : ] = noise[:,:,:] + except: + pass + return img_w_noised + +def do_padding_for_ocr(img, percent_height, padding_color): + padding_size = int( img.shape[0]*percent_height/2. ) + height_new = img.shape[0] + 2*padding_size + width_new = img.shape[1] + 2*padding_size + + h_start = padding_size + w_start = padding_size + + if padding_color == 'white': + img_new = np.ones((height_new, width_new, img.shape[2])).astype(float) * 255 + elif padding_color == 'black': + img_new = np.zeros((height_new, width_new, img.shape[2])).astype(float) + else: + raise ValueError("padding_color must be 'white' or 'black'") + + img_new[h_start:h_start + img.shape[0], w_start:w_start + img.shape[1], :] = np.copy(img[:, :, :]) + + + return img_new + +# TODO: document where this is from +def do_deskewing(img, amplitude): + height, width = img.shape[:2] + + # Generate sinusoidal wave distortion with reduced amplitude + #amplitude = 8 # 5 # Reduce the amplitude for less curvature + frequency = 300 # Increase frequency to stretch the curve + x_indices = np.tile(np.arange(width), (height, 1)) + y_indices = np.arange(height).reshape(-1, 1) + amplitude * np.sin(2 * np.pi * x_indices / frequency) + + # Convert indices to float32 for remapping + map_x = x_indices.astype(np.float32) + map_y = y_indices.astype(np.float32) + + # Apply the remap to create the curve + curved_image = cv2.remap(img, map_x, map_y, interpolation=cv2.INTER_LINEAR, borderMode=cv2.BORDER_CONSTANT) + return curved_image + +# TODO: document where this is from +def do_direction_in_depth(img, direction: str): + height, width = img.shape[:2] + + if direction == 'left': + # Define the original corner points of the image + src_points = np.float32([ + [0, 0], # Top-left corner + [width, 0], # Top-right corner + [0, height], # Bottom-left corner + [width, height] # Bottom-right corner + ]) + + # Define the new corner points for a subtle right-to-left tilt + dst_points = np.float32([ + [2, 13], # Slight inward shift for top-left + [width, 0], # Slight downward shift for top-right + [2, height-13], # Slight inward shift for bottom-left + [width, height] # Slight upward shift for bottom-right + ]) + elif direction == 'right': + # Define the original corner points of the image + src_points = np.float32([ + [0, 0], # Top-left corner + [width, 0], # Top-right corner + [0, height], # Bottom-left corner + [width, height] # Bottom-right corner + ]) + + # Define the new corner points for a subtle right-to-left tilt + dst_points = np.float32([ + [0, 0], # Slight inward shift for top-left + [width, 13], # Slight downward shift for top-right + [0, height], # Slight inward shift for bottom-left + [width, height - 13] # Slight upward shift for bottom-right + ]) + + elif direction == 'up': + # Define the original corner points of the image + src_points = np.float32([ + [0, 0], # Top-left corner + [width, 0], # Top-right corner + [0, height], # Bottom-left corner + [width, height] # Bottom-right corner + ]) + + # Define the new corner points to simulate a tilted perspective + # Make the top part appear closer and the bottom part farther + dst_points = np.float32([ + [50, 0], # Top-left moved inward + [width - 50, 0], # Top-right moved inward + [0, height], # Bottom-left remains the same + [width, height] # Bottom-right remains the same + ]) + elif direction == 'down': + # Define the original corner points of the image + src_points = np.float32([ + [0, 0], # Top-left corner + [width, 0], # Top-right corner + [0, height], # Bottom-left corner + [width, height] # Bottom-right corner + ]) + + # Define the new corner points to simulate a tilted perspective + # Make the top part appear closer and the bottom part farther + dst_points = np.float32([ + [0, 0], # Top-left moved inward + [width, 0], # Top-right moved inward + [50, height], # Bottom-left remains the same + [width - 50, height] # Bottom-right remains the same + ]) + else: + raise ValueError("direction must be 'left', 'right', 'up' or 'down'") + + # Compute the perspective transformation matrix + matrix = cv2.getPerspectiveTransform(src_points, dst_points) + + # Apply the perspective warp + warped_image = cv2.warpPerspective(img, matrix, (width, height)) + return warped_image def return_shuffled_channels(img, channels_order): @@ -25,6 +236,7 @@ def return_shuffled_channels(img, channels_order): img_sh[:,:,2]= img[:,:,channels_order[2]] return img_sh +# TODO: Refactor into one {{{ def return_binary_image_with_red_textlines(img_bin): img_red = np.copy(img_bin) @@ -79,6 +291,8 @@ def return_image_with_red_elements(img, img_bin): img_final[:,:,1][img_bin[:,:,0]==0] = 0 img_final[:,:,2][img_bin[:,:,0]==0] = 255 return img_final + +# }}} def shift_image_and_label(img, label, type_shift): h_n = int(img.shape[0]*1.06) @@ -164,64 +378,6 @@ def return_number_of_total_training_data(path_classes): n_tot = n_tot + len(sub_files) return n_tot - - -def generate_data_from_folder(path_classes, batchsize, height, width, n_classes, list_classes, shuffle=False): - #sub_classes = os.listdir(path_classes) - #n_classes = len(sub_classes) - - all_imgs = [] - labels = [] - #dicts =dict() - #indexer= 0 - for indexer, sub_c in enumerate(list_classes): - sub_files = os.listdir(os.path.join(path_classes,sub_c )) - sub_files = [os.path.join(path_classes,sub_c )+'/' + x for x in sub_files] - #print( os.listdir(os.path.join(path_classes,sub_c )) ) - all_imgs = all_imgs + sub_files - sub_labels = list( np.zeros( len(sub_files) ) +indexer ) - - #print( len(sub_labels) ) - labels = labels + sub_labels - #dicts[sub_c] = indexer - #indexer +=1 - - if shuffle: - ids = np.array(range(len(labels))) - random.shuffle(ids) - labels = np.array(labels)[ids] - all_imgs = np.array(all_imgs)[ids] - - categories = to_categorical(range(n_classes)).astype(np.int16)#[ [1 , 0, 0 , 0 , 0 , 0] , [0 , 1, 0 , 0 , 0 , 0] , [0 , 0, 1 , 0 , 0 , 0] , [0 , 0, 0 , 1 , 0 , 0] , [0 , 0, 0 , 0 , 1 , 0] , [0 , 0, 0 , 0 , 0 , 1] ] - ret_x= np.zeros((batchsize, height,width, 3)).astype(np.uint8) - ret_y= np.zeros((batchsize, n_classes)).astype(float) - batchcount = 0 - while True: - for lab, img in zip(labels, all_imgs): - ###img = cv2.imread(img, 0) - ###img= resize_image (img, height, width) - ###img = img.astype(np.uint16) - ###ret_x[batchcount, :,:,0] = img[:,:] - ###ret_x[batchcount, :,:,1] = img[:,:] - ###ret_x[batchcount, :,:,2] = img[:,:] - - img = cv2.imread(img) - img= resize_image (img, height, width) - img = img.astype(np.uint16) - ret_x[batchcount, :,:,:] = img[:,:,:] - - #print(int(shuffled_labels[i]) ) - #print( categories[int(shuffled_labels[i])] ) - ret_y[batchcount, :] = categories[int(lab)][:] - - batchcount+=1 - - if batchcount>=batchsize: - ret_x = ret_x//255 - yield ret_x, ret_y - ret_x[:] = 0 - ret_y[:] = 0 - batchcount = 0 def do_brightening(img, factor): img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) @@ -241,9 +397,12 @@ def bluring(img_in, kind): img_blur = cv2.medianBlur(img_in, 5) elif kind == 'blur': img_blur = cv2.blur(img_in, (5, 5)) + else: + raise ValueError("kind must be 'gauss', 'median' or 'blur'") return img_blur +# TODO: document where this is from def elastic_transform(image, alpha, sigma, seedj, random_state=None): """Elastic deformation of images as described in [Simard2003]_. .. [Simard2003] Simard, Steinkraus and Platt, "Best Practices for @@ -266,6 +425,7 @@ def elastic_transform(image, alpha, sigma, seedj, random_state=None): return distored_image.reshape(image.shape) +# TODO: Use one of the utils/rotate.py functions for this def rotation_90(img): img_rot = np.zeros((img.shape[1], img.shape[0], img.shape[2])) img_rot[:, :, 0] = img[:, :, 0].T @@ -274,6 +434,8 @@ def rotation_90(img): return img_rot +# TODO: document where this is from +# TODO: Use one of the utils/rotate.py functions for this def rotatedRectWithMaxArea(w, h, angle): """ Given a rectangle of size wxh that has been rotated by 'angle' (in @@ -302,6 +464,7 @@ def rotatedRectWithMaxArea(w, h, angle): return wr, hr +# TODO: Use one of the utils/rotate.py functions for this def rotate_max_area(image, rotated, rotated_label, angle): """ image: cv2 image matrix object angle: in degree @@ -315,6 +478,7 @@ def rotate_max_area(image, rotated, rotated_label, angle): x2 = x1 + int(wr) return rotated[y1:y2, x1:x2], rotated_label[y1:y2, x1:x2] +# TODO: Use one of the utils/rotate.py functions for this def rotate_max_area_single_image(image, rotated, angle): """ image: cv2 image matrix object angle: in degree @@ -328,12 +492,14 @@ def rotate_max_area_single_image(image, rotated, angle): x2 = x1 + int(wr) return rotated[y1:y2, x1:x2] +# TODO: Use one of the utils/rotate.py functions for this def rotation_not_90_func(img, label, thetha): rotated = imutils.rotate(img, thetha) rotated_label = imutils.rotate(label, thetha) return rotate_max_area(img, rotated, rotated_label, thetha) +# TODO: Use one of the utils/rotate.py functions for this def rotation_not_90_func_single_image(img, thetha): rotated = imutils.rotate(img, thetha) return rotate_max_area_single_image(img, rotated, thetha) @@ -356,6 +522,7 @@ def color_images(seg, n_classes): return seg_img +# TODO: use resize_image from utils def resize_image(seg_in, input_height, input_width): return cv2.resize(seg_in, (input_width, input_height), interpolation=cv2.INTER_NEAREST) @@ -368,6 +535,7 @@ def get_one_hot(seg, input_height, input_width, n_classes): return seg_f +# TODO: document where this is from def IoU(Yi, y_predi): ## mean Intersection over Union ## Mean IoU = TP/(FN + TP + FP) @@ -386,10 +554,10 @@ def IoU(Yi, y_predi): #print("Mean IoU: {:4.3f}".format(mIoU)) return mIoU -def generate_arrays_from_folder_reading_order(classes_file_dir, modal_dir, batchsize, height, width, n_classes, thetha, augmentation=False): +def generate_arrays_from_folder_reading_order(classes_file_dir, modal_dir, n_batch, height, width, n_classes, thetha, augmentation=False): all_labels_files = os.listdir(classes_file_dir) - ret_x= np.zeros((batchsize, height, width, 3))#.astype(np.int16) - ret_y= np.zeros((batchsize, n_classes)).astype(np.int16) + ret_x= np.zeros((n_batch, height, width, 3))#.astype(np.int16) + ret_y= np.zeros((n_batch, n_classes)).astype(np.int16) batchcount = 0 while True: for i in all_labels_files: @@ -404,10 +572,10 @@ def generate_arrays_from_folder_reading_order(classes_file_dir, modal_dir, batch ret_y[batchcount, :] = label_class batchcount+=1 - if batchcount>=batchsize: + if batchcount>=n_batch: yield ret_x, ret_y - ret_x= np.zeros((batchsize, height, width, 3))#.astype(np.int16) - ret_y= np.zeros((batchsize, n_classes)).astype(np.int16) + ret_x= np.zeros((n_batch, height, width, 3))#.astype(np.int16) + ret_y= np.zeros((n_batch, n_classes)).astype(np.int16) batchcount = 0 if augmentation: @@ -422,10 +590,10 @@ def generate_arrays_from_folder_reading_order(classes_file_dir, modal_dir, batch ret_y[batchcount, :] = label_class batchcount+=1 - if batchcount>=batchsize: + if batchcount>=n_batch: yield ret_x, ret_y - ret_x= np.zeros((batchsize, height, width, 3))#.astype(np.int16) - ret_y= np.zeros((batchsize, n_classes)).astype(np.int16) + ret_x= np.zeros((n_batch, height, width, 3))#.astype(np.int16) + ret_y= np.zeros((n_batch, n_classes)).astype(np.int16) batchcount = 0 def data_gen(img_folder, mask_folder, batch_size, input_height, input_width, n_classes, task='segmentation'): @@ -467,6 +635,7 @@ def data_gen(img_folder, mask_folder, batch_size, input_height, input_width, n_c yield img, mask +# TODO: Use otsu_copy from utils def otsu_copy(img): img_r = np.zeros(img.shape) img1 = img[:, :, 0] @@ -481,7 +650,7 @@ def otsu_copy(img): return img_r -def get_patches(dir_img_f, dir_seg_f, img, label, height, width, indexer): +def get_patches(img, label, height, width): if img.shape[0] < height or img.shape[1] < width: img, label = do_padding(img, label, height, width) @@ -517,21 +686,16 @@ def get_patches(dir_img_f, dir_seg_f, img, label, height, width, indexer): img_patch = img[index_y_d:index_y_u, index_x_d:index_x_u, :] label_patch = label[index_y_d:index_y_u, index_x_d:index_x_u, :] - cv2.imwrite(dir_img_f + '/img_' + str(indexer) + '.png', img_patch) - cv2.imwrite(dir_seg_f + '/img_' + str(indexer) + '.png', label_patch) - indexer += 1 - - return indexer + yield img_patch, label_patch -def do_padding_white(img): - img_org_h = img.shape[0] - img_org_w = img.shape[1] - +def do_padding_with_color(img, padding_color='black'): index_start_h = 4 index_start_w = 4 - img_padded = np.zeros((img.shape[0] + 2*index_start_h, img.shape[1]+ 2*index_start_w, img.shape[2])) + 255 + img_padded = np.zeros((img.shape[0] + 2*index_start_h, img.shape[1]+ 2*index_start_w, img.shape[2])) + if padding_color == 'white': + img_padded += 255 img_padded[index_start_h: index_start_h + img.shape[0], index_start_w: index_start_w + img.shape[1], :] = img[:, :, :] return img_padded.astype(float) @@ -545,20 +709,7 @@ def do_degrading(img, scale): return resize_image(img_res, img_org_h, img_org_w) - -def do_padding_black(img): - img_org_h = img.shape[0] - img_org_w = img.shape[1] - - index_start_h = 4 - index_start_w = 4 - - img_padded = np.zeros((img.shape[0] + 2*index_start_h, img.shape[1] + 2*index_start_w, img.shape[2])) - img_padded[index_start_h: index_start_h + img.shape[0], index_start_w: index_start_w + img.shape[1], :] = img[:, :, :] - - return img_padded.astype(float) - - +# TODO: How is this different from do_padding_black? def do_padding_label(img): img_org_h = img.shape[0] img_org_w = img.shape[1] @@ -595,58 +746,7 @@ def do_padding(img, label, height, width): return img_new,label_new -def get_patches_num_scale(dir_img_f, dir_seg_f, img, label, height, width, indexer, n_patches, scaler): - if img.shape[0] < height or img.shape[1] < width: - img, label = do_padding(img, label, height, width) - - img_h = img.shape[0] - img_w = img.shape[1] - - height_scale = int(height * scaler) - width_scale = int(width * scaler) - - - nxf = img_w / float(width_scale) - nyf = img_h / float(height_scale) - - if nxf > int(nxf): - nxf = int(nxf) + 1 - if nyf > int(nyf): - nyf = int(nyf) + 1 - - nxf = int(nxf) - nyf = int(nyf) - - for i in range(nxf): - for j in range(nyf): - index_x_d = i * width_scale - index_x_u = (i + 1) * width_scale - - index_y_d = j * height_scale - index_y_u = (j + 1) * height_scale - - if index_x_u > img_w: - index_x_u = img_w - index_x_d = img_w - width_scale - if index_y_u > img_h: - index_y_u = img_h - index_y_d = img_h - height_scale - - - img_patch = img[index_y_d:index_y_u, index_x_d:index_x_u, :] - label_patch = label[index_y_d:index_y_u, index_x_d:index_x_u, :] - - img_patch = resize_image(img_patch, height, width) - label_patch = resize_image(label_patch, height, width) - - cv2.imwrite(dir_img_f + '/img_' + str(indexer) + '.png', img_patch) - cv2.imwrite(dir_seg_f + '/img_' + str(indexer) + '.png', label_patch) - indexer += 1 - - return indexer - - -def get_patches_num_scale_new(dir_img_f, dir_seg_f, img, label, height, width, indexer, scaler): +def get_patches_num_scale_new(img, label, height, width, scaler=1.0): img = resize_image(img, int(img.shape[0] * scaler), int(img.shape[1] * scaler)) label = resize_image(label, int(label.shape[0] * scaler), int(label.shape[1] * scaler)) @@ -688,20 +788,17 @@ def get_patches_num_scale_new(dir_img_f, dir_seg_f, img, label, height, width, i img_patch = img[index_y_d:index_y_u, index_x_d:index_x_u, :] label_patch = label[index_y_d:index_y_u, index_x_d:index_x_u, :] - cv2.imwrite(dir_img_f + '/img_' + str(indexer) + '.png', img_patch) - cv2.imwrite(dir_seg_f + '/img_' + str(indexer) + '.png', label_patch) - indexer += 1 - - return indexer + yield img_patch, label_patch +# TODO: refactor to combine with data_gen_ocr def preprocess_imgs(config, imgs_list, - segs_list, + labs_list, dir_img, - dir_seg, + dir_lab, dir_flow_imgs, - dir_flow_labels, + dir_flow_lbls, logger=None, **kwargs, ): @@ -720,32 +817,49 @@ def preprocess_imgs(config, # override keys from call config.update(kwargs) + seed = random.random() + random.shuffle(imgs_list, random=lambda: seed) + random.shuffle(labs_list, random=lambda: seed) + + # labs_list not used because stem matching more robust indexer = 0 - for im, seg_i in tqdm(zip(imgs_list, segs_list)): - img = cv2.imread(os.path.join(dir_img, im)) - img_name = os.path.splitext(im)[0] + for img, lab in tqdm(zip(imgs_list, labs_list)): + img = cv2.imread(os.path.join(dir_img, img)) + img_name = os.path.splitext(img)[0] if config['task'] in ["segmentation", "binarization"]: - lab = cv2.imread(os.path.join(dir_seg, img_name + '.png')) + # assert lab == img_name + '.png' + lab = cv2.imread(os.path.join(dir_lab, img_name + '.png')) elif config['task'] == "enhancement": - lab = cv2.imread(os.path.join(dir_seg, im)) + lab = cv2.imread(os.path.join(dir_lab, img)) + elif config['task'] == "cnn-rnn-ocr": + # assert lab == 'img_name + '.txt' + with open(os.path.join(dir_lab, img_name + '.txt'), 'r') as f: + lab = f.read().split('\n')[0] else: lab = None try: - indexer = preprocess_img(indexer, img, img_name, lab, - dir_flow_imgs, - dir_flow_labels, - **config) - + if config['task'] == "cnn-rnn-ocr": + yield from preprocess_img_ocr(img, img_name, lab, + **config) + continue + for img, lab in preprocess_img(img, img_name, lab, + **config): + cv2.imwrite(os.path.join(dir_flow_imgs, '/img_%d.png' % indexer), + resize_image(img, + config['input_height'], + config['input_width'])) + cv2.imwrite(os.path.join(dir_flow_lbls, '/img_%d.png' % indexer), + resize_image(lab, + config['input_height'], + config['input_width'])) + indexer += 1 except: logger.exception("skipping image %s", img_name) -def preprocess_img(indexer, - img, +def preprocess_img(img, img_name, lab, - dir_flow_train_imgs, - dir_flow_train_labels, input_height=None, input_width=None, augmentation=False, @@ -785,128 +899,39 @@ def preprocess_img(indexer, **kwargs, ): if not patches: - cv2.imwrite(dir_flow_train_imgs + '/img_' + str(indexer) + '.png', - resize_image(img, - input_height, - input_width)) - cv2.imwrite(dir_flow_train_labels + '/img_' + str(indexer) + '.png', - resize_image(lab, - input_height, - input_width)) - indexer += 1 + yield img, lab if augmentation: if flip_aug: for f_i in flip_index: - cv2.imwrite(dir_flow_train_imgs + '/img_' + str(indexer) + '.png', - resize_image(cv2.flip(img, f_i), - input_height, - input_width)) - cv2.imwrite(dir_flow_train_labels + '/img_' + str(indexer) + '.png', - resize_image(cv2.flip(lab, f_i), - input_height, - input_width)) - indexer += 1 + yield cv2.flip(img, f_i), cv2.flip(lab, f_i) if blur_aug: for blur_i in blur_k: - cv2.imwrite(dir_flow_train_imgs + '/img_' + str(indexer) + '.png', - (resize_image(bluring(img, blur_i), - input_height, - input_width))) - cv2.imwrite(dir_flow_train_labels + '/img_' + str(indexer) + '.png', - resize_image(lab, - input_height, - input_width)) - indexer += 1 + yield bluring(img, blur_i), lab if brightening: for factor in brightness: - cv2.imwrite(dir_flow_train_imgs + '/img_' + str(indexer) + '.png', - (resize_image(do_brightening(img, factor), - input_height, - input_width))) - cv2.imwrite(dir_flow_train_labels + '/img_' + str(indexer) + '.png', - resize_image(lab, - input_height, - input_width)) - indexer += 1 + yield do_brightening(img, factor), lab if binarization: if dir_img_bin: img_bin_corr = cv2.imread(dir_img_bin + '/' + img_name+'.png') - cv2.imwrite(dir_flow_train_imgs + '/img_' + str(indexer) + '.png', - resize_image(img_bin_corr, - input_height, - input_width)) else: - cv2.imwrite(dir_flow_train_imgs + '/img_' + str(indexer) + '.png', - resize_image(otsu_copy(img), - input_height, - input_width)) - cv2.imwrite(dir_flow_train_labels + '/img_' + str(indexer) + '.png', - resize_image(lab, - input_height, - input_width)) - indexer += 1 + img_bin_corr = otsu_copy(img) + yield img_bin_corr, lab if degrading: for degrade_scale_ind in degrade_scales: - cv2.imwrite(dir_flow_train_imgs + '/img_' + str(indexer) + '.png', - (resize_image(do_degrading(img, degrade_scale_ind), - input_height, - input_width))) - cv2.imwrite(dir_flow_train_labels + '/img_' + str(indexer) + '.png', - resize_image(lab, - input_height, - input_width)) - indexer += 1 + yield do_degrading(img, degrade_scale_ind), lab if rotation_not_90: for thetha_i in thetha: - img_max_rotated, label_max_rotated = \ - rotation_not_90_func(img, lab, thetha_i) - cv2.imwrite(dir_flow_train_imgs + '/img_' + str(indexer) + '.png', - resize_image(img_max_rotated, - input_height, - input_width)) - cv2.imwrite(dir_flow_train_labels + '/img_' + str(indexer) + '.png', - resize_image(label_max_rotated, - input_height, - input_width)) - indexer += 1 + yield rotation_not_90_func(img, lab, thetha_i) if channels_shuffling: for shuffle_index in shuffle_indexes: - cv2.imwrite(dir_flow_train_imgs + '/img_' + str(indexer) + '.png', - (resize_image(return_shuffled_channels(img, shuffle_index), - input_height, - input_width))) - cv2.imwrite(dir_flow_train_labels + '/img_' + str(indexer) + '.png', - resize_image(lab, - input_height, - input_width)) - indexer += 1 + yield return_shuffled_channels(img, shuffle_index), lab if scaling: for sc_ind in scales: - img_scaled, label_scaled = \ - scale_image_for_no_patch(img, lab, sc_ind) - cv2.imwrite(dir_flow_train_imgs + '/img_' + str(indexer) + '.png', - resize_image(img_scaled, - input_height, - input_width)) - cv2.imwrite(dir_flow_train_labels + '/img_' + str(indexer) + '.png', - resize_image(label_scaled, - input_height, - input_width)) - indexer += 1 + yield scale_image_for_no_patch(img, lab, sc_ind) if shifting: shift_types = ['xpos', 'xmin', 'ypos', 'ymin', 'xypos', 'xymin'] for st_ind in shift_types: - img_shifted, label_shifted = \ - shift_image_and_label(img, lab, st_ind) - cv2.imwrite(dir_flow_train_imgs + '/img_' + str(indexer) + '.png', - resize_image(img_shifted, - input_height, - input_width)) - cv2.imwrite(dir_flow_train_labels + '/img_' + str(indexer) + '.png', - resize_image(label_shifted, - input_height, - input_width)) - indexer += 1 + yield shift_image_and_label(img, lab, st_ind) if adding_rgb_background: img_bin_corr = cv2.imread(dir_img_bin + '/' + img_name+'.png') for i_n in range(number_of_backgrounds_per_image): @@ -916,15 +941,7 @@ def preprocess_img(indexer, img_with_overlayed_background = \ return_binary_image_with_given_rgb_background( img_bin_corr, img_rgb_background_chosen) - cv2.imwrite(dir_flow_train_imgs + '/img_' + str(indexer) + '.png', - resize_image(img_with_overlayed_background, - input_height, - input_width)) - cv2.imwrite(dir_flow_train_labels + '/img_' + str(indexer) + '.png', - resize_image(lab, - input_height, - input_width)) - indexer += 1 + yield img_with_overlayed_background, lab if adding_rgb_foreground: img_bin_corr = cv2.imread(dir_img_bin + '/' + img_name+'.png') for i_n in range(number_of_backgrounds_per_image): @@ -937,67 +954,37 @@ def preprocess_img(indexer, img_with_overlayed_background = \ return_binary_image_with_given_rgb_background_and_given_foreground_rgb( img_bin_corr, img_rgb_background_chosen, foreground_rgb_chosen) - cv2.imwrite(dir_flow_train_imgs + '/img_' + str(indexer) + '.png', - resize_image(img_with_overlayed_background, - input_height, - input_width)) - cv2.imwrite(dir_flow_train_labels + '/img_' + str(indexer) + '.png', - resize_image(lab, - input_height, - input_width)) - indexer += 1 + yield img_with_overlayed_background, lab if add_red_textlines: img_bin_corr = cv2.imread(dir_img_bin + '/' + img_name+'.png') - img_red_context = \ - return_image_with_red_elements(img, img_bin_corr) - cv2.imwrite(dir_flow_train_imgs + '/img_' + str(indexer) + '.png', - resize_image(img_red_context, - input_height, - input_width)) - cv2.imwrite(dir_flow_train_labels + '/img_' + str(indexer) + '.png', - resize_image(lab, - input_height, - input_width)) - indexer += 1 + yield return_image_with_red_elements(img, img_bin_corr), lab else: - indexer = get_patches(dir_flow_train_imgs, - dir_flow_train_labels, - img, - lab, - input_height, - input_width, - indexer=indexer) + yield from get_patches(img, + lab, + input_height, + input_width) if augmentation: if rotation: - indexer = get_patches(dir_flow_train_imgs, - dir_flow_train_labels, - rotation_90(img), - rotation_90(lab), - input_height, - input_width, - indexer=indexer) + yield from get_patches(rotation_90(img), + rotation_90(lab), + input_height, + input_width) if rotation_not_90: for thetha_i in thetha: img_max_rotated, label_max_rotated = \ rotation_not_90_func(img, lab, thetha_i) - indexer = get_patches(dir_flow_train_imgs, - dir_flow_train_labels, - img_max_rotated, - label_max_rotated, - input_height, - input_width, - indexer=indexer) + yield from get_patches(img_max_rotated, + label_max_rotated, + input_height, + input_width) if channels_shuffling: for shuffle_index in shuffle_indexes: img_shuffled = \ return_shuffled_channels(img, shuffle_index), - indexer = get_patches(dir_flow_train_imgs, - dir_flow_train_labels, - img_shuffled, - lab, - input_height, - input_width, - indexer=indexer) + yield from get_patches(img_shuffled, + lab, + input_height, + input_width) if adding_rgb_background: img_bin_corr = cv2.imread(dir_img_bin + '/' + img_name+'.png') for i_n in range(number_of_backgrounds_per_image): @@ -1007,13 +994,10 @@ def preprocess_img(indexer, img_with_overlayed_background = \ return_binary_image_with_given_rgb_background( img_bin_corr, img_rgb_background_chosen) - indexer = get_patches(dir_flow_train_imgs, - dir_flow_train_labels, - img_with_overlayed_background, - lab, - input_height, - input_width, - indexer=indexer) + yield from get_patches(img_with_overlayed_background, + lab, + input_height, + input_width) if adding_rgb_foreground: img_bin_corr = cv2.imread(dir_img_bin + '/' + img_name+'.png') for i_n in range(number_of_backgrounds_per_image): @@ -1026,155 +1010,280 @@ def preprocess_img(indexer, img_with_overlayed_background = \ return_binary_image_with_given_rgb_background_and_given_foreground_rgb( img_bin_corr, img_rgb_background_chosen, foreground_rgb_chosen) - indexer = get_patches(dir_flow_train_imgs, - dir_flow_train_labels, - img_with_overlayed_background, - lab, - input_height, - input_width, - indexer=indexer) + yield from get_patches(img_with_overlayed_background, + lab, + input_height, + input_width) if add_red_textlines: - img_bin_corr = cv2.imread(dir_img_bin + '/' + img_name+'.png') + img_bin_corr = cv2.imread(os.path.join(dir_img_bin, img_name + '.png')) img_red_context = \ return_image_with_red_elements(img, img_bin_corr) - indexer = get_patches(dir_flow_train_imgs, - dir_flow_train_labels, - img_red_context, - lab, - input_height, - input_width, - indexer=indexer) + yield from get_patches(img_red_context, + lab, + input_height, + input_width) if flip_aug: for f_i in flip_index: - indexer = get_patches(dir_flow_train_imgs, - dir_flow_train_labels, - cv2.flip(img, f_i), - cv2.flip(lab, f_i), - input_height, - input_width, - indexer=indexer) + yield from get_patches(cv2.flip(img, f_i), + cv2.flip(lab, f_i), + input_height, + input_width) if blur_aug: for blur_i in blur_k: - indexer = get_patches(dir_flow_train_imgs, - dir_flow_train_labels, - bluring(img, blur_i), - lab, - input_height, - input_width, - indexer=indexer) + yield from get_patches(bluring(img, blur_i), + lab, + input_height, + input_width) if padding_black: - indexer = get_patches(dir_flow_train_imgs, - dir_flow_train_labels, - do_padding_black(img), - do_padding_label(lab), - input_height, - input_width, - indexer=indexer) + yield from get_patches(do_padding_black(img), + do_padding_label(lab), + input_height, + input_width) if padding_white: - indexer = get_patches(dir_flow_train_imgs, - dir_flow_train_labels, - do_padding_white(img), - do_padding_label(lab), - input_height, - input_width, - indexer=indexer) + yield from get_patches(do_padding_white(img), + do_padding_label(lab), + input_height, + input_width) if brightening: for factor in brightness: - indexer = get_patches(dir_flow_train_imgs, - dir_flow_train_labels, - do_brightening(img, factor), - lab, - input_height, - input_width, - indexer=indexer) + yield from get_patches(do_brightening(img, factor), + lab, + input_height, + input_width) if scaling: for sc_ind in scales: - indexer = get_patches_num_scale_new( - dir_flow_train_imgs, - dir_flow_train_labels, - img , - lab, - input_height, - input_width, - indexer=indexer, - scaler=sc_ind) + yield from get_patches_num_scale_new(img, + lab, + input_height, + input_width, + scaler=sc_ind) if degrading: for degrade_scale_ind in degrade_scales: img_deg = \ do_degrading(img, degrade_scale_ind), - indexer = get_patches(dir_flow_train_imgs, - dir_flow_train_labels, - img_deg, - lab, - input_height, - input_width, - indexer=indexer) + yield from get_patches(img_deg, + lab, + input_height, + input_width) if binarization: if dir_img_bin: - img_bin_corr = cv2.imread(dir_img_bin + '/' + img_name+'.png') - indexer = get_patches(dir_flow_train_imgs, - dir_flow_train_labels, - img_bin_corr, - lab, - input_height, - input_width, - indexer=indexer) + img_bin_corr = cv2.imread(os.path.join(dir_img_bin, img_name + '.png')) else: - indexer = get_patches(dir_flow_train_imgs, - dir_flow_train_labels, - otsu_copy(img), - lab, - input_height, - input_width, - indexer=indexer) + img_bin_corr = otsu_copy(img) + yield from get_patches(img_bin_corr, + lab, + input_height, + input_width) if scaling_brightness: for sc_ind in scales: for factor in brightness: img_bright = do_brightening(img, factor) - indexer = get_patches_num_scale_new( - dir_flow_train_imgs, - dir_flow_train_labels, - img_bright, - lab, - input_height, - input_width, - indexer=indexer, - scaler=sc_ind) + yield from get_patches_num_scale_new(img_bright, + lab, + input_height, + input_width, + scaler=sc_ind) if scaling_bluring: for sc_ind in scales: for blur_i in blur_k: img_blur = bluring(img, blur_i), - indexer = get_patches_num_scale_new( - dir_flow_train_imgs, - dir_flow_train_labels, - img_blur, - lab, - input_height, - input_width, - indexer=indexer, - scaler=sc_ind) + yield from get_patches_num_scale_new(img_blur, + lab, + input_height, + input_width, + scaler=sc_ind) if scaling_binarization: for sc_ind in scales: img_bin = otsu_copy(img), - indexer = get_patches_num_scale_new( - dir_flow_train_imgs, - dir_flow_train_labels, - img_bin, - lab, - input_height, - input_width, - indexer=indexer, - scaler=sc_ind) + yield from get_patches_num_scale_new(img_bin, + lab, + input_height, + input_width, + scaler=sc_ind) if scaling_flip: for sc_ind in scales: for f_i in flip_index: - indexer = get_patches_num_scale_new( - dir_flow_train_imgs, - dir_flow_train_labels, - cv2.flip(img, f_i), - cv2.flip(lab, f_i), - input_height, - input_width, - indexer=indexer, - scaler=sc_ind) - return indexer + yield from get_patches_num_scale_new(cv2.flip(img, f_i), + cv2.flip(lab, f_i), + input_height, + input_width, + scaler=sc_ind) + +def preprocess_img_ocr( + img, + img_name, + lab, + char_to_num=None, + padding_token=-1, + max_len=500, + n_batch=1, + input_height=None, + input_width=None, + augmentation=False, + color_padding_rotation=None, + thetha_padd=None, + padd_colors=None, + rotation_not_90=None, + thetha=None, + padding_white=None, + white_padds=None, + degrading=False, + bin_deg=None, + degrade_scales=None, + blur_aug=False, + blur_k=None, + brightening=False, + brightness=None, + binarization=False, + image_inversion=False, + channels_shuffling=False, + shuffle_indexes=None, + white_noise_strap=False, + textline_skewing=False, + textline_skewing_bin=False, + skewing_amplitudes=None, + textline_left_in_depth=False, + textline_left_in_depth_bin=False, + textline_right_in_depth=False, + textline_right_in_depth_bin=False, + textline_up_in_depth=False, + textline_up_in_depth_bin=False, + textline_down_in_depth=False, + textline_down_in_depth_bin=False, + pepper_aug=False, + pepper_bin_aug=False, + pepper_indexes=None, + dir_img_bin=None, + add_red_textlines=False, + adding_rgb_background=False, + dir_rgb_backgrounds=None, + adding_rgb_foreground=False, + dir_rgb_foregrounds=None, + number_of_backgrounds_per_image=None, + list_all_possible_background_images=None, + list_all_possible_foreground_rgbs=None, +): + def scale_image(img): + return scale_padd_image_for_ocr(img, input_height, input_width).astype(np.float32) / 255. + #lab = vectorize_label(lab, char_to_num, padding_token, max_len) + # now padded at Dataset.padded_batch + lab = char_to_num(tf.strings.unicode_split(label, input_encoding="UTF-8")) + yield scale_image(img), lab + #to_yield = {"image": ret_x, "label": ret_y} + + if dir_img_bin: + img_bin_corr = cv2.imread(os.path.join(dir_img_bin, img_name + '.png')) + else: + img_bin_corr = None + + if not augmentation: + return + + if color_padding_rotation: + for thetha_ind in thetha_padd: + for padd_col in padd_colors: + img_pad = do_padding_for_ocr(img, 1.2, padd_col) + img_rot = rotation_not_90_func_single_image(img_pad, thetha_ind) + yield scale_image(img_rot), lab + if rotation_not_90: + for thetha_ind in thetha: + img_rot = rotation_not_90_func_single_image(img, thetha_ind) + yield scale_image(img_rot), lab + if blur_aug: + for blur_type in blur_k: + img_blur = bluring(img, blur_type) + yield scale_image(img_blur), lab + if degrading: + for deg_scale_ind in degrade_scales: + img_deg = do_degrading(img, deg_scale_ind) + yield scale_image(img_deg), lab + if bin_deg: + for deg_scale_ind in degrade_scales: + img_deg = do_degrading(img_bin_corr, deg_scale_ind) + yield scale_image(img_deg), lab + if brightening: + for bright_scale_ind in brightness: + img_bright = do_brightening(img, bright_scale_ind) + yield scale_image(img_bright), lab + if padding_white: + for padding_size in white_padds: + for padd_col in padd_colors: + img_pad = do_padding_for_ocr(img, padding_size, padd_col) + yield scale_image(img_pad), lab + if adding_rgb_foreground: + for i_n in range(number_of_backgrounds_per_image): + background_image_chosen_name = random.choice(list_all_possible_background_images) + foreground_rgb_chosen_name = random.choice(list_all_possible_foreground_rgbs) + + img_rgb_background_chosen = \ + cv2.imread(dir_rgb_backgrounds + '/' + background_image_chosen_name) + foreground_rgb_chosen = \ + np.load(dir_rgb_foregrounds + '/' + foreground_rgb_chosen_name) + + img_fg = \ + return_binary_image_with_given_rgb_background_and_given_foreground_rgb( + img_bin_corr, img_rgb_background_chosen, foreground_rgb_chosen) + yield scale_image(img_fg), lab + if adding_rgb_background: + for i_n in range(number_of_backgrounds_per_image): + background_image_chosen_name = random.choice(list_all_possible_background_images) + img_rgb_background_chosen = \ + cv2.imread(dir_rgb_backgrounds + '/' + background_image_chosen_name) + img_bg = \ + return_binary_image_with_given_rgb_background(img_bin_corr, img_rgb_background_chosen) + yield scale_image(img_bg), lab + if binarization: + yield scale_image(img_bin_corr), lab + if image_inversion: + img_inv = invert_image(img_bin_corr) + yield scale_image(img_inv), lab + if channels_shuffling: + for shuffle_index in shuffle_indexes: + img_shuf = return_shuffled_channels(img, shuffle_index) + yield scale_image(img_shuf), lab + if add_red_textlines: + img_red = return_image_with_red_elements(img, img_bin_corr) + yield scale_image(img_red), lab + if white_noise_strap: + img_noisy = return_image_with_strapped_white_noises(img) + yield scale_image(img_noisy), lab + if textline_skewing: + for des_scale_ind in skewing_amplitudes: + img_rot = do_deskewing(img, des_scale_ind) + yield scale_image(img_rot), lab + if textline_skewing_bin: + for des_scale_ind in skewing_amplitudes: + img_rot = do_deskewing(img_bin_corr, des_scale_ind) + yield scale_image(img_rot), lab + if textline_left_in_depth: + img_warp = do_direction_in_depth(img, 'left') + yield scale_image(img_warp), lab + if textline_left_in_depth_bin: + img_warp = do_direction_in_depth(img_bin_corr, 'left') + yield scale_image(img_warp), lab + if textline_right_in_depth: + img_warp = do_direction_in_depth(img, 'right') + yield scale_image(img_warp), lab + if textline_right_in_depth_bin: + img_warp = do_direction_in_depth(img_bin_corr, 'right') + yield scale_image(img_warp), lab + if textline_up_in_depth: + img_warp = do_direction_in_depth(img, 'up') + yield scale_image(img_warp), lab + if textline_up_in_depth_bin: + img_warp = do_direction_in_depth(img_bin_corr, 'up') + yield scale_image(img_warp), lab + if textline_down_in_depth: + img_warp = do_direction_in_depth(img, 'down') + yield scale_image(img_warp), lab + if textline_down_in_depth_bin: + img_warp = do_direction_in_depth(img_bin_corr, 'down') + yield scale_image(img_warp), lab + if pepper_aug: + for pepper_ind in pepper_indexes: + img_noisy = add_salt_and_pepper_noise(img, pepper_ind, pepper_ind) + yield scale_image(img_noisy), lab + if pepper_bin_aug: + for pepper_ind in pepper_indexes: + img_noisy = add_salt_and_pepper_noise(img_bin_corr, pepper_ind, pepper_ind) + yield scale_image(img_noisy), lab diff --git a/src/eynollah/training/weights_ensembling.py b/src/eynollah/training/weights_ensembling.py new file mode 100644 index 0000000..6dce7fd --- /dev/null +++ b/src/eynollah/training/weights_ensembling.py @@ -0,0 +1,136 @@ +import sys +from glob import glob +from os import environ, devnull +from os.path import join +from warnings import catch_warnings, simplefilter +import os + +import numpy as np +from PIL import Image +import cv2 +environ['TF_CPP_MIN_LOG_LEVEL'] = '3' +stderr = sys.stderr +sys.stderr = open(devnull, 'w') +import tensorflow as tf +from tensorflow.keras.models import load_model +from tensorflow.python.keras import backend as tensorflow_backend +sys.stderr = stderr +from tensorflow.keras import layers +import tensorflow.keras.losses +from tensorflow.keras.layers import * +import click +import logging + + +class Patches(layers.Layer): + def __init__(self, patch_size_x, patch_size_y): + super(Patches, self).__init__() + self.patch_size_x = patch_size_x + self.patch_size_y = patch_size_y + + def call(self, images): + #print(tf.shape(images)[1],'images') + #print(self.patch_size,'self.patch_size') + batch_size = tf.shape(images)[0] + patches = tf.image.extract_patches( + images=images, + sizes=[1, self.patch_size_y, self.patch_size_x, 1], + strides=[1, self.patch_size_y, self.patch_size_x, 1], + rates=[1, 1, 1, 1], + padding="VALID", + ) + #patch_dims = patches.shape[-1] + patch_dims = tf.shape(patches)[-1] + patches = tf.reshape(patches, [batch_size, -1, patch_dims]) + return patches + def get_config(self): + + config = super().get_config().copy() + config.update({ + 'patch_size_x': self.patch_size_x, + 'patch_size_y': self.patch_size_y, + }) + return config + + + +class PatchEncoder(layers.Layer): + def __init__(self, **kwargs): + super(PatchEncoder, self).__init__() + self.num_patches = num_patches + self.projection = layers.Dense(units=projection_dim) + self.position_embedding = layers.Embedding( + input_dim=num_patches, output_dim=projection_dim + ) + + def call(self, patch): + positions = tf.range(start=0, limit=self.num_patches, delta=1) + encoded = self.projection(patch) + self.position_embedding(positions) + return encoded + def get_config(self): + + config = super().get_config().copy() + config.update({ + 'num_patches': self.num_patches, + 'projection': self.projection, + 'position_embedding': self.position_embedding, + }) + return config + + +def start_new_session(): + ###config = tf.compat.v1.ConfigProto() + ###config.gpu_options.allow_growth = True + + ###self.session = tf.compat.v1.Session(config=config) # tf.InteractiveSession() + ###tensorflow_backend.set_session(self.session) + + config = tf.compat.v1.ConfigProto() + config.gpu_options.allow_growth = True + + session = tf.compat.v1.Session(config=config) # tf.InteractiveSession() + tensorflow_backend.set_session(session) + return session + +def run_ensembling(dir_models, out): + ls_models = os.listdir(dir_models) + + + weights=[] + + for model_name in ls_models: + model = load_model(os.path.join(dir_models,model_name) , compile=False, custom_objects={'PatchEncoder':PatchEncoder, 'Patches': Patches}) + weights.append(model.get_weights()) + + new_weights = list() + + for weights_list_tuple in zip(*weights): + new_weights.append( + [np.array(weights_).mean(axis=0)\ + for weights_ in zip(*weights_list_tuple)]) + + + + new_weights = [np.array(x) for x in new_weights] + + model.set_weights(new_weights) + model.save(out) + os.system('cp '+os.path.join(os.path.join(dir_models,model_name) , "config.json ")+out) + +@click.command() +@click.option( + "--dir_models", + "-dm", + help="directory of models", + type=click.Path(exists=True, file_okay=False), +) +@click.option( + "--out", + "-o", + help="output directory where ensembled model will be written.", + type=click.Path(exists=False, file_okay=False), +) + +def main(dir_models, out): + run_ensembling(dir_models, out) + diff --git a/src/eynollah/utils/__init__.py b/src/eynollah/utils/__init__.py index b839385..38a50be 100644 --- a/src/eynollah/utils/__init__.py +++ b/src/eynollah/utils/__init__.py @@ -241,7 +241,14 @@ def find_num_col_deskew(regions_without_separators, sigma_, multiplier=3.8): z = gaussian_filter1d(regions_without_separators_0, sigma_) return np.std(z) -def find_num_col(regions_without_separators, num_col_classifier, tables, multiplier=3.8, unbalanced=False, vertical_separators=None): +def find_num_col( + regions_without_separators, + num_col_classifier, + tables, + multiplier=3.8, + unbalanced=False, + vertical_separators=None +): if not regions_without_separators.any(): return 0, [] if vertical_separators is None: diff --git a/src/eynollah/utils/contour.py b/src/eynollah/utils/contour.py index c8caca9..3a67c65 100644 --- a/src/eynollah/utils/contour.py +++ b/src/eynollah/utils/contour.py @@ -356,7 +356,7 @@ def join_polygons(polygons: Sequence[Polygon], scale=20) -> Polygon: assert jointp.geom_type == 'Polygon', jointp.wkt # follow-up calculations will necessarily be integer; # so anticipate rounding here and then ensure validity - jointp2 = set_precision(jointp, 1.0) + jointp2 = set_precision(jointp, 1.0, mode="keep_collapsed") if jointp2.geom_type != 'Polygon' or not jointp2.is_valid: jointp2 = Polygon(np.round(jointp.exterior.coords)) jointp2 = make_valid(jointp2) diff --git a/src/eynollah/utils/drop_capitals.py b/src/eynollah/utils/drop_capitals.py index 9f82fac..228a6d9 100644 --- a/src/eynollah/utils/drop_capitals.py +++ b/src/eynollah/utils/drop_capitals.py @@ -19,7 +19,6 @@ def adhere_drop_capital_region_into_corresponding_textline( all_found_textline_polygons_h, kernel=None, curved_line=False, - textline_light=False, ): # print(np.shape(all_found_textline_polygons),np.shape(all_found_textline_polygons[3]),'all_found_textline_polygonsshape') # print(all_found_textline_polygons[3]) @@ -79,7 +78,7 @@ def adhere_drop_capital_region_into_corresponding_textline( # region_with_intersected_drop=region_with_intersected_drop/3 region_with_intersected_drop = region_with_intersected_drop.astype(np.uint8) # print(np.unique(img_con_all_copy[:,:,0])) - if curved_line or textline_light: + if curved_line: if len(region_with_intersected_drop) > 1: sum_pixels_of_intersection = [] diff --git a/src/eynollah/utils/font.py b/src/eynollah/utils/font.py new file mode 100644 index 0000000..939933e --- /dev/null +++ b/src/eynollah/utils/font.py @@ -0,0 +1,16 @@ + +# cannot use importlib.resources until we move to 3.9+ forimportlib.resources.files +import sys +from PIL import ImageFont + +if sys.version_info < (3, 10): + import importlib_resources +else: + import importlib.resources as importlib_resources + + +def get_font(): + #font_path = "Charis-7.000/Charis-Regular.ttf" # Make sure this file exists! + font = importlib_resources.files(__package__) / "../Charis-Regular.ttf" + with importlib_resources.as_file(font) as font: + return ImageFont.truetype(font=font, size=40) diff --git a/src/eynollah/utils/marginals.py b/src/eynollah/utils/marginals.py index eaf0048..9f76fb7 100644 --- a/src/eynollah/utils/marginals.py +++ b/src/eynollah/utils/marginals.py @@ -6,7 +6,7 @@ from .contour import find_new_features_of_contours, return_contours_of_intereste from .resize import resize_image from .rotate import rotate_image -def get_marginals(text_with_lines, text_regions, num_col, slope_deskew, light_version=False, kernel=None): +def get_marginals(text_with_lines, text_regions, num_col, slope_deskew, kernel=None): mask_marginals=np.zeros((text_with_lines.shape[0],text_with_lines.shape[1])) mask_marginals=mask_marginals.astype(np.uint8) @@ -27,9 +27,8 @@ def get_marginals(text_with_lines, text_regions, num_col, slope_deskew, light_ve text_with_lines=resize_image(text_with_lines,text_with_lines_eroded.shape[0],text_with_lines_eroded.shape[1]) - if light_version: - kernel_hor = np.ones((1, 5), dtype=np.uint8) - text_with_lines = cv2.erode(text_with_lines,kernel_hor,iterations=6) + kernel_hor = np.ones((1, 5), dtype=np.uint8) + text_with_lines = cv2.erode(text_with_lines,kernel_hor,iterations=6) text_with_lines_y=text_with_lines.sum(axis=0) text_with_lines_y_eroded=text_with_lines_eroded.sum(axis=0) @@ -43,10 +42,7 @@ def get_marginals(text_with_lines, text_regions, num_col, slope_deskew, light_ve elif thickness_along_y_percent>=30 and thickness_along_y_percent<50: min_textline_thickness=20 else: - if light_version: - min_textline_thickness=45 - else: - min_textline_thickness=40 + min_textline_thickness=45 if thickness_along_y_percent>=14: @@ -128,92 +124,39 @@ def get_marginals(text_with_lines, text_regions, num_col, slope_deskew, light_ve if max_point_of_right_marginal>=text_regions.shape[1]: max_point_of_right_marginal=text_regions.shape[1]-1 - if light_version: - text_regions_org = np.copy(text_regions) - text_regions[text_regions[:,:]==1]=4 - - pixel_img=4 - min_area_text=0.00001 - - polygon_mask_marginals_rotated = return_contours_of_interested_region(mask_marginals,1,min_area_text) - - polygon_mask_marginals_rotated = polygon_mask_marginals_rotated[0] + text_regions_org = np.copy(text_regions) + text_regions[text_regions[:,:]==1]=4 + + pixel_img=4 + min_area_text=0.00001 + + polygon_mask_marginals_rotated = return_contours_of_interested_region(mask_marginals,1,min_area_text) + + polygon_mask_marginals_rotated = polygon_mask_marginals_rotated[0] - polygons_of_marginals=return_contours_of_interested_region(text_regions,pixel_img,min_area_text) + polygons_of_marginals=return_contours_of_interested_region(text_regions,pixel_img,min_area_text) - cx_text_only,cy_text_only ,x_min_text_only,x_max_text_only, y_min_text_only ,y_max_text_only,y_cor_x_min_main=find_new_features_of_contours(polygons_of_marginals) + cx_text_only,cy_text_only ,x_min_text_only,x_max_text_only, y_min_text_only ,y_max_text_only,y_cor_x_min_main=find_new_features_of_contours(polygons_of_marginals) - text_regions[(text_regions[:,:]==4)]=1 + text_regions[(text_regions[:,:]==4)]=1 - marginlas_should_be_main_text=[] + marginlas_should_be_main_text=[] - x_min_marginals_left=[] - x_min_marginals_right=[] + x_min_marginals_left=[] + x_min_marginals_right=[] - for i in range(len(cx_text_only)): - results = cv2.pointPolygonTest(polygon_mask_marginals_rotated, (cx_text_only[i], cy_text_only[i]), False) + for i in range(len(cx_text_only)): + results = cv2.pointPolygonTest(polygon_mask_marginals_rotated, (cx_text_only[i], cy_text_only[i]), False) - if results == -1: - marginlas_should_be_main_text.append(polygons_of_marginals[i]) + if results == -1: + marginlas_should_be_main_text.append(polygons_of_marginals[i]) - text_regions_org=cv2.fillPoly(text_regions_org, pts =marginlas_should_be_main_text, color=(4,4)) - text_regions = np.copy(text_regions_org) + text_regions_org=cv2.fillPoly(text_regions_org, pts =marginlas_should_be_main_text, color=(4,4)) + text_regions = np.copy(text_regions_org) - else: - - text_regions[(mask_marginals_rotated[:,:]!=1) & (text_regions[:,:]==1)]=4 - - pixel_img=4 - min_area_text=0.00001 - - polygons_of_marginals=return_contours_of_interested_region(text_regions,pixel_img,min_area_text) - - cx_text_only,cy_text_only ,x_min_text_only,x_max_text_only, y_min_text_only ,y_max_text_only,y_cor_x_min_main=find_new_features_of_contours(polygons_of_marginals) - - text_regions[(text_regions[:,:]==4)]=1 - - marginlas_should_be_main_text=[] - - x_min_marginals_left=[] - x_min_marginals_right=[] - - for i in range(len(cx_text_only)): - x_width_mar=abs(x_min_text_only[i]-x_max_text_only[i]) - y_height_mar=abs(y_min_text_only[i]-y_max_text_only[i]) - - if x_width_mar>16 and y_height_mar/x_width_mar<18: - marginlas_should_be_main_text.append(polygons_of_marginals[i]) - if x_min_text_only[i]<(mid_point-one_third_left): - x_min_marginals_left_new=x_min_text_only[i] - if len(x_min_marginals_left)==0: - x_min_marginals_left.append(x_min_marginals_left_new) - else: - x_min_marginals_left[0]=min(x_min_marginals_left[0],x_min_marginals_left_new) - else: - x_min_marginals_right_new=x_min_text_only[i] - if len(x_min_marginals_right)==0: - x_min_marginals_right.append(x_min_marginals_right_new) - else: - x_min_marginals_right[0]=min(x_min_marginals_right[0],x_min_marginals_right_new) - - if len(x_min_marginals_left)==0: - x_min_marginals_left=[0] - if len(x_min_marginals_right)==0: - x_min_marginals_right=[text_regions.shape[1]-1] - - - text_regions=cv2.fillPoly(text_regions, pts =marginlas_should_be_main_text, color=(4,4)) - - - #text_regions[:,:int(x_min_marginals_left[0])][text_regions[:,:int(x_min_marginals_left[0])]==1]=0 - #text_regions[:,int(x_min_marginals_right[0]):][text_regions[:,int(x_min_marginals_right[0]):]==1]=0 - - - text_regions[:,:int(min_point_of_left_marginal)][text_regions[:,:int(min_point_of_left_marginal)]==1]=0 - text_regions[:,int(max_point_of_right_marginal):][text_regions[:,int(max_point_of_right_marginal):]==1]=0 ###text_regions[:,0:point_left][text_regions[:,0:point_left]==1]=4 diff --git a/src/eynollah/utils/separate_lines.py b/src/eynollah/utils/separate_lines.py index 830dd8d..869cd23 100644 --- a/src/eynollah/utils/separate_lines.py +++ b/src/eynollah/utils/separate_lines.py @@ -5,8 +5,6 @@ import numpy as np import cv2 from scipy.signal import find_peaks from scipy.ndimage import gaussian_filter1d -from multiprocessing import Process, Queue, cpu_count -from multiprocessing import Pool from .rotate import rotate_image from .resize import resize_image from .contour import ( @@ -20,9 +18,7 @@ from .contour import ( from .shm import share_ndarray, wrap_ndarray_shared from . import ( find_num_col_deskew, - crop_image_inside_box, box2rect, - box2slice, ) def dedup_separate_lines(img_patch, contour_text_interest, thetha, axis): @@ -1593,65 +1589,6 @@ def get_smallest_skew(img, sigma_des, angles, logger=None, plotter=None, map=map var = 0 return angle, var -@wrap_ndarray_shared(kw='textline_mask_tot_ea') -def do_work_of_slopes_new( - box_text, contour, contour_par, - textline_mask_tot_ea=None, slope_deskew=0.0, - logger=None, MAX_SLOPE=999, KERNEL=None, plotter=None -): - if KERNEL is None: - KERNEL = np.ones((5, 5), np.uint8) - if logger is None: - logger = getLogger(__package__) - logger.debug('enter do_work_of_slopes_new') - - x, y, w, h = box_text - crop_coor = box2rect(box_text) - mask_textline = np.zeros(textline_mask_tot_ea.shape) - mask_textline = cv2.fillPoly(mask_textline, pts=[contour], color=(1,1,1)) - all_text_region_raw = textline_mask_tot_ea * mask_textline - all_text_region_raw = all_text_region_raw[y: y + h, x: x + w].astype(np.uint8) - img_int_p = all_text_region_raw[:,:] - img_int_p = cv2.erode(img_int_p, KERNEL, iterations=2) - - if not np.prod(img_int_p.shape) or img_int_p.shape[0] /img_int_p.shape[1] < 0.1: - slope = 0 - slope_for_all = slope_deskew - all_text_region_raw = textline_mask_tot_ea[y: y + h, x: x + w] - cnt_clean_rot = textline_contours_postprocessing(all_text_region_raw, slope_for_all, contour_par, box_text, 0) - else: - try: - textline_con, hierarchy = return_contours_of_image(img_int_p) - textline_con_fil = filter_contours_area_of_image(img_int_p, textline_con, - hierarchy, - max_area=1, min_area=0.00008) - y_diff_mean = find_contours_mean_y_diff(textline_con_fil) if len(textline_con_fil) > 1 else np.NaN - if np.isnan(y_diff_mean): - slope_for_all = MAX_SLOPE - else: - sigma_des = max(1, int(y_diff_mean * (4.0 / 40.0))) - img_int_p[img_int_p > 0] = 1 - slope_for_all = return_deskew_slop(img_int_p, sigma_des, logger=logger, plotter=plotter) - if abs(slope_for_all) <= 0.5: - slope_for_all = slope_deskew - except: - logger.exception("cannot determine angle of contours") - slope_for_all = MAX_SLOPE - - if slope_for_all == MAX_SLOPE: - slope_for_all = slope_deskew - slope = slope_for_all - mask_only_con_region = np.zeros(textline_mask_tot_ea.shape) - mask_only_con_region = cv2.fillPoly(mask_only_con_region, pts=[contour_par], color=(1, 1, 1)) - - all_text_region_raw = textline_mask_tot_ea[y: y + h, x: x + w].copy() - mask_only_con_region = mask_only_con_region[y: y + h, x: x + w] - - all_text_region_raw[mask_only_con_region == 0] = 0 - cnt_clean_rot = textline_contours_postprocessing(all_text_region_raw, slope_for_all, contour_par, box_text) - - return cnt_clean_rot, crop_coor, slope - @wrap_ndarray_shared(kw='textline_mask_tot_ea') @wrap_ndarray_shared(kw='mask_texts_only') def do_work_of_slopes_new_curved( @@ -1751,7 +1688,7 @@ def do_work_of_slopes_new_curved( @wrap_ndarray_shared(kw='textline_mask_tot_ea') def do_work_of_slopes_new_light( box_text, contour, contour_par, - textline_mask_tot_ea=None, slope_deskew=0, textline_light=True, + textline_mask_tot_ea=None, slope_deskew=0, logger=None ): if logger is None: @@ -1768,16 +1705,10 @@ def do_work_of_slopes_new_light( mask_only_con_region = np.zeros(textline_mask_tot_ea.shape) mask_only_con_region = cv2.fillPoly(mask_only_con_region, pts=[contour_par], color=(1, 1, 1)) - if textline_light: - all_text_region_raw = np.copy(textline_mask_tot_ea) - all_text_region_raw[mask_only_con_region == 0] = 0 - cnt_clean_rot_raw, hir_on_cnt_clean_rot = return_contours_of_image(all_text_region_raw) - cnt_clean_rot = filter_contours_area_of_image(all_text_region_raw, cnt_clean_rot_raw, hir_on_cnt_clean_rot, - max_area=1, min_area=0.00001) - else: - all_text_region_raw = np.copy(textline_mask_tot_ea[y: y + h, x: x + w]) - mask_only_con_region = mask_only_con_region[y: y + h, x: x + w] - all_text_region_raw[mask_only_con_region == 0] = 0 - cnt_clean_rot = textline_contours_postprocessing(all_text_region_raw, slope_deskew, contour_par, box_text) + all_text_region_raw = np.copy(textline_mask_tot_ea) + all_text_region_raw[mask_only_con_region == 0] = 0 + cnt_clean_rot_raw, hir_on_cnt_clean_rot = return_contours_of_image(all_text_region_raw) + cnt_clean_rot = filter_contours_area_of_image(all_text_region_raw, cnt_clean_rot_raw, hir_on_cnt_clean_rot, + max_area=1, min_area=0.00001) return cnt_clean_rot, crop_coor, slope_deskew diff --git a/src/eynollah/utils/utils_ocr.py b/src/eynollah/utils/utils_ocr.py index fbe3611..928c164 100644 --- a/src/eynollah/utils/utils_ocr.py +++ b/src/eynollah/utils/utils_ocr.py @@ -128,6 +128,7 @@ def return_textlines_split_if_needed(textline_image, textline_image_bin=None): return [image1, image2], None else: return None, None + def preprocess_and_resize_image_for_ocrcnn_model(img, image_height, image_width): if img.shape[0]==0 or img.shape[1]==0: img_fin = np.ones((image_height, image_width, 3)) @@ -379,7 +380,6 @@ def return_rnn_cnn_ocr_of_given_textlines(image, all_box_coord, prediction_model, b_s_ocr, num_to_char, - textline_light=False, curved_line=False): max_len = 512 padding_token = 299 @@ -404,7 +404,7 @@ def return_rnn_cnn_ocr_of_given_textlines(image, else: for indexing2, ind_poly in enumerate(ind_poly_first): cropped_lines_region_indexer.append(indexer_text_region) - if not (textline_light or curved_line): + if not curved_line: ind_poly = copy.deepcopy(ind_poly) box_ind = all_box_coord[indexing] diff --git a/src/eynollah/utils/xml.py b/src/eynollah/utils/xml.py index 88d1df8..ded098e 100644 --- a/src/eynollah/utils/xml.py +++ b/src/eynollah/utils/xml.py @@ -88,3 +88,7 @@ def order_and_id_of_texts(found_polygons_text_region, found_polygons_text_region order_of_texts.append(interest) return order_of_texts, id_of_texts + +def etree_namespace_for_element_tag(tag: str): + right = tag.find('}') + return tag[1:right] diff --git a/src/eynollah/writer.py b/src/eynollah/writer.py index 2e9c895..4f0827f 100644 --- a/src/eynollah/writer.py +++ b/src/eynollah/writer.py @@ -2,15 +2,15 @@ # pylint: disable=import-error from pathlib import Path import os.path -import xml.etree.ElementTree as ET +import logging +from typing import Optional import numpy as np from shapely import affinity, clip_by_rect -from ocrd_utils import getLogger, points_from_polygon +from ocrd_utils import points_from_polygon from ocrd_models.ocrd_page import ( BorderType, CoordsType, - PcGtsType, TextLineType, TextEquivType, TextRegionType, @@ -26,19 +26,18 @@ from .utils.contour import contour2polygon, make_valid class EynollahXmlWriter: - def __init__(self, *, dir_out, image_filename, curved_line,textline_light, pcgts=None): - self.logger = getLogger('eynollah.writer') + def __init__(self, *, dir_out, image_filename, curved_line, pcgts=None): + self.logger = logging.getLogger('eynollah.writer') self.counter = EynollahIdCounter() self.dir_out = dir_out self.image_filename = image_filename self.output_filename = os.path.join(self.dir_out or "", self.image_filename_stem) + ".xml" self.curved_line = curved_line - self.textline_light = textline_light self.pcgts = pcgts - self.scale_x = None # XXX set outside __init__ - self.scale_y = None # XXX set outside __init__ - self.height_org = None # XXX set outside __init__ - self.width_org = None # XXX set outside __init__ + self.scale_x: Optional[float] = None # XXX set outside __init__ + self.scale_y: Optional[float] = None # XXX set outside __init__ + self.height_org: Optional[int] = None # XXX set outside __init__ + self.width_org: Optional[int] = None # XXX set outside __init__ @property def image_filename_stem(self): @@ -65,8 +64,8 @@ class EynollahXmlWriter: text_region.set_orientation(-slopes[region_idx]) region_bboxes = all_box_coord[region_idx] offset = [page_coord[2], page_coord[0]] - # FIXME: or actually... not self.textline_light and not self.curved_line or np.abs(slopes[region_idx]) > 45? - if not self.textline_light and not (self.curved_line and np.abs(slopes[region_idx]) <= 45): + # FIXME: or actually... self.curved_line or np.abs(slopes[region_idx]) > 45? + if self.curved_line and np.abs(slopes[region_idx]) > 45: offset[0] += region_bboxes[2] offset[1] += region_bboxes[0] coords.set_points(self.calculate_points(polygon_textline, offset)) @@ -77,48 +76,88 @@ class EynollahXmlWriter: f.write(to_xml(pcgts)) def build_pagexml_no_full_layout( - self, found_polygons_text_region, - page_coord, order_of_texts, - all_found_textline_polygons, - all_box_coord, - found_polygons_text_region_img, - found_polygons_marginals_left, found_polygons_marginals_right, - all_found_textline_polygons_marginals_left, all_found_textline_polygons_marginals_right, - all_box_coord_marginals_left, all_box_coord_marginals_right, - slopes, slopes_marginals_left, slopes_marginals_right, - cont_page, polygons_seplines, - found_polygons_tables, - **kwargs): + self, + *, + found_polygons_text_region, + page_coord, + order_of_texts, + all_found_textline_polygons, + all_box_coord, + found_polygons_text_region_img, + found_polygons_marginals_left, + found_polygons_marginals_right, + all_found_textline_polygons_marginals_left, + all_found_textline_polygons_marginals_right, + all_box_coord_marginals_left, + all_box_coord_marginals_right, + slopes, + slopes_marginals_left, + slopes_marginals_right, + cont_page, + polygons_seplines, + found_polygons_tables, + ): return self.build_pagexml_full_layout( - found_polygons_text_region, [], - page_coord, order_of_texts, - all_found_textline_polygons, [], - all_box_coord, [], - found_polygons_text_region_img, found_polygons_tables, [], - found_polygons_marginals_left, found_polygons_marginals_right, - all_found_textline_polygons_marginals_left, all_found_textline_polygons_marginals_right, - all_box_coord_marginals_left, all_box_coord_marginals_right, - slopes, [], slopes_marginals_left, slopes_marginals_right, - cont_page, polygons_seplines, - **kwargs) + found_polygons_text_region=found_polygons_text_region, + found_polygons_text_region_h=[], + page_coord=page_coord, + order_of_texts=order_of_texts, + all_found_textline_polygons=all_found_textline_polygons, + all_found_textline_polygons_h=[], + all_box_coord=all_box_coord, + all_box_coord_h=[], + found_polygons_text_region_img=found_polygons_text_region_img, + found_polygons_tables=found_polygons_tables, + found_polygons_drop_capitals=[], + found_polygons_marginals_left=found_polygons_marginals_left, + found_polygons_marginals_right=found_polygons_marginals_right, + all_found_textline_polygons_marginals_left=all_found_textline_polygons_marginals_left, + all_found_textline_polygons_marginals_right=all_found_textline_polygons_marginals_right, + all_box_coord_marginals_left=all_box_coord_marginals_left, + all_box_coord_marginals_right=all_box_coord_marginals_right, + slopes=slopes, + slopes_h=[], + slopes_marginals_left=slopes_marginals_left, + slopes_marginals_right=slopes_marginals_right, + cont_page=cont_page, + polygons_seplines=polygons_seplines, + ) def build_pagexml_full_layout( - self, - found_polygons_text_region, found_polygons_text_region_h, - page_coord, order_of_texts, - all_found_textline_polygons, all_found_textline_polygons_h, - all_box_coord, all_box_coord_h, - found_polygons_text_region_img, found_polygons_tables, found_polygons_drop_capitals, - found_polygons_marginals_left,found_polygons_marginals_right, - all_found_textline_polygons_marginals_left, all_found_textline_polygons_marginals_right, - all_box_coord_marginals_left, all_box_coord_marginals_right, - slopes, slopes_h, slopes_marginals_left, slopes_marginals_right, - cont_page, polygons_seplines, - ocr_all_textlines=None, ocr_all_textlines_h=None, - ocr_all_textlines_marginals_left=None, ocr_all_textlines_marginals_right=None, - ocr_all_textlines_drop=None, - conf_contours_textregions=None, conf_contours_textregions_h=None, - skip_layout_reading_order=False): + self, + *, + found_polygons_text_region, + found_polygons_text_region_h, + page_coord, + order_of_texts, + all_found_textline_polygons, + all_found_textline_polygons_h, + all_box_coord, + all_box_coord_h, + found_polygons_text_region_img, + found_polygons_tables, + found_polygons_drop_capitals, + found_polygons_marginals_left, + found_polygons_marginals_right, + all_found_textline_polygons_marginals_left, + all_found_textline_polygons_marginals_right, + all_box_coord_marginals_left, + all_box_coord_marginals_right, + slopes, + slopes_h, + slopes_marginals_left, + slopes_marginals_right, + cont_page, + polygons_seplines, + ocr_all_textlines=None, + ocr_all_textlines_h=None, + ocr_all_textlines_marginals_left=None, + ocr_all_textlines_marginals_right=None, + ocr_all_textlines_drop=None, + conf_contours_textregions=None, + conf_contours_textregions_h=None, + skip_layout_reading_order=False, + ): self.logger.debug('enter build_pagexml') # create the file structure @@ -145,6 +184,7 @@ class EynollahXmlWriter: id=counter.next_region_id, type_='paragraph', Coords=CoordsType(points=self.calculate_points(region_contour, offset)) ) + assert textregion.Coords if conf_contours_textregions: textregion.Coords.set_conf(conf_contours_textregions[mm]) page.add_TextRegion(textregion) @@ -161,6 +201,7 @@ class EynollahXmlWriter: id=counter.next_region_id, type_='heading', Coords=CoordsType(points=self.calculate_points(region_contour, offset)) ) + assert textregion.Coords if conf_contours_textregions_h: textregion.Coords.set_conf(conf_contours_textregions_h[mm]) page.add_TextRegion(textregion) diff --git a/tests/__init__.py b/tests/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/cli_tests/conftest.py b/tests/cli_tests/conftest.py new file mode 100644 index 0000000..601d76b --- /dev/null +++ b/tests/cli_tests/conftest.py @@ -0,0 +1,47 @@ +from typing import List +import pytest +import logging + +from click.testing import CliRunner, Result +from eynollah.cli import main as eynollah_cli + + +@pytest.fixture +def run_eynollah_ok_and_check_logs( + pytestconfig, + caplog, + model_dir, + eynollah_subcommands, + eynollah_log_filter, +): + """ + Generates a Click Runner for `cli`, injects model_path and logging level + to `args`, runs the command and checks whether the logs generated contain + every fragment in `expected_logs` + """ + + def _run_click_ok_logs( + subcommand: 'str', + args: List[str], + expected_logs: List[str], + ) -> Result: + assert subcommand in eynollah_subcommands, f'subcommand {subcommand} must be one of {eynollah_subcommands}' + args = [ + '-m', model_dir, + subcommand, + *args + ] + if pytestconfig.getoption('verbose') > 0: + args = ['-l', 'DEBUG'] + args + caplog.set_level(logging.INFO) + runner = CliRunner() + with caplog.filtering(eynollah_log_filter): + result = runner.invoke(eynollah_cli, args, catch_exceptions=False) + assert result.exit_code == 0, result.stdout + if expected_logs: + logmsgs = [logrec.message for logrec in caplog.records] + assert any(logmsg.startswith(needle) for needle in expected_logs for logmsg in logmsgs), f'{expected_logs} not in {logmsgs}' + return result + + return _run_click_ok_logs + diff --git a/tests/cli_tests/test_binarization.py b/tests/cli_tests/test_binarization.py new file mode 100644 index 0000000..1287ffa --- /dev/null +++ b/tests/cli_tests/test_binarization.py @@ -0,0 +1,53 @@ +import pytest +from PIL import Image + +@pytest.mark.parametrize( + "options", + [ + [], # defaults + ["--no-patches"], + ], ids=str) +def test_run_eynollah_binarization_filename( + tmp_path, + run_eynollah_ok_and_check_logs, + resources_dir, + options, +): + infile = resources_dir / '2files/kant_aufklaerung_1784_0020.tif' + outfile = tmp_path / 'kant_aufklaerung_1784_0020.png' + run_eynollah_ok_and_check_logs( + 'binarization', + [ + '-i', str(infile), + '-o', str(outfile), + ] + options, + [ + 'Loaded model' + ] + ) + assert outfile.exists() + with Image.open(infile) as original_img: + original_size = original_img.size + with Image.open(outfile) as binarized_img: + binarized_size = binarized_img.size + assert original_size == binarized_size + +def test_run_eynollah_binarization_directory( + tmp_path, + run_eynollah_ok_and_check_logs, + resources_dir, + image_resources, +): + outdir = tmp_path + run_eynollah_ok_and_check_logs( + 'binarization', + [ + '-di', str(resources_dir / '2files'), + '-o', str(outdir), + ], + [ + f'Binarizing [ 1/2] {image_resources[0].name}', + f'Binarizing [ 2/2] {image_resources[1].name}', + ] + ) + assert len(list(outdir.iterdir())) == 2 diff --git a/tests/cli_tests/test_enhance.py b/tests/cli_tests/test_enhance.py new file mode 100644 index 0000000..b994c5d --- /dev/null +++ b/tests/cli_tests/test_enhance.py @@ -0,0 +1,52 @@ +import pytest +from PIL import Image + +@pytest.mark.parametrize( + "options", + [ + [], # defaults + ["-sos"], + ], ids=str) +def test_run_eynollah_enhancement_filename( + tmp_path, + resources_dir, + run_eynollah_ok_and_check_logs, + options, +): + infile = resources_dir / '2files/kant_aufklaerung_1784_0020.tif' + outfile = tmp_path / 'kant_aufklaerung_1784_0020.png' + run_eynollah_ok_and_check_logs( + 'enhancement', + [ + '-i', str(infile), + '-o', str(outfile.parent), + ] + options, + [ + 'Image was enhanced', + ] + ) + with Image.open(infile) as original_img: + original_size = original_img.size + with Image.open(outfile) as enhanced_img: + enhanced_size = enhanced_img.size + assert (original_size == enhanced_size) == ("-sos" in options) + +def test_run_eynollah_enhancement_directory( + tmp_path, + resources_dir, + image_resources, + run_eynollah_ok_and_check_logs, +): + outdir = tmp_path + run_eynollah_ok_and_check_logs( + 'enhancement', + [ + '-di', str(resources_dir/ '2files'), + '-o', str(outdir), + ], + [ + f'Image {image_resources[0]} was enhanced', + f'Image {image_resources[1]} was enhanced', + ] + ) + assert len(list(outdir.iterdir())) == 2 diff --git a/tests/cli_tests/test_layout.py b/tests/cli_tests/test_layout.py new file mode 100644 index 0000000..7cbe013 --- /dev/null +++ b/tests/cli_tests/test_layout.py @@ -0,0 +1,119 @@ +import pytest +from ocrd_modelfactory import page_from_file +from ocrd_models.constants import NAMESPACES as NS + +@pytest.mark.parametrize( + "options", + [ + [], # defaults + #["--allow_scaling", "--curved-line"], + ["--allow_scaling", "--curved-line", "--full-layout"], + ["--allow_scaling", "--curved-line", "--full-layout", "--reading_order_machine_based"], + # -ep ... + # -eoi ... + # --skip_layout_and_reading_order + ], ids=str) +def test_run_eynollah_layout_filename( + tmp_path, + run_eynollah_ok_and_check_logs, + resources_dir, + options, +): + infile = resources_dir / '2files/kant_aufklaerung_1784_0020.tif' + outfile = tmp_path / 'kant_aufklaerung_1784_0020.xml' + run_eynollah_ok_and_check_logs( + 'layout', + [ + '-i', str(infile), + '-o', str(outfile.parent), + ] + options, + [ + str(infile) + ] + ) + assert outfile.exists() + tree = page_from_file(str(outfile)).etree + regions = tree.xpath("//page:TextRegion", namespaces=NS) + assert len(regions) >= 2, "result is inaccurate" + regions = tree.xpath("//page:SeparatorRegion", namespaces=NS) + assert len(regions) >= 2, "result is inaccurate" + lines = tree.xpath("//page:TextLine", namespaces=NS) + assert len(lines) == 31, "result is inaccurate" # 29 paragraph lines, 1 page and 1 catch-word line + +@pytest.mark.parametrize( + "options", + [ + ["--tables"], + ["--tables", "--full-layout"], + ], ids=str) +def test_run_eynollah_layout_filename2( + tmp_path, + resources_dir, + run_eynollah_ok_and_check_logs, + options, +): + infile = resources_dir / '2files/euler_rechenkunst01_1738_0025.tif' + outfile = tmp_path / 'euler_rechenkunst01_1738_0025.xml' + run_eynollah_ok_and_check_logs( + 'layout', + [ + '-i', str(infile), + '-o', str(outfile.parent), + ] + options, + [ + str(infile) + ] + ) + assert outfile.exists() + tree = page_from_file(str(outfile)).etree + regions = tree.xpath("//page:TextRegion", namespaces=NS) + assert len(regions) >= 2, "result is inaccurate" + regions = tree.xpath("//page:TableRegion", namespaces=NS) + # model/decoding is not very precise, so (depending on mode) we can get fractures/splits/FP + assert len(regions) >= 1, "result is inaccurate" + regions = tree.xpath("//page:SeparatorRegion", namespaces=NS) + assert len(regions) >= 2, "result is inaccurate" + lines = tree.xpath("//page:TextLine", namespaces=NS) + assert len(lines) >= 2, "result is inaccurate" # mostly table (if detected correctly), but 1 page and 1 catch-word line + +def test_run_eynollah_layout_directory( + tmp_path, + resources_dir, + run_eynollah_ok_and_check_logs, +): + outdir = tmp_path + run_eynollah_ok_and_check_logs( + 'layout', + [ + '-di', str(resources_dir / '2files'), + '-o', str(outdir), + ], + [ + 'Job done in', + 'All jobs done in', + ] + ) + assert len(list(outdir.iterdir())) == 2 + +# def test_run_eynollah_layout_marginalia( +# tmp_path, +# resources_dir, +# run_eynollah_ok_and_check_logs, +# ): +# outdir = tmp_path +# outfile = outdir / 'estor_rechtsgelehrsamkeit02_1758_0880_800px.xml' +# run_eynollah_ok_and_check_logs( +# 'layout', +# [ +# '-i', str(resources_dir / 'estor_rechtsgelehrsamkeit02_1758_0880_800px.jpg'), +# '-o', str(outdir), +# ], +# [ +# 'Job done in', +# 'All jobs done in', +# ] +# ) +# assert outfile.exists() +# tree = page_from_file(str(outfile)).etree +# regions = tree.xpath('//page:TextRegion[type="marginalia"]', namespaces=NS) +# assert len(regions) == 5, "expected 5 marginalia regions" diff --git a/tests/cli_tests/test_mbreorder.py b/tests/cli_tests/test_mbreorder.py new file mode 100644 index 0000000..e429e98 --- /dev/null +++ b/tests/cli_tests/test_mbreorder.py @@ -0,0 +1,47 @@ +from ocrd_modelfactory import page_from_file +from ocrd_models.constants import NAMESPACES as NS + +def test_run_eynollah_mbreorder_filename( + tmp_path, + resources_dir, + run_eynollah_ok_and_check_logs, +): + infile = resources_dir / '2files/kant_aufklaerung_1784_0020.xml' + outfile = tmp_path /'kant_aufklaerung_1784_0020.xml' + run_eynollah_ok_and_check_logs( + 'machine-based-reading-order', + [ + '-i', str(infile), + '-o', str(outfile.parent), + ], + [ + # FIXME: mbreorder has no logging! + ] + ) + assert outfile.exists() + #in_tree = page_from_file(str(infile)).etree + #in_order = in_tree.xpath("//page:OrderedGroup//@regionRef", namespaces=NS) + out_tree = page_from_file(str(outfile)).etree + out_order = out_tree.xpath("//page:OrderedGroup//@regionRef", namespaces=NS) + #assert len(out_order) >= 2, "result is inaccurate" + #assert in_order != out_order + assert out_order == ['r_1_1', 'r_2_1', 'r_2_2', 'r_2_3'] + +def test_run_eynollah_mbreorder_directory( + tmp_path, + resources_dir, + run_eynollah_ok_and_check_logs, +): + outdir = tmp_path + run_eynollah_ok_and_check_logs( + 'machine-based-reading-order', + [ + '-di', str(resources_dir / '2files'), + '-o', str(outdir), + ], + [ + # FIXME: mbreorder has no logging! + ] + ) + assert len(list(outdir.iterdir())) == 2 + diff --git a/tests/cli_tests/test_ocr.py b/tests/cli_tests/test_ocr.py new file mode 100644 index 0000000..6bf3080 --- /dev/null +++ b/tests/cli_tests/test_ocr.py @@ -0,0 +1,64 @@ +import pytest +from ocrd_modelfactory import page_from_file +from ocrd_models.constants import NAMESPACES as NS + +@pytest.mark.parametrize( + "options", + [ + ["-trocr"], + [], # defaults + ["-doit", #str(outrenderfile.parent)], + ], + ], ids=str) +def test_run_eynollah_ocr_filename( + tmp_path, + run_eynollah_ok_and_check_logs, + resources_dir, + options, +): + infile = resources_dir / '2files/kant_aufklaerung_1784_0020.tif' + outfile = tmp_path / 'kant_aufklaerung_1784_0020.xml' + outrenderfile = tmp_path / 'render' / 'kant_aufklaerung_1784_0020.png' + outrenderfile.parent.mkdir() + if "-doit" in options: + options.insert(options.index("-doit") + 1, str(outrenderfile.parent)) + run_eynollah_ok_and_check_logs( + 'ocr', + [ + '-i', str(infile), + '-dx', str(infile.parent), + '-o', str(outfile.parent), + ] + options, + [ + # FIXME: ocr has no logging! + ] + ) + assert outfile.exists() + if "-doit" in options: + assert outrenderfile.exists() + #in_tree = page_from_file(str(infile)).etree + #in_order = in_tree.xpath("//page:OrderedGroup//@regionRef", namespaces=NS) + out_tree = page_from_file(str(outfile)).etree + out_texts = out_tree.xpath("//page:TextLine/page:TextEquiv[last()]/page:Unicode/text()", namespaces=NS) + assert len(out_texts) >= 2, ("result is inaccurate", out_texts) + assert sum(map(len, out_texts)) > 100, ("result is inaccurate", out_texts) + +def test_run_eynollah_ocr_directory( + tmp_path, + run_eynollah_ok_and_check_logs, + resources_dir, +): + outdir = tmp_path + run_eynollah_ok_and_check_logs( + 'ocr', + [ + '-di', str(resources_dir / '2files'), + '-dx', str(resources_dir / '2files'), + '-o', str(outdir), + ], + [ + # FIXME: ocr has no logging! + ] + ) + assert len(list(outdir.iterdir())) == 2 + diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..69f3d28 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,37 @@ +from glob import glob +import os +import pytest +from pathlib import Path + + +@pytest.fixture() +def tests_dir(): + return Path(__file__).parent.resolve() + +@pytest.fixture() +def model_dir(tests_dir): + return os.environ.get('EYNOLLAH_MODELS_DIR', str(tests_dir.joinpath('..').resolve())) + +@pytest.fixture() +def resources_dir(tests_dir): + return tests_dir / 'resources' + +@pytest.fixture() +def image_resources(resources_dir): + return [Path(x) for x in glob(str(resources_dir / '2files/*.tif'))] + +@pytest.fixture() +def eynollah_log_filter(): + return lambda logrec: logrec.name.startswith('eynollah') + +@pytest.fixture +def eynollah_subcommands(): + return [ + 'binarization', + 'layout', + 'ocr', + 'enhancement', + 'machine-based-reading-order', + 'models', + ] + diff --git a/tests/resources/euler_rechenkunst01_1738_0025.tif b/tests/resources/2files/euler_rechenkunst01_1738_0025.tif similarity index 100% rename from tests/resources/euler_rechenkunst01_1738_0025.tif rename to tests/resources/2files/euler_rechenkunst01_1738_0025.tif diff --git a/tests/resources/euler_rechenkunst01_1738_0025.xml b/tests/resources/2files/euler_rechenkunst01_1738_0025.xml similarity index 100% rename from tests/resources/euler_rechenkunst01_1738_0025.xml rename to tests/resources/2files/euler_rechenkunst01_1738_0025.xml diff --git a/tests/resources/kant_aufklaerung_1784_0020.tif b/tests/resources/2files/kant_aufklaerung_1784_0020.tif similarity index 100% rename from tests/resources/kant_aufklaerung_1784_0020.tif rename to tests/resources/2files/kant_aufklaerung_1784_0020.tif diff --git a/tests/resources/kant_aufklaerung_1784_0020.xml b/tests/resources/2files/kant_aufklaerung_1784_0020.xml similarity index 100% rename from tests/resources/kant_aufklaerung_1784_0020.xml rename to tests/resources/2files/kant_aufklaerung_1784_0020.xml diff --git a/tests/resources/marginalia/estor_rechtsgelehrsamkeit02_1758_0880_800px.jpg b/tests/resources/marginalia/estor_rechtsgelehrsamkeit02_1758_0880_800px.jpg new file mode 100644 index 0000000..9270508 Binary files /dev/null and b/tests/resources/marginalia/estor_rechtsgelehrsamkeit02_1758_0880_800px.jpg differ diff --git a/tests/resources/marginalia/estor_rechtsgelehrsamkeit02_1758_0880_800px.xml b/tests/resources/marginalia/estor_rechtsgelehrsamkeit02_1758_0880_800px.xml new file mode 100644 index 0000000..45240c4 --- /dev/null +++ b/tests/resources/marginalia/estor_rechtsgelehrsamkeit02_1758_0880_800px.xml @@ -0,0 +1,235 @@ + + + + SBB_QURATOR + 2025-10-30T16:38:21.180191 + 2025-10-30T16:38:21.180191 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/tests/test_model_zoo.py b/tests/test_model_zoo.py new file mode 100644 index 0000000..2042b28 --- /dev/null +++ b/tests/test_model_zoo.py @@ -0,0 +1,16 @@ +from eynollah.model_zoo import EynollahModelZoo + +def test_trocr1( + model_dir, +): + model_zoo = EynollahModelZoo(model_dir) + try: + from transformers import TrOCRProcessor, VisionEncoderDecoderModel + model_zoo.load_model('trocr_processor') + proc = model_zoo.get('trocr_processor', TrOCRProcessor) + assert isinstance(proc, TrOCRProcessor) + model_zoo.load_model('ocr', 'tr') + model = model_zoo.get('ocr', VisionEncoderDecoderModel) + assert isinstance(model, VisionEncoderDecoderModel) + except ImportError: + pass diff --git a/tests/test_run.py b/tests/test_run.py deleted file mode 100644 index 79c64c2..0000000 --- a/tests/test_run.py +++ /dev/null @@ -1,351 +0,0 @@ -from os import environ -from pathlib import Path -import pytest -import logging -from PIL import Image -from eynollah.cli import ( - layout as layout_cli, - binarization as binarization_cli, - enhancement as enhancement_cli, - machine_based_reading_order as mbreorder_cli, - ocr as ocr_cli, -) -from click.testing import CliRunner -from ocrd_modelfactory import page_from_file -from ocrd_models.constants import NAMESPACES as NS - -testdir = Path(__file__).parent.resolve() - -MODELS_LAYOUT = environ.get('MODELS_LAYOUT', str(testdir.joinpath('..', 'models_layout_v0_5_0').resolve())) -MODELS_OCR = environ.get('MODELS_OCR', str(testdir.joinpath('..', 'models_ocr_v0_5_1').resolve())) -MODELS_BIN = environ.get('MODELS_BIN', str(testdir.joinpath('..', 'default-2021-03-09').resolve())) - -@pytest.mark.parametrize( - "options", - [ - [], # defaults - #["--allow_scaling", "--curved-line"], - ["--allow_scaling", "--curved-line", "--full-layout"], - ["--allow_scaling", "--curved-line", "--full-layout", "--reading_order_machine_based"], - ["--allow_scaling", "--curved-line", "--full-layout", "--reading_order_machine_based", - "--textline_light", "--light_version"], - # -ep ... - # -eoi ... - # FIXME: find out whether OCR extra was installed, otherwise skip these - ["--do_ocr"], - ["--do_ocr", "--light_version", "--textline_light"], - ["--do_ocr", "--transformer_ocr"], - #["--do_ocr", "--transformer_ocr", "--light_version", "--textline_light"], - ["--do_ocr", "--transformer_ocr", "--light_version", "--textline_light", "--full-layout"], - # --skip_layout_and_reading_order - ], ids=str) -def test_run_eynollah_layout_filename(tmp_path, pytestconfig, caplog, options): - infile = testdir.joinpath('resources/kant_aufklaerung_1784_0020.tif') - outfile = tmp_path / 'kant_aufklaerung_1784_0020.xml' - args = [ - '-m', MODELS_LAYOUT, - '-i', str(infile), - '-o', str(outfile.parent), - ] - if pytestconfig.getoption('verbose') > 0: - args.extend(['-l', 'DEBUG']) - caplog.set_level(logging.INFO) - def only_eynollah(logrec): - return logrec.name == 'eynollah' - runner = CliRunner() - with caplog.filtering(only_eynollah): - result = runner.invoke(layout_cli, args + options, catch_exceptions=False) - assert result.exit_code == 0, result.stdout - logmsgs = [logrec.message for logrec in caplog.records] - assert str(infile) in logmsgs - assert outfile.exists() - tree = page_from_file(str(outfile)).etree - regions = tree.xpath("//page:TextRegion", namespaces=NS) - assert len(regions) >= 2, "result is inaccurate" - regions = tree.xpath("//page:SeparatorRegion", namespaces=NS) - assert len(regions) >= 2, "result is inaccurate" - lines = tree.xpath("//page:TextLine", namespaces=NS) - assert len(lines) == 31, "result is inaccurate" # 29 paragraph lines, 1 page and 1 catch-word line - -@pytest.mark.parametrize( - "options", - [ - ["--tables"], - ["--tables", "--full-layout"], - ["--tables", "--full-layout", "--textline_light", "--light_version"], - ], ids=str) -def test_run_eynollah_layout_filename2(tmp_path, pytestconfig, caplog, options): - infile = testdir.joinpath('resources/euler_rechenkunst01_1738_0025.tif') - outfile = tmp_path / 'euler_rechenkunst01_1738_0025.xml' - args = [ - '-m', MODELS_LAYOUT, - '-i', str(infile), - '-o', str(outfile.parent), - ] - if pytestconfig.getoption('verbose') > 0: - args.extend(['-l', 'DEBUG']) - caplog.set_level(logging.INFO) - def only_eynollah(logrec): - return logrec.name == 'eynollah' - runner = CliRunner() - with caplog.filtering(only_eynollah): - result = runner.invoke(layout_cli, args + options, catch_exceptions=False) - assert result.exit_code == 0, result.stdout - logmsgs = [logrec.message for logrec in caplog.records] - assert str(infile) in logmsgs - assert outfile.exists() - tree = page_from_file(str(outfile)).etree - regions = tree.xpath("//page:TextRegion", namespaces=NS) - assert len(regions) >= 2, "result is inaccurate" - regions = tree.xpath("//page:TableRegion", namespaces=NS) - # model/decoding is not very precise, so (depending on mode) we can get fractures/splits/FP - assert len(regions) >= 1, "result is inaccurate" - regions = tree.xpath("//page:SeparatorRegion", namespaces=NS) - assert len(regions) >= 2, "result is inaccurate" - lines = tree.xpath("//page:TextLine", namespaces=NS) - assert len(lines) >= 2, "result is inaccurate" # mostly table (if detected correctly), but 1 page and 1 catch-word line - -def test_run_eynollah_layout_directory(tmp_path, pytestconfig, caplog): - indir = testdir.joinpath('resources') - outdir = tmp_path - args = [ - '-m', MODELS_LAYOUT, - '-di', str(indir), - '-o', str(outdir), - ] - if pytestconfig.getoption('verbose') > 0: - args.extend(['-l', 'DEBUG']) - caplog.set_level(logging.INFO) - def only_eynollah(logrec): - return logrec.name == 'eynollah' - runner = CliRunner() - with caplog.filtering(only_eynollah): - result = runner.invoke(layout_cli, args, catch_exceptions=False) - assert result.exit_code == 0, result.stdout - logmsgs = [logrec.message for logrec in caplog.records] - assert len([logmsg for logmsg in logmsgs if logmsg.startswith('Job done in')]) == 2 - assert any(logmsg for logmsg in logmsgs if logmsg.startswith('All jobs done in')) - assert len(list(outdir.iterdir())) == 2 - -@pytest.mark.parametrize( - "options", - [ - [], # defaults - ["--no-patches"], - ], ids=str) -def test_run_eynollah_binarization_filename(tmp_path, pytestconfig, caplog, options): - infile = testdir.joinpath('resources/kant_aufklaerung_1784_0020.tif') - outfile = tmp_path.joinpath('kant_aufklaerung_1784_0020.png') - args = [ - '-m', MODELS_BIN, - '-i', str(infile), - '-o', str(outfile), - ] - if pytestconfig.getoption('verbose') > 0: - args.extend(['-l', 'DEBUG']) - caplog.set_level(logging.INFO) - def only_eynollah(logrec): - return logrec.name == 'SbbBinarizer' - runner = CliRunner() - with caplog.filtering(only_eynollah): - result = runner.invoke(binarization_cli, args + options, catch_exceptions=False) - assert result.exit_code == 0, result.stdout - logmsgs = [logrec.message for logrec in caplog.records] - assert any(True for logmsg in logmsgs if logmsg.startswith('Predicting')) - assert outfile.exists() - with Image.open(infile) as original_img: - original_size = original_img.size - with Image.open(outfile) as binarized_img: - binarized_size = binarized_img.size - assert original_size == binarized_size - -def test_run_eynollah_binarization_directory(tmp_path, pytestconfig, caplog): - indir = testdir.joinpath('resources') - outdir = tmp_path - args = [ - '-m', MODELS_BIN, - '-di', str(indir), - '-o', str(outdir), - ] - if pytestconfig.getoption('verbose') > 0: - args.extend(['-l', 'DEBUG']) - caplog.set_level(logging.INFO) - def only_eynollah(logrec): - return logrec.name == 'SbbBinarizer' - runner = CliRunner() - with caplog.filtering(only_eynollah): - result = runner.invoke(binarization_cli, args, catch_exceptions=False) - assert result.exit_code == 0, result.stdout - logmsgs = [logrec.message for logrec in caplog.records] - assert len([logmsg for logmsg in logmsgs if logmsg.startswith('Predicting')]) == 2 - assert len(list(outdir.iterdir())) == 2 - -@pytest.mark.parametrize( - "options", - [ - [], # defaults - ["-sos"], - ], ids=str) -def test_run_eynollah_enhancement_filename(tmp_path, pytestconfig, caplog, options): - infile = testdir.joinpath('resources/kant_aufklaerung_1784_0020.tif') - outfile = tmp_path.joinpath('kant_aufklaerung_1784_0020.png') - args = [ - '-m', MODELS_LAYOUT, - '-i', str(infile), - '-o', str(outfile.parent), - ] - if pytestconfig.getoption('verbose') > 0: - args.extend(['-l', 'DEBUG']) - caplog.set_level(logging.INFO) - def only_eynollah(logrec): - return logrec.name == 'enhancement' - runner = CliRunner() - with caplog.filtering(only_eynollah): - result = runner.invoke(enhancement_cli, args + options, catch_exceptions=False) - assert result.exit_code == 0, result.stdout - logmsgs = [logrec.message for logrec in caplog.records] - assert any(True for logmsg in logmsgs if logmsg.startswith('Image was enhanced')), logmsgs - assert outfile.exists() - with Image.open(infile) as original_img: - original_size = original_img.size - with Image.open(outfile) as enhanced_img: - enhanced_size = enhanced_img.size - assert (original_size == enhanced_size) == ("-sos" in options) - -def test_run_eynollah_enhancement_directory(tmp_path, pytestconfig, caplog): - indir = testdir.joinpath('resources') - outdir = tmp_path - args = [ - '-m', MODELS_LAYOUT, - '-di', str(indir), - '-o', str(outdir), - ] - if pytestconfig.getoption('verbose') > 0: - args.extend(['-l', 'DEBUG']) - caplog.set_level(logging.INFO) - def only_eynollah(logrec): - return logrec.name == 'enhancement' - runner = CliRunner() - with caplog.filtering(only_eynollah): - result = runner.invoke(enhancement_cli, args, catch_exceptions=False) - assert result.exit_code == 0, result.stdout - logmsgs = [logrec.message for logrec in caplog.records] - assert len([logmsg for logmsg in logmsgs if logmsg.startswith('Image was enhanced')]) == 2 - assert len(list(outdir.iterdir())) == 2 - -def test_run_eynollah_mbreorder_filename(tmp_path, pytestconfig, caplog): - infile = testdir.joinpath('resources/kant_aufklaerung_1784_0020.xml') - outfile = tmp_path.joinpath('kant_aufklaerung_1784_0020.xml') - args = [ - '-m', MODELS_LAYOUT, - '-i', str(infile), - '-o', str(outfile.parent), - ] - if pytestconfig.getoption('verbose') > 0: - args.extend(['-l', 'DEBUG']) - caplog.set_level(logging.INFO) - def only_eynollah(logrec): - return logrec.name == 'mbreorder' - runner = CliRunner() - with caplog.filtering(only_eynollah): - result = runner.invoke(mbreorder_cli, args, catch_exceptions=False) - assert result.exit_code == 0, result.stdout - logmsgs = [logrec.message for logrec in caplog.records] - # FIXME: mbreorder has no logging! - #assert any(True for logmsg in logmsgs if logmsg.startswith('???')), logmsgs - assert outfile.exists() - #in_tree = page_from_file(str(infile)).etree - #in_order = in_tree.xpath("//page:OrderedGroup//@regionRef", namespaces=NS) - out_tree = page_from_file(str(outfile)).etree - out_order = out_tree.xpath("//page:OrderedGroup//@regionRef", namespaces=NS) - #assert len(out_order) >= 2, "result is inaccurate" - #assert in_order != out_order - assert out_order == ['r_1_1', 'r_2_1', 'r_2_2', 'r_2_3'] - -def test_run_eynollah_mbreorder_directory(tmp_path, pytestconfig, caplog): - indir = testdir.joinpath('resources') - outdir = tmp_path - args = [ - '-m', MODELS_LAYOUT, - '-di', str(indir), - '-o', str(outdir), - ] - if pytestconfig.getoption('verbose') > 0: - args.extend(['-l', 'DEBUG']) - caplog.set_level(logging.INFO) - def only_eynollah(logrec): - return logrec.name == 'mbreorder' - runner = CliRunner() - with caplog.filtering(only_eynollah): - result = runner.invoke(mbreorder_cli, args, catch_exceptions=False) - assert result.exit_code == 0, result.stdout - logmsgs = [logrec.message for logrec in caplog.records] - # FIXME: mbreorder has no logging! - #assert len([logmsg for logmsg in logmsgs if logmsg.startswith('???')]) == 2 - assert len(list(outdir.iterdir())) == 2 - -@pytest.mark.parametrize( - "options", - [ - [], # defaults - ["-doit", #str(outrenderfile.parent)], - ], - ["-trocr"], - ], ids=str) -def test_run_eynollah_ocr_filename(tmp_path, pytestconfig, caplog, options): - infile = testdir.joinpath('resources/kant_aufklaerung_1784_0020.tif') - outfile = tmp_path.joinpath('kant_aufklaerung_1784_0020.xml') - outrenderfile = tmp_path.joinpath('render').joinpath('kant_aufklaerung_1784_0020.png') - outrenderfile.parent.mkdir() - args = [ - '-m', MODELS_OCR, - '-i', str(infile), - '-dx', str(infile.parent), - '-o', str(outfile.parent), - ] - if pytestconfig.getoption('verbose') > 0: - args.extend(['-l', 'DEBUG']) - caplog.set_level(logging.DEBUG) - def only_eynollah(logrec): - return logrec.name == 'eynollah' - runner = CliRunner() - if "-doit" in options: - options.insert(options.index("-doit") + 1, str(outrenderfile.parent)) - with caplog.filtering(only_eynollah): - result = runner.invoke(ocr_cli, args + options, catch_exceptions=False) - assert result.exit_code == 0, result.stdout - logmsgs = [logrec.message for logrec in caplog.records] - # FIXME: ocr has no logging! - #assert any(True for logmsg in logmsgs if logmsg.startswith('???')), logmsgs - assert outfile.exists() - if "-doit" in options: - assert outrenderfile.exists() - #in_tree = page_from_file(str(infile)).etree - #in_order = in_tree.xpath("//page:OrderedGroup//@regionRef", namespaces=NS) - out_tree = page_from_file(str(outfile)).etree - out_texts = out_tree.xpath("//page:TextLine/page:TextEquiv[last()]/page:Unicode/text()", namespaces=NS) - assert len(out_texts) >= 2, ("result is inaccurate", out_texts) - assert sum(map(len, out_texts)) > 100, ("result is inaccurate", out_texts) - -def test_run_eynollah_ocr_directory(tmp_path, pytestconfig, caplog): - indir = testdir.joinpath('resources') - outdir = tmp_path - args = [ - '-m', MODELS_OCR, - '-di', str(indir), - '-dx', str(indir), - '-o', str(outdir), - ] - if pytestconfig.getoption('verbose') > 0: - args.extend(['-l', 'DEBUG']) - caplog.set_level(logging.INFO) - def only_eynollah(logrec): - return logrec.name == 'eynollah' - runner = CliRunner() - with caplog.filtering(only_eynollah): - result = runner.invoke(ocr_cli, args, catch_exceptions=False) - assert result.exit_code == 0, result.stdout - logmsgs = [logrec.message for logrec in caplog.records] - # FIXME: ocr has no logging! - #assert any(True for logmsg in logmsgs if logmsg.startswith('???')), logmsgs - assert len(list(outdir.iterdir())) == 2 diff --git a/train/README.md b/train/README.md deleted file mode 100644 index 5f6d326..0000000 --- a/train/README.md +++ /dev/null @@ -1,59 +0,0 @@ -# Training eynollah - -This README explains the technical details of how to set up and run training, for detailed information on parameterization, see [`docs/train.md`](../docs/train.md) - -## Introduction - -This folder contains the source code for training an encoder model for document image segmentation. - -## Installation - -Clone the repository and install eynollah along with the dependencies necessary for training: - -```sh -git clone https://github.com/qurator-spk/eynollah -cd eynollah -pip install '.[training]' -``` - -### Pretrained encoder - -Download our pretrained weights and add them to a `train/pretrained_model` folder: - -```sh -cd train -wget -O pretrained_model.tar.gz https://zenodo.org/records/17243320/files/pretrained_model_v0_5_1.tar.gz?download=1 -tar xf pretrained_model.tar.gz -``` - -### Binarization training data - -A small sample of training data for binarization experiment can be found [on -zenodo](https://zenodo.org/records/17243320/files/training_data_sample_binarization_v0_5_1.tar.gz?download=1), -which contains `images` and `labels` folders. - -### Helpful tools - -* [`pagexml2img`](https://github.com/qurator-spk/page2img) -> Tool to extract 2-D or 3-D RGB images from PAGE-XML data. In the former case, the output will be 1 2-D image array which each class has filled with a pixel value. In the case of a 3-D RGB image, -each class will be defined with a RGB value and beside images, a text file of classes will also be produced. -* [`cocoSegmentationToPng`](https://github.com/nightrome/cocostuffapi/blob/17acf33aef3c6cc2d6aca46dcf084266c2778cf0/PythonAPI/pycocotools/cocostuffhelper.py#L130) -> Convert COCO GT or results for a single image to a segmentation map and write it to disk. -* [`ocrd-segment-extract-pages`](https://github.com/OCR-D/ocrd_segment/blob/master/ocrd_segment/extract_pages.py) -> Extract region classes and their colours in mask (pseg) images. Allows the color map as free dict parameter, and comes with a default that mimics PageViewer's coloring for quick debugging; it also warns when regions do overlap. - -### Train using Docker - -Build the Docker image: - -```bash -cd train -docker build -t model-training . -``` - -Run Docker image - -```bash -cd train -docker run --gpus all -v $PWD:/entry_point_dir model-training -``` diff --git a/train/config_params.json b/train/config_params.json index 1db8026..b01ac08 100644 --- a/train/config_params.json +++ b/train/config_params.json @@ -1,31 +1,50 @@ { "backbone_type" : "transformer", - "task": "segmentation", + "task": "cnn-rnn-ocr", "n_classes" : 2, - "n_epochs" : 0, - "input_height" : 448, - "input_width" : 448, + "max_len": 280, + "n_epochs" : 3, + "input_height" : 32, + "input_width" : 512, "weight_decay" : 1e-6, - "n_batch" : 1, - "learning_rate": 1e-4, + "n_batch" : 4, + "learning_rate": 1e-5, + "save_interval": 1500, "patches" : false, "pretraining" : true, "augmentation" : true, "flip_aug" : false, - "blur_aug" : false, + "blur_aug" : true, "scaling" : false, "adding_rgb_background": true, "adding_rgb_foreground": true, - "add_red_textlines": false, - "channels_shuffling": false, - "degrading": false, - "brightening": false, + "add_red_textlines": true, + "white_noise_strap": true, + "textline_right_in_depth": true, + "textline_left_in_depth": true, + "textline_up_in_depth": true, + "textline_down_in_depth": true, + "textline_right_in_depth_bin": true, + "textline_left_in_depth_bin": true, + "textline_up_in_depth_bin": true, + "textline_down_in_depth_bin": true, + "bin_deg": true, + "textline_skewing": true, + "textline_skewing_bin": true, + "channels_shuffling": true, + "degrading": true, + "brightening": true, "binarization" : true, + "pepper_aug": true, + "pepper_bin_aug": true, + "image_inversion": true, "scaling_bluring" : false, "scaling_binarization" : false, "scaling_flip" : false, "rotation": false, - "rotation_not_90": false, + "color_padding_rotation": true, + "padding_white": true, + "rotation_not_90": true, "transformer_num_patches_xy": [56, 56], "transformer_patchsize_x": 4, "transformer_patchsize_y": 4, @@ -34,13 +53,18 @@ "transformer_layers": 1, "transformer_num_heads": 1, "transformer_cnn_first": false, - "blur_k" : ["blur","guass","median"], + "blur_k" : ["blur","gauss","median"], + "padd_colors" : ["white", "black"], "scales" : [0.6, 0.7, 0.8, 0.9], "brightness" : [1.3, 1.5, 1.7, 2], "degrade_scales" : [0.2, 0.4], + "pepper_indexes": [0.01, 0.005], + "skewing_amplitudes" : [5, 8], "flip_index" : [0, 1, -1], "shuffle_indexes" : [ [0,2,1], [1,2,0], [1,0,2] , [2,1,0]], - "thetha" : [5, -5], + "thetha" : [0.1, 0.2, -0.1, -0.2], + "thetha_padd": [-0.6, -1, -1.4, -1.8, 0.6, 1, 1.4, 1.8], + "white_padds" : [0.1, 0.3, 0.5, 0.7, 0.9], "number_of_backgrounds_per_image": 2, "continue_training": false, "index_start" : 0, @@ -48,11 +72,12 @@ "weighted_loss": false, "is_loss_soft_dice": false, "data_is_provided": false, - "dir_train": "/home/vahid/Documents/test/sbb_pixelwise_segmentation/test_label/pageextractor_test/train_new", + "dir_train": "/home/vahid/extracted_lines/1919_bin/train", "dir_eval": "/home/vahid/Documents/test/sbb_pixelwise_segmentation/test_label/pageextractor_test/eval_new", - "dir_output": "/home/vahid/Documents/test/sbb_pixelwise_segmentation/test_label/pageextractor_test/output_new", + "dir_output": "/home/vahid/extracted_lines/1919_bin/output", "dir_rgb_backgrounds": "/home/vahid/Documents/1_2_test_eynollah/set_rgb_background", "dir_rgb_foregrounds": "/home/vahid/Documents/1_2_test_eynollah/out_set_rgb_foreground", - "dir_img_bin": "/home/vahid/Documents/test/sbb_pixelwise_segmentation/test_label/pageextractor_test/train_new/images_bin" + "dir_img_bin": "/home/vahid/extracted_lines/1919_bin/images_bin", + "characters_txt_file":"/home/vahid/Downloads/models_eynollah/model_eynollah_ocr_cnnrnn_20250930/characters_org.txt" }