mirror of
https://github.com/qurator-spk/eynollah.git
synced 2025-10-08 07:29:55 +02:00
Merge 8a9b4f8f55
into 5725e4fd1f
This commit is contained in:
commit
f9274990bf
33 changed files with 6188 additions and 196 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -6,3 +6,4 @@ output.html
|
||||||
/build
|
/build
|
||||||
/dist
|
/dist
|
||||||
*.tif
|
*.tif
|
||||||
|
*.sw?
|
||||||
|
|
|
@ -22,7 +22,7 @@ Added:
|
||||||
Fixed:
|
Fixed:
|
||||||
|
|
||||||
* allow empty imports for optional dependencies
|
* allow empty imports for optional dependencies
|
||||||
* avoid Numpy warnings (empty slices etc)
|
* avoid Numpy warnings (empty slices etc.)
|
||||||
* remove deprecated Numpy types
|
* remove deprecated Numpy types
|
||||||
* binarization CLI: make `dir_in` usable again
|
* binarization CLI: make `dir_in` usable again
|
||||||
|
|
||||||
|
|
55
README.md
55
README.md
|
@ -11,23 +11,24 @@
|
||||||

|

|
||||||
|
|
||||||
## Features
|
## Features
|
||||||
* Support for up to 10 segmentation classes:
|
* Support for 10 distinct segmentation classes:
|
||||||
* background, [page border](https://ocr-d.de/en/gt-guidelines/trans/lyRand.html), [text region](https://ocr-d.de/en/gt-guidelines/trans/lytextregion.html#textregionen__textregion_), [text line](https://ocr-d.de/en/gt-guidelines/pagexml/pagecontent_xsd_Complex_Type_pc_TextLineType.html), [header](https://ocr-d.de/en/gt-guidelines/trans/lyUeberschrift.html), [image](https://ocr-d.de/en/gt-guidelines/trans/lyBildbereiche.html), [separator](https://ocr-d.de/en/gt-guidelines/trans/lySeparatoren.html), [marginalia](https://ocr-d.de/en/gt-guidelines/trans/lyMarginalie.html), [initial](https://ocr-d.de/en/gt-guidelines/trans/lyInitiale.html), [table](https://ocr-d.de/en/gt-guidelines/trans/lyTabellen.html)
|
* background, [page border](https://ocr-d.de/en/gt-guidelines/trans/lyRand.html), [text region](https://ocr-d.de/en/gt-guidelines/trans/lytextregion.html#textregionen__textregion_), [text line](https://ocr-d.de/en/gt-guidelines/pagexml/pagecontent_xsd_Complex_Type_pc_TextLineType.html), [header](https://ocr-d.de/en/gt-guidelines/trans/lyUeberschrift.html), [image](https://ocr-d.de/en/gt-guidelines/trans/lyBildbereiche.html), [separator](https://ocr-d.de/en/gt-guidelines/trans/lySeparatoren.html), [marginalia](https://ocr-d.de/en/gt-guidelines/trans/lyMarginalie.html), [initial](https://ocr-d.de/en/gt-guidelines/trans/lyInitiale.html), [table](https://ocr-d.de/en/gt-guidelines/trans/lyTabellen.html)
|
||||||
* Support for various image optimization operations:
|
* Support for various image optimization operations:
|
||||||
* cropping (border detection), binarization, deskewing, dewarping, scaling, enhancing, resizing
|
* cropping (border detection), binarization, deskewing, dewarping, scaling, enhancing, resizing
|
||||||
* Text line segmentation to bounding boxes or polygons (contours) including for curved lines and vertical text
|
* Textline segmentation to bounding boxes or polygons (contours) including for curved lines and vertical text
|
||||||
* Detection of reading order (left-to-right or right-to-left)
|
* Text recognition (OCR) using either CNN-RNN or Transformer models
|
||||||
|
* Detection of reading order (left-to-right or right-to-left) using either heuristics or trainable models
|
||||||
* Output in [PAGE-XML](https://github.com/PRImA-Research-Lab/PAGE-XML)
|
* Output in [PAGE-XML](https://github.com/PRImA-Research-Lab/PAGE-XML)
|
||||||
* [OCR-D](https://github.com/qurator-spk/eynollah#use-as-ocr-d-processor) interface
|
* [OCR-D](https://github.com/qurator-spk/eynollah#use-as-ocr-d-processor) interface
|
||||||
|
|
||||||
:warning: Development is currently focused on achieving the best possible quality of results for a wide variety of
|
:warning: Development is focused on achieving the best quality of results for a wide variety of historical
|
||||||
historical documents and therefore processing can be very slow. We aim to improve this, but contributions are welcome.
|
documents and therefore processing can be very slow. We aim to improve this, but contributions are welcome.
|
||||||
|
|
||||||
## Installation
|
## Installation
|
||||||
|
|
||||||
Python `3.8-3.11` with Tensorflow `<2.13` on Linux are currently supported.
|
Python `3.8-3.11` with Tensorflow `<2.13` on Linux are currently supported.
|
||||||
|
|
||||||
For (limited) GPU support the CUDA toolkit needs to be installed.
|
For (limited) GPU support the CUDA toolkit needs to be installed. A known working config is CUDA `11` with cuDNN `8.6`.
|
||||||
|
|
||||||
You can either install from PyPI
|
You can either install from PyPI
|
||||||
|
|
||||||
|
@ -53,26 +54,30 @@ make install EXTRAS=OCR
|
||||||
```
|
```
|
||||||
|
|
||||||
## Models
|
## Models
|
||||||
|
|
||||||
Pretrained models can be downloaded from [zenodo](https://zenodo.org/records/17194824) or [huggingface](https://huggingface.co/SBB?search_models=eynollah).
|
Pretrained models can be downloaded from [zenodo](https://zenodo.org/records/17194824) or [huggingface](https://huggingface.co/SBB?search_models=eynollah).
|
||||||
|
|
||||||
For documentation on methods and models, have a look at [`models.md`](https://github.com/qurator-spk/eynollah/tree/main/docs/models.md).
|
For documentation on models, have a look at [`models.md`](https://github.com/qurator-spk/eynollah/tree/main/docs/models.md).
|
||||||
|
Model cards are also provided for our trained models.
|
||||||
|
|
||||||
## Train
|
## Training
|
||||||
|
|
||||||
In case you want to train your own model with Eynollah, have a look at [`train.md`](https://github.com/qurator-spk/eynollah/tree/main/docs/train.md).
|
In case you want to train your own model with Eynollah, see the
|
||||||
|
documentation in [`train.md`](https://github.com/qurator-spk/eynollah/tree/main/docs/train.md) and use the
|
||||||
|
tools in the [`train` folder](https://github.com/qurator-spk/eynollah/tree/main/train).
|
||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
|
|
||||||
Eynollah supports five use cases: layout analysis (segmentation), binarization,
|
Eynollah supports five use cases: layout analysis (segmentation), binarization,
|
||||||
image enhancement, text recognition (OCR), and (trainable) reading order detection.
|
image enhancement, text recognition (OCR), and reading order detection.
|
||||||
|
|
||||||
### Layout Analysis
|
### Layout Analysis
|
||||||
|
|
||||||
The layout analysis module is responsible for detecting layouts, identifying text lines, and determining reading order
|
The layout analysis module is responsible for detecting layout elements, identifying text lines, and determining reading
|
||||||
using both heuristic methods or a machine-based reading order detection model.
|
order using either heuristic methods or a [pretrained reading order detection model](https://github.com/qurator-spk/eynollah#machine-based-reading-order).
|
||||||
|
|
||||||
Note that there are currently two supported ways for reading order detection: either as part of layout analysis based
|
Reading order detection can be performed either as part of layout analysis based on image input, or, currently under
|
||||||
on image input, or, currently under development, for given layout analysis results based on PAGE-XML data as input.
|
development, based on pre-existing layout analysis results in PAGE-XML format as input.
|
||||||
|
|
||||||
The command-line interface for layout analysis can be called like this:
|
The command-line interface for layout analysis can be called like this:
|
||||||
|
|
||||||
|
@ -105,15 +110,15 @@ The following options can be used to further configure the processing:
|
||||||
| `-sp <directory>` | save cropped page image to this directory |
|
| `-sp <directory>` | save cropped page image to this directory |
|
||||||
| `-sa <directory>` | save all (plot, enhanced/binary image, layout) to this directory |
|
| `-sa <directory>` | save all (plot, enhanced/binary image, layout) to this directory |
|
||||||
|
|
||||||
If no option is set, the tool performs layout detection of main regions (background, text, images, separators
|
If no further option is set, the tool performs layout detection of main regions (background, text, images, separators
|
||||||
and marginals).
|
and marginals).
|
||||||
The best output quality is produced when RGB images are used as input rather than greyscale or binarized images.
|
The best output quality is achieved when RGB images are used as input rather than greyscale or binarized images.
|
||||||
|
|
||||||
### Binarization
|
### Binarization
|
||||||
|
|
||||||
The binarization module performs document image binarization using pretrained pixelwise segmentation models.
|
The binarization module performs document image binarization using pretrained pixelwise segmentation models.
|
||||||
|
|
||||||
The command-line interface for binarization of single image can be called like this:
|
The command-line interface for binarization can be called like this:
|
||||||
|
|
||||||
```sh
|
```sh
|
||||||
eynollah binarization \
|
eynollah binarization \
|
||||||
|
@ -124,16 +129,16 @@ eynollah binarization \
|
||||||
|
|
||||||
### OCR
|
### OCR
|
||||||
|
|
||||||
The OCR module performs text recognition from images using two main families of pretrained models: CNN-RNN–based OCR and Transformer-based OCR.
|
The OCR module performs text recognition using either a CNN-RNN model or a Transformer model.
|
||||||
|
|
||||||
The command-line interface for ocr can be called like this:
|
The command-line interface for OCR can be called like this:
|
||||||
|
|
||||||
```sh
|
```sh
|
||||||
eynollah ocr \
|
eynollah ocr \
|
||||||
-i <single image file> | -di <directory containing image files> \
|
-i <single image file> | -di <directory containing image files> \
|
||||||
-dx <directory of xmls> \
|
-dx <directory of xmls> \
|
||||||
-o <output directory> \
|
-o <output directory> \
|
||||||
-m <path to directory containing model files> | --model_name <path to specific model> \
|
-m <directory containing model files> | --model_name <path to specific model> \
|
||||||
```
|
```
|
||||||
|
|
||||||
### Machine-based-reading-order
|
### Machine-based-reading-order
|
||||||
|
@ -169,22 +174,20 @@ If the input file group is PAGE-XML (from a previous OCR-D workflow step), Eynol
|
||||||
(because some other preprocessing step was in effect like `denoised`), then
|
(because some other preprocessing step was in effect like `denoised`), then
|
||||||
the output PAGE-XML will be based on that as new top-level (`@imageFilename`)
|
the output PAGE-XML will be based on that as new top-level (`@imageFilename`)
|
||||||
|
|
||||||
ocrd-eynollah-segment -I OCR-D-XYZ -O OCR-D-SEG -P models eynollah_layout_v0_5_0
|
ocrd-eynollah-segment -I OCR-D-XYZ -O OCR-D-SEG -P models eynollah_layout_v0_5_0
|
||||||
|
|
||||||
Still, in general, it makes more sense to add other workflow steps **after** Eynollah.
|
In general, it makes more sense to add other workflow steps **after** Eynollah.
|
||||||
|
|
||||||
There is also an OCR-D processor for the binarization:
|
There is also an OCR-D processor for binarization:
|
||||||
|
|
||||||
ocrd-sbb-binarize -I OCR-D-IMG -O OCR-D-BIN -P models default-2021-03-09
|
ocrd-sbb-binarize -I OCR-D-IMG -O OCR-D-BIN -P models default-2021-03-09
|
||||||
|
|
||||||
#### Additional documentation
|
#### Additional documentation
|
||||||
|
|
||||||
Please check the [wiki](https://github.com/qurator-spk/eynollah/wiki).
|
Additional documentation is available in the [docs](https://github.com/qurator-spk/eynollah/tree/main/docs) directory.
|
||||||
|
|
||||||
## How to cite
|
## How to cite
|
||||||
|
|
||||||
If you find this tool useful in your work, please consider citing our paper:
|
|
||||||
|
|
||||||
```bibtex
|
```bibtex
|
||||||
@inproceedings{hip23rezanezhad,
|
@inproceedings{hip23rezanezhad,
|
||||||
title = {Document Layout Analysis with Deep Learning and Heuristics},
|
title = {Document Layout Analysis with Deep Learning and Heuristics},
|
||||||
|
|
156
docs/train.md
156
docs/train.md
|
@ -1,16 +1,24 @@
|
||||||
# Training documentation
|
# Training documentation
|
||||||
|
|
||||||
This aims to assist users in preparing training datasets, training models, and performing inference with trained models.
|
This document aims to assist users in preparing training datasets, training models, and
|
||||||
We cover various use cases including pixel-wise segmentation, image classification, image enhancement, and machine-based
|
performing inference with trained models. We cover various use cases including
|
||||||
reading order detection. For each use case, we provide guidance on how to generate the corresponding training dataset.
|
pixel-wise segmentation, image classification, image enhancement, and
|
||||||
|
machine-based reading order detection. For each use case, we provide guidance
|
||||||
|
on how to generate the corresponding training dataset.
|
||||||
|
|
||||||
The following three tasks can all be accomplished using the code in the
|
The following three tasks can all be accomplished using the code in the
|
||||||
[`train`](https://github.com/qurator-spk/sbb_pixelwise_segmentation/tree/unifying-training-models) directory:
|
[`train`](https://github.com/qurator-spk/eynollah/tree/main/train) directory:
|
||||||
|
|
||||||
* generate training dataset
|
* generate training dataset
|
||||||
* train a model
|
* train a model
|
||||||
* inference with the trained model
|
* inference with the trained model
|
||||||
|
|
||||||
|
## Training , evaluation and output
|
||||||
|
|
||||||
|
The train and evaluation folders should contain subfolders of `images` and `labels`.
|
||||||
|
|
||||||
|
The output folder should be an empty folder where the output model will be written to.
|
||||||
|
|
||||||
## Generate training dataset
|
## Generate training dataset
|
||||||
|
|
||||||
The script `generate_gt_for_training.py` is used for generating training datasets. As the results of the following
|
The script `generate_gt_for_training.py` is used for generating training datasets. As the results of the following
|
||||||
|
@ -64,7 +72,7 @@ to the image area, with a default value of zero. To run the dataset generator, u
|
||||||
python generate_gt_for_training.py machine-based-reading-order \
|
python generate_gt_for_training.py machine-based-reading-order \
|
||||||
-dx "dir of GT xml files" \
|
-dx "dir of GT xml files" \
|
||||||
-domi "dir where output images will be written" \
|
-domi "dir where output images will be written" \
|
||||||
-docl "dir where the labels will be written" \
|
"" -docl "dir where the labels will be written" \
|
||||||
-ih "height" \
|
-ih "height" \
|
||||||
-iw "width" \
|
-iw "width" \
|
||||||
-min "min area ratio"
|
-min "min area ratio"
|
||||||
|
@ -310,54 +318,59 @@ The following parameter configuration can be applied to all segmentation use cas
|
||||||
its sub-parameters, and continued training are defined only for segmentation use cases and enhancements, not for
|
its sub-parameters, and continued training are defined only for segmentation use cases and enhancements, not for
|
||||||
classification and machine-based reading order, as you can see in their example config files.
|
classification and machine-based reading order, as you can see in their example config files.
|
||||||
|
|
||||||
* backbone_type: For segmentation tasks (such as text line, binarization, and layout detection) and enhancement, we
|
* `backbone_type`: For segmentation tasks (such as text line, binarization, and layout detection) and enhancement, we
|
||||||
* offer two backbone options: a "nontransformer" and a "transformer" backbone. For the "transformer" backbone, we first
|
offer two backbone options: a "nontransformer" and a "transformer" backbone. For the "transformer" backbone, we first
|
||||||
* apply a CNN followed by a transformer. In contrast, the "nontransformer" backbone utilizes only a CNN ResNet-50.
|
apply a CNN followed by a transformer. In contrast, the "nontransformer" backbone utilizes only a CNN ResNet-50.
|
||||||
* task : The task parameter can have values such as "segmentation", "enhancement", "classification", and "reading_order".
|
* `task`: The task parameter can have values such as "segmentation", "enhancement", "classification", and "reading_order".
|
||||||
* patches: If you want to break input images into smaller patches (input size of the model) you need to set this
|
* `patches`: If you want to break input images into smaller patches (input size of the model) you need to set this
|
||||||
* parameter to ``true``. In the case that the model should see the image once, like page extraction, patches should be
|
* parameter to `true`. In the case that the model should see the image once, like page extraction, patches should be
|
||||||
* set to ``false``.
|
set to ``false``.
|
||||||
* n_batch: Number of batches at each iteration.
|
* `n_batch`: Number of batches at each iteration.
|
||||||
* n_classes: Number of classes. In the case of binary classification this should be 2. In the case of reading_order it
|
* `n_classes`: Number of classes. In the case of binary classification this should be 2. In the case of reading_order it
|
||||||
* should set to 1. And for the case of layout detection just the unique number of classes should be given.
|
should set to 1. And for the case of layout detection just the unique number of classes should be given.
|
||||||
* n_epochs: Number of epochs.
|
* `n_epochs`: Number of epochs.
|
||||||
* input_height: This indicates the height of model's input.
|
* `input_height`: This indicates the height of model's input.
|
||||||
* input_width: This indicates the width of model's input.
|
* `input_width`: This indicates the width of model's input.
|
||||||
* weight_decay: Weight decay of l2 regularization of model layers.
|
* `weight_decay`: Weight decay of l2 regularization of model layers.
|
||||||
* pretraining: Set to ``true`` to load pretrained weights of ResNet50 encoder. The downloaded weights should be saved
|
* `pretraining`: Set to `true` to load pretrained weights of ResNet50 encoder. The downloaded weights should be saved
|
||||||
* in a folder named "pretrained_model" in the same directory of "train.py" script.
|
in a folder named "pretrained_model" in the same directory of "train.py" script.
|
||||||
* augmentation: If you want to apply any kind of augmentation this parameter should first set to ``true``.
|
* `augmentation`: If you want to apply any kind of augmentation this parameter should first set to `true`.
|
||||||
* flip_aug: If ``true``, different types of filp will be applied on image. Type of flips is given with "flip_index" parameter.
|
* `flip_aug`: If `true`, different types of filp will be applied on image. Type of flips is given with "flip_index" parameter.
|
||||||
* blur_aug: If ``true``, different types of blurring will be applied on image. Type of blurrings is given with "blur_k" parameter.
|
* `blur_aug`: If `true`, different types of blurring will be applied on image. Type of blurrings is given with "blur_k" parameter.
|
||||||
* scaling: If ``true``, scaling will be applied on image. Scale of scaling is given with "scales" parameter.
|
* `scaling`: If `true`, scaling will be applied on image. Scale of scaling is given with "scales" parameter.
|
||||||
* degrading: If ``true``, degrading will be applied to the image. The amount of degrading is defined with "degrade_scales" parameter.
|
* `degrading`: If `true`, degrading will be applied to the image. The amount of degrading is defined with "degrade_scales" parameter.
|
||||||
* brightening: If ``true``, brightening will be applied to the image. The amount of brightening is defined with "brightness" parameter.
|
* `brightening`: If `true`, brightening will be applied to the image. The amount of brightening is defined with "brightness" parameter.
|
||||||
* rotation_not_90: If ``true``, rotation (not 90 degree) will be applied on image. Rotation angles are given with "thetha" parameter.
|
* `rotation_not_90`: If `true`, rotation (not 90 degree) will be applied on image. Rotation angles are given with "thetha" parameter.
|
||||||
* rotation: If ``true``, 90 degree rotation will be applied on image.
|
* `rotation`: If `true`, 90 degree rotation will be applied on image.
|
||||||
* binarization: If ``true``,Otsu thresholding will be applied to augment the input data with binarized images.
|
* `binarization`: If `true`,Otsu thresholding will be applied to augment the input data with binarized images.
|
||||||
* scaling_bluring: If ``true``, combination of scaling and blurring will be applied on image.
|
* `scaling_bluring`: If `true`, combination of scaling and blurring will be applied on image.
|
||||||
* scaling_binarization: If ``true``, combination of scaling and binarization will be applied on image.
|
* `scaling_binarization`: If `true`, combination of scaling and binarization will be applied on image.
|
||||||
* scaling_flip: If ``true``, combination of scaling and flip will be applied on image.
|
* `scaling_flip`: If `true`, combination of scaling and flip will be applied on image.
|
||||||
* flip_index: Type of flips.
|
* `flip_index`: Type of flips.
|
||||||
* blur_k: Type of blurrings.
|
* `blur_k`: Type of blurrings.
|
||||||
* scales: Scales of scaling.
|
* `scales`: Scales of scaling.
|
||||||
* brightness: The amount of brightenings.
|
* `brightness`: The amount of brightenings.
|
||||||
* thetha: Rotation angles.
|
* `thetha`: Rotation angles.
|
||||||
* degrade_scales: The amount of degradings.
|
* `degrade_scales`: The amount of degradings.
|
||||||
* continue_training: If ``true``, it means that you have already trained a model and you would like to continue the training. So it is needed to provide the dir of trained model with "dir_of_start_model" and index for naming the models. For example if you have already trained for 3 epochs then your last index is 2 and if you want to continue from model_1.h5, you can set ``index_start`` to 3 to start naming model with index 3.
|
* `continue_training`: If `true`, it means that you have already trained a model and you would like to continue the
|
||||||
* weighted_loss: If ``true``, this means that you want to apply weighted categorical_crossentropy as loss fucntion. Be carefull if you set to ``true``the parameter "is_loss_soft_dice" should be ``false``
|
training. So it is needed to providethe dir of trained model with "dir_of_start_model" and index for naming
|
||||||
* data_is_provided: If you have already provided the input data you can set this to ``true``. Be sure that the train and eval data are in "dir_output". Since when once we provide training data we resize and augment them and then we write them in sub-directories train and eval in "dir_output".
|
themodels. For example if you have already trained for 3 epochs then your lastindex is 2 and if you want to continue
|
||||||
* dir_train: This is the directory of "images" and "labels" (dir_train should include two subdirectories with names of images and labels ) for raw images and labels. Namely they are not prepared (not resized and not augmented) yet for training the model. When we run this tool these raw data will be transformed to suitable size needed for the model and they will be written in "dir_output" in train and eval directories. Each of train and eval include "images" and "labels" sub-directories.
|
from model_1.h5, you can set `index_start` to 3 to start naming model with index 3.
|
||||||
* index_start: Starting index for saved models in the case that "continue_training" is ``true``.
|
* `weighted_loss`: If `true`, this means that you want to apply weighted categorical_crossentropy as loss fucntion. Be carefull if you set to `true`the parameter "is_loss_soft_dice" should be ``false``
|
||||||
* dir_of_start_model: Directory containing pretrained model to continue training the model in the case that "continue_training" is ``true``.
|
* `data_is_provided`: If you have already provided the input data you can set this to `true`. Be sure that the train
|
||||||
* transformer_num_patches_xy: Number of patches for vision transformer in x and y direction respectively.
|
and eval data are in"dir_output".Since when once we provide training data we resize and augmentthem and then wewrite
|
||||||
* transformer_patchsize_x: Patch size of vision transformer patches in x direction.
|
them in sub-directories train and eval in "dir_output".
|
||||||
* transformer_patchsize_y: Patch size of vision transformer patches in y direction.
|
* `dir_train`: This is the directory of "images" and "labels" (dir_train should include two subdirectories with names of images and labels ) for raw images and labels. Namely they are not prepared (not resized and not augmented) yet for training the model. When we run this tool these raw data will be transformed to suitable size needed for the model and they will be written in "dir_output" in train and eval directories. Each of train and eval include "images" and "labels" sub-directories.
|
||||||
* transformer_projection_dim: Transformer projection dimension. Default value is 64.
|
* `index_start`: Starting index for saved models in the case that "continue_training" is `true`.
|
||||||
* transformer_mlp_head_units: Transformer Multilayer Perceptron (MLP) head units. Default value is [128, 64].
|
* `dir_of_start_model`: Directory containing pretrained model to continue training the model in the case that "continue_training" is `true`.
|
||||||
* transformer_layers: transformer layers. Default value is 8.
|
* `transformer_num_patches_xy`: Number of patches for vision transformer in x and y direction respectively.
|
||||||
* transformer_num_heads: Transformer number of heads. Default value is 4.
|
* `transformer_patchsize_x`: Patch size of vision transformer patches in x direction.
|
||||||
* transformer_cnn_first: We have two types of vision transformers. In one type, a CNN is applied first, followed by a transformer. In the other type, this order is reversed. If transformer_cnn_first is true, it means the CNN will be applied before the transformer. Default value is true.
|
* `transformer_patchsize_y`: Patch size of vision transformer patches in y direction.
|
||||||
|
* `transformer_projection_dim`: Transformer projection dimension. Default value is 64.
|
||||||
|
* `transformer_mlp_head_units`: Transformer Multilayer Perceptron (MLP) head units. Default value is [128, 64].
|
||||||
|
* `transformer_layers`: transformer layers. Default value is 8.
|
||||||
|
* `transformer_num_heads`: Transformer number of heads. Default value is 4.
|
||||||
|
* `transformer_cnn_first`: We have two types of vision transformers. In one type, a CNN is applied first, followed by a transformer. In the other type, this order is reversed. If transformer_cnn_first is true, it means the CNN will be applied before the transformer. Default value is true.
|
||||||
|
|
||||||
In the case of segmentation and enhancement the train and evaluation directory should be as following.
|
In the case of segmentation and enhancement the train and evaluation directory should be as following.
|
||||||
|
|
||||||
|
@ -386,6 +399,30 @@ command, similar to the process for classification and reading order:
|
||||||
|
|
||||||
#### Binarization
|
#### Binarization
|
||||||
|
|
||||||
|
### Ground truth format
|
||||||
|
|
||||||
|
Lables for each pixel are identified by a number. So if you have a
|
||||||
|
binary case, ``n_classes`` should be set to ``2`` and labels should
|
||||||
|
be ``0`` and ``1`` for each class and pixel.
|
||||||
|
|
||||||
|
In the case of multiclass, just set ``n_classes`` to the number of classes
|
||||||
|
you have and the try to produce the labels by pixels set from ``0 , 1 ,2 .., n_classes-1``.
|
||||||
|
The labels format should be png.
|
||||||
|
Our lables are 3 channel png images but only information of first channel is used.
|
||||||
|
If you have an image label with height and width of 10, for a binary case the first channel should look like this:
|
||||||
|
|
||||||
|
Label: [ [1, 0, 0, 1, 1, 0, 0, 1, 0, 0],
|
||||||
|
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||||
|
...,
|
||||||
|
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||||
|
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0] ]
|
||||||
|
|
||||||
|
This means that you have an image by `10*10*3` and `pixel[0,0]` belongs
|
||||||
|
to class `1` and `pixel[0,1]` belongs to class `0`.
|
||||||
|
|
||||||
|
A small sample of training data for binarization experiment can be found here, [Training data sample](https://qurator-data.de/~vahid.rezanezhad/binarization_training_data_sample/), which contains images and lables folders.
|
||||||
|
|
||||||
|
|
||||||
An example config json file for binarization can be like this:
|
An example config json file for binarization can be like this:
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
|
@ -577,8 +614,8 @@ image.
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
For page segmentation (or printspace or border segmentation), the model needs to view the input image in its entirety,
|
For page segmentation (or print space or border segmentation), the model needs to view the input image in its
|
||||||
hence the patches parameter should be set to false.
|
entirety,hence the patches parameter should be set to false.
|
||||||
|
|
||||||
#### layout segmentation
|
#### layout segmentation
|
||||||
|
|
||||||
|
@ -650,9 +687,8 @@ This will straightforwardly return the class of the image.
|
||||||
### machine based reading order
|
### machine based reading order
|
||||||
|
|
||||||
To infer the reading order using a reading order model, we need a page XML file containing layout information but
|
To infer the reading order using a reading order model, we need a page XML file containing layout information but
|
||||||
without the reading order. We simply need to provide the model directory, the XML file, and the output directory.
|
without the reading order. We simply need to provide the model directory, the XML file, and the output directory. The
|
||||||
The new XML file with the added reading order will be written to the output directory with the same name.
|
new XML file with the added reading order will be written to the output directory with the same name. We need to run:
|
||||||
We need to run:
|
|
||||||
|
|
||||||
```sh
|
```sh
|
||||||
python inference.py \
|
python inference.py \
|
||||||
|
@ -662,8 +698,8 @@ python inference.py \
|
||||||
```
|
```
|
||||||
|
|
||||||
### Segmentation (Textline, Binarization, Page extraction and layout) and enhancement
|
### Segmentation (Textline, Binarization, Page extraction and layout) and enhancement
|
||||||
For conducting inference with a trained model for segmentation and enhancement you need to run the following command
|
|
||||||
line:
|
For conducting inference with a trained model for segmentation and enhancement you need to run the following command line:
|
||||||
|
|
||||||
```sh
|
```sh
|
||||||
python inference.py \
|
python inference.py \
|
||||||
|
|
|
@ -13,7 +13,11 @@ license.file = "LICENSE"
|
||||||
requires-python = ">=3.8"
|
requires-python = ">=3.8"
|
||||||
keywords = ["document layout analysis", "image segmentation"]
|
keywords = ["document layout analysis", "image segmentation"]
|
||||||
|
|
||||||
dynamic = ["dependencies", "version"]
|
dynamic = [
|
||||||
|
"dependencies",
|
||||||
|
"optional-dependencies",
|
||||||
|
"version"
|
||||||
|
]
|
||||||
|
|
||||||
classifiers = [
|
classifiers = [
|
||||||
"Development Status :: 4 - Beta",
|
"Development Status :: 4 - Beta",
|
||||||
|
@ -25,10 +29,6 @@ classifiers = [
|
||||||
"Topic :: Scientific/Engineering :: Image Processing",
|
"Topic :: Scientific/Engineering :: Image Processing",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.optional-dependencies]
|
|
||||||
OCR = ["torch <= 2.0.1", "transformers <= 4.30.2"]
|
|
||||||
plotting = ["matplotlib"]
|
|
||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
eynollah = "eynollah.cli:main"
|
eynollah = "eynollah.cli:main"
|
||||||
ocrd-eynollah-segment = "eynollah.ocrd_cli:main"
|
ocrd-eynollah-segment = "eynollah.ocrd_cli:main"
|
||||||
|
@ -41,6 +41,9 @@ Repository = "https://github.com/qurator-spk/eynollah.git"
|
||||||
[tool.setuptools.dynamic]
|
[tool.setuptools.dynamic]
|
||||||
dependencies = {file = ["requirements.txt"]}
|
dependencies = {file = ["requirements.txt"]}
|
||||||
optional-dependencies.test = {file = ["requirements-test.txt"]}
|
optional-dependencies.test = {file = ["requirements-test.txt"]}
|
||||||
|
optional-dependencies.OCR = {file = ["requirements-ocr.txt"]}
|
||||||
|
optional-dependencies.plotting = {file = ["requirements-plotting.txt"]}
|
||||||
|
optional-dependencies.training = {file = ["requirements-training.txt"]}
|
||||||
|
|
||||||
[tool.setuptools.packages.find]
|
[tool.setuptools.packages.find]
|
||||||
where = ["src"]
|
where = ["src"]
|
||||||
|
|
2
requirements-ocr.txt
Normal file
2
requirements-ocr.txt
Normal file
|
@ -0,0 +1,2 @@
|
||||||
|
torch <= 2.0.1
|
||||||
|
transformers <= 4.30.2
|
1
requirements-plotting.txt
Normal file
1
requirements-plotting.txt
Normal file
|
@ -0,0 +1 @@
|
||||||
|
matplotlib
|
1
requirements-training.txt
Symbolic link
1
requirements-training.txt
Symbolic link
|
@ -0,0 +1 @@
|
||||||
|
train/requirements.txt
|
|
@ -4894,9 +4894,9 @@ class Eynollah:
|
||||||
textline_mask_tot_ea_org[img_revised_tab==drop_label_in_full_layout] = 0
|
textline_mask_tot_ea_org[img_revised_tab==drop_label_in_full_layout] = 0
|
||||||
|
|
||||||
|
|
||||||
text_only = ((img_revised_tab[:, :] == 1)) * 1
|
text_only = (img_revised_tab[:, :] == 1) * 1
|
||||||
if np.abs(slope_deskew) >= SLOPE_THRESHOLD:
|
if np.abs(slope_deskew) >= SLOPE_THRESHOLD:
|
||||||
text_only_d = ((text_regions_p_1_n[:, :] == 1)) * 1
|
text_only_d = (text_regions_p_1_n[:, :] == 1) * 1
|
||||||
|
|
||||||
#print("text region early 2 in %.1fs", time.time() - t0)
|
#print("text region early 2 in %.1fs", time.time() - t0)
|
||||||
###min_con_area = 0.000005
|
###min_con_area = 0.000005
|
||||||
|
|
|
@ -12,7 +12,7 @@ from .utils import crop_image_inside_box
|
||||||
from .utils.rotate import rotate_image_different
|
from .utils.rotate import rotate_image_different
|
||||||
from .utils.resize import resize_image
|
from .utils.resize import resize_image
|
||||||
|
|
||||||
class EynollahPlotter():
|
class EynollahPlotter:
|
||||||
"""
|
"""
|
||||||
Class collecting all the plotting and image writing methods
|
Class collecting all the plotting and image writing methods
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -138,8 +138,7 @@ def return_x_start_end_mothers_childs_and_type_of_reading_order(
|
||||||
min_ys=np.min(y_sep)
|
min_ys=np.min(y_sep)
|
||||||
max_ys=np.max(y_sep)
|
max_ys=np.max(y_sep)
|
||||||
|
|
||||||
y_mains=[]
|
y_mains= [min_ys]
|
||||||
y_mains.append(min_ys)
|
|
||||||
y_mains_sep_ohne_grenzen=[]
|
y_mains_sep_ohne_grenzen=[]
|
||||||
|
|
||||||
for ii in range(len(new_main_sep_y)):
|
for ii in range(len(new_main_sep_y)):
|
||||||
|
@ -493,8 +492,7 @@ def find_num_col(regions_without_separators, num_col_classifier, tables, multipl
|
||||||
# print(forest[np.argmin(z[forest]) ] )
|
# print(forest[np.argmin(z[forest]) ] )
|
||||||
if not isNaN(forest[np.argmin(z[forest])]):
|
if not isNaN(forest[np.argmin(z[forest])]):
|
||||||
peaks_neg_true.append(forest[np.argmin(z[forest])])
|
peaks_neg_true.append(forest[np.argmin(z[forest])])
|
||||||
forest = []
|
forest = [peaks_neg_fin[i + 1]]
|
||||||
forest.append(peaks_neg_fin[i + 1])
|
|
||||||
if i == (len(peaks_neg_fin) - 1):
|
if i == (len(peaks_neg_fin) - 1):
|
||||||
# print(print(forest[np.argmin(z[forest]) ] ))
|
# print(print(forest[np.argmin(z[forest]) ] ))
|
||||||
if not isNaN(forest[np.argmin(z[forest])]):
|
if not isNaN(forest[np.argmin(z[forest])]):
|
||||||
|
@ -662,8 +660,7 @@ def find_num_col_only_image(regions_without_separators, multiplier=3.8):
|
||||||
# print(forest[np.argmin(z[forest]) ] )
|
# print(forest[np.argmin(z[forest]) ] )
|
||||||
if not isNaN(forest[np.argmin(z[forest])]):
|
if not isNaN(forest[np.argmin(z[forest])]):
|
||||||
peaks_neg_true.append(forest[np.argmin(z[forest])])
|
peaks_neg_true.append(forest[np.argmin(z[forest])])
|
||||||
forest = []
|
forest = [peaks_neg_fin[i + 1]]
|
||||||
forest.append(peaks_neg_fin[i + 1])
|
|
||||||
if i == (len(peaks_neg_fin) - 1):
|
if i == (len(peaks_neg_fin) - 1):
|
||||||
# print(print(forest[np.argmin(z[forest]) ] ))
|
# print(print(forest[np.argmin(z[forest]) ] ))
|
||||||
if not isNaN(forest[np.argmin(z[forest])]):
|
if not isNaN(forest[np.argmin(z[forest])]):
|
||||||
|
@ -1211,7 +1208,7 @@ def order_of_regions(textline_mask, contours_main, contours_header, y_ref):
|
||||||
|
|
||||||
##plt.plot(z)
|
##plt.plot(z)
|
||||||
##plt.show()
|
##plt.show()
|
||||||
if contours_main != None:
|
if contours_main is not None:
|
||||||
areas_main = np.array([cv2.contourArea(contours_main[j]) for j in range(len(contours_main))])
|
areas_main = np.array([cv2.contourArea(contours_main[j]) for j in range(len(contours_main))])
|
||||||
M_main = [cv2.moments(contours_main[j]) for j in range(len(contours_main))]
|
M_main = [cv2.moments(contours_main[j]) for j in range(len(contours_main))]
|
||||||
cx_main = [(M_main[j]["m10"] / (M_main[j]["m00"] + 1e-32)) for j in range(len(M_main))]
|
cx_main = [(M_main[j]["m10"] / (M_main[j]["m00"] + 1e-32)) for j in range(len(M_main))]
|
||||||
|
@ -1222,7 +1219,7 @@ def order_of_regions(textline_mask, contours_main, contours_header, y_ref):
|
||||||
y_min_main = np.array([np.min(contours_main[j][:, 0, 1]) for j in range(len(contours_main))])
|
y_min_main = np.array([np.min(contours_main[j][:, 0, 1]) for j in range(len(contours_main))])
|
||||||
y_max_main = np.array([np.max(contours_main[j][:, 0, 1]) for j in range(len(contours_main))])
|
y_max_main = np.array([np.max(contours_main[j][:, 0, 1]) for j in range(len(contours_main))])
|
||||||
|
|
||||||
if len(contours_header) != None:
|
if len(contours_header) is not None:
|
||||||
areas_header = np.array([cv2.contourArea(contours_header[j]) for j in range(len(contours_header))])
|
areas_header = np.array([cv2.contourArea(contours_header[j]) for j in range(len(contours_header))])
|
||||||
M_header = [cv2.moments(contours_header[j]) for j in range(len(contours_header))]
|
M_header = [cv2.moments(contours_header[j]) for j in range(len(contours_header))]
|
||||||
cx_header = [(M_header[j]["m10"] / (M_header[j]["m00"] + 1e-32)) for j in range(len(M_header))]
|
cx_header = [(M_header[j]["m10"] / (M_header[j]["m00"] + 1e-32)) for j in range(len(M_header))]
|
||||||
|
@ -1235,17 +1232,16 @@ def order_of_regions(textline_mask, contours_main, contours_header, y_ref):
|
||||||
y_max_header = np.array([np.max(contours_header[j][:, 0, 1]) for j in range(len(contours_header))])
|
y_max_header = np.array([np.max(contours_header[j][:, 0, 1]) for j in range(len(contours_header))])
|
||||||
# print(cy_main,'mainy')
|
# print(cy_main,'mainy')
|
||||||
|
|
||||||
peaks_neg_new = []
|
peaks_neg_new = [0 + y_ref]
|
||||||
peaks_neg_new.append(0 + y_ref)
|
|
||||||
for iii in range(len(peaks_neg)):
|
for iii in range(len(peaks_neg)):
|
||||||
peaks_neg_new.append(peaks_neg[iii] + y_ref)
|
peaks_neg_new.append(peaks_neg[iii] + y_ref)
|
||||||
peaks_neg_new.append(textline_mask.shape[0] + y_ref)
|
peaks_neg_new.append(textline_mask.shape[0] + y_ref)
|
||||||
|
|
||||||
if len(cy_main) > 0 and np.max(cy_main) > np.max(peaks_neg_new):
|
if len(cy_main) > 0 and np.max(cy_main) > np.max(peaks_neg_new):
|
||||||
cy_main = np.array(cy_main) * (np.max(peaks_neg_new) / np.max(cy_main)) - 10
|
cy_main = np.array(cy_main) * (np.max(peaks_neg_new) / np.max(cy_main)) - 10
|
||||||
if contours_main != None:
|
if contours_main is not None:
|
||||||
indexer_main = np.arange(len(contours_main))
|
indexer_main = np.arange(len(contours_main))
|
||||||
if contours_main != None:
|
if contours_main is not None:
|
||||||
len_main = len(contours_main)
|
len_main = len(contours_main)
|
||||||
else:
|
else:
|
||||||
len_main = 0
|
len_main = 0
|
||||||
|
@ -1271,11 +1267,11 @@ def order_of_regions(textline_mask, contours_main, contours_header, y_ref):
|
||||||
top = peaks_neg_new[i]
|
top = peaks_neg_new[i]
|
||||||
down = peaks_neg_new[i + 1]
|
down = peaks_neg_new[i + 1]
|
||||||
indexes_in = matrix_of_orders[:, 0][(matrix_of_orders[:, 3] >= top) &
|
indexes_in = matrix_of_orders[:, 0][(matrix_of_orders[:, 3] >= top) &
|
||||||
((matrix_of_orders[:, 3] < down))]
|
(matrix_of_orders[:, 3] < down)]
|
||||||
cxs_in = matrix_of_orders[:, 2][(matrix_of_orders[:, 3] >= top) &
|
cxs_in = matrix_of_orders[:, 2][(matrix_of_orders[:, 3] >= top) &
|
||||||
((matrix_of_orders[:, 3] < down))]
|
(matrix_of_orders[:, 3] < down)]
|
||||||
cys_in = matrix_of_orders[:, 3][(matrix_of_orders[:, 3] >= top) &
|
cys_in = matrix_of_orders[:, 3][(matrix_of_orders[:, 3] >= top) &
|
||||||
((matrix_of_orders[:, 3] < down))]
|
(matrix_of_orders[:, 3] < down)]
|
||||||
types_of_text = matrix_of_orders[:, 1][(matrix_of_orders[:, 3] >= top) &
|
types_of_text = matrix_of_orders[:, 1][(matrix_of_orders[:, 3] >= top) &
|
||||||
(matrix_of_orders[:, 3] < down)]
|
(matrix_of_orders[:, 3] < down)]
|
||||||
index_types_of_text = matrix_of_orders[:, 4][(matrix_of_orders[:, 3] >= top) &
|
index_types_of_text = matrix_of_orders[:, 4][(matrix_of_orders[:, 3] >= top) &
|
||||||
|
@ -1404,8 +1400,7 @@ def combine_hor_lines_and_delete_cross_points_and_get_lines_features_back_new(
|
||||||
return img_p_in[:,:,0], special_separators
|
return img_p_in[:,:,0], special_separators
|
||||||
|
|
||||||
def return_points_with_boundies(peaks_neg_fin, first_point, last_point):
|
def return_points_with_boundies(peaks_neg_fin, first_point, last_point):
|
||||||
peaks_neg_tot = []
|
peaks_neg_tot = [first_point]
|
||||||
peaks_neg_tot.append(first_point)
|
|
||||||
for ii in range(len(peaks_neg_fin)):
|
for ii in range(len(peaks_neg_fin)):
|
||||||
peaks_neg_tot.append(peaks_neg_fin[ii])
|
peaks_neg_tot.append(peaks_neg_fin[ii])
|
||||||
peaks_neg_tot.append(last_point)
|
peaks_neg_tot.append(last_point)
|
||||||
|
@ -1413,7 +1408,7 @@ def return_points_with_boundies(peaks_neg_fin, first_point, last_point):
|
||||||
|
|
||||||
def find_number_of_columns_in_document(region_pre_p, num_col_classifier, tables, pixel_lines, contours_h=None):
|
def find_number_of_columns_in_document(region_pre_p, num_col_classifier, tables, pixel_lines, contours_h=None):
|
||||||
t_ins_c0 = time.time()
|
t_ins_c0 = time.time()
|
||||||
separators_closeup=( (region_pre_p[:,:,:]==pixel_lines))*1
|
separators_closeup= (region_pre_p[:, :, :] == pixel_lines) * 1
|
||||||
separators_closeup[0:110,:,:]=0
|
separators_closeup[0:110,:,:]=0
|
||||||
separators_closeup[separators_closeup.shape[0]-150:,:,:]=0
|
separators_closeup[separators_closeup.shape[0]-150:,:,:]=0
|
||||||
|
|
||||||
|
@ -1452,7 +1447,7 @@ def find_number_of_columns_in_document(region_pre_p, num_col_classifier, tables,
|
||||||
gray = cv2.bitwise_not(separators_closeup_n_binary)
|
gray = cv2.bitwise_not(separators_closeup_n_binary)
|
||||||
gray=gray.astype(np.uint8)
|
gray=gray.astype(np.uint8)
|
||||||
|
|
||||||
bw = cv2.adaptiveThreshold(gray, 255, cv2.ADAPTIVE_THRESH_MEAN_C, \
|
bw = cv2.adaptiveThreshold(gray, 255, cv2.ADAPTIVE_THRESH_MEAN_C,
|
||||||
cv2.THRESH_BINARY, 15, -2)
|
cv2.THRESH_BINARY, 15, -2)
|
||||||
horizontal = np.copy(bw)
|
horizontal = np.copy(bw)
|
||||||
vertical = np.copy(bw)
|
vertical = np.copy(bw)
|
||||||
|
@ -1588,8 +1583,7 @@ def find_number_of_columns_in_document(region_pre_p, num_col_classifier, tables,
|
||||||
args_cy_splitter=np.argsort(cy_main_splitters)
|
args_cy_splitter=np.argsort(cy_main_splitters)
|
||||||
cy_main_splitters_sort=cy_main_splitters[args_cy_splitter]
|
cy_main_splitters_sort=cy_main_splitters[args_cy_splitter]
|
||||||
|
|
||||||
splitter_y_new=[]
|
splitter_y_new= [0]
|
||||||
splitter_y_new.append(0)
|
|
||||||
for i in range(len(cy_main_splitters_sort)):
|
for i in range(len(cy_main_splitters_sort)):
|
||||||
splitter_y_new.append( cy_main_splitters_sort[i] )
|
splitter_y_new.append( cy_main_splitters_sort[i] )
|
||||||
splitter_y_new.append(region_pre_p.shape[0])
|
splitter_y_new.append(region_pre_p.shape[0])
|
||||||
|
@ -1663,8 +1657,7 @@ def return_boxes_of_images_by_order_of_reading_new(
|
||||||
num_col, peaks_neg_fin = find_num_col(
|
num_col, peaks_neg_fin = find_num_col(
|
||||||
regions_without_separators[int(splitter_y_new[i]):int(splitter_y_new[i+1]),:],
|
regions_without_separators[int(splitter_y_new[i]):int(splitter_y_new[i+1]),:],
|
||||||
num_col_classifier, tables, multiplier=3.)
|
num_col_classifier, tables, multiplier=3.)
|
||||||
peaks_neg_fin_early=[]
|
peaks_neg_fin_early= [0]
|
||||||
peaks_neg_fin_early.append(0)
|
|
||||||
#print(peaks_neg_fin,'peaks_neg_fin')
|
#print(peaks_neg_fin,'peaks_neg_fin')
|
||||||
for p_n in peaks_neg_fin:
|
for p_n in peaks_neg_fin:
|
||||||
peaks_neg_fin_early.append(p_n)
|
peaks_neg_fin_early.append(p_n)
|
||||||
|
|
|
@ -239,8 +239,7 @@ def do_back_rotation_and_get_cnt_back(contour_par, index_r_con, img, slope_first
|
||||||
|
|
||||||
cont_int, _ = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
|
cont_int, _ = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
|
||||||
if len(cont_int)==0:
|
if len(cont_int)==0:
|
||||||
cont_int = []
|
cont_int = [contour_par]
|
||||||
cont_int.append(contour_par)
|
|
||||||
confidence_contour = 0
|
confidence_contour = 0
|
||||||
else:
|
else:
|
||||||
cont_int[0][:, 0, 0] = cont_int[0][:, 0, 0] + np.abs(img_copy.shape[1] - img.shape[1])
|
cont_int[0][:, 0, 0] = cont_int[0][:, 0, 0] + np.abs(img_copy.shape[1] - img.shape[1])
|
||||||
|
|
|
@ -3,7 +3,7 @@ from collections import Counter
|
||||||
REGION_ID_TEMPLATE = 'region_%04d'
|
REGION_ID_TEMPLATE = 'region_%04d'
|
||||||
LINE_ID_TEMPLATE = 'region_%04d_line_%04d'
|
LINE_ID_TEMPLATE = 'region_%04d_line_%04d'
|
||||||
|
|
||||||
class EynollahIdCounter():
|
class EynollahIdCounter:
|
||||||
|
|
||||||
def __init__(self, region_idx=0, line_idx=0):
|
def __init__(self, region_idx=0, line_idx=0):
|
||||||
self._counter = Counter()
|
self._counter = Counter()
|
||||||
|
|
|
@ -76,7 +76,7 @@ def get_marginals(text_with_lines, text_regions, num_col, slope_deskew, light_ve
|
||||||
|
|
||||||
peaks, _ = find_peaks(text_with_lines_y_rev, height=0)
|
peaks, _ = find_peaks(text_with_lines_y_rev, height=0)
|
||||||
peaks=np.array(peaks)
|
peaks=np.array(peaks)
|
||||||
peaks=peaks[(peaks>first_nonzero) & ((peaks<last_nonzero))]
|
peaks=peaks[(peaks>first_nonzero) & (peaks < last_nonzero)]
|
||||||
peaks=peaks[region_sum_0[peaks]<min_textline_thickness ]
|
peaks=peaks[region_sum_0[peaks]<min_textline_thickness ]
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1174,8 +1174,7 @@ def separate_lines_new_inside_tiles(img_path, thetha):
|
||||||
if diff_peaks[i] > cut_off:
|
if diff_peaks[i] > cut_off:
|
||||||
if not np.isnan(forest[np.argmin(z[forest])]):
|
if not np.isnan(forest[np.argmin(z[forest])]):
|
||||||
peaks_neg_true.append(forest[np.argmin(z[forest])])
|
peaks_neg_true.append(forest[np.argmin(z[forest])])
|
||||||
forest = []
|
forest = [peaks_neg[i + 1]]
|
||||||
forest.append(peaks_neg[i + 1])
|
|
||||||
if i == (len(peaks_neg) - 1):
|
if i == (len(peaks_neg) - 1):
|
||||||
if not np.isnan(forest[np.argmin(z[forest])]):
|
if not np.isnan(forest[np.argmin(z[forest])]):
|
||||||
peaks_neg_true.append(forest[np.argmin(z[forest])])
|
peaks_neg_true.append(forest[np.argmin(z[forest])])
|
||||||
|
@ -1195,8 +1194,7 @@ def separate_lines_new_inside_tiles(img_path, thetha):
|
||||||
if diff_peaks_pos[i] > cut_off:
|
if diff_peaks_pos[i] > cut_off:
|
||||||
if not np.isnan(forest[np.argmax(z[forest])]):
|
if not np.isnan(forest[np.argmax(z[forest])]):
|
||||||
peaks_pos_true.append(forest[np.argmax(z[forest])])
|
peaks_pos_true.append(forest[np.argmax(z[forest])])
|
||||||
forest = []
|
forest = [peaks[i + 1]]
|
||||||
forest.append(peaks[i + 1])
|
|
||||||
if i == (len(peaks) - 1):
|
if i == (len(peaks) - 1):
|
||||||
if not np.isnan(forest[np.argmax(z[forest])]):
|
if not np.isnan(forest[np.argmax(z[forest])]):
|
||||||
peaks_pos_true.append(forest[np.argmax(z[forest])])
|
peaks_pos_true.append(forest[np.argmax(z[forest])])
|
||||||
|
@ -1430,9 +1428,9 @@ def separate_lines_new2(img_path, thetha, num_col, slope_region, logger=None, pl
|
||||||
img_int = np.zeros((img_xline.shape[0], img_xline.shape[1]))
|
img_int = np.zeros((img_xline.shape[0], img_xline.shape[1]))
|
||||||
img_int[:, :] = img_xline[:, :] # img_patch_org[:,:,0]
|
img_int[:, :] = img_xline[:, :] # img_patch_org[:,:,0]
|
||||||
|
|
||||||
img_resized = np.zeros((int(img_int.shape[0] * (1.2)), int(img_int.shape[1] * (3))))
|
img_resized = np.zeros((int(img_int.shape[0] * 1.2), int(img_int.shape[1] * 3)))
|
||||||
img_resized[int(img_int.shape[0] * (0.1)) : int(img_int.shape[0] * (0.1)) + img_int.shape[0],
|
img_resized[int(img_int.shape[0] * 0.1): int(img_int.shape[0] * 0.1) + img_int.shape[0],
|
||||||
int(img_int.shape[1] * (1.0)) : int(img_int.shape[1] * (1.0)) + img_int.shape[1]] = img_int[:, :]
|
int(img_int.shape[1] * 1.0): int(img_int.shape[1] * 1.0) + img_int.shape[1]] = img_int[:, :]
|
||||||
# plt.imshow(img_xline)
|
# plt.imshow(img_xline)
|
||||||
# plt.show()
|
# plt.show()
|
||||||
img_line_rotated = rotate_image(img_resized, slopes_tile_wise[i])
|
img_line_rotated = rotate_image(img_resized, slopes_tile_wise[i])
|
||||||
|
@ -1444,8 +1442,8 @@ def separate_lines_new2(img_path, thetha, num_col, slope_region, logger=None, pl
|
||||||
img_patch_separated_returned[:, :][img_patch_separated_returned[:, :] != 0] = 1
|
img_patch_separated_returned[:, :][img_patch_separated_returned[:, :] != 0] = 1
|
||||||
|
|
||||||
img_patch_separated_returned_true_size = img_patch_separated_returned[
|
img_patch_separated_returned_true_size = img_patch_separated_returned[
|
||||||
int(img_int.shape[0] * (0.1)) : int(img_int.shape[0] * (0.1)) + img_int.shape[0],
|
int(img_int.shape[0] * 0.1): int(img_int.shape[0] * 0.1) + img_int.shape[0],
|
||||||
int(img_int.shape[1] * (1.0)) : int(img_int.shape[1] * (1.0)) + img_int.shape[1]]
|
int(img_int.shape[1] * 1.0): int(img_int.shape[1] * 1.0) + img_int.shape[1]]
|
||||||
|
|
||||||
img_patch_separated_returned_true_size = img_patch_separated_returned_true_size[:, margin : length_x - margin]
|
img_patch_separated_returned_true_size = img_patch_separated_returned_true_size[:, margin : length_x - margin]
|
||||||
img_patch_ineterst_revised[:, index_x_d + margin : index_x_u - margin] = img_patch_separated_returned_true_size
|
img_patch_ineterst_revised[:, index_x_d + margin : index_x_u - margin] = img_patch_separated_returned_true_size
|
||||||
|
@ -1473,7 +1471,7 @@ def return_deskew_slop(img_patch_org, sigma_des,n_tot_angles=100,
|
||||||
img_int[:,:]=img_patch_org[:,:]#img_patch_org[:,:,0]
|
img_int[:,:]=img_patch_org[:,:]#img_patch_org[:,:,0]
|
||||||
|
|
||||||
max_shape=np.max(img_int.shape)
|
max_shape=np.max(img_int.shape)
|
||||||
img_resized=np.zeros((int( max_shape*(1.1) ) , int( max_shape*(1.1) ) ))
|
img_resized=np.zeros((int(max_shape * 1.1) , int(max_shape * 1.1)))
|
||||||
|
|
||||||
onset_x=int((img_resized.shape[1]-img_int.shape[1])/2.)
|
onset_x=int((img_resized.shape[1]-img_int.shape[1])/2.)
|
||||||
onset_y=int((img_resized.shape[0]-img_int.shape[0])/2.)
|
onset_y=int((img_resized.shape[0]-img_int.shape[0])/2.)
|
||||||
|
@ -1538,7 +1536,7 @@ def return_deskew_slop_old_mp(img_patch_org, sigma_des,n_tot_angles=100,
|
||||||
img_int[:,:]=img_patch_org[:,:]#img_patch_org[:,:,0]
|
img_int[:,:]=img_patch_org[:,:]#img_patch_org[:,:,0]
|
||||||
|
|
||||||
max_shape=np.max(img_int.shape)
|
max_shape=np.max(img_int.shape)
|
||||||
img_resized=np.zeros((int( max_shape*(1.1) ) , int( max_shape*(1.1) ) ))
|
img_resized=np.zeros((int(max_shape * 1.1) , int(max_shape * 1.1)))
|
||||||
|
|
||||||
onset_x=int((img_resized.shape[1]-img_int.shape[1])/2.)
|
onset_x=int((img_resized.shape[1]-img_int.shape[1])/2.)
|
||||||
onset_y=int((img_resized.shape[0]-img_int.shape[0])/2.)
|
onset_y=int((img_resized.shape[0]-img_int.shape[0])/2.)
|
||||||
|
|
|
@ -21,7 +21,7 @@ from ocrd_models.ocrd_page import (
|
||||||
)
|
)
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
class EynollahXmlWriter():
|
class EynollahXmlWriter:
|
||||||
|
|
||||||
def __init__(self, *, dir_out, image_filename, curved_line,textline_light, pcgts=None):
|
def __init__(self, *, dir_out, image_filename, curved_line,textline_light, pcgts=None):
|
||||||
self.logger = getLogger('eynollah.writer')
|
self.logger = getLogger('eynollah.writer')
|
||||||
|
|
0
train/.gitkeep
Normal file
0
train/.gitkeep
Normal file
29
train/Dockerfile
Normal file
29
train/Dockerfile
Normal file
|
@ -0,0 +1,29 @@
|
||||||
|
# Use NVIDIA base image
|
||||||
|
FROM nvidia/cuda:11.8.0-cudnn8-devel-ubuntu20.04
|
||||||
|
|
||||||
|
# Set the working directory
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
|
||||||
|
# Set environment variable for GitPython
|
||||||
|
ENV GIT_PYTHON_REFRESH=quiet
|
||||||
|
|
||||||
|
# Install Python and pip
|
||||||
|
RUN apt-get update && apt-get install -y --fix-broken && \
|
||||||
|
apt-get install -y \
|
||||||
|
python3 \
|
||||||
|
python3-pip \
|
||||||
|
python3-distutils \
|
||||||
|
python3-setuptools \
|
||||||
|
python3-wheel && \
|
||||||
|
rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
# Copy and install Python dependencies
|
||||||
|
COPY requirements.txt .
|
||||||
|
RUN pip install --no-cache-dir -r requirements.txt
|
||||||
|
|
||||||
|
# Copy the rest of the application
|
||||||
|
COPY . .
|
||||||
|
|
||||||
|
# Specify the entry point
|
||||||
|
CMD ["python3", "train.py", "with", "config_params_docker.json"]
|
59
train/README.md
Normal file
59
train/README.md
Normal file
|
@ -0,0 +1,59 @@
|
||||||
|
# Training eynollah
|
||||||
|
|
||||||
|
This README explains the technical details of how to set up and run training, for detailed information on parameterization, see [`docs/train.md`](../docs/train.md)
|
||||||
|
|
||||||
|
## Introduction
|
||||||
|
|
||||||
|
This folder contains the source code for training an encoder model for document image segmentation.
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
Clone the repository and install eynollah along with the dependencies necessary for training:
|
||||||
|
|
||||||
|
```sh
|
||||||
|
git clone https://github.com/qurator-spk/eynollah
|
||||||
|
cd eynollah
|
||||||
|
pip install '.[training]'
|
||||||
|
```
|
||||||
|
|
||||||
|
### Pretrained encoder
|
||||||
|
|
||||||
|
Download our pretrained weights and add them to a `train/pretrained_model` folder:
|
||||||
|
|
||||||
|
```sh
|
||||||
|
cd train
|
||||||
|
wget -O pretrained_model.tar.gz https://zenodo.org/records/17243320/files/pretrained_model_v0_5_1.tar.gz?download=1
|
||||||
|
tar xf pretrained_model.tar.gz
|
||||||
|
```
|
||||||
|
|
||||||
|
### Binarization training data
|
||||||
|
|
||||||
|
A small sample of training data for binarization experiment can be found [on
|
||||||
|
zenodo](https://zenodo.org/records/17243320/files/training_data_sample_binarization_v0_5_1.tar.gz?download=1),
|
||||||
|
which contains `images` and `labels` folders.
|
||||||
|
|
||||||
|
### Helpful tools
|
||||||
|
|
||||||
|
* [`pagexml2img`](https://github.com/qurator-spk/page2img)
|
||||||
|
> Tool to extract 2-D or 3-D RGB images from PAGE-XML data. In the former case, the output will be 1 2-D image array which each class has filled with a pixel value. In the case of a 3-D RGB image,
|
||||||
|
each class will be defined with a RGB value and beside images, a text file of classes will also be produced.
|
||||||
|
* [`cocoSegmentationToPng`](https://github.com/nightrome/cocostuffapi/blob/17acf33aef3c6cc2d6aca46dcf084266c2778cf0/PythonAPI/pycocotools/cocostuffhelper.py#L130)
|
||||||
|
> Convert COCO GT or results for a single image to a segmentation map and write it to disk.
|
||||||
|
* [`ocrd-segment-extract-pages`](https://github.com/OCR-D/ocrd_segment/blob/master/ocrd_segment/extract_pages.py)
|
||||||
|
> Extract region classes and their colours in mask (pseg) images. Allows the color map as free dict parameter, and comes with a default that mimics PageViewer's coloring for quick debugging; it also warns when regions do overlap.
|
||||||
|
|
||||||
|
### Train using Docker
|
||||||
|
|
||||||
|
Build the Docker image:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd train
|
||||||
|
docker build -t model-training .
|
||||||
|
```
|
||||||
|
|
||||||
|
Run Docker image
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd train
|
||||||
|
docker run --gpus all -v $PWD:/entry_point_dir model-training
|
||||||
|
```
|
0
train/__init__.py
Normal file
0
train/__init__.py
Normal file
29
train/build_model_load_pretrained_weights_and_save.py
Normal file
29
train/build_model_load_pretrained_weights_and_save.py
Normal file
|
@ -0,0 +1,29 @@
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import tensorflow as tf
|
||||||
|
import warnings
|
||||||
|
from tensorflow.keras.optimizers import *
|
||||||
|
from sacred import Experiment
|
||||||
|
from models import *
|
||||||
|
from utils import *
|
||||||
|
from metrics import *
|
||||||
|
|
||||||
|
|
||||||
|
def configuration():
|
||||||
|
gpu_options = tf.compat.v1.GPUOptions(allow_growth=True)
|
||||||
|
session = tf.compat.v1.Session(config=tf.compat.v1.ConfigProto(gpu_options=gpu_options))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
n_classes = 2
|
||||||
|
input_height = 224
|
||||||
|
input_width = 448
|
||||||
|
weight_decay = 1e-6
|
||||||
|
pretraining = False
|
||||||
|
dir_of_weights = 'model_bin_sbb_ens.h5'
|
||||||
|
|
||||||
|
# configuration()
|
||||||
|
|
||||||
|
model = resnet50_unet(n_classes, input_height, input_width, weight_decay, pretraining)
|
||||||
|
model.load_weights(dir_of_weights)
|
||||||
|
model.save('./name_in_another_python_version.h5')
|
58
train/config_params.json
Normal file
58
train/config_params.json
Normal file
|
@ -0,0 +1,58 @@
|
||||||
|
{
|
||||||
|
"backbone_type" : "transformer",
|
||||||
|
"task": "segmentation",
|
||||||
|
"n_classes" : 2,
|
||||||
|
"n_epochs" : 0,
|
||||||
|
"input_height" : 448,
|
||||||
|
"input_width" : 448,
|
||||||
|
"weight_decay" : 1e-6,
|
||||||
|
"n_batch" : 1,
|
||||||
|
"learning_rate": 1e-4,
|
||||||
|
"patches" : false,
|
||||||
|
"pretraining" : true,
|
||||||
|
"augmentation" : true,
|
||||||
|
"flip_aug" : false,
|
||||||
|
"blur_aug" : false,
|
||||||
|
"scaling" : false,
|
||||||
|
"adding_rgb_background": true,
|
||||||
|
"adding_rgb_foreground": true,
|
||||||
|
"add_red_textlines": false,
|
||||||
|
"channels_shuffling": false,
|
||||||
|
"degrading": false,
|
||||||
|
"brightening": false,
|
||||||
|
"binarization" : true,
|
||||||
|
"scaling_bluring" : false,
|
||||||
|
"scaling_binarization" : false,
|
||||||
|
"scaling_flip" : false,
|
||||||
|
"rotation": false,
|
||||||
|
"rotation_not_90": false,
|
||||||
|
"transformer_num_patches_xy": [56, 56],
|
||||||
|
"transformer_patchsize_x": 4,
|
||||||
|
"transformer_patchsize_y": 4,
|
||||||
|
"transformer_projection_dim": 64,
|
||||||
|
"transformer_mlp_head_units": [128, 64],
|
||||||
|
"transformer_layers": 1,
|
||||||
|
"transformer_num_heads": 1,
|
||||||
|
"transformer_cnn_first": false,
|
||||||
|
"blur_k" : ["blur","guass","median"],
|
||||||
|
"scales" : [0.6, 0.7, 0.8, 0.9],
|
||||||
|
"brightness" : [1.3, 1.5, 1.7, 2],
|
||||||
|
"degrade_scales" : [0.2, 0.4],
|
||||||
|
"flip_index" : [0, 1, -1],
|
||||||
|
"shuffle_indexes" : [ [0,2,1], [1,2,0], [1,0,2] , [2,1,0]],
|
||||||
|
"thetha" : [5, -5],
|
||||||
|
"number_of_backgrounds_per_image": 2,
|
||||||
|
"continue_training": false,
|
||||||
|
"index_start" : 0,
|
||||||
|
"dir_of_start_model" : " ",
|
||||||
|
"weighted_loss": false,
|
||||||
|
"is_loss_soft_dice": false,
|
||||||
|
"data_is_provided": false,
|
||||||
|
"dir_train": "/home/vahid/Documents/test/sbb_pixelwise_segmentation/test_label/pageextractor_test/train_new",
|
||||||
|
"dir_eval": "/home/vahid/Documents/test/sbb_pixelwise_segmentation/test_label/pageextractor_test/eval_new",
|
||||||
|
"dir_output": "/home/vahid/Documents/test/sbb_pixelwise_segmentation/test_label/pageextractor_test/output_new",
|
||||||
|
"dir_rgb_backgrounds": "/home/vahid/Documents/1_2_test_eynollah/set_rgb_background",
|
||||||
|
"dir_rgb_foregrounds": "/home/vahid/Documents/1_2_test_eynollah/out_set_rgb_foreground",
|
||||||
|
"dir_img_bin": "/home/vahid/Documents/test/sbb_pixelwise_segmentation/test_label/pageextractor_test/train_new/images_bin"
|
||||||
|
|
||||||
|
}
|
54
train/config_params_docker.json
Normal file
54
train/config_params_docker.json
Normal file
|
@ -0,0 +1,54 @@
|
||||||
|
{
|
||||||
|
"backbone_type" : "nontransformer",
|
||||||
|
"task": "segmentation",
|
||||||
|
"n_classes" : 3,
|
||||||
|
"n_epochs" : 1,
|
||||||
|
"input_height" : 672,
|
||||||
|
"input_width" : 448,
|
||||||
|
"weight_decay" : 1e-6,
|
||||||
|
"n_batch" : 4,
|
||||||
|
"learning_rate": 1e-4,
|
||||||
|
"patches" : false,
|
||||||
|
"pretraining" : true,
|
||||||
|
"augmentation" : false,
|
||||||
|
"flip_aug" : false,
|
||||||
|
"blur_aug" : true,
|
||||||
|
"scaling" : true,
|
||||||
|
"adding_rgb_background": false,
|
||||||
|
"adding_rgb_foreground": false,
|
||||||
|
"add_red_textlines": false,
|
||||||
|
"channels_shuffling": true,
|
||||||
|
"degrading": true,
|
||||||
|
"brightening": true,
|
||||||
|
"binarization" : false,
|
||||||
|
"scaling_bluring" : false,
|
||||||
|
"scaling_binarization" : false,
|
||||||
|
"scaling_flip" : false,
|
||||||
|
"rotation": false,
|
||||||
|
"rotation_not_90": true,
|
||||||
|
"transformer_num_patches_xy": [14, 21],
|
||||||
|
"transformer_patchsize_x": 1,
|
||||||
|
"transformer_patchsize_y": 1,
|
||||||
|
"transformer_projection_dim": 64,
|
||||||
|
"transformer_mlp_head_units": [128, 64],
|
||||||
|
"transformer_layers": 1,
|
||||||
|
"transformer_num_heads": 1,
|
||||||
|
"transformer_cnn_first": true,
|
||||||
|
"blur_k" : ["blur","gauss","median"],
|
||||||
|
"scales" : [0.6, 0.7, 0.8, 0.9],
|
||||||
|
"brightness" : [1.3, 1.5, 1.7, 2],
|
||||||
|
"degrade_scales" : [0.2, 0.4],
|
||||||
|
"flip_index" : [0, 1, -1],
|
||||||
|
"shuffle_indexes" : [ [0,2,1], [1,2,0], [1,0,2] , [2,1,0]],
|
||||||
|
"thetha" : [5, -5],
|
||||||
|
"number_of_backgrounds_per_image": 2,
|
||||||
|
"continue_training": false,
|
||||||
|
"index_start" : 0,
|
||||||
|
"dir_of_start_model" : " ",
|
||||||
|
"weighted_loss": false,
|
||||||
|
"is_loss_soft_dice": true,
|
||||||
|
"data_is_provided": false,
|
||||||
|
"dir_train": "/entry_point_dir/train",
|
||||||
|
"dir_eval": "/entry_point_dir/eval",
|
||||||
|
"dir_output": "/entry_point_dir/output"
|
||||||
|
}
|
8
train/custom_config_page2label.json
Normal file
8
train/custom_config_page2label.json
Normal file
|
@ -0,0 +1,8 @@
|
||||||
|
{
|
||||||
|
"use_case": "textline",
|
||||||
|
"textregions":{ "rest_as_paragraph": 1, "header":2 , "heading":2 , "marginalia":3 },
|
||||||
|
"imageregion":4,
|
||||||
|
"separatorregion":5,
|
||||||
|
"graphicregions" :{"rest_as_decoration":6},
|
||||||
|
"columns_width":{"1":1000, "2":1300, "3":1600, "4":2000, "5":2300, "6":2500}
|
||||||
|
}
|
567
train/generate_gt_for_training.py
Normal file
567
train/generate_gt_for_training.py
Normal file
|
@ -0,0 +1,567 @@
|
||||||
|
import click
|
||||||
|
import json
|
||||||
|
from gt_gen_utils import *
|
||||||
|
from tqdm import tqdm
|
||||||
|
from pathlib import Path
|
||||||
|
from PIL import Image, ImageDraw, ImageFont
|
||||||
|
|
||||||
|
@click.group()
|
||||||
|
def main():
|
||||||
|
pass
|
||||||
|
|
||||||
|
@main.command()
|
||||||
|
@click.option(
|
||||||
|
"--dir_xml",
|
||||||
|
"-dx",
|
||||||
|
help="directory of GT page-xml files",
|
||||||
|
type=click.Path(exists=True, file_okay=False),
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--dir_images",
|
||||||
|
"-di",
|
||||||
|
help="directory of org images. If print space cropping or scaling is needed for labels it would be great to provide the original images to apply the same function on them. So if -ps is not set true or in config files no columns_width key is given this argumnet can be ignored. File stems in this directory should be the same as those in dir_xml.",
|
||||||
|
type=click.Path(exists=True, file_okay=False),
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--dir_out_images",
|
||||||
|
"-doi",
|
||||||
|
help="directory where the output org images after undergoing a process (like print space cropping or scaling) will be written.",
|
||||||
|
type=click.Path(exists=True, file_okay=False),
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--dir_out",
|
||||||
|
"-do",
|
||||||
|
help="directory where ground truth label images would be written",
|
||||||
|
type=click.Path(exists=True, file_okay=False),
|
||||||
|
)
|
||||||
|
|
||||||
|
@click.option(
|
||||||
|
"--config",
|
||||||
|
"-cfg",
|
||||||
|
help="config file of prefered layout or use case.",
|
||||||
|
type=click.Path(exists=True, dir_okay=False),
|
||||||
|
)
|
||||||
|
|
||||||
|
@click.option(
|
||||||
|
"--type_output",
|
||||||
|
"-to",
|
||||||
|
help="this defines how output should be. A 2d image array or a 3d image array encoded with RGB color. Just pass 2d or 3d. The file will be saved one directory up. 2D image array is 3d but only information of one channel would be enough since all channels have the same values.",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--printspace",
|
||||||
|
"-ps",
|
||||||
|
is_flag=True,
|
||||||
|
help="if this parameter set to true, generated labels and in the case of provided org images cropping will be imposed and cropped labels and images will be written in output directories.",
|
||||||
|
)
|
||||||
|
|
||||||
|
def pagexml2label(dir_xml,dir_out,type_output,config, printspace, dir_images, dir_out_images):
|
||||||
|
if config:
|
||||||
|
with open(config) as f:
|
||||||
|
config_params = json.load(f)
|
||||||
|
else:
|
||||||
|
print("passed")
|
||||||
|
config_params = None
|
||||||
|
gt_list = get_content_of_dir(dir_xml)
|
||||||
|
get_images_of_ground_truth(gt_list,dir_xml,dir_out,type_output, config, config_params, printspace, dir_images, dir_out_images)
|
||||||
|
|
||||||
|
@main.command()
|
||||||
|
@click.option(
|
||||||
|
"--dir_imgs",
|
||||||
|
"-dis",
|
||||||
|
help="directory of images with high resolution.",
|
||||||
|
type=click.Path(exists=True, file_okay=False),
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--dir_out_images",
|
||||||
|
"-dois",
|
||||||
|
help="directory where degraded images will be written.",
|
||||||
|
type=click.Path(exists=True, file_okay=False),
|
||||||
|
)
|
||||||
|
|
||||||
|
@click.option(
|
||||||
|
"--dir_out_labels",
|
||||||
|
"-dols",
|
||||||
|
help="directory where original images will be written as labels.",
|
||||||
|
type=click.Path(exists=True, file_okay=False),
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--scales",
|
||||||
|
"-scs",
|
||||||
|
help="json dictionary where the scales are written.",
|
||||||
|
type=click.Path(exists=True, dir_okay=False),
|
||||||
|
)
|
||||||
|
def image_enhancement(dir_imgs, dir_out_images, dir_out_labels, scales):
|
||||||
|
ls_imgs = os.listdir(dir_imgs)
|
||||||
|
with open(scales) as f:
|
||||||
|
scale_dict = json.load(f)
|
||||||
|
ls_scales = scale_dict['scales']
|
||||||
|
|
||||||
|
for img in tqdm(ls_imgs):
|
||||||
|
img_name = img.split('.')[0]
|
||||||
|
img_type = img.split('.')[1]
|
||||||
|
image = cv2.imread(os.path.join(dir_imgs, img))
|
||||||
|
for i, scale in enumerate(ls_scales):
|
||||||
|
height_sc = int(image.shape[0]*scale)
|
||||||
|
width_sc = int(image.shape[1]*scale)
|
||||||
|
|
||||||
|
image_down_scaled = resize_image(image, height_sc, width_sc)
|
||||||
|
image_back_to_org_scale = resize_image(image_down_scaled, image.shape[0], image.shape[1])
|
||||||
|
|
||||||
|
cv2.imwrite(os.path.join(dir_out_images, img_name+'_'+str(i)+'.'+img_type), image_back_to_org_scale)
|
||||||
|
cv2.imwrite(os.path.join(dir_out_labels, img_name+'_'+str(i)+'.'+img_type), image)
|
||||||
|
|
||||||
|
|
||||||
|
@main.command()
|
||||||
|
@click.option(
|
||||||
|
"--dir_xml",
|
||||||
|
"-dx",
|
||||||
|
help="directory of GT page-xml files",
|
||||||
|
type=click.Path(exists=True, file_okay=False),
|
||||||
|
)
|
||||||
|
|
||||||
|
@click.option(
|
||||||
|
"--dir_out_modal_image",
|
||||||
|
"-domi",
|
||||||
|
help="directory where ground truth images would be written",
|
||||||
|
type=click.Path(exists=True, file_okay=False),
|
||||||
|
)
|
||||||
|
|
||||||
|
@click.option(
|
||||||
|
"--dir_out_classes",
|
||||||
|
"-docl",
|
||||||
|
help="directory where ground truth classes would be written",
|
||||||
|
type=click.Path(exists=True, file_okay=False),
|
||||||
|
)
|
||||||
|
|
||||||
|
@click.option(
|
||||||
|
"--input_height",
|
||||||
|
"-ih",
|
||||||
|
help="input height",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--input_width",
|
||||||
|
"-iw",
|
||||||
|
help="input width",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--min_area_size",
|
||||||
|
"-min",
|
||||||
|
help="min area size of regions considered for reading order training.",
|
||||||
|
)
|
||||||
|
|
||||||
|
@click.option(
|
||||||
|
"--min_area_early",
|
||||||
|
"-min_early",
|
||||||
|
help="If you have already generated a training dataset using a specific minimum area value and now wish to create a dataset with a smaller minimum area value, you can avoid regenerating the previous dataset by providing the earlier minimum area value. This will ensure that only the missing data is generated.",
|
||||||
|
)
|
||||||
|
|
||||||
|
def machine_based_reading_order(dir_xml, dir_out_modal_image, dir_out_classes, input_height, input_width, min_area_size, min_area_early):
|
||||||
|
xml_files_ind = os.listdir(dir_xml)
|
||||||
|
xml_files_ind = [ind_xml for ind_xml in xml_files_ind if ind_xml.endswith('.xml')]
|
||||||
|
input_height = int(input_height)
|
||||||
|
input_width = int(input_width)
|
||||||
|
min_area = float(min_area_size)
|
||||||
|
if min_area_early:
|
||||||
|
min_area_early = float(min_area_early)
|
||||||
|
|
||||||
|
|
||||||
|
indexer_start= 0#55166
|
||||||
|
max_area = 1
|
||||||
|
#min_area = 0.0001
|
||||||
|
|
||||||
|
for ind_xml in tqdm(xml_files_ind):
|
||||||
|
indexer = 0
|
||||||
|
#print(ind_xml)
|
||||||
|
#print('########################')
|
||||||
|
xml_file = os.path.join(dir_xml,ind_xml )
|
||||||
|
f_name = ind_xml.split('.')[0]
|
||||||
|
_, _, _, file_name, id_paragraph, id_header,co_text_paragraph,co_text_header,tot_region_ref,x_len, y_len,index_tot_regions,img_poly = read_xml(xml_file)
|
||||||
|
|
||||||
|
id_all_text = id_paragraph + id_header
|
||||||
|
co_text_all = co_text_paragraph + co_text_header
|
||||||
|
|
||||||
|
|
||||||
|
_, cy_main, x_min_main, x_max_main, y_min_main, y_max_main, _ = find_new_features_of_contours(co_text_header)
|
||||||
|
|
||||||
|
img_header_and_sep = np.zeros((y_len,x_len), dtype='uint8')
|
||||||
|
|
||||||
|
for j in range(len(cy_main)):
|
||||||
|
img_header_and_sep[int(y_max_main[j]):int(y_max_main[j])+12,int(x_min_main[j]):int(x_max_main[j]) ] = 1
|
||||||
|
|
||||||
|
|
||||||
|
texts_corr_order_index = [index_tot_regions[tot_region_ref.index(i)] for i in id_all_text ]
|
||||||
|
texts_corr_order_index_int = [int(x) for x in texts_corr_order_index]
|
||||||
|
|
||||||
|
|
||||||
|
co_text_all, texts_corr_order_index_int, regions_ar_less_than_early_min = filter_contours_area_of_image(img_poly, co_text_all, texts_corr_order_index_int, max_area, min_area, min_area_early)
|
||||||
|
|
||||||
|
|
||||||
|
arg_array = np.array(range(len(texts_corr_order_index_int)))
|
||||||
|
|
||||||
|
labels_con = np.zeros((y_len,x_len,len(arg_array)),dtype='uint8')
|
||||||
|
for i in range(len(co_text_all)):
|
||||||
|
img_label = np.zeros((y_len,x_len,3),dtype='uint8')
|
||||||
|
img_label=cv2.fillPoly(img_label, pts =[co_text_all[i]], color=(1,1,1))
|
||||||
|
|
||||||
|
img_label[:,:,0][img_poly[:,:,0]==5] = 2
|
||||||
|
img_label[:,:,0][img_header_and_sep[:,:]==1] = 3
|
||||||
|
|
||||||
|
labels_con[:,:,i] = img_label[:,:,0]
|
||||||
|
|
||||||
|
labels_con = resize_image(labels_con, input_height, input_width)
|
||||||
|
img_poly = resize_image(img_poly, input_height, input_width)
|
||||||
|
|
||||||
|
|
||||||
|
for i in range(len(texts_corr_order_index_int)):
|
||||||
|
for j in range(len(texts_corr_order_index_int)):
|
||||||
|
if i!=j:
|
||||||
|
if regions_ar_less_than_early_min:
|
||||||
|
if regions_ar_less_than_early_min[i]==1:
|
||||||
|
input_multi_visual_modal = np.zeros((input_height,input_width,3)).astype(np.int8)
|
||||||
|
final_f_name = f_name+'_'+str(indexer+indexer_start)
|
||||||
|
order_class_condition = texts_corr_order_index_int[i]-texts_corr_order_index_int[j]
|
||||||
|
if order_class_condition<0:
|
||||||
|
class_type = 1
|
||||||
|
else:
|
||||||
|
class_type = 0
|
||||||
|
|
||||||
|
input_multi_visual_modal[:,:,0] = labels_con[:,:,i]
|
||||||
|
input_multi_visual_modal[:,:,1] = img_poly[:,:,0]
|
||||||
|
input_multi_visual_modal[:,:,2] = labels_con[:,:,j]
|
||||||
|
|
||||||
|
np.save(os.path.join(dir_out_classes,final_f_name+'_missed.npy' ), class_type)
|
||||||
|
|
||||||
|
cv2.imwrite(os.path.join(dir_out_modal_image,final_f_name+'_missed.png' ), input_multi_visual_modal)
|
||||||
|
indexer = indexer+1
|
||||||
|
|
||||||
|
else:
|
||||||
|
input_multi_visual_modal = np.zeros((input_height,input_width,3)).astype(np.int8)
|
||||||
|
final_f_name = f_name+'_'+str(indexer+indexer_start)
|
||||||
|
order_class_condition = texts_corr_order_index_int[i]-texts_corr_order_index_int[j]
|
||||||
|
if order_class_condition<0:
|
||||||
|
class_type = 1
|
||||||
|
else:
|
||||||
|
class_type = 0
|
||||||
|
|
||||||
|
input_multi_visual_modal[:,:,0] = labels_con[:,:,i]
|
||||||
|
input_multi_visual_modal[:,:,1] = img_poly[:,:,0]
|
||||||
|
input_multi_visual_modal[:,:,2] = labels_con[:,:,j]
|
||||||
|
|
||||||
|
np.save(os.path.join(dir_out_classes,final_f_name+'.npy' ), class_type)
|
||||||
|
|
||||||
|
cv2.imwrite(os.path.join(dir_out_modal_image,final_f_name+'.png' ), input_multi_visual_modal)
|
||||||
|
indexer = indexer+1
|
||||||
|
|
||||||
|
|
||||||
|
@main.command()
|
||||||
|
@click.option(
|
||||||
|
"--xml_file",
|
||||||
|
"-xml",
|
||||||
|
help="xml filename",
|
||||||
|
type=click.Path(exists=True, dir_okay=False),
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--dir_xml",
|
||||||
|
"-dx",
|
||||||
|
help="directory of GT page-xml files",
|
||||||
|
type=click.Path(exists=True, file_okay=False),
|
||||||
|
)
|
||||||
|
|
||||||
|
@click.option(
|
||||||
|
"--dir_out",
|
||||||
|
"-o",
|
||||||
|
help="directory where plots will be written",
|
||||||
|
type=click.Path(exists=True, file_okay=False),
|
||||||
|
)
|
||||||
|
|
||||||
|
@click.option(
|
||||||
|
"--dir_imgs",
|
||||||
|
"-di",
|
||||||
|
help="directory where the overlayed plots will be written", )
|
||||||
|
|
||||||
|
def visualize_reading_order(xml_file, dir_xml, dir_out, dir_imgs):
|
||||||
|
assert xml_file or dir_xml, "A single xml file -xml or a dir of xml files -dx is required not both of them"
|
||||||
|
|
||||||
|
if dir_xml:
|
||||||
|
xml_files_ind = os.listdir(dir_xml)
|
||||||
|
xml_files_ind = [ind_xml for ind_xml in xml_files_ind if ind_xml.endswith('.xml')]
|
||||||
|
else:
|
||||||
|
xml_files_ind = [xml_file]
|
||||||
|
|
||||||
|
indexer_start= 0#55166
|
||||||
|
#min_area = 0.0001
|
||||||
|
|
||||||
|
for ind_xml in tqdm(xml_files_ind):
|
||||||
|
indexer = 0
|
||||||
|
#print(ind_xml)
|
||||||
|
#print('########################')
|
||||||
|
#xml_file = os.path.join(dir_xml,ind_xml )
|
||||||
|
|
||||||
|
if dir_xml:
|
||||||
|
xml_file = os.path.join(dir_xml,ind_xml )
|
||||||
|
f_name = Path(ind_xml).stem
|
||||||
|
else:
|
||||||
|
xml_file = os.path.join(ind_xml )
|
||||||
|
f_name = Path(ind_xml).stem
|
||||||
|
print(f_name, 'f_name')
|
||||||
|
|
||||||
|
#f_name = ind_xml.split('.')[0]
|
||||||
|
_, _, _, file_name, id_paragraph, id_header,co_text_paragraph,co_text_header,tot_region_ref,x_len, y_len,index_tot_regions,img_poly = read_xml(xml_file)
|
||||||
|
|
||||||
|
id_all_text = id_paragraph + id_header
|
||||||
|
co_text_all = co_text_paragraph + co_text_header
|
||||||
|
|
||||||
|
|
||||||
|
cx_main, cy_main, x_min_main, x_max_main, y_min_main, y_max_main, _ = find_new_features_of_contours(co_text_all)
|
||||||
|
|
||||||
|
texts_corr_order_index = [int(index_tot_regions[tot_region_ref.index(i)]) for i in id_all_text ]
|
||||||
|
#texts_corr_order_index_int = [int(x) for x in texts_corr_order_index]
|
||||||
|
|
||||||
|
|
||||||
|
#cx_ordered = np.array(cx_main)[np.array(texts_corr_order_index)]
|
||||||
|
#cx_ordered = cx_ordered.astype(np.int32)
|
||||||
|
|
||||||
|
cx_ordered = [int(val) for (_, val) in sorted(zip(texts_corr_order_index, cx_main), key=lambda x: \
|
||||||
|
x[0], reverse=False)]
|
||||||
|
#cx_ordered = cx_ordered.astype(np.int32)
|
||||||
|
|
||||||
|
cy_ordered = [int(val) for (_, val) in sorted(zip(texts_corr_order_index, cy_main), key=lambda x: \
|
||||||
|
x[0], reverse=False)]
|
||||||
|
#cy_ordered = cy_ordered.astype(np.int32)
|
||||||
|
|
||||||
|
|
||||||
|
color = (0, 0, 255)
|
||||||
|
thickness = 20
|
||||||
|
if dir_imgs:
|
||||||
|
layout = np.zeros( (y_len,x_len,3) )
|
||||||
|
layout = cv2.fillPoly(layout, pts =co_text_all, color=(1,1,1))
|
||||||
|
|
||||||
|
img_file_name_with_format = find_format_of_given_filename_in_dir(dir_imgs, f_name)
|
||||||
|
img = cv2.imread(os.path.join(dir_imgs, img_file_name_with_format))
|
||||||
|
|
||||||
|
overlayed = overlay_layout_on_image(layout, img, cx_ordered, cy_ordered, color, thickness)
|
||||||
|
cv2.imwrite(os.path.join(dir_out, f_name+'.png'), overlayed)
|
||||||
|
|
||||||
|
else:
|
||||||
|
img = np.zeros( (y_len,x_len,3) )
|
||||||
|
img = cv2.fillPoly(img, pts =co_text_all, color=(255,0,0))
|
||||||
|
for i in range(len(cx_ordered)-1):
|
||||||
|
start_point = (int(cx_ordered[i]), int(cy_ordered[i]))
|
||||||
|
end_point = (int(cx_ordered[i+1]), int(cy_ordered[i+1]))
|
||||||
|
img = cv2.arrowedLine(img, start_point, end_point,
|
||||||
|
color, thickness, tipLength = 0.03)
|
||||||
|
|
||||||
|
cv2.imwrite(os.path.join(dir_out, f_name+'.png'), img)
|
||||||
|
|
||||||
|
|
||||||
|
@main.command()
|
||||||
|
@click.option(
|
||||||
|
"--xml_file",
|
||||||
|
"-xml",
|
||||||
|
help="xml filename",
|
||||||
|
type=click.Path(exists=True, dir_okay=False),
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--dir_xml",
|
||||||
|
"-dx",
|
||||||
|
help="directory of GT page-xml files",
|
||||||
|
type=click.Path(exists=True, file_okay=False),
|
||||||
|
)
|
||||||
|
|
||||||
|
@click.option(
|
||||||
|
"--dir_out",
|
||||||
|
"-o",
|
||||||
|
help="directory where plots will be written",
|
||||||
|
type=click.Path(exists=True, file_okay=False),
|
||||||
|
)
|
||||||
|
|
||||||
|
@click.option(
|
||||||
|
"--dir_imgs",
|
||||||
|
"-di",
|
||||||
|
help="directory of images where textline segmentation will be overlayed", )
|
||||||
|
|
||||||
|
def visualize_textline_segmentation(xml_file, dir_xml, dir_out, dir_imgs):
|
||||||
|
assert xml_file or dir_xml, "A single xml file -xml or a dir of xml files -dx is required not both of them"
|
||||||
|
if dir_xml:
|
||||||
|
xml_files_ind = os.listdir(dir_xml)
|
||||||
|
xml_files_ind = [ind_xml for ind_xml in xml_files_ind if ind_xml.endswith('.xml')]
|
||||||
|
else:
|
||||||
|
xml_files_ind = [xml_file]
|
||||||
|
|
||||||
|
for ind_xml in tqdm(xml_files_ind):
|
||||||
|
indexer = 0
|
||||||
|
#print(ind_xml)
|
||||||
|
#print('########################')
|
||||||
|
xml_file = os.path.join(dir_xml,ind_xml )
|
||||||
|
f_name = Path(ind_xml).stem
|
||||||
|
|
||||||
|
img_file_name_with_format = find_format_of_given_filename_in_dir(dir_imgs, f_name)
|
||||||
|
img = cv2.imread(os.path.join(dir_imgs, img_file_name_with_format))
|
||||||
|
|
||||||
|
co_tetxlines, y_len, x_len = get_textline_contours_for_visualization(xml_file)
|
||||||
|
|
||||||
|
added_image = visualize_image_from_contours(co_tetxlines, img)
|
||||||
|
|
||||||
|
cv2.imwrite(os.path.join(dir_out, f_name+'.png'), added_image)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@main.command()
|
||||||
|
@click.option(
|
||||||
|
"--xml_file",
|
||||||
|
"-xml",
|
||||||
|
help="xml filename",
|
||||||
|
type=click.Path(exists=True, dir_okay=False),
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--dir_xml",
|
||||||
|
"-dx",
|
||||||
|
help="directory of GT page-xml files",
|
||||||
|
type=click.Path(exists=True, file_okay=False),
|
||||||
|
)
|
||||||
|
|
||||||
|
@click.option(
|
||||||
|
"--dir_out",
|
||||||
|
"-o",
|
||||||
|
help="directory where plots will be written",
|
||||||
|
type=click.Path(exists=True, file_okay=False),
|
||||||
|
)
|
||||||
|
|
||||||
|
@click.option(
|
||||||
|
"--dir_imgs",
|
||||||
|
"-di",
|
||||||
|
help="directory of images where textline segmentation will be overlayed", )
|
||||||
|
|
||||||
|
def visualize_layout_segmentation(xml_file, dir_xml, dir_out, dir_imgs):
|
||||||
|
assert xml_file or dir_xml, "A single xml file -xml or a dir of xml files -dx is required not both of them"
|
||||||
|
if dir_xml:
|
||||||
|
xml_files_ind = os.listdir(dir_xml)
|
||||||
|
xml_files_ind = [ind_xml for ind_xml in xml_files_ind if ind_xml.endswith('.xml')]
|
||||||
|
else:
|
||||||
|
xml_files_ind = [xml_file]
|
||||||
|
|
||||||
|
for ind_xml in tqdm(xml_files_ind):
|
||||||
|
indexer = 0
|
||||||
|
#print(ind_xml)
|
||||||
|
#print('########################')
|
||||||
|
if dir_xml:
|
||||||
|
xml_file = os.path.join(dir_xml,ind_xml )
|
||||||
|
f_name = Path(ind_xml).stem
|
||||||
|
else:
|
||||||
|
xml_file = os.path.join(ind_xml )
|
||||||
|
f_name = Path(ind_xml).stem
|
||||||
|
print(f_name, 'f_name')
|
||||||
|
|
||||||
|
img_file_name_with_format = find_format_of_given_filename_in_dir(dir_imgs, f_name)
|
||||||
|
img = cv2.imread(os.path.join(dir_imgs, img_file_name_with_format))
|
||||||
|
|
||||||
|
co_text, co_graphic, co_sep, co_img, co_table, co_noise, y_len, x_len = get_layout_contours_for_visualization(xml_file)
|
||||||
|
|
||||||
|
|
||||||
|
added_image = visualize_image_from_contours_layout(co_text['paragraph'], co_text['header']+co_text['heading'], co_text['drop-capital'], co_sep, co_img, co_text['marginalia'], co_table, img)
|
||||||
|
|
||||||
|
cv2.imwrite(os.path.join(dir_out, f_name+'.png'), added_image)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@main.command()
|
||||||
|
@click.option(
|
||||||
|
"--xml_file",
|
||||||
|
"-xml",
|
||||||
|
help="xml filename",
|
||||||
|
type=click.Path(exists=True, dir_okay=False),
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--dir_xml",
|
||||||
|
"-dx",
|
||||||
|
help="directory of GT page-xml files",
|
||||||
|
type=click.Path(exists=True, file_okay=False),
|
||||||
|
)
|
||||||
|
|
||||||
|
@click.option(
|
||||||
|
"--dir_out",
|
||||||
|
"-o",
|
||||||
|
help="directory where plots will be written",
|
||||||
|
type=click.Path(exists=True, file_okay=False),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def visualize_ocr_text(xml_file, dir_xml, dir_out):
|
||||||
|
assert xml_file or dir_xml, "A single xml file -xml or a dir of xml files -dx is required not both of them"
|
||||||
|
if dir_xml:
|
||||||
|
xml_files_ind = os.listdir(dir_xml)
|
||||||
|
xml_files_ind = [ind_xml for ind_xml in xml_files_ind if ind_xml.endswith('.xml')]
|
||||||
|
else:
|
||||||
|
xml_files_ind = [xml_file]
|
||||||
|
|
||||||
|
font_path = "Charis-7.000/Charis-Regular.ttf" # Make sure this file exists!
|
||||||
|
font = ImageFont.truetype(font_path, 40)
|
||||||
|
|
||||||
|
for ind_xml in tqdm(xml_files_ind):
|
||||||
|
indexer = 0
|
||||||
|
#print(ind_xml)
|
||||||
|
#print('########################')
|
||||||
|
if dir_xml:
|
||||||
|
xml_file = os.path.join(dir_xml,ind_xml )
|
||||||
|
f_name = Path(ind_xml).stem
|
||||||
|
else:
|
||||||
|
xml_file = os.path.join(ind_xml )
|
||||||
|
f_name = Path(ind_xml).stem
|
||||||
|
print(f_name, 'f_name')
|
||||||
|
|
||||||
|
co_tetxlines, y_len, x_len, ocr_texts = get_textline_contours_and_ocr_text(xml_file)
|
||||||
|
|
||||||
|
total_bb_coordinates = []
|
||||||
|
|
||||||
|
image_text = Image.new("RGB", (x_len, y_len), "white")
|
||||||
|
draw = ImageDraw.Draw(image_text)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
for index, cnt in enumerate(co_tetxlines):
|
||||||
|
x,y,w,h = cv2.boundingRect(cnt)
|
||||||
|
#total_bb_coordinates.append([x,y,w,h])
|
||||||
|
|
||||||
|
#fit_text_single_line
|
||||||
|
|
||||||
|
#x_bb = bb_ind[0]
|
||||||
|
#y_bb = bb_ind[1]
|
||||||
|
#w_bb = bb_ind[2]
|
||||||
|
#h_bb = bb_ind[3]
|
||||||
|
if ocr_texts[index]:
|
||||||
|
|
||||||
|
|
||||||
|
is_vertical = h > 2*w # Check orientation
|
||||||
|
font = fit_text_single_line(draw, ocr_texts[index], font_path, w, int(h*0.4) )
|
||||||
|
|
||||||
|
if is_vertical:
|
||||||
|
|
||||||
|
vertical_font = fit_text_single_line(draw, ocr_texts[index], font_path, h, int(w * 0.8))
|
||||||
|
|
||||||
|
text_img = Image.new("RGBA", (h, w), (255, 255, 255, 0)) # Note: dimensions are swapped
|
||||||
|
text_draw = ImageDraw.Draw(text_img)
|
||||||
|
text_draw.text((0, 0), ocr_texts[index], font=vertical_font, fill="black")
|
||||||
|
|
||||||
|
# Rotate text image by 90 degrees
|
||||||
|
rotated_text = text_img.rotate(90, expand=1)
|
||||||
|
|
||||||
|
# Calculate paste position (centered in bbox)
|
||||||
|
paste_x = x + (w - rotated_text.width) // 2
|
||||||
|
paste_y = y + (h - rotated_text.height) // 2
|
||||||
|
|
||||||
|
image_text.paste(rotated_text, (paste_x, paste_y), rotated_text) # Use rotated image as mask
|
||||||
|
else:
|
||||||
|
text_bbox = draw.textbbox((0, 0), ocr_texts[index], font=font)
|
||||||
|
text_width = text_bbox[2] - text_bbox[0]
|
||||||
|
text_height = text_bbox[3] - text_bbox[1]
|
||||||
|
|
||||||
|
text_x = x + (w - text_width) // 2 # Center horizontally
|
||||||
|
text_y = y + (h - text_height) // 2 # Center vertically
|
||||||
|
|
||||||
|
# Draw the text
|
||||||
|
draw.text((text_x, text_y), ocr_texts[index], fill="black", font=font)
|
||||||
|
image_text.save(os.path.join(dir_out, f_name+'.png'))
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
1838
train/gt_gen_utils.py
Normal file
1838
train/gt_gen_utils.py
Normal file
File diff suppressed because it is too large
Load diff
682
train/inference.py
Normal file
682
train/inference.py
Normal file
|
@ -0,0 +1,682 @@
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
import warnings
|
||||||
|
import cv2
|
||||||
|
import seaborn as sns
|
||||||
|
from tensorflow.keras.models import load_model
|
||||||
|
import tensorflow as tf
|
||||||
|
from tensorflow.keras import backend as K
|
||||||
|
from tensorflow.keras import layers
|
||||||
|
import tensorflow.keras.losses
|
||||||
|
from tensorflow.keras.layers import *
|
||||||
|
from models import *
|
||||||
|
from gt_gen_utils import *
|
||||||
|
import click
|
||||||
|
import json
|
||||||
|
from tensorflow.python.keras import backend as tensorflow_backend
|
||||||
|
import xml.etree.ElementTree as ET
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
|
||||||
|
with warnings.catch_warnings():
|
||||||
|
warnings.simplefilter("ignore")
|
||||||
|
|
||||||
|
__doc__=\
|
||||||
|
"""
|
||||||
|
Tool to load model and predict for given image.
|
||||||
|
"""
|
||||||
|
|
||||||
|
class sbb_predict:
|
||||||
|
def __init__(self,image, dir_in, model, task, config_params_model, patches, save, save_layout, ground_truth, xml_file, out, min_area):
|
||||||
|
self.image=image
|
||||||
|
self.dir_in=dir_in
|
||||||
|
self.patches=patches
|
||||||
|
self.save=save
|
||||||
|
self.save_layout=save_layout
|
||||||
|
self.model_dir=model
|
||||||
|
self.ground_truth=ground_truth
|
||||||
|
self.task=task
|
||||||
|
self.config_params_model=config_params_model
|
||||||
|
self.xml_file = xml_file
|
||||||
|
self.out = out
|
||||||
|
if min_area:
|
||||||
|
self.min_area = float(min_area)
|
||||||
|
else:
|
||||||
|
self.min_area = 0
|
||||||
|
|
||||||
|
def resize_image(self,img_in,input_height,input_width):
|
||||||
|
return cv2.resize( img_in, ( input_width,input_height) ,interpolation=cv2.INTER_NEAREST)
|
||||||
|
|
||||||
|
|
||||||
|
def color_images(self,seg):
|
||||||
|
ann_u=range(self.n_classes)
|
||||||
|
if len(np.shape(seg))==3:
|
||||||
|
seg=seg[:,:,0]
|
||||||
|
|
||||||
|
seg_img=np.zeros((np.shape(seg)[0],np.shape(seg)[1],3)).astype(np.uint8)
|
||||||
|
colors=sns.color_palette("hls", self.n_classes)
|
||||||
|
|
||||||
|
for c in ann_u:
|
||||||
|
c=int(c)
|
||||||
|
segl=(seg==c)
|
||||||
|
seg_img[:,:,0][seg==c]=c
|
||||||
|
seg_img[:,:,1][seg==c]=c
|
||||||
|
seg_img[:,:,2][seg==c]=c
|
||||||
|
return seg_img
|
||||||
|
|
||||||
|
def otsu_copy_binary(self,img):
|
||||||
|
img_r=np.zeros((img.shape[0],img.shape[1],3))
|
||||||
|
img1=img[:,:,0]
|
||||||
|
|
||||||
|
#print(img.min())
|
||||||
|
#print(img[:,:,0].min())
|
||||||
|
#blur = cv2.GaussianBlur(img,(5,5))
|
||||||
|
#ret3,th3 = cv2.threshold(blur,0,255,cv2.THRESH_BINARY+cv2.THRESH_OTSU)
|
||||||
|
retval1, threshold1 = cv2.threshold(img1, 0, 255, cv2.THRESH_BINARY+cv2.THRESH_OTSU)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
img_r[:,:,0]=threshold1
|
||||||
|
img_r[:,:,1]=threshold1
|
||||||
|
img_r[:,:,2]=threshold1
|
||||||
|
#img_r=img_r/float(np.max(img_r))*255
|
||||||
|
return img_r
|
||||||
|
|
||||||
|
def otsu_copy(self,img):
|
||||||
|
img_r=np.zeros((img.shape[0],img.shape[1],3))
|
||||||
|
#img1=img[:,:,0]
|
||||||
|
|
||||||
|
#print(img.min())
|
||||||
|
#print(img[:,:,0].min())
|
||||||
|
#blur = cv2.GaussianBlur(img,(5,5))
|
||||||
|
#ret3,th3 = cv2.threshold(blur,0,255,cv2.THRESH_BINARY+cv2.THRESH_OTSU)
|
||||||
|
_, threshold1 = cv2.threshold(img[:,:,0], 0, 255, cv2.THRESH_BINARY+cv2.THRESH_OTSU)
|
||||||
|
_, threshold2 = cv2.threshold(img[:,:,1], 0, 255, cv2.THRESH_BINARY+cv2.THRESH_OTSU)
|
||||||
|
_, threshold3 = cv2.threshold(img[:,:,2], 0, 255, cv2.THRESH_BINARY+cv2.THRESH_OTSU)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
img_r[:,:,0]=threshold1
|
||||||
|
img_r[:,:,1]=threshold2
|
||||||
|
img_r[:,:,2]=threshold3
|
||||||
|
###img_r=img_r/float(np.max(img_r))*255
|
||||||
|
return img_r
|
||||||
|
|
||||||
|
def soft_dice_loss(self,y_true, y_pred, epsilon=1e-6):
|
||||||
|
|
||||||
|
axes = tuple(range(1, len(y_pred.shape)-1))
|
||||||
|
|
||||||
|
numerator = 2. * K.sum(y_pred * y_true, axes)
|
||||||
|
|
||||||
|
denominator = K.sum(K.square(y_pred) + K.square(y_true), axes)
|
||||||
|
return 1.00 - K.mean(numerator / (denominator + epsilon)) # average over classes and batch
|
||||||
|
|
||||||
|
def weighted_categorical_crossentropy(self,weights=None):
|
||||||
|
|
||||||
|
def loss(y_true, y_pred):
|
||||||
|
labels_floats = tf.cast(y_true, tf.float32)
|
||||||
|
per_pixel_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=labels_floats,logits=y_pred)
|
||||||
|
|
||||||
|
if weights is not None:
|
||||||
|
weight_mask = tf.maximum(tf.reduce_max(tf.constant(
|
||||||
|
np.array(weights, dtype=np.float32)[None, None, None])
|
||||||
|
* labels_floats, axis=-1), 1.0)
|
||||||
|
per_pixel_loss = per_pixel_loss * weight_mask[:, :, :, None]
|
||||||
|
return tf.reduce_mean(per_pixel_loss)
|
||||||
|
return self.loss
|
||||||
|
|
||||||
|
|
||||||
|
def IoU(self,Yi,y_predi):
|
||||||
|
## mean Intersection over Union
|
||||||
|
## Mean IoU = TP/(FN + TP + FP)
|
||||||
|
|
||||||
|
IoUs = []
|
||||||
|
Nclass = np.unique(Yi)
|
||||||
|
for c in Nclass:
|
||||||
|
TP = np.sum( (Yi == c)&(y_predi==c) )
|
||||||
|
FP = np.sum( (Yi != c)&(y_predi==c) )
|
||||||
|
FN = np.sum( (Yi == c)&(y_predi != c))
|
||||||
|
IoU = TP/float(TP + FP + FN)
|
||||||
|
if self.n_classes>2:
|
||||||
|
print("class {:02.0f}: #TP={:6.0f}, #FP={:6.0f}, #FN={:5.0f}, IoU={:4.3f}".format(c,TP,FP,FN,IoU))
|
||||||
|
IoUs.append(IoU)
|
||||||
|
if self.n_classes>2:
|
||||||
|
mIoU = np.mean(IoUs)
|
||||||
|
print("_________________")
|
||||||
|
print("Mean IoU: {:4.3f}".format(mIoU))
|
||||||
|
return mIoU
|
||||||
|
elif self.n_classes==2:
|
||||||
|
mIoU = IoUs[1]
|
||||||
|
print("_________________")
|
||||||
|
print("IoU: {:4.3f}".format(mIoU))
|
||||||
|
return mIoU
|
||||||
|
|
||||||
|
def start_new_session_and_model(self):
|
||||||
|
|
||||||
|
config = tf.compat.v1.ConfigProto()
|
||||||
|
config.gpu_options.allow_growth = True
|
||||||
|
|
||||||
|
session = tf.compat.v1.Session(config=config) # tf.InteractiveSession()
|
||||||
|
tensorflow_backend.set_session(session)
|
||||||
|
#tensorflow.keras.layers.custom_layer = PatchEncoder
|
||||||
|
#tensorflow.keras.layers.custom_layer = Patches
|
||||||
|
self.model = load_model(self.model_dir , compile=False,custom_objects = {"PatchEncoder": PatchEncoder, "Patches": Patches})
|
||||||
|
#config = tf.ConfigProto()
|
||||||
|
#config.gpu_options.allow_growth=True
|
||||||
|
|
||||||
|
#self.session = tf.InteractiveSession()
|
||||||
|
#keras.losses.custom_loss = self.weighted_categorical_crossentropy
|
||||||
|
#self.model = load_model(self.model_dir , compile=False)
|
||||||
|
|
||||||
|
|
||||||
|
##if self.weights_dir!=None:
|
||||||
|
##self.model.load_weights(self.weights_dir)
|
||||||
|
|
||||||
|
if self.task != 'classification' and self.task != 'reading_order':
|
||||||
|
self.img_height=self.model.layers[len(self.model.layers)-1].output_shape[1]
|
||||||
|
self.img_width=self.model.layers[len(self.model.layers)-1].output_shape[2]
|
||||||
|
self.n_classes=self.model.layers[len(self.model.layers)-1].output_shape[3]
|
||||||
|
|
||||||
|
def visualize_model_output(self, prediction, img, task):
|
||||||
|
if task == "binarization":
|
||||||
|
prediction = prediction * -1
|
||||||
|
prediction = prediction + 1
|
||||||
|
added_image = prediction * 255
|
||||||
|
layout_only = None
|
||||||
|
else:
|
||||||
|
unique_classes = np.unique(prediction[:,:,0])
|
||||||
|
rgb_colors = {'0' : [255, 255, 255],
|
||||||
|
'1' : [255, 0, 0],
|
||||||
|
'2' : [255, 125, 0],
|
||||||
|
'3' : [255, 0, 125],
|
||||||
|
'4' : [125, 125, 125],
|
||||||
|
'5' : [125, 125, 0],
|
||||||
|
'6' : [0, 125, 255],
|
||||||
|
'7' : [0, 125, 0],
|
||||||
|
'8' : [125, 125, 125],
|
||||||
|
'9' : [0, 125, 255],
|
||||||
|
'10' : [125, 0, 125],
|
||||||
|
'11' : [0, 255, 0],
|
||||||
|
'12' : [0, 0, 255],
|
||||||
|
'13' : [0, 255, 255],
|
||||||
|
'14' : [255, 125, 125],
|
||||||
|
'15' : [255, 0, 255]}
|
||||||
|
|
||||||
|
layout_only = np.zeros(prediction.shape)
|
||||||
|
|
||||||
|
for unq_class in unique_classes:
|
||||||
|
rgb_class_unique = rgb_colors[str(int(unq_class))]
|
||||||
|
layout_only[:,:,0][prediction[:,:,0]==unq_class] = rgb_class_unique[0]
|
||||||
|
layout_only[:,:,1][prediction[:,:,0]==unq_class] = rgb_class_unique[1]
|
||||||
|
layout_only[:,:,2][prediction[:,:,0]==unq_class] = rgb_class_unique[2]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
img = self.resize_image(img, layout_only.shape[0], layout_only.shape[1])
|
||||||
|
|
||||||
|
layout_only = layout_only.astype(np.int32)
|
||||||
|
img = img.astype(np.int32)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
added_image = cv2.addWeighted(img,0.5,layout_only,0.1,0)
|
||||||
|
|
||||||
|
return added_image, layout_only
|
||||||
|
|
||||||
|
def predict(self, image_dir):
|
||||||
|
if self.task == 'classification':
|
||||||
|
classes_names = self.config_params_model['classification_classes_name']
|
||||||
|
img_1ch = img=cv2.imread(image_dir, 0)
|
||||||
|
|
||||||
|
img_1ch = img_1ch / 255.0
|
||||||
|
img_1ch = cv2.resize(img_1ch, (self.config_params_model['input_height'], self.config_params_model['input_width']), interpolation=cv2.INTER_NEAREST)
|
||||||
|
img_in = np.zeros((1, img_1ch.shape[0], img_1ch.shape[1], 3))
|
||||||
|
img_in[0, :, :, 0] = img_1ch[:, :]
|
||||||
|
img_in[0, :, :, 1] = img_1ch[:, :]
|
||||||
|
img_in[0, :, :, 2] = img_1ch[:, :]
|
||||||
|
|
||||||
|
label_p_pred = self.model.predict(img_in, verbose=0)
|
||||||
|
index_class = np.argmax(label_p_pred[0])
|
||||||
|
|
||||||
|
print("Predicted Class: {}".format(classes_names[str(int(index_class))]))
|
||||||
|
elif self.task == 'reading_order':
|
||||||
|
img_height = self.config_params_model['input_height']
|
||||||
|
img_width = self.config_params_model['input_width']
|
||||||
|
|
||||||
|
tree_xml, root_xml, bb_coord_printspace, file_name, id_paragraph, id_header, co_text_paragraph, co_text_header, tot_region_ref, x_len, y_len, index_tot_regions, img_poly = read_xml(self.xml_file)
|
||||||
|
_, cy_main, x_min_main, x_max_main, y_min_main, y_max_main, _ = find_new_features_of_contours(co_text_header)
|
||||||
|
|
||||||
|
img_header_and_sep = np.zeros((y_len,x_len), dtype='uint8')
|
||||||
|
|
||||||
|
|
||||||
|
for j in range(len(cy_main)):
|
||||||
|
img_header_and_sep[int(y_max_main[j]):int(y_max_main[j])+12,int(x_min_main[j]):int(x_max_main[j]) ] = 1
|
||||||
|
|
||||||
|
co_text_all = co_text_paragraph + co_text_header
|
||||||
|
id_all_text = id_paragraph + id_header
|
||||||
|
|
||||||
|
|
||||||
|
##texts_corr_order_index = [index_tot_regions[tot_region_ref.index(i)] for i in id_all_text ]
|
||||||
|
##texts_corr_order_index_int = [int(x) for x in texts_corr_order_index]
|
||||||
|
texts_corr_order_index_int = list(np.array(range(len(co_text_all))))
|
||||||
|
|
||||||
|
#print(texts_corr_order_index_int)
|
||||||
|
|
||||||
|
max_area = 1
|
||||||
|
#print(np.shape(co_text_all[0]), len( np.shape(co_text_all[0]) ),'co_text_all')
|
||||||
|
#co_text_all = filter_contours_area_of_image_tables(img_poly, co_text_all, _, max_area, min_area)
|
||||||
|
#print(co_text_all,'co_text_all')
|
||||||
|
co_text_all, texts_corr_order_index_int, _ = filter_contours_area_of_image(img_poly, co_text_all, texts_corr_order_index_int, max_area, self.min_area)
|
||||||
|
|
||||||
|
#print(texts_corr_order_index_int)
|
||||||
|
|
||||||
|
#co_text_all = [co_text_all[index] for index in texts_corr_order_index_int]
|
||||||
|
id_all_text = [id_all_text[index] for index in texts_corr_order_index_int]
|
||||||
|
|
||||||
|
labels_con = np.zeros((y_len,x_len,len(co_text_all)),dtype='uint8')
|
||||||
|
for i in range(len(co_text_all)):
|
||||||
|
img_label = np.zeros((y_len,x_len,3),dtype='uint8')
|
||||||
|
img_label=cv2.fillPoly(img_label, pts =[co_text_all[i]], color=(1,1,1))
|
||||||
|
labels_con[:,:,i] = img_label[:,:,0]
|
||||||
|
|
||||||
|
if bb_coord_printspace:
|
||||||
|
#bb_coord_printspace[x,y,w,h,_,_]
|
||||||
|
x = bb_coord_printspace[0]
|
||||||
|
y = bb_coord_printspace[1]
|
||||||
|
w = bb_coord_printspace[2]
|
||||||
|
h = bb_coord_printspace[3]
|
||||||
|
labels_con = labels_con[y:y+h, x:x+w, :]
|
||||||
|
img_poly = img_poly[y:y+h, x:x+w, :]
|
||||||
|
img_header_and_sep = img_header_and_sep[y:y+h, x:x+w]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
img3= np.copy(img_poly)
|
||||||
|
labels_con = resize_image(labels_con, img_height, img_width)
|
||||||
|
|
||||||
|
img_header_and_sep = resize_image(img_header_and_sep, img_height, img_width)
|
||||||
|
|
||||||
|
img3= resize_image (img3, img_height, img_width)
|
||||||
|
img3 = img3.astype(np.uint16)
|
||||||
|
|
||||||
|
inference_bs = 1#4
|
||||||
|
|
||||||
|
input_1= np.zeros( (inference_bs, img_height, img_width,3))
|
||||||
|
|
||||||
|
|
||||||
|
starting_list_of_regions = [list(range(labels_con.shape[2]))]
|
||||||
|
|
||||||
|
index_update = 0
|
||||||
|
index_selected = starting_list_of_regions[0]
|
||||||
|
|
||||||
|
scalibility_num = 0
|
||||||
|
while index_update>=0:
|
||||||
|
ij_list = starting_list_of_regions[index_update]
|
||||||
|
i = ij_list[0]
|
||||||
|
ij_list.pop(0)
|
||||||
|
|
||||||
|
|
||||||
|
pr_list = []
|
||||||
|
post_list = []
|
||||||
|
|
||||||
|
batch_counter = 0
|
||||||
|
tot_counter = 1
|
||||||
|
|
||||||
|
tot_iteration = len(ij_list)
|
||||||
|
full_bs_ite= tot_iteration//inference_bs
|
||||||
|
last_bs = tot_iteration % inference_bs
|
||||||
|
|
||||||
|
jbatch_indexer =[]
|
||||||
|
for j in ij_list:
|
||||||
|
img1= np.repeat(labels_con[:,:,i][:, :, np.newaxis], 3, axis=2)
|
||||||
|
img2 = np.repeat(labels_con[:,:,j][:, :, np.newaxis], 3, axis=2)
|
||||||
|
|
||||||
|
|
||||||
|
img2[:,:,0][img3[:,:,0]==5] = 2
|
||||||
|
img2[:,:,0][img_header_and_sep[:,:]==1] = 3
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
img1[:,:,0][img3[:,:,0]==5] = 2
|
||||||
|
img1[:,:,0][img_header_and_sep[:,:]==1] = 3
|
||||||
|
|
||||||
|
#input_1= np.zeros( (height1, width1,3))
|
||||||
|
|
||||||
|
|
||||||
|
jbatch_indexer.append(j)
|
||||||
|
|
||||||
|
input_1[batch_counter,:,:,0] = img1[:,:,0]/3.
|
||||||
|
input_1[batch_counter,:,:,2] = img2[:,:,0]/3.
|
||||||
|
input_1[batch_counter,:,:,1] = img3[:,:,0]/5.
|
||||||
|
#input_1[batch_counter,:,:,:]= np.zeros( (batch_counter, height1, width1,3))
|
||||||
|
batch_counter = batch_counter+1
|
||||||
|
|
||||||
|
#input_1[:,:,0] = img1[:,:,0]/3.
|
||||||
|
#input_1[:,:,2] = img2[:,:,0]/3.
|
||||||
|
#input_1[:,:,1] = img3[:,:,0]/5.
|
||||||
|
|
||||||
|
if batch_counter==inference_bs or ( (tot_counter//inference_bs)==full_bs_ite and tot_counter%inference_bs==last_bs):
|
||||||
|
y_pr = self.model.predict(input_1 , verbose=0)
|
||||||
|
scalibility_num = scalibility_num+1
|
||||||
|
|
||||||
|
if batch_counter==inference_bs:
|
||||||
|
iteration_batches = inference_bs
|
||||||
|
else:
|
||||||
|
iteration_batches = last_bs
|
||||||
|
for jb in range(iteration_batches):
|
||||||
|
if y_pr[jb][0]>=0.5:
|
||||||
|
post_list.append(jbatch_indexer[jb])
|
||||||
|
else:
|
||||||
|
pr_list.append(jbatch_indexer[jb])
|
||||||
|
|
||||||
|
batch_counter = 0
|
||||||
|
jbatch_indexer = []
|
||||||
|
|
||||||
|
tot_counter = tot_counter+1
|
||||||
|
|
||||||
|
starting_list_of_regions, index_update = update_list_and_return_first_with_length_bigger_than_one(index_update, i, pr_list, post_list,starting_list_of_regions)
|
||||||
|
|
||||||
|
|
||||||
|
index_sort = [i[0] for i in starting_list_of_regions ]
|
||||||
|
|
||||||
|
id_all_text = np.array(id_all_text)[index_sort]
|
||||||
|
|
||||||
|
alltags=[elem.tag for elem in root_xml.iter()]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
link=alltags[0].split('}')[0]+'}'
|
||||||
|
name_space = alltags[0].split('}')[0]
|
||||||
|
name_space = name_space.split('{')[1]
|
||||||
|
|
||||||
|
page_element = root_xml.find(link+'Page')
|
||||||
|
|
||||||
|
"""
|
||||||
|
ro_subelement = ET.SubElement(page_element, 'ReadingOrder')
|
||||||
|
#print(page_element, 'page_element')
|
||||||
|
|
||||||
|
#new_element = ET.Element('ReadingOrder')
|
||||||
|
|
||||||
|
new_element_element = ET.Element('OrderedGroup')
|
||||||
|
new_element_element.set('id', "ro357564684568544579089")
|
||||||
|
|
||||||
|
for index, id_text in enumerate(id_all_text):
|
||||||
|
new_element_2 = ET.Element('RegionRefIndexed')
|
||||||
|
new_element_2.set('regionRef', id_all_text[index])
|
||||||
|
new_element_2.set('index', str(index_sort[index]))
|
||||||
|
|
||||||
|
new_element_element.append(new_element_2)
|
||||||
|
|
||||||
|
ro_subelement.append(new_element_element)
|
||||||
|
"""
|
||||||
|
##ro_subelement = ET.SubElement(page_element, 'ReadingOrder')
|
||||||
|
|
||||||
|
ro_subelement = ET.Element('ReadingOrder')
|
||||||
|
|
||||||
|
ro_subelement2 = ET.SubElement(ro_subelement, 'OrderedGroup')
|
||||||
|
ro_subelement2.set('id', "ro357564684568544579089")
|
||||||
|
|
||||||
|
for index, id_text in enumerate(id_all_text):
|
||||||
|
new_element_2 = ET.SubElement(ro_subelement2, 'RegionRefIndexed')
|
||||||
|
new_element_2.set('regionRef', id_all_text[index])
|
||||||
|
new_element_2.set('index', str(index))
|
||||||
|
|
||||||
|
if (link+'PrintSpace' in alltags) or (link+'Border' in alltags):
|
||||||
|
page_element.insert(1, ro_subelement)
|
||||||
|
else:
|
||||||
|
page_element.insert(0, ro_subelement)
|
||||||
|
|
||||||
|
alltags=[elem.tag for elem in root_xml.iter()]
|
||||||
|
|
||||||
|
ET.register_namespace("",name_space)
|
||||||
|
tree_xml.write(os.path.join(self.out, file_name+'.xml'),xml_declaration=True,method='xml',encoding="utf8",default_namespace=None)
|
||||||
|
#tree_xml.write('library2.xml')
|
||||||
|
|
||||||
|
else:
|
||||||
|
if self.patches:
|
||||||
|
#def textline_contours(img,input_width,input_height,n_classes,model):
|
||||||
|
|
||||||
|
img=cv2.imread(image_dir)
|
||||||
|
self.img_org = np.copy(img)
|
||||||
|
|
||||||
|
if img.shape[0] < self.img_height:
|
||||||
|
img = self.resize_image(img, self.img_height, img.shape[1])
|
||||||
|
|
||||||
|
if img.shape[1] < self.img_width:
|
||||||
|
img = self.resize_image(img, img.shape[0], self.img_width)
|
||||||
|
|
||||||
|
margin = int(0.1 * self.img_width)
|
||||||
|
width_mid = self.img_width - 2 * margin
|
||||||
|
height_mid = self.img_height - 2 * margin
|
||||||
|
img = img / float(255.0)
|
||||||
|
|
||||||
|
img_h = img.shape[0]
|
||||||
|
img_w = img.shape[1]
|
||||||
|
|
||||||
|
prediction_true = np.zeros((img_h, img_w, 3))
|
||||||
|
nxf = img_w / float(width_mid)
|
||||||
|
nyf = img_h / float(height_mid)
|
||||||
|
|
||||||
|
nxf = int(nxf) + 1 if nxf > int(nxf) else int(nxf)
|
||||||
|
nyf = int(nyf) + 1 if nyf > int(nyf) else int(nyf)
|
||||||
|
|
||||||
|
for i in range(nxf):
|
||||||
|
for j in range(nyf):
|
||||||
|
if i == 0:
|
||||||
|
index_x_d = i * width_mid
|
||||||
|
index_x_u = index_x_d + self.img_width
|
||||||
|
else:
|
||||||
|
index_x_d = i * width_mid
|
||||||
|
index_x_u = index_x_d + self.img_width
|
||||||
|
if j == 0:
|
||||||
|
index_y_d = j * height_mid
|
||||||
|
index_y_u = index_y_d + self.img_height
|
||||||
|
else:
|
||||||
|
index_y_d = j * height_mid
|
||||||
|
index_y_u = index_y_d + self.img_height
|
||||||
|
|
||||||
|
if index_x_u > img_w:
|
||||||
|
index_x_u = img_w
|
||||||
|
index_x_d = img_w - self.img_width
|
||||||
|
if index_y_u > img_h:
|
||||||
|
index_y_u = img_h
|
||||||
|
index_y_d = img_h - self.img_height
|
||||||
|
|
||||||
|
img_patch = img[index_y_d:index_y_u, index_x_d:index_x_u, :]
|
||||||
|
label_p_pred = self.model.predict(img_patch.reshape(1, img_patch.shape[0], img_patch.shape[1], img_patch.shape[2]),
|
||||||
|
verbose=0)
|
||||||
|
|
||||||
|
if self.task == 'enhancement':
|
||||||
|
seg = label_p_pred[0, :, :, :]
|
||||||
|
seg = seg * 255
|
||||||
|
elif self.task == 'segmentation' or self.task == 'binarization':
|
||||||
|
seg = np.argmax(label_p_pred, axis=3)[0]
|
||||||
|
seg = np.repeat(seg[:, :, np.newaxis], 3, axis=2)
|
||||||
|
|
||||||
|
|
||||||
|
if i == 0 and j == 0:
|
||||||
|
seg = seg[0 : seg.shape[0] - margin, 0 : seg.shape[1] - margin]
|
||||||
|
prediction_true[index_y_d + 0 : index_y_u - margin, index_x_d + 0 : index_x_u - margin, :] = seg
|
||||||
|
elif i == nxf - 1 and j == nyf - 1:
|
||||||
|
seg = seg[margin : seg.shape[0] - 0, margin : seg.shape[1] - 0]
|
||||||
|
prediction_true[index_y_d + margin : index_y_u - 0, index_x_d + margin : index_x_u - 0, :] = seg
|
||||||
|
elif i == 0 and j == nyf - 1:
|
||||||
|
seg = seg[margin : seg.shape[0] - 0, 0 : seg.shape[1] - margin]
|
||||||
|
prediction_true[index_y_d + margin : index_y_u - 0, index_x_d + 0 : index_x_u - margin, :] = seg
|
||||||
|
elif i == nxf - 1 and j == 0:
|
||||||
|
seg = seg[0 : seg.shape[0] - margin, margin : seg.shape[1] - 0]
|
||||||
|
prediction_true[index_y_d + 0 : index_y_u - margin, index_x_d + margin : index_x_u - 0, :] = seg
|
||||||
|
elif i == 0 and j != 0 and j != nyf - 1:
|
||||||
|
seg = seg[margin : seg.shape[0] - margin, 0 : seg.shape[1] - margin]
|
||||||
|
prediction_true[index_y_d + margin : index_y_u - margin, index_x_d + 0 : index_x_u - margin, :] = seg
|
||||||
|
elif i == nxf - 1 and j != 0 and j != nyf - 1:
|
||||||
|
seg = seg[margin : seg.shape[0] - margin, margin : seg.shape[1] - 0]
|
||||||
|
prediction_true[index_y_d + margin : index_y_u - margin, index_x_d + margin : index_x_u - 0, :] = seg
|
||||||
|
elif i != 0 and i != nxf - 1 and j == 0:
|
||||||
|
seg = seg[0 : seg.shape[0] - margin, margin : seg.shape[1] - margin]
|
||||||
|
prediction_true[index_y_d + 0 : index_y_u - margin, index_x_d + margin : index_x_u - margin, :] = seg
|
||||||
|
elif i != 0 and i != nxf - 1 and j == nyf - 1:
|
||||||
|
seg = seg[margin : seg.shape[0] - 0, margin : seg.shape[1] - margin]
|
||||||
|
prediction_true[index_y_d + margin : index_y_u - 0, index_x_d + margin : index_x_u - margin, :] = seg
|
||||||
|
else:
|
||||||
|
seg = seg[margin : seg.shape[0] - margin, margin : seg.shape[1] - margin]
|
||||||
|
prediction_true[index_y_d + margin : index_y_u - margin, index_x_d + margin : index_x_u - margin, :] = seg
|
||||||
|
prediction_true = prediction_true.astype(int)
|
||||||
|
prediction_true = cv2.resize(prediction_true, (self.img_org.shape[1], self.img_org.shape[0]), interpolation=cv2.INTER_NEAREST)
|
||||||
|
return prediction_true
|
||||||
|
|
||||||
|
else:
|
||||||
|
|
||||||
|
img=cv2.imread(image_dir)
|
||||||
|
self.img_org = np.copy(img)
|
||||||
|
|
||||||
|
width=self.img_width
|
||||||
|
height=self.img_height
|
||||||
|
|
||||||
|
img=img/255.0
|
||||||
|
img=self.resize_image(img,self.img_height,self.img_width)
|
||||||
|
|
||||||
|
|
||||||
|
label_p_pred=self.model.predict(
|
||||||
|
img.reshape(1,img.shape[0],img.shape[1],img.shape[2]))
|
||||||
|
|
||||||
|
if self.task == 'enhancement':
|
||||||
|
seg = label_p_pred[0, :, :, :]
|
||||||
|
seg = seg * 255
|
||||||
|
elif self.task == 'segmentation' or self.task == 'binarization':
|
||||||
|
seg = np.argmax(label_p_pred, axis=3)[0]
|
||||||
|
seg = np.repeat(seg[:, :, np.newaxis], 3, axis=2)
|
||||||
|
|
||||||
|
prediction_true = seg.astype(int)
|
||||||
|
|
||||||
|
prediction_true = cv2.resize(prediction_true, (self.img_org.shape[1], self.img_org.shape[0]), interpolation=cv2.INTER_NEAREST)
|
||||||
|
return prediction_true
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
self.start_new_session_and_model()
|
||||||
|
if self.image:
|
||||||
|
res=self.predict(image_dir = self.image)
|
||||||
|
|
||||||
|
if self.task == 'classification' or self.task == 'reading_order':
|
||||||
|
pass
|
||||||
|
elif self.task == 'enhancement':
|
||||||
|
if self.save:
|
||||||
|
cv2.imwrite(self.save,res)
|
||||||
|
else:
|
||||||
|
img_seg_overlayed, only_layout = self.visualize_model_output(res, self.img_org, self.task)
|
||||||
|
if self.save:
|
||||||
|
cv2.imwrite(self.save,img_seg_overlayed)
|
||||||
|
if self.save_layout:
|
||||||
|
cv2.imwrite(self.save_layout, only_layout)
|
||||||
|
|
||||||
|
if self.ground_truth:
|
||||||
|
gt_img=cv2.imread(self.ground_truth)
|
||||||
|
self.IoU(gt_img[:,:,0],res[:,:,0])
|
||||||
|
|
||||||
|
else:
|
||||||
|
ls_images = os.listdir(self.dir_in)
|
||||||
|
for ind_image in ls_images:
|
||||||
|
f_name = ind_image.split('.')[0]
|
||||||
|
image_dir = os.path.join(self.dir_in, ind_image)
|
||||||
|
res=self.predict(image_dir)
|
||||||
|
|
||||||
|
if self.task == 'classification' or self.task == 'reading_order':
|
||||||
|
pass
|
||||||
|
elif self.task == 'enhancement':
|
||||||
|
self.save = os.path.join(self.out, f_name+'.png')
|
||||||
|
cv2.imwrite(self.save,res)
|
||||||
|
else:
|
||||||
|
img_seg_overlayed, only_layout = self.visualize_model_output(res, self.img_org, self.task)
|
||||||
|
self.save = os.path.join(self.out, f_name+'_overlayed.png')
|
||||||
|
cv2.imwrite(self.save,img_seg_overlayed)
|
||||||
|
self.save_layout = os.path.join(self.out, f_name+'_layout.png')
|
||||||
|
cv2.imwrite(self.save_layout, only_layout)
|
||||||
|
|
||||||
|
if self.ground_truth:
|
||||||
|
gt_img=cv2.imread(self.ground_truth)
|
||||||
|
self.IoU(gt_img[:,:,0],res[:,:,0])
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@click.command()
|
||||||
|
@click.option(
|
||||||
|
"--image",
|
||||||
|
"-i",
|
||||||
|
help="image filename",
|
||||||
|
type=click.Path(exists=True, dir_okay=False),
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--dir_in",
|
||||||
|
"-di",
|
||||||
|
help="directory of images",
|
||||||
|
type=click.Path(exists=True, file_okay=False),
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--out",
|
||||||
|
"-o",
|
||||||
|
help="output directory where xml with detected reading order will be written.",
|
||||||
|
type=click.Path(exists=True, file_okay=False),
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--patches/--no-patches",
|
||||||
|
"-p/-nop",
|
||||||
|
is_flag=True,
|
||||||
|
help="if this parameter set to true, this tool will try to do inference in patches.",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--save",
|
||||||
|
"-s",
|
||||||
|
help="save prediction as a png file in current folder.",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--save_layout",
|
||||||
|
"-sl",
|
||||||
|
help="save layout prediction only as a png file in current folder.",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--model",
|
||||||
|
"-m",
|
||||||
|
help="directory of models",
|
||||||
|
type=click.Path(exists=True, file_okay=False),
|
||||||
|
required=True,
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--ground_truth",
|
||||||
|
"-gt",
|
||||||
|
help="ground truth directory if you want to see the iou of prediction.",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--xml_file",
|
||||||
|
"-xml",
|
||||||
|
help="xml file with layout coordinates that reading order detection will be implemented on. The result will be written in the same xml file.",
|
||||||
|
)
|
||||||
|
|
||||||
|
@click.option(
|
||||||
|
"--min_area",
|
||||||
|
"-min",
|
||||||
|
help="min area size of regions considered for reading order detection. The default value is zero and means that all text regions are considered for reading order.",
|
||||||
|
)
|
||||||
|
def main(image, dir_in, model, patches, save, save_layout, ground_truth, xml_file, out, min_area):
|
||||||
|
assert image or dir_in, "Either a single image -i or a dir_in -di is required"
|
||||||
|
with open(os.path.join(model,'config.json')) as f:
|
||||||
|
config_params_model = json.load(f)
|
||||||
|
task = config_params_model['task']
|
||||||
|
if task != 'classification' and task != 'reading_order':
|
||||||
|
if image and not save:
|
||||||
|
print("Error: You used one of segmentation or binarization task with image input but not set -s, you need a filename to save visualized output with -s")
|
||||||
|
sys.exit(1)
|
||||||
|
if dir_in and not out:
|
||||||
|
print("Error: You used one of segmentation or binarization task with dir_in but not set -out")
|
||||||
|
sys.exit(1)
|
||||||
|
x=sbb_predict(image, dir_in, model, task, config_params_model, patches, save, save_layout, ground_truth, xml_file, out, min_area)
|
||||||
|
x.run()
|
||||||
|
|
||||||
|
if __name__=="__main__":
|
||||||
|
main()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
357
train/metrics.py
Normal file
357
train/metrics.py
Normal file
|
@ -0,0 +1,357 @@
|
||||||
|
from tensorflow.keras import backend as K
|
||||||
|
import tensorflow as tf
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def focal_loss(gamma=2., alpha=4.):
|
||||||
|
gamma = float(gamma)
|
||||||
|
alpha = float(alpha)
|
||||||
|
|
||||||
|
def focal_loss_fixed(y_true, y_pred):
|
||||||
|
"""Focal loss for multi-classification
|
||||||
|
FL(p_t)=-alpha(1-p_t)^{gamma}ln(p_t)
|
||||||
|
Notice: y_pred is probability after softmax
|
||||||
|
gradient is d(Fl)/d(p_t) not d(Fl)/d(x) as described in paper
|
||||||
|
d(Fl)/d(p_t) * [p_t(1-p_t)] = d(Fl)/d(x)
|
||||||
|
Focal Loss for Dense Object Detection
|
||||||
|
https://arxiv.org/abs/1708.02002
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
y_true {tensor} -- ground truth labels, shape of [batch_size, num_cls]
|
||||||
|
y_pred {tensor} -- model's output, shape of [batch_size, num_cls]
|
||||||
|
|
||||||
|
Keyword Arguments:
|
||||||
|
gamma {float} -- (default: {2.0})
|
||||||
|
alpha {float} -- (default: {4.0})
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
[tensor] -- loss.
|
||||||
|
"""
|
||||||
|
epsilon = 1.e-9
|
||||||
|
y_true = tf.convert_to_tensor(y_true, tf.float32)
|
||||||
|
y_pred = tf.convert_to_tensor(y_pred, tf.float32)
|
||||||
|
|
||||||
|
model_out = tf.add(y_pred, epsilon)
|
||||||
|
ce = tf.multiply(y_true, -tf.log(model_out))
|
||||||
|
weight = tf.multiply(y_true, tf.pow(tf.subtract(1., model_out), gamma))
|
||||||
|
fl = tf.multiply(alpha, tf.multiply(weight, ce))
|
||||||
|
reduced_fl = tf.reduce_max(fl, axis=1)
|
||||||
|
return tf.reduce_mean(reduced_fl)
|
||||||
|
|
||||||
|
return focal_loss_fixed
|
||||||
|
|
||||||
|
|
||||||
|
def weighted_categorical_crossentropy(weights=None):
|
||||||
|
""" weighted_categorical_crossentropy
|
||||||
|
|
||||||
|
Args:
|
||||||
|
* weights<ktensor|nparray|list>: crossentropy weights
|
||||||
|
Returns:
|
||||||
|
* weighted categorical crossentropy function
|
||||||
|
"""
|
||||||
|
|
||||||
|
def loss(y_true, y_pred):
|
||||||
|
labels_floats = tf.cast(y_true, tf.float32)
|
||||||
|
per_pixel_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=labels_floats, logits=y_pred)
|
||||||
|
|
||||||
|
if weights is not None:
|
||||||
|
weight_mask = tf.maximum(tf.reduce_max(tf.constant(
|
||||||
|
np.array(weights, dtype=np.float32)[None, None, None])
|
||||||
|
* labels_floats, axis=-1), 1.0)
|
||||||
|
per_pixel_loss = per_pixel_loss * weight_mask[:, :, :, None]
|
||||||
|
return tf.reduce_mean(per_pixel_loss)
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
def image_categorical_cross_entropy(y_true, y_pred, weights=None):
|
||||||
|
"""
|
||||||
|
:param y_true: tensor of shape (batch_size, height, width) representing the ground truth.
|
||||||
|
:param y_pred: tensor of shape (batch_size, height, width) representing the prediction.
|
||||||
|
:return: The mean cross-entropy on softmaxed tensors.
|
||||||
|
"""
|
||||||
|
|
||||||
|
labels_floats = tf.cast(y_true, tf.float32)
|
||||||
|
per_pixel_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=labels_floats, logits=y_pred)
|
||||||
|
|
||||||
|
if weights is not None:
|
||||||
|
weight_mask = tf.maximum(
|
||||||
|
tf.reduce_max(tf.constant(
|
||||||
|
np.array(weights, dtype=np.float32)[None, None, None])
|
||||||
|
* labels_floats, axis=-1), 1.0)
|
||||||
|
per_pixel_loss = per_pixel_loss * weight_mask[:, :, :, None]
|
||||||
|
|
||||||
|
return tf.reduce_mean(per_pixel_loss)
|
||||||
|
|
||||||
|
|
||||||
|
def class_tversky(y_true, y_pred):
|
||||||
|
smooth = 1.0 # 1.00
|
||||||
|
|
||||||
|
y_true = K.permute_dimensions(y_true, (3, 1, 2, 0))
|
||||||
|
y_pred = K.permute_dimensions(y_pred, (3, 1, 2, 0))
|
||||||
|
|
||||||
|
y_true_pos = K.batch_flatten(y_true)
|
||||||
|
y_pred_pos = K.batch_flatten(y_pred)
|
||||||
|
true_pos = K.sum(y_true_pos * y_pred_pos, 1)
|
||||||
|
false_neg = K.sum(y_true_pos * (1 - y_pred_pos), 1)
|
||||||
|
false_pos = K.sum((1 - y_true_pos) * y_pred_pos, 1)
|
||||||
|
alpha = 0.2 # 0.5
|
||||||
|
beta = 0.8
|
||||||
|
return (true_pos + smooth) / (true_pos + alpha * false_neg + beta * false_pos + smooth)
|
||||||
|
|
||||||
|
|
||||||
|
def focal_tversky_loss(y_true, y_pred):
|
||||||
|
pt_1 = class_tversky(y_true, y_pred)
|
||||||
|
gamma = 1.3 # 4./3.0#1.3#4.0/3.00# 0.75
|
||||||
|
return K.sum(K.pow((1 - pt_1), gamma))
|
||||||
|
|
||||||
|
|
||||||
|
def generalized_dice_coeff2(y_true, y_pred):
|
||||||
|
n_el = 1
|
||||||
|
for dim in y_true.shape:
|
||||||
|
n_el *= int(dim)
|
||||||
|
n_cl = y_true.shape[-1]
|
||||||
|
w = K.zeros(shape=(n_cl,))
|
||||||
|
w = (K.sum(y_true, axis=(0, 1, 2))) / n_el
|
||||||
|
w = 1 / (w ** 2 + 0.000001)
|
||||||
|
numerator = y_true * y_pred
|
||||||
|
numerator = w * K.sum(numerator, (0, 1, 2))
|
||||||
|
numerator = K.sum(numerator)
|
||||||
|
denominator = y_true + y_pred
|
||||||
|
denominator = w * K.sum(denominator, (0, 1, 2))
|
||||||
|
denominator = K.sum(denominator)
|
||||||
|
return 2 * numerator / denominator
|
||||||
|
|
||||||
|
|
||||||
|
def generalized_dice_coeff(y_true, y_pred):
|
||||||
|
axes = tuple(range(1, len(y_pred.shape) - 1))
|
||||||
|
Ncl = y_pred.shape[-1]
|
||||||
|
w = K.zeros(shape=(Ncl,))
|
||||||
|
w = K.sum(y_true, axis=axes)
|
||||||
|
w = 1 / (w ** 2 + 0.000001)
|
||||||
|
# Compute gen dice coef:
|
||||||
|
numerator = y_true * y_pred
|
||||||
|
numerator = w * K.sum(numerator, axes)
|
||||||
|
numerator = K.sum(numerator)
|
||||||
|
|
||||||
|
denominator = y_true + y_pred
|
||||||
|
denominator = w * K.sum(denominator, axes)
|
||||||
|
denominator = K.sum(denominator)
|
||||||
|
|
||||||
|
gen_dice_coef = 2 * numerator / denominator
|
||||||
|
|
||||||
|
return gen_dice_coef
|
||||||
|
|
||||||
|
|
||||||
|
def generalized_dice_loss(y_true, y_pred):
|
||||||
|
return 1 - generalized_dice_coeff2(y_true, y_pred)
|
||||||
|
|
||||||
|
|
||||||
|
def soft_dice_loss(y_true, y_pred, epsilon=1e-6):
|
||||||
|
"""
|
||||||
|
Soft dice loss calculation for arbitrary batch size, number of classes, and number of spatial dimensions.
|
||||||
|
Assumes the `channels_last` format.
|
||||||
|
|
||||||
|
# Arguments
|
||||||
|
y_true: b x X x Y( x Z...) x c One hot encoding of ground truth
|
||||||
|
y_pred: b x X x Y( x Z...) x c Network output, must sum to 1 over c channel (such as after softmax)
|
||||||
|
epsilon: Used for numerical stability to avoid divide by zero errors
|
||||||
|
|
||||||
|
# References
|
||||||
|
V-Net: Fully Convolutional Neural Networks for Volumetric Medical Image Segmentation
|
||||||
|
https://arxiv.org/abs/1606.04797
|
||||||
|
More details on Dice loss formulation
|
||||||
|
https://mediatum.ub.tum.de/doc/1395260/1395260.pdf (page 72)
|
||||||
|
|
||||||
|
Adapted from https://github.com/Lasagne/Recipes/issues/99#issuecomment-347775022
|
||||||
|
"""
|
||||||
|
|
||||||
|
# skip the batch and class axis for calculating Dice score
|
||||||
|
axes = tuple(range(1, len(y_pred.shape) - 1))
|
||||||
|
|
||||||
|
numerator = 2. * K.sum(y_pred * y_true, axes)
|
||||||
|
|
||||||
|
denominator = K.sum(K.square(y_pred) + K.square(y_true), axes)
|
||||||
|
return 1.00 - K.mean(numerator / (denominator + epsilon)) # average over classes and batch
|
||||||
|
|
||||||
|
|
||||||
|
def seg_metrics(y_true, y_pred, metric_name, metric_type='standard', drop_last=True, mean_per_class=False,
|
||||||
|
verbose=False):
|
||||||
|
"""
|
||||||
|
Compute mean metrics of two segmentation masks, via Keras.
|
||||||
|
|
||||||
|
IoU(A,B) = |A & B| / (| A U B|)
|
||||||
|
Dice(A,B) = 2*|A & B| / (|A| + |B|)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
y_true: true masks, one-hot encoded.
|
||||||
|
y_pred: predicted masks, either softmax outputs, or one-hot encoded.
|
||||||
|
metric_name: metric to be computed, either 'iou' or 'dice'.
|
||||||
|
metric_type: one of 'standard' (default), 'soft', 'naive'.
|
||||||
|
In the standard version, y_pred is one-hot encoded and the mean
|
||||||
|
is taken only over classes that are present (in y_true or y_pred).
|
||||||
|
The 'soft' version of the metrics are computed without one-hot
|
||||||
|
encoding y_pred.
|
||||||
|
The 'naive' version return mean metrics where absent classes contribute
|
||||||
|
to the class mean as 1.0 (instead of being dropped from the mean).
|
||||||
|
drop_last = True: boolean flag to drop last class (usually reserved
|
||||||
|
for background class in semantic segmentation)
|
||||||
|
mean_per_class = False: return mean along batch axis for each class.
|
||||||
|
verbose = False: print intermediate results such as intersection, union
|
||||||
|
(as number of pixels).
|
||||||
|
Returns:
|
||||||
|
IoU/Dice of y_true and y_pred, as a float, unless mean_per_class == True
|
||||||
|
in which case it returns the per-class metric, averaged over the batch.
|
||||||
|
|
||||||
|
Inputs are B*W*H*N tensors, with
|
||||||
|
B = batch size,
|
||||||
|
W = width,
|
||||||
|
H = height,
|
||||||
|
N = number of classes
|
||||||
|
"""
|
||||||
|
|
||||||
|
flag_soft = (metric_type == 'soft')
|
||||||
|
flag_naive_mean = (metric_type == 'naive')
|
||||||
|
|
||||||
|
# always assume one or more classes
|
||||||
|
num_classes = K.shape(y_true)[-1]
|
||||||
|
|
||||||
|
if not flag_soft:
|
||||||
|
# get one-hot encoded masks from y_pred (true masks should already be one-hot)
|
||||||
|
y_pred = K.one_hot(K.argmax(y_pred), num_classes)
|
||||||
|
y_true = K.one_hot(K.argmax(y_true), num_classes)
|
||||||
|
|
||||||
|
# if already one-hot, could have skipped above command
|
||||||
|
# keras uses float32 instead of float64, would give error down (but numpy arrays or keras.to_categorical gives float64)
|
||||||
|
y_true = K.cast(y_true, 'float32')
|
||||||
|
y_pred = K.cast(y_pred, 'float32')
|
||||||
|
|
||||||
|
# intersection and union shapes are batch_size * n_classes (values = area in pixels)
|
||||||
|
axes = (1, 2) # W,H axes of each image
|
||||||
|
intersection = K.sum(K.abs(y_true * y_pred), axis=axes)
|
||||||
|
mask_sum = K.sum(K.abs(y_true), axis=axes) + K.sum(K.abs(y_pred), axis=axes)
|
||||||
|
union = mask_sum - intersection # or, np.logical_or(y_pred, y_true) for one-hot
|
||||||
|
|
||||||
|
smooth = .001
|
||||||
|
iou = (intersection + smooth) / (union + smooth)
|
||||||
|
dice = 2 * (intersection + smooth) / (mask_sum + smooth)
|
||||||
|
|
||||||
|
metric = {'iou': iou, 'dice': dice}[metric_name]
|
||||||
|
|
||||||
|
# define mask to be 0 when no pixels are present in either y_true or y_pred, 1 otherwise
|
||||||
|
mask = K.cast(K.not_equal(union, 0), 'float32')
|
||||||
|
|
||||||
|
if drop_last:
|
||||||
|
metric = metric[:, :-1]
|
||||||
|
mask = mask[:, :-1]
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
print('intersection, union')
|
||||||
|
print(K.eval(intersection), K.eval(union))
|
||||||
|
print(K.eval(intersection / union))
|
||||||
|
|
||||||
|
# return mean metrics: remaining axes are (batch, classes)
|
||||||
|
if flag_naive_mean:
|
||||||
|
return K.mean(metric)
|
||||||
|
|
||||||
|
# take mean only over non-absent classes
|
||||||
|
class_count = K.sum(mask, axis=0)
|
||||||
|
non_zero = tf.greater(class_count, 0)
|
||||||
|
non_zero_sum = tf.boolean_mask(K.sum(metric * mask, axis=0), non_zero)
|
||||||
|
non_zero_count = tf.boolean_mask(class_count, non_zero)
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
print('Counts of inputs with class present, metrics for non-absent classes')
|
||||||
|
print(K.eval(class_count), K.eval(non_zero_sum / non_zero_count))
|
||||||
|
|
||||||
|
return K.mean(non_zero_sum / non_zero_count)
|
||||||
|
|
||||||
|
|
||||||
|
def mean_iou(y_true, y_pred, **kwargs):
|
||||||
|
"""
|
||||||
|
Compute mean Intersection over Union of two segmentation masks, via Keras.
|
||||||
|
|
||||||
|
Calls metrics_k(y_true, y_pred, metric_name='iou'), see there for allowed kwargs.
|
||||||
|
"""
|
||||||
|
return seg_metrics(y_true, y_pred, metric_name='iou', **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def Mean_IOU(y_true, y_pred):
|
||||||
|
nb_classes = K.int_shape(y_pred)[-1]
|
||||||
|
iou = []
|
||||||
|
true_pixels = K.argmax(y_true, axis=-1)
|
||||||
|
pred_pixels = K.argmax(y_pred, axis=-1)
|
||||||
|
void_labels = K.equal(K.sum(y_true, axis=-1), 0)
|
||||||
|
for i in range(0, nb_classes): # exclude first label (background) and last label (void)
|
||||||
|
true_labels = K.equal(true_pixels, i) # & ~void_labels
|
||||||
|
pred_labels = K.equal(pred_pixels, i) # & ~void_labels
|
||||||
|
inter = tf.to_int32(true_labels & pred_labels)
|
||||||
|
union = tf.to_int32(true_labels | pred_labels)
|
||||||
|
legal_batches = K.sum(tf.to_int32(true_labels), axis=1) > 0
|
||||||
|
ious = K.sum(inter, axis=1) / K.sum(union, axis=1)
|
||||||
|
iou.append(K.mean(tf.gather(ious, indices=tf.where(legal_batches)))) # returns average IoU of the same objects
|
||||||
|
iou = tf.stack(iou)
|
||||||
|
legal_labels = ~tf.debugging.is_nan(iou)
|
||||||
|
iou = tf.gather(iou, indices=tf.where(legal_labels))
|
||||||
|
return K.mean(iou)
|
||||||
|
|
||||||
|
|
||||||
|
def iou_vahid(y_true, y_pred):
|
||||||
|
nb_classes = tf.shape(y_true)[-1] + tf.to_int32(1)
|
||||||
|
true_pixels = K.argmax(y_true, axis=-1)
|
||||||
|
pred_pixels = K.argmax(y_pred, axis=-1)
|
||||||
|
iou = []
|
||||||
|
|
||||||
|
for i in tf.range(nb_classes):
|
||||||
|
tp = K.sum(tf.to_int32(K.equal(true_pixels, i) & K.equal(pred_pixels, i)))
|
||||||
|
fp = K.sum(tf.to_int32(K.not_equal(true_pixels, i) & K.equal(pred_pixels, i)))
|
||||||
|
fn = K.sum(tf.to_int32(K.equal(true_pixels, i) & K.not_equal(pred_pixels, i)))
|
||||||
|
iouh = tp / (tp + fp + fn)
|
||||||
|
iou.append(iouh)
|
||||||
|
return K.mean(iou)
|
||||||
|
|
||||||
|
|
||||||
|
def IoU_metric(Yi, y_predi):
|
||||||
|
# mean Intersection over Union
|
||||||
|
# Mean IoU = TP/(FN + TP + FP)
|
||||||
|
y_predi = np.argmax(y_predi, axis=3)
|
||||||
|
y_testi = np.argmax(Yi, axis=3)
|
||||||
|
IoUs = []
|
||||||
|
Nclass = int(np.max(Yi)) + 1
|
||||||
|
for c in range(Nclass):
|
||||||
|
TP = np.sum((Yi == c) & (y_predi == c))
|
||||||
|
FP = np.sum((Yi != c) & (y_predi == c))
|
||||||
|
FN = np.sum((Yi == c) & (y_predi != c))
|
||||||
|
IoU = TP / float(TP + FP + FN)
|
||||||
|
IoUs.append(IoU)
|
||||||
|
return K.cast(np.mean(IoUs), dtype='float32')
|
||||||
|
|
||||||
|
|
||||||
|
def IoU_metric_keras(y_true, y_pred):
|
||||||
|
# mean Intersection over Union
|
||||||
|
# Mean IoU = TP/(FN + TP + FP)
|
||||||
|
init = tf.global_variables_initializer()
|
||||||
|
sess = tf.Session()
|
||||||
|
sess.run(init)
|
||||||
|
|
||||||
|
return IoU_metric(y_true.eval(session=sess), y_pred.eval(session=sess))
|
||||||
|
|
||||||
|
|
||||||
|
def jaccard_distance_loss(y_true, y_pred, smooth=100):
|
||||||
|
"""
|
||||||
|
Jaccard = (|X & Y|)/ (|X|+ |Y| - |X & Y|)
|
||||||
|
= sum(|A*B|)/(sum(|A|)+sum(|B|)-sum(|A*B|))
|
||||||
|
|
||||||
|
The jaccard distance loss is usefull for unbalanced datasets. This has been
|
||||||
|
shifted so it converges on 0 and is smoothed to avoid exploding or disapearing
|
||||||
|
gradient.
|
||||||
|
|
||||||
|
Ref: https://en.wikipedia.org/wiki/Jaccard_index
|
||||||
|
|
||||||
|
@url: https://gist.github.com/wassname/f1452b748efcbeb4cb9b1d059dce6f96
|
||||||
|
@author: wassname
|
||||||
|
"""
|
||||||
|
intersection = K.sum(K.abs(y_true * y_pred), axis=-1)
|
||||||
|
sum_ = K.sum(K.abs(y_true) + K.abs(y_pred), axis=-1)
|
||||||
|
jac = (intersection + smooth) / (sum_ - intersection + smooth)
|
||||||
|
return (1 - jac) * smooth
|
760
train/models.py
Normal file
760
train/models.py
Normal file
|
@ -0,0 +1,760 @@
|
||||||
|
import tensorflow as tf
|
||||||
|
from tensorflow import keras
|
||||||
|
from tensorflow.keras.models import *
|
||||||
|
from tensorflow.keras.layers import *
|
||||||
|
from tensorflow.keras import layers
|
||||||
|
from tensorflow.keras.regularizers import l2
|
||||||
|
|
||||||
|
##mlp_head_units = [512, 256]#[2048, 1024]
|
||||||
|
###projection_dim = 64
|
||||||
|
##transformer_layers = 2#8
|
||||||
|
##num_heads = 1#4
|
||||||
|
resnet50_Weights_path = './pretrained_model/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5'
|
||||||
|
IMAGE_ORDERING = 'channels_last'
|
||||||
|
MERGE_AXIS = -1
|
||||||
|
|
||||||
|
def mlp(x, hidden_units, dropout_rate):
|
||||||
|
for units in hidden_units:
|
||||||
|
x = layers.Dense(units, activation=tf.nn.gelu)(x)
|
||||||
|
x = layers.Dropout(dropout_rate)(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
class Patches(layers.Layer):
|
||||||
|
def __init__(self, patch_size_x, patch_size_y):#__init__(self, **kwargs):#:__init__(self, patch_size):#__init__(self, **kwargs):
|
||||||
|
super(Patches, self).__init__()
|
||||||
|
self.patch_size_x = patch_size_x
|
||||||
|
self.patch_size_y = patch_size_y
|
||||||
|
|
||||||
|
def call(self, images):
|
||||||
|
#print(tf.shape(images)[1],'images')
|
||||||
|
#print(self.patch_size,'self.patch_size')
|
||||||
|
batch_size = tf.shape(images)[0]
|
||||||
|
patches = tf.image.extract_patches(
|
||||||
|
images=images,
|
||||||
|
sizes=[1, self.patch_size_y, self.patch_size_x, 1],
|
||||||
|
strides=[1, self.patch_size_y, self.patch_size_x, 1],
|
||||||
|
rates=[1, 1, 1, 1],
|
||||||
|
padding="VALID",
|
||||||
|
)
|
||||||
|
#patch_dims = patches.shape[-1]
|
||||||
|
patch_dims = tf.shape(patches)[-1]
|
||||||
|
patches = tf.reshape(patches, [batch_size, -1, patch_dims])
|
||||||
|
return patches
|
||||||
|
def get_config(self):
|
||||||
|
|
||||||
|
config = super().get_config().copy()
|
||||||
|
config.update({
|
||||||
|
'patch_size_x': self.patch_size_x,
|
||||||
|
'patch_size_y': self.patch_size_y,
|
||||||
|
})
|
||||||
|
return config
|
||||||
|
|
||||||
|
class Patches_old(layers.Layer):
|
||||||
|
def __init__(self, patch_size):#__init__(self, **kwargs):#:__init__(self, patch_size):#__init__(self, **kwargs):
|
||||||
|
super(Patches, self).__init__()
|
||||||
|
self.patch_size = patch_size
|
||||||
|
|
||||||
|
def call(self, images):
|
||||||
|
#print(tf.shape(images)[1],'images')
|
||||||
|
#print(self.patch_size,'self.patch_size')
|
||||||
|
batch_size = tf.shape(images)[0]
|
||||||
|
patches = tf.image.extract_patches(
|
||||||
|
images=images,
|
||||||
|
sizes=[1, self.patch_size, self.patch_size, 1],
|
||||||
|
strides=[1, self.patch_size, self.patch_size, 1],
|
||||||
|
rates=[1, 1, 1, 1],
|
||||||
|
padding="VALID",
|
||||||
|
)
|
||||||
|
patch_dims = patches.shape[-1]
|
||||||
|
#print(patches.shape,patch_dims,'patch_dims')
|
||||||
|
patches = tf.reshape(patches, [batch_size, -1, patch_dims])
|
||||||
|
return patches
|
||||||
|
def get_config(self):
|
||||||
|
|
||||||
|
config = super().get_config().copy()
|
||||||
|
config.update({
|
||||||
|
'patch_size': self.patch_size,
|
||||||
|
})
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
class PatchEncoder(layers.Layer):
|
||||||
|
def __init__(self, num_patches, projection_dim):
|
||||||
|
super(PatchEncoder, self).__init__()
|
||||||
|
self.num_patches = num_patches
|
||||||
|
self.projection = layers.Dense(units=projection_dim)
|
||||||
|
self.position_embedding = layers.Embedding(
|
||||||
|
input_dim=num_patches, output_dim=projection_dim
|
||||||
|
)
|
||||||
|
|
||||||
|
def call(self, patch):
|
||||||
|
positions = tf.range(start=0, limit=self.num_patches, delta=1)
|
||||||
|
encoded = self.projection(patch) + self.position_embedding(positions)
|
||||||
|
return encoded
|
||||||
|
def get_config(self):
|
||||||
|
|
||||||
|
config = super().get_config().copy()
|
||||||
|
config.update({
|
||||||
|
'num_patches': self.num_patches,
|
||||||
|
'projection': self.projection,
|
||||||
|
'position_embedding': self.position_embedding,
|
||||||
|
})
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
def one_side_pad(x):
|
||||||
|
x = ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING)(x)
|
||||||
|
if IMAGE_ORDERING == 'channels_first':
|
||||||
|
x = Lambda(lambda x: x[:, :, :-1, :-1])(x)
|
||||||
|
elif IMAGE_ORDERING == 'channels_last':
|
||||||
|
x = Lambda(lambda x: x[:, :-1, :-1, :])(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def identity_block(input_tensor, kernel_size, filters, stage, block):
|
||||||
|
"""The identity block is the block that has no conv layer at shortcut.
|
||||||
|
# Arguments
|
||||||
|
input_tensor: input tensor
|
||||||
|
kernel_size: defualt 3, the kernel size of middle conv layer at main path
|
||||||
|
filters: list of integers, the filterss of 3 conv layer at main path
|
||||||
|
stage: integer, current stage label, used for generating layer names
|
||||||
|
block: 'a','b'..., current block label, used for generating layer names
|
||||||
|
# Returns
|
||||||
|
Output tensor for the block.
|
||||||
|
"""
|
||||||
|
filters1, filters2, filters3 = filters
|
||||||
|
|
||||||
|
if IMAGE_ORDERING == 'channels_last':
|
||||||
|
bn_axis = 3
|
||||||
|
else:
|
||||||
|
bn_axis = 1
|
||||||
|
|
||||||
|
conv_name_base = 'res' + str(stage) + block + '_branch'
|
||||||
|
bn_name_base = 'bn' + str(stage) + block + '_branch'
|
||||||
|
|
||||||
|
x = Conv2D(filters1, (1, 1), data_format=IMAGE_ORDERING, name=conv_name_base + '2a')(input_tensor)
|
||||||
|
x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2a')(x)
|
||||||
|
x = Activation('relu')(x)
|
||||||
|
|
||||||
|
x = Conv2D(filters2, kernel_size, data_format=IMAGE_ORDERING,
|
||||||
|
padding='same', name=conv_name_base + '2b')(x)
|
||||||
|
x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2b')(x)
|
||||||
|
x = Activation('relu')(x)
|
||||||
|
|
||||||
|
x = Conv2D(filters3, (1, 1), data_format=IMAGE_ORDERING, name=conv_name_base + '2c')(x)
|
||||||
|
x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2c')(x)
|
||||||
|
|
||||||
|
x = layers.add([x, input_tensor])
|
||||||
|
x = Activation('relu')(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def conv_block(input_tensor, kernel_size, filters, stage, block, strides=(2, 2)):
|
||||||
|
"""conv_block is the block that has a conv layer at shortcut
|
||||||
|
# Arguments
|
||||||
|
input_tensor: input tensor
|
||||||
|
kernel_size: defualt 3, the kernel size of middle conv layer at main path
|
||||||
|
filters: list of integers, the filterss of 3 conv layer at main path
|
||||||
|
stage: integer, current stage label, used for generating layer names
|
||||||
|
block: 'a','b'..., current block label, used for generating layer names
|
||||||
|
# Returns
|
||||||
|
Output tensor for the block.
|
||||||
|
Note that from stage 3, the first conv layer at main path is with strides=(2,2)
|
||||||
|
And the shortcut should have strides=(2,2) as well
|
||||||
|
"""
|
||||||
|
filters1, filters2, filters3 = filters
|
||||||
|
|
||||||
|
if IMAGE_ORDERING == 'channels_last':
|
||||||
|
bn_axis = 3
|
||||||
|
else:
|
||||||
|
bn_axis = 1
|
||||||
|
|
||||||
|
conv_name_base = 'res' + str(stage) + block + '_branch'
|
||||||
|
bn_name_base = 'bn' + str(stage) + block + '_branch'
|
||||||
|
|
||||||
|
x = Conv2D(filters1, (1, 1), data_format=IMAGE_ORDERING, strides=strides,
|
||||||
|
name=conv_name_base + '2a')(input_tensor)
|
||||||
|
x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2a')(x)
|
||||||
|
x = Activation('relu')(x)
|
||||||
|
|
||||||
|
x = Conv2D(filters2, kernel_size, data_format=IMAGE_ORDERING, padding='same',
|
||||||
|
name=conv_name_base + '2b')(x)
|
||||||
|
x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2b')(x)
|
||||||
|
x = Activation('relu')(x)
|
||||||
|
|
||||||
|
x = Conv2D(filters3, (1, 1), data_format=IMAGE_ORDERING, name=conv_name_base + '2c')(x)
|
||||||
|
x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2c')(x)
|
||||||
|
|
||||||
|
shortcut = Conv2D(filters3, (1, 1), data_format=IMAGE_ORDERING, strides=strides,
|
||||||
|
name=conv_name_base + '1')(input_tensor)
|
||||||
|
shortcut = BatchNormalization(axis=bn_axis, name=bn_name_base + '1')(shortcut)
|
||||||
|
|
||||||
|
x = layers.add([x, shortcut])
|
||||||
|
x = Activation('relu')(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def resnet50_unet_light(n_classes, input_height=224, input_width=224, taks="segmentation", weight_decay=1e-6, pretraining=False):
|
||||||
|
assert input_height % 32 == 0
|
||||||
|
assert input_width % 32 == 0
|
||||||
|
|
||||||
|
img_input = Input(shape=(input_height, input_width, 3))
|
||||||
|
|
||||||
|
if IMAGE_ORDERING == 'channels_last':
|
||||||
|
bn_axis = 3
|
||||||
|
else:
|
||||||
|
bn_axis = 1
|
||||||
|
|
||||||
|
x = ZeroPadding2D((3, 3), data_format=IMAGE_ORDERING)(img_input)
|
||||||
|
x = Conv2D(64, (7, 7), data_format=IMAGE_ORDERING, strides=(2, 2), kernel_regularizer=l2(weight_decay),
|
||||||
|
name='conv1')(x)
|
||||||
|
f1 = x
|
||||||
|
|
||||||
|
x = BatchNormalization(axis=bn_axis, name='bn_conv1')(x)
|
||||||
|
x = Activation('relu')(x)
|
||||||
|
x = MaxPooling2D((3, 3), data_format=IMAGE_ORDERING, strides=(2, 2))(x)
|
||||||
|
|
||||||
|
x = conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1))
|
||||||
|
x = identity_block(x, 3, [64, 64, 256], stage=2, block='b')
|
||||||
|
x = identity_block(x, 3, [64, 64, 256], stage=2, block='c')
|
||||||
|
f2 = one_side_pad(x)
|
||||||
|
|
||||||
|
x = conv_block(x, 3, [128, 128, 512], stage=3, block='a')
|
||||||
|
x = identity_block(x, 3, [128, 128, 512], stage=3, block='b')
|
||||||
|
x = identity_block(x, 3, [128, 128, 512], stage=3, block='c')
|
||||||
|
x = identity_block(x, 3, [128, 128, 512], stage=3, block='d')
|
||||||
|
f3 = x
|
||||||
|
|
||||||
|
x = conv_block(x, 3, [256, 256, 1024], stage=4, block='a')
|
||||||
|
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='b')
|
||||||
|
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='c')
|
||||||
|
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='d')
|
||||||
|
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='e')
|
||||||
|
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='f')
|
||||||
|
f4 = x
|
||||||
|
|
||||||
|
x = conv_block(x, 3, [512, 512, 2048], stage=5, block='a')
|
||||||
|
x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b')
|
||||||
|
x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c')
|
||||||
|
f5 = x
|
||||||
|
|
||||||
|
if pretraining:
|
||||||
|
model = Model(img_input, x).load_weights(resnet50_Weights_path)
|
||||||
|
|
||||||
|
v512_2048 = Conv2D(512, (1, 1), padding='same', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay))(f5)
|
||||||
|
v512_2048 = (BatchNormalization(axis=bn_axis))(v512_2048)
|
||||||
|
v512_2048 = Activation('relu')(v512_2048)
|
||||||
|
|
||||||
|
v512_1024 = Conv2D(512, (1, 1), padding='same', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay))(f4)
|
||||||
|
v512_1024 = (BatchNormalization(axis=bn_axis))(v512_1024)
|
||||||
|
v512_1024 = Activation('relu')(v512_1024)
|
||||||
|
|
||||||
|
o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(v512_2048)
|
||||||
|
o = (concatenate([o, v512_1024], axis=MERGE_AXIS))
|
||||||
|
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
|
||||||
|
o = (Conv2D(512, (3, 3), padding='valid', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay)))(o)
|
||||||
|
o = (BatchNormalization(axis=bn_axis))(o)
|
||||||
|
o = Activation('relu')(o)
|
||||||
|
|
||||||
|
o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o)
|
||||||
|
o = (concatenate([o, f3], axis=MERGE_AXIS))
|
||||||
|
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
|
||||||
|
o = (Conv2D(256, (3, 3), padding='valid', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay)))(o)
|
||||||
|
o = (BatchNormalization(axis=bn_axis))(o)
|
||||||
|
o = Activation('relu')(o)
|
||||||
|
|
||||||
|
o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o)
|
||||||
|
o = (concatenate([o, f2], axis=MERGE_AXIS))
|
||||||
|
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
|
||||||
|
o = (Conv2D(128, (3, 3), padding='valid', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay)))(o)
|
||||||
|
o = (BatchNormalization(axis=bn_axis))(o)
|
||||||
|
o = Activation('relu')(o)
|
||||||
|
|
||||||
|
o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o)
|
||||||
|
o = (concatenate([o, f1], axis=MERGE_AXIS))
|
||||||
|
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
|
||||||
|
o = (Conv2D(64, (3, 3), padding='valid', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay)))(o)
|
||||||
|
o = (BatchNormalization(axis=bn_axis))(o)
|
||||||
|
o = Activation('relu')(o)
|
||||||
|
|
||||||
|
o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o)
|
||||||
|
o = (concatenate([o, img_input], axis=MERGE_AXIS))
|
||||||
|
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
|
||||||
|
o = (Conv2D(32, (3, 3), padding='valid', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay)))(o)
|
||||||
|
o = (BatchNormalization(axis=bn_axis))(o)
|
||||||
|
o = Activation('relu')(o)
|
||||||
|
|
||||||
|
o = Conv2D(n_classes, (1, 1), padding='same', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay))(o)
|
||||||
|
if task == "segmentation":
|
||||||
|
o = (BatchNormalization(axis=bn_axis))(o)
|
||||||
|
o = (Activation('softmax'))(o)
|
||||||
|
else:
|
||||||
|
o = (Activation('sigmoid'))(o)
|
||||||
|
|
||||||
|
model = Model(img_input, o)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def resnet50_unet(n_classes, input_height=224, input_width=224, task="segmentation", weight_decay=1e-6, pretraining=False):
|
||||||
|
assert input_height % 32 == 0
|
||||||
|
assert input_width % 32 == 0
|
||||||
|
|
||||||
|
img_input = Input(shape=(input_height, input_width, 3))
|
||||||
|
|
||||||
|
if IMAGE_ORDERING == 'channels_last':
|
||||||
|
bn_axis = 3
|
||||||
|
else:
|
||||||
|
bn_axis = 1
|
||||||
|
|
||||||
|
x = ZeroPadding2D((3, 3), data_format=IMAGE_ORDERING)(img_input)
|
||||||
|
x = Conv2D(64, (7, 7), data_format=IMAGE_ORDERING, strides=(2, 2), kernel_regularizer=l2(weight_decay),
|
||||||
|
name='conv1')(x)
|
||||||
|
f1 = x
|
||||||
|
|
||||||
|
x = BatchNormalization(axis=bn_axis, name='bn_conv1')(x)
|
||||||
|
x = Activation('relu')(x)
|
||||||
|
x = MaxPooling2D((3, 3), data_format=IMAGE_ORDERING, strides=(2, 2))(x)
|
||||||
|
|
||||||
|
x = conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1))
|
||||||
|
x = identity_block(x, 3, [64, 64, 256], stage=2, block='b')
|
||||||
|
x = identity_block(x, 3, [64, 64, 256], stage=2, block='c')
|
||||||
|
f2 = one_side_pad(x)
|
||||||
|
|
||||||
|
x = conv_block(x, 3, [128, 128, 512], stage=3, block='a')
|
||||||
|
x = identity_block(x, 3, [128, 128, 512], stage=3, block='b')
|
||||||
|
x = identity_block(x, 3, [128, 128, 512], stage=3, block='c')
|
||||||
|
x = identity_block(x, 3, [128, 128, 512], stage=3, block='d')
|
||||||
|
f3 = x
|
||||||
|
|
||||||
|
x = conv_block(x, 3, [256, 256, 1024], stage=4, block='a')
|
||||||
|
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='b')
|
||||||
|
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='c')
|
||||||
|
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='d')
|
||||||
|
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='e')
|
||||||
|
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='f')
|
||||||
|
f4 = x
|
||||||
|
|
||||||
|
x = conv_block(x, 3, [512, 512, 2048], stage=5, block='a')
|
||||||
|
x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b')
|
||||||
|
x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c')
|
||||||
|
f5 = x
|
||||||
|
|
||||||
|
if pretraining:
|
||||||
|
Model(img_input, x).load_weights(resnet50_Weights_path)
|
||||||
|
|
||||||
|
v1024_2048 = Conv2D(1024, (1, 1), padding='same', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay))(
|
||||||
|
f5)
|
||||||
|
v1024_2048 = (BatchNormalization(axis=bn_axis))(v1024_2048)
|
||||||
|
v1024_2048 = Activation('relu')(v1024_2048)
|
||||||
|
|
||||||
|
o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(v1024_2048)
|
||||||
|
o = (concatenate([o, f4], axis=MERGE_AXIS))
|
||||||
|
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
|
||||||
|
o = (Conv2D(512, (3, 3), padding='valid', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay)))(o)
|
||||||
|
o = (BatchNormalization(axis=bn_axis))(o)
|
||||||
|
o = Activation('relu')(o)
|
||||||
|
|
||||||
|
o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o)
|
||||||
|
o = (concatenate([o, f3], axis=MERGE_AXIS))
|
||||||
|
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
|
||||||
|
o = (Conv2D(256, (3, 3), padding='valid', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay)))(o)
|
||||||
|
o = (BatchNormalization(axis=bn_axis))(o)
|
||||||
|
o = Activation('relu')(o)
|
||||||
|
|
||||||
|
o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o)
|
||||||
|
o = (concatenate([o, f2], axis=MERGE_AXIS))
|
||||||
|
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
|
||||||
|
o = (Conv2D(128, (3, 3), padding='valid', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay)))(o)
|
||||||
|
o = (BatchNormalization(axis=bn_axis))(o)
|
||||||
|
o = Activation('relu')(o)
|
||||||
|
|
||||||
|
o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o)
|
||||||
|
o = (concatenate([o, f1], axis=MERGE_AXIS))
|
||||||
|
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
|
||||||
|
o = (Conv2D(64, (3, 3), padding='valid', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay)))(o)
|
||||||
|
o = (BatchNormalization(axis=bn_axis))(o)
|
||||||
|
o = Activation('relu')(o)
|
||||||
|
|
||||||
|
o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o)
|
||||||
|
o = (concatenate([o, img_input], axis=MERGE_AXIS))
|
||||||
|
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
|
||||||
|
o = (Conv2D(32, (3, 3), padding='valid', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay)))(o)
|
||||||
|
o = (BatchNormalization(axis=bn_axis))(o)
|
||||||
|
o = Activation('relu')(o)
|
||||||
|
|
||||||
|
o = Conv2D(n_classes, (1, 1), padding='same', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay))(o)
|
||||||
|
if task == "segmentation":
|
||||||
|
o = (BatchNormalization(axis=bn_axis))(o)
|
||||||
|
o = (Activation('softmax'))(o)
|
||||||
|
else:
|
||||||
|
o = (Activation('sigmoid'))(o)
|
||||||
|
|
||||||
|
model = Model(img_input, o)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def vit_resnet50_unet(n_classes, patch_size_x, patch_size_y, num_patches, mlp_head_units=None, transformer_layers=8, num_heads =4, projection_dim = 64, input_height=224, input_width=224, task="segmentation", weight_decay=1e-6, pretraining=False):
|
||||||
|
if mlp_head_units is None:
|
||||||
|
mlp_head_units = [128, 64]
|
||||||
|
inputs = layers.Input(shape=(input_height, input_width, 3))
|
||||||
|
|
||||||
|
#transformer_units = [
|
||||||
|
#projection_dim * 2,
|
||||||
|
#projection_dim,
|
||||||
|
#] # Size of the transformer layers
|
||||||
|
IMAGE_ORDERING = 'channels_last'
|
||||||
|
bn_axis=3
|
||||||
|
|
||||||
|
x = ZeroPadding2D((3, 3), data_format=IMAGE_ORDERING)(inputs)
|
||||||
|
x = Conv2D(64, (7, 7), data_format=IMAGE_ORDERING, strides=(2, 2),kernel_regularizer=l2(weight_decay), name='conv1')(x)
|
||||||
|
f1 = x
|
||||||
|
|
||||||
|
x = BatchNormalization(axis=bn_axis, name='bn_conv1')(x)
|
||||||
|
x = Activation('relu')(x)
|
||||||
|
x = MaxPooling2D((3, 3), data_format=IMAGE_ORDERING, strides=(2, 2))(x)
|
||||||
|
|
||||||
|
x = conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1))
|
||||||
|
x = identity_block(x, 3, [64, 64, 256], stage=2, block='b')
|
||||||
|
x = identity_block(x, 3, [64, 64, 256], stage=2, block='c')
|
||||||
|
f2 = one_side_pad(x)
|
||||||
|
|
||||||
|
x = conv_block(x, 3, [128, 128, 512], stage=3, block='a')
|
||||||
|
x = identity_block(x, 3, [128, 128, 512], stage=3, block='b')
|
||||||
|
x = identity_block(x, 3, [128, 128, 512], stage=3, block='c')
|
||||||
|
x = identity_block(x, 3, [128, 128, 512], stage=3, block='d')
|
||||||
|
f3 = x
|
||||||
|
|
||||||
|
x = conv_block(x, 3, [256, 256, 1024], stage=4, block='a')
|
||||||
|
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='b')
|
||||||
|
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='c')
|
||||||
|
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='d')
|
||||||
|
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='e')
|
||||||
|
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='f')
|
||||||
|
f4 = x
|
||||||
|
|
||||||
|
x = conv_block(x, 3, [512, 512, 2048], stage=5, block='a')
|
||||||
|
x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b')
|
||||||
|
x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c')
|
||||||
|
f5 = x
|
||||||
|
|
||||||
|
if pretraining:
|
||||||
|
model = Model(inputs, x).load_weights(resnet50_Weights_path)
|
||||||
|
|
||||||
|
#num_patches = x.shape[1]*x.shape[2]
|
||||||
|
|
||||||
|
#patch_size_y = input_height / x.shape[1]
|
||||||
|
#patch_size_x = input_width / x.shape[2]
|
||||||
|
#patch_size = patch_size_x * patch_size_y
|
||||||
|
patches = Patches(patch_size_x, patch_size_y)(x)
|
||||||
|
# Encode patches.
|
||||||
|
encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)
|
||||||
|
|
||||||
|
for _ in range(transformer_layers):
|
||||||
|
# Layer normalization 1.
|
||||||
|
x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
|
||||||
|
# Create a multi-head attention layer.
|
||||||
|
attention_output = layers.MultiHeadAttention(
|
||||||
|
num_heads=num_heads, key_dim=projection_dim, dropout=0.1
|
||||||
|
)(x1, x1)
|
||||||
|
# Skip connection 1.
|
||||||
|
x2 = layers.Add()([attention_output, encoded_patches])
|
||||||
|
# Layer normalization 2.
|
||||||
|
x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
|
||||||
|
# MLP.
|
||||||
|
x3 = mlp(x3, hidden_units=mlp_head_units, dropout_rate=0.1)
|
||||||
|
# Skip connection 2.
|
||||||
|
encoded_patches = layers.Add()([x3, x2])
|
||||||
|
|
||||||
|
encoded_patches = tf.reshape(encoded_patches, [-1, x.shape[1], x.shape[2] , int( projection_dim / (patch_size_x * patch_size_y) )])
|
||||||
|
|
||||||
|
v1024_2048 = Conv2D( 1024 , (1, 1), padding='same', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay))(encoded_patches)
|
||||||
|
v1024_2048 = (BatchNormalization(axis=bn_axis))(v1024_2048)
|
||||||
|
v1024_2048 = Activation('relu')(v1024_2048)
|
||||||
|
|
||||||
|
o = (UpSampling2D( (2, 2), data_format=IMAGE_ORDERING))(v1024_2048)
|
||||||
|
o = (concatenate([o, f4],axis=MERGE_AXIS))
|
||||||
|
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
|
||||||
|
o = (Conv2D(512, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o)
|
||||||
|
o = (BatchNormalization(axis=bn_axis))(o)
|
||||||
|
o = Activation('relu')(o)
|
||||||
|
|
||||||
|
o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o)
|
||||||
|
o = (concatenate([o ,f3], axis=MERGE_AXIS))
|
||||||
|
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
|
||||||
|
o = (Conv2D(256, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o)
|
||||||
|
o = (BatchNormalization(axis=bn_axis))(o)
|
||||||
|
o = Activation('relu')(o)
|
||||||
|
|
||||||
|
o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o)
|
||||||
|
o = (concatenate([o, f2], axis=MERGE_AXIS))
|
||||||
|
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
|
||||||
|
o = (Conv2D(128, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o)
|
||||||
|
o = (BatchNormalization(axis=bn_axis))(o)
|
||||||
|
o = Activation('relu')(o)
|
||||||
|
|
||||||
|
o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o)
|
||||||
|
o = (concatenate([o, f1], axis=MERGE_AXIS))
|
||||||
|
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
|
||||||
|
o = (Conv2D(64, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o)
|
||||||
|
o = (BatchNormalization(axis=bn_axis))(o)
|
||||||
|
o = Activation('relu')(o)
|
||||||
|
|
||||||
|
o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o)
|
||||||
|
o = (concatenate([o, inputs],axis=MERGE_AXIS))
|
||||||
|
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
|
||||||
|
o = (Conv2D(32, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o)
|
||||||
|
o = (BatchNormalization(axis=bn_axis))(o)
|
||||||
|
o = Activation('relu')(o)
|
||||||
|
|
||||||
|
o = Conv2D(n_classes, (1, 1), padding='same', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay))(o)
|
||||||
|
if task == "segmentation":
|
||||||
|
o = (BatchNormalization(axis=bn_axis))(o)
|
||||||
|
o = (Activation('softmax'))(o)
|
||||||
|
else:
|
||||||
|
o = (Activation('sigmoid'))(o)
|
||||||
|
|
||||||
|
model = Model(inputs=inputs, outputs=o)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
def vit_resnet50_unet_transformer_before_cnn(n_classes, patch_size_x, patch_size_y, num_patches, mlp_head_units=None, transformer_layers=8, num_heads =4, projection_dim = 64, input_height=224, input_width=224, task="segmentation", weight_decay=1e-6, pretraining=False):
|
||||||
|
if mlp_head_units is None:
|
||||||
|
mlp_head_units = [128, 64]
|
||||||
|
inputs = layers.Input(shape=(input_height, input_width, 3))
|
||||||
|
|
||||||
|
##transformer_units = [
|
||||||
|
##projection_dim * 2,
|
||||||
|
##projection_dim,
|
||||||
|
##] # Size of the transformer layers
|
||||||
|
IMAGE_ORDERING = 'channels_last'
|
||||||
|
bn_axis=3
|
||||||
|
|
||||||
|
patches = Patches(patch_size_x, patch_size_y)(inputs)
|
||||||
|
# Encode patches.
|
||||||
|
encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)
|
||||||
|
|
||||||
|
for _ in range(transformer_layers):
|
||||||
|
# Layer normalization 1.
|
||||||
|
x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
|
||||||
|
# Create a multi-head attention layer.
|
||||||
|
attention_output = layers.MultiHeadAttention(
|
||||||
|
num_heads=num_heads, key_dim=projection_dim, dropout=0.1
|
||||||
|
)(x1, x1)
|
||||||
|
# Skip connection 1.
|
||||||
|
x2 = layers.Add()([attention_output, encoded_patches])
|
||||||
|
# Layer normalization 2.
|
||||||
|
x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
|
||||||
|
# MLP.
|
||||||
|
x3 = mlp(x3, hidden_units=mlp_head_units, dropout_rate=0.1)
|
||||||
|
# Skip connection 2.
|
||||||
|
encoded_patches = layers.Add()([x3, x2])
|
||||||
|
|
||||||
|
encoded_patches = tf.reshape(encoded_patches, [-1, input_height, input_width , int( projection_dim / (patch_size_x * patch_size_y) )])
|
||||||
|
|
||||||
|
encoded_patches = Conv2D(3, (1, 1), padding='same', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay), name='convinput')(encoded_patches)
|
||||||
|
|
||||||
|
x = ZeroPadding2D((3, 3), data_format=IMAGE_ORDERING)(encoded_patches)
|
||||||
|
x = Conv2D(64, (7, 7), data_format=IMAGE_ORDERING, strides=(2, 2),kernel_regularizer=l2(weight_decay), name='conv1')(x)
|
||||||
|
f1 = x
|
||||||
|
|
||||||
|
x = BatchNormalization(axis=bn_axis, name='bn_conv1')(x)
|
||||||
|
x = Activation('relu')(x)
|
||||||
|
x = MaxPooling2D((3, 3), data_format=IMAGE_ORDERING, strides=(2, 2))(x)
|
||||||
|
|
||||||
|
x = conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1))
|
||||||
|
x = identity_block(x, 3, [64, 64, 256], stage=2, block='b')
|
||||||
|
x = identity_block(x, 3, [64, 64, 256], stage=2, block='c')
|
||||||
|
f2 = one_side_pad(x)
|
||||||
|
|
||||||
|
x = conv_block(x, 3, [128, 128, 512], stage=3, block='a')
|
||||||
|
x = identity_block(x, 3, [128, 128, 512], stage=3, block='b')
|
||||||
|
x = identity_block(x, 3, [128, 128, 512], stage=3, block='c')
|
||||||
|
x = identity_block(x, 3, [128, 128, 512], stage=3, block='d')
|
||||||
|
f3 = x
|
||||||
|
|
||||||
|
x = conv_block(x, 3, [256, 256, 1024], stage=4, block='a')
|
||||||
|
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='b')
|
||||||
|
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='c')
|
||||||
|
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='d')
|
||||||
|
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='e')
|
||||||
|
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='f')
|
||||||
|
f4 = x
|
||||||
|
|
||||||
|
x = conv_block(x, 3, [512, 512, 2048], stage=5, block='a')
|
||||||
|
x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b')
|
||||||
|
x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c')
|
||||||
|
f5 = x
|
||||||
|
|
||||||
|
if pretraining:
|
||||||
|
model = Model(encoded_patches, x).load_weights(resnet50_Weights_path)
|
||||||
|
|
||||||
|
v1024_2048 = Conv2D( 1024 , (1, 1), padding='same', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay))(x)
|
||||||
|
v1024_2048 = (BatchNormalization(axis=bn_axis))(v1024_2048)
|
||||||
|
v1024_2048 = Activation('relu')(v1024_2048)
|
||||||
|
|
||||||
|
o = (UpSampling2D( (2, 2), data_format=IMAGE_ORDERING))(v1024_2048)
|
||||||
|
o = (concatenate([o, f4],axis=MERGE_AXIS))
|
||||||
|
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
|
||||||
|
o = (Conv2D(512, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o)
|
||||||
|
o = (BatchNormalization(axis=bn_axis))(o)
|
||||||
|
o = Activation('relu')(o)
|
||||||
|
|
||||||
|
o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o)
|
||||||
|
o = (concatenate([o ,f3], axis=MERGE_AXIS))
|
||||||
|
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
|
||||||
|
o = (Conv2D(256, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o)
|
||||||
|
o = (BatchNormalization(axis=bn_axis))(o)
|
||||||
|
o = Activation('relu')(o)
|
||||||
|
|
||||||
|
o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o)
|
||||||
|
o = (concatenate([o, f2], axis=MERGE_AXIS))
|
||||||
|
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
|
||||||
|
o = (Conv2D(128, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o)
|
||||||
|
o = (BatchNormalization(axis=bn_axis))(o)
|
||||||
|
o = Activation('relu')(o)
|
||||||
|
|
||||||
|
o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o)
|
||||||
|
o = (concatenate([o, f1], axis=MERGE_AXIS))
|
||||||
|
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
|
||||||
|
o = (Conv2D(64, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o)
|
||||||
|
o = (BatchNormalization(axis=bn_axis))(o)
|
||||||
|
o = Activation('relu')(o)
|
||||||
|
|
||||||
|
o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o)
|
||||||
|
o = (concatenate([o, inputs],axis=MERGE_AXIS))
|
||||||
|
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
|
||||||
|
o = (Conv2D(32, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o)
|
||||||
|
o = (BatchNormalization(axis=bn_axis))(o)
|
||||||
|
o = Activation('relu')(o)
|
||||||
|
|
||||||
|
o = Conv2D(n_classes, (1, 1), padding='same', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay))(o)
|
||||||
|
if task == "segmentation":
|
||||||
|
o = (BatchNormalization(axis=bn_axis))(o)
|
||||||
|
o = (Activation('softmax'))(o)
|
||||||
|
else:
|
||||||
|
o = (Activation('sigmoid'))(o)
|
||||||
|
|
||||||
|
model = Model(inputs=inputs, outputs=o)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
def resnet50_classifier(n_classes,input_height=224,input_width=224,weight_decay=1e-6,pretraining=False):
|
||||||
|
include_top=True
|
||||||
|
assert input_height%32 == 0
|
||||||
|
assert input_width%32 == 0
|
||||||
|
|
||||||
|
|
||||||
|
img_input = Input(shape=(input_height,input_width , 3 ))
|
||||||
|
|
||||||
|
if IMAGE_ORDERING == 'channels_last':
|
||||||
|
bn_axis = 3
|
||||||
|
else:
|
||||||
|
bn_axis = 1
|
||||||
|
|
||||||
|
x = ZeroPadding2D((3, 3), data_format=IMAGE_ORDERING)(img_input)
|
||||||
|
x = Conv2D(64, (7, 7), data_format=IMAGE_ORDERING, strides=(2, 2),kernel_regularizer=l2(weight_decay), name='conv1')(x)
|
||||||
|
f1 = x
|
||||||
|
|
||||||
|
x = BatchNormalization(axis=bn_axis, name='bn_conv1')(x)
|
||||||
|
x = Activation('relu')(x)
|
||||||
|
x = MaxPooling2D((3, 3) , data_format=IMAGE_ORDERING , strides=(2, 2))(x)
|
||||||
|
|
||||||
|
|
||||||
|
x = conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1))
|
||||||
|
x = identity_block(x, 3, [64, 64, 256], stage=2, block='b')
|
||||||
|
x = identity_block(x, 3, [64, 64, 256], stage=2, block='c')
|
||||||
|
f2 = one_side_pad(x )
|
||||||
|
|
||||||
|
|
||||||
|
x = conv_block(x, 3, [128, 128, 512], stage=3, block='a')
|
||||||
|
x = identity_block(x, 3, [128, 128, 512], stage=3, block='b')
|
||||||
|
x = identity_block(x, 3, [128, 128, 512], stage=3, block='c')
|
||||||
|
x = identity_block(x, 3, [128, 128, 512], stage=3, block='d')
|
||||||
|
f3 = x
|
||||||
|
|
||||||
|
x = conv_block(x, 3, [256, 256, 1024], stage=4, block='a')
|
||||||
|
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='b')
|
||||||
|
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='c')
|
||||||
|
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='d')
|
||||||
|
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='e')
|
||||||
|
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='f')
|
||||||
|
f4 = x
|
||||||
|
|
||||||
|
x = conv_block(x, 3, [512, 512, 2048], stage=5, block='a')
|
||||||
|
x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b')
|
||||||
|
x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c')
|
||||||
|
f5 = x
|
||||||
|
|
||||||
|
if pretraining:
|
||||||
|
Model(img_input, x).load_weights(resnet50_Weights_path)
|
||||||
|
|
||||||
|
x = AveragePooling2D((7, 7), name='avg_pool')(x)
|
||||||
|
x = Flatten()(x)
|
||||||
|
|
||||||
|
##
|
||||||
|
x = Dense(256, activation='relu', name='fc512')(x)
|
||||||
|
x=Dropout(0.2)(x)
|
||||||
|
##
|
||||||
|
x = Dense(n_classes, activation='softmax', name='fc1000')(x)
|
||||||
|
model = Model(img_input, x)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
def machine_based_reading_order_model(n_classes,input_height=224,input_width=224,weight_decay=1e-6,pretraining=False):
|
||||||
|
assert input_height%32 == 0
|
||||||
|
assert input_width%32 == 0
|
||||||
|
|
||||||
|
img_input = Input(shape=(input_height,input_width , 3 ))
|
||||||
|
|
||||||
|
if IMAGE_ORDERING == 'channels_last':
|
||||||
|
bn_axis = 3
|
||||||
|
else:
|
||||||
|
bn_axis = 1
|
||||||
|
|
||||||
|
x1 = ZeroPadding2D((3, 3), data_format=IMAGE_ORDERING)(img_input)
|
||||||
|
x1 = Conv2D(64, (7, 7), data_format=IMAGE_ORDERING, strides=(2, 2),kernel_regularizer=l2(weight_decay), name='conv1')(x1)
|
||||||
|
|
||||||
|
x1 = BatchNormalization(axis=bn_axis, name='bn_conv1')(x1)
|
||||||
|
x1 = Activation('relu')(x1)
|
||||||
|
x1 = MaxPooling2D((3, 3) , data_format=IMAGE_ORDERING , strides=(2, 2))(x1)
|
||||||
|
|
||||||
|
x1 = conv_block(x1, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1))
|
||||||
|
x1 = identity_block(x1, 3, [64, 64, 256], stage=2, block='b')
|
||||||
|
x1 = identity_block(x1, 3, [64, 64, 256], stage=2, block='c')
|
||||||
|
|
||||||
|
x1 = conv_block(x1, 3, [128, 128, 512], stage=3, block='a')
|
||||||
|
x1 = identity_block(x1, 3, [128, 128, 512], stage=3, block='b')
|
||||||
|
x1 = identity_block(x1, 3, [128, 128, 512], stage=3, block='c')
|
||||||
|
x1 = identity_block(x1, 3, [128, 128, 512], stage=3, block='d')
|
||||||
|
|
||||||
|
x1 = conv_block(x1, 3, [256, 256, 1024], stage=4, block='a')
|
||||||
|
x1 = identity_block(x1, 3, [256, 256, 1024], stage=4, block='b')
|
||||||
|
x1 = identity_block(x1, 3, [256, 256, 1024], stage=4, block='c')
|
||||||
|
x1 = identity_block(x1, 3, [256, 256, 1024], stage=4, block='d')
|
||||||
|
x1 = identity_block(x1, 3, [256, 256, 1024], stage=4, block='e')
|
||||||
|
x1 = identity_block(x1, 3, [256, 256, 1024], stage=4, block='f')
|
||||||
|
|
||||||
|
x1 = conv_block(x1, 3, [512, 512, 2048], stage=5, block='a')
|
||||||
|
x1 = identity_block(x1, 3, [512, 512, 2048], stage=5, block='b')
|
||||||
|
x1 = identity_block(x1, 3, [512, 512, 2048], stage=5, block='c')
|
||||||
|
|
||||||
|
if pretraining:
|
||||||
|
Model(img_input , x1).load_weights(resnet50_Weights_path)
|
||||||
|
|
||||||
|
x1 = AveragePooling2D((7, 7), name='avg_pool1')(x1)
|
||||||
|
flattened = Flatten()(x1)
|
||||||
|
|
||||||
|
o = Dense(256, activation='relu', name='fc512')(flattened)
|
||||||
|
o=Dropout(0.2)(o)
|
||||||
|
|
||||||
|
o = Dense(256, activation='relu', name='fc512a')(o)
|
||||||
|
o=Dropout(0.2)(o)
|
||||||
|
|
||||||
|
o = Dense(n_classes, activation='sigmoid', name='fc1000')(o)
|
||||||
|
model = Model(img_input , o)
|
||||||
|
|
||||||
|
return model
|
5
train/requirements.txt
Normal file
5
train/requirements.txt
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
sacred
|
||||||
|
seaborn
|
||||||
|
tqdm
|
||||||
|
imutils
|
||||||
|
scipy
|
3
train/scales_enhancement.json
Normal file
3
train/scales_enhancement.json
Normal file
|
@ -0,0 +1,3 @@
|
||||||
|
{
|
||||||
|
"scales" : [ 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9]
|
||||||
|
}
|
450
train/train.py
Normal file
450
train/train.py
Normal file
|
@ -0,0 +1,450 @@
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
||||||
|
import tensorflow as tf
|
||||||
|
from tensorflow.compat.v1.keras.backend import set_session
|
||||||
|
import warnings
|
||||||
|
from tensorflow.keras.optimizers import *
|
||||||
|
from sacred import Experiment
|
||||||
|
from models import *
|
||||||
|
from utils import *
|
||||||
|
from metrics import *
|
||||||
|
from tensorflow.keras.models import load_model
|
||||||
|
from tqdm import tqdm
|
||||||
|
import json
|
||||||
|
from sklearn.metrics import f1_score
|
||||||
|
from tensorflow.keras.callbacks import Callback
|
||||||
|
|
||||||
|
class SaveWeightsAfterSteps(Callback):
|
||||||
|
def __init__(self, save_interval, save_path, _config):
|
||||||
|
super(SaveWeightsAfterSteps, self).__init__()
|
||||||
|
self.save_interval = save_interval
|
||||||
|
self.save_path = save_path
|
||||||
|
self.step_count = 0
|
||||||
|
self._config = _config
|
||||||
|
|
||||||
|
def on_train_batch_end(self, batch, logs=None):
|
||||||
|
self.step_count += 1
|
||||||
|
|
||||||
|
if self.step_count % self.save_interval ==0:
|
||||||
|
save_file = f"{self.save_path}/model_step_{self.step_count}"
|
||||||
|
#os.system('mkdir '+save_file)
|
||||||
|
|
||||||
|
self.model.save(save_file)
|
||||||
|
|
||||||
|
with open(os.path.join(os.path.join(self.save_path, f"model_step_{self.step_count}"),"config.json"), "w") as fp:
|
||||||
|
json.dump(self._config, fp) # encode dict into JSON
|
||||||
|
print(f"saved model as steps {self.step_count} to {save_file}")
|
||||||
|
|
||||||
|
|
||||||
|
def configuration():
|
||||||
|
config = tf.compat.v1.ConfigProto()
|
||||||
|
config.gpu_options.allow_growth = True
|
||||||
|
session = tf.compat.v1.Session(config=config)
|
||||||
|
set_session(session)
|
||||||
|
|
||||||
|
|
||||||
|
def get_dirs_or_files(input_data):
|
||||||
|
if os.path.isdir(input_data):
|
||||||
|
image_input, labels_input = os.path.join(input_data, 'images/'), os.path.join(input_data, 'labels/')
|
||||||
|
# Check if training dir exists
|
||||||
|
assert os.path.isdir(image_input), "{} is not a directory".format(image_input)
|
||||||
|
assert os.path.isdir(labels_input), "{} is not a directory".format(labels_input)
|
||||||
|
return image_input, labels_input
|
||||||
|
|
||||||
|
|
||||||
|
ex = Experiment(save_git_info=False)
|
||||||
|
|
||||||
|
|
||||||
|
@ex.config
|
||||||
|
def config_params():
|
||||||
|
n_classes = None # Number of classes. In the case of binary classification this should be 2.
|
||||||
|
n_epochs = 1 # Number of epochs.
|
||||||
|
input_height = 224 * 1 # Height of model's input in pixels.
|
||||||
|
input_width = 224 * 1 # Width of model's input in pixels.
|
||||||
|
weight_decay = 1e-6 # Weight decay of l2 regularization of model layers.
|
||||||
|
n_batch = 1 # Number of batches at each iteration.
|
||||||
|
learning_rate = 1e-4 # Set the learning rate.
|
||||||
|
patches = False # Divides input image into smaller patches (input size of the model) when set to true. For the model to see the full image, like page extraction, set this to false.
|
||||||
|
augmentation = False # To apply any kind of augmentation, this parameter must be set to true.
|
||||||
|
flip_aug = False # If true, different types of flipping will be applied to the image. Types of flips are defined with "flip_index" in config_params.json.
|
||||||
|
blur_aug = False # If true, different types of blurring will be applied to the image. Types of blur are defined with "blur_k" in config_params.json.
|
||||||
|
padding_white = False # If true, white padding will be applied to the image.
|
||||||
|
padding_black = False # If true, black padding will be applied to the image.
|
||||||
|
scaling = False # If true, scaling will be applied to the image. The amount of scaling is defined with "scales" in config_params.json.
|
||||||
|
shifting = False
|
||||||
|
degrading = False # If true, degrading will be applied to the image. The amount of degrading is defined with "degrade_scales" in config_params.json.
|
||||||
|
brightening = False # If true, brightening will be applied to the image. The amount of brightening is defined with "brightness" in config_params.json.
|
||||||
|
binarization = False # If true, Otsu thresholding will be applied to augment the input with binarized images.
|
||||||
|
adding_rgb_background = False
|
||||||
|
adding_rgb_foreground = False
|
||||||
|
add_red_textlines = False
|
||||||
|
channels_shuffling = False
|
||||||
|
dir_train = None # Directory of training dataset with subdirectories having the names "images" and "labels".
|
||||||
|
dir_eval = None # Directory of validation dataset with subdirectories having the names "images" and "labels".
|
||||||
|
dir_output = None # Directory where the output model will be saved.
|
||||||
|
pretraining = False # Set to true to load pretrained weights of ResNet50 encoder.
|
||||||
|
scaling_bluring = False # If true, a combination of scaling and blurring will be applied to the image.
|
||||||
|
scaling_binarization = False # If true, a combination of scaling and binarization will be applied to the image.
|
||||||
|
rotation = False # If true, a 90 degree rotation will be implemeneted.
|
||||||
|
rotation_not_90 = False # If true rotation based on provided angles with thetha will be implemeneted.
|
||||||
|
scaling_brightness = False # If true, a combination of scaling and brightening will be applied to the image.
|
||||||
|
scaling_flip = False # If true, a combination of scaling and flipping will be applied to the image.
|
||||||
|
thetha = None # Rotate image by these angles for augmentation.
|
||||||
|
shuffle_indexes = None
|
||||||
|
blur_k = None # Blur image for augmentation.
|
||||||
|
scales = None # Scale patches for augmentation.
|
||||||
|
degrade_scales = None # Degrade image for augmentation.
|
||||||
|
brightness = None # Brighten image for augmentation.
|
||||||
|
flip_index = None # Flip image for augmentation.
|
||||||
|
continue_training = False # Set to true if you would like to continue training an already trained a model.
|
||||||
|
transformer_patchsize_x = None # Patch size of vision transformer patches in x direction.
|
||||||
|
transformer_patchsize_y = None # Patch size of vision transformer patches in y direction.
|
||||||
|
transformer_num_patches_xy = None # Number of patches for vision transformer in x and y direction respectively.
|
||||||
|
transformer_projection_dim = 64 # Transformer projection dimension. Default value is 64.
|
||||||
|
transformer_mlp_head_units = [128, 64] # Transformer Multilayer Perceptron (MLP) head units. Default value is [128, 64]
|
||||||
|
transformer_layers = 8 # transformer layers. Default value is 8.
|
||||||
|
transformer_num_heads = 4 # Transformer number of heads. Default value is 4.
|
||||||
|
transformer_cnn_first = True # We have two types of vision transformers. In one type, a CNN is applied first, followed by a transformer. In the other type, this order is reversed. If transformer_cnn_first is true, it means the CNN will be applied before the transformer. Default value is true.
|
||||||
|
index_start = 0 # Index of model to continue training from. E.g. if you trained for 3 epochs and last index is 2, to continue from model_1.h5, set "index_start" to 3 to start naming model with index 3.
|
||||||
|
dir_of_start_model = '' # Directory containing pretrained encoder to continue training the model.
|
||||||
|
is_loss_soft_dice = False # Use soft dice as loss function. When set to true, "weighted_loss" must be false.
|
||||||
|
weighted_loss = False # Use weighted categorical cross entropy as loss fucntion. When set to true, "is_loss_soft_dice" must be false.
|
||||||
|
data_is_provided = False # Only set this to true when you have already provided the input data and the train and eval data are in "dir_output".
|
||||||
|
task = "segmentation" # This parameter defines task of model which can be segmentation, enhancement or classification.
|
||||||
|
f1_threshold_classification = None # This threshold is used to consider models with an evaluation f1 scores bigger than it. The selected model weights undergo a weights ensembling. And avreage ensembled model will be written to output.
|
||||||
|
classification_classes_name = None # Dictionary of classification classes names.
|
||||||
|
backbone_type = None # As backbone we have 2 types of backbones. A vision transformer alongside a CNN and we call it "transformer" and only CNN called "nontransformer"
|
||||||
|
save_interval = None
|
||||||
|
dir_img_bin = None
|
||||||
|
number_of_backgrounds_per_image = 1
|
||||||
|
dir_rgb_backgrounds = None
|
||||||
|
dir_rgb_foregrounds = None
|
||||||
|
|
||||||
|
|
||||||
|
@ex.automain
|
||||||
|
def run(_config, n_classes, n_epochs, input_height,
|
||||||
|
input_width, weight_decay, weighted_loss,
|
||||||
|
index_start, dir_of_start_model, is_loss_soft_dice,
|
||||||
|
n_batch, patches, augmentation, flip_aug,
|
||||||
|
blur_aug, padding_white, padding_black, scaling, shifting, degrading,channels_shuffling,
|
||||||
|
brightening, binarization, adding_rgb_background, adding_rgb_foreground, add_red_textlines, blur_k, scales, degrade_scales,shuffle_indexes,
|
||||||
|
brightness, dir_train, data_is_provided, scaling_bluring,
|
||||||
|
scaling_brightness, scaling_binarization, rotation, rotation_not_90,
|
||||||
|
thetha, scaling_flip, continue_training, transformer_projection_dim,
|
||||||
|
transformer_mlp_head_units, transformer_layers, transformer_num_heads, transformer_cnn_first,
|
||||||
|
transformer_patchsize_x, transformer_patchsize_y,
|
||||||
|
transformer_num_patches_xy, backbone_type, save_interval, flip_index, dir_eval, dir_output,
|
||||||
|
pretraining, learning_rate, task, f1_threshold_classification, classification_classes_name, dir_img_bin, number_of_backgrounds_per_image,dir_rgb_backgrounds, dir_rgb_foregrounds):
|
||||||
|
|
||||||
|
if dir_rgb_backgrounds:
|
||||||
|
list_all_possible_background_images = os.listdir(dir_rgb_backgrounds)
|
||||||
|
else:
|
||||||
|
list_all_possible_background_images = None
|
||||||
|
|
||||||
|
if dir_rgb_foregrounds:
|
||||||
|
list_all_possible_foreground_rgbs = os.listdir(dir_rgb_foregrounds)
|
||||||
|
else:
|
||||||
|
list_all_possible_foreground_rgbs = None
|
||||||
|
|
||||||
|
if task == "segmentation" or task == "enhancement" or task == "binarization":
|
||||||
|
if data_is_provided:
|
||||||
|
dir_train_flowing = os.path.join(dir_output, 'train')
|
||||||
|
dir_eval_flowing = os.path.join(dir_output, 'eval')
|
||||||
|
|
||||||
|
|
||||||
|
dir_flow_train_imgs = os.path.join(dir_train_flowing, 'images')
|
||||||
|
dir_flow_train_labels = os.path.join(dir_train_flowing, 'labels')
|
||||||
|
|
||||||
|
dir_flow_eval_imgs = os.path.join(dir_eval_flowing, 'images')
|
||||||
|
dir_flow_eval_labels = os.path.join(dir_eval_flowing, 'labels')
|
||||||
|
|
||||||
|
configuration()
|
||||||
|
|
||||||
|
else:
|
||||||
|
dir_img, dir_seg = get_dirs_or_files(dir_train)
|
||||||
|
dir_img_val, dir_seg_val = get_dirs_or_files(dir_eval)
|
||||||
|
|
||||||
|
# make first a directory in output for both training and evaluations in order to flow data from these directories.
|
||||||
|
dir_train_flowing = os.path.join(dir_output, 'train')
|
||||||
|
dir_eval_flowing = os.path.join(dir_output, 'eval')
|
||||||
|
|
||||||
|
dir_flow_train_imgs = os.path.join(dir_train_flowing, 'images/')
|
||||||
|
dir_flow_train_labels = os.path.join(dir_train_flowing, 'labels/')
|
||||||
|
|
||||||
|
dir_flow_eval_imgs = os.path.join(dir_eval_flowing, 'images/')
|
||||||
|
dir_flow_eval_labels = os.path.join(dir_eval_flowing, 'labels/')
|
||||||
|
|
||||||
|
if os.path.isdir(dir_train_flowing):
|
||||||
|
os.system('rm -rf ' + dir_train_flowing)
|
||||||
|
os.makedirs(dir_train_flowing)
|
||||||
|
else:
|
||||||
|
os.makedirs(dir_train_flowing)
|
||||||
|
|
||||||
|
if os.path.isdir(dir_eval_flowing):
|
||||||
|
os.system('rm -rf ' + dir_eval_flowing)
|
||||||
|
os.makedirs(dir_eval_flowing)
|
||||||
|
else:
|
||||||
|
os.makedirs(dir_eval_flowing)
|
||||||
|
|
||||||
|
os.mkdir(dir_flow_train_imgs)
|
||||||
|
os.mkdir(dir_flow_train_labels)
|
||||||
|
|
||||||
|
os.mkdir(dir_flow_eval_imgs)
|
||||||
|
os.mkdir(dir_flow_eval_labels)
|
||||||
|
|
||||||
|
# set the gpu configuration
|
||||||
|
configuration()
|
||||||
|
|
||||||
|
imgs_list=np.array(os.listdir(dir_img))
|
||||||
|
segs_list=np.array(os.listdir(dir_seg))
|
||||||
|
|
||||||
|
imgs_list_test=np.array(os.listdir(dir_img_val))
|
||||||
|
segs_list_test=np.array(os.listdir(dir_seg_val))
|
||||||
|
|
||||||
|
# writing patches into a sub-folder in order to be flowed from directory.
|
||||||
|
provide_patches(imgs_list, segs_list, dir_img, dir_seg, dir_flow_train_imgs,
|
||||||
|
dir_flow_train_labels, input_height, input_width, blur_k,
|
||||||
|
blur_aug, padding_white, padding_black, flip_aug, binarization, adding_rgb_background,adding_rgb_foreground, add_red_textlines, channels_shuffling,
|
||||||
|
scaling, shifting, degrading, brightening, scales, degrade_scales, brightness,
|
||||||
|
flip_index,shuffle_indexes, scaling_bluring, scaling_brightness, scaling_binarization,
|
||||||
|
rotation, rotation_not_90, thetha, scaling_flip, task, augmentation=augmentation,
|
||||||
|
patches=patches, dir_img_bin=dir_img_bin,number_of_backgrounds_per_image=number_of_backgrounds_per_image,list_all_possible_background_images=list_all_possible_background_images, dir_rgb_backgrounds=dir_rgb_backgrounds, dir_rgb_foregrounds=dir_rgb_foregrounds,list_all_possible_foreground_rgbs=list_all_possible_foreground_rgbs)
|
||||||
|
|
||||||
|
provide_patches(imgs_list_test, segs_list_test, dir_img_val, dir_seg_val,
|
||||||
|
dir_flow_eval_imgs, dir_flow_eval_labels, input_height, input_width,
|
||||||
|
blur_k, blur_aug, padding_white, padding_black, flip_aug, binarization, adding_rgb_background, adding_rgb_foreground, add_red_textlines, channels_shuffling,
|
||||||
|
scaling, shifting, degrading, brightening, scales, degrade_scales, brightness,
|
||||||
|
flip_index, shuffle_indexes, scaling_bluring, scaling_brightness, scaling_binarization,
|
||||||
|
rotation, rotation_not_90, thetha, scaling_flip, task, augmentation=False, patches=patches,dir_img_bin=dir_img_bin,number_of_backgrounds_per_image=number_of_backgrounds_per_image,list_all_possible_background_images=list_all_possible_background_images, dir_rgb_backgrounds=dir_rgb_backgrounds,dir_rgb_foregrounds=dir_rgb_foregrounds,list_all_possible_foreground_rgbs=list_all_possible_foreground_rgbs )
|
||||||
|
|
||||||
|
if weighted_loss:
|
||||||
|
weights = np.zeros(n_classes)
|
||||||
|
if data_is_provided:
|
||||||
|
for obj in os.listdir(dir_flow_train_labels):
|
||||||
|
try:
|
||||||
|
label_obj = cv2.imread(dir_flow_train_labels + '/' + obj)
|
||||||
|
label_obj_one_hot = get_one_hot(label_obj, label_obj.shape[0], label_obj.shape[1], n_classes)
|
||||||
|
weights += (label_obj_one_hot.sum(axis=0)).sum(axis=0)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
|
||||||
|
for obj in os.listdir(dir_seg):
|
||||||
|
try:
|
||||||
|
label_obj = cv2.imread(dir_seg + '/' + obj)
|
||||||
|
label_obj_one_hot = get_one_hot(label_obj, label_obj.shape[0], label_obj.shape[1], n_classes)
|
||||||
|
weights += (label_obj_one_hot.sum(axis=0)).sum(axis=0)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
weights = 1.00 / weights
|
||||||
|
|
||||||
|
weights = weights / float(np.sum(weights))
|
||||||
|
weights = weights / float(np.min(weights))
|
||||||
|
weights = weights / float(np.sum(weights))
|
||||||
|
|
||||||
|
if continue_training:
|
||||||
|
if backbone_type=='nontransformer':
|
||||||
|
if is_loss_soft_dice and (task == "segmentation" or task == "binarization"):
|
||||||
|
model = load_model(dir_of_start_model, compile=True, custom_objects={'soft_dice_loss': soft_dice_loss})
|
||||||
|
if weighted_loss and (task == "segmentation" or task == "binarization"):
|
||||||
|
model = load_model(dir_of_start_model, compile=True, custom_objects={'loss': weighted_categorical_crossentropy(weights)})
|
||||||
|
if not is_loss_soft_dice and not weighted_loss:
|
||||||
|
model = load_model(dir_of_start_model , compile=True)
|
||||||
|
elif backbone_type=='transformer':
|
||||||
|
if is_loss_soft_dice and (task == "segmentation" or task == "binarization"):
|
||||||
|
model = load_model(dir_of_start_model, compile=True, custom_objects={"PatchEncoder": PatchEncoder, "Patches": Patches,'soft_dice_loss': soft_dice_loss})
|
||||||
|
if weighted_loss and (task == "segmentation" or task == "binarization"):
|
||||||
|
model = load_model(dir_of_start_model, compile=True, custom_objects={'loss': weighted_categorical_crossentropy(weights)})
|
||||||
|
if not is_loss_soft_dice and not weighted_loss:
|
||||||
|
model = load_model(dir_of_start_model , compile=True,custom_objects = {"PatchEncoder": PatchEncoder, "Patches": Patches})
|
||||||
|
else:
|
||||||
|
index_start = 0
|
||||||
|
if backbone_type=='nontransformer':
|
||||||
|
model = resnet50_unet(n_classes, input_height, input_width, task, weight_decay, pretraining)
|
||||||
|
elif backbone_type=='transformer':
|
||||||
|
num_patches_x = transformer_num_patches_xy[0]
|
||||||
|
num_patches_y = transformer_num_patches_xy[1]
|
||||||
|
num_patches = num_patches_x * num_patches_y
|
||||||
|
|
||||||
|
if transformer_cnn_first:
|
||||||
|
if input_height != (num_patches_y * transformer_patchsize_y * 32):
|
||||||
|
print("Error: transformer_patchsize_y or transformer_num_patches_xy height value error . input_height should be equal to ( transformer_num_patches_xy height value * transformer_patchsize_y * 32)")
|
||||||
|
sys.exit(1)
|
||||||
|
if input_width != (num_patches_x * transformer_patchsize_x * 32):
|
||||||
|
print("Error: transformer_patchsize_x or transformer_num_patches_xy width value error . input_width should be equal to ( transformer_num_patches_xy width value * transformer_patchsize_x * 32)")
|
||||||
|
sys.exit(1)
|
||||||
|
if (transformer_projection_dim % (transformer_patchsize_y * transformer_patchsize_x)) != 0:
|
||||||
|
print("Error: transformer_projection_dim error. The remainder when parameter transformer_projection_dim is divided by (transformer_patchsize_y*transformer_patchsize_x) should be zero")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
model = vit_resnet50_unet(n_classes, transformer_patchsize_x, transformer_patchsize_y, num_patches, transformer_mlp_head_units, transformer_layers, transformer_num_heads, transformer_projection_dim, input_height, input_width, task, weight_decay, pretraining)
|
||||||
|
else:
|
||||||
|
if input_height != (num_patches_y * transformer_patchsize_y):
|
||||||
|
print("Error: transformer_patchsize_y or transformer_num_patches_xy height value error . input_height should be equal to ( transformer_num_patches_xy height value * transformer_patchsize_y)")
|
||||||
|
sys.exit(1)
|
||||||
|
if input_width != (num_patches_x * transformer_patchsize_x):
|
||||||
|
print("Error: transformer_patchsize_x or transformer_num_patches_xy width value error . input_width should be equal to ( transformer_num_patches_xy width value * transformer_patchsize_x)")
|
||||||
|
sys.exit(1)
|
||||||
|
if (transformer_projection_dim % (transformer_patchsize_y * transformer_patchsize_x)) != 0:
|
||||||
|
print("Error: transformer_projection_dim error. The remainder when parameter transformer_projection_dim is divided by (transformer_patchsize_y*transformer_patchsize_x) should be zero")
|
||||||
|
sys.exit(1)
|
||||||
|
model = vit_resnet50_unet_transformer_before_cnn(n_classes, transformer_patchsize_x, transformer_patchsize_y, num_patches, transformer_mlp_head_units, transformer_layers, transformer_num_heads, transformer_projection_dim, input_height, input_width, task, weight_decay, pretraining)
|
||||||
|
|
||||||
|
#if you want to see the model structure just uncomment model summary.
|
||||||
|
model.summary()
|
||||||
|
|
||||||
|
|
||||||
|
if task == "segmentation" or task == "binarization":
|
||||||
|
if not is_loss_soft_dice and not weighted_loss:
|
||||||
|
model.compile(loss='categorical_crossentropy',
|
||||||
|
optimizer=Adam(learning_rate=learning_rate), metrics=['accuracy'])
|
||||||
|
if is_loss_soft_dice:
|
||||||
|
model.compile(loss=soft_dice_loss,
|
||||||
|
optimizer=Adam(learning_rate=learning_rate), metrics=['accuracy'])
|
||||||
|
if weighted_loss:
|
||||||
|
model.compile(loss=weighted_categorical_crossentropy(weights),
|
||||||
|
optimizer=Adam(learning_rate=learning_rate), metrics=['accuracy'])
|
||||||
|
elif task == "enhancement":
|
||||||
|
model.compile(loss='mean_squared_error',
|
||||||
|
optimizer=Adam(learning_rate=learning_rate), metrics=['accuracy'])
|
||||||
|
|
||||||
|
|
||||||
|
# generating train and evaluation data
|
||||||
|
train_gen = data_gen(dir_flow_train_imgs, dir_flow_train_labels, batch_size=n_batch,
|
||||||
|
input_height=input_height, input_width=input_width, n_classes=n_classes, task=task)
|
||||||
|
val_gen = data_gen(dir_flow_eval_imgs, dir_flow_eval_labels, batch_size=n_batch,
|
||||||
|
input_height=input_height, input_width=input_width, n_classes=n_classes, task=task)
|
||||||
|
|
||||||
|
##img_validation_patches = os.listdir(dir_flow_eval_imgs)
|
||||||
|
##score_best=[]
|
||||||
|
##score_best.append(0)
|
||||||
|
|
||||||
|
if save_interval:
|
||||||
|
save_weights_callback = SaveWeightsAfterSteps(save_interval, dir_output, _config)
|
||||||
|
|
||||||
|
|
||||||
|
for i in tqdm(range(index_start, n_epochs + index_start)):
|
||||||
|
if save_interval:
|
||||||
|
model.fit(
|
||||||
|
train_gen,
|
||||||
|
steps_per_epoch=int(len(os.listdir(dir_flow_train_imgs)) / n_batch) - 1,
|
||||||
|
validation_data=val_gen,
|
||||||
|
validation_steps=1,
|
||||||
|
epochs=1, callbacks=[save_weights_callback])
|
||||||
|
else:
|
||||||
|
model.fit(
|
||||||
|
train_gen,
|
||||||
|
steps_per_epoch=int(len(os.listdir(dir_flow_train_imgs)) / n_batch) - 1,
|
||||||
|
validation_data=val_gen,
|
||||||
|
validation_steps=1,
|
||||||
|
epochs=1)
|
||||||
|
|
||||||
|
model.save(os.path.join(dir_output,'model_'+str(i)))
|
||||||
|
|
||||||
|
with open(os.path.join(os.path.join(dir_output,'model_'+str(i)),"config.json"), "w") as fp:
|
||||||
|
json.dump(_config, fp) # encode dict into JSON
|
||||||
|
|
||||||
|
#os.system('rm -rf '+dir_train_flowing)
|
||||||
|
#os.system('rm -rf '+dir_eval_flowing)
|
||||||
|
|
||||||
|
#model.save(dir_output+'/'+'model'+'.h5')
|
||||||
|
elif task=='classification':
|
||||||
|
configuration()
|
||||||
|
model = resnet50_classifier(n_classes, input_height, input_width, weight_decay, pretraining)
|
||||||
|
|
||||||
|
opt_adam = Adam(learning_rate=0.001)
|
||||||
|
model.compile(loss='categorical_crossentropy',
|
||||||
|
optimizer = opt_adam,metrics=['accuracy'])
|
||||||
|
|
||||||
|
|
||||||
|
list_classes = list(classification_classes_name.values())
|
||||||
|
testX, testY = generate_data_from_folder_evaluation(dir_eval, input_height, input_width, n_classes, list_classes)
|
||||||
|
|
||||||
|
y_tot=np.zeros((testX.shape[0],n_classes))
|
||||||
|
|
||||||
|
score_best= [0]
|
||||||
|
|
||||||
|
num_rows = return_number_of_total_training_data(dir_train)
|
||||||
|
weights=[]
|
||||||
|
|
||||||
|
for i in range(n_epochs):
|
||||||
|
history = model.fit( generate_data_from_folder_training(dir_train, n_batch , input_height, input_width, n_classes, list_classes), steps_per_epoch=num_rows / n_batch, verbose=1)#,class_weight=weights)
|
||||||
|
|
||||||
|
y_pr_class = []
|
||||||
|
for jj in range(testY.shape[0]):
|
||||||
|
y_pr=model.predict(testX[jj,:,:,:].reshape(1,input_height,input_width,3), verbose=0)
|
||||||
|
y_pr_ind= np.argmax(y_pr,axis=1)
|
||||||
|
y_pr_class.append(y_pr_ind)
|
||||||
|
|
||||||
|
y_pr_class = np.array(y_pr_class)
|
||||||
|
f1score=f1_score(np.argmax(testY,axis=1), y_pr_class, average='macro')
|
||||||
|
print(i,f1score)
|
||||||
|
|
||||||
|
if f1score>score_best[0]:
|
||||||
|
score_best[0]=f1score
|
||||||
|
model.save(os.path.join(dir_output,'model_best'))
|
||||||
|
|
||||||
|
if f1score > f1_threshold_classification:
|
||||||
|
weights.append(model.get_weights() )
|
||||||
|
|
||||||
|
|
||||||
|
if len(weights) >= 1:
|
||||||
|
new_weights=list()
|
||||||
|
for weights_list_tuple in zip(*weights):
|
||||||
|
new_weights.append( [np.array(weights_).mean(axis=0) for weights_ in zip(*weights_list_tuple)] )
|
||||||
|
|
||||||
|
new_weights = [np.array(x) for x in new_weights]
|
||||||
|
model_weight_averaged=tf.keras.models.clone_model(model)
|
||||||
|
model_weight_averaged.set_weights(new_weights)
|
||||||
|
|
||||||
|
model_weight_averaged.save(os.path.join(dir_output,'model_ens_avg'))
|
||||||
|
with open(os.path.join( os.path.join(dir_output,'model_ens_avg'), "config.json"), "w") as fp:
|
||||||
|
json.dump(_config, fp) # encode dict into JSON
|
||||||
|
|
||||||
|
with open(os.path.join( os.path.join(dir_output,'model_best'), "config.json"), "w") as fp:
|
||||||
|
json.dump(_config, fp) # encode dict into JSON
|
||||||
|
|
||||||
|
elif task=='reading_order':
|
||||||
|
configuration()
|
||||||
|
model = machine_based_reading_order_model(n_classes,input_height,input_width,weight_decay,pretraining)
|
||||||
|
|
||||||
|
dir_flow_train_imgs = os.path.join(dir_train, 'images')
|
||||||
|
dir_flow_train_labels = os.path.join(dir_train, 'labels')
|
||||||
|
|
||||||
|
classes = os.listdir(dir_flow_train_labels)
|
||||||
|
if augmentation:
|
||||||
|
num_rows = len(classes)*(len(thetha) + 1)
|
||||||
|
else:
|
||||||
|
num_rows = len(classes)
|
||||||
|
#ls_test = os.listdir(dir_flow_train_labels)
|
||||||
|
|
||||||
|
#f1score_tot = [0]
|
||||||
|
indexer_start = 0
|
||||||
|
opt = SGD(learning_rate=0.01, momentum=0.9)
|
||||||
|
opt_adam = tf.keras.optimizers.Adam(learning_rate=0.0001)
|
||||||
|
model.compile(loss="binary_crossentropy",
|
||||||
|
optimizer = opt_adam,metrics=['accuracy'])
|
||||||
|
|
||||||
|
if save_interval:
|
||||||
|
save_weights_callback = SaveWeightsAfterSteps(save_interval, dir_output, _config)
|
||||||
|
|
||||||
|
for i in range(n_epochs):
|
||||||
|
if save_interval:
|
||||||
|
history = model.fit(generate_arrays_from_folder_reading_order(dir_flow_train_labels, dir_flow_train_imgs, n_batch, input_height, input_width, n_classes, thetha, augmentation), steps_per_epoch=num_rows / n_batch, verbose=1, callbacks=[save_weights_callback])
|
||||||
|
else:
|
||||||
|
history = model.fit(generate_arrays_from_folder_reading_order(dir_flow_train_labels, dir_flow_train_imgs, n_batch, input_height, input_width, n_classes, thetha, augmentation), steps_per_epoch=num_rows / n_batch, verbose=1)
|
||||||
|
model.save( os.path.join(dir_output,'model_'+str(i+indexer_start) ))
|
||||||
|
|
||||||
|
with open(os.path.join(os.path.join(dir_output,'model_'+str(i)),"config.json"), "w") as fp:
|
||||||
|
json.dump(_config, fp) # encode dict into JSON
|
||||||
|
'''
|
||||||
|
if f1score>f1score_tot[0]:
|
||||||
|
f1score_tot[0] = f1score
|
||||||
|
model_dir = os.path.join(dir_out,'model_best')
|
||||||
|
model.save(model_dir)
|
||||||
|
'''
|
||||||
|
|
||||||
|
|
1056
train/utils.py
Normal file
1056
train/utils.py
Normal file
File diff suppressed because it is too large
Load diff
Loading…
Add table
Add a link
Reference in a new issue