This commit is contained in:
Robert Sachunsky 2026-02-17 17:39:56 +00:00 committed by GitHub
commit 8162f64297
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
30 changed files with 3173 additions and 4200 deletions

View file

@ -47,9 +47,9 @@ on how to generate the corresponding training dataset.
The following three tasks can all be accomplished using the code in the The following three tasks can all be accomplished using the code in the
[`train`](https://github.com/qurator-spk/eynollah/tree/main/train) directory: [`train`](https://github.com/qurator-spk/eynollah/tree/main/train) directory:
* generate training dataset * [Generate training dataset](#generate-training-dataset)
* train a model * [Train a model](#train-a-model)
* inference with the trained model * [Inference with the trained model](#inference-with-the-trained-model)
## Training, evaluation and output ## Training, evaluation and output
@ -101,7 +101,7 @@ serve as labels. The enhancement model can be trained with this generated datase
For machine-based reading order, we aim to determine the reading priority between two sets of text regions. The model's For machine-based reading order, we aim to determine the reading priority between two sets of text regions. The model's
input is a three-channel image: the first and last channels contain information about each of the two text regions, input is a three-channel image: the first and last channels contain information about each of the two text regions,
while the middle channel encodes prominent layout elements necessary for reading order, such as separators and headers. while the middle channel encodes prominent layout elements necessary for reading order, such as separators and headers.
To generate the training dataset, our script requires a page XML file that specifies the image layout with the correct To generate the training dataset, our script requires a PAGE XML file that specifies the image layout with the correct
reading order. reading order.
For output images, it is necessary to specify the width and height. Additionally, a minimum text region size can be set For output images, it is necessary to specify the width and height. Additionally, a minimum text region size can be set
@ -120,8 +120,14 @@ eynollah-training generate-gt machine-based-reading-order \
### pagexml2label ### pagexml2label
pagexml2label is designed to generate labels from GT page XML files for various pixel-wise segmentation use cases, `pagexml2label` is designed to generate labels from PAGE XML GT files for various pixel-wise segmentation use cases,
including 'layout,' 'textline,' 'printspace,' 'glyph,' and 'word' segmentation. including:
- `printspace` (i.e. page frame),
- `layout` (i.e. regions),
- `textline`,
- `word`, and
- `glyph`.
To train a pixel-wise segmentation model, we require images along with their corresponding labels. Our training script To train a pixel-wise segmentation model, we require images along with their corresponding labels. Our training script
expects a PNG image where each pixel corresponds to a label, represented by an integer. The background is always labeled expects a PNG image where each pixel corresponds to a label, represented by an integer. The background is always labeled
as zero, while other elements are assigned different integers. For instance, if we have ground truth data with four as zero, while other elements are assigned different integers. For instance, if we have ground truth data with four
@ -131,7 +137,7 @@ In binary segmentation scenarios such as textline or page extraction, the backgr
element is automatically encoded as 1 in the PNG label. element is automatically encoded as 1 in the PNG label.
To specify the desired use case and the elements to be extracted in the PNG labels, a custom JSON file can be passed. To specify the desired use case and the elements to be extracted in the PNG labels, a custom JSON file can be passed.
For example, in the case of 'textline' detection, the JSON file would resemble this: For example, in the case of textline detection, the JSON contents could be this:
```yaml ```yaml
{ {
@ -139,61 +145,77 @@ For example, in the case of 'textline' detection, the JSON file would resemble t
} }
``` ```
In the case of layout segmentation a custom config json file can look like this: In the case of layout segmentation, the config JSON file might look like this:
```yaml ```yaml
{ {
"use_case": "layout", "use_case": "layout",
"textregions":{"rest_as_paragraph":1 , "drop-capital": 1, "header":2, "heading":2, "marginalia":3}, "textregions": {"rest_as_paragraph": 1, "drop-capital": 1, "header": 2, "heading": 2, "marginalia": 3},
"imageregion":4, "imageregion": 4,
"separatorregion":5, "separatorregion": 5,
"graphicregions" :{"rest_as_decoration":6 ,"stamp":7} "graphicregions": {"rest_as_decoration": 6, "stamp": 7}
} }
``` ```
A possible custom config json file for layout segmentation where the "printspace" is a class: The same example if `PrintSpace` (or `Border`) should be represented as a unique class:
```yaml ```yaml
{ {
"use_case": "layout", "use_case": "layout",
"textregions":{"rest_as_paragraph":1 , "drop-capital": 1, "header":2, "heading":2, "marginalia":3}, "textregions": {"rest_as_paragraph": 1, "drop-capital": 1, "header": 2, "heading": 2, "marginalia": 3},
"imageregion":4, "imageregion": 4,
"separatorregion":5, "separatorregion": 5,
"graphicregions" :{"rest_as_decoration":6 ,"stamp":7} "graphicregions": {"rest_as_decoration": 6, "stamp": 7}
"printspace_as_class_in_layout" : 8 "printspace_as_class_in_layout": 8
} }
``` ```
For the layout use case, it is beneficial to first understand the structure of the page XML file and its elements. In the `layout` use-case, it is beneficial to first understand the structure of the PAGE XML file and its elements.
In a given image, the annotations of elements are recorded in a page XML file, including their contours and classes. For a given page image, the visible segments are annotated in XML with their polygon coordinates and types.
For an image document, the known regions are 'textregion', 'separatorregion', 'imageregion', 'graphicregion', On the region level, available segment types include `TextRegion`, `SeparatorRegion`, `ImageRegion`, `GraphicRegion`,
'noiseregion', and 'tableregion'. `NoiseRegion` and `TableRegion`.
Text regions and graphic regions also have their own specific types. The known types for text regions are 'paragraph', Moreover, text regions and graphic regions in particular are subdivided via `@type`:
'header', 'heading', 'marginalia', 'drop-capital', 'footnote', 'footnote-continued', 'signature-mark', 'page-number', - The allowed subtypes for text regions are `paragraph`, `heading`, `marginalia`, `drop-capital`, `header`, `footnote`,
and 'catch-word'. The known types for graphic regions are 'handwritten-annotation', 'decoration', 'stamp', and `footnote-continued`, `signature-mark`, `page-number` and `catch-word`.
'signature'. - The known subtypes for graphic regions are `handwritten-annotation`, `decoration`, `stamp` and `signature`.
Since we don't know all types of text and graphic regions, unknown cases can arise. To handle these, we have defined
two additional types, "rest_as_paragraph" and "rest_as_decoration", to ensure that no unknown types are missed.
This way, users can extract all known types from the labels and be confident that no unknown types are overlooked.
In the custom JSON file shown above, "header" and "heading" are extracted as the same class, while "marginalia" is shown These types and subtypes must be mapped to classes for the segmentation model. However, sometimes these fine-grained
as a different class. All other text region types, including "drop-capital," are grouped into the same class. For the distinctions are not useful or the existing annotations are not very usable (too scarce or too unreliable).
graphic region, "stamp" has its own class, while all other types are classified together. "Image region" and "separator In that case, instead of these subtypes with a specific mapping, they can be pooled together by using the two special
region" are also present in the label. However, other regions like "noise region" and "table region" will not be types:
included in the label PNG file, even if they have information in the page XML files, as we chose not to include them. - `rest_as_paragraph` (mapping missing TextRegion subtypes and `paragraph`)
- `rest_as_decoration` (mapping missing GraphicRegion subtypes and `decoration`)
(That way, users can extract all known types from the labels and be confident that no subtypes are overlooked.)
In the custom JSON example shown above, `header` and `heading` are extracted as the same class,
while `marginalia` is modelled as a different class. All other text region types, including `drop-capital`,
are grouped into the same class. For graphic regions, `stamp` has its own class, while all other types
are classified together. `ImageRegion` and `SeparatorRegion` will also represented with a class label in the
training data. However, other regions like `NoiseRegion` or `TableRegion` will not be included in the PNG files,
even if they were present in the PAGE XML.
The tool expects various command-line options:
```sh ```sh
eynollah-training generate-gt pagexml2label \ eynollah-training generate-gt pagexml2label \
-dx "dir of GT xml files" \ -dx "dir of input PAGE XML files" \
-do "dir where output label png files will be written" \ -do "dir of output label PNG files" \
-cfg "custom config json file" \ -cfg "custom config JSON file" \
-to "output type which has 2d and 3d. 2d is used for training and 3d is just to visualise the labels" -to "output type (2d or 3d)"
``` ```
We have also defined an artificial class that can be added to the boundary of text region types or text lines. This key As output type, use
is called "artificial_class_on_boundary." If users want to apply this to certain text regions in the layout use case, - `2d` for training,
the example JSON config file should look like this: - `3d` to just visualise the labels.
We have also defined an artificial class that can be added to (rendered around) the boundary
of text region types or text lines in order to make separation of neighbouring segments more
reliable. The key is called `artificial_class_on_boundary`, and it takes a list of text region
types to be applied to.
Our example JSON config file could then look like this:
```yaml ```yaml
{ {
@ -215,14 +237,15 @@ the example JSON config file should look like this:
} }
``` ```
This implies that the artificial class label, denoted by 7, will be present on PNG files and will only be added to the This implies that the artificial class label (denoted by 7) will be present in the generated PNG files
elements labeled as "paragraph," "header," "heading," and "marginalia." and will only be added around segments labeled `paragraph`, `header`, `heading` or `marginalia`. (This
class will be handled specially during decoding at inference, and not show up in final results.)
For "textline", "word", and "glyph", the artificial class on the boundaries will be activated only if the For `printspace`, `textline`, `word`, and `glyph` segmentation use-cases, there is no `artificial_class_on_boundary` key,
"artificial_class_label" key is specified in the config file. Its value should be set as 2 since these elements but `artificial_class_label` is available. If specified in the config file, then its value should be set at 2, because
represent binary cases. For example, if the background and textline are denoted as 0 and 1 respectively, then the these elements represent binary classification problems (with background represented as 0, and segments as 1, respectively).
artificial class should be assigned the value 2. The example JSON config file should look like this for "textline" use
case: For example, the JSON config for textline detection could look as follows:
```yaml ```yaml
{ {
@ -231,33 +254,33 @@ case:
} }
``` ```
If the coordinates of "PrintSpace" or "Border" are present in the page XML ground truth files, and the user wishes to If the coordinates of `PrintSpace` (or `Border`) are present in the PAGE XML ground truth files,
crop only the print space area, this can be achieved by activating the "-ps" argument. However, it should be noted that and one wishes to crop images to only cover the print space bounding box, this can be achieved
in this scenario, since cropping will be applied to the label files, the directory of the original images must be by passing the `-ps` option. Note that in this scenario, the directory of the original images
provided to ensure that they are cropped in sync with the labels. This ensures that the correct images and labels must also be provided, to ensure that the images are cropped in sync with the labels. The command
required for training are obtained. The command should resemble the following: line would then resemble this:
```sh ```sh
eynollah-training generate-gt pagexml2label \ eynollah-training generate-gt pagexml2label \
-dx "dir of GT xml files" \ -dx "dir of input PAGE XML files" \
-do "dir where output label png files will be written" \ -do "dir of output label PNG files" \
-cfg "custom config json file" \ -cfg "custom config JSON file" \
-to "output type which has 2d and 3d. 2d is used for training and 3d is just to visualise the labels" \ -to "output type (2d or 3d)" \
-ps \ -ps \
-di "dir where the org images are located" \ -di "dir of input original images" \
-doi "dir where the cropped output images will be written" -doi "dir of output cropped images"
``` ```
## Train a model ## Train a model
### classification ### classification
For the classification use case, we haven't provided a ground truth generator, as it's unnecessary. For classification, For the image classification use-case, we have not provided a ground truth generator, as it is unnecessary.
all we require is a training directory with subdirectories, each containing images of its respective classes. We need All we require is a training directory with subdirectories, each containing images of its respective classes. We need
separate directories for training and evaluation, and the class names (subdirectories) must be consistent across both separate directories for training and evaluation, and the class names (subdirectories) must be consistent across both
directories. Additionally, the class names should be specified in the config JSON file, as shown in the following directories. Additionally, the class names should be specified in the config JSON file, as shown in the following
example. If, for instance, we aim to classify "apple" and "orange," with a total of 2 classes, the example. If, for instance, we aim to classify "apple" and "orange," with a total of 2 classes, the
"classification_classes_name" key in the config file should appear as follows: `classification_classes_name` key in the config file should appear as follows:
```yaml ```yaml
{ {
@ -279,7 +302,7 @@ example. If, for instance, we aim to classify "apple" and "orange," with a total
} }
``` ```
The "dir_train" should be like this: Then `dir_train` should be like this:
``` ```
. .
@ -288,7 +311,7 @@ The "dir_train" should be like this:
└── orange # directory of images for orange class └── orange # directory of images for orange class
``` ```
And the "dir_eval" the same structure as train directory: And `dir_eval` analogously:
``` ```
. .
@ -348,7 +371,7 @@ And the "dir_eval" the same structure as train directory:
└── labels # directory of labels └── labels # directory of labels
``` ```
The classification model can be trained like the classification case command line. The reading-order model can be trained like the classification case command line.
### Segmentation (Textline, Binarization, Page extraction and layout) and enhancement ### Segmentation (Textline, Binarization, Page extraction and layout) and enhancement
@ -358,51 +381,17 @@ The following parameter configuration can be applied to all segmentation use cas
its sub-parameters, and continued training are defined only for segmentation use cases and enhancements, not for its sub-parameters, and continued training are defined only for segmentation use cases and enhancements, not for
classification and machine-based reading order, as you can see in their example config files. classification and machine-based reading order, as you can see in their example config files.
* `backbone_type`: For segmentation tasks (such as text line, binarization, and layout detection) and enhancement, we * `task`: The task parameter must be one of the following values:
offer two backbone options: a "nontransformer" and a "transformer" backbone. For the "transformer" backbone, we first - `binarization`,
apply a CNN followed by a transformer. In contrast, the "nontransformer" backbone utilizes only a CNN ResNet-50. - `enhancement`,
* `task`: The task parameter can have values such as "segmentation", "enhancement", "classification", and "reading_order". - `segmentation`,
* `patches`: If you want to break input images into smaller patches (input size of the model) you need to set this - `classification`,
* parameter to `true`. In the case that the model should see the image once, like page extraction, patches should be - `reading_order`.
set to ``false``. * `backbone_type`: For the tasks `segmentation` (such as text line, and region layout detection),
* `n_batch`: Number of batches at each iteration. `binarization` and `enhancement`, we offer two backbone options:
* `n_classes`: Number of classes. In the case of binary classification this should be 2. In the case of reading_order it - `nontransformer` (only a CNN ResNet-50).
should set to 1. And for the case of layout detection just the unique number of classes should be given. - `transformer` (first apply a CNN, followed by a transformer)
* `n_epochs`: Number of epochs. * `transformer_cnn_first`: Whether to apply the CNN first (followed by the transformer) when using `transformer` backbone.
* `input_height`: This indicates the height of model's input.
* `input_width`: This indicates the width of model's input.
* `weight_decay`: Weight decay of l2 regularization of model layers.
* `pretraining`: Set to `true` to load pretrained weights of ResNet50 encoder. The downloaded weights should be saved
in a folder named "pretrained_model" in the same directory of "train.py" script.
* `augmentation`: If you want to apply any kind of augmentation this parameter should first set to `true`.
* `flip_aug`: If `true`, different types of filp will be applied on image. Type of flips is given with "flip_index" parameter.
* `blur_aug`: If `true`, different types of blurring will be applied on image. Type of blurrings is given with "blur_k" parameter.
* `scaling`: If `true`, scaling will be applied on image. Scale of scaling is given with "scales" parameter.
* `degrading`: If `true`, degrading will be applied to the image. The amount of degrading is defined with "degrade_scales" parameter.
* `brightening`: If `true`, brightening will be applied to the image. The amount of brightening is defined with "brightness" parameter.
* `rotation_not_90`: If `true`, rotation (not 90 degree) will be applied on image. Rotation angles are given with "thetha" parameter.
* `rotation`: If `true`, 90 degree rotation will be applied on image.
* `binarization`: If `true`,Otsu thresholding will be applied to augment the input data with binarized images.
* `scaling_bluring`: If `true`, combination of scaling and blurring will be applied on image.
* `scaling_binarization`: If `true`, combination of scaling and binarization will be applied on image.
* `scaling_flip`: If `true`, combination of scaling and flip will be applied on image.
* `flip_index`: Type of flips.
* `blur_k`: Type of blurrings.
* `scales`: Scales of scaling.
* `brightness`: The amount of brightenings.
* `thetha`: Rotation angles.
* `degrade_scales`: The amount of degradings.
* `continue_training`: If `true`, it means that you have already trained a model and you would like to continue the
training. So it is needed to providethe dir of trained model with "dir_of_start_model" and index for naming
themodels. For example if you have already trained for 3 epochs then your lastindex is 2 and if you want to continue
from model_1.h5, you can set `index_start` to 3 to start naming model with index 3.
* `weighted_loss`: If `true`, this means that you want to apply weighted categorical_crossentropy as loss fucntion. Be carefull if you set to `true`the parameter "is_loss_soft_dice" should be ``false``
* `data_is_provided`: If you have already provided the input data you can set this to `true`. Be sure that the train
and eval data are in"dir_output".Since when once we provide training data we resize and augmentthem and then wewrite
them in sub-directories train and eval in "dir_output".
* `dir_train`: This is the directory of "images" and "labels" (dir_train should include two subdirectories with names of images and labels ) for raw images and labels. Namely they are not prepared (not resized and not augmented) yet for training the model. When we run this tool these raw data will be transformed to suitable size needed for the model and they will be written in "dir_output" in train and eval directories. Each of train and eval include "images" and "labels" sub-directories.
* `index_start`: Starting index for saved models in the case that "continue_training" is `true`.
* `dir_of_start_model`: Directory containing pretrained model to continue training the model in the case that "continue_training" is `true`.
* `transformer_num_patches_xy`: Number of patches for vision transformer in x and y direction respectively. * `transformer_num_patches_xy`: Number of patches for vision transformer in x and y direction respectively.
* `transformer_patchsize_x`: Patch size of vision transformer patches in x direction. * `transformer_patchsize_x`: Patch size of vision transformer patches in x direction.
* `transformer_patchsize_y`: Patch size of vision transformer patches in y direction. * `transformer_patchsize_y`: Patch size of vision transformer patches in y direction.
@ -410,11 +399,63 @@ classification and machine-based reading order, as you can see in their example
* `transformer_mlp_head_units`: Transformer Multilayer Perceptron (MLP) head units. Default value is [128, 64]. * `transformer_mlp_head_units`: Transformer Multilayer Perceptron (MLP) head units. Default value is [128, 64].
* `transformer_layers`: transformer layers. Default value is 8. * `transformer_layers`: transformer layers. Default value is 8.
* `transformer_num_heads`: Transformer number of heads. Default value is 4. * `transformer_num_heads`: Transformer number of heads. Default value is 4.
* `transformer_cnn_first`: We have two types of vision transformers. In one type, a CNN is applied first, followed by a transformer. In the other type, this order is reversed. If transformer_cnn_first is true, it means the CNN will be applied before the transformer. Default value is true. * `patches`: Whether to break up (tile) input images into smaller patches (input size of the model).
If `false`, the model will see the image once (resized to the input size of the model).
Should be set to `false` for cases like page extraction.
* `n_batch`: Number of batches at each iteration.
* `n_classes`: Number of classes. In the case of binary classification this should be 2. In the case of reading_order it
should set to 1. And for the case of layout detection just the unique number of classes should be given.
* `n_epochs`: Number of epochs (iterations over the data) to train.
* `input_height`: the image height for the model's input.
* `input_width`: the image width for the model's input.
* `weight_decay`: Weight decay of l2 regularization of model layers.
* `weighted_loss`: If `true`, this means that you want to apply weighted categorical crossentropy as loss function.
(Mutually exclusive with `is_loss_soft_dice`, and only applies for `segmentation` and `binarization` tasks.)
* `pretraining`: Set to `true` to (download and) initialise pretrained weights of ResNet50 encoder.
* `dir_train`: Path to directory of raw training data (as extracted via `pagexml2labels`, i.e. with subdirectories
`images` and `labels` for input images and output labels.
(These are not prepared for training the model, yet. Upon first run, the raw data will be transformed to suitable size
needed for the model, and written in `dir_output` under `train` and `eval` subdirectories. See `data_is_provided`.)
* `dir_eval`: Ditto for raw evaluation data.
* `dir_output`: Directory to write model checkpoints, logs (for Tensorboard) and precomputed images to.
* `data_is_provided`: If you have already trained at least one complete epoch (using the same data settings) before,
you can set this to `true` to avoid computing the resized / patched / augmented image files again.
Be sure that there are subdirectories `train` and `eval` data are in `dir_output` (each with subdirectories `images`
and `labels`, respectively).
* `continue_training`: If `true`, continue training a model checkpoint from a previous run.
This requires providing the directory of the model checkpoint to load via `dir_of_start_model`
and setting `index_start` counter for naming new checkpoints.
For example if you have already trained for 3 epochs, then your last index is 2, so if you want
to continue with `model_04`, `model_05` etc., set `index_start=3`.
* `index_start`: Starting index for saving models in the case that `continue_training` is `true`.
(Existing checkpoints above this will be overwritten.)
* `dir_of_start_model`: Directory containing existing model checkpoint to initialise model weights from when `continue_training=true`.
(Can be an epoch-interval checkpoint, or batch-interval checkpoint from `save_interval`.)
* `augmentation`: If you want to apply any kind of augmentation this parameter should first set to `true`.
The remaining settings pertain to that...
* `flip_aug`: If `true`, different types of flipping over the image arrays. Requires `flip_index` parameter.
* `flip_index`: List of flip codes (as in `cv2.flip`, i.e. 0 for vertical, positive for horizontal shift, negative for vertical and horizontal shift).
* `blur_aug`: If `true`, different types of blurring will be applied on image. Requires `blur_k` parameter.
* `blur_k`: Method of blurring (`gauss`, `median` or `blur`).
* `scaling`: If `true`, scaling will be applied on image. Requires `scales` parameter.
* `scales`: List of scale factors for scaling.
* `scaling_bluring`: If `true`, combination of scaling and blurring will be applied on image.
* `scaling_binarization`: If `true`, combination of scaling and binarization will be applied on image.
* `scaling_flip`: If `true`, combination of scaling and flip will be applied on image.
* `degrading`: If `true`, degrading will be applied to the image. Requires `degrade_scales` parameter.
* `degrade_scales`: List of intensity factors for degrading.
* `brightening`: If `true`, brightening will be applied to the image. Requires `brightness` parameter.
* `brightness`: List of intensity factors for brightening.
* `binarization`: If `true`, Otsu thresholding will be applied to augment the input data with binarized images.
* `dir_img_bin`: With `binarization`, use this directory to read precomputed binarized images instead of ad-hoc Otsu.
(Base names should correspond to the files in `dir_train/images`.)
* `rotation`: If `true`, 90° rotation will be applied on images.
* `rotation_not_90`: If `true`, random rotation (other than 90°) will be applied on image. Requires `thetha` parameter.
* `thetha`: List of rotation angles (in degrees).
In the case of segmentation and enhancement the train and evaluation directory should be as following. In case of segmentation and enhancement the train and evaluation data should be organised as follows.
The "dir_train" should be like this: The "dir_train" directory should be like this:
``` ```
. .
@ -432,11 +473,12 @@ And the "dir_eval" the same structure as train directory:
└── labels # directory of labels └── labels # directory of labels
``` ```
After configuring the JSON file for segmentation or enhancement, training can be initiated by running the following After configuring the JSON file for segmentation or enhancement,
command, similar to the process for classification and reading order: training can be initiated by running the following command line,
similar to classification and reading-order model training:
``` ```sh
eynollah-training train with config_classification.json` eynollah-training train with config_classification.json
``` ```
#### Binarization #### Binarization
@ -728,7 +770,7 @@ This will straightforwardly return the class of the image.
### machine based reading order ### machine based reading order
To infer the reading order using a reading order model, we need a page XML file containing layout information but To infer the reading order using a reading order model, we need a PAGE XML file containing layout information but
without the reading order. We simply need to provide the model directory, the XML file, and the output directory. The without the reading order. We simply need to provide the model directory, the XML file, and the output directory. The
new XML file with the added reading order will be written to the output directory with the same name. We need to run: new XML file with the added reading order will be written to the output directory with the same name. We need to run:

View file

@ -1,8 +1,9 @@
# ocrd includes opencv, numpy, shapely, click # ocrd includes opencv, numpy, shapely, click
ocrd >= 3.3.0 ocrd >= 3.3.0
numpy <1.24.0 numpy < 2.0
scikit-learn >= 0.23.2 scikit-learn >= 0.23.2
tensorflow < 2.13 tensorflow
tf-keras # avoid keras 3 (also needs TF_USE_LEGACY_KERAS=1)
numba <= 0.58.1 numba <= 0.58.1
scikit-image scikit-image
biopython biopython

View file

@ -2,14 +2,12 @@
# this must be the first import of the CLI! # this must be the first import of the CLI!
from ..eynollah_imports import imported_libs from ..eynollah_imports import imported_libs
from .cli_models import models_cli
from .cli_binarize import binarize_cli
from .cli import main from .cli import main
from .cli_binarize import binarize_cli from .cli_binarize import binarize_cli
from .cli_enhance import enhance_cli from .cli_enhance import enhance_cli
from .cli_extract_images import extract_images_cli from .cli_extract_images import extract_images_cli
from .cli_layout import layout_cli from .cli_layout import layout_cli
from .cli_models import models_cli
from .cli_ocr import ocr_cli from .cli_ocr import ocr_cli
from .cli_readingorder import readingorder_cli from .cli_readingorder import readingorder_cli

View file

@ -21,13 +21,20 @@ import click
type=click.Path(file_okay=True, dir_okay=True), type=click.Path(file_okay=True, dir_okay=True),
required=True, required=True,
) )
@click.option(
"--overwrite",
"-O",
help="overwrite (instead of skipping) if output xml exists",
is_flag=True,
)
@click.pass_context @click.pass_context
def binarize_cli( def binarize_cli(
ctx, ctx,
patches, patches,
input_image, input_image,
dir_in, dir_in,
output, output,
overwrite,
): ):
""" """
Binarize images with a ML model Binarize images with a ML model
@ -39,6 +46,7 @@ def binarize_cli(
image_path=input_image, image_path=input_image,
use_patches=patches, use_patches=patches,
output=output, output=output,
dir_in=dir_in dir_in=dir_in,
overwrite=overwrite
) )

View file

@ -116,19 +116,19 @@ class EynollahImageExtractor(Eynollah):
prediction_regions_org = prediction_regions_org[page_coord[0] : page_coord[1], page_coord[2] : page_coord[3]] prediction_regions_org = prediction_regions_org[page_coord[0] : page_coord[1], page_coord[2] : page_coord[3]]
prediction_regions_org=prediction_regions_org[:,:,0] prediction_regions_org=prediction_regions_org[:,:,0]
mask_lines_only = (prediction_regions_org[:,:] ==3)*1 mask_seps_only = (prediction_regions_org[:,:] ==3)*1
mask_texts_only = (prediction_regions_org[:,:] ==1)*1 mask_texts_only = (prediction_regions_org[:,:] ==1)*1
mask_images_only=(prediction_regions_org[:,:] ==2)*1 mask_images_only=(prediction_regions_org[:,:] ==2)*1
polygons_seplines, hir_seplines = return_contours_of_image(mask_lines_only) polygons_seplines, hir_seplines = return_contours_of_image(mask_seps_only)
polygons_seplines = filter_contours_area_of_image( polygons_seplines = filter_contours_area_of_image(
mask_lines_only, polygons_seplines, hir_seplines, max_area=1, min_area=0.00001, dilate=1) mask_seps_only, polygons_seplines, hir_seplines, max_area=1, min_area=0.00001, dilate=1)
polygons_of_only_texts = return_contours_of_interested_region(mask_texts_only,1,0.00001) polygons_of_only_texts = return_contours_of_interested_region(mask_texts_only,1,0.00001)
polygons_of_only_lines = return_contours_of_interested_region(mask_lines_only,1,0.00001) polygons_of_only_seps = return_contours_of_interested_region(mask_seps_only,1,0.00001)
text_regions_p_true = np.zeros(prediction_regions_org.shape) text_regions_p_true = np.zeros(prediction_regions_org.shape)
text_regions_p_true = cv2.fillPoly(text_regions_p_true, pts = polygons_of_only_lines, color=(3,3,3)) text_regions_p_true = cv2.fillPoly(text_regions_p_true, pts = polygons_of_only_seps, color=(3,3,3))
text_regions_p_true[:,:][mask_images_only[:,:] == 1] = 2 text_regions_p_true[:,:][mask_images_only[:,:] == 1] = 2
text_regions_p_true = cv2.fillPoly(text_regions_p_true, pts=polygons_of_only_texts, color=(1,1,1)) text_regions_p_true = cv2.fillPoly(text_regions_p_true, pts=polygons_of_only_texts, color=(1,1,1))
@ -255,24 +255,24 @@ class EynollahImageExtractor(Eynollah):
self.get_regions_light_v_extract_only_images(img_res, num_col_classifier) self.get_regions_light_v_extract_only_images(img_res, num_col_classifier)
pcgts = self.writer.build_pagexml_no_full_layout( pcgts = self.writer.build_pagexml_no_full_layout(
found_polygons_text_region=[], found_polygons_text_region=[],
page_coord=page_coord, page_coord=page_coord,
order_of_texts=[], order_of_texts=[],
all_found_textline_polygons=[], all_found_textline_polygons=[],
all_box_coord=[], all_box_coord=[],
found_polygons_text_region_img=polygons_of_images, found_polygons_text_region_img=polygons_of_images,
found_polygons_marginals_left=[], found_polygons_marginals_left=[],
found_polygons_marginals_right=[], found_polygons_marginals_right=[],
all_found_textline_polygons_marginals_left=[], all_found_textline_polygons_marginals_left=[],
all_found_textline_polygons_marginals_right=[], all_found_textline_polygons_marginals_right=[],
all_box_coord_marginals_left=[], all_box_coord_marginals_left=[],
all_box_coord_marginals_right=[], all_box_coord_marginals_right=[],
slopes=[], slopes=[],
slopes_marginals_left=[], slopes_marginals_left=[],
slopes_marginals_right=[], slopes_marginals_right=[],
cont_page=cont_page, cont_page=cont_page,
polygons_seplines=[], polygons_seplines=[],
found_polygons_tables=[], found_polygons_tables=[],
) )
if self.plotter: if self.plotter:
self.plotter.write_images_into_directory(polygons_of_images, image_page) self.plotter.write_images_into_directory(polygons_of_images, image_page)

File diff suppressed because it is too large Load diff

View file

@ -1,6 +1,9 @@
""" """
Load libraries with possible race conditions once. This must be imported as the first module of eynollah. Load libraries with possible race conditions once. This must be imported as the first module of eynollah.
""" """
import os
os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15
from ocrd_utils import tf_disable_interactive_logs from ocrd_utils import tf_disable_interactive_logs
from torch import * from torch import *
tf_disable_interactive_logs() tf_disable_interactive_logs()

View file

@ -15,11 +15,13 @@ from pathlib import Path
import gc import gc
import cv2 import cv2
from keras.models import Model
import numpy as np import numpy as np
import tensorflow as tf # type: ignore
from skimage.morphology import skeletonize from skimage.morphology import skeletonize
os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15
import tensorflow as tf # type: ignore
from tensorflow.keras.models import Model
from .model_zoo import EynollahModelZoo from .model_zoo import EynollahModelZoo
from .utils.resize import resize_image from .utils.resize import resize_image
from .utils.pil_cv2 import pil2cv from .utils.pil_cv2 import pil2cv

View file

@ -14,10 +14,12 @@ from pathlib import Path
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
import cv2 import cv2
from keras.models import Model
import numpy as np import numpy as np
import statistics import statistics
os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15
import tensorflow as tf import tensorflow as tf
from tensorflow.keras.models import Model
from .model_zoo import EynollahModelZoo from .model_zoo import EynollahModelZoo
from .utils.resize import resize_image from .utils.resize import resize_image

View file

@ -1,16 +1,19 @@
import os
import json import json
import logging import logging
from copy import deepcopy from copy import deepcopy
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, Tuple, Type, Union from typing import Dict, List, Optional, Tuple, Type, Union
os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15
from ocrd_utils import tf_disable_interactive_logs from ocrd_utils import tf_disable_interactive_logs
tf_disable_interactive_logs() tf_disable_interactive_logs()
from keras.layers import StringLookup from tensorflow.keras.layers import StringLookup
from keras.models import Model as KerasModel from tensorflow.keras.models import Model as KerasModel
from keras.models import load_model from tensorflow.keras.models import load_model
from tabulate import tabulate from tabulate import tabulate
from ..patch_encoder import PatchEncoder, Patches from ..patch_encoder import PatchEncoder, Patches
from .specs import EynollahModelSpecSet from .specs import EynollahModelSpecSet
from .default_specs import DEFAULT_MODEL_SPECS from .default_specs import DEFAULT_MODEL_SPECS

View file

@ -28,7 +28,19 @@
"full_layout": { "full_layout": {
"type": "boolean", "type": "boolean",
"default": true, "default": true,
"description": "Try to detect all element subtypes, including drop-caps and headings" "description": "Try to detect all region subtypes, including drop-capital and heading"
},
"light_version": {
"type": "boolean",
"default": true,
"enum": [true],
"description": "ignored (only for backwards-compatibility)"
},
"textline_light": {
"type": "boolean",
"default": true,
"enum": [true],
"description": "ignored (only for backwards-compatibility)"
}, },
"tables": { "tables": {
"type": "boolean", "type": "boolean",
@ -38,12 +50,12 @@
"curved_line": { "curved_line": {
"type": "boolean", "type": "boolean",
"default": false, "default": false,
"description": "try to return contour of textlines instead of just rectangle bounding box. Needs more processing time" "description": "retrieve textline polygons independent of each other (needs more processing time)"
}, },
"ignore_page_extraction": { "ignore_page_extraction": {
"type": "boolean", "type": "boolean",
"default": false, "default": false,
"description": "if this parameter set to true, this tool would ignore page extraction" "description": "if true, do not attempt page frame detection (cropping)"
}, },
"allow_scaling": { "allow_scaling": {
"type": "boolean", "type": "boolean",
@ -58,7 +70,7 @@
"right_to_left": { "right_to_left": {
"type": "boolean", "type": "boolean",
"default": false, "default": false,
"description": "if this parameter set to true, this tool will extract right-to-left reading order." "description": "if true, return reading order in right-to-left reading direction."
}, },
"headers_off": { "headers_off": {
"type": "boolean", "type": "boolean",
@ -123,13 +135,22 @@
} }
}, },
"resources": [ "resources": [
{
"url": "https://zenodo.org/records/17580627/files/models_all_v0_7_0.zip?download=1",
"name": "models_layout_v0_7_0",
"type": "archive",
"size": 6119874002,
"description": "Models for layout detection, reading order detection, textline detection, page extraction, column classification, table detection, binarization, image enhancement and OCR",
"version_range": ">= v0.7.0"
},
{ {
"url": "https://github.com/qurator-spk/sbb_binarization/releases/download/v0.0.11/saved_model_2020_01_16.zip", "url": "https://github.com/qurator-spk/sbb_binarization/releases/download/v0.0.11/saved_model_2020_01_16.zip",
"name": "default", "name": "default",
"type": "archive", "type": "archive",
"path_in_archive": "saved_model_2020_01_16", "path_in_archive": "saved_model_2020_01_16",
"size": 563147331, "size": 563147331,
"description": "default models provided by github.com/qurator-spk (SavedModel format)" "description": "default models provided by github.com/qurator-spk (SavedModel format)",
"version_range": "< v0.7.0"
}, },
{ {
"url": "https://github.com/qurator-spk/sbb_binarization/releases/download/v0.0.11/saved_model_2021_03_09.zip", "url": "https://github.com/qurator-spk/sbb_binarization/releases/download/v0.0.11/saved_model_2021_03_09.zip",
@ -137,7 +158,8 @@
"type": "archive", "type": "archive",
"path_in_archive": ".", "path_in_archive": ".",
"size": 133230419, "size": 133230419,
"description": "updated default models provided by github.com/qurator-spk (SavedModel format)" "description": "updated default models provided by github.com/qurator-spk (SavedModel format)",
"version_range": "< v0.7.0"
} }
] ]
} }

View file

@ -75,7 +75,7 @@ class SbbBinarizeProcessor(Processor):
if oplevel == 'page': if oplevel == 'page':
self.logger.info("Binarizing on 'page' level in page '%s'", page_id) self.logger.info("Binarizing on 'page' level in page '%s'", page_id)
page_image_bin = cv2pil(self.binarizer.run(image=pil2cv(page_image), use_patches=True)) page_image_bin = cv2pil(self.binarizer.run_single(image=pil2cv(page_image), use_patches=True))
# update PAGE (reference the image file): # update PAGE (reference the image file):
page_image_ref = AlternativeImageType(comments=page_xywh['features'] + ',binarized,clipped') page_image_ref = AlternativeImageType(comments=page_xywh['features'] + ',binarized,clipped')
page.add_AlternativeImage(page_image_ref) page.add_AlternativeImage(page_image_ref)
@ -88,7 +88,7 @@ class SbbBinarizeProcessor(Processor):
for region in regions: for region in regions:
region_image, region_xywh = self.workspace.image_from_segment( region_image, region_xywh = self.workspace.image_from_segment(
region, page_image, page_xywh, feature_filter='binarized') region, page_image, page_xywh, feature_filter='binarized')
region_image_bin = cv2pil(self.binarizer.run(image=pil2cv(region_image), use_patches=True)) region_image_bin = cv2pil(self.binarizer.run_single(image=pil2cv(region_image), use_patches=True))
# update PAGE (reference the image file): # update PAGE (reference the image file):
region_image_ref = AlternativeImageType(comments=region_xywh['features'] + ',binarized') region_image_ref = AlternativeImageType(comments=region_xywh['features'] + ',binarized')
region.add_AlternativeImage(region_image_ref) region.add_AlternativeImage(region_image_ref)
@ -100,7 +100,7 @@ class SbbBinarizeProcessor(Processor):
self.logger.warning("Page '%s' contains no text lines", page_id) self.logger.warning("Page '%s' contains no text lines", page_id)
for line in lines: for line in lines:
line_image, line_xywh = self.workspace.image_from_segment(line, page_image, page_xywh, feature_filter='binarized') line_image, line_xywh = self.workspace.image_from_segment(line, page_image, page_xywh, feature_filter='binarized')
line_image_bin = cv2pil(self.binarizer.run(image=pil2cv(line_image), use_patches=True)) line_image_bin = cv2pil(self.binarizer.run_single(image=pil2cv(line_image), use_patches=True))
# update PAGE (reference the image file): # update PAGE (reference the image file):
line_image_ref = AlternativeImageType(comments=line_xywh['features'] + ',binarized') line_image_ref = AlternativeImageType(comments=line_xywh['features'] + ',binarized')
line.add_AlternativeImage(line_image_ref) line.add_AlternativeImage(line_image_ref)

View file

@ -1,52 +1,46 @@
from keras import layers import os
os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15
import tensorflow as tf import tensorflow as tf
from tensorflow.keras import layers
projection_dim = 64
patch_size = 1
num_patches =21*21#14*14#28*28#14*14#28*28
class PatchEncoder(layers.Layer): class PatchEncoder(layers.Layer):
def __init__(self): # 441=21*21 # 14*14 # 28*28
def __init__(self, num_patches=441, projection_dim=64):
super().__init__() super().__init__()
self.projection = layers.Dense(units=projection_dim) self.num_patches = num_patches
self.position_embedding = layers.Embedding(input_dim=num_patches, output_dim=projection_dim) self.projection_dim = projection_dim
self.projection = layers.Dense(self.projection_dim)
self.position_embedding = layers.Embedding(self.num_patches, self.projection_dim)
def call(self, patch): def call(self, patch):
positions = tf.range(start=0, limit=num_patches, delta=1) positions = tf.range(start=0, limit=self.num_patches, delta=1)
encoded = self.projection(patch) + self.position_embedding(positions) return self.projection(patch) + self.position_embedding(positions)
return encoded
def get_config(self): def get_config(self):
config = super().get_config().copy() return dict(num_patches=self.num_patches,
config.update({ projection_dim=self.projection_dim,
'num_patches': num_patches, **super().get_config())
'projection': self.projection,
'position_embedding': self.position_embedding,
})
return config
class Patches(layers.Layer): class Patches(layers.Layer):
def __init__(self, **kwargs): def __init__(self, patch_size_x=1, patch_size_y=1):
super(Patches, self).__init__() super().__init__()
self.patch_size = patch_size self.patch_size_x = patch_size_x
self.patch_size_y = patch_size_y
def call(self, images): def call(self, images):
batch_size = tf.shape(images)[0] batch_size = tf.shape(images)[0]
patches = tf.image.extract_patches( patches = tf.image.extract_patches(
images=images, images=images,
sizes=[1, self.patch_size, self.patch_size, 1], sizes=[1, self.patch_size_y, self.patch_size_x, 1],
strides=[1, self.patch_size, self.patch_size, 1], strides=[1, self.patch_size_y, self.patch_size_x, 1],
rates=[1, 1, 1, 1], rates=[1, 1, 1, 1],
padding="VALID", padding="VALID",
) )
patch_dims = patches.shape[-1] patch_dims = patches.shape[-1]
patches = tf.reshape(patches, [batch_size, -1, patch_dims]) return tf.reshape(patches, [batch_size, -1, patch_dims])
return patches
def get_config(self):
config = super().get_config().copy() def get_config(self):
config.update({ return dict(patch_size_x=self.patch_size_x,
'patch_size': self.patch_size, patch_size_y=self.patch_size_y,
}) **super().get_config())
return config

View file

@ -9,17 +9,18 @@ Tool to load model and binarize a given image.
import os import os
import logging import logging
from pathlib import Path
from typing import Optional from typing import Optional
import numpy as np import numpy as np
import cv2 import cv2
from ocrd_utils import tf_disable_interactive_logs
from eynollah.model_zoo import EynollahModelZoo os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15
from ocrd_utils import tf_disable_interactive_logs
tf_disable_interactive_logs() tf_disable_interactive_logs()
import tensorflow as tf import tensorflow as tf
from tensorflow.python.keras import backend as tensorflow_backend
from pathlib import Path from .model_zoo import EynollahModelZoo
from .utils import is_image_filename from .utils import is_image_filename
def resize_image(img_in, input_height, input_width): def resize_image(img_in, input_height, input_width):
@ -34,21 +35,13 @@ class SbbBinarizer:
logger: Optional[logging.Logger] = None, logger: Optional[logging.Logger] = None,
): ):
self.logger = logger if logger else logging.getLogger('eynollah.binarization') self.logger = logger if logger else logging.getLogger('eynollah.binarization')
try:
for device in tf.config.list_physical_devices('GPU'):
tf.config.experimental.set_memory_growth(device, True)
except:
self.logger.warning("no GPU device available")
self.models = (model_zoo.model_path('binarization'), model_zoo.load_model('binarization')) self.models = (model_zoo.model_path('binarization'), model_zoo.load_model('binarization'))
self.session = self.start_new_session() self.logger.info('Loaded model %s [%s]', self.models[1], self.models[0])
def start_new_session(self):
config = tf.compat.v1.ConfigProto()
config.gpu_options.allow_growth = True
session = tf.compat.v1.Session(config=config) # tf.InteractiveSession()
tensorflow_backend.set_session(session)
return session
def end_session(self):
tensorflow_backend.clear_session()
self.session.close()
del self.session
def predict(self, model, img, use_patches, n_batch_inference=5): def predict(self, model, img, use_patches, n_batch_inference=5):
model_height = model.layers[len(model.layers)-1].output_shape[1] model_height = model.layers[len(model.layers)-1].output_shape[1]
@ -311,34 +304,20 @@ class SbbBinarizer:
prediction_true = prediction_true.astype(np.uint8) prediction_true = prediction_true.astype(np.uint8)
return prediction_true[:,:,0] return prediction_true[:,:,0]
def run(self, image=None, image_path=None, output=None, use_patches=False, dir_in=None): def run(self, image=None, image_path=None, output=None, use_patches=False, dir_in=None, overwrite=False):
# print(dir_in,'dir_in')
if not dir_in: if not dir_in:
if (image is not None and image_path is not None) or \ if (image is None) == (image_path is None):
(image is None and image_path is None):
raise ValueError("Must pass either a opencv2 image or an image_path") raise ValueError("Must pass either a opencv2 image or an image_path")
if image_path is not None: if image_path is not None:
image = cv2.imread(image_path) image = cv2.imread(image_path)
img_last = 0 img_last = self.run_single(image, use_patches)
model_file, model = self.models
self.logger.info('Predicting %s with model %s', image_path if image_path else '[image]', model_file)
res = self.predict(model, image, use_patches)
img_fin = np.zeros((res.shape[0], res.shape[1], 3))
res[:, :][res[:, :] == 0] = 2
res = res - 1
res = res * 255
img_fin[:, :, 0] = res
img_fin[:, :, 1] = res
img_fin[:, :, 2] = res
img_fin = img_fin.astype(np.uint8)
img_fin = (res[:, :] == 0) * 255
img_last = img_last + img_fin
img_last[:, :][img_last[:, :] > 0] = 255
img_last = (img_last[:, :] == 0) * 255
if output: if output:
if os.path.exists(output):
if overwrite:
self.logger.warning("will overwrite existing output file '%s'", output)
else:
self.logger.warning("output file already exists '%s'", output)
return img_last
self.logger.info('Writing binarized image to %s', output) self.logger.info('Writing binarized image to %s', output)
cv2.imwrite(output, img_last) cv2.imwrite(output, img_last)
return img_last return img_last
@ -346,29 +325,38 @@ class SbbBinarizer:
ls_imgs = list(filter(is_image_filename, os.listdir(dir_in))) ls_imgs = list(filter(is_image_filename, os.listdir(dir_in)))
self.logger.info("Found %d image files to binarize in %s", len(ls_imgs), dir_in) self.logger.info("Found %d image files to binarize in %s", len(ls_imgs), dir_in)
for i, image_path in enumerate(ls_imgs): for i, image_path in enumerate(ls_imgs):
image_stem = os.path.splitext(image_path)[0]
output_path = os.path.join(output, image_stem + '.png')
if os.path.exists(output_path):
if overwrite:
self.logger.warning("will overwrite existing output file '%s'", output_path)
else:
self.logger.warning("will skip input for existing output file '%s'", output_path)
continue
self.logger.info('Binarizing [%3d/%d] %s', i + 1, len(ls_imgs), image_path) self.logger.info('Binarizing [%3d/%d] %s', i + 1, len(ls_imgs), image_path)
image_stem = Path(image_path).stem image = cv2.imread(os.path.join(dir_in, image_path))
image = cv2.imread(os.path.join(dir_in,image_path) ) img_last = self.run_single(image, use_patches)
img_last = 0 self.logger.info('Writing binarized image to %s', output_path)
model_file, model = self.models cv2.imwrite(output_path, img_last)
self.logger.info('Predicting %s with model %s', image_path if image_path else '[image]', model_file)
res = self.predict(model, image, use_patches)
img_fin = np.zeros((res.shape[0], res.shape[1], 3)) def run_single(self, image: np.ndarray, use_patches=False):
res[:, :][res[:, :] == 0] = 2 img_last = 0
res = res - 1 model_file, model = self.models
res = res * 255 res = self.predict(model, image, use_patches)
img_fin[:, :, 0] = res
img_fin[:, :, 1] = res
img_fin[:, :, 2] = res
img_fin = img_fin.astype(np.uint8) img_fin = np.zeros((res.shape[0], res.shape[1], 3))
img_fin = (res[:, :] == 0) * 255 res[:, :][res[:, :] == 0] = 2
img_last = img_last + img_fin res = res - 1
res = res * 255
img_fin[:, :, 0] = res
img_fin[:, :, 1] = res
img_fin[:, :, 2] = res
img_last[:, :][img_last[:, :] > 0] = 255 img_fin = img_fin.astype(np.uint8)
img_last = (img_last[:, :] == 0) * 255 img_fin = (res[:, :] == 0) * 255
img_last = img_last + img_fin
output_filename = os.path.join(output, image_stem + '.png')
self.logger.info('Writing binarized image to %s', output_filename) kernel = np.ones((5, 5), np.uint8)
cv2.imwrite(output_filename, img_last) img_last[:, :][img_last[:, :] > 0] = 255
img_last = (img_last[:, :] == 0) * 255
return img_last

View file

@ -1,13 +1,9 @@
import sys
import click import click
import tensorflow as tf
from .models import resnet50_unet from .models import resnet50_unet
def configuration():
gpu_options = tf.compat.v1.GPUOptions(allow_growth=True)
session = tf.compat.v1.Session(config=tf.compat.v1.ConfigProto(gpu_options=gpu_options))
@click.command() @click.command()
def build_model_load_pretrained_weights_and_save(): def build_model_load_pretrained_weights_and_save():
n_classes = 2 n_classes = 2
@ -17,8 +13,6 @@ def build_model_load_pretrained_weights_and_save():
pretraining = False pretraining = False
dir_of_weights = 'model_bin_sbb_ens.h5' dir_of_weights = 'model_bin_sbb_ens.h5'
# configuration()
model = resnet50_unet(n_classes, input_height, input_width, weight_decay, pretraining) model = resnet50_unet(n_classes, input_height, input_width, weight_decay, pretraining)
model.load_weights(dir_of_weights) model.load_weights(dir_of_weights)
model.save('./name_in_another_python_version.h5') model.save('./name_in_another_python_version.h5')

View file

@ -9,7 +9,7 @@ from .generate_gt_for_training import main as generate_gt_cli
from .inference import main as inference_cli from .inference import main as inference_cli
from .train import ex from .train import ex
from .extract_line_gt import linegt_cli from .extract_line_gt import linegt_cli
from .weights_ensembling import main as ensemble_cli from .weights_ensembling import ensemble_cli
@click.command(context_settings=dict( @click.command(context_settings=dict(
ignore_unknown_options=True, ignore_unknown_options=True,

View file

@ -7,7 +7,7 @@ from PIL import Image, ImageDraw, ImageFont
import cv2 import cv2
import numpy as np import numpy as np
from eynollah.training.gt_gen_utils import ( from .gt_gen_utils import (
filter_contours_area_of_image, filter_contours_area_of_image,
find_format_of_given_filename_in_dir, find_format_of_given_filename_in_dir,
find_new_features_of_contours, find_new_features_of_contours,
@ -26,6 +26,9 @@ from eynollah.training.gt_gen_utils import (
@click.group() @click.group()
def main(): def main():
"""
extract GT data suitable for model training for various tasks
"""
pass pass
@main.command() @main.command()
@ -74,6 +77,9 @@ def main():
) )
def pagexml2label(dir_xml,dir_out,type_output,config, printspace, dir_images, dir_out_images): def pagexml2label(dir_xml,dir_out,type_output,config, printspace, dir_images, dir_out_images):
"""
extract PAGE-XML GT data suitable for model training for segmentation tasks
"""
if config: if config:
with open(config) as f: with open(config) as f:
config_params = json.load(f) config_params = json.load(f)
@ -110,6 +116,9 @@ def pagexml2label(dir_xml,dir_out,type_output,config, printspace, dir_images, di
type=click.Path(exists=True, dir_okay=False), type=click.Path(exists=True, dir_okay=False),
) )
def image_enhancement(dir_imgs, dir_out_images, dir_out_labels, scales): def image_enhancement(dir_imgs, dir_out_images, dir_out_labels, scales):
"""
extract image GT data suitable for model training for image enhancement tasks
"""
ls_imgs = os.listdir(dir_imgs) ls_imgs = os.listdir(dir_imgs)
with open(scales) as f: with open(scales) as f:
scale_dict = json.load(f) scale_dict = json.load(f)
@ -175,6 +184,9 @@ def image_enhancement(dir_imgs, dir_out_images, dir_out_labels, scales):
) )
def machine_based_reading_order(dir_xml, dir_out_modal_image, dir_out_classes, input_height, input_width, min_area_size, min_area_early): def machine_based_reading_order(dir_xml, dir_out_modal_image, dir_out_classes, input_height, input_width, min_area_size, min_area_early):
"""
extract PAGE-XML GT data suitable for model training for reading-order task
"""
xml_files_ind = os.listdir(dir_xml) xml_files_ind = os.listdir(dir_xml)
xml_files_ind = [ind_xml for ind_xml in xml_files_ind if ind_xml.endswith('.xml')] xml_files_ind = [ind_xml for ind_xml in xml_files_ind if ind_xml.endswith('.xml')]
input_height = int(input_height) input_height = int(input_height)
@ -205,14 +217,20 @@ def machine_based_reading_order(dir_xml, dir_out_modal_image, dir_out_classes, i
img_header_and_sep = np.zeros((y_len,x_len), dtype='uint8') img_header_and_sep = np.zeros((y_len,x_len), dtype='uint8')
for j in range(len(cy_main)): for j in range(len(cy_main)):
img_header_and_sep[int(y_max_main[j]):int(y_max_main[j])+12,int(x_min_main[j]):int(x_max_main[j]) ] = 1 img_header_and_sep[int(y_max_main[j]):int(y_max_main[j])+12,
int(x_min_main[j]):int(x_max_main[j]) ] = 1
texts_corr_order_index = [index_tot_regions[tot_region_ref.index(i)] for i in id_all_text ] try:
texts_corr_order_index_int = [int(x) for x in texts_corr_order_index] texts_corr_order_index_int = [int(index_tot_regions[tot_region_ref.index(i)])
for i in id_all_text]
except ValueError as e:
print("incomplete ReadingOrder in", xml_file, "- skipping:", str(e))
continue
co_text_all, texts_corr_order_index_int, regions_ar_less_than_early_min = filter_contours_area_of_image(img_poly, co_text_all, texts_corr_order_index_int, max_area, min_area, min_area_early) co_text_all, texts_corr_order_index_int, regions_ar_less_than_early_min = \
filter_contours_area_of_image(img_poly, co_text_all, texts_corr_order_index_int,
max_area, min_area, min_area_early)
arg_array = np.array(range(len(texts_corr_order_index_int))) arg_array = np.array(range(len(texts_corr_order_index_int)))

View file

@ -1,15 +1,18 @@
import os import os
import numpy as np import numpy as np
import warnings import warnings
import xml.etree.ElementTree as ET from lxml import etree as ET
from tqdm import tqdm from tqdm import tqdm
import cv2 import cv2
from shapely import geometry from shapely import geometry
from pathlib import Path from pathlib import Path
from PIL import ImageFont from PIL import ImageFont
from ocrd_utils import bbox_from_points
KERNEL = np.ones((5, 5), np.uint8) KERNEL = np.ones((5, 5), np.uint8)
NS = { 'pc': 'http://schema.primaresearch.org/PAGE/gts/pagecontent/2019-07-15'
}
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter("ignore") warnings.simplefilter("ignore")
@ -235,12 +238,11 @@ def update_region_contours(co_text, img_boundary, erosion_rate, dilation_rate, y
con_eroded = return_contours_of_interested_region(img_boundary_in,pixel, min_size ) con_eroded = return_contours_of_interested_region(img_boundary_in,pixel, min_size )
try: try:
if len(con_eroded)>1: if len(con_eroded) > 1:
cnt_size = np.array([cv2.contourArea(con_eroded[j]) for j in range(len(con_eroded))]) largest = np.argmax(list(map(cv2.contourArea, con_eroded)))
cnt = contours[np.argmax(cnt_size)]
co_text_eroded.append(cnt)
else: else:
co_text_eroded.append(con_eroded[0]) largest = 0
co_text_eroded.append(con_eroded[largest])
except: except:
co_text_eroded.append(con) co_text_eroded.append(con)
@ -664,7 +666,10 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_
if dir_images: if dir_images:
ls_org_imgs = os.listdir(dir_images) ls_org_imgs = os.listdir(dir_images)
ls_org_imgs_stem = [os.path.splitext(item)[0] for item in ls_org_imgs] ls_org_imgs = {os.path.splitext(item)[0]: item
for item in ls_org_imgs
if not item.endswith('.xml')}
for index in tqdm(range(len(gt_list))): for index in tqdm(range(len(gt_list))):
#try: #try:
print(gt_list[index]) print(gt_list[index])
@ -681,6 +686,7 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_
if 'columns_width' in list(config_params.keys()): if 'columns_width' in list(config_params.keys()):
columns_width_dict = config_params['columns_width'] columns_width_dict = config_params['columns_width']
# FIXME: look in /Page/@custom as well
metadata_element = root1.find(link+'Metadata') metadata_element = root1.find(link+'Metadata')
num_col = None num_col = None
for child in metadata_element: for child in metadata_element:
@ -694,55 +700,13 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_
y_new = int ( x_new * (y_len / float(x_len)) ) y_new = int ( x_new * (y_len / float(x_len)) )
if printspace or "printspace_as_class_in_layout" in list(config_params.keys()): if printspace or "printspace_as_class_in_layout" in list(config_params.keys()):
region_tags = np.unique([x for x in alltags if x.endswith('PrintSpace') or x.endswith('Border')]) ps = (root1.xpath('/pc:PcGts/pc:Page/pc:Border', namespaces=NS) +
co_use_case = [] root1.xpath('/pc:PcGts/pc:Page/pc:PrintSpace', namespaces=NS))
if len(ps):
for tag in region_tags: points = ps[0].find('pc:Coords', NS).get('points')
tag_endings = ['}PrintSpace','}Border'] ps_bbox = bbox_from_points(points)
else:
if tag.endswith(tag_endings[0]) or tag.endswith(tag_endings[1]): ps_bbox = [0, 0, None, None]
for nn in root1.iter(tag):
c_t_in = []
sumi = 0
for vv in nn.iter():
# check the format of coords
if vv.tag == link + 'Coords':
coords = bool(vv.attrib)
if coords:
p_h = vv.attrib['points'].split(' ')
c_t_in.append(
np.array([[int(x.split(',')[0]), int(x.split(',')[1])] for x in p_h]))
break
else:
pass
if vv.tag == link + 'Point':
c_t_in.append([int(float(vv.attrib['x'])), int(float(vv.attrib['y']))])
sumi += 1
elif vv.tag != link + 'Point' and sumi >= 1:
break
co_use_case.append(np.array(c_t_in))
img = np.zeros((y_len, x_len, 3))
img_poly = cv2.fillPoly(img, pts=co_use_case, color=(1, 1, 1))
img_poly = img_poly.astype(np.uint8)
imgray = cv2.cvtColor(img_poly, cv2.COLOR_BGR2GRAY)
_, thresh = cv2.threshold(imgray, 0, 255, 0)
contours, _ = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
cnt_size = np.array([cv2.contourArea(contours[j]) for j in range(len(contours))])
try:
cnt = contours[np.argmax(cnt_size)]
x, y, w, h = cv2.boundingRect(cnt)
except:
x, y , w, h = 0, 0, x_len, y_len
bb_xywh = [x, y, w, h]
if config_file and (config_params['use_case']=='textline' or config_params['use_case']=='word' or config_params['use_case']=='glyph' or config_params['use_case']=='printspace'): if config_file and (config_params['use_case']=='textline' or config_params['use_case']=='word' or config_params['use_case']=='glyph' or config_params['use_case']=='printspace'):
@ -824,7 +788,8 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_
if printspace and config_params['use_case']!='printspace': if printspace and config_params['use_case']!='printspace':
img_poly = img_poly[bb_xywh[1]:bb_xywh[1]+bb_xywh[3], bb_xywh[0]:bb_xywh[0]+bb_xywh[2], :] img_poly = img_poly[ps_bbox[1]:ps_bbox[3],
ps_bbox[0]:ps_bbox[2], :]
if 'columns_width' in list(config_params.keys()) and num_col and config_params['use_case']!='printspace': if 'columns_width' in list(config_params.keys()) and num_col and config_params['use_case']!='printspace':
@ -838,11 +803,18 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_
cv2.imwrite(os.path.join(output_dir, xml_file_stem + '.png'), img_poly) cv2.imwrite(os.path.join(output_dir, xml_file_stem + '.png'), img_poly)
if dir_images: if dir_images:
org_image_name = ls_org_imgs[ls_org_imgs_stem.index(xml_file_stem)] org_image_name = ls_org_imgs[xml_file_stem]
if not org_image_name:
print("image file for XML stem", xml_file_stem, "is missing")
continue
if not os.path.isfile(os.path.join(dir_images, org_image_name)):
print("image file for XML stem", xml_file_stem, "is not readable")
continue
img_org = cv2.imread(os.path.join(dir_images, org_image_name)) img_org = cv2.imread(os.path.join(dir_images, org_image_name))
if printspace and config_params['use_case']!='printspace': if printspace and config_params['use_case']!='printspace':
img_org = img_org[bb_xywh[1]:bb_xywh[1]+bb_xywh[3], bb_xywh[0]:bb_xywh[0]+bb_xywh[2], :] img_org = img_org[ps_bbox[1]:ps_bbox[3],
ps_bbox[0]:ps_bbox[2], :]
if 'columns_width' in list(config_params.keys()) and num_col and config_params['use_case']!='printspace': if 'columns_width' in list(config_params.keys()) and num_col and config_params['use_case']!='printspace':
img_org = resize_image(img_org, y_new, x_new) img_org = resize_image(img_org, y_new, x_new)
@ -1254,7 +1226,8 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_
if "printspace_as_class_in_layout" in list(config_params.keys()): if "printspace_as_class_in_layout" in list(config_params.keys()):
printspace_mask = np.zeros((img_poly.shape[0], img_poly.shape[1])) printspace_mask = np.zeros((img_poly.shape[0], img_poly.shape[1]))
printspace_mask[bb_xywh[1]:bb_xywh[1]+bb_xywh[3], bb_xywh[0]:bb_xywh[0]+bb_xywh[2]] = 1 printspace_mask[ps_bbox[1]:ps_bbox[3],
ps_bbox[0]:ps_bbox[2]] = 1
img_poly[:,:,0][printspace_mask[:,:] == 0] = printspace_class_rgb_color[0] img_poly[:,:,0][printspace_mask[:,:] == 0] = printspace_class_rgb_color[0]
img_poly[:,:,1][printspace_mask[:,:] == 0] = printspace_class_rgb_color[1] img_poly[:,:,1][printspace_mask[:,:] == 0] = printspace_class_rgb_color[1]
@ -1315,7 +1288,8 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_
if "printspace_as_class_in_layout" in list(config_params.keys()): if "printspace_as_class_in_layout" in list(config_params.keys()):
printspace_mask = np.zeros((img_poly.shape[0], img_poly.shape[1])) printspace_mask = np.zeros((img_poly.shape[0], img_poly.shape[1]))
printspace_mask[bb_xywh[1]:bb_xywh[1]+bb_xywh[3], bb_xywh[0]:bb_xywh[0]+bb_xywh[2]] = 1 printspace_mask[ps_bbox[1]:ps_bbox[3],
ps_bbox[0]:ps_bbox[2]] = 1
img_poly[:,:,0][printspace_mask[:,:] == 0] = printspace_class_label img_poly[:,:,0][printspace_mask[:,:] == 0] = printspace_class_label
img_poly[:,:,1][printspace_mask[:,:] == 0] = printspace_class_label img_poly[:,:,1][printspace_mask[:,:] == 0] = printspace_class_label
@ -1324,7 +1298,8 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_
if printspace: if printspace:
img_poly = img_poly[bb_xywh[1]:bb_xywh[1]+bb_xywh[3], bb_xywh[0]:bb_xywh[0]+bb_xywh[2], :] img_poly = img_poly[ps_bbox[1]:ps_bbox[3],
ps_bbox[0]:ps_bbox[2], :]
if 'columns_width' in list(config_params.keys()) and num_col: if 'columns_width' in list(config_params.keys()) and num_col:
img_poly = resize_image(img_poly, y_new, x_new) img_poly = resize_image(img_poly, y_new, x_new)
@ -1338,11 +1313,18 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_
if dir_images: if dir_images:
org_image_name = ls_org_imgs[ls_org_imgs_stem.index(xml_file_stem)] org_image_name = ls_org_imgs[xml_file_stem]
if not org_image_name:
print("image file for XML stem", xml_file_stem, "is missing")
continue
if not os.path.isfile(os.path.join(dir_images, org_image_name)):
print("image file for XML stem", xml_file_stem, "is not readable")
continue
img_org = cv2.imread(os.path.join(dir_images, org_image_name)) img_org = cv2.imread(os.path.join(dir_images, org_image_name))
if printspace: if printspace:
img_org = img_org[bb_xywh[1]:bb_xywh[1]+bb_xywh[3], bb_xywh[0]:bb_xywh[0]+bb_xywh[2], :] img_org = img_org[ps_bbox[1]:ps_bbox[3],
ps_bbox[0]:ps_bbox[2], :]
if 'columns_width' in list(config_params.keys()) and num_col: if 'columns_width' in list(config_params.keys()) and num_col:
img_org = resize_image(img_org, y_new, x_new) img_org = resize_image(img_org, y_new, x_new)
@ -1383,6 +1365,7 @@ def find_new_features_of_contours(contours_main):
y_max_main = np.array([np.max(contours_main[j][:, 1]) for j in range(len(contours_main))]) y_max_main = np.array([np.max(contours_main[j][:, 1]) for j in range(len(contours_main))])
return cx_main, cy_main, x_min_main, x_max_main, y_min_main, y_max_main, y_corr_x_min_from_argmin return cx_main, cy_main, x_min_main, x_max_main, y_min_main, y_max_main, y_corr_x_min_from_argmin
def read_xml(xml_file): def read_xml(xml_file):
file_name = Path(xml_file).stem file_name = Path(xml_file).stem
tree1 = ET.parse(xml_file, parser = ET.XMLParser(encoding='utf-8')) tree1 = ET.parse(xml_file, parser = ET.XMLParser(encoding='utf-8'))
@ -1401,57 +1384,13 @@ def read_xml(xml_file):
index_tot_regions.append(jj.attrib['index']) index_tot_regions.append(jj.attrib['index'])
tot_region_ref.append(jj.attrib['regionRef']) tot_region_ref.append(jj.attrib['regionRef'])
if (link+'PrintSpace' in alltags) or (link+'Border' in alltags): ps = (root1.xpath('/pc:PcGts/pc:Page/pc:Border', namespaces=NS) +
co_printspace = [] root1.xpath('/pc:PcGts/pc:Page/pc:PrintSpace', namespaces=NS))
if link+'PrintSpace' in alltags: if len(ps):
region_tags_printspace = np.unique([x for x in alltags if x.endswith('PrintSpace')]) points = ps[0].find('pc:Coords', NS).get('points')
elif link+'Border' in alltags: ps_bbox = bbox_from_points(points)
region_tags_printspace = np.unique([x for x in alltags if x.endswith('Border')])
for tag in region_tags_printspace:
if link+'PrintSpace' in alltags:
tag_endings_printspace = ['}PrintSpace','}printspace']
elif link+'Border' in alltags:
tag_endings_printspace = ['}Border','}border']
if tag.endswith(tag_endings_printspace[0]) or tag.endswith(tag_endings_printspace[1]):
for nn in root1.iter(tag):
c_t_in = []
sumi = 0
for vv in nn.iter():
# check the format of coords
if vv.tag == link + 'Coords':
coords = bool(vv.attrib)
if coords:
p_h = vv.attrib['points'].split(' ')
c_t_in.append(
np.array([[int(x.split(',')[0]), int(x.split(',')[1])] for x in p_h]))
break
else:
pass
if vv.tag == link + 'Point':
c_t_in.append([int(float(vv.attrib['x'])), int(float(vv.attrib['y']))])
sumi += 1
elif vv.tag != link + 'Point' and sumi >= 1:
break
co_printspace.append(np.array(c_t_in))
img_printspace = np.zeros( (y_len,x_len,3) )
img_printspace=cv2.fillPoly(img_printspace, pts =co_printspace, color=(1,1,1))
img_printspace = img_printspace.astype(np.uint8)
imgray = cv2.cvtColor(img_printspace, cv2.COLOR_BGR2GRAY)
_, thresh = cv2.threshold(imgray, 0, 255, 0)
contours, _ = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
cnt_size = np.array([cv2.contourArea(contours[j]) for j in range(len(contours))])
cnt = contours[np.argmax(cnt_size)]
x, y, w, h = cv2.boundingRect(cnt)
bb_coord_printspace = [x, y, w, h]
else: else:
bb_coord_printspace = None ps_bbox = [0, 0, None, None]
region_tags=np.unique([x for x in alltags if x.endswith('Region')]) region_tags=np.unique([x for x in alltags if x.endswith('Region')])
co_text_paragraph=[] co_text_paragraph=[]
@ -1806,11 +1745,19 @@ def read_xml(xml_file):
img_poly=cv2.fillPoly(img, pts =co_img, color=(4,4,4)) img_poly=cv2.fillPoly(img, pts =co_img, color=(4,4,4))
img_poly=cv2.fillPoly(img, pts =co_sep, color=(5,5,5)) img_poly=cv2.fillPoly(img, pts =co_sep, color=(5,5,5))
return tree1, root1, bb_coord_printspace, file_name, id_paragraph, id_header+id_heading, co_text_paragraph, co_text_header+co_text_heading,\ return (tree1,
tot_region_ref,x_len, y_len,index_tot_regions, img_poly root1,
ps_bbox,
file_name,
id_paragraph,
id_header + id_heading,
co_text_paragraph,
co_text_header + co_text_heading,
tot_region_ref,
x_len,
y_len,
index_tot_regions,
img_poly)
# def bounding_box(cnt,color, corr_order_index ): # def bounding_box(cnt,color, corr_order_index ):
# x, y, w, h = cv2.boundingRect(cnt) # x, y, w, h = cv2.boundingRect(cnt)

View file

@ -1,19 +1,24 @@
"""
Tool to load model and predict for given image.
"""
import sys import sys
import os import os
from typing import Tuple from typing import Tuple
import warnings import warnings
import json import json
import numpy as np
import cv2
from numpy._typing import NDArray
import tensorflow as tf
from keras.models import Model, load_model
from keras import backend as K
import click import click
from tensorflow.python.keras import backend as tensorflow_backend import numpy as np
from numpy._typing import NDArray
import cv2
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15
import tensorflow as tf
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.layers import StringLookup
from .gt_gen_utils import ( from .gt_gen_utils import (
filter_contours_area_of_image, filter_contours_area_of_image,
find_new_features_of_contours, find_new_features_of_contours,
@ -21,24 +26,37 @@ from .gt_gen_utils import (
resize_image, resize_image,
update_list_and_return_first_with_length_bigger_than_one update_list_and_return_first_with_length_bigger_than_one
) )
from .models import ( from ..patch_encoder import (
PatchEncoder, PatchEncoder,
Patches Patches
) )
from .metrics import (
soft_dice_loss,
weighted_categorical_crossentropy,
)
from.utils import scale_padd_image_for_ocr
from ..utils.utils_ocr import decode_batch_predictions
from.utils import (scale_padd_image_for_ocr)
from eynollah.utils.utils_ocr import (decode_batch_predictions)
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter("ignore") warnings.simplefilter("ignore")
__doc__=\ class SBBPredict:
""" def __init__(self,
Tool to load model and predict for given image. image,
""" dir_in,
model,
class sbb_predict: task,
def __init__(self,image, dir_in, model, task, config_params_model, patches, save, save_layout, ground_truth, xml_file, cpu, out, min_area): config_params_model,
patches,
save,
save_layout,
ground_truth,
xml_file,
cpu,
out,
min_area,
):
self.image=image self.image=image
self.dir_in=dir_in self.dir_in=dir_in
self.patches=patches self.patches=patches
@ -57,8 +75,9 @@ class sbb_predict:
self.min_area = 0 self.min_area = 0
def resize_image(self,img_in,input_height,input_width): def resize_image(self,img_in,input_height,input_width):
return cv2.resize( img_in, ( input_width,input_height) ,interpolation=cv2.INTER_NEAREST) return cv2.resize(img_in, (input_width,
input_height),
interpolation=cv2.INTER_NEAREST)
def color_images(self,seg): def color_images(self,seg):
ann_u=range(self.n_classes) ann_u=range(self.n_classes)
@ -74,68 +93,6 @@ class sbb_predict:
seg_img[:,:,2][seg==c]=c seg_img[:,:,2][seg==c]=c
return seg_img return seg_img
def otsu_copy_binary(self,img):
img_r=np.zeros((img.shape[0],img.shape[1],3))
img1=img[:,:,0]
#print(img.min())
#print(img[:,:,0].min())
#blur = cv2.GaussianBlur(img,(5,5))
#ret3,th3 = cv2.threshold(blur,0,255,cv2.THRESH_BINARY+cv2.THRESH_OTSU)
_, threshold1 = cv2.threshold(img1, 0, 255, cv2.THRESH_BINARY+cv2.THRESH_OTSU)
img_r[:,:,0]=threshold1
img_r[:,:,1]=threshold1
img_r[:,:,2]=threshold1
#img_r=img_r/float(np.max(img_r))*255
return img_r
def otsu_copy(self,img):
img_r=np.zeros((img.shape[0],img.shape[1],3))
#img1=img[:,:,0]
#print(img.min())
#print(img[:,:,0].min())
#blur = cv2.GaussianBlur(img,(5,5))
#ret3,th3 = cv2.threshold(blur,0,255,cv2.THRESH_BINARY+cv2.THRESH_OTSU)
_, threshold1 = cv2.threshold(img[:,:,0], 0, 255, cv2.THRESH_BINARY+cv2.THRESH_OTSU)
_, threshold2 = cv2.threshold(img[:,:,1], 0, 255, cv2.THRESH_BINARY+cv2.THRESH_OTSU)
_, threshold3 = cv2.threshold(img[:,:,2], 0, 255, cv2.THRESH_BINARY+cv2.THRESH_OTSU)
img_r[:,:,0]=threshold1
img_r[:,:,1]=threshold2
img_r[:,:,2]=threshold3
###img_r=img_r/float(np.max(img_r))*255
return img_r
def soft_dice_loss(self,y_true, y_pred, epsilon=1e-6):
axes = tuple(range(1, len(y_pred.shape)-1))
numerator = 2. * K.sum(y_pred * y_true, axes)
denominator = K.sum(K.square(y_pred) + K.square(y_true), axes)
return 1.00 - K.mean(numerator / (denominator + epsilon)) # average over classes and batch
# def weighted_categorical_crossentropy(self,weights=None):
#
# def loss(y_true, y_pred):
# labels_floats = tf.cast(y_true, tf.float32)
# per_pixel_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=labels_floats,logits=y_pred)
#
# if weights is not None:
# weight_mask = tf.maximum(tf.reduce_max(tf.constant(
# np.array(weights, dtype=np.float32)[None, None, None])
# * labels_floats, axis=-1), 1.0)
# per_pixel_loss = per_pixel_loss * weight_mask[:, :, :, None]
# return tf.reduce_mean(per_pixel_loss)
# return self.loss
def IoU(self,Yi,y_predi): def IoU(self,Yi,y_predi):
## mean Intersection over Union ## mean Intersection over Union
## Mean IoU = TP/(FN + TP + FP) ## Mean IoU = TP/(FN + TP + FP)
@ -162,29 +119,33 @@ class sbb_predict:
return mIoU return mIoU
def start_new_session_and_model(self): def start_new_session_and_model(self):
if self.task == "cnn-rnn-ocr": if self.cpu:
if self.cpu: tf.config.set_visible_devices([], 'GPU')
os.environ['CUDA_VISIBLE_DEVICES']='-1'
self.model = load_model(self.model_dir)
self.model = tf.keras.models.Model(
self.model.get_layer(name = "image").input,
self.model.get_layer(name = "dense2").output)
else: else:
config = tf.compat.v1.ConfigProto() try:
config.gpu_options.allow_growth = True for device in tf.config.list_physical_devices('GPU'):
tf.config.experimental.set_memory_growth(device, True)
except:
print("no GPU device available", file=sys.stderr)
session = tf.compat.v1.Session(config=config) # tf.InteractiveSession() if self.task == "cnn-rnn-ocr":
tensorflow_backend.set_session(session) self.model = Model(
self.model.get_layer(name = "image").input,
self.model.get_layer(name = "dense2").output)
else:
self.model = load_model(self.model_dir, compile=False,
custom_objects={"PatchEncoder": PatchEncoder,
"Patches": Patches})
##if self.weights_dir!=None: ##if self.weights_dir!=None:
##self.model.load_weights(self.weights_dir) ##self.model.load_weights(self.weights_dir)
assert isinstance(self.model, Model) assert isinstance(self.model, Model)
if self.task != 'classification' and self.task != 'reading_order': if self.task != 'classification' and self.task != 'reading_order':
self.img_height=self.model.layers[len(self.model.layers)-1].output_shape[1] last = self.model.layers[-1]
self.img_width=self.model.layers[len(self.model.layers)-1].output_shape[2] self.img_height = last.output_shape[1]
self.n_classes=self.model.layers[len(self.model.layers)-1].output_shape[3] self.img_width = last.output_shape[2]
self.n_classes = last.output_shape[3]
def visualize_model_output(self, prediction, img, task) -> Tuple[NDArray, NDArray]: def visualize_model_output(self, prediction, img, task) -> Tuple[NDArray, NDArray]:
if task == "binarization": if task == "binarization":
@ -212,21 +173,16 @@ class sbb_predict:
'15' : [255, 0, 255]} '15' : [255, 0, 255]}
layout_only = np.zeros(prediction.shape) layout_only = np.zeros(prediction.shape)
for unq_class in unique_classes: for unq_class in unique_classes:
where = prediction[:,:,0]==unq_class
rgb_class_unique = rgb_colors[str(int(unq_class))] rgb_class_unique = rgb_colors[str(int(unq_class))]
layout_only[:,:,0][prediction[:,:,0]==unq_class] = rgb_class_unique[0] layout_only[:,:,0][where] = rgb_class_unique[0]
layout_only[:,:,1][prediction[:,:,0]==unq_class] = rgb_class_unique[1] layout_only[:,:,1][where] = rgb_class_unique[1]
layout_only[:,:,2][prediction[:,:,0]==unq_class] = rgb_class_unique[2] layout_only[:,:,2][where] = rgb_class_unique[2]
layout_only = layout_only.astype(np.int32)
img = self.resize_image(img, layout_only.shape[0], layout_only.shape[1]) img = self.resize_image(img, layout_only.shape[0], layout_only.shape[1])
layout_only = layout_only.astype(np.int32)
img = img.astype(np.int32) img = img.astype(np.int32)
added_image = cv2.addWeighted(img,0.5,layout_only,0.1,0) added_image = cv2.addWeighted(img,0.5,layout_only,0.1,0)
@ -238,10 +194,10 @@ class sbb_predict:
assert isinstance(self.model, Model) assert isinstance(self.model, Model)
if self.task == 'classification': if self.task == 'classification':
classes_names = self.config_params_model['classification_classes_name'] classes_names = self.config_params_model['classification_classes_name']
img_1ch = img=cv2.imread(image_dir, 0) img_1ch = cv2.imread(image_dir, 0) / 255.0
img_1ch = cv2.resize(img_1ch, (self.config_params_model['input_height'],
img_1ch = img_1ch / 255.0 self.config_params_model['input_width']),
img_1ch = cv2.resize(img_1ch, (self.config_params_model['input_height'], self.config_params_model['input_width']), interpolation=cv2.INTER_NEAREST) interpolation=cv2.INTER_NEAREST)
img_in = np.zeros((1, img_1ch.shape[0], img_1ch.shape[1], 3)) img_in = np.zeros((1, img_1ch.shape[0], img_1ch.shape[1], 3))
img_in[0, :, :, 0] = img_1ch[:, :] img_in[0, :, :, 0] = img_1ch[:, :]
img_in[0, :, :, 1] = img_1ch[:, :] img_in[0, :, :, 1] = img_1ch[:, :]
@ -251,6 +207,7 @@ class sbb_predict:
index_class = np.argmax(label_p_pred[0]) index_class = np.argmax(label_p_pred[0])
print("Predicted Class: {}".format(classes_names[str(int(index_class))])) print("Predicted Class: {}".format(classes_names[str(int(index_class))]))
elif self.task == "cnn-rnn-ocr": elif self.task == "cnn-rnn-ocr":
img=cv2.imread(image_dir) img=cv2.imread(image_dir)
img = scale_padd_image_for_ocr(img, self.config_params_model['input_height'], self.config_params_model['input_width']) img = scale_padd_image_for_ocr(img, self.config_params_model['input_height'], self.config_params_model['input_width'])
@ -279,19 +236,22 @@ class sbb_predict:
img_height = self.config_params_model['input_height'] img_height = self.config_params_model['input_height']
img_width = self.config_params_model['input_width'] img_width = self.config_params_model['input_width']
tree_xml, root_xml, bb_coord_printspace, file_name, id_paragraph, id_header, co_text_paragraph, co_text_header, tot_region_ref, x_len, y_len, index_tot_regions, img_poly = read_xml(self.xml_file) tree_xml, root_xml, ps_bbox, file_name, \
_, cy_main, x_min_main, x_max_main, y_min_main, y_max_main, _ = find_new_features_of_contours(co_text_header) id_paragraph, id_header, \
co_text_paragraph, co_text_header, \
tot_region_ref, x_len, y_len, index_tot_regions, \
img_poly = read_xml(self.xml_file)
_, cy_main, x_min_main, x_max_main, y_min_main, y_max_main, _ = \
find_new_features_of_contours(co_text_header)
img_header_and_sep = np.zeros((y_len,x_len), dtype='uint8') img_header_and_sep = np.zeros((y_len,x_len), dtype='uint8')
for j in range(len(cy_main)): for j in range(len(cy_main)):
img_header_and_sep[int(y_max_main[j]):int(y_max_main[j])+12,int(x_min_main[j]):int(x_max_main[j]) ] = 1 img_header_and_sep[int(y_max_main[j]): int(y_max_main[j]) + 12,
int(x_min_main[j]): int(x_max_main[j])] = 1
co_text_all = co_text_paragraph + co_text_header co_text_all = co_text_paragraph + co_text_header
id_all_text = id_paragraph + id_header id_all_text = id_paragraph + id_header
##texts_corr_order_index = [index_tot_regions[tot_region_ref.index(i)] for i in id_all_text ] ##texts_corr_order_index = [index_tot_regions[tot_region_ref.index(i)] for i in id_all_text ]
##texts_corr_order_index_int = [int(x) for x in texts_corr_order_index] ##texts_corr_order_index_int = [int(x) for x in texts_corr_order_index]
texts_corr_order_index_int = list(np.array(range(len(co_text_all)))) texts_corr_order_index_int = list(np.array(range(len(co_text_all))))
@ -302,7 +262,8 @@ class sbb_predict:
#print(np.shape(co_text_all[0]), len( np.shape(co_text_all[0]) ),'co_text_all') #print(np.shape(co_text_all[0]), len( np.shape(co_text_all[0]) ),'co_text_all')
#co_text_all = filter_contours_area_of_image_tables(img_poly, co_text_all, _, max_area, min_area) #co_text_all = filter_contours_area_of_image_tables(img_poly, co_text_all, _, max_area, min_area)
#print(co_text_all,'co_text_all') #print(co_text_all,'co_text_all')
co_text_all, texts_corr_order_index_int, _ = filter_contours_area_of_image(img_poly, co_text_all, texts_corr_order_index_int, max_area, self.min_area) co_text_all, texts_corr_order_index_int, _ = filter_contours_area_of_image(
img_poly, co_text_all, texts_corr_order_index_int, max_area, self.min_area)
#print(texts_corr_order_index_int) #print(texts_corr_order_index_int)
@ -315,15 +276,13 @@ class sbb_predict:
img_label=cv2.fillPoly(img_label, pts =[co_text_all[i]], color=(1,1,1)) img_label=cv2.fillPoly(img_label, pts =[co_text_all[i]], color=(1,1,1))
labels_con[:,:,i] = img_label[:,:,0] labels_con[:,:,i] = img_label[:,:,0]
if bb_coord_printspace: if ps_bbox:
#bb_coord_printspace[x,y,w,h,_,_] labels_con = labels_con[ps_bbox[1]:ps_bbox[3],
x = bb_coord_printspace[0] ps_bbox[0]:ps_bbox[2], :]
y = bb_coord_printspace[1] img_poly = img_poly[ps_bbox[1]:ps_bbox[3],
w = bb_coord_printspace[2] ps_bbox[0]:ps_bbox[2], :]
h = bb_coord_printspace[3] img_header_and_sep = img_header_and_sep[ps_bbox[1]:ps_bbox[3],
labels_con = labels_con[y:y+h, x:x+w, :] ps_bbox[0]:ps_bbox[2]]
img_poly = img_poly[y:y+h, x:x+w, :]
img_header_and_sep = img_header_and_sep[y:y+h, x:x+w]
@ -709,17 +668,15 @@ class sbb_predict:
help="min area size of regions considered for reading order detection. The default value is zero and means that all text regions are considered for reading order.", help="min area size of regions considered for reading order detection. The default value is zero and means that all text regions are considered for reading order.",
) )
def main(image, dir_in, model, patches, save, save_layout, ground_truth, xml_file, cpu, out, min_area): def main(image, dir_in, model, patches, save, save_layout, ground_truth, xml_file, cpu, out, min_area):
assert image or dir_in, "Either a single image -i or a dir_in -di is required" assert image or dir_in, "Either a single image -i or a dir_in -di input is required"
with open(os.path.join(model,'config.json')) as f: with open(os.path.join(model,'config.json')) as f:
config_params_model = json.load(f) config_params_model = json.load(f)
task = config_params_model['task'] task = config_params_model['task']
if task != 'classification' and task != 'reading_order' and task != "cnn-rnn-ocr": if task not in ['classification', 'reading_order', "cnn-rnn-ocr"]:
if image and not save: assert not image or save, "For segmentation or binarization, an input single image -i also requires an output filename -s"
print("Error: You used one of segmentation or binarization task with image input but not set -s, you need a filename to save visualized output with -s") assert not dir_in or out, "For segmentation or binarization, an input directory -di also requires an output directory -o"
sys.exit(1) x = SBBPredict(image, dir_in, model, task, config_params_model,
if dir_in and not out: patches, save, save_layout, ground_truth, xml_file,
print("Error: You used one of segmentation or binarization task with dir_in but not set -out") cpu, out, min_area)
sys.exit(1)
x=sbb_predict(image, dir_in, model, task, config_params_model, patches, save, save_layout, ground_truth, xml_file, cpu, out, min_area)
x.run() x.run()

View file

@ -1,9 +1,14 @@
from tensorflow import keras import os
from keras.layers import (
os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15
import tensorflow as tf
from tensorflow.keras.layers import (
Activation, Activation,
Add, Add,
AveragePooling2D, AveragePooling2D,
BatchNormalization, BatchNormalization,
Bidirectional,
Conv1D,
Conv2D, Conv2D,
Dense, Dense,
Dropout, Dropout,
@ -13,30 +18,33 @@ from keras.layers import (
Lambda, Lambda,
Layer, Layer,
LayerNormalization, LayerNormalization,
LSTM,
MaxPooling2D, MaxPooling2D,
MultiHeadAttention, MultiHeadAttention,
Reshape,
UpSampling2D, UpSampling2D,
ZeroPadding2D, ZeroPadding2D,
add, add,
concatenate concatenate
) )
from keras.models import Model from tensorflow.keras.models import Model
import tensorflow as tf from tensorflow.keras.regularizers import l2
# from keras import layers, models
from keras.regularizers import l2
from eynollah.patch_encoder import Patches, PatchEncoder from ..patch_encoder import Patches, PatchEncoder
##mlp_head_units = [512, 256]#[2048, 1024] ##mlp_head_units = [512, 256]#[2048, 1024]
###projection_dim = 64 ###projection_dim = 64
##transformer_layers = 2#8 ##transformer_layers = 2#8
##num_heads = 1#4 ##num_heads = 1#4
resnet50_Weights_path = './pretrained_model/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5' RESNET50_WEIGHTS_PATH = './pretrained_model/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5'
RESNET50_WEIGHTS_URL = ('https://github.com/fchollet/deep-learning-models/releases/download/v0.2/'
'resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5')
IMAGE_ORDERING = 'channels_last' IMAGE_ORDERING = 'channels_last'
MERGE_AXIS = -1 MERGE_AXIS = -1
class CTCLayer(tf.keras.layers.Layer): class CTCLayer(Layer):
def __init__(self, name=None): def __init__(self, name=None):
super().__init__(name=name) super().__init__(name=name)
self.loss_fn = tf.keras.backend.ctc_batch_cost self.loss_fn = tf.keras.backend.ctc_batch_cost
@ -61,14 +69,9 @@ def mlp(x, hidden_units, dropout_rate):
return x return x
def one_side_pad(x): def one_side_pad(x):
x = ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING)(x) x = ZeroPadding2D(((1, 0), (1, 0)), data_format=IMAGE_ORDERING)(x)
if IMAGE_ORDERING == 'channels_first':
x = Lambda(lambda x: x[:, :, :-1, :-1])(x)
elif IMAGE_ORDERING == 'channels_last':
x = Lambda(lambda x: x[:, :-1, :-1, :])(x)
return x return x
def identity_block(input_tensor, kernel_size, filters, stage, block): def identity_block(input_tensor, kernel_size, filters, stage, block):
"""The identity block is the block that has no conv layer at shortcut. """The identity block is the block that has no conv layer at shortcut.
# Arguments # Arguments
@ -151,19 +154,13 @@ def conv_block(input_tensor, kernel_size, filters, stage, block, strides=(2, 2))
x = Activation('relu')(x) x = Activation('relu')(x)
return x return x
def resnet50(inputs, weight_decay=1e-6, pretraining=False):
def resnet50_unet_light(n_classes, input_height=224, input_width=224, task="segmentation", weight_decay=1e-6, pretraining=False):
assert input_height % 32 == 0
assert input_width % 32 == 0
img_input = Input(shape=(input_height, input_width, 3))
if IMAGE_ORDERING == 'channels_last': if IMAGE_ORDERING == 'channels_last':
bn_axis = 3 bn_axis = 3
else: else:
bn_axis = 1 bn_axis = 1
x = ZeroPadding2D((3, 3), data_format=IMAGE_ORDERING)(img_input) x = ZeroPadding2D((3, 3), data_format=IMAGE_ORDERING)(inputs)
x = Conv2D(64, (7, 7), data_format=IMAGE_ORDERING, strides=(2, 2), kernel_regularizer=l2(weight_decay), x = Conv2D(64, (7, 7), data_format=IMAGE_ORDERING, strides=(2, 2), kernel_regularizer=l2(weight_decay),
name='conv1')(x) name='conv1')(x)
f1 = x f1 = x
@ -197,61 +194,86 @@ def resnet50_unet_light(n_classes, input_height=224, input_width=224, task="segm
f5 = x f5 = x
if pretraining: if pretraining:
model = Model(img_input, x).load_weights(resnet50_Weights_path) model = Model(inputs, x).load_weights(RESNET50_WEIGHTS_PATH)
v512_2048 = Conv2D(512, (1, 1), padding='same', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay))(f5) return f1, f2, f3, f4, f5
v512_2048 = (BatchNormalization(axis=bn_axis))(v512_2048)
v512_2048 = Activation('relu')(v512_2048)
v512_1024 = Conv2D(512, (1, 1), padding='same', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay))(f4) def unet_decoder(img, f1, f2, f3, f4, f5, n_classes, light=False, task="segmentation", weight_decay=1e-6):
v512_1024 = (BatchNormalization(axis=bn_axis))(v512_1024) if IMAGE_ORDERING == 'channels_last':
v512_1024 = Activation('relu')(v512_1024) bn_axis = 3
o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(v512_2048)
o = (concatenate([o, v512_1024], axis=MERGE_AXIS))
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
o = (Conv2D(512, (3, 3), padding='valid', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay)))(o)
o = (BatchNormalization(axis=bn_axis))(o)
o = Activation('relu')(o)
o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o)
o = (concatenate([o, f3], axis=MERGE_AXIS))
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
o = (Conv2D(256, (3, 3), padding='valid', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay)))(o)
o = (BatchNormalization(axis=bn_axis))(o)
o = Activation('relu')(o)
o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o)
o = (concatenate([o, f2], axis=MERGE_AXIS))
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
o = (Conv2D(128, (3, 3), padding='valid', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay)))(o)
o = (BatchNormalization(axis=bn_axis))(o)
o = Activation('relu')(o)
o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o)
o = (concatenate([o, f1], axis=MERGE_AXIS))
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
o = (Conv2D(64, (3, 3), padding='valid', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay)))(o)
o = (BatchNormalization(axis=bn_axis))(o)
o = Activation('relu')(o)
o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o)
o = (concatenate([o, img_input], axis=MERGE_AXIS))
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
o = (Conv2D(32, (3, 3), padding='valid', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay)))(o)
o = (BatchNormalization(axis=bn_axis))(o)
o = Activation('relu')(o)
o = Conv2D(n_classes, (1, 1), padding='same', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay))(o)
if task == "segmentation":
o = (BatchNormalization(axis=bn_axis))(o)
o = (Activation('softmax'))(o)
else: else:
o = (Activation('sigmoid'))(o) bn_axis = 1
model = Model(img_input, o) o = Conv2D(512 if light else 1024, (1, 1), padding='same',
return model data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay))(f5)
o = BatchNormalization(axis=bn_axis)(o)
o = Activation('relu')(o)
if light:
f4 = Conv2D(512, (1, 1), padding='same',
data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay))(f4)
f4 = BatchNormalization(axis=bn_axis)(f4)
f4 = Activation('relu')(f4)
o = UpSampling2D((2, 2), data_format=IMAGE_ORDERING)(o)
o = concatenate([o, f4], axis=MERGE_AXIS)
o = ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING)(o)
o = Conv2D(512, (3, 3), padding='valid',
data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay))(o)
o = BatchNormalization(axis=bn_axis)(o)
o = Activation('relu')(o)
o = UpSampling2D((2, 2), data_format=IMAGE_ORDERING)(o)
o = concatenate([o, f3], axis=MERGE_AXIS)
o = ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING)(o)
o = Conv2D(256, (3, 3), padding='valid',
data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay))(o)
o = BatchNormalization(axis=bn_axis)(o)
o = Activation('relu')(o)
o = UpSampling2D((2, 2), data_format=IMAGE_ORDERING)(o)
o = concatenate([o, f2], axis=MERGE_AXIS)
o = ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING)(o)
o = Conv2D(128, (3, 3), padding='valid',
data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay))(o)
o = BatchNormalization(axis=bn_axis)(o)
o = Activation('relu')(o)
o = UpSampling2D((2, 2), data_format=IMAGE_ORDERING)(o)
o = concatenate([o, f1], axis=MERGE_AXIS)
o = ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING)(o)
o = Conv2D(64, (3, 3), padding='valid',
data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay))(o)
o = BatchNormalization(axis=bn_axis)(o)
o = Activation('relu')(o)
o = UpSampling2D((2, 2), data_format=IMAGE_ORDERING)(o)
o = concatenate([o, img], axis=MERGE_AXIS)
o = ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING)(o)
o = Conv2D(32, (3, 3), padding='valid',
data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay))(o)
o = BatchNormalization(axis=bn_axis)(o)
o = Activation('relu')(o)
o = Conv2D(n_classes, (1, 1), padding='same',
data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay))(o)
if task == "segmentation":
o = BatchNormalization(axis=bn_axis)(o)
o = Activation('softmax')(o)
else:
o = Activation('sigmoid')(o)
return Model(img, o)
def resnet50_unet_light(n_classes, input_height=224, input_width=224, task="segmentation", weight_decay=1e-6, pretraining=False):
assert input_height % 32 == 0
assert input_width % 32 == 0
img_input = Input(shape=(input_height, input_width, 3))
features = resnet50(img_input, weight_decay=weight_decay, pretraining=pretraining)
return unet_decoder(img_input, *features, n_classes, light=True, task=task, weight_decay=weight_decay)
def resnet50_unet(n_classes, input_height=224, input_width=224, task="segmentation", weight_decay=1e-6, pretraining=False): def resnet50_unet(n_classes, input_height=224, input_width=224, task="segmentation", weight_decay=1e-6, pretraining=False):
assert input_height % 32 == 0 assert input_height % 32 == 0
@ -259,162 +281,29 @@ def resnet50_unet(n_classes, input_height=224, input_width=224, task="segmentati
img_input = Input(shape=(input_height, input_width, 3)) img_input = Input(shape=(input_height, input_width, 3))
if IMAGE_ORDERING == 'channels_last': features = resnet50(img_input, weight_decay=weight_decay, pretraining=pretraining)
bn_axis = 3
else:
bn_axis = 1
x = ZeroPadding2D((3, 3), data_format=IMAGE_ORDERING)(img_input) return unet_decoder(img_input, *features, n_classes, light=False, task=task, weight_decay=weight_decay)
x = Conv2D(64, (7, 7), data_format=IMAGE_ORDERING, strides=(2, 2), kernel_regularizer=l2(weight_decay),
name='conv1')(x)
f1 = x
x = BatchNormalization(axis=bn_axis, name='bn_conv1')(x) def transformer_block(img,
x = Activation('relu')(x) num_patches,
x = MaxPooling2D((3, 3), data_format=IMAGE_ORDERING, strides=(2, 2))(x) patchsize_x,
patchsize_y,
x = conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1)) mlp_head_units,
x = identity_block(x, 3, [64, 64, 256], stage=2, block='b') n_layers,
x = identity_block(x, 3, [64, 64, 256], stage=2, block='c') num_heads,
f2 = one_side_pad(x) projection_dim):
patches = Patches(patchsize_x, patchsize_y)(img)
x = conv_block(x, 3, [128, 128, 512], stage=3, block='a')
x = identity_block(x, 3, [128, 128, 512], stage=3, block='b')
x = identity_block(x, 3, [128, 128, 512], stage=3, block='c')
x = identity_block(x, 3, [128, 128, 512], stage=3, block='d')
f3 = x
x = conv_block(x, 3, [256, 256, 1024], stage=4, block='a')
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='b')
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='c')
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='d')
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='e')
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='f')
f4 = x
x = conv_block(x, 3, [512, 512, 2048], stage=5, block='a')
x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b')
x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c')
f5 = x
if pretraining:
Model(img_input, x).load_weights(resnet50_Weights_path)
v1024_2048 = Conv2D(1024, (1, 1), padding='same', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay))(
f5)
v1024_2048 = (BatchNormalization(axis=bn_axis))(v1024_2048)
v1024_2048 = Activation('relu')(v1024_2048)
o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(v1024_2048)
o = (concatenate([o, f4], axis=MERGE_AXIS))
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
o = (Conv2D(512, (3, 3), padding='valid', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay)))(o)
o = (BatchNormalization(axis=bn_axis))(o)
o = Activation('relu')(o)
o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o)
o = (concatenate([o, f3], axis=MERGE_AXIS))
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
o = (Conv2D(256, (3, 3), padding='valid', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay)))(o)
o = (BatchNormalization(axis=bn_axis))(o)
o = Activation('relu')(o)
o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o)
o = (concatenate([o, f2], axis=MERGE_AXIS))
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
o = (Conv2D(128, (3, 3), padding='valid', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay)))(o)
o = (BatchNormalization(axis=bn_axis))(o)
o = Activation('relu')(o)
o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o)
o = (concatenate([o, f1], axis=MERGE_AXIS))
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
o = (Conv2D(64, (3, 3), padding='valid', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay)))(o)
o = (BatchNormalization(axis=bn_axis))(o)
o = Activation('relu')(o)
o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o)
o = (concatenate([o, img_input], axis=MERGE_AXIS))
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
o = (Conv2D(32, (3, 3), padding='valid', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay)))(o)
o = (BatchNormalization(axis=bn_axis))(o)
o = Activation('relu')(o)
o = Conv2D(n_classes, (1, 1), padding='same', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay))(o)
if task == "segmentation":
o = (BatchNormalization(axis=bn_axis))(o)
o = (Activation('softmax'))(o)
else:
o = (Activation('sigmoid'))(o)
model = Model(img_input, o)
return model
def vit_resnet50_unet(n_classes, patch_size_x, patch_size_y, num_patches, mlp_head_units=None, transformer_layers=8, num_heads =4, projection_dim = 64, input_height=224, input_width=224, task="segmentation", weight_decay=1e-6, pretraining=False):
if mlp_head_units is None:
mlp_head_units = [128, 64]
inputs = Input(shape=(input_height, input_width, 3))
#transformer_units = [
#projection_dim * 2,
#projection_dim,
#] # Size of the transformer layers
IMAGE_ORDERING = 'channels_last'
bn_axis=3
x = ZeroPadding2D((3, 3), data_format=IMAGE_ORDERING)(inputs)
x = Conv2D(64, (7, 7), data_format=IMAGE_ORDERING, strides=(2, 2),kernel_regularizer=l2(weight_decay), name='conv1')(x)
f1 = x
x = BatchNormalization(axis=bn_axis, name='bn_conv1')(x)
x = Activation('relu')(x)
x = MaxPooling2D((3, 3), data_format=IMAGE_ORDERING, strides=(2, 2))(x)
x = conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1))
x = identity_block(x, 3, [64, 64, 256], stage=2, block='b')
x = identity_block(x, 3, [64, 64, 256], stage=2, block='c')
f2 = one_side_pad(x)
x = conv_block(x, 3, [128, 128, 512], stage=3, block='a')
x = identity_block(x, 3, [128, 128, 512], stage=3, block='b')
x = identity_block(x, 3, [128, 128, 512], stage=3, block='c')
x = identity_block(x, 3, [128, 128, 512], stage=3, block='d')
f3 = x
x = conv_block(x, 3, [256, 256, 1024], stage=4, block='a')
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='b')
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='c')
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='d')
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='e')
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='f')
f4 = x
x = conv_block(x, 3, [512, 512, 2048], stage=5, block='a')
x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b')
x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c')
f5 = x
if pretraining:
model = Model(inputs, x).load_weights(resnet50_Weights_path)
#num_patches = x.shape[1]*x.shape[2]
#patch_size_y = input_height / x.shape[1]
#patch_size_x = input_width / x.shape[2]
#patch_size = patch_size_x * patch_size_y
patches = Patches(patch_size_x, patch_size_y)(x)
# Encode patches. # Encode patches.
encoded_patches = PatchEncoder(num_patches, projection_dim)(patches) encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)
for _ in range(transformer_layers): for _ in range(n_layers):
# Layer normalization 1. # Layer normalization 1.
x1 = LayerNormalization(epsilon=1e-6)(encoded_patches) x1 = LayerNormalization(epsilon=1e-6)(encoded_patches)
# Create a multi-head attention layer. # Create a multi-head attention layer.
attention_output = MultiHeadAttention( attention_output = MultiHeadAttention(num_heads=num_heads,
num_heads=num_heads, key_dim=projection_dim, dropout=0.1 key_dim=projection_dim,
)(x1, x1) dropout=0.1)(x1, x1)
# Skip connection 1. # Skip connection 1.
x2 = Add()([attention_output, encoded_patches]) x2 = Add()([attention_output, encoded_patches])
# Layer normalization 2. # Layer normalization 2.
@ -423,180 +312,80 @@ def vit_resnet50_unet(n_classes, patch_size_x, patch_size_y, num_patches, mlp_he
x3 = mlp(x3, hidden_units=mlp_head_units, dropout_rate=0.1) x3 = mlp(x3, hidden_units=mlp_head_units, dropout_rate=0.1)
# Skip connection 2. # Skip connection 2.
encoded_patches = Add()([x3, x2]) encoded_patches = Add()([x3, x2])
assert isinstance(x, Layer)
encoded_patches = tf.reshape(encoded_patches, [-1, x.shape[1], x.shape[2] , int( projection_dim / (patch_size_x * patch_size_y) )])
v1024_2048 = Conv2D( 1024 , (1, 1), padding='same', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay))(encoded_patches) encoded_patches = tf.reshape(encoded_patches,
v1024_2048 = (BatchNormalization(axis=bn_axis))(v1024_2048) [-1,
v1024_2048 = Activation('relu')(v1024_2048) img.shape[1],
img.shape[2],
o = (UpSampling2D( (2, 2), data_format=IMAGE_ORDERING))(v1024_2048) projection_dim // (patchsize_x * patchsize_y)])
o = (concatenate([o, f4],axis=MERGE_AXIS)) return encoded_patches
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
o = (Conv2D(512, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o)
o = (BatchNormalization(axis=bn_axis))(o)
o = Activation('relu')(o)
o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o)
o = (concatenate([o ,f3], axis=MERGE_AXIS))
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
o = (Conv2D(256, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o)
o = (BatchNormalization(axis=bn_axis))(o)
o = Activation('relu')(o)
o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o)
o = (concatenate([o, f2], axis=MERGE_AXIS))
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
o = (Conv2D(128, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o)
o = (BatchNormalization(axis=bn_axis))(o)
o = Activation('relu')(o)
o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o)
o = (concatenate([o, f1], axis=MERGE_AXIS))
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
o = (Conv2D(64, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o)
o = (BatchNormalization(axis=bn_axis))(o)
o = Activation('relu')(o)
o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o)
o = (concatenate([o, inputs],axis=MERGE_AXIS))
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
o = (Conv2D(32, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o)
o = (BatchNormalization(axis=bn_axis))(o)
o = Activation('relu')(o)
o = Conv2D(n_classes, (1, 1), padding='same', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay))(o)
if task == "segmentation":
o = (BatchNormalization(axis=bn_axis))(o)
o = (Activation('softmax'))(o)
else:
o = (Activation('sigmoid'))(o)
model = Model(inputs=inputs, outputs=o) def vit_resnet50_unet(num_patches,
n_classes,
return model transformer_patchsize_x,
transformer_patchsize_y,
def vit_resnet50_unet_transformer_before_cnn(n_classes, patch_size_x, patch_size_y, num_patches, mlp_head_units=None, transformer_layers=8, num_heads =4, projection_dim = 64, input_height=224, input_width=224, task="segmentation", weight_decay=1e-6, pretraining=False): transformer_mlp_head_units=None,
if mlp_head_units is None: transformer_layers=8,
mlp_head_units = [128, 64] transformer_num_heads=4,
transformer_projection_dim=64,
input_height=224,
input_width=224,
task="segmentation",
weight_decay=1e-6,
pretraining=False):
if transformer_mlp_head_units is None:
transformer_mlp_head_units = [128, 64]
inputs = Input(shape=(input_height, input_width, 3)) inputs = Input(shape=(input_height, input_width, 3))
##transformer_units = [ features = resnet50(inputs, weight_decay=weight_decay, pretraining=pretraining)
##projection_dim * 2,
##projection_dim,
##] # Size of the transformer layers
IMAGE_ORDERING = 'channels_last'
bn_axis=3
patches = Patches(patch_size_x, patch_size_y)(inputs)
# Encode patches.
encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)
for _ in range(transformer_layers):
# Layer normalization 1.
x1 = LayerNormalization(epsilon=1e-6)(encoded_patches)
# Create a multi-head attention layer.
attention_output = MultiHeadAttention(
num_heads=num_heads, key_dim=projection_dim, dropout=0.1
)(x1, x1)
# Skip connection 1.
x2 = Add()([attention_output, encoded_patches])
# Layer normalization 2.
x3 = LayerNormalization(epsilon=1e-6)(x2)
# MLP.
x3 = mlp(x3, hidden_units=mlp_head_units, dropout_rate=0.1)
# Skip connection 2.
encoded_patches = Add()([x3, x2])
encoded_patches = tf.reshape(encoded_patches, [-1, input_height, input_width , int( projection_dim / (patch_size_x * patch_size_y) )])
encoded_patches = Conv2D(3, (1, 1), padding='same', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay), name='convinput')(encoded_patches)
x = ZeroPadding2D((3, 3), data_format=IMAGE_ORDERING)(encoded_patches)
x = Conv2D(64, (7, 7), data_format=IMAGE_ORDERING, strides=(2, 2),kernel_regularizer=l2(weight_decay), name='conv1')(x)
f1 = x
x = BatchNormalization(axis=bn_axis, name='bn_conv1')(x) features[-1] = transformer_block(features[-1],
x = Activation('relu')(x) num_patches,
x = MaxPooling2D((3, 3), data_format=IMAGE_ORDERING, strides=(2, 2))(x) transformer_patchsize_x,
transformer_patchsize_y,
x = conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1)) transformer_mlp_head_units,
x = identity_block(x, 3, [64, 64, 256], stage=2, block='b') transformer_layers,
x = identity_block(x, 3, [64, 64, 256], stage=2, block='c') transformer_num_heads,
f2 = one_side_pad(x) transformer_projection_dim)
x = conv_block(x, 3, [128, 128, 512], stage=3, block='a') o = unet_decoder(inputs, *features, n_classes, task=task, weight_decay=weight_decay)
x = identity_block(x, 3, [128, 128, 512], stage=3, block='b')
x = identity_block(x, 3, [128, 128, 512], stage=3, block='c')
x = identity_block(x, 3, [128, 128, 512], stage=3, block='d')
f3 = x
x = conv_block(x, 3, [256, 256, 1024], stage=4, block='a') return Model(inputs, o)
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='b')
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='c')
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='d')
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='e')
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='f')
f4 = x
x = conv_block(x, 3, [512, 512, 2048], stage=5, block='a') def vit_resnet50_unet_transformer_before_cnn(num_patches,
x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b') n_classes,
x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c') transformer_patchsize_x,
f5 = x transformer_patchsize_y,
transformer_mlp_head_units=None,
if pretraining: transformer_layers=8,
model = Model(encoded_patches, x).load_weights(resnet50_Weights_path) transformer_num_heads=4,
transformer_projection_dim=64,
input_height=224,
input_width=224,
task="segmentation",
weight_decay=1e-6,
pretraining=False):
if transformer_mlp_head_units is None:
transformer_mlp_head_units = [128, 64]
inputs = Input(shape=(input_height, input_width, 3))
v1024_2048 = Conv2D( 1024 , (1, 1), padding='same', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay))(x) encoded_patches = transformer_block(inputs,
v1024_2048 = (BatchNormalization(axis=bn_axis))(v1024_2048) num_patches,
v1024_2048 = Activation('relu')(v1024_2048) transformer_patchsize_x,
transformer_patchsize_y,
o = (UpSampling2D( (2, 2), data_format=IMAGE_ORDERING))(v1024_2048) transformer_mlp_head_units,
o = (concatenate([o, f4],axis=MERGE_AXIS)) transformer_layers,
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o) transformer_num_heads,
o = (Conv2D(512, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o) transformer_projection_dim)
o = (BatchNormalization(axis=bn_axis))(o) encoded_patches = Conv2D(3, (1, 1), padding='same',
o = Activation('relu')(o) data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay),
name='convinput')(encoded_patches)
o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o)
o = (concatenate([o ,f3], axis=MERGE_AXIS))
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
o = (Conv2D(256, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o)
o = (BatchNormalization(axis=bn_axis))(o)
o = Activation('relu')(o)
o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o)
o = (concatenate([o, f2], axis=MERGE_AXIS))
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
o = (Conv2D(128, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o)
o = (BatchNormalization(axis=bn_axis))(o)
o = Activation('relu')(o)
o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o)
o = (concatenate([o, f1], axis=MERGE_AXIS))
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
o = (Conv2D(64, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o)
o = (BatchNormalization(axis=bn_axis))(o)
o = Activation('relu')(o)
o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o)
o = (concatenate([o, inputs],axis=MERGE_AXIS))
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
o = (Conv2D(32, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o)
o = (BatchNormalization(axis=bn_axis))(o)
o = Activation('relu')(o)
o = Conv2D(n_classes, (1, 1), padding='same', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay))(o)
if task == "segmentation":
o = (BatchNormalization(axis=bn_axis))(o)
o = (Activation('softmax'))(o)
else:
o = (Activation('sigmoid'))(o)
model = Model(inputs=inputs, outputs=o) features = resnet50(encoded_patches, weight_decay=weight_decay, pretraining=pretraining)
o = unet_decoder(inputs, *features, n_classes, task=task, weight_decay=weight_decay)
return model return Model(inputs, o)
def resnet50_classifier(n_classes,input_height=224,input_width=224,weight_decay=1e-6,pretraining=False): def resnet50_classifier(n_classes,input_height=224,input_width=224,weight_decay=1e-6,pretraining=False):
include_top=True include_top=True
@ -606,47 +395,7 @@ def resnet50_classifier(n_classes,input_height=224,input_width=224,weight_decay=
img_input = Input(shape=(input_height,input_width , 3 )) img_input = Input(shape=(input_height,input_width , 3 ))
if IMAGE_ORDERING == 'channels_last': _, _, _, _, x = resnet50(img_input, weight_decay, pretraining)
bn_axis = 3
else:
bn_axis = 1
x = ZeroPadding2D((3, 3), data_format=IMAGE_ORDERING)(img_input)
x = Conv2D(64, (7, 7), data_format=IMAGE_ORDERING, strides=(2, 2),kernel_regularizer=l2(weight_decay), name='conv1')(x)
f1 = x
x = BatchNormalization(axis=bn_axis, name='bn_conv1')(x)
x = Activation('relu')(x)
x = MaxPooling2D((3, 3) , data_format=IMAGE_ORDERING , strides=(2, 2))(x)
x = conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1))
x = identity_block(x, 3, [64, 64, 256], stage=2, block='b')
x = identity_block(x, 3, [64, 64, 256], stage=2, block='c')
f2 = one_side_pad(x )
x = conv_block(x, 3, [128, 128, 512], stage=3, block='a')
x = identity_block(x, 3, [128, 128, 512], stage=3, block='b')
x = identity_block(x, 3, [128, 128, 512], stage=3, block='c')
x = identity_block(x, 3, [128, 128, 512], stage=3, block='d')
f3 = x
x = conv_block(x, 3, [256, 256, 1024], stage=4, block='a')
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='b')
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='c')
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='d')
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='e')
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='f')
f4 = x
x = conv_block(x, 3, [512, 512, 2048], stage=5, block='a')
x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b')
x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c')
f5 = x
if pretraining:
Model(img_input, x).load_weights(resnet50_Weights_path)
x = AveragePooling2D((7, 7), name='avg_pool')(x) x = AveragePooling2D((7, 7), name='avg_pool')(x)
x = Flatten()(x) x = Flatten()(x)
@ -658,9 +407,6 @@ def resnet50_classifier(n_classes,input_height=224,input_width=224,weight_decay=
x = Dense(n_classes, activation='softmax', name='fc1000')(x) x = Dense(n_classes, activation='softmax', name='fc1000')(x)
model = Model(img_input, x) model = Model(img_input, x)
return model return model
def machine_based_reading_order_model(n_classes,input_height=224,input_width=224,weight_decay=1e-6,pretraining=False): def machine_based_reading_order_model(n_classes,input_height=224,input_width=224,weight_decay=1e-6,pretraining=False):
@ -669,43 +415,10 @@ def machine_based_reading_order_model(n_classes,input_height=224,input_width=224
img_input = Input(shape=(input_height,input_width , 3 )) img_input = Input(shape=(input_height,input_width , 3 ))
if IMAGE_ORDERING == 'channels_last': _, _, _, _, x = resnet50(img_input, weight_decay, pretraining)
bn_axis = 3
else:
bn_axis = 1
x1 = ZeroPadding2D((3, 3), data_format=IMAGE_ORDERING)(img_input)
x1 = Conv2D(64, (7, 7), data_format=IMAGE_ORDERING, strides=(2, 2),kernel_regularizer=l2(weight_decay), name='conv1')(x1)
x1 = BatchNormalization(axis=bn_axis, name='bn_conv1')(x1)
x1 = Activation('relu')(x1)
x1 = MaxPooling2D((3, 3) , data_format=IMAGE_ORDERING , strides=(2, 2))(x1)
x1 = conv_block(x1, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1)) x = AveragePooling2D((7, 7), name='avg_pool1')(x)
x1 = identity_block(x1, 3, [64, 64, 256], stage=2, block='b') flattened = Flatten()(x)
x1 = identity_block(x1, 3, [64, 64, 256], stage=2, block='c')
x1 = conv_block(x1, 3, [128, 128, 512], stage=3, block='a')
x1 = identity_block(x1, 3, [128, 128, 512], stage=3, block='b')
x1 = identity_block(x1, 3, [128, 128, 512], stage=3, block='c')
x1 = identity_block(x1, 3, [128, 128, 512], stage=3, block='d')
x1 = conv_block(x1, 3, [256, 256, 1024], stage=4, block='a')
x1 = identity_block(x1, 3, [256, 256, 1024], stage=4, block='b')
x1 = identity_block(x1, 3, [256, 256, 1024], stage=4, block='c')
x1 = identity_block(x1, 3, [256, 256, 1024], stage=4, block='d')
x1 = identity_block(x1, 3, [256, 256, 1024], stage=4, block='e')
x1 = identity_block(x1, 3, [256, 256, 1024], stage=4, block='f')
x1 = conv_block(x1, 3, [512, 512, 2048], stage=5, block='a')
x1 = identity_block(x1, 3, [512, 512, 2048], stage=5, block='b')
x1 = identity_block(x1, 3, [512, 512, 2048], stage=5, block='c')
if pretraining:
Model(img_input , x1).load_weights(resnet50_Weights_path)
x1 = AveragePooling2D((7, 7), name='avg_pool1')(x1)
flattened = Flatten()(x1)
o = Dense(256, activation='relu', name='fc512')(flattened) o = Dense(256, activation='relu', name='fc512')(flattened)
o=Dropout(0.2)(o) o=Dropout(0.2)(o)
@ -719,83 +432,79 @@ def machine_based_reading_order_model(n_classes,input_height=224,input_width=224
return model return model
def cnn_rnn_ocr_model(image_height=None, image_width=None, n_classes=None, max_seq=None): def cnn_rnn_ocr_model(image_height=None, image_width=None, n_classes=None, max_seq=None):
input_img = tf.keras.Input(shape=(image_height, image_width, 3), name="image") input_img = Input(shape=(image_height, image_width, 3), name="image")
labels = tf.keras.layers.Input(name="label", shape=(None,)) labels = Input(name="label", shape=(None,))
x = tf.keras.layers.Conv2D(64,kernel_size=(3,3),padding="same")(input_img) x = Conv2D(64,kernel_size=(3,3),padding="same")(input_img)
x = tf.keras.layers.BatchNormalization(name="bn1")(x) x = BatchNormalization(name="bn1")(x)
x = tf.keras.layers.Activation("relu", name="relu1")(x) x = Activation("relu", name="relu1")(x)
x = tf.keras.layers.Conv2D(64,kernel_size=(3,3),padding="same")(x) x = Conv2D(64,kernel_size=(3,3),padding="same")(x)
x = tf.keras.layers.BatchNormalization(name="bn2")(x) x = BatchNormalization(name="bn2")(x)
x = tf.keras.layers.Activation("relu", name="relu2")(x) x = Activation("relu", name="relu2")(x)
x = tf.keras.layers.MaxPool2D(pool_size=(1,2),strides=(1,2))(x) x = MaxPooling2D(pool_size=(1,2),strides=(1,2))(x)
x = tf.keras.layers.Conv2D(128,kernel_size=(3,3),padding="same")(x) x = Conv2D(128,kernel_size=(3,3),padding="same")(x)
x = tf.keras.layers.BatchNormalization(name="bn3")(x) x = BatchNormalization(name="bn3")(x)
x = tf.keras.layers.Activation("relu", name="relu3")(x) x = Activation("relu", name="relu3")(x)
x = tf.keras.layers.Conv2D(128,kernel_size=(3,3),padding="same")(x) x = Conv2D(128,kernel_size=(3,3),padding="same")(x)
x = tf.keras.layers.BatchNormalization(name="bn4")(x) x = BatchNormalization(name="bn4")(x)
x = tf.keras.layers.Activation("relu", name="relu4")(x) x = Activation("relu", name="relu4")(x)
x = tf.keras.layers.MaxPool2D(pool_size=(1,2),strides=(1,2))(x) x = MaxPooling2D(pool_size=(1,2),strides=(1,2))(x)
x = tf.keras.layers.Conv2D(256,kernel_size=(3,3),padding="same")(x) x = Conv2D(256,kernel_size=(3,3),padding="same")(x)
x = tf.keras.layers.BatchNormalization(name="bn5")(x) x = BatchNormalization(name="bn5")(x)
x = tf.keras.layers.Activation("relu", name="relu5")(x) x = Activation("relu", name="relu5")(x)
x = tf.keras.layers.Conv2D(256,kernel_size=(3,3),padding="same")(x) x = Conv2D(256,kernel_size=(3,3),padding="same")(x)
x = tf.keras.layers.BatchNormalization(name="bn6")(x) x = BatchNormalization(name="bn6")(x)
x = tf.keras.layers.Activation("relu", name="relu6")(x) x = Activation("relu", name="relu6")(x)
x = tf.keras.layers.MaxPool2D(pool_size=(2,2),strides=(2,2))(x) x = MaxPooling2D(pool_size=(2,2),strides=(2,2))(x)
x = tf.keras.layers.Conv2D(image_width,kernel_size=(3,3),padding="same")(x) x = Conv2D(image_width,kernel_size=(3,3),padding="same")(x)
x = tf.keras.layers.BatchNormalization(name="bn7")(x) x = BatchNormalization(name="bn7")(x)
x = tf.keras.layers.Activation("relu", name="relu7")(x) x = Activation("relu", name="relu7")(x)
x = tf.keras.layers.Conv2D(image_width,kernel_size=(16,1))(x) x = Conv2D(image_width,kernel_size=(16,1))(x)
x = tf.keras.layers.BatchNormalization(name="bn8")(x) x = BatchNormalization(name="bn8")(x)
x = tf.keras.layers.Activation("relu", name="relu8")(x) x = Activation("relu", name="relu8")(x)
x2d = tf.keras.layers.MaxPool2D(pool_size=(1,2),strides=(1,2))(x) x2d = MaxPooling2D(pool_size=(1,2),strides=(1,2))(x)
x4d = tf.keras.layers.MaxPool2D(pool_size=(1,2),strides=(1,2))(x2d) x4d = MaxPooling2D(pool_size=(1,2),strides=(1,2))(x2d)
new_shape = (x.shape[1]*x.shape[2], x.shape[3]) new_shape = (x.shape[1]*x.shape[2], x.shape[3])
new_shape2 = (x2d.shape[1]*x2d.shape[2], x2d.shape[3]) new_shape2 = (x2d.shape[1]*x2d.shape[2], x2d.shape[3])
new_shape4 = (x4d.shape[1]*x4d.shape[2], x4d.shape[3]) new_shape4 = (x4d.shape[1]*x4d.shape[2], x4d.shape[3])
x = tf.keras.layers.Reshape(target_shape=new_shape, name="reshape")(x) x = Reshape(target_shape=new_shape, name="reshape")(x)
x2d = tf.keras.layers.Reshape(target_shape=new_shape2, name="reshape2")(x2d) x2d = Reshape(target_shape=new_shape2, name="reshape2")(x2d)
x4d = tf.keras.layers.Reshape(target_shape=new_shape4, name="reshape4")(x4d) x4d = Reshape(target_shape=new_shape4, name="reshape4")(x4d)
xrnnorg = Bidirectional(LSTM(image_width, return_sequences=True, dropout=0.25))(x)
xrnn2d = Bidirectional(LSTM(image_width, return_sequences=True, dropout=0.25))(x2d)
xrnn4d = Bidirectional(LSTM(image_width, return_sequences=True, dropout=0.25))(x4d)
xrnn2d = Reshape(target_shape=(1, xrnn2d.shape[1], xrnn2d.shape[2]), name="reshape6")(xrnn2d)
xrnn4d = Reshape(target_shape=(1, xrnn4d.shape[1], xrnn4d.shape[2]), name="reshape8")(xrnn4d)
xrnnorg = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(image_width, return_sequences=True, dropout=0.25))(x) xrnn2dup = UpSampling2D(size=(1, 2), interpolation="nearest")(xrnn2d)
xrnn2d = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(image_width, return_sequences=True, dropout=0.25))(x2d) xrnn4dup = UpSampling2D(size=(1, 4), interpolation="nearest")(xrnn4d)
xrnn4d = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(image_width, return_sequences=True, dropout=0.25))(x4d)
xrnn2d = tf.keras.layers.Reshape(target_shape=(1, xrnn2d.shape[1], xrnn2d.shape[2]), name="reshape6")(xrnn2d)
xrnn4d = tf.keras.layers.Reshape(target_shape=(1, xrnn4d.shape[1], xrnn4d.shape[2]), name="reshape8")(xrnn4d)
xrnn2dup = Reshape(target_shape=(xrnn2dup.shape[2], xrnn2dup.shape[3]), name="reshape10")(xrnn2dup)
xrnn4dup = Reshape(target_shape=(xrnn4dup.shape[2], xrnn4dup.shape[3]), name="reshape12")(xrnn4dup)
xrnn2dup = tf.keras.layers.UpSampling2D(size=(1, 2), interpolation="nearest")(xrnn2d) addition = Add()([xrnnorg, xrnn2dup, xrnn4dup])
xrnn4dup = tf.keras.layers.UpSampling2D(size=(1, 4), interpolation="nearest")(xrnn4d)
xrnn2dup = tf.keras.layers.Reshape(target_shape=(xrnn2dup.shape[2], xrnn2dup.shape[3]), name="reshape10")(xrnn2dup) addition_rnn = Bidirectional(LSTM(image_width, return_sequences=True, dropout=0.25))(addition)
xrnn4dup = tf.keras.layers.Reshape(target_shape=(xrnn4dup.shape[2], xrnn4dup.shape[3]), name="reshape12")(xrnn4dup)
out = Conv1D(max_seq, 1, data_format="channels_first")(addition_rnn)
out = BatchNormalization(name="bn9")(out)
out = Activation("relu", name="relu9")(out)
#out = Conv1D(n_classes, 1, activation='relu', data_format="channels_last")(out)
addition = tf.keras.layers.Add()([xrnnorg, xrnn2dup, xrnn4dup]) out = Dense(n_classes, activation="softmax", name="dense2")(out)
addition_rnn = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(image_width, return_sequences=True, dropout=0.25))(addition)
out = tf.keras.layers.Conv1D(max_seq, 1, data_format="channels_first")(addition_rnn)
out = tf.keras.layers.BatchNormalization(name="bn9")(out)
out = tf.keras.layers.Activation("relu", name="relu9")(out)
#out = tf.keras.layers.Conv1D(n_classes, 1, activation='relu', data_format="channels_last")(out)
out = tf.keras.layers.Dense(
n_classes, activation="softmax", name="dense2"
)(out)
# Add CTC layer for calculating CTC loss at each step. # Add CTC layer for calculating CTC loss at each step.
output = CTCLayer(name="ctc_loss")(labels, out) output = CTCLayer(name="ctc_loss")(labels, out)
model = tf.keras.models.Model(inputs=[input_img, labels], outputs=output, name="handwriting_recognizer") model = Model(inputs=[input_img, labels], outputs=output, name="handwriting_recognizer")
return model return model

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -1,136 +1,66 @@
import sys
from glob import glob
from os import environ, devnull
from os.path import join
from warnings import catch_warnings, simplefilter
import os import os
from warnings import catch_warnings, simplefilter
import click
import numpy as np import numpy as np
from PIL import Image
import cv2 os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15
environ['TF_CPP_MIN_LOG_LEVEL'] = '3' os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
stderr = sys.stderr
sys.stderr = open(devnull, 'w') from ocrd_utils import tf_disable_interactive_logs
tf_disable_interactive_logs()
import tensorflow as tf import tensorflow as tf
from tensorflow.keras.models import load_model from tensorflow.keras.models import load_model
from tensorflow.python.keras import backend as tensorflow_backend
sys.stderr = stderr
from tensorflow.keras import layers
import tensorflow.keras.losses
from tensorflow.keras.layers import *
import click
import logging
from ..patch_encoder import (
class Patches(layers.Layer): PatchEncoder,
def __init__(self, patch_size_x, patch_size_y): Patches,
super(Patches, self).__init__() )
self.patch_size_x = patch_size_x
self.patch_size_y = patch_size_y
def call(self, images):
#print(tf.shape(images)[1],'images')
#print(self.patch_size,'self.patch_size')
batch_size = tf.shape(images)[0]
patches = tf.image.extract_patches(
images=images,
sizes=[1, self.patch_size_y, self.patch_size_x, 1],
strides=[1, self.patch_size_y, self.patch_size_x, 1],
rates=[1, 1, 1, 1],
padding="VALID",
)
#patch_dims = patches.shape[-1]
patch_dims = tf.shape(patches)[-1]
patches = tf.reshape(patches, [batch_size, -1, patch_dims])
return patches
def get_config(self):
config = super().get_config().copy()
config.update({
'patch_size_x': self.patch_size_x,
'patch_size_y': self.patch_size_y,
})
return config
class PatchEncoder(layers.Layer):
def __init__(self, **kwargs):
super(PatchEncoder, self).__init__()
self.num_patches = num_patches
self.projection = layers.Dense(units=projection_dim)
self.position_embedding = layers.Embedding(
input_dim=num_patches, output_dim=projection_dim
)
def call(self, patch):
positions = tf.range(start=0, limit=self.num_patches, delta=1)
encoded = self.projection(patch) + self.position_embedding(positions)
return encoded
def get_config(self):
config = super().get_config().copy()
config.update({
'num_patches': self.num_patches,
'projection': self.projection,
'position_embedding': self.position_embedding,
})
return config
def run_ensembling(model_dirs, out_dir):
def start_new_session(): all_weights = []
###config = tf.compat.v1.ConfigProto()
###config.gpu_options.allow_growth = True
###self.session = tf.compat.v1.Session(config=config) # tf.InteractiveSession() for model_dir in model_dirs:
###tensorflow_backend.set_session(self.session) assert os.path.isdir(model_dir), model_dir
model = load_model(model_dir, compile=False,
config = tf.compat.v1.ConfigProto() custom_objects=dict(PatchEncoder=PatchEncoder,
config.gpu_options.allow_growth = True Patches=Patches))
all_weights.append(model.get_weights())
session = tf.compat.v1.Session(config=config) # tf.InteractiveSession()
tensorflow_backend.set_session(session)
return session
def run_ensembling(dir_models, out):
ls_models = os.listdir(dir_models)
weights=[]
for model_name in ls_models:
model = load_model(os.path.join(dir_models,model_name) , compile=False, custom_objects={'PatchEncoder':PatchEncoder, 'Patches': Patches})
weights.append(model.get_weights())
new_weights = list() new_weights = []
for layer_weights in zip(*all_weights):
layer_weights = np.array([np.array(weights).mean(axis=0)
for weights in zip(*layer_weights)])
new_weights.append(layer_weights)
for weights_list_tuple in zip(*weights): #model = tf.keras.models.clone_model(model)
new_weights.append(
[np.array(weights_).mean(axis=0)\
for weights_ in zip(*weights_list_tuple)])
new_weights = [np.array(x) for x in new_weights]
model.set_weights(new_weights) model.set_weights(new_weights)
model.save(out)
os.system('cp '+os.path.join(os.path.join(dir_models,model_name) , "config.json ")+out) model.save(out_dir)
os.system('cp ' + os.path.join(model_dirs[0], "config.json ") + out_dir + "/")
@click.command() @click.command()
@click.option( @click.option(
"--dir_models", "--in",
"-dm", "-i",
help="directory of models", help="input directory of checkpoint models to be read",
multiple=True,
required=True,
type=click.Path(exists=True, file_okay=False), type=click.Path(exists=True, file_okay=False),
) )
@click.option( @click.option(
"--out", "--out",
"-o", "-o",
help="output directory where ensembled model will be written.", help="output directory where ensembled model will be written.",
required=True,
type=click.Path(exists=False, file_okay=False), type=click.Path(exists=False, file_okay=False),
) )
def ensemble_cli(in_, out):
"""
mix multiple model weights
Load a sequence of models and mix them into a single ensemble model
by averaging their weights. Write the resulting model.
"""
run_ensembling(in_, out)
def main(dir_models, out):
run_ensembling(dir_models, out)

File diff suppressed because it is too large Load diff

View file

@ -14,21 +14,16 @@ from shapely.ops import unary_union, nearest_points
from .rotate import rotate_image, rotation_image_new from .rotate import rotate_image, rotation_image_new
def contours_in_same_horizon(cy_main_hor): def contours_in_same_horizon(cy_main_hor):
X1 = np.zeros((len(cy_main_hor), len(cy_main_hor))) """
X2 = np.zeros((len(cy_main_hor), len(cy_main_hor))) Takes an array of y coords, identifies all pairs among them
which are close to each other, and returns all such pairs
X1[0::1, :] = cy_main_hor[:] by index into the array.
X2 = X1.T """
sort = np.argsort(cy_main_hor)
X_dif = np.abs(X2 - X1) same = np.diff(cy_main_hor[sort]) <= 20
args_help = np.array(range(len(cy_main_hor))) # groups = np.split(sort, np.arange(len(cy_main_hor) - 1)[~same] + 1)
all_args = [] same = np.flatnonzero(same)
for i in range(len(cy_main_hor)): return np.stack((sort[:-1][same], sort[1:][same])).T
list_h = list(args_help[X_dif[i, :] <= 20])
list_h.append(i)
if len(list_h) > 1:
all_args.append(list(set(list_h)))
return np.unique(np.array(all_args, dtype=object))
def find_contours_mean_y_diff(contours_main): def find_contours_mean_y_diff(contours_main):
M_main = [cv2.moments(contours_main[j]) for j in range(len(contours_main))] M_main = [cv2.moments(contours_main[j]) for j in range(len(contours_main))]
@ -253,13 +248,17 @@ def return_contours_of_image(image):
return contours, hierarchy return contours, hierarchy
def dilate_textline_contours(all_found_textline_polygons): def dilate_textline_contours(all_found_textline_polygons):
return [[polygon2contour(contour2polygon(contour, dilate=6)) from . import ensure_array
for contour in region] return [ensure_array(
[polygon2contour(contour2polygon(contour, dilate=6))
for contour in region])
for region in all_found_textline_polygons] for region in all_found_textline_polygons]
def dilate_textregion_contours(all_found_textline_polygons): def dilate_textregion_contours(all_found_textregion_polygons):
return [polygon2contour(contour2polygon(contour, dilate=6)) from . import ensure_array
for contour in all_found_textline_polygons] return ensure_array(
[polygon2contour(contour2polygon(contour, dilate=6))
for contour in all_found_textregion_polygons])
def contour2polygon(contour: Union[np.ndarray, Sequence[Sequence[Sequence[Number]]]], dilate=0): def contour2polygon(contour: Union[np.ndarray, Sequence[Sequence[Sequence[Number]]]], dilate=0):
polygon = Polygon([point[0] for point in contour]) polygon = Polygon([point[0] for point in contour])

View file

@ -399,14 +399,14 @@ def separate_lines(img_patch, contour_text_interest, thetha, x_help, y_help):
point_down_rot3=point_down_rot3-y_help point_down_rot3=point_down_rot3-y_help
point_down_rot4=point_down_rot4-y_help point_down_rot4=point_down_rot4-y_help
textline_boxes_rot.append(np.array([[int(x_min_rot1), int(point_up_rot1)], textline_boxes_rot.append(np.array([[[int(x_min_rot1), int(point_up_rot1)]],
[int(x_max_rot2), int(point_up_rot2)], [[int(x_max_rot2), int(point_up_rot2)]],
[int(x_max_rot3), int(point_down_rot3)], [[int(x_max_rot3), int(point_down_rot3)]],
[int(x_min_rot4), int(point_down_rot4)]])) [[int(x_min_rot4), int(point_down_rot4)]]]))
textline_boxes.append(np.array([[int(x_min), int(point_up)], textline_boxes.append(np.array([[[int(x_min), int(point_up)]],
[int(x_max), int(point_up)], [[int(x_max), int(point_up)]],
[int(x_max), int(point_down)], [[int(x_max), int(point_down)]],
[int(x_min), int(point_down)]])) [[int(x_min), int(point_down)]]]))
elif len(peaks) < 1: elif len(peaks) < 1:
pass pass
@ -458,14 +458,14 @@ def separate_lines(img_patch, contour_text_interest, thetha, x_help, y_help):
point_down_rot3=point_down_rot3-y_help point_down_rot3=point_down_rot3-y_help
point_down_rot4=point_down_rot4-y_help point_down_rot4=point_down_rot4-y_help
textline_boxes_rot.append(np.array([[int(x_min_rot1), int(point_up_rot1)], textline_boxes_rot.append(np.array([[[int(x_min_rot1), int(point_up_rot1)]],
[int(x_max_rot2), int(point_up_rot2)], [[int(x_max_rot2), int(point_up_rot2)]],
[int(x_max_rot3), int(point_down_rot3)], [[int(x_max_rot3), int(point_down_rot3)]],
[int(x_min_rot4), int(point_down_rot4)]])) [[int(x_min_rot4), int(point_down_rot4)]]]))
textline_boxes.append(np.array([[int(x_min), int(y_min)], textline_boxes.append(np.array([[[int(x_min), int(y_min)]],
[int(x_max), int(y_min)], [[int(x_max), int(y_min)]],
[int(x_max), int(y_max)], [[int(x_max), int(y_max)]],
[int(x_min), int(y_max)]])) [[int(x_min), int(y_max)]]]))
elif len(peaks) == 2: elif len(peaks) == 2:
dis_to_next = np.abs(peaks[1] - peaks[0]) dis_to_next = np.abs(peaks[1] - peaks[0])
for jj in range(len(peaks)): for jj in range(len(peaks)):
@ -526,14 +526,14 @@ def separate_lines(img_patch, contour_text_interest, thetha, x_help, y_help):
point_down_rot3=point_down_rot3-y_help point_down_rot3=point_down_rot3-y_help
point_down_rot4=point_down_rot4-y_help point_down_rot4=point_down_rot4-y_help
textline_boxes_rot.append(np.array([[int(x_min_rot1), int(point_up_rot1)], textline_boxes_rot.append(np.array([[[int(x_min_rot1), int(point_up_rot1)]],
[int(x_max_rot2), int(point_up_rot2)], [[int(x_max_rot2), int(point_up_rot2)]],
[int(x_max_rot3), int(point_down_rot3)], [[int(x_max_rot3), int(point_down_rot3)]],
[int(x_min_rot4), int(point_down_rot4)]])) [[int(x_min_rot4), int(point_down_rot4)]]]))
textline_boxes.append(np.array([[int(x_min), int(point_up)], textline_boxes.append(np.array([[[int(x_min), int(point_up)]],
[int(x_max), int(point_up)], [[int(x_max), int(point_up)]],
[int(x_max), int(point_down)], [[int(x_max), int(point_down)]],
[int(x_min), int(point_down)]])) [[int(x_min), int(point_down)]]]))
else: else:
for jj in range(len(peaks)): for jj in range(len(peaks)):
if jj == 0: if jj == 0:
@ -602,14 +602,14 @@ def separate_lines(img_patch, contour_text_interest, thetha, x_help, y_help):
point_down_rot3=point_down_rot3-y_help point_down_rot3=point_down_rot3-y_help
point_down_rot4=point_down_rot4-y_help point_down_rot4=point_down_rot4-y_help
textline_boxes_rot.append(np.array([[int(x_min_rot1), int(point_up_rot1)], textline_boxes_rot.append(np.array([[[int(x_min_rot1), int(point_up_rot1)]],
[int(x_max_rot2), int(point_up_rot2)], [[int(x_max_rot2), int(point_up_rot2)]],
[int(x_max_rot3), int(point_down_rot3)], [[int(x_max_rot3), int(point_down_rot3)]],
[int(x_min_rot4), int(point_down_rot4)]])) [[int(x_min_rot4), int(point_down_rot4)]]]))
textline_boxes.append(np.array([[int(x_min), int(point_up)], textline_boxes.append(np.array([[[int(x_min), int(point_up)]],
[int(x_max), int(point_up)], [[int(x_max), int(point_up)]],
[int(x_max), int(point_down)], [[int(x_max), int(point_down)]],
[int(x_min), int(point_down)]])) [[int(x_min), int(point_down)]]]))
return peaks, textline_boxes_rot return peaks, textline_boxes_rot
def separate_lines_vertical(img_patch, contour_text_interest, thetha): def separate_lines_vertical(img_patch, contour_text_interest, thetha):
@ -781,14 +781,14 @@ def separate_lines_vertical(img_patch, contour_text_interest, thetha):
if point_up_rot2 < 0: if point_up_rot2 < 0:
point_up_rot2 = 0 point_up_rot2 = 0
textline_boxes_rot.append(np.array([[int(x_min_rot1), int(point_up_rot1)], textline_boxes_rot.append(np.array([[[int(x_min_rot1), int(point_up_rot1)]],
[int(x_max_rot2), int(point_up_rot2)], [[int(x_max_rot2), int(point_up_rot2)]],
[int(x_max_rot3), int(point_down_rot3)], [[int(x_max_rot3), int(point_down_rot3)]],
[int(x_min_rot4), int(point_down_rot4)]])) [[int(x_min_rot4), int(point_down_rot4)]]]))
textline_boxes.append(np.array([[int(x_min), int(point_up)], textline_boxes.append(np.array([[[int(x_min), int(point_up)]],
[int(x_max), int(point_up)], [[int(x_max), int(point_up)]],
[int(x_max), int(point_down)], [[int(x_max), int(point_down)]],
[int(x_min), int(point_down)]])) [[int(x_min), int(point_down)]]]))
elif len(peaks) < 1: elif len(peaks) < 1:
pass pass
elif len(peaks) == 1: elif len(peaks) == 1:
@ -817,14 +817,14 @@ def separate_lines_vertical(img_patch, contour_text_interest, thetha):
if point_up_rot2 < 0: if point_up_rot2 < 0:
point_up_rot2 = 0 point_up_rot2 = 0
textline_boxes_rot.append(np.array([[int(x_min_rot1), int(point_up_rot1)], textline_boxes_rot.append(np.array([[[int(x_min_rot1), int(point_up_rot1)]],
[int(x_max_rot2), int(point_up_rot2)], [[int(x_max_rot2), int(point_up_rot2)]],
[int(x_max_rot3), int(point_down_rot3)], [[int(x_max_rot3), int(point_down_rot3)]],
[int(x_min_rot4), int(point_down_rot4)]])) [[int(x_min_rot4), int(point_down_rot4)]]]))
textline_boxes.append(np.array([[int(x_min), int(y_min)], textline_boxes.append(np.array([[[int(x_min), int(y_min)]],
[int(x_max), int(y_min)], [[int(x_max), int(y_min)]],
[int(x_max), int(y_max)], [[int(x_max), int(y_max)]],
[int(x_min), int(y_max)]])) [[int(x_min), int(y_max)]]]))
elif len(peaks) == 2: elif len(peaks) == 2:
dis_to_next = np.abs(peaks[1] - peaks[0]) dis_to_next = np.abs(peaks[1] - peaks[0])
for jj in range(len(peaks)): for jj in range(len(peaks)):
@ -872,14 +872,14 @@ def separate_lines_vertical(img_patch, contour_text_interest, thetha):
if point_up_rot2 < 0: if point_up_rot2 < 0:
point_up_rot2 = 0 point_up_rot2 = 0
textline_boxes_rot.append(np.array([[int(x_min_rot1), int(point_up_rot1)], textline_boxes_rot.append(np.array([[[int(x_min_rot1), int(point_up_rot1)]],
[int(x_max_rot2), int(point_up_rot2)], [[int(x_max_rot2), int(point_up_rot2)]],
[int(x_max_rot3), int(point_down_rot3)], [[int(x_max_rot3), int(point_down_rot3)]],
[int(x_min_rot4), int(point_down_rot4)]])) [[int(x_min_rot4), int(point_down_rot4)]]]))
textline_boxes.append(np.array([[int(x_min), int(point_up)], textline_boxes.append(np.array([[[int(x_min), int(point_up)]],
[int(x_max), int(point_up)], [[int(x_max), int(point_up)]],
[int(x_max), int(point_down)], [[int(x_max), int(point_down)]],
[int(x_min), int(point_down)]])) [[int(x_min), int(point_down)]]]))
else: else:
for jj in range(len(peaks)): for jj in range(len(peaks)):
if jj == 0: if jj == 0:
@ -938,14 +938,14 @@ def separate_lines_vertical(img_patch, contour_text_interest, thetha):
if point_up_rot2 < 0: if point_up_rot2 < 0:
point_up_rot2 = 0 point_up_rot2 = 0
textline_boxes_rot.append(np.array([[int(x_min_rot1), int(point_up_rot1)], textline_boxes_rot.append(np.array([[[int(x_min_rot1), int(point_up_rot1)]],
[int(x_max_rot2), int(point_up_rot2)], [[int(x_max_rot2), int(point_up_rot2)]],
[int(x_max_rot3), int(point_down_rot3)], [[int(x_max_rot3), int(point_down_rot3)]],
[int(x_min_rot4), int(point_down_rot4)]])) [[int(x_min_rot4), int(point_down_rot4)]]]))
textline_boxes.append(np.array([[int(x_min), int(point_up)], textline_boxes.append(np.array([[[int(x_min), int(point_up)]],
[int(x_max), int(point_up)], [[int(x_max), int(point_up)]],
[int(x_max), int(point_down)], [[int(x_max), int(point_down)]],
[int(x_min), int(point_down)]])) [[int(x_min), int(point_down)]]]))
return peaks, textline_boxes_rot return peaks, textline_boxes_rot
def separate_lines_new_inside_tiles2(img_patch, thetha): def separate_lines_new_inside_tiles2(img_patch, thetha):
@ -1560,6 +1560,9 @@ def return_deskew_slop(img_patch_org, sigma_des,n_tot_angles=100,
angle2, var2 = get_smallest_skew(img_resized, sigma_des, angles2, map=map, logger=logger, plotter=plotter) angle2, var2 = get_smallest_skew(img_resized, sigma_des, angles2, map=map, logger=logger, plotter=plotter)
if var2 > var: if var2 > var:
angle = angle2 angle = angle2
# precision stage:
angles = np.linspace(angle - 2.5, angle + 2.5, n_tot_angles // 2)
angle, _ = get_smallest_skew(img_resized, sigma_des, angles, map=map, logger=logger, plotter=plotter)
return angle return angle
def get_smallest_skew(img, sigma_des, angles, logger=None, plotter=None, map=map): def get_smallest_skew(img, sigma_des, angles, logger=None, plotter=None, map=map):

View file

@ -370,8 +370,8 @@ def break_curved_line_into_small_pieces_and_then_merge(img_curved, mask_curved,
return img_curved, img_bin_curved return img_curved, img_bin_curved
def return_textline_contour_with_added_box_coordinate(textline_contour, box_ind): def return_textline_contour_with_added_box_coordinate(textline_contour, box_ind):
textline_contour[:,0] = textline_contour[:,0] + box_ind[2] textline_contour[:,:,0] += box_ind[2]
textline_contour[:,1] = textline_contour[:,1] + box_ind[0] textline_contour[:,:,1] += box_ind[0]
return textline_contour return textline_contour

View file

@ -2,11 +2,12 @@
# pylint: disable=import-error # pylint: disable=import-error
from pathlib import Path from pathlib import Path
import os.path import os.path
from typing import Optional
import logging import logging
from .utils.xml import create_page_xml, xml_reading_order from typing import Optional
from .utils.counter import EynollahIdCounter import numpy as np
from shapely import affinity, clip_by_rect
from ocrd_utils import points_from_polygon
from ocrd_models.ocrd_page import ( from ocrd_models.ocrd_page import (
BorderType, BorderType,
CoordsType, CoordsType,
@ -19,6 +20,10 @@ from ocrd_models.ocrd_page import (
to_xml to_xml
) )
from .utils.xml import create_page_xml, xml_reading_order
from .utils.counter import EynollahIdCounter
from .utils.contour import contour2polygon, make_valid
class EynollahXmlWriter: class EynollahXmlWriter:
def __init__(self, *, dir_out, image_filename, curved_line, pcgts=None): def __init__(self, *, dir_out, image_filename, curved_line, pcgts=None):
@ -38,20 +43,14 @@ class EynollahXmlWriter:
def image_filename_stem(self): def image_filename_stem(self):
return Path(Path(self.image_filename).name).stem return Path(Path(self.image_filename).name).stem
def calculate_page_coords(self, cont_page): def calculate_points(self, contour, offset=None):
self.logger.debug('enter calculate_page_coords') self.logger.debug('enter calculate_points')
points_page_print = "" poly = contour2polygon(contour)
for _, contour in enumerate(cont_page[0]): if offset is not None:
if len(contour) == 2: poly = affinity.translate(poly, *offset)
points_page_print += str(int((contour[0]) / self.scale_x)) poly = affinity.scale(poly, xfact=1 / self.scale_x, yfact=1 / self.scale_y, origin=(0, 0))
points_page_print += ',' poly = make_valid(clip_by_rect(poly, 0, 0, self.width_org, self.height_org))
points_page_print += str(int((contour[1]) / self.scale_y)) return points_from_polygon(poly.exterior.coords[:-1])
else:
points_page_print += str(int((contour[0][0]) / self.scale_x))
points_page_print += ','
points_page_print += str(int((contour[0][1] ) / self.scale_y))
points_page_print = points_page_print + ' '
return points_page_print[:-1]
def serialize_lines_in_region(self, text_region, all_found_textline_polygons, region_idx, page_coord, all_box_coord, slopes, counter, ocr_all_textlines_textregion): def serialize_lines_in_region(self, text_region, all_found_textline_polygons, region_idx, page_coord, all_box_coord, slopes, counter, ocr_all_textlines_textregion):
self.logger.debug('enter serialize_lines_in_region') self.logger.debug('enter serialize_lines_in_region')
@ -64,16 +63,12 @@ class EynollahXmlWriter:
text_region.add_TextLine(textline) text_region.add_TextLine(textline)
text_region.set_orientation(-slopes[region_idx]) text_region.set_orientation(-slopes[region_idx])
region_bboxes = all_box_coord[region_idx] region_bboxes = all_box_coord[region_idx]
points_co = '' offset = [page_coord[2], page_coord[0]]
for point in polygon_textline: # FIXME: or actually... self.curved_line or np.abs(slopes[region_idx]) > 45?
if len(point) != 2: if self.curved_line and np.abs(slopes[region_idx]) > 45:
point = point[0] offset[0] += region_bboxes[2]
point_x = point[0] + page_coord[2] offset[1] += region_bboxes[0]
point_y = point[1] + page_coord[0] coords.set_points(self.calculate_points(polygon_textline, offset))
point_x = max(0, int(point_x / self.scale_x))
point_y = max(0, int(point_y / self.scale_y))
points_co += f'{point_x},{point_y} '
coords.set_points(points_co[:-1])
def write_pagexml(self, pcgts): def write_pagexml(self, pcgts):
self.logger.info("output filename: '%s'", self.output_filename) self.logger.info("output filename: '%s'", self.output_filename)
@ -168,9 +163,13 @@ class EynollahXmlWriter:
# create the file structure # create the file structure
pcgts = self.pcgts if self.pcgts else create_page_xml(self.image_filename, self.height_org, self.width_org) pcgts = self.pcgts if self.pcgts else create_page_xml(self.image_filename, self.height_org, self.width_org)
page = pcgts.get_Page() page = pcgts.get_Page()
assert page if len(cont_page):
page.set_Border(BorderType(Coords=CoordsType(points=self.calculate_page_coords(cont_page)))) page.set_Border(BorderType(Coords=CoordsType(points=self.calculate_points(cont_page[0]))))
if skip_layout_reading_order:
offset = None
else:
offset = [page_coord[2], page_coord[0]]
counter = EynollahIdCounter() counter = EynollahIdCounter()
if len(order_of_texts): if len(order_of_texts):
_counter_marginals = EynollahIdCounter(region_idx=len(order_of_texts)) _counter_marginals = EynollahIdCounter(region_idx=len(order_of_texts))
@ -183,8 +182,7 @@ class EynollahXmlWriter:
for mm, region_contour in enumerate(found_polygons_text_region): for mm, region_contour in enumerate(found_polygons_text_region):
textregion = TextRegionType( textregion = TextRegionType(
id=counter.next_region_id, type_='paragraph', id=counter.next_region_id, type_='paragraph',
Coords=CoordsType(points=self.calculate_polygon_coords(region_contour, page_coord, Coords=CoordsType(points=self.calculate_points(region_contour, offset))
skip_layout_reading_order))
) )
assert textregion.Coords assert textregion.Coords
if conf_contours_textregions: if conf_contours_textregions:
@ -201,7 +199,7 @@ class EynollahXmlWriter:
for mm, region_contour in enumerate(found_polygons_text_region_h): for mm, region_contour in enumerate(found_polygons_text_region_h):
textregion = TextRegionType( textregion = TextRegionType(
id=counter.next_region_id, type_='heading', id=counter.next_region_id, type_='heading',
Coords=CoordsType(points=self.calculate_polygon_coords(region_contour, page_coord)) Coords=CoordsType(points=self.calculate_points(region_contour, offset))
) )
assert textregion.Coords assert textregion.Coords
if conf_contours_textregions_h: if conf_contours_textregions_h:
@ -217,7 +215,7 @@ class EynollahXmlWriter:
for mm, region_contour in enumerate(found_polygons_marginals_left): for mm, region_contour in enumerate(found_polygons_marginals_left):
marginal = TextRegionType( marginal = TextRegionType(
id=counter.next_region_id, type_='marginalia', id=counter.next_region_id, type_='marginalia',
Coords=CoordsType(points=self.calculate_polygon_coords(region_contour, page_coord)) Coords=CoordsType(points=self.calculate_points(region_contour, offset))
) )
page.add_TextRegion(marginal) page.add_TextRegion(marginal)
if ocr_all_textlines_marginals_left: if ocr_all_textlines_marginals_left:
@ -229,7 +227,7 @@ class EynollahXmlWriter:
for mm, region_contour in enumerate(found_polygons_marginals_right): for mm, region_contour in enumerate(found_polygons_marginals_right):
marginal = TextRegionType( marginal = TextRegionType(
id=counter.next_region_id, type_='marginalia', id=counter.next_region_id, type_='marginalia',
Coords=CoordsType(points=self.calculate_polygon_coords(region_contour, page_coord)) Coords=CoordsType(points=self.calculate_points(region_contour, offset))
) )
page.add_TextRegion(marginal) page.add_TextRegion(marginal)
if ocr_all_textlines_marginals_right: if ocr_all_textlines_marginals_right:
@ -242,7 +240,7 @@ class EynollahXmlWriter:
for mm, region_contour in enumerate(found_polygons_drop_capitals): for mm, region_contour in enumerate(found_polygons_drop_capitals):
dropcapital = TextRegionType( dropcapital = TextRegionType(
id=counter.next_region_id, type_='drop-capital', id=counter.next_region_id, type_='drop-capital',
Coords=CoordsType(points=self.calculate_polygon_coords(region_contour, page_coord)) Coords=CoordsType(points=self.calculate_points(region_contour, offset))
) )
page.add_TextRegion(dropcapital) page.add_TextRegion(dropcapital)
all_box_coord_drop = [[0, 0, 0, 0]] all_box_coord_drop = [[0, 0, 0, 0]]
@ -257,33 +255,17 @@ class EynollahXmlWriter:
for region_contour in found_polygons_text_region_img: for region_contour in found_polygons_text_region_img:
page.add_ImageRegion( page.add_ImageRegion(
ImageRegionType(id=counter.next_region_id, ImageRegionType(id=counter.next_region_id,
Coords=CoordsType(points=self.calculate_polygon_coords(region_contour, page_coord)))) Coords=CoordsType(points=self.calculate_points(region_contour, offset))))
for region_contour in polygons_seplines: for region_contour in polygons_seplines:
page.add_SeparatorRegion( page.add_SeparatorRegion(
SeparatorRegionType(id=counter.next_region_id, SeparatorRegionType(id=counter.next_region_id,
Coords=CoordsType(points=self.calculate_polygon_coords(region_contour, [0, 0, 0, 0])))) Coords=CoordsType(points=self.calculate_points(region_contour, None))))
for region_contour in found_polygons_tables: for region_contour in found_polygons_tables:
page.add_TableRegion( page.add_TableRegion(
TableRegionType(id=counter.next_region_id, TableRegionType(id=counter.next_region_id,
Coords=CoordsType(points=self.calculate_polygon_coords(region_contour, page_coord)))) Coords=CoordsType(points=self.calculate_points(region_contour, offset))))
return pcgts return pcgts
def calculate_polygon_coords(self, contour, page_coord, skip_layout_reading_order=False):
self.logger.debug('enter calculate_polygon_coords')
coords = ''
for point in contour:
if len(point) != 2:
point = point[0]
point_x = point[0]
point_y = point[1]
if not skip_layout_reading_order:
point_x += page_coord[2]
point_y += page_coord[0]
point_x = int(point_x / self.scale_x)
point_y = int(point_y / self.scale_y)
coords += str(point_x) + ',' + str(point_y) + ' '
return coords[:-1]

View file

@ -22,7 +22,7 @@ def test_run_eynollah_binarization_filename(
'-o', str(outfile), '-o', str(outfile),
] + options, ] + options,
[ [
'Predicting' 'Loaded model'
] ]
) )
assert outfile.exists() assert outfile.exists()
@ -46,8 +46,8 @@ def test_run_eynollah_binarization_directory(
'-o', str(outdir), '-o', str(outdir),
], ],
[ [
f'Predicting {image_resources[0].name}', f'Binarizing [ 1/2] {image_resources[0].name}',
f'Predicting {image_resources[1].name}', f'Binarizing [ 2/2] {image_resources[1].name}',
] ]
) )
assert len(list(outdir.iterdir())) == 2 assert len(list(outdir.iterdir())) == 2

View file

@ -1,6 +1,6 @@
sacred sacred
seaborn seaborn
numpy <1.24.0 numpy
tqdm tqdm
imutils imutils
scipy scipy