mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-03-24 08:02:45 +01:00
Merge 6d55f297a5 into c9f6aa35b2
This commit is contained in:
commit
876467b78d
37 changed files with 4640 additions and 5503 deletions
283
docs/train.md
283
docs/train.md
|
|
@ -47,9 +47,9 @@ on how to generate the corresponding training dataset.
|
|||
The following three tasks can all be accomplished using the code in the
|
||||
[`train`](https://github.com/qurator-spk/eynollah/tree/main/train) directory:
|
||||
|
||||
* generate training dataset
|
||||
* train a model
|
||||
* inference with the trained model
|
||||
* [Generate training dataset](#generate-training-dataset)
|
||||
* [Train a model](#train-a-model)
|
||||
* [Inference with the trained model](#inference-with-the-trained-model)
|
||||
|
||||
## Training, evaluation and output
|
||||
|
||||
|
|
@ -101,7 +101,7 @@ serve as labels. The enhancement model can be trained with this generated datase
|
|||
For machine-based reading order, we aim to determine the reading priority between two sets of text regions. The model's
|
||||
input is a three-channel image: the first and last channels contain information about each of the two text regions,
|
||||
while the middle channel encodes prominent layout elements necessary for reading order, such as separators and headers.
|
||||
To generate the training dataset, our script requires a page XML file that specifies the image layout with the correct
|
||||
To generate the training dataset, our script requires a PAGE XML file that specifies the image layout with the correct
|
||||
reading order.
|
||||
|
||||
For output images, it is necessary to specify the width and height. Additionally, a minimum text region size can be set
|
||||
|
|
@ -120,8 +120,14 @@ eynollah-training generate-gt machine-based-reading-order \
|
|||
|
||||
### pagexml2label
|
||||
|
||||
pagexml2label is designed to generate labels from GT page XML files for various pixel-wise segmentation use cases,
|
||||
including 'layout,' 'textline,' 'printspace,' 'glyph,' and 'word' segmentation.
|
||||
`pagexml2label` is designed to generate labels from PAGE XML GT files for various pixel-wise segmentation use cases,
|
||||
including:
|
||||
- `printspace` (i.e. page frame),
|
||||
- `layout` (i.e. regions),
|
||||
- `textline`,
|
||||
- `word`, and
|
||||
- `glyph`.
|
||||
|
||||
To train a pixel-wise segmentation model, we require images along with their corresponding labels. Our training script
|
||||
expects a PNG image where each pixel corresponds to a label, represented by an integer. The background is always labeled
|
||||
as zero, while other elements are assigned different integers. For instance, if we have ground truth data with four
|
||||
|
|
@ -131,7 +137,7 @@ In binary segmentation scenarios such as textline or page extraction, the backgr
|
|||
element is automatically encoded as 1 in the PNG label.
|
||||
|
||||
To specify the desired use case and the elements to be extracted in the PNG labels, a custom JSON file can be passed.
|
||||
For example, in the case of 'textline' detection, the JSON file would resemble this:
|
||||
For example, in the case of textline detection, the JSON contents could be this:
|
||||
|
||||
```yaml
|
||||
{
|
||||
|
|
@ -139,61 +145,77 @@ For example, in the case of 'textline' detection, the JSON file would resemble t
|
|||
}
|
||||
```
|
||||
|
||||
In the case of layout segmentation a custom config json file can look like this:
|
||||
In the case of layout segmentation, the config JSON file might look like this:
|
||||
|
||||
```yaml
|
||||
{
|
||||
"use_case": "layout",
|
||||
"textregions":{"rest_as_paragraph":1 , "drop-capital": 1, "header":2, "heading":2, "marginalia":3},
|
||||
"imageregion":4,
|
||||
"separatorregion":5,
|
||||
"graphicregions" :{"rest_as_decoration":6 ,"stamp":7}
|
||||
"textregions": {"rest_as_paragraph": 1, "drop-capital": 1, "header": 2, "heading": 2, "marginalia": 3},
|
||||
"imageregion": 4,
|
||||
"separatorregion": 5,
|
||||
"graphicregions": {"rest_as_decoration": 6, "stamp": 7}
|
||||
}
|
||||
```
|
||||
|
||||
A possible custom config json file for layout segmentation where the "printspace" is a class:
|
||||
The same example if `PrintSpace` (or `Border`) should be represented as a unique class:
|
||||
|
||||
```yaml
|
||||
{
|
||||
"use_case": "layout",
|
||||
"textregions":{"rest_as_paragraph":1 , "drop-capital": 1, "header":2, "heading":2, "marginalia":3},
|
||||
"imageregion":4,
|
||||
"separatorregion":5,
|
||||
"graphicregions" :{"rest_as_decoration":6 ,"stamp":7}
|
||||
"printspace_as_class_in_layout" : 8
|
||||
"textregions": {"rest_as_paragraph": 1, "drop-capital": 1, "header": 2, "heading": 2, "marginalia": 3},
|
||||
"imageregion": 4,
|
||||
"separatorregion": 5,
|
||||
"graphicregions": {"rest_as_decoration": 6, "stamp": 7}
|
||||
"printspace_as_class_in_layout": 8
|
||||
}
|
||||
```
|
||||
|
||||
For the layout use case, it is beneficial to first understand the structure of the page XML file and its elements.
|
||||
In a given image, the annotations of elements are recorded in a page XML file, including their contours and classes.
|
||||
For an image document, the known regions are 'textregion', 'separatorregion', 'imageregion', 'graphicregion',
|
||||
'noiseregion', and 'tableregion'.
|
||||
In the `layout` use-case, it is beneficial to first understand the structure of the PAGE XML file and its elements.
|
||||
For a given page image, the visible segments are annotated in XML with their polygon coordinates and types.
|
||||
On the region level, available segment types include `TextRegion`, `SeparatorRegion`, `ImageRegion`, `GraphicRegion`,
|
||||
`NoiseRegion` and `TableRegion`.
|
||||
|
||||
Text regions and graphic regions also have their own specific types. The known types for text regions are 'paragraph',
|
||||
'header', 'heading', 'marginalia', 'drop-capital', 'footnote', 'footnote-continued', 'signature-mark', 'page-number',
|
||||
and 'catch-word'. The known types for graphic regions are 'handwritten-annotation', 'decoration', 'stamp', and
|
||||
'signature'.
|
||||
Since we don't know all types of text and graphic regions, unknown cases can arise. To handle these, we have defined
|
||||
two additional types, "rest_as_paragraph" and "rest_as_decoration", to ensure that no unknown types are missed.
|
||||
This way, users can extract all known types from the labels and be confident that no unknown types are overlooked.
|
||||
Moreover, text regions and graphic regions in particular are subdivided via `@type`:
|
||||
- The allowed subtypes for text regions are `paragraph`, `heading`, `marginalia`, `drop-capital`, `header`, `footnote`,
|
||||
`footnote-continued`, `signature-mark`, `page-number` and `catch-word`.
|
||||
- The known subtypes for graphic regions are `handwritten-annotation`, `decoration`, `stamp` and `signature`.
|
||||
|
||||
In the custom JSON file shown above, "header" and "heading" are extracted as the same class, while "marginalia" is shown
|
||||
as a different class. All other text region types, including "drop-capital," are grouped into the same class. For the
|
||||
graphic region, "stamp" has its own class, while all other types are classified together. "Image region" and "separator
|
||||
region" are also present in the label. However, other regions like "noise region" and "table region" will not be
|
||||
included in the label PNG file, even if they have information in the page XML files, as we chose not to include them.
|
||||
These types and subtypes must be mapped to classes for the segmentation model. However, sometimes these fine-grained
|
||||
distinctions are not useful or the existing annotations are not very usable (too scarce or too unreliable).
|
||||
In that case, instead of these subtypes with a specific mapping, they can be pooled together by using the two special
|
||||
types:
|
||||
- `rest_as_paragraph` (mapping missing TextRegion subtypes and `paragraph`)
|
||||
- `rest_as_decoration` (mapping missing GraphicRegion subtypes and `decoration`)
|
||||
|
||||
(That way, users can extract all known types from the labels and be confident that no subtypes are overlooked.)
|
||||
|
||||
In the custom JSON example shown above, `header` and `heading` are extracted as the same class,
|
||||
while `marginalia` is modelled as a different class. All other text region types, including `drop-capital`,
|
||||
are grouped into the same class. For graphic regions, `stamp` has its own class, while all other types
|
||||
are classified together. `ImageRegion` and `SeparatorRegion` will also represented with a class label in the
|
||||
training data. However, other regions like `NoiseRegion` or `TableRegion` will not be included in the PNG files,
|
||||
even if they were present in the PAGE XML.
|
||||
|
||||
The tool expects various command-line options:
|
||||
|
||||
```sh
|
||||
eynollah-training generate-gt pagexml2label \
|
||||
-dx "dir of GT xml files" \
|
||||
-do "dir where output label png files will be written" \
|
||||
-cfg "custom config json file" \
|
||||
-to "output type which has 2d and 3d. 2d is used for training and 3d is just to visualise the labels"
|
||||
-dx "dir of input PAGE XML files" \
|
||||
-do "dir of output label PNG files" \
|
||||
-cfg "custom config JSON file" \
|
||||
-to "output type (2d or 3d)"
|
||||
```
|
||||
|
||||
We have also defined an artificial class that can be added to the boundary of text region types or text lines. This key
|
||||
is called "artificial_class_on_boundary." If users want to apply this to certain text regions in the layout use case,
|
||||
the example JSON config file should look like this:
|
||||
As output type, use
|
||||
- `2d` for training,
|
||||
- `3d` to just visualise the labels.
|
||||
|
||||
We have also defined an artificial class that can be added to (rendered around) the boundary
|
||||
of text region types or text lines in order to make separation of neighbouring segments more
|
||||
reliable. The key is called `artificial_class_on_boundary`, and it takes a list of text region
|
||||
types to be applied to.
|
||||
|
||||
Our example JSON config file could then look like this:
|
||||
|
||||
```yaml
|
||||
{
|
||||
|
|
@ -215,14 +237,15 @@ the example JSON config file should look like this:
|
|||
}
|
||||
```
|
||||
|
||||
This implies that the artificial class label, denoted by 7, will be present on PNG files and will only be added to the
|
||||
elements labeled as "paragraph," "header," "heading," and "marginalia."
|
||||
This implies that the artificial class label (denoted by 7) will be present in the generated PNG files
|
||||
and will only be added around segments labeled `paragraph`, `header`, `heading` or `marginalia`. (This
|
||||
class will be handled specially during decoding at inference, and not show up in final results.)
|
||||
|
||||
For "textline", "word", and "glyph", the artificial class on the boundaries will be activated only if the
|
||||
"artificial_class_label" key is specified in the config file. Its value should be set as 2 since these elements
|
||||
represent binary cases. For example, if the background and textline are denoted as 0 and 1 respectively, then the
|
||||
artificial class should be assigned the value 2. The example JSON config file should look like this for "textline" use
|
||||
case:
|
||||
For `printspace`, `textline`, `word`, and `glyph` segmentation use-cases, there is no `artificial_class_on_boundary` key,
|
||||
but `artificial_class_label` is available. If specified in the config file, then its value should be set at 2, because
|
||||
these elements represent binary classification problems (with background represented as 0, and segments as 1, respectively).
|
||||
|
||||
For example, the JSON config for textline detection could look as follows:
|
||||
|
||||
```yaml
|
||||
{
|
||||
|
|
@ -231,33 +254,38 @@ case:
|
|||
}
|
||||
```
|
||||
|
||||
If the coordinates of "PrintSpace" or "Border" are present in the page XML ground truth files, and the user wishes to
|
||||
crop only the print space area, this can be achieved by activating the "-ps" argument. However, it should be noted that
|
||||
in this scenario, since cropping will be applied to the label files, the directory of the original images must be
|
||||
provided to ensure that they are cropped in sync with the labels. This ensures that the correct images and labels
|
||||
required for training are obtained. The command should resemble the following:
|
||||
If the coordinates of `PrintSpace` (or `Border`) are present in the PAGE XML ground truth files,
|
||||
and one wishes to crop images to only cover the print space bounding box, this can be achieved
|
||||
by passing the `-ps` option. Note that in this scenario, the directory of the original images
|
||||
must also be provided, to ensure that the images are cropped in sync with the labels. The command
|
||||
line would then resemble this:
|
||||
|
||||
```sh
|
||||
eynollah-training generate-gt pagexml2label \
|
||||
-dx "dir of GT xml files" \
|
||||
-do "dir where output label png files will be written" \
|
||||
-cfg "custom config json file" \
|
||||
-to "output type which has 2d and 3d. 2d is used for training and 3d is just to visualise the labels" \
|
||||
-dx "dir of input PAGE XML files" \
|
||||
-do "dir of output label PNG files" \
|
||||
-cfg "custom config JSON file" \
|
||||
-to "output type (2d or 3d)" \
|
||||
-ps \
|
||||
-di "dir where the org images are located" \
|
||||
-doi "dir where the cropped output images will be written"
|
||||
-di "dir of input original images" \
|
||||
-doi "dir of output cropped images"
|
||||
```
|
||||
|
||||
Also, note that it can be detrimental to layout training if there are visible segments which
|
||||
the annotation does not account for (and thus the model must learn to ignore). So if the images
|
||||
are not cropped, the `-ps` _should_ be used. If a PAGE XML file is missing `PrintSpace` (or `Border`)
|
||||
annotations, use `-mps` to either `skip` these or `project` (i.e. crop from existing segments).
|
||||
|
||||
## Train a model
|
||||
|
||||
### classification
|
||||
|
||||
For the classification use case, we haven't provided a ground truth generator, as it's unnecessary. For classification,
|
||||
all we require is a training directory with subdirectories, each containing images of its respective classes. We need
|
||||
For the image classification use-case, we have not provided a ground truth generator, as it is unnecessary.
|
||||
All we require is a training directory with subdirectories, each containing images of its respective classes. We need
|
||||
separate directories for training and evaluation, and the class names (subdirectories) must be consistent across both
|
||||
directories. Additionally, the class names should be specified in the config JSON file, as shown in the following
|
||||
example. If, for instance, we aim to classify "apple" and "orange," with a total of 2 classes, the
|
||||
"classification_classes_name" key in the config file should appear as follows:
|
||||
`classification_classes_name` key in the config file should appear as follows:
|
||||
|
||||
```yaml
|
||||
{
|
||||
|
|
@ -279,7 +307,7 @@ example. If, for instance, we aim to classify "apple" and "orange," with a total
|
|||
}
|
||||
```
|
||||
|
||||
The "dir_train" should be like this:
|
||||
Then `dir_train` should be like this:
|
||||
|
||||
```
|
||||
.
|
||||
|
|
@ -288,7 +316,7 @@ The "dir_train" should be like this:
|
|||
└── orange # directory of images for orange class
|
||||
```
|
||||
|
||||
And the "dir_eval" the same structure as train directory:
|
||||
And `dir_eval` analogously:
|
||||
|
||||
```
|
||||
.
|
||||
|
|
@ -348,7 +376,7 @@ And the "dir_eval" the same structure as train directory:
|
|||
└── labels # directory of labels
|
||||
```
|
||||
|
||||
The classification model can be trained like the classification case command line.
|
||||
The reading-order model can be trained like the classification case command line.
|
||||
|
||||
### Segmentation (Textline, Binarization, Page extraction and layout) and enhancement
|
||||
|
||||
|
|
@ -358,51 +386,17 @@ The following parameter configuration can be applied to all segmentation use cas
|
|||
its sub-parameters, and continued training are defined only for segmentation use cases and enhancements, not for
|
||||
classification and machine-based reading order, as you can see in their example config files.
|
||||
|
||||
* `backbone_type`: For segmentation tasks (such as text line, binarization, and layout detection) and enhancement, we
|
||||
offer two backbone options: a "nontransformer" and a "transformer" backbone. For the "transformer" backbone, we first
|
||||
apply a CNN followed by a transformer. In contrast, the "nontransformer" backbone utilizes only a CNN ResNet-50.
|
||||
* `task`: The task parameter can have values such as "segmentation", "enhancement", "classification", and "reading_order".
|
||||
* `patches`: If you want to break input images into smaller patches (input size of the model) you need to set this
|
||||
* parameter to `true`. In the case that the model should see the image once, like page extraction, patches should be
|
||||
set to ``false``.
|
||||
* `n_batch`: Number of batches at each iteration.
|
||||
* `n_classes`: Number of classes. In the case of binary classification this should be 2. In the case of reading_order it
|
||||
should set to 1. And for the case of layout detection just the unique number of classes should be given.
|
||||
* `n_epochs`: Number of epochs.
|
||||
* `input_height`: This indicates the height of model's input.
|
||||
* `input_width`: This indicates the width of model's input.
|
||||
* `weight_decay`: Weight decay of l2 regularization of model layers.
|
||||
* `pretraining`: Set to `true` to load pretrained weights of ResNet50 encoder. The downloaded weights should be saved
|
||||
in a folder named "pretrained_model" in the same directory of "train.py" script.
|
||||
* `augmentation`: If you want to apply any kind of augmentation this parameter should first set to `true`.
|
||||
* `flip_aug`: If `true`, different types of filp will be applied on image. Type of flips is given with "flip_index" parameter.
|
||||
* `blur_aug`: If `true`, different types of blurring will be applied on image. Type of blurrings is given with "blur_k" parameter.
|
||||
* `scaling`: If `true`, scaling will be applied on image. Scale of scaling is given with "scales" parameter.
|
||||
* `degrading`: If `true`, degrading will be applied to the image. The amount of degrading is defined with "degrade_scales" parameter.
|
||||
* `brightening`: If `true`, brightening will be applied to the image. The amount of brightening is defined with "brightness" parameter.
|
||||
* `rotation_not_90`: If `true`, rotation (not 90 degree) will be applied on image. Rotation angles are given with "thetha" parameter.
|
||||
* `rotation`: If `true`, 90 degree rotation will be applied on image.
|
||||
* `binarization`: If `true`,Otsu thresholding will be applied to augment the input data with binarized images.
|
||||
* `scaling_bluring`: If `true`, combination of scaling and blurring will be applied on image.
|
||||
* `scaling_binarization`: If `true`, combination of scaling and binarization will be applied on image.
|
||||
* `scaling_flip`: If `true`, combination of scaling and flip will be applied on image.
|
||||
* `flip_index`: Type of flips.
|
||||
* `blur_k`: Type of blurrings.
|
||||
* `scales`: Scales of scaling.
|
||||
* `brightness`: The amount of brightenings.
|
||||
* `thetha`: Rotation angles.
|
||||
* `degrade_scales`: The amount of degradings.
|
||||
* `continue_training`: If `true`, it means that you have already trained a model and you would like to continue the
|
||||
training. So it is needed to providethe dir of trained model with "dir_of_start_model" and index for naming
|
||||
themodels. For example if you have already trained for 3 epochs then your lastindex is 2 and if you want to continue
|
||||
from model_1.h5, you can set `index_start` to 3 to start naming model with index 3.
|
||||
* `weighted_loss`: If `true`, this means that you want to apply weighted categorical_crossentropy as loss fucntion. Be carefull if you set to `true`the parameter "is_loss_soft_dice" should be ``false``
|
||||
* `data_is_provided`: If you have already provided the input data you can set this to `true`. Be sure that the train
|
||||
and eval data are in"dir_output".Since when once we provide training data we resize and augmentthem and then wewrite
|
||||
them in sub-directories train and eval in "dir_output".
|
||||
* `dir_train`: This is the directory of "images" and "labels" (dir_train should include two subdirectories with names of images and labels ) for raw images and labels. Namely they are not prepared (not resized and not augmented) yet for training the model. When we run this tool these raw data will be transformed to suitable size needed for the model and they will be written in "dir_output" in train and eval directories. Each of train and eval include "images" and "labels" sub-directories.
|
||||
* `index_start`: Starting index for saved models in the case that "continue_training" is `true`.
|
||||
* `dir_of_start_model`: Directory containing pretrained model to continue training the model in the case that "continue_training" is `true`.
|
||||
* `task`: The task parameter must be one of the following values:
|
||||
- `binarization`,
|
||||
- `enhancement`,
|
||||
- `segmentation`,
|
||||
- `classification`,
|
||||
- `reading_order`.
|
||||
* `backbone_type`: For the tasks `segmentation` (such as text line, and region layout detection),
|
||||
`binarization` and `enhancement`, we offer two backbone options:
|
||||
- `nontransformer` (only a CNN ResNet-50).
|
||||
- `transformer` (first apply a CNN, followed by a transformer)
|
||||
* `transformer_cnn_first`: Whether to apply the CNN first (followed by the transformer) when using `transformer` backbone.
|
||||
* `transformer_num_patches_xy`: Number of patches for vision transformer in x and y direction respectively.
|
||||
* `transformer_patchsize_x`: Patch size of vision transformer patches in x direction.
|
||||
* `transformer_patchsize_y`: Patch size of vision transformer patches in y direction.
|
||||
|
|
@ -410,11 +404,63 @@ classification and machine-based reading order, as you can see in their example
|
|||
* `transformer_mlp_head_units`: Transformer Multilayer Perceptron (MLP) head units. Default value is [128, 64].
|
||||
* `transformer_layers`: transformer layers. Default value is 8.
|
||||
* `transformer_num_heads`: Transformer number of heads. Default value is 4.
|
||||
* `transformer_cnn_first`: We have two types of vision transformers. In one type, a CNN is applied first, followed by a transformer. In the other type, this order is reversed. If transformer_cnn_first is true, it means the CNN will be applied before the transformer. Default value is true.
|
||||
* `patches`: Whether to break up (tile) input images into smaller patches (input size of the model).
|
||||
If `false`, the model will see the image once (resized to the input size of the model).
|
||||
Should be set to `false` for cases like page extraction.
|
||||
* `n_batch`: Number of batches at each iteration.
|
||||
* `n_classes`: Number of classes. In the case of binary classification this should be 2. In the case of reading_order it
|
||||
should set to 1. And for the case of layout detection just the unique number of classes should be given.
|
||||
* `n_epochs`: Number of epochs (iterations over the data) to train.
|
||||
* `input_height`: the image height for the model's input.
|
||||
* `input_width`: the image width for the model's input.
|
||||
* `weight_decay`: Weight decay of l2 regularization of model layers.
|
||||
* `weighted_loss`: If `true`, this means that you want to apply weighted categorical crossentropy as loss function.
|
||||
(Mutually exclusive with `is_loss_soft_dice`, and only applies for `segmentation` and `binarization` tasks.)
|
||||
* `pretraining`: Set to `true` to (download and) initialise pretrained weights of ResNet50 encoder.
|
||||
* `dir_train`: Path to directory of raw training data (as extracted via `pagexml2labels`, i.e. with subdirectories
|
||||
`images` and `labels` for input images and output labels.
|
||||
(These are not prepared for training the model, yet. Upon first run, the raw data will be transformed to suitable size
|
||||
needed for the model, and written in `dir_output` under `train` and `eval` subdirectories. See `data_is_provided`.)
|
||||
* `dir_eval`: Ditto for raw evaluation data.
|
||||
* `dir_output`: Directory to write model checkpoints, logs (for Tensorboard) and precomputed images to.
|
||||
* `data_is_provided`: If you have already trained at least one complete epoch (using the same data settings) before,
|
||||
you can set this to `true` to avoid computing the resized / patched / augmented image files again.
|
||||
Be sure that there are subdirectories `train` and `eval` data are in `dir_output` (each with subdirectories `images`
|
||||
and `labels`, respectively).
|
||||
* `continue_training`: If `true`, continue training a model checkpoint from a previous run.
|
||||
This requires providing the directory of the model checkpoint to load via `dir_of_start_model`
|
||||
and setting `index_start` counter for naming new checkpoints.
|
||||
For example if you have already trained for 3 epochs, then your last index is 2, so if you want
|
||||
to continue with `model_04`, `model_05` etc., set `index_start=3`.
|
||||
* `index_start`: Starting index for saving models in the case that `continue_training` is `true`.
|
||||
(Existing checkpoints above this will be overwritten.)
|
||||
* `dir_of_start_model`: Directory containing existing model checkpoint to initialise model weights from when `continue_training=true`.
|
||||
(Can be an epoch-interval checkpoint, or batch-interval checkpoint from `save_interval`.)
|
||||
* `augmentation`: If you want to apply any kind of augmentation this parameter should first set to `true`.
|
||||
The remaining settings pertain to that...
|
||||
* `flip_aug`: If `true`, different types of flipping over the image arrays. Requires `flip_index` parameter.
|
||||
* `flip_index`: List of flip codes (as in `cv2.flip`, i.e. 0 for vertical, positive for horizontal shift, negative for vertical and horizontal shift).
|
||||
* `blur_aug`: If `true`, different types of blurring will be applied on image. Requires `blur_k` parameter.
|
||||
* `blur_k`: Method of blurring (`gauss`, `median` or `blur`).
|
||||
* `scaling`: If `true`, scaling will be applied on image. Requires `scales` parameter.
|
||||
* `scales`: List of scale factors for scaling.
|
||||
* `scaling_bluring`: If `true`, combination of scaling and blurring will be applied on image.
|
||||
* `scaling_binarization`: If `true`, combination of scaling and binarization will be applied on image.
|
||||
* `scaling_flip`: If `true`, combination of scaling and flip will be applied on image.
|
||||
* `degrading`: If `true`, degrading will be applied to the image. Requires `degrade_scales` parameter.
|
||||
* `degrade_scales`: List of intensity factors for degrading.
|
||||
* `brightening`: If `true`, brightening will be applied to the image. Requires `brightness` parameter.
|
||||
* `brightness`: List of intensity factors for brightening.
|
||||
* `binarization`: If `true`, Otsu thresholding will be applied to augment the input data with binarized images.
|
||||
* `dir_img_bin`: With `binarization`, use this directory to read precomputed binarized images instead of ad-hoc Otsu.
|
||||
(Base names should correspond to the files in `dir_train/images`.)
|
||||
* `rotation`: If `true`, 90° rotation will be applied on images.
|
||||
* `rotation_not_90`: If `true`, random rotation (other than 90°) will be applied on image. Requires `thetha` parameter.
|
||||
* `thetha`: List of rotation angles (in degrees).
|
||||
|
||||
In the case of segmentation and enhancement the train and evaluation directory should be as following.
|
||||
In case of segmentation and enhancement the train and evaluation data should be organised as follows.
|
||||
|
||||
The "dir_train" should be like this:
|
||||
The "dir_train" directory should be like this:
|
||||
|
||||
```
|
||||
.
|
||||
|
|
@ -432,11 +478,12 @@ And the "dir_eval" the same structure as train directory:
|
|||
└── labels # directory of labels
|
||||
```
|
||||
|
||||
After configuring the JSON file for segmentation or enhancement, training can be initiated by running the following
|
||||
command, similar to the process for classification and reading order:
|
||||
After configuring the JSON file for segmentation or enhancement,
|
||||
training can be initiated by running the following command line,
|
||||
similar to classification and reading-order model training:
|
||||
|
||||
```
|
||||
eynollah-training train with config_classification.json`
|
||||
```sh
|
||||
eynollah-training train with config_classification.json
|
||||
```
|
||||
|
||||
#### Binarization
|
||||
|
|
@ -728,7 +775,7 @@ This will straightforwardly return the class of the image.
|
|||
|
||||
### machine based reading order
|
||||
|
||||
To infer the reading order using a reading order model, we need a page XML file containing layout information but
|
||||
To infer the reading order using a reading order model, we need a PAGE XML file containing layout information but
|
||||
without the reading order. We simply need to provide the model directory, the XML file, and the output directory. The
|
||||
new XML file with the added reading order will be written to the output directory with the same name. We need to run:
|
||||
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
# ocrd includes opencv, numpy, shapely, click
|
||||
ocrd >= 3.3.0
|
||||
numpy <1.24.0
|
||||
numpy < 2.0
|
||||
scikit-learn >= 0.23.2
|
||||
tensorflow < 2.13
|
||||
tensorflow
|
||||
tf-keras # avoid keras 3 (also needs TF_USE_LEGACY_KERAS=1)
|
||||
numba <= 0.58.1
|
||||
scikit-image
|
||||
biopython
|
||||
|
|
|
|||
|
|
@ -2,14 +2,12 @@
|
|||
# this must be the first import of the CLI!
|
||||
from ..eynollah_imports import imported_libs
|
||||
|
||||
from .cli_models import models_cli
|
||||
from .cli_binarize import binarize_cli
|
||||
|
||||
from .cli import main
|
||||
from .cli_binarize import binarize_cli
|
||||
from .cli_enhance import enhance_cli
|
||||
from .cli_extract_images import extract_images_cli
|
||||
from .cli_layout import layout_cli
|
||||
from .cli_models import models_cli
|
||||
from .cli_ocr import ocr_cli
|
||||
from .cli_readingorder import readingorder_cli
|
||||
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ class EynollahCliCtx:
|
|||
# NOTE: not mandatory to exist so --help for subcommands works but will log a warning
|
||||
# and raise exception when trying to load models in the CLI
|
||||
# type=click.Path(exists=True),
|
||||
default=f'{os.getcwd()}/models_eynollah',
|
||||
default=os.getcwd(),
|
||||
)
|
||||
@click.option(
|
||||
"--model-overrides",
|
||||
|
|
|
|||
|
|
@ -21,13 +21,20 @@ import click
|
|||
type=click.Path(file_okay=True, dir_okay=True),
|
||||
required=True,
|
||||
)
|
||||
@click.option(
|
||||
"--overwrite",
|
||||
"-O",
|
||||
help="overwrite (instead of skipping) if output xml exists",
|
||||
is_flag=True,
|
||||
)
|
||||
@click.pass_context
|
||||
def binarize_cli(
|
||||
ctx,
|
||||
patches,
|
||||
input_image,
|
||||
dir_in,
|
||||
output,
|
||||
ctx,
|
||||
patches,
|
||||
input_image,
|
||||
dir_in,
|
||||
output,
|
||||
overwrite,
|
||||
):
|
||||
"""
|
||||
Binarize images with a ML model
|
||||
|
|
@ -39,6 +46,7 @@ def binarize_cli(
|
|||
image_path=input_image,
|
||||
use_patches=patches,
|
||||
output=output,
|
||||
dir_in=dir_in
|
||||
dir_in=dir_in,
|
||||
overwrite=overwrite
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -187,7 +187,6 @@ def layout_cli(
|
|||
assert enable_plotting or not save_all, "Plotting with -sa also requires -ep"
|
||||
assert enable_plotting or not save_page, "Plotting with -sp also requires -ep"
|
||||
assert enable_plotting or not save_images, "Plotting with -si also requires -ep"
|
||||
assert enable_plotting or not allow_enhancement, "Plotting with -ae also requires -ep"
|
||||
assert not enable_plotting or save_layout or save_deskewed or save_all or save_page or save_images or allow_enhancement, \
|
||||
"Plotting with -ep also requires -sl, -sd, -sa, -sp, -si or -ae"
|
||||
assert bool(image) != bool(dir_in), "Either -i (single input) or -di (directory) must be provided, but not both."
|
||||
|
|
|
|||
|
|
@ -116,19 +116,19 @@ class EynollahImageExtractor(Eynollah):
|
|||
prediction_regions_org = prediction_regions_org[page_coord[0] : page_coord[1], page_coord[2] : page_coord[3]]
|
||||
prediction_regions_org=prediction_regions_org[:,:,0]
|
||||
|
||||
mask_lines_only = (prediction_regions_org[:,:] ==3)*1
|
||||
mask_seps_only = (prediction_regions_org[:,:] ==3)*1
|
||||
mask_texts_only = (prediction_regions_org[:,:] ==1)*1
|
||||
mask_images_only=(prediction_regions_org[:,:] ==2)*1
|
||||
|
||||
polygons_seplines, hir_seplines = return_contours_of_image(mask_lines_only)
|
||||
polygons_seplines, hir_seplines = return_contours_of_image(mask_seps_only)
|
||||
polygons_seplines = filter_contours_area_of_image(
|
||||
mask_lines_only, polygons_seplines, hir_seplines, max_area=1, min_area=0.00001, dilate=1)
|
||||
mask_seps_only, polygons_seplines, hir_seplines, max_area=1, min_area=0.00001, dilate=1)
|
||||
|
||||
polygons_of_only_texts = return_contours_of_interested_region(mask_texts_only,1,0.00001)
|
||||
polygons_of_only_lines = return_contours_of_interested_region(mask_lines_only,1,0.00001)
|
||||
polygons_of_only_seps = return_contours_of_interested_region(mask_seps_only,1,0.00001)
|
||||
|
||||
text_regions_p_true = np.zeros(prediction_regions_org.shape)
|
||||
text_regions_p_true = cv2.fillPoly(text_regions_p_true, pts = polygons_of_only_lines, color=(3,3,3))
|
||||
text_regions_p_true = cv2.fillPoly(text_regions_p_true, pts = polygons_of_only_seps, color=(3,3,3))
|
||||
|
||||
text_regions_p_true[:,:][mask_images_only[:,:] == 1] = 2
|
||||
text_regions_p_true = cv2.fillPoly(text_regions_p_true, pts=polygons_of_only_texts, color=(1,1,1))
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -1,6 +1,9 @@
|
|||
"""
|
||||
Load libraries with possible race conditions once. This must be imported as the first module of eynollah.
|
||||
"""
|
||||
import os
|
||||
os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15
|
||||
|
||||
from ocrd_utils import tf_disable_interactive_logs
|
||||
from torch import *
|
||||
tf_disable_interactive_logs()
|
||||
|
|
|
|||
|
|
@ -15,11 +15,13 @@ from pathlib import Path
|
|||
import gc
|
||||
|
||||
import cv2
|
||||
from keras.models import Model
|
||||
import numpy as np
|
||||
import tensorflow as tf # type: ignore
|
||||
from skimage.morphology import skeletonize
|
||||
|
||||
os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15
|
||||
import tensorflow as tf # type: ignore
|
||||
from tensorflow.keras.models import Model
|
||||
|
||||
from .model_zoo import EynollahModelZoo
|
||||
from .utils.resize import resize_image
|
||||
from .utils.pil_cv2 import pil2cv
|
||||
|
|
|
|||
|
|
@ -14,10 +14,12 @@ from pathlib import Path
|
|||
import xml.etree.ElementTree as ET
|
||||
|
||||
import cv2
|
||||
from keras.models import Model
|
||||
import numpy as np
|
||||
import statistics
|
||||
|
||||
os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15
|
||||
import tensorflow as tf
|
||||
from tensorflow.keras.models import Model
|
||||
|
||||
from .model_zoo import EynollahModelZoo
|
||||
from .utils.resize import resize_image
|
||||
|
|
|
|||
|
|
@ -1,17 +1,25 @@
|
|||
import os
|
||||
import json
|
||||
import logging
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15
|
||||
from ocrd_utils import tf_disable_interactive_logs
|
||||
tf_disable_interactive_logs()
|
||||
|
||||
from keras.layers import StringLookup
|
||||
from keras.models import Model as KerasModel
|
||||
from keras.models import load_model
|
||||
from tensorflow.keras.layers import StringLookup
|
||||
from tensorflow.keras.models import Model as KerasModel
|
||||
from tensorflow.keras.models import load_model
|
||||
from tabulate import tabulate
|
||||
from ..patch_encoder import PatchEncoder, Patches
|
||||
|
||||
from ..patch_encoder import (
|
||||
PatchEncoder,
|
||||
Patches,
|
||||
wrap_layout_model_patched,
|
||||
wrap_layout_model_resized,
|
||||
)
|
||||
from .specs import EynollahModelSpecSet
|
||||
from .default_specs import DEFAULT_MODEL_SPECS
|
||||
from .types import AnyModel, T
|
||||
|
|
@ -30,7 +38,7 @@ class EynollahModelZoo:
|
|||
basedir: str,
|
||||
model_overrides: Optional[List[Tuple[str, str, str]]] = None,
|
||||
) -> None:
|
||||
self.model_basedir = Path(basedir)
|
||||
self.model_basedir = Path(basedir).resolve()
|
||||
self.logger = logging.getLogger('eynollah.model_zoo')
|
||||
if not self.model_basedir.exists():
|
||||
self.logger.warning(f"Model basedir does not exist: {basedir}. Set eynollah --model-basedir to the correct directory.")
|
||||
|
|
@ -54,7 +62,7 @@ class EynollahModelZoo:
|
|||
for model_category, model_variant, model_filename in model_overrides:
|
||||
spec = self.specs.get(model_category, model_variant)
|
||||
self.logger.warning("Overriding filename for model spec %s to %s", spec, model_filename)
|
||||
self.specs.get(model_category, model_variant).filename = model_filename
|
||||
self.specs.get(model_category, model_variant).filename = str(Path(model_filename).resolve())
|
||||
self._overrides += model_overrides
|
||||
|
||||
def model_path(
|
||||
|
|
@ -82,6 +90,17 @@ class EynollahModelZoo:
|
|||
"""
|
||||
Load all models by calling load_model and return a dictionary mapping model_category to loaded model
|
||||
"""
|
||||
import tensorflow as tf
|
||||
cuda = False
|
||||
try:
|
||||
for device in tf.config.list_physical_devices('GPU'):
|
||||
tf.config.experimental.set_memory_growth(device, True)
|
||||
cuda = True
|
||||
self.logger.info("using GPU %s", device.name)
|
||||
except RuntimeError:
|
||||
self.logger.exception("cannot configure GPU devices")
|
||||
if not cuda:
|
||||
self.logger.warning("no GPU device available")
|
||||
ret = {}
|
||||
for load_args in all_load_args:
|
||||
if isinstance(load_args, str):
|
||||
|
|
@ -122,7 +141,12 @@ class EynollahModelZoo:
|
|||
model = load_model(
|
||||
model_path, compile=False, custom_objects={"PatchEncoder": PatchEncoder, "Patches": Patches}
|
||||
)
|
||||
model._name = model_category
|
||||
self._loaded[model_category] = model
|
||||
if model_category in ['region_1_2', 'table', 'region_fl_np']:
|
||||
self._loaded[model_category + '_resized'] = wrap_layout_model_resized(model)
|
||||
if model_category in ['region_1_2', 'textline']:
|
||||
self._loaded[model_category + '_patched'] = wrap_layout_model_patched(model)
|
||||
return model # type: ignore
|
||||
|
||||
def get(self, model_category: str, model_type: Optional[Type[T]] = None) -> T:
|
||||
|
|
|
|||
|
|
@ -28,7 +28,19 @@
|
|||
"full_layout": {
|
||||
"type": "boolean",
|
||||
"default": true,
|
||||
"description": "Try to detect all element subtypes, including drop-caps and headings"
|
||||
"description": "Try to detect all region subtypes, including drop-capital and heading"
|
||||
},
|
||||
"light_version": {
|
||||
"type": "boolean",
|
||||
"default": true,
|
||||
"enum": [true],
|
||||
"description": "ignored (only for backwards-compatibility)"
|
||||
},
|
||||
"textline_light": {
|
||||
"type": "boolean",
|
||||
"default": true,
|
||||
"enum": [true],
|
||||
"description": "ignored (only for backwards-compatibility)"
|
||||
},
|
||||
"tables": {
|
||||
"type": "boolean",
|
||||
|
|
@ -38,12 +50,12 @@
|
|||
"curved_line": {
|
||||
"type": "boolean",
|
||||
"default": false,
|
||||
"description": "try to return contour of textlines instead of just rectangle bounding box. Needs more processing time"
|
||||
"description": "retrieve textline polygons independent of each other (needs more processing time)"
|
||||
},
|
||||
"ignore_page_extraction": {
|
||||
"type": "boolean",
|
||||
"default": false,
|
||||
"description": "if this parameter set to true, this tool would ignore page extraction"
|
||||
"description": "if true, do not attempt page frame detection (cropping)"
|
||||
},
|
||||
"allow_scaling": {
|
||||
"type": "boolean",
|
||||
|
|
@ -58,7 +70,7 @@
|
|||
"right_to_left": {
|
||||
"type": "boolean",
|
||||
"default": false,
|
||||
"description": "if this parameter set to true, this tool will extract right-to-left reading order."
|
||||
"description": "if true, return reading order in right-to-left reading direction."
|
||||
},
|
||||
"headers_off": {
|
||||
"type": "boolean",
|
||||
|
|
@ -123,13 +135,22 @@
|
|||
}
|
||||
},
|
||||
"resources": [
|
||||
{
|
||||
"url": "https://zenodo.org/records/17580627/files/models_all_v0_7_0.zip?download=1",
|
||||
"name": "models_layout_v0_7_0",
|
||||
"type": "archive",
|
||||
"size": 6119874002,
|
||||
"description": "Models for layout detection, reading order detection, textline detection, page extraction, column classification, table detection, binarization, image enhancement and OCR",
|
||||
"version_range": ">= v0.7.0"
|
||||
},
|
||||
{
|
||||
"url": "https://github.com/qurator-spk/sbb_binarization/releases/download/v0.0.11/saved_model_2020_01_16.zip",
|
||||
"name": "default",
|
||||
"type": "archive",
|
||||
"path_in_archive": "saved_model_2020_01_16",
|
||||
"size": 563147331,
|
||||
"description": "default models provided by github.com/qurator-spk (SavedModel format)"
|
||||
"description": "default models provided by github.com/qurator-spk (SavedModel format)",
|
||||
"version_range": "< v0.7.0"
|
||||
},
|
||||
{
|
||||
"url": "https://github.com/qurator-spk/sbb_binarization/releases/download/v0.0.11/saved_model_2021_03_09.zip",
|
||||
|
|
@ -137,7 +158,8 @@
|
|||
"type": "archive",
|
||||
"path_in_archive": ".",
|
||||
"size": 133230419,
|
||||
"description": "updated default models provided by github.com/qurator-spk (SavedModel format)"
|
||||
"description": "updated default models provided by github.com/qurator-spk (SavedModel format)",
|
||||
"version_range": "< v0.7.0"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
|
|
|||
|
|
@ -75,7 +75,7 @@ class SbbBinarizeProcessor(Processor):
|
|||
|
||||
if oplevel == 'page':
|
||||
self.logger.info("Binarizing on 'page' level in page '%s'", page_id)
|
||||
page_image_bin = cv2pil(self.binarizer.run(image=pil2cv(page_image), use_patches=True))
|
||||
page_image_bin = cv2pil(self.binarizer.run_single(image=pil2cv(page_image), use_patches=True))
|
||||
# update PAGE (reference the image file):
|
||||
page_image_ref = AlternativeImageType(comments=page_xywh['features'] + ',binarized,clipped')
|
||||
page.add_AlternativeImage(page_image_ref)
|
||||
|
|
@ -88,7 +88,7 @@ class SbbBinarizeProcessor(Processor):
|
|||
for region in regions:
|
||||
region_image, region_xywh = self.workspace.image_from_segment(
|
||||
region, page_image, page_xywh, feature_filter='binarized')
|
||||
region_image_bin = cv2pil(self.binarizer.run(image=pil2cv(region_image), use_patches=True))
|
||||
region_image_bin = cv2pil(self.binarizer.run_single(image=pil2cv(region_image), use_patches=True))
|
||||
# update PAGE (reference the image file):
|
||||
region_image_ref = AlternativeImageType(comments=region_xywh['features'] + ',binarized')
|
||||
region.add_AlternativeImage(region_image_ref)
|
||||
|
|
@ -100,7 +100,7 @@ class SbbBinarizeProcessor(Processor):
|
|||
self.logger.warning("Page '%s' contains no text lines", page_id)
|
||||
for line in lines:
|
||||
line_image, line_xywh = self.workspace.image_from_segment(line, page_image, page_xywh, feature_filter='binarized')
|
||||
line_image_bin = cv2pil(self.binarizer.run(image=pil2cv(line_image), use_patches=True))
|
||||
line_image_bin = cv2pil(self.binarizer.run_single(image=pil2cv(line_image), use_patches=True))
|
||||
# update PAGE (reference the image file):
|
||||
line_image_ref = AlternativeImageType(comments=line_xywh['features'] + ',binarized')
|
||||
line.add_AlternativeImage(line_image_ref)
|
||||
|
|
|
|||
|
|
@ -1,52 +1,160 @@
|
|||
from keras import layers
|
||||
import os
|
||||
os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15
|
||||
import tensorflow as tf
|
||||
|
||||
projection_dim = 64
|
||||
patch_size = 1
|
||||
num_patches =21*21#14*14#28*28#14*14#28*28
|
||||
from tensorflow.keras import layers, models
|
||||
|
||||
class PatchEncoder(layers.Layer):
|
||||
|
||||
def __init__(self):
|
||||
# 441=21*21 # 14*14 # 28*28
|
||||
def __init__(self, num_patches=441, projection_dim=64):
|
||||
super().__init__()
|
||||
self.projection = layers.Dense(units=projection_dim)
|
||||
self.position_embedding = layers.Embedding(input_dim=num_patches, output_dim=projection_dim)
|
||||
self.num_patches = num_patches
|
||||
self.projection_dim = projection_dim
|
||||
self.projection = layers.Dense(self.projection_dim)
|
||||
self.position_embedding = layers.Embedding(self.num_patches, self.projection_dim)
|
||||
|
||||
def call(self, patch):
|
||||
positions = tf.range(start=0, limit=num_patches, delta=1)
|
||||
encoded = self.projection(patch) + self.position_embedding(positions)
|
||||
return encoded
|
||||
positions = tf.range(start=0, limit=self.num_patches, delta=1)
|
||||
return self.projection(patch) + self.position_embedding(positions)
|
||||
|
||||
def get_config(self):
|
||||
config = super().get_config().copy()
|
||||
config.update({
|
||||
'num_patches': num_patches,
|
||||
'projection': self.projection,
|
||||
'position_embedding': self.position_embedding,
|
||||
})
|
||||
return config
|
||||
return dict(num_patches=self.num_patches,
|
||||
projection_dim=self.projection_dim,
|
||||
**super().get_config())
|
||||
|
||||
class Patches(layers.Layer):
|
||||
def __init__(self, **kwargs):
|
||||
super(Patches, self).__init__()
|
||||
self.patch_size = patch_size
|
||||
def __init__(self, patch_size_x=1, patch_size_y=1):
|
||||
super().__init__()
|
||||
self.patch_size_x = patch_size_x
|
||||
self.patch_size_y = patch_size_y
|
||||
|
||||
def call(self, images):
|
||||
batch_size = tf.shape(images)[0]
|
||||
patches = tf.image.extract_patches(
|
||||
images=images,
|
||||
sizes=[1, self.patch_size, self.patch_size, 1],
|
||||
strides=[1, self.patch_size, self.patch_size, 1],
|
||||
sizes=[1, self.patch_size_y, self.patch_size_x, 1],
|
||||
strides=[1, self.patch_size_y, self.patch_size_x, 1],
|
||||
rates=[1, 1, 1, 1],
|
||||
padding="VALID",
|
||||
)
|
||||
patch_dims = patches.shape[-1]
|
||||
patches = tf.reshape(patches, [batch_size, -1, patch_dims])
|
||||
return patches
|
||||
def get_config(self):
|
||||
return tf.reshape(patches, [batch_size, -1, patch_dims])
|
||||
|
||||
config = super().get_config().copy()
|
||||
config.update({
|
||||
'patch_size': self.patch_size,
|
||||
})
|
||||
return config
|
||||
def get_config(self):
|
||||
return dict(patch_size_x=self.patch_size_x,
|
||||
patch_size_y=self.patch_size_y,
|
||||
**super().get_config())
|
||||
|
||||
class wrap_layout_model_resized(models.Model):
|
||||
"""
|
||||
replacement for layout model using resizing to model width/height and back
|
||||
|
||||
(accepts arbitrary width/height input [B, H, W, 3], returns same size segmentation [B, H, W, C])
|
||||
"""
|
||||
def __init__(self, model):
|
||||
super().__init__(name=model.name + '_resized')
|
||||
self.model = model
|
||||
self.height = model.layers[-1].output_shape[1]
|
||||
self.width = model.layers[-1].output_shape[2]
|
||||
|
||||
@tf.function(reduce_retracing=True,
|
||||
#jit_compile=True, (ScaleAndTranslate is not supported by XLA)
|
||||
input_signature=[tf.TensorSpec([1, None, None, 3],
|
||||
dtype=tf.float32)])
|
||||
def call(self, img, training=False):
|
||||
height = tf.shape(img)[1]
|
||||
width = tf.shape(img)[2]
|
||||
img_resized = tf.image.resize(img,
|
||||
(self.height, self.width),
|
||||
antialias=True)
|
||||
pred_resized = self.model(img_resized)
|
||||
pred = tf.image.resize(pred_resized,
|
||||
(height, width))
|
||||
return pred
|
||||
|
||||
def predict(self, x, verbose=0):
|
||||
return self(x).numpy()
|
||||
|
||||
class wrap_layout_model_patched(models.Model):
|
||||
"""
|
||||
replacement for layout model using sliding window for patches
|
||||
|
||||
(accepts arbitrary width/height input [B, H, W, 3], returns same size segmentation [B, H, W, C])
|
||||
"""
|
||||
def __init__(self, model):
|
||||
super().__init__(name=model.name + '_patched')
|
||||
self.model = model
|
||||
self.height = model.layers[-1].output_shape[1]
|
||||
self.width = model.layers[-1].output_shape[2]
|
||||
self.classes = model.layers[-1].output_shape[3]
|
||||
# equivalent of marginal_of_patch_percent=0.1 ...
|
||||
self.stride_x = int(self.width * (1 - 0.1))
|
||||
self.stride_y = int(self.height * (1 - 0.1))
|
||||
offset_height = (self.height - self.stride_y) // 2
|
||||
offset_width = (self.width - self.stride_x) // 2
|
||||
window = tf.image.pad_to_bounding_box(
|
||||
tf.ones((self.stride_y, self.stride_x, 1), dtype=tf.int32),
|
||||
offset_height, offset_width,
|
||||
self.height, self.width)
|
||||
self.window = tf.expand_dims(window, axis=0)
|
||||
|
||||
@tf.function(reduce_retracing=True,
|
||||
#jit_compile=True, (ScaleAndTranslate and ExtractImagePatches not supported by XLA)
|
||||
input_signature=[tf.TensorSpec([1, None, None, 3],
|
||||
dtype=tf.float32)])
|
||||
def call(self, img, training=False):
|
||||
height = tf.shape(img)[1]
|
||||
width = tf.shape(img)[2]
|
||||
if (height < self.height or
|
||||
width < self.width):
|
||||
img_resized = tf.image.resize(img,
|
||||
(self.height, self.width),
|
||||
antialias=True)
|
||||
pred_resized = self.model(img_resized)
|
||||
pred = tf.image.resize(pred_resized,
|
||||
(height, width))
|
||||
return pred
|
||||
|
||||
img_patches = tf.image.extract_patches(
|
||||
images=img,
|
||||
sizes=[1, self.height, self.width, 1],
|
||||
strides=[1, self.stride_y, self.stride_x, 1],
|
||||
rates=[1, 1, 1, 1],
|
||||
padding='SAME')
|
||||
img_patches = tf.squeeze(img_patches)
|
||||
new_shape = (-1, self.height, self.width, 3)
|
||||
img_patches = tf.reshape(img_patches, shape=new_shape)
|
||||
# may be too large:
|
||||
#pred_patches = self.model(img_patches)
|
||||
# so rebatch to fit in memory:
|
||||
img_patches = tf.expand_dims(img_patches, 1)
|
||||
pred_patches = tf.map_fn(self.model, img_patches,
|
||||
parallel_iterations=1,
|
||||
infer_shape=False)
|
||||
pred_patches = tf.squeeze(pred_patches, 1)
|
||||
# calculate corresponding indexes for reconstruction
|
||||
x = tf.range(width)
|
||||
y = tf.range(height)
|
||||
x, y = tf.meshgrid(x, y)
|
||||
indices = tf.stack([y, x], axis=-1)
|
||||
indices_patches = tf.image.extract_patches(
|
||||
images=tf.expand_dims(indices, axis=0),
|
||||
sizes=[1, self.height, self.width, 1],
|
||||
strides=[1, self.stride_y, self.stride_x, 1],
|
||||
rates=[1, 1, 1, 1],
|
||||
padding='SAME')
|
||||
indices_patches = tf.squeeze(indices_patches)
|
||||
indices_patches = tf.reshape(indices_patches, shape=new_shape[:-1] + (2,))
|
||||
|
||||
# use margins for sliding window approach
|
||||
indices_patches = indices_patches * self.window
|
||||
|
||||
pred = tf.scatter_nd(
|
||||
indices_patches,
|
||||
pred_patches,
|
||||
(height, width, self.classes))
|
||||
pred = tf.expand_dims(pred, axis=0)
|
||||
return pred
|
||||
|
||||
def predict(self, x, verbose=0):
|
||||
return self(x).numpy()
|
||||
|
|
|
|||
|
|
@ -26,10 +26,6 @@ class EynollahPlotter:
|
|||
dir_of_deskewed,
|
||||
dir_of_layout,
|
||||
dir_of_cropped_images,
|
||||
image_filename_stem,
|
||||
image_org=None,
|
||||
scale_x=1,
|
||||
scale_y=1,
|
||||
):
|
||||
self.dir_out = dir_out
|
||||
self.dir_of_all = dir_of_all
|
||||
|
|
@ -37,13 +33,8 @@ class EynollahPlotter:
|
|||
self.dir_of_layout = dir_of_layout
|
||||
self.dir_of_cropped_images = dir_of_cropped_images
|
||||
self.dir_of_deskewed = dir_of_deskewed
|
||||
self.image_filename_stem = image_filename_stem
|
||||
# XXX TODO hacky these cannot be set at init time
|
||||
self.image_org = image_org
|
||||
self.scale_x : float = scale_x
|
||||
self.scale_y : float = scale_y
|
||||
|
||||
def save_plot_of_layout_main(self, text_regions_p, image_page):
|
||||
def save_plot_of_layout_main(self, text_regions_p, image_page, name=None):
|
||||
if self.dir_of_layout is not None:
|
||||
values = np.unique(text_regions_p[:, :])
|
||||
# pixels=['Background' , 'Main text' , 'Heading' , 'Marginalia' ,'Drop capitals' , 'Images' , 'Seperators' , 'Tables', 'Graphics']
|
||||
|
|
@ -55,10 +46,10 @@ class EynollahPlotter:
|
|||
colors = [im.cmap(im.norm(value)) for value in values]
|
||||
patches = [mpatches.Patch(color=colors[np.where(values == i)[0][0]], label="{l}".format(l=pixels[int(np.where(values_indexes == i)[0][0])])) for i in values]
|
||||
plt.legend(handles=patches, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.0, fontsize=40)
|
||||
plt.savefig(os.path.join(self.dir_of_layout, self.image_filename_stem + "_layout_main.png"))
|
||||
plt.savefig(os.path.join(self.dir_of_layout,
|
||||
(name or "page") + "_layout_main.png"))
|
||||
|
||||
|
||||
def save_plot_of_layout_main_all(self, text_regions_p, image_page):
|
||||
def save_plot_of_layout_main_all(self, text_regions_p, image_page, name=None):
|
||||
if self.dir_of_all is not None:
|
||||
values = np.unique(text_regions_p[:, :])
|
||||
# pixels=['Background' , 'Main text' , 'Heading' , 'Marginalia' ,'Drop capitals' , 'Images' , 'Seperators' , 'Tables', 'Graphics']
|
||||
|
|
@ -73,9 +64,10 @@ class EynollahPlotter:
|
|||
colors = [im.cmap(im.norm(value)) for value in values]
|
||||
patches = [mpatches.Patch(color=colors[np.where(values == i)[0][0]], label="{l}".format(l=pixels[int(np.where(values_indexes == i)[0][0])])) for i in values]
|
||||
plt.legend(handles=patches, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.0, fontsize=60)
|
||||
plt.savefig(os.path.join(self.dir_of_all, self.image_filename_stem + "_layout_main_and_page.png"))
|
||||
plt.savefig(os.path.join(self.dir_of_all,
|
||||
(name or "page") + "_layout_main_and_page.png"))
|
||||
|
||||
def save_plot_of_layout(self, text_regions_p, image_page):
|
||||
def save_plot_of_layout(self, text_regions_p, image_page, name=None):
|
||||
if self.dir_of_layout is not None:
|
||||
values = np.unique(text_regions_p[:, :])
|
||||
# pixels=['Background' , 'Main text' , 'Heading' , 'Marginalia' ,'Drop capitals' , 'Images' , 'Seperators' , 'Tables', 'Graphics']
|
||||
|
|
@ -87,9 +79,10 @@ class EynollahPlotter:
|
|||
colors = [im.cmap(im.norm(value)) for value in values]
|
||||
patches = [mpatches.Patch(color=colors[np.where(values == i)[0][0]], label="{l}".format(l=pixels[int(np.where(values_indexes == i)[0][0])])) for i in values]
|
||||
plt.legend(handles=patches, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.0, fontsize=40)
|
||||
plt.savefig(os.path.join(self.dir_of_layout, self.image_filename_stem + "_layout.png"))
|
||||
plt.savefig(os.path.join(self.dir_of_layout,
|
||||
(name or "page") + "_layout.png"))
|
||||
|
||||
def save_plot_of_layout_all(self, text_regions_p, image_page):
|
||||
def save_plot_of_layout_all(self, text_regions_p, image_page, name=None):
|
||||
if self.dir_of_all is not None:
|
||||
values = np.unique(text_regions_p[:, :])
|
||||
# pixels=['Background' , 'Main text' , 'Heading' , 'Marginalia' ,'Drop capitals' , 'Images' , 'Seperators' , 'Tables', 'Graphics']
|
||||
|
|
@ -104,9 +97,10 @@ class EynollahPlotter:
|
|||
colors = [im.cmap(im.norm(value)) for value in values]
|
||||
patches = [mpatches.Patch(color=colors[np.where(values == i)[0][0]], label="{l}".format(l=pixels[int(np.where(values_indexes == i)[0][0])])) for i in values]
|
||||
plt.legend(handles=patches, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.0, fontsize=60)
|
||||
plt.savefig(os.path.join(self.dir_of_all, self.image_filename_stem + "_layout_and_page.png"))
|
||||
plt.savefig(os.path.join(self.dir_of_all,
|
||||
(name or "page") + "_layout_and_page.png"))
|
||||
|
||||
def save_plot_of_textlines(self, textline_mask_tot_ea, image_page):
|
||||
def save_plot_of_textlines(self, textline_mask_tot_ea, image_page, name=None):
|
||||
if self.dir_of_all is not None:
|
||||
values = np.unique(textline_mask_tot_ea[:, :])
|
||||
pixels = ["Background", "Textlines"]
|
||||
|
|
@ -120,24 +114,31 @@ class EynollahPlotter:
|
|||
colors = [im.cmap(im.norm(value)) for value in values]
|
||||
patches = [mpatches.Patch(color=colors[np.where(values == i)[0][0]], label="{l}".format(l=pixels[int(np.where(values_indexes == i)[0][0])])) for i in values]
|
||||
plt.legend(handles=patches, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.0, fontsize=60)
|
||||
plt.savefig(os.path.join(self.dir_of_all, self.image_filename_stem + "_textline_and_page.png"))
|
||||
plt.savefig(os.path.join(self.dir_of_all,
|
||||
(name or "page") + "_textline_and_page.png"))
|
||||
|
||||
def save_deskewed_image(self, slope_deskew):
|
||||
def save_deskewed_image(self, slope_deskew, image_org, name=None):
|
||||
if self.dir_of_all is not None:
|
||||
cv2.imwrite(os.path.join(self.dir_of_all, self.image_filename_stem + "_org.png"), self.image_org)
|
||||
cv2.imwrite(os.path.join(self.dir_of_all,
|
||||
(name or "page") + "_org.png"), image_org)
|
||||
if self.dir_of_deskewed is not None:
|
||||
img_rotated = rotate_image_different(self.image_org, slope_deskew)
|
||||
cv2.imwrite(os.path.join(self.dir_of_deskewed, self.image_filename_stem + "_deskewed.png"), img_rotated)
|
||||
img_rotated = rotate_image_different(image_org, slope_deskew)
|
||||
cv2.imwrite(os.path.join(self.dir_of_deskewed,
|
||||
(name or "page") + "_deskewed.png"), img_rotated)
|
||||
|
||||
def save_page_image(self, image_page):
|
||||
def save_page_image(self, image_page, name=None):
|
||||
if self.dir_of_all is not None:
|
||||
cv2.imwrite(os.path.join(self.dir_of_all, self.image_filename_stem + "_page.png"), image_page)
|
||||
cv2.imwrite(os.path.join(self.dir_of_all,
|
||||
(name or "page") + "_page.png"), image_page)
|
||||
if self.dir_save_page is not None:
|
||||
cv2.imwrite(os.path.join(self.dir_save_page, self.image_filename_stem + "_page.png"), image_page)
|
||||
def save_enhanced_image(self, img_res):
|
||||
cv2.imwrite(os.path.join(self.dir_out, self.image_filename_stem + "_enhanced.png"), img_res)
|
||||
cv2.imwrite(os.path.join(self.dir_save_page,
|
||||
(name or "page") + "_page.png"), image_page)
|
||||
|
||||
def save_plot_of_textline_density(self, img_patch_org):
|
||||
def save_enhanced_image(self, img_res, name=None):
|
||||
cv2.imwrite(os.path.join(self.dir_out,
|
||||
(name or "page") + "_enhanced.png"), img_res)
|
||||
|
||||
def save_plot_of_textline_density(self, img_patch_org, name=None):
|
||||
if self.dir_of_all is not None:
|
||||
plt.figure(figsize=(80,40))
|
||||
plt.rcParams['font.size']='50'
|
||||
|
|
@ -149,9 +150,10 @@ class EynollahPlotter:
|
|||
plt.ylabel('Height',fontsize=60)
|
||||
plt.yticks([0,len(gaussian_filter1d(img_patch_org.sum(axis=1), 3))])
|
||||
plt.gca().invert_yaxis()
|
||||
plt.savefig(os.path.join(self.dir_of_all, self.image_filename_stem+'_density_of_textline.png'))
|
||||
plt.savefig(os.path.join(self.dir_of_all,
|
||||
(name or "page") + '_density_of_textline.png'))
|
||||
|
||||
def save_plot_of_rotation_angle(self, angels, var_res):
|
||||
def save_plot_of_rotation_angle(self, angels, var_res, name=None):
|
||||
if self.dir_of_all is not None:
|
||||
plt.figure(figsize=(60,30))
|
||||
plt.rcParams['font.size']='50'
|
||||
|
|
@ -160,19 +162,20 @@ class EynollahPlotter:
|
|||
plt.ylabel('variance of sum of rotated textline in direction of x axis',fontsize=50)
|
||||
plt.plot(angels[np.argmax(var_res)],var_res[np.argmax(np.array(var_res))] ,'*',markersize=50,label='Angle of deskewing=' +str("{:.2f}".format(angels[np.argmax(var_res)]))+r'$\degree$')
|
||||
plt.legend(loc='best')
|
||||
plt.savefig(os.path.join(self.dir_of_all, self.image_filename_stem+'_rotation_angle.png'))
|
||||
plt.savefig(os.path.join(self.dir_of_all,
|
||||
(name or "page") + '_rotation_angle.png'))
|
||||
|
||||
def write_images_into_directory(self, img_contours, image_page):
|
||||
def write_images_into_directory(self, img_contours, image_page, scale_x=1.0, scale_y=1.0, name=None):
|
||||
if self.dir_of_cropped_images is not None:
|
||||
index = 0
|
||||
for cont_ind in img_contours:
|
||||
x, y, w, h = cv2.boundingRect(cont_ind)
|
||||
box = [x, y, w, h]
|
||||
croped_page, page_coord = crop_image_inside_box(box, image_page)
|
||||
|
||||
croped_page = resize_image(croped_page, int(croped_page.shape[0] / self.scale_y), int(croped_page.shape[1] / self.scale_x))
|
||||
|
||||
path = os.path.join(self.dir_of_cropped_images, self.image_filename_stem + "_" + str(index) + ".jpg")
|
||||
cv2.imwrite(path, croped_page)
|
||||
image, _ = crop_image_inside_box(box, image_page)
|
||||
image = resize_image(image,
|
||||
int(image.shape[0] / scale_y),
|
||||
int(image.shape[1] / scale_x))
|
||||
cv2.imwrite(os.path.join(self.dir_of_cropped_images,
|
||||
(name or "page") + f"_{index:03d}.jpg"), image)
|
||||
index += 1
|
||||
|
||||
|
|
|
|||
167
src/eynollah/predictor.py
Normal file
167
src/eynollah/predictor.py
Normal file
|
|
@ -0,0 +1,167 @@
|
|||
import threading
|
||||
from contextlib import ExitStack
|
||||
from functools import lru_cache
|
||||
from typing import List
|
||||
import logging
|
||||
import logging.handlers
|
||||
import multiprocessing as mp
|
||||
import numpy as np
|
||||
|
||||
from .utils.shm import share_ndarray, ndarray_shared
|
||||
|
||||
QSIZE = 200
|
||||
|
||||
|
||||
class Predictor(mp.context.SpawnProcess):
|
||||
"""
|
||||
singleton subprocess solely responsible for prediction with TensorFlow,
|
||||
communicates with any number of worker processes,
|
||||
acts as a shallow replacement for EynollahModelZoo
|
||||
"""
|
||||
class SingleModelPredictor:
|
||||
"""
|
||||
acts as a shallow replacement for EynollahModelZoo
|
||||
"""
|
||||
def __init__(self, predictor: 'Predictor', model: str):
|
||||
self.predictor = predictor
|
||||
self.model = model
|
||||
@property
|
||||
def name(self):
|
||||
return self.model
|
||||
@property
|
||||
def output_shape(self):
|
||||
return self.predictor(self.model, {})
|
||||
def predict(self, data: dict, verbose=0):
|
||||
return self.predictor(self.model, data)
|
||||
|
||||
def __init__(self, logger, model_zoo):
|
||||
self.logger = logger
|
||||
self.model_zoo = model_zoo
|
||||
ctxt = mp.get_context('spawn')
|
||||
self.taskq = ctxt.Queue(maxsize=QSIZE)
|
||||
self.resultq = ctxt.Queue(maxsize=QSIZE)
|
||||
self.logq = ctxt.Queue(maxsize=QSIZE * 100)
|
||||
logging.handlers.QueueListener(
|
||||
self.logq, *(
|
||||
# as per ocrd_utils.initLogging():
|
||||
logging.root.handlers +
|
||||
# as per eynollah_cli.main():
|
||||
self.logger.handlers
|
||||
), respect_handler_level=False).start()
|
||||
self.stopped = ctxt.Event()
|
||||
self.closable = ctxt.Manager().list()
|
||||
super().__init__(name="EynollahPredictor", daemon=True)
|
||||
|
||||
@lru_cache
|
||||
def get(self, model: str):
|
||||
return Predictor.SingleModelPredictor(self, model)
|
||||
|
||||
def __call__(self, model: str, data: dict):
|
||||
# unusable as per python/cpython#79967
|
||||
#with self.jobid.get_lock():
|
||||
# would work, but not public:
|
||||
#with self.jobid._mutex:
|
||||
with self.joblock:
|
||||
self.jobid.value += 1
|
||||
jobid = self.jobid.value
|
||||
if not len(data):
|
||||
self.taskq.put((jobid, model, data))
|
||||
#self.logger.debug("sent shape query task '%d' for model '%s'", jobid, model)
|
||||
return self.result(jobid)
|
||||
with share_ndarray(data) as shared_data:
|
||||
self.taskq.put((jobid, model, shared_data))
|
||||
#self.logger.debug("sent prediction task '%d' for model '%s': %s", jobid, model, shared_data)
|
||||
return self.result(jobid)
|
||||
|
||||
def result(self, jobid):
|
||||
while not self.stopped.is_set():
|
||||
if jobid in self.results:
|
||||
#self.logger.debug("received result for '%d'", jobid)
|
||||
result = self.results.pop(jobid)
|
||||
if isinstance(result, Exception):
|
||||
raise Exception(f"predictor failed for {jobid}") from result
|
||||
elif isinstance(result, dict):
|
||||
with ndarray_shared(result) as shared_result:
|
||||
result = np.copy(shared_result)
|
||||
self.closable.append(jobid)
|
||||
return result
|
||||
try:
|
||||
jobid0, result = self.resultq.get(timeout=0.7)
|
||||
except mp.queues.Empty:
|
||||
continue
|
||||
#self.logger.debug("storing results for '%d': '%s'", jobid0, result)
|
||||
self.results[jobid0] = result
|
||||
raise Exception(f"predictor terminated while waiting on results for {jobid}")
|
||||
|
||||
def run(self):
|
||||
try:
|
||||
self.setup() # fill model_zoo etc
|
||||
except Exception as e:
|
||||
self.logger.exception("setup failed")
|
||||
self.stopped.set()
|
||||
closing = {}
|
||||
def close_all():
|
||||
for jobid in list(self.closable):
|
||||
self.closable.remove(jobid)
|
||||
closing.pop(jobid).close()
|
||||
#self.logger.debug("closed shm for '%d'", jobid)
|
||||
while not self.stopped.is_set():
|
||||
close_all()
|
||||
try:
|
||||
jobid, model, shared_data = self.taskq.get(timeout=1.1)
|
||||
except mp.queues.Empty:
|
||||
continue
|
||||
try:
|
||||
model = self.model_zoo.get(model)
|
||||
if not len(shared_data):
|
||||
#self.logger.debug("getting '%d' output shape of model '%s'", jobid, model)
|
||||
result = model.output_shape
|
||||
else:
|
||||
#self.logger.debug("predicting '%d' with model '%s': ", jobid, model, shared_data)
|
||||
with ndarray_shared(shared_data) as data:
|
||||
result = model.predict(data, verbose=0)
|
||||
#self.logger.debug("sharing result array for '%d'", jobid)
|
||||
with ExitStack() as stack:
|
||||
# we don't know when the result will be received,
|
||||
# but don't want to wait either, so
|
||||
result = stack.enter_context(share_ndarray(result))
|
||||
closing[jobid] = stack.pop_all()
|
||||
except Exception as e:
|
||||
self.logger.error("prediction '%d' failed: %s", jobid, e.__class__.__name__)
|
||||
result = e
|
||||
self.resultq.put((jobid, result))
|
||||
#self.logger.debug("sent result for '%d': %s", jobid, result)
|
||||
close_all()
|
||||
#self.logger.debug("predictor terminated")
|
||||
|
||||
def load_models(self, *loadable: List[str]):
|
||||
self.loadable = loadable
|
||||
self.start() # call run() in subprocess
|
||||
# parent context here
|
||||
del self.model_zoo # only in subprocess
|
||||
ctxt = mp.get_context('fork') # ocrd.Processor will fork workers
|
||||
mngr = ctxt.Manager()
|
||||
self.jobid = mngr.Value('i', 0)
|
||||
self.joblock = mngr.Lock()
|
||||
self.results = mngr.dict()
|
||||
|
||||
def setup(self):
|
||||
logging.root.handlers = [logging.handlers.QueueHandler(self.logq)]
|
||||
self.model_zoo.load_models(*self.loadable)
|
||||
|
||||
def shutdown(self):
|
||||
# do not terminate from forked processor instances
|
||||
if mp.parent_process() is None:
|
||||
self.stopped.set()
|
||||
self.terminate()
|
||||
self.logq.close()
|
||||
self.taskq.close()
|
||||
self.taskq.cancel_join_thread()
|
||||
self.resultq.close()
|
||||
self.resultq.cancel_join_thread()
|
||||
else:
|
||||
self.model_zoo.shutdown()
|
||||
|
||||
def __del__(self):
|
||||
#self.logger.debug(f"deinit of {self} in {mp.current_process().name}")
|
||||
self.shutdown()
|
||||
|
|
@ -8,10 +8,6 @@ from eynollah.model_zoo.model_zoo import EynollahModelZoo
|
|||
from .eynollah import Eynollah, EynollahXmlWriter
|
||||
|
||||
class EynollahProcessor(Processor):
|
||||
# already employs background CPU multiprocessing per page
|
||||
# already employs GPU (without singleton process atm)
|
||||
max_workers = 1
|
||||
|
||||
@cached_property
|
||||
def executable(self) -> str:
|
||||
return 'ocrd-eynollah-segment'
|
||||
|
|
@ -80,14 +76,8 @@ class EynollahProcessor(Processor):
|
|||
image_filename = "dummy" # will be replaced by ocrd.Processor.process_page_file
|
||||
result.images.append(OcrdPageResultImage(page_image, '.IMG', page)) # mark as new original
|
||||
# FIXME: mask out already existing regions (incremental segmentation)
|
||||
self.eynollah.cache_images(
|
||||
image_pil=page_image,
|
||||
dpi=self.parameter['dpi'],
|
||||
)
|
||||
self.eynollah.writer = EynollahXmlWriter(
|
||||
dir_out=None,
|
||||
image_filename=image_filename,
|
||||
curved_line=self.eynollah.curved_line,
|
||||
pcgts=pcgts)
|
||||
self.eynollah.run_single()
|
||||
self.eynollah.run_single(image_filename,
|
||||
img_pil=page_image, pcgts=pcgts,
|
||||
# ocrd.Processor will handle OCRD_EXISTING_OUTPUT more flexibly
|
||||
overwrite=True)
|
||||
return result
|
||||
|
|
|
|||
|
|
@ -9,17 +9,18 @@ Tool to load model and binarize a given image.
|
|||
|
||||
import os
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import cv2
|
||||
from ocrd_utils import tf_disable_interactive_logs
|
||||
|
||||
from eynollah.model_zoo import EynollahModelZoo
|
||||
os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15
|
||||
from ocrd_utils import tf_disable_interactive_logs
|
||||
tf_disable_interactive_logs()
|
||||
import tensorflow as tf
|
||||
from tensorflow.python.keras import backend as tensorflow_backend
|
||||
from pathlib import Path
|
||||
|
||||
from .model_zoo import EynollahModelZoo
|
||||
from .utils import is_image_filename
|
||||
|
||||
def resize_image(img_in, input_height, input_width):
|
||||
|
|
@ -34,21 +35,13 @@ class SbbBinarizer:
|
|||
logger: Optional[logging.Logger] = None,
|
||||
):
|
||||
self.logger = logger if logger else logging.getLogger('eynollah.binarization')
|
||||
try:
|
||||
for device in tf.config.list_physical_devices('GPU'):
|
||||
tf.config.experimental.set_memory_growth(device, True)
|
||||
except:
|
||||
self.logger.warning("no GPU device available")
|
||||
self.models = (model_zoo.model_path('binarization'), model_zoo.load_model('binarization'))
|
||||
self.session = self.start_new_session()
|
||||
|
||||
def start_new_session(self):
|
||||
config = tf.compat.v1.ConfigProto()
|
||||
config.gpu_options.allow_growth = True
|
||||
|
||||
session = tf.compat.v1.Session(config=config) # tf.InteractiveSession()
|
||||
tensorflow_backend.set_session(session)
|
||||
return session
|
||||
|
||||
def end_session(self):
|
||||
tensorflow_backend.clear_session()
|
||||
self.session.close()
|
||||
del self.session
|
||||
self.logger.info('Loaded model %s [%s]', self.models[1], self.models[0])
|
||||
|
||||
def predict(self, model, img, use_patches, n_batch_inference=5):
|
||||
model_height = model.layers[len(model.layers)-1].output_shape[1]
|
||||
|
|
@ -311,34 +304,20 @@ class SbbBinarizer:
|
|||
prediction_true = prediction_true.astype(np.uint8)
|
||||
return prediction_true[:,:,0]
|
||||
|
||||
def run(self, image=None, image_path=None, output=None, use_patches=False, dir_in=None):
|
||||
# print(dir_in,'dir_in')
|
||||
def run(self, image=None, image_path=None, output=None, use_patches=False, dir_in=None, overwrite=False):
|
||||
if not dir_in:
|
||||
if (image is not None and image_path is not None) or \
|
||||
(image is None and image_path is None):
|
||||
if (image is None) == (image_path is None):
|
||||
raise ValueError("Must pass either a opencv2 image or an image_path")
|
||||
if image_path is not None:
|
||||
image = cv2.imread(image_path)
|
||||
img_last = 0
|
||||
model_file, model = self.models
|
||||
self.logger.info('Predicting %s with model %s', image_path if image_path else '[image]', model_file)
|
||||
res = self.predict(model, image, use_patches)
|
||||
|
||||
img_fin = np.zeros((res.shape[0], res.shape[1], 3))
|
||||
res[:, :][res[:, :] == 0] = 2
|
||||
res = res - 1
|
||||
res = res * 255
|
||||
img_fin[:, :, 0] = res
|
||||
img_fin[:, :, 1] = res
|
||||
img_fin[:, :, 2] = res
|
||||
|
||||
img_fin = img_fin.astype(np.uint8)
|
||||
img_fin = (res[:, :] == 0) * 255
|
||||
img_last = img_last + img_fin
|
||||
|
||||
img_last[:, :][img_last[:, :] > 0] = 255
|
||||
img_last = (img_last[:, :] == 0) * 255
|
||||
img_last = self.run_single(image, use_patches)
|
||||
if output:
|
||||
if os.path.exists(output):
|
||||
if overwrite:
|
||||
self.logger.warning("will overwrite existing output file '%s'", output)
|
||||
else:
|
||||
self.logger.warning("output file already exists '%s'", output)
|
||||
return img_last
|
||||
self.logger.info('Writing binarized image to %s', output)
|
||||
cv2.imwrite(output, img_last)
|
||||
return img_last
|
||||
|
|
@ -346,29 +325,38 @@ class SbbBinarizer:
|
|||
ls_imgs = list(filter(is_image_filename, os.listdir(dir_in)))
|
||||
self.logger.info("Found %d image files to binarize in %s", len(ls_imgs), dir_in)
|
||||
for i, image_path in enumerate(ls_imgs):
|
||||
image_stem = os.path.splitext(image_path)[0]
|
||||
output_path = os.path.join(output, image_stem + '.png')
|
||||
if os.path.exists(output_path):
|
||||
if overwrite:
|
||||
self.logger.warning("will overwrite existing output file '%s'", output_path)
|
||||
else:
|
||||
self.logger.warning("will skip input for existing output file '%s'", output_path)
|
||||
continue
|
||||
self.logger.info('Binarizing [%3d/%d] %s', i + 1, len(ls_imgs), image_path)
|
||||
image_stem = Path(image_path).stem
|
||||
image = cv2.imread(os.path.join(dir_in,image_path) )
|
||||
img_last = 0
|
||||
model_file, model = self.models
|
||||
self.logger.info('Predicting %s with model %s', image_path if image_path else '[image]', model_file)
|
||||
res = self.predict(model, image, use_patches)
|
||||
image = cv2.imread(os.path.join(dir_in, image_path))
|
||||
img_last = self.run_single(image, use_patches)
|
||||
self.logger.info('Writing binarized image to %s', output_path)
|
||||
cv2.imwrite(output_path, img_last)
|
||||
|
||||
img_fin = np.zeros((res.shape[0], res.shape[1], 3))
|
||||
res[:, :][res[:, :] == 0] = 2
|
||||
res = res - 1
|
||||
res = res * 255
|
||||
img_fin[:, :, 0] = res
|
||||
img_fin[:, :, 1] = res
|
||||
img_fin[:, :, 2] = res
|
||||
def run_single(self, image: np.ndarray, use_patches=False):
|
||||
img_last = 0
|
||||
model_file, model = self.models
|
||||
res = self.predict(model, image, use_patches)
|
||||
|
||||
img_fin = img_fin.astype(np.uint8)
|
||||
img_fin = (res[:, :] == 0) * 255
|
||||
img_last = img_last + img_fin
|
||||
img_fin = np.zeros((res.shape[0], res.shape[1], 3))
|
||||
res[:, :][res[:, :] == 0] = 2
|
||||
res = res - 1
|
||||
res = res * 255
|
||||
img_fin[:, :, 0] = res
|
||||
img_fin[:, :, 1] = res
|
||||
img_fin[:, :, 2] = res
|
||||
|
||||
img_last[:, :][img_last[:, :] > 0] = 255
|
||||
img_last = (img_last[:, :] == 0) * 255
|
||||
img_fin = img_fin.astype(np.uint8)
|
||||
img_fin = (res[:, :] == 0) * 255
|
||||
img_last = img_last + img_fin
|
||||
|
||||
output_filename = os.path.join(output, image_stem + '.png')
|
||||
self.logger.info('Writing binarized image to %s', output_filename)
|
||||
cv2.imwrite(output_filename, img_last)
|
||||
kernel = np.ones((5, 5), np.uint8)
|
||||
img_last[:, :][img_last[:, :] > 0] = 255
|
||||
img_last = (img_last[:, :] == 0) * 255
|
||||
return img_last
|
||||
|
|
|
|||
|
|
@ -1,13 +1,9 @@
|
|||
import sys
|
||||
import click
|
||||
import tensorflow as tf
|
||||
|
||||
from .models import resnet50_unet
|
||||
|
||||
|
||||
def configuration():
|
||||
gpu_options = tf.compat.v1.GPUOptions(allow_growth=True)
|
||||
session = tf.compat.v1.Session(config=tf.compat.v1.ConfigProto(gpu_options=gpu_options))
|
||||
|
||||
@click.command()
|
||||
def build_model_load_pretrained_weights_and_save():
|
||||
n_classes = 2
|
||||
|
|
@ -17,8 +13,6 @@ def build_model_load_pretrained_weights_and_save():
|
|||
pretraining = False
|
||||
dir_of_weights = 'model_bin_sbb_ens.h5'
|
||||
|
||||
# configuration()
|
||||
|
||||
model = resnet50_unet(n_classes, input_height, input_width, weight_decay, pretraining)
|
||||
model.load_weights(dir_of_weights)
|
||||
model.save('./name_in_another_python_version.h5')
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from .generate_gt_for_training import main as generate_gt_cli
|
|||
from .inference import main as inference_cli
|
||||
from .train import ex
|
||||
from .extract_line_gt import linegt_cli
|
||||
from .weights_ensembling import main as ensemble_cli
|
||||
from .weights_ensembling import ensemble_cli
|
||||
|
||||
@click.command(context_settings=dict(
|
||||
ignore_unknown_options=True,
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ from PIL import Image, ImageDraw, ImageFont
|
|||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from eynollah.training.gt_gen_utils import (
|
||||
from .gt_gen_utils import (
|
||||
filter_contours_area_of_image,
|
||||
find_format_of_given_filename_in_dir,
|
||||
find_new_features_of_contours,
|
||||
|
|
@ -26,32 +26,37 @@ from eynollah.training.gt_gen_utils import (
|
|||
|
||||
@click.group()
|
||||
def main():
|
||||
"""
|
||||
extract GT data suitable for model training for various tasks
|
||||
"""
|
||||
pass
|
||||
|
||||
@main.command()
|
||||
@click.option(
|
||||
"--dir_xml",
|
||||
"-dx",
|
||||
help="directory of GT page-xml files",
|
||||
help="input directory of GT PAGE-XML files",
|
||||
type=click.Path(exists=True, file_okay=False),
|
||||
required=True,
|
||||
)
|
||||
@click.option(
|
||||
"--dir_images",
|
||||
"-di",
|
||||
help="directory of org images. If print space cropping or scaling is needed for labels it would be great to provide the original images to apply the same function on them. So if -ps is not set true or in config files no columns_width key is given this argumnet can be ignored. File stems in this directory should be the same as those in dir_xml.",
|
||||
help="input directory of GT image files (only needed for '--printspace' or scaling configured via 'columns_width'; filename stems should match those in --dir_xml)",
|
||||
type=click.Path(exists=True, file_okay=False),
|
||||
)
|
||||
@click.option(
|
||||
"--dir_out_images",
|
||||
"-doi",
|
||||
help="directory where the output org images after undergoing a process (like print space cropping or scaling) will be written.",
|
||||
help="output directory for training image files (for printspace cropping or scaling)",
|
||||
type=click.Path(exists=True, file_okay=False),
|
||||
)
|
||||
@click.option(
|
||||
"--dir_out",
|
||||
"-do",
|
||||
help="directory where ground truth label images would be written",
|
||||
help="output directory for training label files",
|
||||
type=click.Path(exists=True, file_okay=False),
|
||||
required=True,
|
||||
)
|
||||
|
||||
@click.option(
|
||||
|
|
@ -64,24 +69,45 @@ def main():
|
|||
@click.option(
|
||||
"--type_output",
|
||||
"-to",
|
||||
help="this defines how output should be. A 2d image array or a 3d image array encoded with RGB color. Just pass 2d or 3d. The file will be saved one directory up. 2D image array is 3d but only information of one channel would be enough since all channels have the same values.",
|
||||
type=click.Choice(["2d", "3d"]),
|
||||
default="2d",
|
||||
help="generate labels as [H, W] array pseudo index-color images for training ('2d') or [H, W, C] array RGB color images for plotting ('3d')",
|
||||
)
|
||||
@click.option(
|
||||
"--printspace",
|
||||
"-ps",
|
||||
is_flag=True,
|
||||
help="if this parameter set to true, generated labels and in the case of provided org images cropping will be imposed and cropped labels and images will be written in output directories.",
|
||||
help="crop pages from annotated PrintSpace or Border to generate labels and images (will also require -di for so original images so output images are cropped along with labels)",
|
||||
)
|
||||
@click.option(
|
||||
"--missing-printspace",
|
||||
"-mps",
|
||||
type=click.Choice(["full", "skip", "project"]),
|
||||
default="full",
|
||||
help="if -ps is set, what to do in case a PAGE-XML has no PrintSpace or Border annotation: keep entire page ('full'), ignore file ('skip') or crop artificially from outer hull of all segments ('project')",
|
||||
)
|
||||
|
||||
def pagexml2label(dir_xml,dir_out,type_output,config, printspace, dir_images, dir_out_images):
|
||||
def pagexml2label(dir_xml, dir_out, type_output, config, printspace, missing_printspace, dir_images, dir_out_images):
|
||||
"""
|
||||
extract PAGE-XML GT data suitable for model training for segmentation tasks
|
||||
"""
|
||||
if config:
|
||||
with open(config) as f:
|
||||
config_params = json.load(f)
|
||||
else:
|
||||
print("passed")
|
||||
config_params = None
|
||||
gt_list = get_content_of_dir(dir_xml)
|
||||
get_images_of_ground_truth(gt_list,dir_xml,dir_out,type_output, config, config_params, printspace, dir_images, dir_out_images)
|
||||
get_images_of_ground_truth(get_content_of_dir(dir_xml),
|
||||
dir_xml,
|
||||
dir_out,
|
||||
type_output,
|
||||
config,
|
||||
config_params,
|
||||
printspace,
|
||||
missing_printspace,
|
||||
dir_images,
|
||||
dir_out_images
|
||||
)
|
||||
|
||||
@main.command()
|
||||
@click.option(
|
||||
|
|
@ -110,6 +136,9 @@ def pagexml2label(dir_xml,dir_out,type_output,config, printspace, dir_images, di
|
|||
type=click.Path(exists=True, dir_okay=False),
|
||||
)
|
||||
def image_enhancement(dir_imgs, dir_out_images, dir_out_labels, scales):
|
||||
"""
|
||||
extract image GT data suitable for model training for image enhancement tasks
|
||||
"""
|
||||
ls_imgs = os.listdir(dir_imgs)
|
||||
with open(scales) as f:
|
||||
scale_dict = json.load(f)
|
||||
|
|
@ -175,6 +204,9 @@ def image_enhancement(dir_imgs, dir_out_images, dir_out_labels, scales):
|
|||
)
|
||||
|
||||
def machine_based_reading_order(dir_xml, dir_out_modal_image, dir_out_classes, input_height, input_width, min_area_size, min_area_early):
|
||||
"""
|
||||
extract PAGE-XML GT data suitable for model training for reading-order task
|
||||
"""
|
||||
xml_files_ind = os.listdir(dir_xml)
|
||||
xml_files_ind = [ind_xml for ind_xml in xml_files_ind if ind_xml.endswith('.xml')]
|
||||
input_height = int(input_height)
|
||||
|
|
@ -205,14 +237,20 @@ def machine_based_reading_order(dir_xml, dir_out_modal_image, dir_out_classes, i
|
|||
img_header_and_sep = np.zeros((y_len,x_len), dtype='uint8')
|
||||
|
||||
for j in range(len(cy_main)):
|
||||
img_header_and_sep[int(y_max_main[j]):int(y_max_main[j])+12,int(x_min_main[j]):int(x_max_main[j]) ] = 1
|
||||
img_header_and_sep[int(y_max_main[j]):int(y_max_main[j])+12,
|
||||
int(x_min_main[j]):int(x_max_main[j]) ] = 1
|
||||
|
||||
|
||||
texts_corr_order_index = [index_tot_regions[tot_region_ref.index(i)] for i in id_all_text ]
|
||||
texts_corr_order_index_int = [int(x) for x in texts_corr_order_index]
|
||||
try:
|
||||
texts_corr_order_index_int = [int(index_tot_regions[tot_region_ref.index(i)])
|
||||
for i in id_all_text]
|
||||
except ValueError as e:
|
||||
print("incomplete ReadingOrder in", xml_file, "- skipping:", str(e))
|
||||
continue
|
||||
|
||||
|
||||
co_text_all, texts_corr_order_index_int, regions_ar_less_than_early_min = filter_contours_area_of_image(img_poly, co_text_all, texts_corr_order_index_int, max_area, min_area, min_area_early)
|
||||
co_text_all, texts_corr_order_index_int, regions_ar_less_than_early_min = \
|
||||
filter_contours_area_of_image(img_poly, co_text_all, texts_corr_order_index_int,
|
||||
max_area, min_area, min_area_early)
|
||||
|
||||
|
||||
arg_array = np.array(range(len(texts_corr_order_index_int)))
|
||||
|
|
|
|||
|
|
@ -1,15 +1,18 @@
|
|||
import os
|
||||
import numpy as np
|
||||
import warnings
|
||||
import xml.etree.ElementTree as ET
|
||||
from lxml import etree as ET
|
||||
from tqdm import tqdm
|
||||
import cv2
|
||||
from shapely import geometry
|
||||
from pathlib import Path
|
||||
from PIL import ImageFont
|
||||
from ocrd_utils import bbox_from_points
|
||||
|
||||
|
||||
KERNEL = np.ones((5, 5), np.uint8)
|
||||
NS = { 'pc': 'http://schema.primaresearch.org/PAGE/gts/pagecontent/2019-07-15'
|
||||
}
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
|
|
@ -235,12 +238,11 @@ def update_region_contours(co_text, img_boundary, erosion_rate, dilation_rate, y
|
|||
con_eroded = return_contours_of_interested_region(img_boundary_in,pixel, min_size )
|
||||
|
||||
try:
|
||||
if len(con_eroded)>1:
|
||||
cnt_size = np.array([cv2.contourArea(con_eroded[j]) for j in range(len(con_eroded))])
|
||||
cnt = contours[np.argmax(cnt_size)]
|
||||
co_text_eroded.append(cnt)
|
||||
if len(con_eroded) > 1:
|
||||
largest = np.argmax(list(map(cv2.contourArea, con_eroded)))
|
||||
else:
|
||||
co_text_eroded.append(con_eroded[0])
|
||||
largest = 0
|
||||
co_text_eroded.append(con_eroded[largest])
|
||||
except:
|
||||
co_text_eroded.append(con)
|
||||
|
||||
|
|
@ -656,7 +658,18 @@ def get_layout_contours_for_visualization(xml_file):
|
|||
co_noise.append(np.array(c_t_in))
|
||||
return co_text, co_graphic, co_sep, co_img, co_table, co_map, co_noise, y_len, x_len
|
||||
|
||||
def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_file, config_params, printspace, dir_images, dir_out_images):
|
||||
def get_images_of_ground_truth(
|
||||
gt_list,
|
||||
dir_in,
|
||||
output_dir,
|
||||
output_type,
|
||||
config_file,
|
||||
config_params,
|
||||
printspace,
|
||||
missing_printspace,
|
||||
dir_images,
|
||||
dir_out_images
|
||||
):
|
||||
"""
|
||||
Reading the page xml files and write the ground truth images into given output directory.
|
||||
"""
|
||||
|
|
@ -664,7 +677,10 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_
|
|||
|
||||
if dir_images:
|
||||
ls_org_imgs = os.listdir(dir_images)
|
||||
ls_org_imgs_stem = [os.path.splitext(item)[0] for item in ls_org_imgs]
|
||||
ls_org_imgs = {os.path.splitext(item)[0]: item
|
||||
for item in ls_org_imgs
|
||||
if not item.endswith('.xml')}
|
||||
|
||||
for index in tqdm(range(len(gt_list))):
|
||||
#try:
|
||||
print(gt_list[index])
|
||||
|
|
@ -681,6 +697,7 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_
|
|||
|
||||
if 'columns_width' in list(config_params.keys()):
|
||||
columns_width_dict = config_params['columns_width']
|
||||
# FIXME: look in /Page/@custom as well
|
||||
metadata_element = root1.find(link+'Metadata')
|
||||
num_col = None
|
||||
for child in metadata_element:
|
||||
|
|
@ -694,55 +711,27 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_
|
|||
y_new = int ( x_new * (y_len / float(x_len)) )
|
||||
|
||||
if printspace or "printspace_as_class_in_layout" in list(config_params.keys()):
|
||||
region_tags = np.unique([x for x in alltags if x.endswith('PrintSpace') or x.endswith('Border')])
|
||||
co_use_case = []
|
||||
|
||||
for tag in region_tags:
|
||||
tag_endings = ['}PrintSpace','}Border']
|
||||
|
||||
if tag.endswith(tag_endings[0]) or tag.endswith(tag_endings[1]):
|
||||
for nn in root1.iter(tag):
|
||||
c_t_in = []
|
||||
sumi = 0
|
||||
for vv in nn.iter():
|
||||
# check the format of coords
|
||||
if vv.tag == link + 'Coords':
|
||||
coords = bool(vv.attrib)
|
||||
if coords:
|
||||
p_h = vv.attrib['points'].split(' ')
|
||||
c_t_in.append(
|
||||
np.array([[int(x.split(',')[0]), int(x.split(',')[1])] for x in p_h]))
|
||||
break
|
||||
else:
|
||||
pass
|
||||
|
||||
if vv.tag == link + 'Point':
|
||||
c_t_in.append([int(float(vv.attrib['x'])), int(float(vv.attrib['y']))])
|
||||
sumi += 1
|
||||
elif vv.tag != link + 'Point' and sumi >= 1:
|
||||
break
|
||||
co_use_case.append(np.array(c_t_in))
|
||||
|
||||
img = np.zeros((y_len, x_len, 3))
|
||||
|
||||
img_poly = cv2.fillPoly(img, pts=co_use_case, color=(1, 1, 1))
|
||||
|
||||
img_poly = img_poly.astype(np.uint8)
|
||||
|
||||
imgray = cv2.cvtColor(img_poly, cv2.COLOR_BGR2GRAY)
|
||||
_, thresh = cv2.threshold(imgray, 0, 255, 0)
|
||||
|
||||
contours, _ = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
|
||||
|
||||
cnt_size = np.array([cv2.contourArea(contours[j]) for j in range(len(contours))])
|
||||
|
||||
try:
|
||||
cnt = contours[np.argmax(cnt_size)]
|
||||
x, y, w, h = cv2.boundingRect(cnt)
|
||||
except:
|
||||
x, y , w, h = 0, 0, x_len, y_len
|
||||
|
||||
bb_xywh = [x, y, w, h]
|
||||
ps = (root1.xpath('/pc:PcGts/pc:Page/pc:Border', namespaces=NS) +
|
||||
root1.xpath('/pc:PcGts/pc:Page/pc:PrintSpace', namespaces=NS))
|
||||
coords = root1.xpath('//pc:Coords/@points', namespaces=NS)
|
||||
if len(ps):
|
||||
points = ps[0].find('pc:Coords', NS).get('points')
|
||||
ps_bbox = bbox_from_points(points)
|
||||
elif missing_printspace == 'skip':
|
||||
print(gt_list[index], "has no Border or PrintSpace - skipping file")
|
||||
continue
|
||||
elif missing_printspace == 'project' and len(coords):
|
||||
print(gt_list[index], "has no Border or PrintSpace - projecting hull of segments")
|
||||
bboxes = list(map(bbox_from_points, coords))
|
||||
left, top, right, bottom = zip(*bboxes)
|
||||
left = max(0, min(left) - 5)
|
||||
top = max(0, min(top) - 5)
|
||||
right = min(x_len, max(right) + 5)
|
||||
bottom = min(y_len, max(bottom) + 5)
|
||||
ps_bbox = [left, top, right, bottom]
|
||||
else:
|
||||
print(gt_list[index], "has no Border or PrintSpace - using full page")
|
||||
ps_bbox = [0, 0, None, None]
|
||||
|
||||
|
||||
if config_file and (config_params['use_case']=='textline' or config_params['use_case']=='word' or config_params['use_case']=='glyph' or config_params['use_case']=='printspace'):
|
||||
|
|
@ -824,7 +813,8 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_
|
|||
|
||||
|
||||
if printspace and config_params['use_case']!='printspace':
|
||||
img_poly = img_poly[bb_xywh[1]:bb_xywh[1]+bb_xywh[3], bb_xywh[0]:bb_xywh[0]+bb_xywh[2], :]
|
||||
img_poly = img_poly[ps_bbox[1]:ps_bbox[3],
|
||||
ps_bbox[0]:ps_bbox[2], :]
|
||||
|
||||
|
||||
if 'columns_width' in list(config_params.keys()) and num_col and config_params['use_case']!='printspace':
|
||||
|
|
@ -838,11 +828,18 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_
|
|||
cv2.imwrite(os.path.join(output_dir, xml_file_stem + '.png'), img_poly)
|
||||
|
||||
if dir_images:
|
||||
org_image_name = ls_org_imgs[ls_org_imgs_stem.index(xml_file_stem)]
|
||||
org_image_name = ls_org_imgs[xml_file_stem]
|
||||
if not org_image_name:
|
||||
print("image file for XML stem", xml_file_stem, "is missing")
|
||||
continue
|
||||
if not os.path.isfile(os.path.join(dir_images, org_image_name)):
|
||||
print("image file for XML stem", xml_file_stem, "is not readable")
|
||||
continue
|
||||
img_org = cv2.imread(os.path.join(dir_images, org_image_name))
|
||||
|
||||
if printspace and config_params['use_case']!='printspace':
|
||||
img_org = img_org[bb_xywh[1]:bb_xywh[1]+bb_xywh[3], bb_xywh[0]:bb_xywh[0]+bb_xywh[2], :]
|
||||
img_org = img_org[ps_bbox[1]:ps_bbox[3],
|
||||
ps_bbox[0]:ps_bbox[2], :]
|
||||
|
||||
if 'columns_width' in list(config_params.keys()) and num_col and config_params['use_case']!='printspace':
|
||||
img_org = resize_image(img_org, y_new, x_new)
|
||||
|
|
@ -1254,7 +1251,8 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_
|
|||
|
||||
if "printspace_as_class_in_layout" in list(config_params.keys()):
|
||||
printspace_mask = np.zeros((img_poly.shape[0], img_poly.shape[1]))
|
||||
printspace_mask[bb_xywh[1]:bb_xywh[1]+bb_xywh[3], bb_xywh[0]:bb_xywh[0]+bb_xywh[2]] = 1
|
||||
printspace_mask[ps_bbox[1]:ps_bbox[3],
|
||||
ps_bbox[0]:ps_bbox[2]] = 1
|
||||
|
||||
img_poly[:,:,0][printspace_mask[:,:] == 0] = printspace_class_rgb_color[0]
|
||||
img_poly[:,:,1][printspace_mask[:,:] == 0] = printspace_class_rgb_color[1]
|
||||
|
|
@ -1315,7 +1313,8 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_
|
|||
|
||||
if "printspace_as_class_in_layout" in list(config_params.keys()):
|
||||
printspace_mask = np.zeros((img_poly.shape[0], img_poly.shape[1]))
|
||||
printspace_mask[bb_xywh[1]:bb_xywh[1]+bb_xywh[3], bb_xywh[0]:bb_xywh[0]+bb_xywh[2]] = 1
|
||||
printspace_mask[ps_bbox[1]:ps_bbox[3],
|
||||
ps_bbox[0]:ps_bbox[2]] = 1
|
||||
|
||||
img_poly[:,:,0][printspace_mask[:,:] == 0] = printspace_class_label
|
||||
img_poly[:,:,1][printspace_mask[:,:] == 0] = printspace_class_label
|
||||
|
|
@ -1324,7 +1323,8 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_
|
|||
|
||||
|
||||
if printspace:
|
||||
img_poly = img_poly[bb_xywh[1]:bb_xywh[1]+bb_xywh[3], bb_xywh[0]:bb_xywh[0]+bb_xywh[2], :]
|
||||
img_poly = img_poly[ps_bbox[1]:ps_bbox[3],
|
||||
ps_bbox[0]:ps_bbox[2], :]
|
||||
|
||||
if 'columns_width' in list(config_params.keys()) and num_col:
|
||||
img_poly = resize_image(img_poly, y_new, x_new)
|
||||
|
|
@ -1338,11 +1338,18 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_
|
|||
|
||||
|
||||
if dir_images:
|
||||
org_image_name = ls_org_imgs[ls_org_imgs_stem.index(xml_file_stem)]
|
||||
org_image_name = ls_org_imgs[xml_file_stem]
|
||||
if not org_image_name:
|
||||
print("image file for XML stem", xml_file_stem, "is missing")
|
||||
continue
|
||||
if not os.path.isfile(os.path.join(dir_images, org_image_name)):
|
||||
print("image file for XML stem", xml_file_stem, "is not readable")
|
||||
continue
|
||||
img_org = cv2.imread(os.path.join(dir_images, org_image_name))
|
||||
|
||||
if printspace:
|
||||
img_org = img_org[bb_xywh[1]:bb_xywh[1]+bb_xywh[3], bb_xywh[0]:bb_xywh[0]+bb_xywh[2], :]
|
||||
img_org = img_org[ps_bbox[1]:ps_bbox[3],
|
||||
ps_bbox[0]:ps_bbox[2], :]
|
||||
|
||||
if 'columns_width' in list(config_params.keys()) and num_col:
|
||||
img_org = resize_image(img_org, y_new, x_new)
|
||||
|
|
@ -1383,6 +1390,7 @@ def find_new_features_of_contours(contours_main):
|
|||
y_max_main = np.array([np.max(contours_main[j][:, 1]) for j in range(len(contours_main))])
|
||||
|
||||
return cx_main, cy_main, x_min_main, x_max_main, y_min_main, y_max_main, y_corr_x_min_from_argmin
|
||||
|
||||
def read_xml(xml_file):
|
||||
file_name = Path(xml_file).stem
|
||||
tree1 = ET.parse(xml_file, parser = ET.XMLParser(encoding='utf-8'))
|
||||
|
|
@ -1401,57 +1409,13 @@ def read_xml(xml_file):
|
|||
index_tot_regions.append(jj.attrib['index'])
|
||||
tot_region_ref.append(jj.attrib['regionRef'])
|
||||
|
||||
if (link+'PrintSpace' in alltags) or (link+'Border' in alltags):
|
||||
co_printspace = []
|
||||
if link+'PrintSpace' in alltags:
|
||||
region_tags_printspace = np.unique([x for x in alltags if x.endswith('PrintSpace')])
|
||||
elif link+'Border' in alltags:
|
||||
region_tags_printspace = np.unique([x for x in alltags if x.endswith('Border')])
|
||||
|
||||
for tag in region_tags_printspace:
|
||||
if link+'PrintSpace' in alltags:
|
||||
tag_endings_printspace = ['}PrintSpace','}printspace']
|
||||
elif link+'Border' in alltags:
|
||||
tag_endings_printspace = ['}Border','}border']
|
||||
|
||||
if tag.endswith(tag_endings_printspace[0]) or tag.endswith(tag_endings_printspace[1]):
|
||||
for nn in root1.iter(tag):
|
||||
c_t_in = []
|
||||
sumi = 0
|
||||
for vv in nn.iter():
|
||||
# check the format of coords
|
||||
if vv.tag == link + 'Coords':
|
||||
coords = bool(vv.attrib)
|
||||
if coords:
|
||||
p_h = vv.attrib['points'].split(' ')
|
||||
c_t_in.append(
|
||||
np.array([[int(x.split(',')[0]), int(x.split(',')[1])] for x in p_h]))
|
||||
break
|
||||
else:
|
||||
pass
|
||||
|
||||
if vv.tag == link + 'Point':
|
||||
c_t_in.append([int(float(vv.attrib['x'])), int(float(vv.attrib['y']))])
|
||||
sumi += 1
|
||||
elif vv.tag != link + 'Point' and sumi >= 1:
|
||||
break
|
||||
co_printspace.append(np.array(c_t_in))
|
||||
img_printspace = np.zeros( (y_len,x_len,3) )
|
||||
img_printspace=cv2.fillPoly(img_printspace, pts =co_printspace, color=(1,1,1))
|
||||
img_printspace = img_printspace.astype(np.uint8)
|
||||
|
||||
imgray = cv2.cvtColor(img_printspace, cv2.COLOR_BGR2GRAY)
|
||||
_, thresh = cv2.threshold(imgray, 0, 255, 0)
|
||||
contours, _ = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
|
||||
cnt_size = np.array([cv2.contourArea(contours[j]) for j in range(len(contours))])
|
||||
cnt = contours[np.argmax(cnt_size)]
|
||||
x, y, w, h = cv2.boundingRect(cnt)
|
||||
|
||||
bb_coord_printspace = [x, y, w, h]
|
||||
|
||||
ps = (root1.xpath('/pc:PcGts/pc:Page/pc:Border', namespaces=NS) +
|
||||
root1.xpath('/pc:PcGts/pc:Page/pc:PrintSpace', namespaces=NS))
|
||||
if len(ps):
|
||||
points = ps[0].find('pc:Coords', NS).get('points')
|
||||
ps_bbox = bbox_from_points(points)
|
||||
else:
|
||||
bb_coord_printspace = None
|
||||
|
||||
ps_bbox = [0, 0, None, None]
|
||||
|
||||
region_tags=np.unique([x for x in alltags if x.endswith('Region')])
|
||||
co_text_paragraph=[]
|
||||
|
|
@ -1806,11 +1770,19 @@ def read_xml(xml_file):
|
|||
img_poly=cv2.fillPoly(img, pts =co_img, color=(4,4,4))
|
||||
img_poly=cv2.fillPoly(img, pts =co_sep, color=(5,5,5))
|
||||
|
||||
return tree1, root1, bb_coord_printspace, file_name, id_paragraph, id_header+id_heading, co_text_paragraph, co_text_header+co_text_heading,\
|
||||
tot_region_ref,x_len, y_len,index_tot_regions, img_poly
|
||||
|
||||
|
||||
|
||||
return (tree1,
|
||||
root1,
|
||||
ps_bbox,
|
||||
file_name,
|
||||
id_paragraph,
|
||||
id_header + id_heading,
|
||||
co_text_paragraph,
|
||||
co_text_header + co_text_heading,
|
||||
tot_region_ref,
|
||||
x_len,
|
||||
y_len,
|
||||
index_tot_regions,
|
||||
img_poly)
|
||||
|
||||
# def bounding_box(cnt,color, corr_order_index ):
|
||||
# x, y, w, h = cv2.boundingRect(cnt)
|
||||
|
|
|
|||
|
|
@ -1,19 +1,24 @@
|
|||
"""
|
||||
Tool to load model and predict for given image.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
from typing import Tuple
|
||||
import warnings
|
||||
import json
|
||||
|
||||
import numpy as np
|
||||
import cv2
|
||||
from numpy._typing import NDArray
|
||||
import tensorflow as tf
|
||||
from keras.models import Model, load_model
|
||||
from keras import backend as K
|
||||
import click
|
||||
from tensorflow.python.keras import backend as tensorflow_backend
|
||||
import numpy as np
|
||||
from numpy._typing import NDArray
|
||||
import cv2
|
||||
import xml.etree.ElementTree as ET
|
||||
|
||||
os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15
|
||||
import tensorflow as tf
|
||||
from tensorflow.keras.models import Model, load_model
|
||||
from tensorflow.keras.layers import StringLookup
|
||||
|
||||
from .gt_gen_utils import (
|
||||
filter_contours_area_of_image,
|
||||
find_new_features_of_contours,
|
||||
|
|
@ -21,24 +26,37 @@ from .gt_gen_utils import (
|
|||
resize_image,
|
||||
update_list_and_return_first_with_length_bigger_than_one
|
||||
)
|
||||
from .models import (
|
||||
from ..patch_encoder import (
|
||||
PatchEncoder,
|
||||
Patches
|
||||
)
|
||||
from .metrics import (
|
||||
soft_dice_loss,
|
||||
weighted_categorical_crossentropy,
|
||||
)
|
||||
from.utils import scale_padd_image_for_ocr
|
||||
from ..utils.utils_ocr import decode_batch_predictions
|
||||
|
||||
from.utils import (scale_padd_image_for_ocr)
|
||||
from eynollah.utils.utils_ocr import (decode_batch_predictions)
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
|
||||
__doc__=\
|
||||
"""
|
||||
Tool to load model and predict for given image.
|
||||
"""
|
||||
|
||||
class sbb_predict:
|
||||
def __init__(self,image, dir_in, model, task, config_params_model, patches, save, save_layout, ground_truth, xml_file, cpu, out, min_area):
|
||||
class SBBPredict:
|
||||
def __init__(self,
|
||||
image,
|
||||
dir_in,
|
||||
model,
|
||||
task,
|
||||
config_params_model,
|
||||
patches,
|
||||
save,
|
||||
save_layout,
|
||||
ground_truth,
|
||||
xml_file,
|
||||
cpu,
|
||||
out,
|
||||
min_area,
|
||||
):
|
||||
self.image=image
|
||||
self.dir_in=dir_in
|
||||
self.patches=patches
|
||||
|
|
@ -57,8 +75,9 @@ class sbb_predict:
|
|||
self.min_area = 0
|
||||
|
||||
def resize_image(self,img_in,input_height,input_width):
|
||||
return cv2.resize( img_in, ( input_width,input_height) ,interpolation=cv2.INTER_NEAREST)
|
||||
|
||||
return cv2.resize(img_in, (input_width,
|
||||
input_height),
|
||||
interpolation=cv2.INTER_NEAREST)
|
||||
|
||||
def color_images(self,seg):
|
||||
ann_u=range(self.n_classes)
|
||||
|
|
@ -74,68 +93,6 @@ class sbb_predict:
|
|||
seg_img[:,:,2][seg==c]=c
|
||||
return seg_img
|
||||
|
||||
def otsu_copy_binary(self,img):
|
||||
img_r=np.zeros((img.shape[0],img.shape[1],3))
|
||||
img1=img[:,:,0]
|
||||
|
||||
#print(img.min())
|
||||
#print(img[:,:,0].min())
|
||||
#blur = cv2.GaussianBlur(img,(5,5))
|
||||
#ret3,th3 = cv2.threshold(blur,0,255,cv2.THRESH_BINARY+cv2.THRESH_OTSU)
|
||||
_, threshold1 = cv2.threshold(img1, 0, 255, cv2.THRESH_BINARY+cv2.THRESH_OTSU)
|
||||
|
||||
|
||||
|
||||
img_r[:,:,0]=threshold1
|
||||
img_r[:,:,1]=threshold1
|
||||
img_r[:,:,2]=threshold1
|
||||
#img_r=img_r/float(np.max(img_r))*255
|
||||
return img_r
|
||||
|
||||
def otsu_copy(self,img):
|
||||
img_r=np.zeros((img.shape[0],img.shape[1],3))
|
||||
#img1=img[:,:,0]
|
||||
|
||||
#print(img.min())
|
||||
#print(img[:,:,0].min())
|
||||
#blur = cv2.GaussianBlur(img,(5,5))
|
||||
#ret3,th3 = cv2.threshold(blur,0,255,cv2.THRESH_BINARY+cv2.THRESH_OTSU)
|
||||
_, threshold1 = cv2.threshold(img[:,:,0], 0, 255, cv2.THRESH_BINARY+cv2.THRESH_OTSU)
|
||||
_, threshold2 = cv2.threshold(img[:,:,1], 0, 255, cv2.THRESH_BINARY+cv2.THRESH_OTSU)
|
||||
_, threshold3 = cv2.threshold(img[:,:,2], 0, 255, cv2.THRESH_BINARY+cv2.THRESH_OTSU)
|
||||
|
||||
|
||||
|
||||
img_r[:,:,0]=threshold1
|
||||
img_r[:,:,1]=threshold2
|
||||
img_r[:,:,2]=threshold3
|
||||
###img_r=img_r/float(np.max(img_r))*255
|
||||
return img_r
|
||||
|
||||
def soft_dice_loss(self,y_true, y_pred, epsilon=1e-6):
|
||||
|
||||
axes = tuple(range(1, len(y_pred.shape)-1))
|
||||
|
||||
numerator = 2. * K.sum(y_pred * y_true, axes)
|
||||
|
||||
denominator = K.sum(K.square(y_pred) + K.square(y_true), axes)
|
||||
return 1.00 - K.mean(numerator / (denominator + epsilon)) # average over classes and batch
|
||||
|
||||
# def weighted_categorical_crossentropy(self,weights=None):
|
||||
#
|
||||
# def loss(y_true, y_pred):
|
||||
# labels_floats = tf.cast(y_true, tf.float32)
|
||||
# per_pixel_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=labels_floats,logits=y_pred)
|
||||
#
|
||||
# if weights is not None:
|
||||
# weight_mask = tf.maximum(tf.reduce_max(tf.constant(
|
||||
# np.array(weights, dtype=np.float32)[None, None, None])
|
||||
# * labels_floats, axis=-1), 1.0)
|
||||
# per_pixel_loss = per_pixel_loss * weight_mask[:, :, :, None]
|
||||
# return tf.reduce_mean(per_pixel_loss)
|
||||
# return self.loss
|
||||
|
||||
|
||||
def IoU(self,Yi,y_predi):
|
||||
## mean Intersection over Union
|
||||
## Mean IoU = TP/(FN + TP + FP)
|
||||
|
|
@ -162,29 +119,33 @@ class sbb_predict:
|
|||
return mIoU
|
||||
|
||||
def start_new_session_and_model(self):
|
||||
if self.task == "cnn-rnn-ocr":
|
||||
if self.cpu:
|
||||
os.environ['CUDA_VISIBLE_DEVICES']='-1'
|
||||
self.model = load_model(self.model_dir)
|
||||
self.model = tf.keras.models.Model(
|
||||
self.model.get_layer(name = "image").input,
|
||||
self.model.get_layer(name = "dense2").output)
|
||||
if self.cpu:
|
||||
tf.config.set_visible_devices([], 'GPU')
|
||||
else:
|
||||
config = tf.compat.v1.ConfigProto()
|
||||
config.gpu_options.allow_growth = True
|
||||
|
||||
session = tf.compat.v1.Session(config=config) # tf.InteractiveSession()
|
||||
tensorflow_backend.set_session(session)
|
||||
try:
|
||||
for device in tf.config.list_physical_devices('GPU'):
|
||||
tf.config.experimental.set_memory_growth(device, True)
|
||||
except:
|
||||
print("no GPU device available", file=sys.stderr)
|
||||
|
||||
if self.task == "cnn-rnn-ocr":
|
||||
self.model = Model(
|
||||
self.model.get_layer(name = "image").input,
|
||||
self.model.get_layer(name = "dense2").output)
|
||||
else:
|
||||
self.model = load_model(self.model_dir, compile=False,
|
||||
custom_objects={"PatchEncoder": PatchEncoder,
|
||||
"Patches": Patches})
|
||||
|
||||
##if self.weights_dir!=None:
|
||||
##self.model.load_weights(self.weights_dir)
|
||||
|
||||
assert isinstance(self.model, Model)
|
||||
if self.task != 'classification' and self.task != 'reading_order':
|
||||
self.img_height=self.model.layers[len(self.model.layers)-1].output_shape[1]
|
||||
self.img_width=self.model.layers[len(self.model.layers)-1].output_shape[2]
|
||||
self.n_classes=self.model.layers[len(self.model.layers)-1].output_shape[3]
|
||||
last = self.model.layers[-1]
|
||||
self.img_height = last.output_shape[1]
|
||||
self.img_width = last.output_shape[2]
|
||||
self.n_classes = last.output_shape[3]
|
||||
|
||||
def visualize_model_output(self, prediction, img, task) -> Tuple[NDArray, NDArray]:
|
||||
if task == "binarization":
|
||||
|
|
@ -212,22 +173,17 @@ class sbb_predict:
|
|||
'15' : [255, 0, 255]}
|
||||
|
||||
layout_only = np.zeros(prediction.shape)
|
||||
|
||||
for unq_class in unique_classes:
|
||||
where = prediction[:,:,0]==unq_class
|
||||
rgb_class_unique = rgb_colors[str(int(unq_class))]
|
||||
layout_only[:,:,0][prediction[:,:,0]==unq_class] = rgb_class_unique[0]
|
||||
layout_only[:,:,1][prediction[:,:,0]==unq_class] = rgb_class_unique[1]
|
||||
layout_only[:,:,2][prediction[:,:,0]==unq_class] = rgb_class_unique[2]
|
||||
|
||||
|
||||
layout_only[:,:,0][where] = rgb_class_unique[0]
|
||||
layout_only[:,:,1][where] = rgb_class_unique[1]
|
||||
layout_only[:,:,2][where] = rgb_class_unique[2]
|
||||
layout_only = layout_only.astype(np.int32)
|
||||
|
||||
img = self.resize_image(img, layout_only.shape[0], layout_only.shape[1])
|
||||
|
||||
layout_only = layout_only.astype(np.int32)
|
||||
img = img.astype(np.int32)
|
||||
|
||||
|
||||
|
||||
added_image = cv2.addWeighted(img,0.5,layout_only,0.1,0)
|
||||
|
||||
assert isinstance(added_image, np.ndarray)
|
||||
|
|
@ -238,10 +194,10 @@ class sbb_predict:
|
|||
assert isinstance(self.model, Model)
|
||||
if self.task == 'classification':
|
||||
classes_names = self.config_params_model['classification_classes_name']
|
||||
img_1ch = img=cv2.imread(image_dir, 0)
|
||||
|
||||
img_1ch = img_1ch / 255.0
|
||||
img_1ch = cv2.resize(img_1ch, (self.config_params_model['input_height'], self.config_params_model['input_width']), interpolation=cv2.INTER_NEAREST)
|
||||
img_1ch = cv2.imread(image_dir, 0) / 255.0
|
||||
img_1ch = cv2.resize(img_1ch, (self.config_params_model['input_height'],
|
||||
self.config_params_model['input_width']),
|
||||
interpolation=cv2.INTER_NEAREST)
|
||||
img_in = np.zeros((1, img_1ch.shape[0], img_1ch.shape[1], 3))
|
||||
img_in[0, :, :, 0] = img_1ch[:, :]
|
||||
img_in[0, :, :, 1] = img_1ch[:, :]
|
||||
|
|
@ -251,6 +207,7 @@ class sbb_predict:
|
|||
index_class = np.argmax(label_p_pred[0])
|
||||
|
||||
print("Predicted Class: {}".format(classes_names[str(int(index_class))]))
|
||||
|
||||
elif self.task == "cnn-rnn-ocr":
|
||||
img=cv2.imread(image_dir)
|
||||
img = scale_padd_image_for_ocr(img, self.config_params_model['input_height'], self.config_params_model['input_width'])
|
||||
|
|
@ -279,19 +236,22 @@ class sbb_predict:
|
|||
img_height = self.config_params_model['input_height']
|
||||
img_width = self.config_params_model['input_width']
|
||||
|
||||
tree_xml, root_xml, bb_coord_printspace, file_name, id_paragraph, id_header, co_text_paragraph, co_text_header, tot_region_ref, x_len, y_len, index_tot_regions, img_poly = read_xml(self.xml_file)
|
||||
_, cy_main, x_min_main, x_max_main, y_min_main, y_max_main, _ = find_new_features_of_contours(co_text_header)
|
||||
tree_xml, root_xml, ps_bbox, file_name, \
|
||||
id_paragraph, id_header, \
|
||||
co_text_paragraph, co_text_header, \
|
||||
tot_region_ref, x_len, y_len, index_tot_regions, \
|
||||
img_poly = read_xml(self.xml_file)
|
||||
_, cy_main, x_min_main, x_max_main, y_min_main, y_max_main, _ = \
|
||||
find_new_features_of_contours(co_text_header)
|
||||
|
||||
img_header_and_sep = np.zeros((y_len,x_len), dtype='uint8')
|
||||
|
||||
|
||||
for j in range(len(cy_main)):
|
||||
img_header_and_sep[int(y_max_main[j]):int(y_max_main[j])+12,int(x_min_main[j]):int(x_max_main[j]) ] = 1
|
||||
img_header_and_sep[int(y_max_main[j]): int(y_max_main[j]) + 12,
|
||||
int(x_min_main[j]): int(x_max_main[j])] = 1
|
||||
|
||||
co_text_all = co_text_paragraph + co_text_header
|
||||
id_all_text = id_paragraph + id_header
|
||||
|
||||
|
||||
##texts_corr_order_index = [index_tot_regions[tot_region_ref.index(i)] for i in id_all_text ]
|
||||
##texts_corr_order_index_int = [int(x) for x in texts_corr_order_index]
|
||||
texts_corr_order_index_int = list(np.array(range(len(co_text_all))))
|
||||
|
|
@ -302,7 +262,8 @@ class sbb_predict:
|
|||
#print(np.shape(co_text_all[0]), len( np.shape(co_text_all[0]) ),'co_text_all')
|
||||
#co_text_all = filter_contours_area_of_image_tables(img_poly, co_text_all, _, max_area, min_area)
|
||||
#print(co_text_all,'co_text_all')
|
||||
co_text_all, texts_corr_order_index_int, _ = filter_contours_area_of_image(img_poly, co_text_all, texts_corr_order_index_int, max_area, self.min_area)
|
||||
co_text_all, texts_corr_order_index_int, _ = filter_contours_area_of_image(
|
||||
img_poly, co_text_all, texts_corr_order_index_int, max_area, self.min_area)
|
||||
|
||||
#print(texts_corr_order_index_int)
|
||||
|
||||
|
|
@ -315,15 +276,13 @@ class sbb_predict:
|
|||
img_label=cv2.fillPoly(img_label, pts =[co_text_all[i]], color=(1,1,1))
|
||||
labels_con[:,:,i] = img_label[:,:,0]
|
||||
|
||||
if bb_coord_printspace:
|
||||
#bb_coord_printspace[x,y,w,h,_,_]
|
||||
x = bb_coord_printspace[0]
|
||||
y = bb_coord_printspace[1]
|
||||
w = bb_coord_printspace[2]
|
||||
h = bb_coord_printspace[3]
|
||||
labels_con = labels_con[y:y+h, x:x+w, :]
|
||||
img_poly = img_poly[y:y+h, x:x+w, :]
|
||||
img_header_and_sep = img_header_and_sep[y:y+h, x:x+w]
|
||||
if ps_bbox:
|
||||
labels_con = labels_con[ps_bbox[1]:ps_bbox[3],
|
||||
ps_bbox[0]:ps_bbox[2], :]
|
||||
img_poly = img_poly[ps_bbox[1]:ps_bbox[3],
|
||||
ps_bbox[0]:ps_bbox[2], :]
|
||||
img_header_and_sep = img_header_and_sep[ps_bbox[1]:ps_bbox[3],
|
||||
ps_bbox[0]:ps_bbox[2]]
|
||||
|
||||
|
||||
|
||||
|
|
@ -709,17 +668,15 @@ class sbb_predict:
|
|||
help="min area size of regions considered for reading order detection. The default value is zero and means that all text regions are considered for reading order.",
|
||||
)
|
||||
def main(image, dir_in, model, patches, save, save_layout, ground_truth, xml_file, cpu, out, min_area):
|
||||
assert image or dir_in, "Either a single image -i or a dir_in -di is required"
|
||||
assert image or dir_in, "Either a single image -i or a dir_in -di input is required"
|
||||
with open(os.path.join(model,'config.json')) as f:
|
||||
config_params_model = json.load(f)
|
||||
task = config_params_model['task']
|
||||
if task != 'classification' and task != 'reading_order' and task != "cnn-rnn-ocr":
|
||||
if image and not save:
|
||||
print("Error: You used one of segmentation or binarization task with image input but not set -s, you need a filename to save visualized output with -s")
|
||||
sys.exit(1)
|
||||
if dir_in and not out:
|
||||
print("Error: You used one of segmentation or binarization task with dir_in but not set -out")
|
||||
sys.exit(1)
|
||||
x=sbb_predict(image, dir_in, model, task, config_params_model, patches, save, save_layout, ground_truth, xml_file, cpu, out, min_area)
|
||||
if task not in ['classification', 'reading_order', "cnn-rnn-ocr"]:
|
||||
assert not image or save, "For segmentation or binarization, an input single image -i also requires an output filename -s"
|
||||
assert not dir_in or out, "For segmentation or binarization, an input directory -di also requires an output directory -o"
|
||||
x = SBBPredict(image, dir_in, model, task, config_params_model,
|
||||
patches, save, save_layout, ground_truth, xml_file,
|
||||
cpu, out, min_area)
|
||||
x.run()
|
||||
|
||||
|
|
|
|||
|
|
@ -1,9 +1,17 @@
|
|||
from tensorflow.keras import backend as K
|
||||
import os
|
||||
|
||||
os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15
|
||||
import tensorflow as tf
|
||||
from tensorflow.keras import backend as K
|
||||
from tensorflow.keras.metrics import Metric, MeanMetricWrapper, get
|
||||
from tensorflow.keras.initializers import Zeros
|
||||
from tensorflow_addons.image import connected_components
|
||||
import numpy as np
|
||||
|
||||
|
||||
def focal_loss(gamma=2., alpha=4.):
|
||||
EPS = K.epsilon()
|
||||
|
||||
def focal_loss(gamma=2., alpha=4., epsilon=EPS):
|
||||
gamma = float(gamma)
|
||||
alpha = float(alpha)
|
||||
|
||||
|
|
@ -27,7 +35,6 @@ def focal_loss(gamma=2., alpha=4.):
|
|||
Returns:
|
||||
[tensor] -- loss.
|
||||
"""
|
||||
epsilon = 1.e-9
|
||||
y_true = tf.convert_to_tensor(y_true, tf.float32)
|
||||
y_pred = tf.convert_to_tensor(y_pred, tf.float32)
|
||||
|
||||
|
|
@ -148,7 +155,7 @@ def generalized_dice_loss(y_true, y_pred):
|
|||
|
||||
|
||||
# TODO: document where this is from
|
||||
def soft_dice_loss(y_true, y_pred, epsilon=1e-6):
|
||||
def soft_dice_loss(y_true, y_pred, epsilon=EPS):
|
||||
"""
|
||||
Soft dice loss calculation for arbitrary batch size, number of classes, and number of spatial dimensions.
|
||||
Assumes the `channels_last` format.
|
||||
|
|
@ -361,3 +368,159 @@ def jaccard_distance_loss(y_true, y_pred, smooth=100):
|
|||
sum_ = K.sum(K.abs(y_true) + K.abs(y_pred), axis=-1)
|
||||
jac = (intersection + smooth) / (sum_ - intersection + smooth)
|
||||
return (1 - jac) * smooth
|
||||
|
||||
|
||||
def metrics_superposition(*metrics, weights=None):
|
||||
"""
|
||||
return a single metric derived by adding all given metrics
|
||||
|
||||
default weights are uniform
|
||||
"""
|
||||
if weights is None:
|
||||
weights = len(metrics) * [tf.constant(1.0)]
|
||||
def mixed(y_true, y_pred):
|
||||
results = []
|
||||
for metric, weight in zip(metrics, weights):
|
||||
results.append(metric(y_true, y_pred) * weight)
|
||||
return tf.reduce_mean(tf.stack(results), 0)
|
||||
mixed.__name__ = '/'.join(m.__name__ for m in metrics)
|
||||
return mixed
|
||||
|
||||
|
||||
class Superposition(MeanMetricWrapper):
|
||||
def __init__(self, metrics, weights=None, dtype=None):
|
||||
self._metrics = metrics
|
||||
self._weights = weights
|
||||
mixed = metrics_superposition(*metrics, weights=weights)
|
||||
super().__init__(mixed, name=mixed.__name__, dtype=dtype)
|
||||
def get_config(self):
|
||||
return dict(metrics=self._metrics,
|
||||
weights=self._weights,
|
||||
**super().get_config())
|
||||
|
||||
class ConfusionMatrix(Metric):
|
||||
def __init__(self, nlabels=None, nrm="all", name="confusion_matrix", dtype=tf.float32):
|
||||
super().__init__(name=name, dtype=dtype)
|
||||
assert nlabels is not None
|
||||
self._nlabels = nlabels
|
||||
self._shape = (self._nlabels, self._nlabels)
|
||||
self._matrix = self.add_weight(name, shape=self._shape,
|
||||
initializer=Zeros)
|
||||
assert nrm in ("all", "true", "pred", "none")
|
||||
self._nrm = nrm
|
||||
|
||||
def update_state(self, y_true, y_pred, sample_weight=None):
|
||||
y_pred = tf.math.argmax(y_pred, axis=-1)
|
||||
y_true = tf.math.argmax(y_true, axis=-1)
|
||||
|
||||
y_pred = tf.reshape(y_pred, shape=(-1,))
|
||||
y_true = tf.reshape(y_true, shape=(-1,))
|
||||
|
||||
y_pred.shape.assert_is_compatible_with(y_true.shape)
|
||||
confusion = tf.math.confusion_matrix(y_true, y_pred, num_classes=self._nlabels, dtype=self._dtype)
|
||||
|
||||
return self._matrix.assign_add(confusion)
|
||||
|
||||
def result(self):
|
||||
"""normalize"""
|
||||
if self._nrm == "all":
|
||||
denom = tf.math.reduce_sum(self._matrix, axis=(0, 1))
|
||||
elif self._nrm == "true":
|
||||
denom = tf.math.reduce_sum(self._matrix, axis=1, keepdims=True)
|
||||
elif self._nrm == "pred":
|
||||
denom = tf.math.reduce_sum(self._matrix, axis=0, keepdims=True)
|
||||
else:
|
||||
denom = tf.constant(1.0)
|
||||
return tf.math.divide_no_nan(self._matrix, denom)
|
||||
|
||||
def reset_state(self):
|
||||
for v in self.variables:
|
||||
v.assign(tf.zeros(shape=self._shape))
|
||||
|
||||
def get_config(self):
|
||||
return dict(nlabels=self._nlabels,
|
||||
**super().get_config())
|
||||
|
||||
def connected_components_loss(artificial=0):
|
||||
"""
|
||||
metric/loss function capturing the separability of segmentation maps
|
||||
|
||||
For both sides (true and predicted, resp.), computes
|
||||
1. the argmax() of class-wise softmax input (i.e. the segmentation map)
|
||||
2. the connected components (i.e. the instance label map)
|
||||
3. the max() (i.e. the highest label = nr of components)
|
||||
|
||||
The original idea was to then calculate a regression formula
|
||||
between those two targets. But it is insufficient to just
|
||||
approximate the same number of components, for they might be
|
||||
completely different (true components being merged, predicted
|
||||
components splitting others). We really want to capture the
|
||||
correspondence between those labels, which is localised.
|
||||
|
||||
For that we now calculate the label pairs and their counts.
|
||||
Looking at the M,N incidence matrix, we want those counts
|
||||
to be distributed orthogonally (ideally). So we compute a
|
||||
singular value decomposition and compare the sum total of
|
||||
singular values to the sum total of all label counts. The
|
||||
rate of the two determines a measure of congruence.
|
||||
|
||||
Moreover, for the case of artificial boundary segments around
|
||||
regions, optionally introduced by the training extractor to
|
||||
represent segment identity in the loss (and removed at runtime):
|
||||
Reduce this class to background as well.
|
||||
"""
|
||||
def metric(y_true, y_pred):
|
||||
if artificial:
|
||||
# convert artificial border class to background
|
||||
y_true = y_true[:, :, :, :artificial]
|
||||
y_pred = y_pred[:, :, :, :artificial]
|
||||
# [B, H, W, C]
|
||||
l_true = tf.math.argmax(y_true, axis=-1)
|
||||
l_pred = tf.math.argmax(y_pred, axis=-1)
|
||||
# [B, H, W]
|
||||
c_true = tf.cast(connected_components(l_true), tf.int64)
|
||||
c_pred = tf.cast(connected_components(l_pred), tf.int64)
|
||||
# [B, H, W]
|
||||
n_batch = y_true.shape[0]
|
||||
C_true = tf.math.reduce_max(c_true, (1, 2)) + 1
|
||||
C_pred = tf.math.reduce_max(c_pred, (1, 2)) + 1
|
||||
MODULUS = tf.constant(2**22, tf.int64)
|
||||
tf.debugging.assert_less(C_true, MODULUS,
|
||||
message="cannot compare segments: too many connected components in GT")
|
||||
tf.debugging.assert_less(C_pred, MODULUS,
|
||||
message="cannot compare segments: too many connected components in prediction")
|
||||
c_comb = MODULUS * c_pred + c_true
|
||||
tf.debugging.assert_greater_equal(c_comb, tf.constant(0, tf.int64),
|
||||
message="overflow pairing components")
|
||||
# [B, H, W]
|
||||
# tf.unique does not support batch dim, so...
|
||||
results = []
|
||||
for c_comb, C_true, C_pred in zip(
|
||||
tf.unstack(c_comb, num=n_batch),
|
||||
tf.unstack(C_true, num=n_batch),
|
||||
tf.unstack(C_pred, num=n_batch),
|
||||
):
|
||||
prod, _, count = tf.unique_with_counts(tf.reshape(c_comb, (-1,)))
|
||||
# [L]
|
||||
#corr = tf.zeros([C_pred, C_true], tf.int32)
|
||||
#corr[prod // 2**24, prod % 2**24] = count
|
||||
corr = tf.scatter_nd(tf.stack([prod // MODULUS, prod % MODULUS], axis=1),
|
||||
count, (C_pred, C_true))
|
||||
corr = tf.cast(corr, tf.float32)
|
||||
# [Cpred, Ctrue]
|
||||
sgv = tf.linalg.svd(corr, compute_uv=False)
|
||||
results.append(tf.reduce_sum(sgv) / tf.reduce_sum(corr))
|
||||
return 1.0 - tf.reduce_mean(tf.stack(results), 0)
|
||||
# c_true = tf.reshape(c_true, (n_batch, -1))
|
||||
# c_pred = tf.reshape(c_pred, (n_batch, -1))
|
||||
# # [B, H*W]
|
||||
# n_true = tf.math.reduce_max(c_true, axis=1)
|
||||
# n_pred = tf.math.reduce_max(c_pred, axis=1)
|
||||
# # [B]
|
||||
# diff = tf.cast(n_true - n_pred, tf.float32)
|
||||
# return tf.reduce_mean(tf.math.abs(diff) + alpha * diff, axis=-1)
|
||||
|
||||
metric.__name__ = 'nCC'
|
||||
metric._direction = 'down'
|
||||
return metric
|
||||
|
||||
|
|
|
|||
|
|
@ -1,9 +1,14 @@
|
|||
from tensorflow import keras
|
||||
from keras.layers import (
|
||||
import os
|
||||
|
||||
os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15
|
||||
import tensorflow as tf
|
||||
from tensorflow.keras.layers import (
|
||||
Activation,
|
||||
Add,
|
||||
AveragePooling2D,
|
||||
BatchNormalization,
|
||||
Bidirectional,
|
||||
Conv1D,
|
||||
Conv2D,
|
||||
Dense,
|
||||
Dropout,
|
||||
|
|
@ -13,34 +18,34 @@ from keras.layers import (
|
|||
Lambda,
|
||||
Layer,
|
||||
LayerNormalization,
|
||||
LSTM,
|
||||
MaxPooling2D,
|
||||
MultiHeadAttention,
|
||||
Reshape,
|
||||
UpSampling2D,
|
||||
ZeroPadding2D,
|
||||
add,
|
||||
concatenate
|
||||
)
|
||||
from keras.models import Model
|
||||
import tensorflow as tf
|
||||
# from keras import layers, models
|
||||
from keras.regularizers import l2
|
||||
from tensorflow.keras.models import Model
|
||||
from tensorflow.keras.regularizers import l2
|
||||
from tensorflow.keras.backend import ctc_batch_cost
|
||||
|
||||
from eynollah.patch_encoder import Patches, PatchEncoder
|
||||
from ..patch_encoder import Patches, PatchEncoder
|
||||
|
||||
##mlp_head_units = [512, 256]#[2048, 1024]
|
||||
###projection_dim = 64
|
||||
##transformer_layers = 2#8
|
||||
##num_heads = 1#4
|
||||
resnet50_Weights_path = './pretrained_model/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5'
|
||||
RESNET50_WEIGHTS_PATH = './pretrained_model/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5'
|
||||
RESNET50_WEIGHTS_URL = ('https://github.com/fchollet/deep-learning-models/releases/download/v0.2/'
|
||||
'resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5')
|
||||
|
||||
IMAGE_ORDERING = 'channels_last'
|
||||
MERGE_AXIS = -1
|
||||
|
||||
|
||||
class CTCLayer(tf.keras.layers.Layer):
|
||||
def __init__(self, name=None):
|
||||
super().__init__(name=name)
|
||||
self.loss_fn = tf.keras.backend.ctc_batch_cost
|
||||
|
||||
class CTCLayer(Layer):
|
||||
def call(self, y_true, y_pred):
|
||||
batch_len = tf.cast(tf.shape(y_true)[0], dtype="int64")
|
||||
input_length = tf.cast(tf.shape(y_pred)[1], dtype="int64")
|
||||
|
|
@ -48,7 +53,7 @@ class CTCLayer(tf.keras.layers.Layer):
|
|||
|
||||
input_length = input_length * tf.ones(shape=(batch_len, 1), dtype="int64")
|
||||
label_length = label_length * tf.ones(shape=(batch_len, 1), dtype="int64")
|
||||
loss = self.loss_fn(y_true, y_pred, input_length, label_length)
|
||||
loss = ctc_batch_cost(y_true, y_pred, input_length, label_length)
|
||||
self.add_loss(loss)
|
||||
|
||||
# At test time, just return the computed predictions.
|
||||
|
|
@ -61,14 +66,9 @@ def mlp(x, hidden_units, dropout_rate):
|
|||
return x
|
||||
|
||||
def one_side_pad(x):
|
||||
x = ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING)(x)
|
||||
if IMAGE_ORDERING == 'channels_first':
|
||||
x = Lambda(lambda x: x[:, :, :-1, :-1])(x)
|
||||
elif IMAGE_ORDERING == 'channels_last':
|
||||
x = Lambda(lambda x: x[:, :-1, :-1, :])(x)
|
||||
x = ZeroPadding2D(((1, 0), (1, 0)), data_format=IMAGE_ORDERING)(x)
|
||||
return x
|
||||
|
||||
|
||||
def identity_block(input_tensor, kernel_size, filters, stage, block):
|
||||
"""The identity block is the block that has no conv layer at shortcut.
|
||||
# Arguments
|
||||
|
|
@ -151,6 +151,116 @@ def conv_block(input_tensor, kernel_size, filters, stage, block, strides=(2, 2))
|
|||
x = Activation('relu')(x)
|
||||
return x
|
||||
|
||||
def resnet50(inputs, weight_decay=1e-6, pretraining=False):
|
||||
if IMAGE_ORDERING == 'channels_last':
|
||||
bn_axis = 3
|
||||
else:
|
||||
bn_axis = 1
|
||||
|
||||
x = ZeroPadding2D((3, 3), data_format=IMAGE_ORDERING)(inputs)
|
||||
x = Conv2D(64, (7, 7), data_format=IMAGE_ORDERING, strides=(2, 2), kernel_regularizer=l2(weight_decay),
|
||||
name='conv1')(x)
|
||||
f1 = x
|
||||
|
||||
x = BatchNormalization(axis=bn_axis, name='bn_conv1')(x)
|
||||
x = Activation('relu')(x)
|
||||
x = MaxPooling2D((3, 3), data_format=IMAGE_ORDERING, strides=(2, 2))(x)
|
||||
|
||||
x = conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1))
|
||||
x = identity_block(x, 3, [64, 64, 256], stage=2, block='b')
|
||||
x = identity_block(x, 3, [64, 64, 256], stage=2, block='c')
|
||||
f2 = one_side_pad(x)
|
||||
|
||||
x = conv_block(x, 3, [128, 128, 512], stage=3, block='a')
|
||||
x = identity_block(x, 3, [128, 128, 512], stage=3, block='b')
|
||||
x = identity_block(x, 3, [128, 128, 512], stage=3, block='c')
|
||||
x = identity_block(x, 3, [128, 128, 512], stage=3, block='d')
|
||||
f3 = x
|
||||
|
||||
x = conv_block(x, 3, [256, 256, 1024], stage=4, block='a')
|
||||
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='b')
|
||||
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='c')
|
||||
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='d')
|
||||
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='e')
|
||||
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='f')
|
||||
f4 = x
|
||||
|
||||
x = conv_block(x, 3, [512, 512, 2048], stage=5, block='a')
|
||||
x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b')
|
||||
x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c')
|
||||
f5 = x
|
||||
|
||||
if pretraining:
|
||||
model = Model(inputs, x).load_weights(RESNET50_WEIGHTS_PATH)
|
||||
|
||||
return f1, f2, f3, f4, f5
|
||||
|
||||
def unet_decoder(img, f1, f2, f3, f4, f5, n_classes, light=False, task="segmentation", weight_decay=1e-6):
|
||||
if IMAGE_ORDERING == 'channels_last':
|
||||
bn_axis = 3
|
||||
else:
|
||||
bn_axis = 1
|
||||
|
||||
o = Conv2D(512 if light else 1024, (1, 1), padding='same',
|
||||
data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay))(f5)
|
||||
o = BatchNormalization(axis=bn_axis)(o)
|
||||
o = Activation('relu')(o)
|
||||
|
||||
if light:
|
||||
f4 = Conv2D(512, (1, 1), padding='same',
|
||||
data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay))(f4)
|
||||
f4 = BatchNormalization(axis=bn_axis)(f4)
|
||||
f4 = Activation('relu')(f4)
|
||||
|
||||
o = UpSampling2D((2, 2), data_format=IMAGE_ORDERING, interpolation="bilinear")(o)
|
||||
o = concatenate([o, f4], axis=MERGE_AXIS)
|
||||
o = ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING)(o)
|
||||
o = Conv2D(512, (3, 3), padding='valid',
|
||||
data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay))(o)
|
||||
o = BatchNormalization(axis=bn_axis)(o)
|
||||
o = Activation('relu')(o)
|
||||
|
||||
o = UpSampling2D((2, 2), data_format=IMAGE_ORDERING, interpolation="bilinear")(o)
|
||||
o = concatenate([o, f3], axis=MERGE_AXIS)
|
||||
o = ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING)(o)
|
||||
o = Conv2D(256, (3, 3), padding='valid',
|
||||
data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay))(o)
|
||||
o = BatchNormalization(axis=bn_axis)(o)
|
||||
o = Activation('relu')(o)
|
||||
|
||||
o = UpSampling2D((2, 2), data_format=IMAGE_ORDERING, interpolation="bilinear")(o)
|
||||
o = concatenate([o, f2], axis=MERGE_AXIS)
|
||||
o = ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING)(o)
|
||||
o = Conv2D(128, (3, 3), padding='valid',
|
||||
data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay))(o)
|
||||
o = BatchNormalization(axis=bn_axis)(o)
|
||||
o = Activation('relu')(o)
|
||||
|
||||
o = UpSampling2D((2, 2), data_format=IMAGE_ORDERING, interpolation="bilinear")(o)
|
||||
o = concatenate([o, f1], axis=MERGE_AXIS)
|
||||
o = ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING)(o)
|
||||
o = Conv2D(64, (3, 3), padding='valid',
|
||||
data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay))(o)
|
||||
o = BatchNormalization(axis=bn_axis)(o)
|
||||
o = Activation('relu')(o)
|
||||
|
||||
o = UpSampling2D((2, 2), data_format=IMAGE_ORDERING, interpolation="bilinear")(o)
|
||||
o = concatenate([o, img], axis=MERGE_AXIS)
|
||||
o = ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING)(o)
|
||||
o = Conv2D(32, (3, 3), padding='valid',
|
||||
data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay))(o)
|
||||
o = BatchNormalization(axis=bn_axis)(o)
|
||||
o = Activation('relu')(o)
|
||||
|
||||
o = Conv2D(n_classes, (1, 1), padding='same',
|
||||
data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay))(o)
|
||||
if task == "segmentation":
|
||||
o = BatchNormalization(axis=bn_axis)(o)
|
||||
o = Activation('softmax')(o)
|
||||
else:
|
||||
o = Activation('sigmoid')(o)
|
||||
|
||||
return Model(img, o)
|
||||
|
||||
def resnet50_unet_light(n_classes, input_height=224, input_width=224, task="segmentation", weight_decay=1e-6, pretraining=False):
|
||||
assert input_height % 32 == 0
|
||||
|
|
@ -158,100 +268,9 @@ def resnet50_unet_light(n_classes, input_height=224, input_width=224, task="segm
|
|||
|
||||
img_input = Input(shape=(input_height, input_width, 3))
|
||||
|
||||
if IMAGE_ORDERING == 'channels_last':
|
||||
bn_axis = 3
|
||||
else:
|
||||
bn_axis = 1
|
||||
|
||||
x = ZeroPadding2D((3, 3), data_format=IMAGE_ORDERING)(img_input)
|
||||
x = Conv2D(64, (7, 7), data_format=IMAGE_ORDERING, strides=(2, 2), kernel_regularizer=l2(weight_decay),
|
||||
name='conv1')(x)
|
||||
f1 = x
|
||||
|
||||
x = BatchNormalization(axis=bn_axis, name='bn_conv1')(x)
|
||||
x = Activation('relu')(x)
|
||||
x = MaxPooling2D((3, 3), data_format=IMAGE_ORDERING, strides=(2, 2))(x)
|
||||
|
||||
x = conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1))
|
||||
x = identity_block(x, 3, [64, 64, 256], stage=2, block='b')
|
||||
x = identity_block(x, 3, [64, 64, 256], stage=2, block='c')
|
||||
f2 = one_side_pad(x)
|
||||
|
||||
x = conv_block(x, 3, [128, 128, 512], stage=3, block='a')
|
||||
x = identity_block(x, 3, [128, 128, 512], stage=3, block='b')
|
||||
x = identity_block(x, 3, [128, 128, 512], stage=3, block='c')
|
||||
x = identity_block(x, 3, [128, 128, 512], stage=3, block='d')
|
||||
f3 = x
|
||||
|
||||
x = conv_block(x, 3, [256, 256, 1024], stage=4, block='a')
|
||||
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='b')
|
||||
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='c')
|
||||
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='d')
|
||||
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='e')
|
||||
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='f')
|
||||
f4 = x
|
||||
|
||||
x = conv_block(x, 3, [512, 512, 2048], stage=5, block='a')
|
||||
x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b')
|
||||
x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c')
|
||||
f5 = x
|
||||
|
||||
if pretraining:
|
||||
model = Model(img_input, x).load_weights(resnet50_Weights_path)
|
||||
|
||||
v512_2048 = Conv2D(512, (1, 1), padding='same', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay))(f5)
|
||||
v512_2048 = (BatchNormalization(axis=bn_axis))(v512_2048)
|
||||
v512_2048 = Activation('relu')(v512_2048)
|
||||
|
||||
v512_1024 = Conv2D(512, (1, 1), padding='same', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay))(f4)
|
||||
v512_1024 = (BatchNormalization(axis=bn_axis))(v512_1024)
|
||||
v512_1024 = Activation('relu')(v512_1024)
|
||||
|
||||
o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(v512_2048)
|
||||
o = (concatenate([o, v512_1024], axis=MERGE_AXIS))
|
||||
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
|
||||
o = (Conv2D(512, (3, 3), padding='valid', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay)))(o)
|
||||
o = (BatchNormalization(axis=bn_axis))(o)
|
||||
o = Activation('relu')(o)
|
||||
|
||||
o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o)
|
||||
o = (concatenate([o, f3], axis=MERGE_AXIS))
|
||||
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
|
||||
o = (Conv2D(256, (3, 3), padding='valid', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay)))(o)
|
||||
o = (BatchNormalization(axis=bn_axis))(o)
|
||||
o = Activation('relu')(o)
|
||||
|
||||
o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o)
|
||||
o = (concatenate([o, f2], axis=MERGE_AXIS))
|
||||
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
|
||||
o = (Conv2D(128, (3, 3), padding='valid', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay)))(o)
|
||||
o = (BatchNormalization(axis=bn_axis))(o)
|
||||
o = Activation('relu')(o)
|
||||
|
||||
o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o)
|
||||
o = (concatenate([o, f1], axis=MERGE_AXIS))
|
||||
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
|
||||
o = (Conv2D(64, (3, 3), padding='valid', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay)))(o)
|
||||
o = (BatchNormalization(axis=bn_axis))(o)
|
||||
o = Activation('relu')(o)
|
||||
|
||||
o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o)
|
||||
o = (concatenate([o, img_input], axis=MERGE_AXIS))
|
||||
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
|
||||
o = (Conv2D(32, (3, 3), padding='valid', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay)))(o)
|
||||
o = (BatchNormalization(axis=bn_axis))(o)
|
||||
o = Activation('relu')(o)
|
||||
|
||||
o = Conv2D(n_classes, (1, 1), padding='same', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay))(o)
|
||||
if task == "segmentation":
|
||||
o = (BatchNormalization(axis=bn_axis))(o)
|
||||
o = (Activation('softmax'))(o)
|
||||
else:
|
||||
o = (Activation('sigmoid'))(o)
|
||||
|
||||
model = Model(img_input, o)
|
||||
return model
|
||||
features = resnet50(img_input, weight_decay=weight_decay, pretraining=pretraining)
|
||||
|
||||
return unet_decoder(img_input, *features, n_classes, light=True, task=task, weight_decay=weight_decay)
|
||||
|
||||
def resnet50_unet(n_classes, input_height=224, input_width=224, task="segmentation", weight_decay=1e-6, pretraining=False):
|
||||
assert input_height % 32 == 0
|
||||
|
|
@ -259,162 +278,29 @@ def resnet50_unet(n_classes, input_height=224, input_width=224, task="segmentati
|
|||
|
||||
img_input = Input(shape=(input_height, input_width, 3))
|
||||
|
||||
if IMAGE_ORDERING == 'channels_last':
|
||||
bn_axis = 3
|
||||
else:
|
||||
bn_axis = 1
|
||||
features = resnet50(img_input, weight_decay=weight_decay, pretraining=pretraining)
|
||||
|
||||
x = ZeroPadding2D((3, 3), data_format=IMAGE_ORDERING)(img_input)
|
||||
x = Conv2D(64, (7, 7), data_format=IMAGE_ORDERING, strides=(2, 2), kernel_regularizer=l2(weight_decay),
|
||||
name='conv1')(x)
|
||||
f1 = x
|
||||
return unet_decoder(img_input, *features, n_classes, light=False, task=task, weight_decay=weight_decay)
|
||||
|
||||
x = BatchNormalization(axis=bn_axis, name='bn_conv1')(x)
|
||||
x = Activation('relu')(x)
|
||||
x = MaxPooling2D((3, 3), data_format=IMAGE_ORDERING, strides=(2, 2))(x)
|
||||
|
||||
x = conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1))
|
||||
x = identity_block(x, 3, [64, 64, 256], stage=2, block='b')
|
||||
x = identity_block(x, 3, [64, 64, 256], stage=2, block='c')
|
||||
f2 = one_side_pad(x)
|
||||
|
||||
x = conv_block(x, 3, [128, 128, 512], stage=3, block='a')
|
||||
x = identity_block(x, 3, [128, 128, 512], stage=3, block='b')
|
||||
x = identity_block(x, 3, [128, 128, 512], stage=3, block='c')
|
||||
x = identity_block(x, 3, [128, 128, 512], stage=3, block='d')
|
||||
f3 = x
|
||||
|
||||
x = conv_block(x, 3, [256, 256, 1024], stage=4, block='a')
|
||||
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='b')
|
||||
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='c')
|
||||
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='d')
|
||||
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='e')
|
||||
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='f')
|
||||
f4 = x
|
||||
|
||||
x = conv_block(x, 3, [512, 512, 2048], stage=5, block='a')
|
||||
x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b')
|
||||
x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c')
|
||||
f5 = x
|
||||
|
||||
if pretraining:
|
||||
Model(img_input, x).load_weights(resnet50_Weights_path)
|
||||
|
||||
v1024_2048 = Conv2D(1024, (1, 1), padding='same', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay))(
|
||||
f5)
|
||||
v1024_2048 = (BatchNormalization(axis=bn_axis))(v1024_2048)
|
||||
v1024_2048 = Activation('relu')(v1024_2048)
|
||||
|
||||
o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(v1024_2048)
|
||||
o = (concatenate([o, f4], axis=MERGE_AXIS))
|
||||
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
|
||||
o = (Conv2D(512, (3, 3), padding='valid', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay)))(o)
|
||||
o = (BatchNormalization(axis=bn_axis))(o)
|
||||
o = Activation('relu')(o)
|
||||
|
||||
o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o)
|
||||
o = (concatenate([o, f3], axis=MERGE_AXIS))
|
||||
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
|
||||
o = (Conv2D(256, (3, 3), padding='valid', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay)))(o)
|
||||
o = (BatchNormalization(axis=bn_axis))(o)
|
||||
o = Activation('relu')(o)
|
||||
|
||||
o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o)
|
||||
o = (concatenate([o, f2], axis=MERGE_AXIS))
|
||||
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
|
||||
o = (Conv2D(128, (3, 3), padding='valid', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay)))(o)
|
||||
o = (BatchNormalization(axis=bn_axis))(o)
|
||||
o = Activation('relu')(o)
|
||||
|
||||
o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o)
|
||||
o = (concatenate([o, f1], axis=MERGE_AXIS))
|
||||
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
|
||||
o = (Conv2D(64, (3, 3), padding='valid', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay)))(o)
|
||||
o = (BatchNormalization(axis=bn_axis))(o)
|
||||
o = Activation('relu')(o)
|
||||
|
||||
o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o)
|
||||
o = (concatenate([o, img_input], axis=MERGE_AXIS))
|
||||
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
|
||||
o = (Conv2D(32, (3, 3), padding='valid', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay)))(o)
|
||||
o = (BatchNormalization(axis=bn_axis))(o)
|
||||
o = Activation('relu')(o)
|
||||
|
||||
o = Conv2D(n_classes, (1, 1), padding='same', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay))(o)
|
||||
if task == "segmentation":
|
||||
o = (BatchNormalization(axis=bn_axis))(o)
|
||||
o = (Activation('softmax'))(o)
|
||||
else:
|
||||
o = (Activation('sigmoid'))(o)
|
||||
|
||||
model = Model(img_input, o)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def vit_resnet50_unet(n_classes, patch_size_x, patch_size_y, num_patches, mlp_head_units=None, transformer_layers=8, num_heads =4, projection_dim = 64, input_height=224, input_width=224, task="segmentation", weight_decay=1e-6, pretraining=False):
|
||||
if mlp_head_units is None:
|
||||
mlp_head_units = [128, 64]
|
||||
inputs = Input(shape=(input_height, input_width, 3))
|
||||
|
||||
#transformer_units = [
|
||||
#projection_dim * 2,
|
||||
#projection_dim,
|
||||
#] # Size of the transformer layers
|
||||
IMAGE_ORDERING = 'channels_last'
|
||||
bn_axis=3
|
||||
|
||||
x = ZeroPadding2D((3, 3), data_format=IMAGE_ORDERING)(inputs)
|
||||
x = Conv2D(64, (7, 7), data_format=IMAGE_ORDERING, strides=(2, 2),kernel_regularizer=l2(weight_decay), name='conv1')(x)
|
||||
f1 = x
|
||||
|
||||
x = BatchNormalization(axis=bn_axis, name='bn_conv1')(x)
|
||||
x = Activation('relu')(x)
|
||||
x = MaxPooling2D((3, 3), data_format=IMAGE_ORDERING, strides=(2, 2))(x)
|
||||
|
||||
x = conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1))
|
||||
x = identity_block(x, 3, [64, 64, 256], stage=2, block='b')
|
||||
x = identity_block(x, 3, [64, 64, 256], stage=2, block='c')
|
||||
f2 = one_side_pad(x)
|
||||
|
||||
x = conv_block(x, 3, [128, 128, 512], stage=3, block='a')
|
||||
x = identity_block(x, 3, [128, 128, 512], stage=3, block='b')
|
||||
x = identity_block(x, 3, [128, 128, 512], stage=3, block='c')
|
||||
x = identity_block(x, 3, [128, 128, 512], stage=3, block='d')
|
||||
f3 = x
|
||||
|
||||
x = conv_block(x, 3, [256, 256, 1024], stage=4, block='a')
|
||||
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='b')
|
||||
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='c')
|
||||
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='d')
|
||||
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='e')
|
||||
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='f')
|
||||
f4 = x
|
||||
|
||||
x = conv_block(x, 3, [512, 512, 2048], stage=5, block='a')
|
||||
x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b')
|
||||
x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c')
|
||||
f5 = x
|
||||
|
||||
if pretraining:
|
||||
model = Model(inputs, x).load_weights(resnet50_Weights_path)
|
||||
|
||||
#num_patches = x.shape[1]*x.shape[2]
|
||||
|
||||
#patch_size_y = input_height / x.shape[1]
|
||||
#patch_size_x = input_width / x.shape[2]
|
||||
#patch_size = patch_size_x * patch_size_y
|
||||
patches = Patches(patch_size_x, patch_size_y)(x)
|
||||
def transformer_block(img,
|
||||
num_patches,
|
||||
patchsize_x,
|
||||
patchsize_y,
|
||||
mlp_head_units,
|
||||
n_layers,
|
||||
num_heads,
|
||||
projection_dim):
|
||||
patches = Patches(patchsize_x, patchsize_y)(img)
|
||||
# Encode patches.
|
||||
encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)
|
||||
|
||||
for _ in range(transformer_layers):
|
||||
for _ in range(n_layers):
|
||||
# Layer normalization 1.
|
||||
x1 = LayerNormalization(epsilon=1e-6)(encoded_patches)
|
||||
# Create a multi-head attention layer.
|
||||
attention_output = MultiHeadAttention(
|
||||
num_heads=num_heads, key_dim=projection_dim, dropout=0.1
|
||||
)(x1, x1)
|
||||
attention_output = MultiHeadAttention(num_heads=num_heads,
|
||||
key_dim=projection_dim,
|
||||
dropout=0.1)(x1, x1)
|
||||
# Skip connection 1.
|
||||
x2 = Add()([attention_output, encoded_patches])
|
||||
# Layer normalization 2.
|
||||
|
|
@ -424,179 +310,75 @@ def vit_resnet50_unet(n_classes, patch_size_x, patch_size_y, num_patches, mlp_he
|
|||
# Skip connection 2.
|
||||
encoded_patches = Add()([x3, x2])
|
||||
|
||||
assert isinstance(x, Layer)
|
||||
encoded_patches = tf.reshape(encoded_patches, [-1, x.shape[1], x.shape[2] , int( projection_dim / (patch_size_x * patch_size_y) )])
|
||||
encoded_patches = tf.reshape(encoded_patches,
|
||||
[-1,
|
||||
img.shape[1],
|
||||
img.shape[2],
|
||||
projection_dim // (patchsize_x * patchsize_y)])
|
||||
return encoded_patches
|
||||
|
||||
v1024_2048 = Conv2D( 1024 , (1, 1), padding='same', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay))(encoded_patches)
|
||||
v1024_2048 = (BatchNormalization(axis=bn_axis))(v1024_2048)
|
||||
v1024_2048 = Activation('relu')(v1024_2048)
|
||||
|
||||
o = (UpSampling2D( (2, 2), data_format=IMAGE_ORDERING))(v1024_2048)
|
||||
o = (concatenate([o, f4],axis=MERGE_AXIS))
|
||||
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
|
||||
o = (Conv2D(512, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o)
|
||||
o = (BatchNormalization(axis=bn_axis))(o)
|
||||
o = Activation('relu')(o)
|
||||
|
||||
o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o)
|
||||
o = (concatenate([o ,f3], axis=MERGE_AXIS))
|
||||
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
|
||||
o = (Conv2D(256, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o)
|
||||
o = (BatchNormalization(axis=bn_axis))(o)
|
||||
o = Activation('relu')(o)
|
||||
|
||||
o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o)
|
||||
o = (concatenate([o, f2], axis=MERGE_AXIS))
|
||||
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
|
||||
o = (Conv2D(128, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o)
|
||||
o = (BatchNormalization(axis=bn_axis))(o)
|
||||
o = Activation('relu')(o)
|
||||
|
||||
o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o)
|
||||
o = (concatenate([o, f1], axis=MERGE_AXIS))
|
||||
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
|
||||
o = (Conv2D(64, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o)
|
||||
o = (BatchNormalization(axis=bn_axis))(o)
|
||||
o = Activation('relu')(o)
|
||||
|
||||
o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o)
|
||||
o = (concatenate([o, inputs],axis=MERGE_AXIS))
|
||||
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
|
||||
o = (Conv2D(32, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o)
|
||||
o = (BatchNormalization(axis=bn_axis))(o)
|
||||
o = Activation('relu')(o)
|
||||
|
||||
o = Conv2D(n_classes, (1, 1), padding='same', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay))(o)
|
||||
if task == "segmentation":
|
||||
o = (BatchNormalization(axis=bn_axis))(o)
|
||||
o = (Activation('softmax'))(o)
|
||||
else:
|
||||
o = (Activation('sigmoid'))(o)
|
||||
|
||||
model = Model(inputs=inputs, outputs=o)
|
||||
|
||||
return model
|
||||
|
||||
def vit_resnet50_unet_transformer_before_cnn(n_classes, patch_size_x, patch_size_y, num_patches, mlp_head_units=None, transformer_layers=8, num_heads =4, projection_dim = 64, input_height=224, input_width=224, task="segmentation", weight_decay=1e-6, pretraining=False):
|
||||
if mlp_head_units is None:
|
||||
mlp_head_units = [128, 64]
|
||||
def vit_resnet50_unet(num_patches,
|
||||
n_classes,
|
||||
transformer_patchsize_x,
|
||||
transformer_patchsize_y,
|
||||
transformer_mlp_head_units=None,
|
||||
transformer_layers=8,
|
||||
transformer_num_heads=4,
|
||||
transformer_projection_dim=64,
|
||||
input_height=224,
|
||||
input_width=224,
|
||||
task="segmentation",
|
||||
weight_decay=1e-6,
|
||||
pretraining=False):
|
||||
if transformer_mlp_head_units is None:
|
||||
transformer_mlp_head_units = [128, 64]
|
||||
inputs = Input(shape=(input_height, input_width, 3))
|
||||
|
||||
##transformer_units = [
|
||||
##projection_dim * 2,
|
||||
##projection_dim,
|
||||
##] # Size of the transformer layers
|
||||
IMAGE_ORDERING = 'channels_last'
|
||||
bn_axis=3
|
||||
features = list(resnet50(inputs, weight_decay=weight_decay, pretraining=pretraining))
|
||||
|
||||
patches = Patches(patch_size_x, patch_size_y)(inputs)
|
||||
# Encode patches.
|
||||
encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)
|
||||
features[-1] = transformer_block(features[-1],
|
||||
num_patches,
|
||||
transformer_patchsize_x,
|
||||
transformer_patchsize_y,
|
||||
transformer_mlp_head_units,
|
||||
transformer_layers,
|
||||
transformer_num_heads,
|
||||
transformer_projection_dim)
|
||||
|
||||
for _ in range(transformer_layers):
|
||||
# Layer normalization 1.
|
||||
x1 = LayerNormalization(epsilon=1e-6)(encoded_patches)
|
||||
# Create a multi-head attention layer.
|
||||
attention_output = MultiHeadAttention(
|
||||
num_heads=num_heads, key_dim=projection_dim, dropout=0.1
|
||||
)(x1, x1)
|
||||
# Skip connection 1.
|
||||
x2 = Add()([attention_output, encoded_patches])
|
||||
# Layer normalization 2.
|
||||
x3 = LayerNormalization(epsilon=1e-6)(x2)
|
||||
# MLP.
|
||||
x3 = mlp(x3, hidden_units=mlp_head_units, dropout_rate=0.1)
|
||||
# Skip connection 2.
|
||||
encoded_patches = Add()([x3, x2])
|
||||
return unet_decoder(inputs, *features, n_classes, task=task, weight_decay=weight_decay)
|
||||
|
||||
encoded_patches = tf.reshape(encoded_patches, [-1, input_height, input_width , int( projection_dim / (patch_size_x * patch_size_y) )])
|
||||
def vit_resnet50_unet_transformer_before_cnn(num_patches,
|
||||
n_classes,
|
||||
transformer_patchsize_x,
|
||||
transformer_patchsize_y,
|
||||
transformer_mlp_head_units=None,
|
||||
transformer_layers=8,
|
||||
transformer_num_heads=4,
|
||||
transformer_projection_dim=64,
|
||||
input_height=224,
|
||||
input_width=224,
|
||||
task="segmentation",
|
||||
weight_decay=1e-6,
|
||||
pretraining=False):
|
||||
if transformer_mlp_head_units is None:
|
||||
transformer_mlp_head_units = [128, 64]
|
||||
inputs = Input(shape=(input_height, input_width, 3))
|
||||
|
||||
encoded_patches = Conv2D(3, (1, 1), padding='same', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay), name='convinput')(encoded_patches)
|
||||
encoded_patches = transformer_block(inputs,
|
||||
num_patches,
|
||||
transformer_patchsize_x,
|
||||
transformer_patchsize_y,
|
||||
transformer_mlp_head_units,
|
||||
transformer_layers,
|
||||
transformer_num_heads,
|
||||
transformer_projection_dim)
|
||||
encoded_patches = Conv2D(3, (1, 1), padding='same',
|
||||
data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay),
|
||||
name='convinput')(encoded_patches)
|
||||
|
||||
x = ZeroPadding2D((3, 3), data_format=IMAGE_ORDERING)(encoded_patches)
|
||||
x = Conv2D(64, (7, 7), data_format=IMAGE_ORDERING, strides=(2, 2),kernel_regularizer=l2(weight_decay), name='conv1')(x)
|
||||
f1 = x
|
||||
features = resnet50(encoded_patches, weight_decay=weight_decay, pretraining=pretraining)
|
||||
|
||||
x = BatchNormalization(axis=bn_axis, name='bn_conv1')(x)
|
||||
x = Activation('relu')(x)
|
||||
x = MaxPooling2D((3, 3), data_format=IMAGE_ORDERING, strides=(2, 2))(x)
|
||||
|
||||
x = conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1))
|
||||
x = identity_block(x, 3, [64, 64, 256], stage=2, block='b')
|
||||
x = identity_block(x, 3, [64, 64, 256], stage=2, block='c')
|
||||
f2 = one_side_pad(x)
|
||||
|
||||
x = conv_block(x, 3, [128, 128, 512], stage=3, block='a')
|
||||
x = identity_block(x, 3, [128, 128, 512], stage=3, block='b')
|
||||
x = identity_block(x, 3, [128, 128, 512], stage=3, block='c')
|
||||
x = identity_block(x, 3, [128, 128, 512], stage=3, block='d')
|
||||
f3 = x
|
||||
|
||||
x = conv_block(x, 3, [256, 256, 1024], stage=4, block='a')
|
||||
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='b')
|
||||
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='c')
|
||||
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='d')
|
||||
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='e')
|
||||
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='f')
|
||||
f4 = x
|
||||
|
||||
x = conv_block(x, 3, [512, 512, 2048], stage=5, block='a')
|
||||
x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b')
|
||||
x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c')
|
||||
f5 = x
|
||||
|
||||
if pretraining:
|
||||
model = Model(encoded_patches, x).load_weights(resnet50_Weights_path)
|
||||
|
||||
v1024_2048 = Conv2D( 1024 , (1, 1), padding='same', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay))(x)
|
||||
v1024_2048 = (BatchNormalization(axis=bn_axis))(v1024_2048)
|
||||
v1024_2048 = Activation('relu')(v1024_2048)
|
||||
|
||||
o = (UpSampling2D( (2, 2), data_format=IMAGE_ORDERING))(v1024_2048)
|
||||
o = (concatenate([o, f4],axis=MERGE_AXIS))
|
||||
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
|
||||
o = (Conv2D(512, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o)
|
||||
o = (BatchNormalization(axis=bn_axis))(o)
|
||||
o = Activation('relu')(o)
|
||||
|
||||
o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o)
|
||||
o = (concatenate([o ,f3], axis=MERGE_AXIS))
|
||||
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
|
||||
o = (Conv2D(256, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o)
|
||||
o = (BatchNormalization(axis=bn_axis))(o)
|
||||
o = Activation('relu')(o)
|
||||
|
||||
o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o)
|
||||
o = (concatenate([o, f2], axis=MERGE_AXIS))
|
||||
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
|
||||
o = (Conv2D(128, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o)
|
||||
o = (BatchNormalization(axis=bn_axis))(o)
|
||||
o = Activation('relu')(o)
|
||||
|
||||
o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o)
|
||||
o = (concatenate([o, f1], axis=MERGE_AXIS))
|
||||
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
|
||||
o = (Conv2D(64, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o)
|
||||
o = (BatchNormalization(axis=bn_axis))(o)
|
||||
o = Activation('relu')(o)
|
||||
|
||||
o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o)
|
||||
o = (concatenate([o, inputs],axis=MERGE_AXIS))
|
||||
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
|
||||
o = (Conv2D(32, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o)
|
||||
o = (BatchNormalization(axis=bn_axis))(o)
|
||||
o = Activation('relu')(o)
|
||||
|
||||
o = Conv2D(n_classes, (1, 1), padding='same', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay))(o)
|
||||
if task == "segmentation":
|
||||
o = (BatchNormalization(axis=bn_axis))(o)
|
||||
o = (Activation('softmax'))(o)
|
||||
else:
|
||||
o = (Activation('sigmoid'))(o)
|
||||
|
||||
model = Model(inputs=inputs, outputs=o)
|
||||
|
||||
return model
|
||||
return unet_decoder(inputs, *features, n_classes, task=task, weight_decay=weight_decay)
|
||||
|
||||
def resnet50_classifier(n_classes,input_height=224,input_width=224,weight_decay=1e-6,pretraining=False):
|
||||
include_top=True
|
||||
|
|
@ -606,47 +388,7 @@ def resnet50_classifier(n_classes,input_height=224,input_width=224,weight_decay=
|
|||
|
||||
img_input = Input(shape=(input_height,input_width , 3 ))
|
||||
|
||||
if IMAGE_ORDERING == 'channels_last':
|
||||
bn_axis = 3
|
||||
else:
|
||||
bn_axis = 1
|
||||
|
||||
x = ZeroPadding2D((3, 3), data_format=IMAGE_ORDERING)(img_input)
|
||||
x = Conv2D(64, (7, 7), data_format=IMAGE_ORDERING, strides=(2, 2),kernel_regularizer=l2(weight_decay), name='conv1')(x)
|
||||
f1 = x
|
||||
|
||||
x = BatchNormalization(axis=bn_axis, name='bn_conv1')(x)
|
||||
x = Activation('relu')(x)
|
||||
x = MaxPooling2D((3, 3) , data_format=IMAGE_ORDERING , strides=(2, 2))(x)
|
||||
|
||||
|
||||
x = conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1))
|
||||
x = identity_block(x, 3, [64, 64, 256], stage=2, block='b')
|
||||
x = identity_block(x, 3, [64, 64, 256], stage=2, block='c')
|
||||
f2 = one_side_pad(x )
|
||||
|
||||
|
||||
x = conv_block(x, 3, [128, 128, 512], stage=3, block='a')
|
||||
x = identity_block(x, 3, [128, 128, 512], stage=3, block='b')
|
||||
x = identity_block(x, 3, [128, 128, 512], stage=3, block='c')
|
||||
x = identity_block(x, 3, [128, 128, 512], stage=3, block='d')
|
||||
f3 = x
|
||||
|
||||
x = conv_block(x, 3, [256, 256, 1024], stage=4, block='a')
|
||||
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='b')
|
||||
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='c')
|
||||
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='d')
|
||||
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='e')
|
||||
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='f')
|
||||
f4 = x
|
||||
|
||||
x = conv_block(x, 3, [512, 512, 2048], stage=5, block='a')
|
||||
x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b')
|
||||
x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c')
|
||||
f5 = x
|
||||
|
||||
if pretraining:
|
||||
Model(img_input, x).load_weights(resnet50_Weights_path)
|
||||
_, _, _, _, x = resnet50(img_input, weight_decay, pretraining)
|
||||
|
||||
x = AveragePooling2D((7, 7), name='avg_pool')(x)
|
||||
x = Flatten()(x)
|
||||
|
|
@ -658,9 +400,6 @@ def resnet50_classifier(n_classes,input_height=224,input_width=224,weight_decay=
|
|||
x = Dense(n_classes, activation='softmax', name='fc1000')(x)
|
||||
model = Model(img_input, x)
|
||||
|
||||
|
||||
|
||||
|
||||
return model
|
||||
|
||||
def machine_based_reading_order_model(n_classes,input_height=224,input_width=224,weight_decay=1e-6,pretraining=False):
|
||||
|
|
@ -669,43 +408,10 @@ def machine_based_reading_order_model(n_classes,input_height=224,input_width=224
|
|||
|
||||
img_input = Input(shape=(input_height,input_width , 3 ))
|
||||
|
||||
if IMAGE_ORDERING == 'channels_last':
|
||||
bn_axis = 3
|
||||
else:
|
||||
bn_axis = 1
|
||||
_, _, _, _, x = resnet50(img_input, weight_decay, pretraining)
|
||||
|
||||
x1 = ZeroPadding2D((3, 3), data_format=IMAGE_ORDERING)(img_input)
|
||||
x1 = Conv2D(64, (7, 7), data_format=IMAGE_ORDERING, strides=(2, 2),kernel_regularizer=l2(weight_decay), name='conv1')(x1)
|
||||
|
||||
x1 = BatchNormalization(axis=bn_axis, name='bn_conv1')(x1)
|
||||
x1 = Activation('relu')(x1)
|
||||
x1 = MaxPooling2D((3, 3) , data_format=IMAGE_ORDERING , strides=(2, 2))(x1)
|
||||
|
||||
x1 = conv_block(x1, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1))
|
||||
x1 = identity_block(x1, 3, [64, 64, 256], stage=2, block='b')
|
||||
x1 = identity_block(x1, 3, [64, 64, 256], stage=2, block='c')
|
||||
|
||||
x1 = conv_block(x1, 3, [128, 128, 512], stage=3, block='a')
|
||||
x1 = identity_block(x1, 3, [128, 128, 512], stage=3, block='b')
|
||||
x1 = identity_block(x1, 3, [128, 128, 512], stage=3, block='c')
|
||||
x1 = identity_block(x1, 3, [128, 128, 512], stage=3, block='d')
|
||||
|
||||
x1 = conv_block(x1, 3, [256, 256, 1024], stage=4, block='a')
|
||||
x1 = identity_block(x1, 3, [256, 256, 1024], stage=4, block='b')
|
||||
x1 = identity_block(x1, 3, [256, 256, 1024], stage=4, block='c')
|
||||
x1 = identity_block(x1, 3, [256, 256, 1024], stage=4, block='d')
|
||||
x1 = identity_block(x1, 3, [256, 256, 1024], stage=4, block='e')
|
||||
x1 = identity_block(x1, 3, [256, 256, 1024], stage=4, block='f')
|
||||
|
||||
x1 = conv_block(x1, 3, [512, 512, 2048], stage=5, block='a')
|
||||
x1 = identity_block(x1, 3, [512, 512, 2048], stage=5, block='b')
|
||||
x1 = identity_block(x1, 3, [512, 512, 2048], stage=5, block='c')
|
||||
|
||||
if pretraining:
|
||||
Model(img_input , x1).load_weights(resnet50_Weights_path)
|
||||
|
||||
x1 = AveragePooling2D((7, 7), name='avg_pool1')(x1)
|
||||
flattened = Flatten()(x1)
|
||||
x = AveragePooling2D((7, 7), name='avg_pool1')(x)
|
||||
flattened = Flatten()(x)
|
||||
|
||||
o = Dense(256, activation='relu', name='fc512')(flattened)
|
||||
o=Dropout(0.2)(o)
|
||||
|
|
@ -719,83 +425,79 @@ def machine_based_reading_order_model(n_classes,input_height=224,input_width=224
|
|||
return model
|
||||
|
||||
def cnn_rnn_ocr_model(image_height=None, image_width=None, n_classes=None, max_seq=None):
|
||||
input_img = tf.keras.Input(shape=(image_height, image_width, 3), name="image")
|
||||
labels = tf.keras.layers.Input(name="label", shape=(None,))
|
||||
input_img = Input(shape=(image_height, image_width, 3), name="image")
|
||||
labels = Input(name="label", shape=(None,))
|
||||
|
||||
x = tf.keras.layers.Conv2D(64,kernel_size=(3,3),padding="same")(input_img)
|
||||
x = tf.keras.layers.BatchNormalization(name="bn1")(x)
|
||||
x = tf.keras.layers.Activation("relu", name="relu1")(x)
|
||||
x = tf.keras.layers.Conv2D(64,kernel_size=(3,3),padding="same")(x)
|
||||
x = tf.keras.layers.BatchNormalization(name="bn2")(x)
|
||||
x = tf.keras.layers.Activation("relu", name="relu2")(x)
|
||||
x = tf.keras.layers.MaxPool2D(pool_size=(1,2),strides=(1,2))(x)
|
||||
x = Conv2D(64,kernel_size=(3,3),padding="same")(input_img)
|
||||
x = BatchNormalization(name="bn1")(x)
|
||||
x = Activation("relu", name="relu1")(x)
|
||||
x = Conv2D(64,kernel_size=(3,3),padding="same")(x)
|
||||
x = BatchNormalization(name="bn2")(x)
|
||||
x = Activation("relu", name="relu2")(x)
|
||||
x = MaxPooling2D(pool_size=(1,2),strides=(1,2))(x)
|
||||
|
||||
x = tf.keras.layers.Conv2D(128,kernel_size=(3,3),padding="same")(x)
|
||||
x = tf.keras.layers.BatchNormalization(name="bn3")(x)
|
||||
x = tf.keras.layers.Activation("relu", name="relu3")(x)
|
||||
x = tf.keras.layers.Conv2D(128,kernel_size=(3,3),padding="same")(x)
|
||||
x = tf.keras.layers.BatchNormalization(name="bn4")(x)
|
||||
x = tf.keras.layers.Activation("relu", name="relu4")(x)
|
||||
x = tf.keras.layers.MaxPool2D(pool_size=(1,2),strides=(1,2))(x)
|
||||
x = Conv2D(128,kernel_size=(3,3),padding="same")(x)
|
||||
x = BatchNormalization(name="bn3")(x)
|
||||
x = Activation("relu", name="relu3")(x)
|
||||
x = Conv2D(128,kernel_size=(3,3),padding="same")(x)
|
||||
x = BatchNormalization(name="bn4")(x)
|
||||
x = Activation("relu", name="relu4")(x)
|
||||
x = MaxPooling2D(pool_size=(1,2),strides=(1,2))(x)
|
||||
|
||||
x = tf.keras.layers.Conv2D(256,kernel_size=(3,3),padding="same")(x)
|
||||
x = tf.keras.layers.BatchNormalization(name="bn5")(x)
|
||||
x = tf.keras.layers.Activation("relu", name="relu5")(x)
|
||||
x = tf.keras.layers.Conv2D(256,kernel_size=(3,3),padding="same")(x)
|
||||
x = tf.keras.layers.BatchNormalization(name="bn6")(x)
|
||||
x = tf.keras.layers.Activation("relu", name="relu6")(x)
|
||||
x = tf.keras.layers.MaxPool2D(pool_size=(2,2),strides=(2,2))(x)
|
||||
x = Conv2D(256,kernel_size=(3,3),padding="same")(x)
|
||||
x = BatchNormalization(name="bn5")(x)
|
||||
x = Activation("relu", name="relu5")(x)
|
||||
x = Conv2D(256,kernel_size=(3,3),padding="same")(x)
|
||||
x = BatchNormalization(name="bn6")(x)
|
||||
x = Activation("relu", name="relu6")(x)
|
||||
x = MaxPooling2D(pool_size=(2,2),strides=(2,2))(x)
|
||||
|
||||
x = tf.keras.layers.Conv2D(image_width,kernel_size=(3,3),padding="same")(x)
|
||||
x = tf.keras.layers.BatchNormalization(name="bn7")(x)
|
||||
x = tf.keras.layers.Activation("relu", name="relu7")(x)
|
||||
x = tf.keras.layers.Conv2D(image_width,kernel_size=(16,1))(x)
|
||||
x = tf.keras.layers.BatchNormalization(name="bn8")(x)
|
||||
x = tf.keras.layers.Activation("relu", name="relu8")(x)
|
||||
x2d = tf.keras.layers.MaxPool2D(pool_size=(1,2),strides=(1,2))(x)
|
||||
x4d = tf.keras.layers.MaxPool2D(pool_size=(1,2),strides=(1,2))(x2d)
|
||||
x = Conv2D(image_width,kernel_size=(3,3),padding="same")(x)
|
||||
x = BatchNormalization(name="bn7")(x)
|
||||
x = Activation("relu", name="relu7")(x)
|
||||
x = Conv2D(image_width,kernel_size=(16,1))(x)
|
||||
x = BatchNormalization(name="bn8")(x)
|
||||
x = Activation("relu", name="relu8")(x)
|
||||
x2d = MaxPooling2D(pool_size=(1,2),strides=(1,2))(x)
|
||||
x4d = MaxPooling2D(pool_size=(1,2),strides=(1,2))(x2d)
|
||||
|
||||
|
||||
new_shape = (x.shape[1]*x.shape[2], x.shape[3])
|
||||
new_shape2 = (x2d.shape[1]*x2d.shape[2], x2d.shape[3])
|
||||
new_shape4 = (x4d.shape[1]*x4d.shape[2], x4d.shape[3])
|
||||
|
||||
x = tf.keras.layers.Reshape(target_shape=new_shape, name="reshape")(x)
|
||||
x2d = tf.keras.layers.Reshape(target_shape=new_shape2, name="reshape2")(x2d)
|
||||
x4d = tf.keras.layers.Reshape(target_shape=new_shape4, name="reshape4")(x4d)
|
||||
x = Reshape(target_shape=new_shape, name="reshape")(x)
|
||||
x2d = Reshape(target_shape=new_shape2, name="reshape2")(x2d)
|
||||
x4d = Reshape(target_shape=new_shape4, name="reshape4")(x4d)
|
||||
|
||||
xrnnorg = Bidirectional(LSTM(image_width, return_sequences=True, dropout=0.25))(x)
|
||||
xrnn2d = Bidirectional(LSTM(image_width, return_sequences=True, dropout=0.25))(x2d)
|
||||
xrnn4d = Bidirectional(LSTM(image_width, return_sequences=True, dropout=0.25))(x4d)
|
||||
|
||||
xrnn2d = Reshape(target_shape=(1, xrnn2d.shape[1], xrnn2d.shape[2]), name="reshape6")(xrnn2d)
|
||||
xrnn4d = Reshape(target_shape=(1, xrnn4d.shape[1], xrnn4d.shape[2]), name="reshape8")(xrnn4d)
|
||||
|
||||
|
||||
xrnnorg = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(image_width, return_sequences=True, dropout=0.25))(x)
|
||||
xrnn2d = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(image_width, return_sequences=True, dropout=0.25))(x2d)
|
||||
xrnn4d = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(image_width, return_sequences=True, dropout=0.25))(x4d)
|
||||
xrnn2dup = UpSampling2D(size=(1, 2), interpolation="nearest")(xrnn2d)
|
||||
xrnn4dup = UpSampling2D(size=(1, 4), interpolation="nearest")(xrnn4d)
|
||||
|
||||
xrnn2d = tf.keras.layers.Reshape(target_shape=(1, xrnn2d.shape[1], xrnn2d.shape[2]), name="reshape6")(xrnn2d)
|
||||
xrnn4d = tf.keras.layers.Reshape(target_shape=(1, xrnn4d.shape[1], xrnn4d.shape[2]), name="reshape8")(xrnn4d)
|
||||
xrnn2dup = Reshape(target_shape=(xrnn2dup.shape[2], xrnn2dup.shape[3]), name="reshape10")(xrnn2dup)
|
||||
xrnn4dup = Reshape(target_shape=(xrnn4dup.shape[2], xrnn4dup.shape[3]), name="reshape12")(xrnn4dup)
|
||||
|
||||
addition = Add()([xrnnorg, xrnn2dup, xrnn4dup])
|
||||
|
||||
xrnn2dup = tf.keras.layers.UpSampling2D(size=(1, 2), interpolation="nearest")(xrnn2d)
|
||||
xrnn4dup = tf.keras.layers.UpSampling2D(size=(1, 4), interpolation="nearest")(xrnn4d)
|
||||
addition_rnn = Bidirectional(LSTM(image_width, return_sequences=True, dropout=0.25))(addition)
|
||||
|
||||
xrnn2dup = tf.keras.layers.Reshape(target_shape=(xrnn2dup.shape[2], xrnn2dup.shape[3]), name="reshape10")(xrnn2dup)
|
||||
xrnn4dup = tf.keras.layers.Reshape(target_shape=(xrnn4dup.shape[2], xrnn4dup.shape[3]), name="reshape12")(xrnn4dup)
|
||||
out = Conv1D(max_seq, 1, data_format="channels_first")(addition_rnn)
|
||||
out = BatchNormalization(name="bn9")(out)
|
||||
out = Activation("relu", name="relu9")(out)
|
||||
#out = Conv1D(n_classes, 1, activation='relu', data_format="channels_last")(out)
|
||||
|
||||
addition = tf.keras.layers.Add()([xrnnorg, xrnn2dup, xrnn4dup])
|
||||
|
||||
addition_rnn = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(image_width, return_sequences=True, dropout=0.25))(addition)
|
||||
|
||||
out = tf.keras.layers.Conv1D(max_seq, 1, data_format="channels_first")(addition_rnn)
|
||||
out = tf.keras.layers.BatchNormalization(name="bn9")(out)
|
||||
out = tf.keras.layers.Activation("relu", name="relu9")(out)
|
||||
#out = tf.keras.layers.Conv1D(n_classes, 1, activation='relu', data_format="channels_last")(out)
|
||||
|
||||
out = tf.keras.layers.Dense(
|
||||
n_classes, activation="softmax", name="dense2"
|
||||
)(out)
|
||||
out = Dense(n_classes, activation="softmax", name="dense2")(out)
|
||||
|
||||
# Add CTC layer for calculating CTC loss at each step.
|
||||
output = CTCLayer(name="ctc_loss")(labels, out)
|
||||
|
||||
model = tf.keras.models.Model(inputs=[input_img, labels], outputs=output, name="handwriting_recognizer")
|
||||
model = Model(inputs=(input_img, labels), outputs=output, name="handwriting_recognizer")
|
||||
|
||||
return model
|
||||
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
|
|
@ -1,136 +1,66 @@
|
|||
import sys
|
||||
from glob import glob
|
||||
from os import environ, devnull
|
||||
from os.path import join
|
||||
from warnings import catch_warnings, simplefilter
|
||||
import os
|
||||
from warnings import catch_warnings, simplefilter
|
||||
|
||||
import click
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import cv2
|
||||
environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
||||
stderr = sys.stderr
|
||||
sys.stderr = open(devnull, 'w')
|
||||
|
||||
os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15
|
||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
||||
|
||||
from ocrd_utils import tf_disable_interactive_logs
|
||||
tf_disable_interactive_logs()
|
||||
import tensorflow as tf
|
||||
from tensorflow.keras.models import load_model
|
||||
from tensorflow.python.keras import backend as tensorflow_backend
|
||||
sys.stderr = stderr
|
||||
from tensorflow.keras import layers
|
||||
import tensorflow.keras.losses
|
||||
from tensorflow.keras.layers import *
|
||||
import click
|
||||
import logging
|
||||
|
||||
from ..patch_encoder import (
|
||||
PatchEncoder,
|
||||
Patches,
|
||||
)
|
||||
|
||||
class Patches(layers.Layer):
|
||||
def __init__(self, patch_size_x, patch_size_y):
|
||||
super(Patches, self).__init__()
|
||||
self.patch_size_x = patch_size_x
|
||||
self.patch_size_y = patch_size_y
|
||||
def run_ensembling(model_dirs, out_dir):
|
||||
all_weights = []
|
||||
|
||||
def call(self, images):
|
||||
#print(tf.shape(images)[1],'images')
|
||||
#print(self.patch_size,'self.patch_size')
|
||||
batch_size = tf.shape(images)[0]
|
||||
patches = tf.image.extract_patches(
|
||||
images=images,
|
||||
sizes=[1, self.patch_size_y, self.patch_size_x, 1],
|
||||
strides=[1, self.patch_size_y, self.patch_size_x, 1],
|
||||
rates=[1, 1, 1, 1],
|
||||
padding="VALID",
|
||||
)
|
||||
#patch_dims = patches.shape[-1]
|
||||
patch_dims = tf.shape(patches)[-1]
|
||||
patches = tf.reshape(patches, [batch_size, -1, patch_dims])
|
||||
return patches
|
||||
def get_config(self):
|
||||
for model_dir in model_dirs:
|
||||
assert os.path.isdir(model_dir), model_dir
|
||||
model = load_model(model_dir, compile=False,
|
||||
custom_objects=dict(PatchEncoder=PatchEncoder,
|
||||
Patches=Patches))
|
||||
all_weights.append(model.get_weights())
|
||||
|
||||
config = super().get_config().copy()
|
||||
config.update({
|
||||
'patch_size_x': self.patch_size_x,
|
||||
'patch_size_y': self.patch_size_y,
|
||||
})
|
||||
return config
|
||||
|
||||
|
||||
|
||||
class PatchEncoder(layers.Layer):
|
||||
def __init__(self, **kwargs):
|
||||
super(PatchEncoder, self).__init__()
|
||||
self.num_patches = num_patches
|
||||
self.projection = layers.Dense(units=projection_dim)
|
||||
self.position_embedding = layers.Embedding(
|
||||
input_dim=num_patches, output_dim=projection_dim
|
||||
)
|
||||
|
||||
def call(self, patch):
|
||||
positions = tf.range(start=0, limit=self.num_patches, delta=1)
|
||||
encoded = self.projection(patch) + self.position_embedding(positions)
|
||||
return encoded
|
||||
def get_config(self):
|
||||
|
||||
config = super().get_config().copy()
|
||||
config.update({
|
||||
'num_patches': self.num_patches,
|
||||
'projection': self.projection,
|
||||
'position_embedding': self.position_embedding,
|
||||
})
|
||||
return config
|
||||
|
||||
|
||||
def start_new_session():
|
||||
###config = tf.compat.v1.ConfigProto()
|
||||
###config.gpu_options.allow_growth = True
|
||||
|
||||
###self.session = tf.compat.v1.Session(config=config) # tf.InteractiveSession()
|
||||
###tensorflow_backend.set_session(self.session)
|
||||
|
||||
config = tf.compat.v1.ConfigProto()
|
||||
config.gpu_options.allow_growth = True
|
||||
|
||||
session = tf.compat.v1.Session(config=config) # tf.InteractiveSession()
|
||||
tensorflow_backend.set_session(session)
|
||||
return session
|
||||
|
||||
def run_ensembling(dir_models, out):
|
||||
ls_models = os.listdir(dir_models)
|
||||
|
||||
|
||||
weights=[]
|
||||
|
||||
for model_name in ls_models:
|
||||
model = load_model(os.path.join(dir_models,model_name) , compile=False, custom_objects={'PatchEncoder':PatchEncoder, 'Patches': Patches})
|
||||
weights.append(model.get_weights())
|
||||
|
||||
new_weights = list()
|
||||
|
||||
for weights_list_tuple in zip(*weights):
|
||||
new_weights.append(
|
||||
[np.array(weights_).mean(axis=0)\
|
||||
for weights_ in zip(*weights_list_tuple)])
|
||||
|
||||
|
||||
|
||||
new_weights = [np.array(x) for x in new_weights]
|
||||
new_weights = []
|
||||
for layer_weights in zip(*all_weights):
|
||||
layer_weights = np.array([np.array(weights).mean(axis=0)
|
||||
for weights in zip(*layer_weights)])
|
||||
new_weights.append(layer_weights)
|
||||
|
||||
#model = tf.keras.models.clone_model(model)
|
||||
model.set_weights(new_weights)
|
||||
model.save(out)
|
||||
os.system('cp '+os.path.join(os.path.join(dir_models,model_name) , "config.json ")+out)
|
||||
|
||||
model.save(out_dir)
|
||||
os.system('cp ' + os.path.join(model_dirs[0], "config.json ") + out_dir + "/")
|
||||
|
||||
@click.command()
|
||||
@click.option(
|
||||
"--dir_models",
|
||||
"-dm",
|
||||
help="directory of models",
|
||||
"--in",
|
||||
"-i",
|
||||
help="input directory of checkpoint models to be read",
|
||||
multiple=True,
|
||||
required=True,
|
||||
type=click.Path(exists=True, file_okay=False),
|
||||
)
|
||||
@click.option(
|
||||
"--out",
|
||||
"-o",
|
||||
help="output directory where ensembled model will be written.",
|
||||
required=True,
|
||||
type=click.Path(exists=False, file_okay=False),
|
||||
)
|
||||
def ensemble_cli(in_, out):
|
||||
"""
|
||||
mix multiple model weights
|
||||
|
||||
def main(dir_models, out):
|
||||
run_ensembling(dir_models, out)
|
||||
Load a sequence of models and mix them into a single ensemble model
|
||||
by averaging their weights. Write the resulting model.
|
||||
"""
|
||||
run_ensembling(in_, out)
|
||||
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -14,21 +14,16 @@ from shapely.ops import unary_union, nearest_points
|
|||
from .rotate import rotate_image, rotation_image_new
|
||||
|
||||
def contours_in_same_horizon(cy_main_hor):
|
||||
X1 = np.zeros((len(cy_main_hor), len(cy_main_hor)))
|
||||
X2 = np.zeros((len(cy_main_hor), len(cy_main_hor)))
|
||||
|
||||
X1[0::1, :] = cy_main_hor[:]
|
||||
X2 = X1.T
|
||||
|
||||
X_dif = np.abs(X2 - X1)
|
||||
args_help = np.array(range(len(cy_main_hor)))
|
||||
all_args = []
|
||||
for i in range(len(cy_main_hor)):
|
||||
list_h = list(args_help[X_dif[i, :] <= 20])
|
||||
list_h.append(i)
|
||||
if len(list_h) > 1:
|
||||
all_args.append(list(set(list_h)))
|
||||
return np.unique(np.array(all_args, dtype=object))
|
||||
"""
|
||||
Takes an array of y coords, identifies all pairs among them
|
||||
which are close to each other, and returns all such pairs
|
||||
by index into the array.
|
||||
"""
|
||||
sort = np.argsort(cy_main_hor)
|
||||
same = np.diff(cy_main_hor[sort]) <= 20
|
||||
# groups = np.split(sort, np.arange(len(cy_main_hor) - 1)[~same] + 1)
|
||||
same = np.flatnonzero(same)
|
||||
return np.stack((sort[:-1][same], sort[1:][same])).T
|
||||
|
||||
def find_contours_mean_y_diff(contours_main):
|
||||
M_main = [cv2.moments(contours_main[j]) for j in range(len(contours_main))]
|
||||
|
|
@ -175,7 +170,7 @@ def get_textregion_contours_in_org_image(cnts, img, slope_first):
|
|||
|
||||
return cnts_org
|
||||
|
||||
def get_textregion_contours_in_org_image_light_old(cnts, img, slope_first):
|
||||
def get_textregion_confidences_old(cnts, img, slope_first):
|
||||
zoom = 3
|
||||
img = cv2.resize(img, (img.shape[1] // zoom,
|
||||
img.shape[0] // zoom),
|
||||
|
|
@ -213,16 +208,17 @@ def do_back_rotation_and_get_cnt_back(contour_par, index_r_con, img, slope_first
|
|||
cont_int[0][:, 0, 1] = cont_int[0][:, 0, 1] + np.abs(img_copy.shape[0] - img.shape[0])
|
||||
return cont_int[0], index_r_con, confidence_contour
|
||||
|
||||
def get_textregion_contours_in_org_image_light(cnts, img, confidence_matrix):
|
||||
def get_textregion_confidences(cnts, confidence_matrix):
|
||||
if not len(cnts):
|
||||
return []
|
||||
|
||||
height, width = confidence_matrix.shape
|
||||
confidence_matrix = cv2.resize(confidence_matrix,
|
||||
(img.shape[1] // 6, img.shape[0] // 6),
|
||||
(width // 6, height // 6),
|
||||
interpolation=cv2.INTER_NEAREST)
|
||||
confs = []
|
||||
for cnt in cnts:
|
||||
cnt_mask = np.zeros(confidence_matrix.shape)
|
||||
cnt_mask = np.zeros_like(confidence_matrix)
|
||||
cnt_mask = cv2.fillPoly(cnt_mask, pts=[cnt // 6], color=1.0)
|
||||
confs.append(np.sum(confidence_matrix * cnt_mask) / np.sum(cnt_mask))
|
||||
return confs
|
||||
|
|
@ -253,13 +249,17 @@ def return_contours_of_image(image):
|
|||
return contours, hierarchy
|
||||
|
||||
def dilate_textline_contours(all_found_textline_polygons):
|
||||
return [[polygon2contour(contour2polygon(contour, dilate=6))
|
||||
for contour in region]
|
||||
from . import ensure_array
|
||||
return [ensure_array(
|
||||
[polygon2contour(contour2polygon(contour, dilate=6))
|
||||
for contour in region])
|
||||
for region in all_found_textline_polygons]
|
||||
|
||||
def dilate_textregion_contours(all_found_textline_polygons):
|
||||
return [polygon2contour(contour2polygon(contour, dilate=6))
|
||||
for contour in all_found_textline_polygons]
|
||||
def dilate_textregion_contours(all_found_textregion_polygons):
|
||||
from . import ensure_array
|
||||
return ensure_array(
|
||||
[polygon2contour(contour2polygon(contour, dilate=6))
|
||||
for contour in all_found_textregion_polygons])
|
||||
|
||||
def contour2polygon(contour: Union[np.ndarray, Sequence[Sequence[Sequence[Number]]]], dilate=0):
|
||||
polygon = Polygon([point[0] for point in contour])
|
||||
|
|
|
|||
|
|
@ -501,7 +501,7 @@ def adhere_drop_capital_region_into_corresponding_textline(
|
|||
|
||||
def filter_small_drop_capitals_from_no_patch_layout(layout_no_patch, layout1):
|
||||
|
||||
drop_only = (layout_no_patch[:, :, 0] == 4) * 1
|
||||
drop_only = (layout_no_patch == 4) * 1
|
||||
contours_drop, hir_on_drop = return_contours_of_image(drop_only)
|
||||
contours_drop_parent = return_parent_contours(contours_drop, hir_on_drop)
|
||||
|
||||
|
|
@ -529,9 +529,8 @@ def filter_small_drop_capitals_from_no_patch_layout(layout_no_patch, layout1):
|
|||
if (((map_of_drop_contour_bb == 1) * 1).sum() / float(((map_of_drop_contour_bb == 5) * 1).sum()) * 100) >= 15:
|
||||
contours_drop_parent_final.append(contours_drop_parent[jj])
|
||||
|
||||
layout_no_patch[:, :, 0][layout_no_patch[:, :, 0] == 4] = 0
|
||||
|
||||
layout_no_patch = cv2.fillPoly(layout_no_patch, pts=contours_drop_parent_final, color=(4, 4, 4))
|
||||
layout_no_patch[layout_no_patch == 4] = 0
|
||||
layout_no_patch = cv2.fillPoly(layout_no_patch, pts=contours_drop_parent_final, color=4)
|
||||
|
||||
return layout_no_patch
|
||||
|
||||
|
|
|
|||
|
|
@ -15,7 +15,6 @@ from .contour import (
|
|||
return_contours_of_interested_textline,
|
||||
find_contours_mean_y_diff,
|
||||
)
|
||||
from .shm import share_ndarray, wrap_ndarray_shared
|
||||
from . import (
|
||||
find_num_col_deskew,
|
||||
box2rect,
|
||||
|
|
@ -399,14 +398,14 @@ def separate_lines(img_patch, contour_text_interest, thetha, x_help, y_help):
|
|||
point_down_rot3=point_down_rot3-y_help
|
||||
point_down_rot4=point_down_rot4-y_help
|
||||
|
||||
textline_boxes_rot.append(np.array([[int(x_min_rot1), int(point_up_rot1)],
|
||||
[int(x_max_rot2), int(point_up_rot2)],
|
||||
[int(x_max_rot3), int(point_down_rot3)],
|
||||
[int(x_min_rot4), int(point_down_rot4)]]))
|
||||
textline_boxes.append(np.array([[int(x_min), int(point_up)],
|
||||
[int(x_max), int(point_up)],
|
||||
[int(x_max), int(point_down)],
|
||||
[int(x_min), int(point_down)]]))
|
||||
textline_boxes_rot.append(np.array([[[int(x_min_rot1), int(point_up_rot1)]],
|
||||
[[int(x_max_rot2), int(point_up_rot2)]],
|
||||
[[int(x_max_rot3), int(point_down_rot3)]],
|
||||
[[int(x_min_rot4), int(point_down_rot4)]]]))
|
||||
textline_boxes.append(np.array([[[int(x_min), int(point_up)]],
|
||||
[[int(x_max), int(point_up)]],
|
||||
[[int(x_max), int(point_down)]],
|
||||
[[int(x_min), int(point_down)]]]))
|
||||
elif len(peaks) < 1:
|
||||
pass
|
||||
|
||||
|
|
@ -458,14 +457,14 @@ def separate_lines(img_patch, contour_text_interest, thetha, x_help, y_help):
|
|||
point_down_rot3=point_down_rot3-y_help
|
||||
point_down_rot4=point_down_rot4-y_help
|
||||
|
||||
textline_boxes_rot.append(np.array([[int(x_min_rot1), int(point_up_rot1)],
|
||||
[int(x_max_rot2), int(point_up_rot2)],
|
||||
[int(x_max_rot3), int(point_down_rot3)],
|
||||
[int(x_min_rot4), int(point_down_rot4)]]))
|
||||
textline_boxes.append(np.array([[int(x_min), int(y_min)],
|
||||
[int(x_max), int(y_min)],
|
||||
[int(x_max), int(y_max)],
|
||||
[int(x_min), int(y_max)]]))
|
||||
textline_boxes_rot.append(np.array([[[int(x_min_rot1), int(point_up_rot1)]],
|
||||
[[int(x_max_rot2), int(point_up_rot2)]],
|
||||
[[int(x_max_rot3), int(point_down_rot3)]],
|
||||
[[int(x_min_rot4), int(point_down_rot4)]]]))
|
||||
textline_boxes.append(np.array([[[int(x_min), int(y_min)]],
|
||||
[[int(x_max), int(y_min)]],
|
||||
[[int(x_max), int(y_max)]],
|
||||
[[int(x_min), int(y_max)]]]))
|
||||
elif len(peaks) == 2:
|
||||
dis_to_next = np.abs(peaks[1] - peaks[0])
|
||||
for jj in range(len(peaks)):
|
||||
|
|
@ -526,14 +525,14 @@ def separate_lines(img_patch, contour_text_interest, thetha, x_help, y_help):
|
|||
point_down_rot3=point_down_rot3-y_help
|
||||
point_down_rot4=point_down_rot4-y_help
|
||||
|
||||
textline_boxes_rot.append(np.array([[int(x_min_rot1), int(point_up_rot1)],
|
||||
[int(x_max_rot2), int(point_up_rot2)],
|
||||
[int(x_max_rot3), int(point_down_rot3)],
|
||||
[int(x_min_rot4), int(point_down_rot4)]]))
|
||||
textline_boxes.append(np.array([[int(x_min), int(point_up)],
|
||||
[int(x_max), int(point_up)],
|
||||
[int(x_max), int(point_down)],
|
||||
[int(x_min), int(point_down)]]))
|
||||
textline_boxes_rot.append(np.array([[[int(x_min_rot1), int(point_up_rot1)]],
|
||||
[[int(x_max_rot2), int(point_up_rot2)]],
|
||||
[[int(x_max_rot3), int(point_down_rot3)]],
|
||||
[[int(x_min_rot4), int(point_down_rot4)]]]))
|
||||
textline_boxes.append(np.array([[[int(x_min), int(point_up)]],
|
||||
[[int(x_max), int(point_up)]],
|
||||
[[int(x_max), int(point_down)]],
|
||||
[[int(x_min), int(point_down)]]]))
|
||||
else:
|
||||
for jj in range(len(peaks)):
|
||||
if jj == 0:
|
||||
|
|
@ -602,14 +601,14 @@ def separate_lines(img_patch, contour_text_interest, thetha, x_help, y_help):
|
|||
point_down_rot3=point_down_rot3-y_help
|
||||
point_down_rot4=point_down_rot4-y_help
|
||||
|
||||
textline_boxes_rot.append(np.array([[int(x_min_rot1), int(point_up_rot1)],
|
||||
[int(x_max_rot2), int(point_up_rot2)],
|
||||
[int(x_max_rot3), int(point_down_rot3)],
|
||||
[int(x_min_rot4), int(point_down_rot4)]]))
|
||||
textline_boxes.append(np.array([[int(x_min), int(point_up)],
|
||||
[int(x_max), int(point_up)],
|
||||
[int(x_max), int(point_down)],
|
||||
[int(x_min), int(point_down)]]))
|
||||
textline_boxes_rot.append(np.array([[[int(x_min_rot1), int(point_up_rot1)]],
|
||||
[[int(x_max_rot2), int(point_up_rot2)]],
|
||||
[[int(x_max_rot3), int(point_down_rot3)]],
|
||||
[[int(x_min_rot4), int(point_down_rot4)]]]))
|
||||
textline_boxes.append(np.array([[[int(x_min), int(point_up)]],
|
||||
[[int(x_max), int(point_up)]],
|
||||
[[int(x_max), int(point_down)]],
|
||||
[[int(x_min), int(point_down)]]]))
|
||||
return peaks, textline_boxes_rot
|
||||
|
||||
def separate_lines_vertical(img_patch, contour_text_interest, thetha):
|
||||
|
|
@ -781,14 +780,14 @@ def separate_lines_vertical(img_patch, contour_text_interest, thetha):
|
|||
if point_up_rot2 < 0:
|
||||
point_up_rot2 = 0
|
||||
|
||||
textline_boxes_rot.append(np.array([[int(x_min_rot1), int(point_up_rot1)],
|
||||
[int(x_max_rot2), int(point_up_rot2)],
|
||||
[int(x_max_rot3), int(point_down_rot3)],
|
||||
[int(x_min_rot4), int(point_down_rot4)]]))
|
||||
textline_boxes.append(np.array([[int(x_min), int(point_up)],
|
||||
[int(x_max), int(point_up)],
|
||||
[int(x_max), int(point_down)],
|
||||
[int(x_min), int(point_down)]]))
|
||||
textline_boxes_rot.append(np.array([[[int(x_min_rot1), int(point_up_rot1)]],
|
||||
[[int(x_max_rot2), int(point_up_rot2)]],
|
||||
[[int(x_max_rot3), int(point_down_rot3)]],
|
||||
[[int(x_min_rot4), int(point_down_rot4)]]]))
|
||||
textline_boxes.append(np.array([[[int(x_min), int(point_up)]],
|
||||
[[int(x_max), int(point_up)]],
|
||||
[[int(x_max), int(point_down)]],
|
||||
[[int(x_min), int(point_down)]]]))
|
||||
elif len(peaks) < 1:
|
||||
pass
|
||||
elif len(peaks) == 1:
|
||||
|
|
@ -817,14 +816,14 @@ def separate_lines_vertical(img_patch, contour_text_interest, thetha):
|
|||
if point_up_rot2 < 0:
|
||||
point_up_rot2 = 0
|
||||
|
||||
textline_boxes_rot.append(np.array([[int(x_min_rot1), int(point_up_rot1)],
|
||||
[int(x_max_rot2), int(point_up_rot2)],
|
||||
[int(x_max_rot3), int(point_down_rot3)],
|
||||
[int(x_min_rot4), int(point_down_rot4)]]))
|
||||
textline_boxes.append(np.array([[int(x_min), int(y_min)],
|
||||
[int(x_max), int(y_min)],
|
||||
[int(x_max), int(y_max)],
|
||||
[int(x_min), int(y_max)]]))
|
||||
textline_boxes_rot.append(np.array([[[int(x_min_rot1), int(point_up_rot1)]],
|
||||
[[int(x_max_rot2), int(point_up_rot2)]],
|
||||
[[int(x_max_rot3), int(point_down_rot3)]],
|
||||
[[int(x_min_rot4), int(point_down_rot4)]]]))
|
||||
textline_boxes.append(np.array([[[int(x_min), int(y_min)]],
|
||||
[[int(x_max), int(y_min)]],
|
||||
[[int(x_max), int(y_max)]],
|
||||
[[int(x_min), int(y_max)]]]))
|
||||
elif len(peaks) == 2:
|
||||
dis_to_next = np.abs(peaks[1] - peaks[0])
|
||||
for jj in range(len(peaks)):
|
||||
|
|
@ -872,14 +871,14 @@ def separate_lines_vertical(img_patch, contour_text_interest, thetha):
|
|||
if point_up_rot2 < 0:
|
||||
point_up_rot2 = 0
|
||||
|
||||
textline_boxes_rot.append(np.array([[int(x_min_rot1), int(point_up_rot1)],
|
||||
[int(x_max_rot2), int(point_up_rot2)],
|
||||
[int(x_max_rot3), int(point_down_rot3)],
|
||||
[int(x_min_rot4), int(point_down_rot4)]]))
|
||||
textline_boxes.append(np.array([[int(x_min), int(point_up)],
|
||||
[int(x_max), int(point_up)],
|
||||
[int(x_max), int(point_down)],
|
||||
[int(x_min), int(point_down)]]))
|
||||
textline_boxes_rot.append(np.array([[[int(x_min_rot1), int(point_up_rot1)]],
|
||||
[[int(x_max_rot2), int(point_up_rot2)]],
|
||||
[[int(x_max_rot3), int(point_down_rot3)]],
|
||||
[[int(x_min_rot4), int(point_down_rot4)]]]))
|
||||
textline_boxes.append(np.array([[[int(x_min), int(point_up)]],
|
||||
[[int(x_max), int(point_up)]],
|
||||
[[int(x_max), int(point_down)]],
|
||||
[[int(x_min), int(point_down)]]]))
|
||||
else:
|
||||
for jj in range(len(peaks)):
|
||||
if jj == 0:
|
||||
|
|
@ -938,14 +937,14 @@ def separate_lines_vertical(img_patch, contour_text_interest, thetha):
|
|||
if point_up_rot2 < 0:
|
||||
point_up_rot2 = 0
|
||||
|
||||
textline_boxes_rot.append(np.array([[int(x_min_rot1), int(point_up_rot1)],
|
||||
[int(x_max_rot2), int(point_up_rot2)],
|
||||
[int(x_max_rot3), int(point_down_rot3)],
|
||||
[int(x_min_rot4), int(point_down_rot4)]]))
|
||||
textline_boxes.append(np.array([[int(x_min), int(point_up)],
|
||||
[int(x_max), int(point_up)],
|
||||
[int(x_max), int(point_down)],
|
||||
[int(x_min), int(point_down)]]))
|
||||
textline_boxes_rot.append(np.array([[[int(x_min_rot1), int(point_up_rot1)]],
|
||||
[[int(x_max_rot2), int(point_up_rot2)]],
|
||||
[[int(x_max_rot3), int(point_down_rot3)]],
|
||||
[[int(x_min_rot4), int(point_down_rot4)]]]))
|
||||
textline_boxes.append(np.array([[[int(x_min), int(point_up)]],
|
||||
[[int(x_max), int(point_up)]],
|
||||
[[int(x_max), int(point_down)]],
|
||||
[[int(x_min), int(point_down)]]]))
|
||||
return peaks, textline_boxes_rot
|
||||
|
||||
def separate_lines_new_inside_tiles2(img_patch, thetha):
|
||||
|
|
@ -1493,7 +1492,6 @@ def separate_lines_new2(img_crop, thetha, num_col, slope_region, logger=None, pl
|
|||
|
||||
return img_patch_interest_revised
|
||||
|
||||
@wrap_ndarray_shared(kw='img')
|
||||
def do_image_rotation(angle, img=None, sigma_des=1.0, logger=None):
|
||||
if logger is None:
|
||||
logger = getLogger(__package__)
|
||||
|
|
@ -1507,9 +1505,9 @@ def do_image_rotation(angle, img=None, sigma_des=1.0, logger=None):
|
|||
return var
|
||||
|
||||
def return_deskew_slop(img_patch_org, sigma_des,n_tot_angles=100,
|
||||
main_page=False, logger=None, plotter=None, map=None):
|
||||
main_page=False, logger=None, plotter=None, name=None):
|
||||
if main_page and plotter:
|
||||
plotter.save_plot_of_textline_density(img_patch_org)
|
||||
plotter.save_plot_of_textline_density(img_patch_org, name)
|
||||
|
||||
img_int=np.zeros((img_patch_org.shape[0],img_patch_org.shape[1]))
|
||||
img_int[:,:]=img_patch_org[:,:]#img_patch_org[:,:,0]
|
||||
|
|
@ -1527,16 +1525,16 @@ def return_deskew_slop(img_patch_org, sigma_des,n_tot_angles=100,
|
|||
|
||||
if main_page and img_patch_org.shape[1] > img_patch_org.shape[0]:
|
||||
angles = np.array([-45, 0, 45, 90,])
|
||||
angle, _ = get_smallest_skew(img_resized, sigma_des, angles, map=map, logger=logger, plotter=plotter)
|
||||
angle, _ = get_smallest_skew(img_resized, sigma_des, angles, logger=logger, name=name, plotter=plotter)
|
||||
|
||||
angles = np.linspace(angle - 22.5, angle + 22.5, n_tot_angles)
|
||||
angle, _ = get_smallest_skew(img_resized, sigma_des, angles, map=map, logger=logger, plotter=plotter)
|
||||
angle, _ = get_smallest_skew(img_resized, sigma_des, angles, logger=logger, name=name, plotter=plotter)
|
||||
elif main_page:
|
||||
#angles = np.linspace(-12, 12, n_tot_angles)#np.array([0 , 45 , 90 , -45])
|
||||
angles = np.concatenate((np.linspace(-12, -7, n_tot_angles // 4),
|
||||
np.linspace(-6, 6, n_tot_angles // 2),
|
||||
np.linspace(7, 12, n_tot_angles // 4)))
|
||||
angle, var = get_smallest_skew(img_resized, sigma_des, angles, map=map, logger=logger, plotter=plotter)
|
||||
angle, var = get_smallest_skew(img_resized, sigma_des, angles, logger=logger, name=name, plotter=plotter)
|
||||
|
||||
early_slope_edge=11
|
||||
if abs(angle) > early_slope_edge:
|
||||
|
|
@ -1544,12 +1542,12 @@ def return_deskew_slop(img_patch_org, sigma_des,n_tot_angles=100,
|
|||
angles2 = np.linspace(-90, -12, n_tot_angles)
|
||||
else:
|
||||
angles2 = np.linspace(90, 12, n_tot_angles)
|
||||
angle2, var2 = get_smallest_skew(img_resized, sigma_des, angles2, map=map, logger=logger, plotter=plotter)
|
||||
angle2, var2 = get_smallest_skew(img_resized, sigma_des, angles2, logger=logger, name=name, plotter=plotter)
|
||||
if var2 > var:
|
||||
angle = angle2
|
||||
else:
|
||||
angles = np.linspace(-25, 25, int(0.5 * n_tot_angles) + 10)
|
||||
angle, var = get_smallest_skew(img_resized, sigma_des, angles, map=map, logger=logger, plotter=plotter)
|
||||
angle, var = get_smallest_skew(img_resized, sigma_des, angles, logger=logger, name=name, plotter=plotter)
|
||||
|
||||
early_slope_edge=22
|
||||
if abs(angle) > early_slope_edge:
|
||||
|
|
@ -1557,23 +1555,21 @@ def return_deskew_slop(img_patch_org, sigma_des,n_tot_angles=100,
|
|||
angles2 = np.linspace(-90, -25, int(0.5 * n_tot_angles) + 10)
|
||||
else:
|
||||
angles2 = np.linspace(90, 25, int(0.5 * n_tot_angles) + 10)
|
||||
angle2, var2 = get_smallest_skew(img_resized, sigma_des, angles2, map=map, logger=logger, plotter=plotter)
|
||||
angle2, var2 = get_smallest_skew(img_resized, sigma_des, angles2, logger=logger, name=name, plotter=plotter)
|
||||
if var2 > var:
|
||||
angle = angle2
|
||||
# precision stage:
|
||||
angles = np.linspace(angle - 2.5, angle + 2.5, n_tot_angles // 2)
|
||||
angle, _ = get_smallest_skew(img_resized, sigma_des, angles, logger=logger, name=name, plotter=plotter)
|
||||
return angle
|
||||
|
||||
def get_smallest_skew(img, sigma_des, angles, logger=None, plotter=None, map=map):
|
||||
def get_smallest_skew(img, sigma_des, angles, logger=None, plotter=None, name=None):
|
||||
if logger is None:
|
||||
logger = getLogger(__package__)
|
||||
if map is None:
|
||||
results = [do_image_rotation.__wrapped__(angle, img=img, sigma_des=sigma_des, logger=logger)
|
||||
for angle in angles]
|
||||
else:
|
||||
with share_ndarray(img) as img_shared:
|
||||
results = list(map(partial(do_image_rotation, img=img_shared, sigma_des=sigma_des, logger=None),
|
||||
angles))
|
||||
results = [do_image_rotation(angle, img=img, sigma_des=sigma_des, logger=logger)
|
||||
for angle in angles]
|
||||
if plotter:
|
||||
plotter.save_plot_of_rotation_angle(angles, results)
|
||||
plotter.save_plot_of_rotation_angle(angles, results, name)
|
||||
try:
|
||||
var_res = np.array(results)
|
||||
assert var_res.any()
|
||||
|
|
@ -1586,13 +1582,11 @@ def get_smallest_skew(img, sigma_des, angles, logger=None, plotter=None, map=map
|
|||
var = 0
|
||||
return angle, var
|
||||
|
||||
@wrap_ndarray_shared(kw='textline_mask_tot_ea')
|
||||
@wrap_ndarray_shared(kw='mask_texts_only')
|
||||
def do_work_of_slopes_new_curved(
|
||||
box_text, contour_par,
|
||||
textline_mask_tot_ea=None, mask_texts_only=None,
|
||||
num_col=1, scale_par=1.0, slope_deskew=0.0,
|
||||
logger=None, MAX_SLOPE=999, KERNEL=None, plotter=None
|
||||
logger=None, MAX_SLOPE=999, KERNEL=None, plotter=None, name=None
|
||||
):
|
||||
if KERNEL is None:
|
||||
KERNEL = np.ones((5, 5), np.uint8)
|
||||
|
|
@ -1623,7 +1617,7 @@ def do_work_of_slopes_new_curved(
|
|||
else:
|
||||
sigma_des = max(1, int(y_diff_mean * (4.0 / 40.0)))
|
||||
img_int_p[img_int_p > 0] = 1
|
||||
slope_for_all = return_deskew_slop(img_int_p, sigma_des, logger=logger, plotter=plotter)
|
||||
slope_for_all = return_deskew_slop(img_int_p, sigma_des, logger=logger, name=name, plotter=plotter)
|
||||
if abs(slope_for_all) < 0.5:
|
||||
slope_for_all = slope_deskew
|
||||
except:
|
||||
|
|
@ -1682,7 +1676,6 @@ def do_work_of_slopes_new_curved(
|
|||
|
||||
return textlines_cnt_per_region[::-1], crop_coor, slope
|
||||
|
||||
@wrap_ndarray_shared(kw='textline_mask_tot_ea')
|
||||
def do_work_of_slopes_new_light(
|
||||
box_text, contour, contour_par,
|
||||
textline_mask_tot_ea=None, slope_deskew=0,
|
||||
|
|
|
|||
|
|
@ -370,8 +370,8 @@ def break_curved_line_into_small_pieces_and_then_merge(img_curved, mask_curved,
|
|||
return img_curved, img_bin_curved
|
||||
|
||||
def return_textline_contour_with_added_box_coordinate(textline_contour, box_ind):
|
||||
textline_contour[:,0] = textline_contour[:,0] + box_ind[2]
|
||||
textline_contour[:,1] = textline_contour[:,1] + box_ind[0]
|
||||
textline_contour[:,:,0] += box_ind[2]
|
||||
textline_contour[:,:,1] += box_ind[0]
|
||||
return textline_contour
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -2,11 +2,12 @@
|
|||
# pylint: disable=import-error
|
||||
from pathlib import Path
|
||||
import os.path
|
||||
from typing import Optional
|
||||
import logging
|
||||
from .utils.xml import create_page_xml, xml_reading_order
|
||||
from .utils.counter import EynollahIdCounter
|
||||
from typing import Optional
|
||||
import numpy as np
|
||||
from shapely import affinity, clip_by_rect
|
||||
|
||||
from ocrd_utils import points_from_polygon
|
||||
from ocrd_models.ocrd_page import (
|
||||
BorderType,
|
||||
CoordsType,
|
||||
|
|
@ -19,9 +20,13 @@ from ocrd_models.ocrd_page import (
|
|||
to_xml
|
||||
)
|
||||
|
||||
from .utils.xml import create_page_xml, xml_reading_order
|
||||
from .utils.counter import EynollahIdCounter
|
||||
from .utils.contour import contour2polygon, make_valid
|
||||
|
||||
class EynollahXmlWriter:
|
||||
|
||||
def __init__(self, *, dir_out, image_filename, curved_line, pcgts=None):
|
||||
def __init__(self, *, dir_out, image_filename, image_width, image_height, curved_line, pcgts=None):
|
||||
self.logger = logging.getLogger('eynollah.writer')
|
||||
self.counter = EynollahIdCounter()
|
||||
self.dir_out = dir_out
|
||||
|
|
@ -29,29 +34,23 @@ class EynollahXmlWriter:
|
|||
self.output_filename = os.path.join(self.dir_out or "", self.image_filename_stem) + ".xml"
|
||||
self.curved_line = curved_line
|
||||
self.pcgts = pcgts
|
||||
self.scale_x: Optional[float] = None # XXX set outside __init__
|
||||
self.scale_y: Optional[float] = None # XXX set outside __init__
|
||||
self.height_org: Optional[int] = None # XXX set outside __init__
|
||||
self.width_org: Optional[int] = None # XXX set outside __init__
|
||||
self.image_height = image_height
|
||||
self.image_width = image_width
|
||||
self.scale_x = 1.0
|
||||
self.scale_y = 1.0
|
||||
|
||||
@property
|
||||
def image_filename_stem(self):
|
||||
return Path(Path(self.image_filename).name).stem
|
||||
|
||||
def calculate_page_coords(self, cont_page):
|
||||
self.logger.debug('enter calculate_page_coords')
|
||||
points_page_print = ""
|
||||
for _, contour in enumerate(cont_page[0]):
|
||||
if len(contour) == 2:
|
||||
points_page_print += str(int((contour[0]) / self.scale_x))
|
||||
points_page_print += ','
|
||||
points_page_print += str(int((contour[1]) / self.scale_y))
|
||||
else:
|
||||
points_page_print += str(int((contour[0][0]) / self.scale_x))
|
||||
points_page_print += ','
|
||||
points_page_print += str(int((contour[0][1] ) / self.scale_y))
|
||||
points_page_print = points_page_print + ' '
|
||||
return points_page_print[:-1]
|
||||
def calculate_points(self, contour, offset=None):
|
||||
self.logger.debug('enter calculate_points')
|
||||
poly = contour2polygon(contour)
|
||||
if offset is not None:
|
||||
poly = affinity.translate(poly, *offset)
|
||||
poly = affinity.scale(poly, xfact=1 / self.scale_x, yfact=1 / self.scale_y, origin=(0, 0))
|
||||
poly = make_valid(clip_by_rect(poly, 0, 0, self.image_width, self.image_height))
|
||||
return points_from_polygon(poly.exterior.coords[:-1])
|
||||
|
||||
def serialize_lines_in_region(self, text_region, all_found_textline_polygons, region_idx, page_coord, all_box_coord, slopes, counter, ocr_all_textlines_textregion):
|
||||
self.logger.debug('enter serialize_lines_in_region')
|
||||
|
|
@ -64,16 +63,12 @@ class EynollahXmlWriter:
|
|||
text_region.add_TextLine(textline)
|
||||
text_region.set_orientation(-slopes[region_idx])
|
||||
region_bboxes = all_box_coord[region_idx]
|
||||
points_co = ''
|
||||
for point in polygon_textline:
|
||||
if len(point) != 2:
|
||||
point = point[0]
|
||||
point_x = point[0] + page_coord[2]
|
||||
point_y = point[1] + page_coord[0]
|
||||
point_x = max(0, int(point_x / self.scale_x))
|
||||
point_y = max(0, int(point_y / self.scale_y))
|
||||
points_co += f'{point_x},{point_y} '
|
||||
coords.set_points(points_co[:-1])
|
||||
offset = [page_coord[2], page_coord[0]]
|
||||
# FIXME: or actually... self.curved_line or np.abs(slopes[region_idx]) > 45?
|
||||
if self.curved_line and np.abs(slopes[region_idx]) > 45:
|
||||
offset[0] += region_bboxes[2]
|
||||
offset[1] += region_bboxes[0]
|
||||
coords.set_points(self.calculate_points(polygon_textline, offset))
|
||||
|
||||
def write_pagexml(self, pcgts):
|
||||
self.logger.info("output filename: '%s'", self.output_filename)
|
||||
|
|
@ -166,11 +161,16 @@ class EynollahXmlWriter:
|
|||
self.logger.debug('enter build_pagexml')
|
||||
|
||||
# create the file structure
|
||||
pcgts = self.pcgts if self.pcgts else create_page_xml(self.image_filename, self.height_org, self.width_org)
|
||||
pcgts = self.pcgts if self.pcgts else create_page_xml(
|
||||
self.image_filename, self.image_height, self.image_width)
|
||||
page = pcgts.get_Page()
|
||||
assert page
|
||||
page.set_Border(BorderType(Coords=CoordsType(points=self.calculate_page_coords(cont_page))))
|
||||
if len(cont_page):
|
||||
page.set_Border(BorderType(Coords=CoordsType(points=self.calculate_points(cont_page[0]))))
|
||||
|
||||
if skip_layout_reading_order:
|
||||
offset = None
|
||||
else:
|
||||
offset = [page_coord[2], page_coord[0]]
|
||||
counter = EynollahIdCounter()
|
||||
if len(order_of_texts):
|
||||
_counter_marginals = EynollahIdCounter(region_idx=len(order_of_texts))
|
||||
|
|
@ -183,8 +183,7 @@ class EynollahXmlWriter:
|
|||
for mm, region_contour in enumerate(found_polygons_text_region):
|
||||
textregion = TextRegionType(
|
||||
id=counter.next_region_id, type_='paragraph',
|
||||
Coords=CoordsType(points=self.calculate_polygon_coords(region_contour, page_coord,
|
||||
skip_layout_reading_order))
|
||||
Coords=CoordsType(points=self.calculate_points(region_contour, offset))
|
||||
)
|
||||
assert textregion.Coords
|
||||
if conf_contours_textregions:
|
||||
|
|
@ -201,7 +200,7 @@ class EynollahXmlWriter:
|
|||
for mm, region_contour in enumerate(found_polygons_text_region_h):
|
||||
textregion = TextRegionType(
|
||||
id=counter.next_region_id, type_='heading',
|
||||
Coords=CoordsType(points=self.calculate_polygon_coords(region_contour, page_coord))
|
||||
Coords=CoordsType(points=self.calculate_points(region_contour, offset))
|
||||
)
|
||||
assert textregion.Coords
|
||||
if conf_contours_textregions_h:
|
||||
|
|
@ -217,7 +216,7 @@ class EynollahXmlWriter:
|
|||
for mm, region_contour in enumerate(found_polygons_marginals_left):
|
||||
marginal = TextRegionType(
|
||||
id=counter.next_region_id, type_='marginalia',
|
||||
Coords=CoordsType(points=self.calculate_polygon_coords(region_contour, page_coord))
|
||||
Coords=CoordsType(points=self.calculate_points(region_contour, offset))
|
||||
)
|
||||
page.add_TextRegion(marginal)
|
||||
if ocr_all_textlines_marginals_left:
|
||||
|
|
@ -229,7 +228,7 @@ class EynollahXmlWriter:
|
|||
for mm, region_contour in enumerate(found_polygons_marginals_right):
|
||||
marginal = TextRegionType(
|
||||
id=counter.next_region_id, type_='marginalia',
|
||||
Coords=CoordsType(points=self.calculate_polygon_coords(region_contour, page_coord))
|
||||
Coords=CoordsType(points=self.calculate_points(region_contour, offset))
|
||||
)
|
||||
page.add_TextRegion(marginal)
|
||||
if ocr_all_textlines_marginals_right:
|
||||
|
|
@ -242,7 +241,7 @@ class EynollahXmlWriter:
|
|||
for mm, region_contour in enumerate(found_polygons_drop_capitals):
|
||||
dropcapital = TextRegionType(
|
||||
id=counter.next_region_id, type_='drop-capital',
|
||||
Coords=CoordsType(points=self.calculate_polygon_coords(region_contour, page_coord))
|
||||
Coords=CoordsType(points=self.calculate_points(region_contour, offset))
|
||||
)
|
||||
page.add_TextRegion(dropcapital)
|
||||
all_box_coord_drop = [[0, 0, 0, 0]]
|
||||
|
|
@ -257,33 +256,17 @@ class EynollahXmlWriter:
|
|||
for region_contour in found_polygons_text_region_img:
|
||||
page.add_ImageRegion(
|
||||
ImageRegionType(id=counter.next_region_id,
|
||||
Coords=CoordsType(points=self.calculate_polygon_coords(region_contour, page_coord))))
|
||||
Coords=CoordsType(points=self.calculate_points(region_contour, offset))))
|
||||
|
||||
for region_contour in polygons_seplines:
|
||||
page.add_SeparatorRegion(
|
||||
SeparatorRegionType(id=counter.next_region_id,
|
||||
Coords=CoordsType(points=self.calculate_polygon_coords(region_contour, [0, 0, 0, 0]))))
|
||||
Coords=CoordsType(points=self.calculate_points(region_contour, None))))
|
||||
|
||||
for region_contour in found_polygons_tables:
|
||||
page.add_TableRegion(
|
||||
TableRegionType(id=counter.next_region_id,
|
||||
Coords=CoordsType(points=self.calculate_polygon_coords(region_contour, page_coord))))
|
||||
Coords=CoordsType(points=self.calculate_points(region_contour, offset))))
|
||||
|
||||
return pcgts
|
||||
|
||||
def calculate_polygon_coords(self, contour, page_coord, skip_layout_reading_order=False):
|
||||
self.logger.debug('enter calculate_polygon_coords')
|
||||
coords = ''
|
||||
for point in contour:
|
||||
if len(point) != 2:
|
||||
point = point[0]
|
||||
point_x = point[0]
|
||||
point_y = point[1]
|
||||
if not skip_layout_reading_order:
|
||||
point_x += page_coord[2]
|
||||
point_y += page_coord[0]
|
||||
point_x = int(point_x / self.scale_x)
|
||||
point_y = int(point_y / self.scale_y)
|
||||
coords += str(point_x) + ',' + str(point_y) + ' '
|
||||
return coords[:-1]
|
||||
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ def test_run_eynollah_binarization_filename(
|
|||
'-o', str(outfile),
|
||||
] + options,
|
||||
[
|
||||
'Predicting'
|
||||
'Loaded model'
|
||||
]
|
||||
)
|
||||
assert outfile.exists()
|
||||
|
|
@ -46,8 +46,8 @@ def test_run_eynollah_binarization_directory(
|
|||
'-o', str(outdir),
|
||||
],
|
||||
[
|
||||
f'Predicting {image_resources[0].name}',
|
||||
f'Predicting {image_resources[1].name}',
|
||||
f'Binarizing [ 1/2] {image_resources[0].name}',
|
||||
f'Binarizing [ 2/2] {image_resources[1].name}',
|
||||
]
|
||||
)
|
||||
assert len(list(outdir.iterdir())) == 2
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
sacred
|
||||
seaborn
|
||||
numpy <1.24.0
|
||||
numpy
|
||||
tqdm
|
||||
imutils
|
||||
scipy
|
||||
tensorflow-addons # for connected_components
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue